from inputs import *

import matplotlib.pyplot as plt

import time

import os


plt.rcParams["font.family"] = "sans-serif"
plt.rcParams['font.serif'] = 'Ubuntu'
plt.rcParams["font.size"] = "13"




def bin_spikes(spikes, tmax, dt):

    bins = np.arange(0, tmax + dt, dt)
    idx = np.searchsorted(spikes, bins)
    return np.diff(idx).astype(float)


def xcorr(x,y,s_range, dt):

    x_mean = np.sum(x)/len(x)
    y_mean = np.sum(y)/len(x)

    print(x_mean)

    t0 = time.time()

    ccf = [0]*2*s_range

    for s in range(-s_range, s_range):

        ccf[s + s_range] = np.sum(x * np.roll(y, s)) / len(x)

    ccf = np.array(ccf)


    print('time ' + str(time.time() - t0))


    ccvf =  ccf - x_mean * y_mean

    lm = np.sum(ccvf) * dt / x_mean

    return ccvf, lm


def xcorr0(x,y,dt):

    x_mean = np.sum(x)/len(x)
    y_mean = np.sum(y)/len(x)

    ccf0 = np.sum(x * y) / len(x)

    ccvf0 = (ccf0 - x_mean * y_mean) * dt / x_mean


    return ccvf0







n_cmp = 2
ns_syn = 2

rate = 20. # Hz
t_sim = 500.

corr_glob = 1.
corr_loc = 1.
jtr = 0.

dt = 0.0005

t_ref = 0.002

nt = int(t_sim/dt)

s_range = int(0.0001 * nt)

s_range = int(((t_ref/dt)))

tot_corrA = []

tot_corr_B = []

jtr_list = np.arange(0., 0.006, 0.001)
corr_list = np.arange(0., 1.1, 0.1)


mode = ''


for jtr in jtr_list:

    tot_corr_C = []

    for corr in corr_list:

        t0 = time.time()

        spt = spike_trains_hierarch_ind_global(n_cmp, ns_syn, rate, t_sim, 1., corr, jtr)

        print('time generation: ' + str(time.time() - t0))

        margin_cut = 1

        # division by dt is for scaling

        binned_1 = bin_spikes(spt[1][0][margin_cut:-margin_cut], t_sim, dt)/dt
        binned_2 = bin_spikes(spt[0][1][margin_cut:-margin_cut], t_sim, dt)/dt

        t0 = time.time()

        xcorrT, lm = xcorr(binned_1, binned_2, s_range, dt)

        print('time cross correlation: ' + str(time.time() - t0))

        print('lambda ' + str(lm))

        tot_corrA.append([corr,jtr,lm])

        tot_corr_C.append(lm)

        if mode == 'detailed':

            plt.figure()
            plt.plot(binned_1)
            plt.plot(binned_2)
            plt.show()

            plt.figure()
            plt.plot(xcorrT)
            plt.show()

    tot_corr_B.append(tot_corr_C)







# for jtr in jtr_list:
#
#     tot_corr_C = []
#
#     for corr in corr_list:
#
#         t0 = time.time()
#
#         spt = spike_trains_hierarch(n_cmp, ns_syn, rate, t_sim, corr, corr, jtr)
#
#         print('time generation: ' + str(time.time() - t0))
#
#         margin_cut = 1
#
#         # division by dt is for scaling
#
#         binned_1 = bin_spikes(spt[0][0][margin_cut:-margin_cut], t_sim, dt)/dt
#         binned_2 = bin_spikes(spt[0][1][margin_cut:-margin_cut], t_sim, dt)/dt
#
#         t0 = time.time()
#
#         xcorr0T = xcorr0(binned_1, binned_2, dt)
#
#         tot_corrA.append([corr,jtr,xcorr0T])
#
#         tot_corr_C.append(xcorr0T)
#
#
#     tot_corr_B.append(tot_corr_C)

date_time = time.strftime('%y-%m-%d--%H-%M-%S', time.localtime())

if not os.path.exists('../DATA/' + str(date_time) + '_corr_measure'):
	os.makedirs('../DATA/' + str(date_time) + '_corr_measure')



main_path = os.getcwd()
os.chdir('../DATA/' + str(date_time) + '_corr_measure')


tot_corrA = np.array(tot_corrA)

print(tot_corrA)


plt.plot(tot_corrA[:,0],tot_corrA[:,2])
plt.show()


plt.figure()

contf = plt.contourf(corr_list, jtr_list, tot_corr_B, cmap='YlOrRd', alpha=1., levels=np.arange(0., 1.05, 0.05))

cbar = plt.colorbar(contf, orientation = 'vertical', pad = 0.05, shrink = 0.9)


cbar.ax.get_yaxis().labelpad = 25
cbar.ax.tick_params(labelsize='14' )
cbar.ax.set_ylabel('Correlation', fontsize='18', rotation=270)

plt.xlabel('Ratio of shared spikes')
plt.ylabel('Jitter [ms]')

plt.savefig('figure_ref_doubled.png', dpi = 300)
plt.savefig('figure_ref_doubledsv.svg', dpi = 300)

plt.show()

os.chdir(main_path)