# Description: This file contains the PVN network model used in the CADEX paper.
# this is a template script for the PVN network model, most functions are built to be overriden or extended
from brian2 import *
import time as pyt
import logging
import numpy as np
import joblib
import connection_matrix_gen as cmg
import os
logger = logging.getLogger(__name__)
prefs.codegen.target = 'cython' # weave is not multiprocess-safe!
cache_dir = os.path.expanduser(f'~/.cython/brian-pid-{os.getpid()}')
prefs.codegen.runtime.cython.cache_dir = cache_dir
prefs.codegen.runtime.cython.multiprocess_safe = True
BrianLogger.suppress_hierarchy('brian2.codegen')
BrianLogger.suppress_hierarchy('brian2.groups.group.Group.resolve.resolution_conflict')
#GLOBAL DEFAULTS
NETWORK_ARGS = {'N':1e3, 'p_ei':0.02, 'p_ie':0.02,
'connectivity': 'random', 'connectivity_params':{},
'init_run': True, 'ext_to_gaba': False, 'N_CRH': 500, 'N_GABA': 500, 'ECRH_params': None, 'ICRH_params': None, 'EI_params':None, 'IE_params': None,
'record_poisson': False, 'record_all': True} #default network args
CRH_PARAM_FILE = "CRH_FILT_sampled_cell.joblib"
_CRH_DEFAULT_PARAMS = joblib.load(CRH_PARAM_FILE)
CRH_PARAMS = {
'tauw': _CRH_DEFAULT_PARAMS['tauw'].to_numpy()*ms,
'a': _CRH_DEFAULT_PARAMS['a'].to_numpy()*nS,
'b': _CRH_DEFAULT_PARAMS['b'].to_numpy()*nS,
'C': _CRH_DEFAULT_PARAMS['Cm'].to_numpy() * pF,
'taum': _CRH_DEFAULT_PARAMS['taum'].to_numpy()*ms,
'gL':(_CRH_DEFAULT_PARAMS['Cm'].to_numpy() * pF) / (_CRH_DEFAULT_PARAMS['taum'].to_numpy()*ms),
'DeltaT': _CRH_DEFAULT_PARAMS['DeltaT'].to_numpy() *mV,
'DeltaA': _CRH_DEFAULT_PARAMS['DeltaA'].to_numpy() *mV,
'pr': np.full(500, 1),
'Va': _CRH_DEFAULT_PARAMS['Va'].to_numpy() * mV,
'EL': _CRH_DEFAULT_PARAMS['EL'].to_numpy() *mV,
'Ea': _CRH_DEFAULT_PARAMS['Ea'].to_numpy() *mV,
'VT': _CRH_DEFAULT_PARAMS['VT'].to_numpy() *mV,
'VR': _CRH_DEFAULT_PARAMS['VR'].to_numpy() *mV,
'Vcut': (_CRH_DEFAULT_PARAMS['VT'].to_numpy() + 5 * _CRH_DEFAULT_PARAMS['DeltaT'].to_numpy()) *mV,
'bd': np.full(500, 0) * nS,
}
SYNAPSE_PARAMS = dict(
Ee = 0 * mvolt,
Ei = -80 * mvolt,
taue = 8.901726144 * msecond,
taucrh = 350.5817492 * msecond,
taun = 1 * msecond,
taui = 14.70811669* msecond,
wcrh = 0.016626479 * nS,
we = 6.14418288 * nS,
weinh = 6.14418288 * nS,
wi = 10.50052728 * nS,
wexti = 0.922915983 * nS,
taup = 40* second,
taubr = 80 * second,
exti_in = 8.869675194 * Hz, #todo move to network args
exte_in = 73.65083711 * Hz, #todo move to network args
Nexc = 500, #todo move to network args
Ninh = 500, #todo move to network args
p_exti = 1,
input_mult = 2.246328156,
b_change = 5,
tauw_change = 1,
pr_change_max = 0.1,
SYNAPSE_TYPE = "inst"
)
DEFAULT_SYNAPSE_MODELS = {
"EI_model": {"pre": '''ycrh=clip((ycrh + wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"IE_model": {"pre": '''gi=clip((wi*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ECRH_model": {"pre": '''ge=clip((ge + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"EIEXT_model": {"pre": '''ge=clip((ge + weinh*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ICRH_model": {"pre": '''gi=clip((wexti*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
#this is a special case where we have fixed spike injection
"single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"},
#special case for E->E connections
"EE_model": {"pre": '''ge=clip((ge + we_e*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
}
ALL_FAST_SYNAPSE_MODELS = {
"EI_model": {"pre": '''CRH=clip((CRH+ wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"IE_model": {"pre": '''gi=clip((wi*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ECRH_model": {"pre": '''ge=clip((ge + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"EIEXT_model": {"pre": '''ge=clip((ge + weinh*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ICRH_model": {"pre": '''gi=clip((wexti*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
#this is a special case where we have fixed spike injection
"single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"},
#special case for E->E connections
"EE_model": {"pre": '''ge=clip((ge + we_e*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
}
MONO_EXP_SYNAPSE_MODELS = {
"EI_model": {"pre": '''ycrh=clip((ycrh + wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"IE_model": {"pre": '''gi=clip((wi*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ECRH_model": {"pre": '''y=clip((y + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ICRH_model": {"pre": '''gi=clip((wexti*int(rand()<p_exti)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"}}
SLOW_GABA_SYNAPSE_MODELS = {
"EI_model": {"pre": '''CRH=clip((CRH+ wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"IE_model": {"pre": '''ygi=clip((wi*int(rand()<pr)) + ygi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ECRH_model": {"pre": '''ge=clip((ge + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"ICRH_model": {"pre": '''ygi=clip((wexti*int(rand()<p_exti)) + ygi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
"single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"}}
_syn_dep = '''ddep/dt =(1-dep)/taud : 1 (clock-driven)
'''
_syn_dep_pre = '''
y=clip((y + (we*dep)*int(rand()<p_e)), 0*nS, clamp_lim*nS);
dep -= mu*dep;
'''
EXC_DEPRESSION_MODEL = {
"ECRH_model": {"model": _syn_dep, "pre": _syn_dep_pre, "post": ""},
}
EQS = Equations('''
#membrane equations
dv/dt = ( gL*(EL-v) + gL*DeltaT*exp((v - VT)/DeltaT) + w*(Ea - v) + ge * (Ee-v) + gi * (Ei-v) + CRH * (Ee-v) + I ) * (1./C) * dyn_neur *int(not_refractory) : volt
dw/dt = ( a/(1 + exp((Va - v)/DeltaA)) - w ) / tauw : siemens
#synapse equations, currently in the form of a conductance based synapse
dge/dt = (y-ge)/taued : siemens
dy/dt = -y/tauer : siemens
dygi/dt = -ygi/taui : siemens
dgi/dt = (ygi-gi)/taui : siemens
dycrh/dt = -ycrh/taucrh : siemens
dCRH/dt = (ycrh-CRH)/taucrh : siemens
#parameters for flipping GABA pr in a neuron specific manner
dpr/dt = (1./taup) * pr * (1 - (pr/1)): 1 #probability of release, decays back to 1
p_w = clip(pr, 0.01,1): 1 #not used?
p_e : 1 #excitatory probability
p_exti : 1 #external inhibitory probability
br = b + bd : siemens #adaptation incremented per spike (br) is the sum of the default (b) and the added adaptation (bd)
dbd/dt = -bd*(1./taubr): siemens
taue: second #excitatory decay
tauer : second #excitatory rise
taued : second #excitatory decay
refrac : second #refractory period
clamp_lim : 1 #Limit imposed on synaptic conductance
#neuron parameters
Vcut: volt (constant)
a : siemens
b : siemens
Va : volt (constant)
Ea : volt (constant)
DeltaT : volt (constant)
DeltaA : volt (constant)
VT : volt
VR : volt
tauw : second
C : farad (constant)
taum : second (constant)
gL : siemens (constant)
EL : volt
bool_allow_spike : 1
dyn_neur : 1 (constant) #this is a flag to indicate if the neuron is dynamic clamped, really should not be changed during a run
''')
GABA_PARAMS = {key:item[0] for key,item in CRH_PARAMS.items()} #in our case just take the first value of the CRH params
#except for EL, which is different for GABA
GABA_PARAMS['EL'] =-68.38802621 *mV
#experimental strings
EXP_STRS ={ 0: "EXP_0", #default, no input current, simulated EPSPs, no switch
1: "EXP_1", #default, no input current, simulated EPSPs, switch @ 15, 35
1.5: "EXP_1_EXTENDED", #default, no input current, simulated EPSPs, switch @ 150, 350
2: "EXP_2", #default, no input current, in vitro EPSPs, switch @ 15, 35
3: "EXP_3", #default, with a current pulses, simulated EPSPs
3.5: "EXP_3.5", #default, with a current pulses, simulated EPSPs, in tonic mode
4: "EXP_4", #default, with a current pulsess, in vitro EPSPs
4.5: "EXP_4.5", #default, with a current pulses, in vitro EPSPs, in tonic mode
5: "EXP_5", #default, WITH increasing EPSPs, simulated EPSPs, in burst mode
6: "EXP_6", #default, WITH increasing EPSPs, simulated EPSPs, in tonic mode
7: "EXP_7", #default, WITH increasing EPSPs, simulated EPSPs across a grid
8: "EXP_8" #default, WITH increasing EPSPs, simulated EPSPs across a grid
}
VARS_TO_REC = ['v', 'br', 'tauw', 'pr', 'p_e', 'd_I', 'ge', 'gi', "I", "CRH"]
#make a functional version, with a nightmare level of parameters, but the class version is more readable
def cadex_network(run_time=60, network_args=None,
CRH_param_file=None, CRH_params={}, GABA_param_file=None, GABA_params={}, synapse_params={}, network_events={}, medianem=False,
single_neuron=False, input_current=None, Ei_array=None, Ee_array=None, vars_to_rec=VARS_TO_REC, monitor_synapse=False,
verbose='text', rseed=50):
""" runs the network simulation for the specified run time. Includes the ability to change the input EPSP and the inhibition strength (GABA Pr).
Also includes the ability to apply manipulations to the network post running
Takes:
run_time (int): the time to run the simulation for in seconds
network_events (dict): manipulations to be applied to the network after starting a run. should be in the format of
{<time (in seconds)> : <event string>}, where the time to apply the manipulation is the key, and the event to apply is the value.
Possible <event strings> include:
double_EPSP => doubles the incoming EPSP frequency.
triple_EPSP => triples the incoming EPSP frequency.
disconnect_GABA_increase_adapt => drops the GABA Pr and increases the adaptation current. Putting the network into single spiking mode observed in vivo.
disconnect_GABA => drops the GABA Pr without increasing the adaptation current.
increase_adapt => increases the adaptation current without the loss of inhibition.
eval: <python code> => evaluates the python code in the string. This is useful for applying custom manipulations to the network.
returns:
M (brian2 spike monitor): a brian2 spike monitor containing all the spikes for all neurons
V (brian2 variable monitor): a brian2 variable monitor containing the voltage, adaptation, etc. recorded for the first two CRH neurons
"""
#I apologize for the nightmare of code below. But I am having trouble with brian2 namespaces and the class version is not working
seed(rseed)
start_time = pyt.time()
start_scope()
#% network parameters
if network_args is None:
network_args = NETWORK_ARGS
else:
temp_args = NETWORK_ARGS.copy()
temp_args.update(network_args)
network_args = temp_args
#% end network parameters
#% neuron parameters
#draw these from the sampled cell prior
CRH_params = fetch_neuron_parameters(CRH_param_file, CRH_params, cell="CRH")
GABA_params = fetch_neuron_parameters(GABA_param_file, GABA_params, cell="GABA")
#% end neuron parameters
# synapse parameters
if len(synapse_params.keys())<1:
#if the user gives no params just use the defaults
synapse_params = SYNAPSE_PARAMS.copy()
elif len(synapse_params.keys())>=1:
#if the user gives params, override the defaults
temp_default = SYNAPSE_PARAMS.copy()
temp_default.update(synapse_params)
synapse_params = temp_default
#taue is now split into tauer and taued (rise and decay). If the user passes in taue,
# we will override the defaults and make them the same
if 'tauer' in synapse_params.keys() and 'taued' in synapse_params.keys():
print("using split taue")
#synapse_params['SYNAPSE_TYPE'] = "mono_exp"
elif 'taue' in synapse_params.keys():
print("taue found, splitting into tauer and taued")
synapse_params['tauer'] = synapse_params['taue']
synapse_params['taued'] = synapse_params['taue']
else:
print("WARNING: no taue, tauer, or taued passed in, using default taue")
synapse_params['tauer'] = synapse_params['taue']
synapse_params['taued'] = synapse_params['taue']
#force update the locals to the synapse params
globals().update(synapse_params) #this is actually a bad idea, but I am doing it for now
#end synapse parameters
eqs = EQS
eqs += Equations('d_I = (ge * (Ee-v) + gi * (Ei-v) + w*(Ea - v)) + CRH * (Ee-v): amp')
if monitor_synapse: #if we are monitoring the synapse, add the synapse variables to the recording list
eqs += Equations('I_e = ge * (Ee-v): amp')
eqs += Equations('I_i = gi * (Ei-v): amp')
eqs += Equations('I_CRH = CRH * (Ee-v): amp')
if input_current is not None:
eqs += Equations(f'I = input_current(t, i) : ampere')
else:
eqs += Equations(f'I = 0 * amp : ampere')
#for egaba experiments
if isinstance(Ei_array, TimedArray):
eqs += Equations('Ei = Ei_array(t,i) : volt')
if isinstance(SYNAPSE_PARAMS['Ee'], TimedArray):
eqs += Equations('Ee = SYNAPSE_PARAMS["Ee"](t) : volt')
# build network
P = NeuronGroup( network_args['N'], model=eqs, threshold='v>Vcut', reset='v=VR; w+=br', refractory='refrac', method='heun')
P.pr = 1
P.p_e = 1
P.p_exti = 1
#P.Vcut = 0*mV
P.refrac = 1 * ms
#set the P parameters to the first value of the CRH params, just for
P.C = CRH_params['C'][0]
P.taum = CRH_params['taum'][0]
P.gL = CRH_params['gL'][0]
P.DeltaT = CRH_params['DeltaT'][0]
P.DeltaA = CRH_params['DeltaA'][0]
P.Va = CRH_params['Va'][0]
P.Ea = CRH_params['Ea'][0]
P.VT = CRH_params['VT'][0]
P.VR = CRH_params['VR'][0]
P.tauw = CRH_params['tauw'][0]
P.a = CRH_params['a'][0]
P.b = CRH_params['b'][0]
P.EL = -68.38802621*mV
P.dyn_neur = 1
#% neuron group construction
#CRH and GABA are the same size
CRH = P[:network_args['N_CRH']]; GABA = P[network_args['N_CRH']:network_args['N_GABA']+network_args['N_CRH']]
#set the CRH parameters to the full arrays of the above
if len(list(CRH_params.keys()))>1:
states = CRH.get_states(read_only_variables=False)
for key, val in CRH_params.items():
if key in states.keys():
print(f"setting {key} to {np.nanmean(val)}; for CRH, which is {np.nanmean(states[key])}")
states[key] = val[:len(CRH)]
CRH.set_states(states, units=True)
#synaptic specific decays and weights
if synapse_params['SYNAPSE_TYPE'] == "mono_exp":
GABA.tauer = tauer
GABA.taued = taued
CRH.tauer = tauer
CRH.taued = taued
else:
GABA.tauer = taucrh
GABA.taued = taucrh
CRH.tauer = taue
CRH.taued = taue
#GABA neuron parameters
if len(list(GABA_params.keys()))>1:
states = GABA.get_states(read_only_variables=False)
for key, val in GABA_params.items():
if key in states.keys():
print(f"setting {key} to {np.nanmean(val)}; for GABA, which is {np.nanmean(states[key])}")
#if the array
if isinstance(val, ndarray): #scalar values but with units will be ndarrays
if len(val.shape)>0: #dimensionality check, since numpy will return a 0d array for a single value, but np.isscalar does not work on scalers with units
if val.shape[0] == len(GABA):
states[key] = val
continue
states[key] = val[:len(GABA)]
else:
states[key] = val #if its a scalar, just set it
else: #if its not an ndarray, just set it
states[key] = val
GABA.set_states(states, units=True)
#% end neuron group construction
#% synapse construction
if network_args['connectivity'] == 'random' or network_args['connectivity'] is None:
print("using random connectivity")
#% synapse construction
EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
#build the synapses
EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
EI.connect( True, p=network_args['p_ei'] )
IE.connect( True, p=network_args['p_ie'] )
elif isinstance(network_args['connectivity'], str):
if network_args['connectivity'] == 'threeway':
print("using three way connectivity")
#this is a special case where we have a three way connectivity,
#here we have some funky code, will make it more readable later
# Essentially we are splitting CRH neurons into two groups, and connecting them (one-way)
# Eg. CRH1 -> GABA1, CRH1 -> CRH2, CRH2 -> GABA1)
#split the CRH neurons
CRH1 = CRH[:int(len(CRH)/2)]
CRH2 = CRH[int(len(CRH)/2):]
EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
EE_model, EE_model_pre, EE_model_post = fetch_syn_models(network_args, model_key='EE_model')
#build the new connections
EI = Synapses( CRH2, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )
IE = Synapses( GABA, CRH1, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
EE = Synapses( CRH1, CRH2, model=EE_model, on_pre=EE_model_pre, on_post=EE_model_post )
EI.connect( True, p=network_args['p_ei'] )
IE.connect( True, p=network_args['p_ie'] )
EE.connect( True, p=network_args['p_ee'] )
else:
EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
#build the synapses
EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
function = getattr(cmg, network_args['connectivity'])
print(f"using {network_args['connectivity']} connectivity")
print(f"with params {network_args['connectivity_params']}")
EI_connectivity, IE_connectivity = function(**network_args['connectivity_params'])
EI.connect(i=EI_connectivity[:,0], j=EI_connectivity[:,1])
IE.connect(i=IE_connectivity[:, 0], j=IE_connectivity[:, 1])
elif isinstance(network_args['connectivity'], dict):
EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
#build the synapses
EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
print("using custom connectivity; assuming the user has passed in the connectivity matrices")
EI_connectivity = network_args['connectivity']['EI_connect']
IE_connectivity = network_args['connectivity']['IE_connect']
EI.connect(i=EI_connectivity[:,0], j=EI_connectivity[:,1])
IE.connect(i=IE_connectivity[:,0], j=IE_connectivity[:,1])
else:
EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
#build the synapses
EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
print("connectivity not understood, using random")
EI.connect( True, p=network_args['p_ei'] )
IE.connect( True, p=network_args['p_ie'] )
if medianem:
ME = NeuronGroup(1, model='''
dy/dt = -y*(1./taumer) : 1
dcrh/dt = (y-crh)/taumed : 1
''', method="euler")
crhrelease = Synapses( CRH, ME, on_pre='y+=1' )
crhrelease.connect()
ME_mon = StateMonitor(ME, ['y', 'crh'], record=True)
#% end synapse construction
#% poisson input
# #todo: make this a function
#to CRH
#if the user has a custom ECRH or ICRH model, use that
ECRH_model, ECRH_model_pre, ECRH_model_post = fetch_syn_models(network_args, model_key='ECRH_model')
ICRH_model, ICRH_model_pre, ICRH_model_post = fetch_syn_models(network_args, model_key='ICRH_model')
if "ECRH_model" in network_args.keys():
if ('spike_idx' in network_args['ECRH_model'].keys()): #special case where we have fixed spike injection
print("using fixed spike injection")
#if the user passes in a fixed spike injection, we will use that, there should be special keys in network_args
spike_idx = network_args["ECRH_model"]['spike_idx']
spike_t = network_args["ECRH_model"]['spike_t']
poisson_exc = SpikeGeneratorGroup(Nexc, spike_idx, spike_t*second)
ECRH_model, ECRH_model_pre, ECRH_model_post = fetch_syn_models({}, model_key='ECRH_model')
else:
poisson_exc = PoissonGroup(Nexc, exte_in)
else:
poisson_exc = PoissonGroup(Nexc, exte_in)
input = Synapses(poisson_exc, CRH, model=ECRH_model, on_pre=ECRH_model_pre, on_post=ECRH_model_post)
if network_args['ECRH_params'] is None: #just do one to one
input.connect(i='j')
elif isinstance(network_args['ECRH_params'], dict):
print(f"Found ECRH connectivity params {network_args['ECRH_params']}")
#its probably kwargs for the connectivity function
input.connect(**network_args['ECRH_params'])
elif isinteractive(network_args['ECRH_params'], ndarray):
#its probably a connectivity matrix
input.connect(i=network_args['ECRH_params'][:,0], j=network_args['ECRH_params'][:,1])
else:
print("ECRH params not understood, connecting 1 to 1")
input.connect(i='j')
#Inhibitory input
poisson_inh = PoissonGroup(Ninh, exti_in)
input2 = Synapses(poisson_inh, CRH, model=ICRH_model, on_pre=ICRH_model_pre, on_post=ICRH_model_post)
if network_args['ICRH_params'] is None: #just do one to one
input2.connect(i='j')
elif isinstance(network_args['ICRH_params'], dict):
#its probably kwargs for the connectivity function
input2.connect(**network_args['ICRH_params'])
elif isinteractive(network_args['ICRH_params'], ndarray):
#its probably a connectivity matrix
input2.connect(i=network_args['ICRH_params'][:,0], j=network_args['ICRH_params'][:,1])
else:
print("ICRH params not understood, connecting ,1 to 1")
input2.connect(i='j')
#to GABA, if the user passes in ext_to_gaba, then we will also connect the poisson input to the GABA neurons
if network_args['ext_to_gaba']:
print("ext to gaba")
EIEXT_model, EIEXT_model_pre, EIEXT_model_post = fetch_syn_models(network_args, model_key='EIEXT_model')
poisson_exc2 = PoissonGroup(Nexc, exte_in)
poisson_inh2 = PoissonGroup(Ninh, exti_in)
input3 = Synapses(poisson_inh2, GABA, model=ICRH_model, pre=ICRH_model_pre, post=ICRH_model_post)
input3.connect(i='j')
input4 = Synapses(poisson_exc2, GABA, model=EIEXT_model, pre=EIEXT_model_pre, post=EIEXT_model_post)
input4.connect(**network_args['ECRH_params']) if network_args['ECRH_params'] is not None else input4.connect(i='j')
#TODO neuron modulators?
#% end poisson input
print(f"wexti: {wexti/nS}")
# init vars and monitors
P.v = -80*mV
P.ge = 0
P.gi = 0
P.clamp_lim = 999
# monitor
M = SpikeMonitor( P )
_p_spk = SpikeMonitor( poisson_exc )
if monitor_synapse:
_syn_mon = StateMonitor(input, ['dep'], record=[0,2,3,4,5])
vars_to_rec = np.copy(vars_to_rec).tolist()
vars_to_rec += ['I_e', 'I_i', 'I_CRH']
V = StateMonitor(P, vars_to_rec, [*np.arange(min(10, network_args['N'])), *np.arange(network_args['N_CRH'],network_args['N_CRH']+min(5, network_args['N_GABA']))], )
# run simulation
print("=== Net Sim Start ===")
if 'init_run' in network_args.keys():
if network_args['init_run']:
#run with limited synaptic conductance to get the network to a stable state
P.clamp_lim = 1
P.bool_allow_spike = 0
run(5*second, report=verbose)
P.clamp_lim = 99
P.bool_allow_spike = 1
if len(network_events.keys())<1:
#if the user gives no events just run as is
run(run_time*second, report=verbose)
else:
#if there are events we can apply them
run_time = np.diff(np.hstack([0, *list(network_events.keys()), run_time])) #this is some funky python, but I am computing out the run time between events
events = np.hstack(['none', *list(network_events.values())]) #here we are also expanding out the listed evets
print(events)
print(run_time)
for time, event in zip(run_time, events):
if event=='double_EPSP':
print("Increasing EPSP")
poisson_exc.rates = input_mult*exte_in
if event=='triple_EPSP':
print("Increasing EPSP")
poisson_exc.rates = 3*input_mult*exte_in
if event=='disconnect_GABA_increase_adapt':
print("Removing GABA and increasing Adaptation")
pr_ar = synapse_params['pr_change_max']#np.full(1000, np.random.rand(1000)*synapse_params['pr_change_max'])
b_ar =CRH_params['b'] * synapse_params['b_change']
#print(b_ar.max())
#print(pr_ar.max())
P.pr=pr_ar
CRH.bd=b_ar
if event=='disconnect_GABA':
print("Removing GABA ONLY")
pr_ar = np.full(1000, np.random.rand(1000)*synapse_params['pr_change_max'])
P.pr=pr_ar
if event=='increase_adapt':
print("increasing Adaptation")
b_ar =CRH_params['b'] * synapse_params['b_change']
CRH.bd=b_ar
if "eval" in event:
print("Evaluating")
eval_statement = event.split(": ")[1]
exec(eval_statement)
run(time* second, report=verbose)
print(f"simulation took {(pyt.time()-start_time)/60}")
#out dict
ret_dict = {'spikes':M, 'states':V, 'EI_connect': np.vstack((EI.i, EI.j)),
'IE_connect': np.vstack((IE.i, IE.j)), 'poisson_spikes':_p_spk,
'poisson_connect': np.vstack((input.i, input.j)), }#'syn_mon': _syn_mon}
if medianem:
ret_dict['ME'] = ME_mon
if monitor_synapse:
ret_dict['syn_mon'] = _syn_mon
return ret_dict
def fetch_syn_models(params, model_key='ECRH_model'):
"""fetches the synapse models for the given params. or just returns the defaults
takes:
params (dict): the params to use for the synapse model
model_key (str): the key to use for the model. Defaults to ECRH_model
returns:
model (str): the model string
pre (str): the pre string
post (str): the post string
"""
model = ''''''
pre = ''''''
post = ''''''
if model_key in params.keys():
print(f"Found {model_key} in params")
model += params[model_key]['model']
pre += params[model_key]['pre']
post += params[model_key]['post']
#else just use the defaults
else:
print(f"Using default {model_key}")
model = DEFAULT_SYNAPSE_MODELS [model_key]['model']
pre = DEFAULT_SYNAPSE_MODELS [model_key]['pre']
post = DEFAULT_SYNAPSE_MODELS [model_key]['post']
return model, pre, post
def fetch_neuron_parameters(param_file, param_dict, cell="CRH"):
"""fetches the neuron parameters from a file and returns them as a dictionary (with units)
takes:
param_file (str): the file to load the parameters from
param_dict (dict): the default parameters to use
returns:
params (dict): the parameters to use
"""
default_params = CRH_PARAMS.copy() if cell == "CRH" else GABA_PARAMS.copy()
if param_file is not None:
sampled_cell = joblib.load(param_file)
#kinda funky here, but we are going to set the CRH parameters to the means of the sampled cell, if the user passes in a dict of params, we will override the means
param_file = {
'tauw': sampled_cell['tauw'].to_numpy()*ms,
'a': sampled_cell['a'].to_numpy()*nS,
'b': sampled_cell['b'].to_numpy()*nS,
'C': sampled_cell['Cm'].to_numpy() * pF,
'taum': sampled_cell['taum'].to_numpy()*ms,
'gL':(sampled_cell['Cm'].to_numpy() * pF) / (sampled_cell['taum'].to_numpy()*ms),
'DeltaT': sampled_cell['DeltaT'].to_numpy() *mV,
'DeltaA': sampled_cell['DeltaA'].to_numpy() *mV,
'pr': np.full(500, 1),
'Va': sampled_cell['Va'].to_numpy() * mV,
'EL': sampled_cell['EL'].to_numpy() *mV,
'Ea': sampled_cell['Ea'].to_numpy() *mV,
'VT': sampled_cell['VT'].to_numpy() *mV,
'VR': sampled_cell['VR'].to_numpy() *mV,
'Vcut': (sampled_cell['VT'].to_numpy() + 5 * sampled_cell['DeltaT'].to_numpy()) *mV,
'bd': np.full(500, 0) * nS
}
print(f"loaded {cell} params from file")
#overide values if the user passes in a dict of params
if len(param_dict.keys())<1:
#if the user gives no params just use the defaults
params = default_params.copy()
if param_file is not None:
params.update(param_file)
else:
print(f"Using default {cell} params")
elif len(param_dict.keys()):
temp_default = default_params.copy()
print(f"Over")
if param_file is not None:
temp_default.update(param_file)
temp_default.update(param_dict)
params = temp_default.copy()
return params
#here we will pop_out the params into their respective variables
def load_params_from_scan_csv(file):
params = pd.read_csv(file)
params = params.to_dict(orient='records')[0]
#could defaultdict this, but for now just do it manually
if 'wexti' in params:
params['exti_ii'] = params['wexti']
elif 'exti_ii' in params:
pass
else:
params['exti_ii'] = 0
if 'Nexc' not in params:
if 'N_CRH' in params:
params['Nexc'] = params['N_CRH']
else:
params['Nexc'] = 500
network_args = {'p_ei': params['p_ei'], 'p_ie': params['p_ie']}
synapse_params = {'taue': params['taue']*second, 'taui': params['taui']*second, 'taucrh': params['taucrh']*second,
'wcrh': params[ 'wcrh']*siemens, 'we': params['we']*siemens, 'wi': params['wi']*siemens, 'exte_in': params['exte_in']*Hz,
'exti_in': params['exti_in']*Hz, 'wexti': params['exti_ii']*nS,
'Nexc':params['Nexc']}
if 'ext_to_gabe' in params:
synapse_params['ext_to_gaba'] = params['ext_to_gaba']
else:
synapse_params['ext_to_gaba'] = False
#convert the connectivity parameters to the correct format from strings
if 'connectivity' in params:
if params['connectivity'] == 'random':
network_args['connectivity'] = 'random'
network_args['connectivity_params'] = None
else:
network_args['connectivity_params'] = eval(params['connectivity'])
network_args['connectivity_params']['inter_cluster_num'] = int(network_args['connectivity_params']['inter_cluster_num']*network_args['connectivity_params']['NUM_CLUSTERS']) + 1
network_args['connectivity_params']['p_ei'] = params['p_ei']
network_args['connectivity_params']['p_ie'] = params['p_ie']
if 'p' in params: #needs a better key name but this p is the probability of a CRH neuron being connected to an E neuron
network_args['ECRH_params'] = {'p': params['p']}
else:
network_args['ECRH_params'] = None
return network_args, synapse_params
def load_params_from_joblib(path):
params = joblib.load(path)
network_args = params['network_args']
synapse_params = params['synapse_params']
if 'N_CRH' not in network_args:
network_args['N_CRH'] = 500
return network_args, synapse_params
def load_parameters_from_predefined_exp(parameter_file=None, exp_num=None, tag=None, neuron_num=None, run_time=None, network_args=None):
if parameter_file is not None:
#try and split out the parameter code
basename = os.path.basename(parameter_file)
basename = os.path.splitext(basename)[0]
if exp_num != 1.5:
#if the user passes in a float, we will use that as the exp_num
exp_num = int(exp_num)
#now look for the preconfigured file, if it exists, we will use that
_exp_path = f"{basename}_dyn_clamp_{exp_num}.joblib"
_exp_path = os.path.join(os.path.dirname(parameter_file), _exp_path)
if os.path.exists(_exp_path):
print(f"loading parameters for {parameter_file} from {_exp_path}")
_dict = joblib.load(_exp_path)
print(_dict.keys())
#split out the data
synapse_params_base = _dict['synapse_params']
network_params = _dict['network_args']
network_events = _dict['network_events']
input_current = _dict['input_current']
exp = _dict['exp_num']
RUNTIME = _dict['RUNTIME']
#find the row of input_current that is nonzero
if input_current is not None:
max_idx = np.argmax(input_current, axis=1)
#fix the input current
input_current = TimedArray(input_current*amp, dt=defaultclock.dt)
if 'RUNTIME' not in locals():
RUNTIME = run_time
else:
print(f"preconfigured file not found: {parameter_file} @ {_exp_path}")
print("Generating the new parameters on the fly")
#try and load the parameter file
_dict = joblib.load(parameter_file)
print(_dict.keys())
synapse_params_base = _dict['synapse_params']
network_params = _dict['network_args']
#we can generate the experimental dict from the parameter file
synapse_params, network_params, network_events, input_current, RUNTIME = generate_experimental_dict(exp_num, dc.deepcopy(synapse_params_base), dc.deepcopy(network_params))
return synapse_params, network_params, network_events, input_current, RUNTIME
return synapse_params_base, network_params, network_events, input_current, RUNTIME
def save_params_to_joblib(network_args, synapse_params, path):
joblib.dump({'network_args': network_args, 'synapse_params': synapse_params}, path)
def generate_input_current(pattern='step', steps=[0, 20, 40, 60, 80], run_time=60, stim_length=1, episode_length=None, neuron=0, N=1000, dt=1*ms):
""" Generates a timedarray for input current for the network. Currently only supports step patterns
takes:
pattern (str): the pattern to generate. Currently only supports 'step'
steps (list): the steps to take in the pattern, in pico amps
run_time (int): overall runtime of the network in seconds
stim_length (int): the length of the stimulus in seconds
episode_length (int): the length of the episode in seconds
neuron (int or list): the neuron to apply the stimulus to
"""
#turn the neurons into a list
if not isinstance(neuron, list):
neuron = [neuron]
#generate the time array
if pattern == 'step':
if episode_length is None:
episode_length = run_time / len(steps)
stim_start = episode_length/2 - stim_length/2
stim_end = episode_length/2 + stim_length/2
logger.debug(f"creating a stimuli with episode length {episode_length} with stim start: {stim_start}, stim end: {stim_end}")
input_current = np.zeros((int(run_time*1000), N))
for i, step in enumerate(steps):
for n in neuron:
input_current[int(i*episode_length*1000+stim_start*1000):int(i*episode_length*1000+stim_end*1000), n] = step
input_current = TimedArray(input_current*pA, dt=1*ms)
return input_current
def generate_experimental_dict(exp_num, synapse_params_base, network_params, neuron_to_replace=0):
"""Generates the experimental dictionary for the network. This is a helper function to make the code more readable.
These in silico experiments are used widely in the paper to test the network under different conditions.
takes:
synapse_params_base (dict): the base synapse parameters to use
network_params (dict): the base network parameters to use
exp_num (int): the experiment number to run
neuron_to_replace (int): the neuron to replace in the input current
returns:
synapse_params_base (dict): the updated synapse parameters
network_params (dict): the updated network parameters
network_events (dict): the events to apply to the network
input_current (TimedArray): the input current to apply to the network
RUNTIME (int): the runtime of the
"""
#default input current is none
input_current = None
#default runtime is 60 seconds
RUNTIME = 60
network_events = {}
#config the experiment dict
if exp_num == 0:
network_events = {}
elif exp_num == 1:
synapse_params_base['taup'] = 5*second
synapse_params_base['taubr'] = 7*second
network_events = {15: 'disconnect_GABA_increase_adapt', 35: 'disconnect_GABA_increase_adapt'}
elif exp_num == 1.5:
RUNTIME = 600
synapse_params_base['taup'] = 50*second
synapse_params_base['taubr'] = 70*second
network_events = {200: 'disconnect_GABA_increase_adapt', 400: 'disconnect_GABA_increase_adapt'}
elif exp_num == 2:
synapse_params_base['taup'] = 5*second
synapse_params_base['taubr'] = 7*second
network_events = {15: 'disconnect_GABA_increase_adapt', 40: 'disconnect_GABA_increase_adapt'}
synapse_params_base['exte_in'] = 0.0*Hz #turn off the EPSPs
elif exp_num == 3:
synapse_params_base['taup'] = 5*second
synapse_params_base['taubr'] = 7*second
input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
elif exp_num == 3.5:
input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
synapse_params_base['taup'] = 9e9*second
synapse_params_base['taubr'] = 9e9*second
network_events = {5:'disconnect_GABA_increase_adapt'}
elif exp_num == 4:
synapse_params_base['taup'] = 5*second
synapse_params_base['taubr'] = 7*second
input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
synapse_params_base['exte_in'] = 0.0*Hz #turn off the EPSPs
elif exp_num == 4.5:
input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
synapse_params_base['exte_in'] = 0.0*Hz #turn off the EPSPs
synapse_params_base['taup'] = 9e9*second
synapse_params_base['taubr'] = 9e9*second
network_events = {10:'disconnect_GABA_increase_adapt'}
elif exp_num == 5:
synapse_params_base['taup'] = 5*second
synapse_params_base['taubr'] = 7*second
network_events = {20: 'double_EPSP', 35: 'triple_EPSP'}
elif exp_num == 6:
synapse_params_base['taup'] = 9e9*second
synapse_params_base['taubr'] = 9e9*second
network_events = {10:'disconnect_GABA_increase_adapt', 20: 'double_EPSP', 35: 'triple_EPSP'}
elif exp_num == 7:
synapse_params_base['taup'] = 5*second
synapse_params_base['taubr'] = 7*second
rates_grid = np.linspace(0, 150, 20)
rates_events = {int(i*10) + 10: f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
network_events = rates_events
RUNTIME = 200
elif exp_num == 8:
rates_grid = np.linspace(0, 150, 20)
rates_events = {int(i*10) + 10: f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
network_events = {9:'disconnect_GABA_increase_adapt'}
synapse_params_base['taup'] = 9e9*second
synapse_params_base['taubr'] = 9e9*second
network_events.update(rates_events)
RUNTIME = 200
elif exp_num == 9:
#this one is basically exp_num 7 & 8 stacked together
rates_grid = np.linspace(0, 150, 20)
rates_events = {int(i*10) + 10: f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
rates_events2 = {int(i*10) + (len(rates_grid) * 10 + 20) : f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
network_events = {(len(rates_grid) * 10 + 10):'disconnect_GABA_increase_adapt'}
synapse_params_base['taup'] = 9e9*second
synapse_params_base['taubr'] = 9e9*second
rates_events.update(network_events)
rates_events.update(rates_events2)
network_events = rates_events
RUNTIME = 440
return synapse_params_base, network_params, network_events, input_current, RUNTIME
if __name__ == "__main__":
#test run the network
logger.setLevel(logging.DEBUG)
monitors = cadex_network(run_time=60, network_args={'connectivity': 'random', 'connectivity_params':{}, 'init_run': True})
eventplot(list(monitors['spikes'].spike_trains().values()), color='k')
show()