Skip to content

Commit 873dd36

Browse files
authored
feat: log_softmax decomposition (#3137)
1 parent 9fed78c commit 873dd36

File tree

7 files changed

+111
-36
lines changed

7 files changed

+111
-36
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,11 @@ def aten_ops_unsqueeze(
704704
@dynamo_tensorrt_converter(
705705
torch.ops.aten._softmax.default, supports_dynamic_shapes=True
706706
)
707+
@enforce_tensor_types(
708+
{
709+
0: (TRTTensor,),
710+
}
711+
)
707712
def aten_ops_softmax(
708713
ctx: ConversionContext,
709714
target: Target,
@@ -712,7 +717,7 @@ def aten_ops_softmax(
712717
name: str,
713718
) -> Union[TRTTensor, Sequence[TRTTensor]]:
714719
return impl.normalization.softmax(
715-
ctx, target, SourceIR.ATEN, name, args[0], args[1]
720+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
716721
)
717722

718723

py/torch_tensorrt/dynamo/conversion/impl/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def scaled_dot_product_attention(
151151
)
152152

153153
softmax = impl.normalization.softmax(
154-
ctx, target, source_ir, name + "_softmax", scaled, -1
154+
ctx, target, source_ir, name + "_softmax", scaled, -1, False
155155
)
156156
out = impl.matmul.matrix_multiply(
157157
ctx,

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -439,30 +439,13 @@ def softmax(
439439
source_ir: Optional[SourceIR],
440440
name: str,
441441
input: TRTTensor,
442-
dim: Optional[Any] = None,
442+
dim: int,
443+
half_to_float: bool,
443444
) -> Union[TRTTensor, Sequence[TRTTensor]]:
444-
input_ranks = len(input.shape)
445+
dim = get_positive_dim(dim, len(input.shape))
445446

446-
if not isinstance(input, TRTTensor):
447-
raise RuntimeError(
448-
f"softmax received input {input} that is not part "
449-
"of the TensorRT region!"
450-
)
451-
452-
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
453-
def get_softmax_dim(ndim: int) -> int:
454-
if ndim == 0 or ndim == 1 or ndim == 3:
455-
ret = 0
456-
else:
457-
ret = 1
458-
return ret
459-
460-
if dim is None:
461-
dim = get_softmax_dim(input_ranks)
462-
else:
463-
dim = cast(int, dim)
464-
465-
dim = get_positive_dim(dim, input_ranks)
447+
if half_to_float:
448+
input = cast_trt_tensor(ctx, input, torch.float, name, target, source_ir)
466449

467450
layer = ctx.net.add_softmax(input)
468451
layer.axes = 1 << dim

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@
7676
aten.logit_backward,
7777
aten.log_sigmoid_backward,
7878
aten.log_sigmoid_forward,
79-
aten._log_softmax,
8079
aten._log_softmax_backward_data,
8180
aten.logspace,
8281
aten.logsumexp.default,

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+11
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,17 @@ def scatter_reduce_decomposition(
388388
return scatter_loop_tensor
389389

390390

391+
@register_torch_trt_decomposition(aten._log_softmax, registry=TORCH_TRT_DECOMPOSITIONS)
392+
def log_softmax_decomposition(
393+
x: torch.Tensor,
394+
dim: int,
395+
half_to_float: bool,
396+
) -> torch.Tensor:
397+
return torch.log(
398+
torch.softmax(x, dim, dtype=torch.float if half_to_float else None)
399+
)
400+
401+
391402
def get_decompositions(
392403
enable_experimental_decompositions: bool = False,
393404
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/conversion/test_softmax_aten.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,47 @@
11
import torch
2+
from parameterized import parameterized
23
from torch.testing._internal.common_utils import run_tests
34
from torch_tensorrt import Input
45

56
from .harness import DispatchTestCase
67

78

8-
class TestSoftMaxConverter(DispatchTestCase):
9-
def test_softmax(self):
9+
class TestSoftmaxConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(torch.float, False),
13+
(torch.half, False),
14+
(torch.half, True),
15+
]
16+
)
17+
def test_softmax(self, dtype, half_to_float):
1018
class TestModule(torch.nn.Module):
1119
def forward(self, x):
12-
return torch.ops.aten._softmax.default(x, 1, False)
20+
return torch.ops.aten._softmax.default(x, 1, half_to_float)
1321

14-
inputs = [torch.randn(1, 3, 224, 224)]
22+
inputs = [torch.randn(1, 3, 224, 224, dtype=dtype)]
1523
self.run_test(TestModule(), inputs)
1624

17-
def test_softmax_with_dynamic_shape(self):
25+
@parameterized.expand(
26+
[
27+
(torch.float, False),
28+
(torch.half, False),
29+
(torch.half, True),
30+
]
31+
)
32+
def test_softmax_with_dynamic_shape(self, dtype, half_to_float):
1833
class TestModule(torch.nn.Module):
1934
def forward(self, x):
20-
return torch.ops.aten._softmax.default(x, 2, False)
35+
return torch.ops.aten._softmax.default(x, 2, half_to_float)
2136

2237
input_specs = [
2338
Input(
24-
shape=(-1, 3, -1, -1),
25-
dtype=torch.float32,
26-
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
27-
),
39+
min_shape=(1, 1, 1, 1),
40+
opt_shape=(2, 4, 6, 8),
41+
max_shape=(8, 8, 8, 8),
42+
dtype=dtype,
43+
)
2844
]
29-
3045
self.run_test_with_dynamic_shape(TestModule(), input_specs)
3146

3247

tests/py/dynamo/lowering/test_decompositions.py

+62
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,68 @@ def forward(self, input):
15251525
f"Scatter_reduce TRT outputs don't match with the original model.",
15261526
)
15271527

1528+
@parameterized.expand(
1529+
[
1530+
(torch.float, False),
1531+
(torch.half, False),
1532+
(torch.half, True),
1533+
]
1534+
)
1535+
def test_lowering_log_softmax(self, dtype, half_to_float):
1536+
class TestModule(torch.nn.Module):
1537+
def forward(self, x):
1538+
return torch.ops.aten._log_softmax.default(x, 1, half_to_float)
1539+
1540+
# Operations expected to be removed in the traced graph after decompositions
1541+
expected_ops = {torch.ops.aten._softmax.default, torch.ops.aten.log.default}
1542+
unexpected_ops = {torch.ops.aten._log_softmax.default}
1543+
1544+
inputs = [torch.randn(1, 3, 5, 7, dtype=dtype, device="cuda")]
1545+
1546+
fx_graph = torch.fx.symbolic_trace(TestModule())
1547+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
1548+
fx_graph,
1549+
inputs,
1550+
expected_ops=expected_ops,
1551+
unexpected_ops=unexpected_ops,
1552+
min_block_size=1,
1553+
)
1554+
1555+
self.assertEqual(
1556+
len(unexpected_ops_seen),
1557+
0,
1558+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
1559+
)
1560+
1561+
self.assertEqual(
1562+
len(expected_ops_unseen),
1563+
0,
1564+
f"The following expected ops were not encountered: {expected_ops_unseen}",
1565+
)
1566+
1567+
torch._dynamo.reset()
1568+
1569+
# Validate that the results between Torch and Torch-TRT are similar
1570+
optimized_model = torch_tensorrt.compile(
1571+
fx_graph,
1572+
"torch_compile",
1573+
inputs,
1574+
min_block_size=1,
1575+
pass_through_build_failures=True,
1576+
)
1577+
optimized_model_results = optimized_model(*inputs).detach().cpu()
1578+
torch_model_results = fx_graph(*inputs).detach().cpu()
1579+
1580+
max_diff = float(
1581+
torch.max(torch.abs(optimized_model_results - torch_model_results))
1582+
)
1583+
self.assertAlmostEqual(
1584+
max_diff,
1585+
0,
1586+
DECIMALS_OF_AGREEMENT,
1587+
f"Log_softmax TRT outputs don't match with the original model.",
1588+
)
1589+
15281590

15291591
if __name__ == "__main__":
15301592
run_tests()

0 commit comments

Comments
 (0)