'''

written by Robert Egger (robert.egger@tuebingen.mpg.de)
(c) 2013-2015 Max Planck Society
'''

import os
import time
from collections import Sequence
import numpy as np
from cell import PointCell, SpikeTrain
import reader
import writer
from synapse_mapper import SynapseMapper
#import synapse
from neuron import h

class NetworkMapper:
    '''
    Handles connectivity of presynaptic populations
    to multi-compartmental neuron model.
    Also handles activity of presynaptic populations.
    '''
    
    def __init__(self, postCell, nwParam, simParam=None):
        '''
        dictionary holding all presynaptic cells
        ordered by cell type
        self.cells = {}
        
        dictionary holding indices of
        all active presynaptic cells
        ordered by cell type
        self.connected_cells = {}
        
        reference to postsynaptic (multi-compartment) cell model
        self.postCell = postCell
        
        network parameter set (i.e., parameters.network.pre)
        self.nwParam = nwParam
        '''
        self.cells = {}
        self.connected_cells = {}
        self.postCell = postCell
        self.nwParam = nwParam
        self.simParam = simParam
    
    def create_network(self, synWeightName=None, change=None):
        '''
        Public interface
        Only call this function, it automatically
        takes care of setting up the network correctly.
        '''
        print '***************************'
        print 'creating network'
        print '***************************'
        self._assign_anatomical_synapses()
        self._create_presyn_cells()
        self._activate_presyn_cells()
        self._connect_functional_synapses()
        spikeTrainWeights = None
        if synWeightName:
            spikeTrainWeights, locations = reader.read_synapse_weight_file(synWeightName)
        # awkward temporary implementation of prelease change during simulation time window
        self._connect_spike_trains(spikeTrainWeights, change)
        print '***************************'
        print 'network complete!'
        print '***************************'
    
    def create_saved_network(self, synWeightName=None):
        '''
        Public interface
        Used for re-creating network from anatomical
        and functional connection files
        '''
        print '***************************'
        print 'creating saved network'
        print '***************************'
        self._assign_anatomical_synapses()
        self._create_presyn_cells()
        self._activate_presyn_cells()
        weights = None
        if synWeightName:
            weights, locations = reader.read_synapse_weight_file(synWeightName)
        self._map_functional_realization(weights)
        self._connect_spike_trains(weights)
        print '***************************'
        print 'network complete!'
        print '***************************'
    
    def reconnect_saved_synapses(self, synInfoName, synWeightName=None):
        '''
        Public interface
        used for setting up saved synapse
        locations and activation times
        '''
        print '***************************'
        print 'creating saved network and'
        print 'activating synapses with saved times'
        print '***************************'
        weights = None
        locations = None
        if synWeightName:
            weights, locations = reader.read_synapse_weight_file(synWeightName)
        synInfo = reader.read_synapse_activation_file(synInfoName)
        synTypes = synInfo.keys()
        for synType in synTypes:
            print 'Creating synapses and activation times for cell type %s' % synType
            synParameters = self.nwParam[synType].synapses
            for receptorType in synParameters.receptors.keys():
                if synParameters.receptors[receptorType].has_key('weightDistribution'):
                    weightStr = synParameters.receptors[receptorType].weightDistribution
                    print '\tAttached %s receptor with weight distribution %s' % (receptorType, weightStr)
                else:
                    print '\tAttached %s receptor with weight distribution uniform' % (receptorType)
#            for syn in synInfo[synType]:
            for i in range(len(synInfo[synType])):
                syn = synInfo[synType][i]
                synID, secID, ptID, synTimes, somaDist = syn
