# class definitions to generate an MRG axon and trial objects
from neuron import h
import numpy as np
import pandas as pd
import os
from scipy.interpolate import LinearNDInterpolator

def line(x, a, b):
    return a*x + b


def logfn(x, h, k):
    return h*np.log(x) + k


def logreg(x, L, x0, k, b):
    return (L/(1.0 + np.exp(-k*(x-x0)))) + b 

h.celsius = 37.0
h.dt = 0.005  # ms

class pseudounipolar_neuron:
    """Class to create a pseudounipolar neuron

    Geometric properties of the soma and neck are obtained from Devor et al. Structure and electrical properties are
    obtained from the MRG model modified for sensory neurons in Gaines et al. Contains functions to generate each branch
    of the neuron with user defined properties and set the XYZ position.


    NOTE
    ----
    Instantiating the pseudounipolar_neuron object DOES NOT automatically generate positions for the neuron segments.
    Object function setXYZpos MUST be run on the pseudounipolar_neuron object AFTER it has been instantiated.
    """

    def __init__(self, centralFiberD=5.7, peripheralFiberD=7.3, neckFiberD=7.3, numNodes_p = 100, numNodes_c = 100, somaSize = 80, femElec=''):
        """
        Parameters
        ----------
        centralFiberD : float
            diameter of central branch
        peripheralFiberD: float
            diameter of peripheral branch
        neckFiberD: float
            diameter of neck branch
        numNodes_p: int
            number of nodes after abnormal segments in peripheral branch
        numNodes_c : int
            number of nodes after abnormal segments in central branch
        femdf : dict
            dictionary of FEM dataframes
        femElec : str
            electrode used to run simulation
        """

        self.centralFiberD = centralFiberD
        self.peripheralFiberD = peripheralFiberD
        self.neckFiberD = neckFiberD
        self.timestep = h.dt

        # number of nodes after initial abnormal sections
        self.numNodes_c = numNodes_c
        self.numNodes_p = numNodes_p
        self.somaSize = somaSize

        # electrical parameters
        self.rhoa = 0.7e6   # [Ohm-um] specific axoplasmic resistance
        self.mycm = 0.1     # uF/cm2/lamella membrane//
        self.mygm = 0.001   # S/cm2/lamella membrane//

        self.paralength1 = 3.0
        self.nodelength = 1.0
        self.space_p1 = 0.002
        self.space_p2 = 0.004
        self.space_i = 0.004

        self.sectionDF = pd.DataFrame(columns=['nodeIndex', 'sectionIndex', 'object', 'puniBranch', 'sectionType', 'x', 'y', 'z', 'L'])
        neckRoot = self.createSomaSection()
        tjnRoot = self.createNeckSections(neckRoot)
        centralRoot = self.createCentralBranch(tjnRoot)
        peripheralRoot = self.createPeripheralBranch(tjnRoot)

        neckIdx = self.sectionDF['puniBranch'] == 'n'
        self.stemLength = self.sectionDF[neckIdx]['L'].sum()

        self.etype = femElec

    def createSomaSection(self):
        """Create soma and iseg sections with modified sodium channel densities"""

        soma = h.Section(name='soma')
        soma.nseg = 1
        soma.L = self.somaSize
        soma.diam = self.somaSize
        soma.Ra = self.rhoa / 10000
        soma.cm = 2

        soma.insert('node_sensory')
        for seg in list(soma.allseg())[1:-1]:
            seg.node_sensory.gnabar = seg.node_sensory.gnabar * (300/2000)  # 0.45
            seg.node_sensory.gnapbar = seg.node_sensory.gnapbar * (300/2000)  # 0.0015
            seg.node_sensory.el = -80

        soma.insert('extracellular')
        soma.e_extracellular = 0
        soma.xraxial[0] = (self.rhoa*.01)/(np.pi*((((soma.diam/2)+self.space_p1)**2)-((soma.diam/2)**2)))
        soma.xg[0] = 1e10
        soma.xc[0] = 0.0

        self.sectionDF = self.sectionDF.append({'sectionType': 'soma',
                                                'nodeIndex': 0,
                                                'sectionIndex': 0,
                                                'puniBranch': 'n',
                                                'x': 0, 'y': 0, 'z': 0, 'L': soma.L, 'object': soma}, ignore_index=True)

        # stem unmyelinated segment
        iseg = h.Section(name='iseg')
        iseg.nseg = 33
        iseg.L = 200.0
        iseg.diam = (5.0/7.3)*self.neckFiberD # adjusting linearly for axon size
        iseg.Ra = self.rhoa / 10000
        iseg.cm = 2.0

        iseg.insert('extracellular')
        iseg.e_extracellular = 0
        iseg.xraxial[0] = (self.rhoa*.01)/(np.pi*((((iseg.diam/2)+self.space_p1)**2)-((iseg.diam/2)**2)))
        iseg.xg[0] = 1e10
        iseg.xc[0] = 0.0

        iseg.insert('node_sensory')
        for idx, seg in enumerate(list(iseg.allseg())[1:-1]):

            if idx == 0:
                seg.node_sensory.gnabar = seg.node_sensory.gnabar * (1000/2000)     # 0.75
                seg.node_sensory.gnapbar = seg.node_sensory.gnapbar * (1000/2000)   # 0.0025
            else:
                seg.node_sensory.gnabar = seg.node_sensory.gnabar * (600 / 2000)  # 0.75
                seg.node_sensory.gnapbar = seg.node_sensory.gnapbar * (600 / 2000)  # 0.0025


            seg.node_sensory.el = -80

        self.sectionDF = self.sectionDF.append({'sectionType':'iseg',
                                 'nodeIndex':0,
                                 'sectionIndex':0,
                                 'puniBranch':'n',
                                 'x':0,'y':0,'z':0,'L':iseg.L,'object':iseg},ignore_index=True)
        iseg.connect(soma)

        return iseg

    def createNeckSections(self, parentNode):
        """Create stem MRG axon sections with abnormal internode and myelination"""

        pair1 = self.MRGaxon(axonnodes=1, fiberD=self.neckFiberD, prefix='n', deltaX=85, nlFactor=0.07747)
        pair2 = self.MRGaxon(axonnodes=1, fiberD=self.neckFiberD, prefix='n', deltaX=130, nlFactor=0.48592)
        pair3 = self.MRGaxon(axonnodes=1, fiberD=self.neckFiberD, prefix='n', deltaX=168, nlFactor=0.66197)
        pair4 = self.MRGaxon(axonnodes=1, fiberD=self.neckFiberD, prefix='n', deltaX=201, nlFactor=1.0)

        pair1[0].connect(parentNode)
        pair2[0].connect(pair1[-1])
        pair3[0].connect(pair2[-1])
        pair4[0].connect(pair3[-1])

        modelParams = self.getModelParamsContinuous(self.neckFiberD, None, 1, 'n')
        tjNode = self.createNodeSection(modelParams, 'n', parentSection=pair4[-1], sectionType='tjnNode')

        return tjNode

    def createCentralBranch(self, parentSection):
        """Create central MRG axon sections with 3 abnormal sections"""

        pair1 = self.MRGaxon(axonnodes=1, fiberD=self.centralFiberD, prefix='c', deltaX=358, spec='c1')
        pair2 = self.MRGaxon(axonnodes=1, fiberD=self.centralFiberD, prefix='c', deltaX=780, spec='c2')
        pair3 = self.MRGaxon(axonnodes=1, fiberD=self.centralFiberD, prefix='c', deltaX=1170, spec='c3')
        fiber = self.MRGaxon(axonnodes=self.numNodes_c, fiberD=self.centralFiberD, prefix='c')

        pair1[0].connect(parentSection)
        pair2[0].connect(pair1[-1])
        pair3[0].connect(pair2[-1])
        fiber[0].connect(pair3[-1])

        return fiber

    def createPeripheralBranch(self, parentSection):
        """Create peripheral MRG axon sections with 3 abnormal sections"""

        pair1 = self.MRGaxon(axonnodes=1, fiberD=self.peripheralFiberD, prefix='p', deltaX=461, spec='p1')
        pair2 = self.MRGaxon(axonnodes=1, fiberD=self.peripheralFiberD, prefix='p', deltaX=670, spec='p2')
        pair3 = self.MRGaxon(axonnodes=1, fiberD=self.peripheralFiberD, prefix='p', deltaX=1119, spec='p3')
        fiber2 = self.MRGaxon(axonnodes=self.numNodes_p, fiberD=self.peripheralFiberD, prefix='p')

        pair1[0].connect(parentSection)
        pair2[0].connect(pair1[-1])
        pair3[0].connect(pair2[-1])
        fiber2[0].connect(pair3[-1])

        return fiber2

    def MRGaxon(self, axonnodes, fiberD, prefix, deltaX=None, nlFactor=1.0, spec='none'):
        """Create axon fiber using sections defined in MRG model

        order of sections: Node-MYSA-FLUT-[STIN x 6]-FLUT-MYSA
        note: each iteration ends in a MYSA section so that the output can be connected with a node parent
        """

        modelParams = self.getModelParamsContinuous(fiberD, deltaX, nlFactor, prefix)

        # construct axon
        parentSection = []
        for iNode in range(axonnodes):
            node = self.createNodeSection(modelParams, prefix)
            mysa1 = self.createMYSAsection(modelParams, prefix, node)
            flut1 = self.createFLUTsection(modelParams, prefix, mysa1)
            STINroot = flut1
            for iSTIN in range(6):
                STINroot = self.createSTINsection(modelParams, prefix, parentSection=STINroot)
            flut2 = self.createFLUTsection(modelParams, prefix, STINroot)
            mysa2 = self.createMYSAsection(modelParams, prefix, flut2)
            if parentSection:
                node.connect(parentSection)
            parentSection = mysa2

        return mysa2.wholetree()

    def createNodeSection(self, modelparams, prefix, parentSection=None, sectionType='node'):
        """Create Node section with node_sensory mechanism from Gaines et. al"""

        tmp = h.Section(name='node')
        tmp.nseg = 1
        tmp.L = self.nodelength
        tmp.Ra = self.rhoa / 10000
        tmp.cm = 2.0
        tmp.diam = modelparams['nodeD']

        # xc is a hoc nickname for xc[0], i.e. the capacitance in the extracellular layer that is immediately adjacent to the
        # cell membrane. xc[1] is the capacitance in the outer extracellular layer. Similar comments apply to xg and xg[1].

        tmp.insert('node_sensory')
        tmp.insert('extracellular')
        tmp.e_extracellular = 0
        tmp.xraxial[0] = modelparams['Rpn0']
        tmp.xg[0] = 1e10
        tmp.xc[0] = 0.0

        self.sectionDF = self.sectionDF.append({'sectionType': sectionType,
                                                 'nodeIndex':((self.sectionDF['sectionType']=='node')&((self.sectionDF['puniBranch']==prefix))).sum(),
                                                 'sectionIndex': (self.sectionDF['sectionType']=='node').sum(),
                                                 'puniBranch': prefix,
                                                 'x': 0, 'y': 0, 'z': 0, 'L':tmp.L,'object': tmp},ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def createMYSAsection(self, modelparams, prefix, parentSection=None):
        """Create MYSA section with mysa_sensory mechanism from Gaines et. al"""

        tmp = h.Section(name='MYSA')
        tmp.nseg = 1
        tmp.L = self.paralength1
        tmp.Ra = self.rhoa * (1 / (modelparams['paraD1'] / modelparams['fiberD']) ** 2.0) / 10000
        tmp.cm = 2.0 * modelparams['paraD1'] / modelparams['fiberD'] #/ self.myelinLayers
        tmp.diam = modelparams['fiberD']

        tmp.insert('mysa_sensory')

        tmp.insert('extracellular')
        tmp.e_extracellular = 0
        tmp.xraxial[0] = modelparams['Rpn1']
        tmp.xg[0] = self.mygm / (modelparams['nl'] * 2)
        tmp.xc[0] = self.mycm / (modelparams['nl'] * 2)

        self.sectionDF = self.sectionDF.append({'sectionType': 'MYSA',
                                                'nodeIndex': ((self.sectionDF['sectionType'] == 'node') & ((self.sectionDF['puniBranch'] == prefix))).sum(),
                                                'sectionIndex': (self.sectionDF['sectionType']=='MYSA').sum(),
                                                'puniBranch': prefix,
                                                'x': 0, 'y': 0, 'z': 0,'L':tmp.L, 'object': tmp}, ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def createFLUTsection(self, modelparams, prefix, parentSection=None):
        """Create FLUT section with flut_sensory mechanism from Gaines et. al"""

        tmp = h.Section(name='FLUT')
        tmp.nseg = 1
        tmp.L = modelparams['paralength2']
        tmp.Ra = self.rhoa * (1 / (modelparams['paraD2'] / modelparams['fiberD']) ** 2.0) / 10000
        tmp.cm = 2.0 * modelparams['paraD2'] / modelparams['fiberD'] #/ self.myelinLayers
        tmp.diam = modelparams['fiberD']

        tmp.insert('flut_sensory')


        tmp.insert('extracellular')
        tmp.e_extracellular = 0
        tmp.xraxial[0] = modelparams['Rpn2']
        tmp.xg[0] = self.mygm / (modelparams['nl'] * 2)
        tmp.xc[0] = self.mycm / (modelparams['nl'] * 2)

        self.sectionDF = self.sectionDF.append({'sectionType': 'FLUT',
                                                'nodeIndex':((self.sectionDF['sectionType']=='node')&((self.sectionDF['puniBranch']==prefix))).sum(),
                                                'sectionIndex': (self.sectionDF['sectionType']=='FLUT').sum(),
                                                'puniBranch': prefix,
                                                'x': 0, 'y': 0, 'z': 0,'L':tmp.L, 'object': tmp}, ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def createSTINsection(self, modelparams, prefix, parentSection=None):
        """Create STIN section with stin_sensory mechanism from Gaines et. al"""

        tmp = h.Section(name='STIN')
        tmp.nseg = 1
        tmp.L = modelparams['interlength']
        tmp.Ra = self.rhoa * (1 / (modelparams['axonD'] / modelparams['fiberD']) ** 2.0) / 10000
        tmp.cm = 2 * modelparams['axonD'] / modelparams['fiberD'] #/ self.myelinLayers
        tmp.diam = modelparams['fiberD']
        #tmp.diam = modelparams['paraD2']

        tmp.insert('stin_sensory')

        # tmp.insert('pas')
        # tmp.g_pas = 0.0001 * modelparams['axonD'] / modelparams['fiberD']
        # tmp.e_pas = -80.0

        tmp.insert('extracellular')
        tmp.e_extracellular = 0
        tmp.xraxial[0] = modelparams['Rpx']
        tmp.xg[0] = self.mygm / (modelparams['nl'] * 2)
        tmp.xc[0] = self.mycm / (modelparams['nl'] * 2)

        self.sectionDF = self.sectionDF.append({'sectionType': 'STIN',
                                                'nodeIndex':((self.sectionDF['sectionType']=='node')&((self.sectionDF['puniBranch']==prefix))).sum(),
                                                'sectionIndex': (self.sectionDF['sectionType']=='STIN').sum(),
                                                'puniBranch': prefix,
                                                'x': 0, 'y': 0, 'z': 0,'L':tmp.L, 'object': tmp}, ignore_index=True)

        if parentSection:
            tmp.connect(parentSection)

        return tmp

    def getModelParamsContinuous(self, fiberD, deltax, nlfactor, branch):
        """Generate nterpolated model parameters for all fiberD values based on interpolation functions 
        applied to the discrete values defined in the MRG model."""
        g = line(fiberD, 0.01716804, 0.5075587)
        axonD = line(fiberD, 0.88904883, -1.9104369)
        nodeD = line(fiberD, 0.34490792, -0.14841106)
        paraD1 = line(fiberD,  0.34490792, -0.14841106)
        paraD2 = line(fiberD,  0.88904883, -1.9104369)
        if not deltax:
            deltax = logreg(fiberD, 3.79906687e+03,  2.13820902e+00,  2.48122018e-01, -2.19548067e+03)
        nl = int(round(logfn(fiberD,  65.89739004, -32.66582976))*nlfactor)

        Rpn0 = (self.rhoa * .01) / (np.pi * ((((nodeD / 2) + self.space_p1) ** 2) - ((nodeD / 2) ** 2)))
        Rpn1 = (self.rhoa * .01) / (np.pi * ((((paraD1 / 2) + self.space_p1) ** 2) - ((paraD1 / 2) ** 2)))
        Rpn2 = (self.rhoa * .01) / (np.pi * ((((paraD2 / 2) + self.space_p2) ** 2) - ((paraD2 / 2) ** 2)))
        Rpx = (self.rhoa * .01) / (np.pi * ((((axonD / 2) + self.space_i) ** 2) - ((axonD / 2) ** 2)))
        if branch == 'n':
            # smallest MRG fiberD has paralength2 of 35
            paralength2 = 35
            interlength = (deltax - self.nodelength - (2 * self.paralength1) - (2 * paralength2)) / 6
            if deltax == 85:
                nodeD = fiberD*(5.0/7.3)
        else:
            paralength2 = logreg(fiberD, 30.77203038, 10.53182692,  0.42725082, 31.47653035)
            interlength = (deltax - self.nodelength - (2 * self.paralength1) - (2 * paralength2)) / 6

        return {'fiberD':fiberD,'g':g, 'axonD':axonD, 'nodeD':nodeD, 'paraD1':paraD1, 'paraD2':paraD2,'deltax':deltax,
                'paralength2':paralength2,'nl':nl, 'Rpn0':Rpn0, 'Rpn1':Rpn1, 'Rpn2':Rpn2, 'Rpx':Rpx,'interlength':interlength}

    def getModelParameters(self, fiberD, deltax, nlfactor, branch):
        """ ALTERNATE PARAMETER SETTING - BEST IF USED FOR DISCRETE fiberD VALUES DEFINED IN MRG MODEL.
        Obtain model parameters for discrete fiberD values from the MRG model. Generate interpolated model parameters
        for non-MRG fiberD values. Interpolation functions obtained from Gaines et. al."""

        if fiberD == 5.7:
            g = 0.605
            axonD = 3.4
            nodeD = 1.9
            paraD1 = 1.9
            paraD2 = 3.4
            if not deltax:
                deltax = 500.0
            paralength2 = 35.0
            nl = 80.0*nlfactor
        elif fiberD == 7.3:
            g = 0.630
            axonD = 4.6
            nodeD = 2.4
            paraD1 = 2.4
            paraD2 = 4.6
            if not deltax:
                deltax = 750.0
            paralength2 = 38.0
            nl = 100.0*nlfactor
        elif fiberD == 8.7:
            g = 0.661
            axonD = 5.8
            nodeD = 2.8
            paraD1 = 2.8
            paraD2 = 5.8
            if not deltax:
                deltax = 1000.0
            paralength2 = 40.0
            nl = 110.0*nlfactor
        elif fiberD == 10.0:
            g = 0.690
            axonD = 6.9
            nodeD = 3.3
            paraD1 = 3.3
            paraD2 = 6.9
            if not deltax:
                deltax = 1150.0
            paralength2 = 46.0
            nl = 120.0*nlfactor
        elif fiberD == 11.5:
            g = 0.700
            axonD = 8.1
            nodeD = 3.7
            paraD1 = 3.7
            paraD2 = 8.1
            if not deltax:
                deltax = 1250.0
            paralength2 = 50.0
            nl = 130.0*nlfactor
        elif fiberD == 12.8:
            g = 0.719
            axonD = 9.2
            nodeD = 4.2
            paraD1 = 4.2
            paraD2 = 9.2
            if not deltax:
                deltax = 1350.0
            paralength2 = 54.0
            nl = 135.0*nlfactor
        elif fiberD == 14.0:
            g = 0.739
            axonD = 10.4
            nodeD = 4.7
            paraD1 = 4.7
            paraD2 = 10.4
            if not deltax:
                deltax = 1400.0
            paralength2 = 56.0
            nl = 140.0*nlfactor
        elif fiberD == 15.0:
            g = 0.767
            axonD = 11.5
            nodeD = 5.0
            paraD1 = 5.0
            paraD2 = 11.5
            if not deltax:
                deltax = 1450.0
            paralength2 = 58.0
            nl = 145.0*nlfactor
        elif fiberD == 16.0:
            g = 0.791
            axonD = 12.7
            nodeD = 5.5
            paraD1 = 5.5
            paraD2 = 12.7
            if not deltax:
                deltax = 1500.0
            paralength2 = 60.0
            nl = 150.0*nlfactor
        else:   # interpolation from Gaines et al
            g = 0.0172 * fiberD + 0.5076
            axonD = 0.889 * fiberD - 1.9104
            nodeD = 0.3449 * fiberD - 0.1484
            paraD1 = 0.3527 * fiberD - 0.1804
            paraD2 = 0.889 * fiberD - 1.9104
            if not deltax:
                deltax = 969.3 * np.log(fiberD) - 1144.6
            paralength2 = 2.5811 * fiberD + 19.59
            nl = (65.897 * np.log(fiberD) - 32.666)*nlfactor

        Rpn0 = (self.rhoa * .01) / (np.pi * ((((nodeD / 2) + self.space_p1) ** 2) - ((nodeD / 2) ** 2)))
        Rpn1 = (self.rhoa * .01) / (np.pi * ((((paraD1 / 2) + self.space_p1) ** 2) - ((paraD1 / 2) ** 2)))
        Rpn2 = (self.rhoa * .01) / (np.pi * ((((paraD2 / 2) + self.space_p2) ** 2) - ((paraD2 / 2) ** 2)))
        Rpx = (self.rhoa * .01) / (np.pi * ((((axonD / 2) + self.space_i) ** 2) - ((axonD / 2) ** 2)))
        if branch == 'n':
            # smallest MRG fiberD has paralength2 of 35
            interlength = (deltax - self.nodelength - (2 * self.paralength1) - (2 * 35)) / 6
        else:
            interlength = (deltax - self.nodelength - (2 * self.paralength1) - (2 * paralength2)) / 6

        return {'fiberD':fiberD,'g':g, 'axonD':axonD, 'nodeD':nodeD, 'paraD1':paraD1, 'paraD2':paraD2,'deltax':deltax,
                'paralength2':paralength2,'nl':nl, 'Rpn0':Rpn0, 'Rpn1':Rpn1, 'Rpn2':Rpn2, 'Rpx':Rpx,'interlength':interlength}

    def getSectionFromDF(self, sType, branch=None, nodeIdx=None, returnObj=True):
        """
        Parameters: 

        sType: str
            string associated with section type: 'soma', 'iseg', 'tjnNode', 'node', 'MYSA', 'FLUT', OR 'STIN'
        branch: str
            string associated with the branch of the pseudounipolar neuron: 'n' for neck or stem axon, 'c' for central axon, 'p' for peripheral axon
        nodeIdx: int
            index of the node with which the section is associated
        returnObj: boolean
            True if the section object is desired
            False if the relevant row of the dataframe containing information about the section is desired
        """
        if isinstance(nodeIdx,int) and branch:
            dfRow = self.sectionDF[(self.sectionDF['sectionType'] == sType) &
                           (self.sectionDF['puniBranch'] == branch) &
                           (self.sectionDF['nodeIndex'] == nodeIdx)]
        elif not(nodeIdx and branch) and sType:
            dfRow = self.sectionDF[self.sectionDF['sectionType'] == sType]
        else:
            raise ValueError('missing input argument for branch and/or node index')

        if returnObj:
            return dfRow['object'].item()
        else:
            return dfRow

    def setXYZpos(self, femdict, tjnPos=(0, 0, 0), neckAngle=90):
        """Verify XYZ position of T-junction is inside the DRG and assign xyz coordinates of all sections
        x > 0 --> peripheral branch
        x < 0 --> central branch

        Parameters
        ----------
        femdict : dict
            dictionary of FEM dataframes; IF ALL NEURON POSITIONS ARE ALREADY VERIFIED WITHIN THE DRG, SET femdict = None
        tjnPos: tuple
            (x,y,z) position of t-junction
        neckAngle: float
            angle of the stem axon with the z-axis in the y-z plane
        """
        neckIdx = self.sectionDF['puniBranch'] == 'n'
        if femdict is not None:
            femdata = femdict['DRG']
            stemLength = self.sectionDF[neckIdx]['L'].sum()

            FEM_xMax = tjnPos[0]
            FEM_xPosMask = tjnPos[0] == femdata['x']
            if np.any(FEM_xPosMask):                                # coordinate exists in the fem
                FEM_yMax = femdata[FEM_xPosMask]['y'].max()
                FEM_zMax = femdata[FEM_xPosMask]['z'].max()

            else:                                                   # cooridnate needs interpolation
                # for coord, label in zip(tjnPos[0], 'x'):
                diffList = np.array(femdata['x'].unique()) - FEM_xMax
                x_lowerLim = FEM_xMax + np.max(diffList[diffList < 0])
                x_upperLim = FEM_xMax + np.min(diffList[diffList > 0])

                y_upperLim = femdata[femdata['x'] == x_upperLim]['y'].max()
                y_lowerLim = femdata[femdata['x'] == x_lowerLim]['y'].max()
                FEM_yMax = np.interp(FEM_xMax, [x_lowerLim, x_upperLim], [y_lowerLim, y_upperLim])

                z_upperLim = femdata[femdata['x'] == x_upperLim]['z'].max()
                z_lowerLim = femdata[femdata['x'] == x_lowerLim]['z'].max()
                FEM_zMax = np.interp(FEM_xMax, [x_lowerLim, x_upperLim], [z_lowerLim, z_upperLim])

            if FEM_yMax == FEM_zMax:
                DRG_CSradius_X = FEM_yMax # or FEM_zMax/2
            else:
                DRG_CSradius_X = np.min([FEM_yMax, FEM_zMax])

            punRadius = DRG_CSradius_X - stemLength  # the tjn node can be anywhere inside this circle

        else:
            stemLength = 0
            DRG_CSradius_X = np.inf
            punRadius = np.inf  # if femdata hasnt been provided all positions are valid

        # check that the input position is valid
        # an elegant way would be to create a Polygon object and find if the point is inside the polygon, this will work too
        neckAngle = np.deg2rad(neckAngle)
        xshift = tjnPos[0]
        yshift = tjnPos[1]
        zshift = tjnPos[2]

        if (np.sqrt(zshift ** 2 + yshift ** 2) <= punRadius): # tjn lies inside annulus
            self.sectionDF.loc[neckIdx, 'x'] = xshift
            self.sectionDF.loc[neckIdx, 'y'] = yshift + (self.sectionDF[neckIdx].loc[::-1, 'L'].cumsum()[::-1] - self.sectionDF[neckIdx]['L'] / 2) * np.round(np.sin(neckAngle), 4)
            self.sectionDF.loc[neckIdx, 'z'] = zshift + (self.sectionDF[neckIdx].loc[::-1, 'L'].cumsum()[::-1] - self.sectionDF[neckIdx]['L'] / 2) * np.round(np.cos(neckAngle), 4)

            outsideDRGidx = (self.sectionDF.loc[neckIdx, 'z']**2 + self.sectionDF.loc[neckIdx, 'y']**2)**(1/2) > DRG_CSradius_X
            if outsideDRGidx.any():
                insideDRGidx = (self.sectionDF.loc[neckIdx, 'z'] ** 2 + self.sectionDF.loc[neckIdx, 'y'] ** 2)**(1/2) < DRG_CSradius_X
                self.sectionDF.loc[neckIdx & outsideDRGidx, 'z'] = self.sectionDF.loc[neckIdx & insideDRGidx, 'z'].iloc[0]
                self.sectionDF.loc[neckIdx & outsideDRGidx, 'y'] = self.sectionDF.loc[neckIdx & insideDRGidx, 'y'].iloc[0]

            centralIdx = self.sectionDF['puniBranch'] == 'c'
            self.sectionDF.loc[centralIdx, 'x'] = xshift - (self.sectionDF[centralIdx].loc[::, 'L'].cumsum() - self.sectionDF[centralIdx]['L'] / 2)
            self.sectionDF.loc[centralIdx, 'y'] = yshift
            self.sectionDF.loc[centralIdx, 'z'] = zshift

            periIdx = self.sectionDF['puniBranch'] == 'p'
            self.sectionDF.loc[periIdx, 'x'] = xshift + (self.sectionDF[periIdx].loc[::, 'L'].cumsum() - self.sectionDF[periIdx]['L'] / 2)
            self.sectionDF.loc[periIdx, 'y'] = yshift
            self.sectionDF.loc[periIdx, 'z'] = zshift

            return 1
        elif (stemLength*np.sin(neckAngle) < DRG_CSradius_X) and (stemLength*np.cos(neckAngle) < DRG_CSradius_X): #soma coordinates are inside drg

            self.sectionDF.loc[neckIdx, 'x'] = xshift
            self.sectionDF.loc[neckIdx, 'y'] = yshift + (self.sectionDF[neckIdx].loc[::-1, 'L'].cumsum()[::-1] - self.sectionDF[neckIdx]['L'] / 2) * np.round(np.sin(neckAngle), 4)
            self.sectionDF.loc[neckIdx, 'z'] = zshift + (self.sectionDF[neckIdx].loc[::-1, 'L'].cumsum()[::-1] - self.sectionDF[neckIdx]['L'] / 2) * np.round(np.cos(neckAngle), 4)

            outsideDRGidx = (self.sectionDF.loc[neckIdx, 'z']**2 + self.sectionDF.loc[neckIdx, 'y']**2)**(1/2) > DRG_CSradius_X
            if outsideDRGidx.any():
                insideDRGidx = (self.sectionDF.loc[neckIdx, 'z'] ** 2 + self.sectionDF.loc[neckIdx, 'y'] ** 2)**(1/2) < DRG_CSradius_X
                self.sectionDF.loc[neckIdx & outsideDRGidx, 'z'] = self.sectionDF.loc[neckIdx & insideDRGidx, 'z'].iloc[0]
                self.sectionDF.loc[neckIdx & outsideDRGidx, 'y'] = self.sectionDF.loc[neckIdx & insideDRGidx, 'y'].iloc[0]

            if tjnPos[1] >= 750:
                yTube = 600
            elif tjnPos[1] <= -750:
                yTube = -600
            else:
                yTube = yshift

            if tjnPos[2] >= 750:
                zTube = 600
            elif tjnPos[2] <= -750:
                zTube = -600
            else:
                zTube = zshift
            centralTube = (-3200, yTube, zTube)
            peripheralTube = (3200, yTube, zTube)

            rho_c = np.sqrt((tjnPos[0]-centralTube[0])**2 + (tjnPos[1]-centralTube[1])**2 + (tjnPos[2]-centralTube[2])**2)
            a_c = (tjnPos[0]-centralTube[0])/rho_c
            b_c = (tjnPos[1]-centralTube[1])/rho_c
            c_c = (tjnPos[2]-centralTube[2])/rho_c
            if centralTube[2] == zshift:
                gamma_c = np.pi/2
            else:
                gamma_c = np.abs(np.arctan((centralTube[0] - xshift) / (centralTube[2]-zshift)))
            centralIdx = self.sectionDF['puniBranch'] == 'c'
            self.sectionDF.loc[centralIdx, 'x'] = xshift - (self.sectionDF[centralIdx].loc[::, 'L'].cumsum() - self.sectionDF[centralIdx]['L'] / 2) * np.round(np.sin(gamma_c), 4)
            self.sectionDF.loc[centralIdx, 'y'] = centralTube[1] + b_c/a_c*(self.sectionDF.loc[centralIdx, 'x'] - centralTube[0])
            self.sectionDF.loc[centralIdx, 'z'] = centralTube[2] + c_c/a_c*(self.sectionDF.loc[centralIdx, 'x'] - centralTube[0])
            tubeIdx_c = self.sectionDF[centralIdx]['x'] <= centralTube[0]
            self.sectionDF.loc[(centralIdx & tubeIdx_c),'z'] = centralTube[2]
            self.sectionDF.loc[(centralIdx & tubeIdx_c), 'y'] = centralTube[1]

            rho_p = np.sqrt((tjnPos[0] - peripheralTube[0]) ** 2 + (tjnPos[1] - peripheralTube[1]) ** 2 + (tjnPos[2] - peripheralTube[2]) ** 2)
            a_p = (tjnPos[0] - peripheralTube[0]) / rho_p
            b_p = (tjnPos[1] - peripheralTube[1]) / rho_p
            c_p = (tjnPos[2] - peripheralTube[2]) / rho_p
            if centralTube[2] == zshift:
                gamma_p = np.pi / 2
            else:
                gamma_p = np.abs(np.arctan((peripheralTube[0] - xshift) / (peripheralTube[2] - zshift)))
            periIdx = self.sectionDF['puniBranch'] == 'p'
            self.sectionDF.loc[periIdx, 'x'] = xshift + (self.sectionDF[periIdx].loc[::, 'L'].cumsum() - self.sectionDF[periIdx]['L'] / 2) * np.round(np.sin(gamma_p), 4)
            self.sectionDF.loc[periIdx, 'y'] = peripheralTube[1] + b_p / a_p * (self.sectionDF.loc[periIdx, 'x'] - peripheralTube[0])
            self.sectionDF.loc[periIdx, 'z'] = peripheralTube[2] + c_p / a_p * (self.sectionDF.loc[periIdx, 'x'] - peripheralTube[0])
            tubeIdx_p = self.sectionDF[periIdx]['x'] >= peripheralTube[0]
            self.sectionDF.loc[(periIdx & tubeIdx_p),'z'] = centralTube[2]
            self.sectionDF.loc[(periIdx & tubeIdx_p), 'y'] = centralTube[1]

            return 1
        else:
            return 0

    