From 0f2ebaf2c150d777a397097d705843f4d9ce87d9 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 8 Apr 2024 12:55:25 -0700 Subject: [PATCH 1/2] add broadcast test --- test/python/fx_importer/basic_test.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 08ef9fdc9cd3..68a1b3165946 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -105,6 +105,31 @@ def forward(self, x): print(m) +@run +# CHECK-LABEL: test_broadcast_with_dynamic_shapes +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +def test_broadcast_with_dynamic_shapes(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + dim_0 = Dim("dim_0") + dynamic_shapes = { + "x": {}, + "y": {0: dim_0}, + } + + m = fx.export_and_import(Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net") + print(m) + + @make_boxed_compiler def fx_import_aot_autograd_backend( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] @@ -117,7 +142,7 @@ def fx_import_aot_autograd_backend( @run # CHECK-LABEL: test_stateless_fx_import -# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> def test_stateless_fx_import(): From 5b84db3d37ec9177ca4deb8fba4c705588cdd676 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 29 Apr 2024 04:14:31 -0700 Subject: [PATCH 2/2] run black --- test/python/fx_importer/basic_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 68a1b3165946..fde318630077 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -126,7 +126,9 @@ def forward(self, x, y): "y": {0: dim_0}, } - m = fx.export_and_import(Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net") + m = fx.export_and_import( + Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + ) print(m)