import os
import torch.multiprocessing as mp
from parallel_env import ParallelEnv
from logger import Logger

os.environ['OMP_NUM_THREADS'] = '1'

def learn(n_threads=16, run_num = 100, rewards = [2, 6, 10, 14, 18], min_alpha = 0.0, reward_to_alpha = 0.05, env_sizes=[7 ,9 ,11], env_num_max=10 ):
    mp.set_start_method('spawn')
    n_actions = 4
    input_shape = [3, 42, 42]
    alphas = [ min_alpha + r * reward_to_alpha for r in rewards]
    logger = Logger()

    for env_size in env_sizes:
        print('env_size={0}'.format(env_size))
        logger.reset()
        env_name = "{0}_{0}".format(env_size)
        for r, a in zip(rewards, alphas):
            for env_num in range(0, env_num_max):
                print('alpha={0}, env_num={1}'.format(a, env_num))
                for _ in range(0, run_num):
                    ParallelEnv(env_name=env_name, env_size = env_size, env_num=env_num, num_threads=n_threads,
                                    n_actions=n_actions, input_shape=input_shape,
                                    limit_round_steps = 156, limit_total_steps = 3130, reward=r, alpha=a, stop_threshold=5.0, logger=logger, icm=True)
        logger.save_recoding_trial_log(env_name + '.csv')

if __name__ == '__main__':
    learn()
