# class definitions to generate an MRG axon, extracellular electrode and trial objects

from neuron import h
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import LinearNDInterpolator
import os


def line(x, a, b):
    return a*x + b


def logfn(x, h, k):
    return h*np.log(x) + k


def logreg(x, L, x0, k, b):
    return (L/(1.0 + np.exp(-k*(x-x0)))) + b 

h.celsius = 37.0
h.dt = 0.005  # ms

class axon():
    """Class to create an axon of passage

    Structure and electrical properties are obtained from the MRG model modified for sensory neurons in Gaines et al. 
    Contains functions to generate the axon with user defined properties and set the XYZ position.

    NOTE
    ----
    Instantiating the axon object DOES NOT automatically generate positions for the neuron segments.
    Object function setXYZpos MUST be run on the axon object AFTER it has been instantiated.
    """
    def __init__(self, axonnodes, fiberD, pos=(0,0,0)):
        """
        Parameters
        ----------
        axonnodes : int
            number of nodes to include in the axon (must be >= 1)
        fiberD: float
            diameter of axon
        pos : tuple
            (x,y,z) position of endpoint of axon
        """

        # electrical parameters
        self.rhoa = 0.7e6   # [Ohm-um] specific axoplasmic resistance
        self.mycm = 0.1     # uF/cm2/lamella membrane//
        self.mygm = 0.001   # S/cm2/lamella membrane//

        self.timestep = h.dt

        # morphological parameters
        self.axonnodes = axonnodes          
        self.paranodes1 = 2*(axonnodes-1)   
        self.paranodes2 = 2*(axonnodes-1)   
        self.axoninter = 6*(axonnodes-1)    
        self.numNodes = axonnodes
        self.numSections = self.axonnodes+self.paranodes1+self.paranodes2+self.axoninter

        self.fiberD = fiberD  
        self.paralength1 = 3.0
        self.nodelength = 1.0
        self.space_p1 = 0.002
        self.space_p2 = 0.004
        self.space_i = 0.004

        self.sectionDF = pd.DataFrame(columns=['nodeIndex', 'sectionIndex', 'object', 'sectionType', 'x', 'y', 'z', 'L'])

        self.getModelParamsContinuous(fiberD)

        # construct axon
        parentSection = []
        for iNode in range(axonnodes-1):
            parentSection = self.createMRGfiber(iNode, parentSection=parentSection)
        self.createNodeSection(iNode+1, parentSection=parentSection)

        # setup recording
        self.membraneV = []

    def createMRGfiber(self, nodeNum, parentSection=None):
        node = self.createNodeSection(nodeNum)
        mysa1 = self.createMYSAsection(node)
        flut1 = self.createFLUTsection(mysa1)
        STINroot = flut1
        for iSTIN in range(6):
            STINroot = self.createSTINsection(parentSection=STINroot)
        flut2 = self.createFLUTsection(STINroot)
        mysa2 = self.createMYSAsection(flut2)

        if parentSection:
            node.connect(parentSection)

        return mysa2

    def createNodeSection(self, nodeNum, parentSection=None):
        tmp = h.Section(name='node')
        tmp.nseg = 1
        tmp.L = self.nodelength
        tmp.Ra = self.rhoa / 10000
        tmp.cm = 2.0
        tmp.diam = self.nodeD
        # xc is a hoc nickname for xc[0], i.e. the capacitance in the extracellular layer that is immediately adjacent to the
        # cell membrane. xc[1] is the capacitance in the outer extracellular layer. Similar comments apply to xg and xg[1].

        if nodeNum ==0 or nodeNum == self.axonnodes: #80:
            tmp.insert('pas')
            tmp.g_pas = 0.0001
            tmp.e_pas = -80
        else:
            tmp.insert('node_sensory')
        tmp.insert('extracellular')
        tmp.xraxial[0] = self.Rpn0
        tmp.xg[0] = 1e10
        tmp.xc[0] = 0.0
        self.sectionDF = self.sectionDF.append({'sectionType': 'node',
                                                    'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
                                                    'sectionIndex': (self.sectionDF['sectionType']=='node').sum(),
                                                    'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def createMYSAsection(self, parentSection=None):
        tmp = h.Section(name='MYSA')
        tmp.nseg = 1
        tmp.L = self.paralength1
        tmp.Ra = self.rhoa * (1 / (self.paraD1 / self.fiberD) ** 2.0) / 10000
        tmp.cm = 2.0 * self.paraD1 / self.fiberD
        tmp.diam = self.fiberD

        tmp.insert('mysa_sensory')

        tmp.insert('extracellular')
        tmp.xraxial[0] = self.Rpn1
        tmp.xg[0] = self.mygm / (self.nl * 2)
        tmp.xc[0] = self.mycm / (self.nl * 2)

        self.sectionDF = self.sectionDF.append({'sectionType': 'MYSA',
                                                'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
                                                'sectionIndex': (self.sectionDF['sectionType']=='MYSA').sum(),
                                                'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def createFLUTsection(self, parentSection=None):
        tmp = h.Section(name='FLUT')
        tmp.nseg = 1
        tmp.L = self.paralength2
        tmp.Ra = self.rhoa * (1 / (self.paraD2 / self.fiberD) ** 2.0) / 10000
        tmp.cm = 2.0 * self.paraD2 / self.fiberD
        tmp.diam = self.fiberD

        tmp.insert('flut_sensory')

        tmp.insert('extracellular')
        tmp.xraxial[0] = self.Rpn2
        tmp.xg[0] = self.mygm / (self.nl * 2)
        tmp.xc[0] = self.mycm / (self.nl * 2)
        self.sectionDF = self.sectionDF.append({'sectionType': 'FLUT',
                                                'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
                                                'sectionIndex': (self.sectionDF['sectionType'] == 'FLUT').sum(),
                                                'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def createSTINsection(self, parentSection=None):
        tmp = h.Section(name='STIN')
        tmp.nseg = 1
        tmp.L = self.interlength
        tmp.Ra = self.rhoa * (1 / (self.axonD / self.fiberD) ** 2.0) / 10000
        tmp.cm = 2 * self.axonD / self.fiberD
        tmp.diam = self.fiberD

        tmp.insert('stin_sensory')

        tmp.insert('extracellular')
        tmp.xraxial[0] = self.Rpx
        tmp.xg[0] = self.mygm / (self.nl * 2)
        tmp.xc[0] = self.mycm / (self.nl * 2)
        self.sectionDF = self.sectionDF.append({'sectionType': 'STIN',
                                                'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
                                                'sectionIndex': (self.sectionDF['sectionType'] == 'STIN').sum(),
                                                'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)
        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def setXYZpos(self, pos, bent=False, peak_index=None, xpos_set=False, xshift=None):
        # place middle node of axon at center of DRG with some random jitter
        """
        pos: tuple
            (x,y,z) coordinates of the negative endpoint of the axon OR of the desired center point of the axon
                if the x coordinate is the positive endpoint of the axon, set xpos_set=False
                if the x coordinate is the desired center point of the axon, set xpos_set=True

        Optional parameters should be left to defaults and only changed under the following conditions:

            set bent=True IF you want to have the axon bend at a sharp angle near the top of the DRG to prevent it from sticking out of the DRG area 
                NOTE: no adverse effects result from axons sticking out of the DRG area as long as the axon is long enough to avoid terminal effects
            set peak_index IF bent==True AND you already know the node index at which the axon should bend
            set xpos_set=True IF the x coordinate in the pos tuple is the x-coordinate of exact desired center point of the axon
            set xshift IF the x coordinate is the positive endpoint of the axon AND there is a specific desired amount of jitter from the center
        """
        if xpos_set:
            self.sectionDF['x'] = pos[0] - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
        elif isinstance(xshift, float):
            xpos = 0 + (self.deltax*(self.axonnodes/2))
            xpos = xpos + xshift
            self.sectionDF['x'] = xpos - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
            self.xpos = xpos
            self.xshift = xshift 
        else:
            xpos = 0 + (self.deltax*(self.axonnodes/2))
            xshift = (np.random.random()-0.5)*self.deltax
            xpos = xpos + xshift
            self.sectionDF['x'] = xpos - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
            self.xpos = xpos
            self.xshift = xshift 
        #self.sectionDF['x'] = pos[0] - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
        self.sectionDF['y'] = pos[1]
        self.sectionDF['z'] = pos[2]
        if bent:
            if isinstance(peak_index, int):
                cutoff = peak_index
            else:
                idx = (self.sectionDF['x']>-500) & (self.sectionDF['x']<500)
                i = np.random.randint(len(self.sectionDF.loc[idx, 'x']))
                indices = self.sectionDF.index
                i_test = indices[self.sectionDF.x == self.sectionDF.loc[idx,'x'].iloc[i]]
                cutoff = i_test.to_list()[0]
            df1 = self.sectionDF.iloc[:cutoff].copy()
            df2 = self.sectionDF.iloc[cutoff+1:].copy()
            seg_shift = self.sectionDF.loc[cutoff, 'L']/2
            xshift = self.sectionDF.loc[cutoff, 'x']
            yshift = pos[1]
            zshift = pos[2]
            #print(xshift, yshift, zshift)

            if pos[1] >= 750:
                yTube = 600
            elif pos[1] <= -750:
                yTube = -600
            else:
                yTube = yshift

            if pos[2] >= 750:
                zTube = 600
            elif pos[2] <= -750:
                zTube = -600
            else:
                zTube = zshift
            centralTube = (-3200, yTube, zTube)
            peripheralTube = (3200, yTube, zTube)

            rho_p = np.sqrt((xshift - peripheralTube[0]) ** 2 + (yshift - peripheralTube[1]) ** 2 + (zshift - peripheralTube[2]) ** 2)
            a_p = (xshift - peripheralTube[0]) / rho_p
            b_p = (yshift - peripheralTube[1]) / rho_p
            c_p = (zshift - peripheralTube[2]) / rho_p
            if (peripheralTube[1] == yshift) and (peripheralTube[2] == zshift):
                gamma_p = np.pi / 2
            else:
                gamma_p = np.abs(np.arctan((peripheralTube[0] - xshift) / np.sqrt((peripheralTube[1] - yshift)**2 + (peripheralTube[2] - zshift)**2)))

            df1.loc[:,'x'] = xshift + (seg_shift + df1.loc[::-1, 'L'].cumsum() - df1['L'] / 2) * np.round(np.sin(gamma_p), 4)
            df1.loc[:,'y'] = peripheralTube[1] + b_p / a_p * (df1.loc[:,'x'] - peripheralTube[0])
            df1.loc[:,'z'] = peripheralTube[2] + c_p / a_p * (df1.loc[:,'x'] - peripheralTube[0])
            tubeIdx_p = df1['x'] >= peripheralTube[0]
            df1.loc[tubeIdx_p,'z'] = peripheralTube[2]
            df1.loc[tubeIdx_p,'y'] = peripheralTube[1]

            rho_c = np.sqrt((xshift-centralTube[0])**2 + (yshift-centralTube[1])**2 + (zshift-centralTube[2])**2)
            a_c = (xshift-centralTube[0])/rho_c
            b_c = (yshift-centralTube[1])/rho_c
            c_c = (zshift-centralTube[2])/rho_c
            if (centralTube[1] == yshift) and (centralTube[2] == zshift):
                gamma_c = np.pi/2
            else:
                gamma_c = np.abs(np.arctan((centralTube[0] - xshift) / np.sqrt((centralTube[1]-yshift)**2 + (centralTube[2]-zshift)**2)))

            df2.loc[:,'x'] = xshift - (seg_shift + df2.loc[::, 'L'].cumsum() - df2['L'] / 2) * np.round(np.sin(gamma_c), 4)
            df2.loc[:,'y'] = centralTube[1] + b_c/a_c*(df2.loc[:,'x'] - centralTube[0])
            df2.loc[:,'z'] = centralTube[2] + c_c/a_c*(df2.loc[:,'x'] - centralTube[0])
            tubeIdx_c = df2['x'] <= centralTube[0]
            df2.loc[tubeIdx_c,'z'] = centralTube[2]
            df2.loc[tubeIdx_c,'y'] = centralTube[1]

            self.sectionDF.loc[:cutoff, 'x'] = df1.loc[:, 'x']
            self.sectionDF.loc[:cutoff, 'y'] = df1.loc[:, 'y']
            self.sectionDF.loc[:cutoff, 'z'] = df1.loc[:, 'z']
            self.sectionDF.loc[cutoff+1:, 'x'] = df2.loc[:, 'x']
            self.sectionDF.loc[cutoff+1:, 'y'] = df2.loc[:, 'y']
            self.sectionDF.loc[cutoff+1:, 'z'] = df2.loc[:, 'z']
            self.sectionDF.loc[cutoff, 'x'] = xshift
            self.sectionDF.loc[cutoff, 'y'] = yshift
            self.sectionDF.loc[cutoff, 'z'] = zshift

        return True


    def setXYZpos_curve(self, pos):
        # place middle node at center of DRG and curve along with DRG exterior
        # DOES NOT GET USED FOR PAPER RESULTS

        self.sectionDF['x'] = pos[0] - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
        self.sectionDF['y'] = pos[1]
        self.sectionDF['z'] = pos[2]
        R_axon = np.sqrt(pos[1]**2 + pos[2]**2)
        if pos[2] == 0:
            theta = np.pi/2
        else:
            theta = np.abs(np.arctan(pos[1]/pos[2]))
        
        if R_axon>=700:
            x_cutoff = 3700*np.sqrt(1-(R_axon/1500)**2)

            curveIdx = ((self.sectionDF['x']>=x_cutoff) & (self.sectionDF['x']<3200)) | ((self.sectionDF['x']<=-x_cutoff) & (self.sectionDF['x']>-3200))
            self.sectionDF.loc[curveIdx, 'y'] = (1500*np.sqrt(1-(self.sectionDF.loc[curveIdx, 'x']/3700)**2)) * np.round(np.sin(theta), 4)
            self.sectionDF.loc[curveIdx, 'z'] = (1500*np.sqrt(1-(self.sectionDF.loc[curveIdx, 'x']/3700)**2)) * np.round(np.cos(theta), 4)

            tubeIdx = (self.sectionDF['x']>=3200) | (self.sectionDF['x']<=-3200)
            self.sectionDF.loc[tubeIdx, 'y'] = 700 * np.round(np.sin(theta), 4)
            self.sectionDF.loc[tubeIdx, 'z'] = 700 * np.round(np.cos(theta), 4)
        return True

    def getSectionFromDF(self, sType, nodeIdx=None, returnObj=True):
        """
        Parameters: 

        sType: str
            string associated with section type: 'node', 'MYSA', 'FLUT', OR 'STIN'
        nodeIdx: int
            index of the node with which the section is associated
        returnObj: boolean
            True if the section object is desired
            False if the relevant row of the dataframe containing information about the section is desired
        """
        #if nodeIdx:
        if isinstance(nodeIdx, int):
            #print(nodeIdx)
            dfRow = self.sectionDF[(self.sectionDF ['sectionType'] == sType) &
                           (self.sectionDF ['nodeIndex'] == nodeIdx)]
        elif not(nodeIdx) and sType:
            #print('no node idx')
            dfRow = self.sectionDF[self.sectionDF['sectionType'] == sType]
        else:
            raise ValueError('missing input argument for branch and/or node index')

        if returnObj:
            return dfRow['object'].item()
        else:
            return dfRow

    def getModelParamsContinuous(self, fiberD):
        """Generate nterpolated model parameters for all fiberD values based on interpolation functions 
        applied to the discrete values defined in the MRG model."""
        self.g = line(fiberD, 0.01716804, 0.5075587)
        self.axonD = line(fiberD, 0.88904883, -1.9104369)
        self.nodeD = line(fiberD, 0.34490792, -0.14841106)
        self.paraD1 = line(fiberD,  0.34490792, -0.14841106)
        self.paraD2 = line(fiberD,  0.88904883, -1.9104369)
        self.deltax = logreg(fiberD, 3.79906687e+03,  2.13820902e+00,  2.48122018e-01, -2.19548067e+03)
        self.paralength2 = logreg(fiberD, 30.77203038, 10.53182692,  0.42725082, 31.47653035)
        self.nl = int(round(logfn(fiberD,  65.89739004, -32.66582976)))

        self.Rpn0 = (self.rhoa * .01) / (np.pi * ((((self.nodeD / 2) + self.space_p1) ** 2) - ((self.nodeD / 2) ** 2)))
        self.Rpn1 = (self.rhoa * .01) / (np.pi * ((((self.paraD1 / 2) + self.space_p1) ** 2) - ((self.paraD1 / 2) ** 2)))
        self.Rpn2 = (self.rhoa * .01) / (np.pi * ((((self.paraD2 / 2) + self.space_p2) ** 2) - ((self.paraD2 / 2) ** 2)))
        self.Rpx = (self.rhoa * .01) / (np.pi * ((((self.axonD / 2) + self.space_i) ** 2) - ((self.axonD / 2) ** 2)))
        self.interlength = (self.deltax - self.nodelength - (2 * self.paralength1) - (2 * self.paralength2)) / 6
        

    def getModelParameters(self, fiberD):
        """ ALTERNATE PARAMETER SETTING - BEST IF USED FOR DISCRETE fiberD VALUES DEFINED IN MRG MODEL.
        Obtain model parameters for discrete fiberD values from the MRG model. Generate interpolated model parameters
        for non-MRG fiberD values. Interpolation functions obtained from Gaines et. al."""
        if fiberD == 5.7:
            self.g = 0.605
            self.axonD = 3.4
            self.nodeD = 1.9
            self.paraD1 = 1.9
            self.paraD2 = 3.4
            self.deltax = 500.0
            self.paralength2 = 35.0
            self.nl = 80.0
        elif fiberD == 7.3:
            self.g = 0.630
            self.axonD = 4.6
            self.nodeD = 2.4
            self.paraD1 = 2.4
            self.paraD2 = 4.6
            self.deltax = 750.0
            self.paralength2 = 38.0
            self.nl = 100.0
        elif fiberD == 8.7:
            self.g = 0.661
            self.axonD = 5.8
            self.nodeD = 2.8
            self.paraD1 = 2.8
            self.paraD2 = 5.8
            self.deltax = 1000.0
            self.paralength2 = 40.0
            self.nl = 110.0
        elif fiberD == 10.0:
            self.g = 0.690
            self.axonD = 6.9
            self.nodeD = 3.3
            self.paraD1 = 3.3
            self.paraD2 = 6.9
            self.deltax = 1150.0
            self.paralength2 = 46.0
            self.nl = 120.0
        elif fiberD == 11.5:
            self.g = 0.700
            self.axonD = 8.1
            self.nodeD = 3.7
            self.paraD1 = 3.7
            self.paraD2 = 8.1
            self.deltax = 1250.0
            self.paralength2 = 50.0
            self.nl = 130.0
        elif fiberD == 12.8:
            self.g = 0.719
            self.axonD = 9.2
            self.nodeD = 4.2
            self.paraD1 = 4.2
            self.paraD2 = 9.2
            self.deltax = 1350.0
            self.paralength2 = 54.0
            self.nl = 135.0
        elif fiberD == 14.0:
            self.g = 0.739
            self.axonD = 10.4
            self.nodeD = 4.7
            self.paraD1 = 4.7
            self.paraD2 = 10.4
            self.deltax = 1400.0
            self.paralength2 = 56.0
            self.nl = 140.0
        elif fiberD == 15.0:
            self.g = 0.767
            self.axonD = 11.5
            self.nodeD = 5.0
            self.paraD1 = 5.0
            self.paraD2 = 11.5
            self.deltax = 1450.0
            self.paralength2 = 58.0
            self.nl = 145.0
        elif fiberD == 16.0:
            self.g = 0.791
            self.axonD = 12.7
            self.nodeD = 5.5
            self.paraD1 = 5.5
            self.paraD2 = 12.7
            self.deltax = 1500.0
            self.paralength2 = 60.0
            self.nl = 150.0
        else:
            self.g = 0.0172 * fiberD + 0.5076
            self.axonD = 0.889 * fiberD - 1.9104
            self.nodeD = 0.3449 * fiberD - 0.1484
            self.paraD1 = 0.3527 * fiberD - 0.1804
            self.paraD2 = 0.889 * fiberD - 1.9104
            self.deltax = 969.3 * np.log(fiberD) - 1144.6
            self.paralength2 = 2.5811 * fiberD + 19.59
            self.nl = (65.897 * np.log(fiberD) - 32.666)


        self.Rpn0 = (self.rhoa * .01) / (np.pi * ((((self.nodeD / 2) + self.space_p1) ** 2) - ((self.nodeD / 2) ** 2)))
        self.Rpn1 = (self.rhoa * .01) / (np.pi * ((((self.paraD1 / 2) + self.space_p1) ** 2) - ((self.paraD1 / 2) ** 2)))
        self.Rpn2 = (self.rhoa * .01) / (np.pi * ((((self.paraD2 / 2) + self.space_p2) ** 2) - ((self.paraD2 / 2) ** 2)))
        self.Rpx = (self.rhoa * .01) / (np.pi * ((((self.axonD / 2) + self.space_i) ** 2) - ((self.axonD / 2) ** 2)))
        self.interlength = (self.deltax - self.nodelength - (2 * self.paralength1) - (2 * self.paralength2)) / 6

    def plotMembraneV(self):
        t_vec = np.linspace(0,len(self.membraneV)*h.dt,len(self.membraneV))
        plt.plot(t_vec, self.membraneV)
        plt.xlabel('time (ms)')
        plt.ylabel('mV')
        plt.show()