import mytools
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from pylab import *
from os.path import exists
from matplotlib.collections import PatchCollection

def format3(x):
    return format(x, ".3f").rstrip("0").rstrip(".")
def format4(x):
    return format(x, ".4f").rstrip("0").rstrip(".")


def hex_to_RGB(hex_color):
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def RGB_to_hex(rgb_color):
    return f"#{int(rgb_color[0]):02x}{int(rgb_color[1]):02x}{int(rgb_color[2]):02x}".upper()

def generate_gradient(start_hex, end_hex, steps):
    s = hex_to_RGB(start_hex)
    f = hex_to_RGB(end_hex)
    gradient_hex = []
    for t in range(steps):
        factor = t / (steps - 1) if steps > 1 else 0
        r = s[0] + factor * (f[0] - s[0])
        g = s[1] + factor * (f[1] - s[1])
        b = s[2] + factor * (f[2] - s[2])
        gradient_hex.append(RGB_to_hex((r, g, b)))
    return gradient_hex


def detect_bursts_population(spike_times, isi_threshold):
    """
    Detect bursts in a population spike train using an ISI threshold.

    Parameters
    ----------
    spike_times : array-like
        Sorted spike times (float).
    isi_threshold : float
        Maximum ISI to consider spikes part of the same burst.

    Returns
    -------
    bursts : list of bursts where each burst contains the spike times
    """

    spike_times = asarray(spike_times)

    if spike_times.size < 2:
        return [], zeros_like(spike_times, dtype=bool)

    isis = diff(spike_times)

    burst_spike = zeros(spike_times.shape[0], dtype=bool)
    burst_spike[1:] = isis <= isi_threshold

    bursts = []
    i = 0

    while i < len(spike_times):
        if burst_spike[i]:
            start = i - 1
            j = i
            while j < len(spike_times) and burst_spike[j]:
                j += 1
            burst_times = spike_times[start:j]
            bursts.append(burst_times[:])
            i = j
        else:
            i += 1
    return bursts


N = 50
Lmax = 70
stim_duration = 12.0
tot_duration = 16.0
stim_freq_Hz = 2.0
g_EE = 12.0
g_EI = 20.0
g_IE = 20.0

A_pluss = [0.02, 0.022, 0.024, 0.026, 0.027, 0.028, 0.029, 0.03, 0.032, 0.035, 0.04, 0.045]
A_minus_factors = [-0.7,-0.8,-0.9,-1.0,-1.1,-1.2,-1.5,-2.0,-3.0,-4.0]
istim_to_plot = [24,31]

A_plus_saved  = []
A_minus_saved  = []
mean_spike_t_saved = []
std_spike_t_saved = []
Nspikes_saved = []
weights_saved = []

N_is_longbursts_saved = []
N_is_sparsebursts_saved = []
N_is_24_OKtimed_saved = []
N_is_1_OKtimed_after_12s_saved = []
N_is_3_OKtimed_after_12s_saved = []
N_is_8_OKtimed_after_12s_saved = []
N_is_3_bursts_after_12s_saved = []
N_is_8_bursts_after_12s_saved = []
N_is_spikes_after_12s_saved = []

DT = 0.05
longBurstLen = 0.075
if len(sys.argv) > 1:
  DT = float(sys.argv[1])
if len(sys.argv) > 2:
  longBurstLen = float(sys.argv[2])

