"""
MODEL

A large-scale network simulation for exploring traveling waves, stimuli,
and STDs. Built solely in Python, using Izhikevich neurons and with MPI
support. Runs in real-time with over 8000 cells when appropriately
parallelized.
M1 model extended to interface with Plexon-recorded PMd data, virtual arm,
and reinforcement learning

Usage:
    python model.py # Run simulation, optionally plot a raster
    python simmovie.py # Show a movie of the results
    python model.py scale=20 # Run simulation, set scale=20

MPI usage:
    mpiexec -n 4 nrniv -python -mpi model.py

Version: 2014feb21 by cliffk
2014sep19 modified by salvadord and giljael
"""


###############################################################################
### IMPORT MODULES
###############################################################################

from pylab import seed, rand, sqrt, exp, transpose, ceil, concatenate, array, zeros, ones, vstack, show, disp, mean, inf, concatenate, unique, delete
from time import time, sleep
from datetime import datetime
from scipy.io import savemat, loadmat
import pickle

from neuron import h, init, run # Import NEURON
import shared as s # Import all shared variables and parameters
import analysis
from arm import Arm # Class with arm methods and variables


###############################################################################
### Sequences of commands to run full model
###############################################################################

# training and testing to 2 targets manually
def runTrainTest2targets():

    # optimized values for musculoskeletal arm (here using dummy arm for demo purposes)
    s.targetid = 0
    s.trainTime = 2000 # 85000 # using 2 sec for demo purposes
    s.stdpwin = 48.5
    s.eligwin = 117.8
    s.RLfactor = 0.01 #6
    s.RLinterval = 76.8
    s.backgroundrate = s.backgroundrateTest = 134.5
    s.backgroundrateExplor = 5
    s.cmdmaxrate = 528.8
    s.PMdconnweight = 1.0
    s.PMdconnprob = 2.4

    s.useArm = 'dummyArm' #'musculoskeletal'

    s.numTrials = ceil(s.trainTime/1000)
    s.trialTargets = [i%2 for i in range(int(s.numTrials+1))] # set target for each trial
    s.targetid=s.trialTargets[0]

    verystart=time() # store initial time

    s.plotraster = 1 # set plotting params
    s.plotconn = 0
    s.plotweightchanges = 0
    s.plot3darch = 0
    s.graphsArm = 1
    s.animArm = 1
    s.savemat = 0 # save data during testing
    s.armMinimalSave = 1 # save only arm related data

    # initialize network
    createNetwork()
    addStimulation()
    addBackground()

    # train
    s.usestdp = 1 # Whether or not to use STDP
    s.useRL = 0 # Where or not to use RL
    s.explorMovs = 1# enable exploratory movements
    s.antagInh = 0 # enable exploratory movements
    s.duration = s.trainTime # train time

    setupSim()
    runSim()
    finalizeSim()
    #saveData()
    plotData()

    # test target 0
    s.backgroundrate=s.backgroundrateTest # 300
    s.cmdmaxrate=s.cmdmaxrateTest # 15
    addBackground()
    s.usestdp = 0 # Whether or not to use STDP
    s.useRL = 0 # Where or not to use RL
    s.explorMovs = 0 # disable exploratory movements
    s.duration = s.testTime # testing time
    s.armMinimalSave = 0 # save only arm related data

    s.targetid = 0
    setupSim()
    runSim()
    finalizeSim()
    #saveData()
    plotData()

    if s.rank == 0: # save error to file
        error0 = mean(s.arm.errorAll)
        print('Target error for target ',s.targetid,' is:', error0)
        s.arm.plotTraj(s.outfilestem+'_t0.png')

    # test target 1
    s.targetid = 1
    setupSim()
    runSim()
    finalizeSim()
    saveData()
    plotData()

    if s.rank == 0: # save error to file
        error1 = mean(s.arm.errorAll)
        print('Target error for target 0=', error0, '; target 1=', error1)
        s.arm.plotTraj(s.outfilestem+'_t1.png')

        errorMean = (error0+error1)/2
        errorFitness = errorMean + abs(error0-error1)  # fitness penalizes difference between target errors
        errorDic = {}
        errorDic['error0'] = error0
        errorDic['error1'] = error1
        errorDic['meanError'] = errorMean
        errorDic['errorFitness'] = errorFitness

        print('Mean error = %.4f ; Mean error + difference (fitness) = %.4f'%(errorMean, errorFitness))

        s.targetid = 0 # so saves to correct file name (error of both targets saved to single file ending in target_0_error)
        with open('%s_target_%d_error'% (s.outfilestem,s.targetid), 'wb') as f: # save avg error over targets to outfilestem
            pickle.dump(errorDic, f)

    ## Wrapping up
    s.pc.runworker() # MPI: Start simulations running on each host
    s.pc.done() # MPI: Close MPI
    totaltime = time()-verystart # See how long it took in total
    print(('\nDone; total time = %0.1f s.' % totaltime))
    if (s.plotraster==False and s.plotconn==False and s.plotweightchanges==False): h.quit() # Quit extra processes, or everything if plotting wasn't requested (since assume non-interactive)


# training and testing to 2 targets via evolutionary optim algorithm (batch, no graphics)
def runTrainTest2targetsOptim():
    # evol optimizes the following:
    s.RLrates = s.RLfactor*array([[0.25, -0.25], [0.0, 0.0]]) # RL potentiation/depression rates for E->anything and I->anything, e.g. [0,:] is pot/dep for E cells
    s.connprobs[s.PMd,s.ER5]=s.PMdconnprob
    s.connweights[s.PMd,s.ER5,s.AMPA]=s.PMdconnweight

    s.verbose=0
    s.useArm = 'musculoskeletal'
    #s.useArm = 'dummyArm'
    s.numTrials = ceil(s.trainTime/1000)
    s.trialTargets = [i%2 for i in range(int(s.numTrials+1))] # set target for each trial
    s.targetid=s.trialTargets[0]

    verystart=time() # store initial time

    s.plotraster = 0 # set plotting params
    s.plotconn = 0
    s.plotweightchanges = 0
    s.plot3darch = 0
    s.graphsArm = 0
    s.animArm = 0
    s.savemat = 0 # save data during testing
    s.armMinimalSave = 0 # save only arm related data

    # train
    s.usestdp = 1 # Whether or not to use STDP
    s.useRL = 1 # Where or not to use RL
    s.explorMovs = 1 # enable exploratory movements
    s.antagInh = 0 # enable exploratory movements
    s.duration = s.trainTime # train time
    s.timebetweensaves =  s.trainTime - 1000

    # initialize network
    createNetwork()
    addStimulation()
    addBackground()

    # run train
    setupSim()
    runSim()
    finalizeSim()
    #saveData()
    #plotData()
    if s.rank == 0: # save png of traj
        s.arm.plotTraj(s.outfilestem+'_train.png')  # save traj fig to file
        analysis.plotweightchanges(s.outfilestem+'_train_weights.png')

    test = 1
    s.savemat = 1
    if test:
        # test target 0
        #s.backgroundrate=s.backgroundrateTest # 300
        #s.cmdmaxrate=s.cmdmaxrateTest # 15
        addBackground()
        s.usestdp = 0 # Whether or not to use STDP
        s.useRL = 0 # Where or not to use RL
        s.explorMovs = 0 # disable exploratory movements
        s.duration = s.testTime # testing time
        s.armMinimalSave = 0 # save only arm related data

        s.targetid = 0
        setupSim()
        runSim()
        finalizeSim()
        saveData()
        #plotData()

        if s.rank == 0: # save error to file
            error0 = mean(s.arm.errorAll)
            print('Target error for target ',s.targetid,' is:', error0)
            s.arm.plotTraj(s.outfilestem+'_t0.png')
            analysis.plotraster(s.outfilestem+'_t0_raster.png')

        # test target 1
        s.targetid = 1
        setupSim()
        runSim()
        finalizeSim()
        saveData()
        #plotData()

        if s.rank == 0: # save error to file
            error1 = mean(s.arm.errorAll)
            print('Target error for target 0=', error0, '; target 1=', error1)
            s.arm.plotTraj(s.outfilestem+'_t1.png')
            analysis.plotraster(s.outfilestem+'_t1_raster.png')

            errorMean = (error0+error1)/2
            errorFitness = errorMean + abs(error0-error1)  # fitness penalizes difference between target errors
            errorDic = {}
            errorDic['error0'] = error0
            errorDic['error1'] = error1
            errorDic['meanError'] = errorMean
            errorDic['errorFitness'] = errorFitness

            print('Mean error = %.4f ; Mean error + difference (fitness) = %.4f'%(errorMean, errorFitness))

            s.targetid = 0 # so saves to correct file name (error of both targets saved to single file ending in target_0_error)
            with open('%s_target_%d_error'% (s.outfilestem,s.targetid), 'wb') as f: # save avg error over targets to outfilestem
                pickle.dump(errorDic, f)

    ## Wrapping up
    s.pc.runworker() # MPI: Start simulations running on each host
    s.pc.done() # MPI: Close MPI
    totaltime = time()-verystart # See how long it took in total
    print(('\nDone; total time = %0.1f s.' % totaltime))
    if (s.plotraster==False and s.plotconn==False and s.plotweightchanges==False): h.quit() # Quit extra processes, or everything if plotting wasn't requested (since assume non-interactive)


