import numpy as np
from numpy.fft import fft2,fftfreq,fftshift
from numpy.random import rand
from scipy.ndimage import rotate
from scipy.signal import fftconvolve
from scipy.stats.stats import pearsonr
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
from scipy.ndimage.measurements import label
from scipy.ndimage import gaussian_filter

from simlib import print_progress
import time

class GridProps(object): 
  """
  Utility class to store the properties of a grid pattern
  """
  def __init__(self,n,grid_T,grid_angle):
    
    self.n=n
    self.N=n**2
    self.grid_T=grid_T
    self.grid_angle=grid_angle
    
    self.R_T,self.u1,self.u2,self.u1_rec,self.u2_rec=get_phase_lattice(self.grid_T,self.grid_angle)    
    self.phases=get_phases(self.n,self.grid_T,self.grid_angle,return_u12=False)


def get_mean_shift_length(dX,dY):
  """
  Computes the average length of an arrow in a vector field dX, dY
  """
  return np.sqrt(dX**2+dY**2).ravel().mean()

def get_vector_field(X_in,Y_in,X_out,Y_out):
  """
  Given one input grid of points with coordinates X_in,Y_in and 
  an output grid of points with coordinates X_out,Y_out it computes
  the vector field that connects input and output points taking  
  care of the periodic boundary boundary conditions (in phase space) 
  """
  assert(X_in.shape == Y_in.shape == X_out.shape == Y_out.shape)
  assert(X_in.shape[0]==X_in.shape[1])
  
  n_e=X_in.shape[0]
  N_e=n_e**2
  
  def get_shift_one_dim(M1,M2):
    shift_vect=np.zeros(N_e)
    for pidx in range(N_e):
    
      min_dist = None
      min_shift=None
      m1=M1.ravel()[pidx]
      m2=M2.ravel()[pidx]
      
      for period in (-1,0,1):
        shift=m2-m1+period#*n_e
        if (min_dist is None) or (abs(shift)<min_dist):
          min_shift=shift
          min_dist=abs(shift)
      shift_vect[pidx]=min_shift
    return shift_vect


  dX=get_shift_one_dim(X_in,X_out).reshape(n_e,n_e)
  dY=get_shift_one_dim(Y_in,Y_out).reshape(n_e,n_e)
  mean_len=get_mean_shift_length(dX,dY)
  
  
  return dX,dY,mean_len

    
def get_trans_index(W,ret_W_rolled=False):
  """
  Computes an index that quantifies how much a matrix is translation invariant.
  For a perfectly translation invariant matrix the index approaches 1. 
  This is calculated by applying a circular shift to each row of the matrix with offset equal to the
  row number. For a perfectly translation-invariant matrix, the result is a matrix where all rows are equal.
  We then quantify translation invariance by computing the mean pearson's correlation coefficient between
  each row and the mean row of the matrix. 
  
  """

  import scipy
  
  W_rolled=np.zeros_like(W)
  
  for row_idx in xrange(W.shape[0]):
    W_rolled[row_idx,:]=np.roll(W[row_idx,:],-row_idx)
  
  W_mean_row=W_rolled.mean(axis=0)
  
  row_corrs=np.zeros(W.shape[0])  
  for row_idx in xrange(W.shape[0]):
    row_corrs[row_idx]=scipy.stats.pearsonr(W_rolled[row_idx,:],W_mean_row)[0]
  
    trans_index=row_corrs.mean()
  
  if ret_W_rolled:
    return trans_index,W_rolled
  else:
    return trans_index



def get_recurrent_matrix_tuning_index(W,gp):

  """
  W: recurrent weight matrix
  gp: instance of GridProps
  
  This function takes as imput a recurrent weight matrix that connects grids with similar phases.
  It computes how strong is the phase tuning by comparing the mean of the first harmonic as compared to the mean.
  Note that because the phase space is a rhombus we use a Fourier transform with non-orthogonal unit vectors.
  For each row of the matrix we compute the 2D spectrum of the weights in rhomobidal space and than average the 
  amplitudes at the six first harmonics ([-1,-1],[1,1],[0,1],[0,-1],[1,0],[-1,0] )
 
  """
  
  mean_amps=np.zeros(gp.N)
  for phase_idx in xrange(gp.N):
    mean_amps[phase_idx]=get_conn_to_one_neuron_tuning_index(W[phase_idx],gp)
  return mean_amps.mean()

def get_conn_to_one_neuron_tuning_index(W_one_neuron,gp):
  """
  Compute the recurrent connectivity tuning index for the input connections to one neuron
  """
  max_harmonic=2
  HX,HY,ft=fourier_on_lattice(gp.grid_T,gp.u1_rec,gp.u2_rec,gp.phases,W_one_neuron,max_harmonic=max_harmonic,return_harmonics=True)
  FT=ft.reshape(2*max_harmonic,2*max_harmonic)
  
  # first harmonic
  H1=  np.bitwise_and(HX==0,np.abs(HY)==1)+\
       np.bitwise_and(np.abs(HX)==1,HY==0)+\
       np.bitwise_and(HX==-1,HY==-1)+\
       np.bitwise_and(HX==1,HY==1)
  
  # baseline
  H0=np.bitwise_and(np.abs(HX)==0,np.abs(HY)==0)
  
  return (np.abs(FT[H1])/np.abs(FT[H0])).mean()


  
def get_reciprocal_rhombus_unit_vectors(u1,u2):
  """
  Get the unit vector of the reciprocal rhombus in Fourier space
  """
  U = np.vstack([u1, u2]).T
  U_rec = 2*np.pi*(np.linalg.inv(U)).T
  
  # unit vectors of the reciprocal lattice
  u1_rec = U_rec[:,0]
  u2_rec = U_rec[:,1]
  
  return u1_rec,u2_rec


def get_single_phase_periodic_dist(ref_phase,phase,u1,u2):

  dist_mat=np.zeros(9)
  count=0
  for xp in (-1,0,1):
    for yp in (-1,0,1):
      shift=xp*u1+yp*u2
      dist_mat[count]=np.sqrt(((phase-ref_phase+shift)**2).sum())
      count+=1
  
  return dist_mat.min()


def get_periodic_dist_on_rhombus(n,ref_phase,phases,u1,u2):
  """
  The function returns the periodic distance on rhombus between the ref_phase and all other phases
  The result is an array of n**2 elements, i.e., one distance measure for each phase in the rhombus
  """
  
  # we have to compute nine symmetries to account for 2d periodicity
  dist_mat=np.zeros((n**2,9))

  count=0
  for xp in (-1,0,1):
    for yp in (-1,0,1):
      shift=xp*u1+yp*u2
      dist_mat[:,count]=np.sqrt(((phases-ref_phase+shift)**2).sum(axis=1))
      count+=1
  
  dist=dist_mat.min(axis=1)
  
  return dist



def get_corr_distance_fun(corr,n,u1,u2,phases):
  
  """
  Returns the radial profile (correlatio VS distance function) for each row of the given correlation matrix
  """

  print 'rows: %d'%corr.shape[0]
  
  all_dist=[]  
  all_profiles=[]
  for row_idx in  xrange(corr.shape[0]):
    ref_phase=phases[row_idx,:]
    dist=get_periodic_dist_on_rhombus(n,ref_phase,phases,u1,u2)
  
    dist, inverse,counts = np.unique(dist, return_counts=True, return_inverse=True)
    

    profile =np. bincount(inverse, corr[row_idx,:].squeeze())
    profile/=counts
    
    all_dist.append(dist)
    all_profiles.append(profile)
    
  return all_dist,all_profiles


def find_place_fields(ratemap,max_th=0.3,min_size=9,ret_labels=False):
  """
  Segments place fields from a rate map
  max_th: threshold relative to the maximal firing rate
  min_size: minimum field size in pixels
  """
    
  raw_labels,raw_num_fields=label(ratemap>ratemap.max()*max_th)
  labels=np.zeros_like(ratemap)
  label_idx=1
  
  for i in range(raw_num_fields):
    label_mask=raw_labels==(i+1)
    
    if label_mask.sum()>min_size:
      labels[label_mask]=label_idx
      label_idx+=1
  
      
  num_fields=labels.max()
  
  if ret_labels:  
    return labels,num_fields
  else:
    return num_fields

  

def get_phases(n,grid_T,grid_angle,return_u12=False):
  """
  Samples grid phases evenly on a rhombus with side-length grid_T and orientation grid_angle
  If return_u12=True the function returns also the unit vectors of the rhombus
  """
  
  # unit vectors of the direct lattice
  u1 = grid_T*np.array([np.sin(2*np.pi/3+grid_angle), -np.cos(2*np.pi/3+grid_angle)])
  u2 = grid_T*np.array([-np.sin(grid_angle), np.cos(grid_angle)])
    
  # phase samples
  ran = np.array([np.arange(-n/2.,n/2.)/n]).T
  u1_phases = np.array([u1])*ran
  u2_phases = np.array([u2])*ran
    
  X1,X2=np.meshgrid(u1_phases[:,0],u2_phases[:,0])
  Y1,Y2=np.meshgrid(u1_phases[:,1],u2_phases[:,1])
  X,Y=X1+X2,Y1+Y2

  if return_u12 is True:     
    return u1,u2,np.array([np.ravel(X), np.ravel(Y)]).T
  else:
    return np.array([np.ravel(X), np.ravel(Y)]).T
  
  
