Skip to content

Commit 7b92c1a

Browse files
committed
Implement DINO w/ better logs
1 parent 1fe4fdf commit 7b92c1a

File tree

3 files changed

+336
-194
lines changed

3 files changed

+336
-194
lines changed

src/callbacks_wandb.py

Lines changed: 128 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,8 @@
1-
"""
2-
Lightning callback functions for logging to Weights & Biases.
1+
from itertools import islice
32

4-
Includes a way to visualize RGB images derived from the raw logits of a Masked
5-
Autoencoder's decoder during the validation loop. I.e. to see if the Vision
6-
Transformer model is learning how to do image reconstruction.
7-
8-
Usage:
9-
10-
```
11-
import lightning as L
12-
13-
from src.callbacks_wandb import LogMAEReconstruction
14-
15-
trainer = L.Trainer(
16-
...,
17-
callbacks=[LogMAEReconstruction(num_samples=6)]
18-
)
19-
```
20-
21-
References:
22-
- https://lightning.ai/docs/pytorch/2.1.0/common/trainer.html#callbacks
23-
- https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
24-
- https://github.com/ashleve/lightning-hydra-template/blob/wandb-callbacks/src/callbacks/wandb_callbacks.py#L245
25-
"""
263
import lightning as L
274
import matplotlib.pyplot as plt
285
import numpy as np
29-
import skimage
306
import torch
317
from einops import rearrange
328

@@ -59,180 +35,157 @@ def get_wandb_logger(trainer: L.Trainer) -> L.pytorch.loggers.WandbLogger:
5935
)
6036

6137

