'''
L2 neuron model
evoked activity with/without L1 PW neurons
No variation in functional connectivity
of inhibitory synapses between conditions
(i.e., activation times of inhibitory synapses
are replayed between conditions)

written by Robert Egger (robert.egger@tuebingen.mpg.de)
(c) 2013-2015 Max Planck Society
'''

import sys
import time
import os, os.path
import glob
import neuron
import single_cell_parser as scp
import single_cell_analyzer as sca
import numpy as np
hasMatplotlib = True
try:
    import matplotlib.pyplot as plt
except ImportError:
    hasMatplotlib = False
h = neuron.h

def evoked_activity_replay_inh_synapses(simName, cellName, ongoingUpName, ongoingDownName, evokedUpParamName, evokedDownParamName):
    '''
    pre-stimulus ongoing activity
    and evoked activity
    effect of varying excitatory functional connectivity
    simulated by replaying exact synapse activation times
    of INH synapses while varying EXC functional
    connectivity between conditions
    '''
    neuronParameters = scp.build_parameters(cellName)
    upParameters = scp.build_parameters(ongoingUpName)
    downParameters = scp.build_parameters(ongoingDownName)
    evokedUpNWParameters = scp.build_parameters(evokedUpParamName)
    evokedDownNWParameters = scp.build_parameters(evokedDownParamName)
    scp.load_NMODL_parameters(neuronParameters)
    scp.load_NMODL_parameters(upParameters)
    scp.load_NMODL_parameters(downParameters)
    scp.load_NMODL_parameters(evokedUpNWParameters)
    cellParam = neuronParameters.neuron
    paramUp = upParameters.network
    paramDown = downParameters.network
    paramEvokedUp = evokedUpNWParameters.network
    paramEvokedDown = evokedDownNWParameters.network
    
    cell = scp.create_cell(cellParam, scaleFunc=dendriteScalingUniform)
            
    uniqueID = str(os.getpid())
    dirName = simName
    if not simName.endswith('/'):
        dirName += '/'
    dirName += time.strftime('%Y%m%d-%H%M')
    if not os.path.exists(dirName):
        os.makedirs(dirName)
    
    vTraces = []
    tTraces = []
    
    nSweeps = 2000
    #nSweeps = 2
    tOffset = 100.0 # avoid numerical transients
    tStim = 200.0
    tStop = 250.0
    neuronParameters.sim.tStop = tStop
    dt = neuronParameters.sim.dt
    offsetBin = int(tOffset/dt + 0.5)
    
    spikeThresh = -38.0 # Petersen, AS
    tStim = 200.0
    
    #===========================================================================
    # control runs: create trials with random activity;
    # save exact synapse activation times for replay;
    # save L1 activation times in separate file for analysis
    #===========================================================================
    if 'control1' in simName:
        nRun = 0
        while nRun < nSweeps:
            # 50% down states, 50% up states
            if nRun < 0.5*nSweeps:
                synParameters = paramDown
                synParametersEvoked = paramEvokedDown
            else:
                synParameters = paramUp
                synParametersEvoked = paramEvokedUp
            
            ongoingNW = scp.NetworkMapper(cell, synParameters, neuronParameters.sim)
            ongoingNW.create_network()
            evokedNW = scp.NetworkMapper(cell, synParametersEvoked, neuronParameters.sim)
            evokedNW.create_saved_network()
            
            synTypes , synTypeL1 = [], []
            for synType in cell.synapses.keys():
                # replay all INH synapses
                if 'L1' in synType:
                    synTypeL1.append(synType)
                else:
                    synTypes.append(synType)
            synTypes.sort()
            
            print 'Testing evoked response properties run %d of %d' % (nRun+1, nSweeps)
            tVec = h.Vector()
            tVec.record(h._ref_t)
            startTime = time.time()
            scp.init_neuron_run(neuronParameters.sim)
            stopTime = time.time()
            simdt = stopTime - startTime
            print 'NEURON runtime: %.2f s' % simdt
            
            vmSoma = np.array(cell.soma.recVList[0])
            t = np.array(tVec)
            #===================================================================
            # discard trials above threshold!
            #===================================================================
            begin = int((tStim+15.0)/dt+0.5)
            end = int((tStim+50.0)/dt+0.5)
            maxV = np.max(vmSoma[begin:end])
            if maxV < spikeThresh:
                vTraces.append(np.array(vmSoma[offsetBin:])), tTraces.append(np.array(t[offsetBin:]))
                
                print 'writing simulation results'
                fname = 'simulation_'
                fname += uniqueID
                fname += '_run%04d' % nRun
                if nRun < 0.5*nSweeps:
                    fname += '_down_state'
                else:
                    fname += '_up_state'
                synName = dirName + '/' + fname + '_synapses.csv'
                synNameL1 = dirName + '/' + fname + '_synapsesL1.csv'
                print 'computing active synapse properties'
                sca.compute_synapse_distances_times(synName, cell, t, synTypes)
                sca.compute_synapse_distances_times(synNameL1, cell, t, synTypeL1)
                
                nRun += 1
            else:
                print 'Trial above threshold; running new trial'
            
            cell.re_init_cell()
            cell.remove_synapses('all')
            ongoingNW.re_init_network()
            evokedNW.re_init_network()
    
            print '-------------------------------'
    
    #===========================================================================
    # Load synapse activation times for all trials of corresponding
    # network realization. They should NOT have the L1 activation times
    #===========================================================================
    elif 'control2' in simName:
        tmpStr = simName
        controlBasePath = tmpStr.replace('control2', 'control1')
        synInfoNames = []
        scan_directory(controlBasePath, synInfoNames, '_synapsesL1.csv')
        
        nRun = 0
        while nRun < nSweeps:
            # 50% down states, 50% up states
            if nRun < 0.5*nSweeps:
                synParameters = {}
                synParametersEvoked, synParametersEvokedL1 = {}, {}
                for synType in paramEvokedDown.keys():
                    if 'L1' in synType:
                        synParametersEvokedL1[synType] = paramEvokedDown[synType]
                    else:
                        synParametersEvoked[synType] = paramEvokedDown[synType]
                for synType in paramDown.keys():
                    if 'L1' in synType:
                        synParametersEvokedL1[synType] = paramDown[synType]
                    else:
                        synParameters[synType] = paramDown[synType]
            else:
                synParameters = {}
                synParametersEvoked, synParametersEvokedL1 = {}, {}
                for synType in paramEvokedUp.keys():
                    if 'L1' in synType:
                        synParametersEvokedL1[synType] = paramEvokedUp[synType]
                    else:
                        synParametersEvoked[synType] = paramEvokedUp[synType]
                for synType in paramUp.keys():
                    if 'L1' in synType:
                        synParametersEvokedL1[synType] = paramUp[synType]
                    else:
                        synParameters[synType] = paramUp[synType]
            synParametersEvokedL1['L1D1'].synapses.releaseProb = 1.0
            synParametersEvokedL1['L1D1_ongoing'].synapses.releaseProb = 1.0
            synParametersEvokedL1['L1D2_ongoing'].synapses.releaseProb = 1.0
            
            synInfoName = ''
            nRunStr = 'run%04d' % nRun
            for name in synInfoNames:
                if nRunStr in name:
                    synInfoName = name
                    break
                        
            print 'Replaying network activity from file %s' % synInfoName
            ongoingNW = scp.NetworkMapper(cell, synParameters, neuronParameters.sim)
            ongoingNW.create_network()
            evokedNW = scp.NetworkMapper(cell, synParametersEvoked, neuronParameters.sim)
            evokedNW.create_saved_network()
            replayNW = scp.NetworkMapper(cell, synParametersEvokedL1)
            replayNW.reconnect_saved_synapses(synInfoName)
            
            synTypes , synTypeL1 = [], []
            for synType in cell.synapses.keys():
                # replay all INH synapses
                if 'L1' in synType:
                    synTypeL1.append(synType)
                else:
                    synTypes.append(synType)
            synTypes.sort()
            
            print 'Testing evoked response properties run %d of %d' % (nRun+1, nSweeps)
            tVec = h.Vector()
            tVec.record(h._ref_t)
            startTime = time.time()
            scp.init_neuron_run(neuronParameters.sim)
            stopTime = time.time()
            simdt = stopTime - startTime
            print 'NEURON runtime: %.2f s' % simdt
            
            vmSoma = np.array(cell.soma.recVList[0])
            t = np.array(tVec)
            #===================================================================
            # discard trials above threshold!
            #===================================================================
            begin = int((tStim+15.0)/dt+0.5)
            end = int((tStim+50.0)/dt+0.5)
            maxV = np.max(vmSoma[begin:end])
            if maxV < spikeThresh:
                vTraces.append(np.array(vmSoma[offsetBin:])), tTraces.append(np.array(t[offsetBin:]))
                
                print 'writing simulation results'
                fname = 'simulation_'
                fname += uniqueID
                fname += '_run%04d' % nRun
                if nRun < 0.5*nSweeps:
                    fname += '_down_state'
                else:
                    fname += '_up_state'
                synName = dirName + '/' + fname + '_synapses.csv'
                synNameL1 = dirName + '/' + fname + '_synapsesL1.csv'
                print 'computing active synapse properties'
                sca.compute_synapse_distances_times(synName, cell, t, synTypes)
                sca.compute_synapse_distances_times(synNameL1, cell, t, synTypeL1)
                
                nRun += 1
            else:
                print 'Trial above threshold; running new trial'
            
            cell.re_init_cell()
            cell.remove_synapses('all')
            replayNW.re_init_network()
            ongoingNW.re_init_network()
            evokedNW.re_init_network()
    
            print '-------------------------------'
    
    vTraces = np.array(vTraces)
    print 'computing Vm STD and histogram'
    vStd = np.std(vTraces, axis=0)
    peakWindow, avgPeak = sca.compute_mean_psp_amplitude(vTraces, tStim=200.0-tOffset, dt=neuronParameters.sim.dt)
    windows, avgVmStd = sca.compute_vm_std_windows(vStd, tStim=200.0-tOffset, dt=neuronParameters.sim.dt)
    hist, bins = sca.compute_vm_histogram(vTraces)
    scp.write_all_traces(dirName+'/'+uniqueID+'_vm_all_traces.csv', t[offsetBin:], vTraces)
    scp.write_sim_results(dirName+'/'+uniqueID+'_vm_std.csv', t[offsetBin:], vStd)
    scp.write_sim_results(dirName+'/'+uniqueID+'_vm_avg_psp.csv', peakWindow, avgPeak)
    scp.write_sim_results(dirName+'/'+uniqueID+'_vm_std_windows.csv', windows, avgVmStd)
    scp.write_sim_results(dirName+'/'+uniqueID+'_vm_hist.csv', hist, bins[:-1])
    
    print 'writing simulation parameter files'
    neuronParameters.save(dirName+'/'+uniqueID+'_neuron_model.param')
    upParameters.save(dirName+'/'+uniqueID+'_network_model_upstate.param')
    downParameters.save(dirName+'/'+uniqueID+'_network_model_downstate.param')
    evokedUpNWParameters.save(dirName+'/'+uniqueID+'_network_model_evoked_upstate.param')
    evokedDownNWParameters.save(dirName+'/'+uniqueID+'_network_model_evoked_downstate.param')
    
    if hasMatplotlib:
        ax = []
        plt.figure()
        for i in range(nSweeps):
            ax.append(plt.plot(tTraces[i], vTraces[i], 'k'))
        plt.xlabel('t [ms]')
        plt.ylabel('Vm [mV]')
        plt.savefig(dirName+'/'+uniqueID+'_all_traces.pdf')
        plt.figure()
        plt.plot(tTraces[0], vStd, 'k')
        plt.xlabel('t [ms]')
        plt.ylabel('Vm STD [mV]')
        plt.savefig(dirName+'/'+uniqueID+'_vm_std.pdf')

