# Description: This file contains the PVN network model used in the CADEX paper.
# this is a template script for the PVN network model, most functions are built to be overriden or extended
from brian2 import *
import time as pyt
import logging
import numpy as np
import joblib
import connection_matrix_gen as cmg
import os



logger = logging.getLogger(__name__)
prefs.codegen.target = 'cython'   # weave is not multiprocess-safe!
cache_dir = os.path.expanduser(f'~/.cython/brian-pid-{os.getpid()}')
prefs.codegen.runtime.cython.cache_dir = cache_dir

prefs.codegen.runtime.cython.multiprocess_safe = True
BrianLogger.suppress_hierarchy('brian2.codegen')
BrianLogger.suppress_hierarchy('brian2.groups.group.Group.resolve.resolution_conflict')

#GLOBAL DEFAULTS
NETWORK_ARGS = {'N':1e3, 'p_ei':0.02, 'p_ie':0.02, 
                'connectivity': 'random', 'connectivity_params':{},
                'init_run': True, 'ext_to_gaba': False, 'N_CRH': 500, 'N_GABA': 500, 'ECRH_params': None, 'ICRH_params': None, 'EI_params':None, 'IE_params': None,
                'record_poisson': False, 'record_all': True} #default network args
CRH_PARAM_FILE = "CRH_FILT_sampled_cell.joblib"
_CRH_DEFAULT_PARAMS = joblib.load(CRH_PARAM_FILE)
CRH_PARAMS = {
            'tauw': _CRH_DEFAULT_PARAMS['tauw'].to_numpy()*ms, 
            'a': _CRH_DEFAULT_PARAMS['a'].to_numpy()*nS, 
            'b': _CRH_DEFAULT_PARAMS['b'].to_numpy()*nS,
            'C': _CRH_DEFAULT_PARAMS['Cm'].to_numpy() * pF, 
            'taum': _CRH_DEFAULT_PARAMS['taum'].to_numpy()*ms, 
            'gL':(_CRH_DEFAULT_PARAMS['Cm'].to_numpy() * pF) / (_CRH_DEFAULT_PARAMS['taum'].to_numpy()*ms),
            'DeltaT': _CRH_DEFAULT_PARAMS['DeltaT'].to_numpy() *mV, 
            'DeltaA': _CRH_DEFAULT_PARAMS['DeltaA'].to_numpy() *mV, 
            'pr': np.full(500, 1),
            'Va': _CRH_DEFAULT_PARAMS['Va'].to_numpy() * mV, 
            'EL': _CRH_DEFAULT_PARAMS['EL'].to_numpy() *mV,
            'Ea': _CRH_DEFAULT_PARAMS['Ea'].to_numpy() *mV,
            'VT': _CRH_DEFAULT_PARAMS['VT'].to_numpy() *mV, 
            'VR': _CRH_DEFAULT_PARAMS['VR'].to_numpy() *mV, 
            'Vcut': (_CRH_DEFAULT_PARAMS['VT'].to_numpy() + 5 * _CRH_DEFAULT_PARAMS['DeltaT'].to_numpy()) *mV,
            'bd': np.full(500, 0) * nS,
        }

 
SYNAPSE_PARAMS = dict(
    Ee = 0 * mvolt,
    Ei = -80 * mvolt,
    taue = 8.901726144 * msecond,
    taucrh = 350.5817492 * msecond,
    taun = 1 * msecond,
    taui = 14.70811669* msecond,
    wcrh = 0.016626479 * nS,
    we = 6.14418288 * nS,
    weinh = 6.14418288 * nS,
    wi = 10.50052728 * nS,
    wexti = 0.922915983 * nS,
    taup = 40* second,
    taubr = 80 * second,
    exti_in = 8.869675194 * Hz,  #todo move to network args
    exte_in = 73.65083711 * Hz,  #todo move to network args
    Nexc = 500, #todo move to network args
    Ninh = 500, #todo move to network args
    p_exti = 1,
    input_mult = 2.246328156,
    b_change = 5,
    tauw_change = 1,
    pr_change_max = 0.1,
    SYNAPSE_TYPE = "inst"
)

