Skip to content

Commit 953780e

Browse files
authored
Fix TransformationRobustness doc formatting & add missing RedirectedReLU forward docs
1 parent 07c9e60 commit 953780e

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

captum/optim/_param/image/transforms.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
12511251
return self._center_crop(x)
12521252

12531253

1254+
# Define TransformationRobustness defaults externally for easier Sphinx docs formatting
1255+
_TR_TRANSLATE: List[int] = [4] * 10
1256+
_TR_SCALE: List[float] = [0.995**n for n in range(-5, 80)] + [
1257+
0.998**n for n in 2 * list(range(20, 40))
1258+
]
1259+
_TR_DEGREES: List[int] = (
1260+
list(range(-20, 20)) + list(range(-10, 10)) + list(range(-5, 5)) + 5 * [0]
1261+
)
1262+
1263+
12541264
class TransformationRobustness(nn.Module):
12551265
"""
12561266
This transform combines the standard transforms (:class:`.RandomSpatialJitter`,
@@ -1269,15 +1279,9 @@ class TransformationRobustness(nn.Module):
12691279
def __init__(
12701280
self,
12711281
padding_transform: Optional[nn.Module] = nn.ConstantPad2d(2, value=0.5),
1272-
translate: Optional[Union[int, List[int]]] = [4] * 10,
1273-
scale: Optional[NumSeqOrTensorOrProbDistType] = [
1274-
0.995**n for n in range(-5, 80)
1275-
]
1276-
+ [0.998**n for n in 2 * list(range(20, 40))],
1277-
degrees: Optional[NumSeqOrTensorOrProbDistType] = list(range(-20, 20))
1278-
+ list(range(-10, 10))
1279-
+ list(range(-5, 5))
1280-
+ 5 * [0],
1282+
translate: Optional[Union[int, List[int]]] = _TR_TRANSLATE,
1283+
scale: Optional[NumSeqOrTensorOrProbDistType] = _TR_SCALE,
1284+
degrees: Optional[NumSeqOrTensorOrProbDistType] = _TR_DEGREES,
12811285
final_translate: Optional[int] = 2,
12821286
crop_or_pad_output: bool = False,
12831287
) -> None:

captum/optim/_utils/circuits.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def extract_expanded_weights(
4444
Args:
4545
4646
model (nn.Module): The reference to PyTorch model instance.
47-
target1 (nn.module): The starting target layer. Must be below the layer
47+
target1 (nn.Module): The starting target layer. Must be below the layer
4848
specified for ``target2``.
4949
target2 (nn.Module): The end target layer. Must be above the layer
5050
specified for ``target1``.

captum/optim/models/_common.py

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ class RedirectedReluLayer(nn.Module):
6868

6969
@torch.jit.ignore
7070
def forward(self, input: torch.Tensor) -> torch.Tensor:
71+
"""
72+
Args:
73+
74+
x (torch.Tensor): A tensor to pass through RedirectedReLU.
75+
76+
Returns:
77+
x (torch.Tensor): The output of RedirectedReLU.
78+
"""
7179
return RedirectedReLU.apply(input)
7280

7381

0 commit comments

Comments
 (0)