from scipy.interpolate.interpolate import interp1d

import numpy
import pyV1.utils

from pyV1.inputs.jitteredtemplate import *
from pyV1.inputs.constantrate import *
from pyV1.inputs.randomrate import *
from pyV1.inputs.randombndrate import *
from pyV1.inputs.randommodrate import *

from pyV1.inputs.spikes2rate import spikes2rate

import scipy.interpolate

#import pdb as pdbm

class TargetFunction(object):
    def __init__(self, pos_gen=None):
        self.pos_gen=pos_gen

    def values(self, stim, at_t):
        pass

    def getStim(self, stim):
        if isinstance(stim, CombinedStimulus):
            if not self.pos_gen is None:
                return stim.stim_list[self.pos_gen]

        return stim


class SegmentClassification(TargetFunction):
    '''segment classification'''

    def __init__(self, posSeg=0, pos_gen=None):
        '''posSeg ... Segment position to classify
pos_gen ... which generator to classify (only if CombinedInput)
'''
        super(SegmentClassification, self).__init__(pos_gen)
        self.posSeg=posSeg            # number of segment to classifiy as positive


    def values(self, stim, at_t):
        s = super(SegmentClassification, self).getStim(stim)

        try:
            nSegments=len(s.actualTemplate)
        except:
            raise Exception('Given stimulus is not a template stimulus!')

        if self.posSeg >= nSegments:
            raise Exception('segment out of range!')

        if -self.posSeg > nSegments:
            raise Exception('segment out of range!')

#        if self.pos_gen is None:
        v=(s.actualTemplate[self.posSeg])

        return numpy.tile(v, len(at_t))



class SumOfRates(TargetFunction):
    '''sum of rates (normalized) in the time window [t-delay-W,t-delay]'''
    
    def __init__(self, delay=0e-3, fmax=100, input_binwidth=15e-3, nValues=numpy.inf, channels=None, pos_gen=None, norm=False):
        super(SumOfRates, self).__init__(pos_gen)
        self.delay=delay            # the delay of the time window [t-delay-W,t-delay]
        self.fmax=fmax              # the maximal frequency assumed to occur
        self.nValues=nValues        # number of possible values of the target function (''Inf'' for real valued)
        self.input_binwidth=input_binwidth
        self.norm=norm

        if channels is None:
            channels = []

        self.channels=channels      # over which channels do calculate the sum of rates


    def values(self, input, at_t):
        stim = super(SumOfRates, self).getStim(input)

        # merge all spikes to a single spike train spikes
        if len(self.channels)>0:
            spikes=utils.flatten([[stim.channel[j].data for j in self.channels], input.Tsim])
            d=len(self.channels)
        else:
            spikes=utils.flatten([[c.data for c in stim.channel], input.Tsim])
            d=len(stim.channel)

#        spikes = numpy.hstack((spikes, input.Tsim))
        mfr = len(spikes)/input.Tsim/d

        (rates, t)=spikes2rate(numpy.asarray(spikes), self.input_binwidth, dt=1e-3)
        rates /= d
        if self.norm:
            rates /= self.fmax

        ylin = scipy.interpolate.interp1d(t, rates)
        tmax=t.max()
        tmin=t.min()
        att = numpy.array(at_t)-self.delay
        att[att>tmax]=tmax
        att[att<tmin]=tmin

        y = numpy.asarray(ylin(att))

        return y



class SpikeCorrelation(TargetFunction):
    '''correlation of spikes within the interval [t-delay-W,t-delay]'''

    def __init__(self, delay=0e-3, W=50e-3, delta=5e-3, channels=None, pos_gen=None):
        super(SpikeCorrelation, self).__init__(pos_gen)
        self.delay=delay            # the delay of the time window [t-delay-W,t-delay]
        self.W=W                    # the width of the time window [t-delay-W,t-delay]
        self.delta=delta            # precision of coincidence detection

        if channels is None:
            channels = []

        self.channels=channels      # over which channels do calculate the sum of rates
        
    
    def values(self, input, at_t):
        stim = super(SpikeCorrelation, self).getStim(input)

        if len(self.channels)>0:
            spikes = utils.flatten([stim.channel[j].data for j in self.channels])
        else:
            spikes = utils.flatten([c.data for c in stim.channel])

        spikes.sort()
        n=len(self.channels)

        s=numpy.ones(n)
        ii=0
        tco=numpy.nan*numpy.ones(len(spikes))

        for t in spikes:
            c=numpy.zeros(n)

            for i in range(n):
                st = stim.channel(this.channels(i)).data
                strain = st.compress((st >= (t-self.delta)).flat)

                if len(strain) > 0:
                    if strain[0] <= t:
                        c[i] = c[i]+1

