import operator
import itertools
import math
import pprint
from matplotlib import pyplot, patches
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from scipy import stats, interpolate
import pandas as pd
from ajustador.nrd_output import PUVC

from . import loader, fitnesses, utilities, xml, nrd_output, loadconc

def _on_close(event):
    event.canvas.figure.closed = True

try:
    _GRAPHS
except NameError:
    _GRAPHS = {}
def _get_graph(name, figsize=None, clear=True, newplot=False):
    if newplot:
        _GRAPHS.pop(name, None)
    try:
        f = _GRAPHS[name]
    except KeyError:
        pass
    else:
        if not f.closed:
            if clear:
                f.clear()
                f.canvas.draw() # this is here to make it easier to see what changed
                f.plot_counter = 0
            else:
                f.plot_counter += 1
            return f
    f = _GRAPHS[name] = pyplot.figure(figsize=figsize)
    f.canvas.set_window_title(name)
    f.closed = False
    f.plot_counter = 0
    f.canvas.mpl_connect('close_event', _on_close)
    return f

def plot_neurord_tog(measurement,sim, labels=None,fit_rpt=None,norm=None):
    #groups=(measurement, sim), so groups[0]=fit.measurement (exp_data), groups[1] is fit[x].output 
    f = _get_graph(measurement.name) #adding additional label in arbitrary place
    if fit_rpt:
        f.suptitle('\n'.join(fit_rpt), fontsize=7)
    #determine molecules to plot from the molecules in the experimental and simulated data
    mollist_sim=sim.output[0].specie_names
    ms_per_sec=1000
    if isinstance(measurement,xml.NeurordResult):
        mollist_exp=measurement.output[0].specie_names
        exp_data=measurement.output
    else:
        mollist_exp=list(measurement.data[0].waves.keys())
        exp_data=measurement.data
    mol_list=list(set.intersection(set(mollist_exp),set(mollist_sim)))
    #set up graph, either as one column or multiple columns depending on number of molecules
    if len(mol_list)>8:
        rows=int(np.round(np.sqrt(len(mol_list))))
        cols=int(np.ceil(len(mol_list)/float(rows)))
    else:
        cols=1
        rows=len(mol_list)
    axes=[f.add_subplot(rows,cols,v+1) for v in range(len(mol_list))]
    colors=[pyplot.get_cmap('gist_heat'),pyplot.get_cmap('viridis')]
    #plot the data - both experiment and simulated
    for i,dataset in enumerate([sim.output,exp_data]):
        color_increment=int(round(256.0/(len(dataset)+1)))
        for j,stim_data in enumerate(dataset):  #one or more different simulations/stimulations
            colr = colors[i].__call__(j*color_increment % colors[i].N)
            labl=stim_data.injection
            #print('stimdata', stim_data.file.filename, stim_data.injection, 'color', colr)
            for k,mol in enumerate(mol_list):
                if isinstance(stim_data,nrd_output.Output):
                    if mol in stim_data.specie_names:
                        plotdata=nrd_output.nrd_output_conc(stim_data,mol)
                        if stim_data.norm=='percent' and norm=='percent':
                            plotdata=nrd_output.nrd_output_conc(stim_data,mol)/stim_data.basal(mol)['basal']
                        axes[k].plot(plotdata.index.values/ms_per_sec,plotdata.values[:,0],label=labl,color=colr)
                elif isinstance(stim_data,loadconc.CSV_conc):
                    if mol in list(stim_data.waves.keys()):
                        ydata=1+stim_data.waves[mol].scale*(stim_data.waves[mol].wave.y-1)
                        axes[k].plot(stim_data.waves[mol].wave.x/ms_per_sec,ydata,color=colr,label=labl)
                else:
                    print('drawing.py: new type of data format', type(measurement))
                axes[k].set_ylabel(mol+" uM")
    axes[0].legend(loc='upper right', fontsize=8,ncol=2)
    for i in range(-cols,0):
        axes[i].set_xlabel('time, sec')
    #f.subplots_adjust(left=0.5, right=0.7, top=0.2, bottom=0.1)
    #f.tight_layout()
    f.canvas.draw()
    f.show()
    return f
 
