"""
Tool used for plotting various things throughout the project.
Helpful for navigating the BMTK h5 structuring.
"""
import matplotlib.pyplot as plt
import numpy as np
import h5py
import scipy.signal as s
def get_key(group, index=0):
"""From the list of the keys of the given group, returns the key at
the given index.
Parameters
----------
group : h5py.Group
the h5py group to get the key from
index : int, optional
index of the key in the list of the group's keys, by default 0
Returns
-------
str
desired key for the h5py group
"""
return list(group.keys())[index]
def load_dataset(fname, groups=2):
"""Gets a dataset within the given h5 file.
Many BMTK h5 files have one dataset within some layers of group,
and this is a useful function for getting to that dataset.
Assumes that each group just has one key.
Parameters
----------
fname : str
h5 file to load
groups : int, optional
number of groups before the dataset, by default 2
Returns
-------
h5py.Dataset
the desired dataset
"""
f = h5py.File(fname, 'r')
for i in range(groups):
f = f[get_key(f)]
return f
def plot_spikes(file, show=False, id_scale=-1, id_shift = 0, time_scale = 10):
data = load_dataset(file)
scale = 1
if id_scale > 0:
scale = id_scale / np.max(data['node_ids'])
plt.plot(np.array(data['timestamps'])*time_scale,np.array(data['node_ids']) * scale + id_shift,'.')
if(show):
plt.show()
def plot_v(file, show=False, ms=False):
"""Plots the membrane potential from the given BMTK v_report.h5 file.
Parameters
----------
file : str
location of the h5py file
show : bool, optional
whether to call plt.show() at the end, by default False
ms : bool, optional
whether to scale x by 0.1 to get ms scale, by default False
"""
data = load_dataset(file)
x = np.arange(0, np.array(data['data']).shape[0])
if ms:
x = x / 10
plt.plot(x, data['data'][:, 0])
if(show):
plt.show()
def plot_all_v(file, ms=False):
"""Plots each membrane potential in the given BMTK v_report.h5 file.
Parameters
----------
file : str
location of the h5py file
ms : bool, optional
whether to scale x by 0.1 to get ms scale, by default False
"""
data = load_dataset(file)
x = np.arange(0, np.array(data['data']).shape[0])
if ms:
x = x / 10
for i in range(data['data'].shape[1]):
plt.plot(x, data['data'][:, i])
plt.show()
def plot_se(file, show=False):
"""Used to plot se_clamp_reports from BMTK.
Parameters
----------
file : str
location of the h5py file
show : bool, optional
whether to call plt.show() at the end, by default False
"""
data = load_dataset(file, groups=1)
plt.plot(data[:, 0])
if(show):
plt.show()
# def generate_spike_probs(inh_file, spike_file, time):
# gamma = generate_spike_gamma(inh_file, time)
# data = load_dataset(spike_file)
# timestamps = np.array(data['timestamps'])
# troughs = s.find_peaks(-gamma)[0]
# n_parts = 10
# parts = np.zeros(n_parts)
# for i in range(len(troughs) - 1):
# start = troughs[i]
# part_len = (troughs[i+1] - start)/n_parts
# for j in range(n_parts):
# parts[j] += len(np.where((timestamps >= j*part_len + start) & (timestamps < (j+1)*part_len + start))[0])
# parts = np.array(parts) / parts.sum()
# #t1 = gamma[troughs[100]:troughs[101]]
# t1 = gamma[troughs[0]:troughs[1]]
# plt.plot(parts, label="spike probability")
# plt.plot(np.arange(len(t1)) * (n_parts/len(t1)), t1/10, label="gamma ex.")
# plt.legend()
# plt.show()
# return parts
# def generate_prob_raster(inh_file, spike_file, time):
# gamma = generate_spike_gamma(inh_file, time)
# data = load_dataset(spike_file)
# node_ids = np.array(data['node_ids'])
# timestamps = np.array(data['timestamps'])
# troughs = s.find_peaks(-gamma)[0]
# new_ts = np.zeros(len(timestamps))
# #ids = np.arange(len(timestamps))
# cycle_num = np.zeros(len(new_ts))
# for i in range(len(troughs) - 1):
# start = troughs[i]
# stop = troughs[i + 1]
# length = stop - start
# spikes = np.where((timestamps >= start) & (timestamps < stop))[0]
# cycle_num[spikes] = i
# times = timestamps[spikes]
# times = times - start
# times = times / length
# new_ts[spikes] = times
# #part_len = (troughs[i+1] - start)/n_parts
# # for j in range(n_parts):
# # parts[j] += len(np.where((timestamps >= j*part_len + start) & (timestamps < (j+1)*part_len + start))[0])
# parts = np.zeros(10)
# sep = 0.1
# for i in range(10):
# parts[i] = len(np.where((new_ts >= (i * sep)) & (new_ts < ((i+1)*sep)))[0])
# #parts = np.array(parts) / parts.sum()
# #t1 = gamma[troughs[100]:troughs[101]]
# t1 = gamma[troughs[0]:troughs[1]]
# #import pdb; pdb.set_trace()
# #plt.plot(parts, label="spike probability")
# plt.plot(np.arange(10)+0.5, (parts / parts.sum()), color="black", label="spike probability")
# plt.plot(new_ts*(10), cycle_num/max(cycle_num), ".")
# plt.xticks([0, 5, 10], labels = ["-" + r'$\pi$', 0, r'$\pi$'])
# plt.axvline(x=5, ls="--", color = "black")
# #plt.plot(np.arange(len(t1)) * (len(t1)/len(t1)), t1/3, label="gamma ex.")
# #plt.plot(t1/3, label="gamma ex.")
# plt.legend()
# ax = plt.gca()
# #ax.axes.xaxis.set_visible(False)
# ax.axes.yaxis.set_visible(False)
# plt.show()
# #return parts
# def plot_spike_gamma(file, time):
# gamma = generate_spike_gamma(file, time)
# # troughs = s.find_peaks(-smooth)[0]
# # parts = np.zeros(10)
# # for i in range(len(troughs) - 1):
# # part_len = (troughs[i+1] - troughs[i])/10
# # for j in range(10):
# # parts[j] += len(np.where((timestamps >= j*part_len) & (timestamps < (j+1)*part_len))[0])
# #import pdb; pdb.set_trace()
# plt.plot(np.arange(time)*10, smooth)