# -*- coding: utf-8 -*-
"""
Created on Tue Feb  9 11:11:39 2016

@author: dalbis
"""

import time
import numpy as np
from numpy.linalg import norm
import os
from simlib import print_progress,gen_hash_id,gen_string_id
import simlib as sl

def get_pos_idx(p,pos):
  
  p_dist=((np.array(p)-pos)**2).sum(axis=1)  
  p_idx = np.argmin(p_dist)
  return p_idx
  
def inside_arena(p,arena_shape,L):
  """
  Check if the rat is in the arena
  """
  if arena_shape == 'square':
    if p[0]>=-L/2. and p[0]<L/2. and p[1]>=-L/2. and p[1]<L/2.:
      return True
    else:
      return False
  elif arena_shape == 'circle':
    if norm(p)<L/2.:
      return True
    else:
      return False
      
def bounce_fun(p,v,arena_shape,L,dt,theta_sigma=0):
  """
  Bounce against walls.  Usage:
  v_new=bounce_fun(p,v,arena_shape,L)
  theta=arctan2(v_new[1],v_new[0])
  """        
  if arena_shape == 'square':
    if p[0]<-L/2. or p[0]>=L/2.:
      v_new= np.array([-v[0],v[1]])
    elif p[1]<-L/2. or p[1]>=L/2.:
      v_new= np.array([v[0],-v[1]]) 
    else:
      v_new= v
  elif arena_shape == 'circle':
    n = p/norm(p)
    v_new=v-2*n*np.dot(n,v)
  else:
    v_new = v
  
  theta=np.arctan2(v_new[1],v_new[0])
  if theta_sigma>0:
    theta=theta_sigma*np.randn()+theta

  p=p+v_new*dt
  return p,theta

def periodic_pos(p,L) :
  for idx in 0,1:
    if p[idx]<-L/2.:
      p[idx]+=L
    elif p[idx]>=L/2.:
      p[idx]-=L
  return p
    
    
# key parameters
key_params_const_speed=['arena_shape','L','speed','theta_sigma','position_dt','nx',
                        'walk_seed','walk_time','periodic_walk','bounce','bounce_theta_sigma',
                        'virtual_bound_ratio','init_p','init_theta']