def plot_together(*groups, offset=False, labels=None, separate=False):
    f = _get_graph(groups[0].name + ' together')

    if separate:
        f.subplots_adjust(left=0.03, bottom=0.03, right=0.97, top=0.97,
                          wspace=0.26, hspace=0.20)
        n = len(groups[0].waves)
        columns = 1 if n == 1 else 2 if n in (2, 4) else 3 if n <= 9 else 4 if n <= 16 else 5
    else:
        ax = f.gca()

    for i, waves in enumerate(groups):
        c = [0, 0, 0]
        ptp = waves.injection.ptp()
        if ptp > 0:
            off = waves.injection.min() - 0.2 * ptp
            ptp *= 1.2
        else:
            off = waves.injection.min() - 100e-12
            ptp = 100e-12

        for j, curve in enumerate(waves.waves):
            if separate:
                ax = f.add_subplot(int(math.ceil(n/columns)), columns, j+1)

            c[(i+2) % len(c)] = np.clip((curve.injection - off)/ptp, 0, 1)
            kwargs = {}
            if j == len(waves.waves)-1:
                if labels is None or labels[i] is None:
                    kwargs['label'] = waves.name
                else:
                    kwargs['label'] = labels[i]
            y = curve.wave.y - (curve.baseline.x if offset else 0)
            ax.plot(curve.wave.x, y, c=tuple(c), **kwargs)

    ax.legend(loc='lower right', fontsize=8)
    #f.tight_layout()
    f.canvas.draw()
    f.show()
    return f

def plot_waves(waves):
    f = _get_graph(waves.name + ' baseline and steady state', figsize=(16,10))
    f.subplots_adjust(left=0.03, bottom=0.03, right=0.97, top=0.97,
                      wspace=0.26, hspace=0.20)
    n = len(waves.waves)
    columns = 1 if n == 1 else 2 if n in (2, 4) else 3 if n <= 9 else 4 if n <= 16 else 5
    for i, curve in enumerate(waves.waves):
        ax = f.add_subplot(int(math.ceil(n/columns)), columns, i+1)
        ax.plot(curve.wave.x, curve.wave.y)
        ax.set_title('{} / {}V'.format(waves.name, curve.injection), fontsize=8)

        baseline = curve.baseline
        ax.hlines([baseline.x, baseline.x + baseline.dev*3, baseline.x - baseline.dev*3],
                  curve.wave.x.min(), curve.wave.x.max(), 'y')

        steady = curve.steady
        ax.hlines([steady.x, steady.x + steady.dev*3, steady.x - steady.dev*3],
                  curve.wave.x.min(), curve.wave.x.max(), 'g')

        spikes = curve.spikes
        if spikes.size:
            ax.vlines(spikes.x, -0.08, spikes.y, 'r')
            ax.text(0.5, 0.5, '{} spikes'.format(len(spikes)),
                    horizontalalignment='center',
                    transform=ax.transAxes)

    f.canvas.draw()
    f.show()
    return f

def plot_rectification(waves):
    f = _get_graph(waves.name + ' activation', figsize=(16,10))
    ii = 0
    n = len(waves.waves)
    columns = 1 if n == 1 else 2 if n in (2, 4) else 3 if n <= 9 else 4 if n <= 16 else 5
    for i, curve in enumerate(waves.waves):
        if curve.response.x >= -12e-12:
            continue
        ax = f.add_subplot(int(math.ceil(n/columns)), columns, i+1)
        ax.plot(curve.wave.x, curve.wave.y)
        ax.set_title('{0.filename} / {0.injection}V'.format(curve), fontsize=8)

        ccut = curve.falling_curve
        baseline = curve.baseline
        steady = curve.steady
        rect = curve.rectification
        ax.plot(ccut.x, ccut.y, 'r')
        ax.set_xlim(curve.baseline_before, ccut.x.max() + .01)

        fit = curve.falling_curve_fit
        if fit.good:
            ax.plot(ccut.x, baseline.x + fit.function(ccut.x, *fit.params), 'g--')
            ax.hlines([steady.x, steady.x-rect.x], 0.20, 0.40)

        ii += 1

    f.canvas.draw()
    f.show()
    return f

