"""
 Network of IF cells - Sustained Activity

 Cortical network consisting of N=500 EX and IN cells, with
 80-20% proportion, and random connectivity
 The excitatory cells also include a proportion of LTS cells

 calculate the nb of spikes for each cell -> "numspikes_cx05_LTS500b.dat"
 calculate spiketimes -> "spiketimes_cx05_LTS500b.dat"
 write the Vm of one cell to "Vm170_cx05_LTS500b.dat" for control
  
 Change history: 
   Refactoring of original Python conversion of Alain's Hoc file, putting more
    of the code into the cell classes.
   Replaced multiStimexp, multiAMPAexp and multiGABAAexp with ExpSyn
   Replaced IF_BG4 with AdExpIF
   Replaced gen mechanism with NetSimFD
   Replaced locally-defined AdExp cell class with BretteGerstnerIF from pyNN.neuron
   Replaced list of cells by pyNN.Population and direct NetCon creation with
      pyNN.connect()
   Now using pyNN.EIF_cond_exp_isfa_ista (standard cell, which uses BretteGerstnerIF)
   Can now run with any PyNN-supported simulator.

  The model is taken from 

  Destexhe, A. Self-sustained asynchronous irregular states and
  Up/Down states in thalamic, cortical and thalamocortical
  networks of nonlinear integrate-and-fire neurons.
  Journal of Computational Neuroscience 27: 493-506, 2009. 

  arXiv preprint: http://arxiv.org/abs/0809.0654

  This file reproduces Fig 7 of that paper

"""

import sys
SIMULATOR = sys.argv[-1]
#SIMULATOR = "neuron"

from math import sqrt, pi
from pyNN.random import NumpyRNG, RandomDistribution

exec("import pyNN.%s as pyNN" % SIMULATOR)
    
#-----------------------------------------------------------------
#  Parameters
#-----------------------------------------------------------------

# General parameters
USE_GUI = False   # } NEURON-specific 
USE_CVODE = False # }

SEED_LTS = 428577
SEED_CONN = 193564
SEED_GEN = 983648

DT = 0.1                                        # (ms) Time step
TSTART  = 0                                     
TSTOP   = 5000
V_INIT  = -60

# Cell parameters

LENGTH          = sqrt(20000/pi)                # in um
DIAMETER        = sqrt(20000/pi)                # in um
AREA            = 1e-8 * pi * LENGTH * DIAMETER # membrane area in cm2
TAU             = 20                            # time constant in ms
CAPACITANCE     = 1                             # capacitance in muF/cm2
G_L             = 1e-3 * CAPACITANCE / TAU      # leak conductance in S/cm2
V_REST          = -60                           # resting potential     

a_RS            = 0.001 
b_RS            = 0.1   # full adaptation
b_RS            = 0.005 # weaker adaptation
a_LTS           = 0.02
b_LTS           = 0.0
a_FS            = 0.001
b_FS            = 0.0

TAU_W           = 600
DELTA           = 2.5

# Spike parameters

VTR             = -50           # threshold in mV
VTOP            = 40            # top voltage during spike in mV
VBOT            = -60           # reset voltage in mV
REFRACTORY      = 5.0/2         # refractory period in ms (correction for a bug in IF_CG4)

# Synapse parameters

scale = 1               # scaling factor (>1 = more synapses)

TAU_E           = 5            # excitation time constant in ms
TAU_I           = 10           # inhibition time constant in ms
V_E             = 0            # excitatory reversal potential
V_I             = -80          # inhibitory reversal potential
AMPA_GMAX       = 0.006/scale
GABA_GMAX       = 0.067/scale
DELAY           = DT

# Network parameters
# Cortex
N_CX = 500              # Number of cortical cells
N_I = int(N_CX/5.0)     # Number of Cx inibitory cells
N_E = N_CX - N_I        # Number of excitatory cells
PROB_CONNECT = 0.02*scale       # Connection probability in cortex
PROB_CONNECT = 0.02*2000/N_CX   # prob renormalized to size

C_I = int(N_I*PROB_CONNECT)     # nb inh synapses per neuron
C_E = int(N_E*PROB_CONNECT)     # nb exc synapses per neuron
N_GEN = N_CX            # total number of cells

PROP = 0.05             # proportion of cortical LTS cells

# Stimulation parameters

N_STIM          = N_CX/5        # number of neurons stimulated
STOPSTIM        = 50            # duration of stimulation (ms)
NSYN_STIM       = 20            # nb of stim (exc) synapses per neuron
STIM_INTERVAL   = 70            # mean interval between stims (ms)

MODEL_ID        = "cx05_LTS500b_%s" % SIMULATOR

NEURONS_TO_PLOT = [0, 10, 20, 30, N_E, N_E+10]
NEURONS_TO_RECORD = [170, 0, N_STIM-1, N_E+10]

