# ConnPlotter --- A Tool to Generate Connectivity Pattern Matrices
#
# This file is part of ConnPlotter.
#
# Copyright (C) 2009 Hans Ekkehard Plesser/UMB
#
# ConnPlotter is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# ConnPlotter is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ConnPlotter.  If not, see <http://www.gnu.org/licenses/>.

"""
ConnPlotter is a tool to create connectivity pattern tables.

For background on ConnPlotter, please see

  Eilen Nordlie and Hans Ekkehard Plesser.
  Connection Pattern Tables: A new way to visualize connectivity
  in neuronal network models.
  Frontiers in Neuroinformatics 3:39 (2010)
  doi: 10.3389/neuro.11.039.2009

Example:
# code creating population and connection lists

from ConnPlotter import ConnectionPattern, SynType

# Case A: All connections have the same "synapse_model".
#
# Connections with weight < 0 are classified as excitatory,
#                  weight > 0 are classified as inhibitory.
# Each sender must make either excitatory or inhibitory connection,
# not both. When computing totals, excit/inhib connections are
# weighted with +-1.
pattern = ConnectionPattern(layerList, connList)

# Case B: All connections have the same "synapse_model", but violate Dale's law
#
# Connections with weight < 0 are classified as excitatory,
#                  weight > 0 are classified as inhibitory.
# A single sender may have excitatory and inhibitory connections.
# When computing totals, excit/inhib connections are
# weighted with +-1.
pattern = ConnectionPattern(layerList, connList, 
                            synTypes=(((SynType('exc',  1.0, 'b'),
                                        SynType('inh', -1.0, 'r')),)))

# Case C: Synapse models are "AMPA", "NMDA", "GABA_A", "GABA_B".
# 
# Connections are plotted by synapse model, with AMPA and NMDA
# on the top row, GABA_A and GABA_B in the bottom row when
# combining by layer. Senders must either have AMPA and NMDA or 
# GABA_A and GABA_B synapses, but not both. When computing totals, 
# AMPA and NMDA connections are weighted with +1, GABA_A and GABA_B
# with -1.
pattern = ConnectionPattern(layerList, connList)

# Case D: Explicit synapse types.
#
# If your network model uses other synapse types, or you want to use
# other weighting factors when computing totals, or you want different
# colormaps, you must specify synapse type information explicitly for
# ALL synase models in your network. For each synapse model, you create
# a 
#
#        SynType(name, tweight, cmap)
#
# object, where "name" is the synapse model name, "tweight" the weight
# to be given to the type when computing totals (usually >0 for excit, 
# <0 for inhib synapses), and "cmap" the "colormap": if may be a
# matplotlib.colors.Colormap instance or any valid matplotlib color
# specification; in the latter case, as colormap will be generated 
# ranging from white to the given color. 
# Synapse types are passed as a tuple of tuples. Synapses in a tuple form
# a group. ConnPlotter assumes that a sender may make synapses with all
# types in a single group, but never synapses with types from different
# groups (If you group by transmitter, this simply reflects Dale's law).
# When connections are aggregated by layer, each group is printed on one 
# row.
pattern = ConnectionPattern(layerList, connList, synTypes = \
    ((SynType('Asyn',  1.0, 'orange'),
      SynType('Bsyn',  2.5, 'r'),
      SynType('Csyn',  0.5, (1.0, 0.5, 0.0))),  # end first group
     (SynType('Dsyn', -1.5, matplotlib.pylab.cm.jet),
      SynType('Esyn', -3.2, '0.95'))))
# See documentation of class ConnectionPattern for more options.

# plotting the pattern

# show connection kernels for all sender-target pairs and all synapse models
pattern.plot()

# combine synapses of all types for each sender-target pair
# always used red-blue (inhib-excit) color scale
pattern.plot(aggrSyns=True)   

# for each pair of sender-target layer pair, show sums for each synapse type
pattern.plot(aggrGroups=True)

# As mode layer, but combine synapse types.
# always used red-blue (inhib-excit) color scale
pattern.plot(aggrSyns=True, aggrGroups=True)

# Show only synases of the selected type(s)
pattern.plot(mode=('AMPA',))
pattern.plot(mode=('AMPA', 'GABA_A'))

# use same color scales for all patches
pattern.plot(globalColors=True)  

# manually specify limits for global color scale
pattern.plot(globalColors=True, colorLimits=[0, 2.5])  

# save to file(s)
# NB: do not write to PDF directly, this seems to cause artifacts
pattern.plot(file='net.png')  
pattern.plot(file=('net.eps','net.png'))  

# You can adjust some properties of the figure by changing the
# default values in plotParams. 

# Experimentally, you can dump the connection pattern into a LaTeX table
pattern.toLaTeX('pattern.tex', standalone=True)

# Figure layout can be modified by changing the global variable plotParams.
# Please see the documentation for class PlotParams for details.

# Changes 30 June 2010:
# - Singular layers (extent 0x0) are ignored as target layers.
#   The reason for this is so that single-generator "layers" can be
#   displayed as input.
#   Problems:
#   - singularity is not made clear visually
#   - This messes up the diagonal shading
#   - makes no sense to aggregate any longer
"""

__version__ = '$Revision: 546 $'
__date__    = '$Date: 2010-06-30 16:36:33 +0200 (Wed, 30 Jun 2010) $'
__author__  = 'Hans Ekkehard Plesser'

__all__ = ['ConnectionPattern', 'SynType', 'plotParams', 'PlotParams']

# ----------------------------------------------------------------------------

# To do:
# - proper testsuite
# - layers of different sizes not handled properly 
#   (find biggest layer extent in each direction, then center;
#    may run into problems with population label placement)
# - clean up main
# - color bars
# - "bad color" should be configurable
# - fix hack for colormaps import
# - use generators where possible (eg kernels?)

# ----------------------------------------------------------------------------

# The next is a hack that helps me during development (allows run ConnPlotter),
# should find a better solution.
if __name__ == "__main__":
    import colormaps as cm
else:
    from . import colormaps as cm

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import warnings

# ----------------------------------------------------------------------------

class SynType(object):
    """
    Provide information about how synapse types should be rendered.
    
    A singly nested list of SynType objects can be passed to the
    ConnectionPattern constructor to specify layout and rendering info.
    """

    def __init__(self, name, relweight, cmap):
        """
        Arguments:
        name         Name of synapse type (string, must be unique)
        relweight    Relative weight of synapse type when aggregating
                     across synapse types. Should be negative for inhibitory
                     connections.
        cmap         Either a matplotlib.colors.Colormap instance or a
                     color specification. In the latter case, the colormap
                     will be built from white to the color given. Thus,
                     the color should be fully saturated. Colormaps should
                     have "set_bad(color='white')".
        """

        self.name, self.relweight = name, relweight
        if isinstance(cmap, mpl.colors.Colormap):
            self.cmap = cmap
        else:
            self.cmap = cm.make_colormap(cmap)

# ----------------------------------------------------------------------------

