from numpy import *
import json, sys, os, time
import lzma as xz
import matplotlib
matplotlib.rcParams["savefig.directory"] = ""
from matplotlib.pyplot import *
from simtoolkit import tree
from simtoolkit import data as stkdata
from optparse import OptionParser
oprs = OptionParser("USAGE: %prog [flags] [variable]")
oprs.add_option("-s", "--save-json" , dest="sjson" , default=None, help="save into json" , type="str")
oprs.add_option("-l", "--load-json" , dest="ljson" , default=None, help="load from json" , type="str")
oprs.add_option("-q", "--recompute" , dest="recomp", default=False, help="Recompute correlation" , action="store_true")
oprs.add_option("-B", "--num-of-cores" , dest="ncores", default=os.cpu_count(), \
help="Use number of cores" , type="int" )
oprs.add_option("-R", "--re-normalize" , dest="renorm", default=False, help="Renormalize g_inh by the final g_exc conductnace" \
, action="store_true")
oprs.add_option("-X", "--use-max" , dest="usemax", default=False, help="Use max instead fo mean" , action="store_true")
oprs.add_option("-M", "--estimate-FR" , dest="estFR" , default=False, help="Estimate FR" , action="store_true")
#-------
"""
oprs.add_option( "-o" , dest="output" , default=None , help="output file" , type ="str" )
opts, args = oprs.parse_args()
with stkdata(opts.output,mode='w') as wsd:
for f in args:
with xz.open(f) as fd:
with stkdata(fd,dtype='stkdata',mode='ro') as sd:
hashline = sd["/hash",-1]
wsd["/hash"] = hashline
wsd[f"/{hashline}/srcfile"] = f
for n in sd:
if not hashline in n: continue
for ch in sd[n]: wsd[n] = ch
print(f,print(hashline),'is copied')
"""
#-------
def getcor(sig2gsyn,trngsyn,trndly,t2r,ncells,spikes):
sp = array(spikes)
f0 = 1000. # 1ms bin = 1kHz sample rate
rkm = arange(-600.,601.,1)
rkp = exp(-rkm**2/20.**2)
rkp /= sum(rkp)
rkm = exp(-rkm**2/80.**2)
rkm /= sum(rkm)
rka = rkp-rkm
FRdur = int(ceil(amax(sp[:,0]))+1)
FR = zeros( (ncells,FRdur) )
cFR = zeros( FR.shape )
for t,n in sp:
FR[int(round(n)),int(round(t))] += 1
FR[:,0:10] = 0.
for n in range(ncells):
cFR[n,:] = convolve( FR[n,:],rka,mode='same')
xcor = corrcoef(cFR)
xcor = array([ xcor[n1,n2] for n1 in range(ncells) for n2 in range(n1+1,ncells) ])
xcor[~isfinite(xcor)] = 0
meancor = mean(xcor)
h,b = histogram(xcor, bins=201, range=(-1.,1.) )
h,b = h/sum(h), (b[:-1]+b[1:])/2.
cd = column_stack((b,h))#.tolist()
maxcor = cd[argmax(cd[:,1]),0]
#print(meancor, average(cd[:,0], weights=cd[:,1]), maxcor)
#print(cd)
#exit(0)
#<<DB
return json.dumps([sig2gsyn,trngsyn,trndly,t2r,meancor,maxcor])
def readFR(f):
if not f.endswith('.stkdata.xz'):
print(f"input file {f} doesn't seem like xz commpressed stkdata")
return None
fname = f[:-len('.stkdata.xz')]
with open(fname+'-FR.json') as fd:
ll = None
for l in fd.readlines():
ll = l[:]
meanFR = mean(array(json.loads(ll)))
with open(fname+'-blockTRN-FR.json') as fd:
ll = None
for l in fd.readlines():
ll = l[:]
meanTRNBLKFR = mean(array(json.loads(ll)))
if meanFR == 0: return 0
else : return meanTRNBLKFR/meanFR
# 131->363 (1-141)
def readndata(f,readspikes=False):
print(f"Reading {f}")
with xz.open(f,'r') as fd:
with stkdata(fd,mode="ro",dtype='stkdata') as sd:
try:
hashline = sd['/hash',-1]
ncells = sd[f'/{hashline}/n-neurons',-1]
model = tree().imp(sd[f'/{hashline}/model',-1])
sig2gsyn = model['/network/syn/sig2gsyn']
trndly = model['/network/trn/delay']
trngsyn = abs(model['/network/trn/gsyn'])
wsyn0 = array(sd[f"/{hashline}/gsyn",0])
wsyn1 = array(sd[f"/{hashline}/gsyn",-1])
t2r = trngsyn/mean(wsyn1)
correl = sd[f"/{hashline}/CorrDist/LGN",-1]
meancor = average(correl[:,0],weights=correl[:,1])
maxcor = correl[argmax(correl[:,1]),0]
if readspikes :
correl = correl.tolist()
xspikes = None
for spk in sd[f"/{hashline}/spikes"]:
if xspikes is None: xspikes = spk
else:
xspikes = append(xspikes,spk,axis=0)
xspikes = xspikes.tolist()
except BaseException as e:
sys.stderr.write(f"Cannot reaf{f} : {e}\n")
return None
#print(meancor,maxcor)
#print(correl)
if opts.estFR and not readspikes:
frp = readFR(f)
return [sig2gsyn,trngsyn,trndly,t2r,meancor,maxcor,frp]
elif readspikes:
return [sig2gsyn,trngsyn,trndly,t2r,ncells,xspikes]
else:
return [sig2gsyn,trngsyn,trndly,t2r,meancor,maxcor]
def worker(f):
x = readndata(f,readspikes=True)
if x is None: return None
if opts.estFR:
ret = json.loads(getcor(*x))
frp = readFR(f)
ret.append(frp)
return json.dumps(ret)
else:
return getcor(*x)
opts, args = oprs.parse_args()
recs = []
if opts.ljson is not None:
with open(opts.ljson) as fd:
j = json.load(fd)
recs = j['recs']
args = j['files']
else:
if len(args) == 0:
print("Need input!")
exit(1)
if not opts.recomp:
for f in args:
r = readndata(f)
if r is None: continue
recs.append(r)
if opts.recomp:
if opts.ncores < 2:
recs = [ worker(f) for f in args]
else:
import multiprocessing as mp
pool = mp.Pool(processes=opts.ncores)
result = [pool.apply_async(worker,[f]) for f in args]
pool.close()
pool.join()
recs = [json.loads(r.get()) for r in result]
if opts.sjson is not None:
with open(opts.sjson,'w') as fd:
if opts.recomp:
fd.write("{\n\t\"Recomputed\" : "+json.dumps(time.strftime("%Y-%mm-%d %H-%M-%S"))+",\n")
else:
fd.write("{\n\t\"Recomputed\" : "+json.dumps(None)+",\n")
fd.write("\t\"files\" : "+json.dumps(args)+",\n")
fd.write("\t\"recs\" : "+json.dumps(recs)+"\n")
fd.write("}\n")
recs = array(recs)
u_s2g = unique(recs[:,0])
u_syn = unique(recs[:,1])
u_dly = unique(recs[:,2])
u_t2r = unique(recs[:,3])
print(u_s2g)
print(u_syn)
print(u_dly)
print(u_t2r)
#cmap = get_cmap('bwr')
cmap = get_cmap('coolwarm')
smap = get_cmap('rainbow')
dmap = get_cmap('Greys')
corminmax = amax( abs(recs[:,4]) )
figure(1)
if opts.renorm:
nr=(u_s2g.shape[0]*2) if opts.estFR else u_s2g.shape[0]
for i,c in enumerate(u_s2g):
idx = where(recs[:,0] == c)[0]
subplot(1,nr,i+1)
title(f'{c}')
h=scatter(recs[idx,2],recs[idx,3],c=recs[idx, 5 if opts.usemax else 4],vmin=-corminmax,vmax=corminmax,cmap=cmap)
colorbar(h)
yscale('log')
xscale('log')
if opts.estFR:
sps = u_s2g.shape[0]
maxfrr = amax(recs[:, 6])
for i,c in enumerate(u_s2g):
idx = where(recs[:,0] == c)[0]
subplot(1,nr,i+1+sps)
title(f'{c}')
h=scatter(recs[idx,2],recs[idx,3],c=recs[idx, 6],vmin=0,vmax=maxfrr,cmap=smap)
colorbar(h)
#contour(recs[idx,2],recs[idx,3],recs[idx, 6], [2.77])
yscale('log')
xscale('log')
else:
a_syn = log10(u_syn)
da = around(a_syn[1] - a_syn[0],2)
a_syn = append(a_syn, a_syn[-1]+da)-da/2
a_syn = 10**a_syn
a_dly = log10(u_dly)
da = around(a_dly[1] - a_dly[0],2)
a_dly = append(a_dly, a_dly[-1]+da)-da/2
a_dly = 10**a_dly
cormaps=[]
if opts.estFR:
frmaps = []
for c in u_s2g:
cormap = zeros( (u_syn.shape[0],u_dly.shape[0],4) )
if opts.estFR:
frmap = zeros( (u_syn.shape[0],u_dly.shape[0],4) )
for i,s in enumerate(u_syn):
for j,d in enumerate(u_dly):
idxs = where((recs[:,0] == c)*(recs[:,1] == s)*(recs[:,2] == d))[0]
if idxs.shape[0] == 0:
print(f"SIGMA:{c}, Gsyn:{s}, Delay:{d} : Index size zero")
continue
cors = []
if opts.estFR:
frs = []
for idx in idxs:
cors.append( recs[idx, 5 if opts.usemax else 4] )
if opts.estFR:
frs.append( recs[idx,6] )
cors = array(cors)
if opts.estFR:
frs = array(frs)
cormap[i,j,:] = array([mean(cors),std(cors),amin(cors),amax(cors) ])
if opts.estFR:
frmap[i,j,:] = array([mean(frs),std(frs),amin(frs),amax(frs) ])
cormaps.append([c,cormap])
if opts.estFR:
frmaps.append(frmap)
nr=(len(cormaps)*2) if opts.estFR else len(cormaps)
maxcor = max([ amax(c[:,:,0]) for _,c in cormaps ] )
maxcst = max([ amax(c[:,:,1]) for _,c in cormaps ] )
for i,(c, cormap) in enumerate(cormaps):
subplot(4,nr,i+1)
title(f'{c}')
h=pcolormesh(a_dly,a_syn,cormap[:,:,0],shading='flat',vmin=-maxcor,vmax=maxcor,cmap=cmap)
# h=pcolormesh(a_dly,a_syn,cormap[:,:,0],shading='flat',vmin=0,vmax=corminmax,cmap=smap)
colorbar(h)
yscale('log')
xscale('log')
subplot(4,nr,i+1+nr)
h=pcolormesh(a_dly,a_syn,cormap[:,:,1],shading='flat',vmin=0,vmax=maxcst,cmap=dmap)
colorbar(h)
yscale('log')
xscale('log')
subplot(4,nr,i+1+nr*2)
h=pcolormesh(a_dly,a_syn,cormap[:,:,2],shading='flat',vmin=-maxcor,vmax=maxcor,cmap=cmap)
colorbar(h)
yscale('log')
xscale('log')
subplot(4,nr,i+1+nr*3)
h=pcolormesh(a_dly,a_syn,cormap[:,:,3],shading='flat',vmin=-maxcor,vmax=maxcor,cmap=cmap)
colorbar(h)
yscale('log')
xscale('log')
if opts.estFR:
maxfrr = amax(recs[:, 6])
maxfst = max([ amax(f[:,:,1]) for f in frmaps ] )
sps = len(frmaps)
for i,frmap in enumerate(frmaps):
subplot(4,nr,i+1+sps)
h=pcolormesh(a_dly,a_syn,frmap[:,:,0],shading='flat',vmin=0,vmax=maxfrr,cmap=smap)
colorbar(h)
contour(u_dly,u_syn,frmap[:,:,0], [2.3])
yscale('log')
xscale('log')
subplot(4,nr,i+1+sps+nr)
h=pcolormesh(a_dly,a_syn,frmap[:,:,1],shading='flat',vmin=0,vmax=maxfst,cmap=dmap)
colorbar(h)
contour(u_dly,u_syn,frmap[:,:,0], [2.3])
yscale('log')
xscale('log')
subplot(4,nr,i+1+sps+nr*2)
h=pcolormesh(a_dly,a_syn,frmap[:,:,2],shading='flat',vmin=0,vmax=maxfrr,cmap=smap)
colorbar(h)
contour(u_dly,u_syn,frmap[:,:,2], [2.3])
yscale('log')
xscale('log')
subplot(4,nr,i+1+sps+nr*3)
h=pcolormesh(a_dly,a_syn,frmap[:,:,3],shading='flat',vmin=0,vmax=maxfrr,cmap=smap)
colorbar(h)
contour(u_dly,u_syn,frmap[:,:,3], [2.3])
yscale('log')
xscale('log')
show()