#Copied from stdpsynfire_noNMDA_synstim_vartau.py. Otherwise the same but saves a huge amount of weights - the result files requires a lot of disk space.
from brian2 import *
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
#prefs.codegen.target = 'cython'

start_scope()
timenow = time.time()

N = 10 # How many neurons per population
Lmax = 30 # Length of the single long synfire chain
stim_duration = 12.0 #Duration of stimulus in seconds
tot_duration = 16.0 #Duration of the whole stimulation in seconds
stim_freq_Hz = 2 #Stimulus rate
g_EE = 12.0 #Synapse weights E->E
g_EI = 20.0 #Synapse weights E->I
g_IE = 20.0 #Synapse weights I->E

A_plus_nS = 0.05 # How much the w is added in pre->post when inter-spike-interval approaches 0
A_minus_nS = 0.06 # How much the w is subtracted in post->pre when inter-spike-interval approaches 0
w_init_nS = 0.05 # Initial synaptic weight
A_tau_ms = 20 #Decay constant of the traces
doSave = 1 #Save in mat and PDF
myseed = 1 #Random number seed

if len(sys.argv) > 1:
  N = int(float(sys.argv[1])) 

if len(sys.argv) > 2:
  Lmax = int(float(sys.argv[2]))

if len(sys.argv) > 3:
  stim_duration = float(sys.argv[3])

if len(sys.argv) > 4:
  tot_duration = float(sys.argv[4])
  
if len(sys.argv) > 5:
  stim_freq_Hz = float(sys.argv[5])

if len(sys.argv) > 6:
  g_EE = float(sys.argv[6])

if len(sys.argv) > 7:
  g_EI = float(sys.argv[7])

if len(sys.argv) > 8:
  g_IE = float(sys.argv[8])

if len(sys.argv) > 9:
  A_plus_nS = float(sys.argv[9])

if len(sys.argv) > 10:
  A_minus_nS = float(sys.argv[10])

if len(sys.argv) > 11:
  w_init_nS = float(sys.argv[11])

if len(sys.argv) > 12:
  A_tau_ms = int(float(sys.argv[12]))

if len(sys.argv) > 13:
  doSave = int(float(sys.argv[13]))

if len(sys.argv) > 14:
  myseed = int(float(sys.argv[14]))

seed(myseed)
seed_addition = '' if myseed == 1 else '_seed'+str(myseed)
tau_addition = '' if A_tau_ms == 20 else '_tau'+str(A_tau_ms)
w_max = g_EE * 10/N * nS
w_init = w_init_nS * nS

# ----------------------------
# Network parameters
# ----------------------------

stim_freq = stim_freq_Hz * Hz          # 2 Hz rhythmic drive
T_stim = 1.0 / stim_freq    # Interval between cycles
sim_duration = tot_duration * second
drive_off_time = stim_duration * second

# neuron model
C_m = 500 * pF        # membrane capacitance
g_L = 10 * nS         # leak conductance -> tau_m = C_m / g_L = 20 ms
El = -65*mV
Vr = -65*mV
Vth = -50*mV

# Conductance-based synapse parameters
E_exc = 0*mV
E_inh = -80*mV
tau_exc = 5*ms
tau_inh = 10*ms

eqs = '''
dv/dt = (g_exc*(E_exc - v) + g_inh*(E_inh - v) + g_L*(El - v)) / C_m : volt (unless refractory)
dg_exc/dt = -g_exc / tau_exc : siemens
dg_inh/dt = -g_inh / tau_inh : siemens
'''

# ----------------------------
# Populations
# ----------------------------
base = NeuronGroup(N, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name='base')
base.v = El + randn()*2*mV

# ---------- Inhibitory population for base ----------
N_inh = max(1, N/4)
inhbase = NeuronGroup(N_inh, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name="inhbase")
inhbase.v = El + randn()*2*mV

# Create one long chain: list of subpopulations
chain = []
inhchain = []
for i in range(Lmax):
    G = NeuronGroup(N, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name=f'chain_pool_{i}')
    G.v = El + randn()*2*mV
    chain.append(G)

    G = NeuronGroup(N_inh, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name=f'inhchain_pool_{i}')
    G.v = El + randn()*2*mV
    inhchain.append(G)