class PlotParams(object):
    """
    Collects parameters governing plotting.
    Implemented using properties to ensure they are read-only.
    """

    class Margins(object):
        """Width of outer margins, in mm."""
        def __init__(self):
            """Set default values."""
            self._left  = 15.0
            self._right = 10.0
            self._top   = 10.0
            self._bottom= 10.0
            self._colbar= 10.0

        @property
        def left(self): return self._left
    
        @left.setter
        def left(self, l): self._left = float(l)

        @property
        def right(self): return self._right
    
        @right.setter
        def right(self, r): self._right = float(r)

        @property
        def top(self): return self._top
    
        @top.setter
        def top(self, t): self._top = float(t)

        @property
        def bottom(self): return self._bottom
    
        @bottom.setter
        def bottom(self, b): self._bottom = float(b)

        @property
        def colbar(self): return self._colbar
    
        @colbar.setter
        def colbar(self, b): self._colbar = float(b)

    def __init__(self):
        """Set default values"""
        self._n_kern       = 100
        self._patch_size   = 20.0  # 20 mm
        self._layer_bg     = {'super': '0.9', 'diag': '0.8', 'sub': '0.9'}
        self._layer_font   = mpl.font_manager.FontProperties(size='large')
        self._layer_orient = {'sender': 'horizontal', 'target': 'horizontal'}
        self._pop_font     = mpl.font_manager.FontProperties(size='small')
        self._pop_orient   = {'sender': 'horizontal', 'target': 'horizontal'}
        self._lgd_tick_font= mpl.font_manager.FontProperties(size='x-small')
        self._lgd_title_font=mpl.font_manager.FontProperties(size='xx-small')
        self._lgd_ticks    = None
        self._lgd_tick_fmt = None
        self._lgd_location = None
        self._cbwidth      = None
        self._cbspace      = None
        self._cbheight     = None
        self._cboffset     = None
        self._z_layer      =  25
        self._z_pop        =  50
        self._z_conn       = 100
        self.margins       = self.Margins()

    def reset(self):
        """
        Reset to default values.
        """
        self.__init__()
        
    @property
    def n_kern(self):
        """Sample long kernel dimension at N_kernel points."""        
        return self._n_kern
    
    @n_kern.setter
    def n_kern(self, n):
        if n <= 0:
            raise ValueError('n_kern > 0 required')
        self._n_kern = n

    @property
    def patch_size(self):
        """Length of the longest edge of the largest patch, in mm."""        
        return self._patch_size
    
    @patch_size.setter
    def patch_size(self, sz):
        if sz <= 0:
            raise ValueError('patch_size > 0 required')
        self._patch_size = sz

    @property
    def layer_bg(self):
        """
        Dictionary of colors for layer background. 
        Entries "super", "diag", "sub". Each entry
        can be set to any valid color specification.
        If just a color is given, create dict by
        brightening/dimming.
        """
        return self._layer_bg

    @layer_bg.setter
    def layer_bg(self, bg):
        if isinstance(bg, dict):
            if set(bg.keys()) != set(('super','diag','sub')):
                raise ValueError('Background dict must have keys "super", "diag", "sub"')
            for bgc in bg.values():
                if not mpl.colors.is_color_like(bgc):
                    raise ValueError('Entries in background dict must be valid color specifications.')
            self._layer_bg = bg
        elif not mpl.colors.is_color_like(bg):
            raise ValueError('layer_bg must be dict or valid color specification.')
        else: # is color like
            rgb = mpl.colors.colorConverter.to_rgb(bg)
            self._layer_bg = {'super': [1.1 * c for c in rgb],
                              'diag' : rgb,
                              'sub'  : [0.9 * c for c in rgb]}
        
    @property
    def layer_font(self):
        """
        Font to use for layer labels.
        Can be set to a matplotlib.font_manager.FontProperties instance.
        """
        return self._layer_font

    @layer_font.setter
    def layer_font(self, font):
        if not isinstance(font, mpl.font_manager.FontProperties):
            raise ValueError('layer_font must be a matplotlib.font_manager.FontProperties instance.')
        self._layer_font = font
        
    @property
    def layer_orientation(self):
        """
        Orientation of layer labels.
        Dictionary with orientation of sender and target labels. Orientation
        is either 'horizontal', 'vertial', or a value in degrees.  When set 
        to a single string or number, this value is used for both sender and
        target labels.
        """
        return self._layer_orient

    @layer_orientation.setter
    def layer_orientation(self, orient):
        
        if isinstance(orient, (str,float,int)):
            tmp = {'sender': orient, 'target': orient}
        elif isinstance(orient, dict):
            tmp = self._layer_orient
            tmp.update(orient)
        else:
            raise ValueError('Orientation ust be set to dict, string or number.')
        
        if len(tmp) > 2:
            raise ValueError('Orientation dictionary can only contain keys "sender" and "target".')
 
        self._layer_orient = tmp
            
    @property
    def pop_font(self):
        """
        Font to use for population labels.
        Can be set to a matplotlib.font_manager.FontProperties instance.
        """
        return self._pop_font
    
    @pop_font.setter
    def pop_font(self, font):
        if not isinstance(font, mpl.font_manager.FontProperties):
            raise ValueError('pop_font must be a matplotlib.font_manager.FontProperties instance.')
        self._pop_font = font

    @property
    def pop_orientation(self):
        """
        Orientation of population labels.
        Dictionary with orientation of sender and target labels. Orientation
        is either 'horizontal', 'vertial', or a value in degrees.  When set 
        to a single string or number, this value is used for both sender and
        target labels.
        """
        return self._pop_orient

    @pop_orientation.setter
    def pop_orientation(self, orient):
        
        if isinstance(orient, (str,float,int)):
            tmp = {'sender': orient, 'target': orient}
        elif isinstance(orient, dict):
            tmp = self._pop_orient
            tmp.update(orient)
        else:
            raise ValueError('Orientation ust be set to dict, string or number.')
        
        if len(tmp) > 2:
            raise ValueError('Orientation dictionary can only contain keys "sender" and "target".')
 
        self._pop_orient = tmp

    @property
    def legend_tick_font(self):
        """
        FontProperties for legend (colorbar) ticks.
        """
        return self._lgd_tick_font

    @legend_tick_font.setter
    def legend_tick_font(self, font):
        if not isinstance(font, mpl.font_manager.FontProperties):
            raise ValueError('legend_tick_font must be a matplotlib.font_manager.FontProperties instance.')
        self._lgd_tick_font = font

    @property
    def legend_title_font(self):
        """
        FontProperties for legend (colorbar) titles.
        """
        return self._lgd_title_font

    @legend_title_font.setter
    def legend_title_font(self, font):
        if not isinstance(font, mpl.font_manager.FontProperties):
            raise ValueError('legend_title_font must be a matplotlib.font_manager.FontProperties instance.')
        self._lgd_title_font = font

    @property
    def legend_ticks(self):
        """
        Ordered list of values at which legend (colorbar) ticks shall be set.
        """
        return self._lgd_ticks

    @legend_ticks.setter
    def legend_ticks(self, ticks):
        self._lgd_ticks = ticks

    @property
    def legend_tick_format(self):
        """
        C-style format string for legend (colorbar) tick marks.
        """
        return self._lgd_tick_fmt

    @legend_tick_format.setter
    def legend_tick_format(self, tickfmt):
        self._lgd_tick_fmt = tickfmt

    @property
    def legend_location(self):
        """
        If set to 'top', place legend label above colorbar,
        if None, to the left.
        """
        return self._lgd_location

    @legend_location.setter
    def legend_location(self, loc):
        self._lgd_location = loc

    @property
    def cbwidth(self):
        """
        Width of single colorbar, relative to figure width.
        """
        return self._cbwidth

    @cbwidth.setter
    def cbwidth(self, cbw):
        self._cbwidth = cbw

    @property
    def cbheight(self):
        """
        Height of colorbar, relative to margins.colbar
        """
        return self._cbheight

    @cbheight.setter
    def cbheight(self, cbh):
        self._cbheight = cbh

    @property
    def cbspace(self):
        """
        Spacing between colorbars, relative to figure width.
        """
        return self._cbspace

    @cbspace.setter
    def cbspace(self, cbs):
        self._cbspace = cbs

    @property
    def cboffset(self):
        """
        Left offset of colorbar, relative to figure width.
        """
        return self._cboffset

    @cboffset.setter
    def cboffset(self, cbo):
        self._cboffset = cbo

    @property
    def z_layer(self):
        """Z-value for layer label axes."""
        return self._z_layer
    
    @property
    def z_pop(self):
        """Z-value for population label axes."""
        return self._z_pop
    
    @property
    def z_conn(self):
        """Z-value for connection kernel axes."""
        return self._z_conn
    
# ----------------------------------------------------------------------------

# plotting settings, default values
plotParams = PlotParams()
        
# ----------------------------------------------------------------------------

