# -*- coding: utf-8 -*-
####################################
#Nakano T, Otsuka M, Yoshimoto J, Doya K (2015) A Spiking Neural Network Model of Model-Free Reinforcement Learning with High-Dimensional Sensory Input and Perceptual Ambiguity. PLoS ONE 10(3): e0115620. doi:10.1371/journal.pone.0115620
####################################

import nest
nest.pynestkernel.logstdout("log.txt")
import nest.raster_plot
import nest.voltage_trace
import pylab
import numpy
import scipy
import time, sys
import matplotlib.pyplot as plt
import random
import operator
import math

nest.ResetKernel()
nest.SetKernelStatus({"overwrite_files": True})


class SNNc:
	
	def InitW(self):
		##init weight
		##Whs
		self.Whs=[]
		for j in range(self.NS):
			Whstemp=[]
			for i in range(self.NH):
				Whstemp.append(random.normalvariate(self.InitWhsmean, self.InitWhsstd))
			self.Whs.append(Whstemp)
		
		##Wha
		self.Wha=[]
		for j in range(self.NA):
			Whatemp=[]
			for i in range(self.NH):
				Whatemp.append(random.normalvariate(self.InitWhamean,self.InitWhastd))
			self.Wha.append(Whatemp)
	
	def ConnectW(self):
		##Whs
		connWhs=[]
		for i in range(self.NS):
			self.Whs[i].append(1.0)
			connWhs.append(nest.FindConnections([self.Sneurons[i]]))
			nest.SetStatus(connWhs[i],['weight'][0],self.Whs[i])
			self.Whs[i].pop() 
		
		##Wha
		connWha=[]
		for i in range(self.NA):
			self.Wha[i].append(1.0)
			connWha.append(nest.FindConnections([self.Aneurons[i]]))
			nest.SetStatus(connWha[i],['weight'][0],self.Wha[i])
			self.Wha[i].pop() # for Wah
		
		##Wah
		self.Wah=map(list, zip(*self.Wha))#transpose
		for i in range(self.NH):
			self.Wah[i].append(1.0)
		connWah=[]
		for i in range(self.NH):
			connWah.append(nest.FindConnections([self.Hneurons[i]]))
			nest.SetStatus(connWah[i],['weight'][0],self.Wah[i])
	
		
		
	#define parameters
	NS=90 # number of state neurons
	NH=90 # number of hidden neurons
	NA=90 # number of action neurons
	InitWhsmean=10 # mean weight Whs
	InitWhsstd=11.88  # std weight Whs
	InitWhamean=10 # mean weight Wha
	InitWhastd=11.88  # std weight Wha
	Wseed=0#seed for weight
	Nseed=[123]#seed for noise
	driveparams_on  = {'amplitude':1000.}#current inputs to state neurons (and action neurons)
	driveparams_on_a  = {'amplitude':1000.}#current inputs to state neurons (and action neurons)
	driveparams_off_a  = {'amplitude':-2000.}#current inputs to state neurons (and action neurons)
	driveparams_off  = {'amplitude':0.}#no current inputs
	noiseparams  = {'mean':0.0, 'std':600.}#noise inputs to all state and action neurons
	sdparams  = { "withtime": True, "withgid" : True,'to_file':False, 'to_screen':False,'flush_after_simulate':True,'flush_records':True}
	#neuronparams = { 'tau_m':20., 'V_th':-50., 'E_L':-60., 't_ref':2., 'V_reset':-60., 'C_m':200.}
	
	#create neurons
	Sneurons = nest.Create('iaf_neuron',NS)
	Hneurons = nest.Create('iaf_neuron',NH)
	Aneurons = nest.Create('iaf_neuron',NA)
	
	sd= nest.Create('spike_detector')
	drive= nest.Create('dc_generator',6+2)
	nest.SetKernelStatus({'rng_seeds':Nseed})#noise
	noise= nest.Create('noise_generator',3)
	voltmeter = nest.Create("voltmeter")
	
	#set parameters
	nest.SetStatus(sd,[sdparams] )
	nest.SetStatus(noise,[noiseparams] ) # if noise selection works, comment out this.
	#nest.SetStatus(Sneurons, [neuronparams])
	#nest.SetStatus(Hneurons, [neuronparams])
	#nest.SetStatus(Aneurons, [neuronparams])
	
	#connect
	nest.DivergentConnect(noise[0:1], Sneurons)
	nest.DivergentConnect(noise[1:2], Aneurons)
	nest.DivergentConnect(noise[2:3], Hneurons)
	
	nest.DivergentConnect(drive[0:1], Sneurons[0:NS/6])
	nest.DivergentConnect(drive[1:2], Sneurons[NS/6:NS*2/6])
	nest.DivergentConnect(drive[2:3], Sneurons[NS*2/6:NS*3/6])
	nest.DivergentConnect(drive[3:4], Sneurons[NS*3/6:NS*4/6])
	nest.DivergentConnect(drive[4:5], Sneurons[NS*4/6:NS*5/6])
	nest.DivergentConnect(drive[5:6], Sneurons[NS*5/6:NS])
	nest.DivergentConnect(drive[6:7], Aneurons[0:NA/2])
	nest.DivergentConnect(drive[7:8], Aneurons[NA/2:NA])
	
	nest.ConvergentConnect(Sneurons, Hneurons, weight=100.0, delay=1.0)#whight is no meaning because it is defined later
	nest.ConvergentConnect(Hneurons, Aneurons, weight=100.0, delay=1.0)#DivergentConnect?
	nest.ConvergentConnect(Aneurons, Hneurons, weight=100.0, delay=1.0)

	nest.ConvergentConnect(Sneurons, sd)
	nest.ConvergentConnect(Hneurons, sd)
	nest.ConvergentConnect(Aneurons, sd)
	
	nest.Connect(voltmeter, Hneurons[20:21])
	
	#weight	
	random.seed(Wseed)
	

	
