# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 10:25:22 2017

@author: dalbis
"""


import numpy as np
import pylab as pl
import plotlib as pp
import os
from grid_const import ModelType
from grid_functions import map_merge
from grid_params import GridSpikeParams
from grid_batch import GridBatch
from simlib import ensureDir

figures_path='../figures'
ensureDir(figures_path)
  
#%%

######################## MULTILPLE LEARNING RATES ###################################


learn_rates=[2e-5,3e-5,5e-5,1e-4]
batch_default_map=map_merge(
                GridSpikeParams.gau_grid_small_arena_biphasic_neg,
                {
                'a':1.1                
                })

                
# time vector
num_sim_steps = int(batch_default_map['sim_time']/batch_default_map['dt'])
delta_snap = int(num_sim_steps/batch_default_map['num_snaps'])
snap_times=np.arange(batch_default_map['num_snaps'])*delta_snap*batch_default_map['dt']

      
pl.rc('font',size=13)
pp.set_tick_size(4)

fig=pl.figure(figsize=(8,3),facecolor='w')
ax1=pl.subplot(1,2,1)
ax2=pl.subplot(1,2,2)
pl.subplots_adjust(bottom=0.2,wspace=0.4,left=0.1,right=0.97)

colormap = pl.cm.gist_heat_r
pl.sca(ax1)
pl.gca().set_color_cycle([colormap(i) for i in np.linspace(0.2, 1.0, len(learn_rates))])


for eta in reversed(learn_rates):                
  batch_override_map = {'seed':np.arange(40),'eta':(eta,)}
    
  batch=GridBatch(ModelType.MODEL_SPIKING,batch_default_map,batch_override_map)
  batch.post_init()
  if not os.path.exists(batch.batch_data_path):
    
    print 'Running simulations for learning rate %.3e  (for 40 different initial weights)'%eta
    batch.run()
    batch.post_run()

  batch_data=np.load(batch.batch_data_path)
  all_evo_weight_scores=batch_data['evo_weight_scores_map'][()].values()
  all_evo_weight_scores_mat=np.array(all_evo_weight_scores)
  median_scores=np.median(all_evo_weight_scores_mat,axis=0)
  perc_25=np.percentile(all_evo_weight_scores_mat,25,axis=0)
  perc_75=np.percentile(all_evo_weight_scores_mat,75,axis=0)
  pl.sca(ax1)

  pl.plot(snap_times,median_scores,'-',lw=1.5)  
  
pl.sca(ax1)
pp.custom_axes()
pl.gca().set_xscale('log')
pl.xlim(1e5,snap_times.max())
pl.ylim(0,1.5)
pl.xlabel('Time [10^5 s]')
pl.ylabel('Gridness score')
pl.yticks([0,0.5,1.0,1.5])
pl.xticks([1e5,2e5,5e5,1e6],['1','2','5','10'])  

#%%

##### PLOT CONSTANT VS VARIABLE SPEED ======================================================
                
pl.sca(ax2)
batch_override_map = {'seed': np.arange(40)}

# constant speed
batch=GridBatch(ModelType.MODEL_SPIKING,batch_default_map,batch_override_map)
batch.post_init()

if not os.path.exists(batch.batch_data_path):
  print 'Running simulations with constant speed (for 40 different initial weights)'
  batch.run()
  batch.post_run()
  
  
batch_data=np.load(batch.batch_data_path)
all_evo_weight_scores=batch_data['evo_weight_scores_map'][()].values()
all_evo_weight_scores_mat=np.array(all_evo_weight_scores)

pl.plot(snap_times,all_evo_weight_scores_mat.T,'--k')

median_scores=np.median(all_evo_weight_scores_mat,axis=0)

pl.plot(snap_times,median_scores,'-k',lw=1.5)  
#pl.plot(snap_times,all_evo_weight_scores_mat.T)

#%%
batch_default_map_var_speed=map_merge(
                GridSpikeParams.gau_grid_small_arena_biphasic_neg,
                {
                'a':1.1,
                'variable_speed':True,
                })
  
batch_override_map = {'seed': np.arange(40)}
  

# variable speed
batch_var=GridBatch(ModelType.MODEL_SPIKING,batch_default_map_var_speed,batch_override_map)
batch_var.post_init()
if not os.path.exists(batch_var.batch_data_path):
  print 'Running simulations with variable speed (for 40 different initial weights)'
  batch_var.run()
  batch_var.post_run()
batch_data=np.load(batch_var.batch_data_path)
all_evo_weight_scores=batch_data['evo_weight_scores_map'][()].values()
all_evo_weight_scores_mat=np.array(all_evo_weight_scores)
pl.plot(snap_times,all_evo_weight_scores_mat.T,'g')
median_scores=np.median(all_evo_weight_scores_mat,axis=0)


pl.plot(snap_times,median_scores,'-',lw=1.5,color=pp.green)  
#pl.plot(snap_times,all_evo_weight_scores_mat.T)

pl.sca(ax2)
pp.custom_axes()
pl.gca().set_xscale('log')
pl.xlim(1e5,snap_times.max())
pl.ylim(0,1.5)
pl.xlabel('Time [10^5 s]')
pl.ylabel('Gridness score')
pl.yticks([0,0.5,1.0,1.5])
pl.xticks([1e5,2e5,5e5,1e6],['1','2','5','10'])   

fig.savefig(os.path.join(figures_path,'fig6.eps'),dpi=300,transparent=True)      

##%%
#from grid_walk import GridWalk
#walk=GridWalk(map_merge(batch_default_map_var_speed,
#                            {'arena_shape':'square',  
#                             'virtual_bound_ratio':1.0,
#                             'bounce_theta_sigma':0.0,
#                             'position_dt':1/200.  }))
#
#
#fig=pl.figure(figsize=(1.5,1.))
#pl.subplots_adjust(left=0.2,bottom=0.2)
#n,bins,patches=pl.hist(walk.speed_vect,100,color=pp.green,edgecolor=pp.green,normed=1)
#pl.yticks([0,25])
#pl.xticks([.15,.25,.35])
#pp.custom_axes()
#    
#fig.savefig(os.path.join(figures_path,'fig6b_inset.eps'),dpi=300,transparent=True)