# -*- coding: utf-8 -*-
import numpy as np
import os
import datetime,time
import pylab as pl
import grid_utils.plotlib as pp
import grid_utils.gridlib as gl
import grid_utils.simlib as sl
from grid_utils.random_walk import RandomWalk
from grid_utils.spatial_inputs import SpatialInputs
from grid_utils.spatial_inputs import rhombus_mises
def dir_vect(theta):
"""
Returns a 2-d vector given an angle theta
"""
return np.array([np.cos(theta),np.sin(theta)])
def find_bump_peak_idxs(map1d,unravel=True,**kwargs):
"""
Utility function that returns the 2D coordinates of an activity bump
"""
from scipy.ndimage import gaussian_filter
map_side=int(np.sqrt(len(map1d)))
map2d=map1d.reshape(map_side,map_side)
map2d_smoothed=gaussian_filter(map2d, sigma=3,mode='wrap')
max_idx=map2d_smoothed.argmax()
if unravel:
return np.unravel_index(max_idx, map2d.shape)
else:
return max_idx
def get_bumps_idxs_all_shifts(r,N_e):
bump_idsx2d=np.zeros((2,N_e))
# compute the coordinates of the output bump of each stimulus bump shift
for shift_idx in range(N_e):
bx,by=find_bump_peak_idxs(r[:N_e,shift_idx],unravel=True)
bump_idsx2d[0,shift_idx]=bx
bump_idsx2d[1,shift_idx]=by
return bump_idsx2d
def get_activation_fun(activation,r_max):
"""
Neuronal activation function
"""
clip = lambda(x) : x.clip(min=0,max=r_max)
tanh = lambda(x) : np.clip(np.tanh(x/float(r_max))*r_max,1e-16,r_max)
if activation == 'linear':
return lambda x : x
if activation == 'clip':
return clip
elif activation == 'tanh':
return tanh
else:
raise Exception('Invalid activation function')
class __RecAmp2Pop(object):
"""
This class implements the basic methods for simulating the amplification model in 2D with 2 populations of
neuron (excitatory and inhibitory). This class is private and abstract and is the common building block
upon which the classes RecAmp2PopLearn and RecAmp2PopSteady are written.
The class RecAmp2PopLearn deals with the learning of the recurrent connectivity.
The class RecAmp2PopSteady deals with the simulation of the output steady-state patterns
"""
def __init__(self):
pass
def get_inputs(self,force_gen=False,comp_gridness_score=False,comp_tuning_index=True):
"""
Read inputs from disk and loads into the object
"""
#print self.paramMap
self.inputs=SpatialInputs(sl.map_merge(self.paramMap,{'n':self.n_e}),
force_gen=force_gen,
comp_gridness_score=comp_gridness_score,
comp_tuning_index=comp_tuning_index)
self.inputs_flat=self.inputs.inputs_flat
self.h_e=self.inputs_flat.T
self.h_i=np.zeros((self.N_i,self.NX))
self.h=np.vstack([self.h_e,self.h_i])
def get_hardwired_speed_weights(self):
"""
Generate hardwired speed-weight connectivity matrix
"""
phase_shift=self.speed_phase_shift
# row 1 has the weights of speed cells to grid cell 1
self.W_speed_east=np.zeros_like(self.W_ee)
self.W_speed_west=np.zeros_like(self.W_ee)
self.W_speed_north=np.zeros_like(self.W_ee)
self.W_speed_south=np.zeros_like(self.W_ee)
if self.use_eight_directions is True:
self.W_speed_north_east=np.zeros_like(self.W_ee)
self.W_speed_north_west=np.zeros_like(self.W_ee)
self.W_speed_south_east=np.zeros_like(self.W_ee)
self.W_speed_south_west=np.zeros_like(self.W_ee)
for phase_idx,phase in enumerate(self.gp.phases):
shifted_north_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(np.pi/2.),self.gp.phases)
shifted_south_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(-np.pi/2.),self.gp.phases)
shifted_east_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(0),self.gp.phases)
shifted_west_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(-np.pi),self.gp.phases)
self.W_speed_north[phase_idx,:]=self.W_ee[shifted_north_phase_idx,:]
self.W_speed_south[phase_idx,:]=self.W_ee[shifted_south_phase_idx,:]
self.W_speed_east[phase_idx,:]=self.W_ee[shifted_east_phase_idx,:]
self.W_speed_west[phase_idx,:]=self.W_ee[shifted_west_phase_idx,:]
if self.use_eight_directions is True:
shifted_north_east_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(np.pi/4),self.gp.phases)
shifted_north_west_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(np.pi*3/4),self.gp.phases)
shifted_south_east_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(-np.pi/4),self.gp.phases)
shifted_south_west_phase_idx=gl.get_pos_idx(phase+phase_shift*dir_vect(-np.pi*3/4),self.gp.phases)
self.W_speed_north_east[phase_idx,:]=self.W_ee[shifted_north_east_phase_idx,:]
self.W_speed_north_west[phase_idx,:]=self.W_ee[shifted_north_west_phase_idx,:]
self.W_speed_south_east[phase_idx,:]=self.W_ee[shifted_south_east_phase_idx,:]
self.W_speed_south_west[phase_idx,:]=self.W_ee[shifted_south_west_phase_idx,:]
def switch_to_tuned_inputs(self):
"""
Switches on feed-forward spatial tuning (during the simulation)
"""
self.h_e=self.inputs_flat.T
self.h=np.vstack([self.h_e,self.h_i])
def switch_to_untuned_inputs(self):
"""
Switches off feed-forward spatial tuning (during the simulation)
"""
self.h_e=self.inputs.noise_flat.T
self.h=np.vstack([self.h_e,self.h_i])
def switch_to_no_feedforward_inputs(self):
"""
Switches off feed-forward spatial tuning (during the simulation)
"""
self.h_e=np.ones_like(self.inputs.noise_flat.T)*self.feed_forward_off_value
self.h=np.vstack([self.h_e,self.h_i])
def get_tuned_excitatory_weights(self,binary=True):
"""
Computes a tuned excitatory connectivity matrix
It sets to W_max_e the weights of the num_conns_ee connections with the smallest phase difference
If fixed_connectivity_tuning is set to a value smaller than 1, only the specified fraction of connections will be tuned
"""
self.W_ee=np.zeros((self.N_e,self.N_e))
if not hasattr(self,'fixed_connectivity_tuning'):
self.fixed_connectivity_tuning=1
num_tuned_conns=int(np.floor(self.fixed_connectivity_tuning*self.num_conns_ee))
num_untuned_conns=self.num_conns_ee-num_tuned_conns
for i in xrange(self.N_e):
ref_phase=self.gp.phases[i,:]
dists=gl.get_periodic_dist_on_rhombus(self.n_e,ref_phase,self.gp.phases,self.gp.u1,self.gp.u2)
if binary is True:
sorted_idxs=np.argsort(dists)
tuned_idxs=sorted_idxs[:self.num_conns_ee]
np.random.shuffle(tuned_idxs)
all_idxs=np.arange(self.N_e)
np.random.shuffle(all_idxs)
self.W_ee[i,tuned_idxs[0:num_tuned_conns]]=self.W_max_ee
self.W_ee[i,all_idxs[:num_untuned_conns]]=self.W_max_ee
else:
# initialize bump activity to zero
bump=np.zeros(self.N_e)
# To account for the periodicity of the phase space we need to add up 9 bumps that are shifted
# of +- one period in each dimension (i.e., 3 horizontal shifts + 3 vertical shifts)
for xp in (-1,0,1):
for yp in (-1,0,1):
shift=xp*self.gp.u1+yp*self.gp.u2
bump+=rhombus_mises(self.gp.phases-ref_phase[np.newaxis,:]+shift,0.1)
bump=bump/bump.sum()*self.W_tot_ee
self.W_ee[i,:]=bump
#self.W_ee[i,:]=expit(2*(1-dists/dists.max()))
#self.W_ee[i,:]/=self.W_ee[i,:].sum()
#self.W_ee[i,:]*=self.W_tot_ee
self.W[:self.N_e,:self.N_e]=self.W_ee
def get_random_inhibitory_weights(self):
"""
Compute random inhibitory connectivity matrix
"""
self.W_ei=np.zeros((self.N_e,self.N_i))
self.W_ie=np.zeros((self.N_i,self.N_e))
self. W_ii=np.zeros((self.N_i,self.N_i))
# connections to the excitatory neurons
for row_idx in range(self.N_e):
# from ihibitory
all_idxs_ei=np.arange(self.N_i)
np.random.shuffle(all_idxs_ei)
self.W_ei[row_idx,all_idxs_ei[0:self.num_conns_ei]]=self.W_max_ei
# connections to inhibitory neurons
for row_idx in range(self.N_i):
# from exitatory
all_idxs_ie=np.arange(self.N_e)
np.random.shuffle(all_idxs_ie)
self.W_ie[row_idx,all_idxs_ie[0:self.num_conns_ie]]=self.W_max_ie
# from inhibitory
all_idxs_ii=np.arange(self.N_i)
np.random.shuffle(all_idxs_ii)
self.W_ii[row_idx,all_idxs_ii[0:self.num_conns_ii]]=self.W_max_ii
self.W[:self.N_e,self.N_e:]=self.W_ei
self.W[self.N_e:,:self.N_e]=self.W_ie
self.W[self.N_e:,self.N_e:]=self.W_ii
def post_init(self,force_gen_inputs=False,comp_gridness_score=False,comp_tuning_index=True,get_inputs=True):
# set the seed
np.random.seed(self.seed)
# switches between local recurrent inhibition (n_i>0) or feed-forward inhibition (r0<0)
assert(self.n_i==0 or self.r0==0)
assert(self.n_i>0 or self.r0<0)
#print sl.params_to_str(self.paramMap)
self.N_e=self.n_e**2 # total number of excitatory neurons
self.N_i=self.n_i**2 # total number of inhibitory neurons
self.N=self.N_e+self.N_i # total number of neurons
self.NX=self.nx**2 # total number of space samples
self.num_conns_ee=int(np.floor(self.N_e*self.frac_conns_ee))
if self.n_i>0:
self.num_conns_ie=int(np.floor(self.N_e*self.frac_conns_ie))
self.num_conns_ei=int(np.floor(self.N_i*self.frac_conns_ei))
self.num_conns_ii=int(np.floor(self.N_i*self.frac_conns_ii))
# mean input/output weight for the excitatory neurons
self.W_av_star=float(self.W_tot_ee)/self.N_e
# maximal connection strength for excitatory/inhibitory neurons
self.W_max_ee=float(self.W_tot_ee)/self.num_conns_ee
if self.n_i>0:
self.W_max_ie=float(self.W_tot_ie)/self.num_conns_ie
self.W_max_ei=-float(self.W_tot_ei)/self.num_conns_ei
self.W_max_ii=-float(self.W_tot_ii)/self.num_conns_ii
# get grid properties
self.gp=gl.GridProps(self.n_e,self.grid_T,self.grid_angle)
# compute recurrent connectivity matrix (W_ee can be overwritten by learning or loaded from disk)
self.W=np.zeros((self.N,self.N))
self.get_tuned_excitatory_weights()
if self.N_i>0:
self.get_random_inhibitory_weights()
self.zero_phase_idx = gl.get_pos_idx([0.,0.],self.gp.phases)
# get inputs
if get_inputs:
self.get_inputs(force_gen_inputs,comp_gridness_score,comp_tuning_index)
def run_recurrent_dynamics(self,
record_mean_max=True,
record_rec_input_evo=False,
activation='tanh',
r_max=100,
init_r=None,
custom_h=None):
"""
The output activity is computed for each pixel independently without modeling
the random walk of the virtual rat explicitely
"""
print '\nRunning recurrent dynamics'
activation_fun=get_activation_fun(activation,r_max)
# initialization of the feed-forward input
if custom_h is None:
h=self.h
else:
assert(custom_h.shape[0]==self.N_e+self.N_i)
h=custom_h
# initialization of the output rates
if init_r is None:
r=np.zeros_like(h)
else:
assert(init_r.shape[0]==self.N_e+self.N_i)
r=init_r
# check that shapes matches
assert(h.shape==r.shape)
# number of spatial locations
num_pos=h.shape[1]
num_steps=int(self.recdyn_time/self.dt)
if record_mean_max is True:
self.rec_input_mean_vect=np.zeros((self.N,num_steps))
self.rec_input_max_vect=np.zeros((self.N,num_steps))
self.r_mean_vect=np.zeros((self.N,num_steps))
self.r_max_vect=np.zeros((self.N,num_steps))
self.r_evo=np.zeros((self.N,num_pos,self.recdyn_num_snaps))
if record_rec_input_evo:
self.rec_input_evo=np.zeros((self.N,num_pos,self.recdyn_num_snaps))
delta_snap=num_steps/self.recdyn_num_snaps
snap_idx=0
rec_input=np.zeros_like(r)
start_clock=time.time()
for t in range(num_steps):
#print '%d/%d'%(t,num_steps)
if np.remainder(t,delta_snap)==0 and snap_idx<self.recdyn_num_snaps:
sl.print_progress(snap_idx,self.recdyn_num_snaps,start_clock=start_clock,step=1)
self.r_evo[:,:,snap_idx]=r
#print('r_min: ',r.min())
#print('r_max: ',r.max())
if record_rec_input_evo:
self.rec_input_evo[:,:,snap_idx]=rec_input
snap_idx+=1
if record_mean_max:
self.rec_input_mean_vect[:,t]=np.mean(rec_input,axis=1)
self.rec_input_max_vect[:,t]=np.max(rec_input,axis=1)
self.r_mean_vect[:,t]=np.mean(r,axis=1)
self.r_max_vect[:,t]=np.max(r,axis=1)
# recurrent input
rec_input=np.dot(self.W,r)
# total input, add feed-forward inhibition if recurrent inhibition is not explicitely modeled
tot_input=h+rec_input
if self.N_i==0:
tot_input+=self.r0
tot_activation = activation_fun(tot_input)
r=r+(self.dt/self.tau)*(-r+tot_activation)
self.r=r
def compute_steady_scores(self,comp_inhibitory_scores=True,force_input_scores=False):
# excitatory scores
R_e=self.r[0:self.N_e,:].T
R_e=np.reshape(R_e,(self.nx,self.nx,self.N_e))
self.re_scores,re_spacings=gl.gridness_evo(R_e[:,:,:],self.L/self.nx,num_steps=10)
if comp_inhibitory_scores is True:
# inhibitory scores
R_i=self.r[self.N_e:,:].T
R_i=np.reshape(R_i,(self.nx,self.nx,self.N_i))
self.ri_scores,ri_spacings=gl.gridness_evo(R_i[:,:,:],self.L/self.nx,num_steps=10)
# input scores
if not hasattr(self.inputs,'in_scores') or force_input_scores:
print 'Computing input scores'
self.inputs.gen_data(False,comp_gridness_score=True)
self.he_scores=self.inputs.in_scores
def save_steady_scores(self):
"""
Updates data files by adding gridness scores
"""
assert (hasattr(self,'re_scores') and hasattr(self,'ri_scores') and hasattr(self,'he_scores') )
data=np.load(self.data_path,allow_pickle=True)
dataMap=dict(data.items())
scores_attrs=['re_scores','ri_scores','he_scores']
for scores_attr in scores_attrs:
assert(hasattr(self,scores_attr)),'%s is not a field'%scores_attr
dataMap[scores_attr]=getattr(self,scores_attr)
np.savez(self.data_path,**dataMap)
def load_steady_scores(self):
"""
Loads gridness scores. Generates and save them if not present in the data file.
"""
data=np.load(self.data_path,allow_pickle=True)
scores_attrs=['re_scores','ri_scores','he_scores']
if 're_scores' not in data.keys():
self.compute_steady_scores()
self.save_steady_scores()
data=np.load(self.data_path,allow_pickle=True)
for scores_attr in scores_attrs:
assert(scores_attr in data.keys())
setattr(self,scores_attr,data[scores_attr])
def update_speed_weights_step(self):
"""
Update step to learn speed weights
"""
weights_list = [self.W_speed_east, self.W_speed_west,self.W_speed_north,self.W_speed_south]
speed_input_list = [self.speed_inputs_east,self.speed_inputs_west,
self.speed_inputs_north,self.speed_inputs_south]
if self.use_eight_directions is True:
weights_list+=[self.W_speed_north_east,
self.W_speed_north_west,self.W_speed_south_east,self.W_speed_south_west]
speed_input_list+=[self.speed_inputs_north_east,self.speed_inputs_north_west,
self.speed_inputs_south_east,self.speed_inputs_south_west]
for weights,speed_input in zip(weights_list,speed_input_list):
weight_update=speed_input*(self.rr[:self.N_e]-self.input_mean)*(self.rr_e_trace.T-self.input_mean)
weights+=self.learn_rate_speed_weights*weight_update
# normalize to fixed mean of incoming and outgoing weights
weights-=(weights.mean(axis=1)-self.W_av_star)[:,np.newaxis]
weights-=(weights.mean(axis=0)-self.W_av_star)[np.newaxis,:]
# clip weights
np.clip(weights,0,self.W_max_e,out=weights)
def update_speed_input_step(self,curr_v):
"""
Update step for speed inputs (also used to learn speed weights)
"""
# update speed inputs
self.speed_inputs_east*=0
self.speed_inputs_west*=0
self.speed_inputs_north*=0
self.speed_inputs_south*=0
if self.use_eight_directions is True:
self.speed_inputs_north_east*=0
self.speed_inputs_north_west*=0
self.speed_inputs_south_east*=0
self.speed_inputs_south_west*=0
#speed_values=self.rr[:self.N_e,0]
speed_values=np.ones((self.N_e,1))
if curr_v[0]>0:
# north-east
if self.use_eight_directions is True and curr_v[1]>0:
self.speed_inputs_north_east=speed_values
# south-east
elif self.use_eight_directions is True and curr_v[1]<0:
self.speed_inputs_south_east=speed_values
#east
else:
self.speed_inputs_east=speed_values
elif curr_v[0]<0:
# north-west
if self.use_eight_directions is True and curr_v[1]>0:
self.speed_inputs_north_west=speed_values
# south-west
elif self.use_eight_directions is True and curr_v[1]<0:
self.speed_inputs_south_west=speed_values
# west
else:
self.speed_inputs_west=speed_values
else:
# north
if curr_v[1]>0:
self.speed_inputs_north=speed_values
# south
elif curr_v[1]<0:
self.speed_inputs_south=speed_values
def update_total_speed_input_step(self,curr_v):
"""
Update step to compute the total speed input to add to the recurrent dynamics
"""
tot_speed_input_east=np.dot(self.W_speed_east,self.speed_inputs_east)/self.N_e
tot_speed_input_west=np.dot(self.W_speed_west,self.speed_inputs_west)/self.N_e
tot_speed_input_north=np.dot(self.W_speed_north,self.speed_inputs_north)/self.N_e
tot_speed_input_south=np.dot(self.W_speed_south,self.speed_inputs_south)/self.N_e
self.tot_speed_input_all_padded[:self.N_e,0]=\
tot_speed_input_east+tot_speed_input_west+\
tot_speed_input_north+tot_speed_input_south
if self.use_eight_directions is True:
tot_speed_input_north_east=np.dot(self.W_speed_north_east,
self.speed_inputs_north_east)/self.N_e
tot_speed_input_north_west=np.dot(self.W_speed_north_west,
self.speed_inputs_north_west)/self.N_e
tot_speed_input_south_east=np.dot(self.W_speed_south_east,
self.speed_inputs_south_east)/self.N_e
tot_speed_input_south_west=np.dot(self.W_speed_south_west,
self.speed_inputs_south_west)/self.N_e
self.tot_speed_input_all_padded[:self.N_e,0]+=\
tot_speed_input_north_east+tot_speed_input_north_west+\
tot_speed_input_south_east+tot_speed_input_south_west
else:
# diagonal move with four directions
if abs(curr_v[0])>0 and abs(curr_v[1])>0:
self.tot_speed_input_all_padded[:self.N_e,0]*=.5
def update_recurrent_weights_step(self):
"""
Update step to learn recurrent weights
"""
# update weights: hebbian term
self.delta_Wee=self.learn_rate*(self.rr[0:self.N_e]-self.input_mean)*\
(self.rr[0:self.N_e].T-self.input_mean)
self.W_ee+=self.dt*self.delta_Wee
# update weights: normalize to fixed mean of incoming and outgoing weights
self.W_ee-=(self.W_ee.mean(axis=1)-self.W_av_star)[:,np.newaxis]
self.W_ee-=(self.W_ee.mean(axis=0)-self.W_av_star)[np.newaxis,:]
# clip weights
self.W_ee=np.clip(self.W_ee,0,self.W_max_ee)
# update excitatory weights in the big weight matrix
self.W[:self.N_e,:self.N_e]=self.W_ee
def run_recurrent_dynamics_with_walk(self,
walk_time,
num_snaps,
theta_sigma,
learn_recurrent_weights=False,
learn_speed_weights=False,
track_bump_evo=False,
track_cell_evo=False,
track_cell_idx=0,
run_in_circle=False,
sweep=False,
fixed_position=False,
use_recurrent_input=True,
use_theta_modulation=False,
theta_freq=10.,
use_tuning_switch=False,
switch_off_feedforward=False,
feed_forward_off_value=0.,
rec_gain_with_no_feedforward=1.,
switch_off_times=[],
switch_on_times=[],
tuning_time=1.,
evo_idxs=[],
force_walk=False,
periodic_walk=False,
init_p=np.array([0.,0.]),
init_theta=0.0,
interpolate_inputs=False,
activation='tanh',
r_max=100.,
position_dt=None,
synaptic_filter=False,
tau_synaptic=0.2,
walk_speed=None
):
# we cannot learn both recurrent weights and speed weights at the same time
assert(learn_recurrent_weights is False or learn_speed_weights is False)
# at most one of these flag can be true
assert ((int(run_in_circle)+int(sweep)+int(fixed_position)<2))
self.run_in_circle=run_in_circle
self.sweep=sweep
self.fixed_position=fixed_position
self.rec_gain_with_no_feedforward=rec_gain_with_no_feedforward
self.synaptic_filter=synaptic_filter
self.tau_synaptic=tau_synaptic
self.feed_forward_off_value=feed_forward_off_value
self.walk_speed = walk_speed if walk_speed is not None else self.speed
# copy initial weights in case they are rescaled to compensate the absence of feed-forward inputs
self.Wee_nogain=self.W_ee.copy()
self.W_nogain=self.W.copy()
# initialize swithing times (in case we are turning off feed-forward input or their tuning)
if len(switch_on_times)>0:
curr_switch_on_time=switch_on_times.pop(0)
else:
curr_switch_on_time=None
if len(switch_off_times)>0:
curr_switch_off_time=switch_off_times.pop(0)
else:
curr_switch_off_time=None
# activation function
activation_fun=get_activation_fun(activation,r_max)
### ============= LEARNING RECURRENT WEIGHTS ===============================
if learn_recurrent_weights is True:
# we cannot learn recurrent weights with speed input
assert(self.use_speed_input is False)
print 'Learning recurrent weights with random walk'
# initialize connections to the excitatory neurons
self.W_ee0=np.zeros((self.N_e,self.N_e))
# initial weights are random
if self.start_with_zero_connectivity is False:
print 'Initializing weights at random to the upper bound'
for row_idx in xrange(self.N_e):
idxs=np.arange(self.N_e)
np.random.shuffle(idxs)
self.W_ee0[row_idx,idxs[0:self.num_conns_ee]]=self.W_max_ee
# initial weights are set to zero
else:
print 'Initializing weights to zero'
# initializations
self.learn_snap_idx=0
self.learn_walk_step_idx=0
self.W_ee=self.W_ee0
self.W[:self.N_e,:self.N_e]=self.W_ee
self.rr=np.zeros((self.N,1))
# weight evolution vectors
if len(evo_idxs) == 0 :
self.Wee_evo=np.zeros((self.N_e,num_snaps))
else:
self.Wee_evo=np.zeros((len(evo_idxs),self.N_e,num_snaps))
self.mean_rr_evo=np.zeros(num_snaps)
### ============= LEARNING SPEED WEIGHTS ===================================
elif learn_speed_weights is True:
# we cannot learn speed weights with speed input
assert(self.use_speed_input is False)
print 'Learning speed weights with random walk'
self.W_speed_east_evo=np.zeros((self.N_e,num_snaps))
# target mean weight for input and output connections
self.W_av_star=(np.float(self.W_max_e)*self.num_conns_ee)/self.N_e
# initialize speed weights to zero
# row 1 has the weights of speed cells to grid cell 1
self.W_speed_east=np.zeros_like(self.W_ee)
self.W_speed_west=np.zeros_like(self.W_ee)
self.W_speed_north=np.zeros_like(self.W_ee)
self.W_speed_south=np.zeros_like(self.W_ee)
if self.use_eight_directions is True:
self.W_speed_north_east=np.zeros_like(self.W_ee)
self.W_speed_north_west=np.zeros_like(self.W_ee)
self.W_speed_south_east=np.zeros_like(self.W_ee)
self.W_speed_south_west=np.zeros_like(self.W_ee)
### ============= RUN DYNAMICS WITHOUT LEARNING ============================
else:
print 'Recurrent dynamics with random walk'
print 'use_speed_input: %s'%self.use_speed_input
self.num_walk_steps = int(walk_time/self.dt)
# rate at which we shall update the position (there is the option to interpolate inputs between updates)
if position_dt is None:
self.position_dt=self.L/self.nx/self.speed
else:
self.position_dt=position_dt
self.pos_dt_scale=int(self.position_dt/self.dt)
self.walk=RandomWalk(sl.map_merge(self.paramMap,{ 'walk_time':walk_time,
'position_dt':self.position_dt,
'theta_sigma':theta_sigma,
'sweep':sweep,
'init_p':init_p,
'init_theta':init_theta,
'periodic_walk':periodic_walk,
'speed':self.walk_speed,
}),
force=force_walk,
#init_p=init_p,
#init_theta=init_theta,
)
self.delta_snap = int(np.floor(float(self.num_walk_steps)/(num_snaps)))
assert(self.delta_snap>0)
print 'pos_dt_scale: %d'%self.pos_dt_scale
print 'delta_snap: %d'%self.delta_snap
self.start_clock=time.time()
self.startTime=datetime.datetime.fromtimestamp(time.time())
self.startTimeStr=self.startTime.strftime('%Y-%m-%d %H:%M:%S')
# initializations
self.snap_idx=0
self.walk_step_idx=0
self.rr=np.zeros((self.N,1))
self.start_clock=time.time()
self.r_e_walk_map=np.zeros((self.N_e,self.NX))
self.visits_map=np.zeros(self.NX)
if self.use_speed_input or learn_speed_weights:
self.speed_inputs_east=np.zeros(self.N_e)
self.speed_inputs_west=np.zeros(self.N_e)
self.speed_inputs_north=np.zeros(self.N_e)
self.speed_inputs_south=np.zeros(self.N_e)
if self.use_eight_directions is True:
self.speed_inputs_north_east=np.zeros(self.N_e)
self.speed_inputs_north_west=np.zeros(self.N_e)
self.speed_inputs_south_east=np.zeros(self.N_e)
self.speed_inputs_south_west=np.zeros(self.N_e)
self.tot_speed_input_all_padded=np.zeros((self.N,1))
if track_bump_evo is True:
self.bump_peak_evo=np.zeros((2,num_snaps))
self.bump_hh_peak_evo=np.zeros((2,num_snaps))
self.bump_evo=np.zeros((self.N_e,num_snaps))
self.bump_hh_evo=np.zeros((self.N_e,num_snaps))
self.bump_rec_evo=np.zeros((self.N_e,num_snaps))
self.bump_speed_evo=np.zeros((self.N_e,num_snaps))
if track_cell_evo is True:
self.cell_rr_evo=np.zeros(num_snaps)
self.cell_hh_evo=np.zeros(num_snaps)
self.cell_rec_input_evo=np.zeros(num_snaps)
self.cell_rec_input_from_e_evo=np.zeros(num_snaps)
self.cell_rec_input_from_i_evo=np.zeros(num_snaps)
pos_idx=-1
curr_p=self.walk.pos[pos_idx]
self.hh_e=self.h_e[:,pos_idx]
self.hh_i=self.h_i[:,pos_idx]
# feed-forward input vector
self.hh=np.zeros((self.N_e+self.N_i,1))
self.next_hh=np.zeros((self.N_e+self.N_i,1))
tot_input=np.zeros_like(self.hh)
filtered_tot_input=np.zeros_like(self.hh)
# run the simulation
for step_idx in xrange(self.num_walk_steps):
#print 'step_idx: %d'%step_idx
if self.fixed_position is False:
# ==== start of updating rat position ===============================
if np.remainder(step_idx,self.pos_dt_scale)==0:
#print 'updating position, interpolate_inputs=%d'%interpolate_inputs
# if we are at the end of the walk we start again
if self.walk_step_idx>=self.walk.walk_steps:
self.walk_step_idx=0
# read inputs at this walk step
new_pos_idx= self.walk.pidx_vect[self.walk_step_idx]
# the position has really changed from the last walk step
# note that the position could still be the same because the rat moved less
# than the discretization step used for space, that is, L/nx.
# on straight trajectories position shall update every L/nx/speed
if not (new_pos_idx == pos_idx):
pos_idx=new_pos_idx
new_p=self.walk.pos[pos_idx]
# with speed input or learning speed weights we need to update current direction
if self.use_speed_input is True or learn_speed_weights is True:
if step_idx>0:
dp=new_p-curr_p
# if we are changing position update current direction
if not (dp[0]==0. and dp[1]==0):
curr_v=dp
self.update_speed_input_step(curr_v)
# compute total weighted speed inputs to add to the recurrent dynamics
if self.use_speed_input is True:
self.update_total_speed_input_step(curr_v)
# update speed weights
if learn_speed_weights is True:
self.update_speed_weights_step()
curr_p=new_p
# update the input at this position
self.hh[:self.N_e,0]=self.h_e[:,pos_idx]
self.hh[self.N_e:,0]=self.h_i[:,pos_idx]
# get the inputs at the next different position for interpolation
if interpolate_inputs:
j=1
while(j<1000):
# get inputs at the next different position (for interpolation)
next_walk_step_idx=self.walk_step_idx+j
if next_walk_step_idx>=self.walk.walk_steps:
next_walk_step_idx=0
next_pos_idx=self.walk.pidx_vect[next_walk_step_idx]
if not(next_pos_idx == pos_idx):
break
j+=1
# we found the next position
if j<1000:
self.next_hh[:self.N_e,0]=self.h_e[:,next_pos_idx]
self.next_hh[self.N_e:,0]=self.h_i[:,next_pos_idx]
# the next different position is j walk steps away
hh_increment=(self.next_hh-self.hh)/(self.pos_dt_scale*j)
# we did not find it, no increment
else:
hh_increment=np.zeros_like(self.hh)
else:
if interpolate_inputs:
### the position was the same -> interpolate
self.hh+=hh_increment
self.walk_step_idx+=1
# ==== End of update of rat poistion =================================
else:
if interpolate_inputs:
### time has passed but position not updated -> interpolate
self.hh+=hh_increment
# tuning switch
if use_tuning_switch is True:# and tuning_switch_count<len(switch_times):
if curr_switch_on_time is not None and step_idx*self.dt>=curr_switch_on_time:
print 'Switching back to tuned inputs, time=%.3f'%(step_idx*self.dt)
self.switch_to_tuned_inputs()
self.hh_e=self.h_e[:,pos_idx]
self.hh_i=self.h_i[:,pos_idx]
self.hh=np.vstack((self.hh_e[:,np.newaxis],self.hh_i[:,np.newaxis]))
# replug the normal recurrent weights
self.W_ee=self.Wee_nogain
self.W[:self.N_e,:self.N_e]=self.W_ee
if len(switch_on_times)>0:
curr_switch_on_time=switch_on_times.pop(0)
else:
curr_switch_on_time=None
if curr_switch_off_time is not None and step_idx*self.dt>=curr_switch_off_time:
if switch_off_feedforward is True:
print 'Switching off feed-forward inputs, time=%.3f'%(step_idx*self.dt)
self.switch_to_no_feedforward_inputs()
# add a gain factor to the excitatory recurrent weights
self.W_ee=self.Wee_nogain*rec_gain_with_no_feedforward
self.W[:self.N_e,:self.N_e]=self.W_ee
else:
print 'Switching to untuned inputs, time=%.3f'%(step_idx*self.dt)
self.switch_to_untuned_inputs()
self.hh_e=self.h_e[:,pos_idx]
self.hh_i=self.h_i[:,pos_idx]
self.hh=np.vstack((self.hh_e[:,np.newaxis],self.hh_i[:,np.newaxis]))
if len(switch_off_times)>0:
curr_switch_off_time=switch_off_times.pop(0)
else:
curr_switch_off_time=None
# reset total input
tot_input*=0
# get feedforward input
self.ff_input=self.hh
# add feed-forward input
tot_input+=self.ff_input
# add theta modulation (of the feed-forward inputs) if necessary
if use_theta_modulation is True:
theta_signal=0.5*(np.cos(2*np.pi*step_idx*self.dt*theta_freq)+1)
tot_input*=theta_signal
# compute recurrent input
self.rec_input=np.dot(self.W,self.rr)
# recurrent input from excitatory and inhibitory cells
if track_cell_evo is True:
self.rec_input_from_e=np.dot(self.W[:,:self.N_e],self.rr[:self.N_e,:])
self.rec_input_from_i=np.dot(self.W[:,self.N_e:],self.rr[self.N_e:,:])
# add recurrent and speed input
if self.use_speed_input is True:
tot_input+=(1-self.speed_input_scale)*self.rec_input+\
self.speed_input_scale*self.tot_speed_input_all_padded
# add recurrent input only
elif use_recurrent_input is True:
tot_input+=self.rec_input
# add feed-forward inhibition if needed (in case of no recurrent inhibition)
if self.N_i==0:
tot_input+=self.r0
# # add theta modulation (of the total input) if necessary
# if use_theta_modulation is True:
# theta_signal=0.5*(np.cos(2*np.pi*step_idx*self.dt*theta_freq)+1)
# tot_input*=theta_signal
if self.synaptic_filter is True:
filtered_tot_input+=(self.dt/self.tau_synaptic)*(-filtered_tot_input+tot_input)
else:
filtered_tot_input=tot_input
# compute firing-rate output
self.rr+=(self.dt/self.tau)*(-self.rr+activation_fun(filtered_tot_input))
# record output and position
self.r_e_walk_map[:,pos_idx]+=self.rr[:self.N_e,0]
self.visits_map[pos_idx]+=1
# update recurrent weights if needed
if learn_recurrent_weights is True:
self.update_recurrent_weights_step()
# progress
if np.remainder(step_idx,self.delta_snap)==0:
sl.print_progress(self.snap_idx,num_snaps,self.start_clock)
# track recurrent weights evolution
if learn_recurrent_weights is True:
if len(evo_idxs)==0:
self.Wee_evo[:,self.snap_idx]=self.W_ee[self.zero_phase_idx,:]
else:
for i,evo_cell_idx in enumerate(evo_idxs):
self.Wee_evo[i,:,self.snap_idx]=self.W_ee[evo_cell_idx,:]
self.mean_rr_evo[self.snap_idx]=self.rr.mean()
if learn_speed_weights is True:
self.W_speed_east_evo[:,self.snap_idx]=self.W_speed_east[self.zero_phase_idx,:]
# track bump evolution
if track_bump_evo is True and self.snap_idx<num_snaps:
self.bump_evo[:,self.snap_idx]=self.rr[:self.N_e,0]
self.bump_hh_evo[:,self.snap_idx]=self.hh[:self.N_e,0]
self.bump_rec_evo[:,self.snap_idx]=self.rec_input[:self.N_e,0]
if self.use_speed_input is True:
self.bump_speed_evo[:,self.snap_idx]=self.tot_speed_input_all_padded[:self.N_e,0]
bump_peak_xy=find_bump_peak_idxs(self.rr[:self.N_e,0])
bump_hh_peak_xy=find_bump_peak_idxs(self.hh[:self.N_e,0])
if bump_peak_xy is not None:
self.bump_peak_evo[0,self.snap_idx]=bump_peak_xy[0]
self.bump_peak_evo[1,self.snap_idx]=bump_peak_xy[1]
if bump_hh_peak_xy is not None:
self.bump_hh_peak_evo[0,self.snap_idx]=bump_hh_peak_xy[0]
self.bump_hh_peak_evo[1,self.snap_idx]=bump_hh_peak_xy[1]
# track single cell rate evolution
if track_cell_evo is True and self.snap_idx<num_snaps:
self.cell_rr_evo[self.snap_idx]=self.rr[track_cell_idx,0]
self.cell_hh_evo[self.snap_idx]=self.ff_input[track_cell_idx,0]
self.cell_rec_input_evo[self.snap_idx]=self.rec_input[track_cell_idx,0]
self.cell_rec_input_from_e_evo[self.snap_idx]=self.rec_input_from_e[track_cell_idx,0]
self.cell_rec_input_from_i_evo[self.snap_idx]=self.rec_input_from_i[track_cell_idx,0]
self.snap_idx+=1
self.visits_map[self.visits_map==0]=1
self.r_e_walk_map/=self.visits_map
# logging simulation end
self.endTime=datetime.datetime.fromtimestamp(time.time())
self.endTimeStr=self.endTime.strftime('%Y-%m-%d %H:%M:%S')
self.elapsedTime =time.time()-self.start_clock
print 'Simulation ends: %s'%self.endTimeStr
print 'Elapsed time: %s\n' %sl.format_elapsed_time(self.elapsedTime)
def load_weights_from_data_path(self,weights_data_path):
if os.path.exists(weights_data_path):
print 'Loading recurrent weights: %s'%weights_data_path
data=np.load(weights_data_path,allow_pickle=True)
self.Wee_evo=data['Wee_evo']
self.mean_rr_evo=data['mean_rr_evo']
self.W_ee=data['W_ee']
self.W_ee0=data['W_ee0']
self.W_av_star=data['W_av_star']
self.W[:self.N_e,:self.N_e]=self.W_ee
if 'conn_tuning_index' in data.keys():
self.conn_tuning_index=data['conn_tuning_index']
self.conn_trans_index=data['conn_trans_index']
else:
raise Exception('Data do not exist: %s'%weights_data_path)
# def comp_inhibitory_tuning_index(self,verbose=False) :
# self.grid_tuning_out_inhib=gl.comp_grid_tuning_index(self.L,self.nx,(self.r[self.n_e**2:,:]).T,verbose=verbose)
def comp_amplification_index(self):
"""
Compute Acell ANoise and amplification index (measure based on power spectra in space)
"""
self.grid_tuning_in=self.inputs.grid_tuning_in
self.grid_tuning_out=gl.comp_grid_tuning_index(self.L,self.nx,(self.r[0:self.n_e**2,:]).T)
self.grid_tuning_out_inhib=gl.comp_grid_tuning_index(self.L,self.nx,(self.r[self.n_e**2:,:]).T)
self.grid_amp_index=self.grid_tuning_out/self.grid_tuning_in
def comp_output_spectra(self):
"""
Compute output power spectra
"""
assert(hasattr(self,'r'))
self.nx=int(self.nx)
r_mat=self.r.T.reshape(self.nx,self.nx,self.N)
in_allfreqs = np.fft.fftshift(np.fft.fftfreq(self.nx,d=self.L/self.nx))
self.freqs=in_allfreqs[self.nx/2:]
r_dft_flat=np.fft.fftshift(np.fft.fft2(r_mat,axes=[0,1]),axes=[0,1])*(self.L/self.nx)**2
r_pw=abs(r_dft_flat)**2
r_pw_profiles=gl.dft2d_profiles(r_pw)
self.re_pw_profile=np.mean(r_pw_profiles,axis=0)
self.he_pw_profile=self.inputs.in_mean_pw_profile
def plot_recurrent_connectivity(self):
pp.plot_recurrent_weights(self.W_ee,self.gp,vmax=self.W_ee.max())
tuning_index= gl.get_recurrent_matrix_tuning_index(self.W_ee,self.gp)
tot_in_weight=self.W_ee.sum(axis=1).mean()
print 'Total input weight %.3f'%tot_in_weight
print 'Maximal weight %.3f'%self.W_ee.max()
print 'Connnectivity tuning index: %.3f'%tuning_index
def plot_recurrent_dynamics(self,cell_idx=0,snap_idxs=[1,5,10,5,19]):
pl.rc('font',size=14)
time=np.arange(0,self.recdyn_time,self.dt)
pl.figure(figsize=(10,5))
pl.subplots_adjust(bottom=0.2,hspace=0.4)
pl.subplot(211)
pl.plot(time,self.r_mean_vect[cell_idx,:],lw=2,label='Output')
pl.plot(time,self.rec_input_mean_vect[cell_idx,:],lw=2,label='Rec. Input')
pp.custom_axes()
pl.ylabel('Mean rate [spike/s]')
pl.legend()
pl.subplot(212)
pl.plot(time,self.r_max_vect[cell_idx,:],lw=2,label='Output')
pl.plot(time,self.rec_input_max_vect[cell_idx,:],lw=2,label='Rec. Input')
pp.custom_axes()
pl.ylabel('Max rate [spike/s]')
pl.xlabel('Time [s]')
pl.figure(figsize=(12,5))
pl.subplots_adjust(bottom=0.2,hspace=0.4)
idx=1
for snap_idx in snap_idxs:
pl.subplot(2,len(snap_idxs),idx,aspect='equal')
r_map=self.r_evo[cell_idx,:,snap_idx].reshape(self.nx,self.nx).T
pl.pcolormesh(r_map,vmin=0,rasterized=True)
pl.title('%.2f s'%(snap_idx*(self.recdyn_time/self.recdyn_num_snaps)),fontsize=12)
pp.noframe()
if snap_idx==snap_idxs[0]:
pl.ylabel('Output')
idx+=1
for snap_idx in snap_idxs:
pl.subplot(2,len(snap_idxs),idx,aspect='equal')
r_map=self.rec_input_evo[cell_idx,:,snap_idx].reshape(self.nx,self.nx).T
pl.pcolormesh(r_map,vmin=0,rasterized=True)
pp.noframe()
if snap_idx==snap_idxs[0]:
pl.ylabel('Rec. Input')
idx+=1
def plot_steady_scores(self):
import pylab as pl
import plotlib as pp
pl.figure(figsize=(3.5,2.8))
pl.subplots_adjust(left=0.25,bottom=0.25)
pl.hist(self.he_scores,bins=40,range=[-0.5,2],color='gray',histtype='stepfilled',weights=np.ones_like(self.he_scores)/float(len(self.he_scores)),alpha=1)
pl.hist(self.re_scores,bins=40,range=[-0.5,2],color='black',histtype='stepfilled',weights=np.ones_like(self.re_scores)/float(len(self.re_scores)),alpha=1)
pl.hist(self.he_scores,bins=40,range=[-0.5,2],color='gray',histtype='step',weights=np.ones_like(self.he_scores)/float(len(self.he_scores)),alpha=1,lw=2)
pp.custom_axes()
pl.ylim(1e-3,0.7)
ax=pl.gca()
ax.set_yscale('log')
pl.xlabel('Gridness score')
pl.ylabel('Fraction of cells')
print 'Mean input gridness score: %.2f'%np.mean(self.he_scores)
print 'Mean output gridness score: %.2f'%np.mean(self.re_scores)
pl.title('%.2f'%np.mean(self.re_scores))
class RecAmp2PopLearn(__RecAmp2Pop):
"""
This class inherits from __RecAmp2Pop and implements the learning of the recurrent excitatory connections.
The results of these simulations are saved in the results subfolder 'recamp_2pop' and have prefix 'RecAmp2PopLearn_'.
"""
results_path=os.path.join(sl.get_results_path(),'recamp_2pop')
def __init__(self,paramMap):
# set parameter values from input map
for param,value in paramMap.items():
setattr(self,param,value)
self.paramMap=paramMap
self.str_id=sl.gen_string_id(self.paramMap)
self.hash_id=sl.gen_hash_id(self.str_id)
self.data_path=os.path.join(RecAmp2PopLearn.results_path,'RecAmp2PopLearn_'+self.hash_id+'_data.npz')
self.params_path=os.path.join(RecAmp2PopLearn.results_path,'RecAmp2PopLearn_'+self.hash_id+'_log.txt')
def save_learned_recurrent_weights(self):
# save variables
toSaveMap={'paramMap':self.paramMap,
'Wee_evo':self.Wee_evo,'mean_rr_evo':self.mean_rr_evo,
'W':self.W,'W_ee':self.W_ee,'W_ee0':self.W_ee0,'W_av_star':self.W_av_star,
'conn_tuning_index':self.conn_tuning_index,'conn_trans_index':self.conn_trans_index}
# save
sl.ensureParentDir(self.data_path)
np.savez(self.data_path,**toSaveMap)
print 'Recurrent weights saved in: %s\n'%self.data_path
def load_learned_recurrent_weights(self):
"""
"""
self.load_weights_from_data_path(self.data_path)
def learn_recurrent_weights(self,force=False):
if force or not os.path.exists(self.data_path):
self.post_init()
self.run_recurrent_dynamics_with_walk(self.learn_walk_time,
self.learn_num_snaps,
self.theta_sigma,
learn_recurrent_weights=True,
use_recurrent_input=self.learn_with_recurrent_input)
self.gp=gl.GridProps(self.n_e,self.grid_T,self.grid_angle)
self.conn_tuning_index= gl.get_recurrent_matrix_tuning_index(self.W_ee,self.gp)
self.conn_trans_index=gl.get_trans_index(self.W_ee)
self.save_learned_recurrent_weights()
else:
print 'Data already present: %s'%self.data_path
class RecAmp2PopSteady(__RecAmp2Pop):
"""
This class inherits from __RecAmp2Pop and implements the simulation of the steady-state outputs.
The results of these simulations are saved in the results subfolder 'recamp_2pop' and have prefix 'RecAmp2PopSteady_'.
"""
results_path=os.path.join(sl.get_results_path(),'recamp_2pop')
def __init__(self,paramMap):
# set parameter values from input map
for param,value in paramMap.items():
setattr(self,param,value)
self.paramMap=paramMap
self.str_id=sl.gen_string_id(self.paramMap)
self.hash_id=sl.gen_hash_id(self.str_id)
self.data_path=os.path.join(RecAmp2PopSteady.results_path,'RecAmp2PopSteady_'+self.hash_id+'_data.npz')
self.paramsPath=os.path.join(RecAmp2PopSteady.results_path,'RecAmp2PopSteady_'+self.hash_id+'_log.txt')
def save_steady_outputs(self):
# variables to be saved
toSaveMap={'paramMap':self.paramMap,
'r':self.r, # steady state output patterns
'grid_tuning_in':self.grid_tuning_in,
'grid_tuning_out':self.grid_tuning_out,
'grid_tuning_out_inhib':self.grid_tuning_out_inhib,
'grid_amp_index':self.grid_amp_index
}
# save
sl.ensureParentDir(self.data_path)
np.savez(self.data_path,**toSaveMap)
print 'Steady-state ouput saved in: %s\n'%self.data_path
def load_steady_outputs(self):
if os.path.exists(self.data_path):
print 'Loading steady state output: %s'%self.data_path
data=np.load(self.data_path,allow_pickle=True)
self.r=data['r']
self.grid_tuning_in=data['grid_tuning_in']
self.grid_tuning_out=data['grid_tuning_out']
self.grid_amp_index=data['grid_amp_index']
if 'grid_tuning_out_inhib' in data.keys():
self.grid_tuning_out_inhib=data['grid_tuning_out_inhib']
else:
print 'Steady state output does not exist: %s'%self.data_path
def recompute_and_save_amplification_index(self):
self.post_init()
self.load_steady_outputs()
self.comp_amplification_index()
self.save_steady_outputs()
def compute_and_save_steady_output(self,force=False):
"""
Computes steady-state output of the network and saves the results to disk
Parameters
----------
force : boolean, optional
Force the recomputation of the outputs even if they are already present on disk.
The default is False.
Returns
-------
None.
"""
if force or not os.path.exists(self.data_path):
self.post_init()
# we need to load already learned weights
if self.use_learned_recurrent_weights is True:
assert(self.recurrent_weights_path is not None)
self.load_weights_from_data_path(self.recurrent_weights_path)
self.run_recurrent_dynamics(record_mean_max=False)
self.comp_amplification_index()
self.save_steady_outputs()
else:
print 'Data already present: %s'%self.data_path
def plot_example_outputs(self,cell_idxs=[1,2,3,4,5]):
import pylab as pl
import grid_utils.plotlib as pp
dx=self.L/self.nx
xran=np.arange(self.nx)*self.L/self.nx-self.L/2-dx/2.
vmax=None
pl.figure(figsize=(10,3))
pl.subplots_adjust(wspace=0.2,top=0.8)
for idx,cell_idx in enumerate(cell_idxs):
pl.subplot(1,5,idx+1,aspect='equal')
r_map=self.r[cell_idx,:].reshape(self.nx,self.nx).T
pl.pcolormesh(xran,xran,r_map,rasterized=True,vmin=0,vmax=vmax)
pl.title('%.1f %.2f'%(r_map.max(),self.grid_tuning_out[cell_idx]),fontsize=11)
pp.noframe()
class RecAmp2PopAttractor(__RecAmp2Pop):
"""
This class inherits from __RecAmp2Pop and implements the simulations to probe attractor dynamics
The results of these simulations are saved in the results subfolder 'recamp_2pop' and have prefix 'RecAmp2PopSteady_'.
"""
results_path=os.path.join(sl.get_results_path(),'recamp_2pop')
def __init__(self,paramMap):
#self.as_super = super(RecAmp2PopAttractor, self)
# set parameter values from input map
for param,value in paramMap.items():
setattr(self,param,value)
self.paramMap=paramMap
self.str_id=sl.gen_string_id(self.paramMap)
self.hash_id=sl.gen_hash_id(self.str_id)
self.data_path=os.path.join(RecAmp2PopSteady.results_path,'RecAmp2PopAttractor_'+self.hash_id+'_data.npz')
self.paramsPath=os.path.join(RecAmp2PopSteady.results_path,'RecAmp2PopAttractor_'+self.hash_id+'_log.txt')
# def post_init(self):
# #return super(RecAmp2PopAttractor, self).post_init(get_inputs=False)
# return self.as_super.post_init(get_inputs=False)
def get_stimulus_bump(self):
"""
Generates an artificial activity bumps in phase space to be presented as feed-forward input to the network.
Such a feed-forward input stimulus is used to probe the attractor landscape of the newtwork, i.e.
does the network follow the imposed bump location? Where does the bump move when such input is removed?
We produce N_e such bumps, where N_e is the total number of excitatory neurons in the network.
Each bump is centered at the preferred phase of each excitatory neuron.
Parameters
----------
r_max : float, optional
Peak rate of the bump. The default is 10.
sigma : float, optional
With of the bump. The default is 0.2.
Returns
-------
None. But it stores in the object the fields following fields:
stimulus_bumps:
"""
re_stimulus_bumps=np.zeros((self.N_e,self.N_e))
# loop across the preferred phases of all neurons
for idx,phase_shift in enumerate(self.gp.phases[:,:]):
# initialize bump activity to zero
bump=np.zeros(self.N_e)
# To account for the periodicity of the phase space we need to add up 9 bumps that are shifted
# of +- one period in each dimension (i.e., 3 horizontal shifts + 3 vertiacl shifts)
for xp in (-1,0,1):
for yp in (-1,0,1):
shift=xp*self.gp.u1+yp*self.gp.u2
bump+=rhombus_mises(self.gp.phases-phase_shift[np.newaxis,:]+shift,self.stimulus_bump_sigma)
# save the generated bump for the current preferred phase
re_stimulus_bumps[idx,:]=bump/bump.mean()*self.input_mean
self.stimulus_bumps=np.zeros((self.N_e+self.N_i,self.N_e))
self.stimulus_bumps[:self.N_e,:]=re_stimulus_bumps
def run_single_attractor_sim(self,init_r,time,stimulus_on):
"""
Runs the recurrent dynamics in two different scenarios:
1. With an external stimulus (stimuls_on=True) starting from a zero intial condition
2. Without an external stimuls (stimulus_on=False) starting from a non-zero initial condition
In case 1. the feed-forward input is generated using the method `get_stimulus_bump`. And the rates
of all neurons are initialized at zero
In case 2. the feed-forward input is off (set to a constant value ) and the neurons are initialized
with the rates provided by the input argmunet `init_r`.
Parameters
----------
init_r : Numpy Array of dimensions: N_e+N_i
Initialization values for the rates of all neurons
time : float
Time length of the simulation
stimulus_on : bool
Indicates if an external feed-forward input needs to be provided.
If False we require to initialize the rates to non zero-values
Returns
-------
out_bumps: Numpy array of dimensions: N_e x N_e
Rates of the excitatory neurons for all the N_e shifts of the stimulus bump
out_bump_idsx2d: Numpy a of dimensions: 2 x N_e
Coordinates of the output bump of all the N_e shifts of the similus bump
"""
# if we need a feed-forward stimulus we need to generate it first
if stimulus_on is True:
self.get_stimulus_bump()
assert(init_r is None)
# otherwise we should make sure that the rates have been previously intialized with a stimulus
else:
assert(init_r is not None)
# set the simulation time to the required value
self.recdyn_time=time
# run recurrent dynamics with stimulus on
if stimulus_on :
print('Running with stimulus on for %.1f s'%time)
self.run_recurrent_dynamics(
record_mean_max=False,
custom_h=self.stimulus_bumps,
init_r=np.zeros_like(self.stimulus_bumps),
)
# run recurrent dynamics with stimulus off (stimulus was presented previously)
else:
print('Running with stimulus off for %.1f s'%time)
self.run_recurrent_dynamics(
record_mean_max=False,
custom_h=np.ones_like(init_r)*self.flat_h_rate,
init_r=init_r,
)
# save the output rates at the and of the simulation
out_bumps=self.r
# get bump idxs at the end of the simulation
out_bump_idx2d=get_bumps_idxs_all_shifts(self.r,self.N_e)
out_bump_evo_idx2d=[]
for snap_idx in range(self.recdyn_num_snaps):
out_bump_evo_idx2d.append(get_bumps_idxs_all_shifts(np.squeeze(self.r_evo[:,:,snap_idx]),self.N_e))
out_bump_evo_idx2d=np.array(out_bump_evo_idx2d)
return out_bumps,out_bump_idx2d,out_bump_evo_idx2d
def run_attractor_sims(self):
self.out_bumps_stim_on,self.out_bumps_stim_on_idx2d,self.out_bumps_stim_on_evo_idx2d=self.run_single_attractor_sim(None,self.time_stimulus_on,True)
self.out_bumps_stim_off,self.out_bumps_stim_off_idx2d,self.out_bumps_stim_off_evo_idx2d=self.run_single_attractor_sim(self.out_bumps_stim_on,self.time_stimulus_off,False)
def save_attractor_outputs(self):
# variables to be saved
toSaveMap={'paramMap':self.paramMap,
'out_bumps_stim_on':self.out_bumps_stim_on,
'out_bumps_stim_on_idx2d':self.out_bumps_stim_on_idx2d,
'out_bumps_stim_on_evo_idx2d':self.out_bumps_stim_on_evo_idx2d,
'out_bumps_stim_off':self.out_bumps_stim_off,
'out_bumps_stim_off_idx2d':self.out_bumps_stim_off_idx2d,
'out_bumps_stim_off_evo_idx2d':self.out_bumps_stim_off_evo_idx2d,
'field_len_stim_on':self.field_len_stim_on,
'field_len_stim_off':self.field_len_stim_off,
'field_len_stim_on_evo':self.field_len_stim_on_evo,
'field_len_stim_off_evo':self.field_len_stim_off_evo,
'stim_off_speed_evo':self.stim_off_speed_evo,
'stim_off_avg_last_second_speed':self.stim_off_avg_last_second_speed,
'stim_off_avg_path_dist_evo':self.stim_off_avg_path_dist_evo,
'stim_off_end_dists':self.stim_off_end_dists,
'stim_off_avg_end_dist':self.stim_off_avg_end_dist,
'all_end_phases':self.all_end_phases,
'num_attractors':self.num_attractors,
}
# save
sl.ensureParentDir(self.data_path)
np.savez(self.data_path,**toSaveMap)
print 'Attractor ouput saved in: %s\n'%self.data_path
def load_attractor_outputs(self):
if os.path.exists(self.data_path):
print 'Loading attractor output: %s'%self.data_path
data=np.load(self.data_path,allow_pickle=True)
self.out_bumps_stim_on=data['out_bumps_stim_on']
self.out_bumps_stim_on_idx2d=data['out_bumps_stim_on_idx2d']
self.out_bumps_stim_on_evo_idx2d=data['out_bumps_stim_on_evo_idx2d']
self.out_bumps_stim_off=data['out_bumps_stim_off']
self.out_bumps_stim_off_idx2d=data['out_bumps_stim_off_idx2d']
self.out_bumps_stim_off_evo_idx2d=data['out_bumps_stim_off_evo_idx2d']
self.field_len_stim_on=data['field_len_stim_on']
self.field_len_stim_off=data['field_len_stim_off']
self.field_len_stim_on_evo=data['field_len_stim_on_evo']
self.field_len_stim_off_evo=data['field_len_stim_off_evo']
self.stim_off_speed_evo=data['stim_off_speed_evo']
self.stim_off_avg_last_second_speed=data['stim_off_avg_last_second_speed']
self.stim_off_avg_path_dist_evo=data['stim_off_avg_path_dist_evo']
self.stim_off_end_dists=data['stim_off_end_dists']
self.stim_off_avg_end_dist=data['stim_off_avg_end_dist']
self.all_end_phases=data['all_end_phases']
self.num_attractors=data['num_attractors']
else:
print 'Attractor output does not exist: %s'%self.data_path
def get_attractor_fields(self):
self.X_in,self.Y_in=[C/float(self.n_e) for C in np.mgrid[0:self.n_e,0:self.n_e]]
self.X_out_stim_on=self.out_bumps_stim_on_idx2d[0,:].reshape(self.n_e,self.n_e)/float(self.n_e)
self.Y_out_stim_on=self.out_bumps_stim_on_idx2d[1,:].reshape(self.n_e,self.n_e)/float(self.n_e)
self.X_out_stim_off=self.out_bumps_stim_off_idx2d[0,:].reshape(self.n_e,self.n_e)/float(self.n_e)
self.Y_out_stim_off=self.out_bumps_stim_off_idx2d[1,:].reshape(self.n_e,self.n_e)/float(self.n_e)
self.dX_stim_on,self.dY_stim_on,self.field_len_stim_on=gl.get_vector_field(self.X_in,self.Y_in
,self.X_out_stim_on,self.Y_out_stim_on)
self.dX_stim_off,self.dY_stim_off,self.field_len_stim_off=gl.get_vector_field(self.X_in,self.Y_in
,self.X_out_stim_off,self.Y_out_stim_off)
self.field_len_stim_off_evo=np.zeros(self.recdyn_num_snaps)
self.field_len_stim_on_evo=np.zeros(self.recdyn_num_snaps)
for snap_idx in range(self.recdyn_num_snaps):
curr_X_out_stim_off=self.out_bumps_stim_off_evo_idx2d[snap_idx,0,:].reshape(self.n_e,self.n_e)/float(self.n_e)
curr_Y_out_stim_off=self.out_bumps_stim_off_evo_idx2d[snap_idx,1,:].reshape(self.n_e,self.n_e)/float(self.n_e)
curr_X_out_stim_on=self.out_bumps_stim_on_evo_idx2d[snap_idx,0,:].reshape(self.n_e,self.n_e)/float(self.n_e)
curr_Y_out_stim_on=self.out_bumps_stim_on_evo_idx2d[snap_idx,1,:].reshape(self.n_e,self.n_e)/float(self.n_e)
_,_,field_len_stim_off=gl.get_vector_field(self.X_in,self.Y_in,curr_X_out_stim_off,curr_Y_out_stim_off)
_,_,field_len_stim_on=gl.get_vector_field(self.X_in,self.Y_in,curr_X_out_stim_on,curr_Y_out_stim_on)
self.field_len_stim_off_evo[snap_idx]=field_len_stim_off
self.field_len_stim_on_evo[snap_idx]=field_len_stim_on
def run_and_save_attractor(self,force=False):
if force or not os.path.exists(self.data_path):
self.post_init()
# we need to load already learned weights
if self.use_learned_recurrent_weights is True:
assert(self.recurrent_weights_path is not None)
self.load_weights_from_data_path(self.recurrent_weights_path)
if hasattr(self,'W_ee_scale_factor') and not self.W_ee_scale_factor ==1:
self.W_ee*=self.W_ee_scale_factor
self.W[:self.N_e,:self.N_e]=self.W_ee
else:
if hasattr(self,'smooth_conn_bump') and self.smooth_conn_bump is True:
self.get_tuned_excitatory_weights(binary=False)
else:
self.get_tuned_excitatory_weights(binary=True)
self.run_attractor_sims()
self.get_attractor_fields()
self.get_bump_speed()
self.get_final_attractors()
self.save_attractor_outputs()
else:
print 'Data already present: %s'%self.data_path
def get_detected_phases(self,idx2d,time_idx):
"""
Converts detected bump-peak indexes in phases
idx2d: indexes to convert
time_idx: time index
"""
det_phases=np.zeros_like(self.gp.phases)
for idx in range(len(self.gp.phases)):
if time_idx is None:
indexes=idx2d[:,idx].astype(int)
else:
indexes=idx2d[time_idx,:,idx].astype(int)
det_phase_idx=np.ravel_multi_index(indexes, (self.n_e,self.n_e))
det_phases[idx,:]=self.gp.phases[det_phase_idx,:]
return det_phases
def get_bump_speed(self):
"""
Compute the distance travelled and the speed of the bump in the untuned condition
Returns
-------
None.
"""
old_det_phases=self.get_detected_phases(self.out_bumps_stim_off_evo_idx2d,0)
num_steps=int(self.time_stimulus_off/self.dt)
delta_snap=num_steps/self.recdyn_num_snaps
snap_time=delta_snap*self.dt
one_second_idx=int(1./snap_time)
assert(one_second_idx>0)
#print one_second_idx
stim_off_dists_evo=np.zeros((self.recdyn_num_snaps,self.N_e))
for time_idx in range(self.recdyn_num_snaps):
det_phases=self.get_detected_phases(self.out_bumps_stim_off_evo_idx2d,time_idx)
dists=gl.get_periodic_dist_on_rhombus(self.n_e,det_phases,old_det_phases,self.gp.u1,self.gp.u2)
stim_off_dists_evo[time_idx,:]=dists
old_det_phases=det_phases
self.stim_off_speed_evo=stim_off_dists_evo/snap_time
self.stim_off_avg_last_second_speed=self.stim_off_speed_evo[-one_second_idx:,:].mean(axis=0).mean()
self.stim_off_avg_path_dist_evo=np.cumsum(stim_off_dists_evo.mean(axis=1))
def get_final_attractors(self):
start_det_phases=self.get_detected_phases(self.out_bumps_stim_off_evo_idx2d,0)
end_phases=self.get_detected_phases(self.out_bumps_stim_off_idx2d,None)
self.stim_off_end_dists=gl.get_periodic_dist_on_rhombus(self.n_e,end_phases,start_det_phases,self.gp.u1,self.gp.u2)
self.stim_off_avg_end_dist=self.stim_off_end_dists.mean()
# get unique end phases
self.all_end_phases=set([])
for phase in end_phases:
self.all_end_phases.add(tuple(phase))
self.num_attractors=len(self.all_end_phases)
def plot_landscape(self,stim_on,time_idx=None,quiver=True,rhombus=True,plot=True,attractor_plots=False,rhombus_color='k',rhombus_lw=1):
"""
Plots the energy landscape of the attractor network
stim_on: bool, stimulus on or off
time_idx: time index
"""
import matplotlib
if time_idx is None:
if stim_on:
idxs2d=self.out_bumps_stim_on_idx2d
dX,dY=self.dX_stim_on,self.dY_stim_on
field_len=self.field_len_stim_on
else:
idxs2d=self.out_bumps_stim_off_idx2d
dX,dY=self.dX_stim_off,self.dY_stim_off
field_len=self.field_len_stim_off
else:
if stim_on:
idxs2d=self.out_bumps_stim_on_evo_idx2d[time_idx,:,:]
else:
idxs2d=self.out_bumps_stim_off_evo_idx2d[time_idx,:,:]
X_out=idxs2d[0,:].reshape(self.n_e,self.n_e)/float(self.n_e)
Y_out=idxs2d[1,:].reshape(self.n_e,self.n_e)/float(self.n_e)
dX,dY,field_len=gl.get_vector_field(self.X_in,self.Y_in,X_out,Y_out)
pl.gca().set_aspect('equal', adjustable='box')
H, xedges, yedges = np.histogram2d(idxs2d[0,:], idxs2d[1,:], bins=(range(self.n_e+1), range(self.n_e+1)))
end_phases=self.get_detected_phases(idxs2d,None)
Q=None
if plot is True:
if rhombus is False:
if quiver is False:
img=pl.pcolormesh(self.Y_in,self.X_in,H,norm=matplotlib.colors.LogNorm(vmin=1,vmax=100),cmap='viridis_r')
pp.colorbar(fixed_ticks=[1,10,100])
else:
img=pl.pcolormesh(self.Y_in,self.X_in,H,norm=matplotlib.colors.LogNorm(vmin=1,vmax=100),cmap='viridis_r')
pp.colorbar(fixed_ticks=[1,10,100])
Q=pl.quiver(self.Y_in+.5/self.n_e,self.X_in+.5/self.n_e,dY,dX,headwidth=14)
pl.xlim(0,1)
pl.ylim(0,1)
pp.custom_axes()
else:
if attractor_plots:
dists=np.sqrt(dX**2+dY**2).ravel()
dists/=dists.max()
img=pp.plot_on_rhombus(self.gp.R_T,1,0, self.N_e,self.gp.phases,
dists,plot_axes=False,plot_rhombus=True,
plot_cbar=False,cmap='Purples',rhombus_color=rhombus_color,rhombus_lw=rhombus_lw)
# get unique end phases
all_end_phases=set([])
for phase in end_phases:
all_end_phases.add(tuple(phase))
# plot only the unique end phases
for phase in all_end_phases:
pl.plot(phase[0],phase[1],'.k',ms=8)#,mec='tab:orange',mfc='tab:orange')
else:
img=pp.plot_on_rhombus(self.gp.R_T,1,0, self.N_e,self.gp.phases,
H.ravel(),plot_axes=False,plot_rhombus=True,
plot_cbar=False,norm=matplotlib.colors.LogNorm(vmin=1,vmax=100),cmap='viridis_r')
if quiver:
Q=pl.quiver(self.gp.phases[:,1],self.gp.phases[:,0],dY,dX,pivot='mid',headwidth=10)
pl.xlabel('Bump location (x)')
pl.ylabel('Bump location (y)')
if plot is True:
if attractor_plots:
return img
elif Q is not None:
return img,Q
else:
return H,dY,dX if quiver else H