from sNMO.optimizer import snmOptimizer
from brian2.units import *
import nevergrad as ng
import network_CADEX as cadex
import numpy as np
from spike_train_utils import build_isi_from_spike_train, plot_patterns
from sNMO.error.spikeTrainErrors import emd_pdist_spk, isi_swasserstein_2d, isi_wasserstein_dd, biemd
from matplotlib.pyplot import savefig
import time as time
import pandas as pd
from joblib import Parallel, delayed, dump, load
import sbi
import torch
import os
import pyabf
    
##synaptic params

pexc_input_conn = ng.p.Dict(Nexc=ng.p.TransitionChoice(np.array([10, 30, 50,100, 150, 200, 300, 400, 500]).astype(int)),
                    Exc_p = ng.p.Log(lower=0.001, upper=1))
default_input_conn = ng.p.Dict(Nexc=500, Exc_p=None) #

p_conn = np.array([0.01, 0.02, 0.04, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5])  
var_dict = ng.p.Dict(ee=ng.p.Scalar(lower=0.02, upper=15), 
                    ii=ng.p.Scalar(lower=0.02, upper=15),
                    wcrh=ng.p.Scalar(lower=0.002, upper=15),
                    der= 1, #ng.p.Scalar(lower=0.01, upper=1),
                    de=ng.p.Scalar(lower=0.1, upper=20),
                    di=ng.p.Scalar(lower=0.1, upper=20),
                    dcrh=ng.p.Scalar(lower=10, upper=500),
                    input_hz=ng.p.Log(lower=10,upper=500),
                    exti_in=0,#ng.p.Log(lower=0.001,upper=500),
                    
                    wexti=0,#ng.p.Log(lower=0.00002, upper=4),

                    ##network params
                    p_ei = ng.p.TransitionChoice(p_conn.copy()),
                    p_ie = ng.p.TransitionChoice(p_conn.copy()),

                    ##connectivity param
                    connectivity='random',
                    exc_input=default_input_conn,  #ng.p.Choice([pexc_input_conn, default_input_conn]),
                    #ext_to_gaba = ng.p.TransitionChoice([True, False]),
                    i_pr = 0.05,
                    b_change=ng.p.Scalar(lower=0.75, upper=7),

)

OUTPUT_FOLDER = "/media/smestern/sgbackup/aoi_paper_2/output/"



rounds = 400
workers = 6
#Load the data to fit to
burst_ref = np.load("unit_3_burst_match1.npy")
non_burst_ref = np.load("unit_3_nonburst_match.npy")


#some error scale params
SNM_SF = 0.05
DT = 0.1*ms
RUN_TIME = 40
BASELINE_TIMES = [10, 15]
BASELINE_TIMES_END = [25, 40]
OPT_TONIC = False

