# -*- coding: utf-8 -*-
"""
Created on Mon Oct 14 14:19:19 2019
@author: ocalvin
"""
import argparse
import time
import os
import numpy as np
from agent import DPX_Agent
from world import DPX
from DPXAnalysis2021 import dpx_raster_plot
# -------------- Parameters ----------------
# Create the mappings between the agent and world
stDist = 0.20
cueMap = {'A': np.pi * (0.5 - stDist),
'B': np.pi * (0.5 + stDist),
'X': np.pi * (1.5 - stDist),
'Y': np.pi * (1.5 + stDist)
}
actMap = {0: 'O', 1: 'L', 2: 'R'}
# Output Control
shwTrials = True # Assumes that you only want to run a single experiment
collectData = False
# ------------- Parse Program Arguments --------------------
parser = argparse.ArgumentParser(
prog='DPX_lab.py',
description='Runs a dual-ring agent on the DPX task.',
usage='%(prog)s outfolder [options]'
)
name = 'outfolder'
if not collectData: name = '--' + name
parser.add_argument(
name,
type=str,
help='Folder that recorded data will be stored in.'
)
dpxGroup = parser.add_argument_group('DPX Task')
DPX.fill_parser(dpxGroup)
# Add the parser arguments for the agent
agentGroup = parser.add_argument_group('Dual-Ring SoftMax Agent')
DPX_Agent.fill_parser(agentGroup)
args = vars(parser.parse_args())
# ------------------ Incorporate arguments ------------------------
# output location for this trial
outputFolder = args['outfolder']
# If running a local test use these parameters
if shwTrials:
args.update({'aProp': 1.00})
args.update({'axProp': 1.00})
args.update({'bxProp': 1.00})
#popped.setdefault('mpNMDAg', 1.00)
# Set the World defaults for this experiment
popped = DPX.pop_kwargs(args)
# Create the world
dpxTask = DPX(cueMap, **popped)
# Set the Agent defaults for this experiment
popped = DPX_Agent.pop_kwargs(args)
# Create the agent
rAgent = DPX_Agent(act_map=actMap, **popped)
# Set the weights for the agent's action kinetics
rAgent.set_act_weight(cueMap.get('A') - (np.pi * stDist),
cueMap.get('A') + (np.pi * stDist),
0.075, 'mem', 0
)
rAgent.set_act_weight(cueMap.get('B') - (np.pi * stDist),
cueMap.get('B') + (np.pi * stDist),
0.375, 'mem', 1
)
rAgent.set_act_weight(cueMap.get('X') - (np.pi * stDist),
cueMap.get('X') + (np.pi * stDist),
0.05, 'perc', 0
)
rAgent.set_act_weight(cueMap.get('Y') - (np.pi * stDist),
cueMap.get('Y') + (np.pi * stDist),
0.25, 'perc', 1
)
# ------------------ Print Details -------------------------
dpxTask.description()
rAgent.description()
# ------------------ Start the Task ------------------------
# Creates the directory if it doesn't exist yet
if collectData and not os.path.exists(outputFolder):
os.makedirs(outputFolder)
# Set the dpxParameters
tskDur = (dpxTask.preCueDur
+ dpxTask.cueDur
+ dpxTask.ISI
+ dpxTask.probeDur
+ dpxTask.ITI
)
# Run the trials
for t in range(dpxTask.numTrials):
startTime = time.time()
# Reset the state of the agent before starting the trial
rAgent.state_reset()
dpxTask.next_trial(rAgent)
# If the data is being collected
if collectData:
rec_pp, rec_pi, rec_mp, rec_mi = rAgent.pull_data()
# collects the spiking information
rec_pp.save(outputFolder + 'percPyr.npz', t+1)
rec_pi.save(outputFolder + 'percInt.npz', t+1)
rec_mp.save(outputFolder + 'memPyr.npz', t+1)
rec_mi.save(outputFolder + 'memInt.npz', t+1)
# save the behavioral data
dpxTask.output_data(outputFolder + 'dpx.csv', t+1)
if shwTrials:
print("Trial duration: ", round(time.time() - startTime), " seconds")
rec_pp, rec_pi, rec_mp, rec_mi = rAgent.pull_data()
# Plot the raster plots for the perception and memory
dpx_raster_plot(rec_pp, dpxTask.current_cp())
dpx_raster_plot(rec_mp, dpxTask.current_cp())
break
# Let the user know how long the trial took
print("Trial ", t+1, " duration: ", round(time.time() - startTime),
" seconds")