Skip to content

Commit 0b5e880

Browse files
committed
implement gradient ascent
1 parent ab1bc95 commit 0b5e880

File tree

9 files changed

+1018
-7
lines changed

9 files changed

+1018
-7
lines changed

data/model_info.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
model layer layer_index rf_size xn num_units
2+
alexnet conv1 0 11 15 64
3+
alexnet conv2 3 51 63 192
4+
alexnet conv3 6 99 127 384
5+
alexnet conv4 8 131 159 256
6+
alexnet conv5 10 163 191 256
7+
vgg16 conv1 0 3 5 64
8+
vgg16 conv2 2 5 7 64
9+
vgg16 conv3 5 10 14 128
10+
vgg16 conv4 7 14 18 128
11+
vgg16 conv5 10 24 28 256
12+
vgg16 conv6 12 32 36 256
13+
vgg16 conv7 14 40 52 256
14+
vgg16 conv8 17 60 72 512
15+
vgg16 conv9 19 76 88 512
16+
vgg16 conv10 21 92 104 512
17+
vgg16 conv11 24 132 176 512
18+
vgg16 conv12 26 164 208 512
19+
vgg16 conv13 28 196 240 512
20+
resnet18 conv1 0 7 9 64
21+
resnet18 conv2 4 19 25 64
22+
resnet18 conv3 7 27 33 64
23+
resnet18 conv4 10 35 41 64
24+
resnet18 conv5 13 43 49 64
25+
resnet18 conv6 16 51 65 128
26+
resnet18 conv7 19 67 81 128
27+
resnet18 conv8 21 43 49 128
28+
resnet18 conv9 24 83 97 128
29+
resnet18 conv10 27 99 113 128
30+
resnet18 conv11 30 115 129 256
31+
resnet18 conv12 33 147 193 256
32+
resnet18 conv13 35 99 129 256
33+
resnet18 conv14 38 179 225 256
34+
resnet18 conv15 41 211 257 256
35+
resnet18 conv16 44 243 321 512
36+
resnet18 conv17 47 307 385 512
37+
resnet18 conv18 49 211 257 512
38+
resnet18 conv19 52 371 449 512
39+
resnet18 conv20 55 435 513 512

src/rf_mapping/compare/gt_vs_rfmp.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
Rfmp4aWeighted as W)
2727

2828
# Please specify the model
29-
# model = models.alexnet()
30-
# model_name = 'alexnet'
29+
model = models.alexnet()
30+
model_name = 'alexnet'
3131
# model = models.vgg16()
3232
# model_name = 'vgg16'
33-
model = models.resnet18()
34-
model_name = 'resnet18'
33+
# model = models.resnet18()
34+
# model_name = 'resnet18'
3535

3636
# Please specify what ground_truth method versus what RFMP4
3737
is_occlude = False
@@ -823,7 +823,7 @@ def make_radius3_pdf():
823823

824824

825825
if __name__ == '__main__':
826-
make_radius3_pdf()
826+
# make_radius3_pdf()
827827
pass
828828

829829

@@ -895,7 +895,7 @@ def make_ori_pdf():
895895
plt.close()
896896

897897
if __name__ == '__main__':
898-
# make_ori_pdf()
898+
make_ori_pdf()
899899
pass
900900

901901

@@ -1287,7 +1287,7 @@ def make_error_coords2_pdf():
12871287
plt.close()
12881288

12891289
if __name__ == '__main__':
1290-
make_error_coords2_pdf()
1290+
# make_error_coords2_pdf()
12911291
pass
12921292

12931293

