# -*- coding: utf-8 -*-#####################################Nakano T, Otsuka M, Yoshimoto J, Doya K (2015) A Spiking Neural Network Model of Model-Free Reinforcement Learning with High-Dimensional Sensory Input and Perceptual Ambiguity. PLoS ONE 10(3): e0115620. doi:10.1371/journal.pone.0115620####################################import nestnest.pynestkernel.logstdout("log.txt")import nest.raster_plotimport nest.voltage_traceimport pylabimport numpyimport scipyimport time, sysimport matplotlib.pyplot as pltimport randomimport operatorimport mathimport csvnest.ResetKernel()nest.SetKernelStatus({"overwrite_files": True})####################################class SNNc:		def InitW(self):		##init weight		##Whs		self.Whs=[]		for j in range(self.NS):			Whstemp=[]			for i in range(self.NH):				Whstemp.append(random.normalvariate(self.InitWhsmean, self.InitWhsstd))			self.Whs.append(Whstemp)				##Wha		self.Wha=[]		for j in range(self.NA):			Whatemp=[]			for i in range(self.NH):				Whatemp.append(random.normalvariate(self.InitWhamean,self.InitWhastd))			self.Wha.append(Whatemp)		def ConnectW(self):		##Whs		connWhs=[]		for i in range(self.NS):			self.Whs[i].append(1.0)			connWhs.append(nest.FindConnections([self.Sneurons[i]]))			nest.SetStatus(connWhs[i],['weight'][0],self.Whs[i])			self.Whs[i].pop() 				##Wha		connWha=[]		for i in range(self.NA):			self.Wha[i].append(1.0)			connWha.append(nest.FindConnections([self.Aneurons[i]]))			nest.SetStatus(connWha[i],['weight'][0],self.Wha[i])			self.Wha[i].pop() # for Wah				##Wah		self.Wah=map(list, zip(*self.Wha))#transpose		for i in range(self.NH):			self.Wah[i].append(1.0)		connWah=[]		for i in range(self.NH):			connWah.append(nest.FindConnections([self.Hneurons[i]]))			nest.SetStatus(connWah[i],['weight'][0],self.Wah[i])						#define parameters	NS=22*22 # number of state neurons	NH=90 # number of hidden neurons	NA=90 # number of action neurons	InitWhsmean=10 # mean weight Whs	InitWhsstd=11.88  # std weight Whs	InitWhamean=10 # mean weight Wha	InitWhastd=11.88  # std weight Wha	Wseed=0#seed for weight	Nseed=[123]#seed for noise	driveparams_on  = {'amplitude':1000.}#current inputs to state neurons (and action neurons)	driveparams_on_a  = {'amplitude':1000.}#current inputs to state neurons (and action neurons)	driveparams_off_a  = {'amplitude':-2000.}#current inputs to state neurons (and action neurons)	driveparams_off  = {'amplitude':0.}#no current inputs	noiseparams  = {'mean':0.0, 'std':600.}#noise inputs to all state and action neurons	sdparams  = { "withtime": True, "withgid" : True,'to_file':False, 'to_screen':False,'flush_after_simulate':True,'flush_records':True}	#neuronparams = { 'tau_m':20., 'V_th':-50., 'E_L':-60., 't_ref':2., 'V_reset':-60., 'C_m':200.}		#create neurons	Sneurons = nest.Create('iaf_neuron',NS)	Hneurons = nest.Create('iaf_neuron',NH)	Aneurons = nest.Create('iaf_neuron',NA)		sd= nest.Create('spike_detector')	drive= nest.Create('dc_generator',NS+2)	nest.SetKernelStatus({'rng_seeds':Nseed})#noise	noise= nest.Create('noise_generator',3)	voltmeter = nest.Create("voltmeter")		#set parameters	nest.SetStatus(sd,[sdparams] )	nest.SetStatus(noise,[noiseparams] ) # if noise selection works, comment out this.	#nest.SetStatus(Sneurons, [neuronparams])	#nest.SetStatus(Hneurons, [neuronparams])	#nest.SetStatus(Aneurons, [neuronparams])		#connect	nest.DivergentConnect(noise[0:1], Sneurons)	nest.DivergentConnect(noise[1:2], Aneurons)	nest.DivergentConnect(noise[2:3], Hneurons)		for i in range(NS):		nest.Connect(drive[i:i+1],Sneurons[i:i+1] )	nest.DivergentConnect(drive[NS:NS+1], Aneurons[0:NA/2])	nest.DivergentConnect(drive[NS+1:NS+2], Aneurons[NA/2:NA])		nest.ConvergentConnect(Sneurons, Hneurons, weight=100.0, delay=1.0)#whight is no meaning because it is defined later	nest.ConvergentConnect(Hneurons, Aneurons, weight=100.0, delay=1.0)#DivergentConnect?	nest.ConvergentConnect(Aneurons, Hneurons, weight=100.0, delay=1.0)	nest.ConvergentConnect(Sneurons, sd)	nest.ConvergentConnect(Hneurons, sd)	nest.ConvergentConnect(Aneurons, sd)		nest.Connect(voltmeter, Hneurons[20:21])		#weight		random.seed(Wseed)	####################################def Digit(state=1):		filename="./digit22/digit"+str(state)+"_"+str(random.randint(1,10))	csvfile=open(filename)		obs=[]	for row in csv.reader(csvfile):		for elem in row:			obs.append(int(elem))		return obs	####################################def StateClamp(obs, action=0):	nest.SetStatus(SNN.drive,[SNN.driveparams_off] )# all drives are 0	for i in range(len(obs)):		if obs[i]==1:			nest.SetStatus(SNN.drive[i:i+1],[SNN.driveparams_on] )	if action==-1:		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_on_a] )		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )	elif action==1:		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_on_a] )		nest.Simulate(T)####################################def CalcFR():	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array		maxTime=(spiketimes[-1]//T +1)*T        IdxLastSP=numpy.nonzero(spiketimes>maxTime-100)[0][0]# last 100 ms	spcount=spikesender[IdxLastSP:]	spcount2=sorted(spcount)	fr=[]	for i in range(SNN.NS+SNN.NH+SNN.NA+1):		fr.append(spcount2.count(i))	#note	#len(fr)=SNN.NS+SNN.NH+SNN.NA+1	#fr[0] is no meaning		hiddenFR=fr[SNN.NS+1:SNN.NS+SNN.NH+1]	actionFR=[]	actionFR.append(sum(fr[SNN.NS+SNN.NH+1:SNN.NS+SNN.NH+SNN.NA/2+1]))	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA/2+1:SNN.NS+SNN.NH+SNN.NA+1]))	return hiddenFR, actionFR####################################def Actionselection(actionFR, episode):	beta=Beta*float(episode)/float(Nepisode)	if random.random()<(math.exp(beta*actionFR[0]))/(math.exp(beta*actionFR[0])+math.exp(beta*actionFR[1])):		action=-1	else:		action=1	return action	####################################def StateTrans(state,action):	reward=-1000.	goal=0		nextState=state+action	if nextState==-1: 		nextState=0	if nextState==7:		nextState=6	if nextState==3:		reward=50000.		goal=1				return nextState, reward, goal	####################################def CalcFE_AVE(binSP, state, action, whs, wha):	#h_hat	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety	#this can be considered as mean over bins if maxFR is 50	bins=50	s_hat=[]	a_hat=[]	h_hat=[]	for i in range(0,SNN.NS):		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS,SNN.NS+SNN.NH):		h_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	#calc entropy	entropy=0	for x in h_hat:		entropy=entropy-x*math.log(x)-(1-x)*math.log(1-x)	#calc expected energy	##create state array	sarray=[0]*SNN.NS	if state<3:		sarray[state*16:state*16+16]=[1]*16	else:		sarray[(state-1)*16:(state-1)*16+16]=[1]*16	sarraya=numpy.array(sarray)		##create action array	aarray=[0]*SNN.NA	aarray[(action+1)/2*50:(action+1)/2*50+50]=[1]*50	aarraya=numpy.array(aarray)	#create matrices	smat=scipy.mat(binSP[0:SNN.NS][:])	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])		whsmat=scipy.mat(whs)	whamat=scipy.mat(wha)	#ExpEnergy	if Flagsa==1:		expEnergy_s=-sarraya*whsmat*hmat		expEnergy_a=-aarraya*whamat*hmat		temp2=expEnergy_a+expEnergy_s			else:		expEnergy_s=-smat.T*whsmat*hmat		expEnergy_a=-amat.T*whamat*hmat		temp=expEnergy_a+expEnergy_s		temp2=temp.diagonal()		expEnergy=temp2.tolist()[0]## convert array back to Python list	expEnergy_mean=sum(expEnergy)/bins	#FreeEnergy	freeEnergy=-entropy+expEnergy_mean		if Flagsa==1:		return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy	else:		return s_hat, a_hat, h_hat, entropy, expEnergy_mean, freeEnergy####################################def CalcFE_LPF(binSP, state, action, whs, wha):	alpha_h=0.1	alpha_f=0.1	s_hat=[]	a_hat=[]	h_hat=[]	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety	bins=50	for i in range(0,SNN.NS):		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error			h_trace=[[]]	bin_h=binSP[SNN.NS:SNN.NS+SNN.NH]	#binSP_T=map(list, zip(*binSP))#transposition	for i in range(len(bin_h)):		for j in range(bins):			if j==0:				h_trace[i].append(0.5)			else:				h_trace[i].append((1-alpha_h)*h_trace[i][j-1]+alpha_h*bin_h[i][j])		h_hat.append(h_trace[i][-1])		h_trace.append([])	h_trace.pop()			#calc entropy	entropy=[[]]	for i in range(len(h_trace[0])):		for j in range(len(h_trace)):			if j==0:				entropy[i]=-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])			else:				entropy[i]=entropy[i]-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])		entropy.append([])	entropy.pop()		#calc expected energy	#create matrices	smat=scipy.mat(binSP[0:SNN.NS][:])	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])	whsmat=scipy.mat(whs)	whamat=scipy.mat(wha)	#ExpEnergy	expEnergy_s=-smat.T*whsmat*hmat	expEnergy_a=-amat.T*whamat*hmat	temp=expEnergy_a+expEnergy_s	temp2=temp.diagonal()	expEnergy=temp2.tolist()[0]## convert array back to Python list	#f	f=[]	for i in range(bins):		f.append(-entropy[i]+expEnergy[i])		#FreeEnergy	for i in range(bins):		if i==0:			freeEnergy_t=[f[0]]		else:			freeEnergy_t.append(freeEnergy_t[i-1]+alpha_f*(f[i]-freeEnergy_t[i-1]))		#return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy	return s_hat, a_hat, h_hat, entropy[-1], expEnergy[-1], freeEnergy_t[-1],freeEnergy_t####################################def UpdateW(freeEnergy_1,freeEnergy, reward_1, reward, goal, sarray_1, aarray_1, h_hat, h_hat_1):	if goal==0:		deltaWhs=[]		for i in range(len(sarray_1)):			temp=[]			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*sarray_1[i]*x)			deltaWhs.append(temp)				deltaWha=[]		for i in range(len(aarray_1)):			temp=[]			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*aarray_1[i]*x)			deltaWha.append(temp)		else:		deltaWhs=[]		for i in range(len(sarray_1)):			temp=[]			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE				temp.append((reward+freeEnergy)*sarray_1[i]*x)			deltaWhs.append(temp)				deltaWha=[]		for i in range(len(aarray_1)):			temp=[]			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE				temp.append((reward+freeEnergy)*aarray_1[i]*x)			deltaWha.append(temp)		goal=0	return deltaWhs, deltaWha####################################def WinputSNN(deltaWhs, deltaWha):	whsmat=scipy.mat(SNN.Whs)#to matrix	whamat=scipy.mat(SNN.Wha)	deltaWhsmat=scipy.mat(deltaWhs)	deltaWhamat=scipy.mat(deltaWha)		newWhsmat=whsmat+Alpha*deltaWhsmat	newWhamat=whamat+Alpha*deltaWhamat	#	newWhs=newWhsmat.tolist()#to list#	newWha=newWhamat.tolist()		SNN.Whs=newWhsmat.tolist()# if we use SNN.Whs=NewWhs[:] instead of this,	SNN.Wha=newWhamat.tolist()#len(SNN.Whs[0]) become +1 after NewWhs[i].append(1.0)		SNN.ConnectW()####################################def Outputsd():	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array	spikesender.tofile("Spikesender.txt", sep=', ', format = "%e") 	spiketimes.tofile("Spiketimes.txt", sep=', ', format = "%e") ####################################def BinSPcnt():	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array	maxTime=(spiketimes[-1]//1000 +1)*1000	bins=50	binsize=2		binidx=[]	for i in range(bins):		if max(spiketimes)>maxTime-101+i*binsize:			binidx.append(numpy.nonzero(spiketimes>maxTime-101+i*binsize)[0][0])			#nonzero is like "find" in matlab			j=i		else:			binidx.append(binidx[j])	binspikes=[]	for i in range(bins-1):		binspikes.append(sorted(spikesender[binidx[i]:binidx[i+1]]))	binspikes.append(sorted(spikesender[binidx[bins-1]:]))		binSP=[[]]	for i in range(SNN.NS+SNN.NH+SNN.NA):		for j in range(bins):			if i+1 in binspikes[j]:				binSP[i].append(1)			else:				binSP[i].append(0)		if i != SNN.NS+SNN.NH+SNN.NA-1:			binSP.append([])			#binSP=map(list, zip(*binSP))#transposition					return binSP		#########################################################################mainSNN=SNNc()Flag_FE=0#1 is LPFFlagsa=0#1 for using binary state and action in FEAVET=500Nepoch=1Nepisode=2000Maxstep=30Gamma=0.99Alpha=0.0001Beta=0.1HistoryNstep=[]HistoryCumR=[]Historys=[]Historya=[]HistoryFE=[]HistoryFEts=[]HistoryFEtg=[]HistoryFR=[]	for Epoch in range(Nepoch):	nest.SetKernelStatus({'time':0.0})	Nstep=[]	HistoryCumRtemp=[]	Historystemp2=[]	Historyatemp2=[]	HistoryFEtemp2=[]	HistoryFRtemp2=[]		SNN.InitW()	SNN.ConnectW()			for Episode in range(Nepisode):		Goal=0		Action=0		#State=Episode%2		State=random.randint(0,1)		if State==1:			State=6		FreeEnergy=0.		FreeEnergy_1=0.		Reward=0.		Reward_1=0.		CumR=0		Historystemp=[]		Historyatemp=[]		HistoryFEtemp=[]		HistoryFRtemp=[]				for Step in range(Maxstep):			Obs=Digit(State)			StateClamp(Obs) #run SNN			[HiddenFR, ActionFR]=CalcFR()			Action= Actionselection(ActionFR,Episode)			print("State", State, "Action", Action)			StateClamp(Obs, Action)# for FE			#[HiddenFR, ActionFR]=CalcFR()			BinSP=BinSPcnt()			if Step==0:				State_1=State				Action_1=Action				#HiddenFR_1=HiddenFR				BinSP_1=BinSP[:][:]				Whs_1=SNN.Whs[:][:]				Wha_1=SNN.Wha[:][:]			if Flag_FE==1:				[Sarray, Aarray, H_hat, Entropy, ExpEnergy, FreeEnergy, FreeEnergy_t]=CalcFE_LPF(BinSP, State, Action, Whs_1, Wha_1)#calc FE				[Sarray_1, Aarray_1, H_hat_1, Entropy_1,ExpEnergy_1, FreeEnergy_1, FreeEnergy_t_1]=CalcFE_LPF(BinSP_1, State_1, Action_1, Whs_1, Wha_1)#calc FE			else:				[Sarray, Aarray, H_hat, Entropy, ExpEnergy, FreeEnergy]=CalcFE_AVE(BinSP, State, Action, Whs_1, Wha_1)#calc FE				[Sarray_1, Aarray_1, H_hat_1, Entropy_1,ExpEnergy_1, FreeEnergy_1]=CalcFE_AVE(BinSP_1, State_1, Action_1, Whs_1, Wha_1)#calc FE			State_1=State###			Action_1=Action###			Reward_1=Reward	###			#HiddenFR_1=HiddenFR###			BinSP_1=BinSP[:][:]			Whs_1=SNN.Whs[:][:]	###			Wha_1=SNN.Wha[:][:]	###			[State, Reward, Goal]=StateTrans(State, Action)# move			[DeltaWhs, DeltaWha]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 0, Sarray_1,  Aarray_1, H_hat, H_hat_1)			Print1="Epoch: %d, Episode: %d, Step: %d"			Print2="New state: %d, State: %d, Action: %d, Reward: %d, Goal: %d, CumR=%.3f"			Print3="FreeEnergy(s=%d, a=%d)=%.3f, Entropy=%.3f, ExpEnergy=%.3f"			Print4="ActionFR1=%d, ActionFR2=%d"			print Print1 % (Epoch+1, Episode+1, Step+1)			print Print2 % (State, State_1, Action_1, Reward, Goal, CumR)			print Print3 % (State_1, Action_1, FreeEnergy, Entropy, ExpEnergy)			print Print4 % (ActionFR[0], ActionFR[1])			print("Step",Nstep)			if Step==0:				Historystemp.append(State_1)			Historystemp.append(State)			Historyatemp.append(Action)			HistoryFEtemp.append(FreeEnergy)			HistoryFRtemp.append(ActionFR)			if Step==0 and Epoch ==0 and Flag_FE==1:				HistoryFEts.append(FreeEnergy_t)			if Step != 0:				WinputSNN(DeltaWhs, DeltaWha)#not run 1st move				CumR=CumR+Reward_1*Gamma**(Step-1)			if Goal == 1:				CumR=CumR+Reward*Gamma**(Step+1)				Whs=SNN.Whs[:][:]	####not need if ResetNetwork is off				Wha=SNN.Wha[:][:]	####not need if ResetNetwork is off				nest.ResetNetwork()				SNN.Whs=Whs_1[:][:]	####not need if ResetNetwork is off				SNN.Wha=Wha_1[:][:]	####not need if ResetNetwork is off				[DeltaWhs, DeltaWha]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 1, Sarray, Aarray, H_hat, H_hat_1)				SNN.Whs=Whs[:][:]	####not need if ResetNetwork is off				SNN.Wha=Wha[:][:]	####not need if ResetNetwork is off				WinputSNN(DeltaWhs, DeltaWha)#not run 1st move				break					Historystemp2.append(Historystemp)		Historyatemp2.append(Historyatemp)		HistoryFEtemp2.append(HistoryFEtemp)		HistoryFRtemp2.append(HistoryFRtemp)		Nstep.append(Step+1)		HistoryCumRtemp.append(CumR)		if Epoch ==0 and Flag_FE==1:			if Goal==0:				HistoryFEtg.append([0 for _ in range(50)])			else:				HistoryFEtg.append(FreeEnergy_t)	Historys.append(Historystemp2)	Historya.append(Historyatemp2)	HistoryFE.append(HistoryFEtemp2)	HistoryFR.append(HistoryFRtemp2)	HistoryNstep.append(Nstep)	HistoryCumR.append(HistoryCumRtemp)#to filefHistory= open('History.txt', 'w')for i in range(Nepoch):	for j in range(Nepisode):		fHistory.write("Epoch:%d, Episode: %d\n" % (i+1, j+1))				fHistory.write(str(Historys[i][j]) )		fHistory.write('\n')		fHistory.write(str(Historya[i][j]) )		fHistory.write('\n')		fHistory.write(str(HistoryFE[i][j]) )		fHistory.write('\n')		fHistory.write(str(HistoryFR[i][j]) )		fHistory.write('\n')	fHistory.write('\n')fHistory.close()fHistoryNstep= open('HistoryNstep.txt', 'w')fHistoryCumR= open('HistoryCumR.txt', 'w')for i in range(Nepoch):	for j in range(Nepisode):		fHistoryNstep.write(str(HistoryNstep[i][j]))		fHistoryCumR.write(str(HistoryCumR[i][j]))		if j !=Nepisode-1:			fHistoryNstep.write(', ')			fHistoryCumR.write(', ')	fHistoryNstep.write('\n')	fHistoryCumR.write('\n')fHistoryNstep.close()fHistoryCumR.close()fWhs= open('Whs.txt', 'w')for i in range(SNN.NS):	for j in range(SNN.NH):		fWhs.write(str(SNN.Whs[i][j]))		if j !=SNN.NH-1:			fWhs.write(', ')	fWhs.write('\n')fWhs.close()fWha= open('Wha.txt', 'w')for i in range(SNN.NA):	for j in range(SNN.NH):		fWha.write(str(SNN.Wha[i][j]))		if j !=SNN.NH-1:			fWha.write(', ')	fWha.write('\n')fWha.close()if Flag_FE==1:	fFEts= open('FEts.txt', 'w')	fFEtg= open('FEtg.txt', 'w')	for i in range(Nepisode):		for j in range(50):			fFEts.write(str(HistoryFEts[i][j]))			fFEtg.write(str(HistoryFEtg[i][j]))			if j !=50-1:				fFEts.write(', ')				fFEtg.write(', ')		fFEts.write('\n')		fFEtg.write('\n')	fFEts.close()	fFEtg.close()nest.SetKernelStatus({'time':0.0})nest.SetStatus(SNN.voltmeter,[{"to_file": True, "withtime": True}])Obs=Digit(1)StateClamp(Obs)StateClamp(Obs,1)nest.raster_plot.from_device(SNN.sd, hist=False)plt.xlim( (0, 1000) )#plt.show()plt.savefig('raster.eps')#plt.close()#nest.voltage_trace.from_device(SNN.voltmeter)