####################################
def StateClamp(state=1, action=0):
	nest.SetStatus(SNN.drive,[SNN.driveparams_off] )# all drives are 0
	if state==0:
		nest.SetStatus(SNN.drive[0:1],[SNN.driveparams_on] )
	if state==1:
		nest.SetStatus(SNN.drive[1:2],[SNN.driveparams_on] )
	if state==2:
		nest.SetStatus(SNN.drive[2:3],[SNN.driveparams_on] )
	if state==4:
		nest.SetStatus(SNN.drive[3:4],[SNN.driveparams_on] )
	if state==5:
		nest.SetStatus(SNN.drive[4:5],[SNN.driveparams_on] )
	if state==6:
		nest.SetStatus(SNN.drive[5:6],[SNN.driveparams_on] )
	if action==-1:
		nest.SetStatus(SNN.drive[6:7],[SNN.driveparams_on_a] )
		nest.SetStatus(SNN.drive[7:8],[SNN.driveparams_off_a] )
	if action==1:
		nest.SetStatus(SNN.drive[6:7],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[7:8],[SNN.driveparams_on_a] )	
	
	nest.Simulate(T)

####################################
def CalcFR():
	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
	
	maxTime=(spiketimes[-1]//T +1)*T
        IdxLastSP=numpy.nonzero(spiketimes>maxTime-100)[0][0]# last 100 ms

	spcount=spikesender[IdxLastSP:]
	spcount2=sorted(spcount)
	fr=[]
	for i in range(SNN.NS+SNN.NH+SNN.NA+1):
		fr.append(spcount2.count(i))
	#note
	#len(fr)=SNN.NS+SNN.NH+SNN.NA+1
	#fr[0] is no meaning
	
	hiddenFR=fr[SNN.NS+1:SNN.NS+SNN.NH+1]
	actionFR=[]
	actionFR.append(sum(fr[SNN.NS+SNN.NH+1:SNN.NS+SNN.NH+SNN.NA/2+1]))
	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA/2+1:SNN.NS+SNN.NH+SNN.NA+1]))
	return hiddenFR, actionFR


####################################
def Actionselection(actionFR, episode):
	beta=Beta*float(episode)/float(Nepisode)
	if random.random()<(math.exp(beta*actionFR[0]))/(math.exp(beta*actionFR[0])+math.exp(beta*actionFR[1])):# +1 is to avoid /0
		action=-1
	else:
		action=1
	return action
	

