|
1 |
| -""" |
2 |
| -Lightning callback functions for logging to Weights & Biases. |
| 1 | +from itertools import islice |
3 | 2 |
|
4 |
| -Includes a way to visualize RGB images derived from the raw logits of a Masked |
5 |
| -Autoencoder's decoder during the validation loop. I.e. to see if the Vision |
6 |
| -Transformer model is learning how to do image reconstruction. |
7 |
| -
|
8 |
| -Usage: |
9 |
| -
|
10 |
| -``` |
11 |
| -import lightning as L |
12 |
| -
|
13 |
| -from src.callbacks_wandb import LogMAEReconstruction |
14 |
| -
|
15 |
| -trainer = L.Trainer( |
16 |
| - ..., |
17 |
| - callbacks=[LogMAEReconstruction(num_samples=6)] |
18 |
| -) |
19 |
| -``` |
20 |
| -
|
21 |
| -References: |
22 |
| -- https://lightning.ai/docs/pytorch/2.1.0/common/trainer.html#callbacks |
23 |
| -- https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY |
24 |
| -- https://github.com/ashleve/lightning-hydra-template/blob/wandb-callbacks/src/callbacks/wandb_callbacks.py#L245 |
25 |
| -""" |
26 | 3 | import lightning as L
|
27 | 4 | import matplotlib.pyplot as plt
|
28 | 5 | import numpy as np
|
29 |
| -import skimage |
30 | 6 | import torch
|
31 | 7 | from einops import rearrange
|
32 | 8 |
|
@@ -59,180 +35,157 @@ def get_wandb_logger(trainer: L.Trainer) -> L.pytorch.loggers.WandbLogger:
|
59 | 35 | )
|
60 | 36 |
|
61 | 37 |
|
62 |
| -class LogMAEReconstruction(L.Callback): |
63 |
| - """ |
64 |
| - Logs reconstructed RGB images from a Masked Autoencoder's decoder to WandB. |
65 |
| - """ |
66 |
| - |
67 |
| - def __init__(self, num_samples: int = 8): |
68 |
| - """ |
69 |
| - Define how many sample images to log. |
70 |
| -
|
71 |
| - Parameters |
72 |
| - ---------- |
73 |
| - num_samples : int |
74 |
| - The number of RGB image samples to upload to WandB. Default is 8. |
75 |
| - """ |
76 |
| - super().__init__() |
77 |
| - self.num_samples: int = num_samples |
78 |
| - self.ready: bool = False |
79 |
| - |
80 |
| - if wandb is None: |
81 |
| - raise ModuleNotFoundError( |
82 |
| - "Package `wandb` is required to be installed to use this callback. " |
83 |
| - "Please use `pip install wandb` or " |
84 |
| - "`conda install -c conda-forge wandb` " |
85 |
| - "to install the package" |
86 |
| - ) |
87 |
| - |
88 |
| - def on_sanity_check_start(self, trainer, pl_module): |
89 |
| - """ |
90 |
| - Don't execute callback before validation sanity checks are completed. |
91 |
| - """ |
92 |
| - self.ready = False |
93 |
| - |
94 |
| - def on_sanity_check_end(self, trainer, pl_module): |
95 |
| - """ |
96 |
| - Start executing callback only after all validation sanity checks end. |
97 |
| - """ |
98 |
| - self.ready = True |
99 |
| - |
100 |
| - def on_validation_batch_end( |
101 |
| - self, |
102 |
| - trainer: L.Trainer, |
103 |
| - pl_module: L.LightningModule, |
104 |
| - outputs: dict[str, torch.Tensor], |
105 |
| - batch: dict[str, torch.Tensor | list[str]], |
106 |
| - batch_idx: int, |
107 |
| - ) -> list: |
108 |
| - """ |
109 |
| - Called in the validation loop at the start of every mini-batch. |
110 |
| -
|
111 |
| - Gather a sample of data from the first mini-batch, get the RGB bands, |
112 |
| - apply histogram equalization to the image, and log it to WandB. |
113 |
| - """ |
114 |
| - if self.ready and batch_idx == 0: # only run on first mini-batch |
115 |
| - with torch.inference_mode(): |
116 |
| - # Get WandB logger |
117 |
| - self.logger = get_wandb_logger(trainer=trainer) |
118 |
| - |
119 |
| - # Turn raw logits into reconstructed 512x512 images |
120 |
| - patchified_pixel_values: torch.Tensor = outputs["logits"] |
121 |
| - # assert patchified_pixel_values.shape == torch.Size([32, 64, 53248]) |
122 |
| - y_hat: torch.Tensor = pl_module.vit.unpatchify( |
123 |
| - patchified_pixel_values=patchified_pixel_values |
124 |
| - ) |
125 |
| - # assert y_hat.shape == torch.Size([32, 13, 512, 512]) |
126 |
| - |
127 |
| - # Reshape tensors from channel-first to channel-last |
128 |
| - x: torch.Tensor = torch.einsum( |
129 |
| - "bchw->bhwc", batch["image"][: self.num_samples] |
130 |
| - ) |
131 |
| - y_hat: torch.Tensor = torch.einsum( |
132 |
| - "bchw->bhwc", y_hat[: self.num_samples] |
133 |
| - ) |
134 |
| - # assert y_hat.shape == torch.Size([8, 512, 512, 13]) |
135 |
| - assert x.shape == y_hat.shape |
136 |
| - |
137 |
| - # Plot original and reconstructed RGB images of Sentinel-2 |
138 |
| - rgb_original: np.ndarray = ( |
139 |
| - x[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy() |
140 |
| - ) |
141 |
| - rgb_reconstruction: np.ndarray = ( |
142 |
| - y_hat[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy() |
143 |
| - ) |
144 |
| - |
145 |
| - figures: list[wandb.Image] = [] |
146 |
| - for i in range(min(x.shape[0], self.num_samples)): |
147 |
| - img_original = wandb.Image( |
148 |
| - data_or_path=skimage.exposure.equalize_hist( |
149 |
| - image=rgb_original[i] |
150 |
| - ), |
151 |
| - caption=f"RGB Image {i}", |
152 |
| - ) |
153 |
| - figures.append(img_original) |
154 |
| - |
155 |
| - img_reconstruction = wandb.Image( |
156 |
| - data_or_path=skimage.exposure.equalize_hist( |
157 |
| - image=rgb_reconstruction[i] |
158 |
| - ), |
159 |
| - caption=f"Reconstructed {i}", |
160 |
| - ) |
161 |
| - figures.append(img_reconstruction) |
162 |
| - |
163 |
| - # Upload figures to WandB |
164 |
| - self.logger.experiment.log(data={"Examples": figures}) |
165 |
| - |
166 |
| - return figures |
167 |
| - |
168 |
| - |
169 |
| -class LogIntermediatePredictions(L.Callback): |
| 38 | +class LogDINOPredictions(L.Callback): |
170 | 39 | """Visualize the model results at the end of every epoch."""
|
171 | 40 |
|
172 | 41 | def __init__(self):
|
173 | 42 | """
|
174 | 43 | Instantiates with wandb-logger.
|
175 | 44 | """
|
176 | 45 | super().__init__()
|
| 46 | + self.selected_image = None |
177 | 47 |
|
178 |
| - def on_validation_end( |
| 48 | + def on_validation_batch_end( |
179 | 49 | self,
|
180 | 50 | trainer: L.Trainer,
|
181 | 51 | pl_module: L.LightningModule,
|
| 52 | + outputs, |
| 53 | + batch, |
| 54 | + batch_idx, |
182 | 55 | ) -> None:
|
183 |
| - """ |
184 |
| - Called when the validation loop ends. |
185 |
| - At the end of each epoch, takes the first batch from validation dataset |
186 |
| - & logs the model predictions to wandb-logger for humans to interpret |
187 |
| - how model evolves over time. |
188 |
| - """ |
189 | 56 | with torch.no_grad():
|
190 | 57 | # Get WandB logger
|
191 | 58 | self.logger = get_wandb_logger(trainer=trainer)
|
192 | 59 |
|
193 |
| - # get the first batch from trainer |
194 |
| - batch = next(iter(trainer.val_dataloaders)) |
| 60 | + if self.selected_image is None: |
| 61 | + self.selected_image = self.select_image(trainer, pl_module) |
| 62 | + |
| 63 | + # Update the model weights after batch accumulation and backpropagation |
| 64 | + pl_module.student.load_state_dict(pl_module.teacher.state_dict()) |
| 65 | + |
| 66 | + if batch_idx % trainer.log_every_n_steps == 0: |
| 67 | + self.log_images(trainer, pl_module) |
| 68 | + |
| 69 | + def select_image(self, trainer, pl_module): |
| 70 | + print("Selecting image with max variance") |
| 71 | + batches = islice(iter(trainer.val_dataloaders), 2) |
| 72 | + max_variance = -1 |
| 73 | + for ibatch in batches: |
195 | 74 | batch = {
|
196 | 75 | k: v.to(pl_module.device)
|
197 |
| - for k, v in batch.items() |
| 76 | + for k, v in ibatch.items() |
198 | 77 | if isinstance(v, torch.Tensor)
|
199 | 78 | }
|
200 |
| - # ENCODER |
201 |
| - ( |
202 |
| - encoded_unmasked_patches, |
203 |
| - unmasked_indices, |
204 |
| - masked_indices, |
205 |
| - masked_matrix, |
206 |
| - ) = pl_module.model.encoder(batch) |
207 |
| - |
208 |
| - # DECODER |
209 |
| - pixels = pl_module.model.decoder( |
210 |
| - encoded_unmasked_patches, unmasked_indices, masked_indices |
211 |
| - ) |
212 |
| - pixels = rearrange( |
213 |
| - pixels, |
214 |
| - "b c (h w) (p1 p2) -> b c (h p1) (w p2)", |
215 |
| - h=pl_module.model.image_size // pl_module.model.patch_size, |
216 |
| - p1=pl_module.model.patch_size, |
| 79 | + images = batch["pixels"] # Shape: [batch_size, channels, height, width] |
| 80 | + variances = images.var( |
| 81 | + dim=[1, 2, 3], keepdim=False |
| 82 | + ) # Calculate variance across C, H, W dimensions |
| 83 | + max_var_index = torch.argmax(variances).item() |
| 84 | + if variances[max_var_index] > max_variance: |
| 85 | + max_variance = variances[max_var_index] |
| 86 | + self.selected_image = max_var_index |
| 87 | + assert self.selected_image is not None |
| 88 | + print(f"Selected image with max variance: {self.selected_image}") |
| 89 | + return self.selected_image |
| 90 | + |
| 91 | + def log_images(self, trainer, pl_module): |
| 92 | + if self.selected_image >= trainer.val_dataloaders.batch_size: |
| 93 | + batch = next( |
| 94 | + islice( |
| 95 | + iter(trainer.val_dataloaders), |
| 96 | + self.selected_image // trainer.val_dataloaders.batch_size, |
| 97 | + None, |
| 98 | + ) |
217 | 99 | )
|
| 100 | + else: |
| 101 | + batch = next(iter(trainer.val_dataloaders)) |
218 | 102 |
|
219 |
| - assert pixels.shape == batch["pixels"].shape |
| 103 | + batch = { |
| 104 | + k: v.to(pl_module.device) |
| 105 | + for k, v in batch.items() |
| 106 | + if isinstance(v, torch.Tensor) |
| 107 | + } |
| 108 | + # ENCODER |
| 109 | + ( |
| 110 | + encoded_unmasked_patches, |
| 111 | + unmasked_indices, |
| 112 | + masked_indices, |
| 113 | + masked_matrix, |
| 114 | + ) = pl_module.student.model.encoder(batch) |
| 115 | + |
| 116 | + # DECODER |
| 117 | + pixels = pl_module.student.model.decoder( |
| 118 | + encoded_unmasked_patches, unmasked_indices, masked_indices |
| 119 | + ) |
| 120 | + pixels = rearrange( |
| 121 | + pixels, |
| 122 | + "b c (h w) (p1 p2) -> b c (h p1) (w p2)", |
| 123 | + h=pl_module.student.model.image_size // pl_module.student.model.patch_size, |
| 124 | + p1=pl_module.student.model.patch_size, |
| 125 | + ) |
220 | 126 |
|
221 |
| - n_rows = 2 |
222 |
| - n_cols = 8 |
| 127 | + assert pixels.shape == batch["pixels"].shape |
| 128 | + |
| 129 | + band_groups = { |
| 130 | + "rgb": (2, 1, 0), |
| 131 | + "<rededge>": (3, 4, 5, 7), |
| 132 | + "<ir>": (6, 8, 9), |
| 133 | + "<sar>": (10, 11), |
| 134 | + "dem": (12,), |
| 135 | + } |
| 136 | + |
| 137 | + n_rows, n_cols = ( |
| 138 | + 3, |
| 139 | + len(band_groups), |
| 140 | + ) # Rows for Input, Prediction, Difference |
| 141 | + fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5)) |
| 142 | + |
| 143 | + def normalize_img(img): |
| 144 | + lower_percentile, upper_percentile = 1, 99 |
| 145 | + lower_bound = np.percentile(img, lower_percentile) |
| 146 | + upper_bound = np.percentile(img, upper_percentile) |
| 147 | + img_clipped = np.clip(img, lower_bound, upper_bound) |
| 148 | + return (img_clipped - img_clipped.min()) / ( |
| 149 | + img_clipped.max() - img_clipped.min() |
| 150 | + ) |
223 | 151 |
|
224 |
| - fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 4)) |
| 152 | + for col, (group_name, bands) in enumerate(band_groups.items()): |
| 153 | + input_img = batch["pixels"][:, bands, :, :] |
| 154 | + pred_img = pixels[:, bands, :, :] |
| 155 | + input_img = ( |
| 156 | + input_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0) |
| 157 | + ) |
| 158 | + pred_img = ( |
| 159 | + pred_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0) |
| 160 | + ) |
225 | 161 |
|
226 |
| - for i in range(n_cols): |
227 |
| - axs[0, i].imshow( |
228 |
| - batch["pixels"][i][0].detach().cpu().numpy(), cmap="viridis" |
| 162 | + if group_name == "rgb": |
| 163 | + # Normalize RGB images |
| 164 | + input_norm = normalize_img(input_img) |
| 165 | + pred_norm = normalize_img(pred_img) |
| 166 | + # Calculate absolute difference for RGB |
| 167 | + diff_rgb = np.abs(input_norm - pred_norm) |
| 168 | + else: |
| 169 | + # Calculate mean for non-RGB bands if necessary |
| 170 | + input_mean = input_img.mean(axis=2) if input_img.ndim > 2 else input_img |
| 171 | + pred_mean = pred_img.mean(axis=2) if pred_img.ndim > 2 else pred_img |
| 172 | + # Normalize and calculate difference |
| 173 | + input_norm = normalize_img(input_mean) |
| 174 | + pred_norm = normalize_img(pred_mean) |
| 175 | + diff_rgb = np.abs(input_norm - pred_norm) |
| 176 | + |
| 177 | + axs[0, col].imshow(input_norm, cmap="gray" if group_name != "rgb" else None) |
| 178 | + axs[1, col].imshow(pred_norm, cmap="gray" if group_name != "rgb" else None) |
| 179 | + axs[2, col].imshow(diff_rgb, cmap="gray" if group_name != "rgb" else None) |
| 180 | + |
| 181 | + for ax in axs[:, col]: |
| 182 | + ax.set_title( |
| 183 | + f"""{group_name} {'Input' if ax == axs[0, col] else |
| 184 | + 'Pred' if ax == axs[1, col] else |
| 185 | + 'Diff'}""" |
229 | 186 | )
|
230 |
| - axs[0, i].set_title(f"Image {i}") |
231 |
| - axs[0, i].axis("off") |
232 |
| - |
233 |
| - axs[1, i].imshow(pixels[i][0].detach().cpu().numpy(), cmap="viridis") |
234 |
| - axs[1, i].set_title(f"Preds {i}") |
235 |
| - axs[1, i].axis("off") |
| 187 | + ax.axis("off") |
236 | 188 |
|
237 |
| - self.logger.experiment.log({"Images": wandb.Image(fig)}) |
238 |
| - plt.close(fig) |
| 189 | + plt.tight_layout() |
| 190 | + self.logger.experiment.log({"Images": wandb.Image(fig)}) |
| 191 | + plt.close(fig) |
0 commit comments