import numpy as np
import torch as T
from actor_critic import ActorCritic
from icm import ICM
from memory import Memory
import maze_env as me
import utils as util
def worker(name, input_shape, n_actions, global_agent,
optimizer, env_name, env_num, maze, global_icm,
icm_optimizer, global_count_map, global_round, global_total_steps, global_is_fun, limit_round_steps, limit_total_steps, global_clear, alpha, stop_threshold, icm):
local_agent = ActorCritic(input_shape, n_actions)
if icm:
local_icm = ICM(input_shape, n_actions, alpha=alpha)
else:
local_icm = None
memory = Memory()
env = me.MazeEnv()
env.load_maze(maze=maze, repeat=1)
while global_is_fun.value==1:
with global_total_steps.get_lock():
current_total_steps = global_total_steps.value
do_steps = limit_total_steps - current_total_steps if (limit_total_steps - current_total_steps) < limit_round_steps else limit_round_steps
done, steps, i_reward, count_map = do_round(env=env, global_agent=global_agent, global_icm=global_icm, local_agent=local_agent,
memory=memory, optimizer=optimizer, icm=icm, local_icm=local_icm, icm_optimizer=icm_optimizer, limit_round_steps=do_steps)
with global_is_fun.get_lock():
if global_is_fun.value == 0:
break
evalute_value = i_reward * 500 + util.noize(0.5)
if evalute_value <= stop_threshold:
global_is_fun.value = 0
with global_round.get_lock():
global_round.value += 1
global_total_steps.value += steps
global_clear.value += 1 if done else 0
cout_map_ = count_map.ravel()
for i in range(0, env.maze_size**2):
global_count_map[i] = cout_map_[i]
def do_round(env, global_agent, global_icm, local_agent, memory, optimizer, icm, local_icm, icm_optimizer, limit_round_steps):
T_MAX = 20
obs = env.reset()
count_map = np.zeros((env.maze_size, env.maze_size))
done, ep_steps, i_reward = False, 0, 0.0
hx = T.zeros(1, 256)
while not done and ep_steps < limit_round_steps:
state = T.tensor([obs], dtype=T.float)
action, value, log_prob, hx = local_agent(state, hx)
old_posx, old_posy = env.posx, env.posy
obs_, reward, done, info = env.step(action)
if not (env.posx == old_posx and env.posy == old_posy):
count_map[env.posy, env.posx] += 1
memory.remember(obs, action, obs_, reward, value, log_prob)
obs = obs_
ep_steps += 1
if ep_steps % T_MAX == 0 or done or ep_steps >= limit_round_steps:
states, actions, new_states, rewards, values, log_probs = \
memory.sample_memory()
if icm:
intrinsic_reward, L_I, L_F = \
local_icm.calc_loss(states, new_states, actions)
i_reward += intrinsic_reward.sum().item()
loss = local_agent.calc_cost(obs, hx, done, rewards,
values, log_probs,
intrinsic_reward)
optimizer.zero_grad()
hx = hx.detach_()
if icm:
icm_optimizer.zero_grad()
(L_I + L_F).backward()
loss.backward()
T.nn.utils.clip_grad_norm_(local_agent.parameters(), 40)
for local_param, global_param in zip(
local_agent.parameters(),
global_agent.parameters()):
global_param._grad = local_param.grad
optimizer.step()
local_agent.load_state_dict(global_agent.state_dict())
if icm:
for local_param, global_param in zip(
local_icm.parameters(),
global_icm.parameters()):
global_param._grad = local_param.grad
icm_optimizer.step()
local_icm.load_state_dict(global_icm.state_dict())
memory.clear_memory()
return done, ep_steps, i_reward, count_map