def get_pos_idx(p,pos):
  """
  Return the index of the closest position to 'p' found in the array of positions 'pos'
  """  
  p_dist=((np.array(p)-pos)**2).sum(axis=1)  
  p_idx = np.argmin(p_dist)
  return p_idx
  
def get_angle_amps(num_dft,freq_idx,nx):
  """
  Compute the amplitudes of modes of a given frequency (freq_idx) but different angles
  num_dft: DFT of a grid pattern
  freq_idx: Frequency at which the modes need to be computed
  nx: sampling interval in space
  """
  
  # all radii and all angles  
  yr, xr = np.indices((nx,nx))
  all_r =  np.around( np.sqrt((xr - nx/2)**2 + (yr - nx/2)**2))
  all_ang =  np.arctan2(yr-nx/2,xr-nx/2)*180/ np.pi+180
  all_ang[all_ang==360]=0
  
  # flat dfts
  allr_flat=all_r.reshape(nx**2)
  all_ang_flat = all_ang.reshape(nx**2)
  num_dft_flat=num_dft.reshape(nx**2,num_dft.shape[2])

  # indexes
  idxs= np.arange(nx**2)[allr_flat==freq_idx]
  uidxs=idxs[0:len(idxs)/2]
  
  # take as zero the angle of the fastest growing mode
  max_ang=all_ang_flat[ np.argmax(num_dft_flat[:,-1])]
  angles=np.remainder(all_ang_flat-max_ang,180)[uidxs]

  # amplitudes
  amps=  [ np.squeeze(num_dft_flat[idx,:]) for idx in uidxs]
  
  return angles,amps
  

def get_grid_params(J,L,nx,num_steps=50,return_cx=False):
  """
  Estimates parameters of a grid pattern, i.e., gridness core, grid spacing,
  grid angle,and grid phase
  
  J: input grid pattern 
  L: side-length of the environment
  nx: number of space samples
  num_steps: number of iteration steps to compute the gridness score
  return_cx: if True returns also the autocorrelation matrix
  """
  
  dx=L/nx
  X,Y=np.mgrid[-L/2:L/2:dx,-L/2:L/2:dx]
  pos=np.array([np.ravel(X), np.ravel(Y)])
  
  if return_cx is True:
    score,best_outr,angle,spacing,cx=gridness(J,L/nx,
                                           computeAngle=True,doPlot=False,
                                           num_steps=num_steps,return_cx=True)   
  else:
    score,best_outr,angle,spacing=gridness(J,L/nx,
                                           computeAngle=True,doPlot=False,
                                           num_steps=num_steps,return_cx=False)   

  if spacing is not np.NaN and angle is not np.NaN:                                         
    ref_grid=simple_grid_fun(pos,grid_T=spacing,
                             angle=-angle,phase=[0, 0]).reshape(nx,nx)
    phase=get_grid_phase(J,ref_grid,L/nx,doPlot=False,use_crosscorr=True)
  else:
    phase=np.NaN

  if return_cx is True:
    return score, spacing,angle,phase,cx
  else:    
    return score, spacing,angle,phase


  
def compute_scores_evo(J_vect,n,L,num_steps=50):
  """
  Computes gridness scores for a matrix at different time points
  J_vect = N x num_snaps
  """
  
  num_snaps=J_vect.shape[1]
  assert(J_vect.shape[0]==n**2)
  start_clock=time.time()
  best_score=-1      
  
  scores=np.zeros(num_snaps)
  spacings=np.zeros(num_snaps)
  angles=np.zeros(num_snaps)
  phases=np.zeros((2,num_snaps))
  
  for snap_idx in xrange(num_snaps):
    print_progress(snap_idx,num_snaps,start_clock=start_clock)

    J=J_vect[:,snap_idx]
    
    score,spacing,angle,phase= get_grid_params(J.reshape(n,n),L,n,num_steps=num_steps)

    best_score=max(best_score,score)
    scores[snap_idx]=score
    spacings[snap_idx]=spacing
    angles[snap_idx]=angle
    phases[:,snap_idx]=phase
    
  score_string='final_score: %.2f    best_score: %.2f    mean_score: %.2f\n'%(score,best_score,np.mean(scores))  
  print score_string

  return scores,spacings,angles,phases      

        
def dft2d_num(M_evo,L,n,nozero=True):
  """
  Computes the 2D DFT of a n x n x time_samples matrix wrt the first two dimensions.
  The DC component is set to zero
  """
  
  assert(len(M_evo.shape)==3)
  assert(M_evo.shape[0]==M_evo.shape[1])

  allfreqs = fftshift(fftfreq(n,d=L/n))
  freqs=allfreqs[n/2:]
  M_dft_evo=fftshift(abs(fft2(M_evo,axes=[0,1])),axes=[0,1])
  if nozero is True:
    M_dft_evo[n/2,n/2,:]=0
  return M_dft_evo,freqs,allfreqs



def dft2d_teo(J0_dft,eigs,time,n):
  """
  Compute the theoretical DFT solution of a linear dynamical system
  given the eigenvalues and the initial condition
  """
  N=n**2
  teo_dft=J0_dft.reshape(N,1)*np.exp(time[np.newaxis,:]*eigs.reshape(N,1))
  teo_dft=teo_dft.reshape(n,n,len(time))
  teo_dft[n/2,n/2]=0

  return teo_dft

  
  
def radial_profile(data,norm=False):
  """
  Compute radial profile of a 2D function sampled on a square domain,
  assumes the function is centered in the middle of the square

  # TEST:
  #  
  #  ran=np.arange(-1.01,1.01,0.01)
  #  SSX,SSY = meshgrid(ran,ran)
  #  
  #  T=np.exp(-(SSX**2+SSY**2))
  #  P=radial_profile(T,norm=True)
  #  
  #  pl.figure()
  #  pl.subplot(111,aspect='equal')
  #  pl.pcolormesh(SSX,SSY,T)
  #  custom_axes()
  #  colorbar()
  #  
  #  pl.figure()
  #  pl.plot(ran[101:],P)
  
  """
  
  assert(len(data.shape)==2)
  assert(data.shape[0]==data.shape[1])


  center=np.array(data.shape)/2
  yr, xr = np.indices((data.shape))
  r = np.around(np.sqrt((xr - center[0])**2 + (yr - center[1])**2))
  r = r.astype(int)

  profile =np.bincount(r.ravel(), data.ravel())
  
  if norm is True:
    nr = np.bincount(r.ravel())
    profile/=nr
    
  profile=profile[:len(data)/2]    
  
  return profile
  

def dft2d_profiles(M_dft_evo):
  """
  Computes DFT 2D radial profiles at different time points
  """
  
  assert(len(M_dft_evo.shape)==3)
  assert(M_dft_evo.shape[0]==M_dft_evo.shape[1])

  num_snaps=M_dft_evo.shape[2]
  profiles = np.array([radial_profile(abs(M_dft_evo[:,:,idx]),norm=True) for idx in xrange(num_snaps)])
  return profiles


def gridness_evo(M,dx,num_steps=50):
  """
  Compute gridness evolution
  M: matrix nx X nx X num_steps (or num_cells)
  """
  scores=[]
  spacings=[]
  assert(len(M.shape)==3)
  num_snaps=M.shape[2]
  print 'Computing scores...'
  for idx in xrange(num_snaps):
    
    score,best_outr,orientation,spacing=gridness(M[:,:,idx],dx,computeAngle=False,doPlot=False,num_steps=num_steps)
    scores.append(score)
    spacings.append(spacing)
    print_progress(idx,num_snaps)
  return scores,spacings
        
def unique_rows(a):
  """
  Removes duplicates rows from a matrix
  """
  a = np.ascontiguousnp.array(a)
  unique_a = np.unique(a.view([('', a.dtype)]*a.shape[1]))
  return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1]))
  
  
