Skip to content

Commit ce65245

Browse files
committed
Add missing function to utils
1 parent 7fcc4fe commit ce65245

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch as th
99
from torch import nn as nn
1010
from torch.utils import data as torch_data
11+
import vegans.utils
1112

1213

1314
def make_transition_to_tensor(num_acts):
@@ -206,3 +207,9 @@ def forward(self, transition_tensor: th.Tensor) -> th.Tensor:
206207

207208
dones = th.zeros_like(obs[:, 0])
208209
return self.rew_net(state=obs, action=act, next_state=next_obs, done=dones)
210+
211+
212+
def save_loss_plots(losses, save_dir):
213+
"""Save plots of generator/adversary losses over training."""
214+
fig, _ = vegans.utils.plot_losses(losses, show=False)
215+
fig.savefig(Path(save_dir) / "loss_fig.png")

0 commit comments

Comments
 (0)