import sys
import numpy as np
#import matplotlib.pylab as plt
import json
from argparse import ArgumentParser
import os.path
#####################################
parser = ArgumentParser()
parser.add_argument("--mode",dest="mode",
help="Choose deprivation mode: MD-CL (=contra monocular deprivation), BD (=binocular deprivation), MI (=contra monocular inactivation), MD-IL (=ipsi monocular deprivation).",
default='MD-IL')
args = parser.parse_args()
#####################################
###### Functions
#####################################
def f_IEvol_WilCow(y,ipt,exc,iiw,ipw,iew,tauy):
y = y + (timeStep/tauy)*(-y + gainF_relu( np.dot(ipw,ipt)+np.dot(iew,exc)-np.dot(iiw,y),beta_i ) )
return y
def f_EEvol_WilCow(y,ipt,inh,eew,ipw,eiw,tauy):
y = y + (timeStep/tauy)*(-y + gainF_relu( np.dot(ipw,ipt)+np.dot(eew,y)
-np.dot(eiw,inh),beta_e ) )
return y
def gainF_relu(x,slopef):
out = slopef*(x-theta_v)
out = out*(x>theta_v)
return out
def f_excitatory_plasticity(pre=0,post=0,w=0,recurrent='no'):
dw = ((np.outer(post,pre)>theta_ee_L)*(np.outer(post,pre)-theta_ee_H)
*(np.outer(post,pre)-theta_ee_L))
dw = np.tanh(dw/0.0001)
w = w + timeStep*l_rate_ee*dw
w[w>wmax_exc] = wmax_exc
w[w<wmin_exc] = wmin_exc
if recurrent == 'yes':
if np.size(w,axis=0) == np.size(w,axis=1):
w = w - np.diag(np.diag(w))
else:
print('Problem with recurrent connections: no square matrix!')
return w
def f_calc_odi(w_ee_L3b,w_ee_L4to3b,w_ei_L3b,w_ie_ffwb,w_ie_L3_effb,w_ii_L3b):
# Store firing rates in order to calculate ocular dominance index
tttime=50
respv = np.zeros((NrE_L3,2)) # storage of data
respv2 = np.zeros((NrE_L4,2)) # storage of data
storV = np.zeros((NrE_L3,int(tttime/timeStep))) # storage of data
storVI = np.zeros((NrI_L3,int(tttime/timeStep))) # storage of data
storV4 = np.zeros((NrE_L4,int(tttime/timeStep))) # storage of data
for ippp in range(2):
EvL4 = np.zeros(NrE_L4) # vector of excitatory activities
EvL3 = np.zeros(NrE_L3) # vector of excitatory activities
IvL3 = np.zeros(NrI_L3) # vector of inhibitory activities
if ippp == 1:
# 'ipsi_closed'
a = 0
b = 1
else:
#'contra_closed'
a = 1
b = 0
for tt in np.arange(0,tttime,timeStep):
if tt%stimLen == 0:
Iptv = 0*np.ones(NrE_L4)
if tt%deltaStim == 0:
Visual_Ipt_gauss_ipsi = L4_connections[:,0]*np.array([a*vis_amp]*NrE_L4)
Visual_Ipt_gauss_contra = L4_connections[:,1]*np.array([b*vis_amp]*NrE_L4)
Visual_Ipt_gauss = a*Visual_Ipt_gauss_ipsi + b*Visual_Ipt_gauss_contra
Iptv = Visual_Ipt_gauss + Iptv
Iptv = 1*Iptv*(Iptv>0)
EvL4 = beta_e*(Iptv)
IvL3 = f_IEvol_WilCow(IvL3,layerfactor*EvL4,EvL3,w_ii_L3b,w_ie_ffwb,w_ie_L3_effb,tau_i)
EvL3 = f_EEvol_WilCow(EvL3,layerfactor*EvL4,IvL3,w_ee_L3b,w_ee_L4to3b,w_ei_L3b,tau_e)
storV[:,int(tt/timeStep)] = EvL3
storVI[:,int(tt/timeStep)] = IvL3
storV4[:,int(tt/timeStep)] = EvL4
ts = int(stimLen/timeStep)
for xx in range(np.size(storV,axis=0)):
respv[xx,ippp]= np.mean(storV[xx,:ts],axis=0)
for xx in range(np.size(storV4,axis=0)):
respv2[xx,ippp]= np.mean(storV4[xx,:ts],axis=0)
return respv,respv2
#####################################
###### Parameters
#####################################
#*** General
timeStep = 1 # timestep [ms]
tmax = 15000 # Final time of simulation [ms]
storeDT = timeStep # Frequency for storing data
stimLen = 25 # Duration of a visual input [ms]
deltaStim = 50 # Time between onset of two visual inputs [ms]
deprtime = 500 # Time of deprivation onset [ms]
RedInhibTime = 2000 # Time when inhibition will be reduced [ms]
#*** NEURONS
NrE_L3 = 1 # total population of neurons in l3
NrE_L4 = 1250 # total population of neurons in l4
NrExcNrns = 125 # nr of l4 -> l3 connections (per eye)
Nr4to3 = 2*NrExcNrns # nr of l4 -> l3 connections (both eyes)
NrI_L3 = 1 # total nr of inhibitory neurons in l3
tau_e = 1 # excitatory time constant [ms]
tau_i = 1 # inhibitory time constant [ms]
beta_e = .3 # excitatory gain function slope
beta_i = .3 # inhibitory gain function slope
theta_v = 0 # activity threshold
#*** Activity
vis_amp = 10 # Amplitude of visual input
background_current = 10 #7.5 # Amplitude of background input
if args.mode == "MD-CL":
Contra_depr = 0
Ipsi_depr = 1
bckg_depr = 1
elif args.mode == "BD":
Contra_depr = 0
Ipsi_depr = 0
bckg_depr = 1
elif args.mode == "MI":
Contra_depr = 0
Ipsi_depr = 1
bckg_depr = 0
elif args.mode == "MD-IL":
Contra_depr = 1
Ipsi_depr = 0
bckg_depr = 1
else:
sys.exit("ERROR: wrong mode argument, please choose between MD-CL, BD, MI or MD-IL")
theta_ee_H = 25 #13
theta_ee_L = 12.5 #5
layerfactor = 1 # Extra ffw input factor from layer 4 to 3
#*** Plasticity
l_rate_ee = 4e-6
#*** initial w
wmax_exc = 7.5/Nr4to3
wmin_exc = .05*wmax_exc
wee_ini = 0
wee_L4to3_ini = wmax_exc
wie_ini = 0
w_ie_ffw_ini = 7.5/Nr4to3
wei_ini = 1.75
w_ii_ini = 0*1.5
#####################################
###### Variables
#####################################
#*** LAYER 3
Ev_L3 = np.zeros(NrE_L3) # vector of excitatory activities
Iv_L3 = np.zeros(NrI_L3) # vector of inhibitory activities
w_ie_ffw = w_ie_ffw_ini*np.ones((NrI_L3,NrE_L4)) # ffw weight matrix
w_ei_L3 = wei_ini*np.ones((NrE_L3,NrI_L3)) # ffw weight matrix
w_ie_L3 = wie_ini*np.ones((NrI_L3,NrE_L3))
w_ii_L3 = w_ii_ini*np.ones((NrI_L3,NrI_L3))
w_ee_4to3 = wee_L4to3_ini*np.ones((NrE_L3,NrE_L4))
w_ee_L3 = wee_ini*np.ones((NrE_L3,NrE_L3))
#*** LAYER 4
Ev_L4 = np.zeros(NrE_L4) # vector of excitatory activities
L4_connections = np.ones((NrE_L4,2))
L3_connections = np.zeros((NrE_L3,NrE_L4))
L3_connections_low = np.zeros((NrE_L3,NrE_L4))
meancontra = 30
stdvcontra = 25
q = (meancontra +np.random.randn(NrE_L4)*stdvcontra)
q[q<0] = 100*np.random.rand(len(q[q<0]))
q[q>100] = 100*np.random.rand(len(q[q>100])) #100
q = np.sort(q)
q = np.expand_dims(q,1)
if os.path.exists('dat/connections.txt'):
dat_load = open('dat/connections.txt','r')
dat_conncs = json.load(dat_load)
dat_load.close()
L3_connections = np.array(dat_conncs[0])
L4_connections = np.array(dat_conncs[1])
else:
L4_connections = np.append(q,100-q,axis=1)*.01
arr = np.array((np.random.permutation(np.arange(1250))<250)*1)
L3_connections[0,:] = arr
L3_connections_inh = np.append(np.append(np.ones((1,int(.6*Nr4to3))),np.zeros((1,NrE_L4-Nr4to3))
,axis=1),np.ones((1,int(.4*Nr4to3))),axis=1)
storeV = np.zeros((15,int(tmax/storeDT))) # storage of data
storeW = np.zeros((Nr4to3,int(tmax/storeDT))) # storage of data
######################################
#----- Run -----#
######################################
stimWindow = int(deltaStim/timeStep)
strV3 = np.zeros((NrE_L3,stimWindow)) # storage of data
strV4 = np.zeros((NrE_L4,stimWindow)) # storage of data
strI3 = np.zeros((NrI_L3,stimWindow))
bckgr = background_current
deprived = 1
BD_depr = 1
for tt in np.arange(0,int(tmax),timeStep):
if tt%stimLen == 0:
Iptv = np.ones(NrE_L4)*0
if tt%deltaStim == 0:
Visual_Ipt_gauss_ipsi = L4_connections[:,0]*np.array([BD_depr*vis_amp]*NrE_L4)
Visual_Ipt_gauss_contra = L4_connections[:,1]*np.array([deprived*vis_amp]*NrE_L4)
Visual_Ipt_gauss = Visual_Ipt_gauss_ipsi + Visual_Ipt_gauss_contra
Iptv = Iptv+Visual_Ipt_gauss+bckgr
Iptv = 1*Iptv*(Iptv>0)
storeT = int((tt%deltaStim)/timeStep)
strV3[:,storeT] = Ev_L3
strV4[:,storeT] = Ev_L4
strI3[:,storeT] = Iv_L3
ie_ff = L3_connections_inh*w_ie_ffw #+ L3_connections_inh_low*w_ie_ffw_low
ee_ff = L3_connections*w_ee_4to3 #+ L3_connections_low*w_ee_4to3_low
if np.abs(tt-deprtime) < timeStep:
rvv_d0,rvv2 = f_calc_odi(w_ee_L3,ee_ff,w_ei_L3,ie_ff,w_ie_L3,w_ii_L3)
deprived = Contra_depr
BD_depr = Ipsi_depr #depr_value
bckgr = bckgr*bckg_depr
elif np.abs(tt-RedInhibTime) < timeStep:
if args.mode=="MD-IL":
w_ie_ffw = w_ie_ffw*.75
else:
w_ie_ffw = w_ie_ffw*.35
Ev_L4 = beta_e*(Iptv)
Iv_L3 = f_IEvol_WilCow(Iv_L3,layerfactor*Ev_L4,Ev_L3,w_ii_L3,ie_ff,w_ie_L3,tau_i)
Ev_L3 = f_EEvol_WilCow(Ev_L3,layerfactor*Ev_L4,Iv_L3,w_ee_L3,ee_ff,w_ei_L3,tau_e)
w_ee_4to3 = f_excitatory_plasticity(pre=Ev_L4,post=Ev_L3,w=w_ee_4to3,recurrent='no')
if tt%storeDT < timeStep:
storeV[0,int(tt/storeDT)] = Ev_L3[0]
storeV[1,int(tt/storeDT)] = w_ee_4to3[0,0]
storeV[2,int(tt/storeDT)] = w_ee_4to3[0,-1]
storeV[3,int(tt/storeDT)] = Iv_L3[0]
storeV[4,int(tt/storeDT)] = w_ee_4to3[0,10]
storeV[5,int(tt/storeDT)] = w_ee_4to3[0,20]
storeV[6,int(tt/storeDT)] = w_ee_4to3[0,40]
storeV[7,int(tt/storeDT)] = w_ee_4to3[0,int(NrE_L4*.5)]
storeV[8,int(tt/storeDT)] = np.mean(Ev_L4[:])
storeV[9,int(tt/storeDT)] = Ev_L3[0]*Ev_L4[0]
storeV[10,int(tt/storeDT)] = Ev_L3[0]*Ev_L4[-1]
storeW[:,int(tt/storeDT)] = w_ee_4to3[0,L3_connections[0,:]==1]
##################################
# Display the progress
###################################
if np.abs(tt-tmax*0.05*np.round(20*tt/tmax)) < 0.5*timeStep:
print('> '+str(int(np.round(100*tt/tmax)))+'%')
rvv_d3,rvv2 = f_calc_odi(w_ee_L3,ee_ff,w_ei_L3,ie_ff,w_ie_L3,w_ii_L3)
##################################
# Store Data
##################################
w_evo = storeW[:,0:-1:500]
odi_d0=(rvv_d0[:,1]-rvv_d0[:,0])/(rvv_d0[:,0]+rvv_d0[:,1])
odi_d3=(rvv_d3[:,1]-rvv_d3[:,0])/(rvv_d3[:,0]+rvv_d3[:,1])
dataw = [w_evo.tolist(),storeV[9,:].tolist(),storeV[10,:].tolist(),odi_d0.tolist(),odi_d3.tolist()]
data3 = open("dat/"+args.mode+'.txt','w')
json.dump(dataw,data3)
data3.close()
if not os.path.exists('dat/connections.txt'):
data2 = [L3_connections.tolist(),L4_connections.tolist()]
datfile = open('dat/connections.txt','w')
json.dump(data2,datfile)
datfile.close()
##################################
# Plots
##################################
#
#lw=3
#fs = 18
#
#plt.figure()
#plt.pcolor(w_evo, vmin=wmin_exc, vmax=wmax_exc)
#cbar = plt.colorbar(ticks = [wmin_exc,wmax_exc])
#cbar.ax.set_yticklabels(['min','max'],fontsize=fs)
#plt.xlabel('Time [a.u.]',fontsize=fs)
#plt.ylabel('Synapses [a.u.]',fontsize=fs)
#plt.title('Weight evolution, '+args.mode, fontsize=fs)
#plt.text(2, 125, r'$\star$', fontsize=1.5*fs)
#plt.text(10, 125, r'$\ddag$', fontsize=fs)
#plt.xticks(fontsize=fs)
#plt.yticks(fontsize=fs)
#ax = plt.gca()
#ax.yaxis.set_label_coords(-.15,.5)
#plt.annotate("",
# xy=(-15, 0), xycoords='data',
# xytext=(-15, 250), textcoords='data',
# arrowprops=dict(arrowstyle="<->",
# connectionstyle="arc3",linewidth=lw),annotation_clip=False)
#plt.text(-20, -10, 'Contra', fontsize=fs)
#plt.text(-17, 255, 'Ipsi', fontsize=fs)
#
#plt.figure()
#plt.plot(np.arange(0,int(tmax),timeStep),storeV[3,:],linewidth=4,label='i3')
#plt.plot(np.arange(0,int(tmax),timeStep),storeV[0,:],linewidth=4,label='e3')
#plt.plot(np.arange(0,int(tmax),timeStep),storeV[8,:],linewidth=4,label='e4')
#plt.legend()
#
#
###
##
#
#plt.figure()
#ax = plt.plot(np.arange(0,int(tmax),timeStep),storeV[1,:],'-',linewidth=3,label='closed-eye dominated',color='b')
#plt.plot(np.arange(0,int(tmax),timeStep),storeV[2,:],'--',linewidth=3,label='open-eye dominated',color='r')
#plt.plot(np.arange(0,int(tmax),timeStep),storeV[7,:],'--',linewidth=3,label='open-eye dominated2',color='r')
#plt.legend(loc='best')
##plt.ylim([.2*wmax_exc,wmax_exc+.1*wmax_exc])
#plt.yticks([.1*wmax_exc,wmax_exc],['.1*wmax','wmax'],fontsize=fs)
#plt.xlabel('Time [a.u.]',fontsize=fs)
#plt.title('Weight evolution', fontsize=fs)
#
#
##stimWindow = deltaStim
#range1 = np.arange(0,stimWindow)
#range2 = np.arange(12*stimWindow,13*stimWindow)
#range3 = np.arange(150*stimWindow,151*stimWindow)
##range4 = np.arange(190*stimWindow,199*stimWindow)
#
#CL = np.append(np.append(storeV[9,range1],storeV[9,range2],axis=0),storeV[9,range3],axis=0)#,storeV[9,range4],axis=0)#
#IL = np.append(np.append(storeV[10,range1],storeV[10,range2],axis=0),storeV[10,range3],axis=0)#,storeV[10,range4],axis=0)#
#
#a = beta_e
#b = beta_i
#w_ef = wmax_exc*Nr4to3
#w_if = w_ie_ffw_ini*Nr4to3
#w_ee = wee_ini
#w_ei = wei_ini
#w_ie = wie_ini
#F = a*(vis_amp+0*background_current + background_current)
#
#denom = (1 - a*w_ee + a*b*w_ei*w_ie)
#Theory_prepost = layerfactor*F**2*(a*w_ef - a*b*w_ei*w_if)/denom + 0*F*(a-a*b*w_ei)*background_current/denom
#
#plt.figure()
#plt.plot(np.arange(0,int(3*deltaStim),timeStep),CL,'-',linewidth=3,label='closed-eye dominated',color='b')
#plt.plot(np.arange(0,int(3*deltaStim),timeStep),IL,'--',linewidth=3,label='open-eye dominated',color='r')
#plt.plot([0,int(3*deltaStim)],[theta_ee_H,theta_ee_H],'-',linewidth=2,label=r'$\theta_H$',color='k')
#plt.plot([0,int(3*deltaStim)],[theta_ee_L,theta_ee_L],'--',linewidth=2,label=r'$\theta_L$',color='k')
#plt.plot([0,int(deltaStim)],[Theory_prepost,Theory_prepost],'--',linewidth=2,label=r'theory',color='m')
#plt.legend(loc='best')
#plt.ylabel(r'$\rho_{pre} \rho_{post}$ [Hz$^2$]',fontsize= fs)
#plt.xticks([.05*deltaStim,1.05*deltaStim,2.5*deltaStim],['Normal','MD','MD+reduced inhib.'],fontsize=fs)
##plt.xlabel('Time [a.u.]',fontsize=fs)
#plt.title('Pre * post firing rates', fontsize=fs)
##plt.savefig('img/PrePost_'+args.mode+'.eps',bbox_inches='tight')
#
#bns=7
#
#
#
#tit_var = 'Ocular dominance plasticity'
#
##dataw = [w_ee.tolist(),w_ei.tolist(),w_ie.tolist(),w_ii.tolist(),w_e_ff.tolist(),w_i_ff.tolist()]
##data3 = open(datvar+'_wdata_'+namestr+'.txt','w')
##json.dump(dataw,data3)
##data3.close()
#
#plt.figure()
#plt.plot([0,1],[odi_d0[0],odi_d3[0]],'-',linewidth=lw,color='k')
#plt.plot([0,1],[odi_d0[0],odi_d3[0]],'.',Markersize=25,linewidth=lw,color='k')
#plt.xlim([-.1,1.1])
##plt.ylim([-.05,.3])
#plt.xticks([0,1],['before','after'],rotation=60,fontsize=fs)
#plt.ylabel('ODI',fontsize=fs)
#plt.title(tit_var,fontsize=fs)
##plt.savefig('img/ODI_'+args.mode+'.eps',bbox_inches='tight')
#
#odi_tot_4 = (np.sum(rvv2[:,1])-np.sum(rvv2[:,0]))/(np.sum(rvv2[:,1])+np.sum(rvv2[:,0]))
#odi_4 = (rvv2[:,1]-rvv2[:,0])/(rvv2[:,1]+rvv2[:,0])
#bns=7
#lw=3
#
#plt.figure()
#plt.subplot(1,2,1)
#plt.hist(odi_4,bins=bns,range=(-1,1),facecolor="None",edgecolor='b',label='Before',linewidth=lw)
##plt.legend()
#plt.title('4')
#
#plt.subplot(1,2,2)
#plt.bar([1],[np.sum(rvv2[:,1])],label='contra',color='b')
#plt.bar([2],[np.sum(rvv2[:,0])],label='ipsi',color='r')
#plt.legend()
#plt.title('4')