'''
Function:     sim_vs_net

Arguments:    param_dict - Parameter dictionary, expected variables and definitions below. See run_time_vars at end of script for 
                           default parameter values 
              movie (optional) - The movie to supply as input to the system. If a movie is not provided, one will be
                                 generated using the subdictionary param_dict['image']
                                 
              Expected contents of the dictionary param_dict:
              
              g_gap - Axo-axonal gap junction coupling strength. (nS)
              g_inh_scale - Scaling factor for the negative conductance electrical coupling between VS1 and VS10. The
                            strength of the connection is set to g_gap*g_inh_scale. (no units)
              g_L_dend - Dendritic leak conductance. (nS)
              g_L_axon - Axonal leak conductance. (nS)
              g_axon_dend - Axo-dendritic electric coupling strength. (nS)
              g_exc_in - Excitatory input conductance strength. (nS)
              g_inh_in - Inhibitory input conductance strength. (nS)
              E_E - Excitatory input reversal potential. (mV)
              E_I - Inhibitory input reversal potential. (mV)
              tau_m - Axonal and dendritic membrane time constant. (ms)
              noise_std - Intensity of intrinsic white noise fluctuations in axonal and dendritic compartments. (pA) 
              tau_hp - Time constant of the high-pass filter component of the Reichardt detectors. (ms)
              tau_lp - Time constant of the low-pass filter component of the Reichardt detectors. (ms)
              rot_angle - Horizontal rotation axis. Rotation about this axis induces apparent clockwise rotation of the
                          external world to the fly (i.e., it corresponds to a counter-clockwise roll of the fly about
                          this axis). (deg)
                          For the stripe stimuli, rot_angle is the azimuthal center of the stripe stimulus. (deg)
              deg_per_ms - Speed of the clockwise rotation about the horizontal axis at the angle rot_angle. (deg/ms)
                           For the stripe stimulus, this is the speed of downward motion. (deg/ms)
              T - Total simulation duration. (ms)
              dt_sim - Time step of integration for the membrane potentials. (ms)
              dt_im - Time step used for generating the movie. The induced conductances are then linearly interpolated
                      to the time step dt_sim for integration.
              image - (required if no movie is supplied) Image parameter dictionary, with parameters
                          type - Image type, determines which image generation module is called
                          x_pixels - Number of pixels along the horizontal direction
                          y_pixels - Number of pixels along the vertical direction
              
           
Output:       V_axon_L - Time course of the left-side VS axonal responses at times in the interval [0,T-dt_sim] with 
                         time step dt_sim. Dimension is (10,T/dt_sim), and V_axon_L[i,:] is the response of the neuron
                         VS(i+1). (mV)
              V_axon_R - Same as V_axon_L, but for right-side VS axonal responses.
           
Description:  Simulate the blow fly VS network. The method of simulation is adapted from that of Borst and Weber (2011)
              "Neural action fields for optic flow based navigation: a simulation study of the fly lobula plate 
              network...". Includes only the portion corresponding to the left- and right-side VS network, and the
              "repulsive" coupling which approximates impact of the Vi, Vi2 cells. Generates a movie corresponding to
              a rotation of a generated spherical image about the given horizontal axis. Alternatively, can take as an
              argument a user-generated movie. For details, refer to:
                    Near optimal decoding of transient stimuli from coupled neuronal subpopulations
                    by James Trousdale, Sam Carroll, Fabrizio Gabbiani, and Kresimir Josic.
                    J. Neurosci. ??:??-??..

Authors:     James Trousdale - jamest212@gmail.com
             * Adapted from MATLAB code provided by Drs. Y. Elyada and A. Borst
'''


__all__ = ["sim_vs_net","reichardt_loc","reichardt_filter"]

import numpy as np
from scipy.interpolate import interp1d as interp1d
from scipy.signal import sepfir2d as sepfir2d
from utilities import ang_dist,repmat_3d,my_2d_hist

#                                             #
#                                             #
###### definition of function sim_vs_net ######
#                                             #
#                                             #