def detect_peaks(cx,size=2,do_plot=False):
  """
  Takes an image and detect the peaks usingthe local maximum filter.
  Returns a boolean mask of the peaks (i.e. 1 when
  the pixel's value is the neighborhood maximum, 0 otherwise)
  """

  # smooth the autocorrelation for noise reduction
  cx_smooth=gaussian_filter(cx, sigma=3)

  # define an size-connected neighborhood
  neighborhood = generate_binary_structure(size,size)
  
  #apply the local maximum filter; all pixel of maximal value in their neighborhood are set to 1
  local_max = maximum_filter(cx_smooth, footprint=neighborhood)==cx_smooth
  #local_max is a mask that contains the peaks we are looking for, but also the background. In order to isolate the peaks we must remove the background from the mask.

  #we create the mask of the background
  background = (cx_smooth==0)

  #a little technicality: we must erode the background in order to successfully subtract it form local_max, otherwise a line will 
  # appear along the background border (artifact of the local maximum filter)
  eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)

  #we obtain the final mask, containing only peaks, by removing the background from the local_max mask
  detected_peaks = local_max.astype(int) - eroded_background.astype(int)

  if do_plot is True:
    import pylab as pl
    
    pl.figure(figsize=(10,3))
    pl.subplots_adjust(wspace=0.3)
    pl.subplot(131,aspect='equal')
    pl.pcolormesh(cx_smooth)
    pl.title('Smoothed autocorrelation')
    pl.subplot(132,aspect='equal')
    pl.pcolormesh(local_max)
    pl.title('Local maxima')
    pl.subplot(133,aspect='equal')
    pl.pcolormesh(background)
    pl.title('Background')

  return detected_peaks

def detect_six_closest_peaks(cx,doPlot=False):
  """
  Detects the six peaks closest to the center of the autocorrelogram
  cx: autocorrelogram matrix
  """

  # indexes to cut a circle in the auto-correlation matrix
  SX,SY = np.meshgrid(range(cx.shape[0]),range(cx.shape[1]))
  
  if np.remainder(cx.shape[0],2)==1:
    tile_center = np.array([[(cx.shape[0]+1)/2-1, (cx.shape[1]+1)/2-1]]).T
  else:
    tile_center = np.array([[(cx.shape[0])/2, (cx.shape[1])/2]]).T
  

  peaks=detect_peaks(cx)
  peaks_xy = np.array([SX[peaks==1],SY[peaks==1]])
  
  if peaks_xy.shape[1]<6:
    print 'Warning: less than 6 peaks found!!!'
    return np.zeros((2,6)),tile_center
    
  else:
    peaks_dist = np.sqrt(sum((peaks_xy-tile_center)**2,0))
    
    sort_idxs = np.argsort(peaks_dist)
    
    peaks_dist=peaks_dist[sort_idxs]
    peaks_xy=peaks_xy[:,sort_idxs]
  
    # filter out center and peaks too close to the center (duplicates)
    to_retain_idxs=peaks_dist>2
    
    
    if sum(to_retain_idxs)<6:
      print 'Warning: less than 6 peaks to retain!!!'
      return np.zeros((2,6)),tile_center
      
    else:      
      peaks_dist=peaks_dist[to_retain_idxs]
      peaks_xy=peaks_xy[:,to_retain_idxs]   
      idxs=np.arange(6)
    
    
      if doPlot is True:
        import pylab as pl  
        pl.figure()
        pl.subplot(111,aspect='equal')
        pl.pcolormesh(cx)
        pl.scatter(peaks_xy[0,idxs]+0.5,peaks_xy[1,idxs]+0.5)
        pl.scatter(tile_center[0]+0.5,tile_center[1]+0.5)
        
      return peaks_xy[:,idxs],tile_center

    

def comp_psi_scores(L,nx,r_maps):
  """
  Similar to Simon's score for firing rate maps
  """
  
  if len(r_maps.shape)==1:
    r_maps=r_maps[:,np.newaxis]

  assert(r_maps.shape[0]==nx**2)

  all_psi=[]  
  for cell_idx in xrange(r_maps.shape[1]):
    
    r_map=r_maps[:,cell_idx]
    cx=norm_autocorr(r_map.reshape(nx,nx))
  
  
    peaks,center = detect_six_closest_peaks(cx,False)           # get six closest peaks
  
    #print '%d peaks detected: '%peaks.shape[1]
    
    cent_peaks=peaks-center                               # center them
    peak_dists = np.sqrt(sum(cent_peaks**2,0))               # peak distances
    norm_peaks = cent_peaks/peak_dists                    # normalize to unit norm
    
    angles = np.arccos(norm_peaks[0,:])  # calculate angle
    
    psi_M=np.zeros(5)
    for M_idx,M in enumerate([2,3,4,5,6]):
      psi_M[M_idx]=np.abs(np.mean(np.exp(1j*M*angles)))
     
    if np.argmax(psi_M)==4:
      psi=psi_M[4]
    else:
      psi=0.
    #print angles*180/np.pi
    
    all_psi.append(psi)
  
  return np.array(all_psi)
        

  
def get_grid_phase(x,x0,ds,doPlot=False,use_crosscorr=True):
  """
  Return the grid phase relative to a given reference gridness
  x: grid for which the phase has to be estimated
  x0: reference grid
  """
  if use_crosscorr is True:
    cx=norm_crosscorr(x,x0,type='full')
    # cutout the central part
    n=(cx.shape[0]+1)/2
    cx=cx[n-1-n/2:n-1+n/2,n-1-n/2:n-1+n/2]    
  else:
    cx=x
    n=cx.shape[0]    
  
  L=n*ds
  SX,SY,tiles=get_tiles(L,ds)
  
  peaks=detect_peaks(cx)

  if peaks.sum()>0:
    peaks_xy = np.array([SX[peaks==1],SY[peaks==1]])
    peaks_dist = sum(peaks_xy**2,0)
    
    idxs = np.argmin(peaks_dist)
  
    if doPlot is True:
      import pylab as pl  
      pl.pcolormesh(SX,SY,cx)
      pl.scatter(peaks_xy[0,idxs],peaks_xy[1,idxs])
      pl.plot(0,0,'.g')
  
    phase = peaks_xy[:,idxs]
  else:
    phase=np.NaN
    
  return phase


def get_grid_spacing_and_orientation(cx,ds,doPlot=False,compute_angle=True,ax=None):
  """
  Returns the grid orientation given the autocorrelogram
  ds: space discretization step
  cx: autocorrelogram matrix
  :returns: an angle in radiants
  """

  peaks,center = detect_six_closest_peaks(cx)           # get six closest peaks
  
  #print '%d peaks detected: '%peaks.shape[1]
  
  cent_peaks=peaks-center                               # center them
  #print cent_peaks
  
  peak_dists = np.sqrt(sum(cent_peaks**2,0))               # peak distances
  norm_peaks = cent_peaks/peak_dists                    # normalize to unit norm

  norm_peak_1quad_idxs=np.bitwise_and(norm_peaks[1,:]>0,norm_peaks[0,:]>0)          # indexes of the peaks in the first quadrant x>0 and y>0
  spacing=np.mean(peak_dists)*ds
  
  #print 'get_grid_spacing_and_orientation spacing=%.2f'%spacing
  
  angle=np.NaN

  if compute_angle is True:
    # if we have at least one peak in the first quadrant
    if any(norm_peak_1quad_idxs) == True: 
      norm_peaks_1quad=norm_peaks[:,norm_peak_1quad_idxs]                          # normalized coordinates of the peaks in the first quadrant
      norm_orientation_peak_idx=np.argmin(norm_peaks_1quad[1,:])                   # index of the peak with minumum y 
      norm_orientation_peak=norm_peaks_1quad[:,norm_orientation_peak_idx]          # normalized coordinates of the peak with minimum y
      
      peaks_1quad = peaks[:,norm_peak_1quad_idxs]                                 # coordinates of the peaks in the first quadrant
      orientation_peak=peaks_1quad[:,norm_orientation_peak_idx]                   # coordinates of the peak with minimum y 
  
      angle = np.arccos(norm_orientation_peak[0])  # calculate angle
      
      if angle <0:
        angle=angle+np.pi/3
  
      if doPlot is True:
        import pylab as pl  
        if ax is None:
          pl.figure()
          pl.subplot(111,aspect='equal')

        pl.pcolormesh(cx/(cx[center[0],center[1]]),vmax=1.,cmap='binary',rasterized=True)
        pl.plot([center[0]+.5,orientation_peak[0]+.5],[center[1]+.5,orientation_peak[1]+.5],'-y',linewidth=2)

        for i in xrange(6):
          pl.scatter(peaks[0,i]+0.5,peaks[1,i]+0.5,c='r')
          
        pl.scatter(center[0]+0.5,center[1]+0.5,c='r')
        hlen=cx.shape[0]/3.
        pl.xlim([center[0]-hlen,center[0]+hlen])
        pl.ylim([center[1]-hlen,center[1]+hlen])
    else:
      pass
      #print "no peaks in the first quadrant"

  return angle,spacing


def fr_fun(h,gain=.1,th=0,sat=1,type='arctan'):
  """
  Threshold-saturation firing rate function 
  h: input
  sat: saturation level
  gain: gain
  th: threshold
  """
  if type == 'arctan':
    return sat*2/np.pi*np.arctan(gain*(h-th))*0.5*(np.sign(h-th) + 1)
  elif type == 'sigmoid':
    return sat*1/(1+np.exp(-gain*(h-th)))
  elif type == 'rectified':
    return h*0.5*(np.sign(h-th) + 1)
  elif type=='linear':
    return h



