import string
## Setting GTK backend does not work on the gj cluster,
## as I don't have a local pyGTK for my local python2.6.
## So ignore when on gj or its nodes.
import socket
hname = socket.gethostname()
if 'gulabjamun' not in hname and 'node' not in hname:
import matplotlib
#matplotlib.use('Agg')
#matplotlib.use('GTK')
from pylab import *
from matplotlib import collections
from mpl_toolkits.axes_grid.inset_locator import inset_axes
from poisson_utils import *
## choose figure or poster defaults
poster = False
if not poster:
####### figure defaults
label_fontsize = 8 # pt
plot_linewidth = 0.5 # pt
linewidth = 1.0#0.5
axes_linewidth = 0.5
marker_size = 3.0 # markersize=<...>
cap_size = 2.0 # for errorbar caps, capsize=<...>
columnwidth = 85/25.4 # inches
twocolumnwidth = 174/25.4 # inches
linfig_height = columnwidth*2.0/3.0
fig_dpi = 300
else:
####### poster defaults
label_fontsize = 12 # pt
plot_linewidth = 1.0 # pt
linewidth = 1.0
axes_linewidth = 1.0
marker_size = 3.0
cap_size = 2.0 # for errorbar caps
columnwidth = 4 # inches
linfig_height = columnwidth*2.0/3.0
#######################
########## You need to run:
## From gj:
## ./restart_mpd_static
## The 0th (boss process) will always be node000 as it is the first node in ~/hostfile.
## Hence from node000: cd to the working directory simulations/
## (so that sys.path has accurate relative paths)
## mpiexec -machinefile ~/hostfile -n <numprocs> ~/Python-2.6.4/bin/python2.6 <script_name> [args]
## 0 rank process is for collating all jobs. (rank starts from 0)
## I assume rank 0 process always runs on the machine whose
## X window system has a Display connected and can show the graphs!!!!!!
## The rank 0 stdout is always directed to the terminal from which mpiexec was run.
## I hope X output also works the same way.
## For long simulations save results in a text file
## for replotting later and avoid above ambiguity.
from mpi4py import MPI
mpicomm = MPI.COMM_WORLD
mpisize = mpicomm.Get_size() # Total number of processes
mpirank = mpicomm.Get_rank() # Number of my process
mpiname = MPI.Get_processor_name() # Name of my node
# The 0th process is the boss who collates/receives all data from workers
boss = 0
print 'Process '+str(mpirank)+' on '+mpiname+'.'
def calc_STA( inputlist, spiketrain, dt, STAtime):
""" spiketrain is a list of output spiketimes,
inputlist is the input timeseries with dt sample time,
STAtime is time for which STA must be computed.
User must ensure that inputlist is at least
as long as max spiketime in spiketrain.
endidx = int(spiketime/dt) not round(spiketime/dt)
as index 0 represents input from time 0 to time dt.
returns number of relevant spikes and
sum of Spike Triggered Averages as a list of length STAtime/dt.
"""
lenSTA = int(STAtime/dt)
STAsum = zeros(lenSTA)
numspikes = 0
for spiketime in spiketrain:
if spiketime<STAtime: continue
endidx = int(spiketime/dt)
startidx = endidx - lenSTA
STAsum += inputlist[startidx:endidx]
numspikes += 1
return numspikes,STAsum
def get_phaseimage( phaselist, phasemax, dt,\
overlay = False, rasterwidth = 10, rasterheight = 1 ):
phaseim = None
for resplist in phaselist:
## each spike 'tick' is horizontal with dimensions rasterwidth x rasterheight
phaseim_line = zeros( (int(phasemax/dt)*rasterheight, rasterwidth) )
for phase in resplist:
## a horizontal line of rasterwidth for every spike
## of a given trial, respcycle and phase.
row = int(phase/dt)*rasterheight
phaseim_line[row:row+rasterheight,:] = 1.0
if phaseim is None: phaseim = phaseim_line
else:
if overlay: phaseim += phaseim_line
## on numpy array(), axis=1, keep rows same, add cols.
else: phaseim = append( phaseim, phaseim_line, axis=1)
return phaseim
def plot_rasters(listof_rasterlists, runtime,\
colorlist=['r','g','b'], labellist=['v1','v2','v3'], labels=True):
fig = figure(facecolor='none')
ax = fig.add_subplot(111)
numrasterlists = len(listof_rasterlists)
for rlistnum,rasterlist in enumerate(listof_rasterlists):
numrasters = float(len(rasterlist))
seglist = []
for rnum,raster in enumerate(rasterlist):
for t in raster:
## append a segment for a spike
seglist.append(((t,rlistnum+rnum/numrasters),\
(t,rlistnum+(rnum+1)/numrasters)))
## plot the raster
if labels:
segs = collections.LineCollection(seglist,\
color=colorlist[rlistnum%len(colorlist)],\
label=labellist[rlistnum%len(labellist)])
else:
segs = collections.LineCollection(seglist,\
color=colorlist[rlistnum%len(colorlist)])
ax.add_collection(segs)
ax.set_xlim(0.0,runtime)
if labels:
ax.set_ylim(0,numrasterlists*1.3) # extra 0.3 height for legend
biglegend()
else:
ax.set_ylim(0,numrasterlists)
axes_labels(ax,'time (s)','spike raster trial#')
title('Spike rasters', fontsize=24)
def crosscorr(x,y):
""" pass numpy arrays x and y, so that element by element division & multiplication works.
The older version of correlate() in numpy 1.4 gives only a scalar for tau=0. """
return correlate(x,y)/(sqrt(correlate(x,x))*correlate(y,y))
def crosscorrgram(x, y, dt, halfwindow, starttime, endtime, norm='none'):
""" pass arrays x and y of numtrials arrays of spike times.
x[trialnum][tnum], y[trialnum][tnum]
dt is the binsize, T = endtime-starttime
Valid time length of the correlogram is from
-halfwindow to +halfwindow.
Analysis as per http://mulab.physiol.upenn.edu/crosscorrelation.html
I further normalize by total number of spikes (Dhawale et al 2010 fig S5).
I restrict the reference spike train x, between starttime+halfwindow to endtime-halfwindow;
I restrict the compared spike train y, between starttime to endtime.
For each spike in x, there is a sliding window of spikes in y.
norm = 'overall': a la Ashesh et al 2010, divide by (total #spikes in all sliding windows over y).
Above is same as dividing by (#spikesx * (mean #spikes in a sliding window of y)).
norm = 'analogous': divide by (sqrt(#spikesx) * sqrt(#spikesy))
Similar to dividing by sqrt(autocorrelationx)*sqrt(autocorrelationy)
norm = 'ref': divide by (#spikesx)
i.e. use the number of spikes in the reference spiketrain as the norm factor.
Normalizes such that tau=0 value of auto-corr i.e. crosscorrgram(x,x,...) = 1
norm = 'none': no division
NOTE: mean is not subtracted from the two spike trains.
To do that, first convert the list of spiketimes to a spike raster of 0s and 1s.
Then subtract the respective means, and sum after element-wise multiplication.
Finally, this function returns the average correlogram over all the trials.
With a reference spiketrain x and normalization by #spikesy,
the crosscorrgram becomes somewhat asymmetrical wrt x and y? """
T = endtime-starttime
xstarttime = starttime+halfwindow
xendtime = endtime-halfwindow
## div 2 and +1 to make number of bins odd
bins = int(4*halfwindow/dt)/2 + 1
centralbinnum = bins/2 ## integer division
corrgramavg = array([0.0]*bins)
## x[trialnum][tnum]
numtrials = len(x)
corrnums = 0
for trialnum in range(numtrials):
xtrialnum = x[trialnum]
if len(xtrialnum) == 0: continue
spikenumx = 0
spikenumy_allwindows = 0
corrgram = array([0.0]*bins)
for tx in xtrialnum:
## be careful, MOOSE inserts 0.0-s at the end of the fire times list!!!
if tx<=xstarttime or tx>=xendtime: continue
## central bin is centered around t=0
## tx=ty falls in the center of the central bin.
ystarttime = tx-halfwindow
yendtime = tx+halfwindow
spikenumx += 1
for ty in y[trialnum]:
if ty<=ystarttime or ty>=yendtime: continue
binnum = round((ty-tx)/dt)+centralbinnum
corrgram[binnum] += 1.0
spikenumy_allwindows += 1
## if variable spikenumy_thistrial exists, add it.
#if 'spikenumy_thistrial' in locals():
# spikenumy += spikenumy_thistrial
## Normalization:
## Divide by (total #spikes in all sliding windows over y).
## This is same as dividing by (#spikesx * (mean #spikes in sliding windows of y)).
if norm=='overall':
if spikenumy_allwindows>0:
corrgram /= float(spikenumy_allwindows)
corrgramavg += corrgram
corrnums += 1
#else: corrgram = [float('nan')]
## Divide by (sqrt(#spikesx) * sqrt(#spikesy))
## Similar to dividing by sqrt(autocorrelationx)*sqrt(autocorrelationy)
elif norm=='analogous':
spikenumy = 0
for ty in y[trialnum]:
if ty<starttime or ty>endtime: continue
spikenumy += 1
if spikenumx>0 and spikenumy>0:
corrgram /= (sqrt(spikenumx)*sqrt(spikenumy))
corrgramavg += corrgram
corrnums += 1
## Divide by (#spikesx)
## Normalizes such that tau=0 value of auto-corr i.e. crosscorrgram(x,x,...) = 1
elif norm=='ref':
if spikenumx>0:
corrgram /= spikenumx
corrgramavg += corrgram
corrnums += 1
else:
corrgramavg += corrgram
corrnums += 1
if corrnums==0: return array([nan]*bins)
else: return corrgramavg/float(corrnums)
## --------------------------------------
## matplotlib stuff
def axes_off(ax,x=True,y=True):
if x:
for xlabel_i in ax.get_xticklabels():
xlabel_i.set_visible(False)
xlabel_i.set_fontsize(0.0)
if y:
for xlabel_i in ax.get_yticklabels():
xlabel_i.set_fontsize(0.0)
xlabel_i.set_visible(False)
if x:
for tick in ax.get_xticklines():
tick.set_visible(False)
if y:
for tick in ax.get_yticklines():
tick.set_visible(False)
def set_tick_widths(ax,tick_width):
for tick in ax.xaxis.get_major_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
for tick in ax.xaxis.get_minor_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
for tick in ax.yaxis.get_major_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
for tick in ax.yaxis.get_minor_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
def axes_labels(ax,xtext,ytext,adjustpos=False,fontsize=label_fontsize,xpad=None,ypad=None):
ax.set_xlabel(xtext,fontsize=fontsize,labelpad=xpad)
# increase xticks text sizes
for label in ax.get_xticklabels():
label.set_fontsize(fontsize)
ax.set_ylabel(ytext,fontsize=fontsize,labelpad=ypad)
# increase yticks text sizes
for label in ax.get_yticklabels():
label.set_fontsize(fontsize)
if adjustpos:
## [left,bottom,width,height]
ax.set_position([0.135,0.125,0.84,0.75])
set_tick_widths(ax,axes_linewidth)
def biglegend(legendlocation='upper right',ax=None,fontsize=label_fontsize, **kwargs):
if ax is not None:
leg=ax.legend(loc=legendlocation, **kwargs)
else:
leg=legend(loc=legendlocation, **kwargs)
# increase legend text sizes
for t in leg.get_texts():
t.set_fontsize(fontsize)
def beautify_plot(ax,x0min=True,y0min=True,
xticksposn='bottom',yticksposn='left',xticks=None,yticks=None,
drawxaxis=True,drawyaxis=True):
"""
x0min,y0min control whether to set min of axis at 0.
xticksposn,yticksposn governs whether ticks are at
'both', 'top', 'bottom', 'left', 'right', or 'none'.
xtickx/yticks is a list of ticks, else [min,max] is taken.
Due to rendering issues,
axes do not overlap exactly with the ticks, dunno why.
"""
ax.get_yaxis().set_ticks_position(yticksposn)
ax.get_xaxis().set_ticks_position(xticksposn)
xmin, xmax = ax.get_xaxis().get_view_interval()
ymin, ymax = ax.get_yaxis().get_view_interval()
if x0min: xmin=0
if y0min: ymin=0
if xticks is None: ax.set_xticks([xmin,xmax])
else: ax.set_xticks(xticks)
if yticks is None: ax.set_yticks([ymin,ymax])
else: ax.set_yticks(yticks)
### do not set width and color of axes by below method
### axhline and axvline are not influenced by spine below.
#ax.axhline(linewidth=axes_linewidth, color="k")
#ax.axvline(linewidth=axes_linewidth, color="k")
## spine method of hiding axes is cleaner,
## but alignment problem with ticks in TkAgg backend remains.
for loc, spine in ax.spines.items(): # items() returns [(key,value),...]
spine.set_linewidth(axes_linewidth)
if loc == 'left' and not drawyaxis:
spine.set_color('none') # don't draw spine
elif loc == 'bottom' and not drawxaxis:
spine.set_color('none') # don't draw spine
elif loc in ['right','top']:
spine.set_color('none') # don't draw spine
### alternate method of drawing axes, but for it,
### need to set frameon=False in add_subplot(), etc.
#if drawxaxis:
# ax.add_artist(Line2D((xmin, xmax), (ymin, ymin),\
# color='black', linewidth=axes_linewidth))
#if drawyaxis:
# ax.add_artist(Line2D((xmin, xmin), (ymin, ymax),\
# color='black', linewidth=axes_linewidth))
## axes_labels() sets sizes of tick labels too.
axes_labels(ax,'','',adjustpos=False)
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
return xmin,xmax,ymin,ymax
def fig_clip_off(fig):
## clipping off for all objects in this fig
for o in fig.findobj():
o.set_clip_on(False)
## ------
## from https://gist.github.com/dmeliza/3251476#file-scalebars-py
# Adapted from mpl_toolkits.axes_grid2
# LICENSE: Python Software Foundation (http://docs.python.org/license.html)
from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
pad=0.1, borderpad=0.1, sep=2, prop=None, label_fontsize=label_fontsize, color='k', **kwargs):
"""
Draw a horizontal and/or vertical bar with the size in data coordinate
of the give axes. A label will be drawn underneath (center-aligned).
- transform : the coordinate frame (typically axes.transData)
- sizex,sizey : width of x,y bar, in data units. 0 to omit
- labelx,labely : labels for x,y bars; None to omit
- loc : position in containing axes
- pad, borderpad : padding, in fraction of the legend font size (or prop)
- sep : separation between labels and bars in points.
- **kwargs : additional arguments passed to base class constructor
"""
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
bars = AuxTransformBox(transform)
if sizex:
bars.add_artist(Rectangle((0,0), sizex, 0, fc="none", linewidth=axes_linewidth, color=color))
if sizey:
bars.add_artist(Rectangle((0,0), 0, sizey, fc="none", linewidth=axes_linewidth, color=color))
if sizex and labelx:
textareax = TextArea(labelx,minimumdescent=False,textprops=dict(size=label_fontsize,color=color))
bars = VPacker(children=[bars, textareax], align="center", pad=0, sep=sep)
if sizey and labely:
## VPack a padstr below the rotated labely, else label y goes below the scale bar
## Just adding spaces before labely doesn't work!
padstr = '\n '*len(labely)
textareafiller = TextArea(padstr,textprops=dict(size=label_fontsize/3.0))
textareay = TextArea(labely,textprops=dict(size=label_fontsize,rotation='vertical',color=color))
## filler / pad string VPack-ed below labely
textareayoffset = VPacker(children=[textareay, textareafiller], align="center", pad=0, sep=sep)
## now HPack this padded labely to the bars
bars = HPacker(children=[textareayoffset, bars], align="top", pad=0, sep=sep)
AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
child=bars, prop=prop, frameon=False, **kwargs)
def add_scalebar(ax, matchx=True, matchy=True, hidex=True, hidey=True, \
label_fontsize=label_fontsize, color='k', **kwargs):
""" Add scalebars to axes
Adds a set of scale bars to *ax*, matching the size to the ticks of the plot
and optionally hiding the x and y axes
- ax : the axis to attach ticks to
- matchx,matchy : if True, set size of scale bars to spacing between ticks
if False, size should be set using sizex and sizey params
- hidex,hidey : if True, hide x-axis and y-axis of parent
- **kwargs : additional arguments passed to AnchoredScaleBars
Returns created scalebar object
"""
def f(axis):
l = axis.get_majorticklocs()
return len(l)>1 and (l[1] - l[0])
if matchx:
kwargs['sizex'] = f(ax.xaxis)
kwargs['labelx'] = str(kwargs['sizex'])
if matchy:
kwargs['sizey'] = f(ax.yaxis)
kwargs['labely'] = str(kwargs['sizey'])
sb = AnchoredScaleBar(ax.transData, label_fontsize=label_fontsize, color=color, **kwargs)
ax.add_artist(sb)
if hidex : ax.xaxis.set_visible(False)
if hidey : ax.yaxis.set_visible(False)
return sb
## from https://gist.github.com/dmeliza/3251476#file-scalebars-py -- ends
## ------
## matplotlib stuff ends
## -----------------------------------------
def plotSpikes(firetimes, runtime, plotdt):
firetimes = array(firetimes)
# MOOSE often inserts one or two spiketime = 0.0 entries when storing spikes, so discount those:
firetimes = firetimes[ where(firetimes>0.0)[0] ]
firetimes = firetimes[ where(diff(firetimes)>2*plotdt)[0] ] # Take the falling edge of every threshold crossing.
firearray = zeros(int(round(runtime/plotdt)),dtype=int8) # 1D array of type int8
firelen = len(firearray)
for firetime in firetimes:
firearray[int(round(firelen*float(firetime))/runtime)] = 1
return firearray
def plotBins(firetimes, numbins, runtime, settletime):
binlist = [0]*numbins
firetimes = array(firetimes)
## MOOSE often inserts one or two spiketime = 0.0 entries
## when storing spikes, so discount those:
firetimes = firetimes[ where(firetimes>0.0)[0] ]
for firetime in firetimes:
if firetime>=settletime:
## The small number has been added to the Dr to ensure no index out of range errors
## Nothing to do about causality here:
## while plotting, keep bintime to right edge to make it causal
binnum = int((firetime-settletime)/(runtime-settletime+0.0001)*numbins)
binlist[binnum] += 1
return [binspikes/((runtime-settletime)/float(numbins)) for binspikes in binlist] # return firing rate in Hz
def plotOverlappingBins(firetimes, numbins, time_period, settletime, bin_width_time):
"""
Firing rate in overlapping bins (moving average).
numbins # of bins in the time (settletime) to (time_period+settletime)
Assumes periodic/wrapped boundary conditions with period=time_period.
This way the end bins are accurate,
else they will not have data to one end and show lower firing rates.
Typically, adjust settletime to bin
only the first or second respiratory cycle.
"""
CAUSAL = True
binlist = [0]*numbins
firetimes = array(firetimes)
## MOOSE often inserts one or two spiketime = 0.0 entries
## when storing spikes, so discount those:
firetimes = firetimes[ where(firetimes>0.0)[0] ]
bindt = time_period/float(numbins)
## if CAUSAL, take spikes only to the left of bin centre_times.
if CAUSAL: centre_times = arange(bindt, time_period+bindt/2.0, bindt)
else: centre_times = arange(bindt/2, time_period, bindt)
bin_half_t = bin_width_time/2.0
rightmost_t = time_period
for firetime in firetimes:
## The end bins will not show correct firing rate!
if firetime>=settletime and firetime<(settletime+time_period):
firetime -= settletime
## Each firetime is in multiple bins depending on bin_width_time
for binnum,bin_centre_t in enumerate(centre_times):
## if CAUSAL, take spikes only to the left of bin centre_times.
if CAUSAL:
bin_left = bin_centre_t - bin_width_time
bin_right = bin_centre_t
else:
bin_left = bin_centre_t - bin_half_t
bin_right = bin_centre_t + bin_half_t
if firetime >= bin_left and firetime < bin_right:
binlist[binnum] += 1
## Next lines implement circularity of firetimes
if bin_left < 0 and firetime >= (bin_left+rightmost_t):
binlist[binnum] += 1
if bin_right > rightmost_t and firetime < (bin_right-rightmost_t):
binlist[binnum] += 1
return [float(binspikes)/bin_width_time for binspikes in binlist] # return firing rate in Hz
def calcFreq(timeTable, runtime, settletime, plotdt, threshold, spiketable):
# input: if spiketable is True: timeTable has spike times: i.e. a MOOSE table which has stepMode = TAB_SPIKE
# input: if spiketable is False: timeTable has Vm-s: i.e. a MOOSE table which has stepMode = TAB_BUF
# output: (meanrate, meanrate2, events)
# output: events is a list of times of falling edges of 'spikes' separated by at least 2*eventdt.
# output: meanrate2 is just #spikes/time
# output: meanrate is mean of 1/inter-spike-interval (removing very short ISIs)
tablenumpy = array(timeTable) # convert the MOOSE table into a numpy array
if spiketable: # timeTable has spike times
# only those spike times which are after settle time.
# it is important to do this even is settletime == 0.0,
# since MOOSE inserts spurious t=0.0 spitketime entries in a spike table.
events = tablenumpy[ where(tablenumpy>settletime)[0] ]
else: # timeTable has Vm-s
cutout = tablenumpy[ int(settletime/plotdt): ] # cutout only those spike times which are after settle time.
if len(cutout) <= 0:
events = []
else:
thresholded = where(cutout>threshold)[0] # gives indices whereever cutout > THRESHOLD
# where difference between two adjacent indices in thresholded > 2
# i.e. takes falling edge of every threshold crossing
# THIS IS UNLIKE SPIKETABLE ABOVE WHICH TAKES RISING EDGE!
take = where(diff(thresholded)>2)[0] # numpy's where and diff
indices = thresholded[ take ] # indexed by a list! works for ndarray only, not for usual python lists.
# numpy multiplication of array by scalar -- very different from python list multiplication by integer!!!
events = indices*plotdt + settletime
# calculate mean firing rate as 1/inter-spike-interval (removing very short ISIs)
if len(events)>1: # at least two events needed!
firingRateList = array([])
for i in range(len(events)-1):
firingtimespan = (events[i+1]-events[i])
firingRateList = append(firingRateList,1.0/firingtimespan)
############ Have to filter out the APs which have very closely spaced double firings :
############ typically happens for close to zero currents like 0.15nA etc.
## Keep removing all firing rate entries which are greater than twice the min.
while firingRateList.max() > 2*firingRateList.min():
firingRateList = delete(firingRateList,firingRateList.argmax())
################### Finally calculate the actual mean
meanrate = array(firingRateList).mean() # 1/s = Hz
else:
meanrate = 0
# mean firing rate as #spikes/time
meanrate2 = len(events)/(runtime-settletime) # Hz
return (meanrate, meanrate2, events)
def minimum_distance(a,b,p):
""" a,b,p are vectors.
Given two end-points a and b of a line segment,
find the minimum distance from p to the the line segment.
points could be 2D or 3D."""
a = array(a)
b = array(b)
p = array(p)
## length sq of line segment
len_sq = norm(b-a)**2
if len_sq==0: return norm(p-a)
## take infinte line as c = a + t*(b-a)
## find t where the point p drops a normal to line
t = dot(b-a,p-a)/len_sq
## point before a
if t<0: return norm(p-a)
## point after b
elif t>1.0: return norm(p-b)
else: return norm( p - (a+t*(b-a)) )
def outcode(x,y,xmin,ymin,xmax,ymax):
""" utility function for the Cohen-Sutherland clipper below. """
outcode = 0x0 # inside
if y>ymax: outcode |= 0x1 # top
elif y<ymin: outcode |= 0x2 # bottom
if x>xmax: outcode |= 0x4 # right
elif x<xmin: outcode |= 0x8 # left
return outcode
def clip_line_to_rectangle(x1,y1,x2,y2,xmin,ymin,xmax,ymax):
""" clip line segment from (x1,y1) to (x2,y2) inside a rectange.
2D points only. Cohen-Sutherland algorithm (Wikipedia)"""
outcode1 = outcode(x1,y1,xmin,ymin,xmax,ymax)
outcode2 = outcode(x2,y2,xmin,ymin,xmax,ymax)
accept = False
done = False
while not done:
if not (outcode1 | outcode2): # segment fully inside
accept = True
done = True
elif outcode1 & outcode2: # segment fully outside i.e.
# both points are to the left/right/top/bottom of rectangle
done = True
else:
if outcode1>0: outcode_ex = outcode1
else: outcode_ex = outcode2
if outcode_ex & 0x1 : # top
x = x1 + (x2 - x1) * (ymax - y1) / (y2 - y1)
y = ymax
elif outcode_ex & 0x2: # bottom
x = x1 + (x2 - x1) * (ymin - y1) / (y2 - y1)
y = ymin
elif outcode_ex & 0x4: # right
y = y1 + (y2 - y1) * (xmax - x1) / (x2 - x1)
x = xmax
else: # left
y = y1 + (y2 - y1) * (xmin - x1) / (x2 - x1)
x = xmin
## get the new co-ordinates of the line
if (outcode_ex == outcode1):
x1 = x
y1 = y
outcode1 = outcode(x1, y1, xmin, ymin, xmax, ymax)
else:
x2 = x
y2 = y
outcode2 = outcode(x2, y2, xmin, ymin, xmax, ymax)
return accept,x1,y1,x2,y2
## copied savitsky_golay from scipy cookbook: http://www.scipy.org/Cookbook/SavitzkyGolay
def savitzky_golay(y, window_size, order, deriv=0):
r"""Smooth (and optionally differentiate) data with a Savitzky-Golay filter.
The Savitzky-Golay filter removes high frequency noise from data.
It has the advantage of preserving the original shape and
features of the signal better than other types of filtering
approaches, such as moving averages techhniques.
Parameters
----------
y : array_like, shape (N,)
the values of the time history of the signal.
window_size : int
the length of the window. Must be an odd integer number.
order : int
the order of the polynomial used in the filtering.
Must be less then `window_size` - 1.
deriv: int
the order of the derivative to compute (default = 0 means only smoothing)
Returns
-------
ys : ndarray, shape (N)
the smoothed signal (or it's n-th derivative).
Notes
-----
The Savitzky-Golay is a type of low-pass filter, particularly
suited for smoothing noisy data. The main idea behind this
approach is to make for each point a least-square fit with a
polynomial of high order over a odd-sized window centered at
the point.
Examples
--------
t = np.linspace(-4, 4, 500)
y = np.exp( -t**2 ) + np.random.normal(0, 0.05, t.shape)
ysg = savitzky_golay(y, window_size=31, order=4)
import matplotlib.pyplot as plt
plt.plot(t, y, label='Noisy signal')
plt.plot(t, np.exp(-t**2), 'k', lw=1.5, label='Original signal')
plt.plot(t, ysg, 'r', label='Filtered signal')
plt.legend()
plt.show()
References
----------
.. [1] A. Savitzky, M. J. E. Golay, Smoothing and Differentiation of
Data by Simplified Least Squares Procedures. Analytical
Chemistry, 1964, 36 (8), pp 1627-1639.
.. [2] Numerical Recipes 3rd Edition: The Art of Scientific Computing
W.H. Press, S.A. Teukolsky, W.T. Vetterling, B.P. Flannery
Cambridge University Press ISBN-13: 9780521880688
"""
try:
window_size = np.abs(np.int(window_size))
order = np.abs(np.int(order))
except ValueError, msg:
raise ValueError("window_size and order have to be of type int")
if window_size % 2 != 1 or window_size < 1:
raise TypeError("window_size size must be a positive odd number")
if window_size < order + 2:
raise TypeError("window_size is too small for the polynomials order")
order_range = range(order+1)
half_window = (window_size -1) // 2
# precompute coefficients
b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
m = np.linalg.pinv(b).A[deriv]
# pad the signal at the extremes with
# values taken from the signal itself
firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
y = np.concatenate((firstvals, y, lastvals))
return np.convolve( m, y, mode='valid')
def circular_convolve(x,y,n):
"""
From: http://www.dspguru.com/dsp/tutorials/a-little-mls-tutorial
I need to compute the circular cross-correlation of y and x
N-2
Ryx[n] = 1/(N-1) * SUM{ y[i] * x[i-n] }
i=0
python indexing (negative indices) automatically gives you circularity!!!
"""
N = len(y)
return array( 1.0/(N-1)*sum( [ y[i]*x[i-n] for i in range(N-1) ] ) )
def circular_correlate(x,y):
N = len(y)
return array( [ circular_convolve(x,y,n) for n in range(len(y)) ] )
############################ Information theory functions ###############################
def binary_entropy(p0,p1):
H = 0.0
if p0>0.0:
H -= p0*log2(p0)
if p1>0.0:
H -= p1*log2(p1)
return H
def makestring(intlist):
s = ''
for val in intlist:
s += str(val)
return s
## yield defines a generator
def find_substr_endchars(mainstr,substr,delay=0):
"""
Returns a list (generator) of characters
each following a 'substr' occurence in 'mainstr' after 'delay'.
"""
## don't use regular expressions re module, which finds only non-overlapping matches
## we want to find overlapping matches too.
substrlen = len(substr)
while True:
idx = mainstr.find(substr)
## find returns -1 if substr not found
if idx != -1:
endcharidx = idx+substrlen+delay
if endcharidx<len(mainstr):
yield mainstr[endcharidx]
else: # reached end of string
break
## chop the mainstr just after the start of substr,
## not after the end, as we want overlapping strings also
mainstr = mainstr[idx+1:]
else: # substr not found
break
## yield defines a generator
def find_substrs12_endchars(sidestr,mainstr,substr1,substr2,delay1=0,delay2=0):
"""
Returns a list (generator) of characters from mainstr Y,
each following a substr1 occurence in sidestr X after delay1,
and following a substr2 occurence in mainstr Y after delay2.
"""
## don't use regular expressions re module, which finds only non-overlapping matches
## we want to find overlapping matches too.
substr2len = len(substr2)
substr1len = len(substr1)
abs_idx1 = 0 ## mainstr is getting chopped, but we maintain abs index on sidestr
while True:
idx2 = mainstr.find(substr2)
## find returns -1 if substr2 not found
if idx2 != -1:
endcharidx2 = idx2+substr2len+delay2
### NOTE: abs_startidx1 is one earlier than definition!!! I think necessary for causality.
## put +1 below to switch to definition in Quinn et al 2010
abs_startidx1 = abs_idx1 + endcharidx2 - substr1len-delay1
if endcharidx2<len(mainstr): # mainstr Y has characters left?
if abs_startidx1 >= 0: # sidestr X has sufficient chars before?
## sidestr has substr1 before the char to be returned? and mainstr is not over
## IMP: below if's first term is the only place directed info enters.
## Remove first term below and you get just the entropy of mainstr Y: VERIFIED.
#print sidestr[abs_startidx1:abs_startidx1+substr1len], substr1, abs_startidx1
if sidestr[abs_startidx1:abs_startidx1+substr1len]==substr1:
yield mainstr[endcharidx2]
else: # reached end of string
break
## chop the mainstr just after the start of substr2,
## not after the end, as we want overlapping strings also
mainstr = mainstr[idx2+1:]
## don't chop sidestr as substr1len may be greater than substr2len
## in the next iteration, idx2 will be relative, but for sidestr we maintain abs_idx1
abs_idx1 += idx2+1
else: # substr2 not found
break
def calc_entropyrate(spiketrains,markovorder,delay=0):
"""
spiketrains is a list of spiketrain = [<0|1>,...]
should be int-s, else float 1.0-s make the str-s go crazy!
J = markovorder >= 1. Cannot handle non-Markov i.e. J=0 presently!!
Returns entropy rate, assuming markov chain of order J :
H(X_{J+1}|X_J..X_1).
Assumes spike train bins are binary i.e. 0/1 value in each timebin.
delay is self delay of effect of X on X.
NOTE: one time step i.e. causal delay is permanently present, 'delay' is extra.
"""
Hrate = 0.0
N = 0
if markovorder>0:
## create all possible binary sequences priorstr=X_1...X_J
## convert integer to binary repr str of length markovorder padded with zeros (=0)
reprstr = '{:=0'+str(markovorder)+'b}'
priorstrs = [ reprstr.format(i) for i in range(int(2**markovorder)) ]
else:
## return numpy nan if markovorder <= 0
return nan
## Convert the list of timebins to a string of 0s and 1s.
## Don't do it in loops below, else the same op is repeated len(priorstrs) times.
## Below conversion is quite computationally expensive.
mcs = []
for spiketrain in spiketrains:
## A generator expression is given as argument to makestring
mcs.append(makestring(val for val in spiketrain))
## Calculate entropy for each priorstr, and sum weighted by probability of each priorstr
for priorstr in priorstrs:
num1s = 0
num0s = 0
for mc in mcs:
for postchar in find_substr_endchars(mc,priorstr,delay):
if int(postchar): # if the character just after priorstr is nonzero i.e. 1
num1s += 1
else:
num0s += 1
N_givenprior = float(num1s + num0s)
## H(X|Y) = \sum p(Y=y)*H(X|Y=y) ; the normalization by N is done at the end
## p(Y=y) = N_givenprior/N where N is total after all loops
if N_givenprior>0:
Hrate += N_givenprior * binary_entropy(num0s/N_givenprior,num1s/N_givenprior)
N += N_givenprior
if N!=0: Hrate = Hrate/N
return Hrate
def calc_dirtinforate(spiketrains1,spiketrains2,markovorder1,markovorder2,delay1=0,delay2=0):
"""
Returns directed information rate from spiketrains1 X to spiketrains2 Y.
Returns directed information rate (lim_{n->/inf} 1/n ...),
assuming train2 as markov chain of order K,
and train1 affecting it with markov order J.
I(X^n->Y^n) = H( Y_{J+1} | Y_J..Y_1 ) - H( Y_{L} | Y^{L-1}_{L-J} X^{L-1}_{L-K} ),
where L = max(J,K).
NOTE: I have changed X^{L}_{L-K-1} in definition above to X^{L-1}_{L-K} for causality!
Assumes spike train bins are binary i.e. integer 0/1 value in each timebin.
spiketrains1 and 2 are each a list of spiketrain = [<0|1>,...]
should be int-s, else float 1.0-s make the str-s go crazy!
dimensions of both must be the same.
J = markovorder1 >= 1. Cannot handle non-Markov i.e. J=0 presently!!
K = markovorder2 >= 1. Cannot handle non-Markov i.e. K=0 presently!!
Keep J,K<5, else too computationally intensive.
The prior substrings are searched delay1 and delay2 before Y_n in trains 1 and 2.
delay1 is lateral/side delay of effect of X on Y,
delay2 is self/main delay of effect of Y on Y.
NOTE: one time step i.e. causal delay is permanently present, delay1 and 2 are extra.
"""
dirtIrate_term2 = 0.0
N = 0
## for the 'cause' spike train
if markovorder1>0:
## create all possible binary sequences priorstr=X_1...X_J
## convert integer to binary repr str of length markovorder padded with zeros (=0)
reprstr = '{:=0'+str(markovorder1)+'b}'
priorstrs1 = [ reprstr.format(i) for i in range(int(2**markovorder1)) ]
else:
## return numpy nan if markovorder <= 0
return nan
## for the 'effect' spike train
if markovorder2>0:
## create all possible binary sequences priorstr=X_1...X_K
## convert integer to binary repr str of length markovorder padded with zeros (=0)
reprstr = '{:=0'+str(markovorder2)+'b}'
priorstrs2 = [ reprstr.format(i) for i in range(int(2**markovorder2)) ]
else:
## return numpy nan if markovorder <= 0
return nan
## Convert the list of timebins to a string of 0s and 1s.
## Don't do it in loops below, else the same op is repeated len(priorstrs) times.
## Below conversion is quite computationally expensive.
mcs1 = []
for spiketrain in spiketrains1:
## A generator expression is given as argument to makestring
mcs1.append(makestring(val for val in spiketrain))
mcs2 = []
for spiketrain in spiketrains2:
## A generator expression is given as argument to makestring
mcs2.append(makestring(val for val in spiketrain))
## Calculate entropy for each combo of priorstr 1 & 2,
## and sum weighted by probability of each combo
for priorstr1 in priorstrs1:
for priorstr2 in priorstrs2:
num1s = 0
num0s = 0
for chaini,mc1 in enumerate(mcs1):
mc2 = mcs2[chaini]
for postchar in find_substrs12_endchars(mc1,mc2,priorstr1,priorstr2,delay1,delay2):
## if the character just after priorstr1 & priorstr2, is nonzero i.e. 1
if int(postchar):
num1s += 1
else:
num0s += 1
N_givenpriors = float(num1s + num0s)
## H(Y|Y^X^) = \sum p(Y^=y^)*H(Y|Y^=y^,X^=x^) ;
## the normalization by N is done at the end
## p(Y^=y^,X^=x^) = N_givenpriors/N where N is total after all loops
if N_givenpriors>0:
dirtIrate_term2 += N_givenpriors * \
binary_entropy(num0s/N_givenpriors,num1s/N_givenpriors)
N += N_givenpriors
if N!=0: dirtIrate_term2 = dirtIrate_term2/N
## H( Y_{J+1} | Y_J..Y_1 )
dirtIrate_term1 = calc_entropyrate(spiketrains2,markovorder2,delay2)
## I(X^n->Y^n) = H( Y_{J+1} | Y_J..Y_1 ) - H( Y_{L} | Y^{L-1}_{L-J} X^{L-1}_{L-K} )
dirtIrate = dirtIrate_term1 - dirtIrate_term2
return dirtIrate
def get_spiketrain_from_spiketimes(\
spiketimes,starttime,timerange,numbins,warnmultiple=True,forcebinary=True):
""" bin number of spikes from starttime to endtime, into dt bins.
if warnmultiple, warn if multiple spikes are binned into a single bin.
if forcebinary, set multiple spikes in a bin to 1.
"""
## important to make these int, else spikestrs in entropy calculations go haywire!
spiketrain = zeros(numbins,dtype=int)
for spiketime in spiketimes:
spiketime_reinit = spiketime-starttime
if 0.0 < spiketime_reinit < timerange:
binnum = int(spiketime_reinit/timerange*numbins)
spiketrain[binnum] += 1
if forcebinary or warnmultiple:
multiplespike_indices = where(spiketrain>1)[0] # spiketrain must be a numpy array()
if len(multiplespike_indices)>0:
## if non-empty number of multiple spikes indices, set to 1, print warning
if forcebinary:
spiketrain[multiplespike_indices] = 1
if warnmultiple:
## do not print warnings if user has turned them off
print "There are more than 1 spikes in", \
len(multiplespike_indices), "number of bins."
if forcebinary: print "Have forced them all to be 1."
return spiketrain
##############################################################################