def sim_vs_net(param_dict,movie=[]):
    
    rs_ctr = np.array([10, 26, 42, 58, 74, 90, 106, 122, 138, 154]) # Gaussian receptive field horizontal centers
    sig_th = 12; # Gaussian receptive field horizontal standard deviation (degrees)
    sig_phi = 60; # Gaussian receptive field vertical standard deviation (degrees)
    
    
    T = param_dict['T']
    dt_sim = param_dict['dt_sim']
    dt_im = param_dict['dt_im']
    
    g_gap = float(param_dict['g_gap'])
    g_inh = g_gap*param_dict['g_inh_scale']
    g_L_dend = float(param_dict['g_L_dend'])
    g_L_axon = float(param_dict['g_L_axon'])
    g_axon_dend = float(param_dict['g_axon_dend'])
    
    g_exc_in = float(param_dict['g_exc_in'])
    g_inh_in = float(param_dict['g_inh_in'])
    E_E = float(param_dict['E_E'])
    E_I = float(param_dict['E_I'])
    
    tau_m = float(param_dict['tau_m'])
    noise_std = float(param_dict['noise_std'])
    
    tau_hp = float(param_dict['tau_hp'])
    tau_lp = float(param_dict['tau_lp'])
    
    
    # Check to see if a movie was provided. If one was, use it. If not, generate one
    # according to the dictionary param_dict['image']. At minimum, param_dict['image']
    # must contain a variable 'type' indicating the image type (i.e., white, bar or natural),
    # and the horizontal/vertical dimensions.
    # Refer to the respective image generation modules for more details.
    if np.size(movie) == 0:
        import rotate_image as ri
        
        # Check the image type, call the appropriate image generation module, then generate the movie
        if param_dict['image']['type'] == 'white':
            image = ri.white_image(param_dict['image']['x_pixels'],param_dict['image']['y_pixels'],
                                   param_dict['image']['upsample'])
            movie = ri.rotate_sphere(image,param_dict['rot_angle'],param_dict['deg_per_ms'],T,dt_im)
        elif param_dict['image']['type'] == 'bar':
            image = ri.bar_image(param_dict['image']['x_pixels'],param_dict['image']['y_pixels'],
                                 param_dict['image'])
            movie = ri.rotate_sphere(image,param_dict['rot_angle'],param_dict['deg_per_ms'],T,dt_im)
        elif param_dict['image']['type'] == 'natural':
            image = ri.natural_image(param_dict['image']['x_pixels'],param_dict['image']['y_pixels'],
                                     param_dict['image']['image_lib_path'])
            movie = ri.rotate_sphere(image,param_dict['rot_angle'],param_dict['deg_per_ms'],T,dt_im)
        elif param_dict['image']['type'] == 'stripe':
            movie = ri.vert_strip(param_dict['image']['x_pixels'],param_dict['image']['y_pixels'],
                                  param_dict['rot_angle'],param_dict['image']['strip_rad'],
                                  param_dict['image']['freq_s'],param_dict['deg_per_ms'],T,dt_im)
            
    # Logarithmic scaling of the movie accounts for the signal transduction which occurs at the level of 
    # retinal photoreceptors
    movie = np.log10(movie+1) - np.mean(np.log10(movie+1))
    
    
    # Define the vectors of membrane conductances and capacitances. 
    # Indexing of cells is consistent throughout, and is as follows
    #     Even indices --- dendrites, in increasing order (so 0 corresponds to VS1 dendrite, 2 to VS2 dendrite, etc.)
    #     Odd indices --- axons, in increasing order (so 1 corresponds to VS1 axon, etc.)
    g_L = np.zeros(20) 
    g_L[np.arange(0,20,2)] = g_L_dend
    g_L[np.arange(1,20,2)] = g_L_axon
    C_m = tau_m*g_L 
    
    
    # Construct the electrical coupling weight matrix. Dimension is (20,20), corresponding to the total number of
    # axons and dendrites on one side (the two sides are not coupled presently). This matrix allows the system to be
    # written in the form
    #            V' = GV + noise/input
    G = np.zeros((20,20))
    
    # Couple axons and dendrites of each cell
    for i in np.arange(1,20,2): 
        G[i-1,i] = g_axon_dend
        G[i,i-1] = g_axon_dend
        
    # Couple consecutive axons
    for i in np.arange(3,20,2):
        G[i-2,i] = g_gap;
        G[i,i-2] = g_gap;
    
    # Couple the VS1 and VS10 dendrites via a "repulsive" synapse (electrical synapse with negative conductance)
    # as in Weber, et al. (2008) "Eigenanalysis of a neural network..." --- emulates the effect of including the
    # inhibitory spiking neurons Vi, Vi2.
    G[0,18] = -g_inh; G[18,0] = -g_inh;
    
    # Normalize the diagonal of the weight matrix by the negatives of couplings with other compartments, reflecting the
    # bidirectionality of the gap junctions.
    for i in range(20):
        G[i,i] = -np.sum(G[i,:])-g_L[i]
        
    
    # Call the reichardt_loc function to get the pixel coordinates of the Reichardt detectors
    reichardt_locs = reichardt_loc(np.size(movie,1),np.size(movie,0))
    
    # Apply the Reichardt detector filtering to the movie
    (reich_up,reich_down) = reichardt_filter(movie,reichardt_locs,tau_hp,tau_lp,dt_im)
    
    # Smooth the upward and downward components of the Reichardt detector-filtered image by a 1x3 box filter, 
    # separately in each dimension
    conv_kernel_1d = np.ones(3)/3
    for i in range(np.size(reich_up,2)):
        reich_up[:,:,i] = sepfir2d(reich_up[:,:,i],conv_kernel_1d,conv_kernel_1d)
        reich_down[:,:,i] = sepfir2d(reich_down[:,:,i],conv_kernel_1d,conv_kernel_1d)
        
        
        
    # Compute receptive fields as a product of Gaussians corresponding to horizontal and vertical components.
    recf_L = np.zeros((10,np.size(movie,0),np.size(movie,1)))
    recf_R = np.copy(recf_L)
    
    # phi_exp contains the Gaussian for the vertical component of the receptive fields.
    phi_exp = np.reshape(np.exp(-(np.linspace(-90,90,np.size(movie,0),endpoint=False)**2)/(2*sig_phi**2)),
                         (np.size(movie,0),1))
    
    theta_vals = np.linspace(0,360,np.size(movie,1),endpoint=False) # Discretization of horizontal angles
    
    # For each cell, we compute the Gaussian corresponding to the horizontal component of the receptive field, with
    # the usual difference in the exponent replaced with an angular distance, and mean equal to the horizontal center
    # of each cells receptive field. The horizontal component is reshaped to be dimension (1,x_pixels), and is
    # multiplied with the (y_pixels,1) vertical component, giving the dimension (y_pixels,x_pixels) receptive field,
    # matching the dimension of the movie being supplied as input. The receptive field is then normalized to sum to 1
    # over the pixels which correspond to the location of a Reichardt detector --- pixels without a detector are not
    # registered, and have no impact on the activity of the system.
    for n in range(10):
        th_exp_L = np.reshape(np.exp(-ang_dist(theta_vals,-rs_ctr[n])**2/(2*sig_th**2)),(1,np.size(movie,1)))
        th_exp_R = np.reshape(np.exp(-ang_dist(theta_vals,rs_ctr[n])**2/(2*sig_th**2)),(1,np.size(movie,1)))
        
        recf_L[n,:,:] = np.dot(phi_exp,th_exp_L) 
        recf_L[n,:,:] = recf_L[n,:,:]/np.sum(recf_L[n,:,:]*reichardt_locs)
        
        recf_R[n,:,:] = np.dot(phi_exp,th_exp_R)
        recf_R[n,:,:] = recf_R[n,:,:]/np.sum(recf_R[n,:,:]*reichardt_locs)
        
    
    # Before beginning the simulation, compute the number of time steps at the coarse scale (time step for the
    # image rotation), and the fine scale (time step for the integration of the membrane potentials).
    coarse_steps = int(T/dt_im)+1
    time_steps = int((T-dt_sim*0)/dt_sim)
    
    t_range = np.linspace(0,T,time_steps,endpoint=False)
    
    g_exc_L = np.zeros((20,time_steps))
    g_inh_L = np.zeros((20,time_steps))
    g_exc_R = np.zeros((20,time_steps))
    g_inh_R = np.zeros((20,time_steps))
    
    # Compute the synaptic conductances induced by the image generation for each cell
    for n in range(10):
        # For each side, first duplicate the receptive field weight matrix computed above along the time dimension,
        # then take a matrix dot product (i.e., a Hadamard product) of the duplicated receptive field with the 3-D
        # array of Reichardt detector outputs, taken at each time step. Summing these values at each time step give
        # the excitatory (downward motion) and inhibitory (upward motion) input conductances.
        rep_recf = repmat_3d(recf_L[n,:,:],coarse_steps)
        g_exc_L_coarse = np.sum(np.sum(rep_recf*reich_down,axis=0),axis=0)
        g_inh_L_coarse = np.sum(np.sum(rep_recf*reich_up,axis=0),axis=0)
    
        rep_recf = repmat_3d(recf_R[n,:,:],coarse_steps)
        g_exc_R_coarse = np.sum(np.sum(rep_recf*reich_down,axis=0),axis=0)
        g_inh_R_coarse = np.sum(np.sum(rep_recf*reich_up,axis=0),axis=0)
        
        # Conductances were initially computed at the coarse timescale. First use interp1d to define the linear
        # interpolating function for the excitatory and inhibitory conductances on each side...
        interp_g_exc_L = interp1d(np.linspace(0,T,coarse_steps,endpoint=True),g_exc_L_coarse)
        interp_g_inh_L = interp1d(np.linspace(0,T,coarse_steps,endpoint=True),g_inh_L_coarse)
        
        interp_g_exc_R = interp1d(np.linspace(0,T,coarse_steps,endpoint=True),g_exc_R_coarse)
        interp_g_inh_R = interp1d(np.linspace(0,T,coarse_steps,endpoint=True),g_inh_R_coarse)
        
        # ... then compute the conductances at each of the finer scale time steps using the linear interpolation.
        g_exc_L[2*n,:] = interp_g_exc_L(t_range)*g_exc_in
        g_inh_L[2*n,:] = interp_g_inh_L(t_range)*g_inh_in
    
        g_exc_R[2*n,:] = interp_g_exc_R(t_range)*g_exc_in
        g_inh_R[2*n,:] = interp_g_inh_R(t_range)*g_inh_in
    
    
    # Pre-generate the white noise injected in to each dendritic and axonal compartment.    
    dW_L = 1.0/np.sqrt(dt_sim)*np.random.randn(20,time_steps)*noise_std
    dW_R = 1.0/np.sqrt(dt_sim)*np.random.randn(20,time_steps)*noise_std
    
    
    V_L = np.zeros(np.shape(g_exc_L))
    V_R = np.zeros(np.shape(g_exc_R))
    
    # The incoming synaptic current is the sum of the synaptic conductances, weighted by the reversal potentials.
    syn_curr_L = g_exc_L*E_E + g_inh_L*E_I
    syn_curr_R = g_exc_R*E_E + g_inh_R*E_I
    
    
    # Integrate the membrane potentials via an implicit, first order scheme
    H = np.diag(C_m/dt_sim) - G
    for i in range(1,np.size(syn_curr_L,1)):
        V_L[:,i] = np.linalg.solve(H + np.diag(g_exc_L[:,i]) + np.diag(g_inh_L[:,i]),
                                   syn_curr_L[:,i] + V_L[:,i-1]*C_m/dt_sim + np.sqrt(tau_m)*dW_L[:,i])
        V_R[:,i] = np.linalg.solve(H + np.diag(g_exc_R[:,i]) + np.diag(g_inh_R[:,i]),
                                   syn_curr_R[:,i] + V_R[:,i-1]*C_m/dt_sim + np.sqrt(tau_m)*dW_R[:,i])
        
        
    # Return only the axonal potentials (could easily be changed to return dendritic potentials as well...)
    V_axon_L = V_L[np.arange(1,20,2),:]
    V_axon_R = V_R[np.arange(1,20,2),:]
    return (V_axon_L,V_axon_R)




