from pylab import *
import scipy.io
from os.path import exists
import mytools
import myvenn

NsampPerGridFile = 100
Nperpop = 40

stimAmp1 = [120,125,130,140,150]
gAMPA1 = [12.5, 15.0, 17.5, 20.0, 22.5]
gAMPA2 = [25.0, 30.0, 35.0, 40.0, 45.0]
gAMPA3 = [80.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0]
nmdaAmpaRatio = [0.5, 0.333]
gGABA = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
pvs = [0.9, 0.95]
min_gluts = [0.0]
tauNeur1s = [10.0]
tauNeur2s = [200.0, 250.0]

simIDs = list(range(0,len(stimAmp1)*len(gAMPA1)*len(gAMPA2)*len(gAMPA3)*len(nmdaAmpaRatio)*len(gGABA)*len(pvs)*len(min_gluts)*len(tauNeur1s)*len(tauNeur2s)))

iparAttrs = [x+1 for x in [3,4,5,6,7,8,9,10,11,14,15,16,17,19]]
parNames = ['stimAmp1','stimAmp2','gAMPA1', 'gAMPA2', 'gAMPA3', 'gNMDA1', 'gNMDA2', 'gNMDA3', 'gGABA', 'min_gluts', 'pvs', 'tauNeur1s', 'tauNeur2s']
if True:

  close('all')

  
  #for iax in range(0,len(axarr)):
  #  axarr[iax].tick_params(axis='both', which='major', labelsize=4)
  #  for axis in ['top','bottom','left','right']:
  #    axarr[iax].spines[axis].set_linewidth(0.5)

  filenames_all = []
  strs_all = []
  inperouts_all = []
  NoutputInsides_all = []
  NoutputOutsides_all = []
  for igrid in range(0,900):
    print('Working on '+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile))
    if exists('gridsearch'+str(Nperpop)+'_2pm_sep7_'+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile)+'.mat'):
      A = scipy.io.loadmat('gridsearch'+str(Nperpop)+'_2pm_sep7_'+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile)+'.mat')
    else:
      print('gridsearch'+str(Nperpop)+'_2pm_sep7_'+str(igrid*NsampPerGridFile)+'-'+str(igrid*NsampPerGridFile+NsampPerGridFile)+'.mat does not exist')
      continue
    inperouts_this = []
    NoutputInsides_this = []
    NoutputOutsides_this = []
    for isamp in range(0,NsampPerGridFile):
      inperouts_thisMMN = []
      NoutputInsides_thisMMN = []
      NoutputOutsides_thisMMN = []
      for iMMN in range(0,4):
        Nbetweens = A['NbetweensUnique_all'][isamp][iMMN]
        NoutputInsides = A['NoutputInsidesUnique_all'][isamp][iMMN]
        NoutputOutsides = [A['NoutputOutsidesUnique_all'][isamp][iMMN][i] for i in [1,2,3,4,5]] #Exclude the first one since it's allowed to give a "deviant-like" output. Consider leaving out indices 1 and 4?

        inperout = 0 if NoutputInsides == 0 else (NoutputInsides/sum(NoutputOutsides) if sum(NoutputOutsides) > 0 else 10+NoutputInsides)
        inperouts_thisMMN.append(inperout)
        NoutputInsides_thisMMN.append(NoutputInsides)
        NoutputOutsides_thisMMN.append(sum(NoutputOutsides))

      inperouts_this.append(inperouts_thisMMN[:])
      NoutputInsides_this.append(NoutputInsides_thisMMN[:])
      NoutputOutsides_this.append(NoutputOutsides_thisMMN[:])
      
    inperouts_all.append(inperouts_this[:])
    NoutputInsides_all.append(NoutputInsides_this[:])
    NoutputOutsides_all.append(NoutputOutsides_this[:])
    
    filenames_all = r_[filenames_all,A['filenames_all'][:]]
    strs_all = r_[strs_all,A['strs_all'][:]]

  inperouts_vec = []
  NoutputInsides_vec = []
  NoutputOutsides_vec = []
  for i in range(0,len(inperouts_all)):
    print('Concatenating '+str(i)+'/'+str(len(inperouts_all)))
    inperouts_vec = inperouts_vec+inperouts_all[i]
    NoutputInsides_vec = NoutputInsides_vec + NoutputInsides_all[i]
    NoutputOutsides_vec = NoutputOutsides_vec + NoutputOutsides_all[i]

  goodones = [i for i in range(0,len(inperouts_vec)) if inperouts_vec[i][0] > 0.4 and inperouts_vec[i][1] > 0.4 and inperouts_vec[i][2] > 0.4 and NoutputInsides_vec[i][0] >= 7 and NoutputInsides_vec[i][1] >= 7 and NoutputInsides_vec[i][2] >= 7]
  cols_good = mytools.colorsredtolila(len(goodones),0.7)

  def mystr(x):
    s = str(x)
    if '00000' in s:
      s = s[0:s.find('00000')+1] #Doesn't work always (e.g. try mystr(480000001)), but good enough here
      return s
    if '99999' in s:
      i = s.find('99999')
      if i < 2:
        return s
      if s[i-1] == '.':
        myi = i-2
        issuccess = 0
        n9s = 0
        while myi >= 0:
          if s[myi] != '9':
            issuccess = 1
            break
          myi = myi-1
          n9s = n9s + 1
        if issuccess:
          return s[0:myi]+str(int(s[myi])+1)+'0'*(n9s)
        return s
      if s.find('.') == -1:
        return s[0:i-1]+str(int(s[i-1])+1)+'0'*(len(s)-i)
      if s.find('.') > i:
        return s[0:i-1]+str(int(s[i-1])+1)+'0'*(s.find('.')-i)
      return s[0:i-1]+str(int(s[i-1])+1)
    return s
              
  def changeNperpopAndScale(command_str, NperpopNew, AMPA_scale = 1.0, NMDA_scale = 1.0, GABA_scale = 1.0):
    splitted = command_str.split(' ')
    if int(splitted[2]) != Nperpop:
      print('Error: splitted[2] != Nperpop')
      return ''
    splitted[2] = str(NperpopNew)
    splitted[4] = mystr(float(splitted[4])*AMPA_scale)
    splitted[5] = mystr(float(splitted[5])*AMPA_scale)
    splitted[6] = mystr(float(splitted[6])*NMDA_scale)
    splitted[7] = mystr(float(splitted[7])*NMDA_scale)
    splitted[8] = mystr(float(splitted[8])*GABA_scale)
    return ' '.join(splitted)


  region_counts = {
        "0000": 0,
        "1000": 0,
        "0100": 0,
        "0010": 0,
        "0001": 0,
        "1100": 0,
        "1010": 0,
        "1001": 0,
        "0110": 0,
        "0101": 0,
        "0011": 0,
        "1110": 0,
        "1101": 0,
        "1011": 0,
        "0111": 0,
        "1111": 0
  }
    
  labels = ["Frequency deviant", "Omission", "Duration deviant", "Inverse duration deviant"]
  for M1 in [0,1]:
    for M2 in [0,1]:
      for M3 in [0,1]:
        for M4 in [0,1]:
          #region_counts[str(M1)+str(M2)+str(M3)+str(M4)] = len([strs_all[i] for i in range(0,len(inperouts_vec)) if (inperouts_vec[i][0] >= 1 and NoutputInsides_vec[i][0] >= 32 or not M1) and
          #                                                      (inperouts_vec[i][1] >= 1 and NoutputInsides_vec[i][1] >= 32 or not M2) and
          #                                                      (inperouts_vec[i][2] >= 1 and NoutputInsides_vec[i][2] >= 32 or not M3) and
          #                                                      (inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][3] >= 32 or not M4)])
          region_counts[str(M1)+str(M2)+str(M3)+str(M4)] = len([strs_all[i] for i in range(0,len(inperouts_vec)) if (inperouts_vec[i][0] >= 1 and NoutputInsides_vec[i][0] >= 32) == M1 and
                                                                 (inperouts_vec[i][1] >= 1 and NoutputInsides_vec[i][1] >= 32) == M2 and
                                                                 (inperouts_vec[i][2] >= 1 and NoutputInsides_vec[i][2] >= 32) == M3 and
                                                                 (inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][3] >= 32) == M4])
  print(str(region_counts))
  vennfig,vennax = vennfromchatgpt.draw_venn_from_counts(region_counts, labels)

  print('#parameters with correct detection of all protocols: '+str(len([print(strs_all[i]) for i in range(0,len(inperouts_vec)) if inperouts_vec[i][0] >= 1 and inperouts_vec[i][1] >= 1 and inperouts_vec[i][2] >= 1 and inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][0] >= 32 and NoutputInsides_vec[i][1] >= 32 and NoutputInsides_vec[i][2] >= 32 and NoutputInsides_vec[i][3] >= 32 ])))
  print('#parameters with correct detection of all protocols: '+str(len([print(filenames_all[i]) for i in range(0,len(inperouts_vec)) if inperouts_vec[i][0] >= 1 and inperouts_vec[i][1] >= 1 and inperouts_vec[i][2] >= 1 and inperouts_vec[i][3] >= 1 and NoutputInsides_vec[i][0] >= 32 and NoutputInsides_vec[i][1] >= 32 and NoutputInsides_vec[i][2] >= 32 and NoutputInsides_vec[i][3] >= 32 ])))    
  print('Saving fig_venn.pdf')

  vennax.set_position([0.1,0.45,0.8,0.5])
  pos = vennax.get_position()
  vennfig.text(pos.x0 + 0.03, pos.y1 - 0.155, 'A', fontsize=11)
  vennfig.savefig('fig_venn.pdf')