# Compute mutual info between the presence of the stimulus and the postynaptic responses
# Results are plotted, and dumped in data.mat
from scipy.stats import norm
printtime('**************************')
printtime('* Mutual Info (rand=' + '%03d' % randState + ') *')
printtime('**************************')
# note: patternPeriod should be loaded
eval = array([ max(0,round(endTime)-200.0/pbTimeCompression), round(endTime) ]) # period of evaluation
# load pattern periods
if os.path.exists(os.path.join('..','data','patternPeriod.'+'%03d' % (randState)+'.mat')):
patternPeriod=loadmat(os.path.join('..','data','patternPeriod.'+'%03d' % (randState)+'.mat'))
patternPeriod=patternPeriod['patternPeriod']/pbTimeCompression
for i in range(1):
#dt = min(100*second,endTime) # period of evaluation
#idx = 0; # cursor for patternPeriod array
#
## init arrays
#nH = zeros(M)
#nFA = zeros(M)
#if a>0: #oscillatory mode
# # gather data
# for i in range(outputSpike.nspikes):
# if outputSpike.spikes[i][1]<endTime-dt:
# continue
# while idx<len(patternPeriod) and patternPeriod[idx][1] + .5/oscilFreq < outputSpike.spikes[i][1]:
# idx+=1
# if idx<len(patternPeriod) and patternPeriod[idx][0] + .5/oscilFreq < outputSpike.spikes[i][1] and patternPeriod[idx][1]-patternPeriod[idx][0]>.5/oscilFreq:
# nH[outputSpike.spikes[i][0]] += 1
# else:
# nFA[outputSpike.spikes[i][0]] += 1
#
#else:
# if useReset: # reset mode
evalPeriod = 1/oscilFreq# time bins
# evalPeriod = 50e-3# time bins
if poissonOutput:
nSpikeCriterion = 1 # spike number threshold
else:
nSpikeCriterion = 1 # spike number threshold
latencyRange = [ 0e-3 ] #1e-3*arange(0,0,25) # causes a 1/2 latency shift between stimulus times and response times
maxI = zeros(len(latencyRange))
for l in range(len(latencyRange)):
latency = latencyRange[l]
idxP = 0 # cursor for pattern periods
idxS = 0 # cursor for postsynaptic spikes
nHit = zeros(M)
nMissed = zeros(M)
nFA = zeros(M) # false alarm
nCR = zeros(M) # correct rejection
# move to pertinent pattern periods
while idxP<len(patternPeriod) and patternPeriod[idxP][1] < eval[0]:
idxP+=1
# move to pertinent spikes
while idxS<outputSpike.nspikes and outputSpike.spikes[idxS][1]-.5*latency< eval[0]:
idxS+=1
for i in range(int((eval[1]-eval[0])/evalPeriod)):
t = eval[0]+i*evalPeriod
# determine if evalPeriod is mostly a stimulus or a distractor
patternTime = 0
while idxP<len(patternPeriod) and patternPeriod[idxP][0] < t+evalPeriod:
patternTime += min(t+evalPeriod,patternPeriod[idxP][1]) - max(t,patternPeriod[idxP][0])
if patternPeriod[idxP][1] < t+evalPeriod:
idxP+=1
else:
break
# print patternTime
# count spikes
spikeCount = zeros(M)
while idxS<outputSpike.nspikes and outputSpike.spikes[idxS][1]-.5*latency< t+evalPeriod:
spikeCount[outputSpike.spikes[idxS][0]] += 1
idxS+=1
if patternTime>.5*evalPeriod: # stimulus bin
nHit = nHit + (spikeCount>=nSpikeCriterion)
nMissed = nMissed + (spikeCount<nSpikeCriterion)
else: # distractor bin
nFA = nFA + (spikeCount>=nSpikeCriterion)
nCR = nCR + (spikeCount<nSpikeCriterion)
I = zeros(M)
dprime = zeros(M)
card = int((eval[1]-eval[0])/evalPeriod)
for m in range(M):
if nHit[m]>0:
I[m] += nHit[m] * log( card * nHit[m]/(nFA[m]+nHit[m])/(nMissed[m]+nHit[m]) )
if nFA[m]>0:
I[m] += nFA[m] * log( card * nFA[m]/(nFA[m]+nCR[m])/(nFA[m]+nHit[m]) )
if nMissed[m]>0:
I[m] += nMissed[m] * log( card * nMissed[m]/(nMissed[m]+nCR[m])/(nMissed[m]+nHit[m]) )
if nCR[m]>0:
I[m] += nCR[m] * log( card * nCR[m]/(nMissed[m]+nCR[m])/(nFA[m]+nCR[m]) )
if nHit[m]>0 and nFA[m]>0 and nMissed[m]>0 and nCR[m]>0:
dprime[m] = norm.ppf(nHit[m]/(nHit[m]+nMissed[m])) - norm.ppf(nFA[m]/(nFA[m]+nCR[m]))
I *= 1.0/card * 1/log(2)
Hs = - 1/log(2) * 1.0/card * ( (nMissed[0]+nHit[0]) * log((nMissed[0]+nHit[0])/card) + (nCR[0]+nFA[0]) * log((nCR[0]+nFA[0])/card) )
print 'prop of stim = ' + str((nHit[0]+nMissed[0])/card)
print 'H(stim) = ' + str(Hs)
print 'max(I) = ' + str(max(I))
print 'mean(I) = ' + str(mean(I))
print 'std(I) = ' + str(std(I))
print 'max(I)/H(stim) = ' + str(max(I)/Hs)
print 'mean(I)/H(stim) = ' + str(mean(I)/Hs)
print 'max(dprime) = ' + str(max(dprime))
maxI[l]=max(I)
if os.path.exists(os.path.join('..','data','data.mat')):
data=loadmat(os.path.join('..','data','data.mat'))
data_randState = data['randState']
data_I=data['I']
data_dprime=data['dprime']
data_Hs = data['Hs']
if size(data_randState) == 1:
data_randState = [data_randState]
data_I = [data_I]
data_dprime = [data_dprime]
data_Hs = [data_Hs]
data_randState = concatenate((data_randState,[randState]))
data_I = concatenate((data_I,[I]))
data_dprime = concatenate((data_dprime,[dprime]))
data_Hs = concatenate((data_Hs,[Hs]))
else:
data_randState = array([randState])
data_I = array([I])
data_dprime = array([dprime])
data_Hs = array([Hs])
data={'randState':data_randState,'I':data_I,'dprime':data_dprime,'Hs':data_Hs}
savemat('../data/data.mat',data)
del data
if graph:
figure()
subplot(2,1,1)
imshow(reshape(I,(nR,nG)))
xlabel('gmax')
ylabel('LTD/LTP')
colorbar()
subplot(2,1,2)
imshow(reshape(dprime,(nR,nG)))
xlabel('gmax')
ylabel('LTD/LTP')
colorbar()
eval = eval-100/pbTimeCompression # end of loop on eval periods
if len(latencyRange)>1:
print latencyRange
print maxI
if graph:
figure()
plot(latencyRange,maxI)
if graph:
show()
# for t in range(len(reset)):
# if reset[t]<endTime-dt:
# continue
# if reset[t]>=endTime:
# break
# if reset[t+1]-reset[t]<evalPeriod: # can not evaluate first reset if a second one follows right afterwards
# continue
# # move on to adequate pattern period
# while idxP<len(patternPeriod) and patternPeriod[idxP][1] < reset[t]:
# idxP+=1
#
# # determine if evalPeriod is mostly a stimulus or a distractor
# idxPT = idxP # temporary pattern cursor
# patternTime = 0
# while idxPT<len(patternPeriod) and patternPeriod[idxPT][0] < reset[t]+evalPeriod:
# patternTime += min(reset[t]+evalPeriod,patternPeriod[idxPT][1]) - max(reset[t],patternPeriod[idxPT][0])
# idxPT+=1
#
# # see who has detected it
# hasDetetected = zeros(M)
#