#cp sim_baclofen_combe_rhythmic_syn_inputs_inmdarec_carec.py drawmorph.py

from neuron import h
from neuron import rxd
import mytools
import pickle
import numpy as np
import sys
import time
from pylab import *


tstop = 21000
T = 20000
Tinputs = 0.1 #100Hz
Ninputs = 100 #1 sec
Nsyn = 50
myseed = 1

baclofen_max = 100.0 #uM; the maximal agonist concentration
somaticIscale = 1.0
dodraw = 1
blockedStr = ''
params_all = ['107.034,514.651,167.633,11.026,66.931,3.009,1.904,2.241,1507.740','91.805,846.811,67.856,9.858,71.217,5.000,1.817,2.170,3000.000','133.092,528.522,185.581,8.574,25.868,3.186,2.125,2.729,2815.426','108.557,729.183,292.710,2.473,16.923,4.406,2.127,2.973,696.729','111.531,708.687,165.616,10.797,93.218,2.903,1.728,2.235,2797.567','109.869,469.365,114.137,59.053,73.263,3.573,2.385,3.028,2852.993','84.387,57.393,126.013,4.253,57.544,5.000,1.722,2.122,929.204','126.993,135.047,348.945,3.733,18.257,3.800,2.647,3.237,2804.842','105.438,642.913,142.107,19.602,40.482,3.599,2.391,2.974,1386.488','91.403,459.570,97.658,3.443,65.833,5.000,1.540,1.942,2447.114','71.410,380.140,143.521,2.991,45.796,10.000,2.044,2.836,2126.204','146.055,702.846,310.725,3.041,13.076,3.334,2.351,3.166,2112.405','90.417,993.718,98.403,6.879,95.287,3.944,1.572,2.204,1413.654','119.681,591.014,91.960,89.687,88.948,3.453,2.300,2.755,2950.763','123.170,616.560,117.913,74.842,46.625,3.952,2.609,3.283,3000.000']

params_str = 'p0'
GIRKcond_single = 33e-3 #nS #https://www.jneurosci.org/content/jneuro/25/15/3787.full.pdf

if len(sys.argv) > 1:
  Nsyn = int(float(sys.argv[1]))
if len(sys.argv) > 2:
  myseed = int(float(sys.argv[2]))

