# -*- coding: utf-8 -*-
########## THIS FITTING PROGRAM IS MEANT TO FIT sinusoids to 'mitral responses to sinusoids'!
## USAGE: python2.6 fit_odor_morphs.py ../results/odor_morphs/2011-01-13_odormorph_SINGLES_JOINTS_PGS.pickle
from scipy import optimize
from scipy.special import * # has error function erf() and inverse erfinv()
from pylab import *
import pickle
import sys
import math
sys.path.extend(["..","../networks","../generators","../simulations"])
from stimuliConstants import * # has SETTLETIME, inputList and pulseList, GLOMS_ODOR, GLOMS_NIL
from networkConstants import * # has central_glom
from sim_utils import * # has rebin() to alter binsize
from analysis_utils import * # has read_morphfile() and NUM_REBINS, etc.
iterationnum = 0
## amplitude[sinnum], phase[sinnum] and DC offset are the params
NUMPARAMS = 2*num_sins+1
## I don't use the NUMBINS in simset_odor.py, rather I rebin()
bindt = 5e-3 #s
NUM_REBINS = int(SIN_RUNTIME/bindt)
### numbers of mitral to be fitted.
fitted_mitral_list = [2*central_glom+0, 2*central_glom+1]
FIT_LAT_SINS = True
## those sims for which const rate at central glom and sinusoids at lateral glom
if FIT_LAT_SINS:
filelist = [
#(5.0,'../results/odor_sins/2012_02_14_15_36_sins_SINGLES_JOINTS_NOPGS_numgloms2.pickle') # 20 to 60 Hz
#(5.0,'../results/odor_sins/2012_02_14_17_36_sins_SINGLES_JOINTS_NOPGS_numgloms2.pickle') # 1 to 15 Hz # 5 Hz central
#(5.0,'../results/odor_sins/2012_02_14_19_55_sins_SINGLES_JOINTS_NOPGS_numgloms2.pickle') # 1 to 15 Hz # 9 Hz central
#(5.0,'../results/odor_sins/2012_02_15_08_20_sins_SINGLES_JOINTS_NOPGS_numgloms2.pickle') # 1 to 15 Hz # 3 Hz central
(3.0,'../results/odor_sins/2012_02_15_21_54_sins_SINGLES_JOINTS_NOPGS_numgloms2.pickle') # 1 to 15 Hz # 3 Hz central
]
##
else:
## 30 trials, only mitrals
#filelist = [
#(1.0,'../results/odor_sins/2012_02_02_15_32_sins_NOSINGLES_NOJOINTS_NOPGS_NOLAT_numgloms2.pickle'),
#(2.0,'../results/odor_sins/2012_02_02_17_10_sins_NOSINGLES_NOJOINTS_NOPGS_NOLAT_numgloms2.pickle'),
#(3.0,'../results/odor_sins/2012_02_02_19_18_sins_NOSINGLES_NOJOINTS_NOPGS_NOLAT_numgloms2.pickle')
#]
## 40 trials, higher frequencies, only mitrals
filelist = [
(1.0,'../results/odor_sins/2012_02_04_13_47_sins_NOSINGLES_NOJOINTS_NOPGS_NOLAT_numgloms2.pickle')
]
## 40 trials, higher frequencies, only mitrals, 10x longer time
filelist = [
(1.0,'../results/odor_sins/2012_02_04_17_51_sins_NOSINGLES_NOJOINTS_NOPGS_NOLAT_numgloms2.pickle')
]
## 30 trials, mitrals + spines + singles + PGs
#filelist = [
#(1.0,'../results/odor_sins/2012_02_02_16_17_sins_SINGLES_NOJOINTS_PGS_NOLAT_numgloms2.pickle'),
#(2.0,'../results/odor_sins/2012_02_02_17_54_sins_SINGLES_NOJOINTS_PGS_NOLAT_numgloms2.pickle'),
#(3.0,'../results/odor_sins/2012_02_02_20_08_sins_SINGLES_NOJOINTS_PGS_NOLAT_numgloms2.pickle')
#]
## 40 trials, higher frequencies, mitrals + spines + singles + PGs:
#filelist = [
#(1.0,'../results/odor_sins/2012_02_04_14_00_sins_SINGLES_NOJOINTS_PGS_NOLAT_numgloms2.pickle')
#]
def chisqfunc(params, mitnum, ydata, errdata):
ampl = params[0:num_sins]
phase = params[num_sins:2*num_sins]
DC = params[-1]
global iterationnum
if iterationnum%100==0: print 'iteration number =',iterationnum
chisqarray = [0.0]
for sinnum,f in enumerate(sine_frequencies):
## Leave the first cycle of lowest frequency out for transient settling
## Take the first cycle after leaving above time out
startcyclenum = 1
startbin = int(startcyclenum/float(f)/bindt)
## ydata[sinnum][binnum], similar for errdata
data = ydata[sinnum]
error = errdata[sinnum]
omegabindt = 2*pi*f*bindt
for binnum in range(startbin,NUM_REBINS):
## ampl must be positive, sign appears via phase; phase modulo 2pi
Rmodel = DC + abs(ampl[sinnum]) * sin( omegabindt*binnum + (phase[sinnum]%(2*pi)) )
if Rmodel<0.0: Rmodel=0.0 # threshold if below zero
## divide by error to do chi-square fit
chisqarray.append( (data[binnum] - Rmodel)/error[binnum] )
## not yet squared, so normalized 'chi' to sqrt of number of dof
## ydata[sinnum][binnum]
chisqarray = array(chisqarray) / sqrt(ydata.size-NUMPARAMS)
iterationnum += 1
return chisqarray
def fit_sins(filename, fitted_mitral):
f = open(filename,'r')
mitral_responses_list = pickle.load(f)
f.close()
## mitral_responses_list[avgnum][sinnum][mitnum][spikenum]
mitral_responses_binned_list = \
rebin_pulses(mitral_responses_list, NUM_REBINS, SIN_RUNTIME, 0.0)
numavgs = len(mitral_responses_list)
mitral_responses_mean = mean(mitral_responses_binned_list, axis=0)
mitral_responses_std = std(mitral_responses_binned_list, axis=0)
## take only the responses of the mitral to be fitted
firingbinsmeanList = mitral_responses_mean[:,fitted_mitral,:]
firingbinserrList = mitral_responses_std[:,fitted_mitral,:]/sqrt(numavgs)
## amplitude of sine wave, phase shift and DC offset
params0 = [0.0]*num_sins+[0.0]*num_sins+[0.0]
## put in a minimum error, else divide by zero problems, or NaN value params and fits
## find the minimum error >= errcut
largeerrors = firingbinserrList[where(firingbinserrList>errcut)]
if largeerrors is not (): errmin = largeerrors.min()
else: errmin = errcut
## numpy where(), replace by errmin,
## all those elements in firingbinsList which are less than errmin
firingbinserrList = where(firingbinserrList>errcut, firingbinserrList, errmin)
###################################### Fitting
params = optimize.leastsq( chisqfunc, params0,
args=(fitted_mitral, firingbinsmeanList, firingbinserrList),
full_output=1, maxfev=10000)
print params[3]
params = params[0] # leastsq returns a whole tuple of stuff - errmsg etc.
print "ampl[sinnum]+phase[sinnum]+DC =",params
## Calculate sum of squares of the chisqarray
chisqarraysq = [i**2 for i in
chisqfunc(params, fitted_mitral, firingbinsmeanList, firingbinserrList)]
chisq = reduce(lambda x, y: x+y, chisqarraysq)
############################## Calculate fitted responses and return them
DC_fit = params[-1]
ampl_fit = abs(params[0:num_sins])
phase_fit = params[num_sins:2*num_sins] % (2*pi)
fitted_responses = [ [ \
DC_fit + ampl_fit[sinnum] * sin( 2*pi*t*f + phase_fit[sinnum] ) \
for t in arange(0.0, SIN_RUNTIME, bindt) ] \
for sinnum,f in enumerate(sine_frequencies) ]
return (params,chisq,fitted_responses,firingbinsmeanList,firingbinserrList)
if __name__ == "__main__":
#if len(sys.argv) > 3:
#filename = sys.argv[1]
#ampl = float(sys.argv[2])
#DC = float(sys.argv[3])
#else:
#print "Specify responses data filename, sine amplitude, DC."
#sys.exit(1)
for fitted_mitral in fitted_mitral_list:
mainfig = figure(facecolor='w')
mainax = mainfig.add_subplot(111)
title('Mitral '+str(fitted_mitral)+' frequency response',fontsize=24)
mainfig2 = figure(facecolor='w')
mainax2 = mainfig2.add_subplot(111)
title('Mitral '+str(fitted_mitral)+' phase response',fontsize=24)
paramsList = []
for ampl,filename in filelist:
params,chisq,fitted_responses,firingbinsmeanList,firingbinserrList\
= fit_sins(filename, fitted_mitral)
print "Mit",fitted_mitral,"normalized chisq =",chisq
paramsList.append((ampl,params))
################# Plot simulated and fitted responses
if fitted_mitral != 0: continue
for sinnum in range(num_sins):
fig = figure(facecolor='w')
ax = fig.add_subplot(3,1,2)
sincolor = (sinnum+1) / float(num_sins)
## mean + error (lighter/whiter shade than mean below)
ax.plot(range(NUM_REBINS),\
firingbinsmeanList[sinnum]+firingbinserrList[sinnum],\
color=(0,(1-sincolor)*0.25+0.75,sincolor*0.25+0.75),\
marker='+',linestyle='solid', linewidth=2)
## mean
ax.plot(range(NUM_REBINS),firingbinsmeanList[sinnum],\
color=(0,1-sincolor,sincolor),\
marker='+',linestyle='solid', linewidth=2)
## fitted
ax.plot(range(NUM_REBINS),fitted_responses[sinnum],\
color=(1,1-sincolor,sincolor),\
marker='x',linestyle='solid', linewidth=2)
titlestr = 'Mitral %d response & sinusoid f=%f fit'\
%(fitted_mitral,sine_frequencies[sinnum])
title(titlestr, fontsize=24)
axes_labels(ax,'respiratory phase bin','firing rate (Hz)',adjustpos=True)
################# Plot frequency and phase responses
mainax.plot(sine_frequencies,abs(params[0:num_sins])/float(ampl),label=str(ampl)+'Hz ORN')
mainax2.plot(sine_frequencies,(params[0:num_sins]%(2*pi))/pi*180,label=str(ampl)+'Hz ORN')
axes_labels(mainax,'input frequency (Hz)','stimulus normalized output',adjustpos=True)
mainax.legend()
axes_labels(mainax2,'input frequency (Hz)','output phase (degrees)',adjustpos=True)
mainax2.legend()
show()