File tree 2 files changed +17
-2
lines changed
2 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -141,7 +141,10 @@ def visualize_samples(samples: np.ndarray, save_dir):
141
141
142
142
143
143
def process_image_array (img : np .ndarray ) -> np .ndarray :
144
- """Process a numpy array for feeding into PIL.Image.fromarray."""
144
+ """Process a numpy array for feeding into PIL.Image.fromarray.
145
+
146
+ Should already be in (h,w,c) format.
147
+ """
145
148
up_multiplied = img * 255
146
149
clipped = np .clip (up_multiplied , 0 , 255 )
147
150
cast = clipped .astype (np .uint8 )
Original file line number Diff line number Diff line change @@ -381,10 +381,13 @@ def param_f():
381
381
for feature_i in range (num_features ):
382
382
custom_logger .log (f"Feature { feature_i } " )
383
383
384
- np_trans_tens , indices = nmf .vis_dataset_thumbnail (
384
+ dataset_thumbnails , indices = nmf .vis_dataset_thumbnail (
385
385
feature = feature_i , num_mult = 4 , expand_mult = 1
386
386
)
387
387
388
+ # remove opacity channel from dataset thumbnails
389
+ np_trans_tens = dataset_thumbnails [:- 1 , :, :]
390
+
388
391
obs , _ , next_obs = ndarray_to_transition (np_trans_tens )
389
392
390
393
_log_single_transition_wandb (
@@ -399,6 +402,15 @@ def param_f():
399
402
rows ,
400
403
)
401
404
405
+ if img_save_path is not None :
406
+ obs_PIL = array_to_image (obs , vis_scale )
407
+ obs_PIL .save (img_save_path + f"{ feature_i } _obs.png" )
408
+ next_obs_PIL = array_to_image (next_obs , vis_scale )
409
+ next_obs_PIL .save (img_save_path + f"{ feature_i } _next_obs.png" )
410
+ custom_logger .log (
411
+ f"Saved feature { feature_i } viz in dir { img_save_path } ."
412
+ )
413
+
402
414
if pyplot :
403
415
plt .show ()
404
416
custom_logger .log ("Done with visualization." )
You can’t perform that action at this time.
0 commit comments