"""
analytics for general integrate-and-fire neurons driven by dichotomous noise
"""
from analytics.decorators.cache_dec import cached
from analytics.decorators.param_dec import dictparams
from analytics.helpers import integrate, heav
from numpy import exp, cos, sin, sqrt, real, arctan, arctanh, log, abs
infty=15
dv = 0.# 1e-10
# the following classes describe models (PIF, LIF, QIF ...) for use with dichotomous noise
class PIF:
def f(self, v, mu):
return mu
def fp(self, v, mu):
return 0.
def phi(self, v, mu, s, kp, km):
return v*(kp/(mu + s) + km/(mu - s))
def Tdp(self, mu, s, vr, vt): # deterministic ISI in + state without tr!
return (vt-vr)/(mu+s)+tr
def Tdm(self, mu, s, vr, vt):
return (vt-vr)/(mu-s)+tr
@cached
def intervals(self, mu, s, vr, vt):
return ( ((-infty, vt, vt+dv),) if mu-s<0 else ((vr, vt, vr-dv),) )
def __hash__(self): # needed for caching
return "PIF"
class LIF:
def f(self, v, mu):
return mu - v
def fp(self, v, mu):
return -1.
def phi(self, v, mu, s, kp, km):
return -kp * log(abs(mu+s-v)) - km*log(abs(mu-s-v))
def Tdp(self, mu, s, vr, vt): # deterministic ISI in + state without tr!
return log((mu+s-vr)/(mu+s-vt))
def Tdm(self, mu, s, vr, vt):
return log((mu-s-vr)/(mu-s-vt))
@cached
def intervals(self, mu, s, vr, vt):
sfp = mu-s
if sfp < vr:
return (sfp, vt, vt+dv),
elif sfp > vt:
return (vr, vt, vr-dv),
else:
return ((vr, sfp, vr-dv), (sfp, vt, vt+dv))
def __hash__(self): # needed for caching
return "LIF"
class QIF:
def f(self, v, mu):
return mu + v**2
def fp(self, v, mu):
return 2*v
def phi(self, v, mu, s, kp, km):
return (kp/sqrt(mu+s) * arctan(v/sqrt(mu+s)) + km *
(1./sqrt(mu-s)*arctan(v/sqrt(mu-s)) if mu > s else
-1./sqrt(s-mu) * (arctanh(v/sqrt(s-mu)) if (sqrt(s-mu) > v and v > -sqrt(s-mu))
else .5*log((v/sqrt(s-mu)+1.)/(v/sqrt(s-mu)-1.)))))
def Tdp(self, mu, s, vr, vt): # deterministic ISI in + state without tr
return (arctan(vt/sqrt(mu+s))-arctan(vr/sqrt(mu+s)))/sqrt(mu+s)
def Tdm(self, mu, s, vr, vt):
return (arctan(vt/sqrt(mu-s))-arctan(vr/sqrt(mu+s)))/sqrt(mu-s)
@cached
def intervals(self, mu, s, vr, vt):
if s<mu: # we do not have fps
return (vr, vt, vr),
else:
sfp = -sqrt(s-mu)
ufp = sqrt(s-mu)
if self.f(vr, mu)-s>0:
if vr < sfp < vt:
if sfp < ufp < vt:
return ((vr, sfp, vr-dv), (sfp, ufp, ufp), (ufp, vt, ufp))
else:
return ((vr, sfp, vr-dv), (sfp, vt, vt+dv))
else:
return (vr, vt, vr-dv),
else:
if vr < ufp < vt:
return ((sfp, ufp, ufp), (ufp, vt, ufp))
else:
return (sfp, vt, vt+dv),
def __hash__(self): # needed for caching
return "QIF"
@dictparams
@cached
def if_dicho_alpha(model, mu, s, kp, km, vr, vt, tr):
"""Return the fraction of trajectories crossing the threshold in the + state"""
phi = lambda v: model.phi(v, mu, s, kp, km)
f = lambda v: model.f(v, mu)
ekt = exp(-(kp+km)*tr)
if f(vt)-s < 0:
return 1.
else:
ln, rn, cn = model.intervals(mu, s, vr, vt)[-1]
return 1.- (heav(vr-cn) * kp/(kp+km) * exp(phi(vr)) * (1.-ekt) + kp * integrate(lambda x: exp(phi(x))*heav(x-vr)/(f(x)+s), cn, vt)) / (exp(phi(vt))-heav(vr-cn)*ekt*exp(phi(vr)))
@dictparams
@cached
def T1(model, mu, s, kp, km, vr, vt, tr):
"""Return the mean first passage time"""
phi = lambda v: model.phi(v, mu, s, kp, km)
f = lambda v: model.f(v, mu)
ekt = exp(-(kp+km)*tr)
def plusint(l, r):
return integrate(lambda x: exp(phi(x)-phi(r))/(f(x)+s), l, r)
def minusint(l, r):
return integrate(lambda x: exp(-phi(x)+phi(l))/(f(x)-s), l, r)
al = if_dicho_alpha(locals())
res = 0.
ints = model.intervals(mu, s, vr, vt)
for i in ints:
l, r, c = i
cbar = l if c == r else r
res += (kp+km) * integrate(lambda x: heav(x-vr)/(f(x)+s) * minusint(x, cbar), l, r)
ln, rn, cn = ints[-1]
l0, r0, c0 = ints[0]
c0bar = l0 if c0 == r0 else r0
return res + tr + plusint(cn, vt) + ((1-al)/kp*ekt+1./(kp+km)*(1-ekt)) * (exp(phi(vr)-phi(c0bar))-1.+(kp+km) * minusint(vr, c0bar))
@dictparams
@cached
def P0(model, v, mu, s, kp, km, vr, vt, tr):
"""Return the stationary density. v may not be an array!"""
phi = lambda v: model.phi(v, mu, s, kp, km)
f = lambda v: model.f(v, mu)
fp = lambda v: model.fp(v, mu)
ekt = exp(-(kp+km)*tr)
ints = model.intervals(mu, s, vr, vt)
if v < ints[0][0] or v > ints[-1][1]:
return 0.
r0 = 1./if_dicho_T1(model, mu, s, kp, km, vr, vt, tr)
al = if_dicho_alpha(model, mu, s, kp, km, vr, vt, tr)
Gd = (2*al-1)*ekt + (km-kp)/(kp+km) * (1.-ekt)
# which interval are we in
if v <= ints[0][0]:
return 0.
l, r, c = (None, None, None)
for l, r, c in ints:
if l < v < r:
break
return r0/(f(v)**2-s**2) * (
# (heav(c-vr)-heav(v-vr)) * (s*Gd-f(vr))*exp(phi(vr)-phi(v))
#-(heav(c-vt)-heav(v-vt)) * (s*(2*al-1)-f(vt))*exp(phi(vt)-phi(v))
((1.if c>vr else 0.)-heav(v-vr)) * (s*Gd-f(vr))*exp(phi(vr)-phi(v))
-((1.if c>=vt else 0.)-heav(v-vt)) * (s*(2*al-1)-f(vt))*exp(phi(vt)-phi(v))
+ integrate(lambda x: exp(phi(x)-phi(v))*(heav(x-vr)-heav(x-vt))*(fp(x)+kp+km), c, v))