diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 8ecb675535b..04c6ab7e0d5 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -168,12 +168,18 @@ def update(batch, num_network_updates): num_network_updates = num_network_updates + 1 # Get a data batch batch = batch.to(device, non_blocking=True) + def forward(batch, num_network_updates): + + # Forward pass PPO loss + loss = loss_module(batch) + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + return loss, loss_sum + + loss, loss_sum = torch.compile(forward, backend="inductor", mode="reduce-overhead")(batch, num_network_updates) - # Forward pass PPO loss - loss = loss_module(batch) - loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] # Backward pass loss_sum.backward() + torch.nn.utils.clip_grad_norm_( loss_module.parameters(), max_norm=cfg_optim_max_grad_norm ) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 1d541750ec2..71038a6ff1d 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -985,9 +985,10 @@ def count_and_compile(*model_args, **model_kwargs): nonlocal count nonlocal compiled_model count += 1 - if count == warmup: - compiled_model = torch.compile(model, *args, **kwargs) - return compiled_model(*model_args, **model_kwargs) + #if count == warmup: + # compiled_model = torch.compile(model, fullgraph=True, backend="inductor") + out = compiled_model(*model_args, **model_kwargs) + return out return count_and_compile