# -*- coding: utf-8 -*-#####################################Nakano T, Otsuka M, Yoshimoto J, Doya K (2015) A Spiking Neural Network Model of Model-Free Reinforcement Learning with High-Dimensional Sensory Input and Perceptual Ambiguity. PLoS ONE 10(3): e0115620. doi:10.1371/journal.pone.0115620####################################import nestnest.pynestkernel.logstdout("")import nest.raster_plotimport nest.voltage_traceimport pylabimport numpyimport scipyimport time, sysimport matplotlib.pyplot as pltimport randomimport operatorimport mathimport csvnest.ResetKernel()nest.SetKernelStatus({"overwrite_files": True})####################################class SNNc:	def __init__(self):		#define parameters		self.NS=20*15 # number of state neurons		self.NH=90 # number of hidden neurons		self.NA=90 # number of action neurons		self.NM=50 # number of memory neurons		self.InitWhsmean=20 # mean weight Whs		self.InitWhsstd=11.88  # std weight Whs		self.InitWhamean=20 # mean weight Wha		self.InitWhastd=11.88  # std weight Wha				self.InitWhmmean=20 # mean weight Whm		self.InitWhmstd=11.88  # std weight Whm		self.InitWmmmean=0 # mean weight Whm		self.InitWmmstd=0.  # std weight Whm						self.Wseed=0#seed for weight		self.Nseed=[123]#seed for noise		self.driveparams_on  = {'amplitude':2000.}#current inputs to state neurons (and action neurons)		self.driveparams_on_a  = {'amplitude':2000.}#current inputs to state neurons (and action neurons)		self.driveparams_off_a  = {'amplitude':-5000.}#current inputs to state neurons (and action neurons)		self.driveparams_off  = {'amplitude':0.}#no current inputs		self.driveparams_inh  = {'amplitude':-5000.}#no current inputs		self.noiseparams  = {'mean':0.0, 'std':300.}#noise inputs to all state and action neurons		self.sdparams  = { "withtime": True, "withgid" : True,'to_file':False, 'to_screen':False,'flush_after_simulate':True,'flush_records':True}		#neuronparams = { 'tau_m':20., 'V_th':-50., 'E_L':-60., 't_ref':2., 'V_reset':-60., 'C_m':200.}				#create neurons		self.Sneurons = nest.Create('iaf_neuron',self.NS)		self.Hneurons = nest.Create('iaf_neuron',self.NH)		self.Aneurons = nest.Create('iaf_neuron',self.NA)		self.Mneurons = nest.Create('iaf_neuron',self.NM)				self.sd= nest.Create('spike_detector')		self.drive= nest.Create('dc_generator',self.NS+4+1)		nest.SetKernelStatus({'rng_seeds':self.Nseed})#noise		self.noise= nest.Create('noise_generator',6)		self.voltmeter = nest.Create("voltmeter")			#set parameters		nest.SetStatus(self.sd,[self.sdparams] )		nest.SetStatus(self.noise,[self.noiseparams] ) # if noise selection works, comment out this.		#nest.SetStatus(self.Sneurons, [self.neuronparams])		#nest.SetStatus(self.Hneurons, [self.neuronparams])		#nest.SetStatus(self.Aneurons, [self.neuronparams])			#connect		nest.DivergentConnect(self.noise[0:1], self.Sneurons)		nest.DivergentConnect(self.noise[1:2], self.Aneurons)		nest.DivergentConnect(self.noise[2:3], self.Hneurons)		nest.DivergentConnect(self.noise[3:4], self.Mneurons)				for i in range(self.NS):			nest.Connect(self.drive[i:i+1],self.Sneurons[i:i+1] )					nest.DivergentConnect(self.drive[self.NS:self.NS+1], self.Aneurons[0:self.NA/4])		nest.DivergentConnect(self.drive[self.NS+1:self.NS+2], self.Aneurons[self.NA/4:self.NA/2])		nest.DivergentConnect(self.drive[self.NS+2:self.NS+3], self.Aneurons[self.NA/2:self.NA*3/4])		nest.DivergentConnect(self.drive[self.NS+3:self.NS+4], self.Aneurons[self.NA*3/4:self.NA])				nest.DivergentConnect(self.drive[self.NS+4:self.NS+5], self.Mneurons)				nest.ConvergentConnect(self.Sneurons, self.Hneurons, weight=100.0, delay=1.0)#weight is no meaning because it is defined later		nest.ConvergentConnect(self.Hneurons, self.Aneurons, weight=100.0, delay=1.0)#DivergentConnect?		nest.ConvergentConnect(self.Aneurons, self.Hneurons, weight=100.0, delay=1.0)				nest.ConvergentConnect(self.Mneurons, self.Hneurons, weight=100.0, delay=1.0)		#nest.sli_run('/RandomDivergentConnect << /allow_multapses false >> SetOptions')		#nest.sli_run('/RandomDivergentConnect << /allow_autapses false >> SetOptions')		#nest.RandomDivergentConnect(self.Mneurons, self.Mneurons, self.NM/5, weight=100.0, delay=1.0)		nest.ConvergentConnect(self.Mneurons, self.Mneurons, weight=0.0, delay=1.0)				nest.ConvergentConnect(self.Sneurons, self.Mneurons, weight=0.0, delay=1.0)#weight is no meaning because it is defined later				nest.ConvergentConnect(self.Sneurons, self.sd)		nest.ConvergentConnect(self.Hneurons, self.sd)		nest.ConvergentConnect(self.Aneurons, self.sd)		nest.ConvergentConnect(self.Mneurons, self.sd)				nest.Connect(self.voltmeter, self.Hneurons[20:21])				#weight			random.seed(self.Wseed)							def InitW(self):		##init weight		##Whs		self.Whs=[]		for j in range(self.NS):			Whstemp=[]			for i in range(self.NH):				Whstemp.append(random.normalvariate(self.InitWhsmean, self.InitWhsstd))			self.Whs.append(Whstemp)				##Wha		self.Wha=[]		for j in range(self.NA):			Whatemp=[]			for i in range(self.NH):				Whatemp.append(random.normalvariate(self.InitWhamean,self.InitWhastd))			self.Wha.append(Whatemp)					##Whm		self.Whm=[]		for j in range(self.NM):			Whmtemp=[]			for i in range(self.NH):				Whmtemp.append(random.normalvariate(self.InitWhmmean,self.InitWhmstd))			self.Whm.append(Whmtemp)							##Wmm import		filename ="./MMweight2301.txt"		csvfileWmm = open(filename)				self.Wmm=[[]]		i=0		for row in csv.reader(csvfileWmm):			for elem in row:				self.Wmm[i].append(float(elem))			self.Wmm.append([])			i=i+1		self.Wmm.pop()				#Wms import		filenameWcd="Wcd50_noBias.txt"		csvfileWcd=open(filenameWcd)				self.Wms=[[]]		i=0		for row in csv.reader(csvfileWcd):			for elem in row:				self.Wms[i].append(float(elem))			self.Wms.append([])			i=i+1		self.Wms.pop()							def ConnectW_SM(self):				##Wmm		connWmm=[]		for i in range(self.NM):			self.Wmm[i].append(1.0)			for j in range(self.NM):				connWmm.append(nest.FindConnections([self.Mneurons[i]],[self.Mneurons[j]]))				nest.SetStatus(connWmm[i*self.NM+j],['weight'][0],self.Wmm[i][j])			self.Wmm[i].pop() 				def ConnectW(self):		##Whs		temp=map(list, zip(*self.Whs))+map(list, zip(*self.Wms))		tempM=map(list, zip(*temp))				connWhs=[]		for i in range(self.NS):			tempM[i].append(1.0)			connWhs.append(nest.FindConnections([self.Sneurons[i]]))			nest.SetStatus(connWhs[i],['weight'][0],tempM[i])			tempM[i].pop() 				##Wha		connWha=[]		for i in range(self.NA):			self.Wha[i].append(1.0)			connWha.append(nest.FindConnections([self.Aneurons[i]]))			nest.SetStatus(connWha[i],['weight'][0],self.Wha[i])			self.Wha[i].pop() # for Wah				##Wah		self.Wah=map(list, zip(*self.Wha))#transpose		for i in range(self.NH):			self.Wah[i].append(1.0)		connWah=[]		for i in range(self.NH):			connWah.append(nest.FindConnections([self.Hneurons[i]]))			nest.SetStatus(connWah[i],['weight'][0],self.Wah[i])				##Whm		connWhm=[]		for i in range(self.NM):			self.Whm[i].append(1.0)			for j in range(self.NH):				connWhm.append(nest.FindConnections([self.Mneurons[i]],[self.Hneurons[j]]))				nest.SetStatus(connWhm[i*self.NH+j],['weight'][0],self.Whm[i][j])			self.Whm[i].pop() 				####################################def Digit(state=1):		if state==3:		state= GoalState		#filename="digit"+str(state)	filename="../shrunk_digit_easy_test_20_15T/digit"+str(state)+"_"+str(random.randint(1,10))		csvfile=open(filename)		obs=[]	for row in csv.reader(csvfile):		for elem in row:			obs.append(int(elem))		return obs####################################def StateClamp(obs, action, inh=0):		nest.SetStatus(SNN.drive,[SNN.driveparams_off] )# all drives are 0		for i in range(len(obs)):		if obs[i]==1:			nest.SetStatus(SNN.drive[i:i+1],[SNN.driveparams_on] )		if action==1:		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_on_a] )		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )	if action==2:		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_on_a] )		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )	if action==3:		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_on_a] )		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )	if action==4:		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_on_a] )		if inh==1:		nest.SetStatus(SNN.drive[SNN.NS+4:SNN.NS+5],[SNN.driveparams_inh] )		nest.Simulate(T)		nest.Simulate(T)####################################def CalcFR():	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array		if len(spiketimes)!=0:		maxTime=(spiketimes[-1]//T +1)*T		if max(spiketimes)>maxTime-100:			IdxLastSP=numpy.nonzero(spiketimes>maxTime-100)[0][0]# last 100 ms			spcount=spikesender[IdxLastSP:]			spcount2=sorted(spcount)			fr=[]			for i in range(SNN.NS+SNN.NH+SNN.NA+1):				fr.append(spcount2.count(i))		else:			fr=numpy.zeros(SNN.NS+SNN.NH+SNN.NA+1)	else:		fr=numpy.zeros(SNN.NS+SNN.NH+SNN.NA+1)	#note	#len(fr)=SNN.NS+SNN.NH+SNN.NA+1	#fr[0] is no meaning		hiddenFR=fr[SNN.NS+1:SNN.NS+SNN.NH+1]	actionFR=[]	actionFR.append(sum(fr[SNN.NS+SNN.NH+1:SNN.NS+SNN.NH+SNN.NA/4+1]))	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA/4+1:SNN.NS+SNN.NH+SNN.NA*2/4+1]))	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA*2/4+1:SNN.NS+SNN.NH+SNN.NA*3/4+1]))	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA*3/4+1:SNN.NS+SNN.NH+SNN.NA+1]))	return hiddenFR, actionFR####################################def Actionselection(actionFR,episode,state):	action=1		if state==3:		beta=Beta*float(episode)/float(Nepisode)+0.5/6.		if random.random()<(math.exp(beta*actionFR[1]))/(math.exp(beta*actionFR[1])+math.exp(beta*actionFR[3])):# +1 is to avoid /0			action=2		else:			action=4		return action####################################def StateTrans(state,action,goalflag, length):	reward=-500.	goal=0		if state==0:		if action==1:			nextState=2		else:			nextState=0	if state==1:		if action==1:			nextState=2		else:			nextState=1		if state==2:		if action==1:			if length==0:				nextState=3			else:				nextState=2			nextState=3		elif action==3:			if InitState==0:				nextState=0			if InitState==1:				nextState=1		else:			nextState=2		if state==3:		if action==1:			nextState=3		if action==2:			goal=1			nextState=4			if InitState==GoalState:				reward=20000.				goalflag=1		if action==4:			goal=1			nextState=4			if InitState!=GoalState:				reward=20000.				goalflag=1		if action==3:			nextState=2				length=length-1		return nextState, reward, goal, goalflag, length	####################################def CalcFE_AVE(binSP, state, action, whs, wha, whm):	#h_hat	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety	#this can be considered as mean over bins if maxFR is 50	bins=50	s_hat=[]	a_hat=[]	h_hat=[]	m_hat=[]	for i in range(0,SNN.NS):		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS,SNN.NS+SNN.NH):		h_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS+SNN.NH+SNN.NA,SNN.NS+SNN.NH+SNN.NA+SNN.NM):		m_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	#calc entropy	entropy=0	for x in h_hat:		entropy=entropy-x*math.log(x)-(1-x)*math.log(1-x)			#calc expected energy	##create state array	sarray=[0]*SNN.NS	if state<3:		sarray[state*16:state*16+16]=[1]*16	else:		sarray[(state-1)*16:(state-1)*16+16]=[1]*16	sarraya=numpy.array(sarray)		##create action array	aarray=[0]*SNN.NA	aarray[(action+1)/2*50:(action+1)/2*50+50]=[1]*50	aarraya=numpy.array(aarray)	#create matrices	smat=scipy.mat(binSP[0:SNN.NS][:])	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])	mmat=scipy.mat(binSP[SNN.NS+SNN.NH+SNN.NA:SNN.NS+SNN.NH+SNN.NA+SNN.NM][:])		whsmat=scipy.mat(whs)	whamat=scipy.mat(wha)	whmmat=scipy.mat(whm)	#ExpEnergy	if Flagsa==1:		expEnergy_s=-sarraya*whsmat*hmat		expEnergy_a=-aarraya*whamat*hmat		temp2=expEnergy_a+expEnergy_s			else:		expEnergy_s=-smat.T*whsmat*hmat		expEnergy_a=-amat.T*whamat*hmat		expEnergy_m=-mmat.T*whmmat*hmat		temp=expEnergy_a+expEnergy_s+expEnergy_m		temp2=temp.diagonal()		expEnergy=temp2.tolist()[0]## convert array back to Python list	expEnergy_mean=sum(expEnergy)/bins	#FreeEnergy	freeEnergy=-entropy+expEnergy_mean		if Flagsa==1:		return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy	else:		return s_hat, a_hat, h_hat, m_hat, entropy, expEnergy_mean, freeEnergy####################################def CalcFE_LPF(binSP, state, action, whs, wha, whm):	alpha_h=0.1	alpha_f=0.1	s_hat=[]	a_hat=[]	h_hat=[]	m_hat=[]	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety	bins=50	for i in range(0,SNN.NS):		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS+SNN.NH+SNN.NA,SNN.NS+SNN.NH+SNN.NA+SNN.NM):		m_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error				h_trace=[[]]	bin_h=binSP[SNN.NS:SNN.NS+SNN.NH]	#binSP_T=map(list, zip(*binSP))#transposition	for i in range(len(bin_h)):		for j in range(bins):			if j==0:				h_trace[i].append(0.5)			else:				h_trace[i].append((1-alpha_h)*h_trace[i][j-1]+alpha_h*bin_h[i][j])		h_hat.append(h_trace[i][-1])		h_trace.append([])	h_trace.pop()			#calc entropy	entropy=[[]]	for i in range(len(h_trace[0])):		for j in range(len(h_trace)):			if j==0:				entropy[i]=-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])			else:				entropy[i]=entropy[i]-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])		entropy.append([])	entropy.pop()		#calc expected energy	#create matrices	smat=scipy.mat(binSP[0:SNN.NS][:])	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])	mmat=scipy.mat(binSP[SNN.NS+SNN.NH+SNN.NA:SNN.NS+SNN.NH+SNN.NA+SNN.NM][:])	whsmat=scipy.mat(whs)	whamat=scipy.mat(wha)	whmmat=scipy.mat(whm)	#ExpEnergy	expEnergy_s=-smat.T*whsmat*hmat	expEnergy_a=-amat.T*whamat*hmat	expEnergy_m=-mmat.T*whmmat*hmat	temp=expEnergy_a+expEnergy_s+expEnergy_m	temp2=temp.diagonal()	expEnergy=temp2.tolist()[0]## convert array back to Python list	#f	f=[]	for i in range(bins):		f.append(-entropy[i]+expEnergy[i])		#FreeEnergy	for i in range(bins):		if i==0:			freeEnergy_t=[f[0]]		else:			freeEnergy_t.append(freeEnergy_t[i-1]+alpha_f*(f[i]-freeEnergy_t[i-1]))		#return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy	return s_hat, a_hat, h_hat,m_hat, entropy[-1], expEnergy[-1], freeEnergy_t[-1],freeEnergy_t####################################def UpdateW(freeEnergy_1,freeEnergy, reward_1, reward, goal, sarray_1, aarray_1, marray_1, h_hat, h_hat_1):	if goal==0:		deltaWhs=[]		for i in range(len(sarray_1)):			temp=[]			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*sarray_1[i]*x)			deltaWhs.append(temp)				deltaWha=[]		for i in range(len(aarray_1)):			temp=[]			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*aarray_1[i]*x)			deltaWha.append(temp)		deltaWhm=[]		for i in range(len(marray_1)):			temp=[]			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*marray_1[i]*x)			deltaWhm.append(temp)		else:		deltaWhs=[]		for i in range(len(sarray_1)):			temp=[]			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE				temp.append((reward+freeEnergy)*sarray_1[i]*x)			deltaWhs.append(temp)				deltaWha=[]		for i in range(len(aarray_1)):			temp=[]			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE				temp.append((reward+freeEnergy)*aarray_1[i]*x)			deltaWha.append(temp)		deltaWhm=[]		for i in range(len(marray_1)):			temp=[]			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE				temp.append((reward+freeEnergy)*marray_1[i]*x)			deltaWhm.append(temp)					goal=0	return deltaWhs, deltaWha, deltaWhm####################################def WinputSNN(deltaWhs, deltaWha, deltaWhm):	whsmat=scipy.mat(SNN.Whs)#to matrix	whamat=scipy.mat(SNN.Wha)	whmmat=scipy.mat(SNN.Whm)	deltaWhsmat=scipy.mat(deltaWhs)	deltaWhamat=scipy.mat(deltaWha)	deltaWhmmat=scipy.mat(deltaWhm)		newWhsmat=whsmat+Alpha*deltaWhsmat	newWhamat=whamat+Alpha*deltaWhamat	newWhmmat=whmmat+Alpha*deltaWhmmat	#	newWhs=newWhsmat.tolist()#to list#	newWha=newWhamat.tolist()		SNN.Whs=newWhsmat.tolist()# if we use SNN.Whs=NewWhs[:] instead of this,	SNN.Wha=newWhamat.tolist()#len(SNN.Whs[0]) become +1 after NewWhs[i].append(1.0)	SNN.Whm=newWhmmat.tolist()	SNN.ConnectW()####################################def Outputsd():	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array	spikesender.tofile("Spikesender.txt", sep=', ', format = "%e") 	spiketimes.tofile("Spiketimes.txt", sep=', ', format = "%e") ####################################def BinSPcnt():	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array	maxTime=(spiketimes[-1]//(2*T) +1)*2*T	bins=50	binsize=2		binidx=[]	errorflag=0	for i in range(bins):		if max(spiketimes)>maxTime-101+i*binsize:			binidx.append(numpy.nonzero(spiketimes>maxTime-101+i*binsize)[0][0])			#nonzero is like "find" in matlab			#[0][0] to get the first index			j=i		else:			if len(binidx)==0: #avoid error that j does not exist, that is no spikes				errorflag=1			if len(binidx)!=0: 				binidx.append(binidx[j])					if errorflag==0:		binspikes=[]		for i in range(bins-1):			binspikes.append(sorted(spikesender[binidx[i]:binidx[i+1]]))		binspikes.append(sorted(spikesender[binidx[bins-1]:]))				binSP=[[]]		for i in range(SNN.NS+SNN.NH+SNN.NA+SNN.NM):			for j in range(bins):				if i+1 in binspikes[j]:					binSP[i].append(1)				else:					binSP[i].append(0)			if i != SNN.NS+SNN.NH+SNN.NA+SNN.NM-1:				binSP.append([])	else:		binSP=[[0]*bins]*(SNN.NS+SNN.NH+SNN.NA+SNN.NM)	#binSP=map(list, zip(*binSP))#transposition					return binSP		####################################def ReadW():		filenameWha="Wha.txt"	csvfileWha=open(filenameWha)		wha=[[]]	i=0	for row in csv.reader(csvfileWha):		for elem in row:			wha[i].append(float(elem))		wha.append([])		i=i+1	wha.pop()		filenameWhs="Whs.txt"	csvfileWhs=open(filenameWhs)		whs=[[]]	i=0	for row in csv.reader(csvfileWhs):		for elem in row:			whs[i].append(float(elem))		whs.append([])		i=i+1	whs.pop()		filenameWhm="Whm.txt"	csvfileWhm=open(filenameWhm)		whm=[[]]	i=0	for row in csv.reader(csvfileWhm):		for elem in row:			whm[i].append(float(elem))		whm.append([])		i=i+1	whm.pop()		SNN.Wha=wha	SNN.Whs=whs	SNN.Whm=whm		SNN.ConnectW()	#########################################################################mainSNN=SNNc()Flag_FE=0#1 is LPFFlagsa=0#1 for using binary state and action in FEAVET=500Nepoch=1Nepisode=3000Maxstep=30Gamma=0.99Alpha=0.0001Beta=0.1/6.HistoryReward=[]HistoryNstep=[]HistoryCumR=[]Historys=[]Historya=[]HistoryFE=[]HistoryFEts=[]HistoryFEtg=[]HistoryFR=[]	for Epoch in range(Nepoch):	nest.SetKernelStatus({'time':0.0})	Reward_info=[]	Nstep=[]	HistoryCumRtemp=[]	Historystemp2=[]	Historyatemp2=[]	HistoryFEtemp2=[]	HistoryFRtemp2=[]		SNN.InitW()	SNN.ConnectW_SM()	SNN.ConnectW()		ReadW()		for Episode in range(Nepisode):		GoalFlag=0		Goal=0		Action=0		InitState=Episode%2		GoalState=random.randint(0,1)		Length=random.randint(0,4)		State=InitState		State_1=InitState		Action_1=1		FreeEnergy=0.		FreeEnergy_1=0.		Reward=0.		Reward_1=0.		CumR=0		Historystemp=[]		Historyatemp=[]		HistoryFEtemp=[]		HistoryFRtemp=[]				Obs=Digit(State)		StateClamp(Obs, 0,1) #run SNN				for Step in range(Maxstep):			Obs_1=Obs[:]			Obs=Digit(State)			StateClamp(Obs, 0) #run SNN			[HiddenFR, ActionFR]=CalcFR()			Action= Actionselection(ActionFR,Episode,State)			print("State", State, "Action", Action)			StateClamp(Obs, Action)# for FE			#[HiddenFR, ActionFR]=CalcFR()			BinSP=BinSPcnt()			if Step==0:				State_1=State				Action_1=Action				#HiddenFR_1=HiddenFR				BinSP_1=BinSP[:][:]				Whs_1=SNN.Whs[:][:]				Wha_1=SNN.Wha[:][:]				Whm_1=SNN.Whm[:][:]			if Flag_FE==1:				[Sarray, Aarray, H_hat, Marray, Entropy, ExpEnergy, FreeEnergy, FreeEnergy_t]=CalcFE_LPF(BinSP, State, Action, Whs_1, Wha_1, Whm_1)#calc FE				[Sarray_1, Aarray_1, H_hat_1, Marray_1, Entropy_1,ExpEnergy_1, FreeEnergy_1, FreeEnergy_t_1]=CalcFE_LPF(BinSP_1, State_1, Action_1, Whs_1, Wha_1, Whm_1)#calc FE			else:				[Sarray, Aarray, H_hat, Marray, Entropy, ExpEnergy, FreeEnergy]=CalcFE_AVE(BinSP, State, Action, Whs_1, Wha_1, Whm_1)#calc FE				[Sarray_1, Aarray_1, H_hat_1, Marray_1, Entropy_1,ExpEnergy_1, FreeEnergy_1]=CalcFE_AVE(BinSP_1, State_1, Action_1, Whs_1, Wha_1, Whm_1)#calc FE			State_1=State###			Action_1=Action###			Reward_1=Reward	###			#HiddenFR_1=HiddenFR###			BinSP_1=BinSP[:][:]			Whs_1=SNN.Whs[:][:]	###			Wha_1=SNN.Wha[:][:]	###			Whm_1=SNN.Whm[:][:]	###			[State, Reward, Goal, GoalFlag, Length]=StateTrans(State, Action, GoalFlag, Length)# move			[DeltaWhs, DeltaWha, DeltaWhm]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 0, Sarray_1,  Aarray_1,  Marray_1, H_hat, H_hat_1)			Print1="Epoch: %d, Episode: %d, Step: %d, Flag: %d"			Print2="New state: %d, State: %d, Action: %d, Reward: %d, Goal: %d, CumR=%.3f"			Print3="FreeEnergy(s=%d, a=%d)=%.3f, Entropy=%.3f, ExpEnergy=%.3f"			Print4="ActionFR1=%d, ActionFR2=%d, ActionFR3=%d, ActionFR4=%d"			print Print1 % (Epoch+1, Episode+1, Step+1, GoalFlag)			print Print2 % (State, State_1, Action_1, Reward, Goal, CumR)			print Print3 % (State_1, Action_1, FreeEnergy, Entropy, ExpEnergy)			print Print4 % (ActionFR[0], ActionFR[1], ActionFR[2], ActionFR[3])			print("Reward",Reward_info)			if Step==0:				Historystemp.append(State_1)			Historystemp.append(State)			Historyatemp.append(Action)			HistoryFEtemp.append(FreeEnergy)			HistoryFRtemp.append(ActionFR)			if Step==0 and Epoch ==0 and Flag_FE==1:				HistoryFEts.append(FreeEnergy_t)			if Step != 0:				WinputSNN(DeltaWhs, DeltaWha, DeltaWhm)#not run 1st move				CumR=CumR+Reward_1*Gamma**(Step-1)			if Goal == 1:				CumR=CumR+Reward*Gamma**(Step+1)				Whs=SNN.Whs[:][:]	####not need if ResetNetwork is off				Wha=SNN.Wha[:][:]	####not need if ResetNetwork is off				Whm=SNN.Whm[:][:]	####not need if ResetNetwork is off				if Epoch != Nepoch-1 or Episode != Nepisode-1:					nest.ResetNetwork()#############				SNN.Whs=Whs_1[:][:]	####not need if ResetNetwork is off				SNN.Wha=Wha_1[:][:]	####not need if ResetNetwork is off				SNN.Whm=Whm_1[:][:]	####not need if ResetNetwork is off				[DeltaWhs, DeltaWha, DeltaWhm]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 1, Sarray, Aarray, Marray, H_hat, H_hat_1)				SNN.Whs=Whs[:][:]	####not need if ResetNetwork is off				SNN.Wha=Wha[:][:]	####not need if ResetNetwork is off				SNN.Whm=Whm[:][:]	####not need if ResetNetwork is off				WinputSNN(DeltaWhs, DeltaWha, DeltaWhm)#not run 1st move				break					Historystemp2.append(Historystemp)		Historyatemp2.append(Historyatemp)		HistoryFEtemp2.append(HistoryFEtemp)		HistoryFRtemp2.append(HistoryFRtemp)		Reward_info.append(GoalFlag)		Nstep.append(Step+1)		HistoryCumRtemp.append(CumR)		if Epoch ==0 and Flag_FE==1:			if Goal==0:				HistoryFEtg.append([0 for _ in range(50)])			else:				HistoryFEtg.append(FreeEnergy_t)	Historys.append(Historystemp2)	Historya.append(Historyatemp2)	HistoryFE.append(HistoryFEtemp2)	HistoryFR.append(HistoryFRtemp2)	HistoryNstep.append(Nstep)	HistoryReward.append(Reward_info)	HistoryCumR.append(HistoryCumRtemp)#to filefHistory= open('History.txt', 'w')for i in range(Nepoch):	for j in range(Nepisode):		fHistory.write("Epoch:%d, Episode: %d\n" % (i+1, j+1))				fHistory.write(str(Historys[i][j]) )		fHistory.write('\n')		fHistory.write(str(Historya[i][j]) )		fHistory.write('\n')		fHistory.write(str(HistoryFE[i][j]) )		fHistory.write('\n')		fHistory.write(str(HistoryFR[i][j]) )		fHistory.write('\n')	fHistory.write('\n')fHistory.close()fHistoryReward= open('HistoryReward.txt', 'w')fHistoryNstep= open('HistoryNstep.txt', 'w')fHistoryCumR= open('HistoryCumR.txt', 'w')for i in range(Nepoch):	for j in range(Nepisode):		fHistoryReward.write(str(HistoryReward[i][j]))		fHistoryNstep.write(str(HistoryNstep[i][j]))		fHistoryCumR.write(str(HistoryCumR[i][j]))		if j !=Nepisode-1:			fHistoryReward.write(', ')			fHistoryNstep.write(', ')			fHistoryCumR.write(', ')	fHistoryReward.write('\n')	fHistoryNstep.write('\n')	fHistoryCumR.write('\n')fHistoryReward.close()fHistoryNstep.close()fHistoryCumR.close()fWhs= open('Whs.txt', 'w')for i in range(SNN.NS):	for j in range(SNN.NH):		fWhs.write(str(SNN.Whs[i][j]))		if j !=SNN.NH-1:			fWhs.write(', ')	fWhs.write('\n')fWhs.close()fWha= open('Wha.txt', 'w')for i in range(SNN.NA):	for j in range(SNN.NH):		fWha.write(str(SNN.Wha[i][j]))		if j !=SNN.NH-1:			fWha.write(', ')	fWha.write('\n')fWha.close()fWhm= open('Whm.txt', 'w')for i in range(SNN.NM):	for j in range(SNN.NH):		fWhm.write(str(SNN.Whm[i][j]))		if j !=SNN.NH-1:			fWhm.write(', ')	fWhm.write('\n')fWhm.close()if Flag_FE==1:	fFEts= open('FEts.txt', 'w')	fFEtg= open('FEtg.txt', 'w')	for i in range(Nepisode):		for j in range(50):			fFEts.write(str(HistoryFEts[i][j]))			fFEtg.write(str(HistoryFEtg[i][j]))			if j !=50-1:				fFEts.write(', ')				fFEtg.write(', ')		fFEts.write('\n')		fFEtg.write('\n')	fFEts.close()	fFEtg.close()#nest.SetKernelStatus({'time':0.0})#nest.SetStatus(SNN.voltmeter,[{"to_file": True, "withtime": True}])#StateClamp(0)#StateClamp(0,-1)#StateClamp(0,0,0,-1)#StateClamp(0,1,0,-1)#StateClamp(1,0,0,1)#StateClamp(1,1,0,1)nest.raster_plot.from_device(SNN.sd, hist=False)#plt.xlim( (0, 1000) )##plt.show()plt.savefig('raster.eps')#plt.close()##nest.voltage_trace.from_device(SNN.voltmeter)