###############################################################################
### Create Network
###############################################################################

def createNetwork():
    ## Print diagnostic information
    if s.rank==0: print(("\nCreating simulation of %i cells for %0.1f s on %i hosts..." % (sum(s.popnumbers),s.duration/1000.,s.nhosts)))
    s.pc.barrier()

    ## Create empty data structures
    s.cells=[] # Create empty list for storing cells
    s.dummies=[] # Create empty list for storing fake sections
    s.gidVec=[] # Empty list for storing GIDs (index = local id; value = gid)
    s.gidDic = {} # Empyt dict for storing GIDs (key = gid; value = local id) -- ~x6 faster than gidVec.index()


    ## Set cell types
    celltypes=[]
    for c in range(s.ncells): # Loop over each cell. ncells is all cells in the network.
        if s.cellclasses[c]==1: celltypes.append(s.RS) # Append a regular spiking pyramidal cell
        elif s.cellclasses[c]==2: celltypes.append(s.IB) # Append an intrinsically bursting pyramidal cell
        elif s.cellclasses[c]==3: celltypes.append(s.CH) # Append a chattering cell
        elif s.cellclasses[c]==4: celltypes.append(s.LTS) # Append a low-threshold spiking interneuron
        elif s.cellclasses[c]==5: celltypes.append(s.FS) # Append a fast-spiking interneuron
        elif s.cellclasses[c]==4: celltypes.append(s.TC) # Append a thalamocortical cell
        elif s.cellclasses[c]==5: celltypes.append(s.RTN) # Append a reticular thalamic nucleus cell
        elif s.cellclasses[c]==-1: celltypes.append(s.nsloc) # Append a nsloc
        else: raise Exception('Undefined cell class "%s"' % s.cellclasses[c]) # No match? Cause an error


    ## Set positions
    seed(s.id32('%d'%s.randseed)) # Reset random number generator
    s.xlocs = s.modelsize*rand(s.ncells) # Create random x locations
    s.ylocs = s.modelsize*rand(s.ncells) # Create random y locations
    s.zlocs = rand(s.ncells) # Create random z locations
    for c in range(s.ncells):
        s.zlocs[c] = s.corticalthick * (s.zlocs[c]*(s.popyfrac[s.cellpops[c]][1]-s.popyfrac[s.cellpops[c]][0]) + s.popyfrac[s.cellpops[c]][0])  # calculate based on yfrac for population and corticalthick


    ## Actually create the cells
    s.spikerecorders = [] # Empty list for storing spike-recording Netcons
    s.hostspikevecs = [] # Empty list for storing host-specific spike vectors
    s.cellsperhost = 0
    if s.PMdinput == 'Plexon': ninnclDic = len(s.innclDic) # number of PMd created in this worker
    for c in range(int(s.rank), s.ncells, s.nhosts):
        s.dummies.append(h.Section()) # Create fake sections
        gid = c
        if s.cellnames[gid] == 'PMd':
            if s.PMdinput == 'Plexon':
                cell = celltypes[gid](cellid = gid) # create an NSLOC
                s.inncl.append(h.NetCon(None, cell))  # This netcon receives external spikes
                s.innclDic[gid - s.ncells - s.server.numPMd] = ninnclDic # This dictionary works in case that PMd's gid starts from 0.
                ninnclDic += 1
            elif s.PMdinput == 'targetSplit':
                cell = celltypes[gid](cellid = gid) # create an NSLOC
                cell.number = s.backgroundnumber
                cell.interval = s.backgroundrateMin**-1*1e3
                cell.noise = s.PMdNoiseRatio
            elif s.PMdinput == 'spikes':
                cell = h.VecStim()
            else:
                cell = celltypes[gid](cellid = gid) # create an NSLOC
                cell.number = s.backgroundnumber
                cell.interval = s.backgroundrateMin**-1*1e3

        elif s.cellnames[gid] == 'ASC':
            cell = celltypes[gid](cellid = gid) #create an NSLOC
        else:
            if s.cellclasses[gid]==3:
                cell = s.fastspiking(s.dummies[s.cellsperhost], vt=-47, cellid=gid) # Don't use LTS cell, but instead a FS cell with a low threshold
            else:
                cell = celltypes[gid](s.dummies[s.cellsperhost], cellid=gid) # Create a new cell of the appropriate type (celltypes[gid]) and store it
            #if s.verbose>0: s.cells[-1].useverbose(s.verbose, s.filename+'los.txt') # Turn on diagnostic to file
        s.cells.append(cell)
        s.gidVec.append(gid) # index = local id; value = global id
        s.gidDic[gid] = s.cellsperhost # key = global id; value = local id -- used to get local id because gid.index() too slow!
        s.pc.set_gid2node(gid, s.rank)

        spikevec = h.Vector()
        s.hostspikevecs.append(spikevec)
        spikerecorder = h.NetCon(cell, None)
        spikerecorder.record(spikevec)
        s.spikerecorders.append(spikerecorder)
        s.pc.cell(gid, s.spikerecorders[s.cellsperhost])
        s.cellsperhost += 1 # contain cell numbers per host including PMd and P
    print(('  Number of cells on node %i: %i ' % (s.rank,len(s.cells))))
    s.pc.barrier()


    ## Calculate motor command cell ranges so can be used for EDSC and IDSC connectivity
    nCells = s.motorCmdEndCell - s.motorCmdStartCell
    s.motorCmdCellRange = []
    for i in range(s.nMuscles):
        s.motorCmdCellRange.append(list(range(s.motorCmdStartCell + int(nCells/s.nMuscles)*i, s.motorCmdStartCell + int(nCells/s.nMuscles)*i + int(nCells/s.nMuscles)))) # cells used to for shoulder motor command


    ## Calculate distances and probabilities
    if s.rank==0: print(('Calculating connection probabilities (est. time: %i s)...' % (s.performance*s.cellsperhost**2/3e4)))
    conncalcstart = s.time() # See how long connecting the cells takes
    s.nconnpars = 5 # Connection parameters: pre- and post- cell ID, weight, distances, delays
    s.conndata = [[] for i in range(s.nconnpars)] # List for storing connections
    nPostCells = 0
    EDSCpre = [] # to keep track of EB5->EDSC connection and replicate in EB5->IDSC
    for c in range(s.cellsperhost): # Loop over all postsynaptic cells on this host (has to be postsynaptic because of gid_connect)
        gid = s.gidVec[c] # Increment global identifier
        if s.cellnames[gid] == 'PMd' or s.cellnames[gid] == 'ASC':
            # There are no presynaptic connections for PMd or ASC.
            continue
        nPostCells += 1
        if s.toroidal:
            xpath=(abs(s.xlocs-s.xlocs[gid]))**2
            xpath2=(s.modelsize-abs(s.xlocs-s.xlocs[gid]))**2
            xpath[xpath2<xpath]=xpath2[xpath2<xpath]
            ypath=(abs(s.ylocs-s.ylocs[gid]))**2
            ypath2=(s.modelsize-abs(s.ylocs-s.ylocs[gid]))**2
            ypath[ypath2<ypath]=ypath2[ypath2<ypath]
            zpath=(abs(s.zlocs-s.zlocs[gid]))**2
            distances = sqrt(xpath + ypath) # Calculate all pairwise distances
            distances3d = sqrt(xpath + ypath + zpath) # Calculate all pairwise 3d distances
        else:
            distances = sqrt((s.xlocs-s.xlocs[gid])**2 + (s.ylocs-s.ylocs[gid])**2) # Calculate all pairwise distances
            distances3d = sqrt((s.xlocs-s.xlocs[gid])**2 + (s.ylocs-s.ylocs[gid])**2 + (s.zlocs-s.zlocs[gid])**2) # Calculate all pairwise distances
        allconnprobs = s.scaleconnprob[s.EorI,s.EorI[gid]] * s.connprobs[s.cellpops,s.cellpops[gid]] * exp(-distances/s.connfalloff[s.EorI]) # Calculate pairwise probabilities
        allconnprobs[gid] = 0 # Prohibit self-connections using the cell's GID
        seed(s.id32('%d'%(s.randseed+gid))) # Reset random number generator
        allrands = rand(s.ncells) # Create an array of random numbers for checking each connection
        if s.PMdinput == 'Plexon':
            for c in range(s.popGidStart[s.PMd], s.popGidEnd[s.PMd] + 1):
                allrands[c] = 1
            if s.cellnames[gid] == 'ER5': # PMd->ER5 conn (full conn)
                PMdId = (gid % s.server.numPMd) + s.ncells - s.server.numPMd #CHECK THIS!
                allconnprobs[PMdId] = s.connprobs[s.PMd,s.ER5] # to make this connected to ER5
                allrands[PMdId] = 0 # to make this connect to ER5
                distances[PMdId] = 300 # to make delay 5 in conndata[3]
        makethisconnection = allconnprobs>allrands # Perform test to see whether or not this connection should be made
        preids = array(makethisconnection.nonzero()[0],dtype='int') # Return True elements of that array for presynaptic cell IDs
        if s.PMdinput == 'targetSplit' and s.cellnames[gid] == 'ER5': # PMds 0-47 -> ER5 0-47 ; PMds 48-95 -> ER5 48-95
            if gid < s.popGidStart[s.ER5] + s.popnumbers[s.ER5]/2:
                prePMd = [(x - s.popGidStart[s.ER5])%(s.popnumbers[s.PMd]/2) + s.popGidStart[s.PMd] for x in range(gid, gid+1)] # input from 2 PMds
            else:
                prePMd = [(x - s.popGidStart[s.ER5])%(s.popnumbers[s.PMd]/2) + s.popGidStart[s.PMd] + s.popnumbers[s.PMd]/2 for x in range(gid, gid+1)] # input from 2 PMds
            if array(prePMd).all() < s.popGidEnd[s.PMd]:
                #print 'prePMd=%d to ER5=%d:'%(prePMd[0],gid)
                preids = concatenate([preids, prePMd])
        if s.cellnames[gid] == 'EDSC': # save EDSC presyn cells to replicate in IDSC, and add inputs from IDSC
            EDSCpre.append(array(preids)) # save EDSC presyn cells before adding IDSC input
            invPops = [1, 0, 3, 2] # each postsyn ESDC cell will receive input from all the antagonistic muscle IDSCs
            IDSCpre = [s.motorCmdCellRange[invPops[i]] - s.popGidStart[s.EDSC] + s.popGidStart[s.IDSC] for i in range(s.nMuscles) if gid in s.motorCmdCellRange[i]][0]
            preids = concatenate([preids, IDSCpre]) # add IDSC presynaptic input to EDSC
        elif s.cellnames[gid] == 'IDSC': # use same presyn cells as for EDSC (antagonistic inhibition)
            preids = array(EDSCpre.pop(0))
        postids = array(gid+zeros(len(preids)),dtype='int') # Post-synaptic cell IDs
        s.conndata[0].append(preids) # Append pre-cell ID
        s.conndata[1].append(postids) # Append post-cell ID
        s.conndata[2].append(distances[preids]) # Distances
        s.conndata[3].append(s.mindelay + distances3d[preids]/float(s.velocity)) # Calculate the delays
        wt1 = s.scaleconnweight[s.EorI[preids],s.EorI[postids]] # N weight scale factors -- WARNING, might be flipped
        wt2 = s.connweights[s.cellpops[preids],s.cellpops[postids],:] # NxM inter-population weights
        wt3 = s.receptorweight[:] # M receptor weights
        finalweights = transpose(wt1*transpose(wt2*wt3)) # Multiply out population weights with receptor weights to get NxM matrix
        s.conndata[4].append(finalweights) # Initialize weights to 0, otherwise get memory leaks
    for pp in range(s.nconnpars): s.conndata[pp] = array(concatenate([s.conndata[pp][c] for c in range(nPostCells)])) # Turn pre- and post- cell IDs lists into vectors
    s.nconnections = len(s.conndata[0]) # Find out how many connections we're going to make
    conncalctime = time()-conncalcstart # See how long it took
    if s.rank==0: print(('  Done; time = %0.1f s' % conncalctime))


    # set plastic connections based on plasConnsType (from evol alg)
    if s.plastConnsType == 0:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC]] # only spinal cord
    elif s.plastConnsType == 1:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2]] # + L2-L5
    elif s.plastConnsType == 2:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2],\
        [s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5]] # + L6
    elif s.plastConnsType == 3:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2],\
         [s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5], \
         [s.ER2,s.IL2], [s.ER2,s.IF2], [s.ER5,s.IL5], [s.ER5,s.IF5], [s.EB5,s.IL5], [s.EB5,s.IF5]] # + Inh
    # same with additional plasticity between PMd->L5A
    elif s.plastConnsType == 4:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5]] # only spinal cord + pmd
    elif s.plastConnsType == 5:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5], # spinal cord + pmd
         [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2]] # + L2-L5
    elif s.plastConnsType == 6:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5], # spinal cord + pmd
        [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2], # + L2-L5
        [s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5]] # + L6
    elif s.plastConnsType == 7:
        s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5], # spinal cord + pmd
        [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2], # + L2-L5
        [s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5], # + L6
        [s.ER2,s.IL2], [s.ER2,s.IF2], [s.ER5,s.IL5], [s.ER5,s.IF5], [s.EB5,s.IL5], [s.EB5,s.IF5]]  # + Inh



    ## Actually make connections
    if s.rank==0: print(('Making connections (est. time: %i s)...' % (s.performance*s.nconnections/9e2)))
    print(('  Number of connections on host %i: %i' % (s.rank, s.nconnections)))
    connstart = time() # See how long connecting the cells takes
    s.connlist = [] # Create array for storing each of the connections
    s.stdpconndata = [] # Store data on STDP connections
    if s.usestdp: # STDP enabled?
        s.stdpmechs = [] # Initialize array for STDP mechanisms
        s.precons = [] # Initialize array for presynaptic spike counters
        s.pstcons = [] # Initialize array for postsynaptic spike counters
    for con in range(s.nconnections): # Loop over each connection
        pregid = s.conndata[0][con] # GID of presynaptic cell
        pstgid = s.conndata[1][con] # Index of postsynaptic cell
        pstid = s.gidDic[pstgid]# Index of postynaptic cell -- convert from GID to local
        newcon = s.pc.gid_connect(pregid, s.cells[pstid]) # Create a connection
        newcon.delay = s.conndata[3][con] # Set delay
        for r in range(s.nreceptors): newcon.weight[r] = s.conndata[4][con][r] # Set weight of connection
        s.connlist.append(newcon) # Connect the two cells
        if s.usestdp and ([s.cellpops[pregid],s.cellpops[pstgid]] in s.plastConns): # If using STDP and these pops are set to be plastic connections
            if sum(abs(s.stdprates[s.EorI[pregid],:]))>0 or sum(abs(s.RLrates[s.EorI[pregid],:]))>0: # Don't create an STDP connection if the learning rates are zero
                for r in range(s.nreceptors): # Need a different STDP instances for each receptor
                    if newcon.weight[r]>0: # Only make them for nonzero connections
                        stdpmech = h.STDP(0,sec=s.dummies[pstid]) # Create STDP adjuster
                        stdpmech.hebbwt = s.stdprates[s.EorI[pregid],0] # Potentiation rate
                        stdpmech.antiwt = s.stdprates[s.EorI[pregid],1] # Depression rate
                        stdpmech.wmax = s.maxweight # Maximum synaptic weight
                        precon = s.pc.gid_connect(pregid,stdpmech); precon.weight[0] = 1 # Send presynaptic spikes to the STDP adjuster
                        pstcon = s.pc.gid_connect(pstgid,stdpmech); pstcon.weight[0] = -1 # Send postsynaptic spikes to the STDP adjuster
                        h.setpointer(s.connlist[-1]._ref_weight[r],'synweight',stdpmech) # Associate the STDP adjuster with this weight
                        s.stdpmechs.append(stdpmech) # Save STDP adjuster
                        s.precons.append(precon) # Save presynaptic spike source
                        s.pstcons.append(pstcon) # Save postsynaptic spike source
                        s.stdpconndata.append([pregid,pstgid,r]) # Store presynaptic cell ID, postsynaptic, and receptor
                        if s.verbose: stdpmech.verbose = 1
                        if s.useRL: # using RL
                            stdpmech.RLon = 1 # make sure RL is on
                            stdpmech.RLhebbwt = s.RLrates[s.EorI[pregid],0] # Potentiation rate
                            stdpmech.RLantiwt = s.RLrates[s.EorI[pregid],1] # Depression rate
                            stdpmech.tauhebb = stdpmech.tauanti = s.stdpwin # stdp time constant(ms)
                            stdpmech.RLwindhebb = stdpmech.RLwindhebb = s.eligwin # RL eligibility trace window length (ms)
                            stdpmech.useRLexp = s.useRLexp # RL
                            stdpmech.softthresh = s.useRLsoft # RL soft-thresholding
                        else:
                            stdpmech.RLon = 0 # make sure RL is off

    s.nstdpconns = len(s.stdpconndata) # Get number of STDP connections
    conntime = time()-connstart # See how long it took
    if s.usestdp: print(('  Number of STDP connections on host %i: %i' % (s.rank, s.nstdpconns)))
    if s.rank==0: print(('  Done; time = %0.1f s' % conntime))


