We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7fcc4fe commit ce65245Copy full SHA for ce65245
src/reward_preprocessing/common/utils.py
@@ -8,6 +8,7 @@
8
import torch as th
9
from torch import nn as nn
10
from torch.utils import data as torch_data
11
+import vegans.utils
12
13
14
def make_transition_to_tensor(num_acts):
@@ -206,3 +207,9 @@ def forward(self, transition_tensor: th.Tensor) -> th.Tensor:
206
207
208
dones = th.zeros_like(obs[:, 0])
209
return self.rew_net(state=obs, action=act, next_state=next_obs, done=dones)
210
+
211
212
+def save_loss_plots(losses, save_dir):
213
+ """Save plots of generator/adversary losses over training."""
214
+ fig, _ = vegans.utils.plot_losses(losses, show=False)
215
+ fig.savefig(Path(save_dir) / "loss_fig.png")
0 commit comments