for icell in range(0,1):
  v0 = -80
  ca0 = 0.0001
  distalpoint = 300

  h("""
load_file("stdlib.hoc")
load_file("stdrun.hoc")
objref cvode
cvode = new CVode()
cvode.active(1)
cvode.atol(1e-7)
cvode.maxstep(5)
load_file("loadcell.hoc")
objref st1, vclamp, vclamp_irec
soma st1 = new IClamp(0.5)

objref vsoma, vdend, cadend, casoma
vsoma = new Vector()
casoma = new Vector()
vdend = new Vector()
cadend = new Vector()
objref sl,ns,tvec
tvec = new Vector()
sl = new List()
double siteVec[2]
sl = locateSites("apic","""+str(distalpoint)+""")
maxdiam = 0
for(i=0;i<sl.count();i+=1){
  dd1 = sl.o[i].x[1]
  dd = apic[sl.o[i].x[0]].diam(dd1)
  if (dd > maxdiam) {
    j = i
    maxdiam = dd
  }
}
siteVec[0] = sl.o[j].x[0]
siteVec[1] = sl.o[j].x[1]
apic[siteVec[0]] cvode.record(&v(siteVec[1]),vdend,tvec)
apic[siteVec[0]] cvode.record(&cai(siteVec[1]),cadend,tvec)
soma cvode.record(&v(0.5),vsoma,tvec)
soma cvode.record(&cai(0.5),casoma,tvec)
""")

  h("""
objref syns, precons, nilstim, recsyns, carecsyns
syns = new List()
precons = new List()
recsyns = new List()
carecsyns = new List()
""")
  
  f,axarr = subplots(1,1)
  plotteds = []
  dists = []
  dist2s = []
  mynums0 = []
  mynums1 = []
  for itree in range(0,4):
    if itree == 0:
      nsec = len(h.dend)
    elif itree == 1:
      nsec = len(h.trunk)
    elif itree == 2:
      nsec = len(h.apic)
    else:
      nsec = 1

    for j in range(nsec-1,-1,-1):
      if itree == 0:
        h("access dend["+str(j)+"]")
      elif itree == 1:
        h("access trunk["+str(j)+"]")
      elif itree == 2:
        h("access apic["+str(j)+"]")
      else:
        h("access soma")
    
      h("""
myn = n3d()
myx0 = x3d(0)
myy0 = y3d(0)
myz0 = z3d(0)
""")
      oldcoord = [h.myx0, h.myy0, h.myz0]
      for k in range(1,int(h.myn)):
        h("""
myx0 = x3d("""+str(k)+""")
myy0 = y3d("""+str(k)+""")
myz0 = z3d("""+str(k)+""")
mydiam = diam""")

        coords = [[0.5*(h.myx0+oldcoord[0]),0.5*(h.myy0+oldcoord[1])] for i in range(0,int(h.n3d()))]
        mychr = '77'
        col = "#"+'00'+mychr+'00'
      
        axarr.plot([oldcoord[0],h.myx0],[oldcoord[1],h.myy0],'k-',linewidth=h.mydiam*0.25,color=col)
        plotteds.append([[oldcoord[0],h.myx0],[oldcoord[1],h.myy0],'k-',h.mydiam*0.25,col])
        oldcoord = [h.myx0, h.myy0, h.myz0]

  secs = ['apic['+str(i)+']' for i in range(0,72)]
  areas = [h.apic[i].L*pi*h.apic[i].diam for i in range(0,72)]
  cumps = cumsum(areas)/sum(areas)
  seed(myseed)
  for isyn in range(0,Nsyn):
    r = rand()
    x = rand()
    iapic = min([i for i in range(0,72) if r < cumps[i]])
    h("""apic["""+str(iapic)+"""] syns.append(new AMPANMDA("""+str(x)+"""))
        apic["""+str(iapic)+"""] myx0 = x3d(int("""+str(x)+"""*n3d()))
        apic["""+str(iapic)+"""] myy0 = y3d(int("""+str(x)+"""*n3d()))
        apic["""+str(iapic)+"""] myz0 = z3d(int("""+str(x)+"""*n3d()))
        syns.o[syns.count()-1].gAMPAmax = 0.0003
        syns.o[syns.count()-1].gNMDAmax = 0.0003
        syns.o[syns.count()-1].MgCon = 2
        precons.append(new NetCon(nilstim, syns.o[syns.count()-1]))
        precons.o[precons.count()-1].weight = 1
        precons.o[precons.count()-1].delay = 0
        recsyns.append(new Vector())
        recsyns.o[recsyns.count()-1].record(&syns.o[syns.count()-1].i_NMDA, 1.0)
        carecsyns.append(new Vector())
        apic["""+str(iapic)+"""] carecsyns.o[carecsyns.count()-1].record(&cai("""+str(x)+"""), 1.0)
""")
    axarr.plot(h.myx0,h.myy0,'cx',color='#00FFFF', linewidth=0.4, ms=1.4, mew=0.7)
    plotteds.append([[h.myx0],[h.myy0],'cx',0.4,'#00FFFF',1.4,0.7])
    
axis("equal")
axarr.set_position([0.1,0.1,0.3,0.3])

axarr.plot([100,100,200],[250,150,150],'k-')
axarr.text(150,140,'100 $\mu$m',fontsize=5,va='top',ha='center')
axarr.text(95,200,'100 $\mu$m',fontsize=5,va='center',ha='right')

pos = axarr.get_position()
f.text(pos.x0 - 0.03, pos.y1 - 0.02, chr(ord('A')), fontsize=10)

axarr.set_xticks([])
axarr.set_yticks([])
axarr.spines['left'].set_visible(False)
axarr.spines['right'].set_visible(False)
axarr.spines['top'].set_visible(False)
axarr.spines['bottom'].set_visible(False)

file = open('morph_Nsyn'+str(Nsyn)+'.sav', 'wb')
pickle.dump(plotteds,file)
file.close()

f.savefig("morph_Nsyn"+str(Nsyn)+".pdf")