@@ -1251,6 +1251,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1251
1251
return self ._center_crop (x )
1252
1252
1253
1253
1254
+ # Define TransformationRobustness defaults externally for easier Sphinx docs formatting
1255
+ _TR_TRANSLATE : List [int ] = [4 ] * 10
1256
+ _TR_SCALE : List [float ] = [0.995 ** n for n in range (- 5 , 80 )] + [
1257
+ 0.998 ** n for n in 2 * list (range (20 , 40 ))
1258
+ ]
1259
+ _TR_DEGREES : List [int ] = (
1260
+ list (range (- 20 , 20 )) + list (range (- 10 , 10 )) + list (range (- 5 , 5 )) + 5 * [0 ]
1261
+ )
1262
+
1263
+
1254
1264
class TransformationRobustness (nn .Module ):
1255
1265
"""
1256
1266
This transform combines the standard transforms (:class:`.RandomSpatialJitter`,
@@ -1269,15 +1279,9 @@ class TransformationRobustness(nn.Module):
1269
1279
def __init__ (
1270
1280
self ,
1271
1281
padding_transform : Optional [nn .Module ] = nn .ConstantPad2d (2 , value = 0.5 ),
1272
- translate : Optional [Union [int , List [int ]]] = [4 ] * 10 ,
1273
- scale : Optional [NumSeqOrTensorOrProbDistType ] = [
1274
- 0.995 ** n for n in range (- 5 , 80 )
1275
- ]
1276
- + [0.998 ** n for n in 2 * list (range (20 , 40 ))],
1277
- degrees : Optional [NumSeqOrTensorOrProbDistType ] = list (range (- 20 , 20 ))
1278
- + list (range (- 10 , 10 ))
1279
- + list (range (- 5 , 5 ))
1280
- + 5 * [0 ],
1282
+ translate : Optional [Union [int , List [int ]]] = _TR_TRANSLATE ,
1283
+ scale : Optional [NumSeqOrTensorOrProbDistType ] = _TR_SCALE ,
1284
+ degrees : Optional [NumSeqOrTensorOrProbDistType ] = _TR_DEGREES ,
1281
1285
final_translate : Optional [int ] = 2 ,
1282
1286
crop_or_pad_output : bool = False ,
1283
1287
) -> None :
0 commit comments