from xml.etree import ElementTree as ET

neuroml_ns='http://morphml.org/neuroml/schema'
nml_ns='http://morphml.org/networkml/schema'
mml_ns='http://morphml.org/morphml/schema'
bio_ns='http://morphml.org/biophysics/schema'
cml_ns='http://morphml.org/channelml/schema'
meta_ns='http://morphml.org/metadata/schema'
xsi_ns='http://www.w3.org/2001/XMLSchema-instance'

## ElementTree parse works an order of magnitude or more faster than minidom
## BUT it doesn't keep the original namespaces,
## from http://effbot.org/zone/element-namespaces.htm , I got _namespace_map
# neuroml, bio, mml, etc are defined above
ET._namespace_map[neuroml_ns] = 'neuroml'
ET._namespace_map[nml_ns] = 'nml'
ET._namespace_map[mml_ns] = 'mml'
ET._namespace_map[bio_ns] = 'bio'
ET._namespace_map[cml_ns] = 'cml'
ET._namespace_map[meta_ns] = 'meta'
ET._namespace_map[xsi_ns] = 'xsi'

def set_neuroml_namespaces_attribs(neuromlroot):
    set_attrib_if_not_found(neuromlroot,"xmlns",neuroml_ns)
    set_attrib_if_not_found(neuromlroot,"xmlns:nml",nml_ns)
    set_attrib_if_not_found(neuromlroot,"xmlns:mml",mml_ns)
    set_attrib_if_not_found(neuromlroot,"xmlns:bio",bio_ns)
    set_attrib_if_not_found(neuromlroot,"xmlns:cml",cml_ns)
    set_attrib_if_not_found(neuromlroot,"xmlns:meta",meta_ns)
    ## later doc.write() assigns the xsi namespace a second time
    ## causing double definition and problem with xsltproc,
    ## hence commenting it out here.
    #set_attrib_if_not_found(neuromlroot,"xmlns:xsi",xsi_ns)

def set_attrib_if_not_found(elem, name, value):
    if elem.get(name) is None:
        elem.set(name,value)

def tweak_model(root_element, params):
    if 'excludePopulations' in params: # if params has key 'excludePopulations'
        ## both excludePopulations and excludeProjections must be present together
        pruneExcludes(root_element,params['excludePopulations'],params['excludeProjections'])
    if 'onlyInclude' in params: # if params has key 'onlyInclude'
        keepOnlyInclude(root_element,params['onlyInclude'])

def pruneExcludes(network, excludepops, excludeprojs):
    """
    remove the populations in the excludepops list
    remove the projections in the excludeprojs list
    """
    populations = network.find(".//{"+nml_ns+"}populations")
    pop_remove_list = []
    for population in populations.findall(".//{"+nml_ns+"}population"):
        populationname = population.attrib["name"]
        ## if any of the name-s in exclude_list are a SUBSTRING
        ## of the name of the neuroml population, mark it for removal
        for name in excludepops:
            if name in populationname: # substring
                pop_remove_list.append(population)
    ## remove only unique names,
    ## else you could end up trying to remove same population twice
    for population in set(pop_remove_list):
        populations.remove(population)

    projections = network.find(".//{"+nml_ns+"}projections")
    proj_remove_list = []
    for projection in projections.findall(".//{"+nml_ns+"}projection"):
        projectionname = projection.attrib["name"]
        ## if any of the name-s in exclude_list are a SUBSTRING
        ## of the name of the neuroml projection, mark it for removal
        for name in excludeprojs:
            if name in projectionname: # substring
                proj_remove_list.append(projection)
    ## remove only unique names,
    ## else you could end up trying to remove same projection twice
    for projection in set(proj_remove_list):
        projections.remove(projection)

