import os
import torch.multiprocessing as mp
from parallel_env import ParallelEnv
from logger import Logger
os.environ['OMP_NUM_THREADS'] = '1'
def learn(n_threads=16, run_num = 100, rewards = [2, 6, 10, 14, 18], min_alpha = 0.0, reward_to_alpha = 0.05, env_sizes=[7 ,9 ,11], env_num_max=10 ):
mp.set_start_method('spawn')
n_actions = 4
input_shape = [3, 42, 42]
alphas = [ min_alpha + r * reward_to_alpha for r in rewards]
logger = Logger()
for env_size in env_sizes:
print('env_size={0}'.format(env_size))
logger.reset()
env_name = "{0}_{0}".format(env_size)
for r, a in zip(rewards, alphas):
for env_num in range(0, env_num_max):
print('alpha={0}, env_num={1}'.format(a, env_num))
for _ in range(0, run_num):
ParallelEnv(env_name=env_name, env_size = env_size, env_num=env_num, num_threads=n_threads,
n_actions=n_actions, input_shape=input_shape,
limit_round_steps = 156, limit_total_steps = 3130, reward=r, alpha=a, stop_threshold=5.0, logger=logger, icm=True)
logger.save_recoding_trial_log(env_name + '.csv')
if __name__ == '__main__':
learn()