import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
import sys
import scipy.stats as stats
import matplotlib.cm as cm
import matplotlib as mpl
mpl.rcParams["errorbar.capsize"] = 2
mpl.rcParams["lines.linewidth"] = 1
mpl.rcParams['pdf.fonttype'] = 42
np.set_printoptions( threshold=999999999999999)
def trevrolls(frates):
r2=0.
rs =0.
n = float(frates.size)
for i in range(frates.size): # np.nditer(frates):
r2 += (frates[i]**2)/n
rs += frates[i]/n
return 1. - ((rs**2)/r2)
def loadspikesdat(filename, tduration):
ff = open(filename, 'r')
fdata = ff.readlines()
sx = len(fdata)
sy = tduration;
raster = np.zeros( (sx, sy) );
nid=0
for l in fdata:
ar = np.fromstring(l, sep=' ' , dtype=int)
raster[nid, ar] = 1
raster[nid,0] =0 # XXX bug
nid += 1
return raster
def printpairstats(stat, name):
print( name)
print( stats.f_oneway( stat[0][0] , stat[1][0] ) )
print( stats.f_oneway( stat[0][1] , stat[1][1] ) )
print( stats.f_oneway( stat[0][0] , stat[0][1] ) )
print( stats.f_oneway( stat[1][0] , stat[1][1] ) )
def exportcsv(data, name):
mycsv = np.array( [data[0,0] , data[1,0], data[0, 1], data[1,1] ]);
df = pd.DataFrame(mycsv.T)
df.to_csv(name+".txt")
def label_diff(ax, i,j,text,X,Y):
x = (X[i]+X[j])/2
y = 1.1*max(Y[i], Y[j])
dx = abs(X[i]-X[j])
props = {'connectionstyle':'bar','arrowstyle':'-',\
'shrinkA':20,'shrinkB':20,'linewidth':1}
ax.annotate(text, xy=(X[i],y-7), zorder=10, transform=ax.transData)
#ax.text(.5, .5, "text")
ax.annotate('', xy=(X[i],y), xytext=(X[j],y), arrowprops=props)
NPYRS=400
NINH=100
THRESHOLD=40.
def mkPlots():
ncases = 2
NRUNS=10
saves = np.zeros( (30, NRUNS) );
trs = np.zeros((ncases, NRUNS))
corrs = np.zeros((ncases, NRUNS))
sizeA = np.zeros((ncases, NRUNS))
sizeB = np.zeros((ncases, NRUNS))
fratesA = np.zeros((ncases, NRUNS))
fratesB = np.zeros((ncases, NRUNS))
trA = np.zeros((ncases, NRUNS))
trB = np.zeros((ncases, NRUNS))
pcorr = np.zeros((ncases, 4, NRUNS))
NLIMIT = NPYRS +NINH
for run in range(NRUNS):
spikes = np.loadtxt( './data/control_%d/spikesperpattern.dat'%( run), dtype=float)
spikes = spikes[:, 0:NLIMIT]
activepop = np.sum(np.logical_or((spikes[1,:] >THRESHOLD), ( spikes[0,:] >THRESHOLD) ))
overlap = 100. * np.sum(np.logical_and((spikes[1,:] >THRESHOLD), ( spikes[0,:] >THRESHOLD) )) / (activepop+0.0001)
sizeA[0, run] = 100.* np.sum(spikes[0,:] >THRESHOLD) / NLIMIT
sizeB[0, run] = 100.* np.sum(spikes[1,:] >THRESHOLD) / NLIMIT
trs[0, run] = overlap;
cstart =100
cend =200
cf = np.corrcoef(spikes[0, cstart:cend], spikes[1,cstart:cend])
corrs[0, run] = cf[0,1]
fratesA[0,run] = np.mean(spikes[0,:])
fratesB[0,run] = np.mean(spikes[1,:])
trA[0,run] = trevrolls(spikes[0,:])
trB[0,run] = trevrolls(spikes[1,:])
spikes = np.loadtxt( './data/blocked_%d/spikesperpattern.dat'%( run), dtype=float)
spikes = spikes[:, 0:NLIMIT]
activepop = np.sum(np.logical_or((spikes[1,:] >THRESHOLD), ( spikes[0,:] >THRESHOLD) ))
overlap = 100. * np.sum(np.logical_and((spikes[1,:] >THRESHOLD), ( spikes[0,:] >THRESHOLD) )) / (activepop+0.0001);
sizeA[1, run] = 100.* np.sum(spikes[0,:] >THRESHOLD) / NLIMIT
sizeB[1, run] = 100.* np.sum(spikes[1,:] >THRESHOLD) / NLIMIT
trs[1, run] = overlap;
cf = np.corrcoef(spikes[0, cstart:cend], spikes[1, cstart:cend])
corrs[1, run] = cf[0,1]
fratesA[1,run] = np.mean(spikes[0,:])
fratesB[1,run] = np.mean(spikes[1,:])
trA[1,run] = trevrolls(spikes[0,:])
trB[1,run] = trevrolls(spikes[1,:])
plt.figure()
plt.subplot(1,2,1)
means = np.mean( trs, axis=1)
stds = np.std( trs, axis=1)
plt.ylim((0, 60));
#plt.boxplot( (trs[0], trs[1]), notch=True );
plt.bar([1,2], means, yerr=stds, color=['indigo', 'purple'])
plt.ylabel('Population Overlap CtxA & CtxB (%)');
plt.xticks( [1, 2], ['Control', 'LC Block']);
#np.savetxt("c_overlap.txt", trs, delimiter=',');
saves[0:2] = np.array(trs)
"""
plt.subplot(1,2,2)
means = np.mean(corrs, axis=1)
stds = np.std(corrs, axis=1)
plt.bar([1,2], means, yerr=stds, color=['indigo', 'purple'])
plt.ylabel('Firing Rate Correlation CtxA / CtxB (%)');
plt.xticks( [1, 2], ['Control', 'LC Block']);
saves[3:5] = np.array(corrs)
"""
plt.subplot(1,2,2)
ops = (sizeA[0], sizeB[0], sizeA[1], sizeB[1])
means = np.mean( ops, axis=1)
stds = np.std( ops, axis=1)
plt.bar([1,2,3,4], means, yerr=stds, color=['indigo', 'indigo', 'purple', 'purple'])
#plt.ylim((20, 50));
plt.xticks( [1, 2,3,4], ['CtxA\nControl', 'CtxB\nControl', 'CtxA\nLC Block', 'CtxB\nLC Block']);
plt.ylabel('Active Population % (ff >10Hz)');
st = (stats.f_oneway(sizeA[0], sizeB[0])) # , sizeA[1], sizeB[1]) )
print(st)
#plt.annotate('1One-way ANOVA %f'%(st[1]), xy = (0.3,0.9), xycoords='figure fraction' )
saves[6:10] = np.array(ops)
plt.figure()
ops = (fratesA[0], fratesB[0], fratesA[1], fratesB[1])
means = np.mean( ops, axis=1)
stds = np.std( ops, axis=1)
plt.bar([1,2,3,4], means, yerr=stds, color=['indigo', 'indigo', 'purple', 'purple'])
st = (stats.f_oneway(fratesA[0], fratesB[0], fratesA[1], fratesB[1]))
print(st)
plt.annotate('One-way ANOVA p= %g'%(st[1]), xy = (0.3,0.9), xycoords='figure fraction' )
saves[11:15] = np.array(ops)
#plt.ylim((15, 30));
plt.ylabel('Mean Firing Rate (Hz)');
plt.xticks( [1, 2,3,4], ['MemA\nControl', 'MemB\nControl', 'MemA\nLC Block', 'MemB\nLC Block']);
plt.figure()
plt.ylabel('Sparsity');
ops = (trA[0], trB[0], trA[1], trB[1])
means = np.mean( ops, axis=1)
stds = np.std( ops, axis=1)
plt.bar([1,2,3,4], means, yerr=stds, color=['indigo', 'indigo', 'purple', 'purple'])
plt.xticks( [1, 2,3,4], ['MemA\nControl', 'MemB\nControl', 'MemA\nLC Block', 'MemB\nLC Block']);
saves[18:22] = np.array(ops)
np.savetxt("saves.txt", saves.T, delimiter=',', fmt='%f');
def mkRampPlot():
plt.figure()
plt.ylabel('Mean Firing Rate (Hz)');
selected = [0,5,6,]
data = np.loadtxt( './data/ramp_data.txt');
spikes = data[:,0:50];
stds = np.std(spikes, axis=1)/np.sqrt(10);
means = np.mean(spikes, axis=1)
xlab = [0.01*v for v in range(means.shape[0])];
plt.errorbar( xlab, means, stds, color='black' );
spikes = data[:, 200:250]
stds = np.std(spikes, axis=1)#/np.sqrt(10);
means = np.mean(spikes, axis=1)
plt.errorbar( xlab, means, stds, color='royalblue');
plt.legend(['Control', 'LC Inhibited']);
plt.xlabel('Input current (nA)');
plt.ylim((0, 25));
#plt.xticks( [1, 2], ['Control', 'LC Block']);
def mkSamples():
plt.figure()
myvmax = 17;
plt.subplot(2,3,1)
data = np.loadtxt( './data/control_0/spikesperpattern.dat');
data = data /4;
spikes = data[0,0:400];
spikes = spikes.reshape( (20,20))
plt.imshow(spikes, vmin=0, vmax=myvmax)
plt.axis('off')
plt.subplot(2,3,2)
spikes2 = data[1,0:400];
spikes2 = spikes2.reshape( (20,20))
plt.imshow(spikes, vmin=0, vmax=myvmax)
plt.axis('off')
plt.subplot(2,3,3)
act =(np.logical_and( (spikes>10), (spikes2>10 ) ))
plt.imshow(act);
plt.axis('off')
plt.subplot(2,3,4)
data = np.loadtxt( './data/blocked_3/spikesperpattern.dat');
data = data /4;
spikes = data[0,0:400]
spikes = spikes.reshape( (20,20))
plt.imshow(spikes, vmin=0, vmax=myvmax)
plt.axis('off')
plt.subplot(2,3,5)
spikes2 = data[1,0:400];
spikes2 = spikes2.reshape( (20,20))
plt.imshow(spikes2, vmin=0, vmax=myvmax)
plt.axis('off')
#plt.colorbar()
plt.subplot(2,3,6)
act =(np.logical_and( (spikes>10), (spikes2>10 ) ))
print(act.sum())
plt.imshow(act);
plt.axis('off')
plt.figure()
plt.imshow(spikes)
plt.colorbar();
#plt.xlabel('Input current (nA)');
#plt.xticks( [1, 2], ['Control', 'LC Block']);
def mkPairs(cond, run, label):
raster = np.zeros((500, 8000));
lines = open( './data/%s_%d/spikes.dat'%(cond, run),'r').readlines();
ln=0;
for line in lines:
cols = [int(n) for n in line.split()]
raster[ln, cols ] = 1;
ln +=1;
if ln>=500: break
memA = raster[:, 0:4000];
memB = raster[:, 4000:8000];
actA = (1*(memA.sum(axis=1)>40))
actB= (1*(memB.sum(axis=1)>40))
ov = (((actA + actB)>1))
print(ov.sum())
ovA = memA[ov, :]
ovB = memB[ov, :]
corrs = []
for cs in range(40):
sa= (ovA[:, cs:cs+100].sum(axis=1))
sb= (ovB[:, cs:cs+100].sum(axis=1))
co = np.corrcoef(sa, sb)
corrs.append(co[0,1])
plt.figure()
plt.title(label)
plt.bar(range(40), corrs)
plt.xlabel('Time (1/10 sec)')
plt.ylabel('Corr. Coeff ')
plt.ylim( (-0.2,0.5) )
def mkVolt():
volt = np.loadtxt( './data/ramp_voltage.txt');
plt.figure();
plt.subplot(2,1,1);
plt.ylabel("mV")
plt.plot(volt[:,0]);
plt.subplot(2,1,2);
plt.plot(volt[:,1]);
plt.ylabel("mV")
plt.xlabel("msec")
mkRampPlot()
mkSamples()
mkPlots()
plt.show()