#                stim.channel[self.channels[i]].data=st

            if numpy.all(c>0):
                tco[ii]=t
                ii=ii+1

        tco=tco.compress((tco != numpy.nan).flat)

        y=numpy.zeros(len(at_t))

        for i in range(len(at_t)):
            a=(at_t[i]-self.W-self.delay) < tco
            b=tco <= (at_t[i]-self.delay)
            y[i]=((numpy.logical_and(a, b)).astype(numpy.int32)).sum()

        return y



class CombinedTarget(TargetFunction):
    '''combine arbitary target functions'''
    
    def __init__(self, targets=[], expr='f1'):
        self.targets=targets        # list of target function to be combined
        self.expr=expr              # a string like ''f1*f2+sin(f3)'' which defines how to combine the target functions f1,f2,..


    def values(self, input, at_t):
#        pdbm.set_trace()

        if len(self.targets) > 0:
            i=1
            for target in self.targets:
                f=numpy.array(target.values(input, at_t))
                exec('f%d = f' % i)
                i+=1

            y=eval(self.expr)
        else:
            y=numpy.nan*numpy.ones(len(at_t))
            
        return numpy.atleast_1d(y)
    

if __name__=='__main__':
    Tstim=1.0
    bin = 0.2
    times = numpy.arange(0, Tstim + bin/2.0, bin)
    
    jtemp=JitteredTemplate(Tstim=Tstim, nChannels=2, nTemplates=[2, 2], jitter=4e-3, freq=[5,10])
    jtemp_multi=JitteredTemplate(Tstim=Tstim, nChannels=3, nTemplates=[3, 3], jitter=4e-3, freq=[10,5,10])
    crate=ConstantRate(Tstim=Tstim)
    
    stim_seg=jtemp.generate()
    stim_seg_multi=jtemp_multi.generate()
    stim_rates=crate.generate()
    
    class_seg0=SegmentClassification(posSeg=0)    
    v_seg0=class_seg0.values(stim_seg, times)
    print "SegmentClassification (posSeg=0) values:", v_seg0, "\n"

    class_seg1=SegmentClassification(posSeg=1)
    v_seg1=class_seg1.values(stim_seg, times)
    print "SegmentClassification (posSeg=1) values:", v_seg1, "\n"

    class_seg0_multi=SegmentClassification(posSeg=0)
    v_seg0_multi=class_seg0_multi.values(stim_seg_multi, times)
    print "SegmentClassification (posSeg=0, multi) values:", v_seg0_multi, "\n"

    class_seg1_multi=SegmentClassification(posSeg=1)
    v_seg1_multi=class_seg1_multi.values(stim_seg_multi, times)
    print "SegmentClassification (posSeg=1, multi) values:", v_seg1_multi, "\n"

    try:
        class_seg2_multi=SegmentClassification(posSeg=2)
        v_seg2_multi=class_seg2_multi.values(stim_seg_multi, times)
        print "SegmentClassification (posSeg=2, multi) values:", v_seg2_multi, "\n"
    except:
        print 'Exception: ok'

    try:
        class_seg2_multi=SegmentClassification(posSeg=-3)
        v_seg2_multi=class_seg2_multi.values(stim_seg_multi, times)
        print "SegmentClassification (posSeg=2, multi) values:", v_seg2_multi, "\n"
    except:
        print 'Exception: ok'

    class_seg0_multi=SegmentClassification(posSeg=0)
    v_seg0_multi=class_seg0_multi.values(stim_seg_multi, times)
    print "SegmentClassification (posSeg=0, multi) values:", v_seg0_multi, "\n"

    spike_corr=SpikeCorrelation()
    v_corr=spike_corr.values(stim_rates, times)
    print "SpikeCorrelation values:", v_corr, "\n"

    sum_rates=SumOfRates()
    v_sumrates=sum_rates.values(stim_rates, times)
    print "SumOfRates values:", v_sumrates, "\n"

    sum_rates_norm=SumOfRates(norm=True)
    v_sumrates_norm=sum_rates_norm.values(stim_rates, times)
    print "SumOfRates (normalized) values:", v_sumrates_norm, "\n"

