import numpy as np
import pandas as pd
import glob
from utils import *
from spike_train_utils import *
from sNMO.error.spikeTrainErrors import emd_pdist_isi_hist, isi_wasserstein_dd

import os
from joblib import load, Parallel, delayed
import pycatch22



#Load the data to fit to
burst_ref = np.load("unit_3_burst_match1.npy")
non_burst_ref = np.load("unit_3_nonburst_match.npy")

COMPUTE_INTER_RES_DIST = False #whether to compute the inter-response distance between all neurons, this is slow but useful for some analysis
LOAD_PREV = True #load the previous results, this is useful if you want to load the results from a previous run
RUNTIME = 1 #in seconds, currently not used
ERROR_THRES = 9999 #threshold to further process the results, set to a high value to process all results, set to a low value to only process good results

out_class = ["tonic", "burst"]
BASELINE_TIMES = [10, 20]
BASELINE_TIMES_END = [35, 45]

#add 5 seconds to the baseline times to account for the time it takes to reach the steady state
BASELINE_TIMES = [BASELINE_TIMES[0] + 5, BASELINE_TIMES[1] + 5]
BASELINE_TIMES_END = [BASELINE_TIMES_END[0] + 5, BASELINE_TIMES_END[1] + 5]

def parse_results(id, folder):
    #the id is the path to the csv file, we need the last part of the path
    _id = id.split("_")[1:]
    _id = "_".join(_id)
    _id = _id.split(".")[:-1]
    _id = ".".join(_id)
    id = "spikes_" + _id + ".joblib"
    #load the res
    try:
        spikes, volt = load(os.path.join(folder, id))
    except:
        print(f"Failed to load {id}")
        return {}
    
    st_idx = int(BASELINE_TIMES[0] * second / (0.1*ms))
    bs_end_idx = int(BASELINE_TIMES[1] * second / (0.1*ms))
    burst_error_bs = volt['burst_error'] if 'burst_error' in volt.keys() else None
    non_burst_error_bs = volt['nonburst_error'] if 'nonburst_error' in volt.keys() else None
    #get the baseline isis
    baseline_isis = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES[0], high_cut=BASELINE_TIMES[1], indiv=True)
    baseline_isis_dict = {i: isi for i, isi in enumerate(baseline_isis)}
    baseline_current = np.array(volt['d_I']/pA)[st_idx:bs_end_idx]
    baseline_volt = np.array(volt['v']/mV)[st_idx:bs_end_idx]
    features_baseline = compute_features(baseline_isis_dict, baseline_volt, baseline_current, {'burst_error': burst_error_bs, 'nonburst_error': non_burst_error_bs}, runtime=(BASELINE_TIMES[1] - BASELINE_TIMES[0]) * second)

    end_st_idx = int(BASELINE_TIMES_END[0] * second / (0.1*ms))
    end_bs_end_idx = int(BASELINE_TIMES_END[1] * second / (0.1*ms))
    end_burst_error_bs = volt['burst_error'] if 'burst_error' in volt.keys() else None
    end_non_burst_error_bs = volt['nonburst_error'] if 'nonburst_error' in volt.keys() else None
    #get the end isis
    end_isis = build_isi_from_spike_train(spikes, low_cut=BASELINE_TIMES_END[0], high_cut=BASELINE_TIMES_END[1], indiv=True)
    end_isis_dict = {i: isi for i, isi in enumerate(end_isis)}
    end_current = np.array(volt['d_I']/pA)[end_st_idx:end_bs_end_idx]
    end_volt = np.array(volt['v']/mV)[end_st_idx:end_bs_end_idx]
    features_end = compute_features(end_isis_dict, end_volt, end_current, {'burst_error': end_burst_error_bs, 'nonburst_error': end_non_burst_error_bs},
                                     runtime=(BASELINE_TIMES_END[1] - BASELINE_TIMES_END[0]) * second, unit_list=features_baseline['burst_unit_list'] if 'burst_unit_list' in features_baseline.keys() else None)

    features_end = {f"end_{k}": v for k, v in features_end.items()}
    features_baseline = {f"baseline_{k}": v for k, v in features_baseline.items()}
    #merge the two dicts
    features = {**features_baseline, **features_end}
    return features