####################################
def StateTrans(state,action):
	reward=-1000.
	goal=0
	
	nextState=state+action
	if nextState==-1: 
		nextState=0
	if nextState==7:
		nextState=6
	if nextState==3:
		reward=50000.
		goal=1
			
	return nextState, reward, goal
	
####################################
def CalcFE_AVE(binSP, state, action, whs, wha):
	#h_hat
	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety
	#this can be considered as mean over bins if maxFR is 50
	bins=50
	s_hat=[]
	a_hat=[]
	h_hat=[]
	for i in range(0,SNN.NS):
		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS,SNN.NS+SNN.NH):
		h_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):
		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error



	#calc entropy
	entropy=0
	for x in h_hat:
		entropy=entropy-x*math.log(x)-(1-x)*math.log(1-x)
	#calc expected energy
	##create state array
	sarray=[0]*SNN.NS
	if state<3:
		sarray[state*16:state*16+16]=[1]*16
	else:
		sarray[(state-1)*16:(state-1)*16+16]=[1]*16
	sarraya=numpy.array(sarray)
	
	##create action array
	aarray=[0]*SNN.NA
	aarray[(action+1)/2*50:(action+1)/2*50+50]=[1]*50
	aarraya=numpy.array(aarray)
	#create matrices
	smat=scipy.mat(binSP[0:SNN.NS][:])
	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])
	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])
	
	whsmat=scipy.mat(whs)
	whamat=scipy.mat(wha)
	#ExpEnergy
	if Flagsa==1:
		expEnergy_s=-sarraya*whsmat*hmat
		expEnergy_a=-aarraya*whamat*hmat
		temp2=expEnergy_a+expEnergy_s
		
	else:
		expEnergy_s=-smat.T*whsmat*hmat
		expEnergy_a=-amat.T*whamat*hmat
		temp=expEnergy_a+expEnergy_s
		temp2=temp.diagonal()	
	expEnergy=temp2.tolist()[0]## convert array back to Python list
	expEnergy_mean=sum(expEnergy)/bins
	#FreeEnergy
	freeEnergy=-entropy+expEnergy_mean
	
	if Flagsa==1:
		return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy
	else:
		return s_hat, a_hat, h_hat, entropy, expEnergy_mean, freeEnergy
####################################
def CalcFE_LPF(binSP, state, action, whs, wha):
	alpha_h=0.1
	alpha_f=0.1
	s_hat=[]
	a_hat=[]
	h_hat=[]
	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety
	bins=50

	for i in range(0,SNN.NS):
		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):
		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
		
	h_trace=[[]]
	bin_h=binSP[SNN.NS:SNN.NS+SNN.NH]
	#binSP_T=map(list, zip(*binSP))#transposition
	for i in range(len(bin_h)):
		for j in range(bins):
			if j==0:
				h_trace[i].append(0.5)
			else:
				h_trace[i].append((1-alpha_h)*h_trace[i][j-1]+alpha_h*bin_h[i][j])
		h_hat.append(h_trace[i][-1])
		h_trace.append([])
	h_trace.pop()
	
	
	#calc entropy
	entropy=[[]]
	for i in range(len(h_trace[0])):
		for j in range(len(h_trace)):
			if j==0:
				entropy[i]=-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])
			else:
				entropy[i]=entropy[i]-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])
		entropy.append([])
	entropy.pop()
	
	#calc expected energy
	#create matrices
	smat=scipy.mat(binSP[0:SNN.NS][:])
	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])
	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])
	whsmat=scipy.mat(whs)
	whamat=scipy.mat(wha)
	#ExpEnergy
	expEnergy_s=-smat.T*whsmat*hmat
	expEnergy_a=-amat.T*whamat*hmat
	temp=expEnergy_a+expEnergy_s
	temp2=temp.diagonal()
	expEnergy=temp2.tolist()[0]## convert array back to Python list
	#f
	f=[]
	for i in range(bins):
		f.append(-entropy[i]+expEnergy[i])
	
	#FreeEnergy
	for i in range(bins):
		if i==0:
			freeEnergy_t=[f[0]]
		else:
			freeEnergy_t.append(freeEnergy_t[i-1]+alpha_f*(f[i]-freeEnergy_t[i-1]))
	
	#return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy
	return s_hat, a_hat, h_hat, entropy[-1], expEnergy[-1], freeEnergy_t[-1],freeEnergy_t