class ConnectionPattern(object):
    """
    Connection pattern representation for plotting.

    When a ConnectionPattern is instantiated, all connection kernels
    are pre-computed. They can later be plotted in various forms by
    calling the plot() method.

    The constructor requires layer and connection lists:

       ConnectionPattern(layerList, connList, synTypes, **kwargs)

    The layerList is used to:
       - determine the size of patches
       - determine the block structure

    All other information is taken from the connList. Information
    about synapses is inferred from the connList.

    The following keyword arguments can also be given:

    poporder :  Population order. A dictionary mapping population names 
                to numbers; populations will be sorted in diagram in order 
                of increasing numbers. Otherwise, they are sorted 
                alphabetically.

    intensity:  'wp'  -  use weight * probability (default)
                'p'   -  use probability alone 
                'tcd' -  use total charge deposited * probability
                         requires mList and Vmem; per v 0.7 only supported
                         for ht_neuron.

    mList    :  model list; required for 'tcd'
    Vmem     :  membrane potential; required for 'tcd'
    """

    # ------------------------------------------------------------------------
    
    class _LayerProps(object):
        """
        Information about layer.
        """
        def __init__(self, name, extent):
            """
            name  : name of layer
            extent: spatial extent of the layer
            """
            self.name = name
            self.ext  = extent
            self.singular = extent[0] == 0.0 and extent[1] == 0.0

    # ------------------------------------------------------------------------
    
    class _SynProps(object):
        """
        Information on how to plot patches for a synapse type.
        """
        def __init__(self, row, col, tweight, cmap, idx):
            """
            row, col: Position of synapse in grid of synapse patches, begins at 0,0    
            tweight : weight to apply when adding kernels for different synapses
            cmap    : colormap for synapse type (matplotlib.colors.Colormap instance)
            idx     : linear index, used to order colorbars in figure
            """
            self.r, self.c = row, col
            self.tw        = tweight
            self.cmap      = cmap
            self.index     = idx

    # --------------------------------------------------------------------

    class _PlotKern(object):
        """
        Representing object ready for plotting.
        """

        def __init__(self, sl, sn, tl, tn, syn, kern):
            """
            sl  : sender layer
            sn  : sender neuron/population
            tl  : target layer
            tn  : target neuron/population
            syn : synapse model
            kern: kernel values (numpy masked array)
            All arguments but kern are strings.
            """
            self.sl   = sl
            self.sn   = sn
            self.tl   = tl
            self.tn   = tn
            self.syn  = syn
            self.kern = kern
            
    # ------------------------------------------------------------------------

    class _Connection(object):
        
        def __init__(self, conninfo, layers, synapses, intensity, tcd, Vmem):
            """
            Arguments:
            conninfo: list of connection info entries: (sender, target, conn_dict)
            layers  : list of _LayerProps objects
            synapses: list of _SynProps objects
            intensity: 'wp', 'p', 'tcd'
            tcd      : tcd object
            Vmem     : reference membrane potential for tcd calculations
            """

            self._intensity = intensity

            # get source and target layer
            self.slayer, self.tlayer = conninfo[:2]  
            lnames = [l.name for l in layers]

            if not self.slayer in lnames:
                raise Exception('Unknown source layer "%s".' % self.slayer)
            if not self.tlayer in lnames:
                raise Exception('Unknown target layer "%s".' % self.tlayer)

            # if target layer is singular (extent==(0,0)) we do not create a full object
            self.singular = False
            for l in layers:
                if l.name == self.tlayer and l.singular:
                    self.singular = True
                    return

            # see if we connect to/from specific neuron types
            cdict = conninfo[2]

            if 'sources' in cdict:
                if cdict['sources'].keys() == ['model']:
                    self.snrn = cdict['sources']['model']
                else:
                    raise ValueError('Can only handle sources in form {"model": ...}')
            else:
                self.snrn = None

            if 'targets' in cdict: 
                if cdict['targets'].keys() == ['model']:
                    self.tnrn = cdict['targets']['model']
                else:
                    raise ValueError('Can only handle targets in form {"model": ...}')
            else:
                self.tnrn = None

            # now get (mean) weight, we need this if we classify
            # connections by sign of weight only
            try:
                self._mean_wght = _weighteval(cdict['weights'])
            except:
                raise ValueError('No or corrupt weight information.')

            # synapse model
            if sorted(synapses.keys()) == ['exc', 'inh']:
                # implicit synapse type, we ignore value of
                # 'synapse_model', it is for use by NEST only
                if self._mean_wght >= 0:
                    self.synmodel = 'exc'
                else:
                    self.synmodel = 'inh'
            else:
                try:
                    self.synmodel = cdict['synapse_model']
                    if not self.synmodel in synapses:
                        raise Exception('Unknown synapse model "%s".' 
                                        % self.synmodel)
                except:
                    raise Exception('Explicit synapse model info required.')

            # store information about connection
            try:
                self._mask = cdict['mask']
                self._kern = cdict['kernel']
                self._wght = cdict['weights']
                # next line presumes only one layer name will match
                self._textent = [tl.ext for tl in layers 
                                 if tl.name==self.tlayer][0]
                if intensity == 'tcd':
                    self._tcd = tcd(self.synmodel, self.tnrn, Vmem)
                else:
                    self._tcd = None
            except:
                raise Exception('Corrupt connection dictionary')

            # prepare for lazy evaluation
            self._kernel = None

        # --------------------------------------------------------------------

        @property
        def keyval(self):
            """
            Return key and _Connection as tuple.
            Useful to create dictionary via list comprehension.
            """
            if self.singular:
                return (None, self)
            else:
                return ((self.slayer,self.snrn,self.tlayer,self.tnrn,self.synmodel), 
                        self)

        # --------------------------------------------------------------------
        
        @property
        def kernval(self):
            """Kernel value, as masked array."""
            if self._kernel is None:
                self._kernel = _evalkernel(self._mask, self._kern, self._mean_wght,
                                           self._textent, self._intensity, 
                                           self._tcd)
            return self._kernel

        # --------------------------------------------------------------------
        
        @property
        def mask(self):
            """Dictionary describing the mask."""
            return self._mask

        # --------------------------------------------------------------------
        
        @property
        def kernel(self):
            """Dictionary describing the kernel."""
            return self._kern

        # --------------------------------------------------------------------
        
        @property
        def weight(self):
            """Dictionary describing weight distribution."""
            return self._wght

        # --------------------------------------------------------------------

        def matches(self, sl=None, sn=None, tl=None, tn=None, syn=None):
            """
            Return True if all non-None arguments match.
            Arguments:
            sl : sender layer
            sn : sender neuron type
            tl : target layer
            tn : target neuron type
            syn: synapse type
            """

            return  (sl  is None or sl  == self.slayer) \
                and (sn  is None or sn  == self.snrn) \
                and (tl  is None or tl  == self.tlayer) \
                and (tn  is None or tn  == self.tnrn) \
                and (syn is None or syn == self.synmodel) 

    # ----------------------------------------------------------------------------

    class _Patch(object):
        """
        Represents a patch, i.e., an axes that will actually contain an
        imshow graphic of a connection kernel.
        The patch object contains the physical coordinates of the patch,
        as well as a reference to the actual Axes object once it is created.
        Also contains strings to be used as sender/target labels.

        Everything is based on a coordinate system looking from the top left
        corner down.
        """

        # --------------------------------------------------------------------

        def __init__(self, left, top, row, col, width, height, 
                     slabel=None, tlabel=None, parent=None):
            """
            Arguments:
            left, top     : Location of top-left corner
            row, col      : row, column location in parent block
            width, height : Width and height of patch
            slabel, tlabel: Values for sender/target label
            parent        : _Block to which _Patch/_Block belongs
            """
            self.l, self.t, self.r, self.c, self.w, self.h = left, top, row, col, width, height
            self.slbl, self.tlbl = slabel, tlabel
            self.ax = None
            self._parent = parent

        # --------------------------------------------------------------------

        def _update_size(self, new_lr):
            """Update patch size by inspecting all children.""" 
            if new_lr[0] < self.l: 
                raise ValueError("new_lr[0] = %f < l = %f" % (new_lr[0], self.l))
            if new_lr[1] < self.t: 
                raise ValueError("new_lr[1] = %f < t = %f" % (new_lr[1], self.t))
            self.w, self.h = new_lr[0]-self.l, new_lr[1]-self.t
            if self._parent:
                self._parent._update_size(new_lr)

        # --------------------------------------------------------------------

        @property
        def tl(self): 
            """Top left corner of the patch."""
            return (self.l, self.t)

        # --------------------------------------------------------------------

        @property
        def lr(self): 
            """Lower right corner of the patch."""
            return (self.l+self.w, self.t+self.h)

        # --------------------------------------------------------------------

        @property
        def l_patches(self):
            """Left edge of leftmost _Patch in _Block."""
            if isinstance(self, ConnectionPattern._Block):
                return min([e.l_patches for e in _flattened(self.elements)])
            else:
                return self.l

        # --------------------------------------------------------------------

        @property
        def t_patches(self):
            """Top edge of topmost _Patch in _Block."""
            if isinstance(self, ConnectionPattern._Block):
                return min([e.t_patches for e in _flattened(self.elements)])
            else:
                return self.t

        # --------------------------------------------------------------------

        @property
        def r_patches(self):
            """Right edge of rightmost _Patch in _Block."""
            if isinstance(self, ConnectionPattern._Block):
                return max([e.r_patches for e in _flattened(self.elements)])
            else:
                return self.l + self.w

        # --------------------------------------------------------------------

        @property
        def b_patches(self):
            """Bottom edge of lowest _Patch in _Block."""

            if isinstance(self, ConnectionPattern._Block):
                return max([e.b_patches for e in _flattened(self.elements)])
            else:
                return self.t + self.h

        # --------------------------------------------------------------------

        @property
        def location(self):
            if self.r < self.c:
                return 'super'
            elif self.r == self.c:
                return 'diag'
            else:
                return 'sub'

    # ----------------------------------------------------------------------------

    class _Block(_Patch):
        """
        Represents a block of patches. 

        A block is initialized with its top left corner and is then built
        row-wise downward and column-wise to the right. Rows are added by

        block.newRow(2.0, 1.5)

        where 2.0 is the space between rows, 1.5 the space between the
        first row. Elements are added to a row by

        el = block.newElement(1.0, 0.6, 's', 't')
        el = block.newElement(1.0, 0.6, 's', 't', size=[2.0, 3.0])

        The first example adds a new _Block to the row. 1.0 is the space between
        blocks, 0.6 the space before the first block in a row. 's' and 't' are
        stored as slbl and tlbl (optional). If size is given, a _Patch with
        the given size is created. _Patch is atomic. newElement() returns the
        _Block or _Patch created.
        """

        # ------------------------------------------------------------------------

        def __init__(self, left, top, row, col, slabel=None, tlabel=None, parent=None):
            ConnectionPattern._Patch.__init__(self, left, top, row, col, 0, 0, slabel, tlabel, parent)
            self.elements = []
            self._row_top = None  # top of current row
            self._row = 0
            self._col = 0

        # ------------------------------------------------------------------------

        def newRow(self, dy=0.0, dynew=0.0):
            """
            Open new row of elements.
            Arguments:
            dy   : vertical skip before new row
            dynew: vertical skip if new row is first row
            """

            if self.elements:
                # top of row is bottom of block so far + dy
                self._row_top = self.lr[1] + dy
            else:
                # place relative to top edge of parent
                self._row_top = self.tl[1] + dynew
            self._row += 1
            self._col = 0

            self.elements.append([])
            
        # ------------------------------------------------------------------------

        def newElement(self, dx=0.0, dxnew=0.0, slabel=None, tlabel=None, 
                       size=None):
            """
            Append new element to last row. 
            Creates _Block instance if size is not given, otherwise _Patch.

            Arguments:
            dx    : horizontal skip before new element
            dxnew : horizontal skip if new element is first
            slabel: sender label (on y-axis)
            tlabel: target label (on x-axis)
            size  : size of _Patch to create

            Returns:
            Created _Block or _Patch.
            """

            assert(self.elements)

            if self.elements[-1]:
                # left edge is right edge of block so far + dx
                col_left = self.lr[0] + dx
            else:
                # place relative to left edge of parent
                col_left = self.tl[0] + dxnew

            self._col += 1
            
            if not size is None:
                elem = ConnectionPattern._Patch(col_left, self._row_top, self._row, self._col,
                                                size[0], size[1], slabel, tlabel, self)
            else:
                elem = ConnectionPattern._Block(col_left, self._row_top, self._row, self._col,
                                                slabel, tlabel, self)
            
            self.elements[-1].append(elem)
            self._update_size(elem.lr)

            return elem

        # ------------------------------------------------------------------------

        def addMargin(self, rmarg=0.0, bmarg=0.0):
            """Extend block by margin to right and bottom."""
            if rmarg < 0.0: raise ValueError('rmarg must not be negative!')
            if bmarg < 0.0: raise ValueError('bmarg must not be negative!')

            lr = self.lr
            self._update_size((lr[0]+rmarg, lr[1]+bmarg))

    # ----------------------------------------------------------------------------

    def _prepareAxes(self, mode, showLegend):
        """
        Prepare information for all axes, but do not create the actual axes yet.
        mode: one of 'detailed', 'by layer', 'totals'
        """

        # parameters for figure, all quantities are in mm
        patchmax = plotParams.patch_size # length of largest patch patch dimension

        # actual parameters scaled from default patchmax = 20mm
        lmargin = plotParams.margins.left
        tmargin = plotParams.margins.top
        rmargin = plotParams.margins.right
        bmargin = plotParams.margins.bottom
        cbmargin= plotParams.margins.colbar
        
        blksep  =  3./20. * patchmax  # distance between blocks
        popsep  =  2./20. * patchmax  # distance between populations
        synsep  =  0.5/20.* patchmax  # distance between synapse types

        # find maximal extents of individual patches, horizontal and vertical
        maxext = max(_flattened([l.ext for l in self._layers]))
        
        patchscale = patchmax / float(maxext)  # determines patch size

        # obtain number of synaptic patches per population pair
        # maximum column across all synapse types, same for rows
        nsyncols = max([s.c for s in self._synAttr.values()]) + 1
        nsynrows = max([s.r for s in self._synAttr.values()]) + 1

        # dictionary mapping into patch-axes, to they can be found later
        self._patchTable = {}

        # set to store all created patches to avoid multiple
        # creation of patches at same location
        axset = set()

        # create entire setup, top-down
        self._axes = self._Block(lmargin, tmargin, 1, 1)
        for sl in self._layers:

            # get sorted list of populations for sender layer
            spops = sorted([p[1] for p in self._pops if p[0] == sl.name],
                           key=lambda pn: self._poporder[pn])

            self._axes.newRow(blksep, 0.0)

            for tl in self._layers:

                # ignore singular target layers
                if tl.singular:
                    continue

                # get sorted list of populations for target layer
                tpops = sorted([p[1] for p in self._pops if p[0] == tl.name],
                               key=lambda pn: self._poporder[pn])

                # compute size for patches
                patchsize = patchscale * np.array(tl.ext)

                block = self._axes.newElement(blksep, 0.0, sl.name, tl.name)

                if mode == 'totals':
                    # single patch
                    block.newRow(popsep, popsep/2.)
                    p = block.newElement(popsep, popsep/2., size=patchsize)
                    self._patchTable[(sl.name, None, tl.name, None, None)] = p

                elif mode == 'layer':
                    # We loop over all rows and columns in the synapse patch grid.
                    # For each (r,c), we find the pertaining synapse name by reverse
                    # lookup in the _synAttr dictionary. This is inefficient, but
                    # should not be too costly overall. But we must create the
                    # patches in the order they are placed.
                    # NB: We must create also those block.newElement() that are not
                    #     registered later, since block would otherwise not skip 
                    #     over the unused location.
                    for r in xrange(nsynrows):
                        block.newRow(synsep, popsep/2.)
                        for c in xrange(nsyncols):
                            p = block.newElement(synsep, popsep/2., size=patchsize)
                            smod = [k for k,s in self._synAttr.iteritems()
                                    if s.r == r and s.c == c]
                            if smod:
                                assert(len(smod)==1)
                                self._patchTable[(sl.name,None,tl.name,None,smod[0])] = p

                elif mode == 'population':
                    # one patch per population pair
                    for sp in spops:
                        block.newRow(popsep, popsep/2.)
                        for tp in tpops:
                            pblk = block.newElement(popsep, popsep/2., sp, tp)
                            pblk.newRow(synsep, synsep/2.)
                            self._patchTable[(sl.name,sp,tl.name,tp,None)] = \
                                pblk.newElement(synsep, blksep/2., size=patchsize)

                else:
                    # detailed presentation of all pops
                    for sp in spops:
                        block.newRow(popsep, popsep/2.)
                        for tp in tpops:
                            pblk = block.newElement(popsep, popsep/2., sp, tp)
                            pblk.newRow(synsep, synsep/2.)

                            # Find all connections with matching properties
                            # all information we need here is synapse model.
                            # We store this in a dictionary mapping synapse
                            # patch column to synapse model, for use below.
                            syns = dict([(self._synAttr[c.synmodel].c, c.synmodel)
                                         for c in _flattened(self._cTable.values())
                                         if c.matches(sl.name, sp, tl.name, tp)])
                        
                            # create all synapse patches
                            for n in xrange(nsyncols):

                                # Do not duplicate existing axes.
                                if (sl.name,sp,tl.name,tp,n) in axset:
                                    continue  

                                # Create patch. We must create also such patches
                                # that do not have synapses, since spacing would
                                # go wrong otherwise.
                                p = pblk.newElement(synsep, 0.0, size=patchsize)

                                # if patch represents existing synapse, register
                                if n in syns:
                                    self._patchTable[(sl.name,sp,tl.name,tp,syns[n])] = p

                block.addMargin(popsep/2., popsep/2.)

        self._axes.addMargin(rmargin, bmargin)
        if showLegend:
            self._axes.addMargin(0, cbmargin)  # add color bar at bottom
            figwidth = self._axes.lr[0] - self._axes.tl[0] - rmargin # keep right marg out of calc

            if mode == 'totals' or mode == 'population':

                # single patch at right edge, 20% of figure
                if plotParams.cbwidth:
                    lwidth = plotParams.cbwidth * figwidth
                else:
                    lwidth = 0.2 * figwidth
                    if lwidth > 100.0:   # colorbar should not be wider than 10cm
                        lwidth = 100.0
                lheight = plotParams.cbheight*cbmargin if plotParams.cbheight else 0.3*cbmargin

                if plotParams.legend_location is None:
                    cblift = 0.9 * cbmargin
                else:
                    cblift = 0.7 * cbmargin

                self._cbPatches = self._Patch(self._axes.tl[0],
                                              self._axes.lr[1]- cblift,
                                              None, None, 
                                              lwidth,
                                              lheight)
            else:
                # one patch per synapse type, 20% of figure or less

                # we need to get the synapse names in ascending order of synapse indices
                snames = [s[0] for s in 
                          sorted([(k,v) for k,v in self._synAttr.iteritems()], 
                                 key=lambda kv: kv[1].index)
                          ]
                snum = len(snames)
                
                if plotParams.cbwidth:
                    lwidth = plotParams.cbwidth * figwidth
                    if plotParams.cbspace:
                        lstep = plotParams.cbspace * figwidth
                    else:
                        lstep = 0.5 * lwidth
                else:
                    if snum < 5:
                        lwidth = 0.15 * figwidth
                        lstep  = 0.1  * figwidth
                    else:
                        lwidth = figwidth / (snum + 1.0)
                        lstep = (figwidth - snum*lwidth) / (snum - 1.0)
                    if lwidth > 100.0:   # colorbar should not be wider than 10cm
                        lwidth = 100.0
                        lstep = 30.0
                lheight = plotParams.cbheight*cbmargin if plotParams.cbheight else 0.3*cbmargin

                if plotParams.cboffset is not None:
                    offset = plotParams.cboffset
                else:
                    offset = lstep
                    
                if plotParams.legend_location is None:
                    cblift = 0.9 * cbmargin
                else:
                    cblift = 0.7 * cbmargin

                self._cbPatches = {}
                for j in xrange(snum):
                    self._cbPatches[snames[j]] = \
                        self._Patch(self._axes.tl[0] + offset + j * (lstep + lwidth),
                                    self._axes.lr[1] - cblift,
                                    None, None,
                                    lwidth,
                                    lheight)

    # ----------------------------------------------------------------------------

    def _scaledBox(self, p):
        """Scaled axes rectangle for patch, reverses y-direction."""
        xsc, ysc = self._axes.lr
        return self._figscale * np.array([p.l/xsc, 1-(p.t+p.h)/ysc, p.w/xsc, p.h/ysc])

    # ----------------------------------------------------------------------------

    def _scaledBoxNR(self, p):
        """Scaled axes rectangle for patch, does not reverse y-direction."""
        xsc, ysc = self._axes.lr
        return self._figscale * np.array([p.l/xsc, p.t/ysc, p.w/xsc, p.h/ysc])

    # ----------------------------------------------------------------------------

    def _configSynapses(self, cList, synTypes):
        """Configure synapse information based on connections and user info."""

        # compile information on synapse types and weights
        synnames   = set(c[2]['synapse_model'] for c in cList)
        synweights = set(_weighteval(c[2]['weights']) for c in cList)

        # set up synTypes for all pre-defined cases
        if synTypes:
            # check if there is info for all synapse types
            stnames = _flattened([[s.name for s in r] for r in synTypes])
            if len(stnames) != len(set(stnames)):
                raise ValueError('Names of synapse types in synTypes must be unique!')
            if len(synnames) > 1 and not synnames.issubset(set(stnames)):
                raise ValueError('synTypes must provide information about all synapse types.')

        elif len(synnames) == 1:
            # only one synapse type used
            if min(synweights) >= 0:
                # all weights positive
                synTypes = ((SynType('exc',  1.0, 'red'),),)
            elif max(synweights) <= 0:
                # all weights negative
                synTypes = ((SynType('inh', -1.0, 'blue'),),)
            else:
                # positive and negative weights, assume Dale holds
                synTypes = ((SynType('exc',  1.0, 'red'),),
                            (SynType('inh', -1.0, 'blue' ),))

        elif synnames == set(['AMPA', 'GABA_A']):
                # only AMPA and GABA_A
                synTypes = ((SynType('AMPA'  ,  1.0, 'red'),),
                            (SynType('GABA_A', -1.0, 'blue' ),))

        elif synnames.issubset(set(['AMPA','NMDA','GABA_A','GABA_B'])):
            synTypes = ((SynType('AMPA'  ,  1.0, 'red'     ),
                         SynType('NMDA'  ,  1.0, 'orange'  ),),
                        (SynType('GABA_A', -1.0, 'blue'    ),
                         SynType('GABA_B', -1.0, 'purple'  ),))

        else:
            raise ValueError('Connection list contains unknown synapse models; synTypes required.')

        # now build _synAttr by assigning blocks to rows
        self._synAttr = {}
        row = 0
        ctr = 0 
        for sgroup in synTypes:
            col = 0
            for stype in sgroup:
                self._synAttr[stype.name] = self._SynProps(row, col, stype.relweight, 
                                                           stype.cmap, ctr)
                col += 1
                ctr += 1
            row += 1

    # ----------------------------------------------------------------------------

    def __init__(self, lList, cList, synTypes=None, intensity='wp', 
                 mList=None, Vmem=None, poporder=None):
        """
        lList    : layer list
        cList    : connection list
        synTypes : nested list of synapse types
        intensity: 'wp'  - weight * probability
                   'p'   - probability
                   'tcd' - |total charge deposited| * probability
                           requires mList; currently only for ht_model
                           proper results only if Vmem within reversal potentials
        mList    : model list; only needed with 'tcd'
        Vmem     : reference membrane potential for 'tcd'
        poporder : dictionary mapping population names to numbers; populations
                   will be sorted in diagram in order of increasing numbers.
        """
        # extract layers to dict mapping name to extent
        self._layers = [self._LayerProps(l[0], l[1]['extent']) for l in lList] 

        # ensure layer names are unique
        lnames = [l.name for l in self._layers]
        if len(lnames) != len(set(lnames)):
            raise ValueError('Layer names must be unique.')

        # set up synapse attributes
        self._configSynapses(cList, synTypes)

        # if tcd mode, build tcd representation
        if intensity != 'tcd':
            tcd = None
        else:
            assert(mList)
            import tcd_nest
            tcd = tcd_nest.TCD(mList)

        # Build internal representation of connections.
        # This representation contains one entry for each sender pop, target pop,
        # synapse type tuple. Creating the connection object implies computation
        # of the kernel. 
        # Several connection may agree in all properties, these need to be
        # added here. Therefore, we need to build iteratively and store
        # everything in a dictionary, so we can find early instances.
        self._cTable = {}
        for conn in cList:
            key, val = self._Connection(conn, self._layers, self._synAttr,
                                        intensity, tcd, Vmem).keyval
            if key:
                if key in self._cTable:
                    self._cTable[key].append(val)
                else:
                    self._cTable[key] = [val]

        # number of layers
        self._nlyr = len(self._layers)

        # compile list of populations, list(set()) makes list unique
        self._pops = list(set(_flattened([[(c.slayer, c.snrn), (c.tlayer, c.tnrn)] \
                                             for c in _flattened(self._cTable.values())])))
        self._npop = len(self._pops)

        # store population ordering; if not given, use alphabetical ordering
        # also add any missing populations alphabetically at end
        # layers are ignored
        # create alphabetically sorted list of unique population names
        popnames = sorted(list(set([p[1] for p in self._pops]))) 
        if poporder:
            self._poporder = poporder
            next = max(self._poporder.values()) + 1  # next free sorting index
        else:
            self._poporder = {}
            next = 0
        for pname in popnames:
            if pname not in self._poporder:
                self._poporder[pname] = next
                next += 1

        # compile list of synapse types
        self._synTypes = list(set([c.synmodel for c in _flattened(self._cTable.values())]))

    # ----------------------------------------------------------------------------

    def plot(self, aggrGroups=False, aggrSyns=False, globalColors=False,
             colorLimits=None, showLegend=True,
             selectSyns=None, file=None, fixedWidth=None):
        """
        Plot connection pattern.
        By default, connections between any pair of populations
        are plotted on the screen, with separate color scales for
        all patches.

        Arguments:
        aggrGroups  If True, aggregate projections with the same synapse type
                    and the same source and target groups (default: False)

        aggrSyns    If True, aggregate projections with the same synapse model (default: False)

        globalColors  If True, use global color scale, otherwise local (default: False)

        colorLimits   If given, must be two element vector for lower and upper limits of
                      color scale. Implies globalColors (default: None)

        showLegend   If True, show legend below CPT (default: True).

        selectSyns    If tuple of synapse models, show only connections of the 
                      give types. Cannot be combined with aggregation.

        file     If given, save plot to given file name; file may also be a tuple of 
                 file names, the figure will then be saved to all files. This may be 
                 useful if you want to save the same figure in several formats.
                 You should not save to PDF directly, this may lead to artefacts; 
                 rather save to PS or EPS, then convert. 

        fixedWidth Figure will be scaled to this width in mm by changing patch size.

        Returns:
        kern_min, kern_max   Minimal and maximal values of kernels, with kern_min <=0, kern_max>=0.

        Output:
        figure created
        """

        # translate new to old paramter names (per v 0.5)
        normalize = globalColors
        if colorLimits:
            normalize = True

        if selectSyns:
            if aggrPops or aggrSyns:
                raise ValueError('selectSyns cannot be combined with aggregation.')
            selected = selectSyns
            mode = 'select'
        elif aggrGroups and aggrSyns:
            mode = 'totals'
        elif aggrGroups and not aggrSyns:
            mode = 'layer'
        elif aggrSyns and not aggrGroups:
            mode = 'population'
        else:
            mode = None

        if mode == 'layer':
            # reduce to dimensions sender layer, target layer, synapse type
            # add all kernels agreeing on these three attributes
            plotKerns = []
            for slayer in self._layers:
                for tlayer in self._layers:
                    for synmodel in self._synTypes:
                        kerns = [c.kernval for c in _flattened(self._cTable.values())
                                 if c.matches(sl=slayer.name, tl=tlayer.name, syn=synmodel)]
                        if len(kerns) > 0:
                            plotKerns.append(self._PlotKern(slayer.name, None, tlayer.name, None, synmodel, 
                                                            _addKernels(kerns)))
        elif mode == 'population':
            # reduce to dimensions sender layer, target layer
            # all all kernels, weighting according to synapse type
            plotKerns = []
            for spop in self._pops:
                for tpop in self._pops:
                    kerns = [self._synAttr[c.synmodel].tw * c.kernval for c in _flattened(self._cTable.values())
                             if c.matches(sl=spop[0], sn=spop[1], tl=tpop[0], tn=tpop[1])]
                    if len(kerns) > 0:
                        plotKerns.append(self._PlotKern(spop[0], spop[1], tpop[0], tpop[1], None, 
                                                        _addKernels(kerns)))
        elif mode == 'totals':
            # reduce to dimensions sender layer, target layer
            # all all kernels, weighting according to synapse type
            plotKerns = []
            for slayer in self._layers:
                for tlayer in self._layers:
                    kerns = [self._synAttr[c.synmodel].tw * c.kernval for c in _flattened(self._cTable.values())
                             if c.matches(sl=slayer.name, tl=tlayer.name)]
                    if len(kerns) > 0:
                        plotKerns.append(self._PlotKern(slayer.name, None, tlayer.name, None, None, 
                                                        _addKernels(kerns)))
        elif mode == 'select':
            # copy only those kernels that have the requested synapse type,
            # no dimension reduction
            # nb: we need to sum all kernels in the list for a set of attributes
            plotKerns = [self._PlotKern(clist[0].slayer, clist[0].snrn, clist[0].tlayer, clist[0].tnrn, 
                                        clist[0].synmodel, _addKernels([c.kernval for c in clist]))
                         for clist in self._cTable.values() if clist[0].synmodel in selected]
        else:
            # copy all
            # nb: we need to sum all kernels in the list for a set of attributes
            plotKerns = [self._PlotKern(clist[0].slayer, clist[0].snrn, clist[0].tlayer, clist[0].tnrn, 
                                        clist[0].synmodel, _addKernels([c.kernval for c in clist]))
                         for clist in self._cTable.values()]
            
        self._prepareAxes(mode, showLegend)
        if fixedWidth:

            margs = plotParams.margins.left + plotParams.margins.right

            if fixedWidth <= margs:
                raise ValueError('Requested width must be less than width of margins (%g mm)' % margs)

            currWidth = self._axes.lr[0]
            currPatchMax = plotParams.patch_size # store
            
            # compute required patch size
            plotParams.patch_size = (fixedWidth - margs) / (currWidth - margs) * currPatchMax

            # build new axes
            del self._axes
            self._prepareAxes(mode, showLegend)

            # restore patch size
            plotParams.patch_size = currPatchMax

        # create figure with desired size
        fsize = np.array(self._axes.lr) / 25.4 # convert mm to inches
        f = plt.figure(figsize=fsize, facecolor='w')

        # size will be rounded according to DPI setting, adjust fsize 
        dpi = f.get_dpi()
        fsize = np.floor(fsize*dpi) / dpi

        # check that we got the correct size
        actsize =np.array([f.get_figwidth(), f.get_figheight()], dtype=float)
        if all(actsize == fsize):
            self._figscale = 1.0  # no scaling
        else:
            warnings.warn("""
                  WARNING: Figure shrunk on screen!
                  The figure is shrunk to fit onto the screen.

                  Please specify a different backend using the -d
                  option to obtain full-size figures. Your current
                  backend is: %s
                  """ % mpl.get_backend())
            plt.close(f)
            
            # determine scale: most shrunk dimension
            self._figscale = np.min(actsize / fsize)
            
            # create shrunk on-screen figure
            f = plt.figure(figsize=self._figscale*fsize, facecolor='w') 

            # just ensure all is well now
            actsize =np.array([f.get_figwidth(), f.get_figheight()], dtype=float)


        # add decoration
        for block in _flattened(self._axes.elements):

            ax = f.add_axes(self._scaledBox(block),
                            axisbg = plotParams.layer_bg[block.location], xticks=[], yticks=[],
                            zorder = plotParams.z_layer)
            if hasattr(ax, 'frame'):
                ax.frame.set_visible(False)
            else:
                for sp in ax.spines.values():
                    sp.set_color('none')  # turn off axis lines, make room for frame edge
            if block.l <= self._axes.l_patches and block.slbl:
                ax.set_ylabel(block.slbl, 
                              rotation = plotParams.layer_orientation['sender'],
                              fontproperties = plotParams.layer_font)
            if block.t <= self._axes.t_patches and block.tlbl:
                ax.set_xlabel(block.tlbl, 
                              rotation = plotParams.layer_orientation['target'],
                              fontproperties = plotParams.layer_font)
                ax.xaxis.set_label_position('top')

            # inner blocks for population labels
            if not mode in ('totals', 'layer'):
                for pb in _flattened(block.elements):
                    if not isinstance(pb, self._Block):
                        continue # should not happen

                    ax = f.add_axes(self._scaledBox(pb),
                                    axisbg = 'none', xticks=[], yticks=[], 
                                    zorder = plotParams.z_pop)
                    if hasattr(ax, 'frame'):
                        ax.frame.set_visible(False)
                    else:
                        for sp in ax.spines.values():
                            sp.set_color('none')  # turn off axis lines, make room for frame edge
                    if pb.l+pb.w >= self._axes.r_patches and pb.slbl:
                        ax.set_ylabel(pb.slbl, 
                                      rotation=plotParams.pop_orientation['sender'],
                                      fontproperties=plotParams.pop_font)
                        ax.yaxis.set_label_position('right')
                    if pb.t+pb.h >= self._axes.b_patches and pb.tlbl:
                        ax.set_xlabel(pb.tlbl, 
                                      rotation=plotParams.pop_orientation['target'],
                                      fontproperties=plotParams.pop_font)

        # determine minimum and maximum values across all kernels, but set min <= 0, max >= 0
        kern_max = max(0.0, max([np.max(kern.kern) for kern in plotKerns]))
        kern_min = min(0.0, min([np.min(kern.kern) for kern in plotKerns]))

        # determine color limits for plots
        if colorLimits:
            c_min, c_max = colorLimits  # explicit values

        else:
            # default values for color limits
            # always 0 as lower limit so anything > 0 is non-white, except when totals or populations
            c_min = None if mode in ('totals','population') else 0.0
            c_max = None   # use patch maximum as upper limit
            
            if normalize:
                # use overall maximum, at least 0
                c_max = kern_max
                if aggrSyns:
                    # use overall minimum, if negative, otherwise 0
                    c_min = kern_min
                    # for c_max, use the larger of the two absolute values
                    c_max = kern_max
                    # if c_min is non-zero, use same color scale for neg values
                    if c_min < 0:
                        c_min = -c_max

        # Initialize dict storing sample patches for each synapse type for use
        # in creating color bars. We will store the last patch of any given 
        # synapse type for reference. When aggrSyns, we have only one patch type
        # and store that.
        if not aggrSyns:
            samplePatches = dict([(sname, None) for sname in self._synAttr.keys()])
        else:
            # only single type of patches
            samplePatches = None

        for kern in plotKerns:
            p = self._patchTable[(kern.sl,kern.sn,kern.tl,kern.tn,kern.syn)]
            
            p.ax = f.add_axes(self._scaledBox(p), aspect='equal', 
                              xticks=[], yticks=[], zorder=plotParams.z_conn)
            p.ax.patch.set_edgecolor('none')
            if hasattr(p.ax, 'frame'):
                p.ax.frame.set_visible(False)
            else:
                for sp in p.ax.spines.values():
                    sp.set_color('none')  # turn off axis lines, make room for frame edge

            if not aggrSyns:
                # we have synapse information -> not totals, a vals positive
                assert(kern.syn)
                assert(np.min(kern.kern) >= 0.0)

                # we may overwrite here, but this does not matter, we only need 
                # some reference patch
                samplePatches[kern.syn] = p.ax.imshow(kern.kern, 
                                                      vmin = c_min, vmax = c_max,
                                                      cmap = self._synAttr[kern.syn].cmap)#,
                                  # interpolation='nearest')

            else:
                # we have totals, special color table and normalization

                # we may overwrite here, but this does not matter, we only need 
                # some reference patch
                samplePatches = p.ax.imshow(kern.kern, 
                                  vmin = c_min, vmax = c_max,
                                  cmap = cm.bluered,
                                  norm = cm.ZeroCenterNorm()) #, # must be instance
                                  # interpolation='nearest')

        # Create colorbars at bottom of figure
        if showLegend:

            # Do we have kernel values exceeding the color limits?
            if c_min <= kern_min and kern_max <= c_max:
                extmode = 'neither'
            elif c_min > kern_min and kern_max <= c_max:
                extmode = 'min'
            elif c_min <= kern_min and kern_max > c_max:
                extmode = 'max'
            else:
                extmode = 'both'

            if aggrSyns:
                cbax = f.add_axes(self._scaledBox(self._cbPatches))

                # by default, use 4 ticks to avoid clogging
                # according to docu, we need a separate Locator object
                # for each axis.
                if plotParams.legend_ticks:
                    tcks = plotParams.legend_ticks
                else:
                    tcks = mpl.ticker.MaxNLocator(nbins=4)

                if normalize:
                    # colorbar with freely settable ticks
                    cb = f.colorbar(samplePatches, cax = cbax,
                                    orientation = 'horizontal',
                                    ticks = tcks,
                                    format = plotParams.legend_tick_format, extend=extmode)
                else:
                    # colorbar with tick labels 'Exc', 'Inh'
                    # we add the color bare here explicitly, so we get no problems
                    # if the sample patch includes only pos or only neg values
                    cb = mpl.colorbar.ColorbarBase(cbax, cmap=cm.bluered, orientation='horizontal')
                    cbax.set_xticks([0, 1])
                    cbax.set_xticklabels(['Inh', 'Exc'])

                cb.outline.set_linewidth(0.5)  # narrower line around colorbar

                # fix font for ticks
                plt.setp(cbax.get_xticklabels(), fontproperties=plotParams.legend_tick_font)

                # no title in this case

            else:
                # loop over synapse types
                for syn in self._synAttr.keys():
                    cbax = f.add_axes(self._scaledBox(self._cbPatches[syn]))
                    if plotParams.legend_location is None:
                        cbax.set_ylabel(syn, fontproperties=plotParams.legend_title_font,
                                        rotation='horizontal')
                    else:
                        cbax.set_title(syn, fontproperties=plotParams.legend_title_font,
                                       rotation='horizontal')
                    
                    if normalize:

                        # by default, use 4 ticks to avoid clogging
                        # according to docu, we need a separate Locator object
                        # for each axis.
                        if plotParams.legend_ticks:
                            tcks = plotParams.legend_ticks
                        else:
                            tcks = mpl.ticker.MaxNLocator(nbins=4)

                        # proper colorbar
                        cb = f.colorbar(samplePatches[syn], cax = cbax,                               
                                        orientation = 'horizontal',
                                        ticks = tcks,
                                        format = plotParams.legend_tick_format,
                                        extend = extmode)
                        cb.outline.set_linewidth(0.5)  # narrower line around colorbar

                        # fix font for ticks
                        plt.setp(cbax.get_xticklabels(), 
                                 fontproperties=plotParams.legend_tick_font)

                    else:
                        # just a solid color bar with no ticks
                        cbax.set_xticks([])
                        cbax.set_yticks([])

                        # full-intensity color from color map
                        cbax.set_axis_bgcolor(self._synAttr[syn].cmap(1.0))

                        # narrower border
                        if hasattr(cbax, 'frame'):
                            cbax.frame.set_linewidth(0.5)
                        else:
                            for sp in cbax.spines.values():
                                sp.set_linewidth(0.5)

        # save to file(s), use full size
        f.set_size_inches(fsize)
        if isinstance(file, (list,tuple)):
            for fn in file:
                f.savefig(fn)
        elif isinstance(file, str):
            f.savefig(file)
        f.set_size_inches(actsize)  # reset size for further interactive work

        return kern_min, kern_max

    # ----------------------------------------------------------------------------

    def toLaTeX(self, file, standalone = False, enumerate = False, legend = True):
        """
        Write connection table to file.

        Arguments:
        file        output file name
        standalone  create complete LaTeX file (default: False)
        enumerate   enumerate connections (default: False)
        legend      add explanation of functions used (default: True)
        """

        lfile = open(file, 'w')
        if not lfile:
            raise Exception('Could not open file "%s"' % file)

        if standalone:
            lfile.write(\
r"""
\documentclass[a4paper,american]{article}
\usepackage[pdftex,margin=1in,centering,noheadfoot,a4paper]{geometry}
\usepackage[T1]{fontenc}
\usepackage[utf8]{inputenc}
	
\usepackage{color}
\usepackage{calc}
\usepackage{tabularx}  % automatically adjusts column width in tables
\usepackage{multirow}  % allows entries spanning several rows  
\usepackage{colortbl}  % allows coloring tables
\usepackage[fleqn]{amsmath}
\setlength{\mathindent}{0em}

\usepackage{mathpazo}
\usepackage[scaled=.95]{helvet}
\renewcommand\familydefault{\sfdefault}

\renewcommand\arraystretch{1.2}
\pagestyle{empty}

% \hdr{ncols}{label}{title}
%
% Typeset header bar across table with ncols columns 
% with label at left margin and centered title
%
\newcommand{\hdr}[3]{%
  \multicolumn{#1}{|l|}{%
    \color{white}\cellcolor[gray]{0.0}%
    \textbf{\makebox[0pt]{#2}\hspace{0.5\linewidth}\makebox[0pt][c]{#3}}%
  }%
}

\begin{document}
""")
        lfile.write(\
r"""
\noindent\begin{tabularx}{\linewidth}{%s|l|l|l|c|c|X|}\hline
\hdr{%d}{}{Connectivity}\\\hline
%s \textbf{Src} & \textbf{Tgt} & \textbf{Syn} &
\textbf{Wght} & \textbf{Mask} & \textbf{Kernel} \\\hline
""" % (('|r',7,'&') if enumerate else ('',6,'')))
        # ensure sorting according to keys, gives some alphabetic sorting
        haveU, haveG = False, False

        cctr = 0  # connection counter
        for ckey in sorted(self._cTable.keys()):
            for conn in self._cTable[ckey]:

                cctr += 1
                if enumerate: lfile.write('%d &' % cctr)

                # take care to escape _ in names such as GABA_A
                # also remove any pending '/None'
                lfile.write((r'%s/%s & %s/%s & %s' % \
                                 (conn.slayer, conn.snrn, conn.tlayer, conn.tnrn,
                                  conn.synmodel)).replace('_', r'\_').replace('/None',''))
                lfile.write(' & \n')

                if isinstance(conn.weight, (int,float)):
                    lfile.write(r'%g' % conn.weight)
                elif 'uniform' in conn.weight:
                    cw = conn.weight['uniform']
                    lfile.write(r'$\mathcal{U}[%g, %g)$' % (cw['min'], cw['max']))
                    haveU = True
                else:
                    raise ValueError('Unkown weight type "%s"' % conn.weight.__str__)
                lfile.write(' & \n')

                if 'circular' in conn.mask:
                    lfile.write(r'$\leq %g$' % conn.mask['circular']['radius'])
                elif 'rectangular' in conn.mask:
                    cmr = conn.mask['rectangular']
                    lfile.write(\
                        r"""$[(%+g, %+g), (%+g, %+g)]$""" \
                            % (cmr['lower_left'][0], cmr['lower_left'][1],
                               cmr['upper_right'][0], cmr['upper_right'][1]))
                else:
                    raise ValueError('Unknown mask type "%s"' % conn.mask.__str__)
                lfile.write(' & \n')

                if isinstance(conn.kernel, (int, float)):
                    lfile.write(r'$%g$' % conn.kernel)
                elif 'gaussian' in conn.kernel:
                    ckg = conn.kernel['gaussian']
                    lfile.write(r'$\mathcal{G}(p_0 = %g, \sigma = %g)$' % \
                                    (ckg['p_center'], ckg['sigma']))
                    haveG = True
                else:
                    raise ValueError('Unkown kernel type "%s"' % conn.kernel.__str__)
                lfile.write('\n')

                lfile.write(r'\\\hline' '\n')

        if legend and (haveU or haveG):
            # add bottom line with legend
            lfile.write(r'\hline' '\n')
            lfile.write(r'\multicolumn{%d}{|l|}{\footnotesize ' % (7 if enumerate else 6))
            if haveG:
                lfile.write(r'$\mathcal{G}(p_0, \sigma)$: $p(\mathbf{x})=p_0 e^{-\mathbf{x}^2/2\sigma^2}$')
            if haveG and haveU:
                lfile.write(r', ')
            if haveU:
                lfile.write(r'$\mathcal{U}[a, b)$: uniform distribution on $[a, b)$')
            lfile.write(r'}\\\hline' '\n')

        lfile.write(r'\end{tabularx}' '\n\n')

        if standalone:
            lfile.write(r'\end{document}''\n')

        lfile.close()
        
