from neuron import h
import matplotlib
matplotlib.use('Agg')
import numpy
from pylab import *
import mytools
import pickle
import time
import sys
import random

v0 = -80
ca0 = 0.0001
BACdt = 5.0
fs = 8
tstop = 13000.0
icell = 0

h("""
load_file("stdlib.hoc")
load_file("stdrun.hoc")
objref cvode           
cvode = new CVode()    
cvode.active(1)        
cvode.atol(0.00005)    
load_file("import3d.hoc")
objref L5PC      
load_file(\"hay/L5PCbiophys3.hoc\")
load_file(\"hay/L5PCtemplate.hoc\")
L5PC = new L5PCtemplate(\"hay/cell1.asc\")
access L5PC.soma
""")

radius = 45
if len(sys.argv) > 1:
  radius = int(float(sys.argv[1]))

close("all")
f,axarr = subplots(1,1)
plotteds = []
axarr.plot(0,650,'bx',lw=0.5,ms=4)
axarr.plot([radius*cos(2*pi*x/50) for x in range(0,51)],[650+radius*sin(2*pi*x/50) for x in range(0,51)],'b-',lw=0.5)
plotteds.append([[0],[650],'bx',0.5,'#0000FF'])
plotteds.append([[radius*cos(2*pi*x/50) for x in range(0,51)],[650+radius*sin(2*pi*x/50) for x in range(0,51)],'b-',0.5,'#0000FF'])
dists = []
dist2s = []
mynums0 = []
mynums1 = []
for itree in range(0,3):
  if itree == 0:
    nsec = len(h.L5PC.dend)
  elif itree == 1:
    nsec = len(h.L5PC.apic)
  else:
    nsec = 1

  for j in range(nsec-1,-1,-1):
    if itree == 0:
      h("access L5PC.dend["+str(j)+"]")
    elif itree == 1:
      h("access L5PC.apic["+str(j)+"]")
    else:
      h("access L5PC.soma")
    h("tmpvarx = x3d(0)")
    h("tmpvary = y3d(0)")
    h("tmpvarz = z3d(0)")
    h("tmpvarx2 = x3d(n3d()-1)")
    h("tmpvary2 = y3d(n3d()-1)")
    h("tmpvarz2 = z3d(n3d()-1)")
    coord1 = [h.tmpvarx,h.tmpvary,h.tmpvarz]
    coord2 = [h.tmpvarx2,h.tmpvary2,h.tmpvarz2]
    thisdist = h.distance(0.5)
    dists.append(h.distance(0.5))

    h("""
myn = n3d()
myx0 = x3d(0)
myy0 = y3d(0)
myz0 = z3d(0)
""")
    oldcoord = [h.myx0, h.myy0, h.myz0]
    for k in range(1,int(h.myn)):
      h("""
myx0 = x3d("""+str(k)+""")
myy0 = y3d("""+str(k)+""")
myz0 = z3d("""+str(k)+""")
mydiam = diam""")

      coords = [[0.5*(h.myx0+oldcoord[0]),0.5*(h.myy0+oldcoord[1])] for i in range(0,int(h.n3d()))]
      mydist = dist2s.append((mean(coords,axis=0)[0]-0)**2 + (mean(coords,axis=0)[1]-650)**2)
      coldist = exp(-1/2*dist2s[-1]/radius**2)
      mynum = int(255.999*coldist)
      if mynum < 16:
        mychr = '0'+hex(mynum)[2]
      else:
        mychr = hex(mynum)[2:]
      col = "#"+'00'+mychr+'00'
      
      axarr.plot([oldcoord[0],h.myx0],[oldcoord[1],h.myy0],'k-',linewidth=h.mydiam*0.25,color=col)
      plotteds.append([[oldcoord[0],h.myx0],[oldcoord[1],h.myy0],'k-',h.mydiam*0.25,col])
      oldcoord = [h.myx0, h.myy0, h.myz0]

axis("equal")
f.savefig("morph_distfromapicalbranchcolor_radius"+str(radius)+".pdf")
  
file = open('morph_distfromapicalbranchcolor_radius'+str(radius)+'.sav', 'wb')
pickle.dump(plotteds,file)
file.close()