import numpy as np
import elephant
#must import elephant prior to matplotlib else plt.ion() doesn't work properly
import ISI_anal
import ep_plot_utils as pu
exec(open('/home/avrama/ephys_anal/fft_utils.py').read())
from matplotlib import pyplot as plt
plt.ion()
colors=['r','k','b']
####################################
# Parameters of set of files to analyze
neurtype='ep'
plasYN=1
presyn=['str','GPe']
numbins=10
networksim=1
if networksim:
condition=['POST-NoDaosc', 'POST-HFSosc', 'GABAosc']
condition=['GABA','GABAosc']
inj='0.0'
presyn_set=[(0,'non',1)]#,(0,'non',0)]#(20,'str'),(40,'GPe'),
filedir='ep_net/output/'
else:
stim_freqs=[5,10,20,40]
condition=['-1e-11']#'0.0',
presyn_set=[(freq,syn) for freq in stim_freqs for syn in presyn]
rootname='ep_syn'
filedir='ep/output/'
show_plots=1
############################################################
####### plots for single neuron simulations, multiple frequencies, single trials:
if not networksim:
for inj in condition:
#specify file name pattern
suffix='_plas'+str(plasYN)+'_inj'+inj+'*.npz'
fileroot=filedir+rootname
all_results,all_xvals=pu.plot_freq_dep_psp(fileroot,presyn,suffix,neurtype)
pu.plot_freq_dep_vm(fileroot,presyn,plasYN,inj,neurtype)
else:
#Network simulations
#time points for spike triggered average
mean_prespike_sta1={};mean_prespike_sta2={}
mean_sta_vm={};vmdat={};sta_list={}
spiketime_dict={};syntt_info={}
lat_mean={};lat_std={}
isi_mean={};isi_std={}
isi_set={};all_isi_mean={}
fft_wave={};phase={};freqs={};mean_fft_phase={}
for cond in condition:
sta_start=-40e-3
sta_end=0
#specify file name pattern
rootname='ep'+cond+'_syn'
#suffix='_plas'+str(plasYN)+'_inj'+inj+'*.npz'
suffix='_inj'+inj+'*.npz'
fileroot=filedir+rootname
##### 1st set of analyses ignores the input spikes; most analyses,except for sta, assume multiple trials
mean_sta_vm[cond]={}
fft_wave[cond]={};phase[cond]={};mean_fft_phase[cond]={}
for (freq,syn,plasYN) in presyn_set:
#key=syn+'_'+'freq'+str(freq)
key=syn+'_'+'freq'+str(freq)+'_plas'+str(plasYN)
pattern=fileroot+key+suffix
files=ISI_anal.file_set(pattern)
if len(files):
spiketime_dict[key],syntt_info[key]=ISI_anal.get_spiketimes(files,neurtype)
if freq>0:
#latency not defined if no regular stimulation
#isi in these two functions is separated into pre and post stimulation - requires regulator stimulation
lat_mean[key],lat_std[key],isi_mean[key],isi_std[key], bins=ISI_anal.latency(files,freq,neurtype,numbins)
isi_set[key]=ISI_anal.ISI_histogram(files,freq,neurtype)
if sta_start != sta_end:
#ep spike triggered average of vm before the spike (the standard sta)
sta_list[key],pre_xvals,plotdt,vmdat[key]=ISI_anal.sta_set(files,spiketime_dict[key],neurtype,sta_start,sta_end)
mean_sta_vm[cond][key]=np.mean(sta_list[key],axis=0)
time_wave=np.linspace(0,plotdt*len(vmdat[key][0][0]),len(vmdat[key][0][0]),endpoint=False)
fft_wave[cond][key],phase[cond][key],freqs[cond],mean_vm,mean_fft_phase[cond][key]=ISI_anal.fft_func(vmdat[key],time_wave,init_time=1.0,endtime=18.0)
all_isi_mean[cond]=isi_mean
#
#####1st set of graphs
if show_plots:
pu.plot_postsyn_raster(rootname,suffix,spiketime_dict,syntt_info)
if len(lat_mean):
pu.plot_latency(rootname,lat_mean,lat_std,suffix)
#latency not too meaningful if spikes occur only every few IPSPs, e.g. with 40 Hz stimulation
pu.plot_ISI(rootname,isi_mean,isi_std,bins,suffix)
#ISI histogram
pu.plot_isi_hist(rootname,isi_set,numbins,suffix)
if sta_start != sta_end:
#ep spike triggered average of vm before the spike (the standard sta)
pu.plot_sta_vm(pre_xvals,sta_list,fileroot,suffix)
fft_plot(time_wave,vmdat[key],freqs[cond],fft_wave[cond][key],phase=phase[cond][key],title=cond+key)#,mean_fft=mean_fft_phase[cond])
#
################## additional spike triggered averages and raster plot of input spike times,
######## This next set of analyses requires the input spikes
#2. spike triggered Vm after an input spike
#uses different filenames and different sta start and end
sta_start=0e-3
sta_end=20e-3
pre_spikes={}
post_sta={}
mean_sta={}
fileroot=filedir+'tt'+rootname
suffix=suffix.split('npz')[0]+'npy'
for (freq,syn,stp) in presyn_set:
#pattern=fileroot+syn+'_freq'+str(freq)+suffix
pattern=fileroot+syn+'_freq'+str(freq)+'_plas'+str(stp)+suffix
files=ISI_anal.file_set(pattern)
print('tt files',pattern, 'num files',len(files))
#key=syn+'_'+'freq'+str(freq)
if len(files):
key=syn+'_'+'freq'+str(freq)+'_plas'+str(plasYN)
#calculate raster of pre-synaptic spikes
pre_spikes[key]=ISI_anal.input_raster(files)
# input Spike triggered average Vm after the spike
post_sta[key],mean_sta[key],post_xvals=ISI_anal.post_sta_set(pre_spikes[key],sta_start,sta_end,plotdt,vmdat[key])
if show_plots:
pu.plot_sta_post_vm(pre_spikes,post_sta,mean_sta,post_xvals)
for key in pre_spikes:
pu.plot_input_raster(pre_spikes[key],pattern,maxplots=1)
#
#3. use both pre-synaptic and post-synaptic spikes for spike triggered average input:
#1st calculate instantaneous input firing frequency for each type of input
#2nd calculate sta using input fire freq instead of vmdat
#weights used to sum the different external inputs - values are weights from param_net
weights={'gabaextern2':-2,'gabaextern3':-1,'ampaextern1':1}
binsize=plotdt*10#*100
sta_start=-20e-3
sta_end=0
inst_rate1={}; inst_rate2={}
prespike_sta1={}; prespike_sta2={}
mean_pre_sta1={}; mean_pre_sta2={}
'''
#
for synfreq in pre_spikes:
inst_rate1[synfreq],inst_rate2[synfreq],xbins=ISI_anal.input_fire_freq(pre_spikes[synfreq],binsize)
prespike_sta1[synfreq],mean_pre_sta1[synfreq],bins1=ISI_anal.sta_fire_freq(inst_rate1[synfreq],spiketime_dict[synfreq],sta_start,sta_end,weights,xbins)
prespike_sta2[synfreq],mean_pre_sta2[synfreq],bins2=ISI_anal.sta_fire_freq(inst_rate2[synfreq],spiketime_dict[synfreq],sta_start,sta_end,weights,xbins)
mean_prespike_sta1[cond]=mean_pre_sta1
mean_prespike_sta2[cond]=mean_pre_sta2
######## second set of graphs
if show_plots:
for synfreq in inst_rate1:
pu.plot_inst_firing(inst_rate1[synfreq],xbins,title=cond+synfreq+' smoothed')
#pu.plot_inst_firing(inst_rate2[synfreq],xbins,title=cond+synfreq)
pu.plot_prespike_sta(prespike_sta1[synfreq],mean_pre_sta1[synfreq],bins1,title=cond+synfreq+' smoothed')
pu.plot_prespike_sta(prespike_sta2[synfreq],mean_pre_sta2[synfreq],bins2,title=cond+synfreq)
'''
#
##### Plots of means compared across conditions or across presyn_set
if len(condition)>1:
pu.plot_sta_vm_cond(pre_xvals,sta_list,mean_sta_vm)
#pu.plot_fft_cond(freqs,fft_mean,fft_wave)
fig,axes=plt.subplots(1,1)
fig.suptitle('Mean fft')
for i,(cond,fft_set) in enumerate(mean_fft_phase.items()):
maxval=np.max([np.max(np.abs(f['mag'][1:])) for f in mean_fft_phase[cond].values()])
maxfreq=np.min(np.where(freqs[cond]>500))
for key,fft in fft_set.items():
axes.plot(freqs[cond][0:maxfreq], np.abs(fft['mag'])[0:maxfreq], label=cond+' '+key+' mean',color=colors[i])
mean_of_fft=np.mean([np.abs(fft) for fft in fft_wave[cond][key]],axis=0)
axes.plot(freqs[cond][0:maxfreq], mean_of_fft[0:maxfreq],'--',label='mean of '+cond+' '+key,color=colors[i])
axes.set_xlabel('Frequency in Hertz [Hz]')
axes.set_ylabel('FFT Magnitude')
axes.set_xlim(0 , freqs[cond][maxfreq] )
axes.set_ylim(0,np.round(maxval) )
axes.legend()
#
if len(lat_mean):
pu.plot_ISI_cond(all_isi_mean,bins)
#
if len(inst_rate1):
pu.plot_prespike_sta_cond(mean_prespike_sta1,bins1)
#
else:
if len(presyn_set)>1:
cond=condition[0]
plt.figure()
plt.suptitle('mean sta vm')
for i,(synfreq,mean_sta) in enumerate(mean_sta_vm[cond].items()):
plt.plot(pre_xvals,mean_sta,label=synfreq)
plt.legend()
plt.xlabel('time (s)')
plt.ylabel('Vm (V)')
#
fig,axes=plt.subplots(1,1)
fig.suptitle('Mean fft')
maxval=np.max([np.max(np.abs(f['mag'][1:])) for f in mean_fft_phase[cond].values()])
maxfreq=np.min(np.where(freqs[cond]>500))
for i,(key,fft) in enumerate(mean_fft_phase[cond].items()):
axes.plot(freqs[cond][0:maxfreq], np.abs(fft['mag'])[0:maxfreq], label=cond+' '+key+' mean',color=colors[i])
axes.set_xlabel('Frequency in Hertz [Hz]')
axes.set_ylabel('FFT Magnitude')
axes.set_xlim(0 , freqs[cond][maxfreq] )
axes.set_ylim(0,np.round(maxval) )
axes.legend()