###############################################################################
### Add stimulation
###############################################################################

def addStimulation():
    if s.usestims:
        s.stimstruct = [] # For saving
        s.stimrands=[] # Create input connections
        s.stimsources=[] # Create empty list for storing synapses
        s.stimconns=[] # Create input connections
        s.stimtimevecs = [] # Create array for storing time vectors
        s.stimweightvecs = [] # Create array for holding weight vectors
        if s.saveraw:
            s.stimspikevecs=[] # A list for storing actual cell voltages (WARNING, slow!)
            s.stimrecorders=[] # And for recording spikes
        for stim in range(len(s.stimpars)): # Loop over each stimulus type
            ts = s.stimpars[stim] # Stands for "this stimulus"
            ts.loc = ts.loc * s.modelsize # scale cell locations to model size
            stimvecs = s.makestim(ts.isi, ts.var, ts.width, ts.weight, ts.sta, ts.fin, ts.shape) # Time-probability vectors
            s.stimstruct.append([ts.name, stimvecs]) # Store for saving later
            s.stimtimevecs.append(h.Vector().from_python(stimvecs[0]))

            for c in range(s.cellsperhost):
                gid = s.cellsperhost*int(s.rank)+c # For deciding E or I
                seed(s.id32('%d'%(s.randseed+gid))) # Reset random number generator for this cell
                if ts.fraction>rand(): # Don't do it for every cell necessarily
                    if any(s.cellpops[gid]==ts.pops) and s.xlocs[gid]>=ts.loc[0,0] and s.xlocs[gid]<=ts.loc[0,1] and s.ylocs[gid]>=ts.loc[1,0] and s.ylocs[gid]<=ts.loc[1,1]:

                        maxweightincrease = 20 # Otherwise could get infinitely high, infinitely close to the stimulus
                        distancefromstimulus = sqrt(sum((array([s.xlocs[gid], s.ylocs[gid]])-s.modelsize*ts.falloff[0])**2))
                        fallofffactor = min(maxweightincrease,(ts.falloff[1]/distancefromstimulus)**2)
                        s.stimweightvecs.append(h.Vector().from_python(stimvecs[1]*fallofffactor)) # Scale by the fall-off factor

                        stimrand = h.Random()
                        stimrand.MCellRan4() # If everything has the same seed, should happen at the same time
                        stimrand.negexp(1)
                        stimrand.seq(s.id32('%d'%(s.randseed+gid))*1e3) # Set the sequence i.e. seed
                        s.stimrands.append(stimrand)

                        stimsource = h.NetStim() # Create a NetStim
                        stimsource.interval = ts.rate**-1*1e3 # Interval between spikes
                        stimsource.number = 1e9 # Number of spikes
                        stimsource.noise = ts.noise # Fractional noise in timing
                        stimsource.noiseFromRandom(stimrand) # Set it to use this random number generator
                        s.stimsources.append(stimsource) # Save this NetStim

                        stimconn = h.NetCon(stimsource, s.cells[c]) # Connect this noisy input to a cell
                        for r in range(s.nreceptors): stimconn.weight[r]=0 # Initialize weights to 0, otherwise get memory leaks
                        s.stimweightvecs[-1].play(stimconn._ref_weight[0], s.stimtimevecs[-1]) # Play most-recently-added vectors into weight
                        stimconn.delay=s.mindelay # Specify the delay in ms -- shouldn't make a spot of difference
                        s.stimconns.append(stimconn) # Save this connnection

                        if s.saveraw: # and c <=100:
                            stimspikevec = h.Vector() # Initialize vector
                            s.stimspikevecs.append(stimspikevec) # Keep all those vectors
                            stimrecorder = h.NetCon(stimsource, None)
                            stimrecorder.record(stimspikevec) # Record simulation time
                            s.stimrecorders.append(stimrecorder)
        print(('  Number of stimuli created on host %i: %i' % (s.rank, len(s.stimsources))))


