#
'''
MSN model used in Lindroos et al., (2018). Frontiers

Robert Lindroos (RL) <robert.lindroos at ki.se>
 
The MSN class and most channels were implemented by 
Alexander Kozlov <akozlov at kth.se>
with updates by RL

Implemented in colaboration with Kai Du <kai.du at ki.se>
'''


from __future__ import print_function, division
from neuron import h
#import mpi4py as MPI
import numpy                as np
import plot_functions       as fun
import MSN_builder          as build
import json


h.load_file('stdlib.hoc')
h.load_file('import3d.hoc')



# global result dict
RES = {}


# initial and maximal values of substrates seen with a DA transient of 500 nA amp
# and 500 ms time constant.
with open('./substrates.json') as file:
    SUBSTRATES = json.load(file)
    
    
    
    


    
def alpha(tstart, gmax, tau):
    # calc and returns a "magnitude" using an alpha function -> used for [DA] in cascade
    
    v = 1 - (h.t - tstart) / tau
    e = np.exp(v)
    mag = gmax * (h.t - tstart) / tau * e
    
    return mag
    
    
    
    
    
    
    
def save_vector(t, v, outfile):
    
    with open(outfile, "w") as out:
        for time, y in zip(t, v):
            out.write("%g %g\n" % (time, y))
            
            
            
            
            
            
            
def calc_rand_Modulation(mod_list, range_list=False):
    '''
    uses numpy to draws random modulation factors in range [0,2],
    from a uniform distribution, for each channel in mod_list.
    
    The factors can also be linearly mapped to an arbitrary interval. 
    This is done if a range_list is given.
    
    mod_list     = list of channels to be modulated
    range_list   = list of [min, max] values to be used in modulation. 
                    Must have same length as mod_list.
    '''
    
    mod_factors = []
    
    A=0
    B=2
    
    for i,channel in enumerate(mod_list):
        
        factor = 2.0 * np.random.uniform()
    
        if range_list:
            
            a       = range_list[i][0]
            b       = range_list[i][1]
            
            factor = (b-a) / (B-A) * (factor-A) + a
       
        mod_factors.append(factor)
        
    return mod_factors 
        
         




def make_random_synapse(ns, nc, Syn, sec, x, n, \
                Type='glut',                    \
                NS_start=1,                     \
                NS_interval=1000.0/18.0,        \
                NS_noise=1.0,                   \
                NS_number=1000,                 \
                S_AN_ratio=1.0,                 \
                S_tau_dep=100,                  \
                S_U=1,                          \
                S_e=-60,                        \
                S_tau1=0.25,                    \
                S_tau2=3.75,                    \
                NC_delay=1,                     \
                NC_conductance=0.6e-3,          \
                NC_threshold=0.1                ):
    '''
    creates a synapse in the segment closest to x in section sec.
    NS-arguments are used to define a NetStim object
    S-arguemts are used to specify the synapse mechanism
    NC-arguemts are used to define the NetCon object
    '''
    
    # create/set synapse in segment x of section
    if Type == 'glut':
        key                 = sec
        Syn[key]            = h.tmGlut(x, sec=sec)
        Syn[key].nmda_ratio = S_AN_ratio
        Syn[key].tauR       = S_tau_dep
        Syn[key].U          = S_U
        
    elif Type in ['expSyn2', 'tmgabaa']:
        
        key         = sec.name() + '_gaba'
        
        if Type == 'expSyn2':
            Syn[key]            = h.Exp2Syn(x, sec=sec)
            Syn[key].tau1       = S_tau1
            Syn[key].tau2       = S_tau2 
        elif Type == 'tmgabaa':
            Syn[key]            = h.tmGabaA(x, sec=sec)
            Syn[key].tauR       = S_tau_dep
            Syn[key].U          = S_U
            
        Syn[key].e  = S_e
        
    else:
        
        sys.stderr.write('\nError: wrong synapse Type (%s). \n\tSynapse not set. Exiting\n' %Type)
        sys.exit()
        
         
    # create NetStim object
    ns[key]             = h.NetStim()
    ns[key].start       = NS_start
    ns[key].interval    = NS_interval # mean interval between two spikes in ms
    ns[key].noise       = NS_noise
    ns[key].number      = NS_number
    
    # create NetCon object
    nc[key]             = h.NetCon(ns[key],Syn[key]) #  THIS IS WHERE THE ERROR WAS (Syn[sek] instead of Syn[key])
    nc[key].delay       = NC_delay
    nc[key].weight[0]   = NC_conductance
    nc[key].threshold   = NC_threshold
    






def set_rand_synapse(channel_list, base_mod, max_mod, range_list=[[0.75,1.5],[0.75,1.5]]):   
    ''' 
    calculates and returnes values for the dynamic modulation factrs, normalized to to the 
    substrate range seen during 1000 ms simulation 
    '''
    
    syn_fact = calc_rand_Modulation(channel_list, range_list=range_list)
        
    # normalize factors to max-value of pointer substrate
    normalized_factors     = []
    for i,mech in enumerate(channel_list):
        
        normalized_factors.append( (syn_fact[i] - 1) / (max_mod - base_mod)  ) 
        
    return syn_fact, normalized_factors     
 



   
