#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

Author: Natalie Ness, 2021

Grid cell-to-place cell transformation with AD-related synapse loss simulation

Relating to Ness N, Schultz SR. 'A computational grid-to-place-cell transformation 
model indicates a synaptic driver of place cell impairment in early-stage Alzheimer’s Disease' 

"""
# %%

import numpy as np 
import csv
import matplotlib.pyplot as plt 
from matplotlib.ticker import FormatStrFormatter
import pandas as pd

# %% Define Functions 

def initialize_place_cells(n_pc, n_grids):
    """ Initialize place cell population

    Parameters
    ----------
    n_pc : int
        Number of place cells.
    n_grids : int
        Number of grid cells.

    Returns
    -------
    place_cells : dict
        Dictionary with initialised place cell synapses and synaptic weights arrays.

    """
    #dictionary to store place_cell info 
    place_cells = {}
    #Number GC-to-PC synapses, with 1200 synapses in a network of 311,500 place cells. Scaled depending on number of PCs.
    n_syn = int(np.ceil(1200*(n_pc/311500)))
    #Randomly assign synapses between GCs and PCs 
    gc_input_binary = np.random.permutation(np.hstack((np.ones([1,n_syn]), np.zeros([1,n_grids-n_syn])))[0])
    PC_synapses = np.zeros([n_grids, n_pc])
    for i in range(0,n_pc):
        PC_synapses[:,i] = np.random.permutation(gc_input_binary[:])
    place_cells['synapses'] = PC_synapses
    #initialise synaptic weights randomly from synaptic strength pool 
    weights = np.random.choice(synaptic_strength_pool, size=(n_grids, n_pc), replace=True)
    weights = np.multiply(PC_synapses, weights) 
    place_cells['weights'] = weights
    return place_cells

def turnover(place_cells):
    """ Synaptic turnover of grid cell-to-place cell synapses

    Parameters
    ----------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.

    Returns
    -------
    place_cells : dict
        Dictionary with place cell  with updated synapses and synaptic weights arrays.

    """
    synapses = place_cells['synapses']
    weights = place_cells['weights']
    n_pc = synapses.shape[1]
    n_turnover = int(np.ceil(114*(n_pc/311500)))  #number of synapses to be replaced based on 114 in a network of 311,500 PCs

    for c in range(synapses.shape[1]):
        #turnover synapses
        lost_synapses = np.where(synapses[:,c] == 1.0)[0] 
        #check if there are enough synapses to turnover n_turnover synapses despite AD-related loss
        if len(lost_synapses) >= n_turnover: 
            lost_synapses = np.random.choice(lost_synapses, size=n_turnover, replace=False)
            synapses[lost_synapses,c] = 0.
            weights[lost_synapses,c] = 0. 
            new_synapses = np.where(synapses[:,c] == 0.0)[0]
            new_synapses = np.random.choice(new_synapses, size=n_turnover, replace=False)
            synapses[new_synapses,c] = 1.
            #initialise new weights for naive synapses 
            weights[new_synapses, c] = np.random.choice(synaptic_strength_pool, size=len(new_synapses), replace=True)
        
        elif (len(lost_synapses) < n_turnover) and (len(lost_synapses) >=1):
            #set remaining synapses to 0 if less than n_turnover available 
            synapses[lost_synapses,c] = 0.
            weights[lost_synapses,c] = 0.  
            new_synapses = np.where(synapses[:,c] == 0.0)[0]
            new_synapses = np.random.choice(new_synapses, size=len(lost_synapses), replace=False)
            synapses[new_synapses,c] = 1.
            #initialise new weights for naive synapses 
            weights[new_synapses, c] = np.random.choice(synaptic_strength_pool, size=len(new_synapses), replace=True)
    place_cells['synapses'] = synapses
    place_cells['weights'] = weights
    return place_cells 


def synapse_loss(place_cells, day_count_gc=0): 
    """ AD-related excitatory synapse loss (grid cell-to-place cell synapses)

    Parameters
    ----------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.
    day_count_gc : int, optional
        Day count for excitatory synapse loss. The default is 0.

    Returns
    -------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays, reflecting loss of excitatory synapses.
    day_count_gc : int
        Day count for excitatory synapse loss.

    """
    
    synapses = place_cells['synapses']
    weights = place_cells['weights']
    n_pc = synapses.shape[1]
    n_syn_left = np.sum(synapses) #number of synapses left
    n_syn = n_pc*int(np.ceil(1200*(n_pc/311500))) #total number of synapses on day 0
    
    #determine number of synapses to be lost depending on time passed since simulation start
    timepoint = day_count_gc/30
    pr_timepoint = (day_count_gc-1)/30
    percent_synapses_uptodate = ((0.2716*(timepoint**2))-(9.0677*timepoint))/100
    percent_synapses_uptopreviousdate = ((0.2716*(pr_timepoint**2))-(9.0677*pr_timepoint))/100
    percent_per_day = percent_synapses_uptopreviousdate - percent_synapses_uptodate 
    #add time 
    day_count_gc += 1
    #number of synapses lost on this iteration
    n_loss = int(np.round(percent_per_day * n_syn))
    #if not enough synapses left, set n_loss equal to the total number of synapses left
    #this should not occur during a 365 day simulation
    if n_syn_left < n_loss:
        n_loss = n_syn_left
    
    cells_affected = np.random.choice(np.arange(n_pc), size=n_loss, replace=True)
    #eliminate synapses 
    for c in range(synapses.shape[1]):
        if c in cells_affected:
            n_SYN_affected = np.count_nonzero(cells_affected == c)
            deleted_synapses = np.where(synapses[:,c] == 1.0)[0]
            if len(deleted_synapses) > n_SYN_affected:
                deleted_synapses = np.random.choice(deleted_synapses, size=n_SYN_affected, replace=False)
            synapses[deleted_synapses,c] = 0.
            weights[deleted_synapses,c]=0.
    place_cells['synapses'] = synapses
    place_cells['weights'] = weights 
    return [place_cells, day_count_gc]

def interneuron_synapse_turnover(in_connectivity, out_connectivity):
    """ Synaptic turnover of place cell-to-interneuron and interneuron-to-place cell synapses

    Parameters
    ----------
    in_connectivity : list
        List of place cell-to-interneuron synapses. 
    out_connectivity : list
        List of interneuron-to-place cell synapses. 

    Returns
    -------
    in_connectivity : list
        List of place cell-to-interneuron synapses reflecting synaptic turnover. 
    out_connectivity : list
        List of interneuron-to-place cell synapses reflecting synaptic turnover. 

    """
    n_interneurons = len(in_connectivity)
    #get number of synapses for turnover per interneuron
    n_in = len(in_connectivity[0])
    n_out = len(out_connectivity[0])
    n_turnover_in = int(np.round(n_in-(n_in*np.exp(-1/10)))) #determined based on decay model N0-Nt
    n_turnover_out = int(np.round(n_out-(n_out*np.exp(-1/10))))
    
    for i in range(n_interneurons): 
        cell_pool = np.arange(n_pc)
        in_pool = [x for x in cell_pool if x not in in_connectivity[i]]
        out_pool = [x for x in cell_pool if x not in out_connectivity[i]]
        
        idx_lost_in = np.random.choice(np.arange(len(in_connectivity[i])), size=n_turnover_in, replace=False)
        idx_new_in = np.random.choice(in_pool, size=n_turnover_in, replace=False)
        
        #check if there is enough synapses left for determined level of turnover 
        if n_turnover_out <= len(out_connectivity[i]):
            idx_lost_out = np.random.choice(np.arange(len(out_connectivity[i])), size=n_turnover_out, replace=False)
            idx_new_out = np.random.choice(out_pool, size=n_turnover_out, replace=False)
        elif (n_turnover_out > len(out_connectivity[i])) and (len(out_connectivity[i]) > 0):
            idx_lost_out = np.random.choice(np.arange(len(out_connectivity[i])), size=len(out_connectivity[i]), replace=False)
            idx_new_out = np.random.choice(out_pool, size=len(out_connectivity[i]), replace=False)
        for r in range(n_turnover_out):
            if (len(out_connectivity[i]) > 0):
                out_connectivity[i][idx_lost_out[r]] = idx_new_out[r]
            else:
                out_connectivity[i] = out_connectivity[i] 
        for j in range(n_turnover_in):
            in_connectivity[i][idx_lost_in[j]] = idx_new_in[j]
    return [in_connectivity, out_connectivity]
    

def interneuron_syn_loss(place_cells, out_connectivity, day_count=0):
    """ Interneuron-to-place cell synapse loss

    Parameters
    ----------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.
    out_connectivity : list
        List of interneuron-to-place cell synapses.
    day_count : int, optional
        Day count for inhibitory synapse loss. The default is 0.

    Returns
    -------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.
    out_connectivity : list
        List of updated interneuron-to-place cell synapses relfecting interneuron-to-place cell synapse loss.
    day_count : int, optional
        Day count for inhibitory synapse loss.

    """

    n_pc = place_cells['synapses'].shape[1]
    n_interneurons = int(np.ceil((n_pc/8)))
    #calculate synapse loss using quadratic equation based on values found in Schmid et al. (2010) 
    total_GABAsynapse = n_interneurons * int(np.ceil(928 * (n_pc/311500)))
    #determine timepoint of simulation
    timepoint = day_count/30
    pr_timepoint = (day_count-1)/30
    #determine percentage of synapses lost at current timepoint
    percent_loss_uptodate = -1*((-(0.0532*(timepoint**2)) - (2.2179*timepoint))/100)
    percent_loss_uptopreviousdate = -1*((-(0.0532*(pr_timepoint**2)) - (2.2179*pr_timepoint))/100)
    percent_day = percent_loss_uptodate - percent_loss_uptopreviousdate
    #update day count
    day_count += 1
    #number of synapses lost at current timepoint
    n_loss = int(np.round(percent_day * total_GABAsynapse)) 
    #pick n_loss random interneurons that will be affected 
    INs_affected = np.random.choice(np.arange(n_interneurons), size=n_loss, replace=True) #replace=True to allow one interneuron to lose multiple synapses  
    
    for i in range(n_interneurons):  
        #get utputs for each interneuron
        PC_output = out_connectivity[i]
        PC_output = [int(x) for x in PC_output]
        #AD-related synapse loss 
        
        if i in INs_affected:
            n_SYN_affected = np.count_nonzero(INs_affected == i)
            stop_point = len(PC_output) - n_SYN_affected 
            if stop_point > 0:
                PC_output = PC_output[0:stop_point] 
            else:
                PC_output = []
                print('Interneuron %i has lost all synapses'%i)
            out_connectivity[i] = PC_output
            
    return [place_cells, out_connectivity, day_count]


def interneurons_inp(place_cells, grid, winner_quantile=0.9, in_connectivity=[], out_connectivity=[]):
    """Initialises interneuron-place cell synapses and implements feedback inhibition of place cells

    Parameters
    ----------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.
    grid : array
        Defines grid cell firing rates.
    winner_quantile : float, optional
        Fraction of the maximum firing rate of any cell that a cell has to achieve to escape inhibition.. The default is 0.9.
    in_connectivity : list, optional
        List of place cell-to-interneuron synapses. The default is an empty list.
    out_connectivity : list, optional
        List of interneuron-to-place cell synapses. The default is an empty list.

    Returns
    -------
    place_cells : dict
        Dictionary with updated place cell firing rates reflecting competitive inhibition.
    in_connectivity : list, optional
        Initialised list of place cell-to-interneuron synapses on first iteration.
    out_connectivity : list, optional
        initialised list of interneuron-to-place cell synapses on first iteration.

    """
    
    #calculate PC firing rates
    y = np.dot(grid, place_cells['weights'])
    #determine number of interneurons
    n_pc = place_cells['synapses'].shape[1]
    n_interneurons = int(np.ceil((n_pc/8)))
    #initialise connectivity on first iteration 
    if len(in_connectivity) == 0:
        #number of PC inputs to each interneuron
        cell_input = int(np.ceil(728 * (n_pc/311500)))
        in_connectivity = []
        #number of projections to PC from each interneuron
        cell_output = int(np.ceil(928 * (n_pc/311500)))
        out_connectivity = []
        #set connectivity for each interneuron
        for i in range(n_interneurons):
            PC_input = np.random.choice(np.arange(n_pc), size=cell_input, replace=False)
            in_connectivity.append(list(PC_input))
            PC_output = np.random.choice(np.arange(n_pc), size=cell_output, replace=False)
            out_connectivity.append(list(PC_output))
    
    #Competitive inhibition 
    for i in range(n_interneurons):  
        #get inputs and outputs for each interneuron
        PC_input = in_connectivity[i]
        PC_input = [int(x) for x in PC_input]
        PC_output = out_connectivity[i]
        PC_output = [int(x) for x in PC_output]
        #Competitive inhibition
        for p in range(100): 
            threshold_firing = np.amax(y[p,PC_input]) * winner_quantile 
            for j in range(len(PC_output)):
                if y[p,PC_output[j]] < threshold_firing:
                    y[p,PC_output[j]] = 0.
    place_cells['firing_rates'] = y
    
    return [place_cells, in_connectivity, out_connectivity]


def scaling_function(x, scaling):
    """ Synaptic scaling function used within Hebbian and BCM learning function

    Parameters
    ----------
    x : 1D array
        Array of weights of all synapses converging onto one place cell.
    scaling : float
        Expected sum of synaptic weights converging onto a place cell.

    Returns
    -------
    ans : 1D array
        Array of scaled weights of all synapses converging onto one place cell.

    """

    if np.sum(x) >0:
        ans = (scaling/np.sum(x)) *x
    else:
        ans = x
    return ans

def update_hebbian(place_cells, grid, learning_rate=0.001, scaling=149.1):
    """ Hebbian learning rule for grid cell-to-place cell synapses

    Parameters
    ----------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.
    grid : array
        Defines grid cell firing rates.
    learning_rate : float, optional
        Hebbian learning rate. The default is 1e-3.
    scaling : float, optional
        Expected sum of synaptic weights converging onto a place cell. The default is 149.1, 
        expected in a network with 1,200 grid cell-to-place cell synapses.

    Returns
    -------
    place_cells : dict
        Dictionary with updated place cell synaptic weights.

    """

    y = place_cells['firing_rates']
    dw = np.dot(grid.T, y) 
    dw = place_cells['synapses'] * dw
    weight_change = (learning_rate*dw)
    place_cells['weights'] = place_cells['weights'] + weight_change

    #synaptic scaling 
    place_cells['weights'] = np.apply_along_axis(scaling_function,0, place_cells['weights'], scaling=scaling)
    #delete any potential negative weights 
    place_cells['weights'] = (place_cells['weights'] >0.)*place_cells['weights']
    
    #optional upper bounds for synaptic weights 
    #upperlimit = np.where(place_cells['weights'] > 2.)
    #place_cells['weights'][upperlimit] = 2.
    
    return place_cells

def update_bcm(place_cells, grid, scaling=149.1, limit=2):
    """ BCM learning rule for grid cell-to-place cell synapses

    Parameters
    ----------
    place_cells : dict
        Dictionary with place cell synapses, synaptic weights and firing rates arrays.
    grid : array
        Defines grid cell firing rates.
    scaling : float, optional
        Expected sum of synaptic weights converging onto a place cell. The default is 149.1, 
        expected in a network with 1,200 grid cell-to-place cell synapses.
    limit : int or float, optional
        Limit l for y*(y-T) term. The default is 2.

    Returns
    -------
    place_cells : dict
        Dictionary with updated place cell synaptic weights.

    """ 

    y = place_cells['firing_rates']
    #positive constant F_0 
    F_0 = 50
    #get dynamic threshold for each place cell 
    for c in range(y.shape[1]):
        F_mean = np.mean(y[:,c])
        T = ((F_mean/F_0)**2)*F_mean
        y[:,c] = y[:,c]*(y[:,c]-T)       
        for i in range(len(y[:,c])):
            if y[i,c] > limit:
                y[i,c] = limit
            elif y[i,c] < (-limit):
                y[i,c] = -limit
    #update synaptic weights 
    dw = np.dot(grid.T,y)
    dw = place_cells['synapses'] * dw 
    place_cells['weights'] = place_cells['weights'] + dw 
    #prevent negative weights 
    place_cells['weights'] = (place_cells['weights'] >= 0.)*place_cells['weights'] 
    #synaptic scaling
    place_cells['weights'] = np.apply_along_axis(scaling_function,0, place_cells['weights'], scaling=scaling)
    return place_cells
    

def rle(response_array):
    """ Run length encoding function to get place field centroids in centroid_fun. 

    Parameters
    ----------
    response_array : 1D numpy array or list
        True/False vector array of place cell firing with True where place cell's firing rate 
        is above 50% of the maximum firing rate.

    Returns
    -------
    place_cells : dict
        Dictionary with updated place cell synaptic weights.
    
    """

    values = np.asarray(response_array) 
    N = len(values)
    if N == 0:
        return (None, None, None)
    else:
        y = np.array(values[1:] != values[:-1])     
        i = np.append(np.where(y), N - 1)   
        run_length = np.diff(np.append(-1, i))       
        start_pos = np.cumsum(np.append(0, run_length))[:-1] 
    return(run_length, start_pos, values[i])

def centroid_fun(response_cell, min_field_width=5):
    """ Function used within get_centroids to find place field centroids

    Parameters
    ----------
    response_cell : 1D numpy array 
        Array of firing rates of the a place cell.
    min_field_width : int, optional
        Minimum length of place fields. Default is 5.

    Returns
    -------
    centroid: int
        Position of centroid of the cell's place field along the track. 0 if no place field detected.
        
  """
    maximum = np.max(response_cell)
    runs = rle(response_cell >(maximum*0.5)) #gives True/False vector and runs rle on it
    long_runs = (runs[0] > min_field_width) & (runs[2] == True) & (runs[0] < 50)
    if (np.sum(long_runs) == 1):
        centroid = np.cumsum(runs[0])[long_runs] - runs[0][long_runs]/2 
    else:
        centroid = 0
        #no PF longer than 5 positions has central point at position 0 
    return centroid
 
def get_centroids(response, min_field_width=5):
    """ Function to apply centroid_fun to a 2D array of place cell firing rates.

    Parameters
    ----------
    response : 2D array
        Firing rates of place cells
    min_field_width : int, optional
        Minimum length of place fields. Default is 5.

    Returns
    -------
    centroids : 1D array
        Array of place field centroid values and 0's'

    """

    centroids = np.apply_along_axis(centroid_fun,0, response)
    return centroids
    
def centroid_widths(response, min_field_width=5):
    """ Function to get place field widths

    Parameters
    ----------
    response : 2D array
        Firing rates of place cells
    min_field_width : int, optional
        Minimum length of place fields. The default is 5.

    Returns
    -------
    median_width : float
        Median place field width.
    mean_width : float
        Median place field width.
    width_std : float
        Standard deviation of place field widths.

    """

    widths = []
    for i in range(response.shape[1]):
        maximum = np.max(response[:,i])
        runs = rle(response[:,i] >(maximum*0.5)) #gives True/False vector and runs rle on it
        long_runs = (runs[0] > min_field_width) & (runs[0] < 50) & (runs[2] == True)
        if (np.sum(long_runs) == 1):
            width = runs[0][long_runs]
            widths.append(width)
    median_width = np.median(widths)
    mean_width = np.mean(widths)
    width_std = np.std(widths)
    return (median_width, mean_width, width_std)

def get_place_field_properties(place_cell_list, samples):
    """ Function to get main place cell properties from simulation, including number of 
        place cells and place field widths

    Parameters
    ----------
    place_cell_list : dict
        Dictionary of place cell array for each sampled day in the simulation.
    samples : 1D array or list
        Timepoints at which to determine place cell properties.

    Returns
    -------
    tpcs : list
        Total number of place cells on each day given in samples.
    pf_width : list
        Median place field width on each day given in samples.
    pf_width_mean : list
        Mean place field width on each day given in samples.
    pf_width_std : list
        Standard deviation of place field width on each day given in samples.

    """
    # new working dictionary with place cell firing from sim output tuple
    pcs = {} 
    for d in samples:
        pcs[d] = np.array(place_cell_list[d]['firing_rates'])
    #get PF width
    [avg_width_0, mean_width_0, width_std_0] = centroid_widths(pcs[samples[0]])
    #get centroids and locations of centroids on day 0
    centroids_day0 = get_centroids(pcs[samples[0]])
    if len(centroids_day0) == 1:
        centroids_day0 = centroids_day0[0] #avoid array in array problem 
    centroids_loc_day0 = [i for i,e in enumerate(centroids_day0) if e!= 0]
    
    #Day 0 
    tpcs = []
    tpcs.append(len(centroids_loc_day0))
    pf_width = [avg_width_0]
    pf_width_mean = [mean_width_0]
    pf_width_std = [width_std_0]
    
    for i in samples[1:]:
        [median_width_i, mean_width_i, width_std_i] = centroid_widths(pcs[i])
        pf_width.append(median_width_i)
        pf_width_mean.append(mean_width_i)
        pf_width_std.append(width_std_i) 
        centroids_day_i = get_centroids(pcs[i])         
        if len(centroids_day_i) == 1:
             centroids_day_i = centroids_day_i[0]
        centroids_loc_day_i = [i for i,e in enumerate(centroids_day_i) if e!= 0]
        tpcs.append(len(centroids_loc_day_i))

    return (tpcs, pf_width, pf_width_mean, pf_width_std)

def get_activity_distribution(sim_run, samples):
    """ Get activity distribution of place cells (see Figure 6)
    
    Parameters
    ----------
    sim_run : dict
        Dictionary of place cell array for each sampled day in the simulation.
    samples : 1D array or list
        Timepoints at which to determine place cell properties.

    Returns
    -------
    cell_id: 2D array
        Gives activity level of each cell for each position.

    """    
    n_pc = np.array(sim_run[samples[0]]['firing_rates']).shape[1]
    
    cell_id = np.zeros((n_pc, len(samples))) # should contain activity group for each cell at each time point
    #id's saved as: 0: silent, 1:rare, 2:intermediate, 3:high
    high_cutoff= 200
    low_cutoff = 50
    for c, i in enumerate(samples):
        firing = np.array(sim_run[i]['firing_rates'])        
        for j in range(n_pc):
            s = np.sum(firing[:,j])
            if s>0:
                if s> high_cutoff: #highly active
                    cell_id[j,c] = 3
                elif s<low_cutoff: #rarely active
                    cell_id[j,c] = 1
                else: #intermediately active
                    cell_id[j,c] = 2
            else: #silent
                cell_id[j,c] = 0

    return [cell_id] 

def get_place_fields(day_firing_rates, min_field_width=5):
    """ Get array with firing rates in place field positions only (see Fig 1A)

    Parameters
    ----------
    day_firing_rates : 2D array
        Array of firing rates of place cells on one simulation day.
    min_field_width : int, optional
        Minimum place field width. The default is 5.

    Returns
    -------
    pfs.T : 2D array
        Array with firing rates of cells at place field positions, 0 otherwise.

    """
    firing_rates = np.array(day_firing_rates)
    pfs = np.zeros(firing_rates.shape)
    
    for col in range(firing_rates.shape[1]):
        maximum = np.max(firing_rates[:,col])
        runs = rle((firing_rates[:,col] > (maximum*0.8)) )#& (firing_rates[:,col] > 1))
        long_runs = (runs[0] > min_field_width) & (runs[2] == True)
        if (np.sum(long_runs) == 1):
            start = int(runs[1][long_runs])
            end = int(np.cumsum(runs[0])[long_runs])
            pfs[start:(end+1), col] = firing_rates[start:(end+1),col]  #
    return pfs.T


def get_new_pc_indices(sim_run, samples):
    """ Get number of new place cells

    Parameters
    ----------
    sim_run : dict
        Dictionary of place cell array for each sampled day in the simulation.
    samples : 1D array or list
        Timepoints at which to determine place cell properties.

    Returns
    -------
    new_pc_indices : list
        Number of new place cells since last sample for all timepoints in samples.

    """
    pc_indices = []
    for i, d in enumerate(samples):  
        pfs = get_place_fields(sim_run[d]['firing_rates'])
        pc_idx = []
        for row in range(pfs.shape[0]):
            o = np.where(pfs[row,:] >0)
            if len(o[0]) > 0: 
                pc_idx.append(row)
        pc_indices.append(pc_idx)

    new_pc_indices = [] 
    for i in range(len(samples)):
        if i==0:
            new_pc_indices.append(len(pc_indices[0]))
        else:
            new_pcs=[]
            for j in pc_indices[i]:
                if j not in pc_indices[i-1]:
                    if j not in pc_indices[0]:
                        new_pcs.append(j)
            new_pc_indices.append(len(new_pcs))
    return new_pc_indices

def get_rpcs(sim_run, samples):
    """ Get recurring place cells 

    Parameters
    ----------
    sim_run : dict
        Dictionary of place cell array for each sampled day in the simulation.
    samples : 1D array or list
        Timepoints at which to determine place cell recurrence with samples[1] as reference
        
    Returns
    -------
    rpcs : list
        Number of recurring place cells with less than 5 units in centroid drift between any two timepoints.

    """

    pcs = {}
    for d in samples:
        pcs[d] = np.array(sim_run[d]['firing_rates'])
    centroids_day0 = get_centroids(pcs[samples[0]])
    if len(centroids_day0) == 1:
        centroids_day0 = centroids_day0[0] #avoid array in array problem 
    centroids_loc_day0 = [i for i,e in enumerate(centroids_day0) if e!= 0]
    rpcs = [len(centroids_loc_day0)]
    
    for i in samples[1:]: 
        centroids_day_i = get_centroids(pcs[i])         
        if len(centroids_day_i) == 1:
             centroids_day_i = centroids_day_i[0]
        centroids_loc_day_i = [i for i,e in enumerate(centroids_day_i) if e!= 0]
        
        #safe first centroids after day 0 as reference 
        if i==samples[1]:
            ref_loc_centroids = centroids_loc_day_i
            ref_centroids_position = centroids_day_i

        rc_loc = [x for x in centroids_loc_day_i if x in ref_loc_centroids]
        pos_change = np.array(abs(centroids_day_i[rc_loc] - ref_centroids_position[rc_loc]))
        acceptchange = np.where(pos_change <=5)
        acceptchange = [x for x in acceptchange[0]]
        ref_centroids_position = centroids_day_i

        rc_loc = [e for i,e in enumerate(rc_loc) if i in acceptchange]
        rpcs.append(len(rc_loc))
    return (rpcs)

def get_recurrence_odds(sim_run, samples, rpcs): 
    """ Get probability for recurrence of a place cell or active cell between timepoints

    Parameters
    ----------
    sim_run : dict
        Dictionary of place cell array for each sampled day in the simulation.
    samples : 1D array or list
        Timepoints at which to determine place cell recurrence with samples[1] as reference
    rpcs : list
        Number of recurring place cells with less than 5 units in centroid drift between any two timepoints.
        Output of get_rpcs function.

    Returns
    -------
    active_probs : list
        Probability of recurrence of an active cell between two timepoints.
    recurrence_probs : list
        Probability of recurrence of a place cell between two timepoints.

    """
    n_pc = np.array(sim_run[samples[0]]['firing_rates']).shape[1]
    cell_id = np.zeros((n_pc, len(samples)))

    for c, i in enumerate(samples):
            firing = np.array(sim_run[i]['firing_rates'])
            for j in range(n_pc):
                s = np.sum(firing[:,j])
                if s>0:
                    cell_id[j,c] = 1
                elif s==0:
                    cell_id[j,c] = 0

    cell_activities = pd.DataFrame(cell_id, columns=samples)
    active_probs = []
    for i in range(1,len(samples)):
        overlap = pd.crosstab(cell_activities[samples[1]],cell_activities[samples[i]])[1][1]
        final_per = overlap/ pd.crosstab(cell_activities[samples[1]],cell_activities[samples[1]])[1][1]
        active_probs.append(final_per)

    recurrence_probs = []
    for i in range(1,len(samples)):
        recurrence_probs.append(rpcs[i]/rpcs[1])
    
    return (active_probs, recurrence_probs)

#%% Run simulation

#Load grid cell firing from grid_cells-2d with activities of 10,000 grid cell
grid = np.load("grid_cells-2d.npy")
grid = grid.T
grid = grid[:,:5000]

#Load synaptic strength pool
with open("syn_str_pool.csv", newline='') as csvfile: 
    reader = csv.reader(csvfile, delimiter=' ')
    syn_str_pool = list(reader)
synaptic_strength_pool = []
for e in syn_str_pool:
    synaptic_strength_pool.append(float(e[0]))
 
def simulate_place_cells(grid, samples, learning_rate=1e-3, gc_syn_loss=False, inh_syn_loss=False, n_pc=1000):
    """ Simulate place cell activity

    Parameters
    ----------
    grid : 2D array
        Grid cell activity array.
    learning_rate : float, optional
        Learning rate for Hebbian learning.
    samples : 1D array or list
        Timepoints at which place cell activity samples are stored. Maximum defines length of simulation.
    gc_syn_loss : True/False, optional
        Defines whether AD-related grid cell-to-place cell synaptic loss is implemented. The default is False.
    inh_syn_loss : True/False, optional
        Defines whether AD-related interneuron-to-place cell synaptic loss is implemented. The default is False.
    n_pc : int, optional
        Number of pyramidal cells in simulation. The default is 1000.

    Returns
    -------
    place_cell_list : Dict
        Dictionary of sample timepoints with place cell synapses, syanptic weights and firing rate arrays.

    """
    place_cell_list = {} 

    n_grids = grid.shape[1] #number of grid cells 
    winner_quantile=0.90 #for competitive inhibition

    #get the expected sum of synaptic weight for each place cell 
    E=[]
    for i in range(10000):
        n_syn = int(np.ceil(1200*(n_pc/311500)))
        e=np.random.choice(synaptic_strength_pool, n_syn, replace=True)
        E.append(np.sum(e))
    scaling=np.mean(E)

    #initialize variables
    place_cells = initialize_place_cells(n_pc, n_grids)
    initial_in_connectivity = []
    initial_out_connectivity= []
    day_count_gc = 1 #start at day 1
    day_count = 1 #start inhibitory synapse loss at day 1
    
    #get initial firing rate and initialise interneuron-PC architecture 
    [place_cells, in_connectivity, out_connectivity] = interneurons_inp(place_cells, grid, winner_quantile, in_connectivity=initial_in_connectivity, out_connectivity=initial_out_connectivity) 
    #can either use Hebbian learning (update_hebbian) or BCM learning (update_bcm) here
    #place_cells = update_hebbian(place_cells,grid,learning_rate)
    place_cells = update_bcm(place_cells, grid, scaling=scaling, limit=200)
    [place_cells, in_connectivity, out_connectivity] = interneurons_inp(place_cells, grid, winner_quantile=winner_quantile, in_connectivity=in_connectivity, out_connectivity=out_connectivity) 

    #place_cells data mapped to tuples and saved 
    pc = {}
    pc['synapses'] = tuple(map(tuple, place_cells['synapses']))
    pc['weights'] = tuple(map(tuple, place_cells['weights']))
    pc['firing_rates'] = tuple(map(tuple, place_cells['firing_rates']))
    place_cell_list[0] = pc.copy()
    
    #Export data for re-import if there are memory issues due to length of simulation
    #np.savez_compressed('sim_0.npz', synapses=place_cells['synapses'], weights=place_cells['weights'], firing=place_cells['firing_rates'])
    
    #day 1 and beyond 
    for d in range(1,np.max(samples)+1):
        place_cells = turnover(place_cells) 
        [in_connectivity, out_connectivity] = interneuron_synapse_turnover(in_connectivity, out_connectivity) #interneuron synapse turnover 
        #AD related synapse loss
        if (gc_syn_loss==True): 
            [place_cells, day_count_gc] = synapse_loss(place_cells, day_count_gc)
        if (inh_syn_loss==True):
            [place_cells, out_connectivity, day_count] = interneuron_syn_loss(place_cells=place_cells, out_connectivity=out_connectivity, day_count=day_count)
        
        [place_cells, in_connectivity, out_connectivity] = interneurons_inp(place_cells, grid,  winner_quantile=winner_quantile, in_connectivity=in_connectivity, out_connectivity=out_connectivity)
        #place_cells = update_hebbian(place_cells, grid, learning_rate)
        place_cells = update_bcm(place_cells, grid, scaling=scaling, limit=200)
        [place_cells, in_connectivity, out_connectivity] = interneurons_inp(place_cells, grid, winner_quantile=winner_quantile, in_connectivity=in_connectivity, out_connectivity=out_connectivity)
        
        #save data 
        if d in samples:
            print('Days elapsed: %i'%d)
            #np.savez_compressed('sim_%i'%d, synapses=place_cells['synapses'], weights=place_cells['weights'], firing=place_cells['firing_rates'])
            pc['synapses'] = tuple(map(tuple, place_cells['synapses']))
            pc['weights'] = tuple(map(tuple, place_cells['weights']))
            pc['firing_rates'] = tuple(map(tuple, place_cells['firing_rates']))
            place_cell_list[d] = pc.copy() #save place cell data
    return place_cell_list

""" Set parameters:
    eta: learning rate for Hebbian learning
    n_pc: number of place cells
    gc_syn_loss: True if grid cell-to-place cell synaptic loss implemented
    inh_syn_loss: True if interneuron-to-place cell synaptic loss implemented
    samples: 'Days' on which to save simulation data 
    
    """
eta = 1e-3 
n_pc = int(np.ceil(15575*0.5)) 
gc_syn_loss= False
inh_syn_loss=False
samples = np.arange(0,31,5)

sim_run = simulate_place_cells(grid, samples=samples, learning_rate=eta, gc_syn_loss=gc_syn_loss, inh_syn_loss = inh_syn_loss, n_pc=n_pc)

#%% Analyse place cell properties

(tpcs, pf_width, pf_width_mean, pf_width_std) = get_place_field_properties(sim_run, samples)
cell_id = get_activity_distribution(sim_run, samples)
npcs = get_new_pc_indices(sim_run, samples)

recurrence_samples = np.arange(0,31,5)
rpcs = get_rpcs(sim_run, recurrence_samples)
[active_probs, recurr_probs] = get_recurrence_odds(sim_run, recurrence_samples, rpcs)

#%% Demo graphs 

fig, (ax, ax1, ax2) = plt.subplots(3,1, figsize=(5,25))

ax.errorbar(samples, tpcs, c='C0', label = 'Place cells')
ax.errorbar(samples, npcs, c='C0', ls='--', label = 'New place cells')
ax.set_title('Number of place cells and new place cells')
ax.set_ylabel('Number of cells')
ax.locator_params(axis='y', nbins=6)
ax.tick_params(axis='x', direction='out', left='on', labelleft='on')
ax.tick_params(axis='y',direction='out', left='on',labelleft='on')
ax.legend(facecolor='1', edgecolor='1', loc='upper right')
ax.set_ylim(bottom=0)

ax1.errorbar(samples, pf_width_mean, color='C0')
ax1.set_title('Mean place field width')
ax1.set_ylabel('Mean width (cm)')
ax1.locator_params(axis='y', nbins=6)
ax1.set_ylim(bottom=0, top=20)
ax1.tick_params(axis = 'both', which = 'major')

ax2.errorbar(recurrence_samples[:-1], recurr_probs, label = 'Place cells', ls='-', marker='.', color='C0')
ax2.errorbar(recurrence_samples[:-1], active_probs, label='Active cells', ls='--', marker='.', color='C0', alpha=0.4)
ax2.set_title('Recurrence probability')
ax2.set_ylim(0,1)
ax2.set_ylabel('Probability of recurrence')
ax2.legend(facecolor='1', edgecolor='1', loc='upper right')
ax2.locator_params(axis='y', nbins=8)

for axis in [ax,ax1,ax2]:
    axis.spines['top'].set_visible(False)
    axis.spines['right'].set_visible(False)
    axis.set_xlabel('Time from first session (days)')
    
fig.subplots_adjust(hspace=0.8)