# spike read spikes and print using
# matplotlib the data

# hebbian step func.
aLTP=2e-2
aLTD=2e-2
tauLTP=20.
tauLTD=20.
delay_post=3
delay_pre=1
wmax=1.

import params
def syn_hebb(ispre, w, tpre, tpost, P, M, t):
  from math import exp
  if ispre:
    P=P*exp((tpre-t)/tauLTP)+aLTP
    interval=tpost-t
    dw=wmax*M*exp(interval/tauLTD)
  else:
    M=M*exp((tpost-t)/tauLTD)-aLTD
    interval=t-tpre
    dw=wmax*P*exp(-interval/tauLTP)

  w += dw
  if w > wmax: w = wmax
  elif w < 0: w = 0
  return w, P, M

def hebbian(winit, trpre, trpost):
  # this destroy vector data
  for i in range(len(trpre)):
    trpre[i] += delay_pre
  for i in range(len(trpost)):
    trpost[i] += delay_post

  t=[0]
  w=[winit]
  P=M=0
  i=j=1
  while i<len(trpre) and j<len(trpost):
    if trpre[i] <= trpost[j]:
      ispre=True
      t.append(trpre[i])
    else:
      ispre=False
      t.append(trpost[j])
    wnew,P,M=syn_hebb(ispre,w[-1],trpre[i-1],trpost[j-1],P,M,t[-1])
    w.append(wnew)
    if ispre:
      i+=1
    else:
      j+=1

  for i in range(i,len(trpre)):
    t.append(trpre[i])
    wnew,P,M=syn_hebb(True,w[-1],trpre[i-1],trpost[j-1],P,M,t[-1])
    w.append(wnew)   

  for j in range(j,len(trpost)):
    t.append(trpost[j])   
    wnew,P,M=syn_hebb(False,w[-1],trpre[i-1],trpost[j-1],P,M,t[-1])
    w.append(wnew)

  return t, w

# non hebbian step

ltdinvl_excit = 250.
ltpinvl_excit = 33.33
sighalf_excit = 25.
ltdinvl_inhib = 250.
ltpinvl_inhib = 33.33
sighalf_inhib = 25.



def syn_step(s, isi, excit=True):

  if excit:
    ltpinvl = ltpinvl_excit
    ltdinvl = ltdinvl_excit
    sighalf = sighalf_excit
  else:
    ltpinvl = ltpinvl_inhib
    ltdinvl = ltdinvl_inhib
    sighalf = sighalf_inhib

  if isi < ltpinvl:
    s += 1
  elif isi < ltdinvl:
    s -= 1
  if s > 2 * sighalf:
    s = 2 * sighalf
  elif s < 0:
    s = 0
  return s

class SpikesReader:
    
    def __init__(self, filename, *args):
        self.sort = False
        from struct import unpack
        self.cache_max_len = 100

        self.bincoded = filename.endswith('.spk')
        self.__initweights = {} # initial weights
        
        if len(args) > 0:
          f = open(args[0], 'r')
          line = f.readline()
          while line:
            gid, s = line.split()[:2]
            gid = int(gid)
            s = int(s)
            self.__initweights.update({ gid:s })
            line = f.readline()
          f.close()
          
        if self.bincoded:
          # init for binary format
          self.header = {}
          self.fi = open(filename, 'rb')
          
          offset = unpack('>q', self.fi.read(8))[0]
        
          hlen = offset / 8

          offset += 8

          for j in range(hlen):
            gid, n = unpack('>LL', self.fi.read(8)) # read

            if not self.header.has_key(gid):
              self.header.update({ gid:[(offset, n)] })
            else:
              self.header[gid].append((offset, n)) 

            offset += n * 4
            
        else:
          # init for textual format
          self.fi = open(filename, 'r')

        self.__cache = {}
        self.__old = []

    def retrieve(self, gid):
        # if gid in cache don't retrieve
        if gid not in self.__cache:

            # clean the oldest
            if len(self.__cache) >= self.cache_max_len:
                del self.__cache[self.__old[0]]
                del self.__old[0]

            # read
            t = [ ]

            if self.bincoded:
              # binary format reading code
              for offset, n in self.header[gid]:
                self.fi.seek(offset)
              
                from struct import unpack
                for i in range(n):
                  t.append(unpack('>f', self.fi.read(4))[0])
            else:
              # if not bincoded
              # it's the old textual format
              
              self.fi.seek(1)
              line = self.fi.readline()
              while line:
                tks = line.split()
                if int(tks[1]) == gid:
                  t.append(float(tks[0]))
                line = self.fi.readline()
              if len(t) == 0:
                raise KeyError
            # only for errors...
            if self.sort:
              t = sorted(t)
            self.__old.append(gid)    
            self.__cache.update({ gid:t })

        from copy import copy
        return copy(self.__cache[gid])

    def frequency(self, gid):
        
        t = [ 0. ] + self.retrieve(gid)
        
        fr = [ 0. ]
        for i in range(1, len(t)):
            fr.append(1000. / (t[i] - t[i - 1]))

        return t, fr

    def step(self, gid):
        from mgrs import gid_mgrs_begin
        if gid < gid_mgrs_begin:
            return None
        if gid%2!=0 and params.use_fi_stdp:
          if self.__initweights.has_key(gid):
            wi=self.__initweights[gid] 
          else:
            wi=0
          tpre=[0.]
          if self.header.has_key(gid): tpre += self.retrieve(gid)
          tpost=[0.]
          if self.header.has_key(gid): tpost += self.retrieve(gid+1)
          t,w = hebbian(wi,tpre,tpost)
          return t,w
        else:        
          t = [ 0. ] + self.retrieve(gid)
          try:
            s = [ self.__initweights[gid] ]
          except KeyError:
            s = [ 0 ]
          for i in range(1, len(t)):
            s.append(syn_step(s[-1], t[i] - t[i - 1], excit=(gid % 2 == 0)))

          return t, s

    def close(self):
        self.fi.close()