#                print '\tactivating synapse of type %s' % synType
#                print '\tsecID: %d' % secID
#                print '\tptID: %d' % ptID
#                print '\ttimes: %s' % ','.join([str(t) for t in synTimes])
                newCell = PointCell(synTimes)
                newCell.play()
                if not self.cells.has_key(synType):
                    self.cells[synType] = []
                self.cells[synType].append(newCell)
                synx = self.postCell.sections[secID].relPts[ptID]
                newSyn = self.postCell.add_synapse(secID, ptID, synx, synType)
                if weights:
                    newSyn.weight = weights[synType][synID]
                    #===========================================================
                    # testLoc = locations[synType][synID]
                    # if testLoc[0] != secID or testLoc[1] != ptID:
                    #    errstr = 'secID %d != secID %d --- ptID %d != ptID %d' % (testLoc[0], secID, testLoc[1], ptID)
                    #    raise RuntimeError(errstr)
                    #===========================================================
                else:
                    for recepStr in synParameters.receptors.keys():
                        receptor = synParameters.receptors[recepStr]
                        self._assign_synapse_weights(receptor, recepStr, newSyn)
                activate_functional_synapse(newSyn, self.postCell, newCell, synParameters)
        print '***************************'
        print 'network complete!'
        print '***************************'
    
    def reconnect_network(self):
        '''
        Public interface
        used for re-configuring functional connectivity
        '''
        print '***************************'
        print 're-configuring network'
        print '***************************'
        self._activate_presyn_cells()
        self._connect_functional_synapses()
        self._connect_spike_trains()
        print '***************************'
        print 'network complete!'
        print '***************************'
    
    def create_functional_realization(self):
        '''
        Public interface:
        used for creating fixed functional connectivity.
        
        Give this functional realization a (somewhat) unique name!!!!     
        then save it at the same location as the anatomical realization,  
        and create a network parameter file with the anatomical and       
        corresponding functional realizations already in it               
        IMPORTANT: assumes path names to anatomical realization files     
        work from the working directory! so should be correct relative, or
        preferably absolute paths.
        Saves parameter file in working directory.
        '''
        allParam = self.nwParam
        self.nwParam = allParam.network
        self._assign_anatomical_synapses()
        self._create_presyn_cells()
        functionalMap = self._create_functional_connectivity_map()
        id1 = time.strftime('%Y%m%d-%H%M')
        id2 = str(os.getpid())
        for synType in functionalMap:
            tmpName = self.nwParam[synType].synapses.distributionFile
            splitName = tmpName.split('/')
            anatomicalID = splitName[-1]
            outName = tmpName[:-4]
            outName += '_functional_map_%s_%s.con' % (id1, id2)
            writer.write_functional_realization_map(outName, functionalMap[synType], anatomicalID)
            allParam.network[synType].synapses.connectionFile = outName
        paramName = allParam.info.name
        paramName += '_functional_map_%s_%s.param' % (id1, id2)
        allParam.info.name += '_functional_map_%s_%s' % (id1, id2)
        allParam.save(paramName)
    
    def re_init_network(self):
        for synType in self.cells.keys():
            for cell in self.cells[synType]:
                cell.turn_off()
    
    def _assign_anatomical_synapses(self):
        '''
        Creates anatomical synapses. This should be done first.
        '''
        for preType in self.nwParam.keys():
            print 'mapping anatomical synapse locations for cell type %s' % preType
            synapseFName = self.nwParam[preType].synapses.distributionFile
            synDist = reader.read_synapse_realization(synapseFName)
            mapper = SynapseMapper(self.postCell, synDist)
            mapper.map_synapse_realization()
        print '---------------------------'
    
    def _create_presyn_cells(self):
        '''
        Creates presynaptic cells.
        Should be done after creating anatomical synapses.
        '''
        for synType in self.nwParam.keys():
            if self.nwParam[synType].celltype == 'pointcell':
                nrOfCells = self.nwParam[synType].cellNr
                print 'creating %d PointCells for cell type %s' % (nrOfCells, synType)
                self.cells[synType] = [PointCell() for n in xrange(nrOfCells)]
            elif self.nwParam[synType].celltype == 'spiketrain':
                nrOfSyns = len(self.postCell.synapses[synType])
                print 'creating %d SpikeTrains for cell type %s' % (nrOfSyns, synType)
                self.cells[synType] = [SpikeTrain() for n in xrange(nrOfSyns)]
            else:
                errstr = 'Spike source \"%s\" for cell type %s not implemented!' % (self.nwParam[synType].celltype, synType)
                raise NotImplementedError(errstr)
            for receptorType in self.nwParam[synType].synapses.receptors.keys():
                if self.nwParam[synType].synapses.receptors[receptorType].has_key('weightDistribution'):
                    weightStr = self.nwParam[synType].synapses.receptors[receptorType].weightDistribution
                    print '\tAttached %s receptor with weight distribution %s' % (receptorType, weightStr)
                else:
                    print '\tAttached %s receptor with weight distribution uniform' % (receptorType)
        print '---------------------------'
    
    def _activate_presyn_cells(self):
        '''
        Activates presynaptic cells.
        Should be done after creating presynaptic cells.
        TODO: PointCells are only useable with one spike currently.
        '''
        for synType in self.nwParam.keys():
            if self.nwParam[synType].celltype == 'pointcell':
                nrOfCells = self.nwParam[synType].cellNr
                active, = np.where(np.random.uniform(size=nrOfCells) < self.nwParam[synType].activeFrac)
                try:
                    dist = self.nwParam[synType].distribution
                except AttributeError:
                    print 'WARNING: Could not find attribute \"distribution\" for \"pointcell\" of cell type %s.' % synType
                    print '         Support of \"pointcell\" without this attribute is deprecated.'
                    dist = 'normal'
                if dist == 'normal':
                    mean = self.nwParam[synType].spikeT
                    sigma = self.nwParam[synType].spikeWidth
                    try:
                        offset = self.nwParam[synType].offset
                    except AttributeError:
                        print 'WARNING: Could not find attribute \"offset\" for \"pointcell\" of cell type %s.' % synType
                        print '         Support of \"pointcell\" without this attribute is deprecated.'
                        offset = 10.0
                    spikeTimes = offset + mean + sigma*np.random.randn(len(active))
                elif dist == 'uniform':
                    window = self.nwParam[synType].window
                    offset = self.nwParam[synType].offset
                    spikeTimes = offset + window*np.random.rand(len(active))
                elif dist == 'lognormal':
                    mu = self.nwParam[synType].mu
                    sigma = self.nwParam[synType].sigma
                    offset = self.nwParam[synType].offset
                    spikeTimes = offset + np.random.lognormal(mu, sigma, len(active))
                else:
                    errstr = 'Unknown spike time distribution: %s' % dist
                    raise RuntimeError(errstr)
                print 'initializing spike times for cell type %s' % (synType)
                for i in range(len(active)):
                    if spikeTimes[i] < 0.1:
                        spikeTimes[i] = 0.1
                    self.cells[synType][active[i]].append(spikeTimes[i])
                    self.cells[synType][active[i]].play()
