Skip to content

Commit

Permalink
dropping logging actions for now
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipinoGambino committed Feb 2, 2024
1 parent 3f32bb6 commit e793cdb
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
4 changes: 4 additions & 0 deletions conf/new_beginnings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ obs_space_kwargs: {}
reward_space_kwargs: {}
debug: True

# Environment params
adversary: random
#adversary: negamax

# Model params
act_space: BasicActionSpace
obs_space: BasicObsSpace
Expand Down
2 changes: 1 addition & 1 deletion connectx/connectx_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_env(flags, device: torch.device, teacher_flags: Optional = None, seed
act_space=flags.act_space(),
obs_space=create_flexible_obs_space(flags, teacher_flags),
player_id=flags.player_id,
seed=seed
adversary=flags.adversary
)
reward_space = create_reward_space(flags)
env = RewardSpaceWrapper(env, reward_space)
Expand Down
5 changes: 2 additions & 3 deletions connectx/connectx_gym/connectx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def __init__(
act_space: BaseActSpace,
obs_space: BaseObsSpace,
player_id: int,
seed: Optional[int] = 42,
adversary: str,
):
super(ConnectFour, self).__init__()
self.env = make("connectx", debug=True)
self.player_id = player_id
# players = ["negamax", "negamax"]
players = ["random", "random"]
players = [adversary, adversary]
players[player_id] = None
self.trainer = self.env.train(players)

Expand Down
2 changes: 1 addition & 1 deletion connectx/torchbeast/monobeast.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def learn(
for key, val in batch["info"].items()
if key.startswith("LOGGING_") and "ACTIONS_" not in key
},
"Actions": batch['actions'],
# "Actions": batch['actions'],
"Loss": {
"vtrace_pg_loss": vtrace_pg_loss.detach().item(),
"upgo_pg_loss": upgo_pg_loss.detach().item(),
Expand Down

0 comments on commit e793cdb

Please sign in to comment.