import numpy as np
import elephant
from neo.core import AnalogSignal,SpikeTrain
import quantities as q
from elephant.conversion import BinnedSpikeTrain
from matplotlib import pyplot as plt

def plot_cross_corr(pre_spikes,spiketime_dict,presyn,binsize,maxtime=0):
    mean_cc={};mean_cc_shuffle={};cc_shuffle_corrected={}
    for key in pre_spikes.keys(): #key indexes the simulation condition, e.g. stpYN
        numtrials=len(pre_spikes[key])
        if maxtime==0:
            last_spike=[train[-1] for train in pre_spikes[key][0][presyn]]
            t_end=max(np.round(np.max(last_spike),maxtime))
            maxtime=t_end
        else:
            t_end=maxtime
        cc_hist=[[] for t in range(numtrials)]
        fig,axes =plt.subplots(numtrials,numtrials,sharex=True)
        fig.suptitle('cross correlograms '+key)
        for trial_in in range(numtrials):
            for trial_out in range(numtrials):
                if isinstance(pre_spikes[key][trial_in][presyn], list):
                    spikes = np.sort(np.concatenate(pre_spikes[key][trial_in][presyn]))
                else:
                    spikes=pre_spikes[key][trial_in][presyn]
                train=SpikeTrain(spikes*q.s,t_start=0*q.s,t_stop=t_end*q.s,binsize=binsize*q.s)
                in_train=BinnedSpikeTrain(train,t_start=0*q.s,t_stop=t_end*q.s,binsize=binsize*q.s)
                train=SpikeTrain(spiketime_dict[key][trial_out]*q.s,t_stop=t_end*q.s)
                out_train=BinnedSpikeTrain(train,t_start=0*q.s,t_stop=t_end*q.s,binsize=binsize*q.s)
                #print('trial_in,trial_out', trial_in, trial_out)
                cc_hist[trial_in].append(elephant.spike_train_correlation.cross_correlation_histogram(in_train,out_train))
                axes[trial_in,trial_out].plot(cc_hist[trial_in][trial_out][0].magnitude[:,0])
            axes[trial_in,0].set_ylabel('input '+str(trial_in))
        for trial_out in range(trial_in,numtrials):
            axes[-1,trial_out].set_xlabel('output '+str(trial_out))
        #shuffle corrected mean cross-correlogram
        #initialize these to accumulate across conditions, e.g. pre and post-HFS, and possibly across keys (str freq)
        cc_same=[cc_hist[a][a][0].magnitude[:,0] for a in range(numtrials)]
        mean_cc[key]=np.mean(cc_same,axis=0)
        cc_diff=[cc_hist[a][b][0].magnitude[:,0] for a in range(numtrials) for b in range(numtrials) if b != a ]
        mean_cc_shuffle[key]=np.mean(cc_diff,axis=0)
        cc_shuffle_corrected[key]=mean_cc[key]-mean_cc_shuffle[key]
    #PLOT mean cc and shuffle corrected for each key on one figure
    xbins=np.linspace(-t_end,t_end,len(mean_cc[key]))
    fig,axes =plt.subplots(3,1,sharex=True)
    fig.suptitle('cross correlograms '+presyn)
    for key in mean_cc.keys():
        axes[0].plot(xbins,mean_cc[key],label=key)
        axes[1].plot(xbins,mean_cc_shuffle[key],label=key)
        axes[2].plot(xbins,cc_shuffle_corrected[key],label=key)
    axes[0].set_ylabel('mean cc')
    axes[1].set_ylabel('mean cc shuffled')
    axes[2].set_ylabel('mean cc shuffled-corrected')
    axes[2].legend()
    return