#=========================================================================================


# in the dynamimcal modulation, the channels are connected to one substrate of the cascade.
# base modulation (control) is assumed for base value of the substrate and full modulation
# is assumed when the substrate level reaches its maximal value. Linear scaling is used 
# between these points.
def main(par="./params_dMSN.json", \
                            run=None,       \
                            dynMod=1,       \
                            simDur=2000,    \
                            target=None,    \
                            not2mod=[] ): 
    
    print('iter:', run, target)
    
    
    
    # initiate cell
    cell    =   build.MSN(  params=par,                             \
                            morphology='latest_WT-P270-20-14ak.swc' )
    
        
    # set cascade
    casc = h.D1_reduced_cascade2_0(0.5, sec=cell.soma)
    
    
    # specify pointer 
    if target == 'control':
        target='Target1p'
    cmd         = 'pointer = casc._ref_'+target
    exec(cmd)
    base_mod    = SUBSTRATES[target][0]
    max_mod     = SUBSTRATES[target][1]
        
    
    # set edge of soma as reference for distance 
    h.distance(1, sec=h.soma[0])
    
    
    # record vectors
    tm = h.Vector()
    tm.record(h._ref_t)
    vm = h.Vector()
    vm.record(cell.soma(0.5)._ref_v)
    
    
    # record substrate concentrations
    if run == 0:
        pka = h.Vector()
        pka.record(casc._ref_Target1p)
        camp = h.Vector()
        camp.record(casc._ref_cAMP)
        gprot = h.Vector()
        gprot.record(casc._ref_D1RDAGolf)
        gbg   = h.Vector()
        gbg.record(casc._ref_Gbgolf)
    
    
    # parameters for DA transient
    da_peak   = 500     # concentration     [nM]
    da_tstart = 1000    # stimulation time  [ms]
    da_tau    = 500     # time constant     [ms]
    
    
    tstop = simDur      #                   [ms]
    # dt = default value; 0.025 ms (25 us)
    
    
    # all channels (with potential) to modulate 
    mod_list    = ['naf', 'kas', 'kaf', 'kir', 'cal12', 'cal13', 'can' ] 
      
    
    
    
    # draw mod factors from [min, max] ranges (as percent of control). 
    # Channel ranges are in the following order:
    # ['naf', 'kas', 'kaf', 'kir', 'cal12', 'cal13', 'can' ]
    mod_fact = calc_rand_Modulation(mod_list, range_list=[[0.60,0.80],  \
                                                          [0.65,0.85],  \
                                                          [0.75,0.85],  \
                                                          [0.85,1.25],  \
                                                          [1.0,2.0],    \
                                                          [1.0,2.0],    \
                                                          [0.0,1.0]]    )
    
    
    
    
    # normalize factors to target values seen in simulation by formula:
    #
    #   f           = (factor - 1) / (max_substrate - initial_substrate)
    #
    #   modulation  = 1   +      f * (substrate     - initial_substrate)
    #
    # so that the base value correspond to no modulation and and the maximal substrate (target)
    # value corresponds to the maximal value (given by the factor).
    factors     = []
    for i,mech in enumerate(mod_list):
        
        factor  = (mod_fact[i] - 1) / (max_mod - base_mod) #2317.1
        factors.append(factor)
        
            
    
    # set pointers 
    for sec in h.allsec():
        
        for seg in sec:
            
            # naf and kas are distributed to all sections
            h.setpointer(pointer, 'pka', seg.kas )
            h.setpointer(pointer, 'pka', seg.naf )
            
            
            if sec.name().find('axon') < 0:    
                
                
                # these channels are not in the axon section
                h.setpointer(pointer, 'pka', seg.kaf )
                h.setpointer(pointer, 'pka', seg.cal12 )
                h.setpointer(pointer, 'pka', seg.cal13 )
                h.setpointer(pointer, 'pka', seg.kir )
                
                
                if sec.name().find('soma') >= 0:
                    
                    
                    # can is only distributed to the soma section
                    h.setpointer(pointer, 'pka', seg.can )
                    
                    
                    
                    
                    
    # synaptic modulation ================================================================
    
    # draw random modulation factors for synapses from
    #   intervals given by range_list[[min,max]], where 1 is no modulation.  
    #   these ranges can be further restricted using the plot functions in 
    #       "plot_functions.py"
    glut_f, glut_f_norm     = set_rand_synapse(['amp', 'nmd'], base_mod, max_mod,   \
                                                range_list=[[0.9,1.6], [0.9,1.6]]   )
                                                
    gaba_f, gaba_f_norm     = set_rand_synapse(['gab'],        base_mod, max_mod,   \
                                                range_list=[[0.6,1.4]]              )
    
    syn_fact = glut_f + gaba_f
    
        
    I_d = {}
    ns  = {}
    nc  = {}
    Syn = {}
    
    for sec in h.allsec():
        
        # create one glutamatergic and one gabaergic synapse per section
        if sec.name().find('dend') >= 0:
            
            
            # create a glut synapse
            make_random_synapse(ns, nc, Syn, sec, 0.5, 0,       \
                                    NS_interval=1000.0/28.0,    \
                                    NC_conductance=0.50e-3      )
                                    
            # create a gaba synapse
            make_random_synapse(ns, nc, Syn, sec, 0.0, 0,       \
                                    Type='tmgabaa',             \
                                    NS_interval=1000.0/7.0,     \
                                    NC_conductance=1.50e-3      )
            
            
            # set pointer(s)
            h.setpointer(pointer, 'pka', Syn[sec])
            h.setpointer(pointer, 'pka', Syn[sec.name()+'_gaba'])
            
            
            # configure
            Syn[sec].base       = base_mod
            Syn[sec].f_ampa     = glut_f_norm[0]
            Syn[sec].f_nmda     = glut_f_norm[1]
            
            Syn[sec.name()+'_gaba'].base    = base_mod
            Syn[sec.name()+'_gaba'].f_gaba  = gaba_f_norm[0]
        
        
        elif sec.name().find('axon') >= 0: 
            
            # don't modulate segments in the axon
            continue 
              
        
        # set modulation of channels
        for seg in sec:
            
            for mech in seg:
                
                # turn of or set modulation
                if mech.name() in not2mod:
                    
                    mech.factor = 0.0
                    print(mech.name(), 'and channel:', not2mod, mech.factor, sec.name())
                    
                elif mech.name() in mod_list:
                
                    mech.base       = base_mod
                    index           = mod_list.index( mech.name() )
                    mech.factor     = factors[index]
                        
                    
                    
                    
                    
    
    # solver------------------------------------------------------------------------------            
    cvode = h.CVode()
    
    h.finitialize(cell.v_init)
    
    # run simulation
    while h.t < tstop:
    
        if dynMod == 1:
        
            if h.t > da_tstart: 
                
                casc.DA = alpha(da_tstart, da_peak, da_tau) 
                
        h.fadvance()
        
    
    
    
    
    # save output ------------------------------------------------------------------------
    
    ID          = ''
    all_factors = mod_fact + syn_fact
    
    # create "unique" ID
    for i,mech in enumerate(mod_list+['amp', 'nmd', 'gab']):
        ID = ID + mech + str( int(all_factors[i]*100) )
    
    if dynMod == 1:
    
        # DA transient
        
        if target == 'Target1p':
            save_vector(tm, vm, ''.join(['./Results/Dynamic/spiking_', str(run), '_', ID, '.out']) )
            
            if run == 0:
                # save substrate concentrations
                names = ['Target1p', 'cAMP', 'Gbgolf', 'D1RDAGolf']
        
                for i,substrate in enumerate([pka, camp, gbg, gprot]):
                    save_vector(tm, substrate, './Results/Dynamic/substrate_'+names[i]+'.out' )
        
        if target not in RES:
            RES[target] = {}
        
        RES[target][run]    = fun.getSpikedata_x_y(tm,vm) 
        
    else: 
        
        # no DA transient
        
        save_vector(tm, vm, ''.join(['./Results/Dynamic/spiking_', str(run), '_control.out']) )
        
    
    
    # when run large scale (on supercomputer) all iterarations from one core 
    # were stored in one dict to decrease the number of files. Only factors and spikes
    # were recorded        
    #spikes      = fun.getSpikedata_x_y(tm,vm) 
    #RES[run]    = {'factors': mod_fact + syn_fact, 'spikes': spikes}
        
                