#                    print 'Presynaptic cell %d active at time %.1f' % (i+1, spikeTimes[i])
            elif self.nwParam[synType].celltype == 'spiketrain':
                interval = self.nwParam[synType].interval
                noise = 1.0
                start = 0.0
                stop = -1.0
                nSpikes = None
                try:
                    noise = self.nwParam[synType].noise
                    start = self.nwParam[synType].start
                except AttributeError:
                    print 'WARNING: Could not find attributes \"noise\" or \"start\" for \"spiketrain\" of cell type %s.' % synType
                    print '         Support of \"spiketrains\" without these attributes is deprecated.'
#                optional argument: nr. of spikes
                try:
                    nSpikes = self.nwParam[synType].nspikes
                except AttributeError:
                    pass
                if self.simParam is not None:
                    stop = self.simParam.tStop
                print 'initializing spike trains with mean rate %.2f Hz for cell type %s' % (1000.0/interval, synType)
                for cell in self.cells[synType]:
                    cell.set_interval(interval)
                    cell.set_noise(noise)
                    cell.set_start(start)
                    cell.set_stop(stop)
                    cell.play(nSpikes)
        print '---------------------------'
    
    def _connect_functional_synapses(self):
        '''
        Connects anatomical synapses to spike
        generators (PointCells) according to physiological
        and/or anatomical constraints on connectivity
        (i.e., convergence of presynaptic cell type)
        '''
        synapses = self.postCell.synapses
        for synType in self.nwParam.keys():
            if not self.nwParam[synType].celltype == 'pointcell':
                continue
            print 'setting up functional connectivity for cell type %s' % (synType)
            activeSyn = 0
            connectedCells = set()
            nrPreCells = len(self.cells[synType])
            convergence = self.nwParam[synType].convergence
            # array with indices of presynaptic cells connected to postsynaptic cell
            connected = []
            # if there are synapses there have to be presynaptic neurons...
            while not len(connected):
                connected, = np.where(np.random.uniform(size=nrPreCells) < convergence)
            # array with indices of presynaptic cell assigned to each synapse
            # each connected presynaptic cell has at least 1 synapse by definition
            if len(synapses[synType]) < len(connected):
                # this should not be the anatomical reality, but for completeness...
                connectionIndex = np.random.randint(len(connected), size=len(synapses[synType]))
            else:
                connectionIndex = list(np.random.permutation(len(connected)))
                for i in range(len(connected), len(synapses[synType])):
                    connectionIndex.append(np.random.randint(len(connected)))
            for i in range(len(connectionIndex)):
                con = connected[connectionIndex[i]]
                preSynCell = self.cells[synType][con]
                connectedCells.add(con)
                syn = synapses[synType][i]
                synParameters = self.nwParam[synType].synapses
                for recepStr in synParameters.receptors.keys():
                    receptor = synParameters.receptors[recepStr]
                    self._assign_synapse_weights(receptor, recepStr, syn)
                if preSynCell.is_active():
                    if not syn.pruned:
                        activate_functional_synapse(syn, self.postCell, preSynCell, synParameters)
                    if syn.is_active():
                        activeSyn += 1
                    preSynCell._add_synapse_pointer(syn)
            self.connected_cells[synType] = connectedCells
            print '    connected cells: %d' % len(connectedCells)
            print '    active %s synapses: %d' % (synType, activeSyn)
        print '---------------------------'
    
    def _create_functional_connectivity_map(self):
        '''
        Connects anatomical synapses to spike
        generators (PointCells) according to physiological
        and/or anatomical constraints on connectivity
        (i.e., convergence of presynaptic cell type).
        Used to create fixed functional realization.
        Returns list of functional connections, where
        each functional connection is a tuple
        (cell type, presynaptic cell index, synapse index).
        cell type - string used for indexing point cells and synapses
        presynaptic cell index - index of cell in list self.cells[cell type]
        synapse index - index of synapse in list self.postCell.synapses[cell type]
        '''
