Skip to content

Commit 57ea951

Browse files
authored
More doc improvements
1 parent 3ab53ae commit 57ea951

File tree

5 files changed

+79
-40
lines changed

5 files changed

+79
-40
lines changed

captum/optim/_param/image/transforms.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, background: Optional[torch.Tensor] = None) -> None:
2020
"""
2121
Args:
2222
23-
background (tensor, optional): An NCHW image tensor to be used as the
23+
background (tensor, optional): An NCHW image tensor to be used as the
2424
Alpha channel's background.
2525
Default: ``None``
2626
"""
@@ -36,7 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3636
x (torch.Tensor): RGBA image tensor to blend into an RGB image tensor.
3737
3838
Returns:
39-
**blended** (torch.Tensor): RGB image tensor.
39+
blended (torch.Tensor): RGB image tensor.
4040
"""
4141
assert x.dim() == 4
4242
assert x.size(1) == 4
@@ -60,7 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6060
x (torch.Tensor): RGBA image tensor.
6161
6262
Returns:
63-
**rgb** (torch.Tensor): RGB image tensor without the alpha channel.
63+
rgb (torch.Tensor): RGB image tensor without the alpha channel.
6464
"""
6565
assert x.dim() == 4
6666
assert x.size(1) == 4
@@ -101,7 +101,7 @@ def klt_transform() -> torch.Tensor:
101101
Karhunen-Loève transform (KLT) measured on ImageNet
102102
103103
Returns:
104-
**transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on
104+
transform (torch.Tensor): A Karhunen-Loève transform (KLT) measured on
105105
the ImageNet dataset.
106106
"""
107107
# Handle older versions of PyTorch
@@ -120,7 +120,7 @@ def klt_transform() -> torch.Tensor:
120120
def i1i2i3_transform() -> torch.Tensor:
121121
"""
122122
Returns:
123-
**transform** (torch.Tensor): An approximation of natural colors transform
123+
transform (torch.Tensor): An approximation of natural colors transform
124124
(i1i2i3).
125125
"""
126126
i1i2i3_matrix = [
@@ -134,7 +134,7 @@ def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None:
134134
"""
135135
Args:
136136
137-
transform (str or tensor): Either a string for one of the precalculated
137+
transform (str or tensor): Either a string for one of the precalculated
138138
transform matrices, or a 3x3 matrix for the 3 RGB channels of input
139139
tensors.
140140
"""
@@ -352,7 +352,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
352352
input (torch.Tensor): Input to center crop.
353353
354354
Returns:
355-
**tensor** (torch.Tensor): A center cropped *tensor*.
355+
tensor (torch.Tensor): A center cropped NCHW tensor.
356356
"""
357357

