from sNMO.optimizer import snmOptimizer
from brian2.units import *
import nevergrad as ng
import network_CADEX as cadex
import numpy as np
from spike_train_utils import build_isi_from_spike_train, plot_patterns
from sNMO.error.spikeTrainErrors import emd_pdist_spk, isi_swasserstein_2d, isi_wasserstein_dd, biemd
from matplotlib.pyplot import savefig
import time as time
import pandas as pd
from joblib import Parallel, delayed, dump, load
import sbi
import torch
import os
import pyabf
##synaptic params
pexc_input_conn = ng.p.Dict(Nexc=ng.p.TransitionChoice(np.array([10, 30, 50,100, 150, 200, 300, 400, 500]).astype(int)),
Exc_p = ng.p.Log(lower=0.001, upper=1))
default_input_conn = ng.p.Dict(Nexc=500, Exc_p=None) #
p_conn = np.array([0.01, 0.02, 0.04, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5])
var_dict = ng.p.Dict(ee=ng.p.Scalar(lower=0.02, upper=15),
ii=ng.p.Scalar(lower=0.02, upper=15),
wcrh=ng.p.Scalar(lower=0.002, upper=15),
der= 1, #ng.p.Scalar(lower=0.01, upper=1),
de=ng.p.Scalar(lower=0.1, upper=20),
di=ng.p.Scalar(lower=0.1, upper=20),
dcrh=ng.p.Scalar(lower=10, upper=500),
input_hz=ng.p.Log(lower=10,upper=500),
exti_in=0,#ng.p.Log(lower=0.001,upper=500),
wexti=0,#ng.p.Log(lower=0.00002, upper=4),
##network params
p_ei = ng.p.TransitionChoice(p_conn.copy()),
p_ie = ng.p.TransitionChoice(p_conn.copy()),
##connectivity param
connectivity='random',
exc_input=default_input_conn, #ng.p.Choice([pexc_input_conn, default_input_conn]),
#ext_to_gaba = ng.p.TransitionChoice([True, False]),
i_pr = 0.05,
b_change=ng.p.Scalar(lower=0.75, upper=7),
)
OUTPUT_FOLDER = "/media/smestern/sgbackup/aoi_paper_2/output/"
rounds = 400
workers = 6
#Load the data to fit to
burst_ref = np.load("unit_3_burst_match1.npy")
non_burst_ref = np.load("unit_3_nonburst_match.npy")
#some error scale params
SNM_SF = 0.05
DT = 0.1*ms
RUN_TIME = 40
BASELINE_TIMES = [10, 15]
BASELINE_TIMES_END = [25, 40]
OPT_TONIC = False
def parse_params(params, all_fast=False, mono_synapse=False, threeway=False, slow_gaba=False, slow_crh=False, feedforward=False):
#here we will pop_out the params into their respective variables
#first bring the exc connectivity params up to the top level
params.update(params['exc_input'])
network_args = {'p_ei': params['p_ei'], 'p_ie': params['p_ie']}
synapse_params = {'taue': params['de']*ms,
'taui': params['di']*ms, 'taucrh': params['dcrh']*ms,
'wcrh': params['wcrh']*nS, 'we': params['ee']*nS, 'wi': params['ii']*nS, 'exte_in': params['input_hz']*Hz,
'exti_in': params['exti_in']*Hz,
'wexti': params['wexti']*nS, 'Nexc':params['Nexc'],
}
if threeway:
network_args.update({'p_ee': params['p_ee']})
synapse_params.update({'we_e': params['we_e']*nS})
if params['Exc_p'] is not None:
network_args['ECRH_params'] = {'p': params['Exc_p']}
else:
network_args['ECRH_params'] = None
if all_fast:
network_args.update(cadex.ALL_FAST_SYNAPSE_MODELS)
if mono_synapse:
network_args.update(cadex.MONO_EXP_SYNAPSE_MODELS)
if slow_gaba and not slow_crh:
network_args.update(cadex.SLOW_GABA_SYNAPSE_MODELS)
#network_args["ECRH_model"] = cadex.MONO_EXP_SYNAPSE_MODELS["ECRH_model"]
if slow_gaba and slow_crh:
network_args.update(cadex.SLOW_GABA_SYNAPSE_MODELS)
network_args["EI_model"] = cadex.DEFAULT_SYNAPSE_MODELS["EI_model"]
synapse_params['taup'] = 9999*second
synapse_params['taubr'] = 9999*second
#here we introduce
if params['connectivity'] != 'random' and params['connectivity'] != 'threeway':
network_args['connectivity_params'] = {'NUM_CLUSTERS': params['connectivity']['NUM_CLUSTERS'], 'inter_cluster_num': int(params['connectivity']['NUM_CLUSTERS']*params['connectivity']['inter_cluster_num'])+1}
network_args['connectivity_params'].update({'p_ei': params['p_ei'], 'p_ie': params['p_ie']})
network_args['connectivity'] = "custom_circ_graph"
else:
network_args['connectivity'] = params['connectivity']
if feedforward:
network_args['p_ei'] = 0.0
network_args['p_ie'] = 0.0
base_str = "eval: "
#base_str += f"poisson_exc.rates = poisson_exc.rates/Hz*{params['epsp_rate_switch']}*Hz; "
base_str += f"P[:{500}].pr = {params['i_pr']}; "
base_str += f" P[:{500}].p_e = {params['e_pr']}; " if 'e_pr' in params else ""
base_str += f"P[:{500}].p_exti = {params['exti_p']}; " if 'exti_p' in params else ""
base_str += f"CRH[:{500}].bd=(CRH.b[:{500}]/nS * {params['b_change']})*nS; CRH[:{500}].tauw = (CRH[:{500}].tauw/ms * {1.5})*ms;"
network_args['network_events'] = {30: base_str}
return network_args, synapse_params
def error_func(params):
network_args, synapse_params = parse_params(params)
network_events = network_args.pop('network_events')
#run the network and get the output
output = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, synapse_params=synapse_params, network_events=network_events)
#also run the network in #tonic mode
if OPT_TONIC:
network_events= {5:'disconnect_GABA_increase_adapt',20: 'double_EPSP'}
synapse_params_dis = synapse_params.copy()
synapse_params_dis['taup'] = 9e9*second
synapse_params_dis['taubr'] = 9e9*second
output_tonic = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, synapse_params=synapse_params_dis, network_events=network_events)
spikes = output['spikes']
volt = output['states']
#now we need to calculate the error
sim_isi_strt = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)[:500]
sim_isi_end = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1], indiv=True)[:500]
#output_tonic
if OPT_TONIC:
spikes_tonic = output_tonic['spikes']
sim_isi_strt_tonic = build_isi_from_spike_train(spikes_tonic, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)[:500]
sim_isi_end_tonic = build_isi_from_spike_train(spikes_tonic, low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1],indiv=True)[:500]
else:
sim_isi_strt_tonic = None
sim_isi_end_tonic = None
return sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic
def leakyrelu_error(x, target=0.0, thres=550, scale=0.1):
#this is an error function that replicates a leaky relu like function.
#here it measures the error between x, and target. but if x is below thres, it is scaled by scale
error = np.abs(x - target)
error = np.where(x < thres, error*scale, error)
return error
def ng_error_func(params):
sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic = error_func(params)
return _ng_error_func(params, sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic)
def _ng_error_func(params, sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic, parse_params_func=parse_params):
#params in this case is an array of the parameters
run_id = np.random.randint(0,1000000) + time.time()
#iter through and compute the error for each unit
burst_error = []
nonburst_error = []
burst_pre_fr = []
nonburst_pre_fr = []
burst_post_fr = []
nonburst_post_fr = []
for unit in np.arange(500):
if len(sim_isi_strt[unit]) > 2:
burst_error.append(isi_wasserstein_dd(sim_isi_strt[unit], burst_ref))
nonburst_error.append(isi_wasserstein_dd(sim_isi_strt[unit], non_burst_ref))
burst_pre_fr.append(len(sim_isi_strt[unit])/np.diff(BASELINE_TIMES)[0])
burst_post_fr.append(len(sim_isi_end[unit])/np.diff(BASELINE_TIMES_END)[0])
if OPT_TONIC:
nonburst_pre_fr.append(len(sim_isi_strt_tonic[unit])/np.diff(BASELINE_TIMES)[0])
nonburst_post_fr.append(len(sim_isi_end_tonic[unit])/np.diff(BASELINE_TIMES_END)[0])
else:
nonburst_pre_fr.append(0)
nonburst_post_fr.append(0)
else:
burst_error.append(999)
nonburst_error.append(999)
burst_pre_fr.append(0)
nonburst_pre_fr.append(0)
burst_post_fr.append(0)
nonburst_post_fr.append(0)
#compute out the burst errors
burst_error = np.array(burst_error)
nonburst_error = np.array(nonburst_error)
burst_error = burst_error[~np.isnan(burst_error)]
nonburst_error = nonburst_error[~np.isnan(nonburst_error)]
burst_error_full = np.copy(burst_error)
nonburst_error_full = np.copy(nonburst_error)
#get the mean of the lowest 10% of the errors
burst_error = np.mean(np.sort(burst_error)[:int(len(burst_error)*0.1)])
#in this case we also want to see if any units are nonbursting, we want heterogeneity. but this error should be smaller
nonburst_error = np.mean(np.sort(nonburst_error)[:int(len(nonburst_error)*0.1)]) / 50 if OPT_TONIC else 0
#examine the mean pre-post for burst and nonburst
burst_pre_fr_mean = np.mean(burst_pre_fr)
nonburst_pre_fr_mean = np.mean(nonburst_pre_fr)
burst_post_fr_mean = np.mean(burst_post_fr)
nonburst_post_fr_mean = np.mean(nonburst_post_fr)
burst_pre_max = np.max(burst_pre_fr)
nonburst_pre_max = np.max(nonburst_pre_fr)
burst_post_max = np.max(burst_post_fr)
nonburst_post_max = np.max(nonburst_post_fr)
pre_post_dict = {'burst_pre_fr_mean': burst_pre_fr_mean,'nonburst_pre_fr_mean': nonburst_pre_fr_mean,'burst_post_fr_mean': burst_post_fr_mean,'nonburst_post_fr_mean': nonburst_post_fr_mean,'burst_pre_max': burst_pre_max,'nonburst_pre_max': nonburst_pre_max,'burst_post_max': burst_post_max,'nonburst_post_max': nonburst_post_max}
#get the idx of the burst_pre_fr
burst_pre_fr_idx = np.argsort(burst_pre_fr)
burst_pre_fr_idx = burst_pre_fr_idx[burst_pre_fr_idx <=10]
# c22_error = pycatch22.catch22_all(volt.d_I[burst_pre_fr_idx[0], :]/pA )
# c22_error = c22_scaler.transform(np.array(c22_error['values']).reshape(1,-1))
# c22_error = np.mean(np.abs(c22_error - c22_feat_baseline))
c22_error = 0
#some other error metrics, not used for the optimization
mean_volt = np.mean(volt.v[:500, :]/mV)
mean_I = np.std(volt.d_I[:500, :]/pA)
#SNM error, this is a bit more complicated
#if there is no voltage, or no current, then we will just set the error to 999
if np.isnan(mean_volt) or np.isnan(mean_I):
return 999
else:
#ensure volt is below -50mV most of the time
volt_error = np.mean(np.where(volt.v[:500, :]/mV > -10, 1, 0))
#ensure the peak positive current is low
#measure how long the current is above 250pA
current_error = np.mean(np.where(volt.d_I[:500, :]/pA > 400, 1, 0)) * ( RUN_TIME)
#peak negative is a bit more complicated, we need to find the peak negative current, however this is skewed by the action potentials
#so we need to find the peak negative current between spikes
#current_error = leakyrelu_error(np.amax(volt.d_I[:500, :]/pA), target=0, thres=550, scale=0.1)
#just clip to 0
#current_error = np.clip(current_error, 0, None)/550 #divide by 250 to normalize
snm_error = (volt_error + current_error) * SNM_SF
#plot and save
error_id = str(np.round(burst_error + nonburst_error + snm_error,3)).ljust(5, '0')
plot_patterns(spikes.spike_trains(), 500);savefig(os.path.join(OUTPUT_FOLDER, f"spikes_{error_id}_{int(run_id)}.png"))
#save the params
params_out = {}
network_args, synapse_params = parse_params_func(params)
params_out.update(network_args)
params_out.update(synapse_params)
params_out.update(network_args['ECRH_params']) if network_args['ECRH_params'] is not None else None
params_out['burst_error'] = burst_error
params_out['nonburst_error'] = nonburst_error
params_out['mean_volt'] = mean_volt
params_out['mean_I'] = mean_I
params_out['snm_error'] = snm_error
params_out['runtime'] = RUN_TIME
params_out['max_D_I'] = np.amax(volt.d_I[:500, :]/pA)
params_out.update(pre_post_dict)
params_df = pd.DataFrame(params_out, index=[0])
params_df['connectivity'] = str(params['connectivity'])
params_df.to_csv(os.path.join(OUTPUT_FOLDER, f"params_{error_id}_{int(run_id)}.csv"))
#save the spikes and volts with heavy compression
states = volt.get_states(['v', 'd_I', 'ge', 'gi'])
for key in states:
states[key] = states[key][:, :10]
#add the burst and nonburst error to the states
states['burst_error'] = burst_error_full
states['nonburst_error'] = nonburst_error_full
#add in the pre-post firing rates
states['burst_pre_fr'] = burst_pre_fr
states['nonburst_pre_fr'] = nonburst_pre_fr
states['burst_post_fr'] = burst_post_fr
states['nonburst_post_fr'] = nonburst_post_fr
#add in the c22 error
states['c22_error'] = c22_error
dump([spikes.spike_trains(), states], os.path.join(OUTPUT_FOLDER, f"spikes_{error_id}_{int(run_id)}.joblib"), compress=('lzma', 9))
#also dump the exact params for later use
dump({'network_args': network_args, 'synapse_params': synapse_params, 'network_events': {}}, os.path.join(OUTPUT_FOLDER, f"params_{error_id}_{int(run_id)}.joblib"),
compress=('lzma', 9))
return burst_error + nonburst_error + snm_error + c22_error
def nevergrad():
optimizer = ng.optimizers.ScrHammersleySearch(parametrization=var_dict, budget=int(workers*rounds), num_workers=workers, )
optimizer.enable_pickling()
#optimizer = ng.optimizers.ParaPortfolio.load(f"optimizer.joblib")
for _ in np.arange((optimizer.budget - optimizer.num_ask) // workers):
points_list = []
for x in np.arange(workers):
while True:
try:
points_list.append(optimizer.ask())#
break
except:
print("failed to ask")
error = Parallel(n_jobs=workers)(delayed(ng_error_func)(p.value) for p in points_list)
for points, er in zip(points_list, error):
optimizer.tell(points, er)
optimizer.dump(f"optimizer.joblib")
if __name__ == '__main__':
nevergrad()