# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:26:59 2020
@author: kblackw1
"""
import numpy as np
from RL_class import Environment
############ In fact, this can be used with any task where states are tuples
class separable_T(Environment):
"""Specific Environment:
"""
def __init__(self, states,actions,R,T,params,printR=False):
self.state_types={v:k for v,k in enumerate(states.keys())} #GOOD
self.states=states #GOOD - not really enumerating all states
self.num_states={st:len(v) for st,v in states.items()} #GOOD - not reall enumerating all states
self.Ns=self.num_states[self.state_types[0]]*self.num_states[self.state_types[1]]
print(' ## env init, num states:',self.Ns, 'states types:',self.state_types,'\n states',self.states,self.num_states)
self.actions=actions #Fine
self.Na = len(self.actions) #Fine
super().__init__(self.Ns,self.Na)
self.R=R #reward matrix
self.T=T #transition matrix
self.Hx_len=params['hx_len']
self.start_state=params['start']
self.hx_act=params['hx_act']
self.env_state_bits=0 #which parts of agent state are also env states, make this parameter.
if printR:
reward_thresh_for_printing=0
print('########## R ############')
for s in self.R.keys():
s_words=self.state_from_number(s)
print('state:',s, '=',s_words)
for a in self.R[s].keys():
if np.any([rw[0]>reward_thresh_for_printing for rw in self.R[s][a]]):
print('reward of ',self.R[s][a], 'for state,action pair:',
s_words,self.action_from_number(a))
for a in self.R[s].keys():
print('action:',a, '=', list(self.actions.keys())[list(self.actions.values()).index(a)])
def state_from_number(self,s):
return [list(self.states[self.state_types[i]].keys())[list(self.states[self.state_types[i]].values()).index(t)] for i,t in enumerate(s)]
def action_from_number(self,a):
return list(self.actions.keys())[list(self.actions.values()).index(a)]
def T_for_action_hx(self,lever):
#shift
presses=self.pressHx[1:]
#add new press
self.pressHx=presses+lever
return self.pressHx
def step(self, action,prn_info=False):
"""step by an action"""
#reward from taking action in state
#R[s][a] and T[s][a] are list of tuples,
#in each tuple 1st value, e.g. [0] is reward/new state, 2nd, e.g. p[1] is prob
act=self.action_from_number(action)
num_choices=len(self.R[self.state][action])
weights=[p[1] for p in self.R[self.state][action]]
choice = np.random.choice(num_choices,p=weights) #probabalistic reward
self.reward=self.R[self.state][action][choice][0]
if prn_info and np.abs(self.reward)>2:
print('******* env step, reward=', self.reward,',state=',self.state,self.state_from_number(self.state),',action=',action,act)
#Determine new state (location or other external state) from taking action in state
num_choices=len(self.T[self.state[self.env_state_bits]][action]) #How to select action if T doesn't include pressHx
weights=[p[1] for p in self.T[self.state[self.env_state_bits]][action]]
choice=np.random.choice(num_choices,p=weights) #new state
state_loc=self.T[self.state[self.env_state_bits]][action][choice][0]
#Now determine press history state
env_state=self.state_from_number(self.state)[self.env_state_bits]
if act==self.hx_act and env_state.endswith('lever'):
new_presshx= self.T_for_action_hx(env_state[0]) #1st character of location part of state
state_press=self.states['hx'][new_presshx] #go from press history to state number
else:
state_press=self.state[1] #press_hx part of state, doesn't change if no press
newstate=(state_loc,state_press)
#once reward is received, must scramble the press_hx to prevent agent from getting numerous rewards
#this is kluge, to avoid enumerating all possible transitions
if self.reward>0:
newstate=self.start_state
self.pressHx=self.state_from_number(newstate)[1]
self.state=newstate
#print('new state',self.state,self.state_from_number(self.state))
return self.reward, self.state
def start(self):
"""start an episode"""
self.state = self.start_state
#print('start trial from ',self.state_from_number(self.state))
self.pressHx=self.state_from_number(self.state)[1]
return self.state
def encode(self, state):
i=list(self.R.keys()).index(state)
return i
#
def decode(self, i):
st=list(self.R.keys())[i]
return st