####################################
def UpdateW(freeEnergy_1,freeEnergy, reward_1, reward, goal, sarray_1, aarray_1, h_hat, h_hat_1):
	if goal==0:
		deltaWhs=[]
		for i in range(len(sarray_1)):
			temp=[]
			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*sarray_1[i]*x)
			deltaWhs.append(temp)
		
		deltaWha=[]
		for i in range(len(aarray_1)):
			temp=[]
			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*aarray_1[i]*x)
			deltaWha.append(temp)
	
	else:
		deltaWhs=[]
		for i in range(len(sarray_1)):
			temp=[]
			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
				temp.append((reward+freeEnergy)*sarray_1[i]*x)
			deltaWhs.append(temp)
		
		deltaWha=[]
		for i in range(len(aarray_1)):
			temp=[]
			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
				temp.append((reward+freeEnergy)*aarray_1[i]*x)
			deltaWha.append(temp)
		goal=0
	return deltaWhs, deltaWha


####################################
def WinputSNN(deltaWhs, deltaWha):
	whsmat=scipy.mat(SNN.Whs)#to matrix
	whamat=scipy.mat(SNN.Wha)
	deltaWhsmat=scipy.mat(deltaWhs)
	deltaWhamat=scipy.mat(deltaWha)
	
	newWhsmat=whsmat+Alpha*deltaWhsmat
	newWhamat=whamat+Alpha*deltaWhamat
	
#	newWhs=newWhsmat.tolist()#to list
#	newWha=newWhamat.tolist()
	
	SNN.Whs=newWhsmat.tolist()# if we use SNN.Whs=NewWhs[:] instead of this,
	SNN.Wha=newWhamat.tolist()#len(SNN.Whs[0]) become +1 after NewWhs[i].append(1.0)
	
	SNN.ConnectW()

####################################
def Outputsd():
	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
	spikesender.tofile("Spikesender.txt", sep=', ', format = "%e") 
	spiketimes.tofile("Spiketimes.txt", sep=', ', format = "%e") 

