#
# Stimulus and Input Generator classes for spike trains
# generated from spoken digits speech inputs preprocessed with
# a model of the cochlea
#
# Klampfl Stefan 10/2007
#

import numpy
import sys,os



from pyV1.dataformats import Stimulus
from pyV1.dataformats import Channel
from pyV1.inputs.inputgenerator import InputGenerator

# the stimulus class
class PreprocessedSpeechStimulus(Stimulus):
    def __init__(self, Tsim=0.5, speaker=1, digit=1, utterance=1, reversed=False):
        super(PreprocessedSpeechStimulus, self).__init__(Tsim)
        self.speaker = speaker
        self.utterance = utterance
        self.digit = digit
        self.reversed = reversed
        self.file = "s%d_u%d_d%d" % (speaker,utterance,digit)
        self.save_attrs += ['speaker','utterance','digit','reversed','file']

    def __str__(self):
        desc = '''  PreprocessedSpeechStimulus
  channel   : [1x%s struct]
  Tsim      : %s
  speaker   : %d
  utterance : %d
  digit     : %d
  reversed  : %s
  file      : %s\n''' % (len(self.channel), self.Tsim, self.speaker, self.utterance,
            self.digit,self.reversed,self.file)
        return desc

    def getNumChannels(self):
        return len(self.channel)

    def getTitle(self):
        return str(self.file) + {True:"_rev",False:""}[bool(self.reversed)]

# the input generator class
class PreprocessedSpeech(InputGenerator):
    def __init__(self, **kwds):

        # the speakers, digits, and utterances to use
        self.speakers = kwds.get('speakers',[1,2,5,6,7])
        self.digits = kwds.get('digits',[0,1,2,3,4,5,6,7,8,9])
        self.utterances = kwds.get('utterances',[1,2,3,4,5,6,7,8,9,10])

        # the speakers, digits, and utterances to use for the reversed samples
        # by default, all reversed samples are used, if one of these lists is []
        # no reversed samples are used
        self.rev_speakers = kwds.get('rev_speakers',self.speakers)
        self.rev_digits = kwds.get('rev_digits',self.digits)
        self.rev_utterances = kwds.get('rev_utterances',self.utterances)

        # indicates whether reversed samples are generated by just reversing the spike trains,
        # or by considering the cochleagram of the reversed audio sample, as saved in the HDF5 file
        self.rev_st = kwds.get('rev_st',False)
        # the path of the HDF5 file
        self.path = os.path.realpath(kwds.get('path','.'))
        # the number of forward speech samples
        self.size = len(self.speakers)*len(self.digits)*len(self.utterances)
        # the number of reversed speech samples
        self.rev_size = len(self.rev_speakers)*len(self.rev_digits)*len(self.rev_utterances)
        #the h5filename containing the saved spike trains
        self.h5filename = kwds.get('h5filename','spkdata_10.h5')

    def __str__(self):
        desc = '''  PREPROCESSED_SPEECH
  file           : %s
  speakers       : %s
  digits         : %s
  utterances     : %s
  rev_speakers   : %s
  rev_digits     : %s
  rev_utterances : %s
  size           : %s
  rev_size       : %s
''' % (self.h5filename, self.speakers, self.digits, self.utterances,
            self.rev_speakers, self.rev_digits, self.rev_utterances,
            self.size, self.rev_size)
        return desc
 
    # calculates speaker, utterance and digit from the index
    def get_sud(self, index, reversed):
        i = index
        if reversed:
            i = i-self.size;
            d = self.rev_digits[i % len(self.rev_digits)]
            i = i/len(self.rev_digits)
            s = self.rev_speakers[i % len(self.rev_speakers)]
            i = i/len(self.rev_speakers)
            u = self.rev_utterances[i]
        else:
            d = self.digits[i % len(self.digits)]
            i = i/len(self.digits)
            s = self.speakers[i % len(self.speakers)]
            i = i/len(self.speakers)
            u = self.utterances[i]
        return (s,u,d)

    # returns the total number of samples
    def getTotalSize(self):
        return self.size + self.rev_size

    # returns the full path to the hdf5 file
    def getFullH5Path(self):
        return self.path+"/"+self.h5filename

    # generates a stimulus
    # if what is -1 or left unspecified, a random stimulus is created
    # if what is a integer greater or equal zero, a stimulus is generated by index
    # if what is a tuple with 4 elements a stimulus is created by speaker, digit, utterance,
    # and a boolean indicating a reversed sample
    def generate(self, what=-1):
        if type(what)==int:
            if what<0:
                return self.generateRandom()
            else:
                return self.generateByIdx(what)
        elif type(what)==tuple:
            if len(what)!=4:
                raise TypeError("requires tuple of length 4")
            s,u,d,rev = what
            return self.generateBySUD(s,u,d,rev)
        else:
            raise TypeError("invalid type; expected <int> or <tuple>")

    # generates a stimulus by speaker, utterance, digit, and a boolean indicating
    # a reversed sample
    def generateBySUD(self, speaker, utterance, digit, reversed, Tsim = None):
        stimulus = PreprocessedSpeechStimulus()
        if not Tsim is None:
            stimulus.Tsim = Tsim
        if not self.rev_st:
            revstr = {False:"",True:"_rev"}[reversed]
            grpname = "s%d_u%d_d%d%s" % (speaker,utterance,digit,revstr)
            stimulus.load(filename=self.getFullH5Path(), grpname=grpname)
        else:
            grpname = "s%d_u%d_d%d" % (speaker,utterance,digit)
            stimulus.load(filename=self.getFullH5Path(), grpname=grpname)
            if reversed:
                for c in stimulus.channel:
                    c.data = numpy.sort(stimulus.Tsim - c.data)
                stimulus.reversed = True
        return stimulus

    # generates a stimulus by index
    def generateByIdx(self, index):
        if index < 0 or index > self.size+self.rev_size:
            raise IndexError("index %d out of bounds" % (index))
        elif index < self.size:
            reversed = False
        else:
            reversed = True
        s,u,d = self.get_sud(index,reversed)
        return self.generateBySUD(s,u,d,reversed)

    # generates a random stimulus
    def generateRandom(self):
        idx = numpy.random.randint(self.size+self.rev_size)
        return self.generateByIdx(idx)