#-----------------------------------------------------------------
#  Create cells
#-----------------------------------------------------------------

# we now use a standard cell model from PyNN, so there is nothing to do here

#-----------------------------------------------------------------
#  Create Network
#-----------------------------------------------------------------

rLTS = NumpyRNG(seed=SEED_LTS)
nLTS = 0

def netCreate ():
    global nLTS, neurons
    RS_parameters = {
        'cm': 1000*AREA*CAPACITANCE, 'tau_m': TAU, 'v_rest': V_REST,
        'v_thresh': VTR, 'tau_refrac': REFRACTORY+DT,
        'v_reset': VBOT, 'v_spike': VTR+1e-6, 'a': 1000.0*a_RS, 'b': b_RS,
        'tau_w': TAU_W, 'delta_T': DELTA, 'tau_syn_E': TAU_E, 'e_rev_E': V_E,
        'tau_syn_I': TAU_I, 'e_rev_I': V_I, 'v_init': V_INIT, 'w_init': 0.0
    }
    
    LTS_parameters = RS_parameters.copy()
    LTS_parameters.update({'a': 1000.0*a_LTS, 'b': b_LTS}) # 1000 is for uS --> nS
    FS_parameters = RS_parameters.copy()
    FS_parameters.update({'a': 1000.0*a_FS, 'b': b_FS})

    neurons = pyNN.Population(N_CX, pyNN.EIF_cond_exp_isfa_ista, RS_parameters)
    for nbactual in range(0, N_E):      # create cortical cells (excitatory)
        # check if LTS cell
        if rLTS.uniform(0,1) < PROP:
            print "Cell ", nbactual, " is LTS"
            neurons[nbactual].set_parameters(**LTS_parameters)
            nLTS = nLTS + 1

    for nbactual in range(N_E, N_CX):     # create cortical cells (inhibitory)
        neurons[nbactual].set_parameters(**FS_parameters)

#  Connect cells

rCon = NumpyRNG(seed=SEED_CONN)

PRINT = 2        # flag to print; 0=minimal, 1=verbose, 2=summary

ampa_list = []
gabaa_list = []
stimsyn_list = []

def netConnect(): # local i, j, rand, distvert, nbconn
    ne = 0
    ni = 0
    ie = 0
    ii = 0
    print "Calculate connectivity of cortical cells..."
    # scan cortical cells
    for i in range(0, N_CX):
        if PRINT==1:
           if i<N_E:
                print "Cortical EX cell ",i
           else:
                print "Cortical IN cell ",i
        nbconex = 0
        nbconin = 0

        # Insert excitatory inputs
        j = 0
        while (nbconex < C_E) and (j < N_E):
            rand = rCon.uniform(0.0, 1.0)
            if (i != j) and (rand <= PROB_CONNECT):
                nc = pyNN.connect(neurons[j], neurons[i], weight=AMPA_GMAX,
                                  delay=DELAY, synapse_type="excitatory")
                ampa_list.append(nc)
                nbconex = nbconex + 1    
            j = j + 1
        if PRINT==1:
            print " - exc inputs from CX:", nbconex
        ne = ne + nbconex
        ie = ie + 1

        # Insert inhibitory inputs
        j =  N_E
        while (nbconin < C_I) and (j < N_CX):
            rand = rCon.uniform(0.0, 1.0)
            if (i != j) and (rand <= PROB_CONNECT):
                nc = pyNN.connect(neurons[j], neurons[i], weight=GABA_GMAX,
                                  delay=DELAY, synapse_type="inhibitory")
                gabaa_list.append(nc)
                nbconin = nbconin + 1
            j = j + 1
        if PRINT==1:
            print " - inh inputs from CX:", nbconin
        ni= ni + nbconin
        ii = ii + 1

    if PRINT==2:
        print "MEAN SYNAPSES PER NEURON:"
        print "cortical excitatory inputs ", float(ne)/ie
        print "cortical inhibitory inputs ", float(ni)/ii

#-----------------------------------------------------------------
#  External Input
#-----------------------------------------------------------------

nstim = NSYN_STIM
stim = []
rStim = NumpyRNG(seed=SEED_GEN)

def generate_stimulus(start, stop, interval):
    rd = RandomDistribution('exponential', [interval], rng=rStim)
    t = start
    times = []
    while t < stop:
        t += rd.next()
        if t < stop:
            times.append(t)
    return times

def insertStimulation():
    print "Add stimulation of cortical neurons..."
    for i in range(0, N_STIM):
        g = pyNN.Population(nstim, pyNN.SpikeSourceArray)
        for cell in g:
            spike_times = generate_stimulus(TSTART, STOPSTIM, STIM_INTERVAL)
            cell.spike_times = spike_times
        stim.append(g)
        cm = pyNN.connect(g, neurons[i], weight=AMPA_GMAX*scale, delay=DELAY,
                          synapse_type='excitatory')
        stimsyn_list.append(cm)