# read time stop
tstop = None
try:
  from sys import argv
  tstop = float(argv[argv.index('-tstop') + 1])
except:
  pass
# @@@@@@@@@@@@@@

def show(sr, gids, xlabel, ylabel, call, title, ylim, legend=True):
    if len(gids) == 0:
      return
    from bindict import query as descr
    
    import matplotlib.pyplot as plt
    plt.figure()
    
    color = [ 'b', 'g', 'r', 'c', 'm', 'y', 'k' ]
    never_drawed = False
    for i, g in enumerate(gids):
        never_drawed = never_drawed | call(g, i, color[i % len(color)], descr(g)[-1])

    if not never_drawed:
      plt.close()
      return False

    if legend:
      plt.legend().draggable()
    plt.ylabel(ylabel)
    plt.xlabel(xlabel)
    plt.title(title)
    if len(ylim) == 2:
      plt.ylim(ylim)
    if tstop:
      plt.xlim([ 0, tstop ])
    plt.draw()
    return True

def show_raster(sr, gids):
    import matplotlib.pyplot as plt
    def raster(gid, i, col, descr):
        try:
          t = sr.retrieve(gid)
          plt.scatter(t, [ i ] * len(t), s=10, marker='|', label=descr, c=col)
        except KeyError:
          return False
        return True
    return show(sr, gids, 'spike time (ms)', '', raster, 'Spike raster', [ -1, len(gids) + 1 ])
        

def show_freqs(sr, gids):
    import matplotlib.pyplot as plt
    def freq(gid, i, col, descr):
        try:
          t, fr = sr.frequency(gid)
          plt.plot(t, fr, '-' + col + 'o', label=descr)
        except KeyError:
          return False
        return True

    return show(sr, gids, 'spike time (ms)', 'Freq. (Hz)', freq, 'Frequency', [])
            
        

def show_weights(sr, gids):
    from mgrs import gid_mgrs_begin
    
    # not weights
    gids = gids.difference(set(range(gid_mgrs_begin)))

    import matplotlib.pyplot as plt
    
    def step(gid, i, col, descr):
        try:
          t, d = sr.step(gid)
          if gid%2==0:
            maxsig=2*sighalf_excit
          elif params.use_fi_stdp:
            maxsig=wmax
          else:
            maxsig=2*sighalf_inhib

          for i in range(len(d)): d[i] = d[i]/maxsig #* maxsig

          plt.plot(t, d, col + '-', label=descr)
        except KeyError:
          return False
        return True
    return show(sr, gids, 'spike time (ms)', 'Step', step, 'Syn. Steps', [-0.1, 1.1])#[ -1, 2 * max(sighalf_inhib, sighalf_excit) + 1])    
        


# main history
if __name__ == '__main__':
    
    from sys import argv
    i = argv.index('-i')
    

    
    gids = set()
    for sg in argv[argv.index('-gid') + 1:]:
        try:
          gids.add(int(sg))
        except ValueError:
          break

    sr = SpikesReader(argv[i + 1])
      
    # show all
    import matplotlib.pyplot as plt   
    show_freqs(sr, gids)
    show_weights(sr, gids)
    show_raster(sr, gids)
    plt.show()