from inputgenerator import *
import numpy

from datetime import datetime

class JitteredTemplateStimulus(Stimulus):
    def __init__(self, Tsim=0.5, actualTemplate=[1]):
        super(JitteredTemplateStimulus, self).__init__(Tsim)
        self.actualTemplate = actualTemplate

    def __str__(self):
        desc = '''  JitteredTemplateStimulus
  channel        : [1x%s struct]
  Tsim          : %s
  actualTemplate : %s\n''' % (len(self.channel), self.Tsim, self.actualTemplate)
        return desc


class JitteredTemplate_Segment(object):
    def __init__(self):
        self.template = []

    def __str__(self):
        desc = 'template : [1x%s struct]' % (len(self.template))
        return desc


class JitteredTemplate_Template(object):
    def __init__(self):
        self.st = []

    def __str__(self):
        desc = 'st : %s' % (self.st)
        return desc


class JitteredTemplate(InputGenerator):
    '''  jittered spike templates '''

    def __init__(self, **kwds):
        '''  nChannels  ... number of spike trains (channels)
  Tstim      ... length of spike trains
  jitter     ... jitter to add to each spike
  nTemplates ... number of templates per segment [1 x #Seg]
  nSpikes    ... number of spikes per template (uniformly dist) [1 x #Seg]
  freq       ... frequency of poisson spike train templates [1 x #Seg]
  segment    ... stores the actual spike templates
    '''
        self.nChannels = kwds.get('nChannels', 1)
        self.Tstim = float(kwds.get('Tstim', 1.0))
        self.jitter = float(kwds.get('jitter', 4e-3))
        self.nTemplates = kwds.get('nTemplates', [2])
        self.nSpikes = kwds.get('nSpikes', [])
        self.freq = [float(f) for f in kwds.get('freq', [20])]
        self.segment = []            # stores the actual spike templates

        self.generateTemplate()


    def __str__(self):
        desc = '''  JITTERED_TEMPLATE
  nChannels  : %s
  Tstim      : %s
  jitter     : %s
  nTemplates : %s
  nSpikes    : %s
  freq       : %s
  segment    : [1x%s struct] 
''' % (self.nChannels, self.Tstim, self.jitter, self.nTemplates, self.nSpikes, self.freq, len(self.segment))
        return desc


    def generateTemplate(self):
        self.nSegments = len(self.nTemplates)

        for s in range(self.nSegments):
            self.segment.append(JitteredTemplate_Segment())

            for i in range(self.nTemplates[s]):
                self.generateSegment(self.freq[s], s, i)


    def generateSegment(self, freq, s, i):
        tau_refract=3e-3
        
        lambd = 1/freq-tau_refract

        self.segment[s].template.append(JitteredTemplate_Template())
        self.segment[s].template[i].st=[]

        for j in range(self.nChannels):
            st=[];
            if len(self.nSpikes)>0:
                while len(st)==0:
                    st=numpy.cumsum(tau_refract + numpy.random.exponential(lambd, (1,self.nSpikes[s])))
                    st=st.compress((st<(self.Tstim/self.nSegments)).flat)
                    
            elif len(self.freq)>0:
                m  = numpy.ceil(5*freq*self.Tstim)
                st = numpy.cumsum(tau_refract + numpy.random.exponential(lambd, (1, m)))
                st = st.compress((st<=(self.Tstim/self.nSegments)).flat)

                st = st + s*self.Tstim/self.nSegments;
                self.segment[s].template[i].st.append(st.tolist())


    def generateChannel(self, freq, s, i):
        pass
    
    
    def generate(self, ti = None):
        '''generates a JitteredTemplateStimulus object'''
        
        nSegments = len(self.nTemplates)
        if not ti:
            ti = [int( numpy.random.uniform(0,self.nTemplates[i],1)) for i in range(nSegments)]
        stimulus = JitteredTemplateStimulus(self.Tstim, ti);
        
        for i in range(self.nChannels):
            stimulus.channel.append(Channel())
            
            st = []
            for s in range(nSegments):
                st += self.segment[s].template[ti[s]].st[i]
            st = numpy.asarray(st)

            if self.jitter > 0.0:
                st = st + numpy.random.randn(len(st))*self.jitter
                st = st[st>=0.0]
                st = st[st<=self.Tstim]
                st.sort()
                
            stimulus.channel[i].data = st
            
        return stimulus


    def plot(self, stimulus=None, COLOR=1, FS=1):
        '''Plotting the Template and a generated input'''
        import pylab

        if stimulus is None:
            stimulus=self.generate()
            
        pylab.figure()
        
        DEFCOLS = [];
        if COLOR != 0:
            DEFCOLS = [(1,0,0),(0,1,0),(0,0,1),(1,0.5,0),(0.3,0,0.8),(0.6,0,0),(0,0,0.6),(0,0.5,0),(0,1,1),(1,1,0),(1,0,1),(0,0,0)]
            
        col=0
        maxnt=0
        nSegments = len(self.segment) 
       
        for s in range(nSegments):
            maxnt=max(maxnt, len(self.segment[s].template))
            for t in range(len(self.segment[s].template)):
                col=col+1
                if COLOR == 1:
                    if col < len(DEFCOLS):
                        self.segment[s].template[t].col = DEFCOLS[col]
                    else: 
                        self.segment[s].template[t].col = (0,0,0)
                elif COLOR == 2:
                    if t < len(DEFCOLS):
                        self.segment[s].template[t].col = DEFCOLS[t]
                    else: 
                        self.segment[s].template[t].col = (0,0,0)
                        
                    
    
        sp1=pylab.subplot(211)
        MAXY=maxnt*self.nChannels
        L=self.Tstim/nSegments        
        for s in range(nSegments):
            for t in range(maxnt):
                if t < len(self.segment[s].template):
                    for j in range(len(self.segment[s].template[t].st)):
                        for spike in self.segment[s].template[t].st[j]:
                            pylab.plot(numpy.array([spike, spike]), MAXY-1-(numpy.array([[-0.3],[0.3]])+t*self.nChannels+j), color=self.segment[s].template[t].col, linewidth=2)
                else:
                    pylab.text((s+0.5)*L, MAXY-(t*self.nChannels+1), 'only ' +str(len(self.segment[s].template)) + ' templates', horizontalalignment='center', verticalalignment='center', fontsize=10*FS)                                
                
        for t in range(maxnt):
            pylab.plot([0,self.Tstim], MAXY-(numpy.array([0.5,0.5])+t*self.nChannels), color='k', linewidth=1)
        
        for s in range(nSegments-1):
            if COLOR > 1: 
                c = (0,0,0)
            else: 
                c = (0.5,0.5,0.5)        
            pylab.plot(numpy.array([1,1])*L*(s+1), MAXY-numpy.array([0.3, maxnt*self.nChannels+0.5]),color=c,linewidth=1,linestyle='--')
    
        for s in range(nSegments):
            pylab.text(L*(s+0.5), MAXY-0.2, str(s+1)+'. segment', verticalalignment='bottom', horizontalalignment='center', fontsize=10*FS)
    
        pylab.axis('tight')
        pylab.setp(pylab.gca(), xlim=[0,self.Tstim], ylim=[-1,maxnt*self.nChannels+0.5], xticks=[], yticks=[])   
        pylab.title('possible spike train segments',fontweight='bold',fontsize=10*FS)
    
        sp2=pylab.subplot(212)
        pylab.hold(1)
        MAXY=self.nChannels
            
        for j in range(self.nChannels):
            ST=stimulus.channel[j].data[:]
            for s in range(nSegments):
                t1=s*L
                t2=t1+L;
                st = [si for si in ST if (si>t1 and si<=t2)]            
                for spike in st:
                    pylab.plot([spike, spike], MAXY-1-(numpy.array([[-0.3],[0.3]])+j), color=self.segment[s].template[stimulus.actualTemplate[s]].col, linewidth=2)
        
        if nSegments > 1:
            for s in range(nSegments-1):
                if COLOR > 1: 
                    color = (0, 0, 0)
                else: 
                    c = (0.5, 0.5, 0.5)        
                pylab.plot(numpy.array([1,1])*L*(s+1), MAXY-1-(numpy.array([0.3, self.nChannels+0.5])), color=c, linewidth=1, linestyle='--')

        for s in range(nSegments):
            pylab.text(L*(s+0.5), MAXY-0.3, 'template '+str(stimulus.actualTemplate[s]), verticalalignment='bottom', horizontalalignment='center', fontsize=10*FS)
        
        pylab.axis('tight')
        pylab.setp(pylab.gca(), xlim=[0,self.Tstim], ylim=[-1,self.nChannels+0.5], yticks=[])
        pylab.xlabel('time [sec]', fontsize=10*FS)
    
        if self.nChannels > 1: 
            tit='resulting input spike trains'
        else: 
            tit='resulting input spike train'
            
        pylab.title(tit, fontweight='bold',fontsize=10*FS)
    
        pylab.setp(sp1,position=[0.07, 0.4, 0.86, 0.5])
        pylab.setp(sp2,position=[0.07, 0.1, 0.86, 0.2])