###############################################################################
### Add background inputs
###############################################################################
def addBackground():
    if s.rank==0: print('Creating background inputs...')
    s.backgroundsources=[] # Create empty list for storing synapses
    s.backgroundrands=[] # Create random number generators
    s.backgroundconns=[] # Create input connections
    s.backgroundgid=[] # Target cell gid for each input
    if s.savebackground:
        s.backgroundspikevecs=[] # A list for storing actual cell voltages (WARNING, slow!)
        s.backgroundrecorders=[] # And for recording spikes
    for c in range(s.cellsperhost):
        gid = s.gidVec[c]
        if s.cellnames[gid] == 'ASC' or s.cellnames[gid] == 'PMd' : # These pops won't receive background stimulations.
            pass
        else:
            backgroundrand = h.Random()
            backgroundrand.MCellRan4(gid,gid*2)
            backgroundrand.negexp(1)
            s.backgroundrands.append(backgroundrand)
            if s.cellnames[gid] == 'EDSC' or s.cellnames[gid] == 'IDSC':
                backgroundsource = h.NSLOC() # Create a NSLOC
                backgroundsource.interval = s.backgroundrateMin**-1*1e3 # Take inverse of the frequency and then convert from Hz^-1 to ms
                backgroundsource.noise = 0.3 # Fractional noise in timing
            elif s.cellnames[gid] == 'EB5':
                backgroundsource = h.NSLOC() # Create a NSLOC
                backgroundsource.interval = s.backgroundrate**-1*1e3 # Take inverse of the frequency and then convert from Hz^-1 to ms
                backgroundsource.noise = s.backgroundnoise # Fractional noise in timing
            else:
                backgroundsource = h.NetStim() # Create a NetStim
                backgroundsource.interval = s.backgroundrate**-1*1e3 # Take inverse of the frequency and then convert from Hz^-1 to ms
                backgroundsource.noiseFromRandom(backgroundrand) # Set it to use this random number generator
                backgroundsource.noise = s.backgroundnoise # Fractional noise in timing

            backgroundsource.number = s.backgroundnumber # Number of spikes
            s.backgroundsources.append(backgroundsource) # Save this NetStim
            s.backgroundgid.append(gid) # append cell gid associated to this netstim

            backgroundconn = h.NetCon(backgroundsource, s.cells[c]) # Connect this noisy input to a cell
            for r in range(s.nreceptors): backgroundconn.weight[r]=0 # Initialize weights to 0, otherwise get memory leaks
            if s.cellnames[gid] == 'EDSC' or s.cellnames[gid] == 'IDSC':
                backgroundconn.weight[s.backgroundreceptor] = s.backgroundweightExplor # Specify the weight for the EDSC, IDSC and PMd background input
            elif s.cellnames[gid] == 'EB5' and s.explorMovs == 2:
                backgroundconn.weight[s.backgroundreceptor] = s.backgroundweightExplor # Weight for EB5 input if explor movs via EB5
            else:
                backgroundconn.weight[s.backgroundreceptor] = s.backgroundweight[s.EorI[gid]] # Specify the weight -- 1 is NMDA receptor for smoother, more summative activation
            backgroundconn.delay=2 # Specify the delay in ms -- shouldn't make a spot of difference
            s.backgroundconns.append(backgroundconn) # Save this connnection

            if s.savebackground:
                backgroundspikevec = h.Vector() # Initialize vector
                s.backgroundspikevecs.append(backgroundspikevec) # Keep all those vectors
                backgroundrecorder = h.NetCon(backgroundsource, None)
                backgroundrecorder.record(backgroundspikevec) # Record simulation time
                s.backgroundrecorders.append(backgroundrecorder)
    print(('  Number created on host %i: %i' % (s.rank, len(s.backgroundsources))))
    s.pc.barrier()