def plot_shape(what, *group):
    f = _get_graph('shape')
    f.canvas.set_window_title('shape for {}'.format(what))
    ax = f.gca()
    op = operator.attrgetter(what)
    for waves in group:
        inj, val = waves.injection, op(waves)
        ord = inj.argsort()
        ax.plot(inj[ord], val[ord],
                '-o' if waves.__class__.__module__ == 'ajustador.loader' else '--+',
                label=getattr(waves, 'name', '(mixed)'))
    ax.legend(loc='best', fontsize=8)

    f.canvas.draw()
    f.show()
    return f

def plot_shape2(what, *group):
    f = _get_graph('activation')
    ax = f.gca()
    for waves in group:
        x = [wave.falling_curve.y.min() if wave.falling_curve.y.size > 0 else np.nan
             for wave in waves]
        ax.plot(x, getattr(waves, what).x,
                '-o' if waves.__class__.__module__ == 'ajustador.loader' else '--+',
                label=getattr(waves, 'name', '(mixed)'))
    ax.legend(loc='best', fontsize=8)

    f.canvas.draw()
    f.show()
    return f

def plot_param_space(group, measurement=None, *what, **options):
    age = options.get('age', False)
    fitness_func = options.get('fitness', fitnesses.combined_fitness)
    values = group.param_values(*what)
    if age:
        fitness = np.arange(1, len(values)+1)
    else:
        fitness = [fitness_func(item, measurement) for item in group]

    f = _get_graph('param space')
    f.canvas.set_window_title('3-param view for {}'.format(fitness_func.__name__))
    ax = f.gca(projection='3d')
    if measurement is not None:
        sca = ax.scatter(*values.T, c=fitness)
        f.colorbar(sca, shrink=0.5, aspect=10)
    else:
        ax.scatter(*values.T)
    ax.set_xlabel(what[0])
    ax.set_ylabel(what[1])
    if len(what) > 2:
        ax.set_zlabel(what[2])

    history = options.get('history', False)
    if history:
        ax.plot(*values.T, c='k')

    f.canvas.draw()
    f.show()
    return f