def keepOnlyInclude(network, onlyInclude):
    """
    Keep only the cells that are in onlyInclude['includePopulation']
    and also keep cells that are connected to cells in onlyInclude['includePopulation']
    and keep connections to any of the cells in onlyInclude['includePopulation'].
    Prune the extraneous connections
    but keep those connections in onlyInclude['includeProjections']
    on cells connected to those in onlyInclude['includePopulation']
    """
    ### Remove the connections that do not connect to cells in onlyInclude.
    ### Simultaneously build up a list of cells 'includeCellsDict' that connect to cells in onlyInclude.
    ### Of course this includeCellDict must have the originally included cells!
    ### At the end of this pruning, even if some population-s / projection-s have no elements,
    ### it doesn't matter, as this findall() returns an empty list and not None - so no error.
    ### Further I am not changing the 'size' attrib in <instances> and <connections>,
    ### as it's not used by this reader and I'm not saving the network after pruning.
    ### Do not prune 'includeProjections' immediately;
    ### prune them later avoiding second order cells in includeCellsDict.
    includepopname = onlyInclude['includePopulation'][0]
    includecellids = onlyInclude['includePopulation'][1]
    ## first of all, include those primary cells that the user instructs.
    includeCellsDict = {includepopname:includecellids}
    ## projections 'includeProjs' will be pruned later, keeping connections to second order cells.
    includeProjs = []
    print "removing obviously extra connections in ... "
    for projection in network.findall(".//{"+nml_ns+"}projection"):
        projname = projection.attrib['name']
        includeProj = False
        ## check if any of the given includeprojname is a substring of this projname
        for includeprojname in onlyInclude['includeProjections']:
            if includeprojname in projname:
                includeProj = True
        ## if it is a substring, add this projection
        ## to the list of projections to be pruned later
        if includeProj:
            includeProjs.append(projection)
        source = projection.attrib["source"]
        target = projection.attrib["target"]
        print projname, source, target
        connections = projection.find(".//{"+nml_ns+"}connections")
        if connections is not None:
            for connection in connections.findall(".//{"+nml_ns+"}connection"):
                pre_cell_id = connection.attrib['pre_cell_id']
                ## is the user-included cell a source cell of the connection?
                includecellinsource = (pre_cell_id in includecellids and includepopname==source)
                post_cell_id = connection.attrib['post_cell_id']
                ## is the user-included cell a target cell of the connection?
                includecellintarget = (post_cell_id in includecellids and includepopname==target)
                ## the second-order cell connected to the user-included cell must also be kept
                if includecellinsource:
                    ## since source is included, include the target also
                    ## there can be self connections between the same population i.e. same source and target
                    try:
                        includeCellsDict[target].append(post_cell_id)
                    except KeyError: # create this population entry in the dictionary if not present
                        includeCellsDict[target] = [post_cell_id]
                elif includecellintarget:
                    ## since target is included, include the source also, except if source is a file
                    if 'file' not in source:
                        try:
                            includeCellsDict[source].append(pre_cell_id)
                        except KeyError: # create this population entry in the dictionary if not present
                            includeCellsDict[source] = [pre_cell_id]
                else:
                    ## this connection is extraneous
                    ## but remove this connection only if
                    ## it is not part of the projections to be pruned later
                    if not includeProj:
                        connections.remove(connection)

    ## convert includeCellsDict elements to set-s rather than lists
    ## to have only unique cell_ids and save time below.
    for key in includeCellsDict:
        includeCellsDict[key] = set(includeCellsDict[key])
    
    print "removing extra cells ... "
    ### remove the cells that are not in includeCellsDict
    populations = network.find(".//{"+nml_ns+"}populations")
    for population in network.findall(".//{"+nml_ns+"}population"):
        popname = population.attrib["name"]
        if popname in includeCellsDict.keys():
            includecellids = includeCellsDict[popname]
            instances = population.find(".//{"+nml_ns+"}instances")
            for instance in instances.findall(".//{"+nml_ns+"}instance"):
                ## not a connected cell, so remove
                if instance.attrib['id'] not in includecellids:
                    instances.remove(instance)
        else: ## this whole population is not required!
            populations.remove(population)

    ### Prune the 'includeProjections' that we skipped pruning before,
    ### while keeping connections to second order cells!
    for projection in includeProjs:
        print "removing projection",projection.attrib['name'],\
            "keeping second-order connections."
        source = projection.attrib["source"]
        target = projection.attrib["target"]
        ## boolean: True if includeCellsDict has key source
        source_in_includeCellsDict = source in includeCellsDict
        ## boolean: True if the word 'file' occurs in str source
        file_in_source = 'file' in source
        ## boolean: True if includeCellsDict has key target
        target_in_includeCellsDict = target in includeCellsDict
        connections = projection.find(".//{"+nml_ns+"}connections")
        for connection in connections.findall(".//{"+nml_ns+"}connection"):
            ## is the included cell a source cell of the connection?
            ## keep 'file' as source also.
            if file_in_source:
                includecellinsource = True
            elif source_in_includeCellsDict and \
                    connection.attrib['pre_cell_id'] in includeCellsDict[source]:
                includecellinsource = True
            else: includecellinsource = False
            ## is the included cell a target cell of the connection?
            if target_in_includeCellsDict and \
                    connection.attrib['post_cell_id'] in includeCellsDict[target]:
                includecellintarget = True
            else: includecellintarget= False
            ## this connection is extraneous
            ## if either sourcecell or targetcell is not included.
            if not includecellinsource or not includecellintarget:
                ## remove is a very slow operation!
                connections.remove(connection)

## in-place prettyprint formatter copied from http://effbot.org/zone/element-lib.htm
## first call indent(root, level=0), and then doc.write(filename)
def indent(elem, level=0):
    i = "\n" + level*"  "
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "  "
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
        for elem in elem:
            indent(elem, level+1)
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
    else:
        if level and (not elem.tail or not elem.tail.strip()):
            elem.tail = i