"""
MODEL
A large-scale network simulation for exploring traveling waves, stimuli,
and STDs. Built solely in Python, using Izhikevich neurons and with MPI
support. Runs in real-time with over 8000 cells when appropriately
parallelized.
M1 model extended to interface with Plexon-recorded PMd data, virtual arm,
and reinforcement learning
Usage:
python model.py # Run simulation, optionally plot a raster
python simmovie.py # Show a movie of the results
python model.py scale=20 # Run simulation, set scale=20
MPI usage:
mpiexec -n 4 nrniv -python -mpi model.py
Version: 2014feb21 by cliffk
2014sep19 modified by salvadord and giljael
"""################################################################################## IMPORT MODULES###############################################################################from pylab import seed, rand, sqrt, exp, transpose, ceil, concatenate, array, zeros, ones, vstack, show, disp, mean, inf, concatenate, unique, delete
from time import time, sleep
from datetime import datetime
from scipy.io import savemat, loadmat
import pickle
from neuron import h, init, run # Import NEURONimport shared as s # Import all shared variables and parametersimport analysis
from arm import Arm # Class with arm methods and variables################################################################################## Sequences of commands to run full model################################################################################ training and testing to 2 targets manuallydefrunTrainTest2targets():
# optimized values for musculoskeletal arm (here using dummy arm for demo purposes)
s.targetid = 0
s.trainTime = 2000# 85000 # using 2 sec for demo purposes
s.stdpwin = 48.5
s.eligwin = 117.8
s.RLfactor = 0.01#6
s.RLinterval = 76.8
s.backgroundrate = s.backgroundrateTest = 134.5
s.backgroundrateExplor = 5
s.cmdmaxrate = 528.8
s.PMdconnweight = 1.0
s.PMdconnprob = 2.4
s.useArm = 'dummyArm'#'musculoskeletal'
s.numTrials = ceil(s.trainTime/1000)
s.trialTargets = [i%2for i inrange(int(s.numTrials+1))] # set target for each trial
s.targetid=s.trialTargets[0]
verystart=time() # store initial time
s.plotraster = 1# set plotting params
s.plotconn = 0
s.plotweightchanges = 0
s.plot3darch = 0
s.graphsArm = 1
s.animArm = 1
s.savemat = 0# save data during testing
s.armMinimalSave = 1# save only arm related data# initialize network
createNetwork()
addStimulation()
addBackground()
# train
s.usestdp = 1# Whether or not to use STDP
s.useRL = 0# Where or not to use RL
s.explorMovs = 1# enable exploratory movements
s.antagInh = 0# enable exploratory movements
s.duration = s.trainTime # train time
setupSim()
runSim()
finalizeSim()
#saveData()
plotData()
# test target 0
s.backgroundrate=s.backgroundrateTest # 300
s.cmdmaxrate=s.cmdmaxrateTest # 15
addBackground()
s.usestdp = 0# Whether or not to use STDP
s.useRL = 0# Where or not to use RL
s.explorMovs = 0# disable exploratory movements
s.duration = s.testTime # testing time
s.armMinimalSave = 0# save only arm related data
s.targetid = 0
setupSim()
runSim()
finalizeSim()
#saveData()
plotData()
if s.rank == 0: # save error to file
error0 = mean(s.arm.errorAll)
print('Target error for target ',s.targetid,' is:', error0)
s.arm.plotTraj(s.outfilestem+'_t0.png')
# test target 1
s.targetid = 1
setupSim()
runSim()
finalizeSim()
saveData()
plotData()
if s.rank == 0: # save error to file
error1 = mean(s.arm.errorAll)
print('Target error for target 0=', error0, '; target 1=', error1)
s.arm.plotTraj(s.outfilestem+'_t1.png')
errorMean = (error0+error1)/2
errorFitness = errorMean + abs(error0-error1) # fitness penalizes difference between target errors
errorDic = {}
errorDic['error0'] = error0
errorDic['error1'] = error1
errorDic['meanError'] = errorMean
errorDic['errorFitness'] = errorFitness
print('Mean error = %.4f ; Mean error + difference (fitness) = %.4f'%(errorMean, errorFitness))
s.targetid = 0# so saves to correct file name (error of both targets saved to single file ending in target_0_error)withopen('%s_target_%d_error'% (s.outfilestem,s.targetid), 'wb') as f: # save avg error over targets to outfilestem
pickle.dump(errorDic, f)
## Wrapping up
s.pc.runworker() # MPI: Start simulations running on each host
s.pc.done() # MPI: Close MPI
totaltime = time()-verystart # See how long it took in totalprint(('\nDone; total time = %0.1f s.' % totaltime))
if (s.plotraster==Falseand s.plotconn==Falseand s.plotweightchanges==False): h.quit() # Quit extra processes, or everything if plotting wasn't requested (since assume non-interactive)# training and testing to 2 targets via evolutionary optim algorithm (batch, no graphics)defrunTrainTest2targetsOptim():
# evol optimizes the following:
s.RLrates = s.RLfactor*array([[0.25, -0.25], [0.0, 0.0]]) # RL potentiation/depression rates for E->anything and I->anything, e.g. [0,:] is pot/dep for E cells
s.connprobs[s.PMd,s.ER5]=s.PMdconnprob
s.connweights[s.PMd,s.ER5,s.AMPA]=s.PMdconnweight
s.verbose=0
s.useArm = 'musculoskeletal'#s.useArm = 'dummyArm'
s.numTrials = ceil(s.trainTime/1000)
s.trialTargets = [i%2for i inrange(int(s.numTrials+1))] # set target for each trial
s.targetid=s.trialTargets[0]
verystart=time() # store initial time
s.plotraster = 0# set plotting params
s.plotconn = 0
s.plotweightchanges = 0
s.plot3darch = 0
s.graphsArm = 0
s.animArm = 0
s.savemat = 0# save data during testing
s.armMinimalSave = 0# save only arm related data# train
s.usestdp = 1# Whether or not to use STDP
s.useRL = 1# Where or not to use RL
s.explorMovs = 1# enable exploratory movements
s.antagInh = 0# enable exploratory movements
s.duration = s.trainTime # train time
s.timebetweensaves = s.trainTime - 1000# initialize network
createNetwork()
addStimulation()
addBackground()
# run train
setupSim()
runSim()
finalizeSim()
#saveData()#plotData()if s.rank == 0: # save png of traj
s.arm.plotTraj(s.outfilestem+'_train.png') # save traj fig to file
analysis.plotweightchanges(s.outfilestem+'_train_weights.png')
test = 1
s.savemat = 1if test:
# test target 0#s.backgroundrate=s.backgroundrateTest # 300#s.cmdmaxrate=s.cmdmaxrateTest # 15
addBackground()
s.usestdp = 0# Whether or not to use STDP
s.useRL = 0# Where or not to use RL
s.explorMovs = 0# disable exploratory movements
s.duration = s.testTime # testing time
s.armMinimalSave = 0# save only arm related data
s.targetid = 0
setupSim()
runSim()
finalizeSim()
saveData()
#plotData()if s.rank == 0: # save error to file
error0 = mean(s.arm.errorAll)
print('Target error for target ',s.targetid,' is:', error0)
s.arm.plotTraj(s.outfilestem+'_t0.png')
analysis.plotraster(s.outfilestem+'_t0_raster.png')
# test target 1
s.targetid = 1
setupSim()
runSim()
finalizeSim()
saveData()
#plotData()if s.rank == 0: # save error to file
error1 = mean(s.arm.errorAll)
print('Target error for target 0=', error0, '; target 1=', error1)
s.arm.plotTraj(s.outfilestem+'_t1.png')
analysis.plotraster(s.outfilestem+'_t1_raster.png')
errorMean = (error0+error1)/2
errorFitness = errorMean + abs(error0-error1) # fitness penalizes difference between target errors
errorDic = {}
errorDic['error0'] = error0
errorDic['error1'] = error1
errorDic['meanError'] = errorMean
errorDic['errorFitness'] = errorFitness
print('Mean error = %.4f ; Mean error + difference (fitness) = %.4f'%(errorMean, errorFitness))
s.targetid = 0# so saves to correct file name (error of both targets saved to single file ending in target_0_error)withopen('%s_target_%d_error'% (s.outfilestem,s.targetid), 'wb') as f: # save avg error over targets to outfilestem
pickle.dump(errorDic, f)
## Wrapping up
s.pc.runworker() # MPI: Start simulations running on each host
s.pc.done() # MPI: Close MPI
totaltime = time()-verystart # See how long it took in totalprint(('\nDone; total time = %0.1f s.' % totaltime))
if (s.plotraster==Falseand s.plotconn==Falseand s.plotweightchanges==False): h.quit() # Quit extra processes, or everything if plotting wasn't requested (since assume non-interactive)################################################################################## Create Network###############################################################################defcreateNetwork():
## Print diagnostic informationif s.rank==0: print(("\nCreating simulation of %i cells for %0.1f s on %i hosts..." % (sum(s.popnumbers),s.duration/1000.,s.nhosts)))
s.pc.barrier()
## Create empty data structures
s.cells=[] # Create empty list for storing cells
s.dummies=[] # Create empty list for storing fake sections
s.gidVec=[] # Empty list for storing GIDs (index = local id; value = gid)
s.gidDic = {} # Empyt dict for storing GIDs (key = gid; value = local id) -- ~x6 faster than gidVec.index()## Set cell types
celltypes=[]
for c inrange(s.ncells): # Loop over each cell. ncells is all cells in the network.if s.cellclasses[c]==1: celltypes.append(s.RS) # Append a regular spiking pyramidal cellelif s.cellclasses[c]==2: celltypes.append(s.IB) # Append an intrinsically bursting pyramidal cellelif s.cellclasses[c]==3: celltypes.append(s.CH) # Append a chattering cellelif s.cellclasses[c]==4: celltypes.append(s.LTS) # Append a low-threshold spiking interneuronelif s.cellclasses[c]==5: celltypes.append(s.FS) # Append a fast-spiking interneuronelif s.cellclasses[c]==4: celltypes.append(s.TC) # Append a thalamocortical cellelif s.cellclasses[c]==5: celltypes.append(s.RTN) # Append a reticular thalamic nucleus cellelif s.cellclasses[c]==-1: celltypes.append(s.nsloc) # Append a nslocelse: raise Exception('Undefined cell class "%s"' % s.cellclasses[c]) # No match? Cause an error## Set positions
seed(s.id32('%d'%s.randseed)) # Reset random number generator
s.xlocs = s.modelsize*rand(s.ncells) # Create random x locations
s.ylocs = s.modelsize*rand(s.ncells) # Create random y locations
s.zlocs = rand(s.ncells) # Create random z locationsfor c inrange(s.ncells):
s.zlocs[c] = s.corticalthick * (s.zlocs[c]*(s.popyfrac[s.cellpops[c]][1]-s.popyfrac[s.cellpops[c]][0]) + s.popyfrac[s.cellpops[c]][0]) # calculate based on yfrac for population and corticalthick## Actually create the cells
s.spikerecorders = [] # Empty list for storing spike-recording Netcons
s.hostspikevecs = [] # Empty list for storing host-specific spike vectors
s.cellsperhost = 0if s.PMdinput == 'Plexon': ninnclDic = len(s.innclDic) # number of PMd created in this workerfor c inrange(int(s.rank), s.ncells, s.nhosts):
s.dummies.append(h.Section()) # Create fake sections
gid = c
if s.cellnames[gid] == 'PMd':
if s.PMdinput == 'Plexon':
cell = celltypes[gid](cellid = gid) # create an NSLOC
s.inncl.append(h.NetCon(None, cell)) # This netcon receives external spikes
s.innclDic[gid - s.ncells - s.server.numPMd] = ninnclDic # This dictionary works in case that PMd's gid starts from 0.
ninnclDic += 1elif s.PMdinput == 'targetSplit':
cell = celltypes[gid](cellid = gid) # create an NSLOC
cell.number = s.backgroundnumber
cell.interval = s.backgroundrateMin**-1*1e3
cell.noise = s.PMdNoiseRatio
elif s.PMdinput == 'spikes':
cell = h.VecStim()
else:
cell = celltypes[gid](cellid = gid) # create an NSLOC
cell.number = s.backgroundnumber
cell.interval = s.backgroundrateMin**-1*1e3elif s.cellnames[gid] == 'ASC':
cell = celltypes[gid](cellid = gid) #create an NSLOCelse:
if s.cellclasses[gid]==3:
cell = s.fastspiking(s.dummies[s.cellsperhost], vt=-47, cellid=gid) # Don't use LTS cell, but instead a FS cell with a low thresholdelse:
cell = celltypes[gid](s.dummies[s.cellsperhost], cellid=gid) # Create a new cell of the appropriate type (celltypes[gid]) and store it#if s.verbose>0: s.cells[-1].useverbose(s.verbose, s.filename+'los.txt') # Turn on diagnostic to file
s.cells.append(cell)
s.gidVec.append(gid) # index = local id; value = global id
s.gidDic[gid] = s.cellsperhost # key = global id; value = local id -- used to get local id because gid.index() too slow!
s.pc.set_gid2node(gid, s.rank)
spikevec = h.Vector()
s.hostspikevecs.append(spikevec)
spikerecorder = h.NetCon(cell, None)
spikerecorder.record(spikevec)
s.spikerecorders.append(spikerecorder)
s.pc.cell(gid, s.spikerecorders[s.cellsperhost])
s.cellsperhost += 1# contain cell numbers per host including PMd and Pprint((' Number of cells on node %i: %i ' % (s.rank,len(s.cells))))
s.pc.barrier()
## Calculate motor command cell ranges so can be used for EDSC and IDSC connectivity
nCells = s.motorCmdEndCell - s.motorCmdStartCell
s.motorCmdCellRange = []
for i inrange(s.nMuscles):
s.motorCmdCellRange.append(list(range(s.motorCmdStartCell + int(nCells/s.nMuscles)*i, s.motorCmdStartCell + int(nCells/s.nMuscles)*i + int(nCells/s.nMuscles)))) # cells used to for shoulder motor command## Calculate distances and probabilitiesif s.rank==0: print(('Calculating connection probabilities (est. time: %i s)...' % (s.performance*s.cellsperhost**2/3e4)))
conncalcstart = s.time() # See how long connecting the cells takes
s.nconnpars = 5# Connection parameters: pre- and post- cell ID, weight, distances, delays
s.conndata = [[] for i inrange(s.nconnpars)] # List for storing connections
nPostCells = 0
EDSCpre = [] # to keep track of EB5->EDSC connection and replicate in EB5->IDSCfor c inrange(s.cellsperhost): # Loop over all postsynaptic cells on this host (has to be postsynaptic because of gid_connect)
gid = s.gidVec[c] # Increment global identifierif s.cellnames[gid] == 'PMd'or s.cellnames[gid] == 'ASC':
# There are no presynaptic connections for PMd or ASC.continue
nPostCells += 1if s.toroidal:
xpath=(abs(s.xlocs-s.xlocs[gid]))**2
xpath2=(s.modelsize-abs(s.xlocs-s.xlocs[gid]))**2
xpath[xpath2<xpath]=xpath2[xpath2<xpath]
ypath=(abs(s.ylocs-s.ylocs[gid]))**2
ypath2=(s.modelsize-abs(s.ylocs-s.ylocs[gid]))**2
ypath[ypath2<ypath]=ypath2[ypath2<ypath]
zpath=(abs(s.zlocs-s.zlocs[gid]))**2
distances = sqrt(xpath + ypath) # Calculate all pairwise distances
distances3d = sqrt(xpath + ypath + zpath) # Calculate all pairwise 3d distanceselse:
distances = sqrt((s.xlocs-s.xlocs[gid])**2 + (s.ylocs-s.ylocs[gid])**2) # Calculate all pairwise distances
distances3d = sqrt((s.xlocs-s.xlocs[gid])**2 + (s.ylocs-s.ylocs[gid])**2 + (s.zlocs-s.zlocs[gid])**2) # Calculate all pairwise distances
allconnprobs = s.scaleconnprob[s.EorI,s.EorI[gid]] * s.connprobs[s.cellpops,s.cellpops[gid]] * exp(-distances/s.connfalloff[s.EorI]) # Calculate pairwise probabilities
allconnprobs[gid] = 0# Prohibit self-connections using the cell's GID
seed(s.id32('%d'%(s.randseed+gid))) # Reset random number generator
allrands = rand(s.ncells) # Create an array of random numbers for checking each connectionif s.PMdinput == 'Plexon':
for c inrange(s.popGidStart[s.PMd], s.popGidEnd[s.PMd] + 1):
allrands[c] = 1if s.cellnames[gid] == 'ER5': # PMd->ER5 conn (full conn)
PMdId = (gid % s.server.numPMd) + s.ncells - s.server.numPMd #CHECK THIS!
allconnprobs[PMdId] = s.connprobs[s.PMd,s.ER5] # to make this connected to ER5
allrands[PMdId] = 0# to make this connect to ER5
distances[PMdId] = 300# to make delay 5 in conndata[3]
makethisconnection = allconnprobs>allrands # Perform test to see whether or not this connection should be made
preids = array(makethisconnection.nonzero()[0],dtype='int') # Return True elements of that array for presynaptic cell IDsif s.PMdinput == 'targetSplit'and s.cellnames[gid] == 'ER5': # PMds 0-47 -> ER5 0-47 ; PMds 48-95 -> ER5 48-95if gid < s.popGidStart[s.ER5] + s.popnumbers[s.ER5]/2:
prePMd = [(x - s.popGidStart[s.ER5])%(s.popnumbers[s.PMd]/2) + s.popGidStart[s.PMd] for x inrange(gid, gid+1)] # input from 2 PMdselse:
prePMd = [(x - s.popGidStart[s.ER5])%(s.popnumbers[s.PMd]/2) + s.popGidStart[s.PMd] + s.popnumbers[s.PMd]/2for x inrange(gid, gid+1)] # input from 2 PMdsif array(prePMd).all() < s.popGidEnd[s.PMd]:
#print 'prePMd=%d to ER5=%d:'%(prePMd[0],gid)
preids = concatenate([preids, prePMd])
if s.cellnames[gid] == 'EDSC': # save EDSC presyn cells to replicate in IDSC, and add inputs from IDSC
EDSCpre.append(array(preids)) # save EDSC presyn cells before adding IDSC input
invPops = [1, 0, 3, 2] # each postsyn ESDC cell will receive input from all the antagonistic muscle IDSCs
IDSCpre = [s.motorCmdCellRange[invPops[i]] - s.popGidStart[s.EDSC] + s.popGidStart[s.IDSC] for i inrange(s.nMuscles) if gid in s.motorCmdCellRange[i]][0]
preids = concatenate([preids, IDSCpre]) # add IDSC presynaptic input to EDSCelif s.cellnames[gid] == 'IDSC': # use same presyn cells as for EDSC (antagonistic inhibition)
preids = array(EDSCpre.pop(0))
postids = array(gid+zeros(len(preids)),dtype='int') # Post-synaptic cell IDs
s.conndata[0].append(preids) # Append pre-cell ID
s.conndata[1].append(postids) # Append post-cell ID
s.conndata[2].append(distances[preids]) # Distances
s.conndata[3].append(s.mindelay + distances3d[preids]/float(s.velocity)) # Calculate the delays
wt1 = s.scaleconnweight[s.EorI[preids],s.EorI[postids]] # N weight scale factors -- WARNING, might be flipped
wt2 = s.connweights[s.cellpops[preids],s.cellpops[postids],:] # NxM inter-population weights
wt3 = s.receptorweight[:] # M receptor weights
finalweights = transpose(wt1*transpose(wt2*wt3)) # Multiply out population weights with receptor weights to get NxM matrix
s.conndata[4].append(finalweights) # Initialize weights to 0, otherwise get memory leaksfor pp inrange(s.nconnpars): s.conndata[pp] = array(concatenate([s.conndata[pp][c] for c inrange(nPostCells)])) # Turn pre- and post- cell IDs lists into vectors
s.nconnections = len(s.conndata[0]) # Find out how many connections we're going to make
conncalctime = time()-conncalcstart # See how long it tookif s.rank==0: print((' Done; time = %0.1f s' % conncalctime))
# set plastic connections based on plasConnsType (from evol alg)if s.plastConnsType == 0:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC]] # only spinal cordelif s.plastConnsType == 1:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2]] # + L2-L5elif s.plastConnsType == 2:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2],\
[s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5]] # + L6elif s.plastConnsType == 3:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2],\
[s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5], \
[s.ER2,s.IL2], [s.ER2,s.IF2], [s.ER5,s.IL5], [s.ER5,s.IF5], [s.EB5,s.IL5], [s.EB5,s.IF5]] # + Inh# same with additional plasticity between PMd->L5Aelif s.plastConnsType == 4:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5]] # only spinal cord + pmdelif s.plastConnsType == 5:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5], # spinal cord + pmd
[s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2]] # + L2-L5elif s.plastConnsType == 6:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5], # spinal cord + pmd
[s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2], # + L2-L5
[s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5]] # + L6elif s.plastConnsType == 7:
s.plastConns = [[s.ASC,s.ER2], [s.EB5,s.EDSC], [s.EB5,s.IDSC], [s.PMd,s.ER5], # spinal cord + pmd
[s.ER2,s.ER5], [s.ER5,s.EB5], [s.ER2,s.EB5], [s.ER5,s.ER2], # + L2-L5
[s.ER5,s.ER6], [s.ER6,s.ER5], [s.ER6,s.EB5], # + L6
[s.ER2,s.IL2], [s.ER2,s.IF2], [s.ER5,s.IL5], [s.ER5,s.IF5], [s.EB5,s.IL5], [s.EB5,s.IF5]] # + Inh## Actually make connectionsif s.rank==0: print(('Making connections (est. time: %i s)...' % (s.performance*s.nconnections/9e2)))
print((' Number of connections on host %i: %i' % (s.rank, s.nconnections)))
connstart = time() # See how long connecting the cells takes
s.connlist = [] # Create array for storing each of the connections
s.stdpconndata = [] # Store data on STDP connectionsif s.usestdp: # STDP enabled?
s.stdpmechs = [] # Initialize array for STDP mechanisms
s.precons = [] # Initialize array for presynaptic spike counters
s.pstcons = [] # Initialize array for postsynaptic spike countersfor con inrange(s.nconnections): # Loop over each connection
pregid = s.conndata[0][con] # GID of presynaptic cell
pstgid = s.conndata[1][con] # Index of postsynaptic cell
pstid = s.gidDic[pstgid]# Index of postynaptic cell -- convert from GID to local
newcon = s.pc.gid_connect(pregid, s.cells[pstid]) # Create a connection
newcon.delay = s.conndata[3][con] # Set delayfor r inrange(s.nreceptors): newcon.weight[r] = s.conndata[4][con][r] # Set weight of connection
s.connlist.append(newcon) # Connect the two cellsif s.usestdp and ([s.cellpops[pregid],s.cellpops[pstgid]] in s.plastConns): # If using STDP and these pops are set to be plastic connectionsifsum(abs(s.stdprates[s.EorI[pregid],:]))>0orsum(abs(s.RLrates[s.EorI[pregid],:]))>0: # Don't create an STDP connection if the learning rates are zerofor r inrange(s.nreceptors): # Need a different STDP instances for each receptorif newcon.weight[r]>0: # Only make them for nonzero connections
stdpmech = h.STDP(0,sec=s.dummies[pstid]) # Create STDP adjuster
stdpmech.hebbwt = s.stdprates[s.EorI[pregid],0] # Potentiation rate
stdpmech.antiwt = s.stdprates[s.EorI[pregid],1] # Depression rate
stdpmech.wmax = s.maxweight # Maximum synaptic weight
precon = s.pc.gid_connect(pregid,stdpmech); precon.weight[0] = 1# Send presynaptic spikes to the STDP adjuster
pstcon = s.pc.gid_connect(pstgid,stdpmech); pstcon.weight[0] = -1# Send postsynaptic spikes to the STDP adjuster
h.setpointer(s.connlist[-1]._ref_weight[r],'synweight',stdpmech) # Associate the STDP adjuster with this weight
s.stdpmechs.append(stdpmech) # Save STDP adjuster
s.precons.append(precon) # Save presynaptic spike source
s.pstcons.append(pstcon) # Save postsynaptic spike source
s.stdpconndata.append([pregid,pstgid,r]) # Store presynaptic cell ID, postsynaptic, and receptorif s.verbose: stdpmech.verbose = 1if s.useRL: # using RL
stdpmech.RLon = 1# make sure RL is on
stdpmech.RLhebbwt = s.RLrates[s.EorI[pregid],0] # Potentiation rate
stdpmech.RLantiwt = s.RLrates[s.EorI[pregid],1] # Depression rate
stdpmech.tauhebb = stdpmech.tauanti = s.stdpwin # stdp time constant(ms)
stdpmech.RLwindhebb = stdpmech.RLwindhebb = s.eligwin # RL eligibility trace window length (ms)
stdpmech.useRLexp = s.useRLexp # RL
stdpmech.softthresh = s.useRLsoft # RL soft-thresholdingelse:
stdpmech.RLon = 0# make sure RL is off
s.nstdpconns = len(s.stdpconndata) # Get number of STDP connections
conntime = time()-connstart # See how long it tookif s.usestdp: print((' Number of STDP connections on host %i: %i' % (s.rank, s.nstdpconns)))
if s.rank==0: print((' Done; time = %0.1f s' % conntime))
################################################################################## Add stimulation###############################################################################defaddStimulation():
if s.usestims:
s.stimstruct = [] # For saving
s.stimrands=[] # Create input connections
s.stimsources=[] # Create empty list for storing synapses
s.stimconns=[] # Create input connections
s.stimtimevecs = [] # Create array for storing time vectors
s.stimweightvecs = [] # Create array for holding weight vectorsif s.saveraw:
s.stimspikevecs=[] # A list for storing actual cell voltages (WARNING, slow!)
s.stimrecorders=[] # And for recording spikesfor stim inrange(len(s.stimpars)): # Loop over each stimulus type
ts = s.stimpars[stim] # Stands for "this stimulus"
ts.loc = ts.loc * s.modelsize # scale cell locations to model size
stimvecs = s.makestim(ts.isi, ts.var, ts.width, ts.weight, ts.sta, ts.fin, ts.shape) # Time-probability vectors
s.stimstruct.append([ts.name, stimvecs]) # Store for saving later
s.stimtimevecs.append(h.Vector().from_python(stimvecs[0]))
for c inrange(s.cellsperhost):
gid = s.cellsperhost*int(s.rank)+c # For deciding E or I
seed(s.id32('%d'%(s.randseed+gid))) # Reset random number generator for this cellif ts.fraction>rand(): # Don't do it for every cell necessarilyifany(s.cellpops[gid]==ts.pops) and s.xlocs[gid]>=ts.loc[0,0] and s.xlocs[gid]<=ts.loc[0,1] and s.ylocs[gid]>=ts.loc[1,0] and s.ylocs[gid]<=ts.loc[1,1]:
maxweightincrease = 20# Otherwise could get infinitely high, infinitely close to the stimulus
distancefromstimulus = sqrt(sum((array([s.xlocs[gid], s.ylocs[gid]])-s.modelsize*ts.falloff[0])**2))
fallofffactor = min(maxweightincrease,(ts.falloff[1]/distancefromstimulus)**2)
s.stimweightvecs.append(h.Vector().from_python(stimvecs[1]*fallofffactor)) # Scale by the fall-off factor
stimrand = h.Random()
stimrand.MCellRan4() # If everything has the same seed, should happen at the same time
stimrand.negexp(1)
stimrand.seq(s.id32('%d'%(s.randseed+gid))*1e3) # Set the sequence i.e. seed
s.stimrands.append(stimrand)
stimsource = h.NetStim() # Create a NetStim
stimsource.interval = ts.rate**-1*1e3# Interval between spikes
stimsource.number = 1e9# Number of spikes
stimsource.noise = ts.noise # Fractional noise in timing
stimsource.noiseFromRandom(stimrand) # Set it to use this random number generator
s.stimsources.append(stimsource) # Save this NetStim
stimconn = h.NetCon(stimsource, s.cells[c]) # Connect this noisy input to a cellfor r inrange(s.nreceptors): stimconn.weight[r]=0# Initialize weights to 0, otherwise get memory leaks
s.stimweightvecs[-1].play(stimconn._ref_weight[0], s.stimtimevecs[-1]) # Play most-recently-added vectors into weight
stimconn.delay=s.mindelay # Specify the delay in ms -- shouldn't make a spot of difference
s.stimconns.append(stimconn) # Save this connnectionif s.saveraw: # and c <=100:
stimspikevec = h.Vector() # Initialize vector
s.stimspikevecs.append(stimspikevec) # Keep all those vectors
stimrecorder = h.NetCon(stimsource, None)
stimrecorder.record(stimspikevec) # Record simulation time
s.stimrecorders.append(stimrecorder)
print((' Number of stimuli created on host %i: %i' % (s.rank, len(s.stimsources))))
################################################################################## Add background inputs###############################################################################defaddBackground():
if s.rank==0: print('Creating background inputs...')
s.backgroundsources=[] # Create empty list for storing synapses
s.backgroundrands=[] # Create random number generators
s.backgroundconns=[] # Create input connections
s.backgroundgid=[] # Target cell gid for each inputif s.savebackground:
s.backgroundspikevecs=[] # A list for storing actual cell voltages (WARNING, slow!)
s.backgroundrecorders=[] # And for recording spikesfor c inrange(s.cellsperhost):
gid = s.gidVec[c]
if s.cellnames[gid] == 'ASC'or s.cellnames[gid] == 'PMd' : # These pops won't receive background stimulations.passelse:
backgroundrand = h.Random()
backgroundrand.MCellRan4(gid,gid*2)
backgroundrand.negexp(1)
s.backgroundrands.append(backgroundrand)
if s.cellnames[gid] == 'EDSC'or s.cellnames[gid] == 'IDSC':
backgroundsource = h.NSLOC() # Create a NSLOC
backgroundsource.interval = s.backgroundrateMin**-1*1e3# Take inverse of the frequency and then convert from Hz^-1 to ms
backgroundsource.noise = 0.3# Fractional noise in timingelif s.cellnames[gid] == 'EB5':
backgroundsource = h.NSLOC() # Create a NSLOC
backgroundsource.interval = s.backgroundrate**-1*1e3# Take inverse of the frequency and then convert from Hz^-1 to ms
backgroundsource.noise = s.backgroundnoise # Fractional noise in timingelse:
backgroundsource = h.NetStim() # Create a NetStim
backgroundsource.interval = s.backgroundrate**-1*1e3# Take inverse of the frequency and then convert from Hz^-1 to ms
backgroundsource.noiseFromRandom(backgroundrand) # Set it to use this random number generator
backgroundsource.noise = s.backgroundnoise # Fractional noise in timing
backgroundsource.number = s.backgroundnumber # Number of spikes
s.backgroundsources.append(backgroundsource) # Save this NetStim
s.backgroundgid.append(gid) # append cell gid associated to this netstim
backgroundconn = h.NetCon(backgroundsource, s.cells[c]) # Connect this noisy input to a cellfor r inrange(s.nreceptors): backgroundconn.weight[r]=0# Initialize weights to 0, otherwise get memory leaksif s.cellnames[gid] == 'EDSC'or s.cellnames[gid] == 'IDSC':
backgroundconn.weight[s.backgroundreceptor] = s.backgroundweightExplor # Specify the weight for the EDSC, IDSC and PMd background inputelif s.cellnames[gid] == 'EB5'and s.explorMovs == 2:
backgroundconn.weight[s.backgroundreceptor] = s.backgroundweightExplor # Weight for EB5 input if explor movs via EB5else:
backgroundconn.weight[s.backgroundreceptor] = s.backgroundweight[s.EorI[gid]] # Specify the weight -- 1 is NMDA receptor for smoother, more summative activation
backgroundconn.delay=2# Specify the delay in ms -- shouldn't make a spot of difference
s.backgroundconns.append(backgroundconn) # Save this connnectionif s.savebackground:
backgroundspikevec = h.Vector() # Initialize vector
s.backgroundspikevecs.append(backgroundspikevec) # Keep all those vectors
backgroundrecorder = h.NetCon(backgroundsource, None)
backgroundrecorder.record(backgroundspikevec) # Record simulation time
s.backgroundrecorders.append(backgroundrecorder)
print((' Number created on host %i: %i' % (s.rank, len(s.backgroundsources))))
s.pc.barrier()
################################################################################## Setup Simulation###############################################################################defsetupSim():
## reset time variables
s.timeoflastRL = -inf # Never RL
s.timeoflastsave = -inf # Never saved
s.timeoflastexplor = -inf # time when last exploratory movement was updated# Initialize STDP -- just for recordingif s.usestdp:
s.weightchanges = []
if s.rank==0: print('\nSetting up STDP...')
if s.usestdp:
s.weightchanges = [[] for ps inrange(s.nstdpconns)] # Create an empty list for each STDP connection -- warning, slow with large numbers of connections!for ps inrange(s.nstdpconns): s.weightchanges[ps].append([0, s.stdpmechs[ps].synweight]) # Time of save (0=initial) and the weight## Set up LFP recording
s.lfptime = [] # List of times that the LFP was recorded at
s.nlfps = len(s.lfppops) # Number of distinct LFPs to calculate
s.hostlfps = [] # Voltages for calculating LFP
s.lfpcellids = [[] for pop inrange(s.nlfps)] # Create list of lists of cell IDsfor c inrange(s.cellsperhost): # Loop over each cell and decide which LFP population, if any, it belongs to
gid = s.gidVec[c] # Get this cell's GIDif s.cellnames[gid] == 'ASC'or s.cellnames[gid] == 'PMd': # 'ER2' won't be fired by background stimulations.continuefor pop inrange(s.nlfps): # Loop over each LFP population
thispop = s.cellpops[gid] # Population of this cellifsum(s.lfppops[pop]==thispop)>0: # There's a match
s.lfpcellids[pop].append(gid) # Flag this cell as belonging to this LFP population## Set up raw recording
s.rawrecordings = [] # A list for storing actual cell voltages (WARNING, slow!)if s.saveraw:
if s.rank==0: print('\nSetting up raw recording...')
s.nquantities = 5# Number of variables from each cell to record from# Later this part should be modified because NSLOC doesn't have V, u and I.for c inrange(s.cellsperhost):
gid = s.gidVec[c] # Get this cell's GIDif s.cellnames[gid] == 'ASC'or s.cellnames[gid] == 'PMd': # NSLOC doesn't have V, u and Icontinue
recvecs = [h.Vector() for q inrange(s.nquantities)] # Initialize vectors
recvecs[0].record(h._ref_t) # Record simulation time
recvecs[1].record(s.cells[c]._ref_V) # Record cell voltage
recvecs[2].record(s.cells[c]._ref_u) # Record cell recovery variable
recvecs[3].record(s.cells[c]._ref_I) # Record cell current
recvecs[4].record(s.cells[c]._ref_gAMPA)
# recvecs[5].record(s.cells[c]._ref_gNMDA)# recvecs[6].record(s.cells[c]._ref_gGABAA)# recvecs[7].record(s.cells[c]._ref_gGABAB)# recvecs[8].record(s.cells[c]._ref_gOpsin)
s.rawrecordings.append(recvecs) # Keep all those vectors## Set up virtual armif s.useArm != 'None':
if s.rank==0: print('\nSetting up virtual arm...')
s.arm = Arm(s.useArm, s.animArm, s.graphsArm)
s.arm.targetid = s.targetid
s.arm.setup(s)#duration, loopstep, RLinterval, pc, scale, popnumbers, p)## Communication setup for plexon inputif s.PMdinput == 'Plexon':
h('''
objref cvode
cvode = new CVode()
tstop = 0
''')
if s.isOriginal == 0: # With communication programif s.rank == 0:
#serverManager = s.server.Manager() # isDp in confis.py = 0
s.server.Manager.start() # launch sever processprint("Server process completed and callback function initalized")
e = s.server.Manager.Event() # Queue callback function in the NEURON queue# Wait for external spikes for PMd from Plexonif s.rank == 0:
if s.server.isCommunication == 1:
s.server.getServerInfo() # show parameters of the server processprint("[Waiting for spikes; run the client on Windows machine...]")
while s.server.queue.empty(): # only Rank 0 is waiting for spikes in the queue.pass
s.pc.barrier() # other workers are waiting here.## Play back raw spikes from recorded PMd neurons, into model PMd populationelif s.PMdinput == 'spikes':
# if s.rank == 0:
rawSpikesPMd = loadmat(s.spikesPMdFile)['pmdData'] # load raw data
numrawcells = len(rawSpikesPMd[0][0])
# implement lesionif s.PMdlesion > 0:
numLesionedCells = int(round(s.PMdlesion*numrawcells))
if s.rank==0: print("Lesioning PMd input... removed spike times of %d (%d%%) cells"%(numLesionedCells, s.PMdlesion*100))
for itarget inrange(len(rawSpikesPMd)):
for itrial inrange(len(rawSpikesPMd[0])):
for icell inrange(numLesionedCells):
rawSpikesPMd[itarget][itrial][numrawcells-1-icell]['spkt'] = [array([])] # remove spike times of lesion % of cells# generate spike vectors based on training time, trial duration, and target presentation (eg. alternating trials)if s.duration == 1e3: # if sim duration=1sec, assume its the test trial and select PMd spikes based on targetid
spktPMd = [rawSpikesPMd[s.targetid][s.repeatSingleTrials[s.targetid]][icell]['spkt'][0] for icell inrange(numrawcells)]
elif s.repeatSingleTrials[0] > -1: # use single trials for each target during training
spktPMd = []
for icell inrange(numrawcells):
spkt = []
[spkt.extend(list(rawSpikesPMd[itarget][s.repeatSingleTrials[itarget]][icell]['spkt'][0] + (1000*itrial))) \
for itrial,itarget inenumerate(s.trialTargets[:-1])] # replicate spike times over trials
spktPMd.append(spkt)
else: # use all available trials for each target during trainingpass# play back PMd spikes using VecStims
s.tvecPMdlist = []
gids = [i for i in s.gidVec if i inrange(s.popGidStart[s.PMd], s.popGidEnd[s.PMd])] # calcualate gids in this nodefor icell in gids: # for each unique cell/vecstim
spkcell = spktPMd[icell%numrawcells]
tvecPMd = h.Vector().from_python(spkcell) # find spikes for that vecstim
s.tvecPMdlist.append(tvecPMd) # store vector to avoid runtime error
s.cells[s.gidDic[icell]].play(tvecPMd) # play back sequence of spikes################################################################################## Run Simulation###############################################################################defrunSim():
if s.rank == 0:
print('\nRunning...')
runstart = time() # See how long the run takes# set cache_efficient on
h('objref cvode')
h('cvode = new CVode()')
h.cvode.cache_efficient(1)
s.pc.set_maxstep(10) # MPI: Set the maximum integration time in ms -- not very important
init() # Initialize the simulationwhileround(h.t) < s.duration:
run(min(s.duration,h.t+s.loopstep)) # MPI: Get ready to run the simulation (it isn't actually run until pc.runworker() is called I think)if s.server.simMode == 0:
if s.rank==0and (round(h.t) % s.progupdate)==0: print((' t = %0.1f s (%i%%; time consumed: %0.1f s)' % (h.t/1e3, int(h.t/s.duration*100), (time()-runstart))))
else:
if s.rank==0: print((' t = %0.1f s (%i%%; time consumed: %0.1f s)' % (h.t/1e3, int(h.t/s.duration*100), (time()-runstart))))
# Calculate LFP -- WARNING, need to think about how to optimizeif s.savelfps:
s.lfptime.append(h.t) # Append current time
tmplfps = zeros((s.nlfps)) # Create empty array for storing LFP voltagesfor pop inrange(s.nlfps):
for c inrange(len(s.lfpcellids[pop])):
id = s.gidDic[s.lfpcellids[pop][c]]# Index of postynaptic cell -- convert from GID to local
tmplfps[pop] += s.cells[id].V # Add voltage to LFP estimateif s.verbose:
if s.server.Manager.ns.isnan(tmplfps[pop]) or s.server.Manager.ns.isinf(tmplfps[pop]):
print("Nan or inf")
s.hostlfps.append(tmplfps) # Add voltages# Periodic weight savesif s.usestdp:
timesincelastsave = h.t - s.timeoflastsave
if timesincelastsave >= s.timebetweensaves:
s.timeoflastsave = h.t
#if s.rank == 0: print 'Recording weight changes at time ', h.tfor ps inrange(s.nstdpconns):
if s.stdpmechs[ps].synweight != s.weightchanges[ps][-1][-1]: # Only store connections that changed; [ps] = this connection; [-1] = last entry; [-1] = weight
s.weightchanges[ps].append([s.timeoflastsave, s.stdpmechs[ps].synweight])
## Virtual armif s.useArm != 'None':
armStart = time()
s.arm.run(h.t, s) # run virtual arm apparatus (calculate command, move arm, feedback)if s.useRL and (h.t - s.timeoflastRL >= s.RLinterval): # if time for next RL
s.timeoflastRL = h.t
vec = h.Vector()
if s.rank == 0:
critic = s.arm.RLcritic(h.t) # get critic signal (-1, 0 or 1)
s.pc.broadcast(vec.from_python([critic]), 0) # convert python list to hoc vector for broadcast data received from arm#print criticelse: # other workers
s.pc.broadcast(vec, 0)
critic = vec.to_python()[0]
if critic != 0: # if critic signal indicates punishment (-1) or reward (+1)for stdp in s.stdpmechs: # for all connections in stdp conn list#print 'stdp_before: ', stdp.synweight
stdp.reward_punish(float(critic)) # run stds.mod method to update syn weights based on RL#print stdp.tlastpre#print stdp.tlastpost#stdp.adjustweight(float(0.5))#sleep(0.001)#print 'stdp_after: ', stdp.synweight# Synaptic scaling?#print(' Arm time = %0.4f s') % (time() - armStart)## Time adjustment for online mode simulationif s.PMdinput == 'Plexon'and s.server.simMode == 1:
# To avoid izhi cell's over shooting when h.t moves forward because sim is slow.for c inrange(s.cellsperhost):
gid = s.gidVec[c]
if s.cellnames[gid] == 'PMd': # 'PMds don't have t0 variable.continue
s.cells[c].t0 = s.server.newCurrTime.value - h.dt
dtSave = h.dt # save original dt
h.dt = s.server.newCurrTime.value - h.t # new dt
active = h.cvode.active()
if active != 0:
h.cvode.active(0)
h.fadvance() # Integrate with new dtif active != 0:
h.cvode.active(1)
h.dt = dtSave # Restore orignal dtif s.rank==0:
s.runtime = time()-runstart # See how long it tookprint((' Done; run time = %0.1f s; real-time ratio: %0.2f.' % (s.runtime, s.duration/1000/s.runtime)))
s.pc.barrier() # Wait for all hosts to get to this point################################################################################## Finalize Simulation (gather data from nodes, etc.)###############################################################################deffinalizeSim():
## Variables to unpack data from all hosts## Pack data from all hostsif s.rank==0: print('\nGathering spikes...')
gatherstart = time() # See how long it takes to plotfor host inrange(s.nhosts): # Loop over hostsif host==s.rank: # Only act on a single host
hostspikecells=array([])
hostspiketimes=array([])
for c inrange(len(s.hostspikevecs)): # fails when saving raw
thesespikes = array([s.hostspikevecs[c].x[i] for i inrange(s.hostspikevecs[c].size())]) # Convert spike times to an array
nthesespikes = len(thesespikes) # Find out how many of spikes there were for this cell
hostspiketimes = concatenate((hostspiketimes, thesespikes)) # Add spikes from this cell to the list#hostspikecells = concatenate((hostspikecells, (c+host*s.cellsperhost)*ones(nthesespikes))) # Add this cell's ID to the list
hostspikecells = concatenate((hostspikecells, s.gidVec[c]*ones(nthesespikes))) # Add this cell's ID to the listif s.saveraw:
for c inrange(len(s.rawrecordings)):
for q inrange(len(s.rawrecordings[c])):
s.rawrecordings[c][q] = array(s.rawrecordings[c][q])
messageid=s.pc.pack([hostspiketimes, hostspikecells, s.hostlfps, s.conndata, s.stdpconndata, s.weightchanges, s.rawrecordings]) # Create a mesage ID and store this value
s.pc.post(host,messageid) # Post this message## Unpack data from all hostsif s.rank==0: # Only act on a single host
s.allspikecells = array([])
s.allspiketimes = array([])
s.lfps = zeros((len(s.lfptime),s.nlfps)) # Create an empty array for appending LFP data; first entry is for time
s.allconnections = [array([]) for i inrange(s.nconnpars)] # Store all connections
s.allconnections[s.nconnpars-1] = zeros((0,s.nreceptors)) # Create an empty array for appending connections
s.allstdpconndata = zeros((0,3)) # Create an empty array for appending STDP connection dataif s.usestdp: s.allweightchanges = [] # empty list so weightchanges in this node don't appear twice
s.totalspikes = 0# Keep a running tally of the number of spikes
s.totalconnections = 0# Total number of connections
s.totalstdpconns = 0# Total number of stdp connectionsif s.saveraw: s.allraw = []
for host inrange(s.nhosts): # Loop over hosts
s.pc.take(host) # Get the last message
hostdata = s.pc.upkpyobj() # Unpack them
s.allspiketimes = concatenate((s.allspiketimes, hostdata[0])) # Add spikes from this cell to the list
s.allspikecells = concatenate((s.allspikecells, hostdata[1])) # Add this cell's ID to the listif s.savelfps: s.lfps += array(hostdata[2]) # Sum LFP voltagesfor pp inrange(s.nconnpars): s.allconnections[pp] = concatenate((s.allconnections[pp], hostdata[3][pp])) # Append pre/post synapsesif s.usestdp andlen(hostdata[4]): # Using STDP and at least one STDP connection
s.allstdpconndata = concatenate((s.allstdpconndata, hostdata[4])) # Add data on STDP connectionsfor ps inrange(len(hostdata[4])): s.allweightchanges.append(hostdata[5][ps]) # "ps" stands for "plastic synapse"if s.saveraw:
for c inrange(len(hostdata[6])): s.allraw.append(hostdata[6][c]) # Append cell-by-cell
s.totalspikes = len(s.allspiketimes) # Keep a running tally of the number of spikes
s.totalconnections = len(s.allconnections[0]) # Total number of connections
s.totalstdpconns = len(s.allstdpconndata) # Total number of STDP connections# Record background spike data (cliff: only for one node since takes too long to pack for all and just needed for debugging)if s.savebackground and s.usebackground:
s.allbackgroundspikecells=array([])
s.allbackgroundspiketimes=array([])
for c inrange(len(s.backgroundspikevecs)):
thesespikes = array(s.backgroundspikevecs[c])
s.allbackgroundspiketimes = concatenate((s.allbackgroundspiketimes, thesespikes)) # Add spikes from this stimulator to the list
s.allbackgroundspikecells = concatenate((s.allbackgroundspikecells, c+zeros(len(thesespikes)))) # Add this cell's ID to the list
s.backgrounddata = transpose(vstack([s.allbackgroundspikecells,s.allbackgroundspiketimes]))
else: s.backgrounddata = [] # For saving s no errorif s.saveraw and s.usestims:
s.allstimspikecells=array([])
s.allstimspiketimes=array([])
for c inrange(len(s.stimspikevecs)):
thesespikes = array(s.stimspikevecs[c])
s.allstimspiketimes = concatenate((s.allstimspiketimes, thesespikes)) # Add spikes from this stimulator to the list
s.allstimspikecells = concatenate((s.allstimspikecells, c+zeros(len(thesespikes)))) # Add this cell's ID to the list
s.stimspikedata = transpose(vstack([s.allstimspikecells,s.allstimspiketimes]))
else: s.stimspikedata = [] # For saving so no error
gathertime = time()-gatherstart # See how long it tookif s.rank==0: print((' Done; gather time = %0.1f s.' % gathertime))
s.pc.barrier()
#mindelay = s.pc.allreduce(s.pc.set_maxstep(10), 2) # flag 2 returns minimum value#if s.rank==0: print 'Minimum delay (time-step for queue exchange) is ',mindelay## Finalize virtual arm (es. close pipes, saved data)if s.useArm != 'None':
s.arm.close(s)
# terminate the server processif s.PMdinput == 'Plexon':
if s.isOriginal == 0:
s.server.Manager.stop()
## Print statisticsif s.rank == 0:
print('\nAnalyzing...')
s.firingrate = float(s.totalspikes)/s.ncells/s.duration*1e3# Calculate firing rate -- confusing but cool Python trick for iterating over a list
s.connspercell = s.totalconnections/float(s.ncells) # Calculate the number of connections per cellprint((' Run time: %0.1f s (%i-s sim; %i scale; %i cells; %i workers)' % (s.runtime, s.duration/1e3, s.scale, s.ncells, s.nhosts)))
print((' Spikes: %i (%0.2f Hz)' % (s.totalspikes, s.firingrate)))
print((' Connections: %i (%i STDP; %0.2f per cell)' % (s.totalconnections, s.totalstdpconns, s.connspercell)))
print((' Mean connection distance: %0.2f um' % mean(s.allconnections[2])))
print((' Mean connection delay: %0.2f ms' % mean(s.allconnections[3])))
################################################################################## Save data###############################################################################defsaveData():
if s.rank == 0:
## Save to txt file (spikes and conn)if s.savetxt:
filename = '../data/m1ms-spk.txt'
fd = open(filename, "w")
for c inrange(len(s.allspiketimes)):
print(int(s.allspikecells[c]), s.allspiketimes[c], s.popNamesDic[s.cellnames[int(s.allspikecells[c])]], file=fd)
fd.close()
print("[Spikes are stored in", filename, "]")
if s.verbose:
filename = 'm1ms-conn.txt'
fd = open(filename, "w")
for c inrange(len(s.allconnections[0])):
print(int(s.allconnections[0][c]), int(s.allconnections[1][c]), s.allconnections[2][c], s.allconnections[3][c], s.allconnections[4][c], file=fd)
fd.close()
print("[Connections are stored in", filename, "]")
## Save to mat fileif s.savemat:
savestart = time() # See how long it takes to save# Save simulation code
filestosave = [] #'main.py', 'shared.py', 'network.py', 'arm.py', 'arminterface.py', 'server.py', 'izhi.py', 'izhi2007.mod', 'stdp.mod', 'nsloc.py', 'nsloc.mod'] # Files to save
argv = [];
simcode = [argv, filestosave] # Start off with input parameters, if any, and then the list of files being savedfor f inrange(len(filestosave)): # Loop over each file
fobj = open(filestosave[f]) # Open it for reading
simcode.append(fobj.readlines()) # Append to list of code to save
fobj.close() # Close file object# Tidy variables
spikedata = vstack([s.allspikecells,s.allspiketimes]).T # Put spike data together
connections = vstack([s.allconnections[0],s.allconnections[1]]).T # Put connection data together
distances = s.allconnections[2] # Pull out distances
delays = s.allconnections[3] # Pull out delays
weights = s.allconnections[4] # Pull out weights
stdpdata = s.allstdpconndata # STDP connection dataif s.usestims: stimdata = [vstack(s.stimstruct[c][1]).T for c inrange(len(stimstruct))] # Only pull out vectors, not text, in stimdata# Save variables
info = {'timestamp':datetime.today().strftime("%d %b %Y %H:%M:%S"), 'runtime':s.runtime, 'popnames':s.popnames, 'popEorI':s.popEorI} # Save date, runtime, and input arguments
targetPos = s.arm.targetPos
handPosAll = s.arm.handPosAll
angAll = s.arm.angAll
motorCmdAll = s.arm.motorCmdAll
targetidAll = s.arm.targetidAll
errorAll = s.arm.errorAll
criticAll = s.arm.criticAll
ifnothasattr(s, 'phase'): s.phase = ''
s.filename = s.outfilestem+'_target_'+str(s.arm.targetid)+s.phase
if s.armMinimalSave: # save only data related to arm reaching (for evol alg)
variablestosave = ['targetPos', 'angAll', 'motorCmdAll', 'errorAll']
else:
variablestosave = ['info', 'targetPos', 'angAll', 'motorCmdAll', 'errorAll', 'simcode', 'spikedata', 's.cellpops', 's.cellnames', 's.cellclasses', 's.xlocs', 's.ylocs', 's.zlocs', 'connections', 'distances', 'delays', 'weights', 's.EorI']
if s.savelfps:
variablestosave.extend(['s.lfptime', 's.lfps'])
if s.usestdp:
variablestosave.extend(['stdpdata', 's.allweightchanges'])
if s.savebackground:
variablestosave.extend(['s.backgrounddata'])
if s.saveraw:
variablestosave.extend(['s.stimspikedata', 's.allraw'])
if s.usestims: variablestosave.extend(['stimdata'])
savecommand = "savemat(s.filename, {"for var inrange(len(variablestosave)): savecommand += "'" + variablestosave[var].replace('s.','') + "':" + variablestosave[var] + ", "# Create command out of all the variables
savecommand = savecommand[:-2] + "}, oned_as='column')"# Omit final comma-space and complete commandprint(('Saving output as %s...' % s.filename))
exec(savecommand) # Actually perform the save
savetime = time()-savestart # See how long it took to saveprint((' Done; time = %0.1f s' % savetime))
################################################################################## Plot data###############################################################################defplotData():
## Plottingif s.rank == 0:
if s.plotraster: # Whether or not to plotif (s.totalspikes>s.maxspikestoplot):
disp(' Too many spikes (%i vs. %i)' % (s.totalspikes, s.maxspikestoplot)) # Plot raster, but only if not too many spikeselif s.nhosts>1:
disp(' Plotting raster despite using too many cores (%i)' % s.nhosts)
analysis.plotraster()#;allspiketimes, allspikecells, EorI, ncells, connspercell, backgroundweight, firingrate, duration)else:
print('Plotting raster...')
analysis.plotraster()#allspiketimes, allspikecells, EorI, ncells, connspercell, backgroundweight, firingrate, duration)if s.plotpeth:
print('Plotting PETH...')
analysis.plotPETH()
if s.plotconn:
print('Plotting connectivity matrix...')
analysis.plotconn()
if s.plotpsd:
print('Plotting power spectral density')
analysis.plotpsd()
if s.plotweightchanges:
print('Plotting weight changes...')
analysis.plotweightchanges()
#analysis.plotmotorpopchanges()if s.plot3darch:
print('Plotting 3d architecture...')
analysis.plot3darch()
show(block=False)