def compute_features(spikes_isi, volt, current, burst_error, runtime=RUNTIME*second, unit_list=None):
    
    # === Feature extraction ===
    #get some features, #mean current above and below zero
    
    mean_current_above = np.mean(current[current > 0])
    mean_current_below = np.mean(current[current < 0])
    if np.all(np.isnan(current)) or np.all(current == 0):
        max_current_above = np.nan
        max_current_below = np.nan
    else:
        max_current_above = np.nanmax(current[current > 0])
        max_current_below = np.nanmax(-1*current[current < 0])
    #get the mean voltage.
    mean_volt = np.nanmean(volt)
    #commpute the time spent above 250pA
    time_spent_above = (np.sum(current > 250) / (current.shape[0]* current.shape[1])) * 100



    #for specific units we will compute some states
    units_to_inspect = [0, 1, 2, 3, 10, 13]
    bins = np.arange(0, 60, 3)*1000
    unit_class = []
    unit_mean_state = []
    for i, sp in spikes_isi.items():
        if i in units_to_inspect:
            #compute the state
            isi = sp
            #filter bursts
            states, binned_labels, binned_tonic = compute_states(isi, bins=bins)
            #get the mean state
            mean_state = np.mean(states[states != 0])
            tonic_frac = np.sum(states == 1)/len(states)
            burst_frac = np.sum(states == 3)/len(states)
            #get the mean state
            _class = np.argmax([tonic_frac, burst_frac])
            unit_class.append(out_class[_class])
            unit_mean_state.append(mean_state)

    class_dict = {f"unit_{i}_class": _class for i, _class in zip(units_to_inspect, unit_class)}
    mean_state_dict = {f"unit_{i}_mean_state": _mean_state for i, _mean_state in zip(units_to_inspect, unit_mean_state)}

    
   
   
    # === Burstiness and firing rate ===
    #compute the burstiness and firing rate for each unit
    bursts = []
    firing_rates = []
    cv2_values = []
    spikes_per_burst = []
    isi_hists = []
    hist_bins = np.logspace(0, 4, 20)
    burst_errors = []
    nonburst_errors = []
    gaba_fr = []
    gaba_cv2 = []
    
    for i, sp in spikes_isi.items():
        if i >= 500:
            gaba_fr.append(len(sp) / runtime)
            gaba_cv2.append(ecv2(sp/ms))
            continue
        #get the firing rate
        if len(sp) > 2:
            
            firing_rates.append(len(sp) / runtime)

            #get the burstiness
            isi = sp 
            labeled_isi,_bursts, _non_burst = filter_bursts(isi , sil=25)
            bursts.append(len(_bursts) / len(labeled_isi))
            spikes_per_burst.append(np.nanmean([len(_b) for _b in _bursts]))
            isi_hists.append(np.histogram(isi, bins=hist_bins)[0])
            #compute hte emd between all isis
            #in new version of the network fitting, we precompute the wassestein distance between all isis and store it in a file
            # get cv2
            cv2_values.append(ecv2(isi/ms))

            # if burst_error is not None:
            #     burst_errors.append(burst_error['burst_error'][i])
            #     nonburst_errors.append(burst_error['nonburst_error'][i])
            # else:
            # burst_errors.append(isi_wasserstein_dd(isi , burst_ref, hist=True))

            # nonburst_errors.append(isi_wasserstein_dd(isi, non_burst_ref, hist=True))
            # if 'burst_error' in volt.keys():
            #     burst_errors.append(volt['burst_error'][i])
            #     nonburst_errors.append(volt['nonburst_error'][i])
            # else:
            burst_errors.append(isi_wasserstein_dd(isi , burst_ref, hist=False))

            nonburst_errors.append(isi_wasserstein_dd(isi, non_burst_ref, hist=False))

        else:
            firing_rates.append(0)
            bursts.append(0)
            isi_hists.append(np.zeros(len(hist_bins)-1))
            spikes_per_burst.append(np.nan)
            burst_errors.append(np.nan)
            nonburst_errors.append(np.nan)
            cv2_values.append(np.nan)

    CRH_neur = len(spikes_isi)//2
    assert len(firing_rates) == CRH_neur, "Firing rates and spikes_isi length mismatch"
    assert len(bursts) == CRH_neur, "Bursts and spikes_isi length mismatch"
    assert len(isi_hists) == CRH_neur, "ISI hists and spikes_isi length mismatch"
    assert len(spikes_per_burst) == CRH_neur, "Spikes per burst and spikes_isi length mismatch"
    assert len(burst_errors) == CRH_neur, "Burst errors and spikes_isi length mismatch"
    assert len(nonburst_errors) == CRH_neur, "Non-burst errors and spikes_isi length mismatch"
    assert len(cv2_values) == CRH_neur, "CV2 values and spikes_isi length mismatch"

   
    #compute hte emd between all isis
    if COMPUTE_INTER_RES_DIST:
        emd_dist = emd_pdist_isi_hist(isi_hists)
        
        emd_dist = np.ravel(emd_dist)
        #remove zeros and nans
        emd_dist = emd_dist[np.where(emd_dist > 0)]
        emd_dist = emd_dist[np.where(np.isnan(emd_dist) == False)]
        #get the mean emd
        mean_emd_dist = np.mean(emd_dist)
        std_emd_dist = np.std(emd_dist)
        min_emd_dist = np.min(emd_dist)
        max_emd_dist = np.max(emd_dist)
    else:
        mean_emd_dist = np.nan
        std_emd_dist = np.nan
        min_emd_dist = np.nan
        max_emd_dist = np.nan  
    
    #get the min and max burst and nonburst errors
    min_burst_error = np.nanmin(burst_errors)
    max_burst_error = np.nanmax(burst_errors)
    min_nonburst_error = np.nanmin(nonburst_errors)
    max_nonburst_error = np.nanmax(nonburst_errors)

    #Get the idx's of the min 10% of errors
    burst_idx = np.argsort(burst_errors)[:int(len(burst_errors)*0.1)] if unit_list is None else unit_list
    #for these units, compute the interburst interval, prob of burst, and events per burst x preceeding isi #from Ichiyama et al. 2022
    ibi_hist_full = []
    prob_burst_full = []
    events_per_burst_full = []
    catch_22_bursters = []
    cv2_bursters = []
    fr_bursts = []
    for i in burst_idx: 
        isi = spikes_isi[i]

        ibi_hist, ibi_raw = inter_burst_hist(isi)
        ibi_hist_full.append(ibi_hist)
        temp_res = prob_burst_per_isi(isi)
        prob_burst_full.append(temp_res)
        events_len_times, event_len_uni = preceeding_sil_per_event_len(isi)
        pre_sil = np.array([np.mean(_e) for _e in events_len_times])
        events_per_burst_full.append(pre_sil)
        cv2_bursters.append(ecv2(isi/ms))
        fr_bursts.append(len(isi) / runtime)

    #get the top burst errors that have an index below 10
    burst_idx_lower = np.argsort(burst_errors)
    burst_idx_lower = burst_idx_lower [burst_idx_lower  < 10][:1]
    for i in burst_idx_lower:
        #d_I = current[i]
        #cropt first 10sec
        #d_I = d_I[int(10*10000):]
        catch_22 = {'names': '', 'values': [0]} #pycatch22.catch22_all(d_I/pA)
        catch_22_bursters.append(catch_22['values'])


    catch_22_bursters = np.vstack(catch_22_bursters)
    if events_per_burst_full == []:
        events_per_burst_full = np.zeros((1, 20))
        ibi_hist_full = np.zeros((1, 19))
        prob_burst_full = np.zeros((1, 19))
    events_per_burst_full = np.vstack(events_per_burst_full)
    ibi_hist_full = np.vstack(ibi_hist_full)
    prob_burst_full = np.vstack(prob_burst_full)
    
    
    
    mean_events_per_burst = np.nanmean(events_per_burst_full, axis=0)
    mean_ibi_hist = np.nanmean(ibi_hist_full, axis=0)
    mean_prob_burst = np.nanmean(prob_burst_full, axis=0)
    mean_catch22 = np.nanmean(catch_22_bursters, axis=0)

    #catch 22 features
    catch_22_features = {keys: mean_catch22[i] for i, keys in enumerate(catch_22['names'])}

    #also compute the mean isi_hist
    mean_isi_hist = np.nanmean(isi_hists, axis=0)


    #get the mean firing rate
    mean_firing_rate = np.nanmean(firing_rates)
    #get the mean burstiness
    mean_burstiness = np.nanmean(bursts)
    #get the mean spikes per burstq
    mean_spikes_per_burst = np.nanmean(spikes_per_burst)
    #pack everything into a dict
    res = {"mean_volt": mean_volt,
           "mean_firing_rate": mean_firing_rate,
           "mean_burstiness": mean_burstiness,
           "mean_spikes_per_burst": mean_spikes_per_burst,
           "mean_current_above": mean_current_above,
           "mean_current_below": mean_current_below,
           "max_current_above": max_current_above,
           "max_current_below": max_current_below,
           "time_spent_above": time_spent_above,
            "min_burst_error": min_burst_error,
            "max_burst_error": max_burst_error,
            "min_nonburst_error": min_nonburst_error,
            "max_nonburst_error": max_nonburst_error,
            "mean_isi_hist": mean_isi_hist,
            "mean_ibi_hist": mean_ibi_hist,
            "mean_prob_burst": mean_prob_burst,
            "mean_events_per_burst": mean_events_per_burst,
            "mean_cv2_overall": np.nanmean(cv2_values),
            "std_cv2_overall": np.nanstd(cv2_values),
            "min_cv2_overall": np.nanmin(cv2_values),
            "max_cv2_overall": np.nanmax(cv2_values),
            "mean_cv2_bursters": np.nanmean(cv2_bursters),
            "std_cv2_bursters": np.nanstd(cv2_bursters),
            "min_cv2_bursters": np.nanmin(cv2_bursters),
            "max_cv2_bursters": np.nanmax(cv2_bursters),
            "mean_fr_bursters": np.nanmean(fr_bursts),
            "std_fr_bursters": np.nanstd(fr_bursts),
            "min_fr_bursters": np.nanmin(fr_bursts),
            "max_fr_bursters": np.nanmax(fr_bursts),
            'burst_unit_list': burst_idx,
            'burst_unit_list_len': len(burst_idx),
            'mean_gaba_fr': np.mean(gaba_fr),
            'std_gaba_fr': np.std(gaba_fr),
            'min_gaba_fr': np.min(gaba_fr),
            'max_gaba_fr': np.max(gaba_fr),
            'mean_gaba_cv2': np.mean(gaba_cv2),
            'std_gaba_cv2': np.std(gaba_cv2),
            'min_gaba_cv2': np.min(gaba_cv2),
            'max_gaba_cv2': np.max(gaba_cv2),}
    if COMPUTE_INTER_RES_DIST: #only add if we computed it
        res["mean_emd_dist"] = mean_emd_dist
        res["std_emd_dist"] = std_emd_dist
        res["min_emd_dist"] = min_emd_dist
        res["max_emd_dist"] = max_emd_dist
    res.update(class_dict)
    res.update(mean_state_dict)
    res.update(catch_22_features)
    #if 'c22_error' in volt.keys():
    #    res['c22_error'] = volt['c22_error']

    return res 