####################################
def BinSPcnt():
	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
	maxTime=(spiketimes[-1]//1000 +1)*1000
	bins=50
	binsize=2	
	binidx=[]
	for i in range(bins):
		if max(spiketimes)>maxTime-101+i*binsize:
			binidx.append(numpy.nonzero(spiketimes>maxTime-101+i*binsize)[0][0])
			#nonzero is like "find" in matlab
			j=i
		else:
			binidx.append(binidx[j])
	binspikes=[]
	for i in range(bins-1):
		binspikes.append(sorted(spikesender[binidx[i]:binidx[i+1]]))
	binspikes.append(sorted(spikesender[binidx[bins-1]:]))
	
	binSP=[[]]
	for i in range(SNN.NS+SNN.NH+SNN.NA):
		for j in range(bins):
			if i+1 in binspikes[j]:
				binSP[i].append(1)
			else:
				binSP[i].append(0)
		if i != SNN.NS+SNN.NH+SNN.NA-1:
			binSP.append([])
		
	#binSP=map(list, zip(*binSP))#transposition
				
	return binSP
	
	
####################################
####################################
#main
SNN=SNNc()
Flag_FE=0#1 is LPF
Flagsa=0#1 for using binary state and action in FEAVE
T=500
Nepoch=1
Nepisode=2000
Maxstep=30
Gamma=0.99
Alpha=0.0001
Beta=0.1
HistoryNstep=[]
HistoryCumR=[]
Historys=[]
Historya=[]
HistoryFE=[]
HistoryFEts=[]
HistoryFEtg=[]
HistoryFR=[]
	
for Epoch in range(Nepoch):
	nest.SetKernelStatus({'time':0.0})
	Nstep=[]
	HistoryCumRtemp=[]
	Historystemp2=[]
	Historyatemp2=[]
	HistoryFEtemp2=[]
	HistoryFRtemp2=[]
	
	SNN.InitW()
	SNN.ConnectW()
	
	
	for Episode in range(Nepisode):
		Goal=0
		Action=0
		#State=Episode%2
		State=random.randint(0,1)
		if State==1:
			State=6
		FreeEnergy=0.
		FreeEnergy_1=0.
		Reward=0.
		Reward_1=0.
		CumR=0
		Historystemp=[]
		Historyatemp=[]
		HistoryFEtemp=[]
		HistoryFRtemp=[]
		
		for Step in range(Maxstep):
			StateClamp(State) #run SNN
			[HiddenFR, ActionFR]=CalcFR()
			Action= Actionselection(ActionFR,Episode)
			print("State", State, "Action", Action)
			StateClamp(State, Action)# for FE
			#[HiddenFR, ActionFR]=CalcFR()
			BinSP=BinSPcnt()
			if Step==0:
				State_1=State
				Action_1=Action
				#HiddenFR_1=HiddenFR
				BinSP_1=BinSP[:][:]
				Whs_1=SNN.Whs[:][:]
				Wha_1=SNN.Wha[:][:]
			if Flag_FE==1:
				[Sarray, Aarray, H_hat, Entropy, ExpEnergy, FreeEnergy, FreeEnergy_t]=CalcFE_LPF(BinSP, State, Action, Whs_1, Wha_1)#calc FE
				[Sarray_1, Aarray_1, H_hat_1, Entropy_1,ExpEnergy_1, FreeEnergy_1, FreeEnergy_t_1]=CalcFE_LPF(BinSP_1, State_1, Action_1, Whs_1, Wha_1)#calc FE
			else:
				[Sarray, Aarray, H_hat, Entropy, ExpEnergy, FreeEnergy]=CalcFE_AVE(BinSP, State, Action, Whs_1, Wha_1)#calc FE
				[Sarray_1, Aarray_1, H_hat_1, Entropy_1,ExpEnergy_1, FreeEnergy_1]=CalcFE_AVE(BinSP_1, State_1, Action_1, Whs_1, Wha_1)#calc FE
			State_1=State###
			Action_1=Action###
			Reward_1=Reward	###
			#HiddenFR_1=HiddenFR###
			BinSP_1=BinSP[:][:]
			Whs_1=SNN.Whs[:][:]	###
			Wha_1=SNN.Wha[:][:]	###
			[State, Reward, Goal]=StateTrans(State, Action)# move
			[DeltaWhs, DeltaWha]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 0, Sarray_1,  Aarray_1, H_hat, H_hat_1)
			Print1="Epoch: %d, Episode: %d, Step: %d"
			Print2="New state: %d, State: %d, Action: %d, Reward: %d, Goal: %d, CumR=%.3f"
			Print3="FreeEnergy(s=%d, a=%d)=%.3f, Entropy=%.3f, ExpEnergy=%.3f"
			Print4="ActionFR1=%d, ActionFR2=%d"
			print Print1 % (Epoch+1, Episode+1, Step+1)
			print Print2 % (State, State_1, Action_1, Reward, Goal, CumR)
			print Print3 % (State_1, Action_1, FreeEnergy, Entropy, ExpEnergy)
			print Print4 % (ActionFR[0], ActionFR[1])
			print("Step",Nstep)
			if Step==0:
				Historystemp.append(State_1)
			Historystemp.append(State)
			Historyatemp.append(Action)
			HistoryFEtemp.append(FreeEnergy)
			HistoryFRtemp.append(ActionFR)
			if Step==0 and Epoch ==0 and Flag_FE==1:
				HistoryFEts.append(FreeEnergy_t)
			if Step != 0:
				WinputSNN(DeltaWhs, DeltaWha)#not run 1st move
				CumR=CumR+Reward_1*Gamma**(Step-1)
			if Goal == 1:
				CumR=CumR+Reward*Gamma**(Step+1)
				Whs=SNN.Whs[:][:]	####not need if ResetNetwork is off
				Wha=SNN.Wha[:][:]	####not need if ResetNetwork is off
				nest.ResetNetwork()
				SNN.Whs=Whs_1[:][:]	####not need if ResetNetwork is off
				SNN.Wha=Wha_1[:][:]	####not need if ResetNetwork is off
				[DeltaWhs, DeltaWha]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 1, Sarray, Aarray, H_hat, H_hat_1)
				SNN.Whs=Whs[:][:]	####not need if ResetNetwork is off
				SNN.Wha=Wha[:][:]	####not need if ResetNetwork is off
				WinputSNN(DeltaWhs, DeltaWha)#not run 1st move
				break
			
		Historystemp2.append(Historystemp)
		Historyatemp2.append(Historyatemp)
		HistoryFEtemp2.append(HistoryFEtemp)
		HistoryFRtemp2.append(HistoryFRtemp)
		Nstep.append(Step+1)
		HistoryCumRtemp.append(CumR)
		if Epoch ==0 and Flag_FE==1:
			if Goal==0:
				HistoryFEtg.append([0 for _ in range(50)])
			else:
				HistoryFEtg.append(FreeEnergy_t)


	Historys.append(Historystemp2)
	Historya.append(Historyatemp2)
	HistoryFE.append(HistoryFEtemp2)
	HistoryFR.append(HistoryFRtemp2)
	HistoryNstep.append(Nstep)
	HistoryCumR.append(HistoryCumRtemp)