###############################################################################
### Setup Simulation
###############################################################################
def setupSim():
    ## reset time variables
    s.timeoflastRL = -inf # Never RL
    s.timeoflastsave = -inf # Never saved
    s.timeoflastexplor = -inf # time when last exploratory movement was updated

    # Initialize STDP -- just for recording
    if s.usestdp:
        s.weightchanges = []
        if s.rank==0: print('\nSetting up STDP...')
        if s.usestdp:
            s.weightchanges = [[] for ps in range(s.nstdpconns)] # Create an empty list for each STDP connection -- warning, slow with large numbers of connections!
        for ps in range(s.nstdpconns): s.weightchanges[ps].append([0, s.stdpmechs[ps].synweight]) # Time of save (0=initial) and the weight


    ## Set up LFP recording
    s.lfptime = [] # List of times that the LFP was recorded at
    s.nlfps = len(s.lfppops) # Number of distinct LFPs to calculate
    s.hostlfps = [] # Voltages for calculating LFP
    s.lfpcellids = [[] for pop in range(s.nlfps)] # Create list of lists of cell IDs
    for c in range(s.cellsperhost): # Loop over each cell and decide which LFP population, if any, it belongs to
        gid = s.gidVec[c] # Get this cell's GID
        if s.cellnames[gid] == 'ASC' or s.cellnames[gid] == 'PMd': # 'ER2' won't be fired by background stimulations.
            continue
        for pop in range(s.nlfps): # Loop over each LFP population
            thispop = s.cellpops[gid] # Population of this cell
            if sum(s.lfppops[pop]==thispop)>0: # There's a match
                s.lfpcellids[pop].append(gid) # Flag this cell as belonging to this LFP population


    ## Set up raw recording
    s.rawrecordings = [] # A list for storing actual cell voltages (WARNING, slow!)
    if s.saveraw:
        if s.rank==0: print('\nSetting up raw recording...')
        s.nquantities = 5 # Number of variables from each cell to record from
        # Later this part should be modified because NSLOC doesn't have V, u and I.
        for c in range(s.cellsperhost):
            gid = s.gidVec[c] # Get this cell's GID
            if s.cellnames[gid] == 'ASC' or s.cellnames[gid] == 'PMd': # NSLOC doesn't have V, u and I
                continue
            recvecs = [h.Vector() for q in range(s.nquantities)] # Initialize vectors
            recvecs[0].record(h._ref_t) # Record simulation time
            recvecs[1].record(s.cells[c]._ref_V) # Record cell voltage
            recvecs[2].record(s.cells[c]._ref_u) # Record cell recovery variable
            recvecs[3].record(s.cells[c]._ref_I) # Record cell current
            recvecs[4].record(s.cells[c]._ref_gAMPA)
            # recvecs[5].record(s.cells[c]._ref_gNMDA)
            # recvecs[6].record(s.cells[c]._ref_gGABAA)
            # recvecs[7].record(s.cells[c]._ref_gGABAB)
            # recvecs[8].record(s.cells[c]._ref_gOpsin)
            s.rawrecordings.append(recvecs) # Keep all those vectors


    ## Set up virtual arm
    if s.useArm != 'None':
        if s.rank==0: print('\nSetting up virtual arm...')
        s.arm = Arm(s.useArm, s.animArm, s.graphsArm)
        s.arm.targetid = s.targetid
        s.arm.setup(s)#duration, loopstep, RLinterval, pc, scale, popnumbers, p)


    ## Communication setup for plexon input
    if s.PMdinput == 'Plexon':
        h('''
            objref cvode
            cvode = new CVode()
            tstop = 0
        ''')

        if s.isOriginal == 0: # With communication program
            if s.rank == 0:
                #serverManager = s.server.Manager() # isDp in confis.py = 0
                s.server.Manager.start() # launch sever process
                print("Server process completed and callback function initalized")
            e = s.server.Manager.Event() # Queue callback function in the NEURON queue


        # Wait for external spikes for PMd from Plexon
            if s.rank == 0:
                if s.server.isCommunication == 1:
                    s.server.getServerInfo() # show parameters of the server process
                    print("[Waiting for spikes; run the client on Windows machine...]")
                    while s.server.queue.empty(): # only Rank 0 is waiting for spikes in the queue.
                        pass
            s.pc.barrier() # other workers are waiting here.

    ## Play back raw spikes from recorded PMd neurons, into model PMd population
    elif s.PMdinput == 'spikes':
        # if s.rank == 0:
        rawSpikesPMd = loadmat(s.spikesPMdFile)['pmdData'] # load raw data
        numrawcells = len(rawSpikesPMd[0][0])

        # implement lesion
        if s.PMdlesion > 0:
            numLesionedCells = int(round(s.PMdlesion*numrawcells))
            if s.rank==0: print("Lesioning PMd input... removed spike times of %d (%d%%) cells"%(numLesionedCells, s.PMdlesion*100))
            for itarget in range(len(rawSpikesPMd)):
                for itrial in range(len(rawSpikesPMd[0])):
                    for icell in range(numLesionedCells):
                        rawSpikesPMd[itarget][itrial][numrawcells-1-icell]['spkt'] = [array([])]  # remove spike times of lesion % of cells

        # generate spike vectors based on training time, trial duration, and target presentation (eg. alternating trials)
        if s.duration == 1e3: # if sim duration=1sec, assume its the test trial and select PMd spikes based on targetid
            spktPMd = [rawSpikesPMd[s.targetid][s.repeatSingleTrials[s.targetid]][icell]['spkt'][0] for icell in range(numrawcells)]
        elif s.repeatSingleTrials[0] > -1: # use single trials for each target during training
            spktPMd = []
            for icell in range(numrawcells):
                spkt = []
                [spkt.extend(list(rawSpikesPMd[itarget][s.repeatSingleTrials[itarget]][icell]['spkt'][0] + (1000*itrial))) \
                    for itrial,itarget in enumerate(s.trialTargets[:-1])] # replicate spike times over trials
                spktPMd.append(spkt)
        else: # use all available trials for each target during training
            pass
        # play back PMd spikes using VecStims
        s.tvecPMdlist = []
        gids = [i for i in s.gidVec if i in range(s.popGidStart[s.PMd], s.popGidEnd[s.PMd])] # calcualate gids in this node
        for icell in gids: # for each unique cell/vecstim
            spkcell = spktPMd[icell%numrawcells]
            tvecPMd = h.Vector().from_python(spkcell) # find spikes for that vecstim
            s.tvecPMdlist.append(tvecPMd)  # store vector to avoid runtime error
            s.cells[s.gidDic[icell]].play(tvecPMd)  # play back sequence of spikes


