#!/usr/bin/env python
# coding: utf-8
import matplotlib
#matplotlib.use('Qt5Agg')
from matplotlib import pyplot as plt
plt.rcParams.update(plt.rcParamsDefault)
plt.ion()
import seaborn as sns
import numpy as np
plt.style.use(['seaborn-paper',
{'axes.spines.right':False,
'axes.spines.top': False,
'figure.constrained_layout.use': True,
'pdf.fonttype': 3,#42
'ps.fonttype': 3,
'savefig.dpi': 300,
'savefig.pad_inches': 0,
'figure.figsize':[8,8/1.333],#[11.32,11.32/1.3333],
}])
import sys
import glob
from quantities import s,Hz
import elephant
import neo
import pandas as pd
from scipy.stats import pearsonr
import plas_sim_anal_utils as psau #from moose_nerp.anal
import plas_sim_plots as plas_plot
#### Where to put figures when savefig=True.
figure_path='/home/avrama/moose/NSGPlas_2022jun23_ran0_220_uni/'
#'/home/dbd/Dissertation/plasChapterFigs/''/home/dbd/ResearchNotebook/IBAGS2019/FiguresSrc/'
####data file directory: need to allow this to be specified with args
ddir = figure_path+'NSGPlasticity/testdata/'
#ddir='/media/mybook/ExternalData/plasticity_full_output/NSGPlasticity/testdata' #old seed (not clustered)
csvdir=figure_path+'NSGPlasticity/'
spine_soma_dist_file='/home/avrama/python/DormanAnal/spine_soma_distance.csv'
spine_to_spine_dist_file=csvdir+'D1_short_patch_8753287_D1_17_s2sdist.npz'
#### For truncated normal and moved spikes
ddir='/home/avrama/python/DormanAnal/'
input_dir='/home/dbd/moose_nerp_old/' #also has trunc_normal and moved outputs with longer filenames
figure_path=ddir+'figures/'
csvdir=ddir
spine_to_spine_dist_file=csvdir+'spine_to_spine_dist_D1PatchSample5.csv'
####for clustered inputs
param_file=csvdir+'testparams.pickle'
#### Control which analyses and graphs to generate
plot_input=False
plot_hist=False
regression_all=False #this takes a long time
other_plots=True
RF_use_binned_weight=False
RF_plots=False
linear_reg=False
plot_neighbor_image=False
combined_presyn_cal=False
combined_spatial=False
savefig=False
fontsize=12
warnings=5
dW=False
try:
commandline = ARGS.split() #in python: define space-separated ARGS string
do_exit = False
except NameError: #if you are not in python, read in filename and other parameters
commandline = sys.argv[1:]
do_exit = True
sim_files=commandline[0] #'trunc_normal' #choices: 'nmda_block','moved','trunc_normal', 'seed_1'
#
if len(commandline)>1:
fstart=float(commandline[1])
fend=float(commandline[2])
else:
fstart=0
fend=1
#
tt_names={}
sigma={}
############################ Specify Files ############################33
if sim_files=='trunc_normal':
###### 1. Truncated Normal, control using glob
files = {k:glob.glob(ddir+'plas_sim*'+fn+'*simtime_21.5*thresh_mod_1.158*')[0] for k,fn in zip(['low','medium','high','higher'],['LowV','MediumV','HighV','HigherV'])}
#files = {k:glob.glob(ddir+'*'+fn+'*')[0] for k,fn in zip(['low','medium','high'],['LowV','MediumV','HighV'])}
if plot_input:
tt_names={k:glob.glob(input_dir+'FullTrial'+fn+'*TruncatedNormal.npz')[0] for k,fn in zip(['low','medium','high','higher'],['LowV','MediumV','HighV','HigherV'])}
wt_change_title=''
elif sim_files=='nmda_block':
###### 2. Truncated Normal, nmda block using glob
files = {k:glob.glob(ddir+'*'+fn+'*nonmda*')[0] for k,fn in zip(['low','medium','high','higher'],['LowV','MediumV','HighV','HigherV'])}
## input time tables for truncated normal simulations
if plot_input:
tt_names={k:glob.glob(input_dir+'FullTrial'+fn+'*TruncatedNormal.npz')[0] for k,fn in zip(['low','medium','high','higher'],['LowV','MediumV','HighV','HigherV'])}
wt_change_title='nmda block'
elif sim_files=='moved':
###### 3. Alternative method of introducing variability
files = {int(f.split('Prob_')[-1].split('_Percent')[0].split('.')[0]):f for f in glob.glob(ddir+'*MovedSpikes*.npy')}
files = {k:f for k,f in sorted(files.items())} #sort by move probability
## input time tables for moved spikes
if plot_input:
tt_names={k:glob.glob(input_dir+'MovedSpikesToOtherTrains_Prob_'+str(k)+'_Percent.npz')[0] for k in range(10,101,10)}
#low=list(tt_names.keys())[0]
#high=list(tt_names.keys())[-1]
wt_change_title='Moved Spikes'
elif sim_files.startswith('seed'):
###### 4. single trials, with each trial having different random assignment of spike trains
#
all_files=glob.glob(ddir+'plas_simD1PatchSample5*'+sim_files+'*.npy')
if fstart>0 or fend<1:
file_start=int(fstart*len(all_files))
file_end=int(fend*len(all_files))
files=all_files[file_start:file_end]
print(sim_files,'num files', len(files))
sim_files=sim_files+str(file_start)+'_'+str(file_end)
else:
files=all_files
files = {f.split('seed_')[-1].split('.')[0]:f for f in files}
wt_change_title='Clusters'
#########################################################################
if 'low' in files.keys():
titles = ['$\sigma=1$ ms','$\sigma=10 $ ms','$\sigma=100 $ ms','$\sigma=200 $ ms']
sigma={'low':1,'medium':10,'high':100,'higher':200}
keys=list(files.keys())
low='low'
high='higher'
elif len(files)==4:
keys=list(files.keys())
titles=['P(move)='+str(k)+'%' for k in keys]
low=keys[0]
high=keys[-1]
elif len(files)<=10:
#[::3] is to select a subset of the inputs - weight_vs_variability only has room for 4 panels
keys=[k for k in files.keys()][::3]
titles=['P(move)='+str(k)+'%' for k in files.keys()][::3]
low=keys[0]
high=keys[-1]
else: #if many files, randomly select 4 of them
keys=list(np.random.choice(list(files.keys()),4, replace=False))
titles=['seed='+str(k) for k in keys]
low=keys[0]
## load subset of data, determine some parameters ##
data = {k:np.load(files[k],mmap_mode='r') for k in keys} #subset of data, used for weight_vs_variability plot
datalow=data[low]
simtime=round(datalow['time'][-1])+1
dt=datalow['time'][1]
nochangedW=0.00001
params={'neighbors':20,
'bins_per_sec':100,
'samp_rate':int(1/dt),
'ITI':2, #inter-trial interval - not sure how to read from data
'dt':dt,
'length_of_firing_rate_vector':100,
'ca_downsamp':10,
'simtime':simtime,
'nochangedW':nochangedW}
if plot_input and len(tt_names): ######### fcombined is Figure 1 in manuscript
############ Input spike trains ##################
tt_Ctx_SPN={k:np.load(f,allow_pickle=True) for k,f in tt_names.items()}
fraster,fcombined=plas_plot.input_plot(tt_Ctx_SPN,datalow,low,high)
if savefig:
fraster.savefig(figure_path+sim_files+'initialTrialRasterPSTH.pdf')
fcombined.savefig(figure_path+sim_files+'RasterPSTHSomaVmCombined.pdf')
#### For Fig 2 - Find spine that potentiates the most and that depresses the most
spine_weight_dict = {}
for n in datalow.dtype.names:
if n.endswith('headplas'):
weight = (datalow[n][int(1.1/dt)])
spine_weight_dict[n]=weight
spine_weights = pd.Series(spine_weight_dict)
pd.DataFrame(spine_weights).sort_values(0)
print('min=',pd.DataFrame(spine_weights).sort_values(0).iloc[0],'\nmax=',pd.DataFrame(spine_weights).sort_values(0).iloc[-1])
#potentiated: '/data/D1-extern1_to_228_3-sp1headplas' for random, "/data/D1-extern1_to_312_3-sp0headplas" for truncated normal
#depressed: '/data/D1-extern1_to_259_3-sp1headplas' for random, "/data/D1-extern1_to_154_3-sp0headplas" for truncated normal
pot_ex=pd.DataFrame(spine_weights).sort_values(0).iloc[-1].name
dep_ex=pd.DataFrame(spine_weights).sort_values(0).iloc[0].name
############## Create weight_change_event array, binned spike train array, calcium traces array for calculating weight change triggered average ########
df,weight_change_event_df,inst_rate_array,trains,binned_trains_index,ca_trace_array,ca_index,t1weight_distr,inst_weight_change=psau.weight_change_events_spikes_calcium(files,params,warnings)
trial1_stimdf=weight_change_event_df[weight_change_event_df.time==params['ITI']]
trial1_stim_distr={'mean':trial1_stimdf.groupby('trial').mean().weightchange.values,
'std':trial1_stimdf.groupby('trial').std().weightchange.values,
'files':np.unique(trial1_stimdf.trial),
'no change':np.array([np.sum(g.weightchange==0) for k,g in trial1_stimdf.groupby('trial')])}
print('wce df', weight_change_event_df.head())
######### spine to spine distance, spine to soma distance, and cluster information ######
import os
if len(param_file) and sim_files.startswith('seed'):
weight_change_event_df,df,inst_weight_change=psau.add_cluster_info(weight_change_event_df,df,inst_weight_change,param_file)
filename, file_extension = os.path.splitext(spine_to_spine_dist_file)
if file_extension == '.csv':
sp2sp = pd.read_csv(spine_to_spine_dist_file,index_col=0)
else:
sp2sp_data=np.load(spine_to_spine_dist_file,allow_pickle=True)
spine_to_spine_dist_array=sp2sp_data['s2sd']
allspines=sp2sp_data['index'].item().keys()
sp2sp = pd.DataFrame(spine_to_spine_dist_array,columns=allspines,index=allspines)
newindex = [s.replace('[0]','').replace('/','_').replace('D1','').lstrip('_') for s in sp2sp.index]
sp2sp.index = newindex
sp2sp.columns = newindex
if 'soma' in sp2sp.columns:
weight_change_event_df,df=psau.add_spine_soma_dist(weight_change_event_df,df, sp2sp.soma,warnings=warnings)
else:
weight_change_event_df,df=psau.add_spine_soma_dist(weight_change_event_df,df, spine_soma_dist_file,warnings=warnings)
df['spine'] = df['spine'].apply(lambda s: s.replace('ecdend','secdend'))
grouped=df.groupby('trial') ## for sim with spine clusters, need to groupby seed
stimspinetospinedist={};sorted_other_stim_spines={}
for grp in grouped.groups.keys():
dfx=grouped.get_group(grp)
stimspinetospinedist=sp2sp.filter(items = dfx.loc[dfx['stim']==True].spine.drop_duplicates(),axis=0).filter(items = dfx.loc[dfx['stim']==True].spine.drop_duplicates(),axis=1)
sorted_other_stim_spines[grp] = pd.concat([pd.Series(stimspinetospinedist[c].sort_values().index, name=c) for c in stimspinetospinedist.columns ], axis=1)
##################### Figure 3: Show synaptic weight ending histogram after first trial for example
### Since first trial, results are same for all variabilities
if plot_hist:
fhist=plas_plot.weight_histogram(datalow)
endweight_dist={}
else:
print('********* trial1 end weight for all synapses ************ ')
for i in range(len(t1weight_distr['mean'])):
print('file:',t1weight_distr['files'][i], 'mean weight=',round(t1weight_distr['mean'][i],3),'+/-',
round(t1weight_distr['std'][i],3),', no change:',t1weight_distr['no change'][i])
print('overall mean weight change=',np.mean(t1weight_distr['mean']))
#for k,d in data.items():
#index_for_weight = (np.abs(d['time'] - 2)).argmin() #1st trial ending weight. 2 = ITI? get weight just prior to next trial
#t1weight = [d[n][index_for_weight] for n in d.dtype.names if 'plas' in n]
#print(k,round(np.mean(t1weight),3),'+/-',round(np.std(t1weight),3),', no change:',t1weight.count(1.0))
changedf=df[(df['endweight']<.99) | (df['endweight']>1.01)]
if other_plots:
####### Unused Figure #######
#weight_change_plots(weight_change_event_df)
############# Figure 2 in manuscript - illustrates calcium based learning rule
if not sim_files.startswith('seed'):
fig=plas_plot.plot_spine_calcium_and_weight(datalow,pot_ex,dep_ex)
########## Figure 4 top panels in manuscript
f4top,f4bot=plas_plot.weight_vs_variability(data,df,titles,keys,sigma=sigma)
################ Figure 5 in manuscript - end weight vs presyn is shaped like BCM curve, only use 4 data files
f_bcm=plas_plot.endwt_plot(df,'spikecount','Total Presynaptic Spike Count',titles)
#### Figure 5-supplement in manuscript - end weight vs spine location
if 'spinedist' in changedf.columns:
f_spinedist=plas_plot.endwt_plot(df,'spinedist','Distance to Soma ($\mu$ M)',titles)
#Maybe this one is better?
f_spinedist=plas_plot.endwt_plot(changedf,'spinedist','Distance to Soma ($\mu$ M)',titles)
print('########### Spine distance correlation ########')
if len(np.unique(changedf['trial']))<=10:
for var,group in changedf.groupby('trial'):
print(var,pearsonr(group['spinedist'],group['endweight']))
print('spinedist corr, all',pearsonr(changedf['spinedist'],changedf['endweight']))
if savefig:
if plot_hist:
fhist.savefig(figure_path+sim_files+'WeightHistAfterInitialTrial.pdf')
if not sim_files.startswith('seed'):
#fig.savefig(figure_path+sim_files+'ca_plas_example_fig.eps')
#fig.savefig(figure_path+sim_files+'CaPlasExample.pdf')
#fractional_size(f4top,[1,.75])
#f4top.savefig(figure_path+sim_files+'combined_synaptic_weight_figure.eps')
f4top.savefig(figure_path+sim_files+'CombinedSynapticWeightFullTrialVariability.pdf')
plas_plot.fractional_size(f4bot,[1,.75])
f4bot.savefig(figure_path+sim_files+'EndweightDistByLTPandLTDvsSigma.pdf')
plas_plot.fractional_size(f_bcm,[1,1])
f_bcm.savefig(figure_path+sim_files+'EndingWeightVsPresynSpikeCount.pdf')
plas_plot.fractional_size(f_spinedist,[1,1])
f_spinedist.savefig(figure_path+sim_files+'EndingWeightVsSpine2somaDistance.pdf')
if 'cluster_length' in changedf.columns:
f_clusterL=plas_plot.endwt_plot(df,'cluster_length','cluster length', titles)
f_clusterS=plas_plot.endwt_plot(df,'spines_per_cluster','cluster length', titles)
if savefig:
f_clusterL.savefig(figure_path+sim_files+'EndWeight_vs_clust_length.pdf')
f_clusterS.savefig(figure_path+sim_files+'EndWeight_vs_spines_cluster.pdf')
if 'cluster_length' in changedf.columns:
print('cluster length corr, changedf:',pearsonr(changedf['cluster_length'],changedf['endweight']),
'all synapses',pearsonr(df['cluster_length'],df['endweight']))
print('spines_per_cluster corr, changedf:',pearsonr(changedf['spines_per_cluster'],changedf['endweight']),
'all synapses',pearsonr(df['spines_per_cluster'],df['endweight']))
###################################################################################################################
###################### Weight change triggered average of pre-synaptic firing and of calcium dynamics #################
# ## Generate the 2d array for binned spike times and synaptic distance for a given synapses around a given weight change event
## Run this to create arrays with instantaneous firing rate, will be averaged over synapses with similar weight change
# ## Also, combined neighboring spike trains into one, to look at neighboring firing rate
params['tstart']=1.9 #input spikes begin 1.9 sec prior to weight change time. I.e., trials are 0.1-1.1 sec for the 2 sec weight time
params['tend']=0.9 #not used
params['duration']=1.0#tstart-tend #evaluate 1 sec of firing
if 'cluster_length' in changedf.columns:
params['max_dist']=np.max(weight_change_event_df['cluster_length']) #~200 microns
else:
params['max_dist']=np.max(np.max(sp2sp))
params['dist_thresh']=50e-6 #max_dist ~200 microns, try 50 or 100 microns
print('>>>>>>>>>> Calculating weight change triggered calcium and pre-synaptic firing rate ')
weight_change_alligned_array,weight_change_weighted_array,combined_neighbors_array,weight_change_alligned_ca=psau.weight_changed_aligned_array(params,weight_change_event_df,ca_index,sorted_other_stim_spines,binned_trains_index,ca_trace_array,sp2sp,inst_rate_array,trains)
############# create bins of weight change
## option 1 - 9 evenly spaced bins
numbins=9
nochange=0.01
weight_change_event_df.sort_values('weightchange')
binned_weight_change_index = np.linspace(weight_change_event_df.weightchange.min(),weight_change_event_df.weightchange.max(),numbins)
## Alternative binning: 7 groups: Strong LTD, Moderate LTD, Weak LTD, NO Change, Weak LTP, Moderate LTP, Strong LTP
binned_weight_change_dict,binned_weight_change_index=psau.binned_weight_change(weight_change_event_df,numbins,'weightchange',nochange)
#### calculate mean value within each bin
binned_means = {};binned_std={}
binned_weighted_means = {}
binned_calcium={};binned_calcium_std={}
binned_max={}
for k,v in binned_weight_change_dict.items():
binned_means[k],binned_std[k],_ = psau.mean_std(weight_change_alligned_array,v)
binned_calcium[k]=np.nanmean(weight_change_alligned_ca[:,v],axis=1)
binned_calcium_std[k]=np.nanstd(weight_change_alligned_ca[:,v],axis=1)
binned_weighted_means[k]=np.nanmean(weight_change_weighted_array[:,:,v],axis=2)
print('wt bin',k,len(v))
#Binned means for instantaneous weight change
dW_aligned_array,inst_weight_change=psau.calc_dW_aligned_array(inst_weight_change,sorted_other_stim_spines,params,binned_trains_index,inst_rate_array,duration=0.050)
variables=['post_spike_dt','pre_spike_dt','isi','pre_rate','pre_spike2_dt','isi2','pre_interval']
binned_dW_dict,binned_dW_index=psau.binned_weight_change(inst_weight_change,numbins,'dW',nochangedW)
binned_spiketime,binned_spiketime_std=psau.bin_spiketime(binned_dW_dict,inst_weight_change,variables)
print('dW weight change bins')
binned_dW_means={}
binned_dW_std={}
for k,v in binned_dW_dict.items():
print(k,len(v))
binned_dW_means[k],binned_dW_std[k],_=psau.mean_std(dW_aligned_array,v)
if dW:
cs = plt.cm.coolwarm(np.linspace(0,1,len(binned_dW_means)-1))
cs =list(cs)
cs.insert(len(binned_means)//2,plt.cm.gray(0.5))
f_dW,ax=plt.subplots(1,1,constrained_layout=True,sharey=True)
x = np.linspace(0,0.05,params['bins_per_sec']) #only plot 50 ms
for i,(k,v) in enumerate(binned_dW_means.items()):
m=v[1:,:].mean(axis=0)
err = v[1:,:].std(axis=0)
#ax.fill_between(x[0:len(m)],m-err, m+err,alpha=.5,color=cs[i])
lbl=' '.join([str(round(float(k[0]),5)),'to',str(round(float(k[1]),5))])
ax.plot(x[0:len(m)],m,c=cs[i],label=lbl)
ax.set_ylabel('dW Neighbros Instantaneous Firing Rate (Hz)')
ax.set_xlabel('Time (s)')
#if cbar:
# plas_plot.colorbar7(cs,binned_weight_change_index,ax)
#else:
ax.legend()
f_im= plas_plot.nearby_synapse_image(np.nanmean(dW_aligned_array,axis=2),binned_dW_means,tmax=0.05)
all_mean = np.nanmean(weight_change_alligned_array,axis=2)
pot_index = weight_change_event_df.loc[weight_change_event_df.weightchange>nochange].index.to_numpy()
pot_mean,pot_std,pot_absmean=psau.mean_std(weight_change_alligned_array,pot_index)
dep_index = weight_change_event_df.loc[weight_change_event_df.weightchange<-nochange].index.to_numpy()
dep_mean,dep_std,dep_absmean=psau.mean_std(weight_change_alligned_array,dep_index)
nochange_index = weight_change_event_df.loc[weight_change_event_df.weightchange==0].index.to_numpy()#[::100]
nochange_mean,nochange_std,nochange_absmean=psau.mean_std(weight_change_alligned_array,nochange_index)
mean_firing_3bins={'Potentiation':pot_mean,'Depression':dep_mean,'No-change':nochange_mean}
calcium_fft,fft_freq=psau.fft_anal(binned_calcium)
###################################################################################################################
######################## Plots of Weight Change Triggered Average Pre-Syn firing #####################
if other_plots:
print('begin plots of wcta')
# Bin the colorbar
if matplotlib.__version__>'3.0.0':
cbar=True
else:
cbar=False
if combined_presyn_cal:
f_7bins,cs=plas_plot.combined_figure(binned_means,binned_calcium,binned_weight_change_index,params['duration'],std=[binned_std,binned_calcium_std])
else:
f_7bins,cs=plas_plot.weight_change_trig_avg(binned_means,binned_weight_change_index,params['duration'],std=binned_std,title=wt_change_title,colorbar=cbar)
f_ca,cs=plas_plot.weight_change_trig_avg(binned_calcium,binned_weight_change_index,params['duration'],std=binned_calcium_std,ylabel='Mean Calcium Concentration (mM)',colorbar=cbar)
#f_cafft,cs=plas_plot.weight_change_trig_avg(calcium_fft,binned_weight_change_index,max(fft_freq),ylabel='FFT Calcium Concentration (mM)',colorbar=cbar)
#### Use 3 bins - potentiate, depress, no change
cs3=[plt.cm.tab10(0),plt.cm.tab10(2),plt.cm.tab10(1)]
f3bins,_=plas_plot.weight_change_trig_avg(mean_firing_3bins,binned_weight_change_index,params['duration'],cs=cs3,title='Weight-change triggered average presynaptic firing rate')
################### Nearby Synapses, using 7 bins
f_near,ax=plt.subplots(1,1,constrained_layout=True,sharey=True)
x = np.linspace(0,1,params['bins_per_sec'])
for i,(k,v) in enumerate(binned_means.items()):
#f,ax=plt.subplots(1,1,constrained_layout=True,sharey=True)
#ax.plot(x,v[0,:],c=cs[i],label=k)
m=v[1:,:].mean(axis=0)-nochange_absmean
err = v[1:,:].std(axis=0)
#ax.fill_between(x,m-err, m+err,alpha=.5,color=cs[i])
lbl=' '.join([str(round(float(k[0]),3)),'to',str(round(float(k[1]),3))])
ax.plot(x,m,c=cs[i],label=lbl)
ax.set_ylabel('Mean-Subtracted Instantaneous Firing Rate (Hz)')
ax.set_xlabel('Time (s)')
if cbar:
plas_plot.colorbar7(cs,binned_weight_change_index,ax)
else:
ax.legend()
ax.set_title('Nearby Synapses '+wt_change_title)
################### Nearby Synapses, using 7 bins
# ## Combined neighboring spike trains into one
fig_comb_neighbors,ax = plt.subplots(1,1,constrained_layout=False)
x = np.linspace(0,1,len(combined_neighbors_array))
for i,(k,v) in enumerate(binned_weight_change_dict.items()):
lbl=' '.join([str(round(k[0],3)),'to',str(round(k[1],3))])
ax.plot(x,np.mean(combined_neighbors_array[:,v],axis=1),label=lbl,c = cs[i])
#ax.plot(x,np.mean(combined_neighbors_array[:,:],axis=1),label='all',c='k',alpha=.5)
ax.set_ylabel('Combined Instantaneous Firing\n Rate of Neighboring Synapses (Hz)',fontsize=fontsize)
ax.set_xlabel('Time (s)',fontsize=fontsize)
if cbar:
plas_plot.colorbar7(cs,binned_weight_change_index,ax,fontsize=fontsize)
else:
plt.legend()
if savefig:
plas_plot.fractional_size(f3bins,[1.5/2., .6])
plas_plot.new_fontsize(f_7bins,fontsize)
if combined_presyn_cal:
f_7bins.savefig(figure_path+sim_files+'BinnedWeightChangeTriggeredAverageCalcium.pdf')
else:
f_7bins.savefig(figure_path+sim_files+'BinnedWeightChangeTriggeredAverage.pdf')
plas_plot.new_fontsize(f_ca,fontsize)
f_ca.savefig(figure_path+sim_files+'BinnedWeightChangeTriggeredCalcium.pdf')
#f3bins.tight_layout()
#f3bins.savefig(figure_path+sim_files+'weight_change_triggered_average_firingrate_figure.eps')
#f_near.savefig(figure_path+sim_files+'nearby_synapses_weight_change_average_figure.svg')
f_near.savefig(figure_path+sim_files+'nearby_synapses_weight_change_average_figure.pdf')
plt.savefig(figure_path+sim_files+'CombinedNeigboringFiringRate_'+str(round(params['dist_thresh']*1e6))+'um.pdf')
############### Image plot of firing rate input to nearby synapses.
#### One image per weight change bin
if plot_neighbor_image:
#f_im= plas_plot.nearby_synapse_image(all_mean,binned_means)
f_im= plas_plot.nearby_synapse_1image(all_mean,binned_means)
#f_im= plas_plot.nearby_synapse_image(all_mean,binned_means,mean_sub=True,title='')
#f_im= plas_plot.nearby_synapse_image(all_mean,binned_means,mean_sub=False,title='')
wf_im= plas_plot.nearby_synapse_1image(all_mean,binned_weighted_means,mean_sub=True,title='wgt,')
plas_plot.fractional_size(f_im,[1.5/2,.6])
if savefig :
if isinstance(f_im,list):
for i,fim in enumerate(f_im):
fim.savefig(figure_path+sim_files+'NeighboringSynapse_Average_Heatmap'+str(i)+'.tif')
else:
#f_im.savefig(figure_path+sim_files+'NeighboringSynapse_WeightedAverage_Heatmap.svg')
f_im.savefig(figure_path+sim_files+'NeighboringSynapse_Average_Heatmap.tif')
wf_im.savefig(figure_path+sim_files+'NeighboringSynapse_WeightedAverage_Heatmap.tif')
if dW :
if len(keys)>3:
subset_df=inst_weight_change[(inst_weight_change.trial == keys[0]) | (inst_weight_change.trial == keys[1]) | (inst_weight_change.trial == keys[2]) | (inst_weight_change.trial == keys[3]) ]
f_isi1=plas_plot.inst_wt_change_plot(subset_df)
f_isi3=plas_plot.dW_2Dcolor_plot(subset_df)
else:
f_isi1=plas_plot.inst_wt_change_plot(inst_weight_change)
f_isi3=plas_plot.dW_2Dcolor_plot(inst_weight_change)
f_isi2=plas_plot.spiketime_plot(binned_spiketime,binned_spiketime_std)
if savefig:
for f in f_isi1.keys():
f_isi1[f].savefig(figure_path+sim_files+f+'.pdf')
for f in f_isi2.keys():
f_isi2[f].savefig(figure_path+sim_files+f+'binned.pdf')
for f in f_isi3.keys():
f_isi3[f].savefig(figure_path+sim_files+f+'_2Dcolor.pdf')
###################################################################################################################
####################### Covariance of spine input with neighboring spines #################################
'''
##calculating covariance separately for each weight_change bin
cov_dict={k:[] for k in binned_weight_change_dict.keys()}
for weight_change_bin,indices in binned_weight_change_dict.items():
for i in indices:
temp_cov = np.cov(combined_neighbors_array[:,i].T, weight_change_alligned_array[0,:,i] , bias=True)
cov_dict[weight_change_bin].append(temp_cov[0,1])
pot_cov = np.cov(combined_neighbors_array[::100,pot_index], weight_change_alligned_array[0,:,pot_index].T , bias=True)
'''
all_cov = []
all_xcor = []
all_cor = []
num_nan_corr=0
for i in range(combined_neighbors_array.shape[1]):
if not any(np.isnan(combined_neighbors_array[:,i])):
temp_cov = np.cov(combined_neighbors_array[:,i].T, weight_change_alligned_array[0,:,i] , bias=True)
all_cov.append(temp_cov[0,1])
temp_cor = np.correlate(combined_neighbors_array[:,i].T, weight_change_alligned_array[0,:,i])
all_xcor.append(temp_cor[0])
#print('shape=',np.shape(temp_cov),', cor=',round(temp_cor[0],3),', cov=',round(temp_cov[0,1],1))
temp_corr = np.corrcoef(combined_neighbors_array[:,i].T, weight_change_alligned_array[0,:,i])
if np.isnan(temp_corr[0,1]):
all_cor.append(-9)#FIXME
if combined_neighbors_array.shape[1]==len(binned_trains_index): #this is true if only a single trial was simulated
print ('NAN in temp_corr',i,binned_trains_index[i])
else:
num_nan_corr+=1
else:
all_cor.append(temp_corr[0,1])
if num_nan_corr>0:
print('number of NANs in temp_corr=',num_nan_corr)
all_cov = np.array(all_cov)
all_xcor = np.array(all_xcor)
all_cor = np.array(all_cor)
############ histogram figure from manuscript ####################
if other_plots:
print('plot histogram of neighbor correlations')
f_cov,ax = plt.subplots(constrained_layout=False)
for i,(k,v) in enumerate(binned_weight_change_dict.items()):
#plt.figure()
lbl=' '.join([str(round(k[0],3)),'to',str(round(k[1],3))])
ax.hist(all_cor[v],histtype='step',bins=21,range=(-1,1),density=True,label=k,linewidth=2,color=cs[i],alpha=.8)
ax.set_xlim(-1,1)
ax.set_xlabel('Correlation Coefficient between Direct and Neighboring Input')
ax.set_ylabel('Normalized Histogram')
#ax.hist(all_cor,histtype='step',bins=20,range=(-1,1),density=True,label=k,linewidth=2,color='k',alpha=.9,zorder=-10,linestyle='--')
if cbar:
plas_plot.colorbar7(cs,binned_weight_change_index,ax)
elif not cbar:
ax.legend()
f_cov5,ax = plt.subplots(constrained_layout=False)
histlist = [all_cor[v] for v in binned_weight_change_dict.values()]
histlist.append(all_cor)
histcolors=list(cs)
histcolors.append('k')
ax.hist(histlist,histtype='bar',bins=5,range=(-1,1),density=True,linewidth=2,color=histcolors,stacked=False)
ax.set_xlabel('Correlation Coefficient between Direct and Neighboring Input')
ax.set_ylabel('Normalized Histogram')
if cbar:
plas_plot.colorbar7(cs,binned_weight_change_index,ax)
else:
labels=[' '.join([str(round(k[0],3)),'to',str(round(k[1],3))]) for k in binned_weight_change_dict.keys()]+['all bins']
ax.legend(labels)
if savefig:
plas_plot.fractional_size(f_cov,(1,1))
plas_plot.new_fontsize(f_cov,fontsize)
f_cov.savefig(figure_path+sim_files+'CorrelationHistDirectNeighbors.pdf')
plas_plot.fractional_size(f_cov5,(1,1))
f_cov5.savefig(figure_path+sim_files+'Correl_5binsHistDirectNeighbors.pdf')
if combined_spatial:
fig_corr_space=plas_plot.combined_spatial(combined_neighbors_array,binned_weight_change_dict,binned_weight_change_index,all_cor,cs)
plas_plot.new_fontsize(fig_corr_space,fontsize)
fig_corr_space.savefig(figure_path+sim_files+'neighborsCorrelation.pdf')
else:
f_cov.suptitle('Distribution of Correlation Coefficients Between Direct and Neighboring Inputs')
f_cov5.suptitle('Distribution of Correlation Coefficients Between Direct and Neighboring Inputs')
'''
############## colormap when no other figures created ###################
cs = plt.cm.coolwarm(np.linspace(0,1,len(binned_means)-1))
cs =list(cs)
cs.insert(len(binned_means)//2,plt.cm.gray(0.5))
fig3D=plas_plot.plot_3D_scatter(inst_weight_change,'isi','pre_rate','dW',binned_dW_dict,cs)
fig3D.savefig(figure_path+sim_files+'_3DdW.pdf')
'''
###################################################################################################################
###############################################################
# ## Random Forest using average firing rate in 10 bins for both direct and n nearest neighbors + starting weight
######## X = weight_change_alligned_array
import RF_utils as rfu
num_events = weight_change_alligned_array.shape[2]
if RF_use_binned_weight:
y=weight_change_event_df.weight_bin
save_name=figure_path+sim_files+'weight_bin.npz'
else:
y = weight_change_event_df.weightchange
save_name=figure_path+sim_files
print('SAVE NAME', save_name)
if linear_reg:
save_name=save_name+'lin_reg'
if regression_all:
from sklearn.model_selection import train_test_split
######### Divide into training and testing set
## Need to reduce weight_change_alligned_array to 2 dimensions, by stacking neighborings spines
#weight_change_alligned_array dimensions: n+1 spines(1 direct path and n neighbors x trace length (100 bins) x trials (1 or 10 trials x number of spines assessed)
X = weight_change_alligned_array.reshape((weight_change_alligned_array.shape[0]*weight_change_alligned_array.shape[1],num_events)).T
X = X.reshape(num_events,params['neighbors']*10,10).mean(axis=2) #reduce number of temporal bins from 100 to 10, by taking mean of every 10 bins
Xall=np.concatenate((X,weight_change_event_df.startingweight.to_numpy().reshape(num_events,1)),axis=1) #use all spines as is
#from sklearn import linear_model
#from sklearn.feature_selection import RFE
X_train, X_test, y_train, y_test = train_test_split(Xall, y, test_size=0.1, random_state=42)
######### Does linear regression work?
#reg = linear_model.LinearRegression(fit_intercept=True).fit(X_train_, y_train)
#print(reg.score(X_train_, y_train),reg.coef_,reg.intercept_)
##### Implement random forest on training set, look at predictions for train and test set
print('########### not binned ##############')
reg,fit,pred=rfu.do_random_forest(X_train,y_train,X_test)
rfu.rand_forest_plot(y_train,y_test,fit,pred,'no bins')
#rfe = RFE(estimator=linear_model.LinearRegression(), n_features_to_select=1, step=1)
#rfe.fit(X_train_, y_train)
#ranking = rfe.ranking_
#print([X.columns[ranking==i] for i in range(1,11)])
########## Plots of random forest results
'''
plt.plot(reg.feature_importances_[:-1].reshape(20,10).T[:,0:10]);
from sklearn.inspection import permutation_importance
result = permutation_importance(regr,X_test,y_test)
plt.figure()
plt.plot(result.importances_mean)
'''
newX=weight_change_alligned_array[0,:,:].T #newX is single spine - transpose to become trials X length of trace
adjacentX=weight_change_alligned_array[1:,:,:].mean(axis=0).T #average over neighboring spines
bin_set=[1,3,5]
### train on 90% of data, test on 10%, optionally do some plots, uses random seed of 42
#reg_score,feature_import,linreg_score=rfu.random_forest_variations(newX, adjacentX-np.mean(adjacentX,axis=0), y, weight_change_event_df,all_cor,bin_set,num_events,wt_change_title,RF_plots=False,linear=linear_reg)
### repeat n (trials) times: train on 1-1/trials % of data, test on 1/trials % of data
#reg_score,feature_import,linreg_score=rfu.random_forest_LeaveNout(newX, adjacentX-np.mean(adjacentX,axis=0), y, weight_change_event_df,all_cor,bin_set,num_events,trials=4,linear=linear_reg)
############ RF on calcium bins #############
#ca_reg_score=rfu.RF_calcium(weight_change_alligned_ca,y,bin_set,trials=4)
#np.savez(save_name+'.npz',reg_score=reg_score,t1_endwt=t1weight_distr,t1_endwt_stim=trial1_stim_distr,feature_import=feature_import,linreg=linreg_score,ca_reg_score=ca_reg_score)
#rfu.RF_oob(weight_change_alligned_array,all_cor,weight_change_event_df,y, RF_plots=False)
###### calculate linear correlation between weight change and input firing frequency ###############
numbins=1
Xbins,_=rfu.downsample_X(newX,numbins,weight_change_event_df,num_events,add_start_wt=False)
Xadj,_=rfu.downsample_X(adjacentX,numbins,weight_change_event_df,num_events,add_start_wt=False)
newdf=pd.DataFrame(data=np.column_stack([Xbins,Xadj,y]),columns=['synapse','adj','weight_change'])
print('ALL: correlation, direct firing:',pearsonr(newdf['synapse'],newdf['weight_change']),', adj firing:',pearsonr(newdf['adj'],newdf['weight_change']))
newchangedf=newdf[(newdf.weight_change>.01) | (newdf.weight_change<-0.01)]
print('CHANGE DF: correlation, direct firing:',pearsonr(newchangedf['synapse'],newchangedf['weight_change']),
', adj firing:',pearsonr(newchangedf['adj'],newchangedf['weight_change']))
if 'cluster_length' in weight_change_event_df.columns:
clust_len=weight_change_event_df['cluster_length'].to_numpy().reshape(num_events,1)
clust_sp=weight_change_event_df['spines_per_cluster'].to_numpy().reshape(num_events,1)
newdf=pd.DataFrame(data=np.column_stack([Xbins,Xadj,y,clust_len,clust_sp]),columns=['synapse','adj','weight_change','cluster_length','spines_per_cluster'])
newchangedf=newdf[(newdf.weight_change>.01) | (newdf.weight_change<-0.01)]
print('ALL: correlation, cluster length',pearsonr(newdf['cluster_length'],newdf['weight_change']))
print('CHANGE DF: correlation, cluster length',pearsonr(newchangedf['cluster_length'],newchangedf['weight_change']))
print('cluster length vs spines per cluster',pearsonr(newdf['cluster_length'],newdf['spines_per_cluster']))
################# Calculate correlation using maximum across neighbors
#adjacentX=weight_change_alligned_array[1:,:,:].max(axis=1).max(axis=0) # pearsonr(newdf['adj'],newdf['wc'])=(0.037644946344189606, 0.011199245773898055)
#adjacentX=weight_change_alligned_array[1:,:,:].mean(axis=1).max(axis=0)# pearsonr(newdf['adj'],newdf['wc'])=(0.02789514149579918, 0.060216871815671015)
Xinput=weight_change_alligned_array[1:,:,:]
adjbins=10# pearsonr(newdf['adj'],newdf['wc'])=(0.038146383424438976, 0.010163113730931867)
#adjbins=20# pearsonr(newdf['adj'],newdf['wc'])=(0.03781034055692676, 0.0108475656243252)
#adjbins=5# pearsonr(newdf['adj'],newdf['wc'])=(0.03541923162826019, 0.01701610633017229)
downsamp=np.shape(Xinput)[1]//adjbins
Adjbins=np.zeros((np.shape(Xinput)[0],adjbins,np.shape(Xinput)[2]))
for b in range(adjbins):
Adjbins[:,b,:]=np.mean(Xinput[:,b*downsamp:(b+1)*downsamp,:],axis=1)
adjacentX=Adjbins.max(axis=1).max(axis=0)
newdf=pd.DataFrame(data=np.column_stack([Xbins,adjacentX,y]),columns=['syn','adj','wc'])
print('correlation max adjacent vs weight=',pearsonr(newdf['adj'],newdf['wc']))
############# Try RF using maximum across all neighbors
bin_set=[1,3,5]
trials=4
from sklearn.model_selection import train_test_split
reg_score={str(bn)+feat:[] for bn in bin_set for feat in ['_max_'+str(round(1/adjbins,2))+'secmean','_neighbor'+str(round(params['dist_thresh']*1e6))+'um']}
feature_import={str(bn)+feat:[] for bn in bin_set for feat in ['_max_'+str(round(1/adjbins,2))+'secmean','_neighbor'+str(round(params['dist_thresh']*1e6))+'um']}
for numbins in bin_set:
Xbins,_=rfu.downsample_X(newX,numbins,weight_change_event_df,num_events,add_start_wt=False)
neighbors,_=rfu.downsample_X(combined_neighbors_array.T,numbins,weight_change_event_df,num_events,add_start_wt=False)
features={'_max_'+str(round(1/adjbins,2))+'secmean':adjacentX,'_neighbor'+str(round(params['dist_thresh']*1e6))+'um':neighbors}
for feat,adjX in features.items():
data=np.column_stack([Xbins,adjX])
for i in range(trials):
X_train, X_test, y_train, y_test = train_test_split(data, y, test_size=1/trials)
reg,fit,pred,regscore,feat_vals=rfu.do_random_forest(X_train,y_train,X_test,y_test,feat=feat)
reg_score[str(numbins)+feat].append(regscore)
feature_import[str(numbins)+feat]=feat_vals
feature_import[str(numbins)+'features']=[str(b) for b in range(numbins)]+['adj'+str(i) for i in range(numbins)]
############# Try RF on instantaneous weight change ##############
def dwsets_for_rf(df):
features=['isi','pre_rate','neighbor'] #'pre_spike_dt',
inst_data={'dW='+feat:np.column_stack([df[feat]]) for feat in features}
for i in range(len(features)):
for j in range(i+1,len(features)):
inst_data['dW='+features[i]+','+features[j]]=np.column_stack([df[features[i]],df[features[j]]])
'''if 'cluster_length' in df:
inst_data['dW_clust']=np.column_stack([df.isi,df.cluster_length])'''
#inst_data['dW_time']=np.column_stack([df.pre_rate,df.isi,df.dWt])
df_nonan=df.dropna()
features2=['isi2','pre_interval','cluster_length','neighbor']
for feat1 in features:
for feat2 in features2:
if feat1 != feat2:
inst_data['dW='+feat1+','+feat2]=np.column_stack([df_nonan[feat1],df_nonan[feat2]])
feat0='isi';feat1='pre_rate'
for feat2 in features2:
inst_data['dW='+feat0+','+feat1+','+feat2]=np.column_stack([df_nonan[feat0],df_nonan[feat1],df_nonan[feat2]])
feat1='pre_spike_dt'
for feat2 in features2:
inst_data['dW='+feat0+','+feat1+','+feat2]=np.column_stack([df_nonan[feat0],df_nonan[feat1],df_nonan[feat2]])
y=df.dW
yno_nan=df_nonan.dW
print('\n****** dW correlation ******\n',df_nonan[['pre_spike_dt','pre_interval','isi','isi2','pre_rate','neighbor']].corr())
return inst_data,y,yno_nan
if dW:
adjacentX=dW_aligned_array[1:,:,:].mean(axis=0).T.mean(axis=1) #average over neighboring spines
inst_weight_change['neighbor']=adjacentX
dWchange_df=inst_weight_change.loc[(inst_weight_change.dW>=nochangedW)|(inst_weight_change.dW<=-nochangedW)]
inst_data,all_y,y_nonan=dwsets_for_rf(dWchange_df) #or inst_weight_change
for feat,idata in inst_data.items():
reg_score[feat]=[]
if len(idata)==len(all_y):
y=all_y
elif len(idata)==len(y_nonan):
y=y_nonan
else:
print('PROBLEM!!!!! length of y data NE length of x data')
for i in range(trials):
X_train, X_test, y_train, y_test = train_test_split(idata, y, test_size=1/trials)
reg,fit,pred,regscore,feat_vals=rfu.do_random_forest(X_train,y_train,X_test,y_test,feat=feat)
reg_score[feat].append(regscore)
feature_import[feat]=feat_vals
#feature_import['dW_features']=['pre_rate','isi','key/isi2','pre_int/pre_dt']
#np.savez(save_name+'_dW.npz',reg_score=reg_score,feature_import=feature_import)