|
1 | 1 | import torch
|
| 2 | +from parameterized import parameterized |
2 | 3 | from torch.testing._internal.common_utils import run_tests
|
3 | 4 | from torch_tensorrt import Input
|
4 | 5 |
|
5 | 6 | from .harness import DispatchTestCase
|
6 | 7 |
|
7 | 8 |
|
8 |
| -class TestSoftMaxConverter(DispatchTestCase): |
9 |
| - def test_softmax(self): |
| 9 | +class TestSoftmaxConverter(DispatchTestCase): |
| 10 | + @parameterized.expand( |
| 11 | + [ |
| 12 | + (torch.float, False), |
| 13 | + (torch.half, False), |
| 14 | + (torch.half, True), |
| 15 | + ] |
| 16 | + ) |
| 17 | + def test_softmax(self, dtype, half_to_float): |
10 | 18 | class TestModule(torch.nn.Module):
|
11 | 19 | def forward(self, x):
|
12 |
| - return torch.ops.aten._softmax.default(x, 1, False) |
| 20 | + return torch.ops.aten._softmax.default(x, 1, half_to_float) |
13 | 21 |
|
14 |
| - inputs = [torch.randn(1, 3, 224, 224)] |
| 22 | + inputs = [torch.randn(1, 3, 224, 224, dtype=dtype)] |
15 | 23 | self.run_test(TestModule(), inputs)
|
16 | 24 |
|
17 |
| - def test_softmax_with_dynamic_shape(self): |
| 25 | + @parameterized.expand( |
| 26 | + [ |
| 27 | + (torch.float, False), |
| 28 | + (torch.half, False), |
| 29 | + (torch.half, True), |
| 30 | + ] |
| 31 | + ) |
| 32 | + def test_softmax_with_dynamic_shape(self, dtype, half_to_float): |
18 | 33 | class TestModule(torch.nn.Module):
|
19 | 34 | def forward(self, x):
|
20 |
| - return torch.ops.aten._softmax.default(x, 2, False) |
| 35 | + return torch.ops.aten._softmax.default(x, 2, half_to_float) |
21 | 36 |
|
22 | 37 | input_specs = [
|
23 | 38 | Input(
|
24 |
| - shape=(-1, 3, -1, -1), |
25 |
| - dtype=torch.float32, |
26 |
| - shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], |
27 |
| - ), |
| 39 | + min_shape=(1, 1, 1, 1), |
| 40 | + opt_shape=(2, 4, 6, 8), |
| 41 | + max_shape=(8, 8, 8, 8), |
| 42 | + dtype=dtype, |
| 43 | + ) |
28 | 44 | ]
|
29 |
| - |
30 | 45 | self.run_test_with_dynamic_shape(TestModule(), input_specs)
|
31 | 46 |
|
32 | 47 |
|
|
0 commit comments