from neuron import h
from neuron import rxd
import mytools
import pickle
import numpy as np
import sys
import time
from pylab import *

stim_Nsyn_exc = 500          #Number of synapses along the dendrites
NTcoeff = 0.013             #from drawfitNTcoeff.py
loc = 'apic650-1300'
myseed = 1
dodraw = 1

if len(sys.argv) > 1:
  stim_Nsyn_exc = int(float(sys.argv[1]))
if len(sys.argv) > 2:
  loc = sys.argv[2]
if len(sys.argv) > 3:
  myseed = int(float(sys.argv[3]))
if len(sys.argv) > 4:
  dodraw = int(float(sys.argv[4])) 

loc_tree = loc[0:4] #'apic' or 'dend' or 'both'
loc_start = int(loc[4:].split('-')[0]) #default 0
loc_end = int(loc[4:].split('-')[1]) #default 1300

seed(myseed)

for icell in range(0,1):
  morphology_file = "hay/cell"+str(icell+1)+".asc"
  biophys_file = "hay/L5PCbiophys3.hoc"
  template_file = "hay/L5PCtemplate_withsyns.hoc"
  v0 = -80
  ca0 = 0.0001

  distalpoints = [200,400,600,800]
  BACdt = 5.0

  h("""
load_file("stdlib.hoc")
load_file("stdrun.hoc")
objref cvode
cvode = new CVode()
cvode.active(1)
cvode.atol(1e-7)
load_file("import3d.hoc")
objref L5PC
load_file(\""""+biophys_file+"""\")
load_file(\""""+template_file+"""\")
L5PC = new L5PCtemplate(\""""+morphology_file+"""\")
access L5PC.soma
objref st1, st2

objref vsoma, vdends, recSite, isoma, cadends, casoma
vsoma = new Vector()
casoma = new Vector()
vdends = new List()
cadends = new List()
objref sl,ns,tvec, vrec, carec, icarec, caspinerec
vrec = new Vector()
carec = new Vector()
icarec = new Vector()
caspinerec = new Vector()
tvec = new Vector()
objref sl,ns,tvec
tvec = new Vector()
sl = new List()
double siteVec[2]
""")

  for idist in range(0,len(distalpoints)):
    h("""  
vdends.append(new Vector())
cadends.append(new Vector())
sl = L5PC.locateSites("apic","""+str(distalpoints[idist])+""")
maxdiam = 0
for(i=0;i<sl.count();i+=1){
  dd1 = sl.o[i].x[1]
  dd = L5PC.apic[sl.o[i].x[0]].diam(dd1)
  if (dd > maxdiam) {
    j = i
    maxdiam = dd
  }
}
siteVec[0] = sl.o[j].x[0]
siteVec[1] = sl.o[j].x[1]
L5PC.apic[siteVec[0]] cvode.record(&v(siteVec[1]),vdends.o[vdends.count()-1],tvec)
L5PC.apic[siteVec[0]] cvode.record(&cai(siteVec[1]),cadends.o[cadends.count()-1],tvec)
""")

  h("""
L5PC.soma cvode.record(&v(0.5),vsoma,tvec)
L5PC.soma cvode.record(&cai(0.5),casoma,tvec)
Napic = 0
Nbasal = 0
forsec L5PC.apical { Napic = Napic + 1 }
forsec L5PC.basal { Nbasal = Nbasal + 1 }
objref diams
""")

  secnames = []  #This will have the names of the sections where inputs will be placed
  isecs = [] #This will have the indices (to secnames) where the synapses will be located
  xsecs = [] #This will have the x along the section where the synapses will be located
  trees = [] #This will have the name of the tree

  for isyn in range(0,stim_Nsyn_exc):
    if isyn % 100 == 0:
      print("Setting AMPA/NMDA synapse "+str(isyn)+" / "+str(stim_Nsyn_exc))
    thisdist = loc_start+(loc_end-loc_start)*rand()
    loc_tree_this = loc_tree
    if loc_tree == 'both':
      loc_tree_this = 'dend' if rand() < 5133.491987391413/(5133.491987391413+7440.905948081189) else 'apic' #probability according to total dendritic length
    while loc_tree_this == 'dend' and thisdist > 282 or loc_tree_this == 'apic' and thisdist > 1300:
      thisdist = loc_start+(loc_end-loc_start)*rand()

    h("""sl = L5PC.locateSites(\""""+loc_tree_this+"""\","""+str(thisdist)+""")
diams = new Vector()
for(i=0;i<sl.count();i+=1){
  dd1 = sl.o[i].x[1]   
  dd = L5PC.apic[sl.o[i].x[0]].diam(dd1)
  diams.append(dd)
}                                       
""")
    probs = array(h.diams)/sum(array(h.diams))
    cumprobs = cumsum(probs)
    r = rand()
    myi = [j for j in range(0,len(cumprobs)) if cumprobs[j] >= r][0]
    h("""
siteVec[0] = sl.o["""+str(myi)+"""].x[0]          
siteVec[1] = sl.o["""+str(myi)+"""].x[1]          
""")
    secname = loc_tree_this+"["+str(int(h.siteVec[0]))+"]"
    if secname not in secnames:
      secnames.append(secname)
    isecname = [i for i in range(0,len(secnames)) if secname == secnames[i]][0]
    isecs.append(isecname)
    xsecs.append(h.siteVec[1])
    trees.append(loc_tree_this)

  f,axarr = subplots(1,1)
  plotteds = []

  dists = []
  dist2s = []
  mynums0 = []
  mynums1 = []
  coords_alltrees = []
  coords_alltrees_highres = []
  for itree in range(0,3):
    if itree == 0:
      nsec = len(h.L5PC.dend)
    elif itree == 1:
      nsec = len(h.L5PC.apic)
    else:
      nsec = 1

    coords_thistree = []
    coords_thistree_highres = []
    for j in range(nsec-1,-1,-1):
      coords_thistree.append([])
      coords_thistree_highres.append([])
    for j in range(nsec-1,-1,-1):
      if itree == 0:
        h("access L5PC.dend["+str(j)+"]")
      elif itree == 1:
        h("access L5PC.apic["+str(j)+"]")
      else:
        h("access L5PC.soma")
      h("tmpvarx = x3d(0)")
      h("tmpvary = y3d(0)")
      h("tmpvarz = z3d(0)")
      h("tmpvarx2 = x3d(n3d()-1)")
      h("tmpvary2 = y3d(n3d()-1)")
      h("tmpvarz2 = z3d(n3d()-1)")
      coord1 = [h.tmpvarx,h.tmpvary,h.tmpvarz]
      coord2 = [h.tmpvarx2,h.tmpvary2,h.tmpvarz2]
      #for j in range(0,nrecsperseg):
      thisdist = h.distance(0.5)
      dists.append(h.distance(0.5))

      coords_thistree[j] = [coord1,coord2]

      h("""
myn = n3d()
myx0 = x3d(0)
myy0 = y3d(0)
myz0 = z3d(0)
""")
      oldcoord = [h.myx0, h.myy0, h.myz0]
      #coords_thistree_highres[j] = coords_thistree_highres[j] + [[h.myx0,h.myy0]]
      for k in range(1,int(h.myn)):
        h("""
myx0 = x3d("""+str(k)+""")
myy0 = y3d("""+str(k)+""")
myz0 = z3d("""+str(k)+""")
mydiam = diam""")

        coords = [[0.5*(h.myx0+oldcoord[0]),0.5*(h.myy0+oldcoord[1])] for i in range(0,int(h.n3d()))]
        col = "#008800"
      
        axarr.plot([oldcoord[0],h.myx0],[oldcoord[1],h.myy0],'k-',linewidth=h.mydiam*0.25,color=col)
        plotteds.append([[oldcoord[0],h.myx0],[oldcoord[1],h.myy0],'k-',h.mydiam*0.25,col])
        coords_thistree_highres[j] = coords_thistree_highres[j] + [[[oldcoord[0],h.myx0],[oldcoord[1],h.myy0]]]
        oldcoord = [h.myx0, h.myy0, h.myz0]
    coords_alltrees.append(coords_thistree[:])
    coords_alltrees_highres.append(coords_thistree_highres[:])
  axis("equal")

  for isyn in range(0,len(isecs)):
    itree = 0 if trees[isyn] == 'dend' else 1
    secname = secnames[isecs[isyn]]
    isec = int(secname[5:-1])
    nseg = len(coords_alltrees_highres[itree][isec])-1
    iseg = int(xsecs[isyn]*nseg)
    xseg = (xsecs[isyn]-iseg/nseg)*nseg
    axarr.plot(coords_alltrees_highres[itree][isec][iseg][0][0]+(coords_alltrees_highres[itree][isec][iseg][0][1]-coords_alltrees_highres[itree][isec][iseg][0][0])*xseg,
               coords_alltrees_highres[itree][isec][iseg][1][0]+(coords_alltrees_highres[itree][isec][iseg][1][1]-coords_alltrees_highres[itree][isec][iseg][1][0])*xseg,'kx',lw=0.3,ms=0.5,mew=0.5)
    plotteds.append([coords_alltrees_highres[itree][isec][iseg][0][0]+(coords_alltrees_highres[itree][isec][iseg][0][1]-coords_alltrees_highres[itree][isec][iseg][0][0])*xseg,
                     coords_alltrees_highres[itree][isec][iseg][1][0]+(coords_alltrees_highres[itree][isec][iseg][1][1]-coords_alltrees_highres[itree][isec][iseg][1][0])*xseg,'kx',0.3,0.5,0.5])

  f.savefig("morph_extNTsyns"+str(stim_Nsyn_exc)+"_"+loc+".pdf")
  file = open("morph_extNTsyns"+str(stim_Nsyn_exc)+"_"+loc+".sav", 'wb')
  pickle.dump(plotteds,file)
  file.close()