import torch.multiprocessing as mp
import numpy as np
import utils as util
import math
import maze as m
from actor_critic import ActorCritic
from icm import ICM
from shared_adam import SharedAdam
from worker import worker
class ParallelEnv:
def __init__(self, env_name, env_size, env_num,
input_shape, n_actions, num_threads, limit_round_steps, limit_total_steps, reward, alpha, stop_threshold, logger, icm=False):
names = [str(i) for i in range(num_threads)]
global_actor_critic = ActorCritic(input_shape, n_actions)
global_actor_critic.share_memory()
global_optim = SharedAdam(global_actor_critic.parameters(), lr=1e-4)
if icm:
global_icm = ICM(input_shape, n_actions)
global_icm.share_memory()
global_icm_optim = SharedAdam(global_icm.parameters(), lr=1e-4)
else:
global_icm = None
global_icm_optim = None
global_count_map = mp.Array('d', env_size**2)
global_round = mp.Value('i', 0)
global_total_steps = mp.Value('i', 0)
global_is_fun = mp.Value('i', 1)
global_clear = mp.Value('i', 0)
maze = self._make_maze(env_name, env_num)
self.ps = [mp.Process(target=worker,
args=(name, input_shape, n_actions,
global_actor_critic, global_optim, env_name, env_num, maze,
global_icm,
global_icm_optim, global_count_map, global_round, global_total_steps, global_is_fun, limit_round_steps, limit_total_steps, global_clear, alpha, stop_threshold, icm))
for name in names]
[p.start() for p in self.ps]
[p.join() for p in self.ps]
goal_rate = 0.0 if global_round.value==0 else global_clear.value / float(global_round.value)
count_map = np.array(global_count_map)
entropy = util.calc_entropy(count_map.reshape(env_size, env_size), maze.conner_points)
k = 1.0/math.log(len(maze.conner_points))
logger.record_trial(map_no=env_num, reward=reward, alpha=alpha, round=global_round.value, step=global_total_steps.value, goal_num=global_clear.value, goal_rate=goal_rate, entropy=entropy, entropy_n=k*entropy)
def _make_maze(self, env_name, env_num):
maze_path = util.get_file_path(env_name)[env_num]
maze_map = util.load_map_from_csv(maze_path)
maze = m.Maze()
maze.initialize_maze(maze_map)
return maze