#        visTest = {} # dict holding (cell type, cell, synapse) pairs for simple visualization test
        
        functionalMap = {}
        synapses = self.postCell.synapses
        for synType in self.nwParam.keys():
            if not self.nwParam[synType].celltype == 'pointcell':
                continue
            print 'creating functional connectivity map for cell type %s' % (synType)
            nrPreCells = len(self.cells[synType])
            convergence = self.nwParam[synType].convergence
            # array with indices of presynaptic cells connected to postsynaptic cell
            connected = []
            # if there are synapses there have to be presynaptic neurons...
            while not len(connected):
                connected, = np.where(np.random.uniform(size=nrPreCells) < convergence)
            # array with indices of presynaptic cell assigned to each synapse
            # each connected presynaptic cell has at least 1 synapse by definition
            if len(synapses[synType]) < len(connected):
                # this should not be the anatomical reality, but for completeness...
                connectionIndex = np.random.randint(len(connected), size=len(synapses[synType]))
            else:
                connectionIndex = list(np.random.permutation(len(connected)))
                for i in range(len(connected), len(synapses[synType])):
                    connectionIndex.append(np.random.randint(len(connected)))
            for i in range(len(connectionIndex)):
                con = connected[connectionIndex[i]]
                funCon = (synType, con, i)
                if not functionalMap.has_key(synType):
                    functionalMap[synType] = []
                functionalMap[synType].append(funCon)
#                if synType not in visTest.keys():
#                    visTest[synType] = []
#                visTest[synType].append((synType, con, i))
        
#        functional_connectivity_visualization(visTest, self.postCell)
        return functionalMap
    
    def _map_functional_realization(self, weights=None):
        '''
        Connects anatomical synapses to spike
        generators (PointCells) according to functional
        realization file.
        '''