62-
class LogMAEReconstruction(L.Callback):
63-
"""
64-
Logs reconstructed RGB images from a Masked Autoencoder's decoder to WandB.
65-
"""
66-
67-
def __init__(self, num_samples: int = 8):
68-
"""
69-
Define how many sample images to log.
70-
71-
Parameters
72-
----------
73-
num_samples : int
74-
The number of RGB image samples to upload to WandB. Default is 8.
75-
"""
76-
super().__init__()
77-
self.num_samples: int = num_samples
78-
self.ready: bool = False
79-
80-
if wandb is None:
81-
raise ModuleNotFoundError(
82-
"Package `wandb` is required to be installed to use this callback. "
83-
"Please use `pip install wandb` or "
84-
"`conda install -c conda-forge wandb` "
85-
"to install the package"
86-
)
87-
88-
def on_sanity_check_start(self, trainer, pl_module):
89-
"""
90-
Don't execute callback before validation sanity checks are completed.
91-
"""
92-
self.ready = False
93-
94-
def on_sanity_check_end(self, trainer, pl_module):
95-
"""
96-
Start executing callback only after all validation sanity checks end.
97-
"""
98-
self.ready = True
99-
100-
def on_validation_batch_end(
101-
self,
102-
trainer: L.Trainer,
103-
pl_module: L.LightningModule,
104-
outputs: dict[str, torch.Tensor],
105-
batch: dict[str, torch.Tensor | list[str]],
106-
batch_idx: int,
107-
) -> list:
108-
"""
109-
Called in the validation loop at the start of every mini-batch.
110-
111-
Gather a sample of data from the first mini-batch, get the RGB bands,
112-
apply histogram equalization to the image, and log it to WandB.
113-
"""
114-
if self.ready and batch_idx == 0: # only run on first mini-batch
115-
with torch.inference_mode():
116-
# Get WandB logger
117-
self.logger = get_wandb_logger(trainer=trainer)
118-
119-
# Turn raw logits into reconstructed 512x512 images
120-
patchified_pixel_values: torch.Tensor = outputs["logits"]
121-
# assert patchified_pixel_values.shape == torch.Size([32, 64, 53248])
122-
y_hat: torch.Tensor = pl_module.vit.unpatchify(
123-
patchified_pixel_values=patchified_pixel_values
124-
)
125-
# assert y_hat.shape == torch.Size([32, 13, 512, 512])
126-
127-
# Reshape tensors from channel-first to channel-last
128-
x: torch.Tensor = torch.einsum(
129-
"bchw->bhwc", batch["image"][: self.num_samples]
130-
)
131-
y_hat: torch.Tensor = torch.einsum(
132-
"bchw->bhwc", y_hat[: self.num_samples]
133-
)
134-
# assert y_hat.shape == torch.Size([8, 512, 512, 13])
135-
assert x.shape == y_hat.shape
136-
137-
# Plot original and reconstructed RGB images of Sentinel-2
138-
rgb_original: np.ndarray = (
139-
x[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy()
140-
)
141-
rgb_reconstruction: np.ndarray = (
142-
y_hat[:, :, :, [2, 1, 0]].cpu().to(dtype=torch.float32).numpy()
143-
)
144-
145-
figures: list[wandb.Image] = []
146-
for i in range(min(x.shape[0], self.num_samples)):
147-
img_original = wandb.Image(
148-
data_or_path=skimage.exposure.equalize_hist(
149-
image=rgb_original[i]
150-
),
151-
caption=f"RGB Image {i}",
152-
)
153-
figures.append(img_original)
154-
155-
img_reconstruction = wandb.Image(
156-
data_or_path=skimage.exposure.equalize_hist(
157-
image=rgb_reconstruction[i]
158-
),
159-
caption=f"Reconstructed {i}",
160-
)
161-
figures.append(img_reconstruction)
162-
163-
# Upload figures to WandB
164-
self.logger.experiment.log(data={"Examples": figures})
165-
166-
return figures
167-
168-
169-
class LogIntermediatePredictions(L.Callback):
38+
class LogDINOPredictions(L.Callback):
17039
"""Visualize the model results at the end of every epoch."""
17140

17241
def __init__(self):
17342
"""
17443
Instantiates with wandb-logger.
17544
"""
17645
super().__init__()
46+
self.selected_image = None
17747

178-
def on_validation_end(
48+
def on_validation_batch_end(
17949
self,
18050
trainer: L.Trainer,
18151
pl_module: L.LightningModule,
52+
outputs,
53+
batch,
54+
batch_idx,
18255
) -> None:
183-
"""
184-
Called when the validation loop ends.
185-
At the end of each epoch, takes the first batch from validation dataset
186-
& logs the model predictions to wandb-logger for humans to interpret
187-
how model evolves over time.
188-
"""
18956
with torch.no_grad():
19057
# Get WandB logger
19158
self.logger = get_wandb_logger(trainer=trainer)
19259

193-
# get the first batch from trainer
194-
batch = next(iter(trainer.val_dataloaders))
60+
if self.selected_image is None:
61+
self.selected_image = self.select_image(trainer, pl_module)
62+
63+
# Update the model weights after batch accumulation and backpropagation
64+
pl_module.student.load_state_dict(pl_module.teacher.state_dict())
65+
66+
if batch_idx % trainer.log_every_n_steps == 0:
67+
self.log_images(trainer, pl_module)
68+
69+
def select_image(self, trainer, pl_module):
70+
print("Selecting image with max variance")
71+
batches = islice(iter(trainer.val_dataloaders), 2)
72+
max_variance = -1
73+
for ibatch in batches:
19574
batch = {
19675
k: v.to(pl_module.device)
197-
for k, v in batch.items()
76+
for k, v in ibatch.items()
19877
if isinstance(v, torch.Tensor)
19978
}
200-
# ENCODER
201-
(
202-
encoded_unmasked_patches,
203-
unmasked_indices,
204-
masked_indices,
205-
masked_matrix,
206-
) = pl_module.model.encoder(batch)
207-
208-
# DECODER
209-
pixels = pl_module.model.decoder(
210-
encoded_unmasked_patches, unmasked_indices, masked_indices
211-
)
212-
pixels = rearrange(
213-
pixels,
214-
"b c (h w) (p1 p2) -> b c (h p1) (w p2)",
215-
h=pl_module.model.image_size // pl_module.model.patch_size,
216-
p1=pl_module.model.patch_size,
79+
images = batch["pixels"] # Shape: [batch_size, channels, height, width]
80+
variances = images.var(
81+
dim=[1, 2, 3], keepdim=False
82+
) # Calculate variance across C, H, W dimensions
83+
max_var_index = torch.argmax(variances).item()
84+
if variances[max_var_index] > max_variance:
85+
max_variance = variances[max_var_index]
86+
self.selected_image = max_var_index
87+
assert self.selected_image is not None
88+
print(f"Selected image with max variance: {self.selected_image}")
89+
return self.selected_image
90+
91+
def log_images(self, trainer, pl_module):
92+
if self.selected_image >= trainer.val_dataloaders.batch_size:
93+
batch = next(
94+
islice(
95+
iter(trainer.val_dataloaders),
96+
self.selected_image // trainer.val_dataloaders.batch_size,
97+
None,
98+
)
21799
)
100+
else:
101+
batch = next(iter(trainer.val_dataloaders))
218102

219-
assert pixels.shape == batch["pixels"].shape
103+
batch = {
104+
k: v.to(pl_module.device)
105+
for k, v in batch.items()
106+
if isinstance(v, torch.Tensor)
107+
}
108+
# ENCODER
109+
(
110+
encoded_unmasked_patches,
111+
unmasked_indices,
112+
masked_indices,
113+
masked_matrix,
114+
) = pl_module.student.model.encoder(batch)
115+
116+
# DECODER
117+
pixels = pl_module.student.model.decoder(
118+
encoded_unmasked_patches, unmasked_indices, masked_indices
119+
)
120+
pixels = rearrange(
121+
pixels,
122+
"b c (h w) (p1 p2) -> b c (h p1) (w p2)",
123+
h=pl_module.student.model.image_size // pl_module.student.model.patch_size,
124+
p1=pl_module.student.model.patch_size,
125+
)
220126

221-
n_rows = 2
222-
n_cols = 8
127+
assert pixels.shape == batch["pixels"].shape
128+
129+
band_groups = {
130+
"rgb": (2, 1, 0),
131+
"<rededge>": (3, 4, 5, 7),
132+
"<ir>": (6, 8, 9),
133+
"<sar>": (10, 11),
134+
"dem": (12,),
135+
}
136+
137+
n_rows, n_cols = (
138+
3,
139+
len(band_groups),
140+
) # Rows for Input, Prediction, Difference
141+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))
142+
143+
def normalize_img(img):
144+
lower_percentile, upper_percentile = 1, 99
145+
lower_bound = np.percentile(img, lower_percentile)
146+
upper_bound = np.percentile(img, upper_percentile)
147+
img_clipped = np.clip(img, lower_bound, upper_bound)
148+
return (img_clipped - img_clipped.min()) / (
149+
img_clipped.max() - img_clipped.min()
150+
)
223151

224-
fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 4))
152+
for col, (group_name, bands) in enumerate(band_groups.items()):
153+
input_img = batch["pixels"][:, bands, :, :]
154+
pred_img = pixels[:, bands, :, :]
155+
input_img = (
156+
input_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0)
157+
)
158+
pred_img = (
159+
pred_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0)
160+
)
225161

