#include "RewardGenerator.h"
RewardGenerator::RewardGenerator(double Rduration, double Rnegbase, bool AvgRewardZero, bool OneOverRateScale)
: rewardDuration(Rduration), Rnegbase(Rnegbase), AvgRewardZero(AvgRewardZero), OneOverRateScale(OneOverRateScale)
{}
RewardGenerator::~RewardGenerator()
{
}
int RewardGenerator::reset(double dt)
{
A = 1;
stepsLeftForReward = int(rewardDuration / dt);
lastSpikeTime = - stepsLeftForReward * dt - 2 * dt;
stepsLeftForReward = 0;
reward = Rnegbase;
instRate = 0;
isActive = true;
return 0;
}
int RewardGenerator::advance( AdvanceInfo const& )
{
if (stepsLeftForReward > 0)
reward = A;
else
reward = Rnegbase;
stepsLeftForReward = stepsLeftForReward - 1;
if (!isActive)
reward = Rnegbase;
return 0;
}
int RewardGenerator::spikeHit( spikeport_t port, SpikeEvent const& spike )
{
if (stepsLeftForReward > 0)
return 0;
stepsLeftForReward = int(rewardDuration / spike.dt.in_sec());
if (AvgRewardZero)
if (OneOverRateScale)
A = ((1/instRate) - spike.dt.in_sec()) * fabs(Rnegbase) / (stepsLeftForReward * spike.dt.in_sec());
else
A = (spike.t - lastSpikeTime - (stepsLeftForReward * spike.dt.in_sec())) * fabs(Rnegbase) / (stepsLeftForReward * spike.dt.in_sec());
lastSpikeTime = spike.t;
return 0;
}
double RewardGenerator::getAnalogOutput(analog_port_id_t port ) const
{
return reward;
}
//! Analog input to given port
void RewardGenerator::setAnalogInput(double value, analog_port_id_t port)
{
instRate = value;
}
int RewardGenerator::nSpikeInputPorts() const
{
return 1;
}
int RewardGenerator::nAnalogInputPorts() const
{
return 1;
}
int RewardGenerator::nAnalogOutputPorts() const
{
return 1;
}
SimObject::PortType RewardGenerator::outputPortType(port_t p) const
{
if (p == 0)
return analog;
return undefined;
}
SimObject::PortType RewardGenerator::inputPortType(port_t p) const
{
if (p == 0)
return spiking;
else if (p == 1)
return analog;
return undefined;
}