"""
analysis.py

Functions to plot and analyse results

Version: 2014July9
"""
from pylab import pcolor, nonzero, mean, histogram, arange, bar, vstack,scatter, figure, isscalar, gca, unique, subplot, axes, shape, imshow, colorbar, plot, xlabel, ylabel, title, xlim, ylim, clim, show, zeros, legend, savefig, cm, specgram, get_cmap, psd
from scipy.io import loadmat
from scipy import loadtxt, size, array, linspace, ceil
from datetime import datetime
from time import time
import csv
import pickle
import shared as s

###############################################################################
### Simulation-related graph plotting functions
###############################################################################

## Create colormap
def bicolormap(gap=0.1,mingreen=0.2,redbluemix=0.5,epsilon=0.01):
    from matplotlib.colors import LinearSegmentedColormap as makecolormap

    mng=mingreen; # Minimum amount of green to add into the colors
    mix=redbluemix; # How much red to mix with the blue an vice versa
    eps=epsilon; # How much of the center of the colormap to make gray
    omg=1-gap # omg = one minus gap

    cdict = {'red': ((0.00000, 0.0, 0.0),
                     (0.5-eps, mix, omg),
                     (0.50000, omg, omg),
                     (0.5+eps, omg, 1.0),
                     (1.00000, 1.0, 1.0)),

          'green':  ((0.00000, mng, mng),
                     (0.5-eps, omg, omg),
                     (0.50000, omg, omg),
                     (0.5+eps, omg, omg),
                     (1.00000, mng, mng)),

          'blue':   ((0.00000, 1.0, 1.0),
                     (0.5-eps, 1.0, omg),
                     (0.50000, omg, omg),
                     (0.5+eps, omg, mix),
                     (1.00000, 0.0, 0.0))}
    cmap = makecolormap('bicolormap',cdict,256)

    return cmap

## Raster plot
def plotraster(filename=None): # allspiketimes, allspikecells, EorI, ncells, connspercell, backgroundweight, firingrate, duration): # Define a function for plotting a raster
    plotstart = time() # See how long it takes to plot
    EorIcolors = array([(1,0.4,0) , (0,0.2,0.8)]) # Define excitatory and inhibitory colors -- orange and turquoise
    cellcolors = EorIcolors[array(s.EorI)[array(s.allspikecells,dtype=int)]] # Set each cell to be either orange or turquoise
    figure() # Open a new figure
    scatter(s.allspiketimes,s.allspikecells,10,cellcolors,linewidths=0.5,marker='|') # Create raster
    xlabel('Time (ms)')
    ylabel('Cell ID')
    title('cells=%i syns/cell=%0.1f noise=%0.1f rate=%0.1f Hz' % (s.ncells,s.connspercell,s.backgroundweight[0],s.firingrate),fontsize=12)
    xlim(0,s.duration)
    ylim(0,s.ncells)
    plottime = time()-plotstart # See how long it took
    print(('  Done; time = %0.1f s' % plottime))
    if filename: savefig(filename)
    #show()

## Perievent time histogram
def plotPETH():
    binsize = 20 # bin size in ms
    binedges = arange(0, s.duration+binsize, binsize)
    peth = []
    for ipop in unique(s.cellpops):
        hist,binedges = histogram(s.allspiketimes[array([s.cellpops[int(i)] for i in s.allspikecells]) == ipop], binedges)
        peth.append(hist)
    figure()
    plot(array(peth).T)
    title('PETH (%d ms bins)'%binsize)
    xlabel('Time (ms)')
    ylabel('Spikes/bin')
    ylim(0,s.scale*binsize*2)
    h=axes()
    h.set_xticks(list(range(0,len(binedges),len(binedges)/10)))
    h.set_xticklabels(binedges[0:-1:len(binedges)/10].astype(int))
    legend(s.popnames)