def pf_fun(pos,center=np.array([0,0]),sigma=0.05,amp=1):
  """
  Gaussian place-field input function
  pos: position
  center: center
  sigma: place field width
  amp: maximal amplitude
  """
  
  # multiple positions one center
  if len(pos.shape)>1 and len(center.shape)==1:
    center = np.array([center]).T
  # one position multiple centers
  if len(pos.shape)==1 and len(center.shape)>1:
   pos = np.array([pos]).T
  return np.exp(-sum((pos-center)**2,0)/(2*sigma**2))*amp



def simple_grid_fun(pos,grid_T,angle=0,phase=[0, 0],waves=[0,1,2]):
  """
  Another function for a grid with a simpler mathematical description
  """
  assert(  grid_T is not np.NaN 
         and angle is not np.NaN
         and phase is not np.NaN)
  
  alpha=np.array([np.pi*i/3+angle for i in waves])
  k=4*np.pi/(np.sqrt(3)*grid_T)*np.array([np.cos(alpha),np.sin(alpha)]).T
  if len(pos.shape)>1:
    phase = np.array([phase]).T
  
  return sum(np.cos(np.dot(k,pos+phase)),0)
  

def norm_crosscorr(x,y,type='full',pearson=True):
  """
  Normalized cross-correlogram
  """
  n = fftconvolve(np.ones(x.shape),np.ones(x.shape),type)
  cx=np.divide(fftconvolve(np.flipud(np.fliplr(x)),y,type),n)
  if pearson is True:
    return (cx-x.mean()**2)/x.var()
  else:
    return cx
    
def norm_autocorr(x,type='full',pearson=True):
  """
  Normalized autocorrelation, we divide about the amount of overlap which is given by the autoconvolution of a matrix of ones
  """
  x0 = x-x.mean()
  #return fftconvolve(flipud(fliplr(x)),x,type)
  n = fftconvolve(np.ones(x0.shape),np.ones(x0.shape),type)
  cx=np.divide(fftconvolve(np.flipud(np.fliplr(x0)),x0,type),n)
  if pearson is True:
    return cx/x.var()
  else:
    return cx

def comp_score(cx,idxs,min_diff=False):
  """
  Calculates the gridness score for an autocorrelation pattern and a given array of indexes for elements to retain.
  For the final gridness score the elements shall be outside an inner radius around the central peak and inside an outer radius
  containing the six closest peaks
  cx: autocorrelogram
  idxs: array of indexes for the elements to retain
  """
  deg_ran = [60, 120, 30, 90, 150]   # angles for the gridness score    
  c = np.zeros(len(deg_ran))         # correlation for each rotation angle
  cx_in = cx[idxs[0,:],idxs[1,:]]    # elements of the autocorellation pattern to retain

  # calculate correlation for the five angles
  for deg_idx in range(len(deg_ran)):
    rot = rotate(cx,deg_ran[deg_idx],reshape=False)
    rot_in = rot[idxs[0,:],idxs[1,:]]
    c[deg_idx]=pearsonr(cx_in,rot_in)[0]

  # gridness score by taking the minimum difference
  if min_diff is True:    
    score=c[0:2].min()-c[2:].max()
  # gridness score by taking tha difference of the means  
  else:
    score=np.mean(c[0:2])-np.mean(c[2:]) 
  return score

def get_score_corr_angle(cx,idxs):
  """
  Compute the Pearnons's correlation of all rotation angles 
  """
  
  deg_ran = np.arange(0,180,1)      # angles for the gridness score    
  c = np.zeros(len(deg_ran))         # correlation for each rotation angle
  cx_in = cx[idxs[0,:],idxs[1,:]]    # elements of the autocorellation pattern to retain

  # calculate correlation for the five angles
  for deg_idx in range(len(deg_ran)):
    rot = rotate(cx,deg_ran[deg_idx],reshape=False)
    rot_in = rot[idxs[0,:],idxs[1,:]]
    c[deg_idx]=pearsonr(cx_in,rot_in)[0]

  return deg_ran,c
  



def gridness(x,ds,doPlot=False,computeAngle=False,num_steps=20,
             score_th_for_orientation=0.3,axes=None,cx=None,pearson=True,return_cx=False,min_diff=False):

  if cx is None:
    cx = norm_autocorr(x,pearson=pearson)                                         # compute the normalized autocorrelation of the pattern

  # compute the radial profile and the inner radius of the ring
  profile=radial_profile(cx,norm=False)
  inrad=np.argwhere(profile<0)[0]

  # compute grid spacing
  angle,spacing=get_grid_spacing_and_orientation(cx,ds,doPlot=False,compute_angle=False)
               
  # indexes to cut a circle in the auto-correlation matrix
  SX,SY = np.meshgrid(range(cx.shape[0]),range(cx.shape[1]))
  tiles= np.array([np.ravel(SX), np.ravel(SY)])
  if np.remainder(cx.shape[0],2)==1:
    tile_center = np.array([[(cx.shape[0]+1)/2-1, (cx.shape[1]+1)/2-1]]).T
  else:
    tile_center = np.array([[(cx.shape[0])/2, (cx.shape[1])/2]]).T
  tiles_dist = np.sqrt(sum((tiles-tile_center)**2,0))
    
  # minimal and maximal outer radii
  max_outr=np.ceil(spacing*2/ds)
  min_outr=np.floor(spacing*0.5/ds)


  outr_ran = np.arange(min_outr,max_outr,max_outr/num_steps)    # range of outer radii for the gridness score
  best_score = -2                                               # best gridness score
  best_outr = min_outr                                          # best radius    
  
  # loop over increasing radii and retain the best score
  for outr_idx in range(len(outr_ran)):
  
    # compute score for the current outer radius
    idxs=tiles[:,np.bitwise_and(tiles_dist>inrad,tiles_dist<outr_ran[outr_idx])]
    score = comp_score(cx,idxs,min_diff=True)
  
    # retain best score
    if score > best_score:
      best_score = score
      best_outr = outr_ran[outr_idx]
  
  
  # plot if requested
  if doPlot is True:
    import pylab as pl    
    import plotlib as pp
    ax=pl.gca() if axes is None else axes
    pl.sca(ax)
    pl.figure(figsize=(8,6))
    
    ax_grid = pl.GridSpec(2, 2, wspace=0.4, hspace=0.3)
    
    #pl.subplots_adjust(wspace=0.5)
    
    ax=pl.subplot(ax_grid[0,0])
    pl.axis('equal')
    pl.pcolormesh(x.T,rasterized=True)
    pp.colorbar()
    pp.noframe()
    
    ax=pl.subplot(ax_grid[0,1])
    pl.axis('equal')
    pl.pcolormesh(cx,rasterized=True)
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.set_frame_on(False)
    theta_ran = np.arange(0,2*np.pi,0.1)
    pl.plot(best_outr*np.cos(theta_ran)+tile_center[0],best_outr*np.sin(theta_ran)+tile_center[1],'w')
    pl.plot(max_outr*np.cos(theta_ran)+tile_center[0],max_outr*np.sin(theta_ran)+tile_center[1],'k')
    pl.plot((spacing/ds)*np.cos(theta_ran)+tile_center[0],(spacing/ds)*np.sin(theta_ran)+tile_center[1],'g')
    pl.plot(inrad*np.cos(theta_ran)+tile_center[0],inrad*np.sin(theta_ran)+tile_center[1],'w')
    pl.text(10,10,'%.2f'%best_score, color='black',fontsize=10, weight='bold',bbox={'facecolor':'white'})
    pp.colorbar()
    
    ax=pl.subplot(ax_grid[1,:])
    deg_ran,c = get_score_corr_angle(cx,idxs)
    pl.plot(deg_ran,c,'-k')
    pl.xlabel('Angle')
    pl.ylabel('Correlation')
    pp.custom_axes()
    pl.axhline(1)
    pl.ylim(-1,1)
    pl.title('max= %.2f min=%.2f'%(c.max(),c.min()))
  
  # calculate angle if there we pass a threshold for the gridness
  angle=np.NaN
  if computeAngle is True and best_score > score_th_for_orientation:
    angle,spacing = get_grid_spacing_and_orientation(cx,ds,compute_angle=True)
  
  if return_cx  is True:
    return best_score,best_outr,angle,spacing,cx
  else:
    return best_score,best_outr,angle,spacing

  
               
  
  