#        visTest = {} # dict holding (cell type, cell, synapse) pairs for simple visualization test
        
        synapses = self.postCell.synapses
        for synType in self.nwParam.keys():
            if not self.nwParam[synType].celltype == 'pointcell':
                continue
            print 'setting up functional connectivity for cell type %s' % (synType)
            activeSyn = 0
            connectedCells = set()
            funcMapName = self.nwParam[synType].synapses.connectionFile
            connections, anatomicalID = reader.read_functional_realization_map(funcMapName)
            functionalMap = connections[synType]
            anatomicalRealizationName = self.nwParam[synType].synapses.distributionFile.split('/')[-1]
            if anatomicalID != anatomicalRealizationName:
                errstr = 'Functional mapping does not correspond to anatomical realization %s' % anatomicalRealizationName
                raise RuntimeError(errstr)
            for con in functionalMap:
                cellType, cellID, synID = con
                if cellType != synType:
                    errstr = 'Functional map cell type %s does not correspond to synapse type %s' % (cellType, synType)
                    raise RuntimeError(errstr)
                preSynCell = self.cells[synType][cellID]
                connectedCells.add(cellID)
#                if cellType not in visTest.keys():
#                    visTest[cellType] = []
#                visTest[cellType].append((cellType, cellID, synID))
                syn = synapses[synType][synID]
                synParameters = self.nwParam[synType].synapses
                if weights:
                    syn.weight = weights[synType][synID]
                else:
                    for recepStr in synParameters.receptors.keys():
                        receptor = synParameters.receptors[recepStr]
                        self._assign_synapse_weights(receptor, recepStr, syn)
                if preSynCell.is_active():
                    if not syn.pruned:
                        activate_functional_synapse(syn, self.postCell, preSynCell, synParameters)
                    if syn.is_active():
                        activeSyn += 1
                    preSynCell._add_synapse_pointer(syn)
            self.connected_cells[synType] = connectedCells
            print '    connected cells: %d' % len(connectedCells)
            print '    active %s synapses: %d' % (synType, activeSyn)
        print '---------------------------'
        
#        functional_connectivity_visualization(visTest, self.postCell)
    
    def _connect_spike_trains(self, weights=None, change=None):
        '''
        Connects spike generators with given
        mean spike rate (SpikeTrains) to synapse locations.
        All synapses are independent.
        '''
        synapses = self.postCell.synapses
        if change is not None:
            tChange, changeParam = change
        for synType in self.nwParam.keys():
            if not self.nwParam[synType].celltype == 'spiketrain':
                continue
            print 'activating spike trains for cell type %s' % (synType)
            for i in xrange(len(synapses[synType])):
                syn = synapses[synType][i]
                synParameters = self.nwParam[synType].synapses
                preSynCell = self.cells[synType][i]
                if weights:
                    syn.weight = weights[synType][i]
                else:
                    for recepStr in synParameters.receptors.keys():
                        receptor = synParameters.receptors[recepStr]
                        self._assign_synapse_weights(receptor, recepStr, syn)
                if change is None:
                    activate_functional_synapse(syn, self.postCell, preSynCell, synParameters)
                else:
                    activate_functional_synapse(syn, self.postCell, preSynCell, synParameters, tChange, changeParam[synType].synapses)
        print '---------------------------'

    def _assign_synapse_weights(self, receptor, recepStr, syn):
        if syn.weight is None:
            syn.weight = {}
        if not syn.weight.has_key(recepStr):
            syn.weight[recepStr] = []
        if receptor.has_key("weightDistribution"):
            if receptor["weightDistribution"] == "lognormal":
                if isinstance(receptor.weight, Sequence):
                    for i in range(len(receptor.weight)):
                        mean = receptor.weight[i]
                        std = mean**2
                        sigma = np.sqrt(np.log(1+std**2/mean**2))
                        mu = np.log(mean) - 0.5*sigma**2
                        gmax = np.random.lognormal(mu, sigma)
                        syn.weight[recepStr].append(gmax)
                        #print '    weight[%d] = %.2f' % (i, syn.weight[recepStr][-1])
                else:
                    mean = receptor.weight
                    std = mean**2
                    sigma = np.sqrt(np.log(1+std**2/mean**2))
                    mu = np.log(mean) - 0.5*sigma**2
                    gmax = np.random.lognormal(mu, sigma)
                    syn.weight[recepStr].append(gmax)
                    #print '    weight = %.2f' % (syn.weight[recepStr][-1])
            else:
                distStr = receptor["weightDistribution"]
                errstr = 'Synaptic weight distribution %s not implemented yet!' % distStr
                raise NotImplementedError(errstr)
        else:
            if isinstance(receptor.weight, Sequence):
                for i in range(len(receptor.weight)):
                    syn.weight[recepStr].append(receptor.weight[i])
            else:
                syn.weight[recepStr].append(receptor.weight)
        

