Skip to content

Commit bb3ff16

Browse files
committed
fix: support masked_scatter by lowering path and corner case of masked_scatter
1 parent 79083b6 commit bb3ff16

File tree

5 files changed

+178
-1
lines changed

5 files changed

+178
-1
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+43
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ def create_constant(
344344
with unset_fake_temporarily():
345345

346346
torch_value = to_torch(value, dtype)
347+
if torch_value is None:
348+
raise ValueError(
349+
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
350+
)
347351
if torch_value.dtype == torch.float64:
348352
raise ValueError(
349353
"TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model."
@@ -1065,3 +1069,42 @@ def load_tensorrt_llm() -> bool:
10651069
)
10661070
return False
10671071
return False
1072+
1073+
1074+
def promote_trt_tensors_to_same_dtype(
1075+
ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str
1076+
) -> tuple[TRTTensor, TRTTensor]:
1077+
"""
1078+
Promotes two TensorRT tensors to a common data type to ensure type compatibility
1079+
during operations (e.g., select, where, etc.), following simplified PyTorch promotion rules.
1080+
1081+
Args:
1082+
ctx: Conversion context containing the TRT network definition.
1083+
lhs: The left-hand-side TensorRT tensor.
1084+
rhs: The right-hand-side TensorRT tensor.
1085+
name_prefix: A prefix string used to name any cast operations.
1086+
1087+
Returns:
1088+
A tuple of (lhs_cast, rhs_cast) TensorRT tensors, both cast to the promoted dtype.
1089+
"""
1090+
1091+
# Define supported float types (TensorRT supports float16 and float32)
1092+
float_types = {trt.float16, trt.float32}
1093+
1094+
# Case 1: If either tensor is a float, promote to the wider float type
1095+
if lhs.dtype in float_types or rhs.dtype in float_types:
1096+
# Prefer float32 if either tensor is float32
1097+
if lhs.dtype == trt.float32 or rhs.dtype == trt.float32:
1098+
promoted_dtype = trt.float32
1099+
else:
1100+
promoted_dtype = trt.float16
1101+
else:
1102+
# Case 2: If both tensors are int types (e.g., int32, int64), promote to int32
1103+
# (Note: TensorRT does not support int64 for many ops like select/where)
1104+
promoted_dtype = trt.int32
1105+
1106+
# Cast both tensors to the promoted dtype
1107+
lhs_cast = cast_trt_tensor(ctx, lhs, promoted_dtype, f"{name_prefix}lhs_cast")
1108+
rhs_cast = cast_trt_tensor(ctx, rhs, promoted_dtype, f"{name_prefix}rhs_cast")
1109+
1110+
return lhs_cast, rhs_cast

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
cast_trt_tensor,
1212
get_trt_tensor,
1313
prepend_ones,
14+
promote_trt_tensors_to_same_dtype,
1415
set_layer_name,
1516
)
1617
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
@@ -57,6 +58,9 @@ def where(
5758
if diff > 0:
5859
other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff)
5960

61+
# Ensure that input and other have the same TRT dtype
62+
input, other = promote_trt_tensors_to_same_dtype(ctx, input, other, name)
63+
6064
return select(ctx, target, source_ir, name, input, other, condition)
6165

6266

py/torch_tensorrt/dynamo/conversion/impl/select.py

