Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7e0d74d
Updated
DenSumy Nov 28, 2024
aeacb7d
Clean-up and fixed imports.
ViktorM Dec 1, 2024
04f72d9
Merge from master.
ViktorM Dec 1, 2024
7ea2ffb
Fixed device RNN reset issues.
ViktorM Dec 1, 2024
5e5511e
Reverted RNN changes.
ViktorM Dec 1, 2024
9062567
Added fusion. Clean-up.
ViktorM Dec 1, 2024
65c86c8
Clean up.
ViktorM Dec 3, 2024
249b1f0
Improved error handling.
ViktorM Dec 6, 2024
4f6c8d2
More error messages improvements.
ViktorM Dec 6, 2024
7fc8551
Readme update
ViktorM Dec 6, 2024
5fe2dc1
Merge branch 'master' into VM/torch_compile
ViktorM Dec 10, 2024
3009d02
Better error handling.
ViktorM Dec 16, 2024
6819a1d
Merge branch 'VM/torch_compile' of https://github.com/Denys88/rl_game…
ViktorM Dec 16, 2024
f50c142
Updated configs and README.
ViktorM Dec 26, 2024
f1d0704
Updated config.
ViktorM Dec 26, 2024
3bb11eb
Losses clean up and docs.
ViktorM Feb 6, 2025
73f499d
Fixed having a too small default number of games.
ViktorM Feb 6, 2025
bf2fd86
Code clean-up.
ViktorM Feb 6, 2025
0dc6b91
Torch compile is working and accelerates training.
ViktorM Feb 6, 2025
ca03e70
Updated release docs.
ViktorM Feb 10, 2025
d2e210a
Fixed loading weights after torch.compile was added.
ViktorM Feb 11, 2025
d3ec8c0
Fixed python-package.
ViktorM Mar 7, 2025
ff9d12a
Fixed potential memory leak in running mean std calculations.
ViktorM Mar 10, 2025
778d5d8
Perf improvements in running mean std calculations.
ViktorM Mar 10, 2025
fb2496f
More compile optimizations and code improvements.
ViktorM Mar 14, 2025
e991407
Fixed compile training errors. Assymetric training is still not worki…
ViktorM Mar 19, 2025
9df3047
Fix.
ViktorM Mar 19, 2025
7d3e80c
Fix.
ViktorM Mar 21, 2025
4b32736
Fixed CV.
ViktorM Mar 21, 2025
d8ee304
WIP
ViktorM Mar 24, 2025
83ba02f
compile() fixes WIP.
ViktorM Mar 25, 2025
987e962
Revert "compile() fixes WIP."
ViktorM Apr 5, 2025
acd4529
torch.compile() fixes.
ViktorM Apr 5, 2025
5bfbeb8
Fixed loading weights for training when a model was compiled.
ViktorM Apr 18, 2025
c7419fa
Move precission top.
ViktorM Apr 18, 2025
d3b02ca
Merged from master.
ViktorM Jun 17, 2025
a538c65
Fixed gymnasium import. Better Maniskill configs.
ViktorM Jun 18, 2025
d50789c
Fixed maniskill play.
ViktorM Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,22 @@ Explore RL Games quick and easily in colab notebooks:

For maximum training performance a preliminary installation of Pytorch 2.2 or newer with CUDA 12.1 or newer is highly recommended:

```conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia``` or:
```pip install pip3 install torch torchvision```
```bash
pip3 install torch torchvision
```

Then:

```pip install rl-games```
```bash
pip install rl-games
```

Or clone the repo and install the latest version from source :
```bash
pip install -e .
```

To run CPU-based environments either Ray or envpool are required ```pip install envpool``` or ```pip install ray```
To run CPU-based environments either envpool if supported or Ray are required ```pip install envpool``` or ```pip install ray```
To run Mujoco, Atari games or Box2d based environments training they need to be additionally installed with ```pip install gym[mujoco]```, ```pip install gym[atari]``` or ```pip install gym[box2d]``` respectively.

To run Atari also ```pip install opencv-python``` is required. In addition installation of envpool for maximum simulation and training perfromance of Mujoco and Atari environments is highly recommended: ```pip install envpool```
Expand Down Expand Up @@ -114,13 +122,17 @@ And IsaacGymEnvs: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs

*Ant*

```python train.py task=Ant headless=True```
```python train.py task=Ant test=True checkpoint=nn/Ant.pth num_envs=100```
```bash
python train.py task=Ant headless=True
python train.py task=Ant test=True checkpoint=nn/Ant.pth num_envs=100
```

*Humanoid*

```python train.py task=Humanoid headless=True```
```python train.py task=Humanoid test=True checkpoint=nn/Humanoid.pth num_envs=100```
```bash
python train.py task=Humanoid headless=True
python train.py task=Humanoid test=True checkpoint=nn/Humanoid.pth num_envs=100
```

*Shadow Hand block orientation task*

Expand All @@ -131,6 +143,13 @@ And IsaacGymEnvs: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs

*Atari Pong*

```bash
python runner.py --train --file rl_games/configs/atari/ppo_pong.yaml
python runner.py --play --file rl_games/configs/atari/ppo_pong.yaml --checkpoint nn/PongNoFrameskip.pth
```

Or with poetry:

```bash
poetry install -E atari
poetry run python runner.py --train --file rl_games/configs/atari/ppo_pong.yaml
Expand All @@ -140,22 +159,21 @@ poetry run python runner.py --play --file rl_games/configs/atari/ppo_pong.yaml -
*Brax Ant*

```bash
poetry install -E brax
poetry run pip install --upgrade "jax[cuda]==0.3.13" -f https://storage.googleapis.com/jax-releases/jax_releases.html
poetry run python runner.py --train --file rl_games/configs/brax/ppo_ant.yaml
poetry run python runner.py --play --file rl_games/configs/brax/ppo_ant.yaml --checkpoint runs/Ant_brax/nn/Ant_brax.pth
pip install -U "jax[cuda12]"
pip install brax
python runner.py --train --file rl_games/configs/brax/ppo_ant.yaml
python runner.py --play --file rl_games/configs/brax/ppo_ant.yaml --checkpoint runs/Ant_brax/nn/Ant_brax.pth
```

## Experiment tracking

rl_games support experiment tracking with [Weights and Biases](https://wandb.ai).

```bash
poetry install -E atari
poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
WANDB_API_KEY=xxxx poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test --track
poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test -wandb-entity openrlbenchmark --track
python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
WANDB_API_KEY=xxxx python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test --track
python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test -wandb-entity openrlbenchmark --track
```


Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rl_games"
version = "1.6.1"
version = "1.6.5"
description = ""
readme = "README.md"
authors = [
Expand All @@ -9,7 +9,7 @@ authors = [
]

[tool.poetry.dependencies]
python = ">=3.7.1,<3.11"
python = ">=3.7.1"
gym = {version = "^0.23.0", extras = ["classic_control"]}
tensorboard = "^2.8.0"
tensorboardX = "^2.5"
Expand Down
68 changes: 35 additions & 33 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rl_games.common import datasets

from torch import optim
import torch
import torch


class A2CAgent(a2c_common.ContinuousA2CBase):
Expand All @@ -27,39 +27,43 @@ def __init__(self, base_name, params):
a2c_common.ContinuousA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape
build_config = {
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'num_seqs' : self.num_actors * self.num_agents,
'value_size': self.env_info.get('value_size',1),
'normalize_value' : self.normalize_value,
'actions_num': self.actions_num,
'input_shape': obs_shape,
'num_seqs': self.num_actors * self.num_agents,
'value_size': self.env_info.get('value_size', 1),
'normalize_value': self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)
self.states = None
self.init_rnn_from_model(self.model)
self.last_lr = float(self.last_lr)
self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularisation' or 'bound'
self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)
self.optimizer = optim.Adam(self.model.parameters(),
float(self.last_lr),
eps=1e-08,
weight_decay=self.weight_decay,
fused=True)

if self.has_central_value:
cv_config = {
'state_shape' : self.state_shape,
'value_size' : self.value_size,
'ppo_device' : self.ppo_device,
'num_agents' : self.num_agents,
'horizon_length' : self.horizon_length,
'num_actors' : self.num_actors,
'num_actions' : self.actions_num,
'seq_length' : self.seq_length,
'normalize_value' : self.normalize_value,
'network' : self.central_value_config['network'],
'config' : self.central_value_config,
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'multi_gpu' : self.multi_gpu,
'zero_rnn_on_done' : self.zero_rnn_on_done
'state_shape': self.state_shape,
'value_size': self.value_size,
'ppo_device': self.ppo_device,
'num_agents': self.num_agents,
'horizon_length': self.horizon_length,
'num_actors': self.num_actors,
'num_actions': self.actions_num,
'seq_length': self.seq_length,
'normalize_value': self.normalize_value,
'network': self.central_value_config['network'],
'config': self.central_value_config,
'writter': self.writer,
'max_epochs': self.max_epochs,
'multi_gpu': self.multi_gpu,
'zero_rnn_on_done': self.zero_rnn_on_done
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

Expand All @@ -74,7 +78,7 @@ def __init__(self, base_name, params):
def update_epoch(self):
self.epoch_num += 1
return self.epoch_num

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
Expand All @@ -88,7 +92,7 @@ def restore_central_value_function(self, fn):
self.set_central_value_function_weights(checkpoint)

def get_masked_action_values(self, obs, action_masks):
assert False
raise NotImplementedError("Masked action values are not implemented for continuous actions")

def calc_gradients(self, input_dict):
"""Compute gradients needed to step the networks of the algorithm.
Expand All @@ -114,8 +118,8 @@ def calc_gradients(self, input_dict):

batch_dict = {
'is_train': True,
'prev_actions': actions_batch,
'obs' : obs_batch,
'prev_actions': actions_batch,
'obs': obs_batch,
}

rnn_masks = None
Expand All @@ -125,9 +129,9 @@ def calc_gradients(self, input_dict):
batch_dict['seq_length'] = self.seq_length

if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']
batch_dict['dones'] = input_dict['dones']

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
with torch.amp.autocast('cuda', enabled=self.mixed_precision):
res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['values']
Expand All @@ -138,7 +142,7 @@ def calc_gradients(self, input_dict):
a_loss = self.actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)

if self.has_value_loss:
c_loss = common_losses.critic_loss(self.model,value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
c_loss = common_losses.critic_loss(self.model, value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
else:
c_loss = torch.zeros(1, device=self.ppo_device)
if self.bound_loss_type == 'regularisation':
Expand Down Expand Up @@ -183,7 +187,7 @@ def calc_gradients(self, input_dict):
'new_neglogp' : action_log_probs,
'old_neglogp' : old_action_log_probs_batch,
'masks' : rnn_masks
}, curr_e_clip, 0)
}, curr_e_clip, 0)

self.train_result = (a_loss, c_loss, entropy, \
kl_dist, self.last_lr, lr_mul, \
Expand All @@ -209,5 +213,3 @@ def bound_loss(self, mu):
else:
b_loss = 0
return b_loss


Loading