# ----------------------------------------------------------------------------

def _evalkernel(mask, kernel, weight, extent, intensity, tcd):
    """
    Plot kernel within extent. 
    
    Kernel values are multiplied with abs(weight). If weight is a 
    distribution, the mean value is used.
    
    Result is a masked array, in which the values outside the mask are 
    masked.
    """
    
    # determine resolution, number of data points
    dx = max(extent) / plotParams.n_kern
    nx = np.ceil(extent[0] / dx)
    ny = np.ceil(extent[1] / dx)
    
    x = np.linspace(-0.5*extent[0], 0.5*extent[0], nx)
    y = np.linspace(-0.5*extent[1], 0.5*extent[1], ny)
    X, Y = np.meshgrid(x, y)
    
    if intensity == 'wp':
        return np.ma.masked_array(abs(weight) * _kerneval(X, Y, kernel), 
                                  np.logical_not(_maskeval(X, Y, mask)))
    elif intensity == 'p':
        return np.ma.masked_array(_kerneval(X, Y, kernel), 
                                  np.logical_not(_maskeval(X, Y, mask)))
    elif intensity == 'tcd':
        return np.ma.masked_array(abs(tcd) * abs(weight) * _kerneval(X, Y, kernel), 
                                  np.logical_not(_maskeval(X, Y, mask)))