def file_load(x, res_folder="", further_processing=True, prev_res=None):
    #try:
    if True:
        id = x 
        if "concat" in id:
            temp = pd.DataFrame()
            return temp
        if LOAD_PREV and prev_res is not None:
            #look for the previous results
            if os.path.basename(x) in prev_res.index:
                temp = prev_res.loc[os.path.basename(x)]
                temp = temp.to_frame().T
                temp['id'] = os.path.basename(id)
                temp['path'] = id
                return temp


        temp = pd.read_csv(x, index_col=0)
        temp['id'] = os.path.basename(id)
        temp['path'] = id
        if "concat" in id:
            return temp
        try:
            temp_error_rate = float(id.split("_")[-2])
        except:
            temp_error_rate = float(id.split("_")[-2].strip("[]"))

        res_folder = os.path.dirname(x)
        if further_processing and temp_error_rate < ERROR_THRES:

            further_res = parse_results(temp['id'][0], folder=res_folder)
            temp_dict = {}
            for k, v in further_res.items():
                if isinstance(v, np.ndarray):
                    for i, _v in enumerate(v):
                         temp_dict[f"{k}_{i}"] = _v
                else:
                     temp_dict[k] = v
            temp = temp.assign(**temp_dict)
        
        #except:
        #print(f"Failed to load {x}")
        #temp = pd.DataFrame()
        return temp

def load_results():
    res_folder = "/media/smestern/sgbackup/aoi_paper_2/"
    csv_files = glob.glob(res_folder + "/**/*.csv", recursive=True)

    if LOAD_PREV:
        #load the previous results
        prev_res = pd.read_csv(os.path.join(res_folder, "concat_res.csv"), index_col=0)
    else:
        prev_res = None
        

    results = []
    further_processing = True
    results = Parallel(n_jobs=12, verbose=5)(delayed(file_load)(x, res_folder=res_folder, further_processing=further_processing, prev_res=prev_res) for x in csv_files)

    df = pd.concat(results, axis=0, ignore_index=True)
    #drop all nan cols
    df = df.dropna(axis=1, how='all')

    #set the index to the id
    df.set_index("id", inplace=True)
    df.to_csv(os.path.join(res_folder, "concat_res.csv"))

if __name__ == "__main__":
    load_results()