# -*- coding: utf-8 -*-
####################################
#Nakano T, Otsuka M, Yoshimoto J, Doya K (2015) A Spiking Neural Network Model of Model-Free Reinforcement Learning with High-Dimensional Sensory Input and Perceptual Ambiguity. PLoS ONE 10(3): e0115620. doi:10.1371/journal.pone.0115620
####################################
import nest
nest.pynestkernel.logstdout("")
import nest.raster_plot
import nest.voltage_trace
import pylab
import numpy
import scipy
import time, sys
import matplotlib.pyplot as plt
import random
import operator
import math
import csv
nest.ResetKernel()
nest.SetKernelStatus({"overwrite_files": True})
####################################
class SNNc:
def __init__(self):
#define parameters
self.NS=20*15 # number of state neurons
self.NH=90 # number of hidden neurons
self.NA=90 # number of action neurons
self.NM=50 # number of memory neurons
self.InitWhsmean=20 # mean weight Whs
self.InitWhsstd=11.88 # std weight Whs
self.InitWhamean=20 # mean weight Wha
self.InitWhastd=11.88 # std weight Wha
self.InitWhmmean=20 # mean weight Whm
self.InitWhmstd=11.88 # std weight Whm
self.InitWmmmean=0 # mean weight Whm
self.InitWmmstd=0. # std weight Whm
self.Wseed=0#seed for weight
self.Nseed=[123]#seed for noise
self.driveparams_on = {'amplitude':2000.}#current inputs to state neurons (and action neurons)
self.driveparams_on_a = {'amplitude':2000.}#current inputs to state neurons (and action neurons)
self.driveparams_off_a = {'amplitude':-5000.}#current inputs to state neurons (and action neurons)
self.driveparams_off = {'amplitude':0.}#no current inputs
self.driveparams_inh = {'amplitude':-5000.}#no current inputs
self.noiseparams = {'mean':0.0, 'std':300.}#noise inputs to all state and action neurons
self.sdparams = { "withtime": True, "withgid" : True,'to_file':False, 'to_screen':False,'flush_after_simulate':True,'flush_records':True}
#neuronparams = { 'tau_m':20., 'V_th':-50., 'E_L':-60., 't_ref':2., 'V_reset':-60., 'C_m':200.}
#create neurons
self.Sneurons = nest.Create('iaf_neuron',self.NS)
self.Hneurons = nest.Create('iaf_neuron',self.NH)
self.Aneurons = nest.Create('iaf_neuron',self.NA)
self.Mneurons = nest.Create('iaf_neuron',self.NM)
self.sd= nest.Create('spike_detector')
self.drive= nest.Create('dc_generator',self.NS+4+1)
nest.SetKernelStatus({'rng_seeds':self.Nseed})#noise
self.noise= nest.Create('noise_generator',6)
self.voltmeter = nest.Create("voltmeter")
#set parameters
nest.SetStatus(self.sd,[self.sdparams] )
nest.SetStatus(self.noise,[self.noiseparams] ) # if noise selection works, comment out this.
#nest.SetStatus(self.Sneurons, [self.neuronparams])
#nest.SetStatus(self.Hneurons, [self.neuronparams])
#nest.SetStatus(self.Aneurons, [self.neuronparams])
#connect
nest.DivergentConnect(self.noise[0:1], self.Sneurons)
nest.DivergentConnect(self.noise[1:2], self.Aneurons)
nest.DivergentConnect(self.noise[2:3], self.Hneurons)
nest.DivergentConnect(self.noise[3:4], self.Mneurons)
for i in range(self.NS):
nest.Connect(self.drive[i:i+1],self.Sneurons[i:i+1] )
nest.DivergentConnect(self.drive[self.NS:self.NS+1], self.Aneurons[0:self.NA/4])
nest.DivergentConnect(self.drive[self.NS+1:self.NS+2], self.Aneurons[self.NA/4:self.NA/2])
nest.DivergentConnect(self.drive[self.NS+2:self.NS+3], self.Aneurons[self.NA/2:self.NA*3/4])
nest.DivergentConnect(self.drive[self.NS+3:self.NS+4], self.Aneurons[self.NA*3/4:self.NA])
nest.DivergentConnect(self.drive[self.NS+4:self.NS+5], self.Mneurons)
nest.ConvergentConnect(self.Sneurons, self.Hneurons, weight=100.0, delay=1.0)#weight is no meaning because it is defined later
nest.ConvergentConnect(self.Hneurons, self.Aneurons, weight=100.0, delay=1.0)#DivergentConnect?
nest.ConvergentConnect(self.Aneurons, self.Hneurons, weight=100.0, delay=1.0)
nest.ConvergentConnect(self.Mneurons, self.Hneurons, weight=100.0, delay=1.0)
#nest.sli_run('/RandomDivergentConnect << /allow_multapses false >> SetOptions')
#nest.sli_run('/RandomDivergentConnect << /allow_autapses false >> SetOptions')
#nest.RandomDivergentConnect(self.Mneurons, self.Mneurons, self.NM/5, weight=100.0, delay=1.0)
nest.ConvergentConnect(self.Mneurons, self.Mneurons, weight=0.0, delay=1.0)
nest.ConvergentConnect(self.Sneurons, self.Mneurons, weight=0.0, delay=1.0)#weight is no meaning because it is defined later
nest.ConvergentConnect(self.Sneurons, self.sd)
nest.ConvergentConnect(self.Hneurons, self.sd)
nest.ConvergentConnect(self.Aneurons, self.sd)
nest.ConvergentConnect(self.Mneurons, self.sd)
nest.Connect(self.voltmeter, self.Hneurons[20:21])
#weight
random.seed(self.Wseed)
def InitW(self):
##init weight
##Whs
self.Whs=[]
for j in range(self.NS):
Whstemp=[]
for i in range(self.NH):
Whstemp.append(random.normalvariate(self.InitWhsmean, self.InitWhsstd))
self.Whs.append(Whstemp)
##Wha
self.Wha=[]
for j in range(self.NA):
Whatemp=[]
for i in range(self.NH):
Whatemp.append(random.normalvariate(self.InitWhamean,self.InitWhastd))
self.Wha.append(Whatemp)
##Whm
self.Whm=[]
for j in range(self.NM):
Whmtemp=[]
for i in range(self.NH):
Whmtemp.append(random.normalvariate(self.InitWhmmean,self.InitWhmstd))
self.Whm.append(Whmtemp)
##Wmm import
filename ="./MMweight2301.txt"
csvfileWmm = open(filename)
self.Wmm=[[]]
i=0
for row in csv.reader(csvfileWmm):
for elem in row:
self.Wmm[i].append(float(elem))
self.Wmm.append([])
i=i+1
self.Wmm.pop()
#Wms import
filenameWcd="Wcd50_noBias.txt"
csvfileWcd=open(filenameWcd)
self.Wms=[[]]
i=0
for row in csv.reader(csvfileWcd):
for elem in row:
self.Wms[i].append(float(elem))
self.Wms.append([])
i=i+1
self.Wms.pop()
def ConnectW_SM(self):
##Wmm
connWmm=[]
for i in range(self.NM):
self.Wmm[i].append(1.0)
for j in range(self.NM):
connWmm.append(nest.FindConnections([self.Mneurons[i]],[self.Mneurons[j]]))
nest.SetStatus(connWmm[i*self.NM+j],['weight'][0],self.Wmm[i][j])
self.Wmm[i].pop()
def ConnectW(self):
##Whs
temp=map(list, zip(*self.Whs))+map(list, zip(*self.Wms))
tempM=map(list, zip(*temp))
connWhs=[]
for i in range(self.NS):
tempM[i].append(1.0)
connWhs.append(nest.FindConnections([self.Sneurons[i]]))
nest.SetStatus(connWhs[i],['weight'][0],tempM[i])
tempM[i].pop()
##Wha
connWha=[]
for i in range(self.NA):
self.Wha[i].append(1.0)
connWha.append(nest.FindConnections([self.Aneurons[i]]))
nest.SetStatus(connWha[i],['weight'][0],self.Wha[i])
self.Wha[i].pop() # for Wah
##Wah
self.Wah=map(list, zip(*self.Wha))#transpose
for i in range(self.NH):
self.Wah[i].append(1.0)
connWah=[]
for i in range(self.NH):
connWah.append(nest.FindConnections([self.Hneurons[i]]))
nest.SetStatus(connWah[i],['weight'][0],self.Wah[i])
##Whm
connWhm=[]
for i in range(self.NM):
self.Whm[i].append(1.0)
for j in range(self.NH):
connWhm.append(nest.FindConnections([self.Mneurons[i]],[self.Hneurons[j]]))
nest.SetStatus(connWhm[i*self.NH+j],['weight'][0],self.Whm[i][j])
self.Whm[i].pop()
####################################
def Digit(state=1):
if state==3:
state= GoalState
#filename="digit"+str(state)
filename="../shrunk_digit_easy_test_20_15T/digit"+str(state)+"_"+str(random.randint(1,10))
csvfile=open(filename)
obs=[]
for row in csv.reader(csvfile):
for elem in row:
obs.append(int(elem))
return obs
####################################
def StateClamp(obs, action, inh=0):
nest.SetStatus(SNN.drive,[SNN.driveparams_off] )# all drives are 0
for i in range(len(obs)):
if obs[i]==1:
nest.SetStatus(SNN.drive[i:i+1],[SNN.driveparams_on] )
if action==1:
nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_on_a] )
nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )
if action==2:
nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_on_a] )
nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )
if action==3:
nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_on_a] )
nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )
if action==4:
nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )
nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_on_a] )
if inh==1:
nest.SetStatus(SNN.drive[SNN.NS+4:SNN.NS+5],[SNN.driveparams_inh] )
nest.Simulate(T)
nest.Simulate(T)
####################################
def CalcFR():
spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
if len(spiketimes)!=0:
maxTime=(spiketimes[-1]//T +1)*T
if max(spiketimes)>maxTime-100:
IdxLastSP=numpy.nonzero(spiketimes>maxTime-100)[0][0]# last 100 ms
spcount=spikesender[IdxLastSP:]
spcount2=sorted(spcount)
fr=[]
for i in range(SNN.NS+SNN.NH+SNN.NA+1):
fr.append(spcount2.count(i))
else:
fr=numpy.zeros(SNN.NS+SNN.NH+SNN.NA+1)
else:
fr=numpy.zeros(SNN.NS+SNN.NH+SNN.NA+1)
#note
#len(fr)=SNN.NS+SNN.NH+SNN.NA+1
#fr[0] is no meaning
hiddenFR=fr[SNN.NS+1:SNN.NS+SNN.NH+1]
actionFR=[]
actionFR.append(sum(fr[SNN.NS+SNN.NH+1:SNN.NS+SNN.NH+SNN.NA/4+1]))
actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA/4+1:SNN.NS+SNN.NH+SNN.NA*2/4+1]))
actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA*2/4+1:SNN.NS+SNN.NH+SNN.NA*3/4+1]))
actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA*3/4+1:SNN.NS+SNN.NH+SNN.NA+1]))
return hiddenFR, actionFR
####################################
def Actionselection(actionFR,episode,state):
action=1
if state==3:
beta=Beta*float(episode)/float(Nepisode)+0.5/6.
if random.random()<(math.exp(beta*actionFR[1]))/(math.exp(beta*actionFR[1])+math.exp(beta*actionFR[3])):# +1 is to avoid /0
action=2
else:
action=4
return action
####################################
def StateTrans(state,action,goalflag, length):
reward=-500.
goal=0
if state==0:
if action==1:
nextState=2
else:
nextState=0
if state==1:
if action==1:
nextState=2
else:
nextState=1
if state==2:
if action==1:
if length==0:
nextState=3
else:
nextState=2
nextState=3
elif action==3:
if InitState==0:
nextState=0
if InitState==1:
nextState=1
else:
nextState=2
if state==3:
if action==1:
nextState=3
if action==2:
goal=1
nextState=4
if InitState==GoalState:
reward=20000.
goalflag=1
if action==4:
goal=1
nextState=4
if InitState!=GoalState:
reward=20000.
goalflag=1
if action==3:
nextState=2
length=length-1
return nextState, reward, goal, goalflag, length
####################################
def CalcFE_AVE(binSP, state, action, whs, wha, whm):
#h_hat
maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety
#this can be considered as mean over bins if maxFR is 50
bins=50
s_hat=[]
a_hat=[]
h_hat=[]
m_hat=[]
for i in range(0,SNN.NS):
s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
for i in range(SNN.NS,SNN.NS+SNN.NH):
h_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):
a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
for i in range(SNN.NS+SNN.NH+SNN.NA,SNN.NS+SNN.NH+SNN.NA+SNN.NM):
m_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
#calc entropy
entropy=0
for x in h_hat:
entropy=entropy-x*math.log(x)-(1-x)*math.log(1-x)
#calc expected energy
##create state array
sarray=[0]*SNN.NS
if state<3:
sarray[state*16:state*16+16]=[1]*16
else:
sarray[(state-1)*16:(state-1)*16+16]=[1]*16
sarraya=numpy.array(sarray)
##create action array
aarray=[0]*SNN.NA
aarray[(action+1)/2*50:(action+1)/2*50+50]=[1]*50
aarraya=numpy.array(aarray)
#create matrices
smat=scipy.mat(binSP[0:SNN.NS][:])
amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])
hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])
mmat=scipy.mat(binSP[SNN.NS+SNN.NH+SNN.NA:SNN.NS+SNN.NH+SNN.NA+SNN.NM][:])
whsmat=scipy.mat(whs)
whamat=scipy.mat(wha)
whmmat=scipy.mat(whm)
#ExpEnergy
if Flagsa==1:
expEnergy_s=-sarraya*whsmat*hmat
expEnergy_a=-aarraya*whamat*hmat
temp2=expEnergy_a+expEnergy_s
else:
expEnergy_s=-smat.T*whsmat*hmat
expEnergy_a=-amat.T*whamat*hmat
expEnergy_m=-mmat.T*whmmat*hmat
temp=expEnergy_a+expEnergy_s+expEnergy_m
temp2=temp.diagonal()
expEnergy=temp2.tolist()[0]## convert array back to Python list
expEnergy_mean=sum(expEnergy)/bins
#FreeEnergy
freeEnergy=-entropy+expEnergy_mean
if Flagsa==1:
return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy
else:
return s_hat, a_hat, h_hat, m_hat, entropy, expEnergy_mean, freeEnergy
####################################
def CalcFE_LPF(binSP, state, action, whs, wha, whm):
alpha_h=0.1
alpha_f=0.1
s_hat=[]
a_hat=[]
h_hat=[]
m_hat=[]
maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety
bins=50
for i in range(0,SNN.NS):
s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):
a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
for i in range(SNN.NS+SNN.NH+SNN.NA,SNN.NS+SNN.NH+SNN.NA+SNN.NM):
m_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
h_trace=[[]]
bin_h=binSP[SNN.NS:SNN.NS+SNN.NH]
#binSP_T=map(list, zip(*binSP))#transposition
for i in range(len(bin_h)):
for j in range(bins):
if j==0:
h_trace[i].append(0.5)
else:
h_trace[i].append((1-alpha_h)*h_trace[i][j-1]+alpha_h*bin_h[i][j])
h_hat.append(h_trace[i][-1])
h_trace.append([])
h_trace.pop()
#calc entropy
entropy=[[]]
for i in range(len(h_trace[0])):
for j in range(len(h_trace)):
if j==0:
entropy[i]=-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])
else:
entropy[i]=entropy[i]-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])
entropy.append([])
entropy.pop()
#calc expected energy
#create matrices
smat=scipy.mat(binSP[0:SNN.NS][:])
amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])
hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])
mmat=scipy.mat(binSP[SNN.NS+SNN.NH+SNN.NA:SNN.NS+SNN.NH+SNN.NA+SNN.NM][:])
whsmat=scipy.mat(whs)
whamat=scipy.mat(wha)
whmmat=scipy.mat(whm)
#ExpEnergy
expEnergy_s=-smat.T*whsmat*hmat
expEnergy_a=-amat.T*whamat*hmat
expEnergy_m=-mmat.T*whmmat*hmat
temp=expEnergy_a+expEnergy_s+expEnergy_m
temp2=temp.diagonal()
expEnergy=temp2.tolist()[0]## convert array back to Python list
#f
f=[]
for i in range(bins):
f.append(-entropy[i]+expEnergy[i])
#FreeEnergy
for i in range(bins):
if i==0:
freeEnergy_t=[f[0]]
else:
freeEnergy_t.append(freeEnergy_t[i-1]+alpha_f*(f[i]-freeEnergy_t[i-1]))
#return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy
return s_hat, a_hat, h_hat,m_hat, entropy[-1], expEnergy[-1], freeEnergy_t[-1],freeEnergy_t
####################################
def UpdateW(freeEnergy_1,freeEnergy, reward_1, reward, goal, sarray_1, aarray_1, marray_1, h_hat, h_hat_1):
if goal==0:
deltaWhs=[]
for i in range(len(sarray_1)):
temp=[]
for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*sarray_1[i]*x)
deltaWhs.append(temp)
deltaWha=[]
for i in range(len(aarray_1)):
temp=[]
for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*aarray_1[i]*x)
deltaWha.append(temp)
deltaWhm=[]
for i in range(len(marray_1)):
temp=[]
for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*marray_1[i]*x)
deltaWhm.append(temp)
else:
deltaWhs=[]
for i in range(len(sarray_1)):
temp=[]
for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
temp.append((reward+freeEnergy)*sarray_1[i]*x)
deltaWhs.append(temp)
deltaWha=[]
for i in range(len(aarray_1)):
temp=[]
for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
temp.append((reward+freeEnergy)*aarray_1[i]*x)
deltaWha.append(temp)
deltaWhm=[]
for i in range(len(marray_1)):
temp=[]
for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
temp.append((reward+freeEnergy)*marray_1[i]*x)
deltaWhm.append(temp)
goal=0
return deltaWhs, deltaWha, deltaWhm
####################################
def WinputSNN(deltaWhs, deltaWha, deltaWhm):
whsmat=scipy.mat(SNN.Whs)#to matrix
whamat=scipy.mat(SNN.Wha)
whmmat=scipy.mat(SNN.Whm)
deltaWhsmat=scipy.mat(deltaWhs)
deltaWhamat=scipy.mat(deltaWha)
deltaWhmmat=scipy.mat(deltaWhm)
newWhsmat=whsmat+Alpha*deltaWhsmat
newWhamat=whamat+Alpha*deltaWhamat
newWhmmat=whmmat+Alpha*deltaWhmmat
# newWhs=newWhsmat.tolist()#to list
# newWha=newWhamat.tolist()
SNN.Whs=newWhsmat.tolist()# if we use SNN.Whs=NewWhs[:] instead of this,
SNN.Wha=newWhamat.tolist()#len(SNN.Whs[0]) become +1 after NewWhs[i].append(1.0)
SNN.Whm=newWhmmat.tolist()
SNN.ConnectW()
####################################
def Outputsd():
spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
spikesender.tofile("Spikesender.txt", sep=', ', format = "%e")
spiketimes.tofile("Spiketimes.txt", sep=', ', format = "%e")
####################################
def BinSPcnt():
spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
maxTime=(spiketimes[-1]//(2*T) +1)*2*T
bins=50
binsize=2
binidx=[]
errorflag=0
for i in range(bins):
if max(spiketimes)>maxTime-101+i*binsize:
binidx.append(numpy.nonzero(spiketimes>maxTime-101+i*binsize)[0][0])
#nonzero is like "find" in matlab
#[0][0] to get the first index
j=i
else:
if len(binidx)==0: #avoid error that j does not exist, that is no spikes
errorflag=1
if len(binidx)!=0:
binidx.append(binidx[j])
if errorflag==0:
binspikes=[]
for i in range(bins-1):
binspikes.append(sorted(spikesender[binidx[i]:binidx[i+1]]))
binspikes.append(sorted(spikesender[binidx[bins-1]:]))
binSP=[[]]
for i in range(SNN.NS+SNN.NH+SNN.NA+SNN.NM):
for j in range(bins):
if i+1 in binspikes[j]:
binSP[i].append(1)
else:
binSP[i].append(0)
if i != SNN.NS+SNN.NH+SNN.NA+SNN.NM-1:
binSP.append([])
else:
binSP=[[0]*bins]*(SNN.NS+SNN.NH+SNN.NA+SNN.NM)
#binSP=map(list, zip(*binSP))#transposition
return binSP
####################################
def ReadW():
filenameWha="Wha.txt"
csvfileWha=open(filenameWha)
wha=[[]]
i=0
for row in csv.reader(csvfileWha):
for elem in row:
wha[i].append(float(elem))
wha.append([])
i=i+1
wha.pop()
filenameWhs="Whs.txt"
csvfileWhs=open(filenameWhs)
whs=[[]]
i=0
for row in csv.reader(csvfileWhs):
for elem in row:
whs[i].append(float(elem))
whs.append([])
i=i+1
whs.pop()
filenameWhm="Whm.txt"
csvfileWhm=open(filenameWhm)
whm=[[]]
i=0
for row in csv.reader(csvfileWhm):
for elem in row:
whm[i].append(float(elem))
whm.append([])
i=i+1
whm.pop()
SNN.Wha=wha
SNN.Whs=whs
SNN.Whm=whm
SNN.ConnectW()
####################################
####################################
#main
SNN=SNNc()
Flag_FE=0#1 is LPF
Flagsa=0#1 for using binary state and action in FEAVE
T=500
Nepoch=1
Nepisode=3000
Maxstep=30
Gamma=0.99
Alpha=0.0001
Beta=0.1/6.
HistoryReward=[]
HistoryNstep=[]
HistoryCumR=[]
Historys=[]
Historya=[]
HistoryFE=[]
HistoryFEts=[]
HistoryFEtg=[]
HistoryFR=[]
for Epoch in range(Nepoch):
nest.SetKernelStatus({'time':0.0})
Reward_info=[]
Nstep=[]
HistoryCumRtemp=[]
Historystemp2=[]
Historyatemp2=[]
HistoryFEtemp2=[]
HistoryFRtemp2=[]
SNN.InitW()
SNN.ConnectW_SM()
SNN.ConnectW()
ReadW()
for Episode in range(Nepisode):
GoalFlag=0
Goal=0
Action=0
InitState=Episode%2
GoalState=random.randint(0,1)
Length=random.randint(0,4)
State=InitState
State_1=InitState
Action_1=1
FreeEnergy=0.
FreeEnergy_1=0.
Reward=0.
Reward_1=0.
CumR=0
Historystemp=[]
Historyatemp=[]
HistoryFEtemp=[]
HistoryFRtemp=[]
Obs=Digit(State)
StateClamp(Obs, 0,1) #run SNN
for Step in range(Maxstep):
Obs_1=Obs[:]
Obs=Digit(State)
StateClamp(Obs, 0) #run SNN
[HiddenFR, ActionFR]=CalcFR()
Action= Actionselection(ActionFR,Episode,State)
print("State", State, "Action", Action)
StateClamp(Obs, Action)# for FE
#[HiddenFR, ActionFR]=CalcFR()
BinSP=BinSPcnt()
if Step==0:
State_1=State
Action_1=Action
#HiddenFR_1=HiddenFR
BinSP_1=BinSP[:][:]
Whs_1=SNN.Whs[:][:]
Wha_1=SNN.Wha[:][:]
Whm_1=SNN.Whm[:][:]
if Flag_FE==1:
[Sarray, Aarray, H_hat, Marray, Entropy, ExpEnergy, FreeEnergy, FreeEnergy_t]=CalcFE_LPF(BinSP, State, Action, Whs_1, Wha_1, Whm_1)#calc FE
[Sarray_1, Aarray_1, H_hat_1, Marray_1, Entropy_1,ExpEnergy_1, FreeEnergy_1, FreeEnergy_t_1]=CalcFE_LPF(BinSP_1, State_1, Action_1, Whs_1, Wha_1, Whm_1)#calc FE
else:
[Sarray, Aarray, H_hat, Marray, Entropy, ExpEnergy, FreeEnergy]=CalcFE_AVE(BinSP, State, Action, Whs_1, Wha_1, Whm_1)#calc FE
[Sarray_1, Aarray_1, H_hat_1, Marray_1, Entropy_1,ExpEnergy_1, FreeEnergy_1]=CalcFE_AVE(BinSP_1, State_1, Action_1, Whs_1, Wha_1, Whm_1)#calc FE
State_1=State###
Action_1=Action###
Reward_1=Reward ###
#HiddenFR_1=HiddenFR###
BinSP_1=BinSP[:][:]
Whs_1=SNN.Whs[:][:] ###
Wha_1=SNN.Wha[:][:] ###
Whm_1=SNN.Whm[:][:] ###
[State, Reward, Goal, GoalFlag, Length]=StateTrans(State, Action, GoalFlag, Length)# move
[DeltaWhs, DeltaWha, DeltaWhm]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 0, Sarray_1, Aarray_1, Marray_1, H_hat, H_hat_1)
Print1="Epoch: %d, Episode: %d, Step: %d, Flag: %d"
Print2="New state: %d, State: %d, Action: %d, Reward: %d, Goal: %d, CumR=%.3f"
Print3="FreeEnergy(s=%d, a=%d)=%.3f, Entropy=%.3f, ExpEnergy=%.3f"
Print4="ActionFR1=%d, ActionFR2=%d, ActionFR3=%d, ActionFR4=%d"
print Print1 % (Epoch+1, Episode+1, Step+1, GoalFlag)
print Print2 % (State, State_1, Action_1, Reward, Goal, CumR)
print Print3 % (State_1, Action_1, FreeEnergy, Entropy, ExpEnergy)
print Print4 % (ActionFR[0], ActionFR[1], ActionFR[2], ActionFR[3])
print("Reward",Reward_info)
if Step==0:
Historystemp.append(State_1)
Historystemp.append(State)
Historyatemp.append(Action)
HistoryFEtemp.append(FreeEnergy)
HistoryFRtemp.append(ActionFR)
if Step==0 and Epoch ==0 and Flag_FE==1:
HistoryFEts.append(FreeEnergy_t)
if Step != 0:
WinputSNN(DeltaWhs, DeltaWha, DeltaWhm)#not run 1st move
CumR=CumR+Reward_1*Gamma**(Step-1)
if Goal == 1:
CumR=CumR+Reward*Gamma**(Step+1)
Whs=SNN.Whs[:][:] ####not need if ResetNetwork is off
Wha=SNN.Wha[:][:] ####not need if ResetNetwork is off
Whm=SNN.Whm[:][:] ####not need if ResetNetwork is off
if Epoch != Nepoch-1 or Episode != Nepisode-1:
nest.ResetNetwork()#############
SNN.Whs=Whs_1[:][:] ####not need if ResetNetwork is off
SNN.Wha=Wha_1[:][:] ####not need if ResetNetwork is off
SNN.Whm=Whm_1[:][:] ####not need if ResetNetwork is off
[DeltaWhs, DeltaWha, DeltaWhm]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 1, Sarray, Aarray, Marray, H_hat, H_hat_1)
SNN.Whs=Whs[:][:] ####not need if ResetNetwork is off
SNN.Wha=Wha[:][:] ####not need if ResetNetwork is off
SNN.Whm=Whm[:][:] ####not need if ResetNetwork is off
WinputSNN(DeltaWhs, DeltaWha, DeltaWhm)#not run 1st move
break
Historystemp2.append(Historystemp)
Historyatemp2.append(Historyatemp)
HistoryFEtemp2.append(HistoryFEtemp)
HistoryFRtemp2.append(HistoryFRtemp)
Reward_info.append(GoalFlag)
Nstep.append(Step+1)
HistoryCumRtemp.append(CumR)
if Epoch ==0 and Flag_FE==1:
if Goal==0:
HistoryFEtg.append([0 for _ in range(50)])
else:
HistoryFEtg.append(FreeEnergy_t)
Historys.append(Historystemp2)
Historya.append(Historyatemp2)
HistoryFE.append(HistoryFEtemp2)
HistoryFR.append(HistoryFRtemp2)
HistoryNstep.append(Nstep)
HistoryReward.append(Reward_info)
HistoryCumR.append(HistoryCumRtemp)
#to file
fHistory= open('History.txt', 'w')
for i in range(Nepoch):
for j in range(Nepisode):
fHistory.write("Epoch:%d, Episode: %d\n" % (i+1, j+1))
fHistory.write(str(Historys[i][j]) )
fHistory.write('\n')
fHistory.write(str(Historya[i][j]) )
fHistory.write('\n')
fHistory.write(str(HistoryFE[i][j]) )
fHistory.write('\n')
fHistory.write(str(HistoryFR[i][j]) )
fHistory.write('\n')
fHistory.write('\n')
fHistory.close()
fHistoryReward= open('HistoryReward.txt', 'w')
fHistoryNstep= open('HistoryNstep.txt', 'w')
fHistoryCumR= open('HistoryCumR.txt', 'w')
for i in range(Nepoch):
for j in range(Nepisode):
fHistoryReward.write(str(HistoryReward[i][j]))
fHistoryNstep.write(str(HistoryNstep[i][j]))
fHistoryCumR.write(str(HistoryCumR[i][j]))
if j !=Nepisode-1:
fHistoryReward.write(', ')
fHistoryNstep.write(', ')
fHistoryCumR.write(', ')
fHistoryReward.write('\n')
fHistoryNstep.write('\n')
fHistoryCumR.write('\n')
fHistoryReward.close()
fHistoryNstep.close()
fHistoryCumR.close()
fWhs= open('Whs.txt', 'w')
for i in range(SNN.NS):
for j in range(SNN.NH):
fWhs.write(str(SNN.Whs[i][j]))
if j !=SNN.NH-1:
fWhs.write(', ')
fWhs.write('\n')
fWhs.close()
fWha= open('Wha.txt', 'w')
for i in range(SNN.NA):
for j in range(SNN.NH):
fWha.write(str(SNN.Wha[i][j]))
if j !=SNN.NH-1:
fWha.write(', ')
fWha.write('\n')
fWha.close()
fWhm= open('Whm.txt', 'w')
for i in range(SNN.NM):
for j in range(SNN.NH):
fWhm.write(str(SNN.Whm[i][j]))
if j !=SNN.NH-1:
fWhm.write(', ')
fWhm.write('\n')
fWhm.close()
if Flag_FE==1:
fFEts= open('FEts.txt', 'w')
fFEtg= open('FEtg.txt', 'w')
for i in range(Nepisode):
for j in range(50):
fFEts.write(str(HistoryFEts[i][j]))
fFEtg.write(str(HistoryFEtg[i][j]))
if j !=50-1:
fFEts.write(', ')
fFEtg.write(', ')
fFEts.write('\n')
fFEtg.write('\n')
fFEts.close()
fFEtg.close()
#nest.SetKernelStatus({'time':0.0})
#nest.SetStatus(SNN.voltmeter,[{"to_file": True, "withtime": True}])
#StateClamp(0)
#StateClamp(0,-1)
#StateClamp(0,0,0,-1)
#StateClamp(0,1,0,-1)
#StateClamp(1,0,0,1)
#StateClamp(1,1,0,1)
nest.raster_plot.from_device(SNN.sd, hist=False)
#plt.xlim( (0, 1000) )
##plt.show()
plt.savefig('raster.eps')
#plt.close()
##nest.voltage_trace.from_device(SNN.voltmeter)