class JitteredTemplateDFT(JitteredTemplate):
    '''jittered spike templates with different frequencies of templates'''

    def __init__(self, **kwds):
        '''  nChannels  ... number of spike trains (channels)
  Tstim      ... length of spike trains
  jitter     ... jitter to add to each spike
  nTemplates ... number of templates per segment [1 x #Seg]
  nSpikes    ... number of spikes per template (uniformly dist) [1 x #Seg]
  freq       ... frequency of poisson spike train templates [1 x max(nTemplates)]
  segment    ... stores the actual spike templates
    ''' 
        self.nChannels = kwds.get('nChannels', 1)
        self.Tstim = float(kwds.get('Tstim', 1.0))
        self.jitter = float(kwds.get('jitter', 4e-3))
        self.nTemplates = kwds.get('nTemplates', [2])
        self.nSpikes = kwds.get('nSpikes', [])
        self.freq = [float(f) for f in kwds.get('freq', [20, 20])]
        self.segment = []            # stores the actual spike templates
        self.generateTemplate()


    def __str__(self):
        desc = '''  JITTERED_TEMPLATE DFT (different frequences of templates)
  nChannels  : %s
  Tstim      : %s
  jitter     : %s
  nTemplates : %s
  nSpikes    : %s
  freq       : %s
  segment    : [1x%s struct] 
''' % (self.nChannels, self.Tstim, self.jitter, self.nTemplates, self.nSpikes, self.freq, len(self.segment))
        return desc


    def generateTemplate(self):
        self.nSegments = len(self.nTemplates)

        for s in range(self.nSegments):
            self.segment.append(JitteredTemplate_Segment())
            
            for i in range(self.nTemplates[s]):
                self.generateSegment(self.freq[i], s, i)



