# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:26:59 2020

@author: kblackw1
"""

import numpy as np
from RL_class import Environment
import copy

############ In fact, this can be used with any task where states are tuples
class completeT(Environment): 
    """Specific Environment:
        """    
    def __init__(self, states,actions,R,T,params,printR=False):
        self.state_types={v:k for v,k in enumerate(states.keys())}
        self.states=copy.copy(states)
        self.num_states={st:len(v) for st,v in states.items()}
        self.Ns=len(T)
        print('   ## env init, num states:',self.Ns, 'states types:',self.state_types,'\n   states',self.states,self.num_states)
        self.actions=copy.copy(actions)
        self.Na = len(self.actions)
        super().__init__(self.Ns,self.Na)        
        self.R=copy.deepcopy(R) #reward matrix
        self.T=copy.deepcopy(T) #transition matrix
        self.start_state=copy.copy(params['start'])
        if printR:
            reward_thresh_for_printing=0
            print('########## R ############')
            for s in self.R.keys():
                s_words=self.state_from_number(s)
                print('state:',s, '=',s_words)
                for a in self.R[s].keys():
                    if np.any([rw[0]>reward_thresh_for_printing and rw[1]>0 for rw in self.R[s][a]]):
                      print('reward of ',self.R[s][a], 'for state,action pair:',
                            s_words,self.action_from_number(a))
            for a in self.R[s].keys():
                print('action:',a, '=', list(self.actions.keys())[list(self.actions.values()).index(a)])
       
    def state_from_number(self,s):
        return [list(self.states[self.state_types[i]].keys())[list(self.states[self.state_types[i]].values()).index(t)] for i,t in enumerate(s)]
    def action_from_number(self,a):
        return list(self.actions.keys())[list(self.actions.values()).index(a)]
   
    def step(self, action,prn_info=False):
        """step by an action"""
        #reward from taking action in state
        #R[s][a] and T[s][a] are list of tuples, 
        #in each tuple 1st value, e.g. [0] is reward/new state, 2nd, e.g. p[1] is prob
        num_choices=len(self.R[self.state][action])
        weights=[p[1] for p in self.R[self.state][action]]
        if np.sum(weights)!=1.0:
            print('Reward probs do not sum to 1',self.state,self.state_from_number(self.state),action,self.action_from_number(action))
        choice = np.random.choice(num_choices,p=weights) #probabalistic reward 
        self.reward=self.R[self.state][action][choice][0] #0 contains reward
        if prn_info and np.abs(self.reward)>2:
            print('******* env reward', self.reward,'state,action',self.state,action)
        #Determine new state from taking action in state
        if len(self.T[self.state][action])!=len(self.R[self.state][action]):
            num_choices=len(self.T[self.state][action])
            #print('***********env',self.state,action,num_choices)
            Tweights=[p[1] for p in self.T[self.state][action]]
            #if len(Tweights)>1:
            #    print('W',weights,'TW',Tweights, 'choice',choice,'state',self.state,'ACTION',action,'\nR',self.R[self.state][action],self.T[self.state][action])
            if np.sum(Tweights)!=1.0:
                print('transition probs do not sum to 1',self.state,self.state_from_number(self.state),action,self.action_from_number(action))
            if Tweights!=weights:
                choice=np.random.choice(num_choices,p=Tweights) #transition selection is separate from reward selection
            #if self.reward>9:
            #    print ('prior state',self.state, 'new state',self.T[self.state][action][choice][0])
        self.state=self.T[self.state][action][choice][0]
        return self.reward, self.state
    
    def start(self):
        """start an episode"""
        self.state = self.start_state
        return self.state

    def encode(self, state):
        i=list(self.R.keys()).index(state)
        return i
    #
    def decode(self, i):
        st=list(self.R.keys())[i]
        return st