#! /Users/joao/opt/anaconda3/bin/pythonimport colorsys
from matplotlib import pyplot as plt
import matplotlib as mpl
import netpyne
import numpy as np
import pickle
import math
import json
classColorMap():
defcolormap(data_list,c_map):
steprange = range(len(data_list))
color_map = plt.get_cmap(c_map)
print("Current map color: %s" %color_map.name)
map_colors = {}
for j,k inenumerate(data_list):
map_colors.update({k:color_map(1.*j/float(len(steprange)))})
return map_colors
classCellConnectivity():
defgetPopGIDs(conns, popFolder, target_pops):
# pop_GID_dict={}# --- Selects one population from the list of populations specifiedfor target_pop in target_pops:
print('target_pop: ',target_pop)
# --- Gets the gids of the cells in that population
target_gids = conns['net']['pops'][target_pop]['cellGids']
pop_GID_filename=popFolder+target_pop+'_GIDs.pkl'# pop_GID_dict.update({target_pop:target_gids})withopen(pop_GID_filename, 'wb') as f: pickle.dump(target_gids, f)
# return pop_GID_dictdefgenerateCellConnectivity(conns, connFolder, target_pops):
conn_dict={}
allPops = conns['net']['pops'].keys()
# --- Selects one population from the list of populations specifiedfor target_pop in target_pops:
print('target_pop: ',target_pop)
# --- Gets the gids of the cells in that population
target_gids = conns['net']['pops'][target_pop]['cellGids']
# --- Creates an empty dictionary for that population to store information
conn_dict.update({target_pop:{}})
# --- Loops through the gids of the postsynaptic cellsfor cell_gid in target_gids:
# --- Generates the filename for the .pkl file with the cell connectivity loaded in the updateConns step
cell_filename=connFolder+target_pop+'_'+str(cell_gid)+'.pkl'print('cell_gid: ',cell_gid,'\tcells remaining: ',target_gids[-1]-cell_gid)
conn_dict[target_pop].update({cell_gid:{}})
save_conns=[]
# --- Gets information about each conn from presynaptic cells in an array format [pre_gid, post_sec, post_loc, mech, weight, delay]for conn in conns['net']['cells'][cell_gid]['conns']:
iftype(conn) islist:
pre_gid = conn[0]
eliftype(conn) isdict:
pre_gid = conn['preGid']
pre_sec = conn['sec']
pre_loc = conn['loc']
pre_mec = conn['synMech']
pre_wei = conn['weight']
pre_del = conn['delay']
del conn
conn = [pre_gid,pre_sec,pre_loc,pre_mec,pre_wei,pre_del]
else:
return'invalid conn data format'# pre_gid_list.append(pre_gid)# --- Gets the keys of the popsfor pop in allPops:
# --- Checks from which pop the presynaptic cell comes fromif pre_gid in conns['net']['pops'][pop]['cellGids']:
conn.append(pop)
save_conns.append(conn)
# --- Saves the cell connectivity .pkl files | filename: pop_gid.pklwithopen(cell_filename, 'wb') as f: pickle.dump(save_conns, f)
defgenerate_cell_tags(conns,fileName):
# --- List of dictionaries ( format: [ {index1:{cell_tags1}}, {index2:{cell_tags2}}, ... ] )try: cell_tags = [{ind:conns['net']['cells'][ind].tags} for ind inrange(len(conns['net']['cells']))]
except: cell_tags = [{ind:conns['net']['cells'][ind]['tags']} for ind inrange(len(conns['net']['cells']))]
# --- Formatting to a better dictionary structure
cell_tags_dict={}
for cell_tag in cell_tags:
cell_tags_dict.update(cell_tag)
# --- Saves the cell connectivity .pkl files | filename: cell_tags.pklwithopen(fileName, 'wb') as f: pickle.dump(cell_tags_dict, f)
classConnectivity():
# --- Creates a list with all the GIDs of the presynaptic cellsdefget_pre_gids(conn_data):
pre_gids_=[]
for conn in conn_data:
pre_gids_.append(conn[0])
pre_gids = list(set(pre_gids_))
return pre_gids
# --- Creates a dictionary that organizes the spikes of the cells using the GIDs as keys# --- obs: SpktSpkid = the output of netpyne.analysis.tools.getSpktSpkid(cellGids=list_of_cell_gids)defgetSpkDict(SpktSpkid,all_cell_GIDs=None):
# --- Converting spike time and spike id to lists
spkts = SpktSpkid[1]
spkgids = [int(num) for num in SpktSpkid[2]]
# --- Creates a set of GIDs to generate the dictionary keysif all_cell_GIDs: # in this case, cells that did not spike are still allocated as an empty list# --- Removes duplicates
spkgids_set=list(set(all_cell_GIDs))
# --- Sort list of GIDs
spkgids_set.sort()
else:
# --- Removes duplicates
spkgids_set=list(set(spkgids))
# --- Sort list of GIDs
spkgids_set.sort()
# --- Create Spike dictionary
spk_dict={}
for spk_gid in spkgids_set:
spk_dict.update({spk_gid:[]})
# --- Creates a dictionary where (key=GID, value=spike times)for i inrange(len(spkts)):
spk_dict[spkgids[i]].append(spkts[i])
return spk_dict
defload_weightNorm(weightNorm_filePath):
withopen(weightNorm_filePath, 'rb') as weightNorm_fileObj: weightNorm_dict_ = pickle.load(weightNorm_fileObj)
# --- Fixing the dictionary to match pop names
weightNorm_dict={}
for cell_type in weightNorm_dict_.keys():
if (cell_type=='PT5B_full') or \
(cell_type=='PT5B_reduced'): pop_name = 'PT5B'elif cell_type.startswith('IT5B'): pop_name = 'IT5B'elif cell_type.startswith('PV4'): pop_name = 'PV4'elif cell_type.startswith('PV5A'): pop_name = 'PV5A'elif cell_type.startswith('SOM5A'): pop_name = 'SOM5A'elif cell_type.startswith('PV5B'): pop_name = 'PV5B'elif cell_type.startswith('CT6'): pop_name = 'CT6'elif cell_type.startswith('IT6'): pop_name = 'IT6'# --- A1 popselif cell_type.startswith('ITS4'): pop_name = 'ITS4'elif cell_type.startswith('ITP4'): pop_name = 'ITP4'elif cell_type.startswith('IT5A'): pop_name = 'IT5A'else:
print(cell_type)
print('Error loading weight normalization')
import sys; sys.exit()
weightNorm_dict.update({pop_name:weightNorm_dict_[cell_type]})
print('Weight norm loaded')
return weightNorm_dict
defgenerate_conn_dict(conn_data,simplify_mechs=False,weightNorm_dict=None):
# --- Obtaining a list of presynaptic populations
pre_cell_pops_=[]
for conn_ind,conn inenumerate(conn_data):
pre_cell_pops_.append(conn[-1])
pre_cell_pops=list(set(pre_cell_pops_))
# --- Allocating presynaptic pops as dictionary keys
conn_dict={}
temp_conn_dict={} # temp dict to store unsorted GIDsfor pre_cell_pop in pre_cell_pops:
conn_dict.update({pre_cell_pop:{'exc':{},'inh':{}}})
temp_conn_dict.update({pre_cell_pop:{'exc':{},'inh':{}}}) # temp dict to store unsorted GIDs# --- Splitting the dictionary in 'exc' and 'inh' connections to facilitate processing laterfor conn_ind,conn inenumerate(conn_data):
pre_cell_gid = conn[0]
pre_cell_pop = conn[-1]
if ('GABA'in conn[3]) and (pre_cell_gid notin conn_dict[pre_cell_pop]['inh'].keys()): mech_type='inh'elif ('GABA'notin conn[3]) and (pre_cell_gid notin conn_dict[pre_cell_pop]['exc'].keys()): mech_type='exc'else:
print('Unknown Mechanism: ', conn[3])
continue# --- Assigns unsorted GIDs to temporary dictionary
temp_conn_dict[pre_cell_pop][mech_type].update({pre_cell_gid:{'conns':[],'weights':[],'delays':[]}})
# --- Reassigns sorted GIDs to conn dictionaryfor pre_cell_pop in pre_cell_pops:
for mech_type in ['exc','inh']:
key_list = list(temp_conn_dict[pre_cell_pop][mech_type].keys())
key_list.sort()
for key in key_list:
conn_dict[pre_cell_pop][mech_type].update({key:temp_conn_dict[pre_cell_pop][mech_type][key]})
# --- Processing connectivity datafor conn_ind,conn inenumerate(conn_data):
pre_cell_gid = conn[0] # gid
pre_cell_sec = conn[1] # sec
pre_cell_loc = conn[2] # loc
pre_cell_mech = conn[3] # mech# --- Type of mechanism (exc or inh)if'GABA'in pre_cell_mech: mech_type = 'inh'else: mech_type = 'exc'# --- Weightif weightNorm_dict: pre_cell_weight = conn[4]/weightNorm_dict[conn[1]] # De-normalizing conn weightelse: pre_cell_weight = conn[4] # Using normalized conn weight
pre_cell_delay = conn[5] # delay
pre_cell_pop = conn[-1] # pop# --- String with all conn parameters
conn_name = pre_cell_sec+'|'+str(pre_cell_loc)+'|'+pre_cell_mech+'|'+str(pre_cell_weight)+'|'+str(pre_cell_delay)
# --- Storing the detailed connection and separate weight and delay values
conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['conns'].append(conn_name)
conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['weights'].append(pre_cell_weight)
conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['delays'].append(pre_cell_delay)
# --- Adding weights and Averaging delays of the connections from the same presynaptic cellfor pre_cell_pop in conn_dict.keys():
for mech_type in conn_dict[pre_cell_pop].keys():
for pre_cell_gid in conn_dict[pre_cell_pop][mech_type].keys():
conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['sum_weight']=sum(conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['weights'])
conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['avg_delay']=np.mean(conn_dict[pre_cell_pop][mech_type][pre_cell_gid]['delays'])
return conn_dict
defgeneratePostSpkHist(loadPop, all_post_cell_gids):
# === POSTSYNAPTIC SPIKES# --- Loads all the spike data for the postsynaptic population and stores in a dictionary format
spk_dict_post_={}
spk_dict_post={}
for pop in loadPop:
spk_dict_post_.update({pop:{}})
spk_dict_post.update({pop:{}})
post_cell_spk_info=netpyne.analysis.tools.getSpktSpkid(cellGids=all_post_cell_gids)
spk_dict = Connectivity.getSpkDict(post_cell_spk_info)
spk_dict_post_[pop].update(spk_dict)
# --- Adding empty lists for the cells that did not spike throughout the whole simulation# P.s.: these cells didnt fire in the whole sim / different of the 'silent' group, which includes these cells + others that fired only outside of the QUIET/MOVE periods, but mght have fired a e.g. 100 ms, which is out of those ranges
silent_cells = list(set(all_post_cell_gids).difference(spk_dict_post_[pop].keys()))
silent_cells.sort()
for silent_cell in silent_cells:
spk_dict_post_[pop].update({silent_cell:[]})
# --- Storing the information in a sorted dictionary
keylist = list(spk_dict_post_[pop].keys())
keylist.sort()
for key in keylist:
spk_dict_post[pop].update({key:spk_dict_post_[pop][key]})
return spk_dict_post, silent_cells
defgeneratePostSpkConn(loadPop, all_post_cell_gids, connFolder):
# === POSTSYNAPTIC CONNECTIVITY# --- Load all connectivity from presynaptic to postsynaptic cell beforehand
conn_dict_post={}
for pop in loadPop:
conn_dict_post.update({pop:{}})
print('\n\n##############################################')
print(' Generating %s %s connections '%(pop,len(all_post_cell_gids)))
print('##############################################')
for post_cell_gid_ind, post_cell_gid inenumerate(all_post_cell_gids):
if (post_cell_gid-all_post_cell_gids[0])%100==0:
print('\t %s cells left'%((len(all_post_cell_gids)-post_cell_gid_ind)))
conn_dict_post[pop].update({post_cell_gid:[]})
loadFile=connFolder+pop+'_'+str(post_cell_gid)+'.pkl'withopen(loadFile, 'rb') as cell_fileObj: conn_data = pickle.load(cell_fileObj)
conn_dict_post[pop][post_cell_gid]=conn_data
return conn_dict_post
defgeneratePreSpkHist(conn_dict_post, all_post_cell_gids):
# === PRESYNAPTIC SPIKES# --- Obtaining all the GIDs of presynaptic cells to load spikes beforehand
pre_gids_=[]
for pop in conn_dict_post.keys():
for post_cell_gid in all_post_cell_gids:
for i inrange(len(conn_dict_post[pop][post_cell_gid])):
pre_gids_.append(conn_dict_post[pop][post_cell_gid][i][0])
pre_cell_gids=list(set(pre_gids_))
pre_cell_gids.sort()
# --- Dictionary to store all presynaptic cell spikes by GID
pre_cell_spks=netpyne.analysis.tools.getSpktSpkid(cellGids=pre_cell_gids)
spk_dict_pre=Connectivity.getSpkDict(pre_cell_spks)
# --- Adding empty lists to represent cells that did not spike, but are still connectedfor pre_cell_gid in pre_cell_gids:
pre_spiking_list = list(spk_dict_pre.keys())
if pre_cell_gid notin pre_spiking_list: spk_dict_pre.update({pre_cell_gid:[]})
return spk_dict_pre
defgenerateSpikeHistogram(conn_dict, time_bins, post_cell_spk_times, connected_spk_dict_pre):
# ------------------------------------------------------------------------------------------------- ## Generates STPH for each postsynaptic cell# ------------------------------------------------------------------------------------------------- ## -------------------------------- Input arguments ------------------------------------------------ ## conn_dict : connection details# time_bins : time bins of the histogram to build# post_cell_spk_times : post cell spike times# connected_spk_dict_pre : dict{pre_cell_gid:pre cell spike times} (only connected cells)# ------------------------------------------------------------------------------------------------- ## --- Dictionary to store the spike histogram data for each postsynaptic cell
post_cell_spike_hist_dict={}
# --- Iterates over the presynaptic popfor pre_pop in conn_dict.keys():
post_cell_spike_hist_dict.update({pre_pop:{}})
# --- Iterates over the types of synaptic mechanisms ('exc' or 'inh')for mech_type in conn_dict[pre_pop].keys():
post_cell_spike_hist_dict[pre_pop].update({mech_type:[]})
# --- List to store sum of weighted normalized histograms
weighted_spike_histograms=[]
# --- Iterates over the GIDs of the presynaptic cellsfor pre_cell_gid in conn_dict[pre_pop][mech_type].keys():
valid_spike_differences = []
# --- Iterates over the list of spike times from post cellfor post_cell_spkt in post_cell_spk_times:
# --- Iterates over the list of spike times from pre cellfor pre_cell_spkt in connected_spk_dict_pre[pre_cell_gid]:
# --- Delay to propagate the spike
pre_spike_delay = conn_dict[pre_pop][mech_type][pre_cell_gid]['avg_delay']
# --- Effective spike time (takes into account the spike delay to reach the postsynaptic cell)
effective_pre_cell_spkt = pre_cell_spkt+pre_spike_delay
# --- Pre spike occurred after Post spikeif (effective_pre_cell_spkt)>post_cell_spkt: continue# --- Pre spike occurred before the max interval from Post spike (e.g.: 300-(18+5)>200 = True || 300-(150+5)>200 = False)elif post_cell_spkt-(effective_pre_cell_spkt)>time_bins[-1]: continueelse:
# --- Spike difference (post-(pre+delay))
spike_difference = post_cell_spkt-effective_pre_cell_spkt
# print(pre_pop,mech_type,pre_cell_gid,post_cell_spkt,pre_cell_spkt,pre_spike_delay,spike_difference)
valid_spike_differences.append(spike_difference)
valid_spike_differences.sort()
# print(valid_spike_differences)# --- Calculating spike histogram
spike_histogram,edges = np.histogram(valid_spike_differences,range=[0,time_bins[-1]],bins=len(time_bins))
# print('spike histogram: ', spike_histogram)# --- BREAKS THE CODE SO THAT EMPTY HISTOGRAMS ARE NOT STOREDifall(spk_count == 0for spk_count in spike_histogram): continue
list_spike_histogram = list(spike_histogram)
# --- Normalizing the histogram by the number of postsynaptic spikes, so that the firing frequency of a postsynaptic cell does not affect the results
normalized_spike_histogram = [bin/(len(post_cell_spk_times)) forbinin list_spike_histogram]
# print(pre_pop,mech_type,pre_cell_gid, '\t normalized histogram: ',normalized_spike_histogram)
sum_weight = conn_dict[pre_pop][mech_type][pre_cell_gid]['sum_weight']
# print(list(normalized_spike_histogram),sum_weight)
weighted_spike_histogram=[norm_spk*sum_weight for norm_spk in normalized_spike_histogram]
# print(pre_pop,mech_type,pre_cell_gid, '\t weighted histogram: ',weighted_spike_histogram,'\n')
weighted_spike_histograms.append(weighted_spike_histogram)
# print(pre_pop,mech_type,pre_cell_gid,'weighted_spike_histograms: ',weighted_spike_histograms)# --- Sum of weighted normalized spike histograms (empty histograms removed)
sum_weighted_spike_histograms_=np.sum(weighted_spike_histograms,axis=0)
try: sum_weighted_spike_histograms = list(sum_weighted_spike_histograms_)
except: sum_weighted_spike_histograms = [0]*len(time_bins)
# print(type(sum_weighted_spike_histograms), '\t', sum_weighted_spike_histograms)# --- Final data: List with the sum of weighted normalized spike histograms (empty histograms removed) for each time_bin
post_cell_spike_hist_dict[pre_pop][mech_type]=sum_weighted_spike_histograms
return post_cell_spike_hist_dict
classPlotFigures():
defformatData(pop_spike_hist_dict,all_pops):
mech_types = ['exc','inh']
# --- Dictionary to hold plotting variables # (because data is organized as dict[post_pop][post_cell_gid][pre_pop][mech_type][pre_cell_gid], but plots are easier organized by dict[mech_type][pre_pop])# (also, [post_pop] is not a major factor, because spike_hist_dict only contains a single [post_pop], which also goes in the filename)
pre_cell_pops_=[]
for post_cell_gid in pop_spike_hist_dict.keys():
# print(pop_spike_hist_dict[post_cell_gid].keys())for pre_pop in pop_spike_hist_dict[post_cell_gid].keys():
pre_cell_pops_.append(pre_pop)
pre_cell_pops=list(set(pre_cell_pops_))
pre_cell_pops.sort()
# --- Ordered populations according to network distribution
ordered_pre_pops=[]
for o_pop in all_pops:
if o_pop in pre_cell_pops: ordered_pre_pops.append(o_pop)
# print('Pre pop info | \npre_cell_pops: ',pre_cell_pops, '\nordered_pre_pops: ', ordered_pre_pops)# --- Allocating keys for the plotting dictionary
plot_spike_hist_dict={}
for mech_type in mech_types:
plot_spike_hist_dict.update({mech_type:{}})
for pre_cell_pop in ordered_pre_pops:
plot_spike_hist_dict[mech_type].update({pre_cell_pop:{}})
# --- Filtering out the empty histograms (post_cell has valid spikes, but out of timeRange)
valid_post_cell_gids=[]
for post_cell_gid in pop_spike_hist_dict.keys():
# --- Filters out empty dictionaries (no histogram, because the cell fired out of timeRange)ifany(pop_spike_hist_dict[post_cell_gid]): valid_post_cell_gids.append(post_cell_gid)
# --- Calculating Mean and Std across postsynaptic cellsfor pre_cell_pop in ordered_pre_pops:
for mech_type in mech_types:
post_cell_hists=[]
# --- Iterates over valid histograms (cells that fired) - cells that are originally silent/quiet are filtered outfor post_cell_gid in valid_post_cell_gids:
if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
post_cell_hist = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type]
post_cell_hists.append(post_cell_hist)
post_pop_hist_mean = np.mean(post_cell_hists,axis=0)
post_pop_hist_std = np.std(post_cell_hists,axis=0)
# print(mech_type, pre_cell_pop, '->', pop,' : ', len(post_cell_hists))# print(post_pop_hist_mean,post_pop_hist_std,'\n')
plot_spike_hist_dict[mech_type][pre_cell_pop].update({'mean':post_pop_hist_mean,'std':post_pop_hist_std})
return plot_spike_hist_dict,ordered_pre_pops,valid_post_cell_gids
defplotSPTHtraces(plot_spike_hist_dict, ordered_pre_pops, time_bins, map_colors, divide_plots=False, select_plot_pops=None):
print('SPTH traces method')
# --- P.s.: Select_plot_pops should be a dictionary with the following format# select_plot_pops = {'exc':['pre_pop1', 'pre_pop2', ...], 'inh':['pre_pop2', 'pre_pop5', ...]}# e.g.:# select_plot_pops = {'exc':['PT5B','TVL','TPO','IT2','IT5A'],'inh':['PV5A','PV5B',]}# --- Figure
plt.figure(figsize=(15,10))
# plt.suptitle('spk_histogram_data: '+pop+' histogram | '+network_state+' | '+timeRange_name+' ms')# --- Plot variables
plot_xlim=(0,55)
plot_xticks=list(range(5,51,5))
plot_xticks_labels=[str(-1*(x))+' ← '+str(-1*(x-5)) for x in plot_xticks]
# plot_ylim=None
plot_ylim=(-1,17.5)
# plot_yticks=None
plot_yticks=[0,5,10,15]
plot_errorbar_capsize=3
plot_errorbar_alpha=0.75
plot_pre_pop=ordered_pre_pops
for mech_ind, mech_type inenumerate(plot_spike_hist_dict.keys()):
mech_color=['b','r']
plot_errorbar_marker=['^','v']
if divide_plots: plt.subplot(1, 2, mech_ind+1)
for pre_pop in plot_pre_pop:
# --- Change colors for the line plots
c=map_colors[pre_pop]
x_data = time_bins
y_data = plot_spike_hist_dict[mech_type][pre_pop]['mean']
y_std = plot_spike_hist_dict[mech_type][pre_pop]['std']
ifall(y_d == 0for y_d in y_data) andall(y_s == 0for y_s in y_std): continueif select_plot_pops isnotNone:
if pre_pop notin select_plot_pops[mech_type]: continueiflen(y_data)==len(x_data):
line_style='-'if pre_pop.startswith('PT5B'):c='lightgrey'; line_style='--'# changes color of 'PT5B' to grey
plt.plot(x_data,y_data,line_style,color=c,linewidth=3,label=pre_pop+'_'+mech_type)
for ind_p,p inenumerate(y_std):
# --- Plotting traces only# plt.plot(time_bins[ind_p],mean_exc_spks[ind_p],'o',markerfacecolor='blue',markeredgecolor='w')# --- Plotting errorbars. source: https://stackoverflow.com/questions/22481854/plot-mean-and-standard-deviation
plt.errorbar(x_data[ind_p],y_data[ind_p], y_std[ind_p],linestyle='None', color=c, marker=plot_errorbar_marker[mech_ind], markeredgecolor=mech_color[mech_ind], capsize=plot_errorbar_capsize, alpha=plot_errorbar_alpha)
# plt.errorbar(x_data[ind_p],y_data[ind_p], y_std[ind_p],linestyle='None', color=mech_color[mech_ind], marker=plot_errorbar_marker, capsize=plot_errorbar_capsize, alpha=plot_errorbar_alpha)else: print('skipping plot for ', mech_type, pre_pop)
# --- Plot formattingif divide_plots: plt.title(mech_type)
else: plt.title('Presynaptic population')
plt.legend(loc='upper left')
plt.xlim(plot_xlim)
plt.xticks(ticks=plot_xticks, labels=plot_xticks_labels)
plt.xlabel('Time interval before Postsynaptic spike (ms)')
if plot_ylim isnotNone: plt.ylim(plot_ylim)
if plot_yticks isnotNone: plt.yticks(ticks=plot_yticks)
plt.ylabel('Number of presynaptic spikes * synaptic strength')
plt.gca().invert_xaxis()
defbarPlot(pop_spike_hist_dict,ordered_pre_pops,valid_post_cell_gids,max_time_index,all_pops,ax_lim=None):
mech_types = ['exc','inh']
# --- Allocating keys for the plotting the bar graph
plot_bar_hist_dict={}
for mech_type in mech_types:
plot_bar_hist_dict.update({mech_type:{}})
for pre_cell_pop in ordered_pre_pops:
plot_bar_hist_dict[mech_type].update({pre_cell_pop:{}})
# --- Calculating Mean and Std across postsynaptic cellsfor mech_type in mech_types:
for pre_cell_pop in ordered_pre_pops:
post_cell_hists_windowSum=[]
for post_cell_gid in valid_post_cell_gids:
if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]
post_cell_hists_windowSum.append(sum(post_cell_hist_windowSum))
post_pop_hist_windowSum_mean = np.mean(post_cell_hists_windowSum)
post_pop_hist_windowSum_std = np.std(post_cell_hists_windowSum)
plot_bar_hist_dict[mech_type][pre_cell_pop].update({'mean':post_pop_hist_windowSum_mean,'std':post_pop_hist_windowSum_std})
# --- Draw Figure - Bar plot v4
plt.figure(figsize=(5,5))
# plt.suptitle('Bar plot v4: '+pop+' histogram | '+str(max_time)+' max time | '+network_state+' | '+timeRange_name+' ms')# plt.grid(color='lightgrey')# --- Plot style
divide_plots=False
sum_of_means_exc=0
sum_of_means_inh=0
list_of_stds_exc=[] # sum of stds: https://study.com/skill/learn/how-to-calculate-the-standard-deviation-of-the-sum-of-two-random-variables-explanation.html
list_of_stds_inh=[] # sum of stds: https://study.com/skill/learn/how-to-calculate-the-standard-deviation-of-the-sum-of-two-random-variables-explanation.htmlfor pre_pop_ind, pre_pop inenumerate(ordered_pre_pops):
# for pre_pop_ind, pre_pop in enumerate(plot_bar_hist_dict['exc'].keys()):for mech_ind,mech_type inenumerate(plot_bar_hist_dict.keys()):
if divide_plots: plt.subplot(1, 2, mech_ind+1)
if mech_type == 'exc':
c = 'royalblue'if plot_bar_hist_dict['inh'][pre_pop]['mean'] == 0and plot_bar_hist_dict['inh'][pre_pop]['std'] == 0: displace=0else: displace=-0.2elif mech_type == 'inh':
c = 'r'if plot_bar_hist_dict['exc'][pre_pop]['mean'] == 0and plot_bar_hist_dict['exc'][pre_pop]['std'] == 0: displace=0else: displace=0.2else:
c = 'k'# --- skip empty barsif plot_bar_hist_dict[mech_type][pre_pop]['mean'] == 0and plot_bar_hist_dict[mech_type][pre_pop]['std'] == 0: continue
x_data = pre_pop_ind
if divide_plots: x_name=pre_pop
else: x_name=pre_pop+'_'+mech_type
plt.barh( x_data+displace,
plot_bar_hist_dict[mech_type][pre_pop]['mean'],
0.4,
xerr=plot_bar_hist_dict[mech_type][pre_pop]['std'],
capsize=3,
color=c,
label=x_name)
if mech_type == 'exc':
sum_of_means_exc+=plot_bar_hist_dict[mech_type][pre_pop]['mean']
list_of_stds_exc.append(plot_bar_hist_dict[mech_type][pre_pop]['std'])
else:
sum_of_means_inh+=plot_bar_hist_dict[mech_type][pre_pop]['mean']
list_of_stds_inh.append(plot_bar_hist_dict[mech_type][pre_pop]['std'])
# --- Square of the stds
std_exc_sqr=[std**2for std in list_of_stds_exc]
std_inh_sqr=[std**2for std in list_of_stds_inh]
sum_std_exc=math.sqrt(sum(std_exc_sqr))
sum_std_inh=math.sqrt(sum(std_inh_sqr))
# --- plot of sum of bars - exc
plt.barh( len(ordered_pre_pops)+1-0.2,
# len(plot_bar_hist_dict[mech_type].keys())+1-0.2,
sum_of_means_exc,
0.4,
xerr=sum_std_exc,
capsize=3,
color='royalblue',)
# --- plot of sum of bars - inh
plt.barh( len(ordered_pre_pops)+1+0.2,
# len(plot_bar_hist_dict[mech_type].keys())+1+0.2,
sum_of_means_inh,
0.4,
xerr=sum_std_inh,
capsize=3,
color='r',)
# --- Format figure
plot_labels = ordered_pre_pops+['','Sum']
# plot_labels = list(plot_bar_hist_dict['exc'].keys())+['','Sum']
plt.yticks(ticks=list(range(len(plot_labels))),labels=plot_labels)
plt.gca().invert_yaxis()
if ax_lim isnotNone: plt.xlim((ax_lim))
print('bar plot method')
defboxPlot(pop_spike_hist_dict, ordered_pre_pops, valid_post_cell_gids, max_time_index, ax_lim=None, select_pops=None):
mech_types = ['exc','inh']
mech_colors = ['royalblue','r']
# --- Boxplot - Figure
fig = plt.figure(figsize=(5,5))
# fig.suptitle('Spike Histogram Window Sum Boxplot: '+pop+' histogram | '+str(max_time)+' max time | '+network_state+' | '+timeRange_name+' ms')
boxplot_data=[]
boxplot_labels=[]
boxplot_colors=[]
# --- Calculating Mean and Std across postsynaptic cells
boxplot_hist_dict={} # --- Dictionary with data for Boxplotfor mech_type in mech_types:
boxplot_hist_dict.update({mech_type:{}})
for pre_cell_pop in ordered_pre_pops:
boxplot_hist_dict[mech_type].update({pre_cell_pop:[]})
post_cell_hists_windowSum=[]
for post_cell_gid in valid_post_cell_gids:
if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]
post_cell_hists_windowSum.append(sum(post_cell_hist_windowSum))
# --- Dictionary with data for Boxplot
boxplot_hist_dict[mech_type][pre_cell_pop]=post_cell_hists_windowSum
# --- Boxplot Datafor mech_ind,mech_type inenumerate(boxplot_hist_dict.keys()):
for pre_pop in boxplot_hist_dict[mech_type].keys():
if select_pops isnotNone:
if ((mech_type == mech_types[mech_ind]) and (pre_pop in select_pops[mech_ind])):
boxplot_data.append(boxplot_hist_dict[mech_type][pre_pop])
boxplot_labels.append(pre_pop)
boxplot_colors.append(mech_colors[mech_ind])
else: continue# --- quits plotting if pre_pop not in the plot listelse:
boxplot_data.append(boxplot_hist_dict[mech_type][pre_pop])
boxplot_labels.append(pre_pop)
boxplot_colors.append(mech_colors[mech_ind])
ax = fig.add_subplot(111)
bp = ax.boxplot(boxplot_data,patch_artist = True,notch ='True')
bp_colors = [bp_color for bp_color in boxplot_colors]
for patch, bp_color inzip(bp['boxes'], bp_colors):
patch.set_facecolor(bp_color)
# Changing color and linewidth of whiskersfor whisker in bp['whiskers']:
whisker.set(color ='grey',linewidth = 1.5,linestyle =":")
# changing color and linewidth of mediansfor median in bp['medians']:
median.set(color ='k',linewidth = 1)
# changing style of fliersfor flier in bp['fliers']:
flier.set(marker ='.',color ='k',alpha = 0.5)
boxplot_ticks = [tick+1for tick inrange(len(boxplot_labels))]
ax.set_xticks(ticks=boxplot_ticks,rotation=45,labels=boxplot_labels)
if ax_lim isnotNone: ax.set_ylim(ax_lim)
defviolinPlot(pop_spike_hist_dict, ordered_pre_pops, valid_post_cell_gids, max_time_index, ax_lim=None, select_pops=None):
mech_types = ['exc','inh']
mech_colors = ['royalblue','r']
# --- Violin - Figure
fig = plt.figure(figsize=(5,5))
# fig.suptitle('Spike Histogram Window Sum Violin: '+pop+' histogram | '+str(max_time)+' max time | '+network_state+' | '+timeRange_name+' ms')
violinplot_data=[]; violinplot_labels=[]; violinplot_colors=[]
# --- Calculating Mean and Std across postsynaptic cells
violinplot_hist_dict={} # --- Dictionary with data for Violinfor mech_type in mech_types:
violinplot_hist_dict.update({mech_type:{}})
for pre_cell_pop in ordered_pre_pops:
violinplot_hist_dict[mech_type].update({pre_cell_pop:[]})
post_cell_hists_windowSum=[]
for post_cell_gid in valid_post_cell_gids:
if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]
post_cell_hists_windowSum.append(sum(post_cell_hist_windowSum))
# --- Dictionary with data for Violin
violinplot_hist_dict[mech_type][pre_cell_pop]=post_cell_hists_windowSum
# --- Violinfor mech_ind,mech_type inenumerate(violinplot_hist_dict.keys()):
for pre_pop in violinplot_hist_dict[mech_type].keys():
if select_pops isnotNone:
if ((mech_type == mech_types[mech_ind]) and (pre_pop in select_pops[mech_ind])):
violinplot_data.append(violinplot_hist_dict[mech_type][pre_pop])
violinplot_labels.append(pre_pop)
violinplot_colors.append(mech_colors[mech_ind])
else: continue# --- quits plotting if pre_pop not in the plot listelse:
violinplot_data.append(violinplot_hist_dict[mech_type][pre_pop])
violinplot_labels.append(pre_pop)
violinplot_colors.append(mech_colors[mech_ind])
ax = fig.add_subplot(111)
bp = ax.violinplot(violinplot_data)
violinplot_ticks = [tick+1for tick inrange(len(violinplot_labels))]
ax.set_xticks(ticks=violinplot_ticks,rotation=45,labels=violinplot_labels)
if ax_lim isnotNone: ax.set_ylim(ax_lim)
defscatterPlot(pop_spk_dict_post, timeRange, all_post_cell_gids, cell_tags_dict, select_max_val = None, select_colormap = None, use_x_position=True):
# --- Scatter - Figure
fig = plt.figure(figsize=(5,5))
if select_colormap isnotNone: c_map = select_colormap
else: c_map = 'jet'if select_max_val isnotNone: reference_value = select_max_val+1else:
spk_count = [len(pop_spk_dict_post[key]) for key in pop_spk_dict_post.keys()]
reference_value = max(spk_count)
scatter_map_colors=ColorMap.colormap(list(range(reference_value)),c_map=c_map)
for post_cell_ind, post_cell_gid inenumerate(all_post_cell_gids):
cell_position_x = cell_tags_dict[post_cell_gid]['xnorm']
cell_position_y = cell_tags_dict[post_cell_gid]['ynorm']
if post_cell_gid notin (pop_spk_dict_post.keys()): cell_firing=0else:
valid_spikes=[]
for spkt in pop_spk_dict_post[post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_num=len(valid_spikes)
cell_firing = spk_num
# --- Different colors for cells with no spikesif cell_firing>0:
marker_color = scatter_map_colors[cell_firing]
edge_color = None
alpha = 1else:
marker_color = 'w'
edge_color = 'k'
alpha = 0.25# --- Choose which data to showif use_x_position:
data_1 = cell_position_x
data_2 = cell_position_y
else:
data_1 = cell_firing
data_2 = cell_position_y
plt.plot(data_1,data_2,marker='o',color=marker_color,markeredgecolor=edge_color,alpha=alpha)
plt.gca().invert_yaxis()
# --- Method to normalize a datasetdefNormalizeData(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
defformatMultivariateData(pop_spk_dict_post,pop_spike_hist_dict,featuredPops,valid_post_cell_gids,timeRange,max_time_index,cell_tags_dict,target_data = 'spk',isolate_mech=None,c_map='jet'):
# --- PCA TUTORIAL: # https://github.com/mGalarnyk/Python_Tutorials/blob/master/Sklearn/PCA/PCA_Data_Visualization_Iris_Dataset_Blog.ipynb# --- Pandas Dataframe tutorials:# https://www.includehelp.com/python/dataframe-is-it-pass-by-value-or-pass-by-reference.aspx# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.from_dict.html
mech_types = ['exc','inh']
# --- Allocating keys for the Multivariate analysis
Multivariate_hist_dict={}
# --- Loops through the list of cells with valid spikesfor post_cell_gid in valid_post_cell_gids:
Multivariate_hist_dict.update({post_cell_gid:{}})
for mech_type in mech_types:
Multivariate_hist_dict[post_cell_gid].update({mech_type:{}})
for pre_cell_pop in featuredPops:
Multivariate_hist_dict[post_cell_gid][mech_type].update({pre_cell_pop:{}})
print(Multivariate_hist_dict[post_cell_gid][mech_type].keys())
# --- Calculating Mean and Std across postsynaptic cellsfor post_cell_gid in Multivariate_hist_dict.keys():
for mech_type in Multivariate_hist_dict[post_cell_gid].keys():
for pre_cell_pop in Multivariate_hist_dict[post_cell_gid][mech_type].keys():
# --- Checks if pre_cell_pop is part of that cell's dictionary (SOM5A and PV5A projections are not present in all cells)if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
Multivariate_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# Multivariate_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]# else:# print(pre_cell_pop, ' not in ', pop_spike_hist_dict[post_cell_gid].keys())
Multivariate_hist_dict[post_cell_gid][mech_type].update({pre_cell_pop:sum(Multivariate_hist_windowSum)})
# --- Generate Colormapif target_data == 'spk': # Based on number of spikes# --- Calculating the max number of spikes to generate colormap
spk_nums=[]
for post_cell_gid in Multivariate_hist_dict.keys():
valid_spikes=[]
for spkt in pop_spk_dict_post[post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_nums.append(len(valid_spikes))
max_spk_num=max(spk_nums)
pca_map_colors=ColorMap.colormap(list(range(max_spk_num+1)),c_map=c_map)
# --- List with Colormap boundaries
colormap_boudaries=[0,max_spk_num]
elif target_data == 'ynorm': # Based on cell position in the y-axis
ynorm_list=[]
for post_cell_gid in Multivariate_hist_dict.keys():
ynorm_list.append(cell_tags_dict[post_cell_gid]['ynorm'])
ynorm_list.sort()
steprange = range(len(ynorm_list))
color_map = plt.get_cmap(c_map)
# --- List with Colormap boundaries
colormap_boudaries=[min(ynorm_list),max(ynorm_list)]
pca_map_reference={}
for j,k inenumerate(ynorm_list):
colormap_val=color_map(1.*j/float(len(steprange)))
pca_map_reference.update({k:colormap_val})
pca_map_colors={}
for post_cell_gid in Multivariate_hist_dict.keys():
pca_map_colors.update({post_cell_gid:pca_map_reference[cell_tags_dict[post_cell_gid]['ynorm']]})
elif target_data[0] == 'true_spk':
# --- Calculating the max number of spikes to generate colormap
spk_nums=[]
for post_cell_gid in Multivariate_hist_dict.keys():
valid_spikes=[]
for spkt in target_data[1][post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_nums.append(len(valid_spikes))
max_spk_num=max(spk_nums)
pca_map_colors=ColorMap.colormap(list(range(max_spk_num+1)),c_map=c_map)
# --- List with Colormap boundaries
colormap_boudaries=[0,max_spk_num]
elif target_data[0] == 'fixed_spk_range':
# --- Generate colormap for fixed spike range
pca_map_colors=ColorMap.colormap(list(range(target_data[2][-1]+1)),c_map=c_map)
# --- List with Colormap boundaries
colormap_boudaries=target_data[2]
elif target_data[0] == 'capped_spk_range':
# --- Generate colormap for fixed spike range
pca_map_colors=ColorMap.colormap(list(range(target_data[2][-1]+1)),c_map=c_map)
# --- List with Colormap boundaries
colormap_boudaries=target_data[2]
if isolate_mech isnotNone: print('plotting only a single mech: ', isolate_mech)
# --- Creating a dictionary to organize the data in the Pandas DataFrame format
DataFrame_dict={}
for post_cell_gid in Multivariate_hist_dict.keys():
DataFrame_dict.update({post_cell_gid:{}})
for mech_type in Multivariate_hist_dict[post_cell_gid].keys():
# --- Isolating a single mech_type in the PCAif isolate_mech isnotNone:
if mech_type == isolate_mech: continuefor pre_cell_pop in Multivariate_hist_dict[post_cell_gid][mech_type].keys():
DataFrame_dict[post_cell_gid].update({pre_cell_pop+'_'+mech_type:Multivariate_hist_dict[post_cell_gid][mech_type][pre_cell_pop]})
for post_cell_ind,post_cell_gid inenumerate(Multivariate_hist_dict.keys()):
if target_data == 'spk':
# --- Adding the number of spikes as a feature
valid_spikes=[]
for spkt in pop_spk_dict_post[post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_num=len(valid_spikes)
DataFrame_dict[post_cell_gid].update({'target':pca_map_colors[spk_num]})
elif target_data == 'ynorm':
DataFrame_dict[post_cell_gid].update({'target':pca_map_colors[post_cell_gid]})
# DataFrame_dict[post_cell_gid].update({'target':pca_map_colors[ynorm_norm[post_cell_ind]]})elif target_data[0] == 'true_spk':
# --- Adding the number of spikes as a feature
valid_spikes=[]
for spkt in target_data[1][post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_num=len(valid_spikes)
DataFrame_dict[post_cell_gid].update({'target':pca_map_colors[spk_num]})
elif target_data[0] == 'fixed_spk_range':
# --- Adding the number of spikes as a feature
valid_spikes=[]
for spkt in target_data[1][post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_num=len(valid_spikes)
DataFrame_dict[post_cell_gid].update({'target':pca_map_colors[spk_num]})
elif target_data[0] == 'capped_spk_range':
# --- Adding the number of spikes as a feature
valid_spikes=[]
for spkt in target_data[1][post_cell_gid]:
if spkt>=timeRange[0] and spkt<timeRange[1]: valid_spikes.append(spkt)
spk_num=len(valid_spikes)
# --- capped data to upper limitif spk_num>target_data[2][-1]: spk_num=target_data[2][-1]
DataFrame_dict[post_cell_gid].update({'target':pca_map_colors[spk_num]})
return DataFrame_dict, colormap_boudaries
defcreateDataFrame(DataFrame_dict):
import pandas as pd
from sklearn.preprocessing import StandardScaler
# --- Creating a Pandas DataFrame
df = pd.DataFrame.from_dict(DataFrame_dict, orient='index')
df_features = list(df.columns)[:-1] # removing last key, because is the 'target', used as classifier variable
df_values = df.loc[:, df_features].values
# --- Creating a new Dataframe for Target values because the GIDs and indexes were being mixed up
df_target_ = df.loc[:,['target']].values
df_target = pd.DataFrame(df_target_)
# --- Data Normalization
df_values_Norm = StandardScaler().fit_transform(df_values)
return df,df_features,df_values,df_target,df_values_Norm
defplotPCA(DataFrame_dict,pop_spk_dict_post,n_components=2):
import pandas as pd
from sklearn.decomposition import PCA
df,df_features,df_values,df_target,df_values_Norm = PlotFigures.createDataFrame(DataFrame_dict)
pca = PCA(n_components=n_components)
principalComponents = pca.fit_transform(df_values_Norm)
principalDf = pd.DataFrame(data = principalComponents, columns = ['PC1', 'PC2'])
finalDf = pd.concat([principalDf, df_target], axis = 1)
fig = plt.figure(figsize = (5,5))
ax = fig.add_subplot(1,1,1)
ax.set_xlabel('Principal Component 1', fontsize = 15)
ax.set_ylabel('Principal Component 2', fontsize = 15)
colors=list(df[['target']].values)
post_cell_gid_list=list(pop_spk_dict_post.keys())
off_PCA_value_gids=[]
for ind,color inenumerate(colors):
ax.scatter(finalDf.loc[ind, 'PC1'], finalDf.loc[ind, 'PC2'], c = color, s = 50)
if finalDf.loc[ind, 'PC1']>5: off_PCA_value_gids.append(post_cell_gid_list[ind])
# --- Removing boxes from plot axis
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# --- Removing ticks from plot
ax.set_xticks(ticks=[],labels=None)
ax.set_yticks(ticks=[],labels=None)
return pca
defplotUMAP(DataFrame_dict,n_neighbors=30,color_criteria='gid',cellType_dict=None,c_map='jet',colormap_boudaries=None):
import umap
import sklearn
df,df_features,df_values,df_target,df_values_Norm = PlotFigures.createDataFrame(DataFrame_dict)
reducer = umap.UMAP(random_state=1, n_neighbors=n_neighbors, min_dist=0.0, n_components=2) # 15, 0.1
embedding = reducer.fit_transform(df_values_Norm)
embedding.shape
if color_criteria == 'gid': colors=ColorMap.colormap(list(df.index),c_map=c_map)
elif color_criteria == 'cellType':
colors={}
for gid inlist(df.index):
if gid in cellType_dict['enhanced']: c='seagreen'elif gid in cellType_dict['suppressed']:c='gold'else: c='k'
colors.update({gid:c})
# --- The 'target' values here are already passed as a colormapelif (color_criteria == 'ynorm') or (color_criteria == 'spk') or (color_criteria[0] == 'true_spk') or (color_criteria[0] == 'fixed_spk_range') or (color_criteria[0] == 'capped_spk_range'):
colors_list=list(df[['target']].values)
colors={}
for ind,gid inenumerate(list(df.index)):
colors.update({gid:colors_list[ind][0]})
if colormap_boudaries isnotNone:
# c = np.arange(1, len(colors_list.keys())+1)
norm = mpl.colors.Normalize(vmin=min(colormap_boudaries), vmax=max(colormap_boudaries))
colorbar_cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.jet)
colorbar_cmap.set_array([])
fig, ax = plt.subplots(figsize=(5,5))
# key is the GID of the cellsfor ind,key inenumerate(colors.keys()):
marker='.'
ax.plot( embedding[ind, 0], embedding[ind, 1], color=colors[key], marker=marker, markersize=3,)
ax.set_aspect('equal', 'datalim')
# --- Removing boxes from plot axis
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# --- Removing ticks from plot
ax.set_xticks(ticks=[],labels=None)
ax.set_yticks(ticks=[],labels=None)
if colormap_boudaries isnotNone: fig.colorbar(colorbar_cmap, ticks=colormap_boudaries)
return embedding, df
defapplyKMeans(dataset,n_clusters = 4):
# apply k-means to UMAPfrom sklearn import cluster
kmeans_dataset = cluster.KMeans(n_clusters=n_clusters).fit(dataset)
return kmeans_dataset
defplotKMeans(kmeans_dataset, dataset, dataframe,customLabels=None, customColors=None):
cluster_dictionary={}
dataFrame_gids=list(dataframe.index)
if customLabels isnotNone: kMeans_labels=customLabels
else: kMeans_labels=kmeans_dataset.labels_
kMeans_list=list(kMeans_labels)
kMeans_groups=list(set(kMeans_list))
if customColors isnotNone: cluster_colors_list=customColors
else:
cluster_colors=ColorMap.colormap(list(range(kmeans_dataset.n_clusters)),c_map='jet')
cluster_colors_list=list(cluster_colors.values())
fig, ax = plt.subplots(figsize=(5,5))
for i, c inzip(range(kmeans_dataset.n_clusters), cluster_colors_list):
ax.scatter(dataset[kMeans_labels==i, 0], dataset[kMeans_labels==i, 1], color=c, marker='.',s=5,label='Cluster '+str(i+1),)
ax.set_aspect('equal', 'datalim')
lgnd=ax.legend(loc='lower left',framealpha=0)
for i,lg inenumerate(lgnd.legendHandles):
# change the marker size manually for both lines
lgnd.legendHandles[i]._sizes = [500]
for i inlist(kMeans_groups):
cluster_dictionary.update({i:[]})
for ind, i inenumerate(kMeans_list):
cluster_dictionary[i].append(dataFrame_gids[ind])
# --- Removing boxes from plot axis
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# --- Removing ticks from plot
ax.set_xticks(ticks=[],labels=None)
ax.set_yticks(ticks=[],labels=None)
return cluster_dictionary
defbarplotKMeans2(cluster_dictionary, pop_spike_hist_dict, featuredPops, max_time_index, divide_plots=False):
mech_types=['exc','inh']
print('barplotKMeans')
# --- Allocating keys for the plotting the bar graph
plot_bar_kMeans={}
for kMeans_group in cluster_dictionary.keys():
plot_bar_kMeans.update({kMeans_group:{}})
for mech_type in mech_types:
plot_bar_kMeans[kMeans_group].update({mech_type:{}})
for pre_cell_pop in featuredPops:
plot_bar_kMeans[kMeans_group][mech_type].update({pre_cell_pop:{}})
for kMeans_group in cluster_dictionary.keys():
# --- Calculating Mean and Std across postsynaptic cellsfor mech_type in mech_types:
for pre_cell_pop in featuredPops:
post_cell_hists_windowSum=[]
for post_cell_gid in cluster_dictionary[kMeans_group]:
# --- Checks if pre_cell_pop is part of that cell's dictionary (SOM5A and PV5A projections are not present in all cells)if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]
post_cell_hists_windowSum.append(sum(post_cell_hist_windowSum))
post_pop_hist_windowSum_mean = np.mean(post_cell_hists_windowSum)
post_pop_hist_windowSum_std = np.std(post_cell_hists_windowSum)
plot_bar_kMeans[kMeans_group][mech_type][pre_cell_pop].update({'mean':post_pop_hist_windowSum_mean,'std':post_pop_hist_windowSum_std})
# --- Barplot - One fig per kMeans cluster
fig = plt.figure(figsize=(15,10))
for kMeans_ind,kMeans_group inenumerate(plot_bar_kMeans.keys()):
# --- Figure - Bar plot v3
ax = fig.add_subplot(1, len(plot_bar_kMeans.keys()), kMeans_ind+1)
# plt.subplot(1, len(plot_bar_kMeans.keys()), kMeans_ind+1)for mech_ind,mech_type inenumerate(plot_bar_kMeans[kMeans_group].keys()):
# if divide_plots: plt.subplot(1, 2, mech_ind+1)
sum_of_means=0if mech_type == 'exc':
c = 'royalblue'; displace=-0.2elif mech_type == 'inh':
c = 'r'; displace=+0.2else:
c = 'k'for pre_pop_ind, pre_pop inenumerate(plot_bar_kMeans[kMeans_group][mech_type].keys()):
x_data = pre_pop_ind
if divide_plots: x_name=pre_pop
else: x_name=pre_pop+'_'+mech_type
if mech_ind==1: flipper = 1else: flipper = 1# --- skip empty barsif plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['mean'] == 0and plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['std'] == 0: continue# --- Fix 'nan' values from mean and std calculationif math.isnan(plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['mean']): plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['mean'] = 0if math.isnan(plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['std']): plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['std'] = 0
ax.barh( x_data+displace,
flipper*plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['mean'],
0.4,
xerr=plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['std'],
capsize=3,
color=c,
label=x_name)
# --- Calculates the sum of bars
sum_of_means+=plot_bar_kMeans[kMeans_group][mech_type][pre_pop]['mean']
# --- plot of sum of bars
ax.barh( len(plot_bar_kMeans[kMeans_group][mech_type].keys())+1+displace,
flipper*sum_of_means,
0.4,
capsize=3,
color=c,)
# --- Format figure
plot_labels = list(plot_bar_kMeans[kMeans_group]['exc'].keys())+['','Sum']
if kMeans_ind==0: ax.set_yticks(ticks=list(range(len(plot_labels))),labels=plot_labels)
else: ax.set_yticks(ticks=[],labels=None)
ax.invert_yaxis()
defboxplotKMeans(cluster_dictionary, pop_spike_hist_dict, featuredPops, max_time_index, divide_plots=False, select_pops=None):
mech_types = ['exc','inh']
mech_colors = ['royalblue','r']
print('boxplotKMeans')
# --- Dictionary with data for Boxplot
boxplot_kMeans={}
for kMeans_group in cluster_dictionary.keys():
boxplot_kMeans.update({kMeans_group:{}})
# --- Calculating Mean and Std across postsynaptic cellsfor mech_type in mech_types:
boxplot_kMeans[kMeans_group].update({mech_type:{}})
for pre_cell_pop in featuredPops:
boxplot_kMeans[kMeans_group][mech_type].update({pre_cell_pop:[]})
post_cell_hists_windowSum=[]
for post_cell_gid in cluster_dictionary[kMeans_group]:
if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]
post_cell_hists_windowSum.append(sum(post_cell_hist_windowSum))
# --- Dictionary with data for Boxplot
boxplot_kMeans[kMeans_group][mech_type][pre_cell_pop]=post_cell_hists_windowSum
# --- Boxplot - Figure
fig = plt.figure(figsize=(15,10))
# fig.suptitle('Spike Histogram Window Sum Boxplot - K-Means: '+pop+' histogram | '+str(max_time)+' max time | '+network_state+' | '+timeRange_name+' ms')for kMeans_ind,kMeans_group inenumerate(boxplot_kMeans.keys()):
# --- Plot style# plt.subplot(2, 2, kMeans_ind+1)# ax = fig.add_subplot(2, 2, kMeans_ind+1)
ax = fig.add_subplot(1, len(boxplot_kMeans.keys()), kMeans_ind+1)
boxplot_data=[]; boxplot_labels=[]; boxplot_colors=[]
for mech_ind,mech_type inenumerate(boxplot_kMeans[kMeans_group].keys()):
for pre_pop in boxplot_kMeans[kMeans_group][mech_type].keys():
# --- Selecting populations to plot in the figureif select_pops isnotNone:
if ((mech_type == mech_types[mech_ind]) and (pre_pop in select_pops[mech_ind])):
boxplot_data.append(boxplot_kMeans[kMeans_group][mech_type][pre_pop])
boxplot_labels.append(pre_pop)
boxplot_colors.append(mech_colors[mech_ind])
else: continue# --- quits plotting if pre_pop not in the plot listelse:
boxplot_data.append(boxplot_kMeans[kMeans_group][mech_type][pre_pop])
boxplot_labels.append(pre_pop)
boxplot_colors.append(mech_colors[mech_ind])
# --- Vertical Plot Orientation
vert=False
bp = ax.boxplot(boxplot_data,patch_artist = True,notch ='True',vert=vert)
bp_colors = [bp_color for bp_color in boxplot_colors]
for patch, bp_color inzip(bp['boxes'], bp_colors):
patch.set_facecolor(bp_color)
# changing color and linewidth of whiskersfor whisker in bp['whiskers']:
whisker.set(color ='grey', linewidth = 1.5, linestyle =":")
# changing color and linewidth of mediansfor median in bp['medians']:
median.set(color ='k', linewidth = 1)
# changing style of fliersfor flier in bp['fliers']:
flier.set(marker ='.', color ='k', alpha = 0.5)
boxplot_ticks = [tick+1for tick inrange(len(boxplot_labels))]
ax.set_title("Num of cells: %s"%(len(boxplot_data[0])))
if vert: ax.set_xticks(ticks=boxplot_ticks,rotation=45,labels=boxplot_labels)
else: ax.set_yticks(ticks=boxplot_ticks,rotation=45,labels=boxplot_labels)
return boxplot_kMeans
# --- Class to process and plot data merging QUIET and MOVEMENT statesclassPostAnalysis():
defmergedBarPlot( spk_hist_A,spk_hist_B,
ordered_pre_pops,
post_GIDs_A,post_GIDs_B,
max_time_index,
long_range_pops,
select_pre_pops=None,
ax_lim=None,
verticalPlot=True,
perSecond=False,
timeScaling=None,
states = ['Quiet','Movement'],
state_colors=['royalblue','crimson'],
remove_long_range_inh=True,
showLabels=True,
export_values=True,
export_filename='export_dict.json',
plotStatistics=False):
print('Merged bar plot')
import scipy
# A and B are two different states # (e.g.: # Quiet vs Move; # Cluster 0 vs Cluster 1; # etc# )if select_pre_pops==None:
select_pre_pops=ordered_pre_pops
sumLeftover=Falseelse: sumLeftover=True# --- Adding missing pops if selected but not present in the original list of pops (will appear as a empty bar)for s_pop in select_pre_pops:
if s_pop notin ordered_pre_pops:
print('adding ', s_pop, ' to ', ordered_pre_pops)
ordered_pre_pops.append(s_pop)
mech_types = ['exc','inh']
# --- Allocating keys for the plotting the bar graph
plot_bar_hist_dict={}
for mech_type in mech_types:
plot_bar_hist_dict.update({mech_type:{}})
for pre_cell_pop in ordered_pre_pops:
plot_bar_hist_dict[mech_type].update({pre_cell_pop:{}})
for state in states:
plot_bar_hist_dict[mech_type][pre_cell_pop].update({state:{'mean':0,'std':0}})
if'PT5B'notin plot_bar_hist_dict['exc'].keys():
print('NO PT5B', plot_bar_hist_dict['exc'].keys())
plot_bar_hist_dict['exc'].update({'PT5B':{states[0]:{'mean':0,'std':0},states[1]:{'mean':0,'std':0}}})
print('---->', plot_bar_hist_dict['exc'].keys())
# --- Allocating keys for the exporting valuesif export_values:
export_dict={}
for state in states:
export_dict.update({state:{}})
for mech_type in mech_types:
export_dict[state].update({mech_type:{}})
for pre_cell_pop in ordered_pre_pops:
export_dict[state][mech_type].update({pre_cell_pop:{'mean':0,'std':0,'vals':[]}})
# --- Calculating Mean and Std across postsynaptic cellsfor mech_type in mech_types:
for pre_cell_pop in ordered_pre_pops:
for state in states:
if state == states[0]:
valid_post_cell_gids = post_GIDs_A
pop_spike_hist_dict = spk_hist_A
elif state == states[1]:
valid_post_cell_gids = post_GIDs_B
pop_spike_hist_dict = spk_hist_B
post_cell_hists_windowSum=[]
for post_cell_gid in valid_post_cell_gids:
if pre_cell_pop in pop_spike_hist_dict[post_cell_gid].keys():
# --- Selects a time window from the histogram and adds it together to calculate the total weight
post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][0:max_time_index+1]
# post_cell_hist_windowSum = pop_spike_hist_dict[post_cell_gid][pre_cell_pop][mech_type][1:max_time_index+1]
post_cell_hists_windowSum.append(sum(post_cell_hist_windowSum))
if perSecond:
post_pop_hist_windowSum_mean = np.mean(post_cell_hists_windowSum)/timeScaling
post_pop_hist_windowSum_std = np.std(post_cell_hists_windowSum)/timeScaling
post_pop_hist_windowSum_vals = [x/timeScaling for x in post_cell_hists_windowSum]
else:
post_pop_hist_windowSum_mean = np.mean(post_cell_hists_windowSum)
post_pop_hist_windowSum_std = np.std(post_cell_hists_windowSum)
post_pop_hist_windowSum_vals = post_cell_hists_windowSum
plot_bar_hist_dict[mech_type][pre_cell_pop][state].update({'mean':post_pop_hist_windowSum_mean,'std':post_pop_hist_windowSum_std,'vals':post_pop_hist_windowSum_vals})
# --- removing 'nan' values:for mech_type in mech_types:
for pre_cell_pop in ordered_pre_pops:
for state in states:
if math.isnan(plot_bar_hist_dict[mech_type][pre_cell_pop][state]['mean']): plot_bar_hist_dict[mech_type][pre_cell_pop][state]['mean']=0if math.isnan(plot_bar_hist_dict[mech_type][pre_cell_pop][state]['std']): plot_bar_hist_dict[mech_type][pre_cell_pop][state]['std']=0# --- Draw Figure - Bar plot v4if verticalPlot: fig, ax = plt.subplots(figsize=(10,5))
else: fig, ax = plt.subplots(figsize=(6,14))
if ax_lim isnotNone: ax.set_xlim((ax_lim))
# --- Plot style
divide_plots=False
sum_of_means_exc_Q=0; sum_of_means_exc_M=0; sum_of_means_inh_Q=0; sum_of_means_inh_M=0
list_of_stds_exc_Q=[]; list_of_stds_exc_M=[]; list_of_stds_inh_Q=[]; list_of_stds_inh_M=[]
leftover_sum_of_means_exc_Q=0; leftover_sum_of_means_exc_M=0; leftover_sum_of_means_inh_Q=0; leftover_sum_of_means_inh_M=0
leftover_list_of_stds_exc_Q=[]; leftover_list_of_stds_exc_M=[]; leftover_list_of_stds_inh_Q=[]; leftover_list_of_stds_inh_M=[]
errbar_color='k'; errbar_alpha=0.5
bar_widht_thick = 0.36
bar_widht_slim = 0.23for pre_pop_ind, pre_pop inenumerate(select_pre_pops):
if pre_pop in long_range_pops:
if remove_long_range_inh: bar_width = bar_widht_thick
else: bar_width = bar_widht_slim
else: bar_width = bar_widht_thick
shift_distance = bar_width/2
capsize=1.5for mech_ind,mech_type inenumerate(plot_bar_hist_dict.keys()):
if mech_type == 'exc': c = state_colors[0]
elif mech_type == 'inh': c = state_colors[1]
else: c = 'k'for state_ind,state inenumerate(states):
# --- Plots the pops in <select_pre_pops>if state == states[0]: alpha = 0.5else: alpha = 1if divide_plots: ax.subplot(1, 2, mech_ind+1)
# --- Creates 4 slots for long range popsif pre_pop in long_range_pops:
if remove_long_range_inh:
if state == states[0]: displace=-shift_distance*1elif state == states[1]: displace=shift_distance*1else:
if mech_type == 'exc'and state == states[0]: displace=-shift_distance*3elif mech_type == 'exc'and state == states[1]: displace=-shift_distance*1elif mech_type == 'inh'and state == states[0]: displace=shift_distance*1elif mech_type == 'inh'and state == states[1]: displace=shift_distance*3# --- Creates 2 slots for long range popselse:
if state == states[0]: displace=-shift_distance*1elif state == states[1]: displace=shift_distance*1# --- skip empty barsif plot_bar_hist_dict[mech_type][pre_pop][state]['mean'] == 0and plot_bar_hist_dict[mech_type][pre_pop][state]['std'] == 0: continueelif math.isnan(plot_bar_hist_dict[mech_type][pre_pop][state]['mean']) or math.isnan(plot_bar_hist_dict[mech_type][pre_pop][state]['std']): continue# else:# print('plotting pop ', pre_pop, state)# --- Store data for SUM bar plotif mech_type == 'exc'and state == states[0]:
sum_of_means_exc_Q+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_exc_Q.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
elif mech_type == 'exc'and state == states[1]:
sum_of_means_exc_M+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_exc_M.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
elif mech_type == 'inh'and state == states[0]:
sum_of_means_inh_Q+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_inh_Q.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
elif mech_type == 'inh'and state == states[1]:
sum_of_means_inh_M+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_inh_M.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
# --- skip plotting long range inh if remove_long_range_inh:
if (pre_pop in long_range_pops) and (mech_type=='inh'): continue# --- Plot data for SUM bar plot
x_data = pre_pop_ind
if divide_plots: x_name=pre_pop
else: x_name=pre_pop+'_'+mech_type
if verticalPlot:
ax.bar( x_data+displace,
plot_bar_hist_dict[mech_type][pre_pop][state]['mean'],
bar_width,
color=c,
# label=x_name,
alpha=alpha,
label=None,)
ax.errorbar(x_data+displace,
plot_bar_hist_dict[mech_type][pre_pop][state]['mean'],
yerr=plot_bar_hist_dict[mech_type][pre_pop][state]['std'],
lolims=True,
uplims=False,
capsize=capsize,
color=errbar_color,
alpha=errbar_alpha,
label=None,)
else:
ax.barh( x_data+displace,
plot_bar_hist_dict[mech_type][pre_pop][state]['mean'],
bar_width,
color=c,
# label=x_name,
alpha=alpha,
label=None,)
ax.errorbar(plot_bar_hist_dict[mech_type][pre_pop][state]['mean'],
x_data+displace,
xerr=plot_bar_hist_dict[mech_type][pre_pop][state]['std'],
xlolims=True,
xuplims=False,
capsize=capsize,
color=errbar_color,
alpha=errbar_alpha,
label=None,)
if export_values:
export_dict[state][mech_type][pre_pop]['mean']=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
export_dict[state][mech_type][pre_pop]['std']=plot_bar_hist_dict[mech_type][pre_pop][state]['std']
export_dict[state][mech_type][pre_pop]['vals']=plot_bar_hist_dict[mech_type][pre_pop][state]['vals']
if plotStatistics:
Data1 = plot_bar_hist_dict[mech_type][pre_pop][states[0]]['vals']
Data2 = plot_bar_hist_dict[mech_type][pre_pop][states[1]]['vals']
# --- skip empty barsiflen(Data1) <= 0orlen(Data2) <= 0: continueif remove_long_range_inh:
if mech_type=='inh'and pre_pop in long_range_pops: continue
marker_position = max([plot_bar_hist_dict[mech_type][pre_pop][states[0]]['mean']+plot_bar_hist_dict[mech_type][pre_pop][states[0]]['std'],
plot_bar_hist_dict[mech_type][pre_pop][states[1]]['mean']+plot_bar_hist_dict[mech_type][pre_pop][states[1]]['std']])+200
stats = scipy.stats.mannwhitneyu(Data1, Data2)
if stats.pvalue>0.05: stat_symbol = ''elif stats.pvalue<=0.05and stats.pvalue>0.01: stat_symbol = '*'elif stats.pvalue<=0.01and stats.pvalue>0.001: stat_symbol = '**'elif stats.pvalue<=0.001: stat_symbol = '***'print('plotting ',stats.pvalue,'\t',stat_symbol,'\t',pre_pop_ind,marker_position)
ax.text(marker_position,pre_pop_ind+0.18,s=stat_symbol,size=10,rotation='vertical')
if sumLeftover:
# --- position of the sum bars in the plots
sum_bar_position = len(select_pre_pops)+3
leftover_bar_position = len(select_pre_pops)+1
plot_labels = select_pre_pops+['','Others']+['','Sum']
leftover_list=list(set(ordered_pre_pops).difference(select_pre_pops))
for pre_pop_ind, pre_pop inenumerate(leftover_list):
for mech_ind,mech_type inenumerate(plot_bar_hist_dict.keys()):
for state_ind,state inenumerate(states):
# --- Store data for leftover SUM bar plotif mech_type == 'exc'and state == states[0]:
sum_of_means_exc_Q+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_exc_Q.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
leftover_sum_of_means_exc_Q+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
leftover_list_of_stds_exc_Q.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
elif mech_type == 'exc'and state == states[1]:
sum_of_means_exc_M+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_exc_M.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
leftover_sum_of_means_exc_M+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
leftover_list_of_stds_exc_M.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
elif mech_type == 'inh'and state == states[0]:
sum_of_means_inh_Q+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_inh_Q.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
leftover_sum_of_means_inh_Q+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
leftover_list_of_stds_inh_Q.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
elif mech_type == 'inh'and state == states[1]:
sum_of_means_inh_M+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
list_of_stds_inh_M.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
leftover_sum_of_means_inh_M+=plot_bar_hist_dict[mech_type][pre_pop][state]['mean']
leftover_list_of_stds_inh_M.append(plot_bar_hist_dict[mech_type][pre_pop][state]['std'])
# --- Square of the stds - leftover
leftover_std_exc_sqr_Q=[std**2for std in leftover_list_of_stds_exc_Q]
leftover_std_exc_sqr_M=[std**2for std in leftover_list_of_stds_exc_M]
leftover_std_inh_sqr_Q=[std**2for std in leftover_list_of_stds_inh_Q]
leftover_std_inh_sqr_M=[std**2for std in leftover_list_of_stds_inh_M]
leftover_sum_std_exc_Q=math.sqrt(sum(leftover_std_exc_sqr_Q))
leftover_sum_std_exc_M=math.sqrt(sum(leftover_std_exc_sqr_M))
leftover_sum_std_inh_Q=math.sqrt(sum(leftover_std_inh_sqr_Q))
leftover_sum_std_inh_M=math.sqrt(sum(leftover_std_inh_sqr_M))
positions =[(leftover_bar_position-shift_distance*3),(leftover_bar_position-shift_distance*1),(leftover_bar_position+shift_distance*1),(leftover_bar_position+shift_distance*3)]
sum_means =[leftover_sum_of_means_exc_Q,leftover_sum_of_means_exc_M,leftover_sum_of_means_inh_Q,leftover_sum_of_means_inh_M]
xerrs =[leftover_sum_std_exc_Q,leftover_sum_std_exc_M,leftover_sum_std_inh_Q,leftover_sum_std_inh_M]
colors =[state_colors[0],state_colors[0],state_colors[1],state_colors[1]]
alphas =[0.5,1,0.5,1]
for ind,sum_mean inenumerate(sum_means):
if verticalPlot:
ax.bar( positions[ind],
sum_means[ind],
bar_width,
color=colors[ind],
alpha=alphas[ind],
label=None,)
ax.errorbar(positions[ind],
sum_means[ind],
yerr=xerrs[ind],
lolims=True,
uplims=False,
capsize=capsize,
color=errbar_color,
alpha=errbar_alpha,
label=None,)
else:
ax.barh( positions[ind],
sum_means[ind],
bar_width,
color=colors[ind],
alpha=alphas[ind],
label=None,)
ax.errorbar(sum_means[ind],
positions[ind],
xerr=xerrs[ind],
xlolims=True,
xuplims=False,
capsize=capsize,
color=errbar_color,
alpha=errbar_alpha,
label=None,)
if export_values:
export_dict.update({'leftover':{}})
ind=0for mech_type in mech_types:
for state in states:
export_dict['leftover'].update({mech_type+state:{'mean':sum_means[ind],'std':xerrs[ind]}})
ind+=1else:
sum_bar_position= len(select_pre_pops)+1
plot_labels = select_pre_pops+['','Sum']
# --- Square of the stds
std_exc_sqr_Q=[std**2for std in list_of_stds_exc_Q]
std_exc_sqr_M=[std**2for std in list_of_stds_exc_M]
std_inh_sqr_Q=[std**2for std in list_of_stds_inh_Q]
std_inh_sqr_M=[std**2for std in list_of_stds_inh_M]
sum_std_exc_Q=math.sqrt(sum(std_exc_sqr_Q))
sum_std_exc_M=math.sqrt(sum(std_exc_sqr_M))
sum_std_inh_Q=math.sqrt(sum(std_inh_sqr_Q))
sum_std_inh_M=math.sqrt(sum(std_inh_sqr_M))
# --- parameters to draw the sum bars
sum_bar_width = bar_widht_thick
sum_shift_distance = sum_bar_width/2
sum_capsize = 1.5
positions =[(sum_bar_position-sum_shift_distance*3),(sum_bar_position-sum_shift_distance*1),(sum_bar_position+sum_shift_distance*1),(sum_bar_position+sum_shift_distance*3)]
sum_means =[sum_of_means_exc_Q,sum_of_means_exc_M,sum_of_means_inh_Q,sum_of_means_inh_M]
xerrs =[sum_std_exc_Q,sum_std_exc_M,sum_std_inh_Q,sum_std_inh_M]
colors =[state_colors[0],state_colors[0],state_colors[1],state_colors[1]]
alphas =[0.5,1,0.5,1]
if showLabels: labels =['Exc '+states[0],'Exc '+states[1],'Inh '+states[0],'Inh '+states[1]]
else: labels =['','','','']
for ind,sum_mean inenumerate(sum_means):
if verticalPlot:
ax.bar( positions[ind],
sum_means[ind],
sum_bar_width,
# bar_width,
color=colors[ind],
alpha=alphas[ind],
label=labels[ind],)
ax.errorbar(positions[ind],
sum_means[ind],
yerr=xerrs[ind],
lolims=True,
uplims=False,
capsize=sum_capsize,
# capsize=capsize,
color=errbar_color,
alpha=errbar_alpha,
label=None,)
else:
ax.barh( positions[ind],
sum_means[ind],
sum_bar_width,
# bar_width,
color=colors[ind],
alpha=alphas[ind],
label=labels[ind],)
ax.errorbar(sum_means[ind],
positions[ind],
xerr=xerrs[ind],
xlolims=True,
xuplims=False,
capsize=sum_capsize,
# capsize=capsize,
color=errbar_color,
alpha=errbar_alpha,
label=None,)
if export_values:
export_dict.update({'sum':{}})
ind=0for mech_type in mech_types:
for state in states:
export_dict['sum'].update({mech_type+state:{'mean':sum_means[ind],'std':xerrs[ind]}})
ind+=1# --- Format figureif verticalPlot:
ax.set_xticks(ticks=list(range(len(plot_labels))),labels=plot_labels,rotation=45)
ax.set_xlabel('Presynaptic population')
ax.set_ylabel('Estimated synaptic drive')
else:
ax.set_yticks(ticks=list(range(len(plot_labels))),labels=plot_labels)
# plt.gca().invert_yaxis()
ax.invert_yaxis()
ax.set_xlabel('Estimated synaptic drive')
ax.set_ylabel('Presynaptic population')
# ax.set_xticks([0,1.0,2.0])# --- Removing boxes from plot axis
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
print('bar plot method')
# --- Changing font
plt.rcParams.update({'font.size': 20})
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(20)
if export_values:
# Writing to sample.json
json_object = json.dumps(export_dict, indent=4)
withopen(export_filename, "w") as outfile: outfile.write(json_object)
classBarPlot():
# source: https://stackoverflow.com/questions/14270391/python-matplotlib-multiple-barsdefbar_plot(ax, data, colors=None, total_width=0.8, single_width=1, legend=True, title=None):
"""Draws a bar plot with multiple bars per data point.
Parameters
----------
ax : matplotlib.pyplot.axis
The axis we want to draw our plot on.
data: dictionary
A dictionary containing the data we want to plot. Keys are the names of the
data, the items is a list of the values.
Example:
data = {
"x":[1,2,3],
"y":[1,2,3],
"z":[1,2,3],
}
colors : array-like, optional
A list of colors which are used for the bars. If None, the colors
will be the standard matplotlib color cyle. (default: None)
total_width : float, optional, default: 0.8
The width of a bar group. 0.8 means that 80% of the x-axis is covered
by bars and 20% will be spaces between the bars.
single_width: float, optional, default: 1
The relative width of a single bar within a group. 1 means the bars
will touch eachother within a group, values less than 1 will make
these bars thinner.
legend: bool, optional, default: True
If this is set to true, a legend will be added to the axis.
"""# Check if colors where provided, otherwhise use the default color cycleif colors isNone: colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
# Number of bars per group
n_bars = len(data)
# The width of a single bar
bar_width = total_width / n_bars
# List containing handles for the drawn bars, used for the legend
bars = []
# Iterate over all datafor i, (name, values) inenumerate(data.items()):
x_offset = (i - n_bars / 2) * bar_width + bar_width / 2# The offset in x direction of that bar# Draw a bar for every value of that typefor x, y inenumerate(values):
bar = ax.bar(x + x_offset, y, width=bar_width * single_width, color=colors[i % len(colors)])
bars.append(bar[0]) # Add a handle to the last drawn bar, which we'll need for the legend# Draw legend if we needif legend: ax.legend(bars, data.keys())
# Draw legend if we needif title: ax.set_title(title)