# code by sam neymotin & ernie forzano
from neuron import h
h.load_file("stdrun.hoc")
from pylab import *
import sys
import pickle
import numpy
h.install_vecst() # for samp and other NQS/vecst functions
from conf import *
import os
from scipy.stats.stats import pearsonr
from utils import dtrans
import shutil
ion()
rcParams['lines.markersize'] = 15
rcParams['lines.linewidth'] = 4
tl = tight_layout
useRMP = False # True # use RMP for fitness calculation?
useVoltDiff = False
useISI = False # this is for evaluation of full isi voltage
useISIFeat = False # this is for evaluation of isi voltage features
useISIDepth = False # this is for evaluation of isi voltage depth (min voltage)
useISIDur = False # this is for evaluation of isi voltage duration
useSag = False # whether to use sag for fitness
useSpikeAmp = False # spike amplitude (peak - treshold voltage) - do not need when using SpikeThresh and SpikePeak
useSpikePeak = False # spike peak (absolute voltage)
useSpikeW = False # spike widths at 25% and 50%
useSpikeSlope = False # min,max dv/dt
useSpikeThresh = False # spike threshold voltage
useSpikeShape = False # overall spike shape - uses features (peak,width,slope,thresh)
useSpikeTimes = useSpikeCoinc = False
useSFA = False # spike-frequency adaptation measure
useLVar = False
useInstRate = False
useTTFS = False # use time-to-first-spike for fitness
#
def getfitdims ():
fitdims = []
if useRMP: fitdims.append('RMP')
if useSag: fitdims.append('Sag')
if useFI: fitdims.append('FI')
if useISI: fitdims.append('ISIVolt')
if useISIFeat: fitdims.append('ISIFeat')
if useSFA: fitdims.append('SFA')
if useLVar: fitdims.append('LVar')
if useInstRate: fitdims.append('InstRate')
if useTTFS: fitdims.append('TTFS')
if useSpikeTimes: fitdims.append('SpikeTimes')
if useSpikeCoinc: fitdims.append('SpikeCoinc')
if useSpikeAmp: fitdims.append('SpikeAmp')
if useSpikePeak: fitdims.append('SpikePeak')
if useSpikeW: fitdims.append('SpikeW')
if useSpikeSlope: fitdims.append('SpikeSlope')
if useSpikeThresh: fitdims.append('SpikeThresh')
if useSpikeShape: fitdims.append('SpikeShape')
if useVoltDiff: fitdims.append('VoltDiff')
if useISIDepth: fitdims.append('ISIDepth')
if useISIDur: fitdims.append('ISIDur')
return fitdims
# determine config file name
def setfcfg ():
fcfg = "PTcell.BS0284.cfg" # default config file name
for i in range(len(sys.argv)):
if sys.argv[i].endswith(".cfg") and os.path.exists(sys.argv[i]):
fcfg = sys.argv[i]
#print "config file is " , fcfg
return fcfg
dmod = {}
fcfg=setfcfg() # config file name
dconf = readconf(fcfg)
dprm = dconf['params']
dfixed = dconf['fixed']
sampr = dconf['sampr'] # sampling rate
I = numpy.load(dconf['lstimamp'])
evolts = numpy.load(dconf['evolts']) # experimental voltage traces
tte = linspace(0, 1e3*evolts.shape[0]/sampr, evolts.shape[0])
evolts = numpy.load(dconf['evolts']) # experimental voltage traces
useFI=useInstRate=useISI=useSpikeShape=useVoltDiff=True
fitdims=getfitdims()
#
def geterramp (nqa,row,lc):
err = 0.0
for c in lc:
if nqa.fi(c) != -1:
err += (nqa.getcol(c).x[row] / nqa.getcol(c).mean())**2
return sqrt(err)
#
def adderrampcol (nqa,lc):
nqa.tog('DB')
if nqa.fi('erramp')== -1.0: nqa.resize('erramp'); nqa.pad()
for i in range(int(nqa.v[0].size())): nqa.getcol('erramp').x[i] = geterramp(nqa,i,lc)
nqa.stat('erramp') #
# convert population to NQS
def pop2nq (fpop,fitdims=None):
if fitdims == None: fitdims=getfitdims()
nqa = None
try:
nqa = h.NQS()
except:
h.load_file("nqs.hoc"); #h.load_file("decnqs.hoc")
nqa = h.NQS()
# first setup the fitness dimensions
for s in fitdims: nqa.resize(s)
nqa.clear(len(fpop))
for m in fpop:
fit = m.fitness
for i,val in enumerate(fit): nqa.v[i].append(val)
# then setup the parameter values
for k in list(dprm.keys()): nqa.resize(k)
nqa.pad()
for i,m in enumerate(fpop):
idx = len(fitdims)
prm = m.candidate
jdx = idx; kdx = 0
while jdx < nqa.m[0]:
nqa.v[jdx].x[i] = prm[kdx]
jdx += 1; kdx += 1
adderrampcol(nqa,fitdims)
return nqa
# print out param values (nqa is table, idx is row)
def rowprmstr (nq,idx):
s = ''
for i in range(len(fitdims),int(nq.m[0]),1): s += str(nq.v[i].x[idx]) + ' '
return s
# loads model archive and stores in global ark and nqa objects
def loadark (fn):
global ark,nqa
ark = pickle.load(open(fn))
print(len(ark), ' models in ', fn, ' archive.')
nqa = pop2nq(ark,fitdims)
if fcfg == 'SPI6.cfg': # simplified model
useVoltDiff=useFI=useInstRate=useSpikeW=useSpikeSlope=useSpikeThresh=useSpikePeak=useISI=True
useSpikeShape=False
fitdims=getfitdims() # reset fitness dimensions(fitdims), which differ from detailed model
loadark(os.path.join('data','simparch.pkl')) # load simple model archive
else: # detailed model
loadark(os.path.join('data','detarch.pkl')) # load detailed model archive
# add text to a plot
def addtext (row,col,lgn,ltxt,tx=-0.025,ty=1.03,c='k'):
for gn,txt in zip(lgn,ltxt):
ax = subplot(row,col,gn)
text(tx,ty,txt,fontweight='bold',transform=ax.transAxes,color=c);
def naxbin (ax,nb): ax.locator_params(nbins=nb);
# print full row (fitness and param values) at the given row (idx) from table (nqa)
def rowstr (nq,idx):
s = ''
for i in range(int(nq.m[0])): s += nq.s[i].s + ':' + str(nq.v[i].x[idx]) + "\n"
return s
# print param values at the given row (idx) from table (nqa)
def rowprmvals (nq,idx):
lval = []
for i in range(len(fitdims),int(nq.m[0]),1): lval.append((nq.v[i].x[idx]))
return lval
# find index of f in a (if not there return -1)
def indexof (a,f):
for i,val in enumerate(a):
if abs(val-f) < 0.01: return i
return -1
ISubth = I[0:6] # subthreshold current injections
ISup = I[6:] # current injections for subthresh right before threshold & superthreshold traces
IAll = list(ISubth); IAll.extend(list(ISup))
# draw traces from experiment (uses black color)
def drawexptraces ():
tx,ty=-.05,1.02; offy = amin(tte[0]) - 30
ax=gca(); ax.set_xticks([]); ax.set_yticks([]);
plot([1420,1520],[590,590],'k',linewidth=4)
plot([1520,1520],[580,590],'k',linewidth=4)
ypos = offy
for j,i in enumerate(IAll):
idx = indexof(I,i)
plot(tte,evolts[:,idx] + ypos,'k')
if j > len(ISubth): ypos += 95
else: ypos += 15
cdx=0 # index into color list
# draw traces from the model (cycles through colors)
def drawtraces (model):
global cdx
lclr = ['r','g','b','c','m','y']
tt = numpy.array(dmod[model]['vt'])
tx,ty=-.05,1.02; offy = amin(tt[0]) - 30
if len(get_fignums())==0: drawexptraces()
mdx=0; m=model
ax=gca()
ypos = offy
for j,i in enumerate(IAll):
plot(tt, dmod[m][i] + ypos,lclr[cdx%len(lclr)])
if j > len(ISubth): ypos += 95
else: ypos += 15
ax.set_xticks([]); ax.set_yticks([]);
xlim((400,1600));
ylim((-125,680));
cdx+=1
# run model idx using params in ark/nqa, then load/draw the data
def runmodel (idx):
global lastmodel
# should move pkl file to arch index location so dont have to rerun
fnew = os.path.join('data', fcfg.split('.cfg')[0] + '_' + str(idx) + '.pkl')
if os.path.exists(fnew):
print('model ' + str(idx) + ' already ran, data in', fnew)
else:
cmd = 'python sim.py ' + fcfg + ' ' + rowprmstr(nqa,idx)
print(cmd)
os.system(cmd)
if fcfg.startswith('PTcell'):
shutil.move(os.path.join('data','morph.pkl'),fnew)
else:
shutil.move(os.path.join('data','SPI6.pkl'),fnew)
if not os.path.exists(fnew):
print('ERROR: could not run model!')
return
lastmodel = (fcfg,idx)
dmod[lastmodel] = pickle.load(open(fnew)) # load the data
print('model fitness error/params:', rowstr(nqa,idx))
drawtraces((fcfg,idx))
#
def drtxt (ax,lett,tx=-0.075,ty=1.03,fsz=45): text(tx,ty,lett,fontweight='bold',transform=ax.transAxes,fontsize=fsz)
# draw archive figure showing param values of bottom/top percentiles
def drawarchfig ():
if fcfg == 'SPI6.cfg':
lprm = ['SPI6.gbar_kdmc','SPI6.cal_gcalbar','SPI6.can_gcanbar','SPI6.kBK_gpeak','SPI6.gbar_kap','SPI6.gbar_kdr','SPI6.gbar_nax','SPI6.kBK_caVhminShift','SPI6.cadad_taur','SPI6.cadad_depth','h.vhalfn_kdr','h.vhalfn_kap','h.vhalfl_kap','h.tq_kap']
else:
lprm = ['morph.nax_gbar', 'morph.kdmc_gbar','morph.kdr_gbar','morph.kap_gbar','morph.kBK_gpeak','morph.kBK_caVhminShift','morph.cal_gcalbar','morph.can_gcanbar','morph.cadad_taur','morph.cadad_depth']
draw1dfig(nqa,'erramp',0.01,lprm,nrow=2,ncol=2,gdx=1,stxt='a')
xlim((0.5,10.5)); ylim((-3,4.5))
mbotAMP,mtopAMP = getprct(nqa,'erramp',0.01,lprm)
mcAMP = getprmcors(nqa,'erramp',0.01,lprm)
ax = subplot(2,2,2)
imshow(mcAMP,interpolation='None',origin='lower',aspect='auto',extent=(0,mcAMP.shape[0]-1,0,mcAMP.shape[0]-1))
colorbar(); ax.set_xticks([]); ax.set_yticks([])
mytxt = 'Worst Best'; xlabel(mytxt); ylabel(mytxt);
text(-0.025,1.03,'b',fontweight='bold',transform=ax.transAxes,color='k');
title('Parameter correlations')
draw1dfig(nqa,'FI',0.01,lprm,nrow=2,ncol=2,gdx=3,stxt='c')
xlim((0.5,10.5)); ylim((-3,4.5))
mbotFI,mtopFI = getprct(nqa,'FI',0.01,lprm)
mcFI = getprmcors(nqa,'FI',0.01,lprm)
ax = subplot(2,2,4)
imshow(mcFI,interpolation='None',origin='lower',aspect='auto',extent=(0,mcFI.shape[0]-1,0,mcFI.shape[0]-1))
colorbar(); ax.set_xticks([]); ax.set_yticks([])
mytxt = 'Worst Best'; xlabel(mytxt); ylabel(mytxt);
text(-0.025,1.03,'d',fontweight='bold',transform=ax.transAxes,color='k');
title('Parameter correlations')
subplot(2,2,1); title('Rank by Error Amplitude');
subplot(2,2,3); title('Rank by FI Error')
#
def draw1dfig (nq,scc,prct,lprm,nrow=1,ncol=1,gdx=1,stxt='a'):
tx,ty=-0.025,1.03;
nqt = h.NQS()
nqt.cp(nq)
nqt.sort(scc)
botsidx,boteidx = 0,int(prct*nqt.v[0].size()) # good
topsidx,topeidx = int(nqt.v[0].size()*(1.0-prct)),int(nqt.v[0].size()-1) # bad
ax = subplot(nrow,ncol,gdx)
for pdx,prm in enumerate(lprm):
dat = numpy.array(nqt.getcol(prm).to_python())
dat = dat - mean(dat)
dat = dat / std(dat)
plot([pdx+1 for j in range(boteidx-botsidx)],dat[botsidx:boteidx],'^',markeredgecolor='m',markerfacecolor='none',markersize=60,linewidth=8)
plot([pdx+1 for j in range(topeidx-topsidx)],dat[topsidx:topeidx],'v',markeredgecolor='c',markerfacecolor='none',markersize=60,linewidth=8)
ax.set_xticklabels([dtrans[prm] for prm in lprm])
ax.set_xticks(linspace(1,len(lprm),len(lprm)))
ylabel('Normalized parameter value'); #ylim((-4.2,4.2))
text(tx,ty,stxt,fontweight='bold',transform=ax.transAxes,color='k');
h.nqsdel(nqt)
# get bottom/top percentile from nq using column scc
def getprct (nq,scc,prct,lprm):
nqt = h.NQS()
nqt.cp(nq)
nqt.sort(scc)
botsidx,boteidx = 0,int(prct*nqt.v[0].size()) # good
topsidx,topeidx = int(nqt.v[0].size()*(1.0-prct)),int(nqt.v[0].size()-1) # bad
mtop = zeros((topeidx-topsidx,len(lprm)))
mbot = zeros((boteidx-botsidx,len(lprm)))
for pdx,prm in enumerate(lprm):
dat = numpy.array(nqt.getcol(prm).to_python())
dat = dat - mean(dat)
dat = dat / std(dat)
mbot[:,pdx] = dat[botsidx:boteidx]
mtop[:,pdx] = dat[topsidx:topeidx]
h.nqsdel(nqt)
return mbot,mtop
# get parameter correlations
def getprmcors (nq,scc,prct,lprm):
mbot,mtop = getprct(nq,scc,prct,lprm)
nrow,ncol = mbot.shape
mprct = zeros((nrow*2,ncol))
mprct[0:nrow,:] = mbot
mprct[nrow:,:] = mtop
mc = ones((nrow*2,nrow*2))
for i in range(nrow*2):
for j in range(i+1,nrow*2,1):
mc[i,j]=mc[j,i]=pearsonr(mprct[i,:],mprct[j,:])[0]
return mc