import numpy
import sys
import logging

from pyV1.dataformats import *
from datetime import datetime

import pyV1.utils

def spikes2exp(spikes, sample_time, tau=30e-3, tau2=0):
    delta = (spikes.compress(spikes <= sample_time) - sample_time)/tau
    return (numpy.exp(delta)).sum()


def spikes2alpha(spikes, sample_time, tau1=30e-3, tau2=3e-3):
    MIN_EXP=-20
    delta = (spikes.compress(spikes <= sample_time) - sample_time)
    return (numpy.exp(delta/tau1)-numpy.exp(delta/tau2)).sum()


def spikes2count(spikes, sample_time, tau1=30e-3, tau2=0):
    delta = (spikes.compress(spikes <= sample_time) - sample_time)
    return (delta.compress(delta > -tau1)).shape[0]


class State(object):
    def __init__(self, values, t):
        self.X=values
        self.t=t


#@utils.timedTest
def response2states(Rin, sampling=None, filter=spikes2exp, tau1=30e-3, tau2=3e-3, channels=None):
    if utils.isScalar(Rin):
        Rin=[Rin]
        
    numResponses = len(Rin)
    all_states = []
    count = 0
    sys.stdout.write("response2states:   0%")
    for R in Rin:
        Tmax = R.Tsim
        if sampling is None: 
            sampling = [R.Tsim]
        if type(sampling)==str:
            sampling = eval("numpy.r_"+sampling)

        if channels is None:
            channels = range(len(R.channel))

        numChannels=len(channels)
        numSamplings=len(sampling)
        states = []

        if tau2 > tau1:
            tau1, tau2 = tau2, tau1

        logging.debug("Sampling %d %d-dimensional states...\n" % (numSamplings,numChannels))
        for sample_time in sampling:
            state = []
            for chid in channels:
                channel = R.channel[chid]
                channel_state = filter(numpy.atleast_1d(channel.data), sample_time, tau1, tau2)
                state.append(channel_state)
                count += 1
                sys.stdout.write("\b\b\b\b%3d%%" % (count*100/(numChannels*numSamplings*numResponses)))

            states.append(State(values=numpy.array(state), t=sample_time))
            
        all_states.append(states)
    sys.stdout.write("\b\b\b\b100%\n")
    return all_states


def plot_states(states):
    import pylab

    numChannels = len(states[0].X)
    numSamples = len(states)

    sampling=[state.t for state in states]
    y=numpy.array([state.X for state in states])
    maxy=numpy.ceil(y.max())

    for c in range(numChannels):
        pylab.plot(sampling, numpy.array(y[:,c]) + c*maxy)




def getLayerwiseStates(resp,layer_depth=[3,3,3],layer_base=None,**kwds):
    ''' ditto'''
    if not resp.__dict__.has_key('layers'):
        raise Exception('getLayerIdx not done')

    states = response2states(resp,**kwds)

    #get grid
    if layer_base==None:
        #assume square
        nx = numpy.sqrt(float(numpy.sum(resp.layeridx==0))/float(layer_depth[0]))
        if int(numpy.floor(nx)) != int(nx):
            raise Exception('not square. provide layer_base')
        layer_base = [nx,nx]


    idx_lst = []
    layer_lst = []
    for lay in range(len(resp.layers)):
        idx_lst.append(numpy.where(resp.layeridx==lay)[0])
        layer_lst.append(numpy.ones(layer_depth[lay])*lay)

    layer_lst = numpy.concatenate(layer_lst[::-1])

    for s in states[0]:
        X = []

        for lay in range(len(resp.layers)):
            X.append(s.X[idx_lst[lay]].reshape([layer_base[0],layer_base[1],layer_depth[lay]]))
    
        s.X = numpy.concatenate(X[::-1],axis=2)
        s.layeridx = layer_lst
        s.layers =resp.layers
        
    return states[0]



def plotLayerwiseStates(states,idx=None):

    from utils.mscfuncs import imagesc
    import pylab

    if idx ==None:
        idx = range(len(states))

    mx = 0.
    for i in idx:
        mx = numpy.max([states[i].X.max(), mx]);

    handle_lst = []
    for i in idx:
        handle_lst.append(pylab.figure())

        for lay in range(states[0].X.shape[2]):
            ax = pylab.subplot(3,3,lay+1)
            imagesc(states[i].X[:,:,lay])
            pylab.title('layer %s (%d) (t=%1.3fs)' % (states[i].layers[states[i].layeridx[lay]],lay,states[i].t),fontsize=8)
            pylab.clim([0,mx])
            pylab.setp(ax,'xticklabels',[])
            pylab.setp(ax,'yticklabels',[])

    return handle_lst




if __name__ == '__main__':
    import pylab


    if 0:

        Tsim=1.0
        freq=20
        
        resp = Response(Tsim)
        resp.appendChannel(Channel(data=numpy.random.uniform(0,Tsim, freq*Tsim)))
        resp.appendChannel(Channel(data=numpy.random.uniform(0,Tsim, freq*Tsim)))
        resp.appendChannel(Channel(data=numpy.random.uniform(0,Tsim, freq*Tsim)))
        resp.appendChannel(Channel(data=numpy.random.uniform(0,Tsim, freq*Tsim)))
        resp.appendChannel(Channel(data=numpy.random.uniform(0,Tsim, freq*Tsim)))
        
        sample=numpy.arange(0, Tsim, 0.001)

        #statesexp = response2states(resp, sample, spikes2exp)    
        #pylab.figure()
        #plot_states(statesexp)

        print "response2states spikes2exp"
        statesexp = response2states(resp, sample, spikes2exp)
        print "response2states spikes2alpha"
        statesalpha = response2states(resp, sample, spikes2alpha)
        print "response2states spikes2count"
        statescount = response2states(resp, sample, spikes2count)

        pylab.figure()
        pylab.subplot(311)
        plot_states(statesexp)
        pylab.subplot(312)
        plot_states(statesalpha)
        pylab.subplot(313)
        plot_states(statescount)
        pylab.savefig('states')

        #pylab.show()

    if 1:
        if 1:
            import dataformats
            #fname = '/home/malte/saves/pysim/sim_WInscaleNoRecurrent_ephysBackgroundMovie/WInscaleNoRecurrent_ephysBackgroundMovie.231'
            fname= '/home/malte/saves/pysim/sim_WscaleAndLRW_seed1434103_getStandardMovieStm_62x62_HC1/WscaleAndLRW_seed1434103_getStandardMovieStm_62x62_HC1.199'
            rsp = dataformats.Response()
            rsp.load(fname,'/Response')

            states = getLayerwiseStates(rsp,sampling=numpy.linspace(0,10.,50))
        

        plotLayerwiseStates(states)
        pylab.show()