from fns.utils import *
import PIL.Image
from io import BytesIO
from IPython.display import clear_output, Image, display
import tensorflow as tf
from tensorflow.python.client import timeline
print("*" * 80)
print("functionsTF loaded!")
print("*" * 80)
'''
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
'''
def diag(W):
'''
Set diagonal elements to 0
'''
return W - np.diag(np.diag(W))
class TfConnEvolveNet:
def __init__(self,
T=100,
config=None,
profiling=False
):
# reset the tensorflow graph
tf.reset_default_graph()
## simulation parameters
# for profiling
self.profiling = profiling
# to save raster plot
self.spikeMonitor = False
self.debug = 0
# sampling interval of the weight matrices
self.weight_step = 10
# sampling interval for monitoring variables
self.monitor_step = 1
# to save individual traces
self.monitor_single = False
# default neurons distribution
self.NE1 = config.net.NE1
self.NE2 = config.net.NE2
self.NI1 = config.net.NI1
self.NI2 = config.net.NI2
# number of timesteps and timestep in ms (simulation duration in ms = T*dt)
self.T = T
self.dt = config.sim.dt
# time constant for 1st subnet
self.tauv1 = config.mod.I.tau_v1
# time constant for 2nd subnet
self.tauv2 = config.mod.I.tau_v2
# time constant for the adaption
self.tau_u = config.mod.I.tau_u
# adaptation coupling parameters
self.u_a = 50
# IZH neuron parameters
self.mod_a = config.mod.I.a
self.mod_b = config.mod.I.b
self.mod_c = config.mod.I.c
# neuron reset values
self.v_r_I = config.mod.I.v_reset
self.v_r_E = -70
# neuron threshold values
self.v_thresh_E = 0
self.v_thresh_I = 25
# IAF neuron time constant
self.tau_v_E = 40
# IAF neuron resistance
self.Rm = 0.6
# v initialisation
self.v_init_mean = -100
self.v_init_std = 30
# synapses parameters
self.tau_I_I = 10
self.tau_I_E = 10 #12
## bursting filter
self.tau_burst = 8.0
self.burst_thresh = 1.3
## plasticity variables
# plasticity multiplier
self.FACT = 1/self.dt
# LTD learning rate
self.alpha_LTD = 1.569e-5
# LTP/LTD ratio
self.ratio = config.plast.ratio
# time during which the plasticity is turned of after subnetworks are connected
self.stabTime = config.plast.stabTime
# time when to stop the plasticity
self.stopTime = np.inf
# time at which to connect subnetworks
self.connectTime = config.net.connectTime
## input parameters
# mean input current to inhibitory neurons
self.nu = config.net.noise.mean
self.sigmaNoise = config.net.noise.var
# input signal fed to one or both subnetworks
self.both = config.sim.both
# extra current to excitatory neurons
self.inE = config.net.inE
self.kInputE1 = config.net.noise.kE1
self.kInputE2 = config.net.noise.kE2
self.kNoiseE1 = config.net.noise.kNE1
self.kNoiseE2 = config.net.noise.kNE2
self.noiseScaling = 1 / (1 / (2 * 2 / self.dt)) ** 0.5 * self.sigmaNoise
# default input signal
self.input = np.zeros((1,T))
# self.input = 0
## connectivity
# slope of the WII curve
self.k = config.net.w.k
self.wE1E1 = config.net.w.E1E1
self.wE1E2 = config.net.w.E1E2
self.wE1I1 = config.net.w.E1I1
self.wE1I2 = config.net.w.E1I2
self.wE2E1 = config.net.w.E2E1
self.wE2E2 = config.net.w.E2E2
self.wE2I1 = config.net.w.E2I1
self.wE2I2 = config.net.w.E2I2
self.wI1E1 = config.net.w.I1E1
self.wI1E2 = config.net.w.I1E2
self.wI1I1 = config.net.w.I1I1
self.wI1I2 = config.net.w.I1I2
self.wI2E1 = config.net.w.I2E1
self.wI2E2 = config.net.w.I2E2
self.wI2I1 = config.net.w.I2I1
self.wI2I2 = config.net.w.I2I2
# LTP softbound
self.g0 = 0
# LTP rule
self.ltp_rule = "spiking"
# Symmetry
self.sym = True
# Plasticity direction
self.plast_dir = True
## cc12 = dVcell2 / dVcell1
## Haas et al fig 4c: depression of c21 after injection of current in cell 1
# gap junction conductances
self.g1 = config.net.g1
self.g2 = config.net.g2
# proportion of gap junction to delete
self.propToDelete = 0
self.v0 = 1
# number of shared gap junctions between subnets 1 and 2
self.sG = config.net.sG
# random distribution parameters
self.distrib = config.net.w.distrib
self.mu = config.net.w.mu
self.sigma = config.net.w.var
## tensorflow session parameters
gpu_options = tf.GPUOptions( # per_process_gpu_memory_fraction=memfraction,
allow_growth=True)
config = tf.ConfigProto(
log_device_placement=False,
#inter_op_parallelism_threads=1,
#intra_op_parallelism_threads=1,
gpu_options=gpu_options
)
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
self.sess = tf.InteractiveSession(config=config)
if profiling:
self.run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
self.run_metadata = tf.RunMetadata()
else:
self.run_metadata = None
self.run_options = None
## show progress bar
self.tqdm = True
def makeVect(self):
NE1, NE2, NI1, NI2 = self.NE1, self.NE2, self.NI1, self.NI2
vConnE1 = np.concatenate([np.ones((NE1, 1)), np.zeros((NI1 + NE2 + NI2, 1))])
vConnI1 = np.concatenate([np.zeros((NE1, 1)), np.ones((NI1, 1)), np.zeros((NE2 + NI2, 1))])
vConnI2 = np.concatenate([np.zeros((NE1 + NI1, 1)), np.ones((NI2, 1)), np.zeros((NE2, 1))])
vConnE2 = np.concatenate([np.zeros((NE1 + NI1 + NI2, 1)), np.ones((NE2, 1))])
VE1 = tf.Variable(vConnE1, dtype='float32')
VE2 = tf.Variable(vConnE2, dtype='float32')
VI1 = tf.Variable(vConnI1, dtype='float32')
VI2 = tf.Variable(vConnI2, dtype='float32')
return VE1, VE2, VI1, VI2
def add_shared_gap(self, W_, n):
W = W_.copy()
NE1, NE2, NI1, NI2 = self.NE1, self.NE2, self.NI1, self.NI2
N1 = NE1 + NI1
W[NE1:N1, N1:N1 + n] = 1
W[N1:N1 + NI2, N1 - n:N1] = 1
W[N1:N1 + n, NE1:N1] = 1
W[N1 - n:N1, N1:N1 + NI2] = 1
return diag(W)
def makeConn(self, TF=True, distrib='lognormal_gap', mu=1, sigma=1, sG=0,
we1e1=1, we1e2=1, we1i1=1, we1i2=1,
we2e1=1, we2e2=1, we2i1=1, we2i2=1,
wi1e1=1, wi1e2=1, wi1i1=1, wi1i2=1,
wi2e1=1, wi2e2=1, wi2i1=1, wi2i2=1,
g1=1, g2=1, gS=1):
NE1, NE2, NI1, NI2 = self.NE1, self.NE2, self.NI1, self.NI2
N1 = NE1 + NI1
N2 = NE2 + NI2
# total number of neurons
N = N1 + N2
#print(N1, NE1, NI1)
W0 = np.zeros((N1 + N2, N1 + N2))
## From E1
# WE1E1
WE1E1 = W0.copy()
WE1E1[:NE1, :NE1] = we1e1
WE1E1 = diag(WE1E1)
# WE1E2
WE1E2 = W0.copy()
WE1E2[:NE1, -NE2:] = we1e2
WE1E2 = WE1E2.T
# WE1I1
WE1I1 = W0.copy()
WE1I1[:NE1, NE1:N1] = we1i1
WE1I1 = diag(WE1I1)
WE1I1 = WE1I1.T
# WE1I2
WE1I2 = W0.copy()
WE1I2[:NE1, N1:N1 + NI2] = we1i2
WE1I2 = diag(WE1I2)
WE1I2 = WE1I2.T
## From E2
# WE2E1
WE2E1 = W0.copy()
WE2E1[-NE2:, :NE1] = we2e1
WE2E1 = WE2E1.T
# WE2E2
WE2E2 = W0.copy()
WE2E2[-NE2:, -NE2:] = we2e2
WE2E2 = diag(WE2E2)
# WE2I1
WE2I1 = W0.copy()
WE2I1[-NE2:, NE1:N1] = we2i1
WE2I1 = WE2I1.T
# WE2I2
WE2I2 = W0.copy()
WE2I2[-NE2:, N1:N1 + NI2] = we2i2
WE2I2 = diag(WE2I2)
WE2I2 = WE2I2.T
## From I1
# WI1E1
WI1E1 = W0.copy()
WI1E1[NE1:N1, :NE1] = wi1e1
WI1E1 = diag(WI1E1)
WI1E1 = WI1E1.T
# WI1E2
WI1E2 = W0.copy()
WI1E2[NE1:N1, -NE2:] = wi1e2
WI1E2 = diag(WI1E2)
WI1E2 = WI1E2.T
# WI1I1
WI1I1 = W0.copy()
WI1I1[NE1:N1, NE1:N1] = wi1i1
WI1I1 = diag(WI1I1)
# WI1I2
WI1I2 = W0.copy()
WI1I2[NE1:N1, N1:N1 + NI2] = wi1i2
WI1I2 = diag(WI1I2)
WI1I2 = WI1I2.T
## From I2
# WI2E1
WI2E1 = W0.copy()
WI2E1[N1:N1 + NI2, 0:NE1] = wi2e1
WI2E1 = diag(WI2E1)
WI2E1 = WI2E1.T
# WI2E2
WI2E2 = W0.copy()
WI2E2[N1:N1 + NI2, -NE2:] = wi2e2
WI2E2 = diag(WI2E2)
WI2E2 = WI2E2.T
# WI2I1
WI2I1 = W0.copy()
WI2I1[N1:N1 + NI2, NE1:N1] = wi2i1
WI2I1 = diag(WI2I1)
WI2I1 = WI2I1.T
# WI2I2
WI2I2 = W0.copy()
WI2I2[N1:N1 + NI2, N1:N1 + NI2] = wi2i2
WI2I2 = diag(WI2I2)
## Gap junctions
# WIIg1 gap junctions subnet1
WIIg1 = W0.copy()
WIIg1[NE1:N1, NE1:N1] = g1
WIIg1 = diag(WIIg1)
# WIIg2 gap junctions subnet1
WIIg2 = W0.copy()
WIIg2[N1:N1 + NI2, N1:N1 + NI2] = g2
WIIg2 = diag(WIIg2)
# shared Gap Junctions WIIg:
WIIgS = self.add_shared_gap(W0, sG) * gS
listmat = [WE1E1, WE1E2, WE1I1, WE1I2]
listmat += [WE2E1, WE2E2, WE2I1, WE2I2]
listmat += [WI1E1, WI1E2, WI1I1, WI1I2]
listmat += [WI2E1, WI2E2, WI2I1, WI2I2]
listmatG = [WIIg1, WIIg2, WIIgS]
listmatAll = listmat + listmatG
connMat = []
if distrib == 'lognormal':
for mat in listmatAll:
mat = mat * np.random.lognormal(mu, sigma, (N1 + N2, N1 + N2))
connMat.append(mat)
elif distrib == 'uniform':
for mat in listmatAll:
mat = mat * np.random.random((N1 + N2, N1 + N2))
connMat.append(mat)
elif distrib == 'lognormal_gap':
# print("Using log normal distribution for gap junctions")
for mat in listmat:
connMat.append(mat)
for mat in listmatG:
mat = mat * np.random.lognormal(mu, sigma, (N1 + N2, N1 + N2))
connMat.append(mat)
else:
connMat = listmatAll
WE1E1, WE1E2, WE1I1, WE1I2, \
WE2E1, WE2E2, WE2I1, WE2I2, \
WI1E1, WI1E2, WI1I1, WI1I2, \
WI2E1, WI2E2, WI2I1, WI2I2, \
WIIg1, WIIg2, WIIgS = connMat
# assign symmetry to gap junctions
if self.sym:
print("-- init with symmetric gap junctions")
gap_sym = []
for mat in [WIIg1, WIIg2, WIIgS]:
mat = (mat + mat.T)/2
gap_sym.append(mat)
WIIg1, WIIg2, WIIgS = gap_sym
else:
print("-- init with asymmetric gap junctions")
# get matrix of deleted connections (0s)
A = np.random.rand(N,N)
A = np.tril(A) + np.tril(A, -1).T
connDelete = (A > self.propToDelete) * 1
if TF:
WE1E1 = tf.Variable(WE1E1, dtype=tf.float32, name='E1E1')
WE1E2 = tf.Variable(WE1E2, dtype=tf.float32, name='E1E2')
WE1I1 = tf.Variable(WE1I1, dtype=tf.float32, name='E1I1')
WE1I2 = tf.Variable(WE1I2, dtype=tf.float32, name='E1I2')
WE2E1 = tf.Variable(WE2E1, dtype=tf.float32, name='E2E1')
WE2E2 = tf.Variable(WE2E2, dtype=tf.float32, name='E2E2')
WE2I1 = tf.Variable(WE2I1, dtype=tf.float32, name='E2I1')
WE2I2 = tf.Variable(WE2I2, dtype=tf.float32, name='E2I2')
WI1E1 = tf.Variable(WI1E1, dtype=tf.float32, name='I1E1')
WI1E2 = tf.Variable(WI1E2, dtype=tf.float32, name='I1E2')
WI1I1 = tf.Variable(WI1I1, dtype=tf.float32, name='I1I1')
WI1I2 = tf.Variable(WI1I2, dtype=tf.float32, name='I1I2')
WI2E1 = tf.Variable(WI2E1, dtype=tf.float32, name='I2E1')
WI2E2 = tf.Variable(WI2E2, dtype=tf.float32, name='I2E2')
WI2I1 = tf.Variable(WI2I1, dtype=tf.float32, name='I2I1')
WI2I2 = tf.Variable(WI2I2, dtype=tf.float32, name='I2I2')
WIIg1 = tf.Variable(WIIg1, dtype=tf.float32, name='IIg1')
WIIg2 = tf.Variable(WIIg2, dtype=tf.float32, name='IIg2')
WIIgS = tf.Variable(WIIgS, dtype=tf.float32, name='IIgS')
connDelete = tf.Variable(connDelete, dtype=tf.float32, name='connDelete')
tf.global_variables_initializer().run()
return WE1E1, WE1E2, WE1I1, WE1I2, \
WE2E1, WE2E2, WE2I1, WE2I2, \
WI1E1, WI1E2, WI1I1, WI1I2, \
WI2E1, WI2E2, WI2I1, WI2I2, \
WIIg1, WIIg2, WIIgS, connDelete
def init_float(self, shape, name):
return tf.Variable(tf.zeros(shape), name=name)
def runTFSimul(self):
#################################################################################
### INITIALISATION
#################################################################################
T = self.T
dt = self.dt
NE1, NE2, NI1, NI2 = self.NE1, self.NE2, self.NI1, self.NI2
N1 = NE1 + NI1
N2 = NE2 + NI2
N = N1 + N2
with tf.name_scope('spiking_bursting'):
LowSp = self.init_float([N, 1], 'bursting')
vv = self.init_float([N, 1], 'spiking')
with tf.name_scope('monitoring'):
# variables for monitoring
### sampling
weight_step = self.weight_step
monitor_step = self.monitor_step
vvmE1 = self.init_float([T // monitor_step], "vvE1")
vvmE2 = self.init_float([T // monitor_step], "vvE2")
vvmI1 = self.init_float([T // monitor_step], "vvI1")
vvmI2 = self.init_float([T // monitor_step], "vvI2")
vmE1 = self.init_float([T // monitor_step], "vE1")
vmE2 = self.init_float([T // monitor_step], "vE2")
vmI1 = self.init_float([T // monitor_step], "vI1")
vmI2 = self.init_float([T // monitor_step], "vI2")
imE1 = self.init_float([T // monitor_step], "i1E1")
imE2 = self.init_float([T // monitor_step], "i2E2")
imI1 = self.init_float([T // monitor_step], "i1I1")
imI2 = self.init_float([T // monitor_step], "imI2")
pmI1 = self.init_float([T // monitor_step], "pm")
lowspmI1 = self.init_float([T // monitor_step], "lowspm")
iGapm = self.init_float([T // monitor_step], "iGap")
### debugging
Am = self.init_float([T // monitor_step], "Am")
Bm = self.init_float([T // monitor_step], "Bm")
dwm = self.init_float([T // monitor_step], "dwm")
WI1I1m = self.init_float([T // weight_step], "WI1I1m")
g1m = self.init_float([T // weight_step], "gamma_N1")
g2m = self.init_float([T // weight_step], "gamma_N2")
gSm = self.init_float([T // weight_step], "gamma_NS")
if self.spikeMonitor:
spikes = self.init_float([T, N], "spikes")
else:
spikes = self.init_float([1, N], "spikes")
if self.monitor_single:
iAll = self.init_float([T, N], "iAll")
iChemAll = self.init_float([T, N], "iChem")
vAll = self.init_float([T, N], "vAll")
postAll = self.init_float([T, N], "postSynFilter")
else:
iAll = self.init_float([1, N], "iAll")
iChemAll = self.init_float([1, N], "iChemAll")
vAll = self.init_float([1, N], "vAll")
postAll = self.init_float([1, N], "postSynFilter")
with tf.name_scope('synaptic_connections'):
# matrices with 1 where connection exists
connE1E1, connE1E2, connE1I1, connE1I2, \
connE2E1, connE2E2, connE2I1, connE2I2, \
connI1E1, connI1E2, connI1I1, connI1I2, \
connI2E1, connI2E2, connI2I1, connI2I2, \
connIIg1, connIIg2, connIIgS, connDelete = self.makeConn(sG=self.sG, distrib='single_val')
vectE1, vectE2, vectI1, vectI2 = self.makeVect()
# mean synaptics weights
if NE1 > 0:
wE1E1_init = self.wE1E1 / NE1
wE1I1_init = self.wE1I1 / (NI1 * NE1) ** 0.5
wI1E1_init = self.wI1E1 / (NI1 * NE1) ** 0.5
else:
wE1E1_init, wE1I1_init, wI1E1_init = 0, 0, 0
if NI1 > 0:
wI1I1_init = self.wI1I1 / NI1
g0 = self.g0 / NI1
g1 = self.g1 / NI1
elif NI2 > 0:
g0 = self.g0 / NI2
g1, wI1I1_init = 0, 0
else:
wI1I1_init, g0, g1 = 0, 0, 0
if NE2 > 0:
wE1E2_init = self.wE1E2 / ((NE1 * NE2) ** 0.5)
wE2E1_init = self.wE2E1 / ((NE2 * NE1) ** 0.5)
wE2E2_init = self.wE2E2 / NE2
wE2I1_init = self.wE2I1 / ((NE2 * NI1) ** 0.5)
wI1E2_init = self.wI1E2 / ((NI1 * NE2) ** 0.5)
if NI2 > 0:
wE1I2_init = self.wE1I2 / (NI1 * NI2) ** 0.5
wE2I2_init = self.wE2I2 / (NI2 * NE2) ** 0.5
wI2E2_init = self.wI2E2 / (NI2 * NE2) ** 0.5
else:
wE1I2_init, wE2I2_init, wI2E2_init = 0, 0, 0
else:
wE1E2_init, wE2E1_init, wE2E2_init, wE2I1_init, wE2I2_init, wE2I2_init, wI1E2_init, wI2E2_init = 0, 0, 0, 0, 0, 0, 0, 0
if NI2 > 0:
if NE1 > 0:
wE1I2_init = self.wE1I2 / (NE1 * NI2) ** 0.5
wI2E1_init = self.wI2E1 / (NI2 * NE1) ** 0.5
else:
wE1I2_init, wI2E1_init = 0, 0
wI1I2_init = self.wI1I2 / (NI2 * NI1) ** 0.5
wI2I1_init = self.wI2I1 / (NI2 * NI1) ** 0.5
wI2I2_init = self.wI2I2 / NI2
g2 = self.g2 / NI2
gS = (g1 + g2) / 2
else:
wE1I2_init, wE2I2_init, wI1I2_init, wI2E1_init, wI2E2_init, wI2I1_init, wI2I2_init = 0, 0, 0, 0, 0, 0, 0
g2 = 0
gS = 0
WE1E1, WE1E2, WE1I1, WE1I2, \
WE2E1, WE2E2, WE2I1, WE2I2, \
WI1E1, WI1E2, WI1I1, WI1I2, \
WI2E1, WI2E2, WI2I1, WI2I2, \
Wgap1, Wgap2, WIIgS, _ = self.makeConn(
distrib=self.distrib, TF=True, mu=self.mu, sigma=self.sigma,
we1e1=wE1E1_init / dt, we1e2=wE1E2_init / dt, we1i1=wE1I1_init / dt, we1i2=wE1I2_init / dt,
we2e1=wE2E1_init / dt, we2e2=wE2E2_init / dt, we2i1=wE2I1_init / dt, we2i2=wE2I2_init / dt,
wi1e1=wI1E1_init / dt, wi1e2=wI1E2_init / dt, wi1i1=wI1I1_init / dt, wi1i2=wI1I2_init / dt,
wi2e1=wI2E1_init / dt, wi2e2=wI2E2_init / dt, wi2i1=wI2I1_init / dt, wi2i2=wI2I2_init / dt,
g1=g1, g2=g2, gS=gS
)
WII0 = WI1I1 + WI2I2
WchemI = WI1E1 + WI1E2 + WI1I2 + WI2E1 + WI2E2 + WI2I1
WchemE = WE1E1 + WE1E2 + WE1I1 + WE1I2 + WE2E1 + WE2E2 + WE2I1 + WE2I2
Wchem = WchemI + WchemE
wGap = tf.Variable(Wgap1 + Wgap2)
# delete prop of GJs defined by self.propToDelete
wGap = tf.multiply(wGap, connDelete)
tf.global_variables_initializer().run()
wGap = tf.Variable(wGap, name='wGap')
# plasticity learning rates
# A_LTD_ = self.alpha_LTD * self.FACT
# A_LTP = tf.constant(self.ratio * A_LTD_, name="A_LTD", dtype=tf.float32)
# A_LTD = tf.constant(A_LTD_ * dt/0.1, name="A_LTP", dtype=tf.float32)
A_LTD_ = self.alpha_LTD * self.FACT
A_LTP = tf.constant(self.ratio * A_LTD_/dt, name="A_LTD", dtype=tf.float32)
A_LTD = tf.constant(A_LTD_ , name="A_LTP", dtype=tf.float32)
with tf.name_scope('membrane_var'):
# Create variables for simulation state
u = self.init_float([N, 1], 'u')
v = tf.Variable(self.v0 * tf.random_normal([N, 1], mean=self.v_init_mean, stddev=self.v_init_std, name='v'))
# currents
iBack = self.init_float([N, 1], 'iBack')
iChem = self.init_float([N, 1], 'iChem')
input = tf.cast(tf.constant(self.input), tf.float32)
tauvSubnet = tf.Variable(
self.tauv1 * vectI1 + self.tauv2 * vectI2 + (vectE1 + vectE2),
name="tauv")
with tf.name_scope('simulation_params'):
# stimulation
inE = tf.Variable(self.inE, dtype=tf.float32)
kMult = tf.Variable(self.k, dtype=tf.float32)
TImean = self.nu
Nmean = TImean * (vectI1 + vectI2) + (TImean + inE) * (self.kNoiseE1 * vectE1 + self.kNoiseE2 * vectE2)
# timestep
dt = tf.constant(dt * 1.0, name="timestep")
connectTime = self.connectTime
stabTime = self.stabTime
stopTime = self.stopTime
# connection and plasticity times
sim_index = tf.Variable(0.0, name="sim_index", dtype=tf.float32)
one = tf.Variable(1.0)
ones = tf.ones((1, N))
#################################################################################
## Computation
#################################################################################
# Connect subnetworks
with tf.name_scope('Connect'):
g0_S = tf.reduce_mean(wGap * connIIg1) * ((NI1 + NE1) / NI1) ** 2 + \
tf.reduce_mean(wGap * connIIg2) * ((NI2 + NE2) / max(NI2, 1)) ** 2 # is 0 if NI2 == 0
wGapS = g0_S * connIIgS
connect = tf.group(
wGap.assign(tf.add(wGap, wGapS))
)
# Currents
with tf.name_scope('Currents'):
WII = WII0 * (1 - 2 * kMult * wGap)
iChem_ = iChem + dt / self.tau_I_I * (-iChem + tf.matmul(Wchem + WII, tf.to_float(vv), name="E/IPSPs"))
# noisy input current
iBack_ = iBack + dt / self.tau_I_I * (
-iBack + tf.random_normal((N, 1), mean=0.0, stddev=1.0, dtype=tf.float32, name=None)) * (
vectI1 + vectI2) + \
dt / self.tau_I_E * (-iBack + tf.random_normal((N, 1), mean=0.0, stddev=1.0, dtype=tf.float32,
name=None)) * (vectE1 + self.kNoiseE2 * vectE2)
# input_ = tf.gather(input, tf.to_int32(sim_index), axis=1)
input_ = tf.expand_dims(input[:, tf.to_int32(sim_index)], 1)
iEff_ = iBack_ * self.noiseScaling + input_ * (
vectI1 + vectI2 + self.kInputE1 * vectE1 + self.kInputE2 * vectE2) + Nmean
iGap_ = tf.matmul(wGap, v, name="GJ1") - tf.multiply(tf.reshape(tf.reduce_sum(wGap, 0), (N, 1)), v,
name="GJ2")
# sum all currents
I_ = iGap_ + iChem_ + iEff_
# Neuron models
with tf.name_scope('Izhikevich'):
'''
IZH I + IAF E
'''
# voltage
v_ = (v + dt / tauvSubnet * (
tf.multiply((v + self.mod_a + self.mod_b), (v + self.mod_a)) - self.tau_u * u + I_)) * (vectI1 + vectI2) + \
(v + dt / self.tau_v_E * (-v + self.Rm * I_)) * (vectE1 + vectE2)
# adaptation
u_ = u + dt / self.tau_u * (v_ + self.mod_a + self.mod_c - u)
# spikes
vv_ = tf.to_float(tf.greater(v_, self.v_thresh_I)) * (vectI1 + vectI2) + \
tf.to_float(tf.greater(v_, self.v_thresh_E)) * (vectE1 + vectE2)
# reset
v_ = tf.multiply(vv_, self.v_r_I) * (vectI1 + vectI2) + tf.multiply(vv_, self.v_r_E) * (
vectE1 + vectE2) + tf.multiply((1 - vv_),
v_)
u_ = u_ + self.u_a * vv_ * (vectI1 + vectI2)
# Bursting
with tf.name_scope('bursting'):
LowSp_ = LowSp - dt / self.tau_burst * LowSp + vv_
p_ = tf.to_float(tf.greater(LowSp_, self.burst_thresh))
# plasticity
with tf.name_scope('plasticity'):
A = tf.matmul(p_ * (vectI1 + vectI2), ones, name="bursts") # bursts
if self.ltp_rule == 'spiking':
B = tf.matmul(vv_ * (vectI1 + vectI2), ones, name="spikes") # spikes
if self.ltp_rule == 'passive':
B = tf.matmul((vectI1 + vectI2), ones, name="spikes") # spikes
if self.sym:
print("-- symmetric plasticity change")
A_ = tf.add(A, tf.transpose(A, name="tr_bursts"))/2
B_ = tf.add(B, tf.transpose(B, name="tr_spikes"))/2
else:
print("-- asymmetric plasticity change")
if self.plast_dir:
A_ = tf.transpose(A)
B_ = tf.transpose(B)
else:
A_ = A
B_ = B
#------------------------------------------
# depression
dwLTD_ = A_LTD * A_
#------------------------------------------
# potentiation
if self.g0 == 0:
# no bounds
dwLTP_ = A_LTP * B_
else:
# LTP softbound
dwLTP_ = A_LTP * (tf.multiply(tf.ones([N, N]) - wGap / g0, B_))
dwGap_ = dt * tf.subtract(dwLTP_, dwLTD_)
# lower bound is 0
wGap_ = wGap + dwGap_
wGap_ = tf.clip_by_value(wGap_, clip_value_min = 0, clip_value_max = np.inf)
wGap_ = tf.multiply(wGap_, connDelete)
wGap_before_ = tf.multiply(wGap_, connIIg1 + connIIg2)
wGap_after_ = tf.multiply(wGap_, connIIg1 + connIIg2 + connIIgS)
##############################################################################################
#
# monitoring
#
##############################################################################################
with tf.name_scope('Debugging'):
debug = tf.group(
tf.scatter_update(Am, tf.to_int32(sim_index), tf.reduce_mean(A)),
tf.scatter_update(Bm, tf.to_int32(sim_index), tf.reduce_mean(B)),
tf.scatter_update(dwm, tf.to_int32(sim_index), tf.reduce_mean(dwGap_)))
with tf.name_scope('Monitoring'):
# PSTH
vvmeanE1_ = tf.reduce_sum(vv_ * vectE1)
vvmeanE2_ = tf.reduce_sum(vv_ * vectE2)
vvmeanI1_ = tf.reduce_sum(vv_ * vectI1)
vvmeanI2_ = tf.reduce_sum(vv_ * vectI2)
# mean voltages
vmeanE1_ = tf.reduce_sum(v_ * vectE1)
vmeanE2_ = tf.reduce_sum(v_ * vectE2)
vmeanI1_ = tf.reduce_sum(v_ * vectI1)
vmeanI2_ = tf.reduce_sum(v_ * vectI2)
# LFPs
imeanE1_ = tf.reduce_sum(I_ * vectE1)
imeanE2_ = tf.reduce_sum(I_ * vectE2)
imeanI1_ = tf.reduce_sum(I_ * vectI1)
imeanI2_ = tf.reduce_sum(I_ * vectI2)
pmeanI1_ = tf.reduce_mean(p_ * vectI1)
lowspmeanI1_ = tf.reduce_mean(LowSp_[:2])
iGapm_ = tf.reduce_mean(iGap_ * vectI1)
update = tf.group(
tf.scatter_update(vvmE1, tf.to_int32(sim_index / monitor_step), vvmeanE1_),
tf.scatter_update(vvmE2, tf.to_int32(sim_index / monitor_step), vvmeanE2_),
tf.scatter_update(vvmI1, tf.to_int32(sim_index / monitor_step), vvmeanI1_),
tf.scatter_update(vvmI2, tf.to_int32(sim_index / monitor_step), vvmeanI2_),
tf.scatter_update(vmE1, tf.to_int32(sim_index / monitor_step), vmeanE1_),
tf.scatter_update(vmE2, tf.to_int32(sim_index / monitor_step), vmeanE2_),
tf.scatter_update(vmI1, tf.to_int32(sim_index / monitor_step), vmeanI1_),
tf.scatter_update(vmI2, tf.to_int32(sim_index / monitor_step), vmeanI2_),
tf.scatter_update(imE1, tf.to_int32(sim_index / monitor_step), imeanE1_),
tf.scatter_update(imE2, tf.to_int32(sim_index / monitor_step), imeanE2_),
tf.scatter_update(imI1, tf.to_int32(sim_index / monitor_step), imeanI1_),
tf.scatter_update(imI2, tf.to_int32(sim_index / monitor_step), imeanI2_),
tf.scatter_update(pmI1, tf.to_int32(sim_index / monitor_step), pmeanI1_),
tf.scatter_update(lowspmI1, tf.to_int32(sim_index / monitor_step), lowspmeanI1_),
tf.scatter_update(iGapm, tf.to_int32(sim_index / monitor_step), iGapm_),
)
update_single = tf.group(
tf.scatter_update(vAll, tf.to_int32(sim_index), tf.reshape((v_), (N,))),
tf.scatter_update(postAll, tf.to_int32(sim_index), tf.reshape((LowSp_), (N,))),
tf.scatter_update(iAll, tf.to_int32(sim_index), tf.reshape((I_), (N,))),
tf.scatter_update(iChemAll, tf.to_int32(sim_index), tf.reshape((iChem_), (N,))),
)
update_sim_index = tf.group(
sim_index.assign_add(one),
)
with tf.name_scope('Weights_monitoring'):
WI1I1m_ = tf.reduce_sum(WII * connIIg1)
g1m_ = tf.reduce_sum(wGap * connIIg1)
g2m_ = tf.reduce_sum(wGap * connIIg2)
gSm_ = tf.reduce_sum(wGap * connIIgS)
update_weights = tf.group(
tf.scatter_update(WI1I1m, tf.to_int32(sim_index / weight_step), WI1I1m_),
tf.scatter_update(g1m, tf.to_int32(sim_index / weight_step), g1m_),
tf.scatter_update(g2m, tf.to_int32(sim_index / weight_step), g2m_),
tf.scatter_update(gSm, tf.to_int32(sim_index / weight_step), gSm_),
)
with tf.name_scope('Raster_Plot'):
spike_update = tf.group(
tf.scatter_update(spikes, tf.to_int32(sim_index), tf.reshape((vv_), (N,))),
)
# Operation to update the state
step = tf.group(
iChem.assign(iChem_),
iBack.assign(iBack_),
LowSp.assign(LowSp_),
v.assign(v_),
vv.assign(vv_),
u.assign(u_),
)
# plasticity
plast_before = tf.group(
wGap.assign(wGap_before_),
)
plast_after = tf.group(
wGap.assign(wGap_after_),
)
# initialize the graph
tf.global_variables_initializer().run()
## chemical synapses
# from E1
self.WE1E1 = WE1E1.eval()
self.WE1E2 = WE1E2.eval()
self.WE1I1 = WE1I1.eval()
self.WE1I2 = WE1I2.eval()
# from E2
self.WE2E1 = WE2E1.eval()
self.WE2E2 = WE2E2.eval()
self.WE2I1 = WE2I1.eval()
self.WE2I2 = WE2I2.eval()
# from I1
self.WI1E1 = WI1E1.eval()
self.WI1E2 = WI1E2.eval()
self.WI1I1 = WI1I1.eval()
self.WI1I2 = WI1I2.eval()
self.connIIg1 = connIIg1.eval()
self.connIIg2 = connIIg2.eval()
# from I2
self.WI2E1 = WI2E1.eval()
self.WI2E2 = WI2E2.eval()
self.WI2I1 = WI2I1.eval()
self.WI2I2 = WI2I2.eval()
self.WII = WII.eval()
self.WII0 = WII0.eval()
## gap junctions connections
self.connIIgS = connIIgS.eval()
self.connIIg1 = connIIg1.eval()
self.connIIg2 = connIIg2.eval()
self.wGap0 = wGap.eval()
ops = {'before': [step, plast_before],
'after': [step, plast_after],
'static': [step]
}
if self.monitor_single:
update = [update, update_single]
if monitor_step == 1:
for k, v in ops.items():
ops[k] = v + [update]
if weight_step == 1:
for k, v in ops.items():
ops[k] = v + [update_weights]
if self.spikeMonitor:
for k, v in ops.items():
ops[k] = v + [spike_update]
if self.debug:
for k, v in ops.items():
ops[k] = v + [debug]
t0 = time.time()
pbar = trange(T)
pbar.disable = (not self.tqdm)
for i in pbar:
# Step simulation
if i == connectTime:
self.sess.run([connect])
if i < stabTime or i > stopTime:
self.sess.run(ops['static'],
options=self.run_options,
run_metadata=self.run_metadata
)
else:
self.sess.run(ops['after'],
options=self.run_options,
run_metadata=self.run_metadata
)
if monitor_step != 1 and i % monitor_step == 0:
self.sess.run([update])
if weight_step != 1 and i % weight_step == 0:
self.sess.run([update_weights])
self.sess.run([update_sim_index])
# debugging
self.Am = Am.eval()
self.Bm = Bm.eval()
self.dwm = dwm.eval()
# monitoring variables
self.wGapE = wGap.eval()
self.vvmE1 = vvmE1.eval()
self.vvmE2 = vvmE2.eval()
self.vvmI1 = vvmI1.eval()
self.vvmI2 = vvmI2.eval()
self.vmE1 = vmE1.eval()
self.vmE2 = vmE2.eval()
self.vmI1 = vmI1.eval()
self.vmI2 = vmI2.eval()
self.imI1 = imI1.eval()
self.imI2 = imI2.eval()
self.imE1 = imE1.eval()
self.imE2 = imE2.eval()
self.iAll = iAll.eval().T
self.iChemAll = iChemAll.eval().T
self.vAll = vAll.eval().T
self.postAll = postAll.eval().T
self.pmI1 = pmI1.eval()
self.lowspmI1 = lowspmI1.eval()
self.WI1I1m = WI1I1m.eval()
self.WIIe = WII.eval()
self.gm1 = g1m.eval()
self.gm2 = g2m.eval()
self.gmS = gSm.eval()
self.iGapm = iGapm.eval()
self.burstingActivity1 = np.mean(self.pmI1)
self.spikingActivity1 = np.mean(self.vvmI1)
self.connDelete = connDelete.eval()
if self.spikeMonitor:
self.raster = spikes.eval()
# print simulation duration
# print('\n%.2f\n' % (time.time() - t0))
# profiling information
# Create the Timeline object, and write it to a json
if self.profiling:
tl = timeline.Timeline(self.run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline.json', 'w') as f:
f.write(ctf)
self.sess.close()