#to file
fHistory= open('History.txt', 'w')
for i in range(Nepoch):
	for j in range(Nepisode):
		fHistory.write("Epoch:%d, Episode: %d\n" % (i+1, j+1))		
		fHistory.write(str(Historys[i][j]) )
		fHistory.write('\n')
		fHistory.write(str(Historya[i][j]) )
		fHistory.write('\n')
		fHistory.write(str(HistoryFE[i][j]) )
		fHistory.write('\n')
		fHistory.write(str(HistoryFR[i][j]) )
		fHistory.write('\n')
	fHistory.write('\n')
fHistory.close()

fHistoryNstep= open('HistoryNstep.txt', 'w')
fHistoryCumR= open('HistoryCumR.txt', 'w')
for i in range(Nepoch):
	for j in range(Nepisode):
		fHistoryNstep.write(str(HistoryNstep[i][j]))
		fHistoryCumR.write(str(HistoryCumR[i][j]))
		if j !=Nepisode-1:
			fHistoryNstep.write(', ')
			fHistoryCumR.write(', ')
	fHistoryNstep.write('\n')
	fHistoryCumR.write('\n')
fHistoryNstep.close()
fHistoryCumR.close()

fWhs= open('Whs.txt', 'w')
for i in range(SNN.NS):
	for j in range(SNN.NH):
		fWhs.write(str(SNN.Whs[i][j]))
		if j !=SNN.NH-1:
			fWhs.write(', ')
	fWhs.write('\n')
fWhs.close()

fWha= open('Wha.txt', 'w')
for i in range(SNN.NA):
	for j in range(SNN.NH):
		fWha.write(str(SNN.Wha[i][j]))
		if j !=SNN.NH-1:
			fWha.write(', ')
	fWha.write('\n')
fWha.close()

if Flag_FE==1:
	fFEts= open('FEts.txt', 'w')
	fFEtg= open('FEtg.txt', 'w')
	for i in range(Nepisode):
		for j in range(50):
			fFEts.write(str(HistoryFEts[i][j]))
			fFEtg.write(str(HistoryFEtg[i][j]))
			if j !=50-1:
				fFEts.write(', ')
				fFEtg.write(', ')
		fFEts.write('\n')
		fFEtg.write('\n')
	fFEts.close()
	fFEtg.close()



nest.SetKernelStatus({'time':0.0})
nest.SetStatus(SNN.voltmeter,[{"to_file": True, "withtime": True}])
StateClamp(1)
StateClamp(1,1)
nest.raster_plot.from_device(SNN.sd, hist=False)
plt.xlim( (0, 1000) )
##plt.show()
#plt.savefig('raster.eps')
##plt.close()
##nest.voltage_trace.from_device(SNN.voltmeter)