# Copyright (c) 2022 Adrian Negrean
# negreanadrian@gmail.com
#
# Software released under MIT license, see license.txt for conditions
import math
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.ticker as tck
import seaborn as sb

def distribute_subplots(max_nrows, max_ncols, nsubplots, layout = 'flexible'):
    """
    Distributes n subplots into multiple subplot grids of maximum size nrows x ncols.
    Distribution is done first by row, then by column with [0,0] corresponding to
    the upper left corner.

    Parameters
    ----------
    max_nrows, max_ncols : int
        Maximum number of rows and columns in a figure.

    nsubplots : int
        Total number of subplots to distribute.

    layout : str
        Choose 'fixed' to distribute plots on a fixed grid size on each page or
        choose 'flexible' to adjust the grid size depending on the number of plots
        to maximize space filling on each page.
    Returns
    -------
    list of tuple:
        Number of list elements is the number of figures needed to distribute plots. Each tuple is of the form
        ((nrows, ncols), [(plt_idx, row_1, col_1),...(plt_idx, row_N, col_N)])
        where:
        - first tuple element is the current plot index [0, nsubplots).
        - second tuple element is the grid size in # rows and # columns
        - third element is a list with distributed subplot 0-index coordinates
    """
    out = []
    nleft_to_assign = nsubplots
    curr_plt_idx = 0
    while nleft_to_assign:
        if layout == 'flexible':
            ncols = min(max_ncols, nleft_to_assign)
            nrows = min(max_nrows, int(math.ceil(nleft_to_assign/ncols)))
        elif layout == 'fixed':
            ncols = max_ncols
            nrows = max_nrows
        else:
            raise ValueError("Grid layout can be 'flexible' or 'fixed'.")

        # counter for number of subplots assigned for current figure
        plt_ctr = 0
        fig_plots = []
        nplts = min(nleft_to_assign, ncols*nrows)
        while plt_ctr < nplts:
            col_idx = plt_ctr%ncols
            row_idx = int(plt_ctr/ncols)
            fig_plots.append((curr_plt_idx,row_idx,col_idx))
            plt_ctr += 1
            curr_plt_idx += 1

        out.append(((nrows,ncols),fig_plots))
        nleft_to_assign -= nplts

    return out

def rec_grid_plot(pset, mruns, recpar, dt, secseg_names_filter = []):
    """
    Plot recorded variables over a grid.

    Parameters
    ----------
    pset : dict
        Plot settings.
    mruns : numpy nd.array of dict
        Recorded parameters.
    recpar : str
        Name of recorded parameter to plot.
    dt : float
        Time step in [ms].
    secseg_names_filter : list of str
        Filter sections and segment names.
    Returns
    -------
    fig : matplotlib.Figure
    """
    secseg_names = mruns.flat[0]["rec"][recpar].keys()

    # filter sections and segments to plot from all recorded sections and segments
    if secseg_names_filter:
        secseg_names = [x for x in _expand_secseg_list(secsegs = secseg_names_filter, env_set = env_set) if x in secseg_names]

    if len(mruns.shape) == 1:
        # number of columns in the grid plot
        ncols = 1
        # number of rows in the grid plot
        nrows = len(secseg_names)
        # adjust dimensions to make it compatible with plotting
        mruns = np.expand_dims(mruns,0)
    elif len(mruns.shape) == 2:
        # number of columns in the grid plot
        # organize columns by the first axis of the parameter sweep (last axis will be color hue in the multiline plot)
        ncols = mruns.shape[0]
        # number of rows in the grid plot
        # plot on each row parameters recorded in a certain section or segment
        nrows = len(secseg_names)
    else:
        raise Exception("Parameter sweep dimension {} is not compatible with grid plot.".format(len(mruns.shape)))

    # number of colors in the palette is determined by the last axis dimension of mruns, i.e. the last axis of the parameter sweep
    cmap = sb.cubehelix_palette(n_colors = mruns.shape[-1], start = 2.7, rot = 0, dark = 0.4, light = .9, reverse = False)

    fig, ax = plt.subplots(nrows, ncols, squeeze = False, sharex = True, sharey = True)
    for row_idx, sec_seg_name in enumerate(secseg_names):
        for col_idx in range(ncols):
            for color_idx, color in enumerate(cmap):
                y = mruns[col_idx,color_idx]["rec"][recpar][sec_seg_name][0]
                x = dt*np.arange(len(y))

                if color_idx < len(cmap)-1:
                    ax[row_idx,col_idx].plot(x,y, color = cmap[color_idx])
                else:
                    ax[row_idx,col_idx].plot(x,y, color = cmap[color_idx], label = recpar)
            if not row_idx and not col_idx:
                ax[row_idx,col_idx].legend()
            # set range of axes
            if "xlim" in pset["display"]:
                ax[row_idx,col_idx].set_xlim(pset["display"]["xlim"])
            if "ylim" in pset["display"]:
                ax[row_idx,col_idx].set_ylim(pset["display"]["ylim"])
            # add y-axis minor ticks
            ax[row_idx,col_idx].yaxis.set_minor_locator(tck.AutoMinorLocator())

        ax[row_idx, 0].set_ylabel(sec_seg_name)

    return fig

def plot_dendrogram(dtree, secdata, ax, linestyle = "-", color = (0,0,1,1), alpha = None):
    """
    Plots segment level parameter as a function of distance using a dendrogram style display

    Parameters
    ----------
    dtree : dict
        Dendrogram structure.
    secdata : dict
        Section data to plot. Keys are section names and values are 1D numpy arrays of length
        equal to the number of segments within a section.
    """
    def _rec_dend_plot(node, secdata, ax, linestyle, color, dist = 0, alpha = None):
        """
        Recursively plot dendrogram.
        """
        # iterate over parent sections
        for pkey in node:
            # distance to end of parent section
            dist_to_parent_1end = dist+node[pkey][0][-1]
            # plot parent section input impedance
            ax.plot(dist+node[pkey][0], secdata[pkey][0,:,0], linestyle = linestyle, color = color, alpha = alpha)
            
            # connect end of parent section to start of child section
            for ckey in node[pkey][1]:
                dist_to_child_0end = dist_to_parent_1end+node[pkey][1][ckey][0][0]    
                ax.plot([dist_to_parent_1end, dist_to_child_0end], [secdata[pkey][0,-1,0], secdata[ckey][0,0,0] ], linestyle = linestyle,
                    color = color, alpha = alpha)
            _rec_dend_plot(node = node[pkey][1], secdata = secdata, ax = ax, linestyle = linestyle, color = color, dist = dist_to_parent_1end,
                alpha = alpha)

            

    _rec_dend_plot(node = dtree, secdata = secdata, ax = ax, linestyle = linestyle, color = color, alpha = alpha)