From e793cdbcd0cadd0f612bb462ea21206e2a150af9 Mon Sep 17 00:00:00 2001 From: Nicholas Gorichs Date: Fri, 2 Feb 2024 14:14:04 -0600 Subject: [PATCH] dropping logging actions for now --- conf/new_beginnings.yaml | 4 ++++ connectx/connectx_gym/__init__.py | 2 +- connectx/connectx_gym/connectx_env.py | 5 ++--- connectx/torchbeast/monobeast.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/conf/new_beginnings.yaml b/conf/new_beginnings.yaml index 697fa2f..589c42a 100644 --- a/conf/new_beginnings.yaml +++ b/conf/new_beginnings.yaml @@ -36,6 +36,10 @@ obs_space_kwargs: {} reward_space_kwargs: {} debug: True +# Environment params +adversary: random +#adversary: negamax + # Model params act_space: BasicActionSpace obs_space: BasicObsSpace diff --git a/connectx/connectx_gym/__init__.py b/connectx/connectx_gym/__init__.py index dfa2cdb..b780d30 100644 --- a/connectx/connectx_gym/__init__.py +++ b/connectx/connectx_gym/__init__.py @@ -31,7 +31,7 @@ def create_env(flags, device: torch.device, teacher_flags: Optional = None, seed act_space=flags.act_space(), obs_space=create_flexible_obs_space(flags, teacher_flags), player_id=flags.player_id, - seed=seed + adversary=flags.adversary ) reward_space = create_reward_space(flags) env = RewardSpaceWrapper(env, reward_space) diff --git a/connectx/connectx_gym/connectx_env.py b/connectx/connectx_gym/connectx_env.py index 5840a53..24ca25e 100644 --- a/connectx/connectx_gym/connectx_env.py +++ b/connectx/connectx_gym/connectx_env.py @@ -20,13 +20,12 @@ def __init__( act_space: BaseActSpace, obs_space: BaseObsSpace, player_id: int, - seed: Optional[int] = 42, + adversary: str, ): super(ConnectFour, self).__init__() self.env = make("connectx", debug=True) self.player_id = player_id - # players = ["negamax", "negamax"] - players = ["random", "random"] + players = [adversary, adversary] players[player_id] = None self.trainer = self.env.train(players) diff --git a/connectx/torchbeast/monobeast.py b/connectx/torchbeast/monobeast.py index f446141..5ca68f0 100644 --- a/connectx/torchbeast/monobeast.py +++ b/connectx/torchbeast/monobeast.py @@ -492,7 +492,7 @@ def learn( for key, val in batch["info"].items() if key.startswith("LOGGING_") and "ACTIONS_" not in key }, - "Actions": batch['actions'], + # "Actions": batch['actions'], "Loss": { "vtrace_pg_loss": vtrace_pg_loss.detach().item(), "upgo_pg_loss": upgo_pg_loss.detach().item(),