# Main file.

execfile('init.py') # sets the parameters

#****************************************************
printtime('************************')
printtime('* Computing (rand=' + '%03d' % randState + ') *')
printtime('************************')


if nearestSpike: # flags useful for nearest spike mode
    alreadyDepressed = zeros([N,M],dtype=int)
    alreadyPotentiated = zeros([N,M],dtype=int)

# for memory issues the simulation is divided in multiple periods
# count is the current period number
count = 0

if computeOutput: # instantiate output neurons
    
    _gmax = asarray(gmax) # faster implementation (does not verify dimensions)
    
    if nearestSpike: # flags useful for nearest spike mode
        _alreadyPotentiated = asarray(alreadyPotentiated)
        _alreadyDepressed = asarray(alreadyDepressed)
        
    # mirrors
    mirror_eqs=''' 
    v:1
    dA_pre/dt=-A_pre/tau_pre : 1
    ''' 
    mirror = NeuronGroup(N, model=mirror_eqs, threshold=0.5, reset=mirror_reset)
      
    #STDP neuron
    if conductanceOutput:
        eqs_neurons='''
        dv/dt=(ge*(Ee-v)+El-v)/taum + sigma*xi/taum**.5 : volt
        dge/dt=-ge/taue : 1
        dA_post/dt=-A_post/tau_post : 1
        '''        
    else:
        eqs_neurons='''
        dv/dt=(ge+El-v)/taum + sigma*xi/taum**.5 : volt
        dge/dt=-ge/taue : volt
        dA_post/dt=-A_post/tau_post : 1
        '''
    
    neurons_cr = CustomRefractoriness(neurons_reset,period=refractoryPeriod,state='v')
    
    if poissonOutput: # stochastic spike generation
        neurons=NeuronGroup(M,model=eqs_neurons,threshold=PoissonThreshold(),reset=neurons_cr)
    else: # deterministic spike generation
        neurons=NeuronGroup(M,model=eqs_neurons,threshold=vt,reset=neurons_cr)    
   
    #connections
    synapses=Connection(mirror,neurons,'ge',structure='dense')
    seed(randState)
    if useSavedWeight and os.path.exists(os.path.join('..','data','weight.'+'%03d' % (randState)+'.mat')):
        print 'Loading previously dumped weight'    
        tmp=loadmat(os.path.join('..','data','weight.'+'%03d' % (randState)+'.mat'))
        tmp=tmp['weight']     
        initialWeight = zeros([N,M])
        if M>1:
            for j in range(M):
                initialWeight[:,j] = tmp[:,j]*gmax[j]
        else:
            initialWeight[:,0] = tmp[:]*gmax[0]
        del tmp
    else: # start from random synaptic weights
        initialWeight = zeros([N,M])
        for i in range(N):
            initialWeight[i,:] = initialWeight_min + rand(1)*(initialWeight_max-initialWeight_min)
        if initialWeight.max() > min(gmax):
            print '***********************************************************'
            print '* WARNING: Initial weight > gmax. This should not happen. *'
            print '***********************************************************'
    synapses.connect(mirror,neurons,initialWeight)
    synapses.compress() 
    _synW = asarray(synapses.W)

    # affect initial values
    neurons.v_ = vr+rand(1)*ones(len(neurons))*(vt-vr)
    
    neurons.A_post_=0*volt
    neurons.ge_=0*volt
    mirror.A_pre_=0*volt
     
startTime = timeOffset;

if recomputeSpikeList: # input layer is computed. Need to instantiate corresponding neurons
    if poissonInput:
        input=PoissonGroup(N,zeros(N)) 
    else:    
        if sum(a)==0*namp:
            input_eqs=''' 
            dv/dt = (-v + El + R*I )/taum + sigma*xi/taum**.5 : volt
            I : amp
            ''' 
        else:
            # sinusoidal oscillation
            input_eqs=''' 
            dv/dt = (-v + El + R * ( I + aa*(sin(t*2*pi*oscilFreq-pi))))/taum + sigma*xi/taum**.5 : volt
            I : amp
            aa : amp
            ''' 
#            # sawtooth oscillation
#            input_eqs=''' 
#            dv/dt = (-v + El + R * ( I + a*((t-floor(t*oscilFreq)*1/oscilFreq)*oscilFreq-1.0)))/taum + sigma*xi/taum**.5 : volt
#            I : amp
#            ''' 
        input=NeuronGroup(N,model=input_eqs,threshold=vt,reset=vr,refractory=refractoryPeriod) 
        # affect initial values
        seed(randState)
        input.v_ = vr+rand(N)*(vt-vr)