'''
Function:     reichardt_loc

Arguments:    x_pixels - number of horizontal pixels which will be utilized in images and movies
              y_pixels - number of vertical pixels which will be utilized in images and movies
           
Output:       A dimension (y_pixels,x_pixels) binary array, with 1's corresponding to pixels containing a Reichardt
              detector.
           
Description:  Given horizontal/vertical image dimensions x_pixels/y_pixels, spreads 10,000 Reichardt detectors 
              approximately evenly on the surface of a sphere using the Golden Section spiral algorithm. For reference, 
              see Patrick Boucher's blog post at http://www.softimageblog.com/archives/115 (posted October 4, 2006; 
              last accessed June 25, 2014). For 10,000 detectors, the approximate separation between detectors is 
              between 2 and 3 degrees, in the biologically realistic range.

Authors:     James Trousdale - jamest212@gmail.com
             * Adapted from code provided by Patrick Boucher
'''


#                                                #
#                                                #
###### definition of function reichardt_loc ######
#                                                #
#                                                #

def reichardt_loc(x_pixels,y_pixels):
    
    num_detectors = 10000
    
    theta_edges = np.linspace(0,360,x_pixels+1,endpoint=True)
    phi_edges = np.linspace(0,180,y_pixels+1,endpoint=True)
    
    # Compute the rectangular coordinates (x,y,z) of the location on the unit sphere of the Reichardt detectors.
    inc = np.pi*(3-np.sqrt(5))
    off = 2.0/num_detectors
    i_vec = np.array(range(num_detectors))
    y = i_vec*off - 1 + off/2
    r = np.sqrt(1-y*y)
    phi = i_vec*inc
    x = np.cos(phi)*r
    z = np.sin(phi)*r
    
    # Compute spherical coordinates of the pixels which correspond to the location of a Reichardt detector.
    theta_coords = np.mod(np.arctan2(y,x),2*np.pi)*180/np.pi
    phi_coords = np.arccos(z)*180/np.pi
    
    # Return an array which has a 1 in the entries corresponding to pixels with a detector
    return my_2d_hist(theta_coords,phi_coords,theta_edges,phi_edges)
        
        


