diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 024afd8c62..b342743fab 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -143,6 +143,7 @@ jobs: python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py popd tests-py-dynamo-fe: diff --git a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py index ae60f8cda7..8ab47def08 100644 --- a/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py +++ b/tests/py/dynamo/automatic_plugin/test_automatic_plugin.py @@ -81,12 +81,3 @@ def forward(self, lhs, rhs): if __name__ == "__main__": run_tests() - -# Example Usage -# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float) -# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float) - -# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B) - -# print("C (Addition):", C) -# print("D (Multiplication):", D) diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py new file mode 100644 index 0000000000..6c10aafb7a --- /dev/null +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -0,0 +1,52 @@ +import pytest + +flashinfer = pytest.importorskip("flashinfer") +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt._enums import dtype + +from ..conversion.harness import DispatchTestCase + + +@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] +def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + + +@torch.library.register_fake("flashinfer::rmsnorm") +def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input + + +torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True +) + + +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( + [ + ((64, 64), (64,), torch.float16), + ((256, 256), (256,), torch.float16), + ] + ) + def test_rmsnorm_float(self, input_shape, weight_shape, data_type): + class rmsnorm(nn.Module): + def forward(self, input, weight): + return torch.ops.flashinfer.rmsnorm.default(input, weight) + + inputs = [ + torch.randn(input_shape, device="cuda", dtype=data_type), + torch.randn(weight_shape, device="cuda", dtype=data_type), + ] + + self.run_test(rmsnorm(), inputs, precision=dtype.f16) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 6fb6128089..4f3c4e083b 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -8,6 +8,6 @@ pytest>=8.2.1 pytest-xdist>=3.6.1 pyyaml timm>=1.0.3 -transformers==4.40.2 -nvidia-modelopt[deploy,hf,torch]~=0.17.0 +transformers==4.49.0 +nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13" --extra-index-url https://pypi.nvidia.com