Skip to content

Commit c8b5da1

Browse files
committed
Add script to visualize GAN samples
1 parent 1d68d50 commit c8b5da1

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch as th
2+
3+
from reward_preprocessing.common import utils
4+
5+
GAN_TIMESTAMP = "20221104_163134"
6+
MODEL_NUMBER = "13720"
7+
8+
if __name__ == "__main__":
9+
gan_path = (
10+
"/nas/ucb/daniel/gan_test_data_"
11+
+ GAN_TIMESTAMP
12+
+ "/models/model_"
13+
+ MODEL_NUMBER
14+
+ ".torch"
15+
)
16+
device = "cuda" if th.cuda.is_available() else "cpu"
17+
gan = th.load(gan_path, map_location=th.device(device))
18+
samples, _ = gan.get_training_results()
19+
utils.visualize_samples(samples.detach().cpu().numpy(), gan.folder)

0 commit comments

Comments
 (0)