"""

:copyright: Copyright 2006-2016 by the PyNN team, see AUTHORS.
:license: CeCILL, see LICENSE for details.
"""

import numpy
from pyNN import recording
from pyNN.neuron import simulator
import re
from neuron import h


recordable_pattern = re.compile(r'((?P<section>\w+)(\((?P<location>[-+]?[0-9]*\.?[0-9]+)\))?\.)?(?P<var>\w+)')


class Recorder(recording.Recorder):
    """Encapsulates data and functions related to recording model variables."""
    _simulator = simulator

    def _record(self, variable, new_ids, sampling_interval=None):
        """Add the cells in `new_ids` to the set of recorded cells."""
        if variable == 'spikes':
            for id in new_ids:
                if id._cell.rec is not None:
                    id._cell.rec.record(id._cell.spike_times)
        else:
            self.sampling_interval = sampling_interval or self._simulator.state.dt
            for id in new_ids:
                self._record_state_variable(id._cell, variable)

    def _record_state_variable(self, cell, variable):
        if hasattr(cell, 'recordable') and variable in cell.recordable:
            hoc_var = cell.recordable[variable]
        elif variable == 'v':
            hoc_var = cell.source_section(0.5)._ref_v  # or use "seg.v"?
        elif variable == 'gsyn_exc':
            hoc_var = cell.esyn._ref_g
        elif variable == 'gsyn_inh':
            hoc_var = cell.isyn._ref_g
        else:
            source, var_name = self._resolve_variable(cell, variable)
            hoc_var = getattr(source, "_ref_%s" % var_name)
        cell.traces[variable] = vec = h.Vector()
        if self.sampling_interval == self._simulator.state.dt:
            vec.record(hoc_var)
        else:
            vec.record(hoc_var, self.sampling_interval)
        if not cell.recording_time:
            cell.record_times = h.Vector()
            if self.sampling_interval == self._simulator.state.dt:
                cell.record_times.record(h._ref_t)
            else:
                cell.record_times.record(h._ref_t, self.sampling_interval)
            cell.recording_time += 1

    # could be staticmethod
    def _resolve_variable(self, cell, variable_path):
        match = recordable_pattern.match(variable_path)
        if match:
            parts = match.groupdict()
            if parts['section']:
                section = getattr(cell, parts['section'])
                if parts['location']:
                    source = section(float(parts['location']))
                else:
                    source = section
            else:
                source = cell.source
            return source, parts['var']
        else:
            raise AttributeError("Recording of %s not implemented." % variable_path)

    def _reset(self):
        """Reset the list of things to be recorded."""
        for id in set.union(*self.recorded.values()):
            id._cell.traces = {}
            id._cell.spike_times = h.Vector(0)
        id._cell.recording_time == 0
        id._cell.record_times = None

    def _clear_simulator(self):
        """
        Should remove all recorded data held by the simulator and, ideally,
        free up the memory.
        """
        for id in set.union(*self.recorded.values()):
            if hasattr(id._cell, "traces"):
                for variable in id._cell.traces:
                    id._cell.traces[variable].resize(0)
            if id._cell.rec is not None:
                id._cell.spike_times.resize(0)
            else:
                id._cell.clear_past_spikes()

    def _get_spiketimes(self, id):
        if hasattr(id, "__len__"):
            all_spiketimes = {}
            for cell_id in id:
                spikes = numpy.array(cell_id._cell.spike_times)
                all_spiketimes[cell_id] = spikes[spikes <= simulator.state.t + 1e-9]
            return all_spiketimes
        else:
            spikes = numpy.array(id._cell.spike_times)
            return spikes[spikes <= simulator.state.t + 1e-9]

    def _get_all_signals(self, variable, ids, clear=False):
        # assuming not using cvode, otherwise need to get times as well and use IrregularlySampledAnalogSignal
        if len(ids) > 0:
            signals = numpy.vstack((id._cell.traces[variable] for id in ids)).T
            expected_length = numpy.rint(simulator.state.tstop / self.sampling_interval) + 1
            if signals.shape[0] != expected_length:  # generally due to floating point/rounding issues
                signals = numpy.vstack((signals, signals[-1, :]))
        else:
            signals = numpy.array([])
        return signals

    def _local_count(self, variable, filter_ids=None):
        N = {}
        if variable == 'spikes':
            for id in self.filter_recorded(variable, filter_ids):
                N[int(id)] = id._cell.spike_times.size()
        else:
            raise Exception("Only implemented for spikes")
        return N