Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 27, 2024
1 parent 412343d commit 7d70240
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions docs/visualize-embeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@
"outputs": [],
"source": [
"# As we want to visualize the embeddings from the model, we neither mask the input image or shuffle the patches\n",
"module = ClayMAEModule.load_from_checkpoint(checkpoint_path=CHECKPOINT_PATH,\n",
" metadata_path=METADATA_PATH,\n",
" mask_ratio=0., \n",
" shuffle=False)\n",
"module = ClayMAEModule.load_from_checkpoint(\n",
" checkpoint_path=CHECKPOINT_PATH,\n",
" metadata_path=METADATA_PATH,\n",
" mask_ratio=0.0,\n",
" shuffle=False,\n",
")\n",
"\n",
"module.eval();"
]
Expand Down Expand Up @@ -103,15 +105,15 @@
}
],
"source": [
"# For model training, we stack chips from one sensor into batches of size 128. \n",
"# For model training, we stack chips from one sensor into batches of size 128.\n",
"# This reduces the num_workers we need to load the batches & speeds up the training process.\n",
"# Here, although the batch size is 1, the data module reads batch of size 128.\n",
"dm = ClayDataModule(\n",
" data_dir=DATA_DIR,\n",
" metadata_path=METADATA_PATH,\n",
" size=CHIP_SIZE,\n",
" batch_size=1,\n",
" num_workers=1\n",
" num_workers=1,\n",
")\n",
"dm.setup(stage=\"fit\")"
]
Expand Down Expand Up @@ -350,8 +352,15 @@
}
],
"source": [
"for sensor, chips in zip((\"l1\", \"l2\", \"linz\", \"naip\", \"s1\", \"s2\"), (l1, l2, linz, naip, s1, s2)):\n",
" print(f\"{chips['platform'][0]:<15}\", chips[\"pixels\"].shape, chips[\"time\"].shape, chips[\"latlon\"].shape)"
"for sensor, chips in zip(\n",
" (\"l1\", \"l2\", \"linz\", \"naip\", \"s1\", \"s2\"), (l1, l2, linz, naip, s1, s2)\n",
"):\n",
" print(\n",
" f\"{chips['platform'][0]:<15}\",\n",
" chips[\"pixels\"].shape,\n",
" chips[\"time\"].shape,\n",
" chips[\"latlon\"].shape,\n",
" )"
]
},
{
Expand Down Expand Up @@ -382,14 +391,14 @@
"source": [
"def create_batch(chips, wavelengths, gsd, device):\n",
" batch = {}\n",
" \n",
"\n",
" batch[\"pixels\"] = chips[\"pixels\"].to(device)\n",
" batch[\"time\"] = chips[\"time\"].to(device)\n",
" batch[\"latlon\"] = chips[\"latlon\"].to(device)\n",
" \n",
"\n",
" batch[\"waves\"] = torch.tensor(wavelengths)\n",
" batch[\"gsd\"] = torch.tensor(gsd)\n",
" \n",
"\n",
" return batch"
]
},
Expand Down Expand Up @@ -550,8 +559,8 @@
"source": [
"fig, axs = plt.subplots(3, 8, figsize=(20, 8))\n",
"\n",
"for idx,ax in enumerate(axs.flatten()):\n",
" ax.imshow(batch_naip_pixels[idx, :3,...].transpose(1,2,0))\n",
"for idx, ax in enumerate(axs.flatten()):\n",
" ax.imshow(batch_naip_pixels[idx, :3, ...].transpose(1, 2, 0))\n",
" ax.set_axis_off()\n",
" ax.set_title(idx)"
]
Expand All @@ -576,10 +585,9 @@
"metadata": {},
"outputs": [],
"source": [
"unmsk_embed = rearrange(unmsk_patch_naip[:,1:,:].detach().cpu().numpy(), \n",
" \"b (h w) d-> b d h w\",\n",
" h=28,\n",
" w=28)"
"unmsk_embed = rearrange(\n",
" unmsk_patch_naip[:, 1:, :].detach().cpu().numpy(), \"b (h w) d-> b d h w\", h=28, w=28\n",
")"
]
},
{
Expand Down Expand Up @@ -608,10 +616,10 @@
}
],
"source": [
"embed = unmsk_embed[3] # 3 is randomly picked chip\n",
"embed = unmsk_embed[3] # 3 is randomly picked chip\n",
"fig, axs = plt.subplots(16, 16, figsize=(20, 20))\n",
"\n",
"for idx,ax in enumerate(axs.flatten()):\n",
"for idx, ax in enumerate(axs.flatten()):\n",
" ax.imshow(embed[idx], cmap=\"bwr\")\n",
" ax.set_axis_off()\n",
" ax.set_title(idx)\n",
Expand Down Expand Up @@ -646,18 +654,18 @@
],
"source": [
"fig, axs = plt.subplots(6, 8, figsize=(20, 14))\n",
"embed_dim = 97 # pick any embedding dimension\n",
"embed_dim = 97 # pick any embedding dimension\n",
"\n",
"for i in range(0, 6, 2):\n",
" for j in range(8):\n",
" idx = (i//2)*8+j\n",
" axs[i][j].imshow(batch_naip_pixels[idx, :3,...].transpose(1,2,0))\n",
" idx = (i // 2) * 8 + j\n",
" axs[i][j].imshow(batch_naip_pixels[idx, :3, ...].transpose(1, 2, 0))\n",
" axs[i][j].set_axis_off()\n",
" axs[i][j].set_title(f\"Image {idx}\")\n",
" embed = unmsk_embed[idx]\n",
" axs[i+1][j].imshow(embed[embed_dim], cmap=\"gray\")\n",
" axs[i+1][j].set_axis_off()\n",
" axs[i+1][j].set_title(f\"Embed {idx}\")"
" axs[i + 1][j].imshow(embed[embed_dim], cmap=\"gray\")\n",
" axs[i + 1][j].set_axis_off()\n",
" axs[i + 1][j].set_title(f\"Embed {idx}\")"
]
},
{
Expand Down Expand Up @@ -704,8 +712,10 @@
"source": [
"fig, axs = plt.subplots(3, 8, figsize=(20, 8))\n",
"\n",
"for idx,ax in enumerate(axs.flatten()):\n",
" ax.imshow(np.clip(batch_s2_pixels[idx, [2, 1, 0],...].transpose(1,2,0)/2000, 0, 1))\n",
"for idx, ax in enumerate(axs.flatten()):\n",
" ax.imshow(\n",
" np.clip(batch_s2_pixels[idx, [2, 1, 0], ...].transpose(1, 2, 0) / 2000, 0, 1)\n",
" )\n",
" ax.set_axis_off()\n",
" ax.set_title(idx)"
]
Expand All @@ -717,10 +727,9 @@
"metadata": {},
"outputs": [],
"source": [
"unmsk_embed_s2 = rearrange(unmsk_patch_s2[:,1:,:].detach().cpu().numpy(), \n",
" \"b (h w) d-> b d h w\",\n",
" h=28,\n",
" w=28)"
"unmsk_embed_s2 = rearrange(\n",
" unmsk_patch_s2[:, 1:, :].detach().cpu().numpy(), \"b (h w) d-> b d h w\", h=28, w=28\n",
")"
]
},
{
Expand All @@ -744,7 +753,7 @@
"embed_s2 = unmsk_embed_s2[8]\n",
"fig, axs = plt.subplots(16, 16, figsize=(20, 20))\n",
"\n",
"for idx,ax in enumerate(axs.flatten()):\n",
"for idx, ax in enumerate(axs.flatten()):\n",
" ax.imshow(embed_s2[idx], cmap=\"bwr\")\n",
" ax.set_axis_off()\n",
" ax.set_title(idx)\n",
Expand Down Expand Up @@ -774,14 +783,18 @@
"\n",
"for i in range(0, 6, 2):\n",
" for j in range(8):\n",
" idx = (i//2)*8+j\n",
" axs[i][j].imshow(np.clip(batch_s2_pixels[idx,[2, 1, 0],...].transpose(1,2,0)/2000, 0, 1))\n",
" idx = (i // 2) * 8 + j\n",
" axs[i][j].imshow(\n",
" np.clip(\n",
" batch_s2_pixels[idx, [2, 1, 0], ...].transpose(1, 2, 0) / 2000, 0, 1\n",
" )\n",
" )\n",
" axs[i][j].set_axis_off()\n",
" axs[i][j].set_title(f\"Image {idx}\")\n",
" embed_s2 = unmsk_embed_s2[idx]\n",
" axs[i+1][j].imshow(embed_s2[embed_dim], cmap=\"gray\")\n",
" axs[i+1][j].set_axis_off()\n",
" axs[i+1][j].set_title(f\"Embed {idx}\")"
" axs[i + 1][j].imshow(embed_s2[embed_dim], cmap=\"gray\")\n",
" axs[i + 1][j].set_axis_off()\n",
" axs[i + 1][j].set_title(f\"Embed {idx}\")"
]
},
{
Expand Down

0 comments on commit 7d70240

Please sign in to comment.