Skip to content

Commit 283e84b

Browse files
authored
Merge pull request #22 from HumanCompatibleAI/try-feature-vis
Action labels and overview visualization
2 parents b694993 + 6d7f7d2 commit 283e84b

File tree

6 files changed

+212
-75
lines changed

6 files changed

+212
-75
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,16 @@ def forward(self, latent_tens: th.Tensor):
234234
return self.reward_net.forward(obs, action_vec, next_obs, done)
235235

236236

237-
def log_np_img_wandb(
238-
arr: np.ndarray,
237+
def log_img_wandb(
238+
img: Union[np.ndarray, PIL.Image.Image],
239239
caption: str,
240240
wandb_key: str,
241241
logger: HierarchicalLogger,
242242
scale: int = 1,
243243
step: Optional[int] = None,
244244
) -> None:
245-
"""Log visualized np.ndarray to wandb using given logger.
245+
"""Log np.ndarray as image or PIL image. Logs to wandb using given logger.
246+
If scale is provided, image will be scaled in both cases.
246247
247248
Args:
248249
- arr: Array to turn into image, save.
@@ -253,8 +254,17 @@ def log_np_img_wandb(
253254
- step: Step for logging. If not provided, the logger dumping will be skipped.
254255
In that case logs will be dumped with the next dump().
255256
"""
256-
257-
pil_img = array_to_image(arr, scale)
257+
if isinstance(img, np.ndarray):
258+
pil_img = array_to_image(img, scale)
259+
elif isinstance(img, PIL.Image.Image):
260+
pil_img = img.resize(
261+
# PIL expects tuple of (width, height), as opposed to numpy's
262+
# (height, width).
263+
size=(img.width * scale, img.height * scale),
264+
resample=Image.NEAREST,
265+
)
266+
else:
267+
raise ValueError(f"img must be np.ndarray or PIL.Image.Image, {type(img)=}")
258268
wb_img = wandb.Image(pil_img, caption=caption)
259269
logger.record(wandb_key, wb_img)
260270
if step is not None:
@@ -264,7 +274,8 @@ def log_np_img_wandb(
264274
def array_to_image(arr: np.ndarray, scale: int) -> PIL.Image.Image:
265275
"""Take numpy array on [0,1] scale, return PIL image."""
266276
return Image.fromarray(np.uint8(arr * 255), mode="RGB").resize(
267-
# PIL expects tuple of (width, height), numpy's index 1 is width, 0 height.
277+
# PIL expects tuple of (width, height), numpy's dimension 1 is width, and
278+
# dimension 0 height.
268279
size=(arr.shape[1] * scale, arr.shape[0] * scale),
269280
resample=Image.NEAREST,
270281
)

0 commit comments

Comments
 (0)