'''
written by Robert Egger (robert.egger@tuebingen.mpg.de)
(c) 2013-2015 Max Planck Society
'''
import os
import time
from collections import Sequence
import numpy as np
from cell import PointCell, SpikeTrain
import reader
import writer
from synapse_mapper import SynapseMapper
#import synapse
from neuron import h
class NetworkMapper:
'''
Handles connectivity of presynaptic populations
to multi-compartmental neuron model.
Also handles activity of presynaptic populations.
'''
def __init__(self, postCell, nwParam, simParam=None):
'''
dictionary holding all presynaptic cells
ordered by cell type
self.cells = {}
dictionary holding indices of
all active presynaptic cells
ordered by cell type
self.connected_cells = {}
reference to postsynaptic (multi-compartment) cell model
self.postCell = postCell
network parameter set (i.e., parameters.network.pre)
self.nwParam = nwParam
'''
self.cells = {}
self.connected_cells = {}
self.postCell = postCell
self.nwParam = nwParam
self.simParam = simParam
def create_network(self, synWeightName=None, change=None):
'''
Public interface
Only call this function, it automatically
takes care of setting up the network correctly.
'''
print '***************************'
print 'creating network'
print '***************************'
self._assign_anatomical_synapses()
self._create_presyn_cells()
self._activate_presyn_cells()
self._connect_functional_synapses()
spikeTrainWeights = None
if synWeightName:
spikeTrainWeights, locations = reader.read_synapse_weight_file(synWeightName)
# awkward temporary implementation of prelease change during simulation time window
self._connect_spike_trains(spikeTrainWeights, change)
print '***************************'
print 'network complete!'
print '***************************'
def create_saved_network(self, synWeightName=None):
'''
Public interface
Used for re-creating network from anatomical
and functional connection files
'''
print '***************************'
print 'creating saved network'
print '***************************'
self._assign_anatomical_synapses()
self._create_presyn_cells()
self._activate_presyn_cells()
weights = None
if synWeightName:
weights, locations = reader.read_synapse_weight_file(synWeightName)
self._map_functional_realization(weights)
self._connect_spike_trains(weights)
print '***************************'
print 'network complete!'
print '***************************'
def reconnect_saved_synapses(self, synInfoName, synWeightName=None):
'''
Public interface
used for setting up saved synapse
locations and activation times
'''
print '***************************'
print 'creating saved network and'
print 'activating synapses with saved times'
print '***************************'
weights = None
locations = None
if synWeightName:
weights, locations = reader.read_synapse_weight_file(synWeightName)
synInfo = reader.read_synapse_activation_file(synInfoName)
synTypes = synInfo.keys()
for synType in synTypes:
print 'Creating synapses and activation times for cell type %s' % synType
synParameters = self.nwParam[synType].synapses
for receptorType in synParameters.receptors.keys():
if synParameters.receptors[receptorType].has_key('weightDistribution'):
weightStr = synParameters.receptors[receptorType].weightDistribution
print '\tAttached %s receptor with weight distribution %s' % (receptorType, weightStr)
else:
print '\tAttached %s receptor with weight distribution uniform' % (receptorType)
# for syn in synInfo[synType]:
for i in range(len(synInfo[synType])):
syn = synInfo[synType][i]
synID, secID, ptID, synTimes, somaDist = syn
# print '\tactivating synapse of type %s' % synType
# print '\tsecID: %d' % secID
# print '\tptID: %d' % ptID
# print '\ttimes: %s' % ','.join([str(t) for t in synTimes])
newCell = PointCell(synTimes)
newCell.play()
if not self.cells.has_key(synType):
self.cells[synType] = []
self.cells[synType].append(newCell)
synx = self.postCell.sections[secID].relPts[ptID]
newSyn = self.postCell.add_synapse(secID, ptID, synx, synType)
if weights:
newSyn.weight = weights[synType][synID]
#===========================================================
# testLoc = locations[synType][synID]
# if testLoc[0] != secID or testLoc[1] != ptID:
# errstr = 'secID %d != secID %d --- ptID %d != ptID %d' % (testLoc[0], secID, testLoc[1], ptID)
# raise RuntimeError(errstr)
#===========================================================
else:
for recepStr in synParameters.receptors.keys():
receptor = synParameters.receptors[recepStr]
self._assign_synapse_weights(receptor, recepStr, newSyn)
activate_functional_synapse(newSyn, self.postCell, newCell, synParameters)
print '***************************'
print 'network complete!'
print '***************************'
def reconnect_network(self):
'''
Public interface
used for re-configuring functional connectivity
'''
print '***************************'
print 're-configuring network'
print '***************************'
self._activate_presyn_cells()
self._connect_functional_synapses()
self._connect_spike_trains()
print '***************************'
print 'network complete!'
print '***************************'
def create_functional_realization(self):
'''
Public interface:
used for creating fixed functional connectivity.
Give this functional realization a (somewhat) unique name!!!!
then save it at the same location as the anatomical realization,
and create a network parameter file with the anatomical and
corresponding functional realizations already in it
IMPORTANT: assumes path names to anatomical realization files
work from the working directory! so should be correct relative, or
preferably absolute paths.
Saves parameter file in working directory.
'''
allParam = self.nwParam
self.nwParam = allParam.network
self._assign_anatomical_synapses()
self._create_presyn_cells()
functionalMap = self._create_functional_connectivity_map()
id1 = time.strftime('%Y%m%d-%H%M')
id2 = str(os.getpid())
for synType in functionalMap:
tmpName = self.nwParam[synType].synapses.distributionFile
splitName = tmpName.split('/')
anatomicalID = splitName[-1]
outName = tmpName[:-4]
outName += '_functional_map_%s_%s.con' % (id1, id2)
writer.write_functional_realization_map(outName, functionalMap[synType], anatomicalID)
allParam.network[synType].synapses.connectionFile = outName
paramName = allParam.info.name
paramName += '_functional_map_%s_%s.param' % (id1, id2)
allParam.info.name += '_functional_map_%s_%s' % (id1, id2)
allParam.save(paramName)
def re_init_network(self):
for synType in self.cells.keys():
for cell in self.cells[synType]:
cell.turn_off()
def _assign_anatomical_synapses(self):
'''
Creates anatomical synapses. This should be done first.
'''
for preType in self.nwParam.keys():
print 'mapping anatomical synapse locations for cell type %s' % preType
synapseFName = self.nwParam[preType].synapses.distributionFile
synDist = reader.read_synapse_realization(synapseFName)
mapper = SynapseMapper(self.postCell, synDist)
mapper.map_synapse_realization()
print '---------------------------'
def _create_presyn_cells(self):
'''
Creates presynaptic cells.
Should be done after creating anatomical synapses.
'''
for synType in self.nwParam.keys():
if self.nwParam[synType].celltype == 'pointcell':
nrOfCells = self.nwParam[synType].cellNr
print 'creating %d PointCells for cell type %s' % (nrOfCells, synType)
self.cells[synType] = [PointCell() for n in xrange(nrOfCells)]
elif self.nwParam[synType].celltype == 'spiketrain':
nrOfSyns = len(self.postCell.synapses[synType])
print 'creating %d SpikeTrains for cell type %s' % (nrOfSyns, synType)
self.cells[synType] = [SpikeTrain() for n in xrange(nrOfSyns)]
else:
errstr = 'Spike source \"%s\" for cell type %s not implemented!' % (self.nwParam[synType].celltype, synType)
raise NotImplementedError(errstr)
for receptorType in self.nwParam[synType].synapses.receptors.keys():
if self.nwParam[synType].synapses.receptors[receptorType].has_key('weightDistribution'):
weightStr = self.nwParam[synType].synapses.receptors[receptorType].weightDistribution
print '\tAttached %s receptor with weight distribution %s' % (receptorType, weightStr)
else:
print '\tAttached %s receptor with weight distribution uniform' % (receptorType)
print '---------------------------'
def _activate_presyn_cells(self):
'''
Activates presynaptic cells.
Should be done after creating presynaptic cells.
TODO: PointCells are only useable with one spike currently.
'''
for synType in self.nwParam.keys():
if self.nwParam[synType].celltype == 'pointcell':
nrOfCells = self.nwParam[synType].cellNr
active, = np.where(np.random.uniform(size=nrOfCells) < self.nwParam[synType].activeFrac)
try:
dist = self.nwParam[synType].distribution
except AttributeError:
print 'WARNING: Could not find attribute \"distribution\" for \"pointcell\" of cell type %s.' % synType
print ' Support of \"pointcell\" without this attribute is deprecated.'
dist = 'normal'
if dist == 'normal':
mean = self.nwParam[synType].spikeT
sigma = self.nwParam[synType].spikeWidth
try:
offset = self.nwParam[synType].offset
except AttributeError:
print 'WARNING: Could not find attribute \"offset\" for \"pointcell\" of cell type %s.' % synType
print ' Support of \"pointcell\" without this attribute is deprecated.'
offset = 10.0
spikeTimes = offset + mean + sigma*np.random.randn(len(active))
elif dist == 'uniform':
window = self.nwParam[synType].window
offset = self.nwParam[synType].offset
spikeTimes = offset + window*np.random.rand(len(active))
elif dist == 'lognormal':
mu = self.nwParam[synType].mu
sigma = self.nwParam[synType].sigma
offset = self.nwParam[synType].offset
spikeTimes = offset + np.random.lognormal(mu, sigma, len(active))
else:
errstr = 'Unknown spike time distribution: %s' % dist
raise RuntimeError(errstr)
print 'initializing spike times for cell type %s' % (synType)
for i in range(len(active)):
if spikeTimes[i] < 0.1:
spikeTimes[i] = 0.1
self.cells[synType][active[i]].append(spikeTimes[i])
self.cells[synType][active[i]].play()
# print 'Presynaptic cell %d active at time %.1f' % (i+1, spikeTimes[i])
elif self.nwParam[synType].celltype == 'spiketrain':
interval = self.nwParam[synType].interval
noise = 1.0
start = 0.0
stop = -1.0
nSpikes = None
try:
noise = self.nwParam[synType].noise
start = self.nwParam[synType].start
except AttributeError:
print 'WARNING: Could not find attributes \"noise\" or \"start\" for \"spiketrain\" of cell type %s.' % synType
print ' Support of \"spiketrains\" without these attributes is deprecated.'
# optional argument: nr. of spikes
try:
nSpikes = self.nwParam[synType].nspikes
except AttributeError:
pass
if self.simParam is not None:
stop = self.simParam.tStop
print 'initializing spike trains with mean rate %.2f Hz for cell type %s' % (1000.0/interval, synType)
for cell in self.cells[synType]:
cell.set_interval(interval)
cell.set_noise(noise)
cell.set_start(start)
cell.set_stop(stop)
cell.play(nSpikes)
print '---------------------------'
def _connect_functional_synapses(self):
'''
Connects anatomical synapses to spike
generators (PointCells) according to physiological
and/or anatomical constraints on connectivity
(i.e., convergence of presynaptic cell type)
'''
synapses = self.postCell.synapses
for synType in self.nwParam.keys():
if not self.nwParam[synType].celltype == 'pointcell':
continue
print 'setting up functional connectivity for cell type %s' % (synType)
activeSyn = 0
connectedCells = set()
nrPreCells = len(self.cells[synType])
convergence = self.nwParam[synType].convergence
# array with indices of presynaptic cells connected to postsynaptic cell
connected = []
# if there are synapses there have to be presynaptic neurons...
while not len(connected):
connected, = np.where(np.random.uniform(size=nrPreCells) < convergence)
# array with indices of presynaptic cell assigned to each synapse
# each connected presynaptic cell has at least 1 synapse by definition
if len(synapses[synType]) < len(connected):
# this should not be the anatomical reality, but for completeness...
connectionIndex = np.random.randint(len(connected), size=len(synapses[synType]))
else:
connectionIndex = list(np.random.permutation(len(connected)))
for i in range(len(connected), len(synapses[synType])):
connectionIndex.append(np.random.randint(len(connected)))
for i in range(len(connectionIndex)):
con = connected[connectionIndex[i]]
preSynCell = self.cells[synType][con]
connectedCells.add(con)
syn = synapses[synType][i]
synParameters = self.nwParam[synType].synapses
for recepStr in synParameters.receptors.keys():
receptor = synParameters.receptors[recepStr]
self._assign_synapse_weights(receptor, recepStr, syn)
if preSynCell.is_active():
if not syn.pruned:
activate_functional_synapse(syn, self.postCell, preSynCell, synParameters)
if syn.is_active():
activeSyn += 1
preSynCell._add_synapse_pointer(syn)
self.connected_cells[synType] = connectedCells
print ' connected cells: %d' % len(connectedCells)
print ' active %s synapses: %d' % (synType, activeSyn)
print '---------------------------'
def _create_functional_connectivity_map(self):
'''
Connects anatomical synapses to spike
generators (PointCells) according to physiological
and/or anatomical constraints on connectivity
(i.e., convergence of presynaptic cell type).
Used to create fixed functional realization.
Returns list of functional connections, where
each functional connection is a tuple
(cell type, presynaptic cell index, synapse index).
cell type - string used for indexing point cells and synapses
presynaptic cell index - index of cell in list self.cells[cell type]
synapse index - index of synapse in list self.postCell.synapses[cell type]
'''
# visTest = {} # dict holding (cell type, cell, synapse) pairs for simple visualization test
functionalMap = {}
synapses = self.postCell.synapses
for synType in self.nwParam.keys():
if not self.nwParam[synType].celltype == 'pointcell':
continue
print 'creating functional connectivity map for cell type %s' % (synType)
nrPreCells = len(self.cells[synType])
convergence = self.nwParam[synType].convergence
# array with indices of presynaptic cells connected to postsynaptic cell
connected = []
# if there are synapses there have to be presynaptic neurons...
while not len(connected):
connected, = np.where(np.random.uniform(size=nrPreCells) < convergence)
# array with indices of presynaptic cell assigned to each synapse
# each connected presynaptic cell has at least 1 synapse by definition
if len(synapses[synType]) < len(connected):
# this should not be the anatomical reality, but for completeness...
connectionIndex = np.random.randint(len(connected), size=len(synapses[synType]))
else:
connectionIndex = list(np.random.permutation(len(connected)))
for i in range(len(connected), len(synapses[synType])):
connectionIndex.append(np.random.randint(len(connected)))
for i in range(len(connectionIndex)):
con = connected[connectionIndex[i]]
funCon = (synType, con, i)
if not functionalMap.has_key(synType):
functionalMap[synType] = []
functionalMap[synType].append(funCon)
# if synType not in visTest.keys():
# visTest[synType] = []
# visTest[synType].append((synType, con, i))
# functional_connectivity_visualization(visTest, self.postCell)
return functionalMap
def _map_functional_realization(self, weights=None):
'''
Connects anatomical synapses to spike
generators (PointCells) according to functional
realization file.
'''
# visTest = {} # dict holding (cell type, cell, synapse) pairs for simple visualization test
synapses = self.postCell.synapses
for synType in self.nwParam.keys():
if not self.nwParam[synType].celltype == 'pointcell':
continue
print 'setting up functional connectivity for cell type %s' % (synType)
activeSyn = 0
connectedCells = set()
funcMapName = self.nwParam[synType].synapses.connectionFile
connections, anatomicalID = reader.read_functional_realization_map(funcMapName)
functionalMap = connections[synType]
anatomicalRealizationName = self.nwParam[synType].synapses.distributionFile.split('/')[-1]
if anatomicalID != anatomicalRealizationName:
errstr = 'Functional mapping does not correspond to anatomical realization %s' % anatomicalRealizationName
raise RuntimeError(errstr)
for con in functionalMap:
cellType, cellID, synID = con
if cellType != synType:
errstr = 'Functional map cell type %s does not correspond to synapse type %s' % (cellType, synType)
raise RuntimeError(errstr)
preSynCell = self.cells[synType][cellID]
connectedCells.add(cellID)
# if cellType not in visTest.keys():
# visTest[cellType] = []
# visTest[cellType].append((cellType, cellID, synID))
syn = synapses[synType][synID]
synParameters = self.nwParam[synType].synapses
if weights:
syn.weight = weights[synType][synID]
else:
for recepStr in synParameters.receptors.keys():
receptor = synParameters.receptors[recepStr]
self._assign_synapse_weights(receptor, recepStr, syn)
if preSynCell.is_active():
if not syn.pruned:
activate_functional_synapse(syn, self.postCell, preSynCell, synParameters)
if syn.is_active():
activeSyn += 1
preSynCell._add_synapse_pointer(syn)
self.connected_cells[synType] = connectedCells
print ' connected cells: %d' % len(connectedCells)
print ' active %s synapses: %d' % (synType, activeSyn)
print '---------------------------'
# functional_connectivity_visualization(visTest, self.postCell)
def _connect_spike_trains(self, weights=None, change=None):
'''
Connects spike generators with given
mean spike rate (SpikeTrains) to synapse locations.
All synapses are independent.
'''
synapses = self.postCell.synapses
if change is not None:
tChange, changeParam = change
for synType in self.nwParam.keys():
if not self.nwParam[synType].celltype == 'spiketrain':
continue
print 'activating spike trains for cell type %s' % (synType)
for i in xrange(len(synapses[synType])):
syn = synapses[synType][i]
synParameters = self.nwParam[synType].synapses
preSynCell = self.cells[synType][i]
if weights:
syn.weight = weights[synType][i]
else:
for recepStr in synParameters.receptors.keys():
receptor = synParameters.receptors[recepStr]
self._assign_synapse_weights(receptor, recepStr, syn)
if change is None:
activate_functional_synapse(syn, self.postCell, preSynCell, synParameters)
else:
activate_functional_synapse(syn, self.postCell, preSynCell, synParameters, tChange, changeParam[synType].synapses)
print '---------------------------'
def _assign_synapse_weights(self, receptor, recepStr, syn):
if syn.weight is None:
syn.weight = {}
if not syn.weight.has_key(recepStr):
syn.weight[recepStr] = []
if receptor.has_key("weightDistribution"):
if receptor["weightDistribution"] == "lognormal":
if isinstance(receptor.weight, Sequence):
for i in range(len(receptor.weight)):
mean = receptor.weight[i]
std = mean**2
sigma = np.sqrt(np.log(1+std**2/mean**2))
mu = np.log(mean) - 0.5*sigma**2
gmax = np.random.lognormal(mu, sigma)
syn.weight[recepStr].append(gmax)
#print ' weight[%d] = %.2f' % (i, syn.weight[recepStr][-1])
else:
mean = receptor.weight
std = mean**2
sigma = np.sqrt(np.log(1+std**2/mean**2))
mu = np.log(mean) - 0.5*sigma**2
gmax = np.random.lognormal(mu, sigma)
syn.weight[recepStr].append(gmax)
#print ' weight = %.2f' % (syn.weight[recepStr][-1])
else:
distStr = receptor["weightDistribution"]
errstr = 'Synaptic weight distribution %s not implemented yet!' % distStr
raise NotImplementedError(errstr)
else:
if isinstance(receptor.weight, Sequence):
for i in range(len(receptor.weight)):
syn.weight[recepStr].append(receptor.weight[i])
else:
syn.weight[recepStr].append(receptor.weight)
def activate_functional_synapse(syn, cell, preSynCell, synParameters, tChange=None, synParametersChange=None):
'''Default method to activate single synapse.
Currently, this implementation expects all presynaptic spike
times to be pre-computed; can thus not be used in recurrent
network models at this point.'''
releaseTimes = []
if synParameters.has_key('releaseProb'):
prel = synParameters.releaseProb
if tChange is not None:
prelChange = synParametersChange.releaseProb
for t in preSynCell.spikeTimes:
if tChange is not None:
if t >= tChange:
if np.random.rand() < prelChange:
releaseTimes.append(t)
continue
if np.random.rand() < prel:
releaseTimes.append(t)
else:
releaseTimes = preSynCell.spikeTimes[:]
if not len(releaseTimes):
return
releaseSite = PointCell(releaseTimes)
releaseSite.play()
receptors = synParameters.receptors
syn.activate_hoc_syn(releaseSite, preSynCell, cell, receptors)
# set properties for all receptors here
for recepStr in receptors.keys():
recep = receptors[recepStr]
for param in recep.parameter.keys():
# try treating parameters as hoc range variables,
# then as hoc global variables
try:
paramStr = 'syn.receptors[\'' + recepStr + '\'].'
paramStr += param + '=' + str(recep.parameter[param])
exec(paramStr)
except LookupError:
paramStr = param + '_' + recepStr + '='
paramStr += str(recep.parameter[param])
h(paramStr)
def functional_connectivity_visualization(functionalMap, cell):
nrL4ssCells = 3168
nrL1Cells = 104
L4origin = np.array([-150,-150,0])
# L4colSpacing = np.array([1,0,0])
# L4rowSpacing = np.array([0,30,0])
L4rowSpacing = np.array([1,0,0])
L4colSpacing = np.array([0,30,0])
L1origin = np.array([-550,-150,700])
L1colSpacing = np.array([30,0,0])
L1rowSpacing = np.array([0,30,0])
rows = 10
L4cols = nrL4ssCells//rows
L1cols = nrL1Cells//rows
L4grid = {}
L1grid = {}
for i in range(nrL4ssCells):
# row = i//rows
# col = i - row*L4cols
col = i//L4cols
row = i - col*L4cols
# print 'row = %d' % row
# print 'col = %d' % col
cellPos = L4origin + row*L4rowSpacing + col*L4colSpacing
L4grid[i] = cellPos
for i in range(nrL1Cells):
row = i//rows
col = i - row*L1cols
cellPos = L1origin + row*L1rowSpacing + col*L1colSpacing
L1grid[i] = cellPos
L4map = {}
L1map = {}
for con in functionalMap['L4ssD2']:
cellType, cellID, synID = con
synPos = cell.synapses[cellType][synID].coordinates
if cellID not in L4map.keys():
L4map[cellID] = []
L4map[cellID].append((L4grid[cellID], synPos))
for i in range(nrL4ssCells):
if i not in L4map.keys():
L4map[i] = [(L4grid[i], L4grid[i])]
for con in functionalMap['L1D1']:
cellType, cellID, synID = con
synPos = cell.synapses[cellType][synID].coordinates
if cellID not in L1map.keys():
L1map[cellID] = []
L1map[cellID].append((L1grid[cellID], synPos))
for i in range(nrL1Cells):
if i not in L1map.keys():
L1map[i] = [(L1grid[i], L1grid[i])]
writer.write_functional_map('L4ss_func_map3.am', L4map)
writer.write_functional_map('L1_func_map3.am', L1map)