# fig_kc_removal.py ---
# Author: Subhasis Ray
# Created: Fri Aug 16 11:27:37 2019 (-0400)
# Last-Updated: Tue Aug 20 13:23:15 2019 (-0400)
# By: Subhasis Ray
# Version: $Id$
# Code:
"""This summarizes the results from series of KC removals.
In these simulations I first removed KCs that generated 6 or more spikes from the model and simulated repeatedly.
Once there were no more KCs to remove, I started removing KCs that
spiked 4 or more spikes.
Finally, I removed KCs that spiked at all.
"""
import os
import numpy as np
import h5py as h5
import pandas as pd
import network_data_analysis as nda
from matplotlib import pyplot as plt
DATA_DIR = 'D:/biowulf_stage/fixed_net/'
TEMPLATE_DIR = 'D:/biowulf_stage/fixed_net_templates/'
SIX_SPIKES = [
('22072442', '31370038', '31390354', '31397146', '31519150', '31689528', '31735952', '31766700', '31817932', '+31864090+'),
('22087964', '31370044', '31390198', '31401062', '31414919', '31519151', '31689529', '31730432', '31759138', '31793240', '+31832556+'),
('22087965', '31370045', '31392169', '31404140', '31419791', '31519152', '31689530', '31741849', '31791436', '31825166', '+31901939+'),
('22087966', '31370046', '31396924', '31412598', '31519153', '31689531', '31739668', '31763425', '+31818159+'),
('22087967', '31370047', '31397971', '31412573', '31427275', '31519154', '31689533', '31743014', '31763577', '31835183', '+31871621+'),
('22087969', '31370049', '31393197', '31404071', '31519155', '31689536', '31741660', '+31767463+'),
('22087970', '31370072', '31392820', '31401873', '31519157', '31689537', '31738490', '+31766626+'),
('22087971', '31370183', '31396623', '31412233', '31519159', '31689540', '31737800', '+31764868+'),
('22087972', '31370186', '31396423', '31412439', '31519161', '31689543', '31729998', '31761039', '+31805984+'),
('22087973', '31370242', '31389652', '31519162', '31689544', '31739760', '+31760348+'),]
THREE_SPIKES = [
('31817932', '33088890', '33130828', '33150821', '33169363', '33196902', '33230558', '+33248002+'),
('31793240', '33088894', '33138100', '33152203', '33201249', '33238084', '+33250972+'),
('31825166', '33088901', '33124461', '33149315', '33166758', '33183865', '+33212362+'),
('31763425', '33088907', '33131130', '33150972', '33177255', '33205609', '33234799', '33249051', '+33252123+' ),
('31835183', '33088915', '33127211', '33157592', '33172116', '33192104', '+33235021+'),
('31741660', '33088919', '33147305', '33164595', '33205556', '+33251787+'),
('31738490', '33088926', '33138263', '33155039', '33177259', '+33204717+' ),
('31737800', '33088930', '33139122', '33153551', '33169010', '33187885', '+33223605+' ),
('31761039', '33088933', '33124370', '33148095', '33163998', '33179470', '33217120', '+33239144+' ),
('31739760', '33088939', '33131194', '33151255', '+33167279+'),]
ALL_SPIKES = [
('33230558','33357781','33413275','33422096','33434464','33447914','33465825','+33477877+'),
('33238084','33357783','33393194','33415310','33426780','33437911','33447538','33458282','+33465056+'),
('33183865','33357784','33397857','33420873','33431566','33445268','33455574','+33462623+'),
('33249051','33357785','33399813','33425820','33436914','33445885','33455651','33462572','+33466633+'),
('33192104','33357786','33385823','33412221','33421554','33443595','33452792','+33462456+'),
('33205556','33357787','33383346','33407989','33421229','33436076','33444470','+33454014+'),
('33177259','33357788','33387409','33409757','33422380','33433787','33447922','+33462482+'),
('33187885','33357789','33378257','33400048','33414034','33422392','+33443173+'),
('33217120','33357790','33395217','33418472','33432306','33443443','33448552','33461063','33465508','+33485730+'),
('33151255','33357791','33380109','33414869','33431578','33441682','33448950','+33460146+'),
]
def make_psth_and_vm(ax_psth, ax_vm, ax_kc_hist):
binwidth=100
datalist = (SIX_SPIKES[0][1:-1], THREE_SPIKES[0][1:-1], ALL_SPIKES[0][1:-1])
colors = ['#e66101', '#5e3c99', '#009292']
ls = ['-', ':']
for ii, group in enumerate(datalist):
print(group)
for jj, jid in enumerate((group[0], group[-1])):
print(jid)
try:
fname = nda.find_h5_file(jid, DATA_DIR)
except:
# First entry is old data and moved to back up
# Template should still have all the data
print(f'JID {jid} not in datadir. Looking in template dir')
fname = nda.find_h5_file(jid, TEMPLATE_DIR)
with h5.File(fname, 'r') as fd:
kc_st, kc_id = nda.get_event_times(fd[nda.kc_st_path])
kc_sc = np.array([len(st) for st in kc_st])
try:
ax_kc_hist.hist(kc_sc, bins=np.arange(1, max(kc_sc)+0.5, 1),
color=colors[ii], ls=ls[jj],
label=f'{ii}: {jj}: {jid}',
histtype='step', linewidth=1)
except IndexError:
print(jid, ':', kc_sc, '|')
pop_st = np.concatenate(kc_st)
try:
ax_psth.hist(pop_st,
bins=np.arange(500, 2100, binwidth),
color=colors[ii], ls=ls[jj], histtype='step', label=jid)
except IndexError:
print(jid, pop_st)
ggn_vm, t = nda.get_ggn_vm(fd, 'basal')
ax_vm.plot(t, ggn_vm[0, :], label=jid, color=colors[ii], ls=ls[jj])
ax_psth.legend()
ax_vm.legend()
def plot_spike_counts(ax, fname=None):
if fname is None:
for ii in range(len(SIX_SPIKES)):
spike_counts = []
for jj, sim_set in enumerate([SIX_SPIKES, THREE_SPIKES, ALL_SPIKES]):
sim_list = sim_set[ii][1:-1]
for kk, jid in enumerate(sim_list):
try:
fname = nda.find_h5_file(jid, DATA_DIR)
except:
# First entry is old data and moved to back up
# Template should still have all the data
print(f'JID {jid} not in datadir. Looking in template dir')
fname = nda.find_h5_file(jid, TEMPLATE_DIR)
with h5.File(fname, 'r') as fd:
kc_st, kc_id = nda.get_event_times(fd[nda.kc_st_path])
kc_sc = sum([len(st) for st in kc_st])
print('JID:', jid, 'total spikes:', kc_sc)
# For all cases, plot the result from the last successful KC removal, and first
spike_counts.append(kc_sc)
ax.plot(len(spike_counts)-1, spike_counts[-1], 'k|')
ax.plot(spike_counts, 'o-', fillstyle='none')
else:
total_spike_count = pd.read_csv(fname, sep=',')
for ii, (series_id, simgrp) in enumerate(total_spike_count.groupby('series')):
ax.plot(simgrp['total_spikes'].values, 'o-', fillstyle='none')
series_df = simgrp.reset_index()
for removal, remgrp in series_df.groupby('removal'):
print('#', removal)
print(remgrp)
# jj += len(remgrp)
ax.plot(remgrp.index.values[-1], remgrp['total_spikes'].values[-1], 'k.')
def main():
plt.close('all')
fig = plt.figure()
ax_kc_psth = fig.add_subplot(221)
ax_total_spike_count = fig.add_subplot(222)
ax_ggn_vm = fig.add_subplot(223, sharex=ax_kc_psth)
ax_spike_count_hist = fig.add_subplot(224)
make_psth_and_vm(ax_kc_psth, ax_ggn_vm, ax_spike_count_hist)
if os.path.exists('total_spike_counts.csv'):
plot_spike_counts(ax_total_spike_count, 'total_spike_counts.csv')
ax_total_spike_count.set_xticks([0, 5, 10, 15, 20])
ax_ggn_vm.set_xlim(500, 2200)
ax_ggn_vm.set_ylim(-51, -40)
for ax in fig.axes:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
fig.set_size_inches(175/25.0, 140/25.0)
fig.savefig('Figure_kc_removal.svg')
plt.show()
#
#ax = plt.subplot(111)
#plot_spike_counts(ax, fname='D:/subhasis_ggn/model/analysis/total_spike_counts.csv')
#
if __name__ == '__main__':
main()
#
# fig_kc_removal.py ends here