## Plot power spectra density
def plotpsd():
    colorspsd=array([[0.42,0.67,0.84],[0.42,0.83,0.59],[0.90,0.76,0.00],[0.90,0.32,0.00],[0.34,0.67,0.67],[0.42,0.82,0.83],[0.90,0.59,0.00],[0.33,0.67,0.47],[1.00,0.85,0.00],[0.71,0.82,0.41],[0.57,0.67,0.33],[1.00,0.38,0.60],[0.5,0.2,0.0],[0.0,0.2,0.5]])

    lfpv=[[] for c in range(len(s.lfppops))]
    # Get last modified .mat file if no input and plot
    for c in range(len(s.lfppops)):
        lfpv[c] = s.lfps[:,c]
    lfptot = sum(lfpv)

    # plot pops separately
    plotPops = 0
    if plotPops:
        figure() # Open a new figure
        for p in range(len(s.lfppops)):
            psd(lfpv[p],Fs=200, linewidth= 2,color=colorspsd[p])
            xlabel('Frequency (Hz)')
            ylabel('Power')
            h=axes()
            h.set_yticklabels([])
        legend(['L2/3','L5A', 'L5B', 'L6'])

    # plot overall psd
    figure() # Open a new figure
    psd(lfptot,Fs=200, linewidth= 2)
    xlabel('Frequency (Hz)')
    ylabel('Power')
    h=axes()
    h.set_yticklabels([])

    show()


## Plot connectivityFor diagnostic purposes . Based on conndiagram.py.
def plotconn():
    # Create plot
    figh = figure(figsize=(8,6))
    figh.subplots_adjust(left=0.02) # Less space on left
    figh.subplots_adjust(right=0.98) # Less space on right
    figh.subplots_adjust(top=0.96) # Less space on bottom
    figh.subplots_adjust(bottom=0.02) # Less space on bottom
    figh.subplots_adjust(wspace=0) # More space between
    figh.subplots_adjust(hspace=0) # More space between
    h = axes()
    totalconns = zeros(shape(s.connprobs))
    for c1 in range(size(s.connprobs,0)):
        for c2 in range(size(s.connprobs,1)):
            for w in range(s.nreceptors):
                totalconns[c1,c2] += s.connprobs[c1,c2]*s.connweights[c1,c2,w]*(-1 if w>=2 else 1)*s.scaleconnweight[s.popEorI[c1],s.popEorI[c2]]
    imshow(totalconns,interpolation='nearest',cmap=bicolormap(gap=0))


    # Plot grid lines
    #hold(True)
    for pop in range(s.npops):
        plot(array([0,s.npops])-0.5,array([pop,pop])-0.5,'-',c=(0.7,0.7,0.7))
        plot(array([pop,pop])-0.5,array([0,s.npops])-0.5,'-',c=(0.7,0.7,0.7))

    # Make pretty
    h.set_xticks(list(range(s.npops)))
    h.set_yticks(list(range(s.npops)))
    h.set_xticklabels(s.popnames)
    h.set_yticklabels(s.popnames)
    h.xaxis.set_ticks_position('top')
    xlim(-0.5,s.npops-0.5)
    ylim(s.npops-0.5,-0.5)
    clim(-abs(totalconns).max(),abs(totalconns).max())
    colorbar()
    #show()