358358
return center_crop(
@@ -402,7 +402,7 @@ def center_crop(
402402
Default: ``0.0``
403403
404404
Returns:
405-
**tensor** (torch.Tensor): A center cropped *tensor*.
405+
tensor (torch.Tensor): A center cropped NCHW tensor.
406406
"""
407407

408408
assert input.dim() == 3 or input.dim() == 4
@@ -537,7 +537,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
537537
scale (float): The amount to scale the NCHW image by.
538538
539539
Returns:
540-
**x** (torch.Tensor): A scaled NCHW image tensor.
540+
x (torch.Tensor): A scaled NCHW image tensor.
541541
"""
542542
if self._has_antialias:
543543
x = F.interpolate(
@@ -567,7 +567,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
567567
x (torch.Tensor): NCHW image tensor to randomly scale.
568568
569569
Returns:
570-
**x** (torch.Tensor): A randomly scaled NCHW image *tensor*.
570+
x (torch.Tensor): A randomly scaled NCHW image tensor.
571571
"""
572572
assert x.dim() == 4
573573
if self._is_distribution:
@@ -669,7 +669,7 @@ def _get_scale_mat(
669669
m (float): The scale value to use.
670670
671671
Returns:
672-
**scale_mat** (torch.Tensor): A scale matrix.
672+
scale_mat (torch.Tensor): A scale matrix.
673673
"""
674674
scale_mat = torch.tensor(
675675
[[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype
@@ -686,7 +686,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
686686
scale (float): The amount to scale the NCHW image by.
687687
688688
Returns:
689-
**x** (torch.Tensor): A scaled NCHW image tensor.
689+
x (torch.Tensor): A scaled NCHW image tensor.
690690
"""
691691
scale_matrix = self._get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat(
692692
x.shape[0], 1, 1
@@ -710,7 +710,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
710710
x (torch.Tensor): NCHW image tensor to randomly scale.
711711
712712
Returns:
713-
**x** (torch.Tensor): A randomly scaled NCHW image *tensor*.
713+
x (torch.Tensor): A randomly scaled NCHW image tensor.
714714
"""
715715
assert x.dim() == 4
716716
if self._is_distribution:
@@ -768,7 +768,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
768768
input (torch.Tensor): Input to randomly translate.
769769
770770
Returns:
771-
**tensor** (torch.Tensor): A randomly translated *tensor*.
771+
tensor (torch.Tensor): A randomly translated NCHW tensor.
772772
"""
773773
insets = torch.randint(
774774
high=self.pad_range,
@@ -854,7 +854,7 @@ def _get_rot_mat(
854854
theta (float): The rotation value in degrees.
855855
856856
Returns:
857-
**rot_mat** (torch.Tensor): A rotation matrix.
857+
rot_mat (torch.Tensor): A rotation matrix.
858858
"""
859859
theta = theta * math.pi / 180.0
860860
rot_mat = torch.tensor(
@@ -877,7 +877,7 @@ def _rotate_tensor(self, x: torch.Tensor, theta: float) -> torch.Tensor:
877877
theta (float): The amount to rotate the NCHW image, in degrees.
878878
879879
Returns:
880-
**x** (torch.Tensor): A rotated NCHW image tensor.
880+
x (torch.Tensor): A rotated NCHW image tensor.
881881
"""
882882
rot_matrix = self._get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat(
883883
x.shape[0], 1, 1
@@ -901,7 +901,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
901901
x (torch.Tensor): NCHW image tensor to randomly rotate.
902902
903903
Returns:
904-
**x** (torch.Tensor): A randomly rotated NCHW image *tensor*.
904+
x (torch.Tensor): A randomly rotated NCHW image tensor.
905905
"""
906906
assert x.dim() == 4
907907
if self._is_distribution:
@@ -933,7 +933,7 @@ def __init__(self, multiplier: float = 1.0) -> None:
933933
"""
934934
Args:
935935
936-
multiplier (float, optional): A float value used to scale the input.
936+
multiplier (float, optional): A float value used to scale the input.
937937
"""
938938
super().__init__()
939939
self.multiplier = multiplier
@@ -947,7 +947,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
947947
x (torch.Tensor): Input to scale values of.
948948
949949
Returns:
950-
**tensor** (torch.Tensor): tensor with it's values scaled.
950+
tensor (torch.Tensor): tensor with it's values scaled.
951951
"""
952952
return x * self.multiplier
953953

@@ -966,7 +966,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
966966
x (torch.Tensor): RGB image tensor to convert to BGR.
967967
968968
Returns:
969-
**BGR tensor** (torch.Tensor): A BGR tensor.
969+
BGR tensor (torch.Tensor): A BGR tensor.
970970
"""
971971
assert x.dim() == 4
972972
assert x.size(1) == 3
@@ -1104,7 +1104,7 @@ def forward(
11041104
x (torch.Tensor): Input to apply symmetric padding on.
11051105
11061106
Returns:
1107-
**tensor** (torch.Tensor): Padded tensor.
1107+
tensor (torch.Tensor): Padded tensor.
11081108
"""
11091109
ctx.padding = padding
11101110
x_device = x.device
@@ -1127,7 +1127,7 @@ def backward(
11271127
grad_output (torch.Tensor): Input to remove symmetric padding from.
11281128
11291129
Returns:
1130-
**grad_input** (torch.Tensor): Unpadded tensor.
1130+
grad_input (torch.Tensor): Unpadded tensor.
11311131
"""
11321132
grad_input = grad_output.clone()
11331133
B, C, H, W = grad_input.size()
@@ -1166,7 +1166,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11661166
x (torch.Tensor): Input to reduce channel dimensions on.
11671167
11681168
Returns:
1169-
**3 channel RGB tensor** (torch.Tensor): RGB image tensor.
1169+
x (torch.Tensor): A 3 channel RGB image tensor.
11701170
"""
11711171
assert x.dim() == 4
11721172
return nchannels_to_rgb(x, self.warp)
@@ -1216,6 +1216,16 @@ def _center_crop(self, x: torch.Tensor) -> torch.Tensor:
12161216
]
12171217

12181218
def forward(self, x: torch.Tensor) -> torch.Tensor:
1219+
"""
1220+
Randomly crop an NCHW image tensor.
1221+
1222+
Args:
1223+
1224+
x (torch.Tensor): The NCHW image tensor to randomly crop.
1225+
1226+
Returns
1227+
x (torch.Tensor): The randomly cropped NCHW image tensor.
1228+
"""
12191229
assert x.dim() == 4
12201230
hs = int(math.ceil((x.shape[2] - self.crop_size[0]) / 2.0))
12211231
ws = int(math.ceil((x.shape[3] - self.crop_size[1]) / 2.0))

captum/optim/_utils/circuits.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,24 @@ def extract_expanded_weights(
2020
literally adjacent in a neural network, or where the weights aren’t directly
2121
represented in a single weight tensor.
2222
23+
Example::
24+
25+
>>> # Load InceptionV1 model with nonlinear layers replaced by
26+
>>> # their linear equivalents
27+
>>> linear_model = opt.models.googlenet(
28+
>>> pretrained=True, use_linear_modules_only=True
29+
>>> ).eval()
30+
>>> # Extract weight interactions between target layers
31+
>>> W_3a_3b = opt.circuits.extract_expanded_weights(
32+
>>> linear_model, linear_model.mixed3a, linear_model.mixed3b, 5
33+
>>> )
34+
>>> # Display results for channel 147 of mixed3a and channel 379 of
35+
>>> # mixed3b, in human readable format
36+
>>> W_3a_3b_hm = opt.weights_to_heatmap_2d(
37+
>>> W_3a_3b[379, 147, ...] / W_3a_3b[379, ...].max()
38+
>>> )
39+
>>> opt.show(W_3a_3b_hm)
40+
2341
Voss, et al., "Visualizing Weights", Distill, 2021.
2442
See: https://distill.pub/2020/circuits/visualizing-weights/
2543

captum/optim/_utils/image/atlas.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,14 @@ def compute_avg_cell_samples(
146146
Default: ``8``
147147
148148
Returns:
149-
cell_vecs (torch.tensor): A tensor containing all the direction vectors that
150-
were created, stacked along the batch dimension with a shape of:
151-
[n_vecs, n_channels].
152-
cell_coords (list of Tuple[int, int, int]): List of coordinates for grid
153-
spatial positions of each direction vector, and the number of samples used
154-
for the cell. The list for each cell is in the format of:
155-
[x_coord, y_coord, number_of_samples_used].
149+
cell_vecs_and_cell_coords: A 2 element tuple of: ``(cell_vecs, cell_coords)``.
150+
- cell_vecs (torch.tensor): A tensor containing all the direction vectors
151+
that were created, stacked along the batch dimension with a shape of:
152+
[n_vecs, n_channels].
153+
- cell_coords (list of Tuple[int, int, int]): List of coordinates for grid
154+
spatial positions of each direction vector, and the number of samples
155+
used for the cell. The list for each cell is in the format of:
156+
[x_coord, y_coord, number_of_samples_used].
156157
"""
157158
assert raw_samples.dim() == 2
158159

@@ -205,13 +206,14 @@ def create_atlas_vectors(
205206
Default: ``(0.0, 1.0)``
206207
207208
Returns:
208-
grid_vecs (torch.tensor): A tensor containing all the direction vectors that
209-
were created, stacked along the batch dimension, with a shape of:
210-
[n_vecs, n_channels].
211-
cell_coords (list of Tuple[int, int, int]): List of coordinates for grid
212-
spatial positions of each direction vector, and the number of samples used
213-
for the cell. The list for each cell is in the format of:
214-
[x_coord, y_coord, number_of_samples_used].
209+
grid_vecs_and_cell_coords: A 2 element tuple of: ``(grid_vecs, cell_coords)``.
210+
- grid_vecs (torch.tensor): A tensor containing all the direction vectors
211+
that were created, stacked along the batch dimension, with a shape
212+
of: [n_vecs, n_channels].
213+
- cell_coords (list of Tuple[int, int, int]): List of coordinates for grid
214+
spatial positions of each direction vector, and the number of samples
215+
used for the cell. The list for each cell is in the format of:
216+
[x_coord, y_coord, number_of_samples_used].
215217
"""
216218

217219
assert xy_grid.dim() == 2 and xy_grid.size(1) == 2

captum/optim/_utils/image/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def weights_to_heatmap_2d(
348348
Default: ``["0571b0", "92c5de", "f7f7f7", "f4a582", "ca0020"]``
349349
350350
Returns:
351-
color_tensor (torch.Tensor): A weight heatmap.
351+
color_tensor (torch.Tensor): A weight heatmap.
352352
"""
353353

354354
assert weight.dim() == 2

captum/optim/_utils/reducer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ class ChannelReducer:
2121
2222
See here for more information: https://distill.pub/2018/building-blocks/
2323
24+
Example::
25+
26+
>>> reducer = opt.reducer.ChannelReducer(2, "NMF")
27+
>>> x = torch.randn(1, 8, 128, 128).abs()
28+
>>> output = reducer.fit_transform(x)
29+
>>> print(output.shape)
30+
torch.Size([1, 2, 128, 128])
31+
2432
Args:
2533
2634
n_components (int, optional): The number of channels to reduce the target
@@ -30,7 +38,7 @@ class ChannelReducer:
3038
from sklearn, which requires users to put inputs on CPU before passing them
3139
to :func:`ChannelReducer.fit_transform`.
3240
Default: ``NMF``
33-
**kwargs (optional): Arbitrary keyword arguments used by the specified
41+
**kwargs (any, optional): Arbitrary keyword arguments used by the specified
3442
reduction_alg.
3543
"""
3644

@@ -72,7 +80,8 @@ def fit_transform(
7280
self, x: torch.Tensor, swap_2nd_and_last_dims: bool = True
7381
) -> torch.Tensor:
7482
"""
75-
Perform dimensionality reduction on an input tensor.
83+
Perform dimensionality reduction on an input tensor using the specified
84+
``reduction_alg``'s ``.fit_transform`` function.
7685
7786
Args:
7887

0 commit comments

Comments
 (0)