from sNMO.optimizer import snmOptimizer
from brian2.units import *
import nevergrad as ng
import network_CADEX as cadex
import numpy as np
from spike_train_utils import build_isi_from_spike_train, plot_patterns
from sNMO.error.spikeTrainErrors import emd_pdist_spk, isi_swasserstein_2d, isi_wasserstein_dd, biemd
from matplotlib.pyplot import savefig
from network_fitting import _ng_error_func, parse_params, SNM_SF, RUN_TIME, BASELINE_TIMES, BASELINE_TIMES_END, OPT_TONIC
import time as time
import pandas as pd
from joblib import Parallel, delayed, dump, load
import sbi
import torch
import os
clustered_conn_params = ng.p.Dict(**{'NUM_CLUSTERS': ng.p.Scalar(lower=2, upper=40).set_integer_casting(),
'inter_cluster_num': ng.p.Scalar(lower=0.025, upper=1),})
##synaptic params
pexc_input_conn = ng.p.Dict(Nexc=ng.p.TransitionChoice(np.array([10, 30, 50,100, 150, 200, 300, 400, 500]).astype(int)),
Exc_p = ng.p.Log(lower=0.001, upper=1))
default_input_conn = ng.p.Dict(Nexc=500, Exc_p=None) #
p_conn = np.array([0.01, 0.02, 0.04, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5]) #0.005,
var_dict = ng.p.Dict(ee=ng.p.Scalar(lower=0.02, upper=15),
ii=ng.p.Scalar(lower=0.02, upper=15),
wcrh=0,
weinh = ng.p.Scalar(lower=0.02, upper=20),
#der=ng.p.Scalar(lower=0.01, upper=2),
de=ng.p.Scalar(lower=0.1, upper=20),
di=ng.p.Scalar(lower=0.1, upper=20),
dcrh=1,
input_hz=ng.p.Log(lower=10,upper=500),
exti_in=ng.p.Log(lower=0.001,upper=500),
wexti=ng.p.Scalar(lower=0.00002, upper=15),
##network params
p_ei = ng.p.TransitionChoice(p_conn.copy()),
p_ie = ng.p.TransitionChoice(p_conn.copy()),
##connectivity param
connectivity='random',
exc_input=default_input_conn, #ng.p.Choice([pexc_input_conn, default_input_conn]),
#ext_to_gaba = ng.p.TransitionChoice([True, False])
#ng.p.Choice([pexc_input_conn, default_input_conn]),
#ext_to_gaba = ng.p.TransitionChoice([True, False])
i_pr = 0.05,
b_change=ng.p.Scalar(lower=0.75, upper=7),
)
OUTPUT_FOLDER = "/media/smestern/sgbackup/output/"
rounds = 400
workers = 6
parse_params_func = lambda x: parse_params(x, all_fast=True, feedforward=True)
def error_func(params):
network_args, synapse_params = parse_params_func(params)
network_events = network_args.pop('network_events')
#run the network and get the output
output = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, synapse_params=synapse_params, network_events=network_events)
#also run the network in #tonic mode
if OPT_TONIC:
network_events= {5:'disconnect_GABA_increase_adapt',20: 'double_EPSP'}
synapse_params_dis = synapse_params.copy()
synapse_params_dis['taup'] = 9e9*second
synapse_params_dis['taubr'] = 9e9*second
output_tonic = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, synapse_params=synapse_params_dis, network_events=network_events)
spikes = output['spikes']
volt = output['states']
#now we need to calculate the error
sim_isi_strt = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)[:500]
sim_isi_end = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1], indiv=True)[:500]
#output_tonic
if OPT_TONIC:
spikes_tonic = output_tonic['spikes']
sim_isi_strt_tonic = build_isi_from_spike_train(spikes_tonic, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)[:500]
sim_isi_end_tonic = build_isi_from_spike_train(spikes_tonic, low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1],indiv=True)[:500]
else:
sim_isi_strt_tonic = None
sim_isi_end_tonic = None
return sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic
def ng_error_func(params):
#params in this case is an array of the parameters
run_id = np.random.randint(0,1000000) + time.time()
sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic = error_func(params)
return _ng_error_func(params, sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic, parse_params_func=parse_params_func)
def nevergrad():
optimizer = ng.optimizers.ScrHammersleySearch(parametrization=var_dict, budget=int(workers*rounds), num_workers=workers, )
optimizer.enable_pickling()
optimizer = ng.optimizers.ScrHammersleySearch.load(f"optimizer.joblib")
for _ in np.arange((optimizer.budget - optimizer.num_ask) // workers):
points_list = []
for x in np.arange(workers):
while True:
try:
points_list.append(optimizer.ask())#
break
except:
print("failed to ask")
error = Parallel(n_jobs=workers)(delayed(ng_error_func)(p.value) for p in points_list)
for points, er in zip(points_list, error):
optimizer.tell(points, er)
optimizer.dump(f"optimizer.joblib")
if __name__ == '__main__':
nevergrad()