src/rf_mapping/grad_ascent.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.optim as optim
3+
4+
__all__ = ['GradientAscent']
5+
6+
7+
class GradientAscent:
8+
def __init__(self, truncated_model: torch.nn.Module, unit_index: int, img: torch.Tensor,
9+
lr: float = 0.1, optimizer: str = 'SGD', momentum: bool = False):
10+
"""
11+
Performs gradient ascent on a given image to maximize the response of a specified unit in a neural network.
12+
13+
Args:
14+
truncated_model: The truncated neural network.
15+
unit_index: The index of the unit of interest.
16+
img: The starting image for optimization.
17+
lr: The learning rate for the optimizer.
18+
optimizer: The optimizer to use. Options: 'SGD', 'Adam'.
19+
momentum: Whether to use momentum with the optimizer.
20+
"""
21+
self.model = truncated_model
22+
self.unit_index = unit_index
23+
self.img = img.requires_grad_(True)
24+
self.optimizer = self._get_optimizer(optimizer, lr, momentum)
25+
26+
def _get_optimizer(self, optimizer_name: str, lr: float, momentum: bool) -> optim.Optimizer:
27+
if optimizer_name == 'Adam':
28+
return optim.Adam([self.img], lr=lr)
29+
elif optimizer_name == 'SGD':
30+
return optim.SGD([self.img], lr=lr, momentum=momentum)
31+
else:
32+
raise ValueError(f'Optimizer "{optimizer_name}" not supported')
33+
34+
def _objective_function(self, x: torch.Tensor) -> torch.Tensor:
35+
responses = self.model(x)
36+
num_images, num_units, ny, nx = responses.shape
37+
return responses[0, self.unit_index, ny//2, nx//2]
38+
39+
def step(self) -> torch.Tensor:
40+
"""
41+
Takes one optimization step and returns the updated image tensor.
42+
43+
Returns:
44+
The updated image tensor.
45+
"""
46+
self.optimizer.zero_grad()
47+
48+
# Need to put a negative sign because optimizer minimizes the "loss",
49+
# but this is an response, and we want to maximize it.
50+
response = -self._objective_function(self.img)
51+
52+
# Compute the gradient of the response with respect to the image.
53+
response.backward()
54+
55+
# Update the image using the optimizer.
56+
self.optimizer.step()
57+
58+
# Reset the gradient to zero.
59+
self.img.grad.zero_()
60+
61+
return self.img
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""
2+
Making gradient ascent visualization. Similar to make_zero_intialized_grad_ascent.py,
3+
but the images are not initialized with zeros. Instead, they are initialized by
4+
the image patches that give them the highest (most positive) responses.
5+
6+
Logs
7+
----
8+
Model Name: alexnet
9+
Optimization Method: SGD (no momentum)
10+
Device: Macbook Pro 16" Late 2021 with M1 Pro
11+
Time: 37 minutes
12+
Space: generates about 674 MB of data.
13+
14+
Tony Fu, Bair Lab, March 2023
15+
16+
"""
17+
18+
# #################################### GUARD ####################################
19+
20+
# confirmation = input("Are you sure you want to run this code? (Y/N)")
21+
# if confirmation.lower() == "y":
22+
# pass
23+
# else:
24+
# print("Code execution aborted.")
25+
26+
# ###############################################################################
27+
28+
29+
import os
30+
import sys
31+
import multiprocessing
32+
from typing import Tuple
33+
34+
import torch
35+
import numpy as np
36+
from torchvision import models
37+
from tqdm import tqdm
38+
import matplotlib.pyplot as plt
39+
40+
41+
sys.path.append('../../..')
42+
import src.rf_mapping.constants as c
43+
from src.rf_mapping.spatial import SpatialIndexConverter
44+
from src.rf_mapping.net import get_truncated_model
45+
from src.rf_mapping.image import normalize_img
46+
from src.rf_mapping.grad_ascent import GradientAscent
47+
from src.rf_mapping.model_utils import ModelInfo
48+
49+
50+
# Please specify some model details here:
51+
MODEL_NAME = "alexnet"
52+
53+
# Specify optimization method
54+
OPTIMIZATION_METHOD = 'SGD' # options: SGD and Adam
55+
NUM_ITER = 100
56+
LR = 0.1
57+
MOMENTUM = False
58+
59+
# Set the result directory
60+
RESULT_DIR = os.path.join(c.REPO_DIR, 'results', 'gradient_ascent','mapping', MODEL_NAME)
61+
62+
########################### DON'T TOUCH CODE BELOW ############################
63+
64+
# Load model and related information
65+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66+
MODEL = getattr(models, MODEL_NAME)(pretrained=True).to(DEVICE)
67+
MODEL_INFO = ModelInfo()
68+
LAYER_NAMES = MODEL_INFO.get_layer_names(MODEL_NAME)
69+
70+
# Get the image directory
71+
IMG_SIZE = (227, 227)
72+
NUM_TOP_IMAGES = 100
73+
74+
"""
75+
Obviously, I cannot upload the content of IMG_DIR to GitHub because it is too
76+
big. Here is some more information about the images at IMG_DIR so you can
77+
try it yourself:
78+
79+
- Source: a subset ImageNet's testing set
80+
- Number of image: 50000
81+
- Size: 3 x 227 x 227
82+
- Total size: about 62 GB
83+
- Format: .npy (NumPy arrays)
84+
- Note: The RGB values are batch-normalized to be roughly in the range of -1.0 to +1.0.
85+
- Where you can download it: http://wartburg.biostr.washington.edu/loc/course/artiphys/data/i50k.html
86+
"""
87+
88+
# Get the spatial indicies (for more info, see below)
89+
def get_max_min_indicies(layer_name):
90+
spatial_index_path = os.path.join(c.REPO_DIR, 'results', 'ground_truth', 'top_n',
91+
MODEL_NAME, f"{layer_name}.npy")
92+
return np.load(spatial_index_path).astype(int)
93+
"""
94+
I feel the need to explain what these spatial indices are all about. This was an idea that was developed by Dr. Wyeth Bair and his PhD student, Dr. Dean Pospisil. They presented the 50,000 images to the neural networks. As a reminder, the convolution (or more accurately, cross-correlation) operation slides the unit's kernel along the two spatial dimensions of the input. The first thing they did was to find the spatial locations that produced the maximum (most positive) responses. They repeated this process for all 50,000 images, and then ranked the resulting max locations to find out which image patches gave the strongest responses.
95+
96+
`MAX_MIN_INDICES` contains the results of this ranking for a particular convolutional layer of a model. The array has dimensions [num_units, 100, 4].
97+
98+
`num_units` represents the number of unique kernels in the convolutional layer. For example, Conv1 of AlexNet has 64 unique kernels. The second dimension `k` is 100 because the array stores the top and bottom 100 image patches. The last dimension has a size of 4 because it contains:
99+
(1) `max_img_idx`: the index (ranging from 0 to 49,999) of the k-th most positive response image.
100+
(2) `max_spatial_idx`: the spatial index of the kernel (not the pixel). The kernel is first slided along the x-axis, then y-axis. For instance, a spatial index of 0 corresponds to (0,0), and 1 to (0, 1). We can convert from 1D indexing to 2D using np.unravel_index(spatial_index, (output_height, output_width)).
101+
(3) `min_img_idx`: same as `max_img_idx`, but for the k-th most negative response image.
102+
(4) `min_spatial_idx`: same as `max_spatial_idx`, but for the k-th most negative response image patch.
103+
"""
104+
105+
# Initiate helper objects. This object converts the spatial index from the
106+
# output layer to that of the input layer (i.e., pixel coordinates).
107+
converter = SpatialIndexConverter(MODEL, IMG_SIZE)
108+
109+
##################### Define a few small helper functions #####################
110+
111+
def clip(x, min_value, max_value):
112+
return max(min(x, max_value), min_value)
113+
114+
def pad_box(box: Tuple[int, int, int, int], padding: int):
115+
"""Makes sure box does not go beyond the image after padding."""
116+
y_min, x_min, y_max, x_max = box
117+
new_y_min = clip(y_min-padding, 0, IMG_SIZE[0])
118+
new_x_min = clip(x_min-padding, 0, IMG_SIZE[1])
119+
new_y_max = clip(y_max+padding, 0, IMG_SIZE[0])
120+
new_x_max = clip(x_max+padding, 0, IMG_SIZE[1])
121+
return new_y_min, new_x_min, new_y_max, new_x_max
122+
123+
def process_tensor(img_tensor: torch.Tensor, normalize=True) -> np.ndarray:
124+
"""
125+
Converts a tensor to a Numpy array and normalize it to [0, 1].
126+
127+
Args:
128+
img_tensor: A tensor of shape (C, H, W).
129+
normalize: A flag that decides whether the output should be normalized or not.
130+
131+
Returns:
132+
A Numpy array of shape (H, W, C).
133+
134+
"""
135+
img_numpy = img_tensor.detach().cpu().numpy()
136+
img_numpy = np.squeeze(img_numpy)
137+
img_numpy = np.transpose(img_numpy, (1, 2, 0))
138+
139+
if normalize:
140+
# Normalizes pixel values to [0, 1.0]
141+
img_range = img_numpy.max() - img_numpy.min()
142+
if not np.isclose(img_range, 0, rtol=0, atol=1e-5):
143+
img_numpy = (img_numpy - img_numpy.min()) / img_range
144+
return img_numpy
145+
146+
147+
def one_sided_zero_pad(patch: np.ndarray, desired_size: int, box: Tuple[int, int, int, int]):
148+
"""
149+
Return original patch if it is the right size. Assumes that the patch
150+
given is always smaller or equal to the desired size. The box tells us
151+
the spatial location of the patch on the image.
152+
"""
153+
if len(patch.shape) != 3 or patch.shape[0] != 3:
154+
raise ValueError(f"patch must be have shape (3, height, width), but got {patch.shape}")
155+
if patch.shape[1] == desired_size and patch.shape[2] == desired_size:
156+
return patch
157+
158+
vx_min, hx_min, vx_max, hx_max = box
159+
touching_top_edge = (vx_min <= 0)
160+
touching_left_edge = (hx_min <= 0)
161+
162+
padded_patch = np.zeros((3, desired_size, desired_size))
163+
_, patch_h, patch_w = patch.shape
164+
165+
if touching_top_edge and touching_top_edge:
166+
padded_patch[:, -patch_h:, -patch_w:] = patch # fill from bottom right
167+
elif touching_top_edge:
168+
padded_patch[:, -patch_h:, :patch_w] = patch # fill from bottom left
169+
elif touching_left_edge:
170+
padded_patch[:, :patch_h, -patch_w:] = patch # fill from top right
171+
else:
172+
padded_patch[:, :patch_h, :patch_w] = patch # fill from top left
173+
174+
return padded_patch
175+
176+
###############################################################################
177+
178+
def create_visualizations_for_layer(layer_name):
179+
# Determine layer-specific information
180+
num_units = MODEL_INFO.get_num_units(MODEL_NAME, layer_name)
181+
layer_index = MODEL_INFO.get_layer_index(MODEL_NAME, layer_name)
182+
xn = MODEL_INFO.get_xn(MODEL_NAME, layer_name)
183+
rf_size = MODEL_INFO.get_rf_size(MODEL_NAME, layer_name)
184+
padding = (xn - rf_size) // 2
185+
186+
# Use the truncated model to save time
187+
truncated_model = get_truncated_model(MODEL, layer_index)
188+
189+
# Define the output directory, create it if necessary
190+
layer_dir = os.path.join(RESULT_DIR, layer_name)
191+
if not os.path.exists(layer_dir):
192+
os.makedirs(layer_dir)
193+
194+
# Find the top- and bottom-100 image patches ranking of the layer
195+
max_min_indicies = get_max_min_indicies(layer_name)
196+
197+
# We will also store the results in a numpy array
198+
result_array = np.zeros((num_units, xn, xn, 3))
199+
200+
for unit_index in tqdm(range(num_units)):
201+
for top_i in range(NUM_TOP_IMAGES):
202+
# Get top and bottom image indices and patch spatial indices
203+
max_n_img_index = max_min_indicies[unit_index, top_i, 0]
204+
max_n_patch_index = max_min_indicies[unit_index, top_i, 1]
205+
206+
# Convert from output spatial index to pixel coordinate
207+
box = converter.convert(max_n_patch_index, layer_index, 0, is_forward=False)
208+
209+
# Prevent indexing out of range
210+
y_min, x_min, y_max, x_max = pad_box(box, padding)
211+
212+
# Load the image
213+
img_path = os.path.join(c.IMG_DIR, f"{max_n_img_index}.npy")
214+
img_numpy = np.load(img_path)[:, y_min:y_max+1, x_min:x_max+1]
215+
216+
# Pad it to (3, xn, xn) if necessary
217+
img_numpy = one_sided_zero_pad(img_numpy, xn, (y_min, x_min, y_max, x_max))
218+
219+
# Convert to tensor
220+
img = torch.from_numpy(img_numpy).type('torch.FloatTensor').unsqueeze(0)
221+
img.requires_grad = True
222+
img.to(DEVICE)
223+
224+
# Computer gradient ascent
225+
ga = GradientAscent(truncated_model, unit_index, img, lr=LR,
226+
optimizer=OPTIMIZATION_METHOD, momentum=MOMENTUM)
227+
for _ in range(NUM_ITER - 1):
228+
ga.step()
229+
result_tensor = ga.step()
230+
231+
# Subtract the original image, then convert to numpy array
232+
result_npy = process_tensor(result_tensor, normalize=False) - img_numpy.transpose(1, 2, 0)
233+
234+
# Save result to an image
235+
result_array[unit_index] += result_npy
236+
237+
plt.imshow(normalize_img(result_array[unit_index]))
238+
plt.axis('off')
239+
plt.savefig(os.path.join(layer_dir, f"{unit_index}.png"))
240+
plt.close()
241+
242+
np.save(os.path.join(RESULT_DIR, f"{layer_name}.npy"), result_array)
243+
244+
if __name__ == '__main__':
245+
with multiprocessing.Pool(processes=len(LAYER_NAMES)) as pool:
246+
pool.map(create_visualizations_for_layer, LAYER_NAMES)

0 commit comments

Comments
 (0)