Skip to content

Commit 946c3c1

Browse files
committed
feat: support masked_scatter by lowering path
1 parent 26ea41e commit 946c3c1

File tree

3 files changed

+122
-2
lines changed

3 files changed

+122
-2
lines changed

Diff for: py/torch_tensorrt/dynamo/conversion/impl/select.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Sequence, Union
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -23,8 +24,6 @@
2324
)
2425
from torch_tensorrt.fx.types import TRTTensor
2526

26-
import tensorrt as trt
27-
2827
_LOGGER: logging.Logger = logging.getLogger(__name__)
2928

3029

@@ -463,6 +462,7 @@ def gather(
463462
) -> TRTTensor:
464463
input_shape = input.shape
465464
dim = get_positive_dim(dim, len(input_shape))
465+
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
466466
gather_layer = ctx.net.add_gather(input, index, axis=dim)
467467
gather_layer.mode = trt.GatherMode.ELEMENT
468468
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)

Diff for: py/torch_tensorrt/dynamo/lowering/_decompositions.py

+40
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,46 @@ def scaled_dot_product_cudnn_attention_decomposition(
566566
return attn, None, None, None, 0, 0, None, None, None
567567

568568

569+
@register_torch_trt_decomposition(
570+
aten.masked_scatter, registry=TORCH_TRT_DECOMPOSITIONS
571+
)
572+
def masked_scatter_decomposition(
573+
input: torch.Tensor,
574+
mask: torch.Tensor,
575+
source: torch.Tensor,
576+
) -> torch.Tensor:
577+
"""
578+
Performs an operation equivalent to `input[mask] = source`.
579+
Steps:
580+
1) Broadcast `input` and `mask` to a common shape
581+
2) Flatten them
582+
3) Convert `mask` to int64, compute its cumsum, and subtract 1 to get gather indices
583+
4) Use `gather` to select elements from `source`
584+
5) Use `torch.where` to place gathered elements where `mask` is True
585+
6) Reshape the result to the original shape
586+
"""
587+
588+
# 1) Broadcast `input` and `mask` to a common shape
589+
input_b, mask_b = aten.broadcast_tensors([input, mask])
590+
591+
# 2) Flatten the broadcasted tensors and the source tensor
592+
input_flat = input_b.flatten()
593+
mask_flat = mask_b.flatten()
594+
source_flat = source.flatten()
595+
596+
# 3) Compute gather indices: (cumsum of mask as int64) - 1
597+
source_idx = mask_flat.to(torch.int64).cumsum(0) - 1
598+
599+
# 4) Gather elements from source_flat using these indices
600+
gathered = source_flat.gather(0, source_idx)
601+
602+
# 5) Replace positions where mask is True with gathered values, otherwise keep original
603+
replaced = torch.where(mask_flat, gathered, input_flat)
604+
605+
# 6) Reshape the result back to the broadcasted shape
606+
return replaced.view(input_b.shape)
607+
608+
569609
def get_decompositions(
570610
enable_experimental_decompositions: bool = False,
571611
) -> Dict[OpOverload, Callable[[Any], Any]]:

Diff for: tests/py/dynamo/lowering/test_decompositions.py

+80
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,86 @@ def forward(self, query, key, value, attn_bias=None):
21172117
msg="Scaled_dot_product_cudnn_attention TRT outputs don't match with the original model.",
21182118
)
21192119

2120+
@parameterized.expand(
2121+
[
2122+
("float32_2d", torch.float32, (4, 4)),
2123+
("float16_3d", torch.float16, (2, 3, 4)),
2124+
]
2125+
)
2126+
def test_masked_scatter(self, _, dtype, shape):
2127+
"""
2128+
Test that masked_scatter.default is correctly decomposed into
2129+
(cumsum, gather, where, etc.) and that final TRT results match PyTorch.
2130+
"""
2131+
2132+
class TestModule(torch.nn.Module):
2133+
def __init__(self):
2134+
super().__init__()
2135+
2136+
def forward(self, x, mask, source):
2137+
return torch.ops.aten.masked_scatter.default(x, mask, source)
2138+
2139+
x = torch.randn(*shape, dtype=dtype, device="cuda")
2140+
2141+
mask = torch.rand(*shape, device="cuda") > 0.5
2142+
num_trues = mask.sum().item()
2143+
if num_trues == 0:
2144+
mask[0] = True
2145+
num_trues = 1
2146+
source = torch.arange(num_trues, dtype=dtype, device="cuda")
2147+
2148+
inputs = [x, mask, source]
2149+
2150+
fx_graph = torch.fx.symbolic_trace(TestModule())
2151+
2152+
expected_ops = {
2153+
torch.ops.aten.where.self,
2154+
torch.ops.aten.gather.default,
2155+
torch.ops.aten.cumsum.default,
2156+
}
2157+
unexpected_ops = {torch.ops.aten.masked_scatter.default}
2158+
2159+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
2160+
fx_graph,
2161+
inputs,
2162+
expected_ops=expected_ops,
2163+
unexpected_ops=unexpected_ops,
2164+
min_block_size=1,
2165+
)
2166+
2167+
self.assertEqual(
2168+
len(unexpected_ops_seen),
2169+
0,
2170+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
2171+
)
2172+
2173+
self.assertEqual(
2174+
len(expected_ops_unseen),
2175+
0,
2176+
f"The following expected ops were not encountered: {expected_ops_unseen}",
2177+
)
2178+
2179+
torch._dynamo.reset()
2180+
2181+
trt_model = torch_tensorrt.compile(
2182+
fx_graph,
2183+
"torch_compile",
2184+
inputs,
2185+
min_block_size=1,
2186+
pass_through_build_failures=True,
2187+
)
2188+
with torch.no_grad():
2189+
trt_results = trt_model(*inputs).detach().cpu()
2190+
torch_results = fx_graph(*inputs).detach().cpu()
2191+
2192+
max_diff = float(torch.max(torch.abs(trt_results - torch_results)))
2193+
self.assertAlmostEqual(
2194+
max_diff,
2195+
0,
2196+
DECIMALS_OF_AGREEMENT,
2197+
f"Masked_scatter TRT outputs don't match with the original model. (diff={max_diff})",
2198+
)
2199+
21202200

21212201
if __name__ == "__main__":
21222202
run_tests()

0 commit comments

Comments
 (0)