Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Apr 29, 2024
1 parent 228f25e commit 69c7202
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,8 @@
"PermuteNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PixelUnshuffleModuleStaticRank3Int64_basic",
"PixelUnshuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic",
Expand Down
54 changes: 54 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,24 @@ def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils):
# ==============================================================================


class PixelUnshuffleModuleStaticRank3int64(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([12, 2, 2], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.pixel_unshuffle(x, 2)


@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank3int64())
def PixelUnshuffleModuleStaticRank3int64_basic(module, tu: TestUtils):
module.forward(tu.randint(12, 2, 2, low=0, high=100))


# ==============================================================================


class PixelUnshuffleModuleStaticRank4Float32(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -1004,6 +1022,42 @@ def PixelUnshuffleModuleStaticRank4Float32_basic(module, tu: TestUtils):
# ==============================================================================


class PixelUnshuffleModuleFullDynamic(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1, -1, -1], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.pixel_unshuffle(x, 2)


@register_test_case(module_factory=lambda: PixelUnshuffleModuleFullDynamic())
def PixelUnshuffleModuleFullDynamic_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 8, 4, 4, low=0, high=100))


# ==============================================================================


class PixelUnshuffleModuleSpatiallyDynamic(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([2, 1, -1, -1], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.pixel_unshuffle(x, 2)


@register_test_case(module_factory=lambda: PixelUnshuffleModuleSpatiallyDynamic())
def PixelUnshuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 1, 8, 8, low=0, high=100))


# ==============================================================================


class TensorsConcatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 69c7202

Please sign in to comment.