# -*- coding: utf-8 -*-
"""
Created on Tue Feb 9 11:11:39 2016
@author: dalbis
"""
from time import clock
import numpy as np
from numpy.linalg import norm
import os
from simlib import print_progress,gen_hash_id,gen_string_id
from simlib import ensureParentDir
def get_pos_idx(p,pos):
p_dist=((np.array(p)-pos)**2).sum(axis=1)
p_idx = np.argmin(p_dist)
return p_idx
def inside_arena(p,arena_shape,L):
"""
Check if the rat is in the arena
"""
if arena_shape == 'square':
if p[0]>=-L/2. and p[0]<L/2. and p[1]>=-L/2. and p[1]<L/2.:
return True
else:
return False
elif arena_shape == 'circle':
if norm(p)<L/2.:
return True
else:
return False
def bounce_fun(p,v,arena_shape,L,dt,theta_sigma=0):
"""
Bounce against walls. Usage:
v_new=bounce_fun(p,v,arena_shape,L)
theta=arctan2(v_new[1],v_new[0])
"""
if arena_shape == 'square':
if p[0]<-L/2. or p[0]>=L/2.:
v_new= np.array([-v[0],v[1]])
elif p[1]<-L/2. or p[1]>=L/2.:
v_new= np.array([v[0],-v[1]])
else:
v_new= v
elif arena_shape == 'circle':
n = p/norm(p)
v_new=v-2*n*np.dot(n,v)
else:
v_new = v
theta=np.arctan2(v_new[1],v_new[0])
if theta_sigma>0:
theta=theta_sigma*np.randn()+theta
p=p+v_new*dt
return p,theta
def periodic_pos(p,L) :
for idx in 0,1:
if p[idx]<-L/2.:
p[idx]+=L
elif p[idx]>=L/2.:
p[idx]-=L
return p
class GridWalk(object):
# key parameters
key_params=['arena_shape','L','speed','theta_sigma','position_dt',
'walk_seed','walk_time','periodic_walk','bounce','bounce_theta_sigma',
'virtual_bound_ratio','variable_speed','speed_theta','speed_sigma']
results_path='../results/grid_walk'
@staticmethod
def get_id(paramMap):
print gen_string_id(paramMap,key_params=GridWalk.key_params)
return gen_hash_id(gen_string_id(paramMap,key_params=GridWalk.key_params))
@staticmethod
def get_data_path(paramMap):
return os.path.join(GridWalk.results_path,GridWalk.get_id(paramMap)+'_data.npz')
def __init__(self,keyParamMap,do_print=True,force=False,
init_p=np.array([0.,0.]),init_theta=0.0):
# import parameters
for param in GridWalk.key_params:
setattr(self,param,keyParamMap[param])
self.init_theta=init_theta
self.init_p=init_p
self.id=GridWalk.get_id(keyParamMap)
self.paramsPath=os.path.join(self.results_path,self.id+'_log.txt')
self.dataPath=os.path.join(self.results_path,self.id+'_data.npz')
if force or not os.path.exists(self.dataPath):
# generate and save data
self.gen_data(do_print)
# load data
self.load_data(do_print)
def load_data(self,do_print):
"""
Loads data from disk
"""
if do_print:
print
print 'Loading walk data, Id = %s'%self.id
data=np.load(self.dataPath)
for k,v in data.items():
setattr(self,k,v)
if do_print:
print 'Loaded variables: '+' '.join(data.keys())
def gen_data(self,do_print):
"""
Generates walk data and saves it to disk
"""
if do_print:
print
print 'Generating walk data, Id = %s'%self.id
self.post_init()
self.run()
toSaveMap={'pos':self.pos,'pidx_vect':self.pidx_vect,'dx':self.dx,'nx':self.nx,'walk_steps':self.walk_steps,'speed_vect':self.speed_vect}
ensureParentDir(self.dataPath)
np.savez(self.dataPath,**toSaveMap)
if do_print:
print 'Result saved in: %s\n'%self.dataPath
def post_init(self):
self.dx=self.position_dt
self.nx=int(self.L/self.position_dt)
self.walk_steps = int(self.walk_time/self.position_dt)
self.current_speed=self.speed
np.random.seed(self.walk_seed)
X,Y=np.mgrid[-self.L/2:self.L/2:self.dx,-self.L/2:self.L/2:self.dx]
iX,iY=np.mgrid[-self.nx/2:self.nx/2:1,-self.nx/2:self.nx/2:1]
self.pos=np.array([np.ravel(X), np.ravel(Y)]).T
self.ipos=np.array([np.ravel(iX), np.ravel(iY)]).T
self.startClock=clock()
def update_speed(self):
self.current_speed = self.current_speed+self.speed_theta*(self.speed-self.current_speed)*self.position_dt+self.speed_sigma*np.sqrt(self.position_dt)*np.random.randn()
self.current_speed=self.current_speed.clip(min=0)
def update_position(self):
# choose next running direction and position
p0=self.p
while True:
# update theta
self.theta = self.theta+self.theta_sigma*np.sqrt(self.position_dt)*np.random.randn()
self.v = np.array([np.cos(self.theta),np.sin(self.theta)])*self.current_speed
self.p=p0+self.v*self.position_dt
# boundary conditions
if inside_arena(self.p,self.arena_shape,self.L*self.virtual_bound_ratio):
break
else:
if self.periodic_walk is True:
self.p
self.p=periodic_pos(self.p,self.L)
break
if self.bounce is True:
self.p,self.theta=bounce_fun(self.p,self.v,self.arena_shape,self.L*self.virtual_bound_ratio,self.position_dt,theta_sigma=self.bounce_theta_sigma)
break
def run(self):
self.p = self.init_p
self.theta=self.init_theta
self.pidx_vect=np.zeros((self.walk_steps),dtype=np.int32)
progress_clock=clock()
snap_idx=0
num_snaps=2000
delta_snap=self.walk_steps/num_snaps
self.speed_vect=np.zeros(num_snaps)
for step_idx in xrange(self.walk_steps):
if self.variable_speed is True:
self.update_speed()
self.update_position()
if np.remainder(step_idx,delta_snap)==0:
if snap_idx<num_snaps:
self.speed_vect[snap_idx]=self.current_speed
print_progress(snap_idx,num_snaps,progress_clock)
snap_idx+=1
pidx=get_pos_idx(self.p,self.pos)
self.pidx_vect[step_idx]=pidx
def plot_occupancy(self):
import pylab as pl
from plotlib import custom_axes
p_hist,x,y = np.histogram2d(self.pos[self.pidx_vect,0],self.pos[self.pidx_vect,1],
range=[[-self.L/2, self.L/2], [-self.L/2, self.L/2]],bins=50)
pl.figure()
# plot the results
pl.subplot(111,aspect='equal')
custom_axes()
pl.xlim(-self.L/2,self.L/2)
pl.ylim(-self.L/2,self.L/2)
pl.pcolormesh(x,y,p_hist)
pl.colorbar()
pl.xlabel('X bin')
pl.ylabel('Y bin')
pl.title('Visits')
pl.figure()
pl.hist(self.pidx_vect,color='k',bins=100)
pl.figure()
pl.plot(self.pidx_vect,'-k')
def plot(self,num_steps=1000):
import pylab as pl
from plotlib import custom_axes
pl.figure()
pl.subplot(111,aspect='equal')
custom_axes()
pl.xlim(-self.L/2,self.L/2)
pl.ylim(-self.L/2,self.L/2)
if num_steps is not None:
pl.plot(self.pos[self.pidx_vect[0:num_steps],0],self.pos[self.pidx_vect[0:num_steps],1],'.k',ms=2)
else:
pl.plot(self.pos[self.pidx_vect,0],self.pos[self.pidx_vect,1],'.k',ms=2)
pl.figure(figsize=(10,4))
pl.subplot(121)
time=np.linspace(0,self.walk_time,len(self.speed_vect))
pl.plot(time,self.speed_vect,'-k')
pl.xlabel('Time [s]')
pl.ylabel('Speed [m/s]')
custom_axes()
pl.xlim(0,5)
pl.subplot(122)
pl.hist(self.speed_vect,color='k',bins=100)
pl.xlabel('Speed [m/s]')
pl.ylabel('Count')
custom_axes()
pl.title('Mean = %.2f Var = %.3e'%(self.speed_vect.mean(),self.speed_vect.var()))
def remove_data(self):
os.remove(self.dataPath)
if __name__ == '__main__':
L=1.5
params_map = { 'arena_shape':'square',
'L':L,
'speed':0.25,
'theta_sigma':0.7,
'position_dt':L/50.,
'walk_seed':0,
'walk_time':300,
'periodic_walk':False,
'bounce':True,
'bounce_theta_sigma':0.,
'virtual_bound_ratio':1.0,
'variable_speed':False,
}
tw=GridWalk(params_map,force=True)
tw.plot_occupancy()