# -*- coding: utf-8 -*-
########## THIS FITTING PROGRAM IS MEANT TO ROUGHLY FOLLOW PRIYANKA'S ANALYSIS
########## This variant does not use an air kernel, rather a constant air rate.
## USAGE1: python2.6 fit_scaledpulses.py <../results/odor_pulses/scaledpulses_ ... .pickle> <stimseed>
from scipy import optimize
from scipy import special
import scipy.interpolate
from scipy import signal
import scipy.io
from scipy.stats import linregress
from pylab import *
import pickle
import sys
import math
import copy as cp
sys.path.extend(["..","../networks","../generators","../simulations"])
from stimuliConstants import * # has SETTLETIME, SCALED_RUNTIME
from networkConstants import * # has central_glom
from data_utils import * # has axes_labels()
from analysis_utils import * # has load_morph...(), predict_respresp() and set_min_err()
## index of the reference response
ref_response_scalenum = 3
## if peak detect is True, peak scaling is plotted,
## rather than the scaling wrt reference response above
peak_detect = False#True
## rebin the spikes as below, irrespective of previous binning
pulserebindt = 50e-3#fitting_dt # 50ms as Priyanka uses
pulserebins = int(SCALED_RUNTIME/pulserebindt)
#bin_width_time = 2*pulserebindt ## overlapping bins with this bin width
pulsetlist = arange(pulserebindt/2.0,SCALED_RUNTIME,pulserebindt)
NOISE_ANALYSIS = False
## number of the mitral to be fitted.
fitted_mitrals = [2*central_glom,2*central_glom+1]
class fit_plot_scaledpulses():
def __init__(self,filename,stimseed,scaledWidth=scaledWidth):
self.filename = filename
self.stimseed = stimseed
self.scaledWidth = scaledWidth
def fit_pulses(self,dirextn='',_noshow=True,_savefig=False, _test=False):
if _test:
########## load in the stimuli
## scaledPulseList[glomnum][scalenum][odornum][binnum]
scaledPulsesList,genORNkernels \
= read_scaledpulses_stimuli(self.stimseed,self.scaledWidth)
## decimate pulseList by taking every convolutiondt/FIRINGFILLDT bins
## and decimate ORNkernels by taking every kerneldt/FIRINGFILLDT=50/1=50 bins
## pulserebinList[scalenum,binnum] numpy array
pulserebinList,linORNkernelA,linORNkernelB = \
decimate(scaledPulsesList[0,:,0],pulserebindt,genORNkernels,kerneldt,kernel_size)
numtrials = 1
else:
######### load in the mitral pulse responses
#### mitral_responses_list[avgnum][scalenum][mitralnum][spikenum]
#### mitral_responses_binned_list[avgnum][scalenum][mitralnum][binnum]
mitral_responses_list, mitral_responses_binned_list = read_pulsefile(self.filename)
if NOISE_ANALYSIS:
##---------- Print/plot for best bin size, but this bin size is not used presently
## From Neural Computation 19, 1503–1527 (2007) Shimazaki & Shinomoto, pg 6
DeltaList = arange(25.0e-3,500.0e-3,25.0e-3)
SSvals = []
SSvals_extended = []
for Delta in DeltaList:
numtrials,mitral_responses_mean1,mitral_responses_std1 = \
rebin_mean(mitral_responses_list,Delta,SCALED_RUNTIME)
fitted_mitral = 0
firingbinsmeanList1 = mitral_responses_mean1[:,fitted_mitral]
firingbinserrList1 = mitral_responses_std1[:,fitted_mitral]
## each val in firingbinsmeanList[scalenum][binnum] is k_i/(n*Delta)
kbar = mean(firingbinsmeanList1)*numtrials*Delta
kvariance = (std(firingbinsmeanList1)*numtrials*Delta)**2
shimazaki_shinomoto_val = (2*kbar-kvariance)/((numtrials*Delta)**2)
SSvals.append(shimazaki_shinomoto_val)
print 'Delta t =',Delta,"shimazaki_shinomoto_val =",shimazaki_shinomoto_val
## Cost function i.e. shimazaki_shinomoto_val for different # of trials
Priyanka_numtrials = 12
ssval_extended = (1.0/Priyanka_numtrials - 1.0/numtrials)*kbar/numtrials/Delta**2 + \
shimazaki_shinomoto_val
SSvals_extended.append(ssval_extended)
if not _noshow:
fig = figure()
ax = fig.add_subplot(111)
ax.plot(DeltaList,SSvals,label='#='+str(numtrials))
ax.plot(DeltaList,SSvals_extended,label='#='+str(Priyanka_numtrials))
ax.legend()
axes_labels(ax,'Delta t (s)','SS Cost function (Hz^2)',fontsize=20)
##-------------------------- rebin the responses and pulses ------------------------------
## rebin sim responses to pulserebindt=50ms, then take mean
numtrials,mitral_responses_mean,mitral_responses_std = \
rebin_mean(mitral_responses_list,pulserebindt,SCALED_RUNTIME)
## full fitting data for both mitrals
fits_2mits = []
peak_scales_2mits = []
for mit_i,fitted_mitral in enumerate(fitted_mitrals):
if _test:
firingbinsmeanList = pulserebinList
firingbinsmeanList += uniform(-0,0,shape(firingbinsmeanList))
firingbinserrList = zeros(shape(firingbinsmeanList))
else:
## take the odor responses of the mitral to be fitted
firingbinsmeanList = mitral_responses_mean[:,fitted_mitral]
## The model predicts the individual response not the mean.
## Hence below fitting uses standard deviation, not standard error of the mean.
firingbinserrList = mitral_responses_std[:,fitted_mitral]
starti = int(PULSE_START/pulserebindt)
endi = int((PULSE_START+scaledWidth+kernel_time)/pulserebindt)
air_bgnd = firingbinsmeanList[0]
air_bgnd_relevant = firingbinsmeanList[0][starti:endi]
##---------------------------- fit scaled pulse responses ------------------------------------------
## define the reference response / scaling
ref_scale = scaledList[ref_response_scalenum]
ref_response = firingbinsmeanList[ref_response_scalenum][starti:endi]-air_bgnd_relevant
fits = []
peak_scales = []
for scalenum in [1,2,3,4,5]: ## conc scaled pulses
scale = scaledList[scalenum]
response = firingbinsmeanList[scalenum][starti:endi]-air_bgnd_relevant
## http://www.jerrydallal.com/LHSP/slrout.htm for defn of std error of the estimate
## SEE is std error of data about the regression line
slope, intercept, r_value, p_value, see = linregress(ref_response,response)
## SEE is called \hat{\sigma}_\epsilon i.e. sqrt(MSE) here:
## http://en.wikipedia.org/wiki/Regression_analysis
## formula for std error of slope is also from above Wikipedia article
se_slope = see/std(ref_response)
avgfrate = sum(firingbinsmeanList[scalenum][starti:endi])/float(endi-starti)
fits.append((scale,slope,intercept,r_value,p_value,see,se_slope,avgfrate))
## peak scaling
peak_scales.append(max(response))
peak_scales = array(peak_scales)#/peak_scales[ref_response_scalenum-1]*scaledList[ref_response_scalenum]
peak_scales_2mits.append(peak_scales)
fits_2mits.append(fits)
##---------------------------- plot scaled pulse responses -----------------------------------------
if not _noshow:
if peak_detect: print "BEWARE. Using peak scaling"
else: print "BEWARE. Using fitted scaling"
############################### plot the responses and the fits
fig = figure(figsize=(columnwidth,linfig_height/2),dpi=300,facecolor='w') # 'none' is transparent
## conc scaled pulses, leave the 0th pulse which is air_bgnd
ax = plt.subplot2grid((1,3),(0,0),rowspan=1)
for scaleiter,scale in enumerate(scaledList[1:]): ## conc scaled pulses
sister_ratio = (fitted_mitral%MIT_SISTERS)/float(MIT_SISTERS)
scaledpulsetime = array(pulsetlist[starti:endi]) - pulsetlist[starti] # start from t=0
################### Plot the simulated responses
## smooth the simulated response
## windowsize=5 and SD=0.65 are defaults from matlab's smoothts() for gaussian smoothing
Gwindow = signal.gaussian(5,0.65)
## help from http://www.scipy.org/Cookbook/SignalSmooth
simresponse = convolve(Gwindow/Gwindow.sum(),\
firingbinsmeanList[scaleiter+1]-air_bgnd,mode='same')
## ditch the smoothing above for scaled pulses
simresponse = firingbinsmeanList[scaleiter+1][starti:endi]-air_bgnd_relevant
## numpy array, hence adds element by element
scale_color = ['r','b','g','m','c'][scaleiter]
fill_between(scaledpulsetime,
simresponse+firingbinserrList[scaleiter+1][starti:endi]/sqrt(numtrials),
simresponse-firingbinserrList[scaleiter+1][starti:endi]/sqrt(numtrials),
color=scale_color,alpha=0.3)
plot(scaledpulsetime,simresponse,linewidth=plot_linewidth,color=scale_color)
xmin,xmax,ymin,ymax = beautify_plot(ax,x0min=True,y0min=False,xticksposn='bottom',yticksposn='left')
ax.set_xticks([0,xmax])
if ymin<10: ax.set_yticks([ymin,0,ymax])
else: ax.set_yticks([ymin,ymax])
plot([0,scaledWidth],[ymin+2,ymin+2],linewidth=plot_linewidth*3,color='r')
axes_labels(ax,'s','Hz',adjustpos=False,xpad=0,ypad=-3)
#ax.yaxis.set_label_coords(-0.4,1.2)
## plot the scaling
ax = plt.subplot2grid((1,3),(0,1),rowspan=1)
minx = min(ref_response)
maxx = max(ref_response)
xlist = arange(minx,maxx,(maxx-minx)/50.0) # this is in Hz
for scalenum in [1,2,3,4,5]: ## conc scaled pulses
scale = scaledList[scalenum]
response = firingbinsmeanList[scalenum][starti:endi]-air_bgnd_relevant
scale_normed,slope,intercept,r_value,_,_,_,_ = fits[scalenum-1]
print "stimseed =",self.stimseed,", scale =",scale_normed,"r^2 = ",r_value**2
if scalenum != ref_response_scalenum:
color4scale = ['r','b','g','m','c'][scalenum-1]
marker4scale = ['s','+','d','x','.'][scalenum-1]
scatter(ref_response,response,s=marker_size,marker=marker4scale,\
color=color4scale,edgecolor=color4scale)
ylist = slope*xlist+intercept
plot(xlist,ylist,color=color4scale,linewidth=linewidth)
xmin,xmax,ymin,ymax = beautify_plot(ax,x0min=False,y0min=False,\
xticksposn='bottom',yticksposn='left')
ax.set_xticks([xmin,0,xmax])
ax.set_yticks([ymin,0,ymax])
axes_labels(ax,'Hz','',adjustpos=False,xpad=1)
#ax.yaxis.set_label_coords(-0.5,1.2)
## plot response scaling vs conc scaling
ax = plt.subplot2grid((1,3),(0,2),rowspan=1)
plot(range(6),range(6),color=(0,0,0.7,0.5),dashes=(2.0,1.0)) ## linear reference
concratios,slopevsconc,_,_,_,_,se_slope,avg_frates = zip(*fits)
print "Average firing rates for mitral",fitted_mitral,\
"for different scales is",avg_frates
errorbar(append([0],concratios),y=append([0],array(slopevsconc))*ref_scale,\
yerr=append([0],array(se_slope))*ref_scale,\
color='b',linewidth=linewidth,marker='o',ms=marker_size,dashes=(2.0,1.0))
ax_twin = ax.twinx()
ax_twin.plot(scaledList,append([0],peak_scales),color='k',linewidth=linewidth,\
marker='o',ms=marker_size,dashes=(0.5,1.0))
beautify_plot(ax,x0min=False,y0min=False,xticksposn='bottom',yticksposn='left')
## Draw the twin y axis (turned off always by beautify_plot)
for loc, spine in ax.spines.items(): # items() returns [(key,value),...]
spine.set_linewidth(axes_linewidth)
if loc in ['right']:
spine.set_color('k') # draw spine in black
ax.set_ylim(0,5)
ax.set_xticks([0,1,5])
ax.set_yticks([0,1,5])
ax_twin.set_ylim(0,80)
ax_twin.set_yticks([0,80])
axes_labels(ax,'ORN scaling','mitral scaling',adjustpos=False,xpad=1,ypad=0)
axes_labels(ax_twin,'','mitral peak',adjustpos=False,ypad=-1)
#ax.yaxis.set_label_coords(-0.3,1.2)
fig.tight_layout()
fig_clip_off(fig)
fig.subplots_adjust(top=0.94,left=0.1,right=0.91,hspace=0.4,wspace=0.5)
if _savefig:
fig.savefig('../figures/scalelinearity_example_'+str(self.stimseed)+\
'_mit'+str(mit_i)+'.svg',dpi=fig.dpi)
fig.savefig('../figures/scalelinearity_example_'+str(self.stimseed)+\
'_mit'+str(mit_i)+'.png',dpi=fig.dpi)
if NOISE_ANALYSIS:
## plot the variance vs firing rate mean for each mitral
## variance = mean/bintime of firng rate for Poisson process
fig2 = figure()
ax2 = fig2.add_subplot(111)
for scaleiter,scale in enumerate(scaledList): ## conc scaled pulses including air
ax2.scatter( firingbinsmeanList[scaleiter], \
firingbinserrList[scaleiter]**2, color='r' )
beautify_plot(ax2)
axes_labels(ax2,'mean rate (Hz)','variance (Hz^2)',fontsize=14)
## plot individual trials for a given response
fig3 = figure()
ax3 = fig3.add_subplot(111)
for trialspikelist in mitral_responses_list:
plot( plotBins( trialspikelist[0][fitted_mitral],\
pulserebins, SCALED_RUNTIME, 0.0) )
return fits_2mits, peak_scales_2mits
if __name__ == "__main__":
NOSHOW = False
SAVEFIG = True#False
if len(sys.argv) > 2:
filename = sys.argv[1]
stimseed = sys.argv[2]
worker = fit_plot_scaledpulses(filename,stimseed)
post_pulses = filename.split('odor_pulses')[1]
dirextn = post_pulses.split('/')[0]
print 'directory extension is',dirextn
if 'TEST' in sys.argv: TEST=True
else: TEST=False
worker.fit_pulses(dirextn,NOSHOW,SAVEFIG,TEST)
show()
else:
print "At least specify data file containing pickled mitral responses, and ORN frate seed."
sys.exit(1)