#def gridness_old(x,ds,doPlot=False,computeAngle=False,num_steps=20,
#             score_th_for_orientation=0.3,axes=None,cx=None,pearson=True,return_cx=False,min_diff=False):
#  """
#  Computes the gridness score of a pattern
#  x: pattern 
#  doPolt: plots the autocorrelogram and the gridness score
#  """
#  if cx is None:
#    cx = norm_autocorr(x,pearson=pearson)                                         # compute the normalized autocorrelation of the pattern
#
#
#  # compute grid spacing and orientation    
#  angle,spacing=get_grid_spacing_and_orientation(cx,ds,doPlot=False,compute_angle=False)
#
#  max_outr=np.ceil(spacing*2.5/ds)
#  min_outr=np.floor(spacing*0.7/ds)
#  outr_ran = np.arange(min_outr,max_outr,max_outr/num_steps)    # range of outer radii for the gridness score
#  best_score = -2                                               # best gridness score
#  best_outr = min_outr                                          # best radius
#
#  # indexes to cut a circle in the auto-correlation matrix
#  SX,SY = np.meshgrid(range(cx.shape[0]),range(cx.shape[1]))
#  tiles= np.array([np.ravel(SX), np.ravel(SY)])
#  if np.remainder(cx.shape[0],2)==1:
#    tile_center = np.array([[(cx.shape[0]+1)/2-1, (cx.shape[1]+1)/2-1]]).T
#  else:
#    tile_center = np.array([[(cx.shape[0])/2, (cx.shape[1])/2]]).T
#  tiles_dist = np.sqrt(sum((tiles-tile_center)**2,0))
#
#  # loop over increasing radii and retain the best score
#  for outr_idx in range(len(outr_ran)):
#
#    # compute score for the current outer radius
#    idxs=tiles[:,tiles_dist<outr_ran[outr_idx]]
#    score = comp_score(cx,idxs,min_diff=min_diff)
#
#    # retain best score
#    if score > best_score:
#      best_score = score
#      best_outr = outr_ran[outr_idx]
#
#  # take as inner radius half of the outer radius and recompute the score
#  in_r = best_outr/2
#  idxs= tiles[:,np.logical_and(tiles_dist>in_r,tiles_dist<best_outr)]
#  best_score = comp_score(cx,idxs,min_diff=min_diff)
#
#  # plot if requested
#  if doPlot is True:
#    import pylab as pl    
#    import plotlib as pp
#    #ax=pl.gca() if axes is None else axes
#    #pl.sca(ax)
#    pl.figure(figsize=(8,6))
#    
#    ax_grid = pl.GridSpec(2, 2, wspace=0.4, hspace=0.3)
#    
#    #pl.subplots_adjust(wspace=0.5)
#    
#    ax=pl.subplot(ax_grid[0,0])
#    pl.axis('equal')
#    pl.pcolormesh(x.T,rasterized=True)
#    pp.colorbar()
#    pp.noframe()
#    
#    ax=pl.subplot(ax_grid[0,1])
#    pl.axis('equal')
#    pl.pcolormesh(cx,rasterized=True)
#    ax.axes.get_yaxis().set_visible(False)
#    ax.axes.get_xaxis().set_visible(False)
#    ax.set_frame_on(False)
#    theta_ran = np.arange(0,2*np.pi,0.1)
#    pl.plot(best_outr*np.cos(theta_ran)+tile_center[0],best_outr*np.sin(theta_ran)+tile_center[1],'w')
#    pl.plot(max_outr*np.cos(theta_ran)+tile_center[0],max_outr*np.sin(theta_ran)+tile_center[1],'k')
#    pl.plot((spacing/ds)*np.cos(theta_ran)+tile_center[0],(spacing/ds)*np.sin(theta_ran)+tile_center[1],'g')
#    pl.plot(in_r*np.cos(theta_ran)+tile_center[0],in_r*np.sin(theta_ran)+tile_center[1],'w')
#    pl.text(10,10,'%.2f'%best_score, color='black',fontsize=10, weight='bold',bbox={'facecolor':'white'})
#    pp.colorbar()
#    
#    ax=pl.subplot(ax_grid[1,:])
#    deg_ran,c = get_score_corr_angle(cx,idxs)
#    pl.plot(deg_ran,c,'-k')
#    pl.xlabel('Angle')
#    pl.ylabel('Correlation')
#    pp.custom_axes()
#    pl.axhline(1)
#    pl.ylim(-1,1)
#    pl.title('max= %.2f min=%.2f'%(c.max(),c.min()))
#    
#    
#  # calculate angle if there we pass a threshold for the gridness
#  angle=np.NaN
#  if computeAngle is True and best_score > score_th_for_orientation:
#    angle,spacing = get_grid_spacing_and_orientation(cx,ds,compute_angle=True)
#  
#  if return_cx  is True:
#    return best_score,best_outr,angle,spacing,cx
#  else:
#    return best_score,best_outr,angle,spacing

def get_tiles(L,ds):
  """
  Returns the positions of the vertices of a square grid of side length L.
  The parameter ds indicates the grid spacing.
  """
  SX,SY = np.meshgrid(np.arange(-L/2.,L/2.,ds),np.arange(-L/2.,L/2.,ds))
  tiles= np.array([np.ravel(SX), np.ravel(SY)])
  return SX,SY,tiles
  
def get_tiles_int(L,num_samp=200):
  """
  Returns the positions of the vertices of a square grid of side length L.
  The parameter ds indicates the grid spacing.
  """
  samples=np.arange(-num_samp/2,num_samp/2)/float(num_samp)*L
  SX,SY = np.meshgrid(samples,samples)
  tiles= np.array([np.ravel(SX), np.ravel(SY)])
  return SX,SY,tiles
  
def divide_triangle(parent_triangle,grid_vertices,level=0,max_level=3,prec=9):
  """
  Recursively tassellates an equilateral triangles. 
  parent_triangle: input list of the vertices of the triangle to tasselate
  grid_vertices: output set of the vertices of the tassellated grid
  level: current level of the recursion
  max_level: desired level of recursive tassellation  
  The algorithm works like this. The triangle is divided in four equilateral
  child triangles by taking the midpoints of its edges. The central child triangle
  has vertices given by the thee midpoints, which are computed by the function
  get_child_triangle. Than the vertices of the other three sibling triangles are
  computed by the function get_sibling_triangles. After this subdivision,
  the function is called recursively for generated child triangle.
  """
  
  # the two main functions of the algorithm  
  get_child_triangle = lambda parent_triangle: [ tuple(np.around(0.5*(np.array(parent_triangle[i])+np.array(parent_triangle[i-1])),6)) for i in range(3) ]
  get_sibling_triangles = lambda parent_triangle,child_triangle: [ [ parent_triangle[p-3], child_triangle[p-3], child_triangle[p-2] ] for p in range(3)]

  child_triangle=get_child_triangle(parent_triangle)   # get the central child triangle
  [grid_vertices.add(v) for v in child_triangle if v not in grid_vertices]       # add it to the final set of vertices
  child_triangles=[child_triangle]

  if level<max_level:  
    child_triangles+= get_sibling_triangles(parent_triangle,child_triangle)
    for new_parent_triangle in child_triangles:
      divide_triangle(new_parent_triangle,grid_vertices,level+1,max_level,prec)
  else:
    return 
  
 
 
def get_all_phases(freq,angle,num_phases=100):
  """
  Returns a set of phases evenly distributed within the whole phase space
  """
  # the elementary phase space is an hexagon   
  side=np.sqrt(3)/(3*freq)
  hexagon=get_hexagon(side,angle)

  # first we get the phases uniformly spaced on a parallelogram
  axes=(0,1)
  phases=get_phases_on_pgram(freq,angle,num_phases,axes=axes)
  
  # then we shift these phases by +-lshift in the direction of the largest
  # diagonal of the parralelogram
  lshift=np.sqrt(3)/(6*freq)
  alpha1=angle+np.pi/6+axes[0]*np.pi/3
  alpha2=angle+np.pi/6+axes[1]*np.pi/3
  shift= lshift*(np.array([np.cos(alpha1)+np.cos(alpha2),np.sin(alpha1)+np.sin(alpha2)]))
  phases_shift1=phases+np.array([shift])
  phases_shift2=phases-np.array([shift])
  
  # we stack the three set of phases obtained
  all_phases=np.vstack((phases,phases_shift1,phases_shift2))
  
  # we discard the phases outside the elementary phase space  
  idxs=np.points_inside_poly(all_phases,hexagon)
  all_phases=all_phases[idxs,:]
  return all_phases
  
  
  
def get_hexagon(side,angle):
  """
  Returns the vertices of an hexagon of a given side length and oriented according
  to a given angle. The first and the last vertices are the same (this is to have)
  a closed line whan plotting the hexagon.
  """
  verts=np.zeros((7,2))
  for i in range(7):
    alpha=angle+np.pi/6+i*np.pi/3
    verts[i,0]=side*np.cos(alpha)
    verts[i,1]=side*np.sin(alpha)
  return verts
  
