################################################################################
# Source class / functions for simulating networks of spiking neurons
# with different degrees of specific connectivity
#
# Reference: Sadeh, Clopath and Rotter (PLOS ONE, 2015).
# "Processing of Feature Selectivity in Cortical Networks with Specific Connectivity".
#
# Author: Sadra Sadeh <s.sadeh@ucl.ac.uk> // Created: 2015-2016
################################################################################

import numpy as np
import pylab as pl
import time, os, sys, pickle
from scipy.stats import norm
from imp import reload
import defaultParams; reload(defaultParams); from defaultParams import *
import nest

################################################################################
## Class for simulating a network of leaky integrate-and-fire (LIF) neurons
################################################################################
class SpecNet(object):

      # --- initial parameters
      def __init__(self):
          pass


      # --- generating random connectivity
      def random_connectivity(self, eps_exc=0.1, eps_inh=0.1):
          '''
          generates random connectivity with defined connection probabilites
          and with fixed in-degrees

          esp_exc: probability of (E --> E,I) connections
          esp_inh: probability of (I --> E,I) connections
          '''

          print("Generating connectivitiy matrix")

          # exc pre-synaptic input per neuron
          CE = int(eps_exc* NE)
          # inh pre-synaptic input per neuron
          CI = int(eps_inh* NI)

          # comment out to generate a new connectivity each time
          np.random.seed(1234)

          # fixed in-degree
          Cin_exc, Cin_inh = [], []
          for n in range(N):
              #if n%1000 == 0: print(n)

              # excitatory
              if n <= NE:
                 zex = list(range(1, n)) + list(range(n+1,NE+1))
              else:
                 zex = list(range(1, NE+1))
              np.random.shuffle(zex)
              cin_exc = zex[0:CE]

              # inhibitory
              if n <= NE:
                 zin = list(range(NE+1, N+1))
              else:
                 zin = list(range(NE+1, n)) + list(range(n+1,N+1))
              np.random.shuffle(zin)
              cin_inh = zin[0:CI]

              Cin_exc.append(cin_exc)
              Cin_inh.append(cin_inh)

          return Cin_exc, Cin_inh


      # --- initiating NEST
      def NEST_initiate(self, dt=.1, n_cores=4):
            '''
            initiating NEST

            dt: time resolution of simulations
            n_cores: how many cores to use for simulations (the more, the faster)
            '''

            # NEST restart
            nest.ResetKernel()
            rng_seed = int(np.random.uniform(0, 12345))
            nest.SetStatus([0],[{'rng_seeds':[rng_seed]}])
            nest.SetKernelStatus({"resolution": dt,
                                "print_time": True,
                                "overwrite_files": True,
                                'local_num_threads': n_cores})
            nest.sr("/RandomConvergentConnect << /allow_multapses false >> SetOptions")


      # --- stimulating the network in response to one stimulus orientation
      def stimulate_network(self, stim_deg, con_exc=[], con_inh=[], stim_id=0, \
                            mem_pot_rec=1, fs_conn = [0,0,0,0], fs_mod=.1):

          # fs_conn: defines whether (E,I) to (E,I) synapses are FS (1) or not (0)
          # order: [EtoE, EtoI, ItoE, ItoI]
          fs_ee, fs_ei, fs_ie, fs_ii = fs_conn

          # --- Start
          ti = time.time()

          self.NEST_initiate()

          print("Network ...")

          # spiking neurons
          nest.SetDefaults("iaf_psc_delta", neuron_params_delta)
          nodes_ex = nest.Create("iaf_psc_delta", NE)
          nodes_in = nest.Create("iaf_psc_delta", NI)

          nodes_all = nodes_ex + nodes_in

          # replicate neurons
          nest.SetDefaults("iaf_psc_delta", neuron_params_delta_rep)
          nodes_ex_rep_excInp = nest.Create("iaf_psc_delta", n_smpl)
          nodes_ex_rep_inhInp = nest.Create("iaf_psc_delta", n_smpl)

          nodes_in_rep_excInp = nest.Create("iaf_psc_delta", n_smpl)
          nodes_in_rep_inhInp = nest.Create("iaf_psc_delta", n_smpl)

          nodes_all_rep_excInp = nodes_ex_rep_excInp + nodes_in_rep_excInp
          nodes_all_rep_inhInp = nodes_ex_rep_inhInp + nodes_in_rep_inhInp

          # vector of total input to all neurons
          input_rate = r_base *(1.+ input_mod* np.cos(2*(stim_deg - po_init)) )

          noise = nest.Create("poisson_generator", N)

          # -- define synapses
          #external
          nest.CopyModel("static_synapse", "ext", {"weight":J_ext, "delay":delay})
          #feedforward
          nest.CopyModel("static_synapse", "ffw", {"weight":J_ffw, "delay":delay})
          # E-to-E
          nest.CopyModel("static_synapse", "ex-ex", {"weight":J_ee, "delay":delay})
          # E-to-I
          nest.CopyModel("static_synapse", "ex-in", {"weight":J_ei, "delay":delay})
          # I-to-E
          nest.CopyModel("static_synapse", "in-ex", {"weight":J_ie, "delay":delay})
          # I-to-I
          nest.CopyModel("static_synapse", "in-in", {"weight":J_ii, "delay":delay})

          # -- devices to record

          # spike detector
          sp_all_trans = nest.Create("spike_detector", trial_no)
          sp_all = nest.Create("spike_detector", trial_no)
          for trial in range(trial_no):
              # transient
              nest.SetStatus([sp_all_trans[trial]],
                          {"label":"spikes-all-trans-"+stim_id+'-tr'+str(trial),
                           "withgid":True, "withtime":True,
                           "to_file":True, "to_memory":False,
                           "start": simtime*trial + 0., "stop": simtime*trial + t_trans })
              nest.ConvergentConnect(nodes_all, [sp_all_trans[trial]], model="ext")
              # stationary
              nest.SetStatus([sp_all[trial]],
                        {"label":"spikes-all-"+stim_id+'-tr'+str(trial),
                         "withgid":True, "withtime":True,
                         "to_file":True, "to_memory":False,
                         "start": simtime*trial + t_trans, "stop": simtime*trial + simtime })
              nest.ConvergentConnect(nodes_all, [sp_all[trial]], model="ext")

          # volt meter
          if mem_pot_rec:
             # stationary
             vm_ex_stat = nest.Create("voltmeter")
             nest.SetStatus(vm_ex_stat, {"label":"vm-exc-"+stim_id,
                         "to_file":True, "to_memory":False})
             nest.DivergentConnect(vm_ex_stat, nodes_ex[0:n_smpl])

             vm_in_stat = nest.Create("voltmeter")
             nest.SetStatus(vm_in_stat, {"label":"vm-inh-"+stim_id,
                         "to_file":True, "to_memory":False})
             nest.DivergentConnect(vm_in_stat, nodes_in[0:n_smpl])

             # replicates (measuring input)
             vm_ex_excInp = nest.Create("voltmeter")
             nest.SetStatus(vm_ex_excInp, {"label":"vm-exc-excInp-"+stim_id,
                         "to_file":True, "to_memory":False})
             nest.DivergentConnect(vm_ex_excInp, nodes_ex_rep_excInp[0:n_smpl])

             vm_ex_inhInp = nest.Create("voltmeter")
             nest.SetStatus(vm_ex_inhInp, {"label":"vm-exc-inhInp-"+stim_id,
                         "to_file":True, "to_memory":False})
             nest.DivergentConnect(vm_ex_inhInp, nodes_ex_rep_inhInp[0:n_smpl])

             vm_in_excInp = nest.Create("voltmeter")
             nest.SetStatus(vm_in_excInp, {"label":"vm-inh-excInp-"+stim_id,
                         "to_file":True, "to_memory":False})
             nest.DivergentConnect(vm_in_excInp, nodes_in_rep_excInp[0:n_smpl])

             vm_in_inhInp = nest.Create("voltmeter")
             nest.SetStatus(vm_in_inhInp, {"label":"vm-inh-inhInp-"+stim_id,
                         "to_file":True, "to_memory":False})
             nest.DivergentConnect(vm_in_inhInp, nodes_in_rep_inhInp[0:n_smpl])

          # -- Connecting the network
          print("Connecting the network ...")

          CE = int(eps_exc* NE)
          CI = int(eps_inh* NI)

          for nn in enumerate(nodes_all):
              nest.SetStatus([noise[nn[0]]], {'rate':input_rate[nn[0]]})
              nest.Connect([noise[nn[0]]], [nn[1]], model="ffw")

          ext_inp = nest.Create("poisson_generator")
          nest.SetStatus(ext_inp, {"rate":r_ext})
          nest.DivergentConnect(ext_inp, nodes_all, model="ext")

          # -- feature-specific modulation of the weights
          for ne in range(NE):
              # exc --> exc
               dth = po_init[np.array(con_exc[ne])-1] - po_init[ne]
               exc_list = (J_ee *(1.+ fs_ee *fs_mod * np.cos(2*dth))).tolist()
               nest.ConvergentConnect(con_exc[ne], [nodes_all[ne]], exc_list, CE*[delay])
               if ne < n_smpl:
                   nest.ConvergentConnect(con_exc[ne], [nodes_all_rep_excInp[ne]], exc_list, CE*[delay])
              # inh --> exc
               dth = po_init[np.array(con_inh[ne])-1] - po_init[ne]
               inh_list = (J_ie *(1+fs_ie*fs_mod *np.cos(2*dth))).tolist()
               nest.ConvergentConnect(con_inh[ne], [nodes_all[ne]], inh_list, CI*[delay])
               if ne < n_smpl:
                   nest.ConvergentConnect(con_inh[ne], [nodes_all_rep_inhInp[ne]], inh_list, CI*[delay])

          for ni in range(NE, N):
              # exc --> inh
               dth = po_init[np.array(con_exc[ni])-1] - po_init[ni]
               exc_list = (J_ei *(1+fs_ei*fs_mod *np.cos(2*dth))).tolist()
               nest.ConvergentConnect(con_exc[ni], [nodes_all[ni]], exc_list, CE*[delay])
              # inh --> inh
               dth = po_init[np.array(con_inh[ni])-1] - po_init[ni]
               inh_list = (J_ii *(1+fs_ii*fs_mod *np.cos(2*dth))).tolist()
               nest.ConvergentConnect(con_inh[ni], [nodes_all[ni]], inh_list, CI*[delay])

          # --- running simulations

          print("Simulating the network ...")

          for trial in range(trial_no):
              print('# -- Trial # ', str(trial+1))
              nest.Simulate(simtime + t_trans)

          ts = time.time()
          sim_time = ts - ti
          r_avg = nest.GetStatus([sp_all[0]], 'n_events')[0]/ (N*simtime/1000.)

          print('\n########################################')
          print('End of simulation.')
          print("Simulation time   : %.2f s" % sim_time)
          print("Average rate      : %.2f Hz" %  r_avg)
          print('######################################## \n')