@@ -20,7 +20,7 @@ def __init__(self, background: Optional[torch.Tensor] = None) -> None:
20
20
"""
21
21
Args:
22
22
23
- background (tensor, optional): An NCHW image tensor to be used as the
23
+ background (tensor, optional): An NCHW image tensor to be used as the
24
24
Alpha channel's background.
25
25
Default: ``None``
26
26
"""
@@ -36,7 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
36
x (torch.Tensor): RGBA image tensor to blend into an RGB image tensor.
37
37
38
38
Returns:
39
- ** blended** (torch.Tensor): RGB image tensor.
39
+ blended (torch.Tensor): RGB image tensor.
40
40
"""
41
41
assert x .dim () == 4
42
42
assert x .size (1 ) == 4
@@ -60,7 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
60
x (torch.Tensor): RGBA image tensor.
61
61
62
62
Returns:
63
- ** rgb** (torch.Tensor): RGB image tensor without the alpha channel.
63
+ rgb (torch.Tensor): RGB image tensor without the alpha channel.
64
64
"""
65
65
assert x .dim () == 4
66
66
assert x .size (1 ) == 4
@@ -101,7 +101,7 @@ def klt_transform() -> torch.Tensor:
101
101
Karhunen-Loève transform (KLT) measured on ImageNet
102
102
103
103
Returns:
104
- ** transform** (torch.Tensor): A Karhunen-Loève transform (KLT) measured on
104
+ transform (torch.Tensor): A Karhunen-Loève transform (KLT) measured on
105
105
the ImageNet dataset.
106
106
"""
107
107
# Handle older versions of PyTorch
@@ -120,7 +120,7 @@ def klt_transform() -> torch.Tensor:
120
120
def i1i2i3_transform () -> torch .Tensor :
121
121
"""
122
122
Returns:
123
- ** transform** (torch.Tensor): An approximation of natural colors transform
123
+ transform (torch.Tensor): An approximation of natural colors transform
124
124
(i1i2i3).
125
125
"""
126
126
i1i2i3_matrix = [
@@ -134,7 +134,7 @@ def __init__(self, transform: Union[str, torch.Tensor] = "klt") -> None:
134
134
"""
135
135
Args:
136
136
137
- transform (str or tensor): Either a string for one of the precalculated
137
+ transform (str or tensor): Either a string for one of the precalculated
138
138
transform matrices, or a 3x3 matrix for the 3 RGB channels of input
139
139
tensors.
140
140
"""
@@ -352,7 +352,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
352
352
input (torch.Tensor): Input to center crop.
353
353
354
354
Returns:
355
- ** tensor** (torch.Tensor): A center cropped * tensor* .
355
+ tensor (torch.Tensor): A center cropped NCHW tensor.
356
356
"""
357
357
358
358
return center_crop (
@@ -402,7 +402,7 @@ def center_crop(
402
402
Default: ``0.0``
403
403
404
404
Returns:
405
- ** tensor** (torch.Tensor): A center cropped * tensor* .
405
+ tensor (torch.Tensor): A center cropped NCHW tensor.
406
406
"""
407
407
408
408
assert input .dim () == 3 or input .dim () == 4
@@ -537,7 +537,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
537
537
scale (float): The amount to scale the NCHW image by.
538
538
539
539
Returns:
540
- **x** (torch.Tensor): A scaled NCHW image tensor.
540
+ x (torch.Tensor): A scaled NCHW image tensor.
541
541
"""
542
542
if self ._has_antialias :
543
543
x = F .interpolate (
@@ -567,7 +567,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
567
567
x (torch.Tensor): NCHW image tensor to randomly scale.
568
568
569
569
Returns:
570
- **x** (torch.Tensor): A randomly scaled NCHW image * tensor* .
570
+ x (torch.Tensor): A randomly scaled NCHW image tensor.
571
571
"""
572
572
assert x .dim () == 4
573
573
if self ._is_distribution :
@@ -669,7 +669,7 @@ def _get_scale_mat(
669
669
m (float): The scale value to use.
670
670
671
671
Returns:
672
- ** scale_mat** (torch.Tensor): A scale matrix.
672
+ scale_mat (torch.Tensor): A scale matrix.
673
673
"""
674
674
scale_mat = torch .tensor (
675
675
[[m , 0.0 , 0.0 ], [0.0 , m , 0.0 ]], device = device , dtype = dtype
@@ -686,7 +686,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
686
686
scale (float): The amount to scale the NCHW image by.
687
687
688
688
Returns:
689
- **x** (torch.Tensor): A scaled NCHW image tensor.
689
+ x (torch.Tensor): A scaled NCHW image tensor.
690
690
"""
691
691
scale_matrix = self ._get_scale_mat (scale , x .device , x .dtype )[None , ...].repeat (
692
692
x .shape [0 ], 1 , 1
@@ -710,7 +710,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
710
710
x (torch.Tensor): NCHW image tensor to randomly scale.
711
711
712
712
Returns:
713
- **x** (torch.Tensor): A randomly scaled NCHW image * tensor* .
713
+ x (torch.Tensor): A randomly scaled NCHW image tensor.
714
714
"""
715
715
assert x .dim () == 4
716
716
if self ._is_distribution :
@@ -768,7 +768,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
768
768
input (torch.Tensor): Input to randomly translate.
769
769
770
770
Returns:
771
- ** tensor** (torch.Tensor): A randomly translated * tensor* .
771
+ tensor (torch.Tensor): A randomly translated NCHW tensor.
772
772
"""
773
773
insets = torch .randint (
774
774
high = self .pad_range ,
@@ -854,7 +854,7 @@ def _get_rot_mat(
854
854
theta (float): The rotation value in degrees.
855
855
856
856
Returns:
857
- ** rot_mat** (torch.Tensor): A rotation matrix.
857
+ rot_mat (torch.Tensor): A rotation matrix.
858
858
"""
859
859
theta = theta * math .pi / 180.0
860
860
rot_mat = torch .tensor (
@@ -877,7 +877,7 @@ def _rotate_tensor(self, x: torch.Tensor, theta: float) -> torch.Tensor:
877
877
theta (float): The amount to rotate the NCHW image, in degrees.
878
878
879
879
Returns:
880
- **x** (torch.Tensor): A rotated NCHW image tensor.
880
+ x (torch.Tensor): A rotated NCHW image tensor.
881
881
"""
882
882
rot_matrix = self ._get_rot_mat (theta , x .device , x .dtype )[None , ...].repeat (
883
883
x .shape [0 ], 1 , 1
@@ -901,7 +901,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
901
901
x (torch.Tensor): NCHW image tensor to randomly rotate.
902
902
903
903
Returns:
904
- **x** (torch.Tensor): A randomly rotated NCHW image * tensor* .
904
+ x (torch.Tensor): A randomly rotated NCHW image tensor.
905
905
"""
906
906
assert x .dim () == 4
907
907
if self ._is_distribution :
@@ -933,7 +933,7 @@ def __init__(self, multiplier: float = 1.0) -> None:
933
933
"""
934
934
Args:
935
935
936
- multiplier (float, optional): A float value used to scale the input.
936
+ multiplier (float, optional): A float value used to scale the input.
937
937
"""
938
938
super ().__init__ ()
939
939
self .multiplier = multiplier
@@ -947,7 +947,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
947
947
x (torch.Tensor): Input to scale values of.
948
948
949
949
Returns:
950
- ** tensor** (torch.Tensor): tensor with it's values scaled.
950
+ tensor (torch.Tensor): tensor with it's values scaled.
951
951
"""
952
952
return x * self .multiplier
953
953
@@ -966,7 +966,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
966
966
x (torch.Tensor): RGB image tensor to convert to BGR.
967
967
968
968
Returns:
969
- ** BGR tensor** (torch.Tensor): A BGR tensor.
969
+ BGR tensor (torch.Tensor): A BGR tensor.
970
970
"""
971
971
assert x .dim () == 4
972
972
assert x .size (1 ) == 3
@@ -1104,7 +1104,7 @@ def forward(
1104
1104
x (torch.Tensor): Input to apply symmetric padding on.
1105
1105
1106
1106
Returns:
1107
- ** tensor** (torch.Tensor): Padded tensor.
1107
+ tensor (torch.Tensor): Padded tensor.
1108
1108
"""
1109
1109
ctx .padding = padding
1110
1110
x_device = x .device
@@ -1127,7 +1127,7 @@ def backward(
1127
1127
grad_output (torch.Tensor): Input to remove symmetric padding from.
1128
1128
1129
1129
Returns:
1130
- ** grad_input** (torch.Tensor): Unpadded tensor.
1130
+ grad_input (torch.Tensor): Unpadded tensor.
1131
1131
"""
1132
1132
grad_input = grad_output .clone ()
1133
1133
B , C , H , W = grad_input .size ()
@@ -1166,7 +1166,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
1166
x (torch.Tensor): Input to reduce channel dimensions on.
1167
1167
1168
1168
Returns:
1169
- **3 channel RGB tensor** (torch.Tensor): RGB image tensor.
1169
+ x (torch.Tensor): A 3 channel RGB image tensor.
1170
1170
"""
1171
1171
assert x .dim () == 4
1172
1172
return nchannels_to_rgb (x , self .warp )
@@ -1216,6 +1216,16 @@ def _center_crop(self, x: torch.Tensor) -> torch.Tensor:
1216
1216
]
1217
1217
1218
1218
def forward (self , x : torch .Tensor ) -> torch .Tensor :
1219
+ """
1220
+ Randomly crop an NCHW image tensor.
1221
+
1222
+ Args:
1223
+
1224
+ x (torch.Tensor): The NCHW image tensor to randomly crop.
1225
+
1226
+ Returns
1227
+ x (torch.Tensor): The randomly cropped NCHW image tensor.
1228
+ """
1219
1229
assert x .dim () == 4
1220
1230
hs = int (math .ceil ((x .shape [2 ] - self .crop_size [0 ]) / 2.0 ))
1221
1231
ws = int (math .ceil ((x .shape [3 ] - self .crop_size [1 ]) / 2.0 ))
0 commit comments