var_speed_params=['variable_speed','speed_theta','speed_sigma']


    
class RandomWalk(object):

  results_path=os.path.join(sl.get_results_path(),'random_walk')

  
  @staticmethod
  def get_id(paramMap):  
    string_id=gen_string_id(paramMap,key_params=RandomWalk.get_key_params(paramMap))
    #print string_id
    return gen_hash_id(string_id)

    
  @staticmethod  
  def get_data_path(paramMap):
    return os.path.join(RandomWalk.results_path,RandomWalk.get_id(paramMap)+'_data.npz')
    
  @staticmethod
  def get_key_params(paramMap):
    key_params=key_params_const_speed
    
    if 'variable_speed' in paramMap.keys() and paramMap['variable_speed'] is True:
      key_params+=var_speed_params

    if 'run_in_circle' in paramMap.keys() and paramMap['run_in_circle'] is True:
      key_params+=['run_in_circle',]
      
    if 'sweep' in paramMap.keys() and paramMap['sweep'] is True:
      key_params+=['sweep',]
      
      
    return key_params
    
    
  def __init__(self,keyParamMap,do_print=True,force=False):
               #init_p=np.array([0.,0.]),init_theta=0.0):
            
    # import parameters
    for param in RandomWalk.get_key_params(keyParamMap):
      setattr(self,param,keyParamMap[param])
      #print 'setting %s %s'%(param,keyParamMap[param])
          

    #self.init_theta=init_theta
    #self.init_p=init_p
    
    self.circle_radius=0.397
    self.circle_omega=self.speed/self.circle_radius
      
    if not hasattr(self,'variable_speed'):
      self.variable_speed=False
      
    if not hasattr(self,'run_in_circle'):
      self.run_in_circle=False       

    if not hasattr(self,'sweep'):
      self.sweep=False     
      
    self.id=RandomWalk.get_id(keyParamMap)
    self.paramsPath=os.path.join(self.results_path,self.id+'_log.txt')
    self.dataPath=os.path.join(self.results_path,self.id+'_data.npz')   

    if force or not os.path.exists(self.dataPath):
      
      # generate and save data
      self.gen_data(do_print)

    # load data 
    self.load_data(do_print)
        

        
  def load_data(self,do_print):
    """
    Loads data from disk
    """
    
    if do_print:
      print
      print 'Loading walk data, Id = %s'%self.id
    
    data=np.load(self.dataPath,allow_pickle=True)
    
    for k,v in data.items():
      setattr(self,k,v)

    if do_print:     
      print 'Loaded variables: '+' '.join(data.keys())


  def gen_data(self,do_print):
    """
    Generates walk data and saves it to disk
    """

    if do_print:
      print
      print 'Generating walk data, Id = %s'%self.id
    
    self.post_init()
    self.run()

    toSaveMap={'pos':self.pos,'pidx_vect':self.pidx_vect,'walk_steps':self.walk_steps,'speed_vect':self.speed_vect}               
    np.savez(self.dataPath,**toSaveMap)

    if do_print:    
      print 'Result saved in: %s\n'%self.dataPath

      
  def post_init(self):

    #self.dx=self.position_dt
    #self.nx=int(self.L/self.position_dt)
    
    self.dx=float(self.L)/self.nx
    
    
    self.walk_steps = int(self.walk_time/self.position_dt)
    
    self.current_speed=self.speed
    
    np.random.seed(self.walk_seed)

    X,Y=np.mgrid[-self.L/2:self.L/2:self.dx,-self.L/2:self.L/2:self.dx]
    iX,iY=np.mgrid[-self.nx/2:self.nx/2:1,-self.nx/2:self.nx/2:1]    
    
    self.pos=np.array([np.ravel(X), np.ravel(Y)]).T
    self.ipos=np.array([np.ravel(iX), np.ravel(iY)]).T
        
    self.startClock=time.time()

  def update_speed(self):

    self.current_speed = self.current_speed+self.speed_theta*(self.speed-self.current_speed)*self.position_dt+self.speed_sigma*np.sqrt(self.position_dt)*np.random.randn()
    self.current_speed=self.current_speed.clip(min=0)
    
  def update_position(self):
    
    # choose next running direction and position
    p0=self.p  
    while True:
      
      # update theta    
      self.theta = self.theta+self.theta_sigma*np.sqrt(self.position_dt)*np.random.randn()
        
      self.v = np.array([np.cos(self.theta),np.sin(self.theta)])*self.current_speed
      self.p=p0+self.v*self.position_dt
  
      # boundary conditions
      if inside_arena(self.p,self.arena_shape,self.L*self.virtual_bound_ratio):
        break
      else:
        
        if self.periodic_walk is True:
          self.p
          self.p=periodic_pos(self.p,self.L)
          break
        
        if self.bounce is True:
          self.p,self.theta=bounce_fun(self.p,self.v,self.arena_shape,self.L*self.virtual_bound_ratio,self.position_dt,theta_sigma=self.bounce_theta_sigma)
          break
            


  def update_position_circle(self):
      
      self.circle_angle+=self.circle_omega*self.position_dt
      self.p=self.circle_radius*np.array([np.cos(self.circle_angle),np.sin(self.circle_angle)])
      
      
            
  def update_position_sweep(self):

      self.v = np.array([np.cos(self.theta),np.sin(self.theta)])*self.current_speed
      self.p_next=self.p+self.v*self.position_dt
    
      # got to a boundary
      if not inside_arena(self.p_next,self.arena_shape,self.L*self.virtual_bound_ratio):
        down_theta=-np.pi/2.
        
        self.v = np.array([np.cos(down_theta),np.sin(down_theta)])*self.current_speed
        self.p_next=self.p+self.v*self.position_dt

        # changing direction        
        if self.theta==0:
          self.theta=-np.pi
        elif self.theta==-np.pi:
          self.theta=0.
          
                      
        if not inside_arena(self.p_next,self.arena_shape,self.L*self.virtual_bound_ratio):
          print 'sweep end'
          self.sweep_end=True
          return
                  
      self.p=self.p_next
      
      
  def run(self):
        
    self.p = self.init_p
    #print self.p
    
    if self.sweep:
      self.p=np.array([-self.L/2.,self.L/2.])
      self.sweep_end=False
    
    elif self.run_in_circle:
      self.circle_angle=0.
      self.p=np.array([self.circle_radius,0.])
      
    self.theta=self.init_theta

    self.pidx_vect=np.zeros((self.walk_steps),dtype=np.int32)    

    progress_clock=time.time()

    snap_idx=0
    num_snaps=100
    delta_snap=self.walk_steps/num_snaps
    self.speed_vect=np.zeros(num_snaps)
    
    for step_idx in xrange(self.walk_steps):

      
      
      if self.variable_speed is True:
        self.update_speed()
      
      if self.sweep is True:
        if self.sweep_end: 
          break
        else:
          self.update_position_sweep()
          
      elif self.run_in_circle is True:
        self.update_position_circle()
      else:
        self.update_position()
      
      if np.remainder(step_idx,delta_snap)==0:
        if snap_idx<num_snaps:
          self.speed_vect[snap_idx]=self.current_speed
        print_progress(snap_idx,num_snaps,progress_clock)
        snap_idx+=1
        
      
      pidx=get_pos_idx(self.p,self.pos)
      self.pidx_vect[step_idx]=pidx
            
  def plot_occupancy(self):
    import pylab as pl
    from plotlib import custom_axes

    p_hist,x,y = np.histogram2d(self.pos[self.pidx_vect,0],self.pos[self.pidx_vect,1],
                                range=[[-self.L/2, self.L/2], [-self.L/2, self.L/2]],bins=50)
    
    pl.figure()
    # plot the results
    pl.subplot(111,aspect='equal')
    custom_axes()
    pl.xlim(-self.L/2,self.L/2)
    pl.ylim(-self.L/2,self.L/2)
    
    pl.pcolormesh(x,y,p_hist)
    pl.colorbar()
    pl.xlabel('X bin')
    pl.ylabel('Y bin')
    pl.title('Visits')

    pl.figure()
    pl.hist(self.pidx_vect,color='k',bins=100)

    pl.figure()
    pl.plot(self.pidx_vect,'-k')


  def plot(self,num_steps=1000):
    import pylab as pl
    from plotlib import custom_axes
    
    pl.figure()
    pl.subplot(111,aspect='equal')
    custom_axes()
    pl.xlim(-self.L/2,self.L/2)
    pl.ylim(-self.L/2,self.L/2)
    
    if num_steps is not None:
      pl.plot(self.pos[self.pidx_vect[0:num_steps],0],self.pos[self.pidx_vect[0:num_steps],1],'.k',ms=2)  
    else:
      pl.plot(self.pos[self.pidx_vect,0],self.pos[self.pidx_vect,1],'.k',ms=2)  
        
    pl.figure(figsize=(10,4))
    pl.subplot(121)
    time=np.linspace(0,self.walk_time,len(self.speed_vect))
    pl.plot(time,self.speed_vect,'-k')
    pl.xlabel('Time [s]')
    pl.ylabel('Speed [m/s]')
    custom_axes()
    pl.xlim(0,5)

    pl.subplot(122)
    pl.hist(self.speed_vect,color='k',bins=100)
    pl.xlabel('Speed [m/s]')
    pl.ylabel('Count')
    custom_axes()
    pl.title('Mean = %.2f  Var = %.3e'%(self.speed_vect.mean(),self.speed_vect.var()))

  def remove_data(self):
    os.remove(self.dataPath)
    
if __name__ == '__main__':
  
  ## TESTING
  
  L=1.5
  params_map = { 'arena_shape':'square',
                'L':L,
                'speed':0.25,
                'theta_sigma':0.7,
                'position_dt':L/50.,
                'walk_seed':0,
                'walk_time':300,
                'periodic_walk':False,
                'bounce':True,
                'bounce_theta_sigma':0.,
                'virtual_bound_ratio':1.0,
                'variable_speed':False,
                'run_in_circle':False
              }        

  tw=RandomWalk(params_map,force=True)
  tw.plot_occupancy()