diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml index 1200234c4..6886bfa83 100644 --- a/.github/workflows/build_packages.yml +++ b/.github/workflows/build_packages.yml @@ -12,6 +12,9 @@ on: # Runs at 11:00 AM UTC, which is 3:00 AM PST (UTC-8) - cron: '0 11 * * *' +permissions: + contents: read + jobs: # Note: metadata generation could happen in a separate trigger/schedule # workflow. For cross platform builds, it's useful to just generate the @@ -40,7 +43,7 @@ jobs: sharktank_package_version=$(python3 build_tools/python_deploy/compute_local_version.py sharktank) shortfin_package_version=$(python3 build_tools/python_deploy/compute_local_version.py shortfin) - name: Upload version_local.json - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: version_local path: | @@ -50,6 +53,8 @@ jobs: build_packages: name: "${{ matrix.package }} :: ${{ matrix.platform }} :: ${{ matrix.python-version }}" runs-on: ${{ matrix.runs-on }} + permissions: + contents: write needs: [setup_metadata] strategy: fail-fast: false @@ -116,7 +121,7 @@ jobs: ./c/shortfin/build_tools/build_linux_package.sh - name: Upload python wheels - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: if-no-files-found: error name: snapshot-${{ matrix.package }}-${{ matrix.platform }}-${{ matrix.python-version }} @@ -126,7 +131,6 @@ jobs: uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 with: artifacts: bindist/*.whl - token: "${{ secrets.RELEASE_PUBLISH_ACCESS_TOKEN }}" tag: "dev-wheels" name: "dev-wheels" body: "Automatic snapshot release of shark-ai python wheels." diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 41ad5af6b..6a3b764b8 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -70,8 +70,8 @@ jobs: # Test with pinned nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-base-compiler==3.0.0rc20241115 \ - iree-base-runtime==3.0.0rc20241115 + iree-base-compiler==3.0.0rc20241118 \ + iree-base-runtime==3.0.0rc20241118 - name: Run llama tests run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-all-llama --iree-hip-target=gfx942 --html=out/index.html diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index 63637e9b9..6c381b658 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -71,14 +71,14 @@ jobs: # Test with pinned nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-base-compiler==3.0.0rc20241115 \ - iree-base-runtime==3.0.0rc20241115 + iree-base-compiler==3.0.0rc20241118 \ + iree-base-runtime==3.0.0rc20241118 - name: Run llama 8b tests run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-8b-llama - name: Upload llama executable files - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: llama-files path: ${{ github.workspace }}/${{ steps.date.outputs.date }} diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml index 31218d25f..355ffcf8b 100644 --- a/.github/workflows/ci-sdxl.yaml +++ b/.github/workflows/ci-sdxl.yaml @@ -64,7 +64,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-3.0.0rc20241115 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index 6a5fa4112..cb9df91b8 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -69,8 +69,8 @@ jobs: # We could also pin to a known working or stable version. # This should eventually stabilize. Do the best we can for now. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-base-compiler==3.0.0rc20241115 \ - iree-base-runtime==3.0.0rc20241115 \ + iree-base-compiler==3.0.0rc20241118 \ + iree-base-runtime==3.0.0rc20241118 \ "numpy<2.0" - name: Install SGLang diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 45ddfe90d..afeca11a6 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -59,7 +59,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-3.0.0rc20241115 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index 5692a8336..42de8f0f6 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -109,7 +109,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_SOURCE_DIR }} submodules: false - ref: iree-3.0.0rc20241115 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules working-directory: ${{ env.IREE_SOURCE_DIR }} diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index c382edbf4..0e0e1db2a 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -57,7 +57,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-3.0.0rc20241115 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml index 544b45c76..00873c432 100644 --- a/.github/workflows/ci_windows_x64-libshortfin.yml +++ b/.github/workflows/ci_windows_x64-libshortfin.yml @@ -54,7 +54,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-3.0.0rc20241115 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/docs/user_guide.md b/docs/user_guide.md index c3da1f4f5..aedc9f546 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -17,7 +17,7 @@ Our current user guide requires that you have: This section will help you install Python and set up a Python environment with venv. -Officially we support Python versions: 3.11, 3.12, 3.13, 3.13t +Officially we support Python versions: 3.11, 3.12, 3.13 The rest of this guide assumes you are using Python 3.11. @@ -39,6 +39,10 @@ Setup your Python environment with the following commands: # Set up a virtual environment to isolate packages from other envs. python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate + +# Optional: faster installation of torch with just CPU support. +# See other options at https://pytorch.org/get-started/locally/ +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ``` ## Install SHARK and its dependencies diff --git a/sharktank/version.json b/sharktank/version.json index f09f61d2a..85afb41ed 100644 --- a/sharktank/version.json +++ b/sharktank/version.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.2.dev" + "package-version": "3.0.0.dev" } diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index 93ee63594..f025eccfe 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -40,7 +40,7 @@ if(NOT WIN32) endif() # Pins -set(SHORTFIN_IREE_GIT_TAG "iree-3.0.0rc20241115") +set(SHORTFIN_IREE_GIT_TAG "iree-3.0.0rc20241118") # build options option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF) diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index d17606b4b..c668e6a8b 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -428,16 +428,15 @@ ConfigOptions CreateConfigOptions(std::optional &env_prefix, } // namespace NB_MODULE(lib, m) { -// Tragically, debug builds of Python do the right thing and don't immortalize -// many identifiers and such. This makes the last chance leak checking that -// nanobind does somewhat unreliable since the reports it prints may be -// to identifiers that are no longer live (at a time in process shutdown -// where it is expected that everything left just gets dropped on the floor). -// This causes segfaults or ASAN violations in the leak checker on exit in -// certain scenarios where we have spurious "leaks" of global objects. -#if defined(Py_DEBUG) + // Tragically, debug builds of Python do the right thing and don't immortalize + // many identifiers and such. This makes the last chance leak checking that + // nanobind does somewhat unreliable since the reports it prints may be + // to identifiers that are no longer live (at a time in process shutdown + // where it is expected that everything left just gets dropped on the floor). + // This causes segfaults or ASAN violations in the leak checker on exit in + // certain scenarios where we have spurious "leaks" of global objects. + py::set_leak_warnings(false); -#endif logging::InitializeFromEnv(); diff --git a/shortfin/python/shortfin_apps/sd/_deps.py b/shortfin/python/shortfin_apps/sd/_deps.py index 9965065ce..92bd089ec 100644 --- a/shortfin/python/shortfin_apps/sd/_deps.py +++ b/shortfin/python/shortfin_apps/sd/_deps.py @@ -9,7 +9,7 @@ try: import transformers except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "diffusers") from e + raise ShortfinDepNotFoundError(__name__, "transformers") from e try: import tokenizers diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index f23922dd6..98678c46d 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -5,9 +5,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from iree.build import * -from iree.build.executor import FileNamespace +from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile import itertools import os +import urllib import shortfin.array as sfnp import copy @@ -24,7 +25,7 @@ sfnp.bfloat16: "bf16", } -ARTIFACT_VERSION = "11132024" +ARTIFACT_VERSION = "11182024" SDXL_BUCKET = ( f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/" ) @@ -162,34 +163,93 @@ def needs_update(ctx): return False -def needs_file(filename, ctx, namespace=FileNamespace.GEN): +def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + needed = True if os.path.exists(out_file): - needed = False - else: - # name_path = "bin" if namespace == FileNamespace.BIN else "" - # if name_path: - # filename = os.path.join(name_path, filename) - filekey = os.path.join(ctx.path, filename) - ctx.executor.all[filekey] = None - needed = True - return needed + if url: + needed = not is_valid_size(out_file, url) + if not needed: + return False + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + return True def needs_compile(filename, target, ctx): - device = "amdgpu" if "gfx" in target else "llvmcpu" - vmfb_name = f"{filename}_{device}-{target}.vmfb" + vmfb_name = f"{filename}_{target}.vmfb" namespace = FileNamespace.BIN - return needs_file(vmfb_name, ctx, namespace) + return needs_file(vmfb_name, ctx, namespace=namespace) def get_cached_vmfb(filename, target, ctx): - device = "amdgpu" if "gfx" in target else "llvmcpu" - vmfb_name = f"{filename}_{device}-{target}.vmfb" - namespace = FileNamespace.BIN + vmfb_name = f"{filename}_{target}.vmfb" return ctx.file(vmfb_name) +def is_valid_size(file_path, url): + if not url: + return True + with urllib.request.urlopen(url) as response: + content_length = response.getheader("Content-Length") + local_size = get_file_size(str(file_path)) + if content_length: + content_length = int(content_length) + if content_length != local_size: + return False + return True + + +def get_file_size(file_path): + """Gets the size of a local file in bytes as an integer.""" + + file_stats = os.stat(file_path) + return file_stats.st_size + + +def fetch_http_check_size(*, name: str, url: str) -> BuildFile: + context = BuildContext.current() + output_file = context.allocate_file(name) + action = FetchHttpWithCheckAction( + url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor + ) + output_file.deps.add(action) + return output_file + + +class FetchHttpWithCheckAction(BuildAction): + def __init__(self, url: str, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.url = url + self.output_file = output_file + + def _invoke(self, retries=4): + path = self.output_file.get_fs_path() + self.executor.write_status(f"Fetching URL: {self.url} -> {path}") + try: + urllib.request.urlretrieve(self.url, str(path)) + except urllib.error.HTTPError as e: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + else: + raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None + local_size = get_file_size(str(path)) + try: + with urllib.request.urlopen(self.url) as response: + content_length = response.getheader("Content-Length") + if content_length: + content_length = int(content_length) + if content_length != local_size: + raise IOError( + f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}" + ) + except IOError: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + + @entrypoint(description="Retreives a set of SDXL submodels.") def sdxl( model_json=cl_arg( @@ -224,7 +284,7 @@ def sdxl( mlir_filenames = get_mlir_filenames(model_params, model) mlir_urls = get_url_map(mlir_filenames, mlir_bucket) for f, url in mlir_urls.items(): - if update or needs_file(f, ctx): + if update or needs_file(f, ctx, url): fetch_http(name=f, url=url) vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target) @@ -244,15 +304,14 @@ def sdxl( vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx) else: for f, url in vmfb_urls.items(): - if update or needs_file(f, ctx): + if update or needs_file(f, ctx, url): fetch_http(name=f, url=url) params_filenames = get_params_filenames(model_params, model=model, splat=splat) params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) for f, url in params_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) - if needs_file(f, ctx): - fetch_http(name=f, url=url) + if needs_file(f, ctx, url): + fetch_http_check_size(name=f, url=url) filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] return filenames diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py index f3502f22e..432f08b4e 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -6,12 +6,9 @@ from iree.build import * from iree.build.executor import FileNamespace -import itertools import os -import shortfin.array as sfnp -import copy -ARTIFACT_VERSION = "11132024" +ARTIFACT_VERSION = "11182024" SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" @@ -72,21 +69,18 @@ def sdxlconfig( model_config_filenames = [f"{model}_config_i8.json"] model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) for f, url in model_config_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) topology_config_filenames = [f"topology_config_{topology}.txt"] topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET) for f, url in topology_config_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) flagfile_filenames = [f"{model}_flagfile_{target}.txt"] flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) for f, url in flagfile_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) @@ -95,7 +89,6 @@ def sdxlconfig( ) tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) for f, url in tuning_urls.items(): - out_file = os.path.join(ctx.executor.output_dir, f) if update or needs_file(f, ctx): fetch_http(name=f, url=url) filenames = [ diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 478d03ad8..2b954c18b 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -15,7 +15,6 @@ from dataclasses import dataclass from pathlib import Path -import dataclasses_json from dataclasses_json import dataclass_json, Undefined import shortfin.array as sfnp diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index 1afa73d5e..62ac5e855 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -5,18 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio -import io import logging import json import shortfin as sf -import shortfin.array as sfnp # TODO: Have a generic "Responder" interface vs just the concrete impl. from shortfin.interop.fastapi import FastAPIResponder from .io_struct import GenerateReqInput -from .messages import InferenceExecRequest, InferencePhase +from .messages import InferenceExecRequest from .service import GenerateService from .metrics import measure @@ -83,7 +81,6 @@ def __init__( self.batcher = service.batcher self.complete_infeed = self.system.create_queue() - @measure(type="throughput", num_items="num_output_images", freq=1, label="samples") async def run(self): logger.debug("Started ClientBatchGenerateProcess: %r", self) try: diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py index d1d9cf41a..73e77316f 100644 --- a/shortfin/python/shortfin_apps/sd/components/io_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -4,12 +4,10 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union from dataclasses import dataclass import uuid -import shortfin.array as sfnp - @dataclass class GenerateReqInput: diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py index ea29b69a4..e416592d0 100644 --- a/shortfin/python/shortfin_apps/sd/components/manager.py +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True): sb.visible_devices = sb.available_devices sb.visible_devices = get_selected_devices(sb, device_ids) self.ls = sb.create_system() - logging.info(f"Created local system with {self.ls.device_names} devices") + logger.info(f"Created local system with {self.ls.device_names} devices") # TODO: Come up with an easier bootstrap thing than manually # running a thread. self.t = threading.Thread(target=lambda: self.ls.run(self.run())) @@ -39,9 +39,10 @@ def start(self): def shutdown(self): logger.info("Shutting down system manager") self.command_queue.close() + self.ls.shutdown() async def run(self): reader = self.command_queue.reader() while command := await reader(): ... - logging.info("System manager command processor stopped") + logger.info("System manager command processor stopped") diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py index a1811beea..62e855698 100644 --- a/shortfin/python/shortfin_apps/sd/components/metrics.py +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -6,8 +6,7 @@ import logging import time -import asyncio -from typing import Callable, Any +from typing import Any import functools logger = logging.getLogger("shortfin-sd.metrics") diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index ad3fd9404..9b09632a6 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -6,12 +6,10 @@ import asyncio import logging -import math import numpy as np from tqdm.auto import tqdm from pathlib import Path from PIL import Image -import io import base64 import shortfin as sf @@ -23,9 +21,7 @@ from .tokenizer import Tokenizer from .metrics import measure - logger = logging.getLogger("shortfin-sd.service") -logger.setLevel(logging.DEBUG) prog_isolations = { "none": sf.ProgramIsolation.NONE, @@ -79,23 +75,32 @@ def __init__( self.workers = [] self.fibers = [] - self.fiber_status = [] + self.idle_fibers = set() for idx, device in enumerate(self.sysman.ls.devices): for i in range(self.workers_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") self.workers.append(worker) + for idx, device in enumerate(self.sysman.ls.devices): for i in range(self.fibers_per_device): - fiber = sysman.ls.create_fiber( - self.workers[i % len(self.workers)], devices=[device] - ) + tgt_worker = self.workers[i % len(self.workers)] + fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) self.fibers.append(fiber) - self.fiber_status.append(0) + self.idle_fibers.add(fiber) for idx in range(len(self.workers)): self.inference_programs[idx] = {} self.inference_functions[idx] = {} # Scope dependent objects. self.batcher = BatcherProcess(self) + def get_worker_index(self, fiber): + if fiber not in self.fibers: + raise ValueError("A worker was requested from a rogue fiber.") + fiber_idx = self.fibers.index(fiber) + worker_idx = int( + (fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker + ) + return worker_idx + def load_inference_module(self, vmfb_path: Path, component: str = None): if not self.inference_modules.get(component): self.inference_modules[component] = [] @@ -112,7 +117,7 @@ def load_inference_parameters( ): p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope) for path in paths: - logging.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) + logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) p.load(path, format=format) if not self.inference_parameters.get(component): self.inference_parameters[component] = [] @@ -121,6 +126,7 @@ def load_inference_parameters( def start(self): # Initialize programs. for component in self.inference_modules: + logger.info(f"Loading component: {component}") component_modules = [ sf.ProgramModule.parameter_provider( self.sysman.ls, *self.inference_parameters.get(component, []) @@ -141,7 +147,6 @@ def start(self): isolation=self.prog_isolation, trace_execution=self.trace_execution, ) - logger.info("Program loaded.") for worker_idx, worker in enumerate(self.workers): self.inference_functions[worker_idx]["encode"] = {} @@ -270,14 +275,17 @@ def board_flights(self): return self.strobes = 0 batches = self.sort_batches() - for idx, batch in batches.items(): - for fidx, status in enumerate(self.service.fiber_status): - if ( - status == 0 - or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL - ): - self.board(batch["reqs"], index=fidx) - break + for batch in batches.values(): + # Assign the batch to the next idle fiber. + if len(self.service.idle_fibers) == 0: + return + fiber = self.service.idle_fibers.pop() + fiber_idx = self.service.fibers.index(fiber) + worker_idx = self.service.get_worker_index(fiber) + logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") + self.board(batch["reqs"], fiber=fiber) + if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(fiber) def sort_batches(self): """Files pending requests into sorted batches suitable for program invocations.""" @@ -310,11 +318,11 @@ def sort_batches(self): } return batches - def board(self, request_bundle, index): + def board(self, request_bundle, fiber): pending = request_bundle if len(pending) == 0: return - exec_process = InferenceExecutorProcess(self.service, index) + exec_process = InferenceExecutorProcess(self.service, fiber) for req in pending: if len(exec_process.exec_requests) >= self.ideal_batch_size: break @@ -322,8 +330,6 @@ def board(self, request_bundle, index): if exec_process.exec_requests: for flighted_request in exec_process.exec_requests: self.pending_requests.remove(flighted_request) - if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL: - self.service.fiber_status[index] = 1 exec_process.launch() @@ -338,15 +344,11 @@ class InferenceExecutorProcess(sf.Process): def __init__( self, service: GenerateService, - index: int, + fiber, ): - super().__init__(fiber=service.fibers[index]) + super().__init__(fiber=fiber) self.service = service - self.fiber_index = index - self.worker_index = int( - (index - index % self.service.fibers_per_worker) - / self.service.fibers_per_worker - ) + self.worker_index = self.service.get_worker_index(fiber) self.exec_requests: list[InferenceExecRequest] = [] @measure(type="exec", task="inference process") @@ -360,7 +362,7 @@ async def run(self): phase = req.phase phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - device0 = self.service.fibers[self.fiber_index].device(0) + device0 = self.fiber.device(0) if phases[InferencePhase.PREPARE]["required"]: await self._prepare(device=device0, requests=self.exec_requests) if phases[InferencePhase.ENCODE]["required"]: @@ -375,7 +377,8 @@ async def run(self): for i in range(req_count): req = self.exec_requests[i] req.done.set_success() - self.service.fiber_status[self.fiber_index] = 0 + if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(self.fiber) except Exception: logger.exception("Fatal error in image generation") @@ -574,7 +577,7 @@ async def _denoise(self, device, requests): for i, t in tqdm( enumerate(range(step_count)), disable=(not self.service.show_progress), - desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})", + desc=f"DENOISE (bs{req_bs})", ): step = sfnp.device_array.for_device(device, [1], sfnp.sint64) s_host = step.for_transfer() diff --git a/shortfin/python/shortfin_apps/sd/components/tokenizer.py b/shortfin/python/shortfin_apps/sd/components/tokenizer.py index 2bd3781d1..5903d89a5 100644 --- a/shortfin/python/shortfin_apps/sd/components/tokenizer.py +++ b/shortfin/python/shortfin_apps/sd/components/tokenizer.py @@ -4,12 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from pathlib import Path - from transformers import CLIPTokenizer, BatchEncoding -import numpy as np - import shortfin as sf import shortfin.array as sfnp diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 9cd624241..4e3835690 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -5,23 +5,21 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Any - import argparse import logging from pathlib import Path import sys import os -import io import copy import subprocess +from contextlib import asynccontextmanager +import uvicorn # Import first as it does dep checking and reporting. from shortfin.interop.fastapi import FastAPIResponder - -from contextlib import asynccontextmanager +from shortfin.support.logging_setup import native_handler from fastapi import FastAPI, Request, Response -import uvicorn from .components.generate import ClientGenerateBatchProcess from .components.config_struct import ModelParams @@ -29,25 +27,49 @@ from .components.manager import SystemManager from .components.service import GenerateService from .components.tokenizer import Tokenizer -from .components.builders import sdxl -from shortfin.support.logging_setup import native_handler, configure_main_logger logger = logging.getLogger("shortfin-sd") logger.addHandler(native_handler) -logger.setLevel(logging.INFO) logger.propagate = False THIS_DIR = Path(__file__).resolve().parent +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + @asynccontextmanager async def lifespan(app: FastAPI): sysman.start() try: for service_name, service in services.items(): - logging.info("Initializing service '%s':", service_name) - logging.info(str(service)) + logger.info("Initializing service '%s':", service_name) + logger.info(str(service)) service.start() except: sysman.shutdown() @@ -55,7 +77,7 @@ async def lifespan(app: FastAPI): yield try: for service_name, service in services.items(): - logging.info("Shutting down service '%s'", service_name) + logger.info("Shutting down service '%s'", service_name) service.shutdown() finally: sysman.shutdown() @@ -83,11 +105,14 @@ async def generate_request(gen_req: GenerateReqInput, request: Request): app.put("/generate")(generate_request) -def configure(args) -> SystemManager: +def configure_sys(args) -> SystemManager: # Setup system (configure devices, etc). model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) + return sysman, model_config, flagfile, tuning_spec + +def configure_service(args, sysman, model_config, flagfile, tuning_spec): # Setup each service we are hosting. tokenizers = [] for idx, tok_name in enumerate(args.tokenizers): @@ -163,13 +188,13 @@ def get_configs(args): try: val = int(val) except ValueError: - continue + val = val elif len(arglist) == 2: value = arglist[-1] try: value = int(value) except ValueError: - continue + value = value else: # It's a boolean arg. value = True @@ -178,7 +203,6 @@ def get_configs(args): # It's an env var. arglist = spec.split("=") os.environ[arglist[0]] = arglist[1] - return model_config, topology_config, flagfile, tuning_spec, args @@ -207,6 +231,7 @@ def get_modules(args, model_config, flagfile, td_spec): filenames = [] for modelname in vmfbs.keys(): ireec_args = model_flags["all"] + model_flags[modelname] + ireec_extra_args = " ".join(ireec_args) builder_args = [ sys.executable, "-m", @@ -220,8 +245,12 @@ def get_modules(args, model_config, flagfile, td_spec): f"--model={modelname}", f"--iree-hal-target-device={args.device}", f"--iree-hip-target={args.target}", - f"--iree-compile-extra-args={' '.join(ireec_args)}", + f"--iree-compile-extra-args={ireec_extra_args}", ] + logger.info(f"Preparing runtime artifacts for {modelname}...") + logger.debug( + "COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]) + ) output = subprocess.check_output(builder_args).decode() output_paths = output.splitlines() @@ -229,16 +258,14 @@ def get_modules(args, model_config, flagfile, td_spec): for name in filenames: for key in vmfbs.keys(): if key in name.lower(): - if any([x in name for x in [".irpa", ".safetensors", ".gguf"]]): + if any(x in name for x in [".irpa", ".safetensors", ".gguf"]): params[key].extend([name]) elif "vmfb" in name: vmfbs[key].extend([name]) return vmfbs, params -def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): - from pathlib import Path - +def main(argv, log_config=UVICORN_LOG_CONFIG): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) @@ -257,7 +284,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): type=str, required=False, default="gfx942", - choices=["gfx942", "gfx1100"], + choices=["gfx942", "gfx1100", "gfx90a"], help="Primary inferencing device LLVM target arch.", ) parser.add_argument( @@ -297,7 +324,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--isolation", type=str, - default="per_fiber", + default="per_call", choices=["per_fiber", "per_call", "none"], help="Concurrency control -- How to isolate programs.", ) @@ -365,15 +392,17 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): default=1, help="Use tunings for attention and matmul ops. 0 to disable.", ) - args = parser.parse_args(argv) if not args.artifacts_dir: home = Path.home() artdir = home / ".cache" / "shark" args.artifacts_dir = str(artdir) + else: + args.artifacts_dir = Path(args.artifacts_dir).resolve() global sysman - sysman = configure(args) + sysman, model_config, flagfile, tuning_spec = configure_sys(args) + configure_service(args, sysman, model_config, flagfile, tuning_spec) uvicorn.run( app, host=args.host, @@ -388,27 +417,5 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): main( sys.argv[1:], # Make logging defer to the default shortfin logging config. - log_config={ - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(asctime)s - %(levelname)s - %(message)s", - "datefmt": "%Y-%m-%d %H:%M:%S", - }, - }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "default", - }, - }, - "loggers": { - "uvicorn": { - "handlers": ["console"], - "level": "INFO", - "propagate": False, - }, - }, - }, + log_config=UVICORN_LOG_CONFIG, ) diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index bc0f10655..0d88a59c7 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -4,17 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from datetime import datetime as dt +import os +import sys +import time import json -import requests import argparse import base64 -import time import asyncio import aiohttp -import sys -import os +import requests -from datetime import datetime as dt from PIL import Image sample_request = { @@ -32,10 +32,10 @@ } -def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024): +def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024): timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") image = Image.frombytes( - mode="RGB", size=(width, height), data=base64.b64decode(bytes) + mode="RGB", size=(width, height), data=base64.b64decode(in_bytes) ) if not os.path.isdir(outputdir): os.mkdir(outputdir) @@ -58,14 +58,13 @@ def get_batched(request, arg, idx): async def send_request(session, rep, args, data): print("Sending request batch #", rep) - url = f"http://0.0.0.0:{args.port}/generate" + url = f"{args.host}:{args.port}/generate" start = time.time() async with session.post(url, json=data) as response: end = time.time() # Check if the response was successful if response.status == 200: response.raise_for_status() # Raise an error for bad responses - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") res_json = await response.json(content_type=None) if args.save: for idx, item in enumerate(res_json["images"]): @@ -78,9 +77,8 @@ async def send_request(session, rep, args, data): latency = end - start print("Responses processed.") return latency, len(data["prompt"]) - else: - print(f"Error: Received {response.status} from server") - raise Exception + print(f"Error: Received {response.status} from server") + raise Exception async def static(args): @@ -116,7 +114,7 @@ async def static(args): latencies.append(latency) sample_counts.append(num_samples) end = time.time() - if not any([i is None for i in [latencies, sample_counts]]): + if not any(i is None for i in [latencies, sample_counts]): total_num_samples = sum(sample_counts) sps = str(total_num_samples / (end - start)) # Until we have better measurements, don't report the throughput that includes saving images. @@ -163,9 +161,9 @@ async def interactive(args): pending, return_when=asyncio.ALL_COMPLETED ) for task in done: - latency, num_samples = await task + _, _ = await task pending = [] - if any([i is None for i in [latencies, sample_counts]]): + if any(i is None for i in [latencies, sample_counts]): raise ValueError("Received error response from server.") @@ -175,11 +173,27 @@ async def ainput(prompt: str) -> str: async def async_range(count): for i in range(count): - yield (i) + yield i await asyncio.sleep(0.0) -def main(argv): +def check_health(url): + ready = False + print("Waiting for server.", end=None) + while not ready: + try: + if requests.get(f"{url}/health", timeout=20).status_code == 200: + print("Successfully connected to server.") + ready = True + return + time.sleep(2) + print(".", end=None) + except: + time.sleep(2) + print(".", end=None) + + +def main(): p = argparse.ArgumentParser() p.add_argument( "--file", @@ -205,6 +219,9 @@ def main(argv): default="gen_imgs", help="Directory to which images get saved.", ) + p.add_argument( + "--host", type=str, default="http://0.0.0.0", help="Server host address." + ) p.add_argument("--port", type=str, default="8000", help="Server port") p.add_argument( "--steps", @@ -218,6 +235,7 @@ def main(argv): help="Start as an example CLI client instead of sending static requests.", ) args = p.parse_args() + check_health(f"{args.host}:{args.port}") if args.interactive: asyncio.run(interactive(args)) else: @@ -225,4 +243,4 @@ def main(argv): if __name__ == "__main__": - main(sys.argv) + main() diff --git a/shortfin/requirements-iree-compiler.txt b/shortfin/requirements-iree-compiler.txt index ec033c57c..ada82f2eb 100644 --- a/shortfin/requirements-iree-compiler.txt +++ b/shortfin/requirements-iree-compiler.txt @@ -1,4 +1,4 @@ # Keep in sync with "ref: iree-" in .github/workflows/* and GIT_TAG in CMakeLists.txt -f https://iree.dev/pip-release-links.html -iree-base-compiler==3.0.0rc20241115 -iree-base-runtime==3.0.0rc20241115 +iree-base-compiler==3.0.0rc20241118 +iree-base-runtime==3.0.0rc20241118 diff --git a/shortfin/src/shortfin/local/worker.cc b/shortfin/src/shortfin/local/worker.cc index d5ffafdbe..eed500891 100644 --- a/shortfin/src/shortfin/local/worker.cc +++ b/shortfin/src/shortfin/local/worker.cc @@ -46,8 +46,8 @@ Worker::Worker(const Options options) iree_status_ignore(status); }; // TODO: We need a way to dynamically resize this vs having a hard limit. - iree_loop_sync_options_t loop_options = {.max_queue_depth = 256, - .max_wait_count = 256}; + iree_loop_sync_options_t loop_options = {.max_queue_depth = 2048, + .max_wait_count = 2048}; SHORTFIN_THROW_IF_ERROR( iree_loop_sync_allocate(loop_options, options_.allocator, &loop_sync_)); iree_loop_sync_scope_initialize(loop_sync_, OnError, this, &loop_scope_); diff --git a/shortfin/version.json b/shortfin/version.json index f09f61d2a..85afb41ed 100644 --- a/shortfin/version.json +++ b/shortfin/version.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.2.dev" + "package-version": "3.0.0.dev" } diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index b50df12d5..ee331a2a6 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -517,48 +517,52 @@ def tune( with ir.Context() as ctx: tuner_context = TunerContext(ctx, tune_logger) - mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) + with parse_mlir(mlir_text, tuner_context) as mlir_module: + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) - walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - assert dispatch_tuner, "No suitable dispatch tuner found" - problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate( - generate_solutions(tuner_context, problem_size, num_subgroups) - ): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + walk_result: OpWalkResult = walk_mlir_op( + mlir_module, dispatch_tuner_registry + ) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate( + generate_solutions(tune_logger, problem_size, num_subgroups) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params( + problem_size, mlir_template, config + ) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") def main(): diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 47e351fc7..36fb87cbb 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -10,17 +10,31 @@ import pytest +from typing import Generator + +from iree.compiler import ir # type: ignore + from . import candidate_gen from . import common +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + def remove_comments(mlir: str) -> str: return "\n".join( filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) ) -def test_apply_params_mmt() -> None: +def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: problem_size = common.ProblemSize( common.MatmulSize(M, N, K), - common.ShapedType([M, K], common.ElementType.f16), - common.ShapedType([N, K], common.ElementType.f16), - common.ShapedType([M, N], common.ElementType.f32), + common.ShapedType([M, K], tuner_ctx.type.f16), + common.ShapedType([N, K], tuner_ctx.type.f16), + common.ShapedType([M, N], tuner_ctx.type.f32), common.DispatchKind.mmt, ) tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) @@ -73,7 +87,7 @@ def test_apply_params_mmt() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified -def test_apply_params_conv() -> None: +def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: problem_size = common.ProblemSize( common.MatmulSize(oh * ow, oc, fh * fw * ic), - common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16), - common.ShapedType([fh, fw, ic, oc], common.ElementType.f16), - common.ShapedType([n, oh, ow, oc], common.ElementType.f32), + common.ShapedType([n, oh + 2, ow + 2, oc], tuner_ctx.type.f16), + common.ShapedType([fh, fw, ic, oc], tuner_ctx.type.f16), + common.ShapedType([n, oh, ow, oc], tuner_ctx.type.f32), common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( @@ -130,7 +144,7 @@ def test_apply_params_conv() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified -def test_apply_params_contract() -> None: +def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", " None: tile_dims = "*mnk" problem_size = common.ProblemSize( common.MatmulSize(2048, 3840, 1280), - common.ShapedType([2, 1024, 1280], common.ElementType.f16), - common.ShapedType([3, 20, 64, 1280], common.ElementType.f16), - common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32), + common.ShapedType([2, 1024, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 20, 64, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 2, 20, 1024, 64], tuner_ctx.type.f32), common.DispatchKind.contraction, ) @@ -177,7 +191,7 @@ def test_apply_params_contract() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir -def test_apply_params_batch_matmul() -> None: +def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: tile_dims = "bmnk" problem_size = common.ProblemSize( common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], common.ElementType.f16), - common.ShapedType([64, 640, 320], common.ElementType.f16), - common.ShapedType([64, 968, 320], common.ElementType.f32), + common.ShapedType([64, 968, 640], tuner_ctx.type.f16), + common.ShapedType([64, 640, 320], tuner_ctx.type.f16), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, ) @@ -228,7 +242,7 @@ def test_apply_params_batch_matmul() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified -def test_apply_params_batch_mmt_float() -> None: +def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: problem_size = common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.f16), - common.ShapedType([2, 640, 640], common.ElementType.f16), - common.ShapedType([2, 4096, 640], common.ElementType.f32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.f16), + common.ShapedType([2, 640, 640], tuner_ctx.type.f16), + common.ShapedType([2, 4096, 640], tuner_ctx.type.f32), common.DispatchKind.batch_mmt, ) @@ -276,7 +290,7 @@ def test_apply_params_batch_mmt_float() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified -def test_apply_params_batch_mmt_int() -> None: +def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: problem_size = common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.i8), - common.ShapedType([2, 640, 640], common.ElementType.i8), - common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([2, 640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), common.DispatchKind.batch_mmt, ) @@ -347,7 +361,7 @@ def test_apply_params_batch_mmt_int() -> None: assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable -def test_apply_params_broadcast_rhs_mmt() -> None: +def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: problem_size = common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.i8), - common.ShapedType([640, 640], common.ElementType.i8), - common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), common.DispatchKind.broadcast_rhs_mmt, ) @@ -422,7 +436,7 @@ def test_apply_params_broadcast_rhs_mmt() -> None: assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable -def test_detect_broadcast_rhs_mmt() -> None: +def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 7b295cdb0..a34f172eb 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -13,10 +13,27 @@ from iree.compiler import ir # type: ignore +class CommonTypes: + def __init__(self, ctx: ir.Context): + assert ctx + self.i1 = ir.IntegerType.get_signless(1, ctx) + self.i8 = ir.IntegerType.get_signless(8, ctx) + self.i16 = ir.IntegerType.get_signless(16, ctx) + self.i32 = ir.IntegerType.get_signless(32, ctx) + + self.f8E4M3FNUZ = ir.Float8E4M3FNUZType.get(ctx) + self.f8E5M2FNUZ = ir.Float8E5M2FNUZType.get(ctx) + self.f16 = ir.F16Type.get(ctx) + self.f32 = ir.F32Type.get(ctx) + + self.bf16 = ir.BF16Type.get(ctx) + + class TunerContext: def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): - self.mlir_ctx = mlir_ctx - self.logger = logger + self.mlir_ctx: ir.Context = mlir_ctx + self.logger: logging.Logger = logger + self.type: CommonTypes = CommonTypes(mlir_ctx) class DispatchKind(Enum): @@ -28,40 +45,17 @@ class DispatchKind(Enum): broadcast_rhs_mmt = 6 -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - @dataclass class ShapedType: shape: list[int] - element_type: ElementType + element_type: ir.IntegerType | ir.FloatType def rank(self) -> int: return len(self.shape) @property def bitwidth(self) -> int: - return self.element_type.bitwidth + return self.element_type.width def __str__(self) -> str: dim_to_str = lambda dim: str(dim) if dim != -1 else "?" @@ -91,11 +85,11 @@ def MNK(self) -> tuple[int, int, int]: @dataclass class MfmaIntrinsic: - output_type: ElementType + output_type: ir.IntegerType | ir.FloatType m: int n: int k: int - input_type: ElementType + input_type: ir.IntegerType | ir.FloatType def __str__(self) -> str: input = str(self.input_type).upper() @@ -104,19 +98,27 @@ def __str__(self) -> str: @staticmethod def mfma_f32_16x16x16_f16(): - return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + return MfmaIntrinsic(f32, 16, 16, 16, f16) @staticmethod def mfma_f32_32x32x8_f16(): - return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + return MfmaIntrinsic(f32, 32, 32, 8, f16) @staticmethod def mfma_i32_16x16x32_i8(): - return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8) + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + return MfmaIntrinsic(i32, 16, 16, 32, i8) @staticmethod def mfma_i32_32x32x16_i8(): - return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8) + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + return MfmaIntrinsic(i32, 32, 32, 16, i8) @staticmethod def all(): @@ -201,22 +203,6 @@ def get_pipeline_config(configuration: Configuration) -> str: return extra_config -class MlirRegex(Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" - - def __str__(self) -> str: - return self.value - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - def read_input_mlir(filename: str) -> list[str]: with open(filename, "r") as f: return f.readlines() @@ -243,18 +229,6 @@ def from_problem_size(problem_size: ProblemSize): return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) -def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) - assert shape_match - - shape_str = shape_match.group(1) - dims_and_elem = shape_str.split("x") - dims = [int(x) for x in dims_and_elem[:-1]] - elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) - - @dataclass class MLIRTransformation: """Transformation of MLIR context""" diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 858d593c9..891d703e2 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -11,28 +11,39 @@ import pytest from . import common +from typing import Generator -def test_get_shaped_type_element_bitwidth() -> None: - assert common.ShapedType([1024, 2048], common.ElementType.i8).bitwidth == 8 - assert common.ShapedType([2048], common.ElementType.i32).bitwidth == 32 - assert common.ShapedType([2048, 512, 384], common.ElementType.f8).bitwidth == 8 - assert common.ShapedType([1, 1], common.ElementType.f16).bitwidth == 16 +from iree.compiler import ir # type: ignore -def test_get_shaped_type_to_str() -> None: - assert str(common.ShapedType([1024, 2048], common.ElementType.i8)) == "1024x2048xi8" - assert str(common.ShapedType([1024], common.ElementType.f32)) == "1024xf32" - assert str(common.ShapedType([1, 2, 3], common.ElementType.f16)) == "1x2x3xf16" - assert str(common.ShapedType([-1, 2, 3], common.ElementType.f16)) == "?x2x3xf16" +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) -def test_parse_tensor_type() -> None: - assert common.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( - [1, 2, 3], common.ElementType.f32 - ) - assert common.parse_tensor_type("tensor<123xi8>") == common.ShapedType( - [123], common.ElementType.i8 - ) + +@pytest.fixture +def mlir_ctx() -> Generator[ir.Context, None, None]: + with ir.Context() as ctx: + yield ctx + + +def test_get_shaped_type_element_bitwidth(tuner_ctx: common.TunerContext) -> None: + assert common.ShapedType([1024, 2048], tuner_ctx.type.i8).bitwidth == 8 + assert common.ShapedType([2048], tuner_ctx.type.i32).bitwidth == 32 + assert common.ShapedType([2048, 512, 384], tuner_ctx.type.f8E4M3FNUZ).bitwidth == 8 + assert common.ShapedType([1, 1], tuner_ctx.type.f16).bitwidth == 16 + + +def test_get_shaped_type_to_str(tuner_ctx: common.TunerContext) -> None: + assert str(common.ShapedType([1024, 2048], tuner_ctx.type.i8)) == "1024x2048xi8" + assert str(common.ShapedType([1024], tuner_ctx.type.f32)) == "1024xf32" + assert str(common.ShapedType([1, 2, 3], tuner_ctx.type.f16)) == "1x2x3xf16" + assert str(common.ShapedType([-1, 2, 3], tuner_ctx.type.f16)) == "?x2x3xf16" def test_gpu_pipeline_options() -> None: @@ -59,7 +70,7 @@ def test_gpu_pipeline_options() -> None: ) -def test_get_pipeline_config() -> None: +def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], @@ -85,18 +96,18 @@ def test_get_pipeline_config() -> None: ) -def test_mfma_intrinsic_to_str() -> None: +def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" -def test_get_compatible_mfma_intrinsics() -> None: +def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.f16), - common.ShapedType([1280, 1280], common.ElementType.f16), - common.ShapedType([2048, 1280], common.ElementType.f32), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.mmt, ) ) == [ @@ -107,9 +118,9 @@ def test_get_compatible_mfma_intrinsics() -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.i8), - common.ShapedType([1280, 1280], common.ElementType.i8), - common.ShapedType([2048, 1280], common.ElementType.i32), + common.ShapedType([2048, 1280], tuner_ctx.type.i8), + common.ShapedType([1280, 1280], tuner_ctx.type.i8), + common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.mmt, ) ) == [ @@ -120,9 +131,9 @@ def test_get_compatible_mfma_intrinsics() -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], common.ElementType.f32), - common.ShapedType([64, 640, 320], common.ElementType.f32), - common.ShapedType([64, 968, 320], common.ElementType.f32), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, ) ) == [ diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index ac46d8edd..edd7ccc38 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -130,10 +130,10 @@ def generate_constraints( def generate_solutions( - ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int + logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int ) -> Iterator[Configuration]: M, N, K = problem_size.MNK - ctx.logger.info(f"{M},{N},{K}") + logger.info(f"{M},{N},{K}") m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -170,7 +170,7 @@ def generate_solutions( waves_per_eu, ) solver.add(z3.simplify(z3.And(constraints))) - ctx.logger.debug(f"Initial constraints: {solver}") + logger.debug(f"Initial constraints: {solver}") i = 0 while solver.check() == z3.sat: model = solver.model() diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 55f3a8c43..7e1a5c55d 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -11,32 +11,41 @@ import pytest import z3 # type: ignore -from logging import Logger -from unittest.mock import MagicMock +from typing import Generator + +from iree.compiler import ir # type: ignore from . import common from . import dispatch_constraints -def test_generate_solutions() -> None: +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: matmul_size = common.MatmulSize(2048, 3840, 1280) - lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) - rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) - res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + lhs_type = common.ShapedType([2048, 1280], tuner_ctx.type.f16) + rhs_type = common.ShapedType([3840, 1280], tuner_ctx.type.f16) + res_type = common.ShapedType([2048, 3840], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) - logger: Logger = MagicMock(spec=Logger) - ctx = common.TunerContext(None, logger) - configs = dispatch_constraints.generate_solutions(ctx, problem_size, 4) + configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4) assert configs is not None -def test_calculate_shared_memory_usage_in_bytes() -> None: +def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) -> None: matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -47,7 +56,7 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: == 147456 ) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i8) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -58,7 +67,7 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: == 81920 ) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -70,11 +79,11 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: ) -def test_generate_constraints_valid_input() -> None: +def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> None: matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -115,12 +124,12 @@ def test_generate_constraints_valid_input() -> None: assert solver.check() == z3.sat -def test_generate_constraints_invalid_input() -> None: +def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> None: # Define input parameters that should lead to unsatisfiable constraints matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 670f8c3f7..c4b4b9ad5 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -14,6 +14,12 @@ from .common import * +def parse_tensor_type(tensor_type: str) -> ShapedType: + shaped_ty = ir.RankedTensorType(ir.Type.parse(tensor_type)) + assert shaped_ty + return ShapedType(shaped_ty.shape, shaped_ty.element_type) + + def get_mmt_tile_sizes(configuration: Configuration): return configuration.tile_sizes @@ -35,6 +41,22 @@ def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: return [1] + configuration.tile_sizes +class MlirRegex(Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<([^>]+)>" + + def __str__(self) -> str: + return self.value + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: mlir_module = None try: diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index bcdee240c..d3a99806f 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -10,8 +10,7 @@ import pytest -from logging import Logger -from unittest.mock import MagicMock +from typing import Generator from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore @@ -20,7 +19,26 @@ from . import dispatch_parser -def test_get_mmt_tile_sizes() -> None: +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: + assert dispatch_parser.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( + [1, 2, 3], tuner_ctx.type.f32 + ) + assert dispatch_parser.parse_tensor_type("tensor<123xi8>") == common.ShapedType( + [123], tuner_ctx.type.i8 + ) + + +def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], @@ -34,7 +52,7 @@ def test_get_mmt_tile_sizes() -> None: assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] -def test_get_conv_tile_sizes() -> None: +def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], @@ -56,7 +74,7 @@ def test_get_conv_tile_sizes() -> None: ] -def test_get_contract_tile_sizes() -> None: +def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], @@ -77,7 +95,7 @@ def test_get_contract_tile_sizes() -> None: ] -def test_get_shapes_mmt() -> None: +def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -86,14 +104,14 @@ def test_get_shapes_mmt() -> None: ] assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.f16), - common.ShapedType([1280, 1280], common.ElementType.f16), - common.ShapedType([2048, 1280], common.ElementType.f32), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), dispatch_parser.DispatchKind.mmt, ) -def test_get_shapes_conv() -> None: +def test_get_shapes_conv(tuner_ctx: common.TunerContext) -> None: template = [ r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", @@ -101,14 +119,14 @@ def test_get_shapes_conv() -> None: ] assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( common.MatmulSize(32, 256, 11520), - common.ShapedType([1, 3, 34, 1280], common.ElementType.f16), - common.ShapedType([3, 3, 1280, 256], common.ElementType.f16), - common.ShapedType([1, 1, 32, 256], common.ElementType.f32), + common.ShapedType([1, 3, 34, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 3, 1280, 256], tuner_ctx.type.f16), + common.ShapedType([1, 1, 32, 256], tuner_ctx.type.f32), dispatch_parser.DispatchKind.conv, ) -def test_get_shapes_contract() -> None: +def test_get_shapes_contract(tuner_ctx: common.TunerContext) -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -119,14 +137,14 @@ def test_get_shapes_contract() -> None: template ) == common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.f16), - common.ShapedType([1280, 1280], common.ElementType.f16), - common.ShapedType([2048, 1280], common.ElementType.f32), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), dispatch_parser.DispatchKind.contraction, ) -def test_get_shapes_batch_matmul() -> None: +def test_get_shapes_batch_matmul(tuner_ctx: common.TunerContext) -> None: template = [ "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", @@ -136,14 +154,14 @@ def test_get_shapes_batch_matmul() -> None: template ) == common.ProblemSize( common.MatmulSize(32, 32, 1024, 1), - common.ShapedType([1, 32, 1024], common.ElementType.f32), - common.ShapedType([1, 1024, 32], common.ElementType.f32), - common.ShapedType([1, 32, 32], common.ElementType.f32), + common.ShapedType([1, 32, 1024], tuner_ctx.type.f32), + common.ShapedType([1, 1024, 32], tuner_ctx.type.f32), + common.ShapedType([1, 32, 32], tuner_ctx.type.f32), dispatch_parser.DispatchKind.batch_matmul, ) -def test_get_shapes_batch_mmt() -> None: +def test_get_shapes_batch_mmt(tuner_ctx: common.TunerContext) -> None: template = [ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', @@ -151,26 +169,23 @@ def test_get_shapes_batch_mmt() -> None: ] assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.i8), - common.ShapedType([2, 640, 640], common.ElementType.i8), - common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([2, 640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), dispatch_parser.DispatchKind.batch_mmt, ) -def test_parse_mlir() -> None: - with ir.Context() as ctx: - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - logger: Logger = MagicMock(spec=Logger) - tuner_context = common.TunerContext(ctx, logger) - mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_context) - assert mlir_module is not None - assert isinstance(mlir_module, ir.Module) - assert isinstance(mlir_module.body.operations[0], func.FuncOp) +def test_parse_mlir(tuner_ctx: common.TunerContext) -> None: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } +""" + mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_ctx) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/version.json b/tuner/version.json index f09f61d2a..85afb41ed 100644 --- a/tuner/version.json +++ b/tuner/version.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.2.dev" + "package-version": "3.0.0.dev" }