"""

Description: This file defines the class that acts as a template for creating the myelinated axon.
             It is based on works of Scurfield and Latimer (2018), which in turn is based on a model by Gow and Devaux (2008).

Edit History: Created by Nilapratim Sengupta in July-August 2023.

"""

# Import statements
from neuron import h

# Defining the class
class AxonTemplate:

    def __init__(self, count_axonSegments):
        
        # Defining number of sections
        self.count_axonSegments = count_axonSegments
        self.noOfNodes_axon = (count_axonSegments + 1)
        self.noOfParanodeSets_axon = (2 * count_axonSegments) 
        self.noOfJuxtaparanodeSets_axon = (2 * count_axonSegments)
        self.noOfInternodes_axon = count_axonSegments

        # Creating sections
        self.hill_axon = h.Section(name='hill_axon')
        self.iseg_axon = h.Section(name='iseg_axon')
        self.nodes_axon = [h.Section(name=f'nodes_axon[{loopCounter}]') for loopCounter in range(self.noOfNodes_axon)]
        self.paranodeOnes_axon = [h.Section(name=f'paranodeOnes_axon[{loopCounter}]') for loopCounter in range(self.noOfParanodeSets_axon)]
        self.paranodeTwos_axon = [h.Section(name=f'paranodeTwos_axon[{loopCounter}]') for loopCounter in range(self.noOfParanodeSets_axon)]
        self.paranodeThrees_axon = [h.Section(name=f'paranodeThrees_axon[{loopCounter}]') for loopCounter in range(self.noOfParanodeSets_axon)]
        self.paranodeFours_axon = [h.Section(name=f'paranodeFours_axon[{loopCounter}]') for loopCounter in range(self.noOfParanodeSets_axon)]
        self.juxtaparanodes_axon = [h.Section(name=f'juxtaparanodes_axon[{loopCounter}]') for loopCounter in range(self.noOfJuxtaparanodeSets_axon)]
        self.internodes_axon = [h.Section(name=f'internodes_axon[{loopCounter}]') for loopCounter in range(self.noOfInternodes_axon)]

        # Defining topology of the myelinated axon
        self.iseg_axon.connect(self.hill_axon(1))
        self.nodes_axon[0].connect(self.iseg_axon(1))
        for loopCounter in range(self.noOfNodes_axon - 1):
            self.paranodeOnes_axon[2*loopCounter].connect(self.nodes_axon[loopCounter](1))
            self.paranodeTwos_axon[2*loopCounter].connect(self.paranodeOnes_axon[2*loopCounter](1))
            self.paranodeThrees_axon[2*loopCounter].connect(self.paranodeTwos_axon[2*loopCounter](1))
            self.paranodeFours_axon[2*loopCounter].connect(self.paranodeThrees_axon[2*loopCounter](1))
            self.juxtaparanodes_axon[2*loopCounter].connect(self.paranodeFours_axon[2*loopCounter](1))
            self.internodes_axon[loopCounter].connect(self.juxtaparanodes_axon[2*loopCounter](1))
            self.juxtaparanodes_axon[2*loopCounter+1].connect(self.internodes_axon[loopCounter](1))
            self.paranodeFours_axon[2*loopCounter+1].connect(self.juxtaparanodes_axon[2*loopCounter+1](1))
            self.paranodeThrees_axon[2*loopCounter+1].connect(self.paranodeFours_axon[2*loopCounter+1](1))
            self.paranodeTwos_axon[2*loopCounter+1].connect(self.paranodeThrees_axon[2*loopCounter+1](1))
            self.paranodeOnes_axon[2*loopCounter+1].connect(self.paranodeTwos_axon[2*loopCounter+1](1))
            self.nodes_axon[loopCounter+1].connect(self.paranodeOnes_axon[2*loopCounter+1](1))
            
        # Creating section lists
        self.Nodes_axon = h.SectionList()
        self.ParanodeOnes_axon = h.SectionList()
        self.ParanodeTwos_axon = h.SectionList()
        self.ParanodeThrees_axon = h.SectionList()
        self.ParanodeFours_axon = h.SectionList()
        self.Paranodes_axon = h.SectionList()
        self.Juxtaparanodes_axon = h.SectionList()
        self.Internodes_axon = h.SectionList()
        self.TotalAxon = h.SectionList()
        self.ExposedAxon = h.SectionList()
        self.MyelinatedAxon = h.SectionList()

        # Defining subsets within the myelinated axon
        for loopCounter in range(self.noOfNodes_axon):
            self.Nodes_axon.append(self.nodes_axon[loopCounter])
            self.TotalAxon.append(self.nodes_axon[loopCounter])
            self.ExposedAxon.append(self.nodes_axon[loopCounter])
            
        for loopCounter in range(self.noOfParanodeSets_axon):
            self.ParanodeOnes_axon.append(self.paranodeOnes_axon[loopCounter])
            self.ParanodeTwos_axon.append(self.paranodeTwos_axon[loopCounter])
            self.ParanodeThrees_axon.append(self.paranodeThrees_axon[loopCounter])
            self.ParanodeFours_axon.append(self.paranodeFours_axon[loopCounter])
            self.Paranodes_axon.append(self.paranodeOnes_axon[loopCounter])
            self.Paranodes_axon.append(self.paranodeTwos_axon[loopCounter])
            self.Paranodes_axon.append(self.paranodeThrees_axon[loopCounter])
            self.Paranodes_axon.append(self.paranodeFours_axon[loopCounter])
            self.TotalAxon.append(self.paranodeOnes_axon[loopCounter])
            self.TotalAxon.append(self.paranodeTwos_axon[loopCounter])
            self.TotalAxon.append(self.paranodeThrees_axon[loopCounter])
            self.TotalAxon.append(self.paranodeFours_axon[loopCounter])
            self.MyelinatedAxon.append(self.paranodeOnes_axon[loopCounter])
            self.MyelinatedAxon.append(self.paranodeTwos_axon[loopCounter])
            self.MyelinatedAxon.append(self.paranodeThrees_axon[loopCounter])
            self.MyelinatedAxon.append(self.paranodeFours_axon[loopCounter])
            
        for loopCounter in range(self.noOfJuxtaparanodeSets_axon):
            self.Juxtaparanodes_axon.append(self.juxtaparanodes_axon[loopCounter])
            self.TotalAxon.append(self.juxtaparanodes_axon[loopCounter])
            self.MyelinatedAxon.append(self.juxtaparanodes_axon[loopCounter])
            
        for loopCounter in range(self.noOfInternodes_axon):
            self.Internodes_axon.append(self.internodes_axon[loopCounter])
            self.TotalAxon.append(self.internodes_axon[loopCounter])
            self.MyelinatedAxon.append(self.internodes_axon[loopCounter])
            
        self.TotalAxon.append(self.hill_axon)
        self.TotalAxon.append(self.iseg_axon)
        self.ExposedAxon.append(self.hill_axon)
        self.ExposedAxon.append(self.iseg_axon)