# ----------------------------------------------------------------------------

def _weighteval(weight):
    """Returns weight, or mean of distribution, signed."""
            
    w = None
    if isinstance(weight, (float, int)):
        w = weight
    elif isinstance(weight, dict):
        assert(len(weight) == 1)
        if 'uniform' in weight:
            w = 0.5 * (weight['uniform']['min'] 
                       + weight['uniform']['max'])
        elif 'gaussian' in weight:
            w = weight['gaussian']['mean']
        else:
            raise Exception('Unknown weight type "%s"' % weight.keys()[0])
        
    if not w:
        raise Exception('Cannot handle weight.')
        
    return float(w)

# ----------------------------------------------------------------------------

def _maskeval(x, y, mask):
    """
    Evaluate mask given as topology style dict at
    (x,y). Assume x,y are 2d numpy matrices.
    """

    assert(len(mask)==1)
    if 'circular' in mask:
        r = mask['circular']['radius']
        m = x**2+y**2 <= r**2
    elif 'doughnut' in mask:
        ri = mask['doughnut']['inner_radius']
        ro = mask['doughnut']['outer_radius']
        d = x**2 + y**2
        m = np.logical_and(ri <= d, d <= ro)
    elif 'rectangular' in mask:
        ll = mask['rectangular']['lower_left']        
        ur = mask['rectangular']['upper_right']
        m = np.logical_and(np.logical_and(ll[0] <= x, x <= ur[0]),
                           np.logical_and(ll[1] <= y, y <= ur[1]))
    else:
        raise Exception('Unknown mask type "%s"' % mask.keys()[0])
    
    return m