def parse_params(params, all_fast=False, mono_synapse=False, threeway=False, slow_gaba=False, slow_crh=False, feedforward=False):
    #here we will pop_out the params into their respective variables
    #first bring the exc connectivity params up to the top level
    params.update(params['exc_input'])
    network_args = {'p_ei': params['p_ei'], 'p_ie': params['p_ie']}
    synapse_params = {'taue': params['de']*ms,
                      'taui': params['di']*ms, 'taucrh': params['dcrh']*ms, 
                      'wcrh': params['wcrh']*nS, 'we': params['ee']*nS, 'wi': params['ii']*nS, 'exte_in': params['input_hz']*Hz,
                      'exti_in': params['exti_in']*Hz, 
                        'wexti': params['wexti']*nS, 'Nexc':params['Nexc'], 
                      } 
    
    if threeway:
        network_args.update({'p_ee': params['p_ee']})
        synapse_params.update({'we_e': params['we_e']*nS})
    
    if params['Exc_p'] is not None:
        network_args['ECRH_params'] = {'p': params['Exc_p']}
    else:
        network_args['ECRH_params'] = None
    
    if all_fast:
        network_args.update(cadex.ALL_FAST_SYNAPSE_MODELS)
    if mono_synapse:
        network_args.update(cadex.MONO_EXP_SYNAPSE_MODELS)
    if slow_gaba and not slow_crh:
        network_args.update(cadex.SLOW_GABA_SYNAPSE_MODELS)
    
        #network_args["ECRH_model"] = cadex.MONO_EXP_SYNAPSE_MODELS["ECRH_model"]
    if slow_gaba and slow_crh:
        network_args.update(cadex.SLOW_GABA_SYNAPSE_MODELS)
        network_args["EI_model"] = cadex.DEFAULT_SYNAPSE_MODELS["EI_model"]

    synapse_params['taup'] = 9999*second
    synapse_params['taubr'] = 9999*second



    #here we introduce
    if params['connectivity'] != 'random' and params['connectivity'] != 'threeway':
        network_args['connectivity_params'] = {'NUM_CLUSTERS': params['connectivity']['NUM_CLUSTERS'], 'inter_cluster_num': int(params['connectivity']['NUM_CLUSTERS']*params['connectivity']['inter_cluster_num'])+1}
        network_args['connectivity_params'].update({'p_ei': params['p_ei'], 'p_ie': params['p_ie']})
        network_args['connectivity'] = "custom_circ_graph"
    else:
        network_args['connectivity'] = params['connectivity']

    if feedforward:
        network_args['p_ei'] = 0.0
        network_args['p_ie'] = 0.0
       


    base_str = "eval: "
    #base_str += f"poisson_exc.rates = poisson_exc.rates/Hz*{params['epsp_rate_switch']}*Hz; "
    base_str += f"P[:{500}].pr = {params['i_pr']}; "
    base_str += f" P[:{500}].p_e = {params['e_pr']}; " if 'e_pr' in params else ""
    base_str += f"P[:{500}].p_exti = {params['exti_p']}; " if 'exti_p' in params else ""
    base_str += f"CRH[:{500}].bd=(CRH.b[:{500}]/nS *  {params['b_change']})*nS; CRH[:{500}].tauw = (CRH[:{500}].tauw/ms *  {1.5})*ms;"
    network_args['network_events'] = {30: base_str}
    return network_args, synapse_params

def error_func(params):
    network_args, synapse_params = parse_params(params)
    network_events = network_args.pop('network_events')

    #run the network and get the output
    output = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, synapse_params=synapse_params, network_events=network_events)
    #also run the network in #tonic mode
    if OPT_TONIC:
        network_events= {5:'disconnect_GABA_increase_adapt',20: 'double_EPSP'}
        synapse_params_dis = synapse_params.copy()
        synapse_params_dis['taup'] = 9e9*second
        synapse_params_dis['taubr'] = 9e9*second
        output_tonic = cadex.cadex_network(run_time=RUN_TIME, GABA_param_file=cadex.CRH_PARAM_FILE, network_args=network_args, synapse_params=synapse_params_dis, network_events=network_events)
    spikes = output['spikes']
    volt = output['states']

    #now we need to calculate the error
    sim_isi_strt = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)[:500]
    sim_isi_end = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1], indiv=True)[:500]

    #output_tonic 
    if OPT_TONIC:
        spikes_tonic = output_tonic['spikes']
        sim_isi_strt_tonic = build_isi_from_spike_train(spikes_tonic, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)[:500]
        sim_isi_end_tonic = build_isi_from_spike_train(spikes_tonic,  low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1],indiv=True)[:500]
    else:
        sim_isi_strt_tonic = None
        sim_isi_end_tonic = None


    return sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic


def leakyrelu_error(x, target=0.0, thres=550, scale=0.1):
    #this is an error function that replicates a leaky relu like function.
    #here it measures the error between x, and target. but if x is below thres, it is scaled by scale
    error = np.abs(x - target)
    error = np.where(x < thres, error*scale, error)
    return error

def ng_error_func(params):
    
    sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic = error_func(params)
    return _ng_error_func(params, sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic)



