"""
init.py

Starting script to run NetPyNE-based cortical network model to study epilepsy.


Usage:
    python init.py # Run simulation, optionally plot a raster


MPI usage:
    mpiexec -n 4 nrniv -python -mpi init.py


Contributors: salvadordura@gmail.com
"""

import matplotlib; matplotlib.use('Agg')  # to avoid graphics error in servers

from netpyne import sim
import random as random
import time
random.seed(1072) 
import os
import pickle
#import math

# read parameter files
cfg, netParams = sim.readCmdLineArgs(simConfigDefault='cfg_net.py', netParamsDefault='netParams_net.py')

# create network, run simulation and analyze output
#sim.createSimulateAnalyze(netParams, cfg)

sim.pc.timeout(0)
sim.initialize(netParams, cfg)  # create network object and set cfg and net params
sim.net.createPops()                  # instantiate network populations
sim.net.createCells()                 # instantiate network cells based on defined populations
sim.net.defineCellShapes()
sim.net.connectCells()                # create connections between cells based on params
sim.net.addStims()                    # add external stimulation to cells (IClamps etc)
sim.net.addRxD()                      # add reaction-diffusion (RxD)
sim.setupRecording()             # setup variables to record for each cell (spikes, V traces, etc)

for cell in sim.net.cells:
    for sec in cell.secs.values():
        sec['vinit'] = random.uniform(-90.0,+10.0)
    cell.initV()


# simConfig.filename = 'rectified'

#Ncells_E2_to_E2 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['E2','E2']]
#Ncells_E2_to_E5 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['E2','E5']]
#Ncells_E5_to_E2 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['E5','E2']]
#Ncells_E5_to_E5 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['E5','E5']]

#Ncells_E2_to_I2 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['E2','I2']]
#Ncells_E5_to_I2 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['E5','I2']]

#Ncells_I2_to_I2 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['I2','I2']]

#Ncells_I2_to_E2 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['I2','E2']]
#Ncells_I2_to_E5 = [n for n in range(len(sim.net.cells)) if sim.net.cells[n].tags['pop'] in ['I2','E5']]


def adjust_gap_junction_configs(syn_type):
    for n in range(len(sim.net.cells)):
        conn_as_pre = [nc for nc in range(len(sim.net.cells[n].conns)) if sim.net.cells[n].conns[nc]['synMech']==syn_type and sim.net.cells[n].conns[nc]['gapJunction']=='pre']
        conn_as_post = [nc for nc in range(len(sim.net.cells[n].conns)) if sim.net.cells[n].conns[nc]['synMech']==syn_type and sim.net.cells[n].conns[nc]['gapJunction']=='post']

        print('Cell '+str(n))
        print(conn_as_pre, conn_as_post)
    
        for nc in conn_as_pre:
            sim.net.cells[n].secs['soma']['synMechs'][nc]['hObj'].weight = 0.0

        #for nc in conn_as_post:
         #   sim.net.cells[n].secs['soma']['synMechs'][nc]['hObj'].esyn = esyn
          #  sim.net.cells[n].secs['soma']['synMechs'][nc]['hObj'].vslope = 10.0
          #  sim.net.cells[n].secs['soma']['synMechs'][nc]['hObj'].weight = weight
          #  sim.net.cells[n].secs['soma']['synMechs'][nc]['hObj'].vth = vth
          #  sim.net.cells[n].secs['soma']['synMechs'][nc]['hObj'].tau = tau



adjust_gap_junction_configs('gradsynAMPA')
adjust_gap_junction_configs('gradsynGABA')

# Interval
#sim.runSimIntervalSaving(100000)
#sim.pc.barrier()
#sim.gatherData()
#sim.pc.barrier()
#sim.saveData()
#sim.pc.barrier()
#sim.analysis.plotData()



# Normal
#sim.simulate()
#sim.pc.barrier()
#sim.analyze() 
#sim.pc.barrier()


# Nodes 
# sim.runSim()
# sim.runSimWithIntervalFunc(1000000, sim.saveDataInNodes)
sim.runSimWithIntervalFunc(500000, sim.saveDataInNodes)

print("File names!\n")

sim.pc.barrier()
print("A. sim rank is "+str(sim.rank)+"\n")
if sim.rank==0:
    os.rename('./output','./node_interval_results')
sim.pc.barrier()
sim.saveDataInNodes()
sim.pc.barrier()

if sim.rank==0:
    os.rename('./output','./data_net_data')
print("B. sim rank is "+str(sim.rank)+"\n")

sim.pc.barrier()
sim.gatherDataFromFiles()
sim.pc.barrier()

print("C. sim rank is "+str(sim.rank)+"\n")
if sim.rank==0:
    os.rename('./data_net_data','./output')
    x = {'simData':sim.allSimData}
    #print(x['simData'])
    with open("./output/data_net_data.pkl", 'wb') as fileObj:
        pickle.dump(x,fileObj)

sim.pc.barrier()
sim.analysis.plotData()
print("D. sim.rank is "+str(sim.rank)+"\n")

#if sim.rank!=0:
#    time.sleep(20)

quit()