import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import statsmodels.api as sm
import numpy.random as rnd
import random
from scipy.interpolate import interp1d
import param
par = param.create_params()
cell_types = par.cell_types
num_types = par.num_types
halves = par.halves
vect_index = par.vect_index
left_index = par.left_index
right_index = par.right_index
tot_num_x_type = par.tot_num_x_type
colors = par.colors
pos = par.pos
pos_old = np.loadtxt("data/pos.txt")
def tmfr(cellist,tw,dtF,I):
st=[]
for cell in cellist:
st.append(np.array(cell.record["spk"]))
st=np.concatenate(st)
t = np.arange(I[0],I[1],dtF)
tmfr=[]
for ti in t:
tmfr.append(1.0*sum((st>ti)&(st<(ti+tw)))/len(cellist))
return (t,tmfr)
def plot_spikes_data(st,I=''):
st=np.array([train.tolist() for train in st])
plt.subplot(2,1,2)
for i in xrange(len(left_index)):
if i!=3 and i!=4:
plot_spk_train(pos[left_index[i]],st[left_index[i]],I,colors=colors[i])
plt.ylim([500,2000])
plt.ylabel("Left")
plt.subplot(2,1,1)
for i in xrange(len(right_index)):
if i!=3 and i!=4:
plot_spk_train(pos[right_index[i]],st[right_index[i]],I,colors=colors[i])
plt.ylim([500,2000])
plt.ylabel("Right")
def SpikeTrain(cellist):
return [np.array(cell.record['spk']) for cell in cellist]
def TimeVoltageTrace(cell):
return (np.array(cell.record['t']),np.array(cell.record['vm']))
def plot_spk_train(pos,st,I='',colors=''):
for (i,spike_train) in enumerate(st):
plt.plot(spike_train,pos[i]*np.ones_like(spike_train),marker='.',markersize=5,markerfacecolor=colors,markeredgecolor=colors,linestyle='None')
if len(I):
plt.xlim(I)
def inj_current(cells,delay,dur,amp_mean=0,amp_std=0):
for cell in cells:
cell.CC.delay = delay
cell.CC.dur = dur
cell.CC.amp = rnd.normal(amp_mean,amp_std)
def inj_current2(cells,delay,dur,amp_mean=0,amp_std=0):
for cell in cells:
cell.CC2.delay = delay
cell.CC2.dur = dur
cell.CC2.amp = rnd.normal(amp_mean,amp_std)
def plotLeftRightSpikeTrainType(cellist,cell_type,I):
st=np.array([train.tolist() for train in SpikeTrain(cellist)])
plt.subplot(2,1,2)
for i in xrange(len(left_index)):
plot_spk_train([cell.pos for cell in [cellist[x] for x in left_index[i]]],st[left_index[i]],I,colors=colors[i])
plt.ylim([500,3620])
plt.ylabel("Left")
plt.subplot(2,1,1)
for i in xrange(len(right_index)):
plot_spk_train([cell.pos for cell in [cellist[x] for x in right_index[i]]],st[right_index[i]],I,colors=colors[i])
plt.ylim([500,3620])
plt.ylabel("Right")
def plotLeftRightSpikeTrain(cellist,I):
st=np.array([train.tolist() for train in SpikeTrain(cellist)])
plt.subplot(2,1,2)
for i in xrange(len(left_index)):
if i!=25: # no mn plot for alan
plot_spk_train([cell.pos for cell in [cellist[x] for x in left_index[i]]],st[left_index[i]],I,colors=colors[i])
plt.ylim([-131,3620])
plt.ylabel("Left")
plt.subplot(2,1,1)
for i in xrange(len(right_index)):
if i!=25: # no mn plot for alan
plot_spk_train([cell.pos for cell in [cellist[x] for x in right_index[i]]],st[right_index[i]],I,colors=colors[i])
plt.ylim([-131,3620])
plt.ylabel("Right")
def plotLeftRightVoltageOffset(cellist,I,offset=0.0):
n = len(cellist)
for i in xrange(n):
if i<n/2:
plt.subplot(2,1,2)
(t,v) = TimeVoltageTrace(cellist[i])
plt.plot(t,v+offset*i,color=cellist[i].color,linewidth=1.0)
else:
plt.subplot(2,1,1)
(t,v) = TimeVoltageTrace(cellist[i])
plt.plot(t,v+offset*(i-n/2),color=cellist[i].color,linewidth=1.0)
plt.subplot(2,1,1)
plt.ylim([-80,40+offset*n/2])
plt.xlim(I)
plt.ylabel("Right")
plt.subplot(2,1,2)
plt.xlim(I)
plt.ylabel("Left")
plt.ylim([-80,40+offset*n/2])
def plotLeftRightVoltage(cellist,I):
n = len(cellist)
for i in xrange(n):
if cellist[i].body_side==1:
plt.subplot(4,1,4)
(t,v) = TimeVoltageTrace(cellist[i])
plt.plot(t,v,color=cellist[i].color,linewidth=1.0)
elif cellist[i].body_side==2:
plt.subplot(4,1,1)
(t,v) = TimeVoltageTrace(cellist[i])
plt.plot(t,v,color=cellist[i].color,linewidth=1.0)
else:
print "Error no body side declaration"
plt.subplot(4,1,4)
plt.xlim(I)
plt.ylabel("Left")
plt.ylim([-80,40])
plt.subplot(4,1,1)
plt.xlim(I)
plt.ylabel("Right")
plt.ylim([-80,40])
def plotLeftRightVoltageExt(cellist,I):
n = len(cellist)
for i in xrange(n):
if i<n/2:
plt.subplot(2,1,2)
(t,v) = TimeVoltageTrace(cellist[i])
plt.plot(t,v,color=cellist[i].color,linewidth=1.0)
else:
plt.subplot(2,1,1)
(t,v) = TimeVoltageTrace(cellist[i])
plt.plot(t,v,color=cellist[i].color,linewidth=1.0)
plt.subplot(2,1,2)
plt.ylabel("Left")
plt.ylim([-80,40])
plt.xlim([I[0],I[1]])
plt.subplot(2,1,1)
plt.xlim([I[0],I[1]])
plt.ylim([-80,40])
plt.ylabel("Right")
def probabilistic_model_extension(num):
P = np.load('data/P1000.npy')
rnd.seed(num)
n = num_types[-1]
Q = np.zeros((sum(tot_num_x_type),sum(tot_num_x_type)))
# dla->xIN
for idx1 in left_index[6]:
for idx2 in left_index[7]:
Q[idx1,idx2] = 0.03
for idx1 in right_index[6]:
for idx2 in right_index[7]:
Q[idx1,idx2] = 0.03
# dlc->xIN
for idx1 in left_index[1]:
for idx2 in right_index[7]:
Q[idx1,idx2] = 0.014
for idx1 in right_index[1]:
for idx2 in left_index[7]:
Q[idx1,idx2] = 0.014
# xIN->CPG
prob = 0.05 # decreasing this param reduces sync->rest transitions
dist_tresh = 1000
rostral_cpg_left = np.concatenate([[int(idx) for idx in left_index[numb] if pos[idx]<dist_tresh] for numb in [4,2,3,5] if len([int(idx) for idx in left_index[numb] if pos[idx]<dist_tresh])>0])
rostral_cpg_right = np.concatenate([[int(idx) for idx in right_index[numb] if pos[idx]<dist_tresh] for numb in [4,2,3,5] if len([int(idx) for idx in right_index[numb] if pos[idx]<dist_tresh])>0])
for i in left_index[7]:
for j in rostral_cpg_left:
Q[i,j] = prob
for j in rostral_cpg_right:
if rnd.rand()<0.:
Q[i,j] = prob
for i in right_index[7]:
for j in rostral_cpg_right:
Q[i,j] = prob
for j in rostral_cpg_left:
if rnd.rand()<0.:
Q[i,j] = prob
# xIN->xIN
cell_num = 5 # mn
Psmall = P[np.ix_(par.left_index_old[cell_num],par.left_index_old[cell_num])] + P[np.ix_(par.left_index_old[cell_num],par.left_index_old[cell_num])].transpose()
molt=2.5
n_xin = len(left_index[7])
for i in xrange(n_xin):
for j in range(n_xin):
Q[i+left_index[7][0],j+left_index[7][0]] = molt*Psmall[i,j]
if rnd.rand()<0.33:
Q[i+left_index[7][0],j+right_index[7][0]] = molt*Psmall[i,j]
n_xin=len(right_index[7])
for i in xrange(n_xin):
for j in range(n_xin):
Q[i+right_index[7][0],j+right_index[7][0]] = molt*Psmall[i,j]
if rnd.rand()<0.33:
Q[i+right_index[7][0],j+left_index[7][0]] = molt*Psmall[i,j]
# tSt->tIN
for i in left_index[8]:
for j in left_index[9]:
Q[i,j] = GeneralizeData(P[rnd.choice(par.left_index_old[0]),par.left_index_old[6]])
for i in right_index[8]:
for j in right_index[9]:
Q[i,j] = GeneralizeData(P[rnd.choice(par.right_index_old[0]),par.right_index_old[6]])
# tSt->rdlc
for i in left_index[8]:
for j in [idx for idx in left_index[1] if pos[idx]<700]:
Q[i,j] = GeneralizeData(P[rnd.choice(par.left_index_old[0]),par.left_index_old[1]])
for i in right_index[8]:
for j in [idx for idx in right_index[1] if pos[idx]<700]:
Q[i,j] = GeneralizeData(P[rnd.choice(par.right_index_old[0]),par.right_index_old[1]])
# tIN->xIN
for idx1 in left_index[9]:
for idx2 in left_index[7]:
Q[idx1,idx2] = 0.03
for idx1 in right_index[9]:
for idx2 in right_index[7]:
Q[idx1,idx2] = 0.03
# tIN->dIN
for post_num in range(4,5):
pre = par.left_index_old[6] # dla
post = par.left_index_old[post_num]
f = distance_dependent_prob(pre,post)
for i in left_index[9]:
for j in left_index[post_num]:
x = pos[j]-pos[i]
Q[i,j] = f(-x)
pre = par.right_index_old[6] # dla
post = par.right_index_old[post_num]
f = distance_dependent_prob(pre,post)
for i in right_index[9]:
for j in right_index[post_num]:
x = pos[j]-pos[i]
Q[i,j] = f(-x)
# tSp->MHR
for i in left_index[10]:
for j in left_index[11]:
Q[i,j] = GeneralizeData(P[rnd.choice(par.left_index_old[0]),par.left_index_old[6]])
for i in right_index[10]:
for j in right_index[11]:
Q[i,j] = GeneralizeData(P[rnd.choice(par.right_index_old[0]),par.right_index_old[6]])
# MHR-> ipsilaterally to CPG
for post_num in range(2,6):
# left MHRs
pre = par.left_index_old[1][-2:-1]
post = par.right_index_old[post_num]
f = distance_dependent_prob(pre,post)
# contralateral
for i in left_index[11]:
for j in right_index[post_num]:
x = pos[j]-pos[i]
Q[i,j] = f(-x)
# 20% ipsilateral
for i in rnd.random_integers(left_index[11][0],left_index[11][-1],len(left_index[11])*2/10):
for j in left_index[post_num]:
x = pos[j]-pos[i]
Q[i,j] = f(-x)
# right MHRs
pre = par.right_index_old[1][-2:-1]
post = par.left_index_old[post_num]
f = distance_dependent_prob(pre,post)
# contralateral
for i in right_index[11]:
for j in left_index[post_num]:
x = pos[j]-pos[i]
Q[i,j] = f(-x)
# 20% ipsilateral
for i in rnd.random_integers(right_index[11][0],right_index[11][-1],len(right_index[11])*2/10):
for j in right_index[post_num]:
x = pos[j]-pos[i]
Q[i,j] = f(-x)
Aconnectome = np.load("data/A_connectome"+str(num)+".npy")
A = np.zeros((n,n))
for i in xrange(n):
for j in xrange(n):
if rnd.rand()<Q[i,j]:
A[i,j] = 1
# anatomical connectome
A[np.ix_(xrange(sum(tot_num_x_type[0:8])),xrange(sum(tot_num_x_type[0:8])))]=Aconnectome;
return A
def plot_matrix_detailed(Q):
n=len(Q)
num_x_type=[num_types[i+1]-num_types[i] for i in xrange(len(num_types)-1)]
num_x_type.insert(0,0)
halves=np.divide(num_x_type,2)
types=cell_types
#order_best=[0,8,10,6,1,9,11,7,2,3,4,13,5,12]
order_best=[0,8,10,6,1,9,11,7,2,3,4,5]
grid_color1="grey"
grid_color2="grey"
lw=0.5
tr=0.5
alpha=1
Q=Q.transpose()
Qnew=np.zeros(Q.shape)
v_ind=[]
for i in xrange(len(num_x_type)+1):
if i!=0:
v_ind.append(np.sum(num_x_type[0:i]))
new_idx=[]
for i in xrange(len(order_best)):
if v_ind[order_best[i]+1]-v_ind[order_best[i]]>0:
new_idx.append(range(v_ind[order_best[i]],v_ind[order_best[i]+1]))
# === test with different colors ===
idx_inh = np.concatenate([range(v_ind[2],v_ind[3]), range(v_ind[3],v_ind[4]), range(v_ind[11],v_ind[12])]) # aIN, cIN, MHR indexes
Q[:,idx_inh] = - Q[:,idx_inh]
new_idx=np.concatenate((new_idx)).tolist()
Qnew=Q[new_idx,:]
Qnew=Qnew[:,new_idx]
colors_new = [colors[i] for i in order_best]
num_x_type = [num_x_type[i+1] for i in order_best]
num_x_type.insert(0,0)
halves = np.divide(num_x_type,2)
types = [cell_types[i] for i in order_best]
fig, ax = plt.subplots(figsize=(14,14))
ax.matshow(Qnew.transpose(),norm=mpl.colors.Normalize(vmin=-1.5, vmax=1.5), cmap='seismic')
#plt.matshow(Qnew.transpose(),cmap='Greys_r')
plt.plot([-100,n-0.5],[n-0.5,n-0.5],color=grid_color1,alpha=alpha)
plt.plot([-100,n-0.5],[-100,-100],color=grid_color1,alpha=alpha)
plt.plot([n-0.5,n-0.5],[-100,n-0.5],color=grid_color1,alpha=alpha)
plt.plot([-100,-100],[-100,n-0.5],color=grid_color1,alpha=alpha)
#plt.text(-60,-40,'P', fontsize=20)
for i in xrange(len(num_x_type)):
line_half=np.sum(num_x_type[0:i])+halves[i]-tr
line=np.sum(num_x_type[0:i])-tr
plt.plot([0,n-tr],[line_half,line_half],'--',color=grid_color2,linewidth=lw,alpha=alpha)
plt.plot([-100,n-tr],[line,line],'-',color=grid_color1,linewidth=lw,alpha=alpha)
plt.plot([line_half,line_half],[0,n-tr],'--',color=grid_color2,linewidth=lw,alpha=alpha)
plt.plot([line,line],[-100,n-tr],'-',color=grid_color1,linewidth=lw,alpha=alpha)
for i in xrange(len(types)):
plt.text(np.sum(num_x_type[0:i+1])+halves[i+1]-tr-25,-40,types[i],fontsize=12,fontweight='bold',color=colors_new[i],rotation=65)
plt.text(-80,np.sum(num_x_type[0:i+1])+halves[i+1]-tr,types[i],fontsize=12,fontweight='bold',color=colors_new[i],rotation=20)
plt.xlim([-105,num_types[-1]+10])
plt.ylim([-105,num_types[-1]+10])
plt.axis("off")
def GeneralizeData(input_data):
data=np.sort(input_data)
cdf = sm.distributions.ECDF(data)
p=np.unique(cdf(data)).tolist()
p.insert(0,0)
x=np.unique(data).tolist()
x.insert(0,x[0])
w=rnd.rand()
r=next(tmp[0] for tmp in enumerate(p[1:]) if w<tmp[1])
if p[r+1]-p[r] != 0:
return x[r]+(x[r+1]-x[r])*(w-p[r])/(p[r+1]-p[r])
else:
return x[r]
def probability_visualization(P):
plt.figure()
plt.imshow(P,cmap='Greys_r')
plt.show()
def period(spk,tstop=1500):
T=[]
bound1=50
bound2=70
for i in vect_index[5]:
if len(spk[i])>20:
tmp=spk[i][-1]-spk[i][-2]
if tmp>bound1 and tmp<bound2:
T.append(tmp)
return np.median(T)
def indexes(A):
idx=np.where(sum(A[np.ix_(vect_index[4],vect_index[4])])>=thresh1)
conn=sum(A[np.ix_(750+idx[0],vect_index[4])])
din_mean=np.mean(conn)
din_std=np.std(conn)
idx=np.where(sum(A[np.ix_(vect_index[4],vect_index[3])])>=thresh2)
conn=sum(A[np.ix_(366+idx[0],vect_index[4])])
cin_mean=np.mean(conn)
cin_std=np.std(conn)
y_hat=din_mean-cin_mean
y_mdl=79.13-1.33*y_hat
return (din_mean,din_std,cin_mean,cin_std)
def distance_dependent_prob(pre,post):
P = np.load("data/P1000.npy")
half_pos = 3000
histo = range(-half_pos,half_pos+1,40)
all_prob=[]
for i in pre:
y_histo = np.zeros(len(histo))
count = np.zeros(len(histo))
for j in post:
if P[i,j]>0:
h = pos_old[j]-pos_old[i]
find_pos = [r for r in xrange(len(histo)-1) if histo[r]<h and histo[r+1]>=h]
y_histo[find_pos] = y_histo[find_pos]+P[i,j]
count[find_pos] = count[find_pos]+1
tmp=[]
for k in xrange(len(y_histo)):
if count[k]!=0:
tmp.append(y_histo[k]/count[k])
else:
tmp.append(0)
all_prob.append(tmp)
prob_mean = np.mean(all_prob,axis=0)
f = interp1d(np.multiply(histo,1),prob_mean,kind='linear',fill_value="extrapolate")
return f
def first_firing_times(spk,tstop):
n=len(spk)
spk_left=spk[0:n/2]
spk_right=spk[n/2+1:n]
start_left=[]
for train in spk_left:
if len(train)>=3:
start_left.append(train[0])
start_right=[]
for train in spk_right:
if len(train)>=3:
start_right.append(train[0])
return (start_left,start_right)
def average_dIN_volt(v):
ave_v=[]
for i in xrange(v.shape[1]):
ave_v.append(np.mean(v[:,i]))
return ave_v
def classify_behaviour(spk,hdin_volt,m,t): # only works with fixed time step integration
din_spks_l=[spk[j] for j in left_index[4] if pos[j]<1000.0]
din_spks_r=[spk[j] for j in right_index[4] if pos[j]<1000.0]
swim_l_start=np.median([x[0] for x in din_spks_l if len(x)>3])
swim_r_start=np.median([x[0] for x in din_spks_r if len(x)>3])
if np.all([~np.isnan(swim_l_start), ~np.isnan(swim_r_start)]):
if swim_l_start<=swim_r_start:
tstar=np.mean([x[0] for x in din_spks_l if len(x)>0])
else:
tstar=np.mean([x[0] for x in din_spks_r if len(x)>0])
else:
tstar=0
n=len(hdin_volt)
hdin_v_l=hdin_volt[0:n/2]
hdin_v_r=hdin_volt[n/2:-1]
adin_l=average_dIN_volt(np.array(hdin_v_l))
adin_r=average_dIN_volt(np.array(hdin_v_r))
thresh=-27
idxL=[k for k in xrange(len(adin_l)-1) if adin_l[k+1]>=thresh and adin_l[k]<thresh]
if len(idxL)>0:
tL=t[idxL[0]]
else:
tL=None
idxR=[k for k in xrange(len(adin_r)-1) if adin_r[k+1]>=thresh and adin_r[k]<thresh]
if len(idxR)>0:
tR=t[idxR[0]]
else:
tR=None
plt.figure()
plt.plot(t,adin_l,'b',label='left',linewidth=2.0)
plt.plot(t,adin_r,'r',label='right',linewidth=2.0)
plt.legend(loc='upper left',fontsize=18)
plt.ylabel('<v$_{hdIN}$>')
if tstar is not 0:
plt.fill_between(t, -60, -10, where=t>tstar,facecolor='green', alpha=0.25)
plt.ylim([-80,-10])
if len(idxL)>0:
idx_l=idxL[0]
plt.plot(t[idx_l], adin_l[idx_l],'bo',markersize=8)
idxR=[k for k in xrange(len(adin_r)-1) if adin_r[k+1]>=thresh and adin_r[k]<thresh]
if len(idxR)>0:
idx_r=idxR[0]
plt.plot(t[idx_r], adin_r[idx_r],'ro',markersize=8)
if tL!=None and tR!=None:
if abs(tL-tR)>3.0:
if tL<tR:
out=1 # swim left
else:
out=2 # swim right
else:
out=3 # sync
else:
if tL!=None or tR!=None:
out = 4 # one sided
else:
out = 5 # no swim
return (out,tL,tR,tstar)
def classification_name(idx):
if idx==0:
return "Undetected"
if idx==1:
return "Swim left"
if idx==2:
return "Swim right"
if idx==3:
return "Sync"
if idx==4:
return "One sided"
if idx==5:
return "No swim"