#import sys
import numpy as np
#import matplotlib.pylab as plt
import json
from argparse import ArgumentParser
import os.path
#####################################
parser = ArgumentParser()
parser.add_argument("--distr",dest="distr",
help="Choose distribution type: broad or narrow",
default='broad')
args = parser.parse_args()
distr_mode = args.distr #
#####################################
###### Functions
#####################################
def f_IEvol_WilCow(y,ipt,exc,iiw,ipw,iew,tauy):
y = y + (timeStep/tauy)*(-y + gainF_relu( np.dot(ipw,ipt)+np.dot(iew,exc)-np.dot(iiw,y),beta_i ) )
return y
def f_EEvol_WilCow(y,ipt,inh,eew,ipw,eiw,tauy):
y = y + (timeStep/tauy)*(-y + gainF_relu( np.dot(ipw,ipt)+np.dot(eew,y)
-np.dot(eiw,inh),beta_e ) )
return y
def gainF_relu(x,slopef):
out = slopef*(x-theta_v)
out = out*(x>theta_v)
return out
def f_excitatory_plasticity(pre=0,post=0,w=0,recurrent='no'):
dw = ((np.outer(post,pre)>theta_ee_L)*(np.outer(post,pre)-theta_ee_H)
*(np.outer(post,pre)-theta_ee_L))
dw = np.tanh(dw/0.0001)
w = w + timeStep*l_rate_ee*dw
w[w>wmax_exc] = wmax_exc
w[w<wmin_exc] = wmin_exc
if recurrent == 'yes':
if np.size(w,axis=0) == np.size(w,axis=1):
w = w - np.diag(np.diag(w))
else:
print('Problem with recurrent connections: no square matrix!')
return w
def f_calc_odi(w_ee_L3b,w_ee_L4to3b,w_ei_L3b,w_ie_ffwb,w_ie_L3_effb,w_ii_L3b):
# Store firing rates in order to calculate ocular dominance index
tttime=50
respv = np.zeros((NrE_L3,2)) # storage of data
respv2 = np.zeros((NrE_L4,2)) # storage of data
storV = np.zeros((NrE_L3,int(tttime/timeStep))) # storage of data
storVI = np.zeros((NrI_L3,int(tttime/timeStep))) # storage of data
storV4 = np.zeros((NrE_L4,int(tttime/timeStep))) # storage of data
for ippp in range(2):
EvL4 = np.zeros(NrE_L4) # vector of excitatory activities
EvL3 = np.zeros(NrE_L3) # vector of excitatory activities
IvL3 = np.zeros(NrI_L3) # vector of inhibitory activities
if ippp == 1:
# 'ipsi_closed'
a = 0
b = 1
else:
#'contra_closed'
a = 1
b = 0
for tt in np.arange(0,tttime,timeStep):
if tt%stimLen == 0:
Iptv = 0*np.ones(NrE_L4)
if tt%deltaStim == 0:
Visual_Ipt_gauss_ipsi = L4_connections[:,0]*np.array([a*vis_amp]*NrE_L4)
Visual_Ipt_gauss_contra = L4_connections[:,1]*np.array([b*vis_amp]*NrE_L4)
Visual_Ipt_gauss = a*Visual_Ipt_gauss_ipsi + b*Visual_Ipt_gauss_contra
Iptv = Visual_Ipt_gauss + Iptv
Iptv = 1*Iptv*(Iptv>0)
EvL4 = beta_e*(Iptv)
IvL3 = f_IEvol_WilCow(IvL3,layerfactor*EvL4,EvL3,w_ii_L3b,w_ie_ffwb,w_ie_L3_effb,tau_i)
EvL3 = f_EEvol_WilCow(EvL3,layerfactor*EvL4,IvL3,w_ee_L3b,w_ee_L4to3b,w_ei_L3b,tau_e)
storV[:,int(tt/timeStep)] = EvL3
storVI[:,int(tt/timeStep)] = IvL3
storV4[:,int(tt/timeStep)] = EvL4
ts = int(stimLen/timeStep)
for xx in range(np.size(storV,axis=0)):
respv[xx,ippp]= np.mean(storV[xx,:ts],axis=0)
for xx in range(np.size(storV4,axis=0)):
respv2[xx,ippp]= np.mean(storV4[xx,:ts],axis=0)
return respv,respv2
#####################################
###### Parameters
#####################################
#*** General
timeStep = 1 # timestep [ms]
tmax = 10000 # Final time of simulation [ms]
storeDT = timeStep # Frequency for storing data
stimLen = 25 # Duration of a visual input [ms]
deltaStim = 50 # Time between onset of two visual inputs [ms]
deprtime = 500 # Time of deprivation onset [ms]
RedInhibTime = 2000 # Time when inhibition will be reduced [ms]
#*** NEURONS
NrE_L3 = 1 # total population of neurons in l3
NrE_L4 = 1250 # total population of neurons in l4
NrExcNrns = 125 # nr of l4 -> l3 connections (per eye)
Nr4to3 = 2*NrExcNrns # nr of l4 -> l3 connections (both eyes)
NrI_L3 = 1 # total nr of inhibitory neurons in l3
tau_e = 1 # excitatory time constant [ms]
tau_i = 1 # inhibitory time constant [ms]
beta_e = .3 # excitatory gain function slope
beta_i = .3 # inhibitory gain function slope
theta_v = 0 # activity threshold
#*** Activity
vis_amp = 10 # Amplitude of visual input
background_current = 10 #7.5 # Amplitude of background input
Contra_depr = 0
Ipsi_depr = 1
bckg_depr = 1
theta_ee_H = 25 #13
theta_ee_L = 0
layerfactor = 1 # Extra ffw input factor from layer 4 to 3
#*** Plasticity
l_rate_ee = 4e-6
#*** initial w
wmax_exc = 7.5/Nr4to3
wmin_exc = .05*wmax_exc
wee_ini = 0
wee_L4to3_ini = wmax_exc
wie_ini = 0
w_ie_ffw_ini = 7.5/Nr4to3
wei_ini = 1.75
w_ii_ini = 0
#####################################
###### Variables
#####################################
#*** LAYER 3
Ev_L3 = np.zeros(NrE_L3) # vector of excitatory activities
Iv_L3 = np.zeros(NrI_L3) # vector of inhibitory activities
w_ie_ffw = w_ie_ffw_ini*np.ones((NrI_L3,NrE_L4)) # ffw weight matrix
w_ei_L3 = wei_ini*np.ones((NrE_L3,NrI_L3)) # ffw weight matrix
w_ie_L3 = wie_ini*np.ones((NrI_L3,NrE_L3))
w_ii_L3 = w_ii_ini*np.ones((NrI_L3,NrI_L3))
w_ee_4to3 = wee_L4to3_ini*np.ones((NrE_L3,NrE_L4))
w_ee_L3 = wee_ini*np.ones((NrE_L3,NrE_L3))
#*** LAYER 4
Ev_L4 = np.zeros(NrE_L4) # vector of excitatory activities
L4_connections = np.ones((NrE_L4,2))
L3_connections = np.zeros((NrE_L3,NrE_L4))
L3_connections_low = np.zeros((NrE_L3,NrE_L4))
meancontra = 30
stdvcontra = 25
if distr_mode == 'broad':
q = (meancontra +np.random.randn(NrE_L4)*stdvcontra)
else:
q = 38*np.ones(NrE_L4)
q[q<0] = 100*np.random.rand(len(q[q<0]))
q[q>100] = 100*np.random.rand(len(q[q>100])) #100
q = np.sort(q)
q = np.expand_dims(q,1)
if os.path.exists('dat/connections.txt'):
dat_load = open('dat/connections.txt','r')
dat_conncs = json.load(dat_load)
dat_load.close()
L3_connections = np.array(dat_conncs[0])
L4_connections = np.append(q,100-q,axis=1)*.01
else:
L4_connections = np.append(q,100-q,axis=1)*.01
arr = np.array((np.random.permutation(np.arange(1250))<250)*1)
L3_connections[0,:] = arr
L3_connections_inh = np.append(np.append(np.ones((1,int(.6*Nr4to3))),np.zeros((1,NrE_L4-Nr4to3))
,axis=1),np.ones((1,int(.4*Nr4to3))),axis=1)
storeV = np.zeros((15,int(tmax/storeDT))) # storage of data
storeW = np.zeros((Nr4to3,int(tmax/storeDT))) # storage of data
######################################
#----- Run -----#
######################################
stimWindow = int(deltaStim/timeStep)
strV3 = np.zeros((NrE_L3,stimWindow)) # storage of data
strV4 = np.zeros((NrE_L4,stimWindow)) # storage of data
strI3 = np.zeros((NrI_L3,stimWindow))
bckgr = background_current
deprived = 1
BD_depr = 1
for tt in np.arange(0,int(tmax),timeStep):
if tt%stimLen == 0:
Iptv = np.ones(NrE_L4)*0
if tt%deltaStim == 0:
Visual_Ipt_gauss_ipsi = L4_connections[:,0]*np.array([BD_depr*vis_amp]*NrE_L4)
Visual_Ipt_gauss_contra = L4_connections[:,1]*np.array([deprived*vis_amp]*NrE_L4)
Visual_Ipt_gauss = Visual_Ipt_gauss_ipsi + Visual_Ipt_gauss_contra
Iptv = Iptv+Visual_Ipt_gauss+bckgr
Iptv = 1*Iptv*(Iptv>0)
storeT = int((tt%deltaStim)/timeStep)
strV3[:,storeT] = Ev_L3
strV4[:,storeT] = Ev_L4
strI3[:,storeT] = Iv_L3
ie_ff = L3_connections_inh*w_ie_ffw #+ L3_connections_inh_low*w_ie_ffw_low
ee_ff = L3_connections*w_ee_4to3 #+ L3_connections_low*w_ee_4to3_low
if np.abs(tt-deprtime) < timeStep:
rvv_d0,rvv2 = f_calc_odi(w_ee_L3,ee_ff,w_ei_L3,ie_ff,w_ie_L3,w_ii_L3)
deprived = Contra_depr
BD_depr = Ipsi_depr #depr_value
bckgr = bckgr*bckg_depr
elif np.abs(tt-RedInhibTime) < timeStep:
w_ie_ffw = w_ie_ffw*.33
Ev_L4 = beta_e*(Iptv)
Iv_L3 = f_IEvol_WilCow(Iv_L3,layerfactor*Ev_L4,Ev_L3,w_ii_L3,ie_ff,w_ie_L3,tau_i)
Ev_L3 = f_EEvol_WilCow(Ev_L3,layerfactor*Ev_L4,Iv_L3,w_ee_L3,ee_ff,w_ei_L3,tau_e)
w_ee_4to3 = f_excitatory_plasticity(pre=Ev_L4,post=Ev_L3,w=w_ee_4to3,recurrent='no')
if tt%storeDT < timeStep:
storeV[0,int(tt/storeDT)] = Ev_L3[0]
storeV[1,int(tt/storeDT)] = w_ee_4to3[0,0]
storeV[2,int(tt/storeDT)] = w_ee_4to3[0,-1]
storeV[3,int(tt/storeDT)] = Iv_L3[0]
storeV[4,int(tt/storeDT)] = w_ee_4to3[0,10]
storeV[5,int(tt/storeDT)] = w_ee_4to3[0,20]
storeV[6,int(tt/storeDT)] = w_ee_4to3[0,40]
storeV[7,int(tt/storeDT)] = w_ee_4to3[0,int(NrE_L4*.5)]
storeV[8,int(tt/storeDT)] = np.mean(Ev_L4[:])
storeV[9,int(tt/storeDT)] = Ev_L3[0]*Ev_L4[0]
storeV[10,int(tt/storeDT)] = Ev_L3[0]*Ev_L4[-1]
storeW[:,int(tt/storeDT)] = w_ee_4to3[0,L3_connections[0,:]==1]
##################################
# Display the progress
###################################
if np.abs(tt-tmax*0.05*np.round(20*tt/tmax)) < 0.5*timeStep:
print('> '+str(int(np.round(100*tt/tmax)))+'%')
rvv_d3,rvv2 = f_calc_odi(w_ee_L3,ee_ff,w_ei_L3,ie_ff,w_ie_L3,w_ii_L3)
odi_d0=(rvv_d0[:,1]-rvv_d0[:,0])/(rvv_d0[:,0]+rvv_d0[:,1])
odi_d3=(rvv_d3[:,1]-rvv_d3[:,0])/(rvv_d3[:,0]+rvv_d3[:,1])
odi_4 = (rvv2[:,1]-rvv2[:,0])/(rvv2[:,1]+rvv2[:,0])
dataw = [odi_4.tolist(),odi_d0.tolist(),odi_d3.tolist()]
data3 = open("dat/supp_"+distr_mode+'.txt','w')
json.dump(dataw,data3)
data3.close()
if not os.path.exists('dat/connections.txt'):
data2 = [L3_connections.tolist(),L4_connections.tolist()]
datfile = open('dat/connections.txt','w')
json.dump(data2,datfile)
datfile.close()