File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -362,10 +362,14 @@ def interpret(
362
362
for feature_i in range (num_features ):
363
363
custom_logger .log (f"Feature { feature_i } " )
364
364
365
- np_trans_tens , indices = nmf .vis_dataset_thumbnail (
365
+ dataset_thumbnails , indices = nmf .vis_dataset_thumbnail (
366
366
feature = feature_i , num_mult = 4 , expand_mult = 1
367
367
)
368
368
369
+ # remove opacity channel from dataset thumbnails
370
+ num_channels_total = dataset_thumbnails .shape [0 ]
371
+ np_trans_tens = dataset_thumbnails [0 : num_channels_total - 1 , :, :]
372
+
369
373
obs , _ , next_obs = ndarray_to_transition (np_trans_tens )
370
374
371
375
_log_single_transition_wandb (
@@ -380,6 +384,15 @@ def interpret(
380
384
rows ,
381
385
)
382
386
387
+ if img_save_path is not None :
388
+ obs_PIL = array_to_image (obs , vis_scale )
389
+ obs_PIL .save (img_save_path + f"{ feature_i } _obs.png" )
390
+ next_obs_PIL = array_to_image (next_obs , vis_scale )
391
+ next_obs_PIL .save (img_save_path + f"{ feature_i } _next_obs.png" )
392
+ custom_logger .log (
393
+ f"Saved feature { feature_i } viz in dir { img_save_path } ."
394
+ )
395
+
383
396
if pyplot :
384
397
plt .show ()
385
398
custom_logger .log ("Done with visualization." )
You can’t perform that action at this time.
0 commit comments