def scan_directory(path, fnames, suffix):
    for fname in glob.glob(os.path.join(path, '*')):
        if os.path.isdir(fname):
            scan_directory(fname, fnames, suffix)
        elif fname.endswith(suffix):
            fnames.append(fname)
        else:
            continue

def dendriteScalingUniform(cell):
    dendScale = 1/1.2
    for sec in cell.sections:
        if sec.label == 'Dendrite' or sec.label == 'ApicalDendrite':
            dummy = h.pt3dclear(sec=sec)
            for i in range(sec.nrOfPts):
                x, y, z = sec.pts[i]
                sec.diamList[i] = sec.diamList[i]*dendScale
                d = sec.diamList[i]
                dummy = h.pt3dadd(x, y, z, d, sec=sec)

if __name__ == '__main__':
    if len(sys.argv) == 7:
        name = sys.argv[1]
        cellName = sys.argv[2]
        ongoingUpName = sys.argv[3]
        ongoingDownName = sys.argv[4]
        evokedUpName = sys.argv[5]
        evokedDownName = sys.argv[6]
        evoked_activity_replay_inh_synapses(name, cellName, ongoingUpName, ongoingDownName, evokedUpName, evokedDownName)
    else:
        print 'Error! Number of arguments is %d; should be 6!' % (len(sys.argv)-1)