# ----------------------------------------------------------------------------

def _kerneval(x, y, fun):
    """
    Evaluate function given as topology style dict at
    (x,y). Assume x,y are 2d numpy matrices
    """
    
    if isinstance(fun, (float, int)):
        return float(fun) * np.ones(np.shape(x))
    elif isinstance(fun, dict):
        assert(len(fun) == 1)
        
    if 'gaussian' in fun:
        g = fun['gaussian']
        p0 = g['p_center']
        sig = g['sigma']
        return p0 * np.exp(-0.5*(x**2+y**2)/sig**2)
    else:
        raise Exception('Unknown kernel "%s"', fun.keys()[0])

    # something very wrong
    raise Exception('Cannot handle kernel.')

# ----------------------------------------------------------------------------

def _addKernels(kList):
    """
    Add a list of kernels.
    
    Arguments:
    kList:  List of masked arrays of equal size.
    
    Returns:
    Masked array of same size as input. All values are added,
    setting masked values to 0. The mask for the sum is the 
    logical AND of all individual masks, so that only such
    values are masked that are masked in all kernels.
    _addKernels always returns a new array object, even if
    kList has only a single element.
    """

    assert(len(kList) > 0)

    if len(kList) < 2:
        return kList[0].copy()
    
    d = np.ma.filled(kList[0], fill_value = 0).copy()
    m = kList[0].mask.copy()
    for k in kList[1:]:
        d += np.ma.filled(k, fill_value = 0)
        m  = np.logical_and(m, k.mask)
        
    return np.ma.masked_array(d, m)
    