# ----------------------------
# External rhythmic stimulus
# ----------------------------
stim_times = np.arange(0.0, float(stim_duration), float(1/stim_freq)) * second
stim = SpikeGeneratorGroup(N, indices=np.repeat(np.arange(N), len(stim_times)), times=np.tile(stim_times, N), name='stim')
#stim = SpikeGeneratorGroup(1, indices=[0 for i in stim_times], times=stim_times, name='stim')
A_stim = g_EE * 10/N * nS 
S_stim = Synapses(stim, base, on_pre='g_exc_post += A_stim', name='stim_to_base', method='euler', delay=5*ms)
S_stim.connect(p=0.5)

# ----------------------------
# Feedforward connections
# ----------------------------
w_ff = g_EE * 10/N * nS 
delay_per_hop = 5 * ms  
ff_syns = []

on_pre_ee = '''
g_exc_post += w_ff
'''



# base -> first pool
S0 = Synapses(base, chain[0], on_pre=on_pre_ee, delay=delay_per_hop, name='base_to_chain0', method='euler')
S0.connect(p=0.5)
ff_syns.append(S0)

# chain[i] -> chain[i+1]
for i in range(Lmax - 1):
    S = Synapses(chain[i], chain[i+1], on_pre=on_pre_ee, delay=delay_per_hop, name=f'chain_{i}_to_{i+1}', method='euler')
    S.connect(p=0.5)
    ff_syns.append(S)

on_pre_ei = '''
g_exc_post += w_ei
'''

on_pre_ie = '''
g_inh_post += w_ie
'''

# E -> I and I -> E synapses (base <-> inh)
w_ei = g_EI * 10/N * nS   # base -> inh (excitatory conductance). Scaled inversely with respect to N
w_ie = g_IE * 10/N * nS   # inh -> base (inhibitory conductance). Scaled inversely with respect to N
S_ei = Synapses(base, inhbase, on_pre=on_pre_ei, delay=5*ms, name='E_to_I')
S_ei.connect(p=0.5)
S_ie = Synapses(inhbase, base, on_pre=on_pre_ie, delay=5*ms, name='I_to_E')
S_ie.connect(p=0.8)

ei_syns = []
ie_syns = []
for i in range(Lmax):
    S = Synapses(chain[i], inhchain[i], on_pre='g_exc_post += w_ei', delay=5*ms, name=f'chain_{i}_to_I', method='euler')
    S.connect(p=0.5)
    ei_syns.append(S)
    S = Synapses(inhchain[i], chain[i], on_pre='g_inh_post += w_ie', delay=5*ms, name=f'I_to_chain_{i}', method='euler')
    S.connect(p=0.8)
    ie_syns.append(S)


# ----------------------------
# STDP feedback connections: from each chain pool -> base
# ----------------------------
tau_pre = A_tau_ms*ms
tau_post = A_tau_ms*ms
A_plus = A_plus_nS * nS
A_minus = A_minus_nS * nS

stdp_syns = []
for i, pool in enumerate(chain):
    model = '''
    w : siemens
    dpre/dt = -pre/tau_pre : siemens (clock-driven)
    dpost/dt = -post/tau_post : siemens (clock-driven)
    '''
    on_pre = '''
    g_exc_post += w
    pre += A_plus
    w = clip(w + post, 0*nS, w_max)
    '''
    on_post = '''
    post += A_minus
    w = clip(w + pre, 0*nS, w_max)
    '''
    S = Synapses(pool, base, model=model, on_pre=on_pre, on_post=on_post, name=f'stdp_chain{i}_to_base', method='euler')
    S.connect(p=0.5) 
    S.w = w_init
    stdp_syns.append(S)

# ----------------------------
# Monitors
# ----------------------------
spike_monitors = []
spike_monitors.append(SpikeMonitor(base,record=True))
for i, pool in enumerate(chain):
    spike_monitors.append(SpikeMonitor(pool,record=True))
spike_monitors.append(SpikeMonitor(inhbase,record=True))  # monitor for inhibitory population

weight_monitors = []
for i, S in enumerate(stdp_syns):
    weight_monitors.append(StateMonitor(S, 'w', record=True, dt=50*ms))

