import h5py
import pandas as pd
from raster_maker import SonataWriter
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as ss
import scipy
import json
with open('L5NetParams.json') as f:
params = json.load(f)
df = pd.read_csv('Connections.csv')
p_delta = 0.1 # change p_delta of the nodes to 100% modulation
num_cells = int(p_delta * df[(df['Source Population']=='exc_stim')&(df.Name.str.contains('dend'))]['Node ID'].nunique())
cells_to_change = np.random.choice(df[(df['Source Population']=='exc_stim')&
(df.Name.str.contains('dend'))]['Node ID'].unique(),
num_cells,
replace=False)
depth_of_mod = 1
freq = 64
phase = 0#2*np.pi/3
tsim = params['time']['stop'] # seconds
t = np.arange(0,tsim,.001)
mod_trace = depth_of_mod*(np.sin((2 * np.pi * freq * t ) - phase) + 1) + (1-depth_of_mod)
#numbPoints = scipy.stats.poisson(rate_temp/1000).rvs()#Poisson number of points
#simSpks=np.where(numbPoints>0)[0]
f = h5py.File('exc_stim_spikes.h5','r')
mask = np.isin(f['spikes']['exc_stim']['node_ids'][:], cells_to_change)
anti_mask = ~mask
old_timestamps = f['spikes']['exc_stim']['timestamps'][anti_mask]
old_nodeids = f['spikes']['exc_stim']['node_ids'][anti_mask]
fr_df = pd.DataFrame(np.concatenate((f['spikes']['exc_stim']['timestamps'][mask].reshape(-1,1),
f['spikes']['exc_stim']['node_ids'][mask].reshape(-1,1)),axis=1),
columns = ['timestamps','node_ids'])
fr_df = (fr_df.groupby('node_ids')['timestamps'].count()/tsim).reset_index()
ts = []
nid = []
for n in fr_df['node_ids']:
fr = fr_df.loc[fr_df.node_ids==n,'timestamps'].values
#import pdb; pdb.set_trace()
numbPoints = scipy.stats.poisson(fr*mod_trace/1000).rvs()
ts.append(np.where(numbPoints>0)[0])
nid.append(np.repeat(n,np.where(numbPoints>0)[0].shape[0]))
new_timestamps = np.concatenate(ts).ravel()
new_nodeids = np.concatenate(nid).ravel()
timestamps = np.concatenate((old_timestamps,new_timestamps)).astype(int)
node_ids = np.concatenate((old_nodeids,new_nodeids)).astype(int)
fname = 'exc_stim_spikes2.h5'
writer = SonataWriter(fname, ["spikes", "exc_stim"], ["timestamps", "node_ids"], [np.float, np.int])
for i in np.unique(node_ids):
simSpks = timestamps[node_ids==i]
writer.append_repeat("node_ids", int(i), len(simSpks))
writer.append_ds(simSpks, "timestamps")
writer.close()