###############################################################################
### Run Simulation
###############################################################################
def runSim():
    if s.rank == 0:
        print('\nRunning...')
        runstart = time() # See how long the run takes

    # set cache_efficient on
    h('objref cvode')
    h('cvode = new CVode()')
    h.cvode.cache_efficient(1)

    s.pc.set_maxstep(10) # MPI: Set the maximum integration time in ms -- not very important
    init() # Initialize the simulation

    while round(h.t) < s.duration:
        run(min(s.duration,h.t+s.loopstep)) # MPI: Get ready to run the simulation (it isn't actually run until pc.runworker() is called I think)
        if s.server.simMode == 0:
            if s.rank==0 and (round(h.t) % s.progupdate)==0: print(('  t = %0.1f s (%i%%; time consumed: %0.1f s)' % (h.t/1e3, int(h.t/s.duration*100), (time()-runstart))))
        else:
            if s.rank==0: print(('  t = %0.1f s (%i%%; time consumed: %0.1f s)' % (h.t/1e3, int(h.t/s.duration*100), (time()-runstart))))

        # Calculate LFP -- WARNING, need to think about how to optimize
        if s.savelfps:
            s.lfptime.append(h.t) # Append current time
            tmplfps = zeros((s.nlfps)) # Create empty array for storing LFP voltages
            for pop in range(s.nlfps):
                for c in range(len(s.lfpcellids[pop])):
                    id = s.gidDic[s.lfpcellids[pop][c]]# Index of postynaptic cell -- convert from GID to local
                    tmplfps[pop] += s.cells[id].V # Add voltage to LFP estimate
                if s.verbose:
                    if s.server.Manager.ns.isnan(tmplfps[pop]) or s.server.Manager.ns.isinf(tmplfps[pop]):
                        print("Nan or inf")
            s.hostlfps.append(tmplfps) # Add voltages

        # Periodic weight saves
        if s.usestdp:
            timesincelastsave = h.t - s.timeoflastsave
            if timesincelastsave >= s.timebetweensaves:
                s.timeoflastsave = h.t
                #if s.rank == 0: print 'Recording weight changes at time ', h.t
                for ps in range(s.nstdpconns):
                    if s.stdpmechs[ps].synweight != s.weightchanges[ps][-1][-1]: # Only store connections that changed; [ps] = this connection; [-1] = last entry; [-1] = weight
                        s.weightchanges[ps].append([s.timeoflastsave, s.stdpmechs[ps].synweight])

        ## Virtual arm
        if s.useArm != 'None':
            armStart = time()
            s.arm.run(h.t, s) # run virtual arm apparatus (calculate command, move arm, feedback)
            if s.useRL and (h.t - s.timeoflastRL >= s.RLinterval): # if time for next RL
                s.timeoflastRL = h.t
                vec = h.Vector()
                if s.rank == 0:
                    critic = s.arm.RLcritic(h.t) # get critic signal (-1, 0 or 1)
                    s.pc.broadcast(vec.from_python([critic]), 0) # convert python list to hoc vector for broadcast data received from arm
                    #print critic
                else: # other workers
                    s.pc.broadcast(vec, 0)
                    critic = vec.to_python()[0]
                if critic != 0: # if critic signal indicates punishment (-1) or reward (+1)
                    for stdp in s.stdpmechs: # for all connections in stdp conn list
                        #print 'stdp_before: ', stdp.synweight
                        stdp.reward_punish(float(critic)) # run stds.mod method to update syn weights based on RL
                        #print stdp.tlastpre
                        #print stdp.tlastpost
                        #stdp.adjustweight(float(0.5))
                        #sleep(0.001)
                        #print 'stdp_after: ', stdp.synweight
            # Synaptic scaling?

            #print(' Arm time = %0.4f s') % (time() - armStart)

        ## Time adjustment for online mode simulation
        if s.PMdinput == 'Plexon' and s.server.simMode == 1:
            # To avoid izhi cell's over shooting when h.t moves forward because sim is slow.
            for c in range(s.cellsperhost):
                gid = s.gidVec[c]
                if s.cellnames[gid] == 'PMd': # 'PMds don't have t0 variable.
                    continue
                s.cells[c].t0 = s.server.newCurrTime.value - h.dt
            dtSave = h.dt # save original dt
            h.dt = s.server.newCurrTime.value - h.t # new dt
            active = h.cvode.active()
            if active != 0:
                h.cvode.active(0)
            h.fadvance() # Integrate with new dt
            if active != 0:
                h.cvode.active(1)
            h.dt = dtSave # Restore orignal dt

    if s.rank==0:
        s.runtime = time()-runstart # See how long it took
        print(('  Done; run time = %0.1f s; real-time ratio: %0.2f.' % (s.runtime, s.duration/1000/s.runtime)))
    s.pc.barrier() # Wait for all hosts to get to this point


