diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 32ff30b818..a9e73170f7 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -143,6 +143,7 @@ jobs: python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/flashinfer_plugin.py popd tests-py-dynamo-fe: diff --git a/py/requirements.txt b/py/requirements.txt index 5644656330..9cb947c4d8 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -6,3 +6,4 @@ torch>=2.7.0.dev,<2.8.0 torchvision>=0.22.0.dev,<0.23.0 --extra-index-url https://pypi.ngc.nvidia.com pyyaml +flashinfer-python \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index 4211bae1fa..be0cc756ca 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -1,3 +1,4 @@ +import itertools import logging from types import FunctionType from typing import Any, Callable, Tuple @@ -130,16 +131,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: output = torch_op(*fake_args, **kwargs) # We assume that number of dimensions are the same in torch op - shape_calc_fns = [None] * args[0].ndim - for i in range(args[0].ndim): - input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args] + shape_calc_fns = [None] * output.ndim + + for i in range(output.ndim): + input_node_expr = input_node_expr = list( + itertools.chain.from_iterable( + [sym.node.expr for sym in syms_arg] for syms_arg in syms_args + ) + ) + shape_calc_fns[i] = lambdify( tuple(input_node_expr), output.shape[i].node.expr, "math" ) out_desc = tensor_args[0].like() for i in range(out_desc.ndim): - input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args] + input_shape_expr = list( + itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) + ) + if output.shape[i].node.expr is None: raise ValueError(f"output.shape[{i}].node.expr cannot be None") out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py index ae60f8cda7..8ab47def08 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py @@ -81,12 +81,3 @@ def forward(self, lhs, rhs): if __name__ == "__main__": run_tests() - -# Example Usage -# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float) -# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float) - -# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B) - -# print("C (Addition):", C) -# print("D (Multiplication):", D) diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py new file mode 100644 index 0000000000..9fc7e7df8f --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -0,0 +1,50 @@ +import flashinfer +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._enums import dtype + +from ..conversion.harness import DispatchTestCase + + +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] +def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + + +@torch.library.register_fake("flashinfer::rmsnorm") +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input + + +torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True +) + + +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( + [ + ((64, 64), (64,), torch.float16), + ((256, 256), (256,), torch.float16), + ] + ) + def test_rmsnorm_float(self, input_shape, weight_shape, data_type): + class rmsnorm(nn.Module): + def forward(self, input, weight): + return torch.ops.flashinfer.rmsnorm.default(input, weight) + + inputs = [ + torch.randn(input_shape, device="cuda", dtype=data_type), + torch.randn(weight_shape, device="cuda", dtype=data_type), + ] + + self.run_test(rmsnorm(), inputs, precision=dtype.f16) + + +if __name__ == "__main__": + run_tests()