def get_rhombus(side,angle=np.pi/6):
  """
  Returns the vertices of a rhombus with edges oriented 60 degrees apart. 
  The first and the last vertices are the same (this is to have)
  a closed line whan plotting the polygon.
  """
  verts=np.zeros((5,2))
  verts[1,:]=side*np.array([np.cos(angle),np.sin(angle)])
  verts[3,:]=side*np.array([np.cos(angle+np.pi/3),np.sin(angle+np.pi/3)])
  verts[2,:]=verts[1,:]+verts[3,:]
  center=(verts[1,:]+verts[3,:])/2
  verts-=center
  return verts
  
def get_simple_hexagon(side,angle):
  """
  Same as get_hexagon but without the pi/6 offset in the orientation
  """
  verts=np.zeros((7,2))
  for i in range(7):
    alpha=angle+i*np.pi/3
    verts[i,0]=side*np.cos(alpha)
    verts[i,1]=side*np.sin(alpha)
  return verts
  
  
  
def get_phases_on_pgram(freq,angle,num_phases=36,axes=(0,1)):
  """
  Returns a set of phases uniformely sampled within a parallelogram.
  The parallelogram is the space spanned by two vectors oriented as two
  of the three grid axes and having length equal to double the period of the 
  cosine waves that form the grid.
  """
  # period of the cosines of the grid with the given parameter
  l=np.sqrt(3)/(2*freq)
  dl =l/(np.sqrt(num_phases)/2)
  ran=np.arange(-l,l,dl)+dl/2
  
  # the angles of the two axes    
  alpha1=angle+np.pi/6+axes[0]*np.pi/3
  alpha2=angle+np.pi/6+axes[1]*np.pi/3
   
  # points on the first axis
  x_phases1=np.cos(alpha1)*ran
  y_phases1=np.sin(alpha1)*ran

  # points on the first axis
  x_phases2=np.cos(alpha2)*ran
  y_phases2=np.sin(alpha2)*ran

  # points spanned by the two axes  
  X1,X2=np.meshgrid(x_phases1,x_phases2)
  Y1,Y2=np.meshgrid(y_phases1,y_phases2)
  X,Y=X1+X2,Y1+Y2
  phases = np.array([np.ravel(X), np.ravel(Y)]).T
  return phases   
    


def get_phases_on_axes(freq,angle,num_phases=60,axes=(0,1,2)):
  """
  Returns a set of phases such that the sum of grids with these phases is 
  flat, i.e., all grids cancel out. This is obtained by sampling phases on
  three lines with a length that is the double of the cosine
  period. The three lines are 60 degrees apart and are tilted by 90 degrees
  with respect to the original grid angle.
  """  
  # period of the cosines of the grid with the given parameter
  l=np.sqrt(3)/(2*freq)

  phases_per_axis = num_phases/len(axes)
  dl =l/phases_per_axis
  ran=np.arange(-l,l,dl)+dl/2

  x_phases = np.array([]) 
  y_phases = np.array([])   

  for i in axes:
    x_phases=np.concatenate((x_phases,np.cos(angle+np.pi/6+i*np.pi/3)*ran))
    y_phases=np.concatenate((y_phases,np.sin(angle+np.pi/6+i*np.pi/3)*ran))
    
  phases = np.array((x_phases,y_phases)).T
  return phases
  
  
################################
#### 2D GRIDS AND LATTICE ######
################################


  
def fourier_on_lattice(side,p1_rec,p2_rec,samples,signal,max_harmonic=None,return_harmonics=False):
  """
  Fourier series on a Bravais lattice whose unit vectors are 60 degrees apart
  
  inputs:
  --------
  
  side: side-length of the rhomboidal primary cell of the direct lattice
  p1_rec: primary vector of the reciprocal lattice
  p2_rec: primary vector of the reciprocal lattice
  samples: space samples in the lattice
  signal: signal of which the Fourier transform should be taken. SHAPE: n**2 X num_signals
  max_harmonic: maximum number of harmonic to compute (limit for faster computation)
  
  return_harmonics: if True the harmonics matrices are returned too
  
  output:
  -------

  F: Fourier transform on lattice. SHAPE: (2*max_harmonic)**2 x num_signals
     e.g. if max_harmonic is 2: we have 16 x num_signals, so harmonics go from -max_harmonic to max_harmonic-1

  
  """
  
  if max_harmonic is None:
    max_harmonic=np.int(np.sqrt(len(samples))/2)
    
  s1 = np.dot(samples,p1_rec)
  s2 = np.dot(samples,p2_rec)
  s12 = np.array([s1,s2])
  
  k_ran = np.arange(-max_harmonic,max_harmonic)  
  A,B = np.meshgrid(k_ran,k_ran)
  ab= np.array([np.ravel(A), np.ravel(B)]).T


  F= np.dot(np.exp(-1j*np.dot(ab,s12)),signal)
  
        
  # normalize by multiplying by the area of a rhombus with side-length dphi     
  V=side*side*np.sqrt(3)/2.
  F=F*V/len(samples)
  
  if return_harmonics:
    return A,B,F
  else:
    return F        


def inverse_fourier_on_lattice(side,p1_rec,p2_rec,samples,F):
  """
  Fourier series on a Bravais lattice whose unit vectors are 60 degrees apart
  
  input:
  -------
  
  side: side-length of the rhomboidal primary cell of the direct lattice
  p1_rec: primary vector of the reciprocal lattice
  p2_rec: primary vector of the reciprocal lattice
  samples: space samples in the lattice
  F: signal of which the inverse Fourier transform should be taken  SHAPE: n**2 X num_signals
  
  output:
  -------
  signal: inverse Fourier transfrom of F, SHAPE: n**2 X num_signals
  
  """
  
  max_harmonic =int((np.sqrt(F.shape[0]))/2 )
  
  
  k_ran = np.arange(-max_harmonic,max_harmonic)  

  s1 = np.dot(samples,p1_rec)
  s2 = np.dot(samples,p2_rec)
  
  s12 = np.array([s1,s2])
  A,B = np.meshgrid(k_ran,k_ran)
  ab= np.array([np.ravel(A), np.ravel(B)]).T
  
 
  signal= np.dot(F.T,np.exp(1j*np.dot(ab,s12)))
      
  # normalize      
  V=side*side*np.sqrt(3)/2
  signal=np.real(signal)/V
  
  return signal

def power_on_lattice(side,p1_rec,p2_rec,samples,signal,max_harmonic=None):  
  F=fourier_on_lattice(side,p1_rec,p2_rec,samples,signal,max_harmonic=max_harmonic)
  pw=(F*F.conjugate())/(side**2*np.sqrt(3)/2)
  return pw

def autocorr_on_lattice(side,p1_rec,p2_rec,samples,signal):
  pw=power_on_lattice(side,p1_rec,p2_rec,samples,signal)
  autocorr=inverse_fourier_on_lattice(side,p1_rec,p2_rec,samples,pw)
  return autocorr


def get_mean_pw_on_phase_rhombus(n,nx,L,signal,max_harmonic=10):

  #  signal: signal of which the Fourier transform should be taken. SHAPE: n**2 X num_signals

  gp=GridProps(n,2*np.pi,0)

  # compute noise power
  pw_phi=power_on_lattice(2*np.pi,gp.u1_rec,gp.u2_rec,gp.phases,signal,max_harmonic=max_harmonic).mean(axis=1)
  pw=pw_phi.real.reshape(max_harmonic*2,max_harmonic*2)
  hran=np.arange(2*max_harmonic)-max_harmonic
  
  return hran,pw


def get_mean_pw_on_space_rhombus(nx,L,signal,max_harmonic=10):
  """
  Computes the mean power estimated on a triangular lattice in space
  """
  hran,F=get_spectrum_on_space_rhombus(nx,L,signal,0.,max_harmonic=max_harmonic)
  pw=np.abs(F)**2
  pw_mean=pw.mean(axis=2)
  return hran,pw_mean
  
