"""
Avrama Blackwell
2021 November 14
Subset of analyses and slight modification from net_anal
no regular stimulation
don't calculate sta
don't plot input rasters
calculate PSD from multiple neuron types
"""
import numpy as np
import glob
import utils as u
import net_anal_utils as nau #must be imported prior to na and isc - which import stuff from it
import plot_utils as pu
filedir='/home/avrama/moose/moose_nerp/moose_nerp/networks/'
param1=['matrix1s','striosome1s','strios2.5_1.5']
labels={'matrix1s': 'matrix','striosome1s':'striosome','strios2.5_1.5':'striosome\nstrong dSPN->iSPN'}
suffix="_out"
dt=0.1e-3
simtime=1.0
numtraces=0 #set to zero to avoid plotting traces
plot_fft_isi=True #set to false to avoid plotting fft and isi histogram
numbins=100 #number of bins (for histograms) if no overlap of bins
min_max=[0.0,0.6]#1.2]#sec
smooth_window=2e-3 #for instantaneous firing rate
maxfreq=100 #only save PSD from FFT from frequencies up to this number
init_time=0.200 #msec
shuffle_correct=False #set to false to calculate uncorrected FFT
num_shuffles=5
alldata={m:{} for m in ['spike_time','spike_rate','spike_rate_t','isis','histbins','fft','freqs']}
mean_spikes={m:{} for m in param1};ste_spikes={m:{} for m in param1}
norm_mean={}
def plot_traces(vmdat,dt,numtraces,ftitle=''):
from matplotlib import pyplot as plt
fig,ax=plt.subplots(len(vmdat),1,sharex=True)
fig.suptitle(ftitle)
for ii,ntype in enumerate(vmdat.keys()):
time=np.linspace(0,dt*len(vmdat[ntype][0]),len(vmdat[ntype][0]),endpoint=False)
for tracenum in range(numtraces):
ax[ii].plot(time,vmdat[ntype][tracenum])
ax[ii].set_ylabel(ntype+' Vm (mV)')
ax[ii].set_xlabel('Time (sec)')
def bar_plots(spike_rate_mean,spike_rate_std,labels):
from matplotlib import pyplot as plt
plt.figure()
x=np.arange(2)
for ii,cond in enumerate(spike_rate_mean.keys()):
plt.bar(x+ii/4.0,spike_rate_mean[cond].values(),yerr=spike_rate_std[cond].values(),
ecolor='black',width=0.25,label=labels[cond])
plt.xlabel('Neuron Type')
plt.legend(fontsize=10, loc='upper left')
for cond in param1:
accum_key=cond
pattern=filedir+cond+suffix
files=sorted(glob.glob(pattern+'*.npz'))
if len(files)==0:
print('************ no files found using',pattern)
elif len(files)>1:
print(' ####### Too many files, refine pattern')
else:
for f in files:
print('************ analyzing', f)
dat=np.load(f,'r',allow_pickle=True)
vmdat=dat['vm'].item()
spike_time=u.spiketime_from_vm(vmdat,dt)
maxtime=np.max([np.max(u.flatten(st)) for st in spike_time.values()]) #time of last spike
simtime = round(max(simtime, maxtime),1) #update simtime if needed
spike_rate,spike_rate_mean,spike_rate_tvals=u.spikerate_func(spike_time,simtime,smooth_window)
isi,isi_stats,isi_hist,histbins = u.isi(spike_time,numbins,min_max)
mean_fft,std_fft,freqs=u.fft_func(spike_rate,smooth_window,init_time,maxfreq)
shuffled_spikes=u.shuffle(spike_time,num_shuffles)
shuffle_rate,_,_=u.spikerate_func(shuffled_spikes,simtime,smooth_window)
shuffled_fft,shuffled_std,shuffle_freqs=u.fft_func(shuffle_rate,smooth_window,init_time,maxfreq)
correct_fft={ntype: mean_fft[ntype]-shuffled_fft[ntype] for ntype in mean_fft.keys()}
alldata['spike_time'][cond]=spike_time
alldata['spike_rate'][cond]=spike_rate_mean
alldata['spike_rate_t'][cond]=spike_rate_tvals
alldata['isis'][cond]=isi_hist
alldata['histbins'][cond]=histbins
if shuffle_correct:
alldata['fft'][cond]= correct_fft
else:
alldata['fft'][cond]= mean_fft
alldata['freqs'][cond]=freqs
for ntype in spike_time.keys():
time=np.linspace(0,smooth_window*len(spike_rate_mean[ntype]),len(spike_rate_mean[ntype]),endpoint=False)
start_index=np.min(np.where(time>init_time))
mean_spikes[cond][ntype]=np.mean(spike_rate_mean[ntype][start_index:])
ste_spikes[cond][ntype]=np.std(spike_rate_mean[ntype][start_index:])/np.sqrt(len(spike_rate[ntype]))
norm_mean[cond]=np.mean(mean_spikes[cond]['D1']/mean_spikes[cond]['D2'])
###################################################
######## Single parameter set plots
###################################################
pu.plot_raster(spike_time,simtime,ftitle=cond)
print('******** ISI stats for', cond,isi_stats)
if numtraces>0:
plot_traces(vmdat,dt,numtraces,ftitle=cond)
###################################################
# plots comparing data across cond
###################################################
from matplotlib import pyplot as plt
plt.figure()
plt.bar(labels.values(),norm_mean.values(), ecolor='black')
pu.plot_dict_of_dicts(alldata['spike_rate'],alldata['spike_rate_t'],ylabel='rate (Hz)')
bar_plots(mean_spikes,ste_spikes,labels)
if plot_fft_isi:
pu.plot_dict_of_dicts(alldata['fft'],alldata['freqs'],ylabel='PSD',xlabel='Freq (Hz)')
hist_xvals= {p:[(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)] for p,bins in alldata['histbins'].items()}
pu.plot_dict_of_dicts(alldata['isis'],hist_xvals,xlabel='ISI (sec)', ylabel='number')
for cond in param1:
confile_name=filedir+cond+'*_connect.npz'
nau.print_con(confile_name)
'''
np.save(outfname,alldata)
if savetxt:
nau.write_dict_of_dicts(spikerate_mean,rate_xvals,'spike_rate_'+out_fname,'rate',spikerate_std)
nau.write_dict_of_dicts(spikerate_elphmean,elph_xvals,'elph_spike_rate_'+out_fname,'Erate',spikerate_elphstd)
nau.write_triple_dict(isihist_mean,'isi_histogram_'+out_fname,'isiN',isihist_std,xdata=hist_xvals,xheader='isi_bin') #possibly delete triple dict and loop over neur type?
nau.write_dict_of_dicts(mean_fft,alldata.freqs,'fft_'+out_fname,'fft',std_fft,xheader='freq') #this may need triple dict if do fft for multiple neur types
'''