'''
Function:     reichardt_filter

Arguments:    movie - a movie (here, generated by rotation of an image about an axis) of dimension
                      (y_pixels,x_pixels,time_steps)
              reichard_locs - binary array of dimension (y_pixels,x_pixels), with 1 corresponding to a pixel containing
                              a Reichardt detector
              tau_hp - time constant of the high-pass filter in each half of the Reichardt detector
              tau_lp - time constant of the low-pass filter in each half of the Reichardt detector
              dt - time step of integration for the application of the Reichardt filtering
           
Output:       reich_up - Upward component of motion detected by the Reichardt filtering
              reich_down - Downward component of motion detected by the Reichardt filtering
           
Description: The reichardt_filter function filters the given movie through an array of Reichardt detectors, returning
             components corresponding to upward and downward motion. The type of detector implemented here makes use of
             a total of four filters --- two high pass and two low pass. For details, see Borst and Weber (2011) "Neural Action
             Fields for Optic Flow Based Navigation: A Simulation Study of the Fly Lobula Plate Network", 
             PLoS ONE 6(1): e16303. doi:10.1371/journal.pone.0016303.

Authors:     James Trousdale - jamest212@gmail.com
             * Adapted from code provided by Drs. Y. Elyada and A. Borst
'''

#                                                   #
#                                                   #
###### definition of function reichardt_filter ######
#                                                   #
#                                                   #