def _ng_error_func(params, sim_isi_strt, sim_isi_end, spikes, volt, sim_isi_strt_tonic, sim_isi_end_tonic, parse_params_func=parse_params):
    #params in this case is an array of the parameters
    run_id = np.random.randint(0,1000000) + time.time()
    #iter through and compute the error for each unit
    burst_error = []
    nonburst_error = []
    burst_pre_fr = []
    nonburst_pre_fr = []
    burst_post_fr = []
    nonburst_post_fr = []
    for unit in np.arange(500):
        if len(sim_isi_strt[unit]) > 2:
            burst_error.append(isi_wasserstein_dd(sim_isi_strt[unit], burst_ref))
            nonburst_error.append(isi_wasserstein_dd(sim_isi_strt[unit], non_burst_ref))
            burst_pre_fr.append(len(sim_isi_strt[unit])/np.diff(BASELINE_TIMES)[0])
            burst_post_fr.append(len(sim_isi_end[unit])/np.diff(BASELINE_TIMES_END)[0])
            if OPT_TONIC:   
                nonburst_pre_fr.append(len(sim_isi_strt_tonic[unit])/np.diff(BASELINE_TIMES)[0])
                nonburst_post_fr.append(len(sim_isi_end_tonic[unit])/np.diff(BASELINE_TIMES_END)[0])
            else:
                nonburst_pre_fr.append(0)
                nonburst_post_fr.append(0)
        else:
            burst_error.append(999)
            nonburst_error.append(999)
            burst_pre_fr.append(0)
            nonburst_pre_fr.append(0)
            burst_post_fr.append(0)
            nonburst_post_fr.append(0)
    #compute out the burst errors
    burst_error = np.array(burst_error) 
    nonburst_error = np.array(nonburst_error)
    burst_error = burst_error[~np.isnan(burst_error)]
    nonburst_error = nonburst_error[~np.isnan(nonburst_error)]
    burst_error_full = np.copy(burst_error)
    nonburst_error_full = np.copy(nonburst_error)
    #get the mean of the lowest 10% of the errors
    burst_error = np.mean(np.sort(burst_error)[:int(len(burst_error)*0.1)])
    #in this case we also want to see if any units are nonbursting, we want heterogeneity. but this error should be smaller
    nonburst_error =  np.mean(np.sort(nonburst_error)[:int(len(nonburst_error)*0.1)]) / 50 if OPT_TONIC else 0

    #examine the mean pre-post for burst and nonburst
    burst_pre_fr_mean = np.mean(burst_pre_fr)
    nonburst_pre_fr_mean = np.mean(nonburst_pre_fr)
    burst_post_fr_mean = np.mean(burst_post_fr)
    nonburst_post_fr_mean = np.mean(nonburst_post_fr)
    burst_pre_max = np.max(burst_pre_fr)
    nonburst_pre_max = np.max(nonburst_pre_fr)
    burst_post_max = np.max(burst_post_fr)
    nonburst_post_max = np.max(nonburst_post_fr)
    pre_post_dict = {'burst_pre_fr_mean': burst_pre_fr_mean,'nonburst_pre_fr_mean': nonburst_pre_fr_mean,'burst_post_fr_mean': burst_post_fr_mean,'nonburst_post_fr_mean': nonburst_post_fr_mean,'burst_pre_max': burst_pre_max,'nonburst_pre_max': nonburst_pre_max,'burst_post_max': burst_post_max,'nonburst_post_max': nonburst_post_max}


    #get the idx of the burst_pre_fr
    burst_pre_fr_idx = np.argsort(burst_pre_fr)
    burst_pre_fr_idx = burst_pre_fr_idx[burst_pre_fr_idx <=10]

    # c22_error = pycatch22.catch22_all(volt.d_I[burst_pre_fr_idx[0], :]/pA )
    # c22_error = c22_scaler.transform(np.array(c22_error['values']).reshape(1,-1))
    # c22_error = np.mean(np.abs(c22_error - c22_feat_baseline))
    c22_error = 0


    #some other error metrics, not used for the optimization
    mean_volt = np.mean(volt.v[:500, :]/mV)
    mean_I = np.std(volt.d_I[:500, :]/pA)

    #SNM error, this is a bit more complicated
    #if there is no voltage, or no current, then we will just set the error to 999
    if np.isnan(mean_volt) or np.isnan(mean_I):
        return 999
    else:
        #ensure volt is below -50mV most of the time
        volt_error = np.mean(np.where(volt.v[:500, :]/mV > -10, 1, 0))
        #ensure the peak positive current is low
        #measure how long the current is above 250pA
        current_error = np.mean(np.where(volt.d_I[:500, :]/pA > 400, 1, 0)) * ( RUN_TIME)
        
        #peak negative is a bit more complicated, we need to find the peak negative current, however this is skewed by the action potentials
        #so we need to find the peak negative current between spikes

        #current_error = leakyrelu_error(np.amax(volt.d_I[:500, :]/pA), target=0, thres=550, scale=0.1)
        
        #just clip to 0
        #current_error = np.clip(current_error, 0, None)/550 #divide by 250 to normalize
        snm_error = (volt_error + current_error) * SNM_SF






    #plot and save
    error_id = str(np.round(burst_error + nonburst_error + snm_error,3)).ljust(5, '0')
    plot_patterns(spikes.spike_trains(), 500);savefig(os.path.join(OUTPUT_FOLDER, f"spikes_{error_id}_{int(run_id)}.png"))
    #save the params
    params_out = {}
    network_args, synapse_params = parse_params_func(params)
    params_out.update(network_args)
    params_out.update(synapse_params)
    params_out.update(network_args['ECRH_params']) if network_args['ECRH_params'] is not None else None
    params_out['burst_error'] = burst_error
    params_out['nonburst_error'] = nonburst_error
    params_out['mean_volt'] = mean_volt
    params_out['mean_I'] = mean_I
    params_out['snm_error'] = snm_error
    params_out['runtime'] = RUN_TIME
    params_out['max_D_I'] = np.amax(volt.d_I[:500, :]/pA)
    params_out.update(pre_post_dict)
    params_df = pd.DataFrame(params_out, index=[0])
    params_df['connectivity'] = str(params['connectivity'])
    params_df.to_csv(os.path.join(OUTPUT_FOLDER, f"params_{error_id}_{int(run_id)}.csv"))

    #save the spikes and volts with heavy compression
    states = volt.get_states(['v', 'd_I', 'ge', 'gi'])
    for key in states:
        states[key] = states[key][:, :10]
    

    #add the burst and nonburst error to the states
    states['burst_error'] = burst_error_full
    states['nonburst_error'] = nonburst_error_full
    #add in the pre-post firing rates
    states['burst_pre_fr'] = burst_pre_fr
    states['nonburst_pre_fr'] = nonburst_pre_fr
    states['burst_post_fr'] = burst_post_fr
    states['nonburst_post_fr'] = nonburst_post_fr
    #add in the c22 error
    states['c22_error'] = c22_error


    dump([spikes.spike_trains(), states], os.path.join(OUTPUT_FOLDER, f"spikes_{error_id}_{int(run_id)}.joblib"), compress=('lzma', 9))
    #also dump the exact params for later use
    dump({'network_args': network_args, 'synapse_params': synapse_params, 'network_events': {}}, os.path.join(OUTPUT_FOLDER, f"params_{error_id}_{int(run_id)}.joblib"), 
         compress=('lzma', 9))
    return burst_error + nonburst_error + snm_error + c22_error




def nevergrad():

    optimizer = ng.optimizers.ScrHammersleySearch(parametrization=var_dict, budget=int(workers*rounds), num_workers=workers, )
    optimizer.enable_pickling()
    #optimizer = ng.optimizers.ParaPortfolio.load(f"optimizer.joblib")
    for _ in np.arange((optimizer.budget - optimizer.num_ask) // workers):
        points_list = []
        for x in np.arange(workers):    
            while True:
                try:
                    points_list.append(optimizer.ask())#
                    break
                except:
                    print("failed to ask")
            
        error = Parallel(n_jobs=workers)(delayed(ng_error_func)(p.value) for p in points_list)
        for points, er in zip(points_list, error):
            optimizer.tell(points, er)
        optimizer.dump(f"optimizer.joblib")

if __name__ == '__main__':
    nevergrad()