# -*- coding: utf-8 -*-
"""
Created on Wed Mar 9 12:36:43 2016
@author: dalbis
"""
import numpy as np
from simlib import print_progress
from time import clock
import datetime,time
import os
from grid_inputs import GridInputs
from numpy.fft import fft2,ifft2,fftshift
from numpy.linalg import eigvals
from grid_functions import get_periodic_dist, get_non_periodic_dist,map_merge
from scipy.integrate import romb
from grid_functions import K_t,K_outeq_t,corr_rate
from scipy.integrate import quad
from grid_const import InputType,FilterType
from simlib import ensureParentDir
class GridCorrSpace(object):
"""
Input correlations in space
"""
results_path='../results/grid_corr_space'
@staticmethod
def get_key_params(paramMap):
filter_key_params_filter_input =['mu1','mu2','mu3','tau1','tau2','tau3']
filter_key_params_filter_output=['mu_out','tau_in','tau_out']
basic_key_params=GridInputs.get_key_params(paramMap)+['speed',]
key_params=[]
if 'filter_type' not in paramMap.keys() or paramMap['filter_type']==FilterType.FILTER_INPUT:
key_params=basic_key_params+filter_key_params_filter_input
else:
key_params=basic_key_params+filter_key_params_filter_output
if 'norm_bound_add' in paramMap:
key_params=key_params+['norm_bound_add',]
if 'norm_bound_mul' in paramMap:
key_params=key_params+['norm_bound_mul',]
return key_params
@staticmethod
def get_id(paramMap):
filename=''
for param in GridCorrSpace.get_key_params(paramMap):
filename+='%s=%s_'%(param,paramMap[param])
filename=filename[:-1]
return filename
@staticmethod
def get_data_path(paramMap):
return os.path.join(GridCorrSpace.results_path,GridCorrSpace.get_id(paramMap)+'_data.npz')
def __init__(self,paramMap,auto_gen=True,do_print=True,force_gen_inputs=False,force_gen_corr=False,keys_to_load=[],use_theory=True,normalize_equal_mean_adding=False,normalize_equal_mean_scaling=False):
self.filter_type=FilterType.FILTER_INPUT
self.use_theory=use_theory
# import parameters
for param in GridCorrSpace.get_key_params(paramMap):
setattr(self,param,paramMap[param])
if 'filter_type' in paramMap.keys():
self.filter_type=paramMap['filter_type']
self.id=GridCorrSpace.get_id(paramMap)
self.dataPath=os.path.join(GridCorrSpace.results_path,self.id+'_data.npz')
if auto_gen is True:
# generate and save data
if force_gen_corr or not os.path.exists(self.dataPath):
self.gen_data(do_print=do_print,force_gen_inputs=force_gen_inputs)
# load data
self.load_data(do_print=do_print,keys_to_load=keys_to_load)
def gen_data(self,do_print=True,force_gen_inputs=False):
"""
Generates corr space data and saves it to disk
"""
if do_print:
print
print 'Generating corr space data, id = %s'%self.id
self.post_init(force_gen_inputs=force_gen_inputs)
self.run()
self.post_run()
def load_data(self,do_print=True,keys_to_load=[]):
"""
Loads data from disk
"""
if do_print:
print
print 'Loading corr_space data, Id = %s'%self.id
data= np.load(self.dataPath,mmap_mode='r')
loaded_keys=[]
if len(keys_to_load)==0:
for k,v in data.items():
setattr(self,k,v)
loaded_keys.append(k)
else:
for k in keys_to_load:
setattr(self,k,data[k])
loaded_keys.append(k)
if do_print:
print 'Loaded variables: '+' '.join(loaded_keys)
def post_init(self,force_gen_inputs=False):
self.startClock=clock()
self.startTime=datetime.datetime.fromtimestamp(time.time())
self.startTimeStr=self.startTime.strftime('%Y-%m-%d %H:%M:%S')
self.dx=self.L/self.nx
X,Y=np.mgrid[-self.L/2:self.L/2:self.dx,-self.L/2:self.L/2:self.dx]
self.pos=np.array([np.ravel(X), np.ravel(Y)]).T
if self.filter_type==FilterType.FILTER_INPUT:
self.b1=1./self.tau1
self.b2=1./self.tau2
self.b3=1./self.tau3
elif self.filter_type==FilterType.FILTER_OUTPUT:
self.b_in=1./self.tau_in
self.b_out=1./self.tau_out
self.N=self.n**2
# number of samples for the filter
self.tau_samps=2**8+1
self.tau_ran=np.arange(self.tau_samps)*self.dx
if self.filter_type==FilterType.FILTER_INPUT:
self.K_samp= K_t(self.b1,self.b2,self.b3,self.mu1,self.mu2,self.mu3,self.tau_ran/self.speed)/self.speed
elif self.filter_type==FilterType.FILTER_OUTPUT:
self.K_samp= K_outeq_t(self.b_in,self.b_out,self.mu_out,self.tau_ran/self.speed)/self.speed
tapered=False
if hasattr(self,'tap_inputs') and self.tap_inputs is True:
tapered=True
# for Gaussian periodic we just compute analytically
if self.inputs_type == InputType.INPUT_GAU_GRID \
and self.use_theory is True and tapered is False:
self.compute_analytically=True
print 'Analytical estimation for Gaussian receptive fields (%s)'%('periodic' if self.periodic_inputs else 'non periodic')
ran,step=np.linspace(-self.L/2.,self.L/2.,self.n,endpoint=False,retstep=True)
SSX,SSY = np.meshgrid(ran,ran)
self.centers= np.array([np.ravel(SSX), np.ravel(SSY)]).T
self.amp=self.input_mean*self.L**2/(2*np.pi*self.sigma**2)
else:
print 'Numerical estimation for general inputs'
self.compute_analytically=False
# load inputs
self.inputs=GridInputs(self.__dict__,force_gen=force_gen_inputs)
self.inputs_flat=self.inputs.inputs_flat
self.inputs_path=self.inputs.dataPath
# parameters map
self.paramMap = {'id':self.id,
'L':self.L,'n':self.n,'speed':self.speed,'nx':self.nx,
'sigma':self.sigma,'input_mean':self.input_mean,
'periodic_inputs':self.periodic_inputs,
'inputs_type':self.inputs_type}
if self.filter_type == FilterType.FILTER_INPUT:
self.paramMap=map_merge(self.paramMap, {'tau1':self.tau1,'tau2':self.tau2,'tau3':self.tau3,
'mu1':self.mu1,'mu2':self.mu2,'mu3':self.mu3,
'b1':self.b1,'b2':self.b2,'b3':self.b3})
elif self.filter_type == FilterType.FILTER_OUTPUT:
self.paramMap=map_merge(self.paramMap, {'tau_in':self.tau_in,'tau_out':self.tau_out,'mu_out':self.mu_out,
'b_in':self.b_in,'b_out':self.b_out})
#@profile
def run(self):
# initialize CC matrix
CC_teo=np.zeros((self.N,self.N))
# analytical computation for periodic gaussians
if self.compute_analytically is True:
# choose the K filter function depending on the filter type (input or equivalent output)
if self.filter_type==FilterType.FILTER_INPUT:
K_t_fun=lambda t: K_t(self.b1,self.b2,self.b3,self.mu1,self.mu2,self.mu3,t)
elif self.filter_type==FilterType.FILTER_OUTPUT:
K_t_fun=lambda t: K_outeq_t(self.b_in,self.b_out,self.mu_out,t)
# shorthand for the analytical correlation function
corr_rate_short=lambda tau,u: corr_rate(K_t_fun,self.speed,self.sigma,tau,u)
# compyute distance matrix
CC_dist=np.zeros_like(CC_teo)
if self.periodic_inputs is True:
get_dist=get_periodic_dist
else:
get_dist=get_non_periodic_dist
for i in xrange(self.N):
for j in xrange(self.N):
CC_dist[i,j]= get_dist(self.centers[i,:],self.centers[j,:],self.L)
# fill in correlation value for each distance
all_dist=np.unique(CC_dist.ravel())
for dist in all_dist:
corr=np.pi*self.amp**2*self.sigma**2/self.L**2*quad(corr_rate_short, 0.,2.,args=(dist))[0]
CC_teo[CC_dist==dist]=corr
# numerical calculation for general inputs
else:
#pyfftw.interfaces.cache.enable()
# compute DFTs
#inputs_mat=pyfftw.empty_aligned((self.nx,self.nx,self.N), dtype='float32')
#inputs_mat[:]=self.inputs_flat.reshape(self.nx,self.nx,self.N)
inputs_mat=self.inputs_flat.reshape(self.nx,self.nx,self.N)
inputs_dfts=fft2(inputs_mat,axes=[0,1])
# binning for line integral
center=np.array([self.nx,self.nx])/2
yr, xr = np.indices(([self.nx,self.nx]))
r = np.around(np.sqrt((xr - center[0])**2 + (yr - center[1])**2)).astype(int)
nr = np.bincount(r.ravel())
snap_idx=0
prog_clock=clock()
num_snaps=self.N*(self.N+1)/2
# loop over matric elements
for i in xrange(self.N):
input_i_dft=inputs_dfts[:,:,i]
for j in xrange(i,self.N):
print_progress(snap_idx,num_snaps,start_clock=prog_clock,step=num_snaps/100)
# get j-input
input_j_dft=inputs_dfts[:,:,j]
# inputs correlation
dft_prod=input_i_dft*np.conj(input_j_dft)
#dft_prod=pyfftw.empty_aligned((self.nx,self.nx), dtype='complex64')
#dft_prod[:]=input_i_dft*np.conj(input_j_dft)
input_corr=fftshift(np.real(ifft2(dft_prod)))*self.dx**2
# integral on a circle of radius tau
count = np.bincount(r.ravel(), input_corr.ravel())/nr
corr_prof_teo=np.zeros(len(self.tau_ran))
corr_prof_teo[:self.nx/2] = count[:self.nx/2]
# convolution with the filter
CC_teo[i,j]=romb(self.K_samp*corr_prof_teo,dx=self.dx)
snap_idx+=1
# normalize and fill upper triangle
CC_teo/=self.L**2
CC_teo=CC_teo+CC_teo.T
CC_teo[np.diag(np.ones(self.N).astype(bool))]*=0.5
# normalize correlation matrix to equal mean at the boundary (additive approach)
if self.periodic_inputs is False and hasattr(self,'norm_bound_add') and self.norm_bound_add is True:
print 'Correlation matrix boundary normalization (additive)'
C4d=CC_teo.reshape(self.n,self.n,self.n,self.n)
C4d_norm=np.zeros_like(C4d)
mean_C=C4d.mean(axis=3).mean(axis=2)
for i in xrange(self.n):
for j in xrange(self.n):
C4d_norm[i,j,:,:]=C4d[i,j,:,:]-mean_C[i,j]+mean_C.min()
C_norm=C4d_norm.reshape(self.N,self.N)
CC_teo=C_norm
# normalize correlation matrix to equal mean at the boundary (multiplicative approach)
if self.periodic_inputs is False and hasattr(self,'norm_bound_mul') and self.norm_bound_mul is True:
print 'Correlation matrix boundary normalization (multiplicative)'
C4d=CC_teo.reshape(self.n,self.n,self.n,self.n)
C4d_norm=np.zeros_like(C4d)
mean_C=C4d.mean(axis=3).mean(axis=2)
for i in xrange(self.n):
for j in xrange(self.n):
C4d_norm[i,j,:,:]=C4d[i,j,:,:]/mean_C[i,j]*mean_C.min()
C_norm=C4d_norm.reshape(self.N,self.N)
CC_teo=C_norm
self.CC_teo=CC_teo
self.eigs=eigvals(self.CC_teo)
def post_run(self):
ensureParentDir(self.dataPath)
np.savez(self.dataPath,paramMap=self.paramMap,
CC_teo=self.CC_teo,
eigs=self.eigs,computed_analytically=self.compute_analytically)