# spike read spikes and print using
# matplotlib the data
# hebbian step func.
aLTP=2e-2
aLTD=2e-2
tauLTP=20.
tauLTD=20.
delay_post=3
delay_pre=1
wmax=1.
ltdinvl_excit = 250.
ltpinvl_excit = 33.33
sighalf_excit = 25.0
ltdinvl_inhib = 250.
ltpinvl_inhib = 33.33
sighalf_inhib = 25.0
import params
def syn_hebb(ispre, w, tpre, tpost, P, M, t):
from math import exp
if ispre:
P=P*exp((tpre-t)/tauLTP)+aLTP
interval=tpost-t
dw=wmax*M*exp(interval/tauLTD)
else:
M=M*exp((tpost-t)/tauLTD)-aLTD
interval=t-tpre
dw=wmax*P*exp(-interval/tauLTP)
w += dw
if w > wmax: w = wmax
elif w < 0: w = 0
return w, P, M
def hebbian(winit, trpre, trpost):
# this destroy vector data
for i in range(len(trpre)):
trpre[i] += delay_pre
for i in range(len(trpost)):
trpost[i] += delay_post
t=[0]
w=[winit]
P=M=0
i=j=1
while i<len(trpre) and j<len(trpost):
if trpre[i] <= trpost[j]:
ispre=True
t.append(trpre[i])
else:
ispre=False
t.append(trpost[j])
wnew,P,M=syn_hebb(ispre,w[-1],trpre[i-1],trpost[j-1],P,M,t[-1])
w.append(wnew)
if ispre:
i+=1
else:
j+=1
for i in range(i,len(trpre)):
t.append(trpre[i])
wnew,P,M=syn_hebb(True,w[-1],trpre[i-1],trpost[j-1],P,M,t[-1])
w.append(wnew)
for j in range(j,len(trpost)):
t.append(trpost[j])
wnew,P,M=syn_hebb(False,w[-1],trpre[i-1],trpost[j-1],P,M,t[-1])
w.append(wnew)
return t, w
# non hebbian step
def syn_step(s, isi, excit=True):
if excit:
ltpinvl = ltpinvl_excit
ltdinvl = ltdinvl_excit
sighalf = sighalf_excit
else:
ltpinvl = ltpinvl_inhib
ltdinvl = ltdinvl_inhib
sighalf = sighalf_inhib
if isi < ltpinvl:
s += 1
elif isi < ltdinvl:
s -= 1
if s > 2 * sighalf:
s = 2 * sighalf
elif s < 0:
s = 0
return s
def syn_weights(t, winit=0, excit=True):
w = [ winit ]
for i in range(1, len(t)):
w.append(syn_step(w[-1], t[i] - t[i-1], excit))
return w
class SpikesReader:
def __init__(self, filename, *args):
self.sort = False
from struct import unpack
self.cache_max_len = 100
self.tstop = None
self.bincoded = filename.endswith('.spk') or filename.endswith('.spk2')
self.__initweights = {} # initial weights
if len(args) > 0:
f = open(args[0], 'r')
line = f.readline()
while line:
gid, s = line.split()[:2]
gid = int(gid)
s = int(s)
self.__initweights.update({ gid:s })
line = f.readline()
f.close()
if self.bincoded:
# init for binary format
self.header = {}
self.fi = open(filename, 'rb')
offset = unpack('>q', self.fi.read(8))[0]
# read the time
if filename.endswith('.spk2'):
self.tstop = unpack('>f', self.fi.read(4))[0]
hlen = offset / 8
offset += 4
else:
hlen = offset / 8
offset += 8
for j in range(hlen):
gid, n = unpack('>LL', self.fi.read(8)) # read
if not self.header.has_key(gid):
self.header.update({ gid:[(offset, n)] })
else:
self.header[gid].append((offset, n))
offset += n * 4
else:
# init for textual format
self.fi = open(filename, 'r')
self.__cache = {}
self.__old = []
def retrieve(self, gid):
# if gid in cache don't retrieve
if gid not in self.__cache:
# clean the oldest
if len(self.__cache) >= self.cache_max_len:
del self.__cache[self.__old[0]]
del self.__old[0]
# read
t = [ ]
if self.bincoded:
# binary format reading code
offset, n = self.header[gid][-1]
# for offset, n in self.header[gid]:
self.fi.seek(offset)
from struct import unpack
# for i in range(n):
# t.append(unpack('>f', self.fi.read(4))[0])
t = list(unpack('>' + 'f'*n, self.fi.read(4*n)))
else:
# if not bincoded
# it's the old textual format
self.fi.seek(1)
line = self.fi.readline()
while line:
tks = line.split()
if int(tks[1]) == gid:
t.append(float(tks[0]))
line = self.fi.readline()
if len(t) == 0:
raise KeyError
# only for errors...
if self.sort:
t = sorted(t)
self.__old.append(gid)
self.__cache.update({ gid:t })
from copy import copy
return copy(self.__cache[gid])
def freqvssniff(self, gid, tstart=50.0):
t = [ tstart+params.sniff_invl*0.5 ]
nspk = [ 0 ]
for x in self.retrieve(gid):
i = int((x-tstart)/params.sniff_invl)
if i >= len(t):
t.append(t[-1]+params.sniff_invl*0.5)
nspk.append(0)
nspk[-1] += 1*1000.0/params.sniff_invl
return t, nspk
def frequency(self, gid):
t = [ 0. ] + self.retrieve(gid)
fr = [ 0. ]
for i in range(1, len(t)):
fr.append(1000. / (t[i] - t[i - 1]))
return t, fr
def stepvssniff(self, gid, tstart=50.0):
return self.step(gid, dt=params.sniff_invl)
def step(self, gid, dt=None, tlast=50.0):
from mgrs import gid_mgrs_begin
if gid < gid_mgrs_begin:
return None
if gid%2!=0 and params.use_fi_stdp:
if self.__initweights.has_key(gid):
wi=self.__initweights[gid]
else:
wi=0
tpre=[0.]
if self.header.has_key(gid): tpre += self.retrieve(gid)
tpost=[0.]
if self.header.has_key(gid): tpost += self.retrieve(gid+1)
t,w = hebbian(wi,tpre,tpost)
return t,w
else:
t = [ 0. ]
try:
s = [ self.__initweights[gid] ]
except KeyError:
try:
if gid%2:
init_weight = params.init_inh_weight
else:
init_weight = params.init_exc_weight
except:
init_weight = 0
s = [ init_weight ]
try:
t += self.retrieve(gid)
except KeyError:
return t, s
if dt == None:
for i in range(1, len(t)):
s.append(syn_step(s[-1], t[i] - t[i-1], excit=(gid%2 == 0)))
return t, s
_t = [ ]
_s = [ ]
for i in range(1, len(t)):
s.append(syn_step(s[-1], t[i] - t[i-1], excit=(gid%2 == 0)))
if t[i]+params.sniff_invl > tlast:
_t.append(tlast)
_s.append(s[-1])
tlast += params.sniff_invl
return _t, _s
def close(self):
self.fi.close()
class SpikesWriter:
def __init__(self, filename, tstop):
self.filename = filename
self.__fo = open(filename + '.data', 'wb')
self.header = {}
self.tstop = tstop
def write(gid, t):
from struct import pack
self.header[gid] = len(t)
self.__fo.write(pack('>'+('f'*len(t)), t))
def close(self, filename):
self.__fo.close()
from struct import pack
fo = open(self.filename + '.time', 'wb')
fo.write(pack('>f', self.tstop))
fo.close()
# write header
fo = open(self.filename + '.header', 'wb')
for x in self.header.items():
fo.write(pack('>LL', x))
fo.close()
from os import path
fo = open(self.filename + '.size', 'wb')
fo.write(pack('>q', path.getsize(self.filename + '.header')))
fo.close()
# read time stop
tstop = 20050
try:
from sys import argv
tstop = float(argv[argv.index('-tstop') + 1])
except:
pass
# @@@@@@@@@@@@@@
def show(sr, gids, xlabel, ylabel, call, title, ylim, legend=True):
if len(gids) == 0:
return
from bindict import query as descr
import matplotlib.pyplot as plt
plt.figure()
color = [ 'b', 'g', 'r', 'c', 'm', 'y', 'k' ]
never_drawed = False
for i, g in enumerate(gids):
never_drawed = never_drawed | call(g, i, color[i % len(color)], descr(g)[-1])
if not never_drawed:
plt.close()
return False
if legend:
plt.legend().draggable()
plt.ylabel(ylabel)
plt.xlabel(xlabel)
plt.title(title)
if len(ylim) == 2:
plt.ylim(ylim)
if sr.tstop:
plt.xlim([ 0, sr.tstop ])
elif tstop:
plt.xlim([ 0, tstop ])
plt.draw()
return True
def show_raster(sr, gids):
import matplotlib.pyplot as plt
def raster(gid, i, col, descr):
try:
t = sr.retrieve(gid)
plt.scatter(t, [ i ] * len(t), s=10, marker='|', label=descr, c=col)
except KeyError:
return False
return True
return show(sr, gids, 'spike time (ms)', '', raster, 'Spike raster', [ -1, len(gids) + 1 ])
def show_freqs(sr, gids):
import matplotlib.pyplot as plt
def freq(gid, i, col, descr):
try:
t, fr = sr.frequency(gid)
plt.plot(t, fr, '-' + col + 'o', label=descr)
except KeyError:
return False
return True
return show(sr, gids, 'spike time (ms)', 'Freq. (Hz)', freq, 'Frequency', [])
def show_weights(sr, gids):
from mgrs import gid_mgrs_begin
# not weights
gids = gids.difference(set(range(gid_mgrs_begin)))
import matplotlib.pyplot as plt
def step(gid, i, col, descr):
try:
t, d = sr.step(gid)
if gid%2==0:
maxsig=2*sighalf_excit
elif params.use_fi_stdp:
maxsig=wmax
else:
maxsig=2*sighalf_inhib
for i in range(len(d)): d[i] = d[i]/maxsig #* maxsig
plt.plot(t, d, col + '-', label=descr)
except KeyError:
return False
return True
return show(sr, gids, 'spike time (ms)', 'Step', step, 'Syn. Steps', [-0.1, 1.1])#[ -1, 2 * max(sighalf_inhib, sighalf_excit) + 1])
def show_evol(sr, gids, tstart=50.0):
import matplotlib.pyplot as plt
def evol(gid, i, col, descr):
try:
if gid%2:
ltpinvl=ltpinvl_inhib
ltdinvl=ltdinvl_inhib
else:
ltpinvl=ltpinvl_excit
ltdinvl=ltdinvl_excit
t, fr = sr.frequency(gid)
dw = [0]*len(fr)
isniff = 0
lastdw = 0
for i in range(1, len(t)):
_isniff = int(t[i]/params.sniff_invl)
if _isniff > isniff:
isniff = _isniff
lastdw = 0
if fr[i] >= 1000/ltpinvl:
dw[i] = lastdw + 1
elif fr[i] >= 1000/ltdinvl:
dw[i] = lastdw - 1
else:
dw[i] = lastdw
lastdw = dw[i]
plt.plot(t, dw, '-' + col + 'o', label=descr)
except KeyError:
return False
return True
return show(sr, gids, 'spike time (ms)', 'DStep', evol, 'Evolution', [])
# main history
if __name__ == '__main__':
from sys import argv
i = argv.index('-i')
gids = set()
for sg in argv[argv.index('-gid') + 1:]:
try:
gids.add(int(sg))
except ValueError:
break
sr = SpikesReader(argv[i + 1])
# show all
import matplotlib.pyplot as plt
show_freqs(sr, gids)
show_weights(sr, gids)
show_raster(sr, gids)
plt.show()