# code for generating CSIM stimuli in MATLAB and saving it
# into HDF5 from Python
# NOTE: this solution is very slow
def generateBSAinput(scale = 10):
    from mlabwrap import mlab
    #mlab.addpath('../')
    mlab.addpath('/home/mammoth/dejan/simtools/AuditoryToolbox')
    mlab.addpath('/home/mammoth/dejan/simtools/RCToolbox')
    mlab.addpath('/home/mammoth/dejan/simtools/RCToolbox/spike_coding')
    mlab.addpath('/home/mammoth/dejan/simtools/RCToolbox/utility')
    mlab.addpath('/home/mammoth/dejan/simtools/speech')
    currDir = os.getcwd()
    mlab.cd('/home/mammoth/dejan/simtools/speech')
    mlab.startup()
    #mlab.cd(currDir)
    mlab.cd('/home/mammoth/dejan/simtools/RCToolbox')    
    InputDist = mlab.preprocessed_speech('scale',scale)
    h5filename = 'spkdata_%d.h5' % scale
    N = mlab.get(InputDist,'size').flatten()
    N_rev = mlab.get(InputDist,'rev_size').flatten()
    NN = N + N_rev
    print "Saving inputs..."
    for i in range(NN):
        stimulus = PreprocessedSpeechStimulus()
        S = mlab.generate_input(InputDist,i+1)
        nc = mlab.length(S.channel).flatten()
        for c in range(nc):
            channel = Channel(S.channel[c].data.flatten())
            stimulus.channel.append(channel)
        stimulus.Tsim = S.info.Tstim.flatten()[0]
        stimulus.file = S.info.file
        stimulus.speaker = S.info.speaker.flatten()[0]
        stimulus.utterance = S.info.utterance.flatten()[0]
        stimulus.digit = S.info.digit.flatten()[0]
        stimulus.reversed = S.info.reversed.flatten()[0]==1
        if stimulus.reversed:
            grpname = stimulus.file + "_rev"
        else:
            grpname = stimulus.file
        print "%d: %s" % (i,grpname)
        stimulus.save(filename=h5filename, grpname=grpname)

if __name__ == '__main__':
    import pylab

    # generate input distribution with digits 1 and 2 and no reversed digits
    indist = PreprocessedSpeech(**{'digits':[1,2],'rev_digits':[]})
    print "total size:", indist.getTotalSize()

    # generate random stimulus
    stim1 = indist.generate()
    pylab.figure()
    stim1.plot()
    pylab.title(stim1.getTitle())
    print stim1

    # generate stimulus by index
    stim2 = indist.generate(0)
    pylab.figure()
    stim2.plot()
    pylab.title(stim2.getTitle())
    print stim2

    # generate stimulus by speaker, utterance, digit, and reversed (bool)
    stim3 = indist.generate((2,2,2,False))
    pylab.figure()
    stim3.plot()
    pylab.title(stim3.getTitle())
    print stim3

    pylab.show()