#-----------------------------------------------------------------
# Simulation settings
#-----------------------------------------------------------------

pyNN.setup(DT, min_delay=DELAY, quit_on_end=False, use_cvode=USE_CVODE,
           rng_seeds_seed=SEED_GEN)#,mappingAnalysis=True)

#-----------------------------------------------------------------
#  Add graphs
#-----------------------------------------------------------------

g = [None]*20
ngraph = 0

def addgraph(v_min, v_max, label, colour):
    global ngraph
    ngraph = ngraph+1
    ii = ngraph-1
    g[ii] = h.Graph()
    g[ii].size(TSTART, TSTOP, v_min, v_max)
    g[ii].xaxis()
    g[ii].yaxis()
    g[ii].addexpr(label, colour, 0)
    g[ii].save_name("graphList[0].")
    h.graphList[0].append(g[ii])
   

print ""
print "======================================================================="
print "            Network of ",N_GEN,"IF neurons in an active state"
print "======================================================================="
print ""

#------------------------------------------------------------------------------
#  creating cells
#------------------------------------------------------------------------------
print "----[ CREATING CELLS ]----"
netCreate()

#------------------------------------------------------------------------------
#  creating network
#------------------------------------------------------------------------------
print "----[ CREATING NETWORK ]----"
netConnect()

#------------------------------------------------------------------------------
#  adding network input
#------------------------------------------------------------------------------
print "----[ ADDING NETWORK INPUT ]----"
insertStimulation()

#------------------------------------------------------------------------------
#  procedures to write spike times
#------------------------------------------------------------------------------

nspikes = []

#def write_numspikes():  
#    f = open("numspikes_%s.dat" % MODEL_ID, 'w')
#    f.write("%g %g\n" % (N_GEN, pyNN.get_current_time()))             # write nb of cells and time
#    sum1 = 0
#    sum2 = 0
#    sum3 = 0
#    sum4 = 0
#    spike_counts = neurons.get_spike_counts()
#    #print spike_counts
#    for i in range(0, N_GEN):
#        nspikes = spike_counts.get(i, 0)
#        f.write("%g\n" % nspikes)       # write tot number of spikes
#        rate = nspikes * 1000.0 / TSTOP
#        if i<N_E:
#            sum1 = sum1 + rate
#            sum2 = sum2 + rate**2
#        else:
#            sum3 = sum3 + rate
#            sum4 = sum4 + rate**2
#    f.close()
#
#    sum1 = float(sum1) / N_E
#    sum2 = sqrt( float(sum2)/N_E - sum1**2 )
#    sum3 = float(sum3) / N_I
#    sum4 = sqrt( float(sum4)/N_I - sum3**2 )
#    return sum1, sum2, sum3, sum4

#-----------------------------------------------------------------
#  Graphs
#-----------------------------------------------------------------

if USE_GUI and SIMULATOR == 'neuron':
    from neuron import h, gui

    h('objref py')
    h.py = h.PythonObject() # lets Hoc access Python
    
    h.nrncontrolmenu()
    # adding graphs
    for id in NEURONS_TO_PLOT:
        addgraph(-80, 40, "py.neurons[%d]._cell.seg.v" % id, 4)

# record spikes
neurons.record()

# record the Vm
neurons_to_record = [neurons[id] for id in NEURONS_TO_RECORD]
neurons.record_v(neurons_to_record)      

#-----------------------------------------------------------------
# Procedure to run simulation and menu
#-----------------------------------------------------------------

def run_sim(with_graphs=False):
    print "----[ RUNNING THE SIMULATION ]----"
    
    if with_graphs:
        h.v_init = V_INIT
        h.init() # not really needed, since h.run() calls finitialize()
        h.tstop = TSTOP
        h.run()
    else:
        
        pyNN.run(TSTOP)

    print "Writing spikes to file..."
    #rate_RS, std_RS, rate_FS, std_FS = write_numspikes()
    neurons.printSpikes("spiketimes_%s.dat" % MODEL_ID, compatible_output=False)
    
#    print "Mean rate per RS cell (Hz) = ", rate_RS
#    print " standard deviation = ", std_RS
#    print "Mean rate per FS cell (Hz) = ", rate_FS
#    print " standard deviation = ", std_FS
    
    neurons.print_v("Vm_%s.dat" % MODEL_ID, compatible_output=False)

def make_Vpanel():                    # make panel
    h.xpanel("Brette-Gerstner network")
    h.xbutton("Run simulation", "py.run_sim(1)")
    h.xpanel()
    
if USE_GUI and SIMULATOR == 'neuron':
    make_Vpanel()
else:
    run_sim()


pyNN.end()