from __future__ import division
import dipoleWrapperNeuron
import poissonSourceAxon
from pylab import *
from random import randint, gauss
from numpy import *
from scipy.stats import poisson
import pickle
from scipy.signal import butter, lfilter, freqz
from matplotlib.font_manager import FontProperties, findfont
import matplotlib.font_manager as fm
from mpl_toolkits.mplot3d import Axes3D
# Set up time
tau = 0.1
maxTime = 2.0
maxTimeInMS = maxTime * 1000
maxTimeStep = maxTime / tau
tspan = arange(0,maxTimeInMS/tau, 1)
T1=len(tspan)/5
T2 = 4*T1
# Neuron Parameters
neuronHeight = 0.00056
predictionEOnset = 300
predictionIOnset = 320
predictionEOffset = 1800
predictionIOffset = 1820
sensoryEOnset = 100
sensoryEOffset = 1800
sensoryIOnset = 120
sensoryIOffset = 1820
predictionEWeight = 15
predictionIWeight = 12.5
sensoryEWeight = 4
sensoryIWeight = 18
predictionELambda = 0.01
predictionILambda = 0.01
sensoryELambda = 0.01
sensoryILambda = 0.01
offLambda = 0.0001
predictionInputCount = 200
sensoryInputCount = 200
ratioOfCorrelatedInhibition = 0.5
correlatedPredictionCount = int(sensoryInputCount * ratioOfCorrelatedInhibition)
unCorrelatedPredictionCount = sensoryInputCount - correlatedPredictionCount
correlatedSensoryCount = int(predictionInputCount * ratioOfCorrelatedInhibition)
unCorrelatedSensoryCount = predictionInputCount - correlatedSensoryCount
# Set the pharmacological manipulations
ACh = False
Retigabine = False
RetigabinePartial = False
RetigabineRescue = False
PCP = False
PCPPartial = False
# Uncomment the label you want for the saved plots, indicating which pharmacological manipulations were present.
pharmas = 'Baseline'
# pharmas = 'PCP_Partial'
# pharmas = 'PCP'
# pharmas = 'Retigabine'
# pharmas = 'RetigabinePartial'
# pharmas = 'PCPPartialPlusRetigabine'
# pharmas = 'PCPPlusRetigabine' # Used this one.
# pharmas = 'PCPPlusRetigabinePartial'
# pharmas = 'PCPPlusRetigabineRescueWeight'
# pharmas = 'TargetPrimedBaseline'
# pharmas = 'ACh'
# pharmas = 'AChPerSE'
# pharmas = 'noEPrime'
# pharmas = 'noIPrime'
trialCount = 10
correlationMax = 1
correlationRange = linspace(1.0, 0, 11)
correlationRangeLen = len(correlationRange)
trialRec = 0
noiseIdxRec = 0
## Robustness records
results = []
TPRecord = zeros((trialCount, correlationRangeLen))
FPRecord = zeros((trialCount, correlationRangeLen))
FNRecord = zeros((trialCount, correlationRangeLen))
TNRecord = zeros((trialCount, correlationRangeLen))
# # Uncomment to resume where you left off
# pickleJar = open("pickle.jar", "rb")
# [trialRec, noiseIdxRec, FPRecord, TPRecord, TNRecord, FNRecord] = pickle.load(pickleJar) # primedVoltageRaster, unprimedVoltageRaster
for trial in range(trialRec, trialCount):
print("TRIAL: ", trial)
# Loop
for noiseIdx in range(noiseIdxRec, correlationRangeLen):
ratioOfCorrelatedInhibition = correlationRange[noiseIdx]
print("\tCorrelated Proportion: ", ratioOfCorrelatedInhibition)
correlatedPredictionCount = int(predictionInputCount * ratioOfCorrelatedInhibition)
unCorrelatedPredictionCount = predictionInputCount - correlatedPredictionCount
for primeLambdas in [0.01, 0.001]:
sensoryELambda = primeLambdas
sensoryILambda = primeLambdas
# Set up the neuron
neuron = dipoleWrapperNeuron.DipoleWrapperNeuron(inputs_a=[], inputs_b=[], outputs=[], externalInput_a=0.0, externalInput_b=0.0, distance = neuronHeight, position = [randint(1,100), randint(1,100), 0.0], debug=False, tau=tau, tspan=tspan, inactivating=False)
if ACh:
neuron.base.gdapt = 3.2
neuron.apex.gdapt = 3.2
if PCP:
neuron.base.g_nmda = 1.75e-3
neuron.apex.g_nmda = 1.75e-3
sensoryILambda = 0.001
predictionILambda = 0.001
if PCPPartial:
neuron.base.g_nmda = 1.9e-3
neuron.apex.g_nmda = 1.9e-3
sensoryILambda = 0.009
if Retigabine:
neuron.base.gdapt = 5.0
neuron.apex.gdapt = 5.0
sensoryIWeight += 6.0
if RetigabinePartial:
neuron.base.gdapt = 4.5
neuron.apex.gdapt = 4.5
sensoryIWeight += 5.5
if RetigabineRescue:
neuron.base.gdapt = 4.1
neuron.apex.gdapt = 4.1
sensoryIWeight += 4
predictionEConnections = []
predictionIConnections = []
sensoryEConnections = []
sensoryIConnections = []
# Set up prediction inputs from CA3
for n in range(correlatedPredictionCount):
weightE = predictionEWeight * gauss(1.0, 0.1)
weightI = -1 * predictionIWeight * gauss(1.0, 0.1)
predictionEConnections.append(poissonSourceAxon.PoissonAxon(tau, predictionELambda, predictionEOnset, sensoryEOffset, offLambda))
neuron.addInput_b(predictionEConnections[-1], weightE)
neuron.addInput_b(predictionEConnections[-1], weightI)
for n in range(unCorrelatedPredictionCount):
weightE = predictionEWeight * gauss(1.0, 0.1)
weightI = -1 * predictionIWeight * gauss(1.0, 0.1)
predictionEConnections.append(poissonSourceAxon.PoissonAxon(tau, predictionELambda, predictionEOnset, sensoryEOffset, offLambda))
neuron.addInput_b(predictionEConnections[-1], weightE)
predictionIConnections.append(poissonSourceAxon.PoissonAxon(tau, predictionILambda, predictionIOnset, sensoryIOffset, offLambda))
neuron.addInput_b(predictionIConnections[-1], weightI)
# Set up sensory inputs from EC3
for n in range(correlatedSensoryCount):
weightE = sensoryEWeight * gauss(1.0, 0.1)
weightI = -1 * sensoryIWeight * gauss(1.0, 0.1)
sensoryEConnections.append(poissonSourceAxon.PoissonAxon(tau, sensoryELambda, sensoryEOnset, sensoryEOffset, offLambda))
neuron.addInput_a(sensoryEConnections[-1], weightE)
neuron.addInput_a(sensoryEConnections[-1], weightI)
for n in range(unCorrelatedSensoryCount):
weightE = sensoryEWeight * gauss(1.0, 0.1)
weightI = -1 * sensoryIWeight * gauss(1.0, 0.1)
sensoryEConnections.append(poissonSourceAxon.PoissonAxon(tau, sensoryELambda, sensoryEOnset, sensoryEOffset, offLambda))
neuron.addInput_a(sensoryEConnections[-1], weightE)
sensoryIConnections.append(poissonSourceAxon.PoissonAxon(tau, sensoryILambda, sensoryIOnset, sensoryIOffset, offLambda))
neuron.addInput_a(sensoryIConnections[-1], weightI)
# Set up records
Voltage = zeros(int(maxTimeInMS / tau))
# Run the simluation
for time in range(len(tspan)):
# if time > drivingOnset:
# neuron.base.debug = True
# print("time:", time)
neuron.step(time)
Voltage[time] = neuron.base.v
for connection in predictionEConnections:
connection.step()
for connection in predictionIConnections:
connection.step()
for connection in sensoryEConnections:
connection.step()
for connection in sensoryIConnections:
connection.step()
# Compute TP/TN/FP/FN
onsetSpikeRate = len([neuron.base.spikeRecord[s] for s in range(len(neuron.base.spikeRecord)) if
neuron.base.spikeRecord[s][0] > int(floor(predictionEOnset / tau)) and
neuron.base.spikeRecord[s][0] < int(floor(predictionEOnset / tau + 200 / tau))])
sustainedSpikeRate = len([neuron.base.spikeRecord[s] for s in range(len(neuron.base.spikeRecord)) if
neuron.base.spikeRecord[s][0] > int(floor(predictionEOnset / tau + 201 / tau)) and
neuron.base.spikeRecord[s][0] < int(floor(predictionEOffset / tau))])
sum(neuron.base.spikeRecord[int(floor(predictionEOnset / tau + 201 / tau)): int(
floor(predictionEOffset / tau))])
# print(onsetSpikeRate)
# print(sustainedSpikeRate)
if ((onsetSpikeRate + 0.0001) / (sustainedSpikeRate + 0.0001) >= 2.0) and (sustainedSpikeRate < 2):
phasic = 1
# print "PHASIC!"
else:
phasic = 0
if primeLambdas < 0.01 and phasic == 1:
FPRecord[trial, noiseIdx] = 1
print("False Positive!")
figure()
plot(arange(0, maxTimeInMS, tau), Voltage, linewidth=1.5)
# plot(arange(0, maxTimeInMS, tau), noiseRecord, linewidth=1.0)
title("False Positive with Noise:" + str(ratioOfCorrelatedInhibition))
xlabel("Time in MS")
ylabel("Voltage in mV")
a = gca()
ylim([-80, 40])
# show()
elif primeLambdas >= 0.01 and phasic == 1:
TPRecord[trial, noiseIdx] = 1
print("True Positive!")
elif primeLambdas < 0.01 and phasic == 0:
TNRecord[trial, noiseIdx] = 1
print("True Negative!")
elif primeLambdas >= 0.01 and phasic == 0:
FNRecord[trial, noiseIdx] = 1
print("False Negative!")
figure()
plot(arange(0, maxTimeInMS, tau), Voltage, linewidth=1.5)
# plot(arange(0, maxTimeInMS, tau), noiseRecord, linewidth=1.0)
title("False Negative with Noise:" + str(ratioOfCorrelatedInhibition))
xlabel("Time in MS")
ylabel("Voltage in mV")
a = gca()
ylim([-80, 40])
# show()
with open("pickle.jar", "wb") as pickleJar:
pickle.dump([trial, noiseIdx, FPRecord, TPRecord, TNRecord, FNRecord], pickleJar)
noiseIdxRec = 0
FPToPlot = mean(FPRecord, 0) / 2
TPToPlot = mean(TPRecord, 0) / 2
TNToPlot = mean(TNRecord, 0) / 2
FNToPlot = mean(FNRecord, 0) / 2
FPToPlot = [sum(FPRecord[:, n]) for n in range(correlationRangeLen)]
TPToPlot = [sum(TPRecord[:, n]) for n in range(correlationRangeLen)]
TNToPlot = [sum(TNRecord[:, n]) for n in range(correlationRangeLen)]
FNToPlot = [sum(FNRecord[:, n]) for n in range(correlationRangeLen)]
correlationRangeToPlot = linspace(10,0,11)
figure()
plot(correlationRangeToPlot, FPToPlot, linewidth=1.5, label="False Positives")
plot(correlationRangeToPlot, TPToPlot, linewidth=1.5, label="True Positives")
plot(correlationRangeToPlot, TNToPlot, linewidth=1.5, label="True Negatives")
plot(correlationRangeToPlot, FNToPlot, linewidth=1.5, label="False Negatives")
legend()
title("Performance with Decorrelated Inputs")
xlabel("Proportion of Linked E/I Inputs")
ylabel("Number of Trials")
a = gca()
# ylim([-0.1,1.1])
labels = [item.get_text() for item in a.get_xticklabels()]
labels[1] = 1.0
labels[2] = 0.8
labels[3] = 0.6
labels[4] = 0.4
labels[5] = 0.2
labels[6] = 0.0
a.set_xticklabels(labels)
figure()
plot(correlationRangeToPlot, FPToPlot, linewidth=1.5, label="False Positives")
plot(correlationRangeToPlot, TNToPlot, linewidth=1.5, label="True Negatives")
legend()
title("Performance with Decorrelated Inputs: Unprimed Trials")
xlabel("Proportion of Linked E/I Inputs")
ylabel("Number of Trials")
a = gca()
# ylim([-0.1,1.1])
labels = [item.get_text() for item in a.get_xticklabels()]
labels[1] = 1.0
labels[2] = 0.8
labels[3] = 0.6
labels[4] = 0.4
labels[5] = 0.2
labels[6] = 0.0
a.set_xticklabels(labels)
figure()
plot(correlationRangeToPlot, TPToPlot, linewidth=1.5, label="True Positives")
plot(correlationRangeToPlot, FNToPlot, linewidth=1.5, label="False Negatives")
legend()
title("Performance with Decorrelated Inputs: Primed Trials")
xlabel("Proportion of Linked E/I Inputs")
ylabel("Number of Trials")
a = gca()
# ylim([-0.1,1.1])
labels = [item.get_text() for item in a.get_xticklabels()]
labels[1] = 1.0
labels[2] = 0.8
labels[3] = 0.6
labels[4] = 0.4
labels[5] = 0.2
labels[6] = 0.0
a.set_xticklabels(labels)
show()