from bulbdef import gid_is_mitral, gid_is_mtufted, gid_is_granule
class nonhebbian:
inh_invl_ltd = 250.
inh_invl_ltp = 33.33
inh_sighalf = 25.0
exc_invl_ltd = 250.
exc_invl_ltp = 33.33
exc_sighalf = 25.0
inh_slope = 10.
exc_slope = 10.
@staticmethod
def step(wlast, isi, excit):
if excit:
invl_ltp = nonhebbian.exc_invl_ltp
invl_ltd = nonhebbian.exc_invl_ltd
sighalf = nonhebbian.exc_sighalf
else:
invl_ltp = nonhebbian.inh_invl_ltp
invl_ltd = nonhebbian.inh_invl_ltd
sighalf = nonhebbian.inh_sighalf
w = wlast * 2 * sighalf
if isi < invl_ltp:
w += 1
elif isi < invl_ltd:
w -= 1
if w > 2 * sighalf:
w = 2 * sighalf
elif w < 0:
w = 0
return w / (2 * sighalf)
@staticmethod
def weights(t, winit, excit):
if excit:
winit /= 2*nonhebbian.exc_sighalf
else:
winit /= 2*nonhebbian.inh_sighalf
t = [0]+t
w = [winit]
for i in range(1, len(t)):
w.append(nonhebbian.step(w[-1], t[i]-t[i-1], excit))
return t, w
class hebbian:
pre_delay = 1
post_delay = 1
wmax = 1
tauLTP = 20.
tauLTD = 20.
aLTP = 0.001
aLTD = 0.00106
@staticmethod
def step(wlast, t, tpre, tpost, P, M, is_pre=True):
from math import exp
wmax = hebbian.wmax
tauLTP = hebbian.tauLTP
tauLTD = hebbian.tauLTD
aLTP = hebbian.aLTP
aLTD = hebbian.aLTD
if is_pre:
P = P*exp((tpre - t)/tauLTP)+aLTP
interval = tpost-t
dw = wmax*M*exp(interval/tauLTD)
else:
M = M*exp((tpost - t)/tauLTD)-aLTD
interval = t-tpre
dw = wmax*P*exp(-interval/tauLTP)
w = wlast + dw
if w > wmax:
w = wmax
elif w < 0:
w = 0
return w, P, M
@staticmethod
def weights(tpre, tpost, winit):
for i in range(len(tpre)): tpre[i] += hebbian.pre_delay
for i in range(len(tpost)): tpost[i] += hebbian.post_delay
t = [0]
w = [winit]
P = 0
M = 0
i = 1
j = 1
while i < len(tpre) and j < len(tpost):
if tpre[i] <= tpost[j]:
t.append(tpre[i])
wnew, P, M = hebbian.step(w[-1], t[-1], tpre[i-1], tpost[j-1], P, M)
i += 1
else:
t.append(tpost[j])
wnew, P, M = hebbian.step(w[-1], t[-1], tpre[i-1], tpost[j-1], P, M, False)
j += 1
w.append(wnew)
for i in range(i, len(tpre)):
t.append(tpre[i])
wnew, P, M = hebbian.step(w[-1], t[-1], tpre[i-1], tpost[j-1], P, M)
w.append(wnew)
for j in range(j, len(tpost)):
t.append(tpost[j])
wnew, P, M = hebbian.step(w[-1], t[-1], tpre[i-1], tpost[j-1], P, M, False)
w.append(wnew)
return t, w
class SpikesReader:
def __init__(self, spkfilename, wfilename=None):
from struct import unpack
self.__spkcache = {}
self.__rank_by_age = []
self.cache_size = 100
self.initweights = {}
self.tstop = None
self.fi = open(spkfilename, 'rb')
offset = unpack('>q', self.fi.read(8))[0]
Nrecord = offset/8
offset += 8
# initial weights
if wfilename:
with open(wfilename, 'r') as wfi:
line = wfi.readline()
while line:
tk = line.split()
self.initweights[int(tk[0])] = int(tk[1])
line = wfi.readline()
self.hebbian = False
# read time
if spkfilename.endswith('.spk2'):
self.tstop = unpack('>f', self.fi.read(4))[0]
offset += 4
# read the header
self.header = {}
for i in range(Nrecord):
gid, nspk = unpack('>LL', self.fi.read(8))
self.header[gid] = (offset, nspk)
offset += nspk*4
def close(self):
self.fi.close()
def retrieve(self, gid):
from struct import unpack
try:
t = self.__spkcache[gid] # retrieve from the cache before
# push on head the gid
i = self.__rank_by_age.index(gid)
aux = self.__rank_by_age[-1]
self.__rank_by_age[i] = aux
self.__rank_by_age[-1] = gid
return t+[] # clone list
except KeyError:
pass
# not cached, retrieve offset, nspikes
offset, nspk = self.header[gid]
self.fi.seek(offset)
tspk = list(unpack('>' + 'f'*nspk, self.fi.read(4*nspk)))
# add to the cache
if len(self.__spkcache) >= self.cache_size:
oldest_gid = self.__rank_by_age[0]
del self.__spkcache[oldest_gid]
del self.__rank_by_age[0]
self.__spkcache[gid] = tspk
self.__rank_by_age.append(gid)
return tspk
def weight(self, gid):
# soma does not allow syn. weight
if gid_is_mitral(gid) or \
gid_is_mtufted(gid) or \
gid_is_granule(gid):
return None
# init
try:
winit = self.initweights[gid]
except KeyError:
winit = 0
if gid % 2 != 0 and self.hebbian:
tpre = [0]
try:
tpre += self.retrieve(gid)
except KeyError: pass
tpost = [0]
try:
tpost += self.retrieve(gid+1)
except KeyError: pass
return hebbian.weights(tpre, tpost, winit)
else:
return nonhebbian.weights(self.retrieve(gid), winit, gid % 2 == 0)
def frequency(self, gid):
t = self.retrieve(gid)
fr = []
for i in range(1, len(t)):
fr.append(1000.0/(t[i]-t[i-1]))
return t[1:], fr
if __name__ == '__main__':
sr = SpikesReader('out.spk2')
print sr.tstop
print sr.retrieve(0)