From 7d70240955943af73096fd6850f5feee9b6e06ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 May 2024 11:27:07 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/visualize-embeddings.ipynb | 87 +++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 37 deletions(-) diff --git a/docs/visualize-embeddings.ipynb b/docs/visualize-embeddings.ipynb index 5ce19e6a..89b12b27 100644 --- a/docs/visualize-embeddings.ipynb +++ b/docs/visualize-embeddings.ipynb @@ -70,10 +70,12 @@ "outputs": [], "source": [ "# As we want to visualize the embeddings from the model, we neither mask the input image or shuffle the patches\n", - "module = ClayMAEModule.load_from_checkpoint(checkpoint_path=CHECKPOINT_PATH,\n", - " metadata_path=METADATA_PATH,\n", - " mask_ratio=0., \n", - " shuffle=False)\n", + "module = ClayMAEModule.load_from_checkpoint(\n", + " checkpoint_path=CHECKPOINT_PATH,\n", + " metadata_path=METADATA_PATH,\n", + " mask_ratio=0.0,\n", + " shuffle=False,\n", + ")\n", "\n", "module.eval();" ] @@ -103,7 +105,7 @@ } ], "source": [ - "# For model training, we stack chips from one sensor into batches of size 128. \n", + "# For model training, we stack chips from one sensor into batches of size 128.\n", "# This reduces the num_workers we need to load the batches & speeds up the training process.\n", "# Here, although the batch size is 1, the data module reads batch of size 128.\n", "dm = ClayDataModule(\n", @@ -111,7 +113,7 @@ " metadata_path=METADATA_PATH,\n", " size=CHIP_SIZE,\n", " batch_size=1,\n", - " num_workers=1\n", + " num_workers=1,\n", ")\n", "dm.setup(stage=\"fit\")" ] @@ -350,8 +352,15 @@ } ], "source": [ - "for sensor, chips in zip((\"l1\", \"l2\", \"linz\", \"naip\", \"s1\", \"s2\"), (l1, l2, linz, naip, s1, s2)):\n", - " print(f\"{chips['platform'][0]:<15}\", chips[\"pixels\"].shape, chips[\"time\"].shape, chips[\"latlon\"].shape)" + "for sensor, chips in zip(\n", + " (\"l1\", \"l2\", \"linz\", \"naip\", \"s1\", \"s2\"), (l1, l2, linz, naip, s1, s2)\n", + "):\n", + " print(\n", + " f\"{chips['platform'][0]:<15}\",\n", + " chips[\"pixels\"].shape,\n", + " chips[\"time\"].shape,\n", + " chips[\"latlon\"].shape,\n", + " )" ] }, { @@ -382,14 +391,14 @@ "source": [ "def create_batch(chips, wavelengths, gsd, device):\n", " batch = {}\n", - " \n", + "\n", " batch[\"pixels\"] = chips[\"pixels\"].to(device)\n", " batch[\"time\"] = chips[\"time\"].to(device)\n", " batch[\"latlon\"] = chips[\"latlon\"].to(device)\n", - " \n", + "\n", " batch[\"waves\"] = torch.tensor(wavelengths)\n", " batch[\"gsd\"] = torch.tensor(gsd)\n", - " \n", + "\n", " return batch" ] }, @@ -550,8 +559,8 @@ "source": [ "fig, axs = plt.subplots(3, 8, figsize=(20, 8))\n", "\n", - "for idx,ax in enumerate(axs.flatten()):\n", - " ax.imshow(batch_naip_pixels[idx, :3,...].transpose(1,2,0))\n", + "for idx, ax in enumerate(axs.flatten()):\n", + " ax.imshow(batch_naip_pixels[idx, :3, ...].transpose(1, 2, 0))\n", " ax.set_axis_off()\n", " ax.set_title(idx)" ] @@ -576,10 +585,9 @@ "metadata": {}, "outputs": [], "source": [ - "unmsk_embed = rearrange(unmsk_patch_naip[:,1:,:].detach().cpu().numpy(), \n", - " \"b (h w) d-> b d h w\",\n", - " h=28,\n", - " w=28)" + "unmsk_embed = rearrange(\n", + " unmsk_patch_naip[:, 1:, :].detach().cpu().numpy(), \"b (h w) d-> b d h w\", h=28, w=28\n", + ")" ] }, { @@ -608,10 +616,10 @@ } ], "source": [ - "embed = unmsk_embed[3] # 3 is randomly picked chip\n", + "embed = unmsk_embed[3] # 3 is randomly picked chip\n", "fig, axs = plt.subplots(16, 16, figsize=(20, 20))\n", "\n", - "for idx,ax in enumerate(axs.flatten()):\n", + "for idx, ax in enumerate(axs.flatten()):\n", " ax.imshow(embed[idx], cmap=\"bwr\")\n", " ax.set_axis_off()\n", " ax.set_title(idx)\n", @@ -646,18 +654,18 @@ ], "source": [ "fig, axs = plt.subplots(6, 8, figsize=(20, 14))\n", - "embed_dim = 97 # pick any embedding dimension\n", + "embed_dim = 97 # pick any embedding dimension\n", "\n", "for i in range(0, 6, 2):\n", " for j in range(8):\n", - " idx = (i//2)*8+j\n", - " axs[i][j].imshow(batch_naip_pixels[idx, :3,...].transpose(1,2,0))\n", + " idx = (i // 2) * 8 + j\n", + " axs[i][j].imshow(batch_naip_pixels[idx, :3, ...].transpose(1, 2, 0))\n", " axs[i][j].set_axis_off()\n", " axs[i][j].set_title(f\"Image {idx}\")\n", " embed = unmsk_embed[idx]\n", - " axs[i+1][j].imshow(embed[embed_dim], cmap=\"gray\")\n", - " axs[i+1][j].set_axis_off()\n", - " axs[i+1][j].set_title(f\"Embed {idx}\")" + " axs[i + 1][j].imshow(embed[embed_dim], cmap=\"gray\")\n", + " axs[i + 1][j].set_axis_off()\n", + " axs[i + 1][j].set_title(f\"Embed {idx}\")" ] }, { @@ -704,8 +712,10 @@ "source": [ "fig, axs = plt.subplots(3, 8, figsize=(20, 8))\n", "\n", - "for idx,ax in enumerate(axs.flatten()):\n", - " ax.imshow(np.clip(batch_s2_pixels[idx, [2, 1, 0],...].transpose(1,2,0)/2000, 0, 1))\n", + "for idx, ax in enumerate(axs.flatten()):\n", + " ax.imshow(\n", + " np.clip(batch_s2_pixels[idx, [2, 1, 0], ...].transpose(1, 2, 0) / 2000, 0, 1)\n", + " )\n", " ax.set_axis_off()\n", " ax.set_title(idx)" ] @@ -717,10 +727,9 @@ "metadata": {}, "outputs": [], "source": [ - "unmsk_embed_s2 = rearrange(unmsk_patch_s2[:,1:,:].detach().cpu().numpy(), \n", - " \"b (h w) d-> b d h w\",\n", - " h=28,\n", - " w=28)" + "unmsk_embed_s2 = rearrange(\n", + " unmsk_patch_s2[:, 1:, :].detach().cpu().numpy(), \"b (h w) d-> b d h w\", h=28, w=28\n", + ")" ] }, { @@ -744,7 +753,7 @@ "embed_s2 = unmsk_embed_s2[8]\n", "fig, axs = plt.subplots(16, 16, figsize=(20, 20))\n", "\n", - "for idx,ax in enumerate(axs.flatten()):\n", + "for idx, ax in enumerate(axs.flatten()):\n", " ax.imshow(embed_s2[idx], cmap=\"bwr\")\n", " ax.set_axis_off()\n", " ax.set_title(idx)\n", @@ -774,14 +783,18 @@ "\n", "for i in range(0, 6, 2):\n", " for j in range(8):\n", - " idx = (i//2)*8+j\n", - " axs[i][j].imshow(np.clip(batch_s2_pixels[idx,[2, 1, 0],...].transpose(1,2,0)/2000, 0, 1))\n", + " idx = (i // 2) * 8 + j\n", + " axs[i][j].imshow(\n", + " np.clip(\n", + " batch_s2_pixels[idx, [2, 1, 0], ...].transpose(1, 2, 0) / 2000, 0, 1\n", + " )\n", + " )\n", " axs[i][j].set_axis_off()\n", " axs[i][j].set_title(f\"Image {idx}\")\n", " embed_s2 = unmsk_embed_s2[idx]\n", - " axs[i+1][j].imshow(embed_s2[embed_dim], cmap=\"gray\")\n", - " axs[i+1][j].set_axis_off()\n", - " axs[i+1][j].set_title(f\"Embed {idx}\")" + " axs[i + 1][j].imshow(embed_s2[embed_dim], cmap=\"gray\")\n", + " axs[i + 1][j].set_axis_off()\n", + " axs[i + 1][j].set_title(f\"Embed {idx}\")" ] }, {