Skip to content

Commit 82ca0a3

Browse files
chohk88Chengzhe Xu
authored and
Chengzhe Xu
committed
feat: support masked_scatter by lowering path
1 parent 2e82d17 commit 82ca0a3

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

Diff for: py/torch_tensorrt/dynamo/conversion/impl/select.py

+1
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def gather(
462462
) -> TRTTensor:
463463
input_shape = input.shape
464464
dim = get_positive_dim(dim, len(input_shape))
465+
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
465466
gather_layer = ctx.net.add_gather(input, index, axis=dim)
466467
gather_layer.mode = trt.GatherMode.ELEMENT
467468
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)

Diff for: py/torch_tensorrt/dynamo/lowering/_decompositions.py

+40
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,46 @@ def cudnn_grid_sampler_decomposition(
575575
return torch.grid_sampler_2d(x, grid, 0, 0, True)
576576

577577

578+
@register_torch_trt_decomposition(
579+
aten.masked_scatter, registry=TORCH_TRT_DECOMPOSITIONS
580+
)
581+
def masked_scatter_decomposition(
582+
input: torch.Tensor,
583+
mask: torch.Tensor,
584+
source: torch.Tensor,
585+
) -> torch.Tensor:
586+
"""
587+
Performs an operation equivalent to `input[mask] = source`.
588+
Steps:
589+
1) Broadcast `input` and `mask` to a common shape
590+
2) Flatten them
591+
3) Convert `mask` to int64, compute its cumsum, and subtract 1 to get gather indices
592+
4) Use `gather` to select elements from `source`
593+
5) Use `torch.where` to place gathered elements where `mask` is True
594+
6) Reshape the result to the original shape
595+
"""
596+
597+
# 1) Broadcast `input` and `mask` to a common shape
598+
input_b, mask_b = aten.broadcast_tensors([input, mask])
599+
600+
# 2) Flatten the broadcasted tensors and the source tensor
601+
input_flat = input_b.flatten()
602+
mask_flat = mask_b.flatten()
603+
source_flat = source.flatten()
604+
605+
# 3) Compute gather indices: (cumsum of mask as int64) - 1
606+
source_idx = mask_flat.to(torch.int64).cumsum(0) - 1
607+
608+
# 4) Gather elements from source_flat using these indices
609+
gathered = source_flat.gather(0, source_idx)
610+
611+
# 5) Replace positions where mask is True with gathered values, otherwise keep original
612+
replaced = torch.where(mask_flat, gathered, input_flat)
613+
614+
# 6) Reshape the result back to the broadcasted shape
615+
return replaced.view(input_b.shape)
616+
617+
578618
def get_decompositions(
579619
enable_experimental_decompositions: bool = False,
580620
) -> Dict[OpOverload, Callable[[Any], Any]]:

Diff for: tests/py/dynamo/lowering/test_decompositions.py

+81
Original file line numberDiff line numberDiff line change
@@ -2168,5 +2168,86 @@ def forward(self, x, grid):
21682168
)
21692169

21702170

2171+
@parameterized.expand(
2172+
[
2173+
("float32_2d", torch.float32, (4, 4)),
2174+
("float16_3d", torch.float16, (2, 3, 4)),
2175+
]
2176+
)
2177+
def test_masked_scatter(self, _, dtype, shape):
2178+
"""
2179+
Test that masked_scatter.default is correctly decomposed into
2180+
(cumsum, gather, where, etc.) and that final TRT results match PyTorch.
2181+
"""
2182+
2183+
class TestModule(torch.nn.Module):
2184+
def __init__(self):
2185+
super().__init__()
2186+
2187+
def forward(self, x, mask, source):
2188+
return torch.ops.aten.masked_scatter.default(x, mask, source)
2189+
2190+
x = torch.randn(*shape, dtype=dtype, device="cuda")
2191+
2192+
mask = torch.rand(*shape, device="cuda") > 0.5
2193+
num_trues = mask.sum().item()
2194+
if num_trues == 0:
2195+
mask[0] = True
2196+
num_trues = 1
2197+
source = torch.arange(num_trues, dtype=dtype, device="cuda")
2198+
2199+
inputs = [x, mask, source]
2200+
2201+
fx_graph = torch.fx.symbolic_trace(TestModule())
2202+
2203+
expected_ops = {
2204+
torch.ops.aten.where.self,
2205+
torch.ops.aten.gather.default,
2206+
torch.ops.aten.cumsum.default,
2207+
}
2208+
unexpected_ops = {torch.ops.aten.masked_scatter.default}
2209+
2210+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
2211+
fx_graph,
2212+
inputs,
2213+
expected_ops=expected_ops,
2214+
unexpected_ops=unexpected_ops,
2215+
min_block_size=1,
2216+
)
2217+
2218+
self.assertEqual(
2219+
len(unexpected_ops_seen),
2220+
0,
2221+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
2222+
)
2223+
2224+
self.assertEqual(
2225+
len(expected_ops_unseen),
2226+
0,
2227+
f"The following expected ops were not encountered: {expected_ops_unseen}",
2228+
)
2229+
2230+
torch._dynamo.reset()
2231+
2232+
trt_model = torch_tensorrt.compile(
2233+
fx_graph,
2234+
"torch_compile",
2235+
inputs,
2236+
min_block_size=1,
2237+
pass_through_build_failures=True,
2238+
)
2239+
with torch.no_grad():
2240+
trt_results = trt_model(*inputs).detach().cpu()
2241+
torch_results = fx_graph(*inputs).detach().cpu()
2242+
2243+
max_diff = float(torch.max(torch.abs(trt_results - torch_results)))
2244+
self.assertAlmostEqual(
2245+
max_diff,
2246+
0,
2247+
DECIMALS_OF_AGREEMENT,
2248+
f"Masked_scatter TRT outputs don't match with the original model. (diff={max_diff})",
2249+
)
2250+
2251+
21712252
if __name__ == "__main__":
21722253
run_tests()

0 commit comments

Comments
 (0)