## Plot weight changes
def plotweightchanges(filename=None):
    if s.usestdp:
        # create plot
        figh = figure(figsize=(1.2*8,1.2*6))
        figh.subplots_adjust(left=0.02) # Less space on left
        figh.subplots_adjust(right=0.98) # Less space on right
        figh.subplots_adjust(top=0.96) # Less space on bottom
        figh.subplots_adjust(bottom=0.02) # Less space on bottom
        figh.subplots_adjust(wspace=0) # More space between
        figh.subplots_adjust(hspace=0) # More space between
        h = axes()

        # create data matrix
        wcs = [x[-1][-1] for x in s.allweightchanges] # absolute final weight
        wcs = [x[-1][-1]-x[0][-1] for x in s.allweightchanges] # absolute weight change
        pre,post,recep = list(zip(*[(x[0],x[1],x[2]) for x in s.allstdpconndata]))
        ncells = int(max(max(pre),max(post))+1)
        wcmat = zeros([ncells, ncells])

        for iwc,ipre,ipost,irecep in zip(wcs,pre,post,recep):
            wcmat[int(ipre),int(ipost)] = iwc *(-1 if irecep>=2 else 1)

        # plot
        imshow(wcmat,interpolation='nearest',cmap=bicolormap(gap=0,mingreen=0.2,redbluemix=0.1,epsilon=0.01))
        xlabel('post-synaptic cell id')
        ylabel('pre-synaptic cell id')
        h.set_xticks(s.popGidStart)
        h.set_yticks(s.popGidStart)
        h.set_xticklabels(s.popnames)
        h.set_yticklabels(s.popnames)
        h.xaxis.set_ticks_position('top')
        xlim(-0.5,ncells-0.5)
        ylim(ncells-0.5,-0.5)
        clim(-abs(wcmat).max(),abs(wcmat).max())
        colorbar()

        if filename: savefig(filename)
        #show()

        changeOverTime = 0
        if changeOverTime:
        # change over time
            figure()
            relative = 1 # relative or absolute w changes

            wc = array([wi[-1] for w in s.allweightchanges for wi in w if len(w)>1])

            maxSteps = max([len(w) for w in s.allweightchanges])
            wc = zeros((len(s.allweightchanges), maxSteps))
            for iconn,conn in enumerate(s.allweightchanges):
                for it in range(maxSteps):
                    if relative:
                        wc[iconn, it] = conn[it][-1]-conn[0][-1] if len(conn)>it else wc[iconn, it-1]
                    else:
                        wc[iconn, it] = conn[it][-1] if len(conn)>it else wc[iconn, it-1]

            vmax = max([max(row) for row in wc])
            vmin = min([min(row) for row in wc])
            pcolor(wc, cmap='hot_r', vmin=vmin, vmax=vmax)
            xlim((0,maxSteps))
            ylim((0,len(wc)))
            xlabel('Time (weight updates)')
            ylabel('Synaptic connection id')
            colorbar()
            #show()



