Skip to content

Commit 13b96ae

Browse files
authored
Fixed fold_constants, test_handler switched to onnx (Project-MONAI#8211)
Fixed fold_constants: the result was not saved. test_handler switched to onnx as torch-tensorrt is causing issues with CI on various Torch versions and is not used anyway. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Boris Fomitchev <[email protected]>
1 parent 746a97a commit 13b96ae

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

monai/networks/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,10 @@ def convert_to_onnx(
712712
onnx_model = onnx.load(f)
713713

714714
if do_constant_folding and polygraphy_imported:
715-
from polygraphy.backend.onnx.loader import fold_constants
715+
from polygraphy.backend.onnx.loader import fold_constants, save_onnx
716716

717-
fold_constants(onnx_model, size_threshold=constant_size_threshold)
717+
onnx_model = fold_constants(onnx_model, size_threshold=constant_size_threshold)
718+
save_onnx(onnx_model, f)
718719

719720
if verify:
720721
if isinstance(inputs, dict):

tests/test_trt_compile.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def tearDown(self):
6161
if current_device != self.gpu_device:
6262
torch.cuda.set_device(self.gpu_device)
6363

64-
@unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
64+
# @unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
6565
def test_handler(self):
6666
from ignite.engine import Engine
6767

@@ -74,7 +74,7 @@ def test_handler(self):
7474

7575
with tempfile.TemporaryDirectory() as tempdir:
7676
engine = Engine(lambda e, b: None)
77-
args = {"method": "torch_trt"}
77+
args = {"method": "onnx", "dynamic_batchsize": [1, 4, 8]}
7878
TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine)
7979
engine.run([0] * 8, max_epochs=1)
8080
self.assertIsNotNone(net1._trt_compiler)
@@ -86,7 +86,11 @@ def test_lists(self):
8686
model = ListAdd().cuda()
8787

8888
with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
89-
args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}}
89+
args = {
90+
"output_lists": [[-1], [2], []],
91+
"export_args": {"dynamo": False, "verbose": True},
92+
"dynamic_batchsize": [1, 4, 8],
93+
}
9094
x = torch.randn(1, 16).to("cuda")
9195
y = torch.randn(1, 16).to("cuda")
9296
z = torch.randn(1, 16).to("cuda")

0 commit comments

Comments
 (0)