###############################################################################
### Finalize Simulation  (gather data from nodes, etc.)
###############################################################################
def finalizeSim():

    ## Variables to unpack data from all hosts

    ## Pack data from all hosts
    if s.rank==0: print('\nGathering spikes...')
    gatherstart = time() # See how long it takes to plot
    for host in range(s.nhosts): # Loop over hosts
        if host==s.rank: # Only act on a single host
            hostspikecells=array([])
            hostspiketimes=array([])
            for c in range(len(s.hostspikevecs)): # fails when saving raw
                thesespikes = array([s.hostspikevecs[c].x[i] for i in range(s.hostspikevecs[c].size())]) # Convert spike times to an array
                nthesespikes = len(thesespikes) # Find out how many of spikes there were for this cell
                hostspiketimes = concatenate((hostspiketimes, thesespikes)) # Add spikes from this cell to the list
                #hostspikecells = concatenate((hostspikecells, (c+host*s.cellsperhost)*ones(nthesespikes))) # Add this cell's ID to the list
                hostspikecells = concatenate((hostspikecells, s.gidVec[c]*ones(nthesespikes))) # Add this cell's ID to the list
            if s.saveraw:
                for c in range(len(s.rawrecordings)):
                    for q in range(len(s.rawrecordings[c])):
                        s.rawrecordings[c][q] = array(s.rawrecordings[c][q])
            messageid=s.pc.pack([hostspiketimes, hostspikecells, s.hostlfps, s.conndata, s.stdpconndata, s.weightchanges, s.rawrecordings]) # Create a mesage ID and store this value
            s.pc.post(host,messageid) # Post this message


    ## Unpack data from all hosts
    if s.rank==0: # Only act on a single host
        s.allspikecells = array([])
        s.allspiketimes = array([])
        s.lfps = zeros((len(s.lfptime),s.nlfps)) # Create an empty array for appending LFP data; first entry is for time
        s.allconnections = [array([]) for i in range(s.nconnpars)] # Store all connections
        s.allconnections[s.nconnpars-1] = zeros((0,s.nreceptors)) # Create an empty array for appending connections
        s.allstdpconndata = zeros((0,3)) # Create an empty array for appending STDP connection data
        if s.usestdp: s.allweightchanges = [] # empty list so weightchanges in this node don't appear twice
        s.totalspikes = 0 # Keep a running tally of the number of spikes
        s.totalconnections = 0 # Total number of connections
        s.totalstdpconns = 0 # Total number of stdp connections
        if s.saveraw: s.allraw = []
        for host in range(s.nhosts): # Loop over hosts
            s.pc.take(host) # Get the last message
            hostdata = s.pc.upkpyobj() # Unpack them
            s.allspiketimes = concatenate((s.allspiketimes, hostdata[0])) # Add spikes from this cell to the list
            s.allspikecells = concatenate((s.allspikecells, hostdata[1])) # Add this cell's ID to the list
            if s.savelfps: s.lfps += array(hostdata[2]) # Sum LFP voltages
            for pp in range(s.nconnpars): s.allconnections[pp] = concatenate((s.allconnections[pp], hostdata[3][pp])) # Append pre/post synapses
            if s.usestdp and len(hostdata[4]): # Using STDP and at least one STDP connection
                s.allstdpconndata = concatenate((s.allstdpconndata, hostdata[4])) # Add data on STDP connections
                for ps in range(len(hostdata[4])): s.allweightchanges.append(hostdata[5][ps]) # "ps" stands for "plastic synapse"
            if s.saveraw:
                for c in range(len(hostdata[6])): s.allraw.append(hostdata[6][c]) # Append cell-by-cell

        s.totalspikes = len(s.allspiketimes) # Keep a running tally of the number of spikes
        s.totalconnections = len(s.allconnections[0]) # Total number of connections
        s.totalstdpconns = len(s.allstdpconndata) # Total number of STDP connections


    # Record background spike data (cliff: only for one node since takes too long to pack for all and just needed for debugging)
    if s.savebackground and s.usebackground:
        s.allbackgroundspikecells=array([])
        s.allbackgroundspiketimes=array([])
        for c in range(len(s.backgroundspikevecs)):
            thesespikes = array(s.backgroundspikevecs[c])
            s.allbackgroundspiketimes = concatenate((s.allbackgroundspiketimes, thesespikes)) # Add spikes from this stimulator to the list
            s.allbackgroundspikecells = concatenate((s.allbackgroundspikecells, c+zeros(len(thesespikes)))) # Add this cell's ID to the list
        s.backgrounddata = transpose(vstack([s.allbackgroundspikecells,s.allbackgroundspiketimes]))
    else: s.backgrounddata = [] # For saving s no error

    if s.saveraw and s.usestims:
        s.allstimspikecells=array([])
        s.allstimspiketimes=array([])
        for c in range(len(s.stimspikevecs)):
            thesespikes = array(s.stimspikevecs[c])
            s.allstimspiketimes = concatenate((s.allstimspiketimes, thesespikes)) # Add spikes from this stimulator to the list
            s.allstimspikecells = concatenate((s.allstimspikecells, c+zeros(len(thesespikes)))) # Add this cell's ID to the list
        s.stimspikedata = transpose(vstack([s.allstimspikecells,s.allstimspiketimes]))
    else: s.stimspikedata = [] # For saving so no error

    gathertime = time()-gatherstart # See how long it took
    if s.rank==0: print(('  Done; gather time = %0.1f s.' % gathertime))
    s.pc.barrier()

    #mindelay = s.pc.allreduce(s.pc.set_maxstep(10), 2) # flag 2 returns minimum value
    #if s.rank==0: print 'Minimum delay (time-step for queue exchange) is ',mindelay


    ## Finalize virtual arm (es. close pipes, saved data)
    if s.useArm != 'None':
        s.arm.close(s)


    # terminate the server process
    if s.PMdinput == 'Plexon':
        if s.isOriginal == 0:
            s.server.Manager.stop()


    ## Print statistics
    if s.rank == 0:
        print('\nAnalyzing...')
        s.firingrate = float(s.totalspikes)/s.ncells/s.duration*1e3 # Calculate firing rate -- confusing but cool Python trick for iterating over a list
        s.connspercell = s.totalconnections/float(s.ncells) # Calculate the number of connections per cell
        print(('  Run time: %0.1f s (%i-s sim; %i scale; %i cells; %i workers)' % (s.runtime, s.duration/1e3, s.scale, s.ncells, s.nhosts)))
        print(('  Spikes: %i (%0.2f Hz)' % (s.totalspikes, s.firingrate)))
        print(('  Connections: %i (%i STDP; %0.2f per cell)' % (s.totalconnections, s.totalstdpconns, s.connspercell)))
        print(('  Mean connection distance: %0.2f um' % mean(s.allconnections[2])))
        print(('  Mean connection delay: %0.2f ms' % mean(s.allconnections[3])))


