import sys, os
from bmtk.simulator import bionet
import numpy as np
import pandas as pd
import h5py
from neuron import h
from scipy.stats import skew
import synapses
from bmtk.simulator.bionet.pyfunction_cache import add_weight_function
import pickle
#import load_processor
#load_processor.load()
np.random.seed(42)
synapses.load()
#import pdb; pdb.set_trace()
pc = h.ParallelContext() # object to access MPI methods
MPI_size = int(pc.nhost())
MPI_rank = int(pc.id())
# if __name__ == '__main__':
# if __file__ != sys.argv[-1]:
# inp = sys.argv[-1]
# else:
# raise Exception("no work" + str(sys.argv[-1]))
# fname = str(inp)
config_file = 'simulation_config.json'
def gaussianBL(edge_props, source, target):
w0 = edge_props["syn_weight"]
sigma = edge_props["weight_sigma"]
try:
maximum = edge_props["weight_max"]
return min(maximum, np.random.normal(w0, sigma, 1))
except:
return np.random.normal(w0, sigma, 1)
def lognormal(edge_props, source, target):
m = edge_props["syn_weight"]
s = edge_props["weight_sigma"]
mean = np.log(m) - 0.5 * np.log((s/m)**2+1)
std = np.sqrt(np.log((s/m)**2 + 1))
try:
maximum = edge_props["weight_max"]
return max(min(maximum, np.random.lognormal(mean, std, 1)), 0)
except:
return max(0, np.random.lognormal(mean, std, 1))
add_weight_function(lognormal)
add_weight_function(gaussianBL)
conf = bionet.Config.from_json(config_file, validate=True)
conf.build_env()
graph = bionet.BioNetwork.from_config(conf)
#import pdb; pdb.set_trace()
sim = bionet.BioSimulator.from_config(conf, network=graph)
# from analyze_area import analyze_area
# analyze_area(graph.get_local_cells()[0]._morph.seg_prop)
#import pdb; pdb.set_trace()
# cell = graph.get_local_cells()[1]
# memb = h.Vector()
# memb.record(cell.hobj.soma[0](0.5)._ref_v)
#import pdb; pdb.set_trace()
cells = graph.get_local_cells()
#import pdb; pdb.set_trace()
#gid_min = min(cells.keys())
# exc_strengths = {}
# inh_strengths = {}
# comb_strens = {}
# for gid, cell in cells.items():
# exc_strens = []
# inh_strens = []
# fr_comb_total = 0
# for con in cell._connections:
# #import pdb; pdb.set_trace()
# #if len(np.array(con.source_node._train_vec)) > 0:
# #import pdb; pdb.set_trace()
# if con._edge_prop.source_population == 'exc_stim':
# fr_comb_total += len(np.array(con.source_node._train_vec)) * con._syn.initW
# #if len(np.array(con.source_node._train_vec)) > 0:
# #exc_strens.append(con.syn_weight)
# exc_strens.append(con._syn.initW)
# elif con._edge_prop.source_population == 'inh_stim':
# #inh_strens.append(con.syn_weight)
# inh_strens.append(con._syn.initW)
# else:
# raise Exception("Source pop is: " + str(con._edge_prop.source_population))
# #print(gid, ":", fr_comb_total / len(cell._connections))
# comb_strens[gid] = fr_comb_total / len(exc_strens)
# #exc_strengths[gid - gid_min] = exc_strens
# #inh_strengths[gid - gid_min] = inh_strens
# exc_strengths[gid] = exc_strens
# inh_strengths[gid] = inh_strens
import pandas as pd
conns = cells[0]._connections
# inh_conns = conns[np.where(conns.source_node._population == "inh_stim")]
# inh_conns = conns[np.where(conns.source_node._population == "exc_stim")]
ids = []
weights = []
distance = []
is_basal = []
soma = cells[0].hobj.soma[0](0.5)
for con in conns:
if con.source_node._population == "exc_stim":
ids.append(con.source_node.node_id)
weights.append(con._syn.initW)
postseg = con._connector.postseg()
distance.append(float(h.distance(soma, postseg)))
loc = str(postseg).split('.')[1][:4]
if loc == "dend":
is_basal.append(True)
elif loc == "apic":
is_basal.append(False)
else:
raise Exception(loc + "should be dend or apic.")
exc_syns = pd.DataFrame()
exc_syns["node_id"] = ids
exc_syns["weight"] = weights
exc_syns["is_basal"] = is_basal
exc_syns["distance"] = distance
exc_syns.to_csv("exc_syn_info.csv", index=False)
import pdb; pdb.set_trace()
sim.run()
# pc.barrier()
# #import pdb; pdb.set_trace()
# raster_file = './output/spikes.h5'
# frs = {}
# local_gids = list(exc_strengths.keys())
# #local_gids = local_gids - np.min(local_gids)
# #import pdb; pdb.set_trace()
# for key in local_gids:
# frs[key] = 0
# try:
# f = h5py.File(raster_file,'r')
# spike_keys = list(f['spikes'].keys())
# if len(spike_keys) > 1:
# raise Exception("Spike keys: " + str(spike_keys))
# spike_key = list(f['spikes'].keys())[0]
# timestamps = f['spikes'][spike_key]['timestamps'].value
# gids = f['spikes'][spike_key]['node_ids'].value
# for i in range(len(gids)):
# gid = gids[i] + min(local_gids)
# if gid in local_gids and timestamps[i] >= 200:
# frs[gid] += 1
# except:
# print("No spikes.")
# df = pd.DataFrame()
# dicts = [{"gid": gid, "FR": frs[gid] / 5, "num_exc": len(exc_strengths[gid]), "num_inh": len(inh_strengths[gid]),
# "avg_exc": np.mean(exc_strengths[gid]), "avg_inh": np.mean(inh_strengths[gid]),
# "max_exc": np.max(exc_strengths[gid]), "max_inh": np.max(inh_strengths[gid]),
# "std_exc": np.std(exc_strengths[gid]), "std_inh": np.std(inh_strengths[gid]),
# "skew_exc": skew(exc_strengths[gid]), "skew_inh": skew(inh_strengths[gid]), "comb_stren": comb_strens[gid]} for gid in local_gids]
# df = pd.DataFrame(dicts)
# #df.set_index("gid")
# df.to_csv(fname+str(MPI_rank)+'.csv', index=False)
# #import pdb; pdb.set_trace()
# pc.barrier()
# if MPI_rank == 0:
# base_df = pd.read_csv(fname+"0.csv", index_col="gid")
# res_df = pd.concat([base_df] + [pd.read_csv(fname+str(rank)+".csv", index_col="gid") for rank in range(1, MPI_size)])
# frs_df = pd.read_csv('frs_temp.csv', index_col="gid")
# res_df = res_df.join(frs_df)
# os.remove('frs_temp.csv')
# [os.remove(fname+str(rank)+".csv") for rank in range(MPI_size)]
# res_df.to_csv(fname+".csv")