@@ -234,15 +234,16 @@ def forward(self, latent_tens: th.Tensor):
234
234
return self .reward_net .forward (obs , action_vec , next_obs , done )
235
235
236
236
237
- def log_np_img_wandb (
238
- arr : np .ndarray ,
237
+ def log_img_wandb (
238
+ img : Union [ np .ndarray , PIL . Image . Image ] ,
239
239
caption : str ,
240
240
wandb_key : str ,
241
241
logger : HierarchicalLogger ,
242
242
scale : int = 1 ,
243
243
step : Optional [int ] = None ,
244
244
) -> None :
245
- """Log visualized np.ndarray to wandb using given logger.
245
+ """Log np.ndarray as image or PIL image. Logs to wandb using given logger.
246
+ If scale is provided, image will be scaled in both cases.
246
247
247
248
Args:
248
249
- arr: Array to turn into image, save.
@@ -253,8 +254,17 @@ def log_np_img_wandb(
253
254
- step: Step for logging. If not provided, the logger dumping will be skipped.
254
255
In that case logs will be dumped with the next dump().
255
256
"""
256
-
257
- pil_img = array_to_image (arr , scale )
257
+ if isinstance (img , np .ndarray ):
258
+ pil_img = array_to_image (img , scale )
259
+ elif isinstance (img , PIL .Image .Image ):
260
+ pil_img = img .resize (
261
+ # PIL expects tuple of (width, height), as opposed to numpy's
262
+ # (height, width).
263
+ size = (img .width * scale , img .height * scale ),
264
+ resample = Image .NEAREST ,
265
+ )
266
+ else :
267
+ raise ValueError (f"img must be np.ndarray or PIL.Image.Image, { type (img )= } " )
258
268
wb_img = wandb .Image (pil_img , caption = caption )
259
269
logger .record (wandb_key , wb_img )
260
270
if step is not None :
@@ -264,7 +274,8 @@ def log_np_img_wandb(
264
274
def array_to_image (arr : np .ndarray , scale : int ) -> PIL .Image .Image :
265
275
"""Take numpy array on [0,1] scale, return PIL image."""
266
276
return Image .fromarray (np .uint8 (arr * 255 ), mode = "RGB" ).resize (
267
- # PIL expects tuple of (width, height), numpy's index 1 is width, 0 height.
277
+ # PIL expects tuple of (width, height), numpy's dimension 1 is width, and
278
+ # dimension 0 height.
268
279
size = (arr .shape [1 ] * scale , arr .shape [0 ] * scale ),
269
280
resample = Image .NEAREST ,
270
281
)
0 commit comments