Skip to content

Commit

Permalink
Dropped Tensorflow wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipinoGambino committed Feb 2, 2024
1 parent add8931 commit 9a9995c
Show file tree
Hide file tree
Showing 11 changed files with 550 additions and 33 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ reward_space_kwargs: {}
debug: true
act_space: BasicActionSpace
obs_space: BasicObsSpace
reward_space: LongGameReward
reward_space: GameResultReward
optimizer_class: Adam
optimizer_kwargs:
lr: 0.0001
Expand Down
2 changes: 1 addition & 1 deletion connectx/connectx_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from . import act_spaces, obs_spaces, reward_spaces
from .connectx_env import ConnectFour
from .wrappers import DictEnv, LoggingEnv, PytorchEnv, RewardSpaceWrapper, TensorflowEnv, VecEnv
from .wrappers import DictEnv, LoggingEnv, PytorchEnv, RewardSpaceWrapper, VecEnv

ACT_SPACES_DICT = {
key: val for key, val in act_spaces.__dict__.items()
Expand Down
1 change: 0 additions & 1 deletion connectx/connectx_gym/connectx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from .act_spaces import BaseActSpace
from .obs_spaces import BaseObsSpace
from .reward_spaces import GameResultReward

from ..utility_constants import BOARD_SIZE

Expand Down
34 changes: 4 additions & 30 deletions connectx/connectx_gym/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import copy
import gym
import torch
import numpy as np
import copy
# import tensorflow as tf

from kaggle_environments.core import Environment, make
from typing import Dict, List, NoReturn, Optional, Union, Tuple
import torch
from typing import Dict, List, Union, Tuple

# from . import ConnectFour
from .act_spaces import BaseActSpace
from .obs_spaces import BaseObsSpace
from .reward_spaces import BaseRewardSpace, GameResultReward
from .reward_spaces import BaseRewardSpace


class LoggingEnv(gym.Wrapper):
Expand Down Expand Up @@ -166,26 +160,6 @@ def _to_tensor(self, x: Union[Dict, np.ndarray]) -> Dict[str, Union[Dict, torch.
elif isinstance(x, np.ndarray):
return torch.from_numpy(x).to(device=self.device)

class TensorflowEnv(gym.Wrapper):
def __init__(self, env: Union[gym.Env, VecEnv], device: torch.device = torch.device("cpu")):
super(TensorflowEnv, self).__init__(env)
self.device = device

def reset(self, **kwargs) -> Tuple[Dict, List, bool, List]:
return tuple([self._to_tensor(out) for out in super(TensorflowEnv, self).reset(**kwargs)])

def step(self, actions: List[torch.Tensor]):
action = [
act.cpu().numpy() for act in actions
]
return tuple([self._to_tensor(out) for out in super(TensorflowEnv, self).step(action)])

def _to_tensor(self, x: Union[Dict, np.ndarray]) -> Dict[str, Union[Dict, torch.Tensor]]:
if isinstance(x, dict):
return {key: self._to_tensor(val) for key, val in x.items()}
else:
return tf.convert_to_tensor(x, dtype=tf.float32)


class DictEnv(gym.Wrapper):
@staticmethod
Expand Down
41 changes: 41 additions & 0 deletions outputs/02-02/09-55-53/.hydra/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: new_beginnings
project: ConnectX
entity: filipinogambino
group: debug
use_mixed_precision: false
total_steps: 10000.0
batch_size: 8
checkpoint_freq: 60.0
num_actors: 1
n_actor_envs: 2
unroll_length: 42
player_id: 0
seed: 42
model_arch: mha_model
embedding_dim: 32
hidden_dim: 128
n_heads: 4
n_blocks: 1
device: cpu
rescale_value_input: false
obs_space_kwargs: {}
reward_space_kwargs: {}
debug: true
act_space: BasicActionSpace
obs_space: BasicObsSpace
reward_space: GameResultReward
optimizer_class: Adam
optimizer_kwargs:
lr: 0.0001
eps: 0.0003
min_lr_mod: 0.01
entropy_cost: 0.001
baseline_cost: 1.0
teacher_kl_cost: 0.0
lmb: 0.8
reduction: sum
actor_device: cpu
learner_device: cpu
disable_wandb: false
model_log_freq: 100
sharing_strategy: file_descriptor
154 changes: 154 additions & 0 deletions outputs/02-02/09-55-53/.hydra/hydra.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
hydra:
run:
dir: ./outputs/${now:%m-%d}/${now:%H-%M-%S}
sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.job.num}
launcher:
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
sweeper:
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
max_batch_size: null
params: null
help:
app_name: ${hydra.job.name}
header: '${hydra.help.app_name} is powered by Hydra.
'
footer: 'Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help
'
template: '${hydra.help.header}
== Configuration groups ==
Compose your configuration from those groups (group=option)
$APP_CONFIG_GROUPS
== Config ==
Override anything in the config (foo.bar=value)
$CONFIG
${hydra.help.footer}
'
hydra_help:
template: 'Hydra (${hydra.runtime.version})
See https://hydra.cc for more info.
== Flags ==
$FLAGS_HELP
== Configuration groups ==
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
to command line)
$HYDRA_CONFIG_GROUPS
Use ''--cfg hydra'' to Show the Hydra config.
'
hydra_help: ???
hydra_logging:
version: 1
formatters:
simple:
format: '[%(asctime)s][HYDRA] %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
root:
level: INFO
handlers:
- console
loggers:
logging_example:
level: DEBUG
disable_existing_loggers: false
job_logging:
version: 1
formatters:
simple:
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
root:
level: INFO
handlers:
- console
- file
disable_existing_loggers: false
env: {}
mode: RUN
searchpath: []
callbacks: {}
output_subdir: .hydra
overrides:
hydra:
- hydra.mode=RUN
task: []
job:
name: run_monobeast
chdir: null
override_dirname: ''
id: ???
num: ???
config_name: new_beginnings
env_set: {}
env_copy: []
config:
override_dirname:
kv_sep: '='
item_sep: ','
exclude_keys: []
runtime:
version: 1.3.2
version_base: '1.3'
cwd: C:\Users\nick.gorichs\PycharmProjects\Connect_Four_2
config_sources:
- path: hydra.conf
schema: pkg
provider: hydra
- path: C:\Users\nick.gorichs\PycharmProjects\Connect_Four_2\conf
schema: file
provider: main
- path: ''
schema: structured
provider: schema
output_dir: C:\Users\nick.gorichs\PycharmProjects\Connect_Four_2\outputs\02-02\09-55-53
choices:
hydra/env: default
hydra/callbacks: null
hydra/job_logging: default
hydra/hydra_logging: default
hydra/hydra_help: default
hydra/help: default
hydra/sweeper: basic
hydra/launcher: basic
hydra/output: default
verbose: false
1 change: 1 addition & 0 deletions outputs/02-02/09-55-53/.hydra/overrides.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Loading

0 comments on commit 9a9995c

Please sign in to comment.