for iA_plus in range(0,len(A_pluss)):
  A_plus_nS = A_pluss[iA_plus]
  w_init_nS = A_plus_nS
  for ifA in range(0,len(A_minus_factors)):
    A_minus_nS = A_plus_nS*A_minus_factors[ifA]
    
    spikes_all = []
    weights_all = []
    is_longbursts = []
    is_sparsebursts = []
    is_24_OKtimed = []
    is_too_many_after_12s = []
    is_1_OKtimed_after_12s = []
    is_3_OKtimed_after_12s = []
    is_8_OKtimed_after_12s = []
    is_3_bursts_after_12s = []
    is_8_bursts_after_12s = []
    is_spikes_after_12s = []
    for myseed in range(1,21):
      seed_addition = '' if myseed == 1 else '_seed'+str(myseed)
      filename_body = "stdpsynfire_synstim_N"+str(N)+"_L"+str(Lmax)+"_T"+str(stim_duration)+"_"+str(tot_duration)+"_"+str(stim_freq_Hz)+"Hz_gEE"+str(g_EE)+"_gEI"+str(g_EI)+"_gIE"+str(g_IE)+"_A"+format4(A_plus_nS)+"_"+format4(A_minus_nS)+"_"+format4(w_init_nS\
)
      if exists("synfirefiles/"+filename_body+seed_addition+".mat"):
        try:
          A = scipy.io.loadmat("synfirefiles/"+filename_body+seed_addition+".mat")
        except:
          print("Could not load synfirefiles/"+filename_body+seed_addition+".mat, continuing")
          continue
        spikes_t = A['spikes'][0] # [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i for i in range(0,len(spike_monitors))]],
        spikes_i = A['spikes'][1] # [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i for i in range(0,len(spike_monitors))]],
        weights = A['weights'] # [weight_monitors[0].t/second,[np.mean(weight_monitors[i].w, axis=0) for i in range(0,len(weight_monitors))]]
        if max(weights[0][1][0]) < 1e-4:
          print("synfirefiles/"+filename_body+seed_addition+".mat is not in pS!")
          weights[0][1][0] = weights[0][1][0]*1e12

        spikes_all.append([spikes_t[:],spikes_i[:]])
        weights_all.append(weights[:])

        bursts=detect_bursts_population(spikes_t[0][0],DT)
        burstlens = [x[-1]-x[0] for x in bursts]

        is_spikes_after_12s.append(max(spikes_t[0][0]) > 12.6)
        is_longbursts.append(max(burstlens) > longBurstLen)
        is_sparsebursts.append(min([len(x) for x in bursts if x[0] < 24*0.5]) < 5)
        
        NburstsOK = 0
        burstsOK_after_12s = []
        for iburst in range(1,len(bursts)):
          if bursts[iburst-1][-1] + 0.5-DT <= bursts[iburst][0] <= bursts[iburst][-1] <= bursts[iburst-1][-1] + 0.5+DT and burstlens[iburst] <= DT:
            NburstsOK = NburstsOK + 1
          if bursts[iburst][0] >= 24*0.5:
            burstsOK_after_12s.append(bursts[iburst-1][-1] + 0.5-DT <= bursts[iburst][0] <= bursts[iburst][-1] <= bursts[iburst-1][-1] + 0.5+DT and burstlens[iburst] <= DT)

        is_24_OKtimed.append(NburstsOK >= 24 and 0.500*23-DT <= bursts[23][0] <= bursts[23][-1] <= 0.500*23+DT)
        is_1_OKtimed_after_12s.append(sum(burstsOK_after_12s) > 0 and burstsOK_after_12s[0])
        is_3_OKtimed_after_12s.append(sum(burstsOK_after_12s) > 2 and burstsOK_after_12s[0] == burstsOK_after_12s[1] == burstsOK_after_12s[2])
        is_8_OKtimed_after_12s.append(sum(burstsOK_after_12s) > 7 and sum([x for x in burstsOK_after_12s[0:8]]) == 8)
        is_3_bursts_after_12s.append(sum([1 for x in bursts if x[0] >= 24*0.5]) > 2)
        is_8_bursts_after_12s.append(sum([1 for x in bursts if x[0] >= 24*0.5]) > 7)

    A_plus_saved.append(A_plus_nS)
    A_minus_saved.append(A_minus_nS)

    N_is_spikes_after_12s_saved.append(sum(is_spikes_after_12s))
    N_is_longbursts_saved.append(sum(is_longbursts))
    N_is_sparsebursts_saved.append(sum(is_sparsebursts))
    N_is_24_OKtimed_saved.append(sum(is_24_OKtimed))
    N_is_1_OKtimed_after_12s_saved.append(sum(is_1_OKtimed_after_12s))
    N_is_3_OKtimed_after_12s_saved.append(sum(is_3_OKtimed_after_12s))
    N_is_8_OKtimed_after_12s_saved.append(sum(is_8_OKtimed_after_12s))
    N_is_3_bursts_after_12s_saved.append(sum(is_3_bursts_after_12s))
    N_is_8_bursts_after_12s_saved.append(sum(is_8_bursts_after_12s))

f,axarr = subplots(1,1)

cols = generate_gradient('#0000CC','#FFFFFF',21)
stuffPlotted = []
for iA_plus in range(0,len(A_pluss)):
  A_plus_nS = A_pluss[iA_plus]
  for ifA in range(0,len(A_minus_factors)):
    A_minus_nS = A_plus_nS*A_minus_factors[ifA]
    ind = [i for i in range(0,len(N_is_longbursts_saved)) if A_plus_saved[i] == A_plus_nS and A_minus_saved[i] == A_minus_nS]
    if len(ind) == 0:
      polygon = Polygon(array([[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]]).T)
      p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
      p.set_facecolor('#000000')
      p.set_edgecolor(None)
      axarr.add_collection(p)
      stuffPlotted.append([[[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]],'#000000',''])
      continue
    if len(ind) > 1:
      qwe
    ind = ind[0]
    polygon = Polygon(array([[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]]).T)
    N_is_longbursts = N_is_longbursts_saved[ind]
    N_is_3_OKtimed_after_12s = N_is_3_OKtimed_after_12s_saved[ind]
    N_is_longbursts = N_is_longbursts_saved[ind]

    p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
    p.set_facecolor(cols[N_is_3_OKtimed_after_12s])
    p.set_edgecolor(None)
    if N_is_3_OKtimed_after_12s < 20:
      p.set_hatch('/'*int((N_is_longbursts)/4))
      myadd = '/'*int((N_is_longbursts)/4)
    else:
      myadd = ''
    stuffPlotted.append([[[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]],cols[N_is_3_OKtimed_after_12s],myadd])
    axarr.add_collection(p)
    if iA_plus == 0:
      axarr.text(-0.05,ifA+0.5,str(A_minus_factors[ifA]),fontsize=5,rotation=0,va='center',ha='right',clip_on=False)

  axarr.text(iA_plus+0.5,-0.05,str(A_plus_nS)+' nS',fontsize=5,va='top',ha='center',clip_on=False)
  
axarr.set_xlim([0,len(A_pluss)])
axarr.set_ylim([0,len(A_minus_factors)])
axarr.set_xticks([])
axarr.set_yticks([])

f.savefig("fig_stdpsynfire_gridsearch3_"+str(DT)+"_"+str(longBurstLen)+".pdf")
scipy.io.savemat("fig_stdpsynfire_gridsearch3_"+str(DT)+"_"+str(longBurstLen)+".mat",{'stuffPlotted': stuffPlotted, 'A_pluss': A_pluss, 'A_minus_factors': A_minus_factors})