# Start the simulation.
# Function needed for HBP compability  ===================================================
if __name__ == "__main__":
    
    # this will take a few minutes to run. Sim time can be reduced by reducing the n_runs
    # below
    
    simulate    = True
    
    if simulate:
    
        print('starting sim')
        
        
        n_runs_control      = 5
        n_runs_modulated    = 5
        targets             = ['Target1p', 'cAMP', 'Gbgolf',  'D1RDAGolf']
        
        # modulated
        for target in targets:
            for n in range(n_runs_modulated):
                main( par="./params_dMSN.json",          \
                            run=n,                      \
                            simDur=2000,                \
                            dynMod=1,                   \
                            target=target,              \
                            not2mod=[]                  )
        
        
        # control (dynMod = 0)
        for n in range(n_runs_modulated,n_runs_modulated+n_runs_control):
            main( par="./params_dMSN.json",          \
                        run=n,                      \
                        simDur=2000,                \
                        dynMod=0,                   \
                        target='control',           \
                        not2mod=[]                  )
       
        fun.save_obj( RES, './Results/Dynamic/SPIKES' )
               
    else:
            
        RES = fun.load_obj( './Results/Dynamic/SPIKES.pkl' )          
            
    print('plotting')            
    fun.plot_fig6B('./Results/Dynamic/', RES)