voltage_monitors = []
voltage_monitors.append(StateMonitor(base, 'v', record=0))
for i, pool in enumerate(chain):
    voltage_monitors.append(StateMonitor(pool, 'v', record=0))
voltage_monitors.append(StateMonitor(inhbase, 'v', record=0))  # voltage monitor for inh

net = Network(
    base,                    # base population
    *chain,                  # chain populations
    inhbase,                 # base inhibitory population
    *inhchain,               # inhibitory chain populations
    S_ei, S_ie,              # Synapses to and from the inhibitory population
    stim, S_stim,            # Stimulus and the stimulus synapses
    *ff_syns,                # feed-forward chain synapses
    *ie_syns,
    *ei_syns,
    *stdp_syns,              # feedback STDP synapses
    *spike_monitors,         # all spike monitors
    *voltage_monitors,       # all v monitors
    *weight_monitors         # all weight monitors
)
# ----------------------------
# Simulation
# ----------------------------
print("Initialization done in "+str(time.time()-timenow)+" seconds")
timenow = time.time()

net.run(sim_duration)
print("Whole simulation done in "+str(time.time()-timenow)+" seconds")

# ----------------------------
# Plot results
# ----------------------------
f,axarr = subplots(2,1)

axarr[0].plot(spike_monitors[0].t/second, spike_monitors[0].i, '.', markersize=2)
axarr[0].set_title('Base population spikes')
axarr[0].set_xlabel('Time (s)')
axarr[0].set_ylabel('Neuron index')
axnew = []
axnew.append(f.add_axes([0.5,0.65,0.2,0.25]))
axnew.append(f.add_axes([0.78,0.65,0.2,0.25]))
axnew[0].plot(spike_monitors[0].t/second, spike_monitors[0].i, '.', markersize=2)
axnew[0].set_xlim([0,0.1])
axnew[1].plot(voltage_monitors[0].t/second,(voltage_monitors[0].v/mV)[0],'r-',lw=0.3,label='base')
axnew[1].plot(voltage_monitors[5].t/second,(voltage_monitors[5].v/mV)[0],'m-',lw=0.3,label='pool5')
axnew[1].set_xlim([0,1])

spm = spike_monitors[0]
axarr[1].plot(spm.t/second, spm.i, '.', markersize=1.5)
offset = N+2
for i in range(0, Lmax):  # plot every pool
    spm = spike_monitors[1+i]
    axarr[1].plot(spm.t/second, spm.i + offset, '.', markersize=1.5)
    offset += N+2
axarr[1].set_xlabel('Time (s)')
axarr[1].set_ylabel('Chain pools (stacked)')
axarr[1].set_title('Synfire chain activity (subset)')
tight_layout()

if doSave:
  f,axarr = subplots(1,1)
  for i, S in enumerate(stdp_syns):
      if len(S.w):
          mean_w = np.mean(weight_monitors[i].w, axis=0)
          axarr.plot(weight_monitors[i].t/second, mean_w, label=f'Pool {i}')
          try:
            axarr.text(weight_monitors[i].t[-1]/second, mean_w[-1], 'Pool '+str(i), fontsize=6, ha='right', va='bottom')
          except:
            print('Problem text weight '+str(i))
  axarr.axvline(float(drive_off_time/second), color='k', linestyle='--', label='Drive off')
  axarr.set_xlabel('Time (s)')
  axarr.set_ylabel('Mean feedback weight')
  axarr.set_title('Evolution of feedback synapse weights (pool#base)')
  
  scipy.io.savemat("synfirefiles/stdpsynfire_synstim_allweights_N"+str(N)+"_L"+str(Lmax)+"_T"+str(stim_duration)+"_"+str(tot_duration)+"_"+str(stim_freq_Hz)+"Hz_gEE"+str(g_EE)+"_gEI"+str(g_EI)+"_gIE"+str(g_IE)+"_A"+str(A_plus_nS)+"_"+str(A_minus_nS)+"_"+str(w_init_nS)+tau_addition+seed_addition+".mat",{'spikes': [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i[:] for i in range(0,len(spike_monitors))]],'weights': [weight_monitors[0].t/second,[weight_monitors[i].w/psiemens for i in range(0,len(weight_monitors))]]})