Skip to content

Commit 3566f6b

Browse files
committed
Fix type for visualizing samples
1 parent f291c2f commit 3566f6b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/reward_preprocessing/scripts/train_gan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def train_gan(
101101
# save samples, return losses, save plot of losses
102102
samples, losses = gan.get_training_results()
103103
utils.save_loss_plots(losses, gan.folder)
104-
utils.visualize_samples(samples.detach().cpu().numpy(), gan.folder)
104+
utils.visualize_samples(samples, gan.folder)
105105
return losses
106106

107107

0 commit comments

Comments
 (0)