#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 14 21:28:15 2017

@author: spiros
"""

import numpy as np
import pickle, sys, os, time
import scipy.ndimage.filters as flt
from functions_analysis import spike_map, binning
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.use('agg')
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def analysis_path_cluster(ntrial,case,learning):
    
    folder1='data_analysis'
    
    os.system('mkdir -p '+folder1+'/figures/')
    os.system('mkdir -p '+folder1+'/metrics/')
    fdname1 = '/'+folder1+'/figures/'
    fdname2 = '/'+folder1+'/metrics/'

    print "Analyse ... " + case +" trial "+ntrial +" " +learning 
    
    os.system('mkdir -p '+folder1+'/figures/'+learning+'/')
    os.system('mkdir -p '+folder1+'/metrics/'+learning+'/')
    maindir=os.getcwd()

    # Give path dimensions
    npath_x = 200
    npath_y = 1
    # Number of pyramidal
    Ncells  = 130
    Nbins   = 100
    skernel = 3.0 /(npath_x/Nbins)
    runsAll = 5
    
    ### Define the map size!
    rate_maps_all=np.zeros((Ncells,Nbins))
    # 3-d matrix of all pyramidals
    rateMaps            = np.zeros((Ncells,Nbins,npath_y))
    rateMaps_unsmoothed = np.zeros((Ncells,Nbins,npath_y))
    time_array_in_bin   = np.zeros((Ncells,Nbins,npath_y))
    
    # File location - pathfile
    fileload  = folder1 +'/metrics_permutations/'+learning
    
    with open(fileload+'/path_all_trial_'+str(ntrial)+'.pkl', 'rb') as f:
        path_all=pickle.load(f)    
    
    with open(fileload+'/spiketimes_all_trial_'+str(ntrial)+'.pkl', 'rb') as f:
        spiketimes_all=pickle.load(f)
    
    # Loop for all pyramidals
    for npyr in range(Ncells):
        
        # A matrix for rate map
        Zall = np.zeros((Nbins,npath_y))
        time_array_all = np.zeros(Nbins*npath_y)

        for nrun in range(1,runsAll+1):
            
            # Load of path -- different for each run
            path = path_all[nrun-1]
                        
            # Make the time -space map
            time_array = np.bincount(path[:,0])[1:]
            csum = np.cumsum(time_array)

            # Load pickled data with spiketimes of each pyramidal, if file doesnot exist
            # continue the loop
#            fileload = path1+cond+'/Trial_'+ntrial+'/Run_'+nrun+'/pickled_sn_pvsoma_'+npyr+'.pkl'
#            
#            if os.path.isfile(fileload):
#                with open(fileload, 'rb') as f:
#                    spiketimes=pickle.load(f)
#                #remove first entry -- aka pyramidal number
#                spiketimes = spiketimes[1:]
#                total+=1
#            else:
#                print "File does not exist."
#                continue

            spiketimes = spiketimes_all['Run'+str(nrun)]['Pyramidal'+str(npyr)][case]            

            
            Z = spike_map(spiketimes,csum, npath_x, npath_y)
            # Take the sum over all runs given by total
            Zall += binning(Z, Nbins, 'summing')
            time_array_binned = binning(time_array, Nbins, 'summing').squeeze()
            
            time_array_all += time_array_binned / 1000.0 # time spent in each bin in ms

        # Calculate the time spent in each bin
        time_smoothed = flt.gaussian_filter1d(time_array_all, sigma=skernel, mode='nearest',truncate=3.0)
        Zsmoothed           = flt.gaussian_filter1d(Zall.squeeze(), sigma=skernel, mode='nearest',truncate=3.0)
        # convert to Hz, so divide with seconds,  time ms/1000 (ms/sec) --> seconds
        Zmean = np.divide(Zsmoothed, time_smoothed)
        
        
        # Gaussian smoothing
#        Zsmoothed = flt.gaussian_filter1d(Zmean, sigma=skernel, mode='nearest',truncate=3.0)
        # Spatial Coherence - "Spatial representations of place cells in darkness are supported by path integration and border information"

        
        rateMaps_unsmoothed[int(npyr),:,:] = Zall
        rateMaps[int(npyr),:,:]            = Zmean.reshape(-1,1)
        time_array_in_bin[int(npyr),:,:]   = time_array_all.reshape(-1,1)
        

    print '\nDone with the rate maps'


    fig, axes = plt.subplots(nrows=13, ncols=10,figsize=(20, 20))
    nn=0
    for ax in axes.flat:
        Max = np.max(rateMaps[nn,:,:])
        im = ax.imshow(rateMaps[nn,:,:].T/Max, origin='lower',cmap="jet", aspect='10')
        ax.tick_params( axis='y',which='both',right='off',left='off',labelleft='off')
        ax.title.set_text('PC'+str(nn) + '  ' + str(np.round(Max,1))+ ' Hz')
        nn+=1

    fig.colorbar(im, ax=axes.ravel().tolist())
    
    if not os.path.exists(maindir+fdname1+learning+'/'):
        os.makedirs(maindir+fdname1+learning+'/')    
    
    plt.savefig(maindir+fdname1+learning+'/'+case+'_'+str(ntrial)+'_heatmap.pdf',format='pdf',dpi=300)
#    plt.savefig(maindir+fdname1+learning+'/'+case+'_'+str(ntrial)+'_heatmap.png',format='png',dpi=300)
#    plt.close(fig)
    
    idx        = np.argmax(rateMaps.squeeze(), axis=1)
    new_idx    = np.lexsort((range(Ncells), idx))
    rtMaps     = rateMaps[new_idx,:,:].squeeze()
    
#    Max        = np.max(rtMaps, axis=1).reshape(-1,1)
#    for i in xrange(Max.shape[0]):
#        if Max[i,0]==0:
#           Max[i,0]=1e-12 
          
    rate_maps_all  = rtMaps
        
    
    
    fig = plt.subplots(figsize=(15, 15))
    ax = plt.gca()
    im = ax.imshow(rate_maps_all, cmap="jet", aspect='equal')
    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    ax.set_xlim((0, Nbins))
    ax.set_xticks(range(0,Nbins+1, 50 / (npath_x/Nbins)) )
    ax.set_xticklabels(['-0.5', '-0.25', '0', '0.25', '0.5'], fontsize = 13)
    ax.set_yticks(range(0,Ncells+1, 20))
    ax.set_yticklabels([str(x) for x in range(0,Ncells+1, 20)], fontsize = 13)    
    ax.set_title(case, fontsize=14)
    plt.savefig(maindir+fdname1+learning+'/'+case+'_'+str(ntrial)+'_heatmap_all_cells.pdf',format='pdf',dpi=300)
#    plt.savefig(maindir+fdname1+learning+'/'+case+'_'+ntrial+'_heatmap_all_cells.png',format='png',dpi=300)
    #==============================================================================
    # ##################### RATE MAPS SAVING #################################
    #==============================================================================

    mydict= {}
    mydict['maps']=rateMaps
    mydict['maps_unsmoothed']=rateMaps_unsmoothed
    mydict['time_in_bin'] = time_array_in_bin
    filesave = maindir+fdname2+learning
    if not os.path.exists(filesave):
        os.makedirs(filesave)
    
    with open(filesave+'/pickled_sn_'+case+'_'+ntrial+'.pkl', 'wb') as handle:
        pickle.dump(mydict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    print "\nDone with "+case+" analysis. Done with trial "+ntrial


tic      = time.time()
ntrial   = sys.argv[1]
case     = sys.argv[2]
learning = sys.argv[3]
results  = analysis_path_cluster(ntrial,case,learning)
toc      = time.time()

print "\nTotal time: "+str(round(toc-tic,3))+" seconds"