Skip to content

Commit 2754200

Browse files
albertbou92vmoens
andauthored
[Feature] Submitit run script (#1822)
Co-authored-by: vmoens <[email protected]>
1 parent 06fcac1 commit 2754200

File tree

94 files changed

+1043
-110
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+1043
-110
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
4545
optim.updates_per_episode=3 \
4646
optim.warmup_steps=10 \
4747
optim.device=cuda:0 \
48+
env.backend=gymnasium \
4849
logger.backend=
4950
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \
5051
optim.gradient_steps=55 \

examples/a2c/a2c_atari.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@ def main(cfg: "DictConfig"): # noqa: F821
9393
if cfg.logger.backend:
9494
exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
9595
logger = get_logger(
96-
cfg.logger.backend, logger_name="a2c", experiment_name=exp_name
96+
cfg.logger.backend,
97+
logger_name="a2c",
98+
experiment_name=exp_name,
99+
wandb_kwargs={
100+
"config": dict(cfg),
101+
"project": cfg.logger.project_name,
102+
"group": cfg.logger.group_name,
103+
},
97104
)
98105

99106
# Create test environment

examples/a2c/a2c_mujoco.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@ def main(cfg: "DictConfig"): # noqa: F821
7979
if cfg.logger.backend:
8080
exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
8181
logger = get_logger(
82-
cfg.logger.backend, logger_name="a2c", experiment_name=exp_name
82+
cfg.logger.backend,
83+
logger_name="a2c",
84+
experiment_name=exp_name,
85+
wandb_kwargs={
86+
"config": dict(cfg),
87+
"project": cfg.logger.project_name,
88+
"group": cfg.logger.group_name,
89+
},
8390
)
8491

8592
# Create test environment

examples/a2c/config_atari.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ collector:
1111
# logger
1212
logger:
1313
backend: wandb
14+
project_name: torchrl_example_a2c
15+
group_name: null
1416
exp_name: Atari_Schulman17
1517
test_interval: 40_000_000
1618
num_test_episodes: 3

examples/a2c/config_mujoco.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# task and env
22
env:
3-
env_name: HalfCheetah-v3
3+
env_name: HalfCheetah-v4
44

55
# collector
66
collector:
@@ -10,6 +10,8 @@ collector:
1010
# logger
1111
logger:
1212
backend: wandb
13+
project_name: torchrl_example_a2c
14+
group_name: null
1315
exp_name: Mujoco_Schulman17
1416
test_interval: 1_000_000
1517
num_test_episodes: 5

examples/bandits/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Bandits example
2+
3+
## Note:
4+
This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the
5+
benchmarking of future releases, to ensure that it can be successfully run with the release code and that the
6+
results are consistent. For now, be aware that this additional check has not been performed in the case of this
7+
specific example.

examples/cql/cql_offline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,19 @@
3232
@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
3333
def main(cfg: "DictConfig"): # noqa: F821
3434
# Create logger
35-
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
35+
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
3636
logger = None
3737
if cfg.logger.backend:
3838
logger = get_logger(
3939
logger_type=cfg.logger.backend,
4040
logger_name="cql_logging",
4141
experiment_name=exp_name,
42-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
42+
wandb_kwargs={
43+
"mode": cfg.logger.mode,
44+
"config": dict(cfg),
45+
"project": cfg.logger.project_name,
46+
"group": cfg.logger.group_name,
47+
},
4348
)
4449
# Set seeds
4550
torch.manual_seed(cfg.env.seed)

examples/cql/cql_online.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,19 @@
3636
@hydra.main(version_base="1.1", config_path=".", config_name="online_config")
3737
def main(cfg: "DictConfig"): # noqa: F821
3838
# Create logger
39-
exp_name = generate_exp_name("CQL-online", cfg.env.exp_name)
39+
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
4040
logger = None
4141
if cfg.logger.backend:
4242
logger = get_logger(
4343
logger_type=cfg.logger.backend,
4444
logger_name="cql_logging",
4545
experiment_name=exp_name,
46-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
46+
wandb_kwargs={
47+
"mode": cfg.logger.mode,
48+
"config": dict(cfg),
49+
"project": cfg.logger.project_name,
50+
"group": cfg.logger.group_name,
51+
},
4752
)
4853

4954
# Set seeds

examples/cql/discrete_cql_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ env:
33
name: CartPole-v1
44
task: ""
55
backend: gym
6-
exp_name: cql_cartpole_gym
76
n_samples_stats: 1000
87
max_episode_steps: 200
98
seed: 0
@@ -24,6 +23,9 @@ collector:
2423
# Logger
2524
logger:
2625
backend: wandb
26+
project_name: torchrl_example_cql
27+
group_name: null
28+
exp_name: cql_cartpole_gym
2729
log_interval: 5000 # record interval in frames
2830
eval_steps: 200
2931
mode: online

examples/cql/discrete_cql_online.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,18 @@ def main(cfg: "DictConfig"): # noqa: F821
3838
device = torch.device(cfg.optim.device)
3939

4040
# Create logger
41-
exp_name = generate_exp_name("DiscreteCQL", cfg.env.exp_name)
41+
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
4242
logger = None
4343
if cfg.logger.backend:
4444
logger = get_logger(
4545
logger_type=cfg.logger.backend,
4646
logger_name="discretecql_logging",
4747
experiment_name=exp_name,
48-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
48+
wandb_kwargs={
49+
"mode": cfg.logger.mode,
50+
"config": dict(cfg),
51+
"project": cfg.logger.project_name,
52+
},
4953
)
5054

5155
# Set seeds

0 commit comments

Comments
 (0)