# ----------------------------------------------------------------------------

def _flattened(lst):
    """Returned list flattend at first level."""
    return sum(lst, [])

# ----------------------------------------------------------------------------

"""
if __name__ == "__main__":

    import sys
    sys.path += ['./examples']

#    import simple
#    reload(simple)
    cp = ConnectionPattern(simple.layerList, simple.connectList)

    import simple2
    reload(simple2)
    cp2 = ConnectionPattern(simple2.layerList, simple2.connectList)

    st3 = ((SynType('GABA_B', -5.0, 'orange'),
            SynType('GABA_A', -1.0, 'm')),
           (SynType('NMDA', 5.0, 'b'),
            SynType('FOO', 1.0, 'aqua'),
            SynType('AMPA', 3.0, 'g')))
    cp3s = ConnectionPattern(simple2.layerList, simple2.connectList,
                             synTypes=st3)       


    import simple3
    reload(simple3)
    cp3 = ConnectionPattern(simple3.layerList, simple3.connectList)

    
#    cp._prepareAxes('by layer')
#    cp2._prepareAxes('by layer')
#    cp3._prepareAxes('detailed')

    cp2.plot()
    cp2.plot(mode='layer')
    cp2.plot(mode='population')
    cp2.plot(mode='totals')
    cp2.plot(mode=('AMPA',))
    cp2.plot(mode=('AMPA','GABA_B'))
#    cp3.plot()
#    cp3.plot(mode='population')
#    cp3.plot(mode='layer')
#    cp3.plot(mode='totals')
#    cp.plot(normalize=True)
#    cp.plot(totals=True, normalize=True)
#    cp2.plot()
#    cp2.plot(file=('cp3.eps'))
#    cp2.plot(byLayer=True)
#    cp2.plot(totals=True)

"""