Skip to content

Commit

Permalink
per step reward
Browse files Browse the repository at this point in the history
Per game reward is better for a pre-trained model
  • Loading branch information
FilipinoGambino committed Feb 1, 2024
1 parent e12b4c8 commit bd095fe
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
4 changes: 3 additions & 1 deletion connectx/connectx_gym/connectx_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kaggle_environments import make
from typing import Dict, List, Optional, Tuple
import gym
import math
import numpy as np
from scipy.special import softmax

Expand Down Expand Up @@ -46,7 +47,8 @@ def reset(self, **kwargs):
return obs, self.reward, done, self.info

def step(self, action):
obs, self.reward, done, _ = self.trainer.step(action)
obs, _, done, _ = self.trainer.step(action)
self.reward = self.turn / math.prod(BOARD_SIZE)
self._update(obs, self.reward, action)
return obs, self.reward, done, self.info

Expand Down
17 changes: 8 additions & 9 deletions connectx/connectx_gym/reward_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ def compute_rewards(self, game_state: Environment) -> Tuple[Tuple[float, float],
pass

@abstractmethod
def _compute_rewards(self, game_state: dict, done: bool) -> Tuple[float, float]:
def _compute_rewards(self, game_state: dict) -> Tuple[float, float]:
pass


class GameResultReward(FullGameRewardSpace):
@staticmethod
def get_reward_spec() -> RewardSpec:
return RewardSpec(
reward_min=-1.,
reward_max=1.,
reward_min=-10.,
reward_max=10.,
zero_sum=True,
only_once=True
only_once=False
)

def __init__(self, early_stop: bool = False, **kwargs):
Expand All @@ -63,12 +63,11 @@ def __init__(self, early_stop: bool = False, **kwargs):
def compute_rewards(self, game_state: Environment) -> Tuple[float, bool]:
if self.early_stop:
raise NotImplementedError # done = done or should_early_stop(game_state)
done = game_state.done
return self._compute_rewards(game_state, done), done
return self._compute_rewards(game_state), game_state.done

def _compute_rewards(self, game_state: Environment, done: bool) -> float:
if not done:
return 0.
def _compute_rewards(self, game_state: Environment) -> float:
if game_state.done:
return game_state.reward * 10.
return game_state.reward

# def compute_rewards(self, game_state: Environment) -> Tuple[Tuple[float, float], bool]:
Expand Down
7 changes: 3 additions & 4 deletions connectx/connectx_gym/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def __init__(self, env: gym.Env, reward_space: BaseRewardSpace):
super(LoggingEnv, self).__init__(env)
self.reward_space = reward_space
self.vals_peak = {}
self.reward_sum = [0., 0.]
self.reward_sum = []

def info(self, info: Dict[str, np.ndarray], rewards: int) -> Dict[str, np.ndarray]:
info = copy.copy(info)
player = self.env.unwrapped.player_id
logs = dict(step=self.env.unwrapped.turn)

self.reward_sum[player] = rewards + self.reward_sum[player]
self.reward_sum.append(rewards)
logs["mean_cumulative_rewards"] = [np.mean(self.reward_sum)]
logs["mean_cumulative_reward_magnitudes"] = [np.mean(np.abs(self.reward_sum))]
logs["max_cumulative_rewards"] = [np.max(self.reward_sum)]
Expand All @@ -38,7 +37,7 @@ def info(self, info: Dict[str, np.ndarray], rewards: int) -> Dict[str, np.ndarra

def reset(self, **kwargs):
obs, reward, done, info = super(LoggingEnv, self).reset(**kwargs)
self.reward_sum = [0., 0.]
self.reward_sum = [reward]
return obs, [reward], done, self.info(info, reward)

def step(self, action: Dict[str, np.ndarray]):
Expand Down

0 comments on commit bd095fe

Please sign in to comment.