+1
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def gather(
462462
) -> TRTTensor:
463463
input_shape = input.shape
464464
dim = get_positive_dim(dim, len(input_shape))
465+
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
465466
gather_layer = ctx.net.add_gather(input, index, axis=dim)
466467
gather_layer.mode = trt.GatherMode.ELEMENT
467468
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,10 @@ def slice_scatter_decomposition(
196196
) -> torch.Tensor:
197197
dim_size = input_tensor.shape[dim]
198198
device_input_tensor = input_tensor.device
199+
200+
start = 0 if start is None else start # Ensure start is int
199201
start = get_positive_dim(start, input_tensor.shape[dim])
200-
if end is None:
202+
if end is None: # Ensure end is int
201203
end = dim_size
202204
end = get_positive_dim(end, input_tensor.shape[dim])
203205
if step is None:
@@ -575,6 +577,53 @@ def cudnn_grid_sampler_decomposition(
575577
return torch.grid_sampler_2d(x, grid, 0, 0, True)
576578

577579

580+
@register_torch_trt_decomposition(
581+
aten.masked_scatter, registry=TORCH_TRT_DECOMPOSITIONS
582+
)
583+
def masked_scatter_decomposition(
584+
input: torch.Tensor,
585+
mask: torch.Tensor,
586+
source: torch.Tensor,
587+
) -> torch.Tensor:
588+
"""
589+
Decomposition of `aten.masked_scatter` for TensorRT.
590+
591+
Emulates the behavior of `input[mask] = source` using only TensorRT-compatible ops.
592+
593+
Steps:
594+
1) Broadcast `input` and `mask` to a common shape.
595+
2) Flatten all tensors for uniform indexing.
596+
3) Compute gather indices for `source` by applying cumsum to the boolean mask.
597+
- Use `masked_fill` to avoid invalid indices in positions where `mask` is False.
598+
4) Gather values from `source` at valid positions.
599+
5) Use `torch.where` to insert gathered values into `input` where `mask` is True.
600+
6) Reshape the result back to the original broadcasted shape.
601+
"""
602+
603+
# 1) Broadcast input and mask to the same shape
604+
input_b, mask_b = aten.broadcast_tensors([input, mask])
605+
606+
# 2) Flatten tensors for element-wise operations
607+
input_flat = input_b.flatten()
608+
mask_flat = mask_b.flatten()
609+
source_flat = source.flatten()
610+
611+
# 3) Compute gather indices from cumsum of the mask
612+
# Subtract 1 so that the first True position maps to index 0 in source
613+
source_idx = mask_flat.cumsum(0) - 1
614+
# Set gather index to 0 where mask is False (these will be ignored later)
615+
safe_idx = source_idx.masked_fill(~mask_flat, 0)
616+
617+
# 4) Gather values from source using computed indices
618+
gathered = source_flat.gather(0, safe_idx)
619+
620+
# 5) Replace masked positions in input with gathered values
621+
replaced = torch.where(mask_flat, gathered, input_flat)
622+
623+
# 6) Reshape the result to match the original broadcasted shape
624+
return replaced.view(input_b.shape)
625+
626+
578627
def get_decompositions(
579628
enable_experimental_decompositions: bool = False,
580629
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

+80
Original file line numberDiff line numberDiff line change
@@ -2167,6 +2167,86 @@ def forward(self, x, grid):
21672167
msg="Cudnn_grid_sampler TRT outputs don't match with the original model.",
21682168
)
21692169

2170+
@parameterized.expand(
2171+
[
2172+
("float32_2d", torch.float32, (4, 4)),
2173+
("float16_3d", torch.float16, (2, 3, 4)),
2174+
]
2175+
)
2176+
def test_masked_scatter(self, _, dtype, shape):
2177+
"""
2178+
Test that masked_scatter.default is correctly decomposed into
2179+
(cumsum, gather, where, etc.) and that final TRT results match PyTorch.
2180+
"""
2181+
2182+
class TestModule(torch.nn.Module):
2183+
def __init__(self):
2184+
super().__init__()
2185+
2186+
def forward(self, x, mask, source):
2187+
return torch.ops.aten.masked_scatter.default(x, mask, source)
2188+
2189+
x = torch.randn(*shape, dtype=dtype, device="cuda")
2190+
2191+
mask = torch.rand(*shape, device="cuda") > 0.5
2192+
num_trues = mask.sum().item()
2193+
if num_trues == 0:
2194+
mask[0] = True
2195+
num_trues = 1
2196+
source = torch.arange(num_trues, dtype=dtype, device="cuda")
2197+
2198+
inputs = [x, mask, source]
2199+
2200+
fx_graph = torch.fx.symbolic_trace(TestModule())
2201+
2202+
expected_ops = {
2203+
torch.ops.aten.where.self,
2204+
torch.ops.aten.gather.default,
2205+
torch.ops.aten.cumsum.default,
2206+
}
2207+
unexpected_ops = {torch.ops.aten.masked_scatter.default}
2208+
2209+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
2210+
fx_graph,
2211+
inputs,
2212+
expected_ops=expected_ops,
2213+
unexpected_ops=unexpected_ops,
2214+
min_block_size=1,
2215+
)
2216+
2217+
self.assertEqual(
2218+
len(unexpected_ops_seen),
2219+
0,
2220+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
2221+
)
2222+
2223+
self.assertEqual(
2224+
len(expected_ops_unseen),
2225+
0,
2226+
f"The following expected ops were not encountered: {expected_ops_unseen}",
2227+
)
2228+
2229+
torch._dynamo.reset()
2230+
2231+
trt_model = torch_tensorrt.compile(
2232+
fx_graph,
2233+
"torch_compile",
2234+
inputs,
2235+
min_block_size=1,
2236+
pass_through_build_failures=True,
2237+
)
2238+
with torch.no_grad():
2239+
trt_results = trt_model(*inputs).detach().cpu()
2240+
torch_results = fx_graph(*inputs).detach().cpu()
2241+
2242+
max_diff = float(torch.max(torch.abs(trt_results - torch_results)))
2243+
self.assertAlmostEqual(
2244+
max_diff,
2245+
0,
2246+
DECIMALS_OF_AGREEMENT,
2247+
f"Masked_scatter TRT outputs don't match with the original model. (diff={max_diff})",
2248+
)
2249+
21702250

21712251
if __name__ == "__main__":
21722252
run_tests()

0 commit comments

Comments
 (0)