Skip to content

Commit

Permalink
LongGameReward
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipinoGambino committed Feb 2, 2024
1 parent bd095fe commit 155e708
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 52 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ reward_space_kwargs: {}
debug: true
act_space: BasicActionSpace
obs_space: BasicObsSpace
reward_space: GameResultReward
reward_space: LongGameReward
optimizer_class: Adam
optimizer_kwargs:
lr: 0.0001
Expand Down
64 changes: 33 additions & 31 deletions connectx/connectx_gym/connectx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,49 +32,47 @@ def __init__(
self.rows = self.env.configuration.rows
self.columns = self.env.configuration.columns

self.reward = 0.
self.action_space = act_space
self.obs_space = obs_space
self.default_reward_space = GameResultReward()
self.info = dict()

def reset(self, **kwargs):
obs = self.trainer.reset()
self.reward = 0.
reward = 0.
done = False
self._update(obs, self.reward)
self._update(obs, reward)

return obs, self.reward, done, self.info
return obs, reward, done, self.info

def step(self, action):
obs, _, done, _ = self.trainer.step(action)
self.reward = self.turn / math.prod(BOARD_SIZE)
self._update(obs, self.reward, action)
return obs, self.reward, done, self.info
obs, reward, done, _ = self.trainer.step(action)
self._update(obs, reward, action)
return obs, reward, done, self.info

def process_actions(self, logits: np.ndarray) -> Tuple[List[List[str]], Dict[str, np.ndarray]]:
step = self.env.state[0]['observation']['step']
board = self.env.state[0]['observation']['board']
obs = np.array(board).reshape(BOARD_SIZE)
print(f"\naction logits:\n{logits}")
valid_action_logits = self.action_space.process_actions(
logits,
obs,
)
print(f"\nvalid actions:\n{valid_action_logits}")
valid_action_probs = softmax(valid_action_logits)
action = np.random.choice(BOARD_SIZE[1], p=valid_action_probs)

self.info.update(
dict(
logits=logits,
masked_logits=valid_action_logits,
masked_probs=valid_action_probs,
action=action,
step=step,
)
)
return action
# def process_actions(self, logits: np.ndarray) -> Tuple[List[List[str]], Dict[str, np.ndarray]]:
# step = self.env.state[0]['observation']['step']
# board = self.env.state[0]['observation']['board']
# obs = np.array(board).reshape(BOARD_SIZE)
# print(f"\naction logits:\n{logits}")
# valid_action_logits = self.action_space.process_actions(
# logits,
# obs,
# )
# print(f"\nvalid actions:\n{valid_action_logits}")
# valid_action_probs = softmax(valid_action_logits)
# action = np.random.choice(BOARD_SIZE[1], p=valid_action_probs)
#
# self.info.update(
# dict(
# logits=logits,
# masked_logits=valid_action_logits,
# masked_probs=valid_action_probs,
# action=action,
# step=step,
# )
# )
# return action

def _update(self, obs, reward, action=-1):
obs_array = np.array(obs['board']).reshape((1,*BOARD_SIZE))
Expand All @@ -91,3 +89,7 @@ def render(self, **kwargs):
@property
def turn(self):
return self.env.state[0]['observation']['step']

@property
def done(self):
return self.env.done
53 changes: 33 additions & 20 deletions connectx/connectx_gym/reward_spaces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import NamedTuple, Tuple, Dict
from abc import ABC, abstractmethod
import logging
import math

from kaggle_environments.core import Environment
import numpy as np

from connectx_env import ConnectFour
from ..utility_constants import BOARD_SIZE

class RewardSpec(NamedTuple):
reward_min: float
reward_max: float
Expand Down Expand Up @@ -50,34 +54,43 @@ class GameResultReward(FullGameRewardSpace):
@staticmethod
def get_reward_spec() -> RewardSpec:
return RewardSpec(
reward_min=-10.,
reward_max=10.,
reward_min=-1.,
reward_max=1.,
zero_sum=True,
only_once=False
only_once=True
)

def __init__(self, early_stop: bool = False, **kwargs):
super(GameResultReward, self).__init__(**kwargs)
self.early_stop = early_stop

def compute_rewards(self, game_state: Environment) -> Tuple[float, bool]:
def compute_rewards(self, game_state: ConnectFour) -> Tuple[float, bool]:
if self.early_stop:
raise NotImplementedError # done = done or should_early_stop(game_state)
return self._compute_rewards(game_state), game_state.done

def _compute_rewards(self, game_state: ConnectFour) -> float:
if not game_state.done:
return 0.
return game_state.info['reward']

class LongGameReward(BaseRewardSpace):
@staticmethod
def get_reward_spec() -> RewardSpec:
return RewardSpec(
reward_min=-1.,
reward_max=1.,
zero_sum=False,
only_once=False
)
def __init__(self, early_stop: bool = False, **kwargs):
super(LongGameReward, self).__init__(**kwargs)
self.early_stop = early_stop

def compute_rewards(self, game_state: ConnectFour) -> Tuple[float, bool]:
if self.early_stop:
raise NotImplementedError # done = done or should_early_stop(game_state)
return self._compute_rewards(game_state), game_state.done

def _compute_rewards(self, game_state: Environment) -> float:
if game_state.done:
return game_state.reward * 10.
return game_state.reward

# def compute_rewards(self, game_state: Environment) -> Tuple[Tuple[float, float], bool]:
# if self.early_stop:
# raise NotImplementedError # done = done or should_early_stop(game_state)
# done = game_state.done
# return self._compute_rewards(game_state, done), done
#
# def _compute_rewards(self, game_state: Environment, done: bool) -> Tuple[float, float]:
# if not done:
# return 0., 0.
# rewards = (game_state.state[0]['reward'], game_state.state[1]['reward'])
# return rewards
def _compute_rewards(self, game_state: ConnectFour) -> float:
return game_state.turn / math.prod(BOARD_SIZE)

0 comments on commit 155e708

Please sign in to comment.