class JitteredTemplateDFCH(JitteredTemplate):
    '''jittered spike templates with different frequencies of channels'''

    def __init__(self, **kwds):
        '''  nChannels  ... number of spike trains (channels)
  Tstim      ... length of spike trains
  jitter     ... jitter to add to each spike
  nTemplates ... number of templates per segment [1 x #Seg]
  nSpikes    ... number of spikes per template (uniformly dist) [1 x #Seg]
  freq       ... frequency of poisson spike train templates [1 x arbitary number]
  segment    ... stores the actual spike templates
    ''' 
        self.nChannels = kwds.get('nChannels', 1)
        self.Tstim = float(kwds.get('Tstim', 1.0))
        self.jitter = float(kwds.get('jitter', 4e-3))
        self.nTemplates = kwds.get('nTemplates', [2])
        self.nSpikes = kwds.get('nSpikes', [])
        self.freq = [float(f) for f in kwds.get('freq', [20, 20])]
        self.segment = []            # stores the actual spike templates
        self.generateTemplate()


    def __str__(self):
        desc = '''  JITTERED_TEMPLATE DFCH (different frequences of channels)
  nChannels  : %s
  Tstim      : %s
  jitter     : %s
  nTemplates : %s
  nSpikes    : %s
  freq       : %s
  segment    : [1x%s struct] 
''' % (self.nChannels, self.Tstim, self.jitter, self.nTemplates, self.nSpikes, self.freq, len(self.segment))
        return desc

    def generateTemplate(self):
        self.nSegments = len(self.nTemplates)

        for s in range(self.nSegments):
            self.segment.append(JitteredTemplate_Segment())
            
            for i in range(self.nTemplates[s]):
                self.generateSegment(self.freq[i], s, i)