###############################################################################
### Save data
###############################################################################
def saveData():
    if s.rank == 0:
        ## Save to txt file (spikes and conn)
        if s.savetxt:
            filename = '../data/m1ms-spk.txt'
            fd = open(filename, "w")
            for c in range(len(s.allspiketimes)):
                print(int(s.allspikecells[c]), s.allspiketimes[c], s.popNamesDic[s.cellnames[int(s.allspikecells[c])]], file=fd)
            fd.close()
            print("[Spikes are stored in", filename, "]")

            if s.verbose:
                filename = 'm1ms-conn.txt'
                fd = open(filename, "w")
                for c in range(len(s.allconnections[0])):
                    print(int(s.allconnections[0][c]), int(s.allconnections[1][c]), s.allconnections[2][c], s.allconnections[3][c], s.allconnections[4][c], file=fd)
                fd.close()
                print("[Connections are stored in", filename, "]")

        ## Save to mat file
        if s.savemat:
            savestart = time() # See how long it takes to save

            # Save simulation code
            filestosave = [] #'main.py', 'shared.py', 'network.py', 'arm.py', 'arminterface.py', 'server.py', 'izhi.py', 'izhi2007.mod', 'stdp.mod', 'nsloc.py', 'nsloc.mod'] # Files to save
            argv = [];
            simcode = [argv, filestosave] # Start off with input parameters, if any, and then the list of files being saved
            for f in range(len(filestosave)): # Loop over each file
                fobj = open(filestosave[f]) # Open it for reading
                simcode.append(fobj.readlines()) # Append to list of code to save
                fobj.close() # Close file object

            # Tidy variables
            spikedata = vstack([s.allspikecells,s.allspiketimes]).T # Put spike data together
            connections = vstack([s.allconnections[0],s.allconnections[1]]).T # Put connection data together
            distances = s.allconnections[2] # Pull out distances
            delays = s.allconnections[3] # Pull out delays
            weights = s.allconnections[4] # Pull out weights
            stdpdata = s.allstdpconndata # STDP connection data
            if s.usestims: stimdata = [vstack(s.stimstruct[c][1]).T for c in range(len(stimstruct))] # Only pull out vectors, not text, in stimdata

            # Save variables
            info = {'timestamp':datetime.today().strftime("%d %b %Y %H:%M:%S"), 'runtime':s.runtime, 'popnames':s.popnames, 'popEorI':s.popEorI} # Save date, runtime, and input arguments

            targetPos = s.arm.targetPos
            handPosAll = s.arm.handPosAll
            angAll = s.arm.angAll
            motorCmdAll = s.arm.motorCmdAll
            targetidAll = s.arm.targetidAll
            errorAll = s.arm.errorAll
            criticAll = s.arm.criticAll
            if not hasattr(s, 'phase'): s.phase = ''
            s.filename = s.outfilestem+'_target_'+str(s.arm.targetid)+s.phase
            if s.armMinimalSave: # save only data related to arm reaching (for evol alg)
                variablestosave = ['targetPos', 'angAll', 'motorCmdAll', 'errorAll']
            else:
                variablestosave = ['info', 'targetPos', 'angAll', 'motorCmdAll', 'errorAll', 'simcode', 'spikedata', 's.cellpops', 's.cellnames', 's.cellclasses', 's.xlocs', 's.ylocs', 's.zlocs', 'connections', 'distances', 'delays', 'weights', 's.EorI']

            if s.savelfps:
                variablestosave.extend(['s.lfptime', 's.lfps'])
            if s.usestdp:
                variablestosave.extend(['stdpdata', 's.allweightchanges'])
            if s.savebackground:
                variablestosave.extend(['s.backgrounddata'])
            if s.saveraw:
                variablestosave.extend(['s.stimspikedata', 's.allraw'])
            if s.usestims: variablestosave.extend(['stimdata'])
            savecommand = "savemat(s.filename, {"
            for var in range(len(variablestosave)): savecommand += "'" + variablestosave[var].replace('s.','') + "':" + variablestosave[var] + ", " # Create command out of all the variables
            savecommand = savecommand[:-2] + "}, oned_as='column')" # Omit final comma-space and complete command

            print(('Saving output as %s...' % s.filename))
            exec(savecommand) # Actually perform the save

            savetime = time()-savestart # See how long it took to save
            print(('  Done; time = %0.1f s' % savetime))


###############################################################################
### Plot data
###############################################################################
def plotData():
    ## Plotting
    if s.rank == 0:
        if s.plotraster: # Whether or not to plot
            if (s.totalspikes>s.maxspikestoplot):
                disp('  Too many spikes (%i vs. %i)' % (s.totalspikes, s.maxspikestoplot)) # Plot raster, but only if not too many spikes
            elif s.nhosts>1:
                disp('  Plotting raster despite using too many cores (%i)' % s.nhosts)
                analysis.plotraster()#;allspiketimes, allspikecells, EorI, ncells, connspercell, backgroundweight, firingrate, duration)
            else:
                print('Plotting raster...')
                analysis.plotraster()#allspiketimes, allspikecells, EorI, ncells, connspercell, backgroundweight, firingrate, duration)

        if s.plotpeth:
            print('Plotting PETH...')
            analysis.plotPETH()

        if s.plotconn:
            print('Plotting connectivity matrix...')
            analysis.plotconn()

        if s.plotpsd:
            print('Plotting power spectral density')
            analysis.plotpsd()

        if s.plotweightchanges:
            print('Plotting weight changes...')
            analysis.plotweightchanges()
            #analysis.plotmotorpopchanges()

        if s.plot3darch:
            print('Plotting 3d architecture...')
            analysis.plot3darch()

        show(block=False)