'''
Klaus Schuch schuch@igi.tugraz.at
March 2007
'''
from inputgenerator import *
from spikes2rate import *
from rate2spikes import *
import numpy
import scipy
import pyV1.utils
from scipy.interpolate.interpolate import interp1d
class RandomRate(RateGenerator):
'''genarates spike trains with a rate modulated by r(t) which is drawn randomly'''
def __init__(self, **kwds):
''' nChannels ... number of spike trains (channels)
Tstim ... length of spike trains
dt ... time step for rate calculation
nRates ... number of different rates that can be chosen (equal spaced in the interval [0,fmax]
binwidth ... in each interval of length binwidth a new random rate is drawn
fmax ... maximal rate of a single channel
'''
super(RandomRate, self).__init__(**kwds)
self.nRates = float(kwds.get('nRates', numpy.inf))
self.binwidth = kwds.get('binwidth', 30e-3)
self.fmax = kwds.get('fmax', 80)
def __str__(self):
desc = ''' RANDOM RATE
nChannels : %s
Tstim : %s
nRates : %s
binwidth : %s
fmax : %s
dt : %s
''' % (self.nChannels, self.Tstim, self.nRates, self.binwidth, self.fmax, self.dt)
return desc
def calcRate(self):
tt = numpy.arange(0, self.Tstim, self.dt)
t = numpy.arange((-self.binwidth/2), self.Tstim+self.binwidth, self.binwidth)
if self.nRates == numpy.inf:
r = numpy.random.uniform(0, self.fmax, len(t))
elif self.nRates > 1:
r = numpy.ceil(random.uniform(0, self.nRates, len(t)))/(self.nRates)*self.fmax
else:
r = numpy.ones(len(t))*self.fmax
r = interp1d(t, r, 'linear') # nearest not implemented yet
return r(tt)
def plot(self, stim=None, Tseg=0):
import pylab
import matplotlib
if stim is None:
stim=self.generate()
matplotlib.rc('text', usetex=True)
pylab.figure()
pylab.subplot(3,1,1)
binwidth=30e-3 #self.binwidth
tr=numpy.arange(0, stim.Tsim, stim.dt)
(y, ty) = spikes2rate(utils.flatten([c.data for c in stim.channel]), binwidth)
ty=ty.compress((y != numpy.nan).flat)
y=y.compress((y != numpy.nan).flat)
pylab.plot(tr, stim.r, ty, y/self.nChannels, 'r--')
pylab.axis('tight')
pylab.setp(pylab.gca(), xlim=[0, self.Tstim])
pylab.xlabel('time [s]')
pylab.ylabel('rate per spike train [Hz]')
pylab.title('rates')
pylab.legend(('r(t)', r'$r_{measured}$ ($\Delta=%s~ms$)' % str(binwidth*1000)))
pylab.subplot(3,1,2) # cla reset; hold on;
for j in range(self.nChannels):
st=stim.channel[j].data
for spike in list(st):
pylab.plot(numpy.array([spike, spike]), (numpy.array([[-0.3],[0.3]])+j+1), color='k')
pylab.setp(pylab.gca(), xlim=[0, self.Tstim], ylim=[0.5, self.nChannels+0.5], yticks=numpy.arange(self.nChannels)+1)
pylab.xlabel('time [s]')
pylab.ylabel('channel')
pylab.title('spike trains')#, fontweight='bold')
pylab.subplot(3,1,3)
if Tseg > 0:
r=self.rand_rate(Tseg)
cr=scipy.correlate(r-numpy.mean(r), r-numpy.mean(r), mode='full')
else:
Tseg=self.Tstim
cr=scipy.correlate(stim.r-numpy.mean(stim.r), stim.r-numpy.mean(stim.r), mode='full')
cr=cr/max(cr)
tr=(numpy.arange(len(cr))-len(cr)/2)*self.dt
cs=scipy.correlate(y-y.mean(), y-y.mean(), mode='full')
cs=cs/cs.max()
ts=(numpy.arange(len(cs))-len(cs)/2)*self.dt
pylab.plot(tr,cr,ts,cs,'r--')
pylab.xlabel('lag [s]')
pylab.ylabel('correlation coeff')
mm=max(abs(min(min(tr),min(ts))), abs(max(max(tr),max(ts))))
pylab.setp(pylab.gca(), xlim=[-mm, mm], ylim=[min(cs), 1])
pylab.title('auto-correlation', fontweight='bold')
pylab.legend(('r(t)', r'$r_{measured}$ ($\Delta=%s~ms$)' % str(binwidth*1000)))
if __name__ == '__main__':
import pylab
rrate = RandomRate(Tstim=1.0, nChannels=3)
# stim = rrate.generate()
# rrate.plot(stim)
rrate.plot()
pylab.show()