## plot motor subpopulations connectivity changes
def plotmotorpopchanges():
    showInh = True
    if s.usestdp:
        Ewpre =  []
        Ewpost = []
        EwpreSum = []
        EwpostSum = []
        if showInh:
            Iwpre =  []
            Iwpost = []
            IwpreSum = []
            IwpostSum = []
        for imus in range(len(s.motorCmdCellRange)):
            Ewpre.append([x[0][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in s.motorCmdCellRange[imus]])
            Ewpost.append([x[-1][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in s.motorCmdCellRange[imus]])
            EwpreSum.append(sum(Ewpre[imus]))
            EwpostSum.append(sum(Ewpost[imus]))


            if showInh:
                motorInhCellRange = s.motorCmdCellRange[imus] - s.popGidStart[s.EDSC] + s.popGidStart[s.IDSC]
                Iwpre.append([x[0][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in motorInhCellRange])
                Iwpost.append([x[-1][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in motorInhCellRange])
                IwpreSum.append(sum(Iwpre[imus]))
                IwpostSum.append(sum(Iwpost[imus]))

        print('\ninitial E weights: ',EwpreSum)
        print('final E weigths: ',EwpostSum)
        print('absolute E difference: ',array(EwpostSum) - array(EwpreSum))
        print('relative E difference: ',(array(EwpostSum) - array(EwpreSum)) / array(EwpreSum))

        if showInh:
            print('\ninitial I weights: ',IwpreSum)
            print('final I weigths: ',IwpostSum)
            print('absolute I difference: ',array(IwpostSum) - array(IwpreSum))
            print('relative I difference: ',(array(IwpostSum) - array(IwpreSum)) / array(IwpreSum))


        # plot
        figh = figure(figsize=(1.2*8,1.2*6))
        ax1 = figh.add_subplot(2,1,1)
        ind = arange(len(EwpreSum))  # the x locations for the groups
        width = 0.35       # the width of the bars
        ax1.bar(ind, EwpreSum, width, color='b')
        ax1.bar(ind+width, EwpostSum, width, color='r')
        ax1.set_xticks(ind+width)
        ax1.set_xticklabels( ('shext','shflex','elext','elflex') )
        #legend(['pre','post'])
        ax1.grid()

        ax2 = figh.add_subplot(2,1,2)
        width = 0.70       # the width of the bars
        bar(ind,(array(EwpostSum) - array(EwpreSum)) / array(EwpreSum), width, color='b')
        ax2.set_xticks(ind+width/2)
        ax2.set_xticklabels( ('shext','shflex','elext','elflex') )
        ax2.grid()

        if showInh:
            figh = figure(figsize=(1.2*8,1.2*6))
            ax1 = figh.add_subplot(2,1,1)
            ind = arange(len(IwpreSum))  # the x locations for the groups
            width = 0.35       # the width of the bars
            ax1.bar(ind, IwpreSum, width, color='b')
            ax1.bar(ind+width, IwpostSum, width, color='r')
            ax1.set_xticks(ind+width)
            ax1.set_xticklabels( ('shext','shflex','elext','elflex') )
            legend(['pre','post'])
            ax1.grid()

            ax2 = figh.add_subplot(2,1,2)
            width = 0.70       # the width of the bars
            bar(ind,(array(IwpostSum) - array(IwpreSum)) / array(IwpreSum), width, color='b')
            ax2.set_xticks(ind+width/2)
            ax2.set_xticklabels( ('shext','shflex','elext','elflex') )
            ax2.grid()


## plot 3d architecture:
def plot3darch():
    # create plot
    figh = figure(figsize=(1.2*8,1.2*6))
    # figh.subplots_adjust(left=0.02) # Less space on left
    # figh.subplots_adjust(right=0.98) # Less space on right
    # figh.subplots_adjust(top=0.98) # Less space on bottom
    # figh.subplots_adjust(bottom=0.02) # Less space on bottom
    ax = figh.add_subplot(1,1,1, projection='3d')
    h = axes()

    #print len(s.xlocs),len(s.ylocs),len(s.zlocs)
    xlocs =[1,2,3]
    ylocs=[3,2,1]
    zlocs=[0.1,0.5,1.2]
    ax.scatter(xlocs,ylocs, zlocs,  s=10, c=zlocs, edgecolors='none',cmap = 'jet_r' , linewidths=0.0, alpha=1, marker='o')
    azim = 40
    elev = 60
    ax.view_init(elev, azim)
    #xlim(min(s.xlocs),max(s.xlocs))
    #ylim(min(s.ylocs),max(s.ylocs))
    #ax.set_zlim(min(s.zlocs),max(s.zlocs))
    xlabel('lateral distance (mm)')
    ylabel('lateral distance (mm)')
    ylabel('cortical depth (mm)')


###############################################################################
### Evolutionary-algorithm analysis/plotting functions
###############################################################################

#%% plot filled error bars
def errorfill(x, y, yerr, lw=1, elinewidth=1, color=None, alpha_fill=0.2, ax=None):
    ax = ax if ax is not None else gca()
    if color is None:
        color = next(ax._get_lines.color_cycle)
    if isscalar(yerr) or len(yerr) == len(y):
        ymin = y - yerr
        ymax = y + yerr
    elif len(yerr) == 2:
        ymin, ymax = yerr
    ax.plot(x, y, color=color, lw=lw)
    ax.fill_between(x, ymax, ymin, color=color, lw= elinewidth, alpha=alpha_fill)

#%% function to obtain unique list of lists
def uniqueList(seq):
    seen = {}
    result = []
    indices = []
    for index,item in enumerate(seq):
        marker = tuple(item)
        if marker in seen: continue
        seen[marker] = 1
        result.append(item)
        indices.append(index)
    return result,indices

#%% function to read data
def loadData(folder, islands, dataFrom):
    #%% Load data from files
    if islands > 1:
        ind_gens_isl=[] # individuals data for islands
        ind_cands_isl=[]
        ind_fits_isl=[]
        ind_cs_isl=[]

        stat_gens_isl=[] # statistics.csv for islands
        stat_worstfits_isl=[]
        stat_bestfits_isl=[]
        stat_avgfits_isl=[]
        stat_stdfits_isl=[]

        fits_sort_isl=[] #sorted data
        gens_sort_isl=[]
        cands_sort_isl=[]
        params_sort_isl=[]

    for island in range(islands):
        ind_gens=[] # individuals data
        ind_cands=[]
        ind_fits=[]
        ind_cs=[]

        eval_gens=[] # error files for each evaluation
        eval_cands=[]
        eval_fits=[]
        eval_params=[]

        stat_gens=[] # statistics.csv
        stat_worstfits=[]
        stat_bestfits=[]
        stat_avgfits=[]
        stat_stdfits=[]

        if islands > 0:
            folderFinal = folder+"_island_"+str(island)
        else:
            folderFinal = folder

        with open('../data/%s/individuals.csv'% (folderFinal)) as f: # read individuals.csv
            reader=csv.reader(f)
            for row in reader:
                ind_gens.append(int(row[0]))
                ind_cands.append(int(row[1]))
                ind_fits.append(float(row[2]))
                cs = [float(row[i].replace("[","").replace("]","")) for i in range(3,len(row))]
                ind_cs.append(cs)

        with open('../data/%s/statistics.csv'% (folderFinal)) as f: # read statistics.csv
            reader=csv.reader(f)
            for row in reader:
                stat_gens.append(float(row[0]))
                stat_worstfits.append(float(row[2]))
                stat_bestfits.append(float(row[3]))
                stat_avgfits.append(float(row[4]))
                stat_stdfits.append(float(row[6]))

        # unique generation number (sometimes repeated due to rerunning in hpc)
        stat_gens, stat_gens_indices = unique(stat_gens,1) # unique individuals
        stat_worstfits, stat_bestfits, stat_avgfits, stat_stdfits = list(zip(*[[stat_worstfits[i], stat_bestfits[i], stat_avgfits[i], stat_stdfits[i]] for i in stat_gens_indices]))

        if dataFrom == 'fitness':
            for igen in range(max(ind_gens)): # read error files from evaluations
                for ican in range(max(ind_cands)):
                    try:
                        f=open('../data/%s/gen_%d_cand_%d_error'%(folderFinal, igen,ican));
                        eval_fits.append(pickle.load(f))
                        f=open('../data/%s/gen_%d_cand_%d_params'%(folderFinal, igen,ican));
                        eval_params.append(pickle.load(f))
                        eval_gens.append(igen)
                        eval_cands.append(ican)
                    except:
                        pass
                        #eval_fits.append(0.15)
                        #eval_params.append([])

        # find x corresponding to smallest error from function evaluations
        if dataFrom == 'fitness':
            #fits_sort, fits_sort_indices, fits_sort_origind = unique(eval_fits, True, True)
            fits_sort_indices = sorted(list(range(len(eval_fits))), key=lambda k: eval_fits[k])
            fits_sort = [eval_fits[i] for i in fits_sort_indices]
            gens_sort = [eval_gens[i] for i in fits_sort_indices]
            cands_sort = [eval_cands[i] for i in fits_sort_indices]
            params_sort = [eval_params[i] for i in fits_sort_indices]
        # find x corresponding to smallest error from individuals file
        elif dataFrom == 'individuals':
            params_unique, unique_indices = uniqueList(ind_cs) # unique individuals
            fits_unique = [ind_fits[i] for i in unique_indices]
            gens_unique = [ind_gens[i] for i in unique_indices]
            cands_unique = [ind_cands[i] for i in unique_indices]

            sort_indices = sorted(list(range(len(fits_unique))), key=lambda k: fits_unique[k]) # sort fits
            fits_sort = [fits_unique[i] for i in sort_indices]
            gens_sort = [gens_unique[i] for i in sort_indices]
            cands_sort = [cands_unique[i] for i in sort_indices]
            params_sort = [params_unique[i] for i in sort_indices]

        # if multiple islands, save data for each
        if islands > 1:
            ind_gens_isl.append(ind_gens) # individuals data for islands
            ind_cands_isl.append(ind_cands)
            ind_fits_isl.append(ind_fits)
            ind_cs_isl.append(ind_cs)

            stat_gens_isl.append(stat_gens) # statistics.csv for islands
            stat_worstfits_isl.append(stat_worstfits)
            stat_bestfits_isl.append(stat_bestfits)
            stat_avgfits_isl.append(stat_avgfits)
            stat_stdfits_isl.append(stat_stdfits)

            fits_sort_isl.append(fits_sort) #sorted data
            gens_sort_isl.append(gens_sort)
            cands_sort_isl.append(cands_sort)
            params_sort_isl.append(params_sort)

    if islands > 1:
        return ind_gens_isl, ind_cands_isl, ind_fits_isl, ind_cs_isl, stat_gens_isl, \
            stat_worstfits_isl, stat_bestfits_isl, stat_avgfits_isl, stat_stdfits_isl, \
            fits_sort_isl, gens_sort_isl, cands_sort_isl, params_sort_isl