def reichardt_filter(movie,reichardt_locs,tau_hp,tau_lp,dt):
    
    y_res = 180.0/np.size(movie,0) # Vertical size of pixels (in degrees).
    
    dims = np.array(np.shape(movie))
    
    high_pass = np.zeros(dims)
    low_pass = np.zeros(dims)
    reich = np.zeros(dims)
    
    high_pass_const = dt/(tau_hp + dt)
    low_pass_const = dt/(tau_lp + dt)
    
    # Perform high- and low-pass filtering in time of the movie
    for i in range(1,np.size(movie,2)):
        high_pass[:,:,i] = high_pass_const*movie[:,:,i] + (1-high_pass_const)*high_pass[:,:,i-1]
        low_pass[:,:,i] = low_pass_const*movie[:,:,i] + (1-low_pass_const)*low_pass[:,:,i-1]
    high_pass = movie - high_pass
    
    
    # Each detector is composed of two detector subunits separated by a vertical distance of two degrees. Thus,
    # the output of a detector cross-multiplies the high-pass component of the corresponding pixel with the low pass 
    # component of the pixel two degrees above the pixel corresponding to the detectors location, and vice-versa, then
    # subtracts the two. Under this formalism, a negative response indicates generally upward motion across the pixel
    # corresponding to the detector, and a positive response indicates downward motion.
    reich[int(2.0/y_res):,:,:] = low_pass[:-int(2.0/y_res),:,:]*high_pass[int(2.0/y_res):,:,:] \
                                    - high_pass[:-int(2.0/y_res),:,:]*low_pass[int(2.0/y_res):,:,:]
        
    
    # Separate the filtered movie into upward and downward components.  
    reich_up = np.copy(reich)
    reich_up[reich_up > 0] = 0
    reich_down = reich
    reich_down[reich_down < 0] = 0
    
    # Zero all pixels which do not correspond to the location of a Reichardt detector.
    reich_up = -reich_up*repmat_3d(reichardt_locs,np.size(movie,2))
    reich_down = reich_down*repmat_3d(reichardt_locs,np.size(movie,2))
    
    return (reich_up,reich_down)
    
    