#        input.v_ = El
        input.aa_ = a # amplitude of the oscillatory input current

    if computeOutput: # connect input layer to its mirror
        C_input_mirror = IdentityConnection(input, mirror)
        
        
    if useReset: # load reset times from reset.###.mat
        reset=loadmat(os.path.join('..','data','reset.'+'%03d' % (randState)+'.mat'))
        reset=reset['resetTimes']
#        # to artificially double the reset frequency:
#        reset=.5*concatenate([reset,reset[-1]+250e-3+reset])
    else:
        reset=[Inf]
    rcursor=0 # cursor for reset array
    
    
    I = zeros(N)*namp # just to initialize the array
    
    
    # open input current file
    f = open(os.path.join('..','data','inputValues.'+'%03d' % (randState)+'.txt'),'r')
    
    printtime('Starting (recompute spike list)')
    
    localStartTime = time()*second

    inputSpike = SpikeMonitor(input,True) # input spikes need to be monitored all the time with this kind of computation

    for l in f: # iterate on file lines
    
        # read time
        endTime = double(l[0:9])*second/pbTimeCompression

        # imposed end time
        endTime = min(imposedEnd,endTime)
    
        # monitors
        if endTime>=monitorTime and not isMonitoring:
            print '********************'            
            print '* Start monitoring *'            
            print '********************'            
            isMonitoring = True
#            if monitorInput:
#                inputSpike = SpikeMonitor(input,True)
            if monitorInputPot:
                inputPot = StateMonitor(input,'v',record=True)
            if monitorOutput:
                outputSpike = SpikeMonitor(neurons,True)
            if monitorPot:
                pot = StateMonitor(neurons,'v',record=True)
            if monitorCurrent:
                current = StateMonitor(neurons,'ge',record=True)
            if monitorRate:
                rate = []
                for i in range(M):
                    rate.append(PopulationRateMonitor(neurons[i],bin=10000*ms))
    
        # read input currents
        for i in range(N):
            I[i]= Imin[i] + double(l[11+i*8:11+(i+1)*8-2]) * (Imax[i]-Imin[i])
#            I[i]= Imin[i] + i*1.0/(N-1) * (Imax[i]-Imin[i])
#            I[i]= Imin[i] + mod(i,12)*1.0/(N/9-1) * (Imax[i]-Imin[i])
                     
        # affect currents
        if poissonInput: # with poisson currents in fact correspond to rates
            input._S[0,:] = I
        else:
            input.I_ = I
    
        if reset[rcursor]> endTime: # no reset during period
            defaultclock.reinit(startTime) # make sure end time is exactly the one we want, to avoid drifting
            run(endTime-startTime) # run Brian simulator until endTime
        else:
            fromTime = startTime
            while reset[rcursor]<= endTime: # iterate on resets
                defaultclock.reinit(fromTime) # make sure end time is exactly the one we want, to avoid drifting
                run(reset[rcursor]-fromTime) # run Brian until reset
#                input.v_ = El # reset to resting potential
                input.v_ = vr # reset to reset potential
                
                fromTime = reset[rcursor]
                rcursor = rcursor+1
                
            # last bit    
            defaultclock.reinit(reset[rcursor-1]) # make sure end time is exactly the one we want, to avoid drifting
            run(endTime-reset[rcursor-1]) # run Brian until end time
            
    
        if endTime>=imposedEnd: # exit condition
            break
    
        if inputSpike.nspikes > 1000000: # periodic log output
            printtime('Period # '+ str(count+1) +' - simulated time: '+str(endTime)+' - computation time: ' + str(time()*second-localStartTime))
            localStartTime = time()*second
            if dumpSpikeList: # dump spike list of this period
                print 'Dumping #' + str(count+1) + ' (nspikes=' + str(inputSpike.nspikes) + ')'
                savemat(os.path.join('..','data','spikeList.'+'%03d' % (randState)+'.'+ '%03d' % (count+1) +'.mat'),{'sl':inputSpike.spikes})
                print 'Dumping time: '+ str(time()*second-localStartTime)                        
                localStartTime = time()*second
            inputSpike.reinit()
            count += 1
                
        # periodic graphic plot output
        if floor(endTime/analyzePeriod)!=floor(startTime/analyzePeriod):
            # compute final normalized weight (there's probably a smarter way to do that...)
            if computeOutput:  
                finalWeight = zeros([N,M])
                for i in range(N):
                    for j in range(M):
                        finalWeight[i,j] = _synW[i,j]/gmax[j]
            execfile('analyze.py')
                                        
        # start := end
        startTime = endTime
            
    f.close() # close input current file
    
    # last dump
    if inputSpike.nspikes>0: # something to dump
        printtime('Period # '+ str(count+1) +' - simulated time: '+str(endTime)+' - computation time: ' + str(time()*second-localStartTime))
        localStartTime = time()*second
        if dumpSpikeList:
            print 'Dumping #' + str(count+1) + ' (nspikes=' + str(inputSpike.nspikes) + ')'
            savemat(os.path.join('..','data','spikeList.'+'%03d' % (randState)+'.'+ '%03d' % (count+1) +'.mat'),{'sl':inputSpike.spikes})
            print 'Dumping time: ' + str(time()*second-localStartTime)                        
            localStartTime = time()*second
            
