#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 25 17:36:15 2019

@author: dalbis
"""


import pylab as pl
import numpy as np
import grid_utils.plotlib as pp
import grid_utils.gridlib as gl
import grid_utils.simlib as sl

from recamp_2pop import RecAmp2PopLearn,RecAmp2PopSteady
import amp_paper_2d_main as apm


input_color='m'


#%%
### LOAD DEFAULT SIMULATION DATA

extra_params={}
sim_conn=RecAmp2PopLearn(sl.map_merge(apm.def_recamp_learn_params,extra_params))

sim=RecAmp2PopSteady(sl.map_merge(apm.def_recamp_steady_params,
extra_params,{'recurrent_weights_path':sim_conn.data_path}))

sim.post_init()
sim.load_weights_from_data_path(sim.recurrent_weights_path)
sim.load_steady_outputs()
sim.load_steady_scores()



#%% ========================================================================================

### EXAMPLE INPUTS

input_idxs=[1,2,3]  
dx=sim.L/sim.nx
xran=np.arange(sim.nx)*sim.L/sim.nx-sim.L/2-dx/2.

pl.figure(figsize=(5.5,3))
pl.subplots_adjust(wspace=0.2)

for idx,cell_idx in enumerate(input_idxs):
  pl.subplot(2,5,idx+1,aspect='equal')
  r_map=sim.inputs.inputs_flat[:,cell_idx].reshape(sim.nx,sim.nx).T
  pl.pcolormesh(xran,xran,r_map,vmin=0,rasterized=True)
  
  pl.title('%2.1f    $\\bf{%2.2f}$'%(r_map.max(),sim.grid_tuning_in[cell_idx]),fontsize=8)

  #pl.title('%-18.1f '%(r_map.max()),fontsize=8)
  pp.noframe()  


fname = 'fig2d_model_example_inputs' 
pp.save_fig(sl.get_figures_path(),fname,exts=['png','svg'])

#%%

### PLOT EXAMPLE EXCITATORY/INHIBITORY OUTPUTS ===========================================

output_scores_to_show=[0.35,0.4,0.45,0.5,0.6]
exc_cell_idxs=[np.argmin(np.abs(sim.grid_tuning_out-out_score)) for out_score in output_scores_to_show]
inhib_cell_idxs=np.array([0,1,2,3,4])#+sim.N_e
  
dx=sim.L/sim.nx
xran=np.arange(sim.nx)*sim.L/sim.nx-sim.L/2-dx/2.

plot_scores=True
vmax=None

for inhibitory in False,True:
  
  cell_idxs = inhib_cell_idxs if inhibitory else exc_cell_idxs
  #print cell_idxs

  pl.figure(figsize=(5.5,3))
  pl.subplots_adjust(wspace=0.2)
  for idx,cell_idx in enumerate(cell_idxs):

    pl.subplot(2,5,idx+1,aspect='equal')
    r_map=sim.r[cell_idx+(sim.N_e if inhibitory else 0),:].reshape(sim.nx,sim.nx).T

      
    pl.pcolormesh(xran,xran,r_map,rasterized=True,vmin=0,vmax=vmax)
    
    if plot_scores is True:      
        if inhibitory is True:          
            pl.title('%2.1f    $\\bf{%2.2f}$'%(r_map.max(),sim.grid_tuning_out_inhib[cell_idx]),fontsize=8)            
        else:
            pl.title('%2.1f    $\\bf{%2.2f}$'%(r_map.max(),sim.grid_tuning_out[cell_idx]),fontsize=8)
    
    pp.noframe()
        
  fname = 'fig2d_model_inhib_outputs' if inhibitory else 'fig2d_model_exc_outputs'
  pp.save_fig(sl.get_figures_path(),fname,exts=['png','svg'])
    
#%% =================================================================================

### PLOT RECURRENT CONNECTIVITY 
  

for learned in True,False:  

  W=sim.W_ee if learned else sim.W_ee0
  pp.plot_recurrent_weights(W,sim.gp,vmax=sim.W_max_ee,ms=5,figsize=(3.2,3.5))
  tuning_index= gl.get_recurrent_matrix_tuning_index(W,sim.gp)

  fname = 'fig2d_model_rec_weights_learned' if learned else 'fig2d_model_rec_weights_init'
  pp.save_fig(sl.get_figures_path(),fname,exts=['png','svg'])
  
  print 'Connnectivity tuning index: %.3f'%tuning_index

#%%

  
### PLOT TUNING INDEX HISTOGRAMS 
  
  
pl.rc('font',size=10)

bins = pl.histogram_bin_edges(sim.grid_tuning_in,bins=100,range=[0,1])
#logbins = np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins))

pl.figure(figsize=(2.7,2.3))
pl.subplots_adjust(left=0.3,bottom=0.26,right=0.95)
pl.hist(sim.grid_tuning_out_inhib,bins=bins,color='dodgerblue',histtype='stepfilled',weights=np.ones_like(sim.grid_tuning_out_inhib)/float(len(sim.grid_tuning_out_inhib)),alpha=1)
pl.hist(sim.grid_tuning_in,bins=bins,color=input_color,histtype='stepfilled',weights=np.ones_like(sim.grid_tuning_in)/float(len(sim.grid_tuning_in)),alpha=1)
pl.hist(sim.grid_tuning_out,bins=bins,color='black',histtype='stepfilled',weights=np.ones_like(sim.grid_tuning_out)/float(len(sim.grid_tuning_out)),alpha=1)
pl.hist(sim.grid_tuning_out_inhib,bins=bins,color='dodgerblue',histtype='step',weights=np.ones_like(sim.grid_tuning_out_inhib)/float(len(sim.grid_tuning_out_inhib)),alpha=1)

pp.custom_axes()
pl.xlim(0,0.7)
pl.xlabel('Grid tuning index')
pl.ylabel('Fraction of cells')

print 'Mean input grid tuning index: %.2f'%np.mean(sim.grid_tuning_in)
print 'Mean output grid tuning index: %.2f'%np.mean(sim.grid_tuning_out)


fname = 'fig2d_model_grid_tuning_hists'

#pp.save_fig(sl.get_figures_path(),fname,exts=['png','svg'])