226-
for i in range(n_cols):
227-
axs[0, i].imshow(
228-
batch["pixels"][i][0].detach().cpu().numpy(), cmap="viridis"
162+
if group_name == "rgb":
163+
# Normalize RGB images
164+
input_norm = normalize_img(input_img)
165+
pred_norm = normalize_img(pred_img)
166+
# Calculate absolute difference for RGB
167+
diff_rgb = np.abs(input_norm - pred_norm)
168+
else:
169+
# Calculate mean for non-RGB bands if necessary
170+
input_mean = input_img.mean(axis=2) if input_img.ndim > 2 else input_img
171+
pred_mean = pred_img.mean(axis=2) if pred_img.ndim > 2 else pred_img
172+
# Normalize and calculate difference
173+
input_norm = normalize_img(input_mean)
174+
pred_norm = normalize_img(pred_mean)
175+
diff_rgb = np.abs(input_norm - pred_norm)
176+
177+
axs[0, col].imshow(input_norm, cmap="gray" if group_name != "rgb" else None)
178+
axs[1, col].imshow(pred_norm, cmap="gray" if group_name != "rgb" else None)
179+
axs[2, col].imshow(diff_rgb, cmap="gray" if group_name != "rgb" else None)
180+
181+
for ax in axs[:, col]:
182+
ax.set_title(
183+
f"""{group_name} {'Input' if ax == axs[0, col] else
184+
'Pred' if ax == axs[1, col] else
185+
'Diff'}"""
229186
)
230-
axs[0, i].set_title(f"Image {i}")
231-
axs[0, i].axis("off")
232-
233-
axs[1, i].imshow(pixels[i][0].detach().cpu().numpy(), cmap="viridis")
234-
axs[1, i].set_title(f"Preds {i}")
235-
axs[1, i].axis("off")
187+
ax.axis("off")
236188

237-
self.logger.experiment.log({"Images": wandb.Image(fig)})
238-
plt.close(fig)
189+
plt.tight_layout()
190+
self.logger.experiment.log({"Images": wandb.Image(fig)})
191+
plt.close(fig)

0 commit comments

Comments
 (0)