from sNMO.optimizer import snmOptimizer
from brian2.units import *
import nevergrad as ng
import network_CADEX as cadex
import numpy as np
from spike_train_utils import build_isi_from_spike_train, plot_patterns, ecv2
from sNMO.error.spikeTrainErrors import emd_pdist_spk, isi_swasserstein_2d, isi_wasserstein_dd, biemd
from matplotlib.pyplot import savefig
from network_fitting import _ng_error_func, SNM_SF, RUN_TIME
import time as time
import pandas as pd
from joblib import Parallel, delayed, dump, load
import sbi
import torch
import os


prev_results_paths = pd.read_csv('./output/best_networks_paths.csv')
prev_results_paths = prev_results_paths['path'].values
var_dict = ng.p.Dict(
                    path=ng.p.Choice(prev_results_paths),
                    i_pr = ng.p.Scalar(lower=0.01, upper=0.2),
                    b_change=ng.p.Scalar(lower=0.75, upper=10),
                    tauw=ng.p.Scalar(lower=0.5, upper=2.5),
                    
                    )




rounds = 50
workers = 12
RUN_TIME = 60 * second
#ap2 
results_comp = pd.read_csv('./output/_ap2_invivo_ref.csv')
ref_cv2_post = results_comp['invivo_post_cv2'].values
ref_fr_post = results_comp['invivo_post_fr'].values

def parse_params_func(params):
    #in our case we load the params from a file
    params = params.copy()
    network_args, synapse_params = cadex.load_params_from_joblib( params['path'].replace('.csv', '.joblib'))

    network_events = {}#network_args.pop('network_events')
    network_str = f'eval: P[:500].pr = {params["i_pr"]}; CRH.bd = CRH_params["b"] * {params["b_change"]}; CRH[:500].tauw = (CRH[:500].tauw/ms *  {params["tauw"]})*ms'
    network_events[10] = network_str
    network_args['network_events'] = network_events
    return network_args, synapse_params

def error_func(params):
    network_args, synapse_params = parse_params_func(params)
    network_events = network_args.pop('network_events')
    #run the network and get the output
    output = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, 
                                 synapse_params=synapse_params, network_events=network_events)
   
    spikes = output['spikes']
    #now we need to calculate the error
    sim_isi_strt = build_isi_from_spike_train(spikes, low_cut=16, high_cut=40, indiv=True)[:500]
   


    return sim_isi_strt

def shift_error_func(ISI):
    #compute the mean fr
    mean_fr = np.nanmean([(len(isi)+1) / (40-16) for isi in ISI])
    #compute the cv2
    mean_cv2 = np.nanmean([ecv2(isi) for isi in ISI])

    #compute the mse
    fr_error = np.mean((mean_fr - np.mean(ref_fr_post)) ** 2)
    cv2_error = np.mean((mean_cv2 - np.mean(ref_cv2_post)) ** 2)
    return {'fr_error': fr_error, 'cv2_error': cv2_error, 
            'mean_fr': mean_fr, 'mean_cv2': mean_cv2}



def ng_error_func(params):
    #params in this case is an array of the parameters
    run_id = np.random.randint(0,1000000) + time.time()
    sim_isi_strt = error_func(params)
    return shift_error_func(sim_isi_strt)



def nevergrad():

    optimizer = ng.optimizers.ScrHammersleySearch(parametrization=var_dict, budget=int(workers*rounds), num_workers=workers, )
    optimizer.enable_pickling()
    #optimizer = ng.optimizers.ScrHammersleySearch.load(f"optimizer.joblib")
    output_dict = {}
    for _ in np.arange((optimizer.budget - optimizer.num_ask) // workers):
        points_list = []
        for x in np.arange(workers):    
            while True:
                try:
                    points_list.append(optimizer.ask())#
                    break
                except:
                    print("failed to ask")
            
        error = Parallel(n_jobs=workers)(delayed(ng_error_func)(p.value) for p in points_list)
        for points, er in zip(points_list, error):

            optimizer.tell(points, np.sum([er['fr_error'], er['cv2_error']]))
            output_dict[f"{points.value['path']}_{points.value['i_pr']*np.random.randint(0, 1000000)}"] = {
                'i_pr': points.value['i_pr'],
                'b_change': points.value['b_change'],
                'tauw': points.value['tauw'],
                'fr_error': er['fr_error'],
                'mean_fr': er['mean_fr'],
                'cv2_error': er['cv2_error'],
                'mean_cv2': er['mean_cv2'],
                'path': points.value['path'],
            }

        optimizer.dump(f"optimizer.joblib")
    output_df = pd.DataFrame.from_dict(output_dict, orient='index')
    output_df.to_csv(f'./output/refitsswitch.csv')
if __name__ == '__main__':
    nevergrad()