def get_spectrum_on_space_rhombus(nx,L,signal,grid_angle,max_harmonic=10,do_plot=False,swap_xy=False):
  """
  This compute the mean Fourier spectrum on a lattice for a given firing rate pattern (signal) sampled on a square grid
  The first step is to crop out a rhombus from the square grid where the signal is sampled.
  Afterwards we proceed using the method fourier_on_lattice
  
  """
  
  from matplotlib.path import Path
  
  if len(signal.shape)==1:
    signal=signal[:,np.newaxis]
    
  # space samples
  dx=float(L)/nx
  X,Y=np.mgrid[-L/2.:L/2.:dx,-L/2.:L/2.:dx]
  pos=np.array([np.ravel(X), np.ravel(Y)]).T
  
  # cut out a rhombus in space
  gp_L=GridProps(nx,L,grid_angle)
  R_L=gp_L.R_T
  path=Path(R_L)    
  idxs = path.contains_points(pos)
  pos_romb=pos[idxs,:]
  
  if swap_xy is True:
    # swap x and y if needed
    signal_mat=signal.reshape(nx,nx,signal.shape[1])
    signal_mat=signal_mat.transpose((1,0,2))
    signal=signal_mat.reshape(nx**2,signal.shape[1])
  
  # compute Fourier spectrum on lattice 
  F=fourier_on_lattice(L,gp_L.u1_rec,gp_L.u2_rec,pos_romb,signal[idxs,:],max_harmonic=max_harmonic)

  F=F.reshape(max_harmonic*2,max_harmonic*2,signal.shape[1])
  hran=np.arange(2*max_harmonic)-max_harmonic
  
  if do_plot is True:
    
    # plot the first input
    plot_amp=np.squeeze(np.abs(F[:,:,0]))
    #plot_amp[max_harmonic,max_harmonic]=0
  
    import pylab as pl
    import plotlib as pp
    pl.figure(figsize=(10,5))
    pl.subplots_adjust(wspace=0.3)
    
    pl.subplot(121,aspect='equal')
    pp.plot_on_rhombus(gp_L.R_T,L,0,len(signal[idxs,0]),pos_romb,signal[idxs,0])       
    pl.title('Pattern cropped on lattice')  
    
    pl.subplot(122,aspect='equal')
    pl.pcolormesh(plot_amp)
    pl.title('Fourier amplitude')  
    pp.colorbar()  
  
  return hran,F

def get_tuning_harmonic_masks(grid_harmonic,hran):
  """
  Compute binary masks to select 2D harmonics for Acell and Anoise
  """

  #grid_harmonic=int(L/grid_T)
  HX,HY=np.meshgrid(hran,hran)
    
  tuning_mask=np.zeros_like(HX).astype(bool)
  tuning_mask[(np.abs(HX)==grid_harmonic) & (HY==0)]=True
  tuning_mask[(np.abs(HY)==grid_harmonic) & (HX==0)]=True
  tuning_mask[(HX==grid_harmonic) & (HY==grid_harmonic)]=True
  tuning_mask[(HX==-grid_harmonic) & (HY==-grid_harmonic)]=True
  
  return tuning_mask


def get_phase_lattice(T,grid_angle):  
  """
  Returns the phase-space rhombus (R_T) and the unit vectors of the direct 
  (u_1 and u_2) and reciprocal (u1_rec and u2_rec) lattices of the phase space.  
  """
  
  # phase space
  R_T=get_rhombus(T,np.pi/6+grid_angle)
  
  u1 = T*np.array([np.cos(np.pi/6+grid_angle), np.sin(np.pi/6+grid_angle)])
  u2 = T*np.array([np.cos(grid_angle+np.pi/2), np.sin(grid_angle+np.pi/2)])
  
  U = np.vstack([u1, u2]).T
  U_rec = 2*np.pi*(np.linalg.inv(U)).T
  
  # unit vectors of the reciprocal lattice, note the scaling by 2pi
  u1_rec = U_rec[:,0]
  u2_rec = U_rec[:,1]

  return R_T,u1,u2,u1_rec,u2_rec

def get_phase_samples(n,u1,u2):
  """
  Sample phases in a lattice with unit vectors u1 and u2
  """

  # phase samples
  ran = np.arange(-n/2.,n/2.)/n
  u1_phases = np.array([u1])*ran[:,np.newaxis]
  u2_phases = np.array([u2])*ran[:,np.newaxis]
  
  X1,X2=np.meshgrid(u1_phases[:,0],u2_phases[:,0])
  Y1,Y2=np.meshgrid(u1_phases[:,1],u2_phases[:,1])
  X,Y=X1+X2,Y1+Y2
  phases = np.array([np.ravel(X), np.ravel(Y)]).T
  return phases

    
def get_space_samples(nx,L):

  """
  Sample space in a square lattice of side-length L
  """
  
  # space samples
  ran = np.arange(-nx/2.,nx/2.)/nx
  
  # ortogonal unit vectors for space
  v1=(L)*np.array([0,1])
  v2=(L)*np.array([1,0])
  
  v1_pos = np.array([v1])*ran[:,np.newaxis]
  v2_pos = np.array([v2])*ran[:,np.newaxis]
  
  
  X1,X2=np.meshgrid(v1_pos[:,0],v2_pos[:,0])
  Y1,Y2=np.meshgrid(v1_pos[:,1],v2_pos[:,1])
  X,Y=X1+X2,Y1+Y2
  pos = np.array([np.ravel(X), np.ravel(Y)]).T
  
  return pos



def get_square_signal(N,NX,pos,phases,T):
  """
  Computes a square grid as the sum of two waves
  """
  
  N=len(phases)
  NX=len(pos)

  angles=np.array([np.pi/2*i for i in np.arange(2)])
  k=2*np.pi/T*np.array([np.cos(angles),np.sin(angles)]).T
  
  pos_x = pos[:,0]
  pos_y = pos[:,1]
  
  phases_x = phases[:,0]
  phases_y = phases[:,1]
  
  pp_x = pos_x[np.newaxis,:]+phases_x[:,np.newaxis]
  pp_y = pos_y[np.newaxis,:]+phases_y[:,np.newaxis]
  
  g=np.zeros((N,NX))
  
  for i in range(2):
    g+=np.cos(k[i,0]*pp_x+k[i,1]*pp_y)  
  return g  


def get_grid_signal(N,NX,pos,phases,T,grid_angle):
  """
  Computes a triangular grid as the sum of three waves
  """
  
  N=len(phases)
  NX=len(pos)
  T_cos = T/2*np.sqrt(3) 
  
  angles=np.array([np.pi*i/3+grid_angle for i in np.arange(3)])
  k=2*np.pi/T_cos*np.array([np.cos(angles),np.sin(angles)]).T
  
  pos_x = pos[:,0]
  pos_y = pos[:,1]
  
  phases_x = phases[:,0]
  phases_y = phases[:,1]
  
  pp_x = pos_x[np.newaxis,:]+phases_x[:,np.newaxis]
  pp_y = pos_y[np.newaxis,:]+phases_y[:,np.newaxis]
  
  g=np.zeros((N,NX))
  
  for i in range(3):
    g+=np.cos(k[i,0]*pp_x+k[i,1]*pp_y)  
  return g  

 
  
