Skip to content

Commit 2fecbfd

Browse files
authored
Add test that shards, exports and runs with IREE a Conv2DLayer (#161)
In general we should avoid having many end-to-end tests, but I decided to add it since it was the isolation of a real problem iree-org/iree#18283 and I already had the test. We would like to have at least several E2E tests of tiny models that run on every PR.
1 parent 8e89540 commit 2fecbfd

File tree

5 files changed

+268
-59
lines changed

5 files changed

+268
-59
lines changed

sharktank/conftest.py

+49
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from pathlib import Path
8+
import pytest
9+
from typing import Optional
810

911

1012
# Tests under each top-level directory will get a mark.
@@ -24,3 +26,50 @@ def pytest_collection_modifyitems(items, config):
2426
mark = TLD_MARKS.get(tld)
2527
if mark:
2628
item.add_marker(mark)
29+
30+
31+
def pytest_addoption(parser):
32+
parser.addoption(
33+
"--mlir",
34+
type=Path,
35+
default=None,
36+
help="Path to exported MLIR program. If not specified a temporary file will be used.",
37+
)
38+
parser.addoption(
39+
"--module",
40+
type=Path,
41+
default=None,
42+
help="Path to exported IREE module. If not specified a temporary file will be used.",
43+
)
44+
parser.addoption(
45+
"--parameters",
46+
type=Path,
47+
default=None,
48+
help="Exported model parameters. If not specified a temporary file will be used.",
49+
)
50+
parser.addoption(
51+
"--caching",
52+
action="store_true",
53+
default=False,
54+
help="Load cached results if present instead of recomputing.",
55+
)
56+
57+
58+
@pytest.fixture(scope="session")
59+
def mlir_path(pytestconfig: pytest.Config) -> Optional[Path]:
60+
return pytestconfig.getoption("mlir")
61+
62+
63+
@pytest.fixture(scope="session")
64+
def module_path(pytestconfig: pytest.Config) -> Optional[Path]:
65+
return pytestconfig.getoption("module")
66+
67+
68+
@pytest.fixture(scope="session")
69+
def parameters_path(pytestconfig: pytest.Config) -> Optional[Path]:
70+
return pytestconfig.getoption("parameters")
71+
72+
73+
@pytest.fixture(scope="session")
74+
def caching(pytestconfig: pytest.Config) -> Optional[Path]:
75+
return pytestconfig.getoption("caching")

sharktank/sharktank/ops/sharded_impls.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def conv2d_all_split(
8888
input.is_replicated or input.shard_dim == 1
8989
), "Only sharding of input channel dimension is supported"
9090
assert (
91-
weight.shard_dim == 0 and bias.shard_dim == 0
91+
bias is None or weight.shard_dim == 0 and bias.shard_dim == 0
9292
), "Only sharding of output channel dimension is supported"
9393

9494
# TODO: allow for implementation where we don't all-gather, but gather
@@ -146,7 +146,7 @@ def conv2d_replicated_input_split_weight_and_bias(
146146
assert input.shard_count == weight.shard_count
147147
assert bias is None or weight.shard_count == bias.shard_count
148148
assert (
149-
weight.shard_dim == 0 and bias.shard_dim == 0
149+
bias is None or weight.shard_dim == 0 and bias.shard_dim == 0
150150
), "Only sharding of output channel dimension is supported"
151151
assert groups == 1
152152

@@ -189,7 +189,8 @@ def conv2d_split_weight_and_bias(
189189
accum_dtype,
190190
) -> SplitPrimitiveTensor:
191191
assert accum_dtype is None, "accum_dtype not supported"
192-
assert weight.shard_count == bias.shard_count
192+
if bias is not None:
193+
assert weight.shard_count == bias.shard_count
193194

194195
# Output channels dimension is split.
195196
if weight.shard_dim == 0 and groups == 1:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import unittest
2+
3+
# Copyright 2024 Advanced Micro Devices, Inc.
4+
#
5+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
6+
# See https://llvm.org/LICENSE.txt for license information.
7+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
9+
from pathlib import Path
10+
import tempfile
11+
import torch
12+
from shark_turbine import aot
13+
from sharktank.models.punet.layers import Conv2DLayer
14+
from sharktank import ops
15+
from sharktank.types import (
16+
Dataset,
17+
DefaultPrimitiveTensor,
18+
Theta,
19+
ShardedTensor,
20+
SplitPrimitiveTensor,
21+
unbox_tensor,
22+
)
23+
from sharktank.types.sharding import Conv2DSplitOutputChannelSharding
24+
import iree.runtime
25+
from typing import List, Optional
26+
import os
27+
28+
vm_context: iree.runtime.VmContext = None
29+
30+
31+
def get_compiler_args(target_device_kind: str, shard_count: int) -> List[str]:
32+
result = [
33+
f"--iree-hal-target-device={target_device_kind}[{i}]"
34+
for i in range(shard_count)
35+
]
36+
return result
37+
38+
39+
def compile_iree_module(
40+
export_output: aot.ExportOutput, module_path: str, shard_count: int
41+
):
42+
export_output.session.set_flags(
43+
*get_compiler_args(target_device_kind="llvm-cpu", shard_count=shard_count)
44+
)
45+
export_output.compile(save_to=module_path, target_backends=None)
46+
47+
48+
# TODO: improve IREE's Python API to be more concise in a multi-device context.
49+
# This run function should be way shorter.
50+
def run_iree_module(
51+
sharded_input_image: ShardedTensor,
52+
module_path: str,
53+
parameters_path: str,
54+
) -> ShardedTensor:
55+
shard_count = sharded_input_image.shard_count
56+
hal_driver = iree.runtime.get_driver("local-task")
57+
vm_instance = iree.runtime.VmInstance()
58+
available_devices = hal_driver.query_available_devices()
59+
# Use the same actual device for all devices.
60+
devices = [
61+
hal_driver.create_device(available_devices[0]) for _ in range(shard_count)
62+
]
63+
hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices)
64+
params_path = Path(parameters_path)
65+
# TODO: make IREE able to load the parameters from the top parameter file
66+
# without having to specify the parameter file for each shard separately.
67+
parameter_index = iree.runtime.ParameterIndex()
68+
for i in range(shard_count):
69+
parameter_index.load(
70+
file_path=str(
71+
Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
72+
)
73+
)
74+
parameter_provider = parameter_index.create_provider(scope="model")
75+
parameters_module = iree.runtime.create_io_parameters_module(
76+
vm_instance, parameter_provider
77+
)
78+
79+
vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
80+
81+
# The context needs to be destroyed after the buffers, although
82+
# it is not associate with them on the API level.
83+
global vm_context
84+
vm_context = iree.runtime.VmContext(
85+
instance=vm_instance, modules=(hal_module, parameters_module, vm_module)
86+
)
87+
module_input_args = [
88+
iree.runtime.asdevicearray(
89+
devices[i], sharded_input_image.shards[i].as_torch().to("cpu").numpy()
90+
)
91+
for i in range(shard_count)
92+
]
93+
94+
vm_function = vm_module.lookup_function("main")
95+
invoker = iree.runtime.FunctionInvoker(
96+
vm_context=vm_context,
97+
# TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
98+
# This works, but does not look right.
99+
device=devices[0],
100+
vm_function=vm_function,
101+
)
102+
results = invoker(*module_input_args)
103+
shards = [torch.tensor(tensor.to_host()) for tensor in results]
104+
return SplitPrimitiveTensor(ts=shards, shard_dim=1)
105+
106+
107+
def run_test_sharded_conv2d_with_iree(
108+
mlir_path: Path, module_path: Path, parameters_path: Path, caching: bool
109+
):
110+
torch.set_default_dtype(torch.float32)
111+
torch.manual_seed(123456)
112+
batches = 2
113+
in_channels = 6
114+
out_channels = 8
115+
height = 11
116+
width = 13
117+
kernel_height = 5
118+
kernel_width = 5
119+
shard_count = 2
120+
unsharded_theta = Theta(
121+
{
122+
"weight": DefaultPrimitiveTensor(
123+
data=torch.rand(
124+
out_channels,
125+
in_channels,
126+
kernel_height,
127+
kernel_width,
128+
)
129+
),
130+
}
131+
)
132+
unsharded_theta.rename_tensors_to_paths()
133+
134+
if not caching or not os.path.exists(parameters_path):
135+
sharding_spec = Conv2DSplitOutputChannelSharding(shard_count=shard_count)
136+
sharded_theta = ops.reshard(unsharded_theta, sharding_spec)
137+
138+
# Roundtrip the dataset, which anchors the tensors as parameters to be loaded
139+
# vs constants to be frozen (TODO: This is a bit wonky).
140+
sharded_dataset = Dataset({}, sharded_theta)
141+
sharded_dataset.save(parameters_path)
142+
143+
sharded_dataset = Dataset.load(parameters_path)
144+
145+
input_image = torch.rand(
146+
batches,
147+
in_channels,
148+
height,
149+
width,
150+
)
151+
152+
sharded_torch_module = Conv2DLayer(sharded_dataset.root_theta, padding=(0, 0))
153+
sharded_input_image = ops.reshard_split(input_image, dim=1, count=shard_count)
154+
expected_result = sharded_torch_module(sharded_input_image)
155+
156+
if not caching or not os.path.exists(module_path):
157+
exported_module = aot.export(
158+
sharded_torch_module,
159+
args=(sharded_input_image,),
160+
)
161+
exported_module.save_mlir(mlir_path)
162+
163+
compile_iree_module(
164+
export_output=exported_module,
165+
module_path=module_path,
166+
shard_count=shard_count,
167+
)
168+
169+
actual_result = run_iree_module(
170+
sharded_input_image=sharded_input_image,
171+
module_path=module_path,
172+
parameters_path=parameters_path,
173+
)
174+
assert len(actual_result.shards) == len(expected_result.shards)
175+
assert actual_result.shard_dim == expected_result.shard_dim
176+
# TODO: reenable this check once numerical issues are resolved.
177+
# See https://github.com/iree-org/iree/issues/18283
178+
# for actual_shard, expected_shard in zip(
179+
# actual_result.shards, expected_result.shards
180+
# ):
181+
# torch.testing.assert_close(
182+
# unbox_tensor(actual_shard), unbox_tensor(expected_shard)
183+
# )
184+
185+
186+
def test_sharded_conv2d_with_iree(
187+
mlir_path: Optional[Path],
188+
module_path: Optional[Path],
189+
parameters_path: Optional[Path],
190+
caching: bool,
191+
):
192+
"""Test sharding, exporting and running with IREE a 2D convolution layer."""
193+
194+
with tempfile.TemporaryDirectory(
195+
# TODO: verify hypothesis and remove ignore_cleanup_errors=True after a fix.
196+
# torch.export.export is spawning some processes that don't exit when the
197+
# function returns, this causes some objects to not get destroyed, which
198+
# in turn holds files params.rank0.irpa and params.rank1.irpa open.
199+
ignore_cleanup_errors=True
200+
) as tmp_dir:
201+
mlir_path = Path(tmp_dir) / "model.mlir" if mlir_path is None else mlir_path
202+
module_path = (
203+
Path(tmp_dir) / "module.vmfb" if module_path is None else module_path
204+
)
205+
parameters_path = (
206+
Path(tmp_dir) / "params.irpa"
207+
if parameters_path is None
208+
else parameters_path
209+
)
210+
run_test_sharded_conv2d_with_iree(
211+
mlir_path, module_path, parameters_path, caching
212+
)

sharktank/tests/models/punet/conftest.py

-56
This file was deleted.

sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def compile_iree_module(
4040
export_output.compile(save_to=module_path, target_backends=None)
4141

4242

43+
# TODO: improve IREE's Python API to be more concise in a multi-device context.
44+
# This run function should be way shorter.
4345
def run_iree_module(
4446
sharded_input_image: ShardedTensor,
4547
sharded_input_time_emb: ShardedTensor,
@@ -206,6 +208,7 @@ def run_test_sharded_resnet_block_with_iree(
206208
)
207209
assert len(actual_result.shards) == len(expected_result.shards)
208210
# TODO: reenable this check once numerical issues are resolved.
211+
# See https://github.com/iree-org/iree/issues/18283
209212
# for actual_shard, expected_shard in zip(
210213
# actual_result.shards, expected_result.shards
211214
# ):

0 commit comments

Comments
 (0)