#                                                 #
#                                                 #
###### default code executed when sim_vs_net ###### 
###### is called without parameters          ######
#                                                 #
#                                                 #
    

if __name__ == "__main__":
    import matplotlib.pyplot as plt

    #enumerate the different types of stimuli
    stim_type = [ 'bar', 'white', 'natural', 'stripe']
    
    #select one stimulus type here by changing the index from 0 to 3
    stim = stim_type[3]
    
    run_time_vars = {'g_gap':1000, # nS
                     'g_inh_scale':0.06, # Factor which multiplies g_gap to give g_inh
                     'g_L_dend':180, # nS
                     'g_L_axon':30, # nS
                     'g_axon_dend':110, #nS
                     'g_exc_in':10000, # nS
                     'g_inh_in':15000, # nS
                     'E_E':60, # mV
                     'E_I':-40, # mV,
                     'tau_m':1.4, # ms
                     'noise_std':6, #10**2.5, # pA
                     'tau_hp':50, # ms
                     'tau_lp':20, # ms
                     'dt_sim':0.01, # ms
                     'dt_im':1, # ms
                     'rot_angle':90,
                     'deg_per_ms': 0.5,
                     'T':100,
                     } 
    
    if stim == 'bar':
        run_time_vars['image'] = {'type':'bar','x_pixels':360,'y_pixels':180,'num_bars':25,'arc_length':40,'arc_width':5}
    elif stim == 'white': 
        run_time_vars['image'] ={'type':'white','x_pixels':360,'y_pixels':180,'upsample':4}
    elif stim == 'natural':                    
        run_time_vars['image'] = {'type':'natural','x_pixels':360,'y_pixels':180,'image_lib_path':'./rotate_image/'}
    elif stim == 'stripe':
        run_time_vars['rot_angle'] = 0
        run_time_vars['deg_per_ms'] = 0.125
        run_time_vars['image'] = {'type':'stripe','x_pixels':360,'y_pixels':180,'strip_rad':5,'freq_s':8}


    (V_L,V_R) = sim_vs_net(run_time_vars)


    T = run_time_vars['T']
    t = np.linspace(0,T,T/run_time_vars['dt_sim'],endpoint=False)
    
    fig1 = plt.figure()
    ax1 = fig1.add_subplot(121)
    ax2 = fig1.add_subplot(122)
    for i in range(10):
        ax1.plot(t,V_L[i,:].T,label='VS'+str(i+1))
        ax2.plot(t,V_R[i,:].T)
    
    ax1.legend()
    ax1.set_title('Left side responses')
    ax2.set_title('Right side responses')
    disp_lim = np.max([np.max(np.abs(V_L)),np.max(np.abs(V_R))])*1.1
    ax1.set_ylim([-disp_lim,disp_lim])
    ax2.set_ylim([-disp_lim,disp_lim])
    
    plt.show()