#include <limits.h>
#include <float.h>

#include "NsSystem.hh"
#include "NsTract.hh"


#define CHECK_RANGE(val, min, max) \
    ABORT_UNLESS(Util::isInRange(val, min, max), \
                 "bad value for '{}': {}", #val, val)


NsTract::NsTract(const string &id,
                 NsLayer *fromLayer,
                 NsLayer *toLayer,
                 const string &type)
    : id(id), type(type), fromLayer(fromLayer), toLayer(toLayer),
      e3Level(0), lastE3Level(DBL_MAX), lastTimeStep(UINT_MAX)
{
    acqLearnRate            = props.getDouble(type + '.' + "acqLearnRate");
    reactE3Level            = props.getDouble(type + '.' + "reactE3Level");
    consLearnRate01h        = props.getDouble(type + '.' + "consLearnRate01h");
    psdDecayRate01h         = props.getDouble(type + '.' + "psdDecayRate01h");
    cpAmparRemovalRate01h   = props.getDouble(type + '.' + "cpAmparRemovalRate01h");
    ciAmparInsertionRate01h = props.getDouble(type + '.' + "ciAmparInsertionRate01h");
    ciAmparRemovalRate01h   = props.getDouble(type + '.' + "ciAmparRemovalRate01h");
    baseDepotProb01h        = props.getDouble(type + '.' + "baseDepotProb01h");
    maxE3DepotProb01h       = props.getDouble(type + '.' + "maxE3DepotProb01h");
    e3DecayRate01h          = props.getDouble(type + '.' + "e3DecayRate01h");
    maxPotProb01h           = props.getDouble(type + '.' + "maxPotProb01h");
    
    
    // Sanity check
    //
    CHECK_RANGE(acqLearnRate,            0.0, 1.0);
    CHECK_RANGE(reactE3Level,            0.0, 1.0);
    CHECK_RANGE(consLearnRate01h,        0.0, 1.0);
    CHECK_RANGE(psdDecayRate01h,         0.0, 1.0);
    CHECK_RANGE(cpAmparRemovalRate01h,   0.0, 1.0);
    CHECK_RANGE(ciAmparRemovalRate01h,   0.0, 1.0);
    CHECK_RANGE(baseDepotProb01h,        0.0, 1.0);
    CHECK_RANGE(e3DecayRate01h,          0.0, 1.0);
    CHECK_RANGE(maxE3DepotProb01h,       0.0, 1.0);
    CHECK_RANGE(maxPotProb01h,           0.0, 1.0);

    for (auto fu : fromLayer->units) {
        for (auto tu : toLayer->units) {
            if (fu != tu) {
                connections.push_back(new NsConnection(this, fu, tu));
            }
        }
    }
}

/**
 * Given an exponential decay rate for some interval A, calculate the
 * equivalent rate for some other interval B.
 *
 * Exponential decay is: x(t+a) = (1 - rateA) * x(t)
 *
 * In "increasing exponential decay" the distance to some asymptote S decays
 * exponentially with rate (1 - rateA), so
 * 
 * A - x(t+a) = (1 - rateA) * (S - x(t))
 *
 * For both cases, rateB = 1 - (1 - rateA^(B/A))
 * 
 * "Increasing exponential decay" is also known as "exponential decay
 * (increasing form)" or "exponential decay (rising form)".
 */
static double calcExpDecayRate(double rateA, double intervalA, double intervalB)
{
    return 1.0 - pow(1.0 - rateA, intervalB/intervalA);
}

/**
 * Given a probability of an event happening during a some interval A,
 * calculate the equivalent probability for some other interval B.
 *
 * P(n) = 1 - (1 - P(A))^(B/A) 
 *
 * Note: this exactly the same as calcExpDecayRate, which is not
 *       surprising, since a constant probability of decay at the particle
 *       level translates to a constant *rate* of decay at the population
 *       level. 
 */
static double calcProb(double probA, double intervalA, double intervalB)
{
    return 1.0 - pow(1.0 - probA, intervalB/intervalA);
}

/**
 * Given a constant rate for some interval A, calculate the equivalent rate
 * for some other interval B.
 */
static double calcConstantRate(double rateA, double intervalA, double intervalB)
{
    return intervalB / intervalA * rateA;
}

/**
 * Calculates rates for the current timeStep value
 */
void NsTract::calcRates()
{
    consLearnRate        = calcExpDecayRate(consLearnRate01h, 1.0, timeStep);
    psdDecayRate         = calcExpDecayRate(psdDecayRate01h, 1.0, timeStep);
    cpAmparRemovalRate   = calcExpDecayRate(cpAmparRemovalRate01h, 1.0,
                                            timeStep);
    ciAmparInsertionRate = calcConstantRate(ciAmparInsertionRate01h, 1.0,
                                          timeStep);
    ciAmparRemovalRate   = calcExpDecayRate(ciAmparRemovalRate01h, 1.0,
                                            timeStep);
    baseDepotProb        = calcProb(baseDepotProb01h, 1.0, timeStep);
    e3DecayRate          = calcExpDecayRate(e3DecayRate01h, 1.0, timeStep);
    maxPotProb           = calcProb(maxPotProb01h, 1.0, timeStep);
    calcDepotProb();
}

/*
 * Calculate the total probability of depotentiation as the combination of
 * two independent probabilities: the constitutive depotentiation
 * (baseDepotProb) and depotentition due to the E3 enzyme (e3DepotProb)
 *
 * This function is called whenever e3Level or timeStep changes.
 */
inline void NsTract::calcDepotProb()
{
    // Don't waste time if nothing changed
    //
    if (e3Level != lastE3Level || timeStep != lastTimeStep) {
        double e3DepotProb01h = maxE3DepotProb01h * e3Level;
        e3DepotProb = calcProb(e3DepotProb01h, 1.0, timeStep);
        depotProb = baseDepotProb + e3DepotProb - baseDepotProb * e3DepotProb;

        ABORT_IF(depotProb > 1.0, "impossible");

        lastE3Level = e3Level;
        lastTimeStep = timeStep;
    }
}

void NsTract::stimulate(double learnRate, uint numStimCycles,
                        const char *tag)
{
    for (auto c : connections) {
        c->stimulate(learnRate, numStimCycles, tag);
    }
}

void NsTract::acquire(uint numStimCycles, const char *tag)
{
    stimulate(acqLearnRate, numStimCycles, tag);
}

void NsTract::consolidate(uint numStimCycles)
{
    stimulate(consLearnRate, numStimCycles, "cons");
}

void NsTract::amparTrafficking()
{
    for (auto c : connections) {
        c->amparTrafficking(cpAmparRemovalRate,
                            ciAmparInsertionRate,
                            ciAmparRemovalRate);
    }
}

/**
 * Randomly depotentiate some connections
 *
 */
void NsTract::depotentiateSome()
{
    for (auto c: connections) {
        if (c->isPotentiated && (Util::randDouble(0.0, 1.0) < depotProb)) {
            c->depotentiate("random");
        }
    }
}

/**
 * Run maintenance processes
 */
void NsTract::maintain()
{
    depotentiateSome();
    amparTrafficking();
    e3Level -= e3DecayRate * e3Level;
    
    // Recalculate depotentiation probability after updating E3 level
    //
    calcDepotProb();

    debugTrace("time: {} tract: {}  e3Level: {}  depotProb: {}\n",
               simTime, id, e3Level, depotProb);
}

/**
 * Toggle PSI on or off on all of the tract's connections
 * @param state State
 */
void NsTract::togglePsi(bool state)
{
    for (auto c: connections) {
        c->togglePsi(state);
    }
}

/**
 * Set E3 level and invoke reactivation processing in all connection that
 * are in the Hebbian condition, i.e. from-unit and to-unit are both active
 */
void NsTract::reactivate()
{
    // - Activate E3 enzyme. (E3 increases probability of depotentiation)
    //   TODO: should this be restricted to connections originating from or
    //   terminating on the units selected in makePattern below? i.e. units
    //   activated by reactivation.
    //
    e3Level = reactE3Level;
    calcDepotProb();

    for (auto c: connections) {
        if (c->fromUnit->isActive && c->toUnit->isActive) {
            c->reactivate();
        }
    }
}


/**
 * Count number of potentiated connections in the tract
 * @return The count
 */
uint NsTract::getNumPotentiated() const
{
    uint ret = 0;
    for (auto c: connections) {
        if (c->isPotentiated) ret++;
    }
    return ret;
}

/**
 * Print header line for the numPotentiate printouts
 */
void NsTract::printNumPotentiatedHdr()
{
    infoTrace("time tract id numPotentiated\n");
}

/**
 * Print number of potentiated connections in the tract
 */
void NsTract::printNumPotentiated() const
{
    infoTrace("{} tract {} {}\n", simTime, id, getNumPotentiated());
}

/**
 * Print the state of all of the tract's connections
 */
void NsTract::printState() const
{
    printNumPotentiated();
    for (auto c: connections) {
        c->printState();
    }
}

/**
 * Generate a string representation of the tract and all its connections
 */
string NsTract::toStr(uint iLvl, const string &iStr) const
{
    string ret = fmt::format("{}NsTract[{}]: ",
                             Util::repeatStr(iStr, iLvl), id);
    ret += fmt::format("\n{}acqLearnRate={}",
                       Util::repeatStr(iStr, iLvl + 1), acqLearnRate);
    ret += fmt::format("\n{}consLearnRate={}",
                       Util::repeatStr(iStr, iLvl + 1), consLearnRate);

    for (auto c: connections) {
        ret += "\n" + c->toStr(iLvl + 1, iStr);
    }
    return ret;
}