def clipped_zoom(img, zoom_factor, **kwargs):
  """
  Zooms into an image by keeping the number of pixels constant
  """

  import warnings
  from scipy.ndimage import zoom
  h, w = img.shape[:2]

  # For multichannel images we don't want to apply the zoom factor to the RGB
  # dimension, so instead we create a tuple of zoom factors, one per array
  # dimension, with 1's for any trailing dimensions after the width and height.
  zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2)

  # Zooming out
  if zoom_factor < 1:

      # Bounding box of the zoomed-out image within the output array
      zh = int(np.round(h * zoom_factor))
      zw = int(np.round(w * zoom_factor))
      top = (h - zh) // 2
      left = (w - zw) // 2

      # Zero-padding
      out = np.zeros_like(img)
      
      with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        out[top:top+zh, left:left+zw] = zoom(img, zoom_tuple, **kwargs)

  # Zooming in
  elif zoom_factor > 1:

      # Bounding box of the zoomed-in region within the input array
      zh = int(np.ceil(h / zoom_factor))
      zw = int(np.ceil(w / zoom_factor))
      top = (h - zh) // 2
      left = (w - zw) // 2
      
      with warnings.catch_warnings():
        warnings.simplefilter("ignore")        
        out = zoom(img[top:top+zh, left:left+zw], zoom_tuple, **kwargs)
      
      # `out` might still be slightly larger than `img` due to rounding, so
      # trim off any extra pixels at the edges
      trim_top = ((out.shape[0] - h) // 2)
      trim_left = ((out.shape[1] - w) // 2)
      out = out[trim_top:trim_top+h, trim_left:trim_left+w]

  # If zoom_factor == 1, just return the input array
  else:
      out = img
  return out
    

def normalize_grid(L,nx,r_map,target_harmonic,est_angle,est_spacing,do_plot=False,verbose=False):
  """
  Normalize grid to fixed spacing and orientation in order to compute the grid tuning index.
  We normalize the grid to a spacing of L/target_harmonic where L is the arena length and target_harmonic is an integer
  The default orientation is zero angle (aligned to vertical border)
  
  L: side length of the arena
  nx: number of samples per dimension
  r_map: one firing rate map in the arena
  target_harmonic: an integer we normalize to a spacing of L/target_harmonic 
  """
  
  assert(type(target_harmonic)==int and target_harmonic>1 )
  
  # make sure the data is 2D
  r_map=r_map.reshape(nx,nx)
  
  # estimate spacing and orientation
  #cx=norm_autocorr(r_map)  
  #st_angle,est_spacing=get_grid_spacing_and_orientation(cx,float(L)/nx,doPlot=False)  
  #print 'est_angle: %.2f est_spacing: %.2f'%(est_angle*180/np.pi,est_spacing)
  
  # rescale and rotate to normalized space  
  zoom_factor=(float(L)/target_harmonic)/est_spacing
  rotation_angle=np.remainder(est_angle*180/np.pi,60)
  if rotation_angle>30.:
    rotation_angle=rotation_angle-60.
  
  if verbose:
    print 'zoom: %.2f rotation: %.2f '%(zoom_factor, rotation_angle)
  r_map_rescaled=clipped_zoom(r_map,zoom_factor) 
  
  assert np.all(r_map_rescaled.shape==(nx,nx))
  r_map_norm=rotate(r_map_rescaled, rotation_angle,reshape=False)
  
  # plotting  
  if do_plot is True:
    import pylab as pl  
    import plotlib as pp
    
    pl.figure(figsize=(12,5))
    pl.subplots_adjust(wspace=0.3)

    pl.subplot(131,aspect='equal')
    pl.pcolormesh(r_map.T)  
    pp.colorbar()
    pl.title('Original')
    
    pl.subplot(132,aspect='equal')
    pl.pcolormesh(r_map_rescaled.T)    
    pp.colorbar()               
    pl.title('Rescaled')
    
    pl.subplot(133,aspect='equal')
    pl.pcolormesh(r_map_norm.T)
    pp.colorbar()
    pl.title('Rescaled and Rotated')
    
  return r_map_norm  




def comp_grid_tuning_index(L,nx,r_maps,do_plot=False,verbose=False,warnings=False,return_tuning_harmonics=False):
  """
  Compute grid tuning index for a batch of firing rate maps
  L: side length of the arena
  nx: number of samples per dimension
  r_maps: nx**2 X N firing rate maps
  """
  
  assert(r_maps.min()>=0)
  
  max_harmonic=10
  
  if len(r_maps.shape)==1:
    r_maps=r_maps[:,np.newaxis]

  num_cells=r_maps.shape[1]
  assert(r_maps.shape[0]==nx**2)
    
  dx=float(L)/nx

  r_maps_norm=np.zeros_like(r_maps)  
  target_harmonics=np.zeros(num_cells)
  
  for cell_idx in xrange(num_cells):
    if verbose:
      print 'Normalizing %d/%d'%(cell_idx,num_cells)
    
    plot_curr_cell=do_plot & (cell_idx==0)
    
    r_map=r_maps[:,cell_idx]

   
    # compute autocorrelation and estimate grid spacing and orientation  
    cx=norm_autocorr(r_map.reshape(nx,nx))
    est_angle,est_spacing=get_grid_spacing_and_orientation(cx,dx,doPlot=plot_curr_cell)
    if np.isnan(est_angle):
      est_angle=0.
      if warnings:
        print 'Cannot estimate angle for cell_idx=%d, pattern will not be rotated'%cell_idx
    
    if  not np.isnan(est_spacing):
      target_harmonic=int(round(L/est_spacing))
      
      if verbose:
        print 'target_harmonic=%d for  cell_idx=%d'%(target_harmonic,cell_idx)
      
      if target_harmonic>1:
        # normalize firing rate map to scale L/target_harmonic and angle zero
        r_map_norm = normalize_grid(L,nx,r_map,target_harmonic,est_angle,est_spacing,do_plot=plot_curr_cell,verbose=verbose)
      else:
        if warnings:
          print 'WARNING: Target harmonic smaller than 2 for cell_idx=%d'%cell_idx
        target_harmonic=np.nan
        r_map_norm=np.zeros(nx**2)
        
    else:
      if warnings:
        print 'WARNING: Cannot estimate spacing for cell_idx=%d, setting grid-tuning index to 0'%cell_idx
      target_harmonic=np.nan
      r_map_norm=np.zeros(nx**2)

    # save normalized map and target harmonic for this pattern
    r_maps_norm[:,cell_idx]=r_map_norm.reshape(nx**2)
    target_harmonics[cell_idx]=target_harmonic
    
    
  # compute the Fourier amplitude of all the patterns on space rhombus
  hran,F=get_spectrum_on_space_rhombus(nx,L,r_maps_norm,0.,max_harmonic,do_plot=do_plot)
  F_amp=np.abs(F)
  
  grid_tuning_indexes=np.zeros(num_cells)
  
  for cell_idx in xrange(num_cells):
    if not np.isnan(target_harmonics[cell_idx]):
      # compute mean power at the target harmonic and at DC
      tuning_mask=get_tuning_harmonic_masks(target_harmonics[cell_idx],hran)
      tuning_amp=F_amp[tuning_mask,cell_idx].mean(axis=0)    
      dc_value=F_amp[max_harmonic,max_harmonic,cell_idx] #the DC is at position max_harmonic 
  
      # compute the tuning index
      grid_tuning_indexes[cell_idx]=tuning_amp/dc_value
  
#  
#  # compute the Fourier amplitude of all the patterns on space rhombus
#  hran,F=get_spectrum_on_space_rhombus(nx,L,r_maps_norm,0.,max_harmonic,do_plot=do_plot)
#  F_amp=np.abs(F)
#  
#  # compute mean power at the target harmonic and at DC
#  tuning_mask=get_tuning_harmonic_masks(L,float(L)/target_harmonic,hran)
#  tuning_amp=F_amp[tuning_mask,:].mean(axis=0)    
#  dc_values=F_amp[max_harmonic,max_harmonic,:] #why maxharmonic here?
#  
#  # compute the tuning index
#  grid_tuning_indexes=tuning_amp/dc_values
#  
  
  
  assert np.all(np.isfinite(grid_tuning_indexes))
  
  if return_tuning_harmonics:
    return target_harmonics,grid_tuning_indexes
  else:
    return grid_tuning_indexes
  

  
#################
#### TESTING ####
#################

def get_test_grids_spacing_range(L,nx,num_grids,spacing_min,spacing_max,zero_phases=True,grid_angle=0):
  

  spacings=np.linspace(spacing_min,spacing_max,num_grids)
  norm_phases = np.zeros((2,num_grids)).T


  SX,SY,tiles=get_tiles(L,float(L)/nx)
  
  grids=np.zeros((nx,nx,num_grids))
  
  for idx,spacing in enumerate(spacings):

    max_grid_phase_x = 2.0*spacing
    max_grid_phase_y = spacing*np.sqrt(3)


    phase=norm_phases[idx,:]
    phase[0]*=max_grid_phase_x
    phase[1]*=max_grid_phase_y

  
    grid=simple_grid_fun(tiles,spacing,angle=grid_angle,phase=phase)
    grids[:,:,idx]=grid.reshape(nx,nx)
  return grids,spacings



def get_test_grids_all_angles(L,nx,num_grids,grid_spacing,zero_phases=True):
  
  ang_range=np.arange(0,np.pi/3,np.pi/3/num_grids)
  max_grid_phase_x = 2.0*grid_spacing
  max_grid_phase_y = grid_spacing*np.sqrt(3)
  phases = np.zeros((2,num_grids)).T
  phases[:,0] = rand(num_grids)*max_grid_phase_x
  phases[:,1] = rand(num_grids)*max_grid_phase_y

  SX,SY,tiles=get_tiles(L,float(L)/nx)
  
  grids=np.zeros((nx,nx,num_grids))
  
  for idx,ang in enumerate(ang_range):
    
    grid=simple_grid_fun(tiles,grid_spacing,angle=ang,phase=phases[idx,:])
    grids[:,:,idx]=grid.reshape(nx,nx)
  return grids,ang_range  
    
  
def test_orientation_and_spacing_detection():
  """
  A function to test grid orientation detection
  """
  import pylab as pl  

  L=2
  nx=100
  num_grids=25
  grid_spacing=.5
  
  grids,angles=get_test_grids_all_angles(L,nx,num_grids,grid_spacing,zero_phases=True)

  
  pl.figure(figsize=(10,10))

  for idx in xrange(num_grids):
    ax=pl.subplot(5,5,idx+1,aspect='equal')
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.set_frame_on(False)
    grid=grids[:,:,idx]
    cx=norm_autocorr(grid)
    pl.axis('equal')    
    est_angle,est_spacing=get_grid_spacing_and_orientation(cx,float(L)/nx,doPlot=True,ax=pl.gca())

    ang_deg=angles[idx]*360/(2*np.pi)
    est_angle_deg=np.remainder((est_angle*360/(2*np.pi))-30,60)
    pl.text(20,20,'Real Ang.: %.2f\nEst Ang.: %.2f'%(ang_deg,est_angle_deg),fontsize=9,color='k',weight='bold',bbox={'facecolor':'white','edgecolor':'white'})
    pl.text(20,80,'Real Sp.: %.2f\nEst Sp.: %.2f'%(grid_spacing,est_spacing),fontsize=9,color='k',weight='bold',bbox={'facecolor':'white'})
  pl.subplots_adjust(hspace=0.1,wspace=0.1,left=0.05,right=0.95,top=0.95,bottom=0.05)