import torch.multiprocessing as mp
import numpy as np
import utils as util
import math
import maze as m
from actor_critic import ActorCritic
from icm import ICM
from shared_adam import SharedAdam
from worker import worker

class ParallelEnv:
    def __init__(self, env_name, env_size, env_num,
                 input_shape, n_actions, num_threads, limit_round_steps, limit_total_steps, reward, alpha, stop_threshold, logger, icm=False):
        names = [str(i) for i in range(num_threads)]

        global_actor_critic = ActorCritic(input_shape, n_actions)
        global_actor_critic.share_memory()
        global_optim = SharedAdam(global_actor_critic.parameters(), lr=1e-4)

        if icm:
            global_icm = ICM(input_shape, n_actions)
            global_icm.share_memory()
            global_icm_optim = SharedAdam(global_icm.parameters(), lr=1e-4)
        else:
            global_icm = None
            global_icm_optim = None
        
        global_count_map = mp.Array('d', env_size**2)
        global_round = mp.Value('i', 0)
        global_total_steps = mp.Value('i', 0)
        global_is_fun = mp.Value('i', 1)
        global_clear = mp.Value('i', 0)
        maze = self._make_maze(env_name, env_num)

        self.ps = [mp.Process(target=worker,
                              args=(name, input_shape, n_actions,
                                    global_actor_critic, global_optim, env_name, env_num, maze,
                                    global_icm,
                                    global_icm_optim, global_count_map, global_round, global_total_steps, global_is_fun, limit_round_steps, limit_total_steps, global_clear, alpha, stop_threshold, icm))
                   for name in names]

        [p.start() for p in self.ps]
        [p.join() for p in self.ps]
        goal_rate = 0.0 if global_round.value==0 else global_clear.value / float(global_round.value)
        count_map = np.array(global_count_map)
        entropy = util.calc_entropy(count_map.reshape(env_size, env_size), maze.conner_points)
        k = 1.0/math.log(len(maze.conner_points))
        logger.record_trial(map_no=env_num, reward=reward, alpha=alpha, round=global_round.value, step=global_total_steps.value, goal_num=global_clear.value, goal_rate=goal_rate, entropy=entropy, entropy_n=k*entropy)

    def _make_maze(self, env_name, env_num):
        maze_path = util.get_file_path(env_name)[env_num]
        maze_map = util.load_map_from_csv(maze_path)
        maze = m.Maze()
        maze.initialize_maze(maze_map)
        return maze