# TODO


class JitteredTemplateRndInit(JitteredTemplate):
    '''jittered spike templates with different frequencies of templates
the initial input in the first Tinit seconds will be a random Poisson spike train of rate finit'''
    pass


class JitteredTemplateDetection(JitteredTemplate):
    '''jittered spike templates with '''
    def generate(self):
        '''generates a JitteredTemplateStimulus object'''        
        nSegments = len(self.nTemplates)
        ti = [int( numpy.random.uniform(0,self.nTemplates[i],1)) for i in range(nSegments)]
        stimulus = JitteredTemplateStimulus(self.Tstim, ti);
        
        for i in range(self.nChannels):
            stimulus.channel.append(Channel())
            st = []
            for s in range(nSegments):
                if ti[s] == 0:
                    tau_refract=3e-3
                    lambd = 1./self.freq[s]-tau_refract
                    m  = numpy.ceil(5*self.freq[s]*self.Tstim)
                    
                    spikes = numpy.cumsum(tau_refract + numpy.random.exponential(lambd, (1, m)))
                    spikes = spikes.compress((spikes <= (self.Tstim/self.nSegments)).flat)
                    spikes = spikes + s*self.Tstim/self.nSegments;
                    st += spikes.tolist()
                else:
                    st += self.segment[s].template[ti[s]].st[i]
                
            st = numpy.asarray(st)
            if self.jitter > 0.0:
                st = st + numpy.random.randn(len(st))*self.jitter
                st = st[st>=0.0]
                st = st[st<=self.Tstim]
                st.sort()
            stimulus.channel[i].data = st
            
        return stimulus
    

if __name__ == '__main__':
#    import pylab
#    jtemp=JitteredTemplate(Tstim=1.0, nChannels=3, nTemplates=[2,2,3,2], jitter=4e-3, freq=[10,20,10,30])
#    jtemp1=JitteredTemplate(Tstim=1.0, nChannels=1, nTemplates=[2,2,2], jitter=4e-3, freq=[20, 20, 20])
#    jtemp2=JitteredTemplateDFT(Tstim=1.0, nChannels=1, nTemplates=[2,3,3], jitter=4e-3, freq=[10, 30, 10])
#    jtemp2=JitteredTemplateDFCH(Tstim=1.0, nChannels=1, nTemplates=[2,3], jitter=4e-3, freq=[10, 30, 10, 40])
#    stim1=jtemp1.generate()
#    stim2=jtemp2.generate()

#    t_start=datetime.today()
#    r=[jtemp2.generate() for i in range(0, 1000)]
#    print 'duration:', datetime.today()-t_start

    jtempdetection1=JitteredTemplateDetection(Tstim=1.0, nChannels=1, nTemplates=[2,2,2,2,2,2], jitter=1e-3, freq=[20, 20, 20, 20, 20, 20])
    s1=jtempdetection1.generate()
    
    jtempdetection1.plot(s1)
    
#    jtemp1.plot(stim1)
#    jtemp1.plot(stim1, COLOR=2)
#    jtemp2.plot(stim2)
#    jtemp2.plot(stim2, COLOR=2)
#    pylab.savefig('jitTemp')
#    pylab.show()