DEFAULT_SYNAPSE_MODELS = { 
    "EI_model": {"pre": '''ycrh=clip((ycrh + wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "IE_model": {"pre": '''gi=clip((wi*int(rand()<pr)) + gi, 0*nS,  clamp_lim*nS)''', "post":"", "model":""},
    "ECRH_model": {"pre": '''ge=clip((ge + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "EIEXT_model": {"pre": '''ge=clip((ge + weinh*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "ICRH_model": {"pre": '''gi=clip((wexti*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    #this is a special case where we have fixed spike injection
    "single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"},
    #special case for E->E connections
    "EE_model": {"pre": '''ge=clip((ge + we_e*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    }

ALL_FAST_SYNAPSE_MODELS = { 
    "EI_model": {"pre": '''CRH=clip((CRH+ wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "IE_model": {"pre": '''gi=clip((wi*int(rand()<pr)) + gi, 0*nS,  clamp_lim*nS)''', "post":"", "model":""},
    "ECRH_model": {"pre": '''ge=clip((ge + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "EIEXT_model": {"pre": '''ge=clip((ge + weinh*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "ICRH_model": {"pre": '''gi=clip((wexti*int(rand()<pr)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    #this is a special case where we have fixed spike injection
    "single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"},
    #special case for E->E connections
    "EE_model": {"pre": '''ge=clip((ge + we_e*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    }

MONO_EXP_SYNAPSE_MODELS = { 
    "EI_model": {"pre": '''ycrh=clip((ycrh + wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "IE_model": {"pre": '''gi=clip((wi*int(rand()<pr)) + gi, 0*nS,  clamp_lim*nS)''', "post":"", "model":""},
    "ECRH_model": {"pre": '''y=clip((y + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "ICRH_model": {"pre": '''gi=clip((wexti*int(rand()<p_exti)) + gi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"}}

SLOW_GABA_SYNAPSE_MODELS = { 
    "EI_model": {"pre": '''CRH=clip((CRH+ wcrh), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "IE_model": {"pre": '''ygi=clip((wi*int(rand()<pr)) + ygi, 0*nS,  clamp_lim*nS)''', "post":"", "model":""},
    "ECRH_model": {"pre": '''ge=clip((ge + we*int(rand()<p_e)), 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "ICRH_model": {"pre": '''ygi=clip((wexti*int(rand()<p_exti)) + ygi, 0*nS, clamp_lim*nS)''', "post":"", "model":""},
    "single_neuron": {"pre": "", "post":"", "model":"v_syn = v_pre + (randn() * 5): volt"}}


_syn_dep = '''ddep/dt =(1-dep)/taud : 1 (clock-driven)
            '''
_syn_dep_pre = '''
                y=clip((y + (we*dep)*int(rand()<p_e)), 0*nS, clamp_lim*nS);
                dep -= mu*dep;
                '''
EXC_DEPRESSION_MODEL = {
    "ECRH_model": {"model": _syn_dep, "pre": _syn_dep_pre, "post": ""},
}

EQS = Equations('''
    #membrane equations
    dv/dt = ( gL*(EL-v) + gL*DeltaT*exp((v - VT)/DeltaT) + w*(Ea - v) + ge * (Ee-v) + gi * (Ei-v) + CRH * (Ee-v) + I ) * (1./C) * dyn_neur *int(not_refractory) : volt 
    dw/dt = ( a/(1 + exp((Va - v)/DeltaA)) - w ) / tauw : siemens
                    
    #synapse equations, currently in the form of a conductance based synapse
    dge/dt = (y-ge)/taued : siemens
    dy/dt = -y/tauer : siemens
    dygi/dt = -ygi/taui : siemens
    dgi/dt = (ygi-gi)/taui : siemens
    dycrh/dt = -ycrh/taucrh : siemens
    dCRH/dt = (ycrh-CRH)/taucrh : siemens

    #parameters for flipping GABA pr in a neuron specific manner
    dpr/dt = (1./taup) * pr * (1 - (pr/1)): 1 #probability of release, decays back to 1
    p_w = clip(pr, 0.01,1): 1 #not used?
    p_e : 1 #excitatory probability
    p_exti : 1 #external inhibitory probability
    br = b + bd : siemens #adaptation incremented per spike (br) is the sum of the default (b) and the added adaptation (bd)
    dbd/dt = -bd*(1./taubr): siemens
                    
    taue: second #excitatory decay
    tauer : second #excitatory rise
    taued : second #excitatory decay
    refrac : second #refractory period
    clamp_lim : 1 #Limit imposed on synaptic conductance
                    
    #neuron parameters
    Vcut: volt (constant)
    a : siemens
    b : siemens
    Va : volt (constant)
    Ea : volt (constant)
    DeltaT : volt (constant)
    DeltaA : volt (constant)
    VT : volt
    VR : volt
    tauw : second
    C : farad (constant)
    taum : second (constant)
    gL : siemens (constant)
    EL : volt
    bool_allow_spike : 1
    dyn_neur : 1 (constant) #this is a flag to indicate if the neuron is dynamic clamped, really should not be changed during a run
    ''')



GABA_PARAMS = {key:item[0] for key,item in CRH_PARAMS.items()} #in our case just take the first value of the CRH params
#except for EL, which is different for GABA
GABA_PARAMS['EL'] =-68.38802621 *mV


#experimental strings
EXP_STRS ={ 0: "EXP_0", #default, no input current, simulated EPSPs, no switch
           
            1: "EXP_1", #default, no input current, simulated EPSPs, switch @ 15, 35
            1.5: "EXP_1_EXTENDED", #default, no input current, simulated EPSPs, switch @ 150, 350
            2: "EXP_2", #default, no input current, in vitro EPSPs, switch @ 15, 35
            3: "EXP_3",  #default, with a current pulses, simulated EPSPs
            3.5: "EXP_3.5", #default, with a current pulses, simulated EPSPs, in tonic mode
            4: "EXP_4", #default, with a current pulsess, in vitro EPSPs
            4.5: "EXP_4.5", #default, with a current pulses, in vitro EPSPs, in tonic mode
            5: "EXP_5", #default, WITH increasing EPSPs, simulated EPSPs, in burst mode
            6: "EXP_6", #default, WITH increasing EPSPs, simulated EPSPs, in tonic mode
            7: "EXP_7", #default, WITH increasing EPSPs, simulated EPSPs across a grid
            8: "EXP_8" #default, WITH increasing EPSPs, simulated EPSPs across a grid
            }



VARS_TO_REC = ['v', 'br', 'tauw', 'pr', 'p_e', 'd_I', 'ge', 'gi', "I", "CRH"]

#make a functional version, with a nightmare level of parameters, but the class version is more readable
def cadex_network(run_time=60, network_args=None, 
                  CRH_param_file=None, CRH_params={}, GABA_param_file=None, GABA_params={}, synapse_params={}, network_events={}, medianem=False,
                  single_neuron=False, input_current=None, Ei_array=None, Ee_array=None, vars_to_rec=VARS_TO_REC, monitor_synapse=False,
    verbose='text', rseed=50):
    """ runs the network simulation for the specified run time. Includes the ability to change the input EPSP and the inhibition strength (GABA Pr).
    Also includes the ability to apply manipulations to the network post running
    Takes:
    run_time (int): the time to run the simulation for in seconds
    network_events (dict): manipulations to be applied to the network after starting a run. should be in the format of
                        {<time (in seconds)> : <event string>}, where the time to apply the manipulation is the key, and the event to apply is the value.
                        Possible <event strings> include:
                            double_EPSP => doubles the incoming EPSP frequency.
                            triple_EPSP => triples the incoming EPSP frequency.
                            disconnect_GABA_increase_adapt => drops the GABA Pr and increases the adaptation current. Putting the network into single spiking mode observed in vivo.
                            disconnect_GABA => drops the GABA Pr without increasing the adaptation current.
                            increase_adapt => increases the adaptation current without the loss of inhibition.
                            eval: <python code> => evaluates the python code in the string. This is useful for applying custom manipulations to the network.
                            
    returns:
    M (brian2 spike monitor): a brian2 spike monitor containing all the spikes for all neurons
    V (brian2 variable monitor): a brian2 variable monitor containing the voltage, adaptation, etc. recorded for the first two CRH neurons
    
    
    """
    #I apologize for the nightmare of code below. But I am having trouble with brian2 namespaces and the class version is not working

    seed(rseed)
    start_time = pyt.time()
    start_scope()
    
    #% network parameters
    if network_args is None:
        network_args = NETWORK_ARGS
    else:
        temp_args = NETWORK_ARGS.copy()
        temp_args.update(network_args)
        network_args = temp_args
    #% end network parameters

    #% neuron parameters
    #draw these from the sampled cell prior
    CRH_params = fetch_neuron_parameters(CRH_param_file, CRH_params, cell="CRH")
    GABA_params = fetch_neuron_parameters(GABA_param_file, GABA_params, cell="GABA")
    #% end neuron parameters
    

    # synapse parameters
    if len(synapse_params.keys())<1:
        #if the user gives no params just use the defaults
        synapse_params = SYNAPSE_PARAMS.copy()
    elif len(synapse_params.keys())>=1:
        #if the user gives params, override the defaults
        temp_default = SYNAPSE_PARAMS.copy()
        temp_default.update(synapse_params)
        synapse_params = temp_default

    #taue is now split into tauer and taued (rise and decay). If the user passes in taue, 
    # we will override the defaults and make them the same 
    if 'tauer' in synapse_params.keys() and 'taued' in synapse_params.keys():
        print("using split taue")
        #synapse_params['SYNAPSE_TYPE'] = "mono_exp"
    elif 'taue' in synapse_params.keys():
        print("taue found, splitting into tauer and taued")
        synapse_params['tauer'] = synapse_params['taue']
        synapse_params['taued'] = synapse_params['taue']
    else:
        print("WARNING: no taue, tauer, or taued passed in, using default taue")
        synapse_params['tauer'] = synapse_params['taue']
        synapse_params['taued'] = synapse_params['taue']



    #force update the locals to the synapse params
    globals().update(synapse_params) #this is actually a bad idea, but I am doing it for now
    #end synapse parameters
    eqs = EQS

    eqs += Equations('d_I = (ge * (Ee-v) + gi * (Ei-v) + w*(Ea - v)) + CRH * (Ee-v): amp') 
    if monitor_synapse: #if we are monitoring the synapse, add the synapse variables to the recording list
        eqs += Equations('I_e = ge * (Ee-v): amp')
        eqs += Equations('I_i = gi * (Ei-v): amp')
        eqs += Equations('I_CRH = CRH * (Ee-v): amp')

    if input_current is not None:
        eqs += Equations(f'I = input_current(t, i) : ampere')
    else:
        eqs += Equations(f'I = 0 * amp : ampere')

    #for egaba experiments
    if isinstance(Ei_array, TimedArray):
        eqs += Equations('Ei = Ei_array(t,i) : volt')
    if isinstance(SYNAPSE_PARAMS['Ee'], TimedArray):
        eqs += Equations('Ee = SYNAPSE_PARAMS["Ee"](t) : volt')
    
    # build network
    P = NeuronGroup( network_args['N'], model=eqs, threshold='v>Vcut', reset='v=VR; w+=br', refractory='refrac',  method='heun')
    P.pr = 1
    P.p_e = 1
    P.p_exti = 1
    #P.Vcut = 0*mV
    P.refrac = 1 * ms
    #set the P parameters to the first value of the CRH params, just for 
    P.C = CRH_params['C'][0]
    P.taum = CRH_params['taum'][0]
    P.gL = CRH_params['gL'][0]
    P.DeltaT = CRH_params['DeltaT'][0]
    P.DeltaA = CRH_params['DeltaA'][0]
    P.Va = CRH_params['Va'][0]
    P.Ea = CRH_params['Ea'][0]
    P.VT = CRH_params['VT'][0]
    P.VR = CRH_params['VR'][0]
    P.tauw = CRH_params['tauw'][0]
    P.a = CRH_params['a'][0]
    P.b = CRH_params['b'][0]
    P.EL = -68.38802621*mV
    P.dyn_neur = 1
    #% neuron group construction
    #CRH and GABA are the same size
    CRH = P[:network_args['N_CRH']]; GABA = P[network_args['N_CRH']:network_args['N_GABA']+network_args['N_CRH']]
    #set the CRH parameters to the full arrays of the above
    if len(list(CRH_params.keys()))>1:
        states = CRH.get_states(read_only_variables=False)
        for key, val in CRH_params.items():
            if key in states.keys():
                print(f"setting {key} to {np.nanmean(val)}; for CRH, which is {np.nanmean(states[key])}")
                states[key] = val[:len(CRH)]
        CRH.set_states(states, units=True)
    #synaptic specific decays and weights
    if synapse_params['SYNAPSE_TYPE'] == "mono_exp":
        GABA.tauer = tauer 
        GABA.taued = taued
        CRH.tauer = tauer 
        CRH.taued = taued
    else:
        GABA.tauer = taucrh
        GABA.taued = taucrh
        CRH.tauer = taue
        CRH.taued = taue
    #GABA neuron parameters
    if len(list(GABA_params.keys()))>1:
        states = GABA.get_states(read_only_variables=False)
        for key, val in GABA_params.items():
            if key in states.keys():
                print(f"setting {key} to {np.nanmean(val)}; for GABA, which is {np.nanmean(states[key])}")
                #if the array
                if isinstance(val, ndarray): #scalar values but with units will be ndarrays
                    if len(val.shape)>0: #dimensionality check, since numpy will return a 0d array for a single value, but np.isscalar does not work on scalers with units
                        if val.shape[0] == len(GABA):
                            states[key] = val
                            continue
                        states[key] = val[:len(GABA)]
                    else:
                        states[key] = val #if its a scalar, just set it
                else: #if its not an ndarray, just set it
                    states[key] = val

        GABA.set_states(states, units=True)
    #% end neuron group construction
         
    #% synapse construction
    if network_args['connectivity'] == 'random' or network_args['connectivity'] is None:
        print("using random connectivity")
        #% synapse construction
        EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
        IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
        #build the synapses
        EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
        IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
        EI.connect( True, p=network_args['p_ei'] )
        IE.connect( True, p=network_args['p_ie'] )
    elif isinstance(network_args['connectivity'], str):
        if network_args['connectivity'] == 'threeway':
            print("using three way connectivity")
            #this is a special case where we have a three way connectivity,
            #here we have some funky code, will make it more readable later
            # Essentially we are splitting CRH neurons into two groups, and connecting them (one-way)
            # Eg. CRH1 -> GABA1, CRH1 -> CRH2, CRH2 -> GABA1)
            #split the CRH neurons
            CRH1 = CRH[:int(len(CRH)/2)]
            CRH2 = CRH[int(len(CRH)/2):]
            EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
            IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
            EE_model, EE_model_pre, EE_model_post = fetch_syn_models(network_args, model_key='EE_model')
            #build the new connections
            EI = Synapses( CRH2, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )
            IE = Synapses( GABA, CRH1, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
            EE = Synapses( CRH1, CRH2, model=EE_model, on_pre=EE_model_pre, on_post=EE_model_post )
            EI.connect( True, p=network_args['p_ei'] )
            IE.connect( True, p=network_args['p_ie'] )
            EE.connect( True, p=network_args['p_ee'] )
        else:
            EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
            IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
            #build the synapses
            EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
            IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
            function = getattr(cmg, network_args['connectivity']) 
            print(f"using {network_args['connectivity']} connectivity")
            print(f"with params {network_args['connectivity_params']}")
            EI_connectivity, IE_connectivity = function(**network_args['connectivity_params'])
            EI.connect(i=EI_connectivity[:,0], j=EI_connectivity[:,1])
            IE.connect(i=IE_connectivity[:, 0], j=IE_connectivity[:, 1])
    elif isinstance(network_args['connectivity'], dict):
        EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
        IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
        #build the synapses
        EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
        IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
        print("using custom connectivity; assuming the user has passed in the connectivity matrices")
        EI_connectivity = network_args['connectivity']['EI_connect']
        IE_connectivity = network_args['connectivity']['IE_connect']
        EI.connect(i=EI_connectivity[:,0], j=EI_connectivity[:,1])
        IE.connect(i=IE_connectivity[:,0], j=IE_connectivity[:,1])
    else:
        EI_model, EI_model_pre, EI_model_post = fetch_syn_models(network_args, model_key='EI_model')
        IE_model, IE_model_pre, IE_model_post = fetch_syn_models(network_args, model_key='IE_model')
        #build the synapses
        EI = Synapses( CRH, GABA, model=EI_model, on_pre=EI_model_pre, on_post=EI_model_post )#*(rand()<pr)
        IE = Synapses( GABA, CRH, model=IE_model, on_pre=IE_model_pre, on_post=IE_model_post )
        print("connectivity not understood, using random")
        EI.connect( True, p=network_args['p_ei'] )
        IE.connect( True, p=network_args['p_ie'] )

    if medianem:
        ME = NeuronGroup(1, model='''
        dy/dt = -y*(1./taumer) : 1
        dcrh/dt = (y-crh)/taumed : 1
        ''', method="euler")
        crhrelease = Synapses( CRH, ME, on_pre='y+=1' )
        crhrelease.connect()
        ME_mon = StateMonitor(ME, ['y', 'crh'], record=True)
    
    #% end synapse construction

    #% poisson input 
    # #todo: make this a function
    #to CRH
    #if the user has a custom ECRH or ICRH model, use that
    ECRH_model, ECRH_model_pre, ECRH_model_post = fetch_syn_models(network_args, model_key='ECRH_model')
    ICRH_model, ICRH_model_pre, ICRH_model_post = fetch_syn_models(network_args, model_key='ICRH_model')
    if "ECRH_model" in network_args.keys():
        if ('spike_idx' in network_args['ECRH_model'].keys()): #special case where we have fixed spike injection
            print("using fixed spike injection")
            #if the user passes in a fixed spike injection, we will use that, there should be special keys in network_args
            spike_idx = network_args["ECRH_model"]['spike_idx']
            spike_t = network_args["ECRH_model"]['spike_t']
            poisson_exc = SpikeGeneratorGroup(Nexc, spike_idx, spike_t*second) 

            ECRH_model, ECRH_model_pre, ECRH_model_post = fetch_syn_models({}, model_key='ECRH_model')
        else:
            poisson_exc = PoissonGroup(Nexc, exte_in)
    else:
        poisson_exc = PoissonGroup(Nexc, exte_in)
    input = Synapses(poisson_exc, CRH, model=ECRH_model, on_pre=ECRH_model_pre, on_post=ECRH_model_post)
    if network_args['ECRH_params'] is None: #just do one to one
        input.connect(i='j')
    elif isinstance(network_args['ECRH_params'], dict):
        print(f"Found ECRH connectivity params {network_args['ECRH_params']}")
        #its probably kwargs for the connectivity function
        input.connect(**network_args['ECRH_params'])
    elif isinteractive(network_args['ECRH_params'], ndarray):
        #its probably a connectivity matrix
        input.connect(i=network_args['ECRH_params'][:,0], j=network_args['ECRH_params'][:,1])
    else:
        print("ECRH params not understood, connecting 1 to 1")
        input.connect(i='j')

    #Inhibitory input
    poisson_inh = PoissonGroup(Ninh, exti_in)
    input2 = Synapses(poisson_inh, CRH, model=ICRH_model, on_pre=ICRH_model_pre, on_post=ICRH_model_post)
    if network_args['ICRH_params'] is None: #just do one to one
        input2.connect(i='j')
    elif isinstance(network_args['ICRH_params'], dict):
        #its probably kwargs for the connectivity function
        input2.connect(**network_args['ICRH_params'])
    elif isinteractive(network_args['ICRH_params'], ndarray):
        #its probably a connectivity matrix
        input2.connect(i=network_args['ICRH_params'][:,0], j=network_args['ICRH_params'][:,1])
    else:
        print("ICRH params not understood, connecting  ,1 to 1")
        input2.connect(i='j')

    #to GABA, if the user passes in ext_to_gaba, then we will also connect the poisson input to the GABA neurons
    if network_args['ext_to_gaba']:
        print("ext to gaba")
        EIEXT_model, EIEXT_model_pre, EIEXT_model_post = fetch_syn_models(network_args, model_key='EIEXT_model')
        poisson_exc2 = PoissonGroup(Nexc, exte_in)
        poisson_inh2 = PoissonGroup(Ninh, exti_in)
        input3 = Synapses(poisson_inh2, GABA, model=ICRH_model, pre=ICRH_model_pre, post=ICRH_model_post)
        input3.connect(i='j')
        input4 = Synapses(poisson_exc2, GABA, model=EIEXT_model, pre=EIEXT_model_pre, post=EIEXT_model_post)
        input4.connect(**network_args['ECRH_params']) if network_args['ECRH_params'] is not None else input4.connect(i='j')
    #TODO neuron modulators?
    #% end poisson input

    print(f"wexti: {wexti/nS}")
    # init vars and monitors
    P.v = -80*mV
    P.ge = 0
    P.gi = 0
    P.clamp_lim = 999
    # monitor
    M = SpikeMonitor( P )
    _p_spk = SpikeMonitor( poisson_exc )

    if monitor_synapse:
        _syn_mon = StateMonitor(input, ['dep'], record=[0,2,3,4,5])
        vars_to_rec = np.copy(vars_to_rec).tolist()
        vars_to_rec += ['I_e', 'I_i', 'I_CRH']

    V = StateMonitor(P, vars_to_rec, [*np.arange(min(10, network_args['N'])), *np.arange(network_args['N_CRH'],network_args['N_CRH']+min(5, network_args['N_GABA']))], )
    
        
    # run simulation
    print("=== Net Sim Start ===")

    if 'init_run' in network_args.keys():
        if network_args['init_run']:
            #run with limited synaptic conductance to get the network to a stable state
            P.clamp_lim = 1
            P.bool_allow_spike = 0
            run(5*second, report=verbose)
            P.clamp_lim = 99
            P.bool_allow_spike = 1

    if len(network_events.keys())<1:
        #if the user gives no events just run as is
        run(run_time*second, report=verbose)
    else:
        #if there are events we can apply them
        run_time = np.diff(np.hstack([0, *list(network_events.keys()), run_time])) #this is some funky python, but I am computing out the run time between events
        events = np.hstack(['none', *list(network_events.values())]) #here we are also expanding out the listed evets
        print(events)
        print(run_time)
        for time, event in zip(run_time, events):
          if event=='double_EPSP':
            print("Increasing EPSP")
            poisson_exc.rates = input_mult*exte_in
          if event=='triple_EPSP':
            print("Increasing EPSP")
            poisson_exc.rates = 3*input_mult*exte_in
          if event=='disconnect_GABA_increase_adapt':
            print("Removing GABA and increasing Adaptation")
            pr_ar = synapse_params['pr_change_max']#np.full(1000, np.random.rand(1000)*synapse_params['pr_change_max'])
            b_ar =CRH_params['b'] * synapse_params['b_change']

            #print(b_ar.max()) 
            #print(pr_ar.max())
            P.pr=pr_ar
            CRH.bd=b_ar
          if event=='disconnect_GABA':
            print("Removing GABA ONLY")
            pr_ar = np.full(1000, np.random.rand(1000)*synapse_params['pr_change_max'])
            P.pr=pr_ar
          if event=='increase_adapt':
            print("increasing Adaptation")
            b_ar =CRH_params['b'] * synapse_params['b_change']
            CRH.bd=b_ar
          if "eval" in event:
            print("Evaluating")
            eval_statement = event.split(": ")[1]
            exec(eval_statement)
          run(time* second, report=verbose)
    print(f"simulation took {(pyt.time()-start_time)/60}")

    #out dict 
    ret_dict = {'spikes':M, 'states':V, 'EI_connect': np.vstack((EI.i, EI.j)),
                 'IE_connect': np.vstack((IE.i, IE.j)), 'poisson_spikes':_p_spk, 
                 'poisson_connect': np.vstack((input.i, input.j)), }#'syn_mon': _syn_mon}
    if medianem:
        ret_dict['ME'] = ME_mon
    if monitor_synapse:
        ret_dict['syn_mon'] = _syn_mon

    return ret_dict


def fetch_syn_models(params, model_key='ECRH_model'):
    """fetches the synapse models for the given params. or just returns the defaults
    takes:
        params (dict): the params to use for the synapse model
        model_key (str): the key to use for the model. Defaults to ECRH_model
    returns:
        model (str): the model string
        pre (str): the pre string
        post (str): the post string
    """
    model = ''''''
    pre = ''''''
    post = ''''''
    if model_key in params.keys():
        print(f"Found {model_key} in params")
        model += params[model_key]['model']
        pre += params[model_key]['pre']
        post += params[model_key]['post']
    #else just use the defaults
    else:
        print(f"Using default {model_key}")
        model = DEFAULT_SYNAPSE_MODELS [model_key]['model']
        pre = DEFAULT_SYNAPSE_MODELS [model_key]['pre']
        post = DEFAULT_SYNAPSE_MODELS [model_key]['post']
    return model, pre, post

def fetch_neuron_parameters(param_file, param_dict, cell="CRH"):
    """fetches the neuron parameters from a file and returns them as a dictionary (with units)
    takes:
        param_file (str): the file to load the parameters from
        param_dict (dict): the default parameters to use
    returns:
        params (dict): the parameters to use
    """
    default_params = CRH_PARAMS.copy() if cell == "CRH" else GABA_PARAMS.copy()


    if param_file is not None:
        sampled_cell = joblib.load(param_file)
        #kinda funky here, but we are going to set the CRH parameters to the means of the sampled cell, if the user passes in a dict of params, we will override the means
        param_file = {
            'tauw': sampled_cell['tauw'].to_numpy()*ms, 
            'a': sampled_cell['a'].to_numpy()*nS, 
            'b': sampled_cell['b'].to_numpy()*nS,
            'C': sampled_cell['Cm'].to_numpy() * pF, 
            'taum': sampled_cell['taum'].to_numpy()*ms, 
            'gL':(sampled_cell['Cm'].to_numpy() * pF) / (sampled_cell['taum'].to_numpy()*ms),
            'DeltaT': sampled_cell['DeltaT'].to_numpy() *mV, 
            'DeltaA': sampled_cell['DeltaA'].to_numpy() *mV, 
            'pr': np.full(500, 1),
            'Va': sampled_cell['Va'].to_numpy() * mV, 
            'EL': sampled_cell['EL'].to_numpy() *mV, 
            'Ea': sampled_cell['Ea'].to_numpy() *mV,
            'VT': sampled_cell['VT'].to_numpy() *mV, 
            'VR': sampled_cell['VR'].to_numpy() *mV, 
            'Vcut': (sampled_cell['VT'].to_numpy() + 5 * sampled_cell['DeltaT'].to_numpy()) *mV,
            'bd': np.full(500, 0) * nS
        }
        print(f"loaded {cell} params from file")
    #overide values if the user passes in a dict of params
    if len(param_dict.keys())<1:
        #if the user gives no params just use the defaults
        params = default_params.copy()
        if param_file is not None:
            params.update(param_file)
        else:
            print(f"Using default {cell} params")

    elif len(param_dict.keys()):
        temp_default = default_params.copy()
        print(f"Over")
        if param_file is not None:
            temp_default.update(param_file)
        temp_default.update(param_dict)
        params = temp_default.copy()

    return params
        

        

#here we will pop_out the params into their respective variables
def load_params_from_scan_csv(file):
    params = pd.read_csv(file)
    params = params.to_dict(orient='records')[0]
    #could defaultdict this, but for now just do it manually
    if 'wexti' in params:
        params['exti_ii'] = params['wexti']
    elif 'exti_ii' in params:
        pass
    else:
        params['exti_ii'] = 0
    if 'Nexc' not in params:
        if 'N_CRH' in params:
            params['Nexc'] = params['N_CRH']
        else:
            params['Nexc'] = 500
    

    network_args = {'p_ei': params['p_ei'], 'p_ie': params['p_ie']}
    synapse_params = {'taue': params['taue']*second, 'taui': params['taui']*second, 'taucrh': params['taucrh']*second, 
                        'wcrh': params[ 'wcrh']*siemens, 'we': params['we']*siemens, 'wi': params['wi']*siemens, 'exte_in': params['exte_in']*Hz,
                        'exti_in': params['exti_in']*Hz,  'wexti': params['exti_ii']*nS,
                        'Nexc':params['Nexc']} 
    
    if 'ext_to_gabe' in params:
        synapse_params['ext_to_gaba'] = params['ext_to_gaba']
    else:
        synapse_params['ext_to_gaba'] = False

    #convert the connectivity parameters to the correct format from strings
    if 'connectivity' in params:
        if params['connectivity'] == 'random':
            network_args['connectivity'] = 'random'
            network_args['connectivity_params'] = None
            
        else:
            network_args['connectivity_params'] = eval(params['connectivity'])
            network_args['connectivity_params']['inter_cluster_num'] = int(network_args['connectivity_params']['inter_cluster_num']*network_args['connectivity_params']['NUM_CLUSTERS']) + 1
            network_args['connectivity_params']['p_ei'] = params['p_ei']
            network_args['connectivity_params']['p_ie'] = params['p_ie']
    
    if 'p' in params: #needs a better key name but this p is the probability of a CRH neuron being connected to an E neuron
        network_args['ECRH_params'] = {'p': params['p']}
    else:
        network_args['ECRH_params'] = None

    return network_args, synapse_params

def load_params_from_joblib(path):
    params = joblib.load(path)
    network_args = params['network_args']
    synapse_params = params['synapse_params']
    if 'N_CRH' not in network_args:
        network_args['N_CRH'] = 500
    return network_args, synapse_params


def load_parameters_from_predefined_exp(parameter_file=None, exp_num=None, tag=None, neuron_num=None, run_time=None, network_args=None):
    if parameter_file is not None: 
        #try and split out the parameter code
        basename = os.path.basename(parameter_file)
        basename = os.path.splitext(basename)[0]
        if exp_num != 1.5:
            #if the user passes in a float, we will use that as the exp_num
            exp_num = int(exp_num)
        #now look for the preconfigured file, if it exists, we will use that
        _exp_path = f"{basename}_dyn_clamp_{exp_num}.joblib"
        _exp_path = os.path.join(os.path.dirname(parameter_file), _exp_path)
        if os.path.exists(_exp_path):
            print(f"loading parameters for {parameter_file} from {_exp_path}")
            _dict = joblib.load(_exp_path)
            print(_dict.keys())
            #split out the data
            synapse_params_base = _dict['synapse_params']
            network_params = _dict['network_args']
            network_events = _dict['network_events']
            input_current = _dict['input_current']
            exp = _dict['exp_num']
            RUNTIME = _dict['RUNTIME']

            #find the row of input_current that is nonzero
            if input_current is not None:
                max_idx = np.argmax(input_current, axis=1)

                #fix the input current
                input_current = TimedArray(input_current*amp, dt=defaultclock.dt)


            if 'RUNTIME' not in locals():
                RUNTIME = run_time
        else:
            print(f"preconfigured file not found: {parameter_file} @ {_exp_path}")
            print("Generating the new parameters on the fly")
            #try and load the parameter file
            _dict = joblib.load(parameter_file)
            print(_dict.keys())
            synapse_params_base = _dict['synapse_params']
            network_params = _dict['network_args']
            
            #we can generate the experimental dict from the parameter file
            synapse_params, network_params, network_events, input_current, RUNTIME = generate_experimental_dict(exp_num, dc.deepcopy(synapse_params_base), dc.deepcopy(network_params))

            return synapse_params, network_params, network_events, input_current, RUNTIME    

    return synapse_params_base, network_params, network_events, input_current, RUNTIME


def save_params_to_joblib(network_args, synapse_params, path):
    joblib.dump({'network_args': network_args, 'synapse_params': synapse_params}, path)
    

def generate_input_current(pattern='step', steps=[0, 20, 40, 60, 80], run_time=60, stim_length=1, episode_length=None, neuron=0, N=1000, dt=1*ms):
    """ Generates a timedarray for input current for the network. Currently only supports step patterns
    takes:
        pattern (str): the pattern to generate. Currently only supports 'step'
        steps (list): the steps to take in the pattern, in pico amps
        run_time (int): overall runtime of the network in seconds
        stim_length (int): the length of the stimulus in seconds
        episode_length (int): the length of the episode in seconds
        neuron (int or list): the neuron to apply the stimulus to
    """
    #turn the neurons into a list
    if not isinstance(neuron, list):
        neuron = [neuron]
    #generate the time array
    if pattern == 'step':
        if episode_length is None:
            episode_length = run_time / len(steps)
        stim_start = episode_length/2 - stim_length/2
        stim_end = episode_length/2 + stim_length/2
        logger.debug(f"creating a stimuli with episode length {episode_length} with stim start: {stim_start}, stim end: {stim_end}")
        input_current = np.zeros((int(run_time*1000), N))
        for i, step in enumerate(steps):
            for n in neuron:
                input_current[int(i*episode_length*1000+stim_start*1000):int(i*episode_length*1000+stim_end*1000), n] = step
        input_current = TimedArray(input_current*pA, dt=1*ms)
    return input_current


def generate_experimental_dict(exp_num, synapse_params_base, network_params, neuron_to_replace=0):
    """Generates the experimental dictionary for the network. This is a helper function to make the code more readable.
    These in silico experiments are used widely in the paper to test the network under different conditions.
    takes:
        synapse_params_base (dict): the base synapse parameters to use
        network_params (dict): the base network parameters to use
        exp_num (int): the experiment number to run
        neuron_to_replace (int): the neuron to replace in the input current
    returns:
        synapse_params_base (dict): the updated synapse parameters
        network_params (dict): the updated network parameters
        network_events (dict): the events to apply to the network
        input_current (TimedArray): the input current to apply to the network
        RUNTIME (int): the runtime of the
    """

    #default input current is none
    input_current = None
    #default runtime is 60 seconds
    RUNTIME = 60
    network_events = {}

    #config the experiment dict
    if exp_num == 0:
        network_events = {}
    elif exp_num == 1:
        synapse_params_base['taup'] = 5*second
        synapse_params_base['taubr'] = 7*second
        network_events = {15: 'disconnect_GABA_increase_adapt', 35: 'disconnect_GABA_increase_adapt'}
    elif exp_num == 1.5:
        RUNTIME = 600
        synapse_params_base['taup'] = 50*second
        synapse_params_base['taubr'] = 70*second
        network_events = {200: 'disconnect_GABA_increase_adapt', 400: 'disconnect_GABA_increase_adapt'}
    elif exp_num == 2:
        synapse_params_base['taup'] = 5*second
        synapse_params_base['taubr'] = 7*second
        network_events = {15: 'disconnect_GABA_increase_adapt', 40: 'disconnect_GABA_increase_adapt'}
        synapse_params_base['exte_in'] = 0.0*Hz #turn off the EPSPs
    elif exp_num == 3:
        synapse_params_base['taup'] = 5*second
        synapse_params_base['taubr'] = 7*second
        input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
    elif exp_num == 3.5:
        input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
        synapse_params_base['taup'] = 9e9*second
        synapse_params_base['taubr'] = 9e9*second
        network_events = {5:'disconnect_GABA_increase_adapt'}
    elif exp_num == 4:
        synapse_params_base['taup'] = 5*second
        synapse_params_base['taubr'] = 7*second
        input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
        synapse_params_base['exte_in'] = 0.0*Hz #turn off the EPSPs
    elif exp_num == 4.5:
        input_current = generate_input_current(run_time=RUNTIME, neuron=neuron_to_replace)
        synapse_params_base['exte_in'] = 0.0*Hz #turn off the EPSPs
        synapse_params_base['taup'] = 9e9*second
        synapse_params_base['taubr'] = 9e9*second
        network_events = {10:'disconnect_GABA_increase_adapt'}
    elif exp_num == 5:
        synapse_params_base['taup'] = 5*second
        synapse_params_base['taubr'] = 7*second
        network_events = {20: 'double_EPSP', 35: 'triple_EPSP'}
    elif exp_num == 6:
        synapse_params_base['taup'] = 9e9*second
        synapse_params_base['taubr'] = 9e9*second
        network_events = {10:'disconnect_GABA_increase_adapt', 20: 'double_EPSP', 35: 'triple_EPSP'}
    elif exp_num == 7:
        synapse_params_base['taup'] = 5*second
        synapse_params_base['taubr'] = 7*second
        rates_grid = np.linspace(0, 150, 20)
        rates_events = {int(i*10)  + 10: f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
        network_events = rates_events
        RUNTIME = 200
    elif exp_num == 8:
        rates_grid = np.linspace(0, 150, 20)
        rates_events = {int(i*10) + 10: f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
        network_events = {9:'disconnect_GABA_increase_adapt'}
        synapse_params_base['taup'] = 9e9*second
        synapse_params_base['taubr'] = 9e9*second
        network_events.update(rates_events)
        RUNTIME = 200
    elif exp_num == 9:
        #this one is basically exp_num 7 & 8 stacked together
        rates_grid = np.linspace(0, 150, 20)
        rates_events = {int(i*10) + 10: f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
        rates_events2 = {int(i*10) + (len(rates_grid) * 10 + 20) : f'eval: poisson_exc.rates = {rate}*Hz' for i, rate in enumerate(rates_grid)}
        network_events = {(len(rates_grid) * 10 + 10):'disconnect_GABA_increase_adapt'}
        synapse_params_base['taup'] = 9e9*second
        synapse_params_base['taubr'] = 9e9*second
        rates_events.update(network_events)
        rates_events.update(rates_events2)
        network_events = rates_events
        RUNTIME = 440

    return synapse_params_base, network_params, network_events, input_current, RUNTIME


if __name__ == "__main__":
    #test run the network
    logger.setLevel(logging.DEBUG)
    monitors = cadex_network(run_time=60, network_args={'connectivity': 'random', 'connectivity_params':{}, 'init_run': True})
    eventplot(list(monitors['spikes'].spike_trains().values()), color='k')
    show()