else: # not in recomputeSpikeList mode. Thus spikes are read from files spikeList.###.###.mat (first number: random seed, second number: file number) 
    
    if not computeOutput:
        print 'Warning: bad configuration: compute neither output nor output...'
    
                
    
    printtime('Starting (use saved spike list)')
    
    # look for spike list files
    fileList = listMatFile('../data/',randState)
    print str(len(fileList)) + ' spike list files found'
    
   
    for fl in fileList: # iterate on spile list files

        # read spike list
        localStartTime = time()*second
        print 'Reading '+ fl
        spikeList=loadmat(os.path.join('..','data',fl))
        spikeList=spikeList['sl']
        spikeList[:,1]+=timeOffset
        spikeList[:,1]/=pbTimeCompression
        print str(size(spikeList,0)) + ' spikes read (in ' + str(time()*second-localStartTime) + ')'
        
        input = SpikeGeneratorGroup(N, spikeList) # special Brian NeuronGroup that fire at specified dates
        endTime = spikeList[-1][1]       
        del spikeList
        
        # monitors
        if endTime>=monitorTime and not isMonitoring:
            print '********************'            
            print '* Start monitoring *'            
            print '********************'            
            isMonitoring = True
            if monitorInput:
                inputSpike = SpikeMonitor(input,True)
            if monitorOutput:
                outputSpike = SpikeMonitor(neurons,True)
            if monitorPot:
                pot = StateMonitor(neurons,'v',record=True)
            if monitorCurrent:
                current = StateMonitor(neurons,'ge',record=True)
            if monitorRate:
                rate = []
                for i in range(M):
                    rate.append(PopulationRateMonitor(neurons[i],bin=2000*ms))
                                
        # imposed end time
        endTime = min(imposedEnd,endTime)
        
        # connect new spike generator
        C_input_mirror = IdentityConnection(input, mirror)
        

        # run
        print 'Running from t=' + str(startTime) + ' to t=' + str(endTime)
        defaultclock.reinit(startTime) # make sure end time is exactly the one we want, to avoid drifting
        run(endTime-startTime) # run Brian simulator until endTime
    
        # periodic graphic plot output
        if floor(endTime/analyzePeriod)!=floor(startTime/analyzePeriod):
            # compute final normalized weight (there's probably a smarter way to do that...)
            if computeOutput:  
                finalWeight = zeros([N,M])
                for i in range(N):
                    for j in range(M):
                        finalWeight[i,j] = _synW[i,j]/gmax[j]
            execfile('analyze.py')

        # start := end
        startTime = endTime

        # explicitly free memory
        del input

        printtime('Period # '+ str(count+1) +': computation time: ' + str(time()*second-localStartTime))
        localStartTime = time()*second
        count += 1
        
        if endTime>=imposedEnd:
            break
        
        for j in range(M):
            if mean(_synW[:,j])/gmax[j]>burstingCriterion:
                print 'WARNING: neuron # ' + str(j) + ' is bursting. Disconnecting it.'
                _synW[:,j] = 0*mV
                gmax[j] = 0*mV


        

print 'Total computation time: ' + str(time()*second-globalStartTime)

# compute final normalized weight (there's probably a smarter way to do that...)
if computeOutput:  
    finalWeight = zeros([N,M])
    for i in range(N):
        for j in range(M):
            finalWeight[i,j] = _synW[i,j]/gmax[j]


#execfile('pickleAll.py') # pickle all variable (under development...)
if imposedEnd>6: # don't dump short simulations, probably done for display only
    execfile('saveWeight.py') # dump final weights 
    
#execfile('toMatlab.py') # dump variables in a mat file (under development)

if graph: # graphical plot output
    execfile('analyze.py')
    show()

if imposedEnd>6 and monitorOutput: # mutual info (stimulus, response)
    execfile('mutualInfo.py')
    show()