def plot_history(groups, measurement=None, *,
                 show_quit=False, labels=None, ymax=None, fitness=None,
                 clear=True,
                 newplot=False,
                 Norm=None):

    if hasattr(groups[0], 'name'):
        groups = groups,

    func = fitness or groups[0].fitness_func
    if len(measurement.name):
        name = 'fit history {}'.format(measurement.name)
    else:
        name='fit history {}'.format(groups[0].dirname.split('/')[-2])
    f = _get_graph(name, clear=clear, newplot=newplot)
    ax = f.gca()

    colors = list('rgbkmc')
    markers = 'x+12348'
    colors = colors[f.plot_counter:] + colors[:f.plot_counter]
    markers = markers[f.plot_counter:] + markers[:f.plot_counter]

    for i, group in enumerate(groups):
        func = fitness or group.fitness_func
        
        fitnesses = [func(item, measurement) for item in group]
        fitnesses = pd.DataFrame(fitnesses)
        if show_quit:
            quit = fitnesses.fit_finished(fitnesses)

        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]

        label = (labels[i] if labels is not None else
                 '{} {}'.format(group.name, func.__name__))
        if show_quit:
            ax.plot(fitnesses[-quit], color + marker, label=label, picker=5)
            ax.plot(fitnesses[quit], marker=marker, color='0.5', picker=5)
        else:
            ax.plot(fitnesses, color + marker, label=label, picker=5)

    if ymax is not None:
        ax.set_ylim(top=ymax)
    ax.legend(frameon=True, loc='upper right', fontsize=8, numpoints=1)
    ax.set_xlabel('model evaluation')
    ax.set_ylabel(func.__name__)
    f.tight_layout()
    f.canvas.draw()

    def onpick(event):
        thisline = event.artist
        xdata = thisline.get_xdata()
        ind = event.ind
        x = xdata[ind][0]
        sim = groups[0][x]
        
        texts = []
        if hasattr(sim, 'report'):
            texts.append(sim.report())
        if hasattr(measurement, 'report'):
            texts.append(measurement.report())

        if isinstance(sim,xml.NeurordSimulation):
            params=[sim.name.split()[i] for i in range(1,len(sim.name.split()),2)]
            fit_dict=func(sim, measurement,full=1)
            for mol,molfit in fit_dict.items():
                text_string=mol
                for f,v in molfit.items():
                    text_string=text_string+' '+f+': '+str(round(v,2)) 
                texts.append(text_string)
            print(params)
            print('Fitness report',fit_dict)
            f = plot_neurord_tog(measurement,sim,
                                 labels='iteration {}:{}'.format(x,' '.join(params)),
                                 fit_rpt=texts,norm=Norm)
        else:
            if measurement:
                # FIXME: map from artist to group
                f = plot_together(measurement, sim,
                              labels=[None, '{}: {}'.format(x, sim.name)])
                if hasattr(func, 'report'):
                    texts.append(func.report(sim, measurement))
            else:
                plot_together(sim)
            if texts:
                f.axes[0].text(0, 1, '\n\n'.join(texts),
                           verticalalignment='top',
                           transform=ax.transAxes,
                           fontsize=7)

    if hasattr(f, '_pick_event_id'):
        f.canvas.mpl_disconnect(f._pick_event_id)
    f._pick_event_id = f.canvas.mpl_connect('pick_event', onpick)

    f.show()
    return f

def plot_param_view(group, measurement, *what, **options):
    fitness_func = options.get('fitness', fitnesses.combined_fitness)

    values = group.param_values(*what)
    fitness = [fitness_func(item, measurement) for item in group]

    f = _get_graph('param space')
    f.canvas.set_window_title('2-param view for {}'.format(fitness_func.__name__))
    ax = f.gca(projection='3d')
    sca = ax.scatter(values[:, 0], values[:, 1], fitness, c=fitness)
    f.colorbar(sca, shrink=0.5, aspect=10)
    ax.set_xlabel(what[0])
    ax.set_ylabel(what[1])
    ax.set_zlabel("fitness")

    history = options.get('history', False)
    if history:
        ax.plot(*values.T, c='k')

    f.canvas.draw()
    f.show()
    return f