def activate_functional_synapse(syn, cell, preSynCell, synParameters, tChange=None, synParametersChange=None):
    '''Default method to activate single synapse.
    Currently, this implementation expects all presynaptic spike
    times to be pre-computed; can thus not be used in recurrent
    network models at this point.'''
    releaseTimes = []
    if synParameters.has_key('releaseProb'):
        prel = synParameters.releaseProb
        if tChange is not None:
            prelChange = synParametersChange.releaseProb
        for t in preSynCell.spikeTimes:
            if tChange is not None:
                if t >= tChange:
                    if np.random.rand() < prelChange:
                        releaseTimes.append(t)
                    continue
            if np.random.rand() < prel:
                releaseTimes.append(t)
    else:
        releaseTimes = preSynCell.spikeTimes[:]
    if not len(releaseTimes):
        return
    releaseSite = PointCell(releaseTimes)
    releaseSite.play()
    receptors = synParameters.receptors
    syn.activate_hoc_syn(releaseSite, preSynCell, cell, receptors)
#    set properties for all receptors here
    for recepStr in receptors.keys():
        recep = receptors[recepStr]
        for param in recep.parameter.keys():
#            try treating parameters as hoc range variables,
#            then as hoc global variables
            try:
                paramStr = 'syn.receptors[\'' + recepStr + '\'].'
                paramStr += param + '=' + str(recep.parameter[param])
                exec(paramStr)
            except LookupError:
                paramStr = param + '_' + recepStr + '='
                paramStr += str(recep.parameter[param])
                h(paramStr)

def functional_connectivity_visualization(functionalMap, cell):
    nrL4ssCells = 3168
    nrL1Cells = 104
    
    L4origin = np.array([-150,-150,0])
#    L4colSpacing = np.array([1,0,0])
#    L4rowSpacing = np.array([0,30,0])
    L4rowSpacing = np.array([1,0,0])
    L4colSpacing = np.array([0,30,0])
    L1origin = np.array([-550,-150,700])
    L1colSpacing = np.array([30,0,0])
    L1rowSpacing = np.array([0,30,0])
    
    rows = 10
    L4cols = nrL4ssCells//rows
    L1cols = nrL1Cells//rows
    
    L4grid = {}
    L1grid = {}
    
    for i in range(nrL4ssCells):
#        row = i//rows
#        col = i - row*L4cols
        col = i//L4cols
        row = i - col*L4cols
#        print 'row = %d' % row
#        print 'col = %d' % col
        cellPos = L4origin + row*L4rowSpacing + col*L4colSpacing
        L4grid[i] = cellPos
    for i in range(nrL1Cells):
        row = i//rows
        col = i - row*L1cols
        cellPos = L1origin + row*L1rowSpacing + col*L1colSpacing
        L1grid[i] = cellPos
    
    L4map = {}
    L1map = {}
    
    for con in functionalMap['L4ssD2']:
        cellType, cellID, synID = con
        synPos = cell.synapses[cellType][synID].coordinates
        if cellID not in L4map.keys():
            L4map[cellID] = []
        L4map[cellID].append((L4grid[cellID], synPos))
    for i in range(nrL4ssCells):
        if i not in L4map.keys():
            L4map[i] = [(L4grid[i], L4grid[i])]
    for con in functionalMap['L1D1']:
        cellType, cellID, synID = con
        synPos = cell.synapses[cellType][synID].coordinates
        if cellID not in L1map.keys():
            L1map[cellID] = []
        L1map[cellID].append((L1grid[cellID], synPos))
    for i in range(nrL1Cells):
        if i not in L1map.keys():
            L1map[i] = [(L1grid[i], L1grid[i])]
    
    writer.write_functional_map('L4ss_func_map3.am', L4map)
    writer.write_functional_map('L1_func_map3.am', L1map)