#    combined=CombinedTarget(targets=[SumOfRates(channels=[0]), SumOfRates(channels=[1]),
#                                    SumOfRates(channels=[2]), SumOfRates(channels=[3])],
#                                    expr='f1+f2+f3+f4')
#    v_comb=combined.values(stim_rates, times)
#    print "Combined values:", v_comb, "\n"
    
    
    jtemp1=JitteredTemplate(Tstim=Tstim, nChannels=2, nTemplates=[2, 2], jitter=4e-3, freq=[5,10])
    jtemp2=JitteredTemplate(Tstim=Tstim, nChannels=2, nTemplates=[2, 2], jitter=4e-3, freq=[5,10])
    comb_jtemp_gen=CombinedInputGenerator([jtemp1, jtemp2])

    comb_jtemp_stim=comb_jtemp_gen.generate()

    class_seg0_gen0=SegmentClassification(posSeg=-2, pos_gen=0)
    v_seg0_gen0=class_seg0_gen0.values(comb_jtemp_stim, times)
    print "SegmentClassification (posSeg=0, pos_gen=0) values:", v_seg0_gen0
    print 'actualTemplate:', comb_jtemp_stim.stim_list[0].actualTemplate, '\n'

    class_seg0_gen1=SegmentClassification(posSeg=-2, pos_gen=1)
    v_seg0_gen1=class_seg0_gen1.values(comb_jtemp_stim, times)
    print "SegmentClassification (posSeg=0, pos_gen=1) values:", v_seg0_gen1
    print 'actualTemplate:', comb_jtemp_stim.stim_list[1].actualTemplate, '\n'

    xor_target = CombinedTarget(targets=[SegmentClassification(posSeg=(0), pos_gen=0), \
                                         SegmentClassification(posSeg=(0), pos_gen=1)], expr='(f1+f2) % 2')
    v_xor_target=xor_target.values(comb_jtemp_stim, times)
    print "xor_target:", v_xor_target, "\n"


    fmax1 = 25; fmax2 = 25; fmin1 = 15; fmin2 = 15;
    rate_gen1=RandomBndRate(Tstim=Tstim, binwidth=bin, nChannels=2, fmin=fmin1, fmax=fmax1)
    rate_gen2=RandomBndRate(Tstim=Tstim, binwidth=bin, nChannels=2, fmin=fmin2, fmax=fmax2)
    comb_rate_gen=CombinedInputGenerator([rate_gen1, rate_gen2])


    comb_rate_stim=comb_rate_gen.generate()

    comb_rate_target=CombinedTarget(targets=[SumOfRates(pos_gen=0,delay=0.004), SumOfRates(pos_gen=1, delay=0.004)], expr='f1/f2')
    v_comb_target=comb_rate_target.values(comb_rate_stim, times)
    print "Combined values (f1/f2):", v_comb_target
    sum_rates_delay=SumOfRates(delay=0.004)
    v_comb_rate=sum_rates_delay.values(comb_rate_stim, times)
    print "Combined values sum_rates all:", v_comb_rate
    sum_rates_gen0=SumOfRates(pos_gen=0, delay=0.004)
    v_comb_rate_gen0=sum_rates_gen0.values(comb_rate_stim, times)
    print "Combined values sum_rates 0:", v_comb_rate_gen0
    sum_rates_gen1=SumOfRates(pos_gen=1, delay=0.004)
    v_comb_rate_gen1=sum_rates_gen1.values(comb_rate_stim, times)
    print "Combined values sum_rates 1:", v_comb_rate_gen1
    print "Combined values (f1/f2) test:", v_comb_rate_gen0/v_comb_rate_gen1, "\n"

    comb_rate_target=CombinedTarget(targets=[SumOfRates(pos_gen=0), SumOfRates(pos_gen=1)], expr='f1/f2')
    v_comb_target=comb_rate_target.values(comb_rate_stim, times)
    print "Combined values (f1/f2):", v_comb_target
    sum_rates=SumOfRates()
    v_comb_rate=sum_rates.values(comb_rate_stim, times)
    print "Combined values sum_rates all:", v_comb_rate
    sum_rates_gen0=SumOfRates(pos_gen=0)
    v_comb_rate_gen0=sum_rates_gen0.values(comb_rate_stim, times)
    print "Combined values sum_rates 0:", v_comb_rate_gen0
    sum_rates_gen1=SumOfRates(pos_gen=1)
    v_comb_rate_gen1=sum_rates_gen1.values(comb_rate_stim, times)
    print "Combined values sum_rates 1:", v_comb_rate_gen1
    print "Combined values (f1/f2) test:", v_comb_rate_gen0/v_comb_rate_gen1, "\n"