def plot_param_section(group, measurement, *what, regression=False,
                       fitness=None, fitness_name=None,
                       log=False):
    if not what:
        what = group.param_names()
    columns = 1 if len(what) < 6 else 2

    if fitness is None:
        fitness = group.fitness_func
    if fitness_name is None:
        fitness_name = getattr(fitness, '__name__', str(fitness))

    values = group.param_values(*what)
    fitnesses = [fitness(item, measurement) if measurement is not None else fitness(item)
                 for item in group]

    rows = int(math.ceil(values.shape[1] / columns))

    f = _get_graph(' '.join(('param section',
                             getattr(group, 'name', '(no name)'),
                             fitness_name)))
    f.subplots_adjust(left=0.08, bottom=0.06, right=0.96, top=0.97,
                      wspace=0.17, hspace=0.24)

    for n, param in enumerate(what):
        ax = f.add_subplot(rows, columns, (n%rows)*columns + n//rows + 1)
        res = ax.scatter(values.T[n], fitnesses,
                         c=range(len(values)))

        if regression:
            a, b = stats.linregress(values.T[n], fitnesses)[:2]
            x1, x2 = values.T[n].min(), values.T[n].max()
            ax.plot([x1, x2], [a*x1+b, a*x2+b], 'r--')

        if log:
            ax.set_yscale('symlog' if isinstance(log, int) else log)

        if n == (rows - 1) // 2 * columns:
            ax.set_ylabel(fitness_name)
        ax2 = ax.twinx()
        ax2.set_ylabel(what[n])
        ax2.set_yticks([])

    f.colorbar(res, ax=f.axes, shrink=0.5, aspect=10)
    f.canvas.draw()
    f.show()
    return f


def _product(seq):
    return reduce(operator.mul, seq, 1)

def clutter(array):
    if array.shape[0] > array.shape[1]:
        # we want horizontal layouts because they fit better in the window
        return np.inf
    else:
        dd0 = np.diff(array, axis=0) ** 2
        dd1 = np.diff(array, axis=1) ** 2
        return np.nanmean(np.hstack((dd0.flat, dd1.flat)))**0.5

def cbdr(values, func, xnames, yname, order=None, debug=False):
    """We have n dimensions, with a shape like (d0, d1, ..., d(n-1)).
    Each variable has a range... but let's map them to (0,1).
    Then final mapping is:

    X = x'(n-1) + x'(n-3) * d(n-1) + ... + x'(0 or 1) * d(2 or 3)
    Y = x'(n-2) + x'(n-4) * d(n-2) + ... + x'(1 or 0) * d(3 or 2)

    where

    x'(i) = [x(i) - min x(i)] / [max x(i) - min x(i)]

    So the multiplier for x' is

    (1, 1, d(2), d(3), d(4), ..., d(n-1))
    """
    dimsplit = values.shape[1] // 2
    orders = ((order,) if order is not None
              else itertools.permutations(range(values.shape[1])))

    xorig, yorig = utilities.arange_values(values, func)

    best = np.inf
    for perm in orders:
        _xs = utilities.reorder_list(xorig, perm)
        _ys = utilities.reorder_array(yorig, perm)

        _ys_shape = np.array(_ys.shape)
        _finalshape = (_product(_ys_shape[:dimsplit]), _product(_ys_shape[dimsplit:]))
        _final = np.resize(_ys, _finalshape)

        cl = clutter(_final)

        if debug:
            print('{} {} → rms(clutter)={}, {}'
                  .format(perm,
                          '-'.join(xnames[i] for i in perm), cl,
                          '*' if cl < best else ''))
        if cl < best or np.isinf(best):
            xs, ys = _xs, _ys
            best = cl
            order = perm
            finalshape, final = _finalshape, _final
            ys_shape = _ys_shape

    print('Parameters:')
    m = max(len(p) for p in xnames)
    for i in range(len(order)):
        print('(axis {}) {}: {:{}} {}'.format(order[i], '-|'[i < dimsplit], xnames[order[i]], m, xs[i].flatten()))

    f = _get_graph('cbdr')
    f.canvas.set_window_title('cbdr {} × {} → {}'
                              .format('-'.join(xnames[i] for i in order[:dimsplit]),
                                      '-'.join(xnames[i] for i in order[dimsplit:]),
                                      yname))
    ax = f.gca()
    rms = (np.array(func)**2).mean()**0.5
    ax.set_title('{} rms(fitness)={} rms(clutter)={}'.format(yname, rms, best))
    im = ax.imshow(final, interpolation='none', origin='lower')
    ax.set_xticks([])
    ax.set_yticks([])
    f.colorbar(im, shrink=0.5, aspect=10)

    if debug:
        f2 = _get_graph('cbdr - clutter')
        print('final shape', finalshape, final.shape)
        im = f2.add_subplot(2, 1, 1).imshow(np.diff(final, axis=0)**2,
                                            interpolation='none', origin='lower')
        f2.colorbar(im, shrink=0.5, aspect=10)
        im = f2.add_subplot(2, 1, 2).imshow(np.diff(final, axis=1)**2,
                                            interpolation='none', origin='lower')
        f2.colorbar(im, shrink=0.5, aspect=10)
        f2.canvas.draw()
        f2.show()

    for i in range(len(ys_shape)):
        if i < dimsplit:
            size = _product(ys_shape[i+1:dimsplit])
            w, h = 1, size
            pos = (-dimsplit + i) * 2 - 1.5, 0 - .5
            textpos = pos[0] + .5, size - .25
            textopt = dict(verticalalignment='bottom', horizontalalignment='center', rotation=90)
        else:
            size = _product(ys_shape[i+1:])
            w, h = size, 1
            pos = 0 - .5, (-len(ys_shape) + i) * 2 - 1.5
            textpos = size - .25, pos[1] + .5
            textopt = dict(verticalalignment='center')
        # print(i, pos, w, h)
        ax.add_patch(patches.Rectangle(pos, w, h, clip_on=False, alpha=0.3, facecolor='grey'))
        ax.text(textpos[0], textpos[1], xnames[order[i]], **textopt)

    f.canvas.draw()
    f.show()
    return f

def plot_flat(group, measurement, *what, **options):
    if not what:
        what = group.param_names()

    fitness_func = options.pop('fitness', fitnesses.combined_fitness)
    log = options.pop('log', False)

    values = group.param_values(*what)
    fitness = [fitness_func(item, measurement, **opts) for item in group]

    nontrivial = np.ptp(values, axis=0) > 1e-10
    values = values[:, nontrivial]
    what = np.array(what)[nontrivial]
    print(values)
    print(what)
    if log:
        fitness = np.log(fitness)

    return cbdr(values, fitness, what, fitness_func.__name__, **options)

def _make_grid(values, npoints=200):
    # values is (measures × dimensions)
    xi = (np.linspace(dim.min(), dim.max(), npoints)
          for dim in values.T)
    return np.meshgrid(*xi, sparse=True)

def find_min_values(values, fitness):
    df = pd.DataFrame(np.hstack((values, np.array(fitness)[:,None])))
    mins = df.groupby(list(range(values.shape[1]))).min()
    mins.reset_index(inplace=True)
    return mins.values[:, :-1], mins.values[:, -1]

def plot_map(group, measurement, *what, **options):
    fitness_func = options.pop('fitness', fitnesses.combined_fitness)
    log = options.pop('log', False)
    dots = options.pop('dots', False)

    values = group.param_values(*what)
    fitness = [fitness_func(item, measurement, **options) for item in group]

    rms = (np.array(fitness)**2).mean()**0.5
    if log:
        fitness = np.log(fitness)

    yname = fitness_func.__name__
    f = _get_graph('param map')
    f.canvas.set_window_title('params {} × {} → {}'.format(what[0], what[1], yname))

    values, fitness = find_min_values(values, fitness)
    grid_x, grid_y = _make_grid(values)
    points = interpolate.griddata(values, fitness, (grid_x, grid_y), method=method)

    extent = (values[:,0].min(), values[:,0].max(),
              values[:,1].min(), values[:,1].max())

    ax = f.gca()
    ax.set_title('{} rms(fitness)={}'.format(yname, rms))
    ax.set_xlabel(what[0])
    ax.set_ylabel(what[1])
    im = ax.imshow(points,
                   extent=extent,
                   origin='lower', aspect='auto', **options)
    ax.set_xlim(extent[0], extent[1])
    ax.set_ylim(extent[2], extent[3])
    f.colorbar(im, shrink=0.5, aspect=10)

    if dots:
        ax.plot(values[:,0], values[:,1], 'k.', ms=1)

    f.canvas.draw()
    f.show()
    return f