Skip to content

Commit 52a12a8

Browse files
committed
Merge remote-tracking branch 'origin/main' into release/0.3.0
2 parents b9cf712 + 69453a6 commit 52a12a8

File tree

120 files changed

+1136
-224
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+1136
-224
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
4545
optim.updates_per_episode=3 \
4646
optim.warmup_steps=10 \
4747
optim.device=cuda:0 \
48+
env.backend=gymnasium \
4849
logger.backend=
4950
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \
5051
optim.gradient_steps=55 \

examples/a2c/a2c_atari.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@ def main(cfg: "DictConfig"): # noqa: F821
9393
if cfg.logger.backend:
9494
exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
9595
logger = get_logger(
96-
cfg.logger.backend, logger_name="a2c", experiment_name=exp_name
96+
cfg.logger.backend,
97+
logger_name="a2c",
98+
experiment_name=exp_name,
99+
wandb_kwargs={
100+
"config": dict(cfg),
101+
"project": cfg.logger.project_name,
102+
"group": cfg.logger.group_name,
103+
},
97104
)
98105

99106
# Create test environment

examples/a2c/a2c_mujoco.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@ def main(cfg: "DictConfig"): # noqa: F821
7979
if cfg.logger.backend:
8080
exp_name = generate_exp_name("A2C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
8181
logger = get_logger(
82-
cfg.logger.backend, logger_name="a2c", experiment_name=exp_name
82+
cfg.logger.backend,
83+
logger_name="a2c",
84+
experiment_name=exp_name,
85+
wandb_kwargs={
86+
"config": dict(cfg),
87+
"project": cfg.logger.project_name,
88+
"group": cfg.logger.group_name,
89+
},
8390
)
8491

8592
# Create test environment

examples/a2c/config_atari.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ collector:
1111
# logger
1212
logger:
1313
backend: wandb
14+
project_name: torchrl_example_a2c
15+
group_name: null
1416
exp_name: Atari_Schulman17
1517
test_interval: 40_000_000
1618
num_test_episodes: 3

examples/a2c/config_mujoco.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# task and env
22
env:
3-
env_name: HalfCheetah-v3
3+
env_name: HalfCheetah-v4
44

55
# collector
66
collector:
@@ -10,6 +10,8 @@ collector:
1010
# logger
1111
logger:
1212
backend: wandb
13+
project_name: torchrl_example_a2c
14+
group_name: null
1315
exp_name: Mujoco_Schulman17
1416
test_interval: 1_000_000
1517
num_test_episodes: 5

examples/bandits/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Bandits example
2+
3+
## Note:
4+
This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the
5+
benchmarking of future releases, to ensure that it can be successfully run with the release code and that the
6+
results are consistent. For now, be aware that this additional check has not been performed in the case of this
7+
specific example.

examples/cql/cql_offline.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,19 @@
3232
@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
3333
def main(cfg: "DictConfig"): # noqa: F821
3434
# Create logger
35-
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
35+
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
3636
logger = None
3737
if cfg.logger.backend:
3838
logger = get_logger(
3939
logger_type=cfg.logger.backend,
4040
logger_name="cql_logging",
4141
experiment_name=exp_name,
42-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
42+
wandb_kwargs={
43+
"mode": cfg.logger.mode,
44+
"config": dict(cfg),
45+
"project": cfg.logger.project_name,
46+
"group": cfg.logger.group_name,
47+
},
4348
)
4449
# Set seeds
4550
torch.manual_seed(cfg.env.seed)

examples/cql/cql_online.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,19 @@
3636
@hydra.main(version_base="1.1", config_path=".", config_name="online_config")
3737
def main(cfg: "DictConfig"): # noqa: F821
3838
# Create logger
39-
exp_name = generate_exp_name("CQL-online", cfg.env.exp_name)
39+
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
4040
logger = None
4141
if cfg.logger.backend:
4242
logger = get_logger(
4343
logger_type=cfg.logger.backend,
4444
logger_name="cql_logging",
4545
experiment_name=exp_name,
46-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
46+
wandb_kwargs={
47+
"mode": cfg.logger.mode,
48+
"config": dict(cfg),
49+
"project": cfg.logger.project_name,
50+
"group": cfg.logger.group_name,
51+
},
4752
)
4853

4954
# Set seeds

examples/cql/discrete_cql_config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ env:
33
name: CartPole-v1
44
task: ""
55
backend: gym
6-
exp_name: cql_cartpole_gym
76
n_samples_stats: 1000
87
max_episode_steps: 200
98
seed: 0
@@ -24,6 +23,9 @@ collector:
2423
# Logger
2524
logger:
2625
backend: wandb
26+
project_name: torchrl_example_cql
27+
group_name: null
28+
exp_name: cql_cartpole_gym
2729
log_interval: 5000 # record interval in frames
2830
eval_steps: 200
2931
mode: online

examples/cql/discrete_cql_online.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,18 @@ def main(cfg: "DictConfig"): # noqa: F821
3838
device = torch.device(cfg.optim.device)
3939

4040
# Create logger
41-
exp_name = generate_exp_name("DiscreteCQL", cfg.env.exp_name)
41+
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
4242
logger = None
4343
if cfg.logger.backend:
4444
logger = get_logger(
4545
logger_type=cfg.logger.backend,
4646
logger_name="discretecql_logging",
4747
experiment_name=exp_name,
48-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
48+
wandb_kwargs={
49+
"mode": cfg.logger.mode,
50+
"config": dict(cfg),
51+
"project": cfg.logger.project_name,
52+
},
4953
)
5054

5155
# Set seeds

examples/cql/offline_config.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# env and task
22
env:
3-
name: Hopper-v2
3+
name: Hopper-v4
44
task: ""
55
library: gym
6-
exp_name: cql_${replay_buffer.dataset}
76
n_samples_stats: 1000
87
seed: 0
98
backend: gym # D4RL uses gym so we make sure gymnasium is hidden
109

1110
# logger
1211
logger:
1312
backend: wandb
13+
project_name: torchrl_example_cql
14+
group_name: null
15+
exp_name: cql_${replay_buffer.dataset}
1416
eval_iter: 5000
1517
eval_steps: 1000
1618
mode: online

examples/cql/online_config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
env:
33
name: Pendulum-v1
44
task: ""
5-
exp_name: cql_${env.name}
65
n_samples_stats: 1000
76
seed: 0
87
train_num_envs: 1
@@ -23,6 +22,9 @@ collector:
2322
# logger
2423
logger:
2524
backend: wandb
25+
project_name: torchrl_example_cql
26+
group_name: null
27+
exp_name: cql_${env.name}
2628
log_interval: 5000 # record interval in frames
2729
eval_steps: 1000
2830
mode: online

examples/ddpg/config.yaml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# environment and task
22
env:
3-
name: HalfCheetah-v3
3+
name: HalfCheetah-v4
44
task: ""
5-
exp_name: ${env.name}_DDPG
65
library: gymnasium
76
max_episode_steps: 1000
87
seed: 42
@@ -22,7 +21,7 @@ collector:
2221
replay_buffer:
2322
size: 1000000
2423
prb: 0 # use prioritized experience replay
25-
scratch_dir: ${env.exp_name}_${env.seed}
24+
scratch_dir: ${logger.exp_name}_${env.seed}
2625

2726
# optimization
2827
optim:
@@ -44,5 +43,8 @@ network:
4443
# logging
4544
logger:
4645
backend: wandb
46+
project_name: torchrl_example_ddpg
47+
group_name: null
48+
exp_name: ${env.name}_DDPG
4749
mode: online
4850
eval_iter: 25000

examples/ddpg/ddpg.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821
3838
device = torch.device(cfg.network.device)
3939

4040
# Create logger
41-
exp_name = generate_exp_name("DDPG", cfg.env.exp_name)
41+
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
4242
logger = None
4343
if cfg.logger.backend:
4444
logger = get_logger(
4545
logger_type=cfg.logger.backend,
4646
logger_name="ddpg_logging",
4747
experiment_name=exp_name,
48-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
48+
wandb_kwargs={
49+
"mode": cfg.logger.mode,
50+
"config": dict(cfg),
51+
"project": cfg.logger.project_name,
52+
"group": cfg.logger.group_name,
53+
},
4954
)
5055

5156
# Set seeds

examples/decision_transformer/dt_config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# environment and task
22
env:
3-
name: HalfCheetah-v3
3+
name: HalfCheetah-v4
44
task: ""
55
library: gym
66
stacked_frames: 20
@@ -20,7 +20,9 @@ env:
2020
# logger
2121
logger:
2222
backend: wandb
23+
project_name: torchrl_example_dt
2324
model_name: DT
25+
group_name: null
2426
exp_name: DT-HalfCheetah-medium-v2
2527
pretrain_log_interval: 500 # record interval in frames
2628
fintune_log_interval: 1

examples/decision_transformer/odt_config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# environment and task
22
env:
3-
name: HalfCheetah-v3
3+
name: HalfCheetah-v4
44
task: ""
55
library: gym
66
stacked_frames: 20
@@ -20,6 +20,8 @@ env:
2020
# logger
2121
logger:
2222
backend: wandb
23+
project_name: torchrl_example_odt
24+
group_name: null
2325
exp_name: oDT-HalfCheetah-medium-v2
2426
model_name: oDT
2527
pretrain_log_interval: 500 # record interval in frames

examples/decision_transformer/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -493,17 +493,18 @@ def make_dt_optimizer(optim_cfg, loss_module):
493493

494494

495495
def make_logger(cfg):
496-
from omegaconf import OmegaConf
497-
498496
if not cfg.logger.backend:
499497
return None
500498
exp_name = generate_exp_name(cfg.logger.model_name, cfg.logger.exp_name)
501-
cfg.logger.exp_name = exp_name
502499
logger = get_logger(
503500
cfg.logger.backend,
504501
logger_name=cfg.logger.model_name,
505502
experiment_name=exp_name,
506-
wandb_kwargs={"config": OmegaConf.to_container(cfg)},
503+
wandb_kwargs={
504+
"config": dict(cfg),
505+
"project": cfg.logger.project_name,
506+
"group": cfg.logger.group_name,
507+
},
507508
)
508509
return logger
509510

examples/discrete_sac/config.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
env:
44
name: CartPole-v1
55
task: ""
6-
exp_name: ${env.name}_DiscreteSAC
76
library: gym
87
seed: 42
98
max_episode_steps: 500
@@ -23,7 +22,7 @@ collector:
2322
replay_buffer:
2423
prb: 0 # use prioritized experience replay
2524
size: 1000000
26-
scratch_dir: ${env.exp_name}_${env.seed}
25+
scratch_dir: ${logger.exp_name}_${env.seed}
2726

2827
# optim
2928
optim:
@@ -48,5 +47,8 @@ network:
4847
# logging
4948
logger:
5049
backend: wandb
50+
project_name: torchrl_example_discrete_sac
51+
group_name: null
52+
exp_name: ${env.name}_DiscreteSAC
5153
mode: online
5254
eval_iter: 5000

examples/discrete_sac/discrete_sac.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,19 @@ def main(cfg: "DictConfig"): # noqa: F821
3838
device = torch.device(cfg.network.device)
3939

4040
# Create logger
41-
exp_name = generate_exp_name("DiscreteSAC", cfg.env.exp_name)
41+
exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name)
4242
logger = None
4343
if cfg.logger.backend:
4444
logger = get_logger(
4545
logger_type=cfg.logger.backend,
4646
logger_name="DiscreteSAC_logging",
4747
experiment_name=exp_name,
48-
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
48+
wandb_kwargs={
49+
"mode": cfg.logger.mode,
50+
"config": dict(cfg),
51+
"project": cfg.logger.project_name,
52+
"group": cfg.logger.group_name,
53+
},
4954
)
5055

5156
# Set seeds

examples/dqn/config_atari.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ buffer:
2121

2222
# logger
2323
logger:
24-
backend: null
24+
backend: wandb
25+
project_name: torchrl_example_dqn
26+
group_name: null
2527
exp_name: DQN
2628
test_interval: 1_000_000
2729
num_test_episodes: 3

examples/dqn/config_cartpole.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ buffer:
2020

2121
# logger
2222
logger:
23-
backend: null
23+
backend: wandb
24+
project_name: torchrl_example_dqn
25+
group_name: null
2426
exp_name: DQN
2527
test_interval: 50_000
2628
num_test_episodes: 5

examples/dqn/dqn_atari.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@ def main(cfg: "DictConfig"): # noqa: F821
9999
if cfg.logger.backend:
100100
exp_name = generate_exp_name("DQN", f"Atari_mnih15_{cfg.env.env_name}")
101101
logger = get_logger(
102-
cfg.logger.backend, logger_name="dqn", experiment_name=exp_name
102+
cfg.logger.backend,
103+
logger_name="dqn",
104+
experiment_name=exp_name,
105+
wandb_kwargs={
106+
"config": dict(cfg),
107+
"project": cfg.logger.project_name,
108+
"group": cfg.logger.group_name,
109+
},
103110
)
104111

105112
# Create the test environment

0 commit comments

Comments
 (0)