diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci-libshortfin.yml
similarity index 54%
rename from .github/workflows/ci_linux_x64-libshortfin.yml
rename to .github/workflows/ci-libshortfin.yml
index afeca11a6..33a6df72b 100644
--- a/.github/workflows/ci_linux_x64-libshortfin.yml
+++ b/.github/workflows/ci-libshortfin.yml
@@ -10,13 +10,13 @@ on:
workflow_dispatch:
pull_request:
paths:
- - '.github/workflows/ci_linux_x64-libshortfin.yml'
+ - '.github/workflows/ci-libshortfin.yml'
- 'shortfin/**'
push:
branches:
- main
paths:
- - '.github/workflows/ci_linux_x64-libshortfin.yml'
+ - '.github/workflows/ci-libshortfin.yml'
- 'shortfin/**'
permissions:
@@ -36,17 +36,55 @@ env:
jobs:
build-and-test:
- name: Build and test
- runs-on: ubuntu-24.04
+ name: "Unit tests :: ${{ matrix.name }} :: ${{ matrix.python-version }}"
+ runs-on: ${{ matrix.runs-on }}
+ defaults:
+ run:
+ shell: bash
strategy:
+ fail-fast: false
matrix:
+ name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Ubuntu (GCC)", "Windows (MSVC)"]
python-version: ["3.10", "3.11", "3.12"]
+ include:
+ - name: Ubuntu (Clang)(full)
+ runs-on: ubuntu-24.04
+ cmake-options:
+ -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD
+ additional-packages: clang lld
+ - name: Ubuntu (Clang)(host-only)
+ runs-on: ubuntu-24.04
+ # In this configuration, also build static+dynamic in order to verify
+ # that path structurally works.
+ cmake-options:
+ -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD -DSHORTFIN_HAVE_AMDGPU=OFF -DSHORTFIN_BUILD_STATIC=ON -DSHORTFIN_BUILD_DYNAMIC=ON
+ additional-packages: clang lld
+ - name: Ubuntu (GCC)
+ runs-on: ubuntu-24.04
+ - name: Windows (MSVC)
+ runs-on: windows-2022
+ exclude:
+ # Only test Python 3.12 with GCC
+ - name: Ubuntu (GCC)
+ python-version: "3.10"
+ - name: Ubuntu (GCC)
+ python-version: "3.11"
+ # TODO: Include additional Python versions for Windows after build got fixed
+ - name: Windows (MSVC)
+ python-version: "3.10"
+ - name: Windows (MSVC)
+ python-version: "3.11"
steps:
- - name: Install dependencies
+ - name: (Linux) Install dependencies
+ if: "runner.os == 'Linux'"
run: |
sudo apt update
- sudo apt install clang lld cmake ninja-build
+ sudo apt install cmake ninja-build ${{matrix.additional-packages}}
+
+ - name: (Windows) Configure MSVC
+ if: "runner.os == 'Windows'"
+ uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -70,56 +108,42 @@ jobs:
git submodule update --init --depth 1 -- third_party/googletest
git submodule update --init --depth 1 -- third_party/hip-build-deps/
- - name: Setup Python ${{ matrix.python-version }}
+ - name: "Setup Python ${{ matrix.python-version }}"
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
+ cache-dependency-path: |
+ 'shortfin/requirements-tests.txt'
+ 'shortfin/requirements-iree-compiler.txt'
- name: Install Python packages
- # TODO: Switch to `pip install -r requirements.txt -e shortfin/`.
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
pip install -r requirements-tests.txt
pip install -r requirements-iree-compiler.txt
pip freeze
- - name: Build shortfin (full)
+ - name: Build shortfin
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
mkdir build
cmake -GNinja \
-S. \
-Bbuild \
- -DCMAKE_C_COMPILER=clang-18 \
- -DCMAKE_CXX_COMPILER=clang++-18 \
- -DCMAKE_LINKER_TYPE=LLD \
- -DSHORTFIN_BUNDLE_DEPS=ON \
-DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
- -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
+ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
+ ${{matrix.cmake-options}}
cmake --build build --target all
- pip install -v -e build/
- - name: Test shortfin (full)
+ - name: pip install shortfin
+ if: ${{ matrix.name != 'Ubuntu (Clang)(host-only)'}}
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- ctest --timeout 30 --output-on-failure --test-dir build
- pytest -s
+ pip install -v -e build/
- - name: Build shortfin (host-only)
+ - name: Test shortfin
+ if: ${{ matrix.name != 'Ubuntu (Clang)(host-only)'}}
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- mkdir build-host-only
- # In this configuration, also build static+dynamic in order to verify
- # that path structurally works.
- cmake -GNinja \
- -S. \
- -Bbuild-host-only \
- -DCMAKE_C_COMPILER=clang-18 \
- -DCMAKE_CXX_COMPILER=clang++-18 \
- -DCMAKE_LINKER_TYPE=LLD \
- -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
- -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
- -DSHORTFIN_HAVE_AMDGPU=OFF \
- -DSHORTFIN_BUILD_STATIC=ON \
- -DSHORTFIN_BUILD_DYNAMIC=ON
- cmake --build build-host-only --target all
+ ctest --timeout 30 --output-on-failure --test-dir build
+ pytest -s
diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml
index 34e91cebb..644066094 100644
--- a/.github/workflows/ci-llama-large-tests.yaml
+++ b/.github/workflows/ci-llama-large-tests.yaml
@@ -76,14 +76,14 @@ jobs:
iree-base-runtime
- name: Run llama tests
- run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/index.html
+ run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/llm/llama/benchmark/index.html
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
with:
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
- publish_dir: ./out/llm/llama/benchmarks
- destination_dir: ./llm/llama/benchmarks
+ publish_dir: ./out/llm/llama/benchmark
+ destination_dir: ./llm/llama/benchmark
keep_files: true
- name: Upload llama executable files
diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml
index 504e7e5e3..f44e2772b 100644
--- a/.github/workflows/ci-sglang-benchmark.yml
+++ b/.github/workflows/ci-sglang-benchmark.yml
@@ -28,7 +28,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
- runs-on: llama-mi300x-3
+ runs-on: mi300x-4
defaults:
run:
shell: bash
@@ -78,7 +78,7 @@ jobs:
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
- name: Launch Shortfin Server
- run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
+ run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml
index 1c382617d..c61756d78 100644
--- a/.github/workflows/ci-sglang-integration-tests.yml
+++ b/.github/workflows/ci-sglang-integration-tests.yml
@@ -29,7 +29,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
- runs-on: llama-mi300x-3
+ runs-on: mi300x-4
defaults:
run:
shell: bash
diff --git a/.github/workflows/ci-shark-ai.yml b/.github/workflows/ci-shark-ai.yml
index bf8007e65..fc85a76a7 100644
--- a/.github/workflows/ci-shark-ai.yml
+++ b/.github/workflows/ci-shark-ai.yml
@@ -49,7 +49,7 @@ jobs:
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
- key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }}
- name: Install pip deps
run: |
diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml
index 54aa3c763..0164b6cdc 100644
--- a/.github/workflows/ci_eval.yaml
+++ b/.github/workflows/ci_eval.yaml
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-name: CI - Perplexity
+name: CI - sharktank perplexity
on:
workflow_dispatch:
@@ -21,10 +21,10 @@ concurrency:
cancel-in-progress: true
jobs:
- test_perplexity_vmfb:
+ test_perplexity_iree:
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
timeout-minutes: 1000
- name: "IREE/vmfb"
+ name: "Perplexity-IREE"
strategy:
matrix:
version: [3.11]
@@ -74,13 +74,21 @@ jobs:
iree-base-compiler \
iree-base-runtime
- - name: Run perplexity test with vmfb
- run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
+ - name: Run perplexity test with IREE
+ run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/llama/perplexity/iree_perplexity
+ destination_dir: ./llm/llama/perplexity/iree_perplexity
+ keep_files: true
test_perplexity_torch:
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
timeout-minutes: 1000
- name: "Torch/eager mode"
+ name: "Perplexity-Torch"
strategy:
matrix:
version: [3.11]
@@ -123,5 +131,13 @@ jobs:
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
- - name: Run perplexity test in eager mode
- run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
+ - name: Run perplexity test with Torch
+ run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/llama/perplexity/torch_perplexity
+ destination_dir: ./llm/llama/perplexity/torch_perplexity
+ keep_files: true
diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml
new file mode 100644
index 000000000..4622f5c57
--- /dev/null
+++ b/.github/workflows/ci_eval_short.yaml
@@ -0,0 +1,77 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - sharktank perplexity short
+
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - main
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ test_perplexity_iree:
+ name: "Llama3.1 8B FP16"
+ strategy:
+ matrix:
+ version: [3.11]
+ runs-on: [llama-mi300x-3]
+ fail-fast: false
+ runs-on: ${{matrix.runs-on}}
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
+ steps:
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
+
+ - name: Install sharktank deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-tubrine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+ # Try with the latest IREE nightly releases, not what iree-turbine pins.
+ # We could also pin to a known working or stable version.
+ # This should eventually stabilize. Do the best we can for now.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
+
+ - name: Run perplexity test with vmfb
+ run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
index 0e0e1db2a..550366e1b 100644
--- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
+++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
@@ -98,6 +98,6 @@ jobs:
- name: Run shortfin Python tests (full)
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd
+ pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/sd
# TODO: Enable further tests and switch to
# pytest -s
diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml
deleted file mode 100644
index 00873c432..000000000
--- a/.github/workflows/ci_windows_x64-libshortfin.yml
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-name: CI - shortfin - Windows
-
-on:
- workflow_dispatch:
- pull_request:
- paths:
- - '.github/workflows/ci_windows_x64-libshortfin.yml'
- - 'shortfin/**'
- push:
- branches:
- - main
- paths:
- - '.github/workflows/ci_windows_x64-libshortfin.yml'
- - 'shortfin/**'
-
-permissions:
- contents: read
-
-concurrency:
- # A PR number if a pull request and otherwise the commit hash. This cancels
- # queued and in-progress runs for the same PR (presubmit) or commit
- # (postsubmit). The workflow name is prepended to avoid conflicts between
- # different workflows.
- group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
- cancel-in-progress: true
-
-env:
- IREE_REPO_DIR: ${{ github.workspace }}/iree
- LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/
-
-jobs:
- build-and-test:
- name: Build and test
- runs-on: windows-2022
-
- steps:
- - name: Configure MSVC
- uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0
-
- - name: Checkout repository
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- submodules: false
-
- - name: Checkout IREE repo
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- repository: iree-org/iree
- path: ${{ env.IREE_REPO_DIR }}
- submodules: false
- ref: iree-3.0.0rc20241118
-
- - name: Initalize IREE submodules
- working-directory: ${{ env.IREE_REPO_DIR }}
- run : |
- git submodule update --init --depth 1 -- third_party/benchmark
- git submodule update --init --depth 1 -- third_party/cpuinfo/
- git submodule update --init --depth 1 -- third_party/flatcc
- git submodule update --init --depth 1 -- third_party/googletest
- git submodule update --init --depth 1 -- third_party/hip-build-deps/
-
- - name: Setup Python
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- with:
- python-version: "3.12"
- cache: "pip"
- - name: Install Python packages
- working-directory: ${{ env.LIBSHORTFIN_DIR }}
- run: |
- pip install -r requirements-tests.txt
- pip install -r requirements-iree-compiler.txt
- pip freeze
-
- - name: Build shortfin (full)
- working-directory: ${{ env.LIBSHORTFIN_DIR }}
- shell: bash
- run: |
- mkdir build
- cmake -GNinja \
- -S. \
- -Bbuild \
- -DSHORTFIN_BUNDLE_DEPS=ON \
- -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
- -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
- cmake --build build --target all
- pip install -v -e build/
-
- - name: Test shortfin (full)
- working-directory: ${{ env.LIBSHORTFIN_DIR }}
- run: |
- ctest --timeout 30 --output-on-failure --test-dir build
- pytest -s
diff --git a/README.md b/README.md
index 44f1e6113..ae3eac423 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ If you're looking to use SHARK check out our [User Guide](docs/user_guide.md). F
-[](https://badge.fury.io/py/shortfin) [](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush)
+[](https://badge.fury.io/py/shortfin) [](https://github.com/nod-ai/shark-ai/actions/workflows/ci-libshortfin.yml?query=event%3Apush)
The shortfin sub-project is SHARK's high performance inference library and
serving engine.
diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py
similarity index 74%
rename from app_tests/benchmark_tests/llm/conftest.py
rename to app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py
index cc354b7eb..1e1c64b24 100644
--- a/app_tests/benchmark_tests/llm/conftest.py
+++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py
@@ -9,15 +9,22 @@
import pytest
import sys
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
-from integration_tests.llm.utils import compile_model, export_paged_llm_v1
+sys.path.append(
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
+)
+from integration_tests.llm.utils import (
+ compile_model,
+ export_paged_llm_v1,
+ download_with_hf_datasets,
+)
@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")
- model_path = request.param["model_path"]
+ model_name = request.param["model_name"]
+ model_param_file_name = request.param["model_param_file_name"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]
@@ -25,6 +32,9 @@ def pre_process_model(request, tmp_path_factory):
config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"
+ model_path = tmp_dir / model_param_file_name
+ download_with_hf_datasets(tmp_dir, model_name)
+
export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
config = {
diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py
similarity index 76%
rename from app_tests/benchmark_tests/llm/sglang_benchmark_test.py
rename to app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py
index 0de775795..b66904570 100644
--- a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py
+++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py
@@ -4,7 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-import json
import logging
import multiprocessing
import os
@@ -16,14 +15,14 @@
pytest.importorskip("sglang")
from sglang import bench_serving
-from utils import SGLangBenchmarkArgs
+from .utils import SGLangBenchmarkArgs, log_jsonl_result
from integration_tests.llm.utils import (
find_available_port,
start_llm_server,
)
-logger = logging.getLogger("__name__")
+logger = logging.getLogger(__name__)
device_settings = {
"device_flags": [
@@ -33,30 +32,21 @@
"device": "hip",
}
-# TODO: Download on demand instead of assuming files exist at this path
-MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
-TOKENIZER_DIR = Path("/data/llama3.1/8b/")
-
-
-def log_jsonl_result(file_path):
- with open(file_path, "r") as file:
- json_string = file.readline().strip()
-
- json_data = json.loads(json_string)
- for key, val in json_data.items():
- logger.info(f"{key.upper()}: {val}")
-
@pytest.mark.parametrize(
- "request_rate",
- [1, 2, 4, 8, 16, 32],
+ "request_rate,model_param_file_name",
+ [
+ (req_rate, "meta-llama-3.1-8b-instruct.f16.gguf")
+ for req_rate in [1, 2, 4, 8, 16, 32]
+ ],
)
@pytest.mark.parametrize(
"pre_process_model",
[
(
{
- "model_path": MODEL_PATH,
+ "model_name": "llama3_8B_fp16",
+ "model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
"settings": device_settings,
"batch_sizes": [1, 4],
}
@@ -64,7 +54,9 @@ def log_jsonl_result(file_path):
],
indirect=True,
)
-def test_sglang_benchmark_server(request_rate, pre_process_model):
+def test_sglang_benchmark_server(
+ request_rate, model_param_file_name, pre_process_model
+):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
@@ -72,7 +64,8 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"
- tokenizer_path = TOKENIZER_DIR / "tokenizer.json"
+ tokenizer_path = tmp_dir / "tokenizer.json"
+ model_path = tmp_dir / model_param_file_name
# Start shortfin llm server
port = find_available_port()
@@ -81,7 +74,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
tokenizer_path,
config_path,
vmfb_path,
- MODEL_PATH,
+ model_path,
device_settings,
timeout=30,
)
@@ -91,7 +84,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
backend="shortfin",
num_prompt=10,
base_url=f"http://localhost:{port}",
- tokenizer=TOKENIZER_DIR,
+ tokenizer=tmp_dir,
request_rate=request_rate,
)
output_file = (
@@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
logger.info("======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
- logger.info(e)
+ logger.error(e)
server_process.terminate()
server_process.wait()
diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py
similarity index 84%
rename from app_tests/benchmark_tests/llm/utils.py
rename to app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py
index 55b01da04..47cea4d76 100644
--- a/app_tests/benchmark_tests/llm/utils.py
+++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py
@@ -6,8 +6,12 @@
from argparse import Namespace
from dataclasses import dataclass
+import json
+import logging
from pathlib import Path
+logger = logging.getLogger(__name__)
+
@dataclass
class SGLangBenchmarkArgs:
@@ -54,3 +58,12 @@ def __repr__(self):
f"Tokenizer: {self.tokenizer}\n"
f"Request Rate: {self.request_rate}"
)
+
+
+def log_jsonl_result(file_path):
+ with open(file_path, "r") as file:
+ json_string = file.readline().strip()
+
+ json_data = json.loads(json_string)
+ for key, val in json_data.items():
+ logger.info(f"{key.upper()}: {val}")
diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py
index 05712039e..80b5b3c09 100644
--- a/app_tests/integration_tests/llm/utils.py
+++ b/app_tests/integration_tests/llm/utils.py
@@ -15,7 +15,7 @@
import requests
from transformers import AutoTokenizer
-logger = logging.getLogger("__name__")
+logger = logging.getLogger(__name__)
class AccuracyValidationException(RuntimeError):
diff --git a/docs/amdgpu_kernel_optimization_guide.md b/docs/amdgpu_kernel_optimization_guide.md
index 09c5b59f9..91b7f1385 100644
--- a/docs/amdgpu_kernel_optimization_guide.md
+++ b/docs/amdgpu_kernel_optimization_guide.md
@@ -4,7 +4,7 @@ Author: Jakub Kuderski @kuhar
Date: 2024-06-24
-Last Update: 2024-08-22
+Last Update: 2024-11-22
## Introduction
@@ -293,3 +293,124 @@ forms a *clause* that translates to a single data fabric transaction.
> [!TIP]
> For allocations of 4 GB or less, you can implement predicated loads using the
> `buffer` instructions.
+
+## Data-Parallel Primitives and Warp-level Reduction
+
+For cross-lane data sharing, the most straightforward way is LDS. Some lanes
+write data to some locations on LDS and other lanes read data from LDS. Besides,
+there are several instructions can be used to share data cross lanes within a
+wavefront/warp.
+
+Here's a brief introduction of these instructions. Please check out [this
+blog](https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/) for
+details.
+
+### ds_permute/ds_bpermute
+
+`ds_permute`/`ds_bpermute` instructions use LDS hardware for data sharing but
+don't actually write to an LDS location. But it still needs `s_waitcnt`
+instruction to determine when data is returned to `dest` VGPR.
+
+Example:
+```nasm
+ds_bpermute_b32 dest, addr, src [offset:addr_offset]
+```
+
+### ds_swizzle
+
+Compared to `ds_bpermute`, the `ds_swizzle` instruction doesn't require an
+additional VGPR for offset since it's encoded in the instruction.
+
+`ds_swizzle` is likely to have less address generation instructions required
+than `ds_bpermute`.
+
+The cons are:
+1. It only supports limited patterns.
+2. Similar to `ds_bpermute`, `s_waitcnt` is required to wait for the `dest` VGPR.
+
+Example:
+```nasm
+ds_swizzle_b32 dest, src offset:ds_pattern
+```
+
+### Data-Parallel Primitives, DPP
+
+DPP is a 32-bit instruction modifier appended to the normal VALU instructions.
+It allows VALU instructions to access data in neighboring lanes directly, which
+means it doesn't need LDS hardware anymore, hence `s_waitcnt` instructions are
+**not required**.
+
+Unfortunately, it also supported limited patterns like `ds_swizzle`. And there
+are some instructions that can't be modified by DPP.
+
+Example:
+```nasm
+; Normal VALU instruction.
+v_add_f32
+
+; Instruction modified by DPP.
+v_add_f32_dpp
+```
+
+It's worth mentioning that DPP has different names and syntaxes on different
+architectures:
+* CDNA: DPP
+* RDNA: DPP8/DPP16
+
+For details, please check the [MI300 ISA Reference
+Guide](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf)
+and the [RDNA3 ISA Reference
+Guide](https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf).
+
+### How to use them in MLIR
+
+Each instruction has a corresponding Op in MLIR (except for `ds_permute`, this
+one is not implemented at the time of writing):
+* `ds_bpermute`: `rocdl.ds_bpermute`
+* `ds_swizzle`: `rocdl.ds_swizzle`
+* DPP: `rocdl.update.dpp`, `amdgpu.dpp` (a thin wrapper around
+ `rocdl.update.dpp` with more comprehensive user interface, e.g., replace magic
+ numbers with enums)
+
+The first 2 are straightforward, while DPP follows a different fashion.
+
+Since DPP is an instruction modifier instead of an instruction itself, there are
+tremendous number of combinations of VALU instructions and DPP. To solve that,
+`rocdl.update.dpp` and `amdgpu.dpp` are designed to be a wrapper of
+`v_mov_b32_dpp` instruction. And it depends on LLVM compiler to fuse it with the
+subsequent VALU instruction **with best efforts**.
+
+For example, `v_mov_b32_dpp` + `v_add_f32_e32` might be fused into `v_add_f32_dpp`.
+
+There are plenty of constraints stopping an instruction from being merged. For
+example, if either the `bank_mask` or the `row_mask` is not `0xf`, it can't be
+fused. You can check the
+[GCNDPPCombine::combineDPPMov](https://github.com/llvm/llvm-project/blob/ab51eccf88f5321e7c60591c5546b254b6afab99/llvm/lib/Target/AMDGPU/GCNDPPCombine.cpp#L522)
+function to see how it works.
+
+### Comparison
+
+To summarize, there's no free lunch: instruction's expressivity comes at the
+expense of performance.
+
+The relative performance of cross-lane instructions is as follows:
+
+DPP > `ds_swizzle` >= `ds_permute` > `ds_bpermute`
+
+while the generality ranking is the reverse:
+
+DPP < `ds_swizzle` < `ds_permute` < `ds_bpermute`
+
+This table presents the approximate instruction latency, collected
+experimentally on Fused Softmax kernel with
+[rocprofv2](https://github.com/ROCm/rocprofiler?tab=readme-ov-file#plugin-support)
+on the MI300 GPU:
+
+| Instructions | MLIR Op | Hardware | latency/#cycles |
+| ---------------------- | ---------------------------- | ------------ | --------------- |
+| ds_permute/ds_bpermute | rocdl.ds_bpermute | LDS hardware | ~50* |
+| ds_swizzle | rocdl.ds_swizzle | LDS hardware | ~50* |
+| DPP | rocdl.update.dpp, amdgpu.dpp | VALU | 4~12 |
+
+*: For `ds_permute`/`ds_bpermute` and `ds_swizzle`, the latency includes the
+instruction itself and its corresponding `s_waitcnt` instruction.
diff --git a/docs/developer_guide.md b/docs/developer_guide.md
index 832466688..73aee61f7 100644
--- a/docs/developer_guide.md
+++ b/docs/developer_guide.md
@@ -3,6 +3,57 @@
Each sub-project has its own developer guide. If you would like to work across
projects, these instructions should help you get started:
+
+### Install Dependencies
+
+Install shortfin dependencies
+```bash
+sudo apt update && sudo apt install -y clang lld
+```
+
+### Prepare your python environment
+
+Install:
+
+```bash
+sudo apt install python-is-python3 python3-venv python3-dev
+```
+
+
+
+ Or, alternatively, use `pyenv` to manage a separate python installation for more control over its version:
+
+
+The following instructions are taken from pyenv's guide here: https://github.com/pyenv/pyenv?tab=readme-ov-file#a-getting-pyenv
+
+First, install pyenv and its dependencies.
+
+```bash
+sudo apt update; sudo apt install build-essential libssl-dev zlib1g-dev \
+libbz2-dev libreadline-dev libsqlite3-dev curl git \
+libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev
+curl https://pyenv.run | bash
+```
+
+Then, make pyenv available by adding the below to your `~/.bashrc`:
+
+```bash
+export PYENV_ROOT="$HOME/.pyenv"
+command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"
+eval "$(pyenv init -)"
+```
+
+Finally, install a pyenv-managed version of python
+
+```bash
+pyenv install 3.12 # or whichever python version you'd like
+pyenv local 3.12
+```
+
+Now, your python, pip, and venv should be managed by pyenv instead.
+
+
+
### Setup a venv
We recommend setting up a Python
@@ -54,8 +105,10 @@ See also: [nightly_releases.md](nightly_releases.md).
### Running tests
```bash
+pip install -r shortfin/requirements-tests.txt
pytest sharktank
pytest shortfin
+pytest app_tests/integration_tests
```
### Optional: pre-commits and developer settings
diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md
index ddc0cb3bb..03e625b96 100644
--- a/docs/model_cookbook.md
+++ b/docs/model_cookbook.md
@@ -210,6 +210,8 @@ iree-compile /tmp/open_llama_3b_v2/open-llama-3b-v2-f16.mlir \
-o /tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb
```
+TODO: replace these instructions with the newer shortfin code
+
Run via `service_v1_cli.py` (shortfin serving, with tokenizer):
* TODO: script (via service CLI?) to dump inputs/outputs to .bin/.npy files
diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
index 5e0749546..4a8423bc8 100644
--- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
+++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
@@ -64,7 +64,7 @@ We will use the `hf_datasets` module in `sharktank` to download a
LLama3.1 8b f16 model.
```bash
-python -m sharktank.utils.hf_datasets amd-shark/llama3.1-8B --local-dir $EXPORT_DIR
+python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR
```
### Define environment variables
diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md
new file mode 100644
index 000000000..b63861a56
--- /dev/null
+++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md
@@ -0,0 +1,254 @@
+# Using `shortfin` with `sglang`
+
+This doc includes basic steps for hooking up sglang with a running Shortfin server.
+
+## Current Support Status
+
+| Feature | Description | Enabled | Reference |
+| ----------- | ----------- | ---------- | ------------ |
+| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) |
+| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) |
+| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) |
+| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) |
+| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) |
+| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) |
+| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) |
+
+## Prerequisites
+
+For this tutorial, you will need to meet the following prerequisites:
+
+### Software
+
+- Python >= 3.11
+ - You can check out [pyenv](https://github.com/pyenv/pyenv)
+ as a good tool to be able to manage multiple versions of python
+ on the same system.
+- A running `shortfin` LLM server as described [below](#installstart-shortfin-llm-server)
+ - We will use the shortfin server as the `backend` to generate completions
+ from SGLang's `frontend language`. In this tutorial, you can think of
+ `sglang` as the client and `shortfin` as the server.
+
+### Hardware
+
+- This tutorial is designed to run on an [AMD MI300X GPU](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html)
+
+## Install/Start `shortfin` LLM server
+
+Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/e2e_llama8b_mi300x.md)
+to export a model with `sharktank` and start a `shortfin` LLM server
+with that model.
+
+## Install sglang
+
+### Install sglang inside of virtual environment
+
+Currently, we have our SGLang integration located at this [forked repo](https://github.com/nod-ai/sglang).
+We can use pip to install it in the same virtual environment that we used
+to start our Shortfin LLM Server.
+
+```bash
+pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
+```
+
+## Getting started
+
+You can verify the installation/setup through the following examples:
+
+- [Multi-Turn Q&A Example](#multi-turn-qa-example)
+- [Fork Example](#fork-example)
+- [Benchmark Shortfin](#bench-mark-shortfin-w-sglang-bench_serving-script)
+
+## Multi-Turn Q&A example
+
+Now that we have sglang installed, we can run an example to show a multi-turn
+Q&A flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html):
+
+### Open python interpreter
+
+```bash
+python
+```
+
+### Run example
+
+You can copy and paste the following example into your interpreter:
+
+```python
+import sglang as sgl
+
+from sglang.lang.chat_template import get_chat_template
+
+backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000", ) # Change base_url if running at different address
+
+sgl.set_default_backend(backend)
+
+@sgl.function
+def multi_turn_question(s, question_1, question_2):
+ s += sgl.user(question_1)
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
+ s += sgl.user(question_2)
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
+
+state = multi_turn_question.run(question_1="Name the capital city of the USA.", question_2="The Smithsonian is in this location.")
+
+for m in state.messages():
+ print(m["role"], m["content"])
+```
+
+### Shortfin example output
+
+You should see an output similar to this:
+
+```text
+========== single ==========
+
+user : Name the capital city of the USA
+assistant : The capital city of the United States of America is Washington, D.C. (short for District of Columbia).
+user : The Smithsonian is in this location.
+assistant : The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes.
+```
+
+## Fork example
+
+Now that we have sglang installed, we can run an example to show a `fork`
+flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html):
+
+### Open python interpreter
+
+```bash
+python
+```
+
+### Run example
+
+You can copy and paste the following example into your interpreter:
+
+```python
+import sglang as sgl
+
+from sglang.lang.chat_template import get_chat_template
+
+backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000") # Change base_url if running at different address
+
+sgl.set_default_backend(backend)
+
+@sgl.function
+def tip_suggestion(s):
+ s += (
+ "Here are two tips for staying healthy: "
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
+ )
+ forks = s.fork(2)
+ for i, f in enumerate(forks):
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
+ f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
+ s += "In summary" + sgl.gen("summary")
+
+state = tip_suggestion.run()
+
+print(state.text())
+```
+
+### Shortfin example output
+
+You should see an output similar to this:
+
+```text
+Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise.
+
+Tip 1:A balanced diet is important for maintaining good health. It should
+include a variety of foods from all the major food groups, such as fruits,
+vegetables, grains, proteins, and dairy. Eating a balanced diet can help
+prevent chronic diseases such as heart disease, diabetes, and obesity.
+
+Now, expand tip 2 into a paragraph:
+Regular exercise is also important for maintaining good health. It can help
+improve cardiovascular health, strengthen muscles and bones, and reduce the
+risk of chronic diseases. Exercise can also help improve mental health by
+reducing stress and anxiety. It is recommended that adults get at least 150
+minutes of moderate-intensity exercise or 75 minutes of vigorous-intensity
+exercise per week.
+
+Now, combine the two paragraphs into a single paragraph:
+A balanced diet and regular exercise are both important for maintaining good
+health. A balanced diet should include a variety of foods from all the major
+food groups, such as fruits, vegetables, grains, proteins, and dairy.
+Eating a balanced diet can help prevent chronic diseases such as heart disease,
+diabetes, and obesity. Regular exercise is also important for maintaining good
+health. It can help improve cardiovascular health, strengthen muscles and bones,
+and reduce the risk of chronic diseases. Exercise can also help improve mental
+health by reducing stress and anxiety. It is recommended that
+
+Tip 2:Regular exercise is important for maintaining a healthy body and mind.
+It can help improve cardiovascular health, strengthen muscles and bones,
+and reduce the risk of chronic diseases such as diabetes and heart disease.
+Additionally, exercise has been shown to improve mood, reduce stress,
+and increase overall well-being. It is recommended that adults engage in
+at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of
+vigorous-intensity aerobic activity per week, as well as strength training
+exercises at least two days per week.
+
+In summary, a balanced diet and regular exercise are both essential for
+maintaining good health. A balanced diet should include a variety of foods from
+all the major food groups, while regular exercise can help improve
+cardiovascular health, strengthen muscles and bones, reduce the risk of
+chronic diseases, and improve mental health. It is recommended that adults
+engage in at least 150 minutes of moderate-intensity aerobic activity or
+75 minutes of vigorous-intensity aerobic activity per week,
+as well as strength training exercises at least two days per week.
+```
+
+## Benchmark shortfin w/ sglang `bench_serving` script
+
+We can obtain benchmarking metrics using the `bench_serving` script
+provided by SGLang:
+
+**NOTE: Change `--base-url` if running at a different address**
+
+```bash
+python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer /path/to/tokenizer/dir --request-rate 1
+```
+
+There are some more metrics captured, but the most relevant are the following:
+
+- E2E Latency
+- TTFT (Time to First Token)
+- TPOT (Time per Output Token)
+- ITL (Inter-Token Latency)
+- Request Throughput
+- Benchmark Duration
+
+When complete, you should see an output similar to this:
+
+```text
+============ Serving Benchmark Result ============
+Backend: shortfin
+Traffic request rate: 1.0
+Successful requests: 10
+Benchmark duration (s): 427.91
+Total input tokens: 1960
+Total generated tokens: 2774
+Total generated tokens (retokenized): 63
+Request throughput (req/s): 0.02
+Input token throughput (tok/s): 4.58
+Output token throughput (tok/s): 6.48
+----------------End-to-End Latency----------------
+Mean E2E Latency (ms): 416268.77
+Median E2E Latency (ms): 417159.14
+---------------Time to First Token----------------
+Mean TTFT (ms): 292404.29
+Median TTFT (ms): 365989.01
+P99 TTFT (ms): 367325.63
+-----Time per Output Token (excl. 1st token)------
+Mean TPOT (ms): 1359.41
+Median TPOT (ms): 163.96
+P99 TPOT (ms): 6316.12
+---------------Inter-token Latency----------------
+Mean ITL (ms): 2238.99
+Median ITL (ms): 958.75
+P99 ITL (ms): 2719.50
+==================================================
+```
diff --git a/sharktank/conftest.py b/sharktank/conftest.py
index 475f386be..ddd371198 100644
--- a/sharktank/conftest.py
+++ b/sharktank/conftest.py
@@ -153,16 +153,28 @@ def pytest_addoption(parser):
# --outtype=f32 \
# t5-v1_1-small
parser.addoption(
- "--google-t5-v1-1-small-fp32-model-path",
+ "--google-t5-v1-1-small-f32-model-path",
type=Path,
- default="/data/t5/small/google__t5-v1_1-small_fp32.gguf",
- help="Google T5 v1.1 small fp32 model path",
+ default="/data/t5/small/google__t5-v1_1-small_f32.gguf",
+ help="Google T5 v1.1 small float32 model path",
)
parser.addoption(
- "--google-t5-v1-1-xxl-fp32-model-path",
+ "--google-t5-v1-1-small-bf16-model-path",
type=Path,
- default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf",
- help="Google T5 v1.1 XXL fp32 model path",
+ default="/data/t5/small/google__t5-v1_1-small_bf16.gguf",
+ help="Google T5 v1.1 small bfloat16 model path",
+ )
+ parser.addoption(
+ "--google-t5-v1-1-xxl-f32-model-path",
+ type=Path,
+ default="/data/t5/xxl/google__t5-v1_1-xxl_f32.gguf",
+ help="Google T5 v1.1 XXL float32 model path",
+ )
+ parser.addoption(
+ "--google-t5-v1-1-xxl-bf16-model-path",
+ type=Path,
+ default="/data/t5/xxl/google__t5-v1_1-xxl_bf16.gguf",
+ help="Google T5 v1.1 XXL bfloat16 model path",
)
parser.addoption(
@@ -288,15 +300,20 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
- model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
+ model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option(
+ request,
+ "--google-t5-v1-1-small-f32-model-path",
+ "google__t5_v1_1_small_f32_model",
+ )
+ model_path["google__t5_v1_1_small_bf16_model_path"] = set_fixture_from_cli_option(
request,
- "--google-t5-v1-1-small-fp32-model-path",
- "google__t5_v1_1_small_fp32_model",
+ "--google-t5-v1-1-small-bf16-model-path",
+ "google__t5_v1_1_small_bf16_model",
)
- model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
+ model_path["google__t5_v1_1_xxl_f32_model_path"] = set_fixture_from_cli_option(
request,
- "--google-t5-v1-1-xxl-fp32-model-path",
- "google__t5_v1_1_xxl_fp32_model",
+ "--google-t5-v1-1-xxl-f32-model-path",
+ "google__t5_v1_1_xxl_f32_model",
)
return model_path
diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py
index 182b37a50..45af24004 100644
--- a/sharktank/integration/models/punet/integration_test.py
+++ b/sharktank/integration/models/punet/integration_test.py
@@ -89,12 +89,13 @@ def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
def sdxl_int8_base_files():
from huggingface_hub import hf_hub_download
- REPO_ID = "amd-shark/sdxl-quant-models"
- REVISION = "942e771bf0c2657a8b33380103d04747a75dfa4a"
+ REPO_ID = "amd-shark/sdxl-quant-int8"
+ SUBFOLDER = "mi300_all_sym_8_step14_fp32"
+ REVISION = "efda8afb35fd72c1769e02370b320b1011622958"
def download(filename):
return hf_hub_download(
- repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION
+ repo_id=REPO_ID, subfolder=SUBFOLDER, filename=filename, revision=REVISION
)
return {
diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt
index 6b533d977..26d89c59d 100644
--- a/sharktank/requirements.txt
+++ b/sharktank/requirements.txt
@@ -1,7 +1,7 @@
iree-turbine
# Runtime deps.
-gguf==0.6.0
+gguf==0.10.0
numpy<2.0
# Needed for newer gguf versions (TODO: remove when gguf package includes this)
diff --git a/sharktank/sharktank/evaluate/README.md b/sharktank/sharktank/evaluate/README.md
index 784bb24fd..beb0281cd 100644
--- a/sharktank/sharktank/evaluate/README.md
+++ b/sharktank/sharktank/evaluate/README.md
@@ -9,16 +9,32 @@ pip install -r sharktank/requirements-tests.txt
### Perplexity
-Test perplexity for Llama3.1 8B & 405B (FP16 & FP8) models:
+Perplexity score measures the ability of a language model to predict the next token in a sequence. A lower score indicates that a model has higher certainty in it's predictions. Perplexity acts as an intrinsic evaluation metric that measures the model quality, independent of any downstream task.
+
+In SHARK-Platform, we use perplexity to track code regressions and quality loss across quantized models (with FP16 as baseline). We use 100 prompts randomly selected from the Wikitext-2 test set and calculate the mean perplexities shown below. These numbers are neither comparable between models with different tokenizers nor with other projects due to varying implementations.
+
+* Test perplexity for Llama3.1 8B (FP16) model:
```bash
pytest sharktank/tests/evaluate/perplexity_test.py --longrun
```
-Get perplexity for a new model:
+* Calculate perplexity for a new model:
```bash
python -m sharktank.evaluate.perplexity \
--gguf-file=llama3_70b_f16.gguf \
--tokenizer-config-json=tokenizer_config.json
```
+
+### Perplexity Scoreboard
+
+| CPU | GPU |
+|:-------------: |:----------:|
+| AMD EPYC 9554 | MI300X |
+
+#### LLaMA 3.1
+
+|Models |Model size (GB) |Torch score |IREE score |
+|:----------------------|:---------------|:-------------|:-------------|
+|8B FP16 TP1 decomposed |16.07 |14.930181 |14.991893 |
diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_iree.py
similarity index 85%
rename from sharktank/sharktank/evaluate/perplexity_vmfb.py
rename to sharktank/sharktank/evaluate/perplexity_iree.py
index 4f95ae1bd..6060eb91b 100644
--- a/sharktank/sharktank/evaluate/perplexity_vmfb.py
+++ b/sharktank/sharktank/evaluate/perplexity_iree.py
@@ -9,6 +9,7 @@
import json
import time
import random
+import re
from datetime import timedelta
from tqdm import tqdm
@@ -83,17 +84,24 @@ def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
- seconds = end - start
- time_taken = abs(timedelta(seconds=round(seconds)))
-
- if seconds < 1:
- time_taken = f" {seconds * 1000} ms"
+ total_seconds = end - start
+ time_taken = abs(timedelta(seconds=total_seconds))
+ hours, minutes, seconds = re.split(":", str(time_taken))
+
+ if total_seconds < 1:
+ time_taken = f" {round(total_seconds * 1000, 3)} ms"
+ elif total_seconds < 60:
+ time_taken = "{:.2f} secs".format(round(float(total_seconds), 2))
+ else:
+ time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format(
+ int(hours), int(minutes), round(float(seconds), 2)
+ )
func_name = func.__name__
if func_name == "get_perplexity":
- func_name = f"Total time to calculate perplexity"
+ func_name = f"Calculate perplexity"
elif func_name == "compile_model":
- func_name = f"Total time to export and compile"
+ func_name = f"Export & compile"
logger.info(f" {func_name}: {time_taken}")
return result
@@ -119,7 +127,7 @@ def print_token_comparison(self, i):
def compile_model(self, weight_path_str):
self.weight_path_str = weight_path_str
- logger.info(f"Compiling: {self.weight_path_str}")
+ logger.info(f" Compiling: {self.weight_path_str}")
export_artifacts = ExportArtifacts(
irpa_path=self.weight_path_str,
@@ -135,7 +143,7 @@ def compile_model(self, weight_path_str):
@timeit
def load_model(self, weight_path, tokenizer, vmfb_path):
- config = LlamaModelConfig(
+ self.config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(weight_path.properties),
block_seq_stride=16,
kv_cache_type=self.kv_cache_type,
@@ -145,18 +153,18 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
tensor_parallelism_size=self.tensor_parallelism_size,
)
- if config.tensor_parallelism_size > 1:
- weight_path.root_theta = shard_theta(weight_path.root_theta, config)
+ if self.config.tensor_parallelism_size > 1:
+ weight_path.root_theta = shard_theta(weight_path.root_theta, self.config)
theta = weight_path.root_theta
- if config.hp.expert_count:
- if config.hp.model_arch == "grok":
- model = PagedGrokModelV1(theta, config)
+ if self.config.hp.expert_count:
+ if self.config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(theta, self.config)
else:
- model = PagedMixtralModelV1(theta, config)
+ model = PagedMixtralModelV1(theta, self.config)
else:
- model = PagedLlamaModelV1(theta, config)
+ model = PagedLlamaModelV1(theta, self.config)
self.generator = TorchGenerator(model, tokenizer)
@@ -169,7 +177,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
self.haldevice = self.runner.config.device
@timeit
- def get_prompts(self):
+ def get_prompts(self, num_prompts):
test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
"text"
]
@@ -183,12 +191,15 @@ def get_prompts(self):
s.replace("\n", "").rstrip()
for s in test_prompts
if s != "" and len(s.split()) >= 20 and s.count("=") < 2
- ]
+ ][0:num_prompts]
+
+ self.test_prompts = test_prompts
self.bs = len(test_prompts)
- return test_prompts
+ logger.info(f" Batch size: {self.bs}")
+ @timeit
def prefill_vmfb(self, token_batch, i):
seq_block_ids = self.batch.pad_block_ids()
@@ -244,25 +255,7 @@ def decode_vmfb(self, token_batch, i):
return decode_logits
@timeit
- def get_logits(self):
-
- token_ids, seq_lens = self.generator.tokenizer.encode(
- self.test_prompts,
- pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
- )
-
- logger.info(f" Prompts for Evaluation:")
- for idx, prompt in enumerate(self.test_prompts):
- logger.info(
- f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
- )
-
- self.max_prompt_length = max(seq_lens)
-
- self.token_ids = torch.tensor(token_ids, device=self.torch_device)
- self.attention_mask = (
- (self.token_ids != 0).int().detach().clone().to(self.torch_device)
- )
+ def get_logits(self, page_cache_size):
is_first_token = True
start = 0
@@ -298,6 +291,7 @@ def get_logits(self):
token_batch=token_batch,
seq_lens_batch=self.seq_lens_batch,
bs=self.bs,
+ page_cache_size=page_cache_size,
)
self.cache_state = ireert.asdevicearray(
@@ -347,11 +341,31 @@ def compute_perplexity(self):
}
@timeit
- def get_perplexity(self, test_prompts):
+ def get_perplexity(self):
- self.test_prompts = test_prompts
+ token_ids, seq_lens = self.generator.tokenizer.encode(
+ self.test_prompts,
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ self.page_cache_size = (
+ len(token_ids[0]) // self.config.block_seq_stride
+ ) * self.bs + 1
+
+ logger.debug(f" Prompts for Evaluation:")
+ for idx, prompt in enumerate(self.test_prompts):
+ logger.debug(
+ f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
+ )
- self.get_logits()
+ self.max_prompt_length = max(seq_lens)
+
+ self.token_ids = torch.tensor(token_ids, device=self.torch_device)
+ self.attention_mask = (
+ (self.token_ids != 0).int().detach().clone().to(self.torch_device)
+ )
+
+ self.get_logits(page_cache_size=self.page_cache_size)
self.out_logits = self.out_logits[..., :-1, :].contiguous()
self.token_ids = self.token_ids[..., 1:].contiguous()
@@ -379,7 +393,9 @@ def run_perplexity(
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
+ num_prompts,
):
+ start = time.time()
perplexity = Perplexity(
torch_device=torch_device,
iree_device=iree_device,
@@ -390,12 +406,19 @@ def run_perplexity(
attention_kernel=attention_kernel,
)
- test_prompts = perplexity.get_prompts()
- logger.info(f" Total test prompts: {len(test_prompts)}")
+ perplexity.get_prompts(num_prompts=num_prompts)
vmfb_path = perplexity.compile_model(weight_path_str)
perplexity.load_model(weight_path, tokenizer, vmfb_path)
- ppl = perplexity.get_perplexity(test_prompts)
+ ppl = perplexity.get_perplexity()
+
+ end = time.time()
+ total_time = round(end - start, 2)
+ if total_time < 60:
+ total_time = str(total_time) + " secs"
+ else:
+ total_time = str(round(total_time / 60, 2)) + " mins"
+ logger.info(f" Total time taken: {total_time}")
return ppl
@@ -404,7 +427,7 @@ def main(argv):
parser = cli.create_parser()
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--torch-device", help="Torch device (or default)")
- parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'")
+ parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')")
parser.add_argument(
"--iree-hip-target",
action="store",
@@ -429,6 +452,12 @@ def main(argv):
default=1,
help="Number of devices for tensor parallel sharding",
)
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=100,
+ help="Number of prompts for perplexity test",
+ )
cli.add_tokenizer_options(parser)
cli.add_input_dataset_options(parser)
@@ -452,6 +481,7 @@ def main(argv):
kv_cache_type=kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
attention_kernel=args.attention_kernel,
+ num_prompts=args.num_prompts,
)
logger.info(f"\n{json.dumps(ppl, indent=2)}")
diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py
index fc3aa5fca..258e8c9a0 100644
--- a/sharktank/sharktank/evaluate/perplexity_torch.py
+++ b/sharktank/sharktank/evaluate/perplexity_torch.py
@@ -8,6 +8,7 @@
import logging
import time
import random
+import re
from datetime import timedelta
import json
import numpy as np
@@ -69,15 +70,22 @@ def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
- seconds = end - start
- time_taken = abs(timedelta(seconds=round(seconds)))
-
- if seconds < 1:
- time_taken = f" {seconds * 1000} ms"
+ total_seconds = end - start
+ time_taken = abs(timedelta(seconds=total_seconds))
+ hours, minutes, seconds = re.split(":", str(time_taken))
+
+ if total_seconds < 1:
+ time_taken = f" {round(total_seconds * 1000, 3)} ms"
+ elif total_seconds < 60:
+ time_taken = "{:.2f} secs".format(round(float(total_seconds), 2))
+ else:
+ time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format(
+ int(hours), int(minutes), round(float(seconds), 2)
+ )
func_name = func.__name__
if func_name == "get_perplexity":
- func_name = "Total time"
+ func_name = "Calculate perplexity"
logger.info(f" {func_name}: {time_taken}")
return result
@@ -102,7 +110,7 @@ def print_token_comparison(self, i):
@timeit
def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel):
- config = LlamaModelConfig(
+ self.config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=self.kv_cache_type,
@@ -112,23 +120,23 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern
tensor_parallelism_size=tensor_parallelism_size,
)
- if config.tensor_parallelism_size > 1:
- dataset.root_theta = shard_theta(dataset.root_theta, config)
+ if self.config.tensor_parallelism_size > 1:
+ dataset.root_theta = shard_theta(dataset.root_theta, self.config)
theta = dataset.root_theta
- if config.hp.expert_count:
- if config.hp.model_arch == "grok":
- model = PagedGrokModelV1(theta, config)
+ if self.config.hp.expert_count:
+ if self.config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(theta, self.config)
else:
- model = PagedMixtralModelV1(theta, config)
+ model = PagedMixtralModelV1(theta, self.config)
else:
- model = PagedLlamaModelV1(theta, config)
+ model = PagedLlamaModelV1(theta, self.config)
self.generator = TorchGenerator(model, tokenizer)
@timeit
- def get_prompts(self):
+ def get_prompts(self, num_prompts):
test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
"text"
@@ -144,34 +152,16 @@ def get_prompts(self):
s.replace("\n", "").rstrip()
for s in test_prompts
if s != "" and len(s.split()) >= 20 and s.count("=") < 2
- ]
-
- logger.info(f" num_test_prompts: {len(test_prompts)}")
-
- return test_prompts
-
- @timeit
- def get_logits(self):
-
- token_ids, seq_lens = self.generator.tokenizer.encode(
- self.test_prompts,
- pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
- )
+ ][0:num_prompts]
- logger.info(f" Prompts for Evaluation:")
- for idx, prompt in enumerate(self.test_prompts):
- logger.info(
- f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
- )
+ self.test_prompts = test_prompts
- self.max_prompt_length = max(seq_lens)
+ self.bs = len(test_prompts)
- self.token_ids = torch.tensor(token_ids, device=self.device)
- self.attention_mask = (
- (self.token_ids != 0).int().detach().clone().to(self.device)
- )
+ logger.info(f" Batch size: {self.bs}")
- self.bs = len(self.test_prompts)
+ @timeit
+ def get_logits(self, page_cache_size):
is_first_token = True
start = 0
@@ -204,6 +194,7 @@ def get_logits(self):
token_batch=token_batch,
seq_lens_batch=seq_lens_batch,
bs=self.bs,
+ page_cache_size=page_cache_size,
)
self.batch.prefill()
@@ -260,10 +251,31 @@ def compute_perplexity(self):
}
@timeit
- def get_perplexity(self, test_prompts):
+ def get_perplexity(self):
- self.test_prompts = test_prompts
- self.get_logits()
+ token_ids, seq_lens = self.generator.tokenizer.encode(
+ self.test_prompts,
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ self.page_cache_size = (
+ len(token_ids[0]) // self.config.block_seq_stride
+ ) * self.bs + 1
+
+ logger.debug(f" Prompts for Evaluation:")
+ for idx, prompt in enumerate(self.test_prompts):
+ logger.debug(
+ f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
+ )
+
+ self.max_prompt_length = max(seq_lens)
+
+ self.token_ids = torch.tensor(token_ids, device=self.device)
+ self.attention_mask = (
+ (self.token_ids != 0).int().detach().clone().to(self.device)
+ )
+
+ self.get_logits(page_cache_size=self.page_cache_size)
self.out_logits = self.out_logits[..., :-1, :].contiguous()
self.token_ids = self.token_ids[..., 1:].contiguous()
@@ -287,12 +299,22 @@ def run_perplexity_torch(
kv_cache_type,
tensor_parallelism_size,
attention_kernel,
+ num_prompts,
):
- perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type)
+ start = time.time()
+ perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type)
+ perplexity.get_prompts(num_prompts=num_prompts)
perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel)
- test_prompts = perplexity.get_prompts()
- ppl = perplexity.get_perplexity(test_prompts=test_prompts)
+ ppl = perplexity.get_perplexity()
+
+ end = time.time()
+ total_time = round(end - start, 2)
+ if total_time < 60:
+ total_time = str(total_time) + " secs"
+ else:
+ total_time = str(round(total_time / 60, 2)) + " mins"
+ logger.info(f" Total time taken: {total_time}")
return ppl
@@ -314,6 +336,12 @@ def main(argv):
default=1,
help="Number of devices for tensor parallel sharding.",
)
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=100,
+ help="Number of prompts for perplexity test",
+ )
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
@@ -331,6 +359,7 @@ def main(argv):
kv_cache_type=kv_cache_type,
tensor_parallelism_size=args.tensor_parallelism_size,
attention_kernel=args.attention_kernel,
+ num_prompts=args.num_prompts,
)
logger.info(f"\n{json.dumps(ppl, indent=2)}")
diff --git a/sharktank/sharktank/kernels/templates/flash_attention.mlir b/sharktank/sharktank/kernels/templates/flash_attention.mlir
index 15d75c372..4085fef9c 100644
--- a/sharktank/sharktank/kernels/templates/flash_attention.mlir
+++ b/sharktank/sharktank/kernels/templates/flash_attention.mlir
@@ -33,19 +33,16 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_
%scale = tensor.extract %s[] : !s_type
- %init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type
- %transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2]
-
%empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type
%empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type
%atten = iree_linalg_ext.attention {indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>,
- affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]}
- ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
+ ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> !o_type
diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py
index 35a2ee570..996a92152 100644
--- a/sharktank/sharktank/layers/configs/llm_configs.py
+++ b/sharktank/sharktank/layers/configs/llm_configs.py
@@ -227,6 +227,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
== properties["t5.attention.layer_norm_rms_epsilon"]
)
+ all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
+
gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
@@ -236,11 +238,9 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
- "t5.decoder_start_token_id": ["decoder_start_token_id"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
- all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
all_kwargs.update(
{
config_name: properties[gguf_name]
@@ -248,6 +248,19 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
for config_name in config_names
}
)
+
+ gguf_to_optional_config_names_map = {
+ "t5.decoder_start_token_id": ["decoder_start_token_id"],
+ }
+ all_kwargs.update(
+ {
+ config_name: properties[gguf_name]
+ for gguf_name, config_names in gguf_to_optional_config_names_map.items()
+ for config_name in config_names
+ if gguf_name in properties
+ }
+ )
+
if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
all_kwargs.update(kwargs)
diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py
index b679dccde..acd9b8a37 100644
--- a/sharktank/sharktank/layers/linear.py
+++ b/sharktank/sharktank/layers/linear.py
@@ -31,9 +31,8 @@ class LinearLayer(ThetaLayer):
x = x * premul_input
matmul(x, weight.T) + bias
- fake_quant exists to allow export without adding dequant ops.
- when fake_quant is True, the op will in quant dequant fashion.
- When false, it will keep quantized types.
+ fake quant only exists in order to allow for q_input to act as qdq.
+ when fake quant is false, q_input will quantize normally.
```
"""
@@ -43,7 +42,7 @@ def __init__(
*,
weight_name: str = "weight",
bias_name: str = "bias",
- fake_quant: bool = True,
+ fake_quant: bool = False,
):
super().__init__(theta)
self._simulate_native_quant = True
@@ -74,21 +73,23 @@ def forward(self, x):
x = q_input.quantize(x)
if self.fake_quant:
x = x.unpack().dequant()
- elif qdq_input is not None and self.fake_quant:
+
+ elif qdq_input is not None:
x = qdq_input.quantize(x).unpack().dequant()
y = ops.linear(x, weight, bias)
# Unconditionally dequantize.
- if isinstance(y, QuantizedTensor) and not self.fake_quant:
+ if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
- if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz:
- y = ops.to(y, torch.float16)
- return y
- if qdq_output is not None and self.fake_quant:
+ if not isinstance(y, QuantizedTensor):
+ if y.dtype == torch.float8_e4m3fnuz:
+ y = ops.to(y, torch.float16)
+ return y
+ if qdq_output is not None:
y = qdq_output.quantize(y).unpack().dequant()
return y
diff --git a/sharktank/sharktank/layers/token_embedding.py b/sharktank/sharktank/layers/token_embedding.py
index 32e7fec8f..e5e06c0ef 100644
--- a/sharktank/sharktank/layers/token_embedding.py
+++ b/sharktank/sharktank/layers/token_embedding.py
@@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
+from typing import Optional
from .. import ops
from .base import Theta, ThetaLayer
@@ -16,7 +17,7 @@ def __init__(
theta: Theta,
*,
weight_name: str = "weight",
- dtype: torch.dtype = torch.float32,
+ dtype: Optional[torch.dtype] = torch.float32,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py
index 7bd5eef3d..8d5f75db2 100644
--- a/sharktank/sharktank/models/t5/export.py
+++ b/sharktank/sharktank/models/t5/export.py
@@ -4,12 +4,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Union
+import functools
+from typing import Optional, Union
from pathlib import Path
import torch
+from copy import copy
from .t5 import T5Config, T5Encoder
from ...types import Dataset
+from ...transforms.dataset import set_float_dtype
from iree.turbine.aot import FxProgramsBuilder, export
__all__ = [
@@ -91,7 +94,18 @@ def prune_decoder_parameters(dataset: Dataset):
pass
-def export_encoder_iree_parameters(model_path: str, output_path: str):
- dataset = Dataset.load(model_path)
+def export_encoder_iree_parameters(
+ model_path_or_dataset: str | Dataset,
+ output_path: str,
+ dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(model_path_or_dataset, Dataset):
+ dataset = copy(model_path_or_dataset)
+ else:
+ dataset = Dataset.load(model_path_or_dataset)
+ if dtype:
+ dataset.root_theta = dataset.root_theta.transform(
+ functools.partial(set_float_dtype, dtype=dtype)
+ )
prune_decoder_parameters(dataset)
dataset.save(output_path)
diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py
index 4ae9108d5..88472db1d 100644
--- a/sharktank/sharktank/models/t5/t5.py
+++ b/sharktank/sharktank/models/t5/t5.py
@@ -684,7 +684,9 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None):
self.add_module(
"final_layer_norm",
RMSNormLayer(
- theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon
+ theta(f"{theta_prefix}.output_norm"),
+ epsilon=config.layer_norm_epsilon,
+ dtype=config.activation_dtype,
),
)
@@ -1046,7 +1048,9 @@ def __init__(self, theta: Theta, config: T5Config):
super().__init__()
self.add_module(
"token_embedding",
- TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
+ TokenEmbeddingLayer(
+ theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype
+ ),
)
encoder_config = copy.deepcopy(config)
diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py
index 8ffb1f51e..d1353daaa 100644
--- a/sharktank/sharktank/ops/attention_impls.py
+++ b/sharktank/sharktank/ops/attention_impls.py
@@ -47,7 +47,7 @@ def _extract_linear_scale(t):
return unbox_tensor(t), None
-def flash_attention(q, k, v, a):
+def flash_attention(q, k, v, a, is_causal, scale):
scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32)
q, qscale = _extract_linear_scale(q)
diff --git a/sharktank/sharktank/serving_poc/__init__.py b/sharktank/sharktank/serving_poc/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/sharktank/sharktank/serving_poc/framework/logging.py b/sharktank/sharktank/serving_poc/framework/logging.py
deleted file mode 100644
index fe5ffc069..000000000
--- a/sharktank/sharktank/serving_poc/framework/logging.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import logging
-import os
-import sys
-
-
-# Whether debug assertions are disabled.
-NDEBUG: bool = False
-
-_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG")
-
-
-class DefaultFormatter(logging.Formatter):
- def __init__(self):
- super().__init__(
- "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s",
- "%m-%d %H:%M:%S",
- )
-
-
-def _setup_logger():
- root_logger = logging.getLogger("sharktank.serving_poc")
- root_logger.setLevel(logging.DEBUG)
- default_handler = logging.StreamHandler(sys.stderr)
- default_handler.flush = sys.stderr.flush
- default_handler.setLevel(_default_log_level)
- default_handler.setFormatter(DefaultFormatter())
- root_logger.addHandler(default_handler)
- root_logger.propagate = False
- return root_logger, default_handler
-
-
-root_logger, default_handler = _setup_logger()
-
-logging.getLogger("asyncio").addHandler(default_handler)
-
-
-def get_logger(name: str):
- logger = logging.getLogger(name)
- logger.setLevel(_default_log_level)
- logger.addHandler(default_handler)
- logger.propagate = False
- return logger
diff --git a/sharktank/sharktank/serving_poc/framework/session.py b/sharktank/sharktank/serving_poc/framework/session.py
deleted file mode 100644
index 28af0fd44..000000000
--- a/sharktank/sharktank/serving_poc/framework/session.py
+++ /dev/null
@@ -1,610 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-"""Runtime session constructs.
-
-Key concepts:
-
- * DeviceSession: A single HAL device and other process-level globals. Shared global
- memory and corresponding synchronization handles are accessible from here.
- * WorkQueue: Logical stream of execution, nested under the DeviceSession. Each
- queue holds a timeline semaphore which sequences invocations. For these models,
- we route workloads of vastly different characteristics to distinct queues (i.e.
- prefill vs decode step).
- * LoadedModule: Modules that have been loaded but have not yet been instantiated into
- a context.
- * HostContext: At least one HostContext is created per LoadedModule. It encapsulates
- a VMContext and performs invocations on a dedicated thread. Typically, there will
- be more that one HostContext per LoadedModule as it helps us load balance the
- host side work across multiple OS threads, ensuring faster feeding of the device.
-"""
-
-from typing import Any, Callable, Coroutine, Generic, TypeVar, Optional, Union
-
-import asyncio
-import concurrent.futures
-import math
-import queue
-from threading import Lock, Thread
-import warnings
-
-import numpy as np
-
-from iree.runtime import ( # type: ignore[import-untyped]
- create_hal_module,
- create_io_parameters_module,
- get_driver,
- BufferUsage,
- HalBufferView,
- HalCommandBuffer,
- HalDevice,
- HalDeviceLoopBridge,
- HalDriver,
- HalElementType,
- HalFence,
- HalSemaphore,
- MemoryType,
- ParameterIndex,
- VmFunction,
- VmInstance,
- VmContext,
- VmModule,
-)
-
-from .logging import get_logger, NDEBUG
-
-T = TypeVar("T")
-
-logger = get_logger("shark_turbine.serving.session")
-_CONFIG_LOCK = Lock()
-_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None
-
-
-def get_vm_instance() -> VmInstance:
- global _GLOBAL_VM_INSTANCE
- if not _GLOBAL_VM_INSTANCE:
- with _CONFIG_LOCK:
- if not _GLOBAL_VM_INSTANCE:
- _GLOBAL_VM_INSTANCE = VmInstance()
- return _GLOBAL_VM_INSTANCE
-
-
-class DeviceSession:
- """Top-level object associated with a single attached device."""
-
- __slots__ = [
- "device",
- "driver",
- "_module_sets",
- "queues",
- "_queue_request_count",
- "vm_instance",
- ]
-
- def __init__(
- self,
- *,
- uri: Optional[str] = None,
- driver: Optional[Union[str, HalDriver]] = None,
- device: Optional[HalDevice] = None,
- vm_instance: Optional[VmInstance] = None,
- queue_count: int = 1,
- ):
- self._queue_request_count = 0
- self.vm_instance = vm_instance or get_vm_instance()
- if uri is not None:
- assert (
- driver is None and device is None
- ), "If 'uri' is given, 'driver' and 'device' cannot be set"
- logger.info("Opening device by uri: %s", uri)
- driver = uri_driver = get_driver(uri)
- device = uri_driver.create_device_by_uri(uri)
- assert driver is not None, "'driver' cannot be None"
- self.driver = driver if not isinstance(driver, str) else get_driver(driver)
- self.device = device if device else self.driver.create_default_device()
-
- # Dependent objects.
- self._module_sets: dict[str, "ModuleSet"] = {}
- self.queues = [WorkQueue(self, i) for i in range(queue_count)]
-
- def shutdown(self):
- for ms in self._module_sets.values():
- ms.shutdown()
-
- def create_module_set(self, name: str, *, context_count: int = 1) -> "ModuleSet":
- assert (
- name not in self._module_sets
- ), f"Modules with name {name} already created"
- lm = ModuleSet(self, name, context_count=context_count)
- self._module_sets[name] = lm
- return lm
-
- def module_set(self, name: str) -> "ModuleSet":
- try:
- return self._module_sets[name]
- except KeyError:
- raise KeyError(
- f"ModuleSet '{name}' not found. Available: {self._module_sets.keys()}"
- )
-
- def queue(self, index: int = -1) -> "WorkQueue":
- """Gets a queue either with an explicit index or in some rotating fashion."""
- if index >= 0:
- return self.queues[index]
- else:
- self._queue_request_count += 1
- qc = self._queue_request_count
- return self.queues[qc % len(self.queues)]
-
-
-class ModuleSet:
- __slots__ = [
- "contexts",
- "modules",
- "name",
- "session",
- "_context_counter",
- ]
-
- def __init__(self, session: DeviceSession, name: str, *, context_count: int):
- assert context_count > 0
- self.session = session
- self.name = name
- self.modules: list[VmModule] = [
- create_hal_module(session.vm_instance, session.device)
- ]
- self.contexts = [None] * context_count
- self._context_counter = 0
-
- @property
- def initialized(self) -> bool:
- return self.contexts[-1] is not None
-
- def add(self, *modules: VmModule):
- for module in modules:
- self.modules.append(module)
-
- def load_vmfb(self, vmfb_path: str):
- logger.info("Loading VMFB %s", vmfb_path)
- self.add(VmModule.mmap(self.session.vm_instance, vmfb_path))
-
- def load_io_module(self, sources_path: str):
- logger.info("Loading IO Module %s", sources_path)
- index = ParameterIndex()
- index.load(sources_path)
- par_provider = index.create_provider(scope="model")
- self.add(create_io_parameters_module(self.session.vm_instance, par_provider))
-
- def initialize(self):
- assert not self.initialized, "Already initialized"
- count = len(self.contexts)
- logger.info("Initializing %s contexts for %s", count, self.name)
- for i in range(count):
- self.contexts[i] = HostContext(
- self.session, self.modules, name=f"HostContext-{self.name}-{i}"
- )
-
- def shutdown(self):
- for hc in self.contexts:
- if hc is not None:
- hc.shutdown()
-
- def module(self, name: str) -> VmModule:
- for m in self.modules:
- if m.name == name:
- return m
- raise KeyError(
- f"Module `{name}` not found. Available: {[m.name for m in self.modules]}"
- )
-
- def function(self, module_name: str, function_name: str) -> VmFunction:
- m = self.module(module_name)
- f = m.lookup_function(function_name)
- if f is None:
- raise KeyError(
- f"Function '{function_name}' not found in '{module_name}'. "
- f"Available: {m.function_names}"
- )
- return f
-
- @property
- def host_context(self) -> "HostContext":
- """Gets a context, load balancing across available instances."""
- with _CONFIG_LOCK:
- self._context_counter += 1
- counter = self._context_counter
- contexts = self.contexts
- context = contexts[counter % len(contexts)]
- assert context is not None, "Module set not initialized"
- return context
-
-
-_ThunkQueueT = queue.SimpleQueue[Union[None, Callable[[], None]]]
-
-
-class HostContext:
- def __init__(self, session: DeviceSession, modules: list[VmModule], name: str):
- self.session = session
- self.vm_context = VmContext(session.vm_instance, modules=modules)
- self.name = name
- self.loop = asyncio.new_event_loop()
- self.loop.set_debug(True)
-
- # def exc_handler(loop, context):
- # print("[EXCEPTION]", loop, context)
- # self.loop.set_exception_handler(exc_handler)
-
- self._device_bridge = HalDeviceLoopBridge(session.device, self.loop)
- self._shutdown_future = self.loop.create_future()
- logger.info(f"Starting asyncio loop thread %s", name)
- self._loop_thread = Thread(
- target=self.loop.run_until_complete,
- args=[self._shutdown_future],
- name=name,
- daemon=False,
- )
- self._loop_thread.start()
-
- def shutdown(self, join: bool = True):
- if self._shutdown_future is None:
- return
- logger.info("Signalling shutdown of host context %s", self.name)
- local_future = self._shutdown_future
- del self._shutdown_future
-
- def _shutdown():
- local_future.set_result(True)
-
- self.loop.call_soon_threadsafe(_shutdown)
- self._device_bridge.stop()
- if join:
- self._loop_thread.join()
- self.loop.close()
-
- def __del__(self):
- if hasattr(self, "_shutdown_future"):
- warnings.warn(f"HostContext deallocated without shutdown(): {self}")
- self.shutdown(join=False)
-
- def run_concurrent(
- self, coro: Coroutine[Any, Any, T]
- ) -> concurrent.futures.Future[T]:
- """Runs a coroutine from another thread, returning a concurrent Future.
-
- This should be used for submitting initial work to the host context from
- another thread or event loop.
-
- Note that the concurrent Future should have its result() retrieved to
- ensure that any asynchronous exceptions are propagated. Otherwise, they will
- be silently consumed.
- """
- return asyncio.run_coroutine_threadsafe(coro, self.loop)
-
- def run_sync(self, coro: Coroutine[Any, Any, T]) -> T:
- """Runs a coroutine on the host context loop from another thread.
-
- Waits on and returns the result.
- This is primarily intended for testing.
- """
- return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
-
- def on_semaphore(
- self, sem: HalSemaphore, payload: int, value: Any
- ) -> asyncio.Future:
- """Returns an awaitable for when the semaphore attains a payload timepoint.
-
- The resulting Future will take the given `value` once complete.
- """
- return self._device_bridge.on_semaphore(sem, payload, value)
-
-
-class WorkQueue:
- """Models a queue as a progression of steps against a timeline semaphore."""
-
- __slots__ = [
- "_device",
- "_lock",
- "_semaphore",
- "_step",
- "index",
- ]
-
- def __init__(self, session: DeviceSession, index: int = 0):
- self.index = index
- self._device = session.device
- self._lock = Lock()
- self._semaphore = session.device.create_semaphore(0)
- self._step = 0
-
- def execute_sequential(self, command_buffer: HalCommandBuffer):
- """Executes a list of command buffers at the current step, advancing to the
- next.
- """
- with self._lock:
- current_step = self._step
- next_step = current_step + 1
- self._step = next_step
- sem = self._semaphore
- self._device.queue_execute(
- command_buffer, [(sem, current_step)], [(sem, next_step)]
- )
-
- def current_fence(self) -> HalFence:
- """Gets a fence representing the current step."""
- with self._lock:
- return HalFence.create_at(self._semaphore, self._step)
-
- def step_fences(self) -> tuple[HalFence, HalFence]:
- """Gets two fences, one at the current step and one at the next."""
- with self._lock:
- current_step = self._step
- next_step = current_step + 1
- self._step = next_step
- sem = self._semaphore
- return HalFence.create_at(sem, current_step), HalFence.create_at(sem, next_step)
-
- def sync(self, host_context: HostContext) -> asyncio.Future:
- """Awaitable that completes when all work currently queued completed."""
- with self._lock:
- current_step = self._step
- return host_context.on_semaphore(self._semaphore, current_step, True)
-
- def guard(self, value: T) -> "TimelineGuarded[T]":
- """Guards an arbitrary value as a timeline guard at the current queue
- position. The value will become available when the queue is sync'd."""
- return TimelineGuarded(value, self._semaphore, self._step)
-
- def __repr__(self):
- with self._lock:
- return f"WorkQueue[{self.index}](semaphore={self._semaphore}, step={self._step}"
-
-
-class TransferBuffer:
- """Transfer buffers are pairs of host/device buffers of a specific size.
-
- They are used for streaming to/from the device.
- """
-
- __slots__ = [
- "host_buffer",
- "device_buffer",
- "host_buffer_map",
- "_pool",
- ]
-
- def __init__(self, session: DeviceSession, buffer_size_bytes: int):
- self.host_buffer = session.device.allocator.allocate_buffer(
- memory_type=MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE,
- allowed_usage=BufferUsage.DEFAULT,
- allocation_size=buffer_size_bytes,
- )
- self.device_buffer = session.device.allocator.allocate_buffer(
- memory_type=MemoryType.DEVICE_LOCAL,
- allowed_usage=BufferUsage.DEFAULT,
- allocation_size=buffer_size_bytes,
- )
- self.host_buffer_map = self.host_buffer.map()
- self._pool: Optional["TransferBufferPool"] = None
-
- @staticmethod
- def allocate_shaped(
- session: DeviceSession, shape: list[int], element_type: HalElementType
- ) -> "TransferBuffer":
- assert HalElementType.is_byte_aligned(element_type)
- buffer_size_bytes = math.prod(shape) * HalElementType.dense_byte_count(
- element_type
- )
- return TransferBuffer(session, buffer_size_bytes)
-
- def recycle(self):
- pool = self._pool
- assert (
- pool is not None
- ), f"Cannot recycle a TransferBuffer that was not acquired from a pool ({self})"
- self._pool = None
- pool.recycle(self)
-
- def h2d_array(
- self,
- cb: HalCommandBuffer,
- shape: list[int],
- element_type: HalElementType,
- *,
- fill_value: Any = None,
- ) -> tuple[np.ndarray, HalBufferView]:
- """Performs an h2d transfer on the given CommandBuffer of the given shape and
- element type.
-
- Returns a host array and device buffer view. Because transfers do not start
- until the command buffer is submitted, the host array should be populated
- between the return from this call and submission.
- """
- ary = self.host_buffer_map.asarray(
- shape, HalElementType.map_to_dtype(element_type)
- )
- if fill_value is not None:
- ary.fill(fill_value)
- bv = HalBufferView(self.device_buffer, shape, element_type)
- cb.copy(self.host_buffer, self.device_buffer, length=bv.byte_length)
- return ary, bv
-
- def __repr__(self):
- if self._pool is None:
- return f"TransferBuffer(FREE)"
- else:
- return f"TransferBuffer({self._pool})"
-
- if not NDEBUG:
-
- def __del__(self):
- if self._pool is not None:
- warnings.warn(
- f"Deallocated TransferBuffer which needed to be recycled: {self}"
- )
-
-
-class TransferBufferPool:
- """Pool of transfer buffers of a fixed size."""
-
- __slots__ = [
- "_allocator",
- "_free_list",
- "name",
- ]
-
- def __init__(
- self,
- allocator: Callable[[], TransferBuffer],
- *,
- initial_capacity: int,
- growable: bool = False,
- name: str = "",
- ):
- self.name = name
- if initial_capacity > 0:
- self._free_list = [allocator() for _ in range(initial_capacity)]
- self._allocator = None
- if growable:
- self._allocator = allocator
-
- @staticmethod
- def shaped(
- session: DeviceSession,
- shape: list[int],
- element_type: HalElementType,
- *,
- initial_capacity: int,
- growable: bool = False,
- name: str = "",
- ) -> "TransferBufferPool":
- """Allocates a pool of transfer buffers of the given shape."""
- if initial_capacity > 0:
- logger.info(
- "Allocating initial capacity %s of '%s' transfer buffers: %s x %r",
- initial_capacity,
- name,
- shape,
- element_type,
- )
- return TransferBufferPool(
- lambda: TransferBuffer.allocate_shaped(session, shape, element_type),
- initial_capacity=initial_capacity,
- growable=growable,
- name=name,
- )
-
- @staticmethod
- def sized(
- session: DeviceSession,
- buffer_byte_size: int,
- *,
- initial_capacity: int,
- growable: bool = False,
- name: str = "",
- ) -> "TransferBufferPool":
- """Allocates a pool of transfer buffers of a given size in bytes."""
- if initial_capacity > 0:
- logger.info(
- "Allocating initial capacity %s of '%s' transfer buffers: %s bytes",
- initial_capacity,
- name,
- buffer_byte_size,
- )
- return TransferBufferPool(
- lambda: TransferBuffer(session, buffer_byte_size),
- initial_capacity=initial_capacity,
- growable=growable,
- name=name,
- )
-
- def acquire(self) -> TransferBuffer:
- """Acquires a transfer buffer from the pool.
-
- Must be returned via recycle() when done.
- """
- free_list = self._free_list
- if len(free_list) > 0:
- tb = free_list.pop()
- assert tb._pool is None
- tb._pool = self
- return tb
-
- allocator = self._allocator
- if not allocator:
- raise RuntimeError(
- f"Transfer buffer pool '%s' exhausted and not growable", self.name
- )
- logger.info("Grow transfer buffer pool '%s'", self.name)
- tb = allocator()
- assert tb._pool is None
- tb._pool = self
- return tb
-
- def recycle(self, tb: TransferBuffer):
- """Recycles an acquired transfer buffer."""
- self._free_list.append(tb)
-
- def __repr__(self):
- return f"TransferBufferPool({self.name})"
-
-
-class AsyncResources:
- """Resources held for some asynchronous scope."""
-
- __slots__ = [
- "_resources",
- ]
-
- def __init__(self):
- self._resources: list[Union[TransferBuffer, "AsyncResources"]] = []
-
- def acquire_transfer_buffer(self, pool: TransferBufferPool) -> TransferBuffer:
- tb = pool.acquire()
- self._resources.append(tb)
- return tb
-
- def recycle(self):
- for r in self._resources:
- r.recycle()
- self._resources.clear()
-
- if not NDEBUG:
-
- def __del__(self):
- if len(self._resources) != 0:
- warnings.warn(
- f"Deallocated AsyncResources that was not recycled: {self}"
- )
- self.recycle()
-
-
-class TimelineGuarded(Generic[T]):
- """Some form of results that are structurally available now but will not be
- populated until some point in the future.
-
- This is used to encapsulate entities that are guarded by availability of
- a timepoint. Note that we only allow a single timepoint guard in order to
- simplify subsequent coordination. This will typically be the case when the
- guard is derived from a queue of some form (as opposed to a gather).
- """
-
- __slots__ = [
- "value",
- "sem",
- "timeline",
- ]
-
- def __init__(self, value: T, sem: HalSemaphore, timeline: int):
- self.value = value
- self.sem = sem
- self.timeline = timeline
-
- def resolve(self, host_context: HostContext) -> asyncio.Future[T]:
- """Produces an awaitable that resolves to the value once available."""
- return host_context.on_semaphore(self.sem, self.timeline, self.value)
-
- def __repr__(self):
- return f"TimelineGuarded[{self.sem} @ {self.timeline}] = {self.value}"
diff --git a/sharktank/sharktank/serving_poc/llm/__init__.py b/sharktank/sharktank/serving_poc/llm/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/sharktank/sharktank/serving_poc/llm/api/rest_server.py b/sharktank/sharktank/serving_poc/llm/api/rest_server.py
deleted file mode 100644
index 67536173f..000000000
--- a/sharktank/sharktank/serving_poc/llm/api/rest_server.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-# Heavily adapted from the vllm api_server.py.
-
-from typing import AsyncGenerator, Optional, Sequence
-
-import argparse
-import json
-
-from fastapi import FastAPI, Request
-from fastapi.responses import JSONResponse, Response, StreamingResponse
-import sys
-import uuid
-import uvicorn
-
-from ...framework.logging import get_logger
-from ...framework.session import DeviceSession
-
-
-from ..service import (
- create_mock_generate_service,
- GenerateService,
- GenerateRequest,
-)
-
-logger = get_logger("sharktank.serving_poc.llm.api_server")
-app = FastAPI()
-service: Optional[GenerateService] = None
-
-
-def get_service() -> GenerateService:
- assert service is not None, "Service was not initialized"
- return service
-
-
-@app.get("/health")
-async def health() -> Response:
- get_service()
- return Response(status_code=200)
-
-
-@app.post("/generate")
-async def generate(request: Request) -> Response:
- service = get_service()
- r = await request.json()
- prompt = r.pop("prompt")
- stream = bool(r.pop("stream", False))
- request_id = uuid.uuid4().hex
-
- generate_request = GenerateRequest(request_id=request_id, prompt=prompt)
- result_parts = service.handle_request(generate_request)
-
- if stream:
- # TODO: This isn't entirely matching how others do it: we should be returning
- # the full result on each update.
- async def stream_contents() -> AsyncGenerator[bytes, None]:
- async for part in result_parts:
- response_record = json.dumps({"text": part.text})
- yield (response_record + "\0").encode()
-
- return StreamingResponse(stream_contents())
-
- # Non-streaming just reads to the final.
- async for result_part in result_parts:
- if await request.is_disconnected():
- # Abort.
- await service.abort(generate_request.request_id)
- return Response(status_code=499)
-
- assert result_part is not None, "No results generated!"
- return JSONResponse({"text": result_part.text})
-
-
-def main(clargs: Sequence[str]):
- global service
- parser = argparse.ArgumentParser()
- parser.add_argument("--host", type=str, default=None)
- parser.add_argument("--port", type=int, default=8000)
- parser.add_argument(
- "--root-path",
- type=str,
- default=None,
- help="Root path to use for installing behind path based proxy.",
- )
- parser.add_argument(
- "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout"
- )
- parser.add_argument(
- "--testing-mock-service",
- action="store_true",
- help="Enable the mock testing service",
- )
- parser.add_argument(
- "--device-uri", type=str, default="local-task", help="Device URI to serve on"
- )
-
- args = parser.parse_args(clargs)
-
- # Spin up the device machinery.
- # Note that in the future, for multi-device, we will need more scaffolding for
- # configuration and bringup, obviously.
- device_session = DeviceSession(uri=args.device_uri)
-
- if args.testing_mock_service:
- logger.info("Enabling mock LLM generate service")
- service = create_mock_generate_service()
-
- app.root_path = args.root_path
- uvicorn.run(
- app,
- host=args.host,
- port=args.port,
- log_level="debug",
- timeout_keep_alive=args.timeout_keep_alive,
- )
-
-
-if __name__ == "__main__":
- main(sys.argv[1:])
diff --git a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py
deleted file mode 100644
index a2299c67e..000000000
--- a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py
+++ /dev/null
@@ -1,133 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-"""Manages the block cache."""
-
-from iree.runtime import ( # type: ignore
- HalBufferView,
- HalElementType,
- BufferUsage,
- MemoryType,
- PyModuleInterface,
- VmModule,
-)
-
-from ..framework.logging import get_logger
-from ..framework.session import DeviceSession
-
-from .config import human_size, CacheParams
-
-
-logger = get_logger("sharktank.serving_poc.llm.cache")
-
-
-class AttnBlockCacheEntry:
- __slots__ = [
- "index",
- "in_use",
- ]
-
- def __init__(self, index: int):
- self.index = index
- self.in_use = False
-
- def __repr__(self):
- return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})"
-
-
-class AttnBlockCache:
- def __init__(self, session: DeviceSession, cache_params: CacheParams):
- self.session = session
- self.cache_params = cache_params
- self._initialize_block_cache()
-
- def _initialize_block_cache(self):
- model_params = self.cache_params.model
- # Allocate the on-device cache slab.
- attn_block_count = self.cache_params.device_block_count
- attn_block_size_elements = self.cache_params.attn_block_size_elements
- attn_block_size_bytes = attn_block_size_elements * model_params.attn_dtype_size
- attn_cache_size_bytes = attn_block_count * attn_block_size_bytes
-
- logger.info("Setting up cache for\n %r", self.cache_params)
- logger.info(
- "Allocating attention static cache on device of %s "
- "(blocks=%s, block_size=%s bytes)",
- human_size(attn_cache_size_bytes),
- attn_block_count,
- attn_block_size_bytes,
- )
- self.attn_block_buffer = self.session.device.allocator.allocate_buffer(
- memory_type=MemoryType.DEVICE_LOCAL,
- allowed_usage=BufferUsage.DEFAULT,
- allocation_size=attn_cache_size_bytes,
- )
-
- # Attn block logical view.
- self.attn_block_buffer_view = HalBufferView(
- self.attn_block_buffer,
- [
- attn_block_count,
- attn_block_size_elements,
- ],
- model_params.attn_dtype,
- )
-
- # Accounting structs.
- self.attn_block_entries = [
- AttnBlockCacheEntry(i) for i in range(attn_block_count)
- ]
- self.attn_block_free = list(self.attn_block_entries)
-
- async def acquire_attn_blocks(
- self, count: int, into_list: list[AttnBlockCacheEntry]
- ):
- """Acquires 'count' attention blocks.
-
- If there are insufficient free blocks, raises an exception.
- """
- free_list = self.attn_block_free
- assert (
- len(free_list) >= count
- ), f"Cache does not contain requested {count} free attn blocks"
- for i in range(count):
- into_list.append(free_list.pop())
-
- async def release_attn_blocks(self, blocks: list[AttnBlockCacheEntry]):
- """Releases a list of attention blocks.
-
- If at all possible, this should be batched to include all blocks that need to
- be released at a given time since this will trigger heavy-weight scheduling
- that will work better with a view of the new free list as a whole.
- """
- free_list = self.attn_block_free
- for block in blocks:
- free_list.append(block)
-
-
-def create_attn_block_cache_module(attn_block_cache: AttnBlockCache) -> VmModule:
- """Creates a VM module that exports the attention block cache.
-
- For in-system use, we use a dynamic module that can provide the block cache
- slab. In other uses, this may be provided by a statically compiled module
- that does the same.
-
- Interface:
- Module name: attn_block_cache
- Exports:
- func @attn_block_cache.get_shared_slab() -> (!hal.buffer_view)
- """
-
- class Module:
- def __init__(self, iface):
- ...
-
- def get_shared_slab(self):
- return attn_block_cache.attn_block_buffer_view.ref
-
- iface = PyModuleInterface(module_name="attn_block_cache", ctor=Module)
- iface.export("get_shared_slab", "0v_r", Module.get_shared_slab)
- return iface.create()
diff --git a/sharktank/sharktank/serving_poc/llm/config.py b/sharktank/sharktank/serving_poc/llm/config.py
deleted file mode 100644
index df5db5f8f..000000000
--- a/sharktank/sharktank/serving_poc/llm/config.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-"""Configuration objects.
-
-Parameters that are intrinsic to a specific model.
-
-In a typical transformer model, the KV cache is organized similar to (mapped to
-our parameter names below):
- k = tensor.empty(transformer_block_count, batch_size, seq,
- attn_head_count, attn_head_dim)
- v = ...
-
-For context, a popular model has parameters of:
- attn_dtype_size = 2 # (fp16)
- max_seq_len = 2048
- transformer_block_count = 32
- attn_head_count = 32
- attn_head_dim = 128 # (dim / head_count)
-
-If paging, then we primary care about the organization of a single block, where
-a block represents a single position in the sequence for a single item in the batch.
-Therefore, it will be organized like:
- block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim)
-
-In this scenario, we declare that one block holds the KV cache for all transformer
-block layers because it reduces the accounting. As such, for the above example,
-a single position in the sequence will be 524,288 bytes, assuming a 2-byte element
-type. If we choose to block by block_stride=16 positions, each block will be 8MiB.
-Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536
-blocks for a total number of sequence positions of 24,576.
-
-These are well-known numbers but are derived above to give a sense of scale.
-
-In order to indirect through to the block cache, we have to provide the index map
-to specific invocations:
-
-* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will
- need write indices of [batch_size, prompt_len // block_stride + 1].
-* Decode step: Decode is auto-regressive, and needs to first compute the new kv
- row and then attend over all rows in the cache up to this point in the sequence.
-
-If wanting to avoid dynamic allocation of transients, we can also pool the index
-tables based on the maximum batch size and maximum sequence length. Since all
-block cache sizes are well within the range of an i16, we will use that for storage.
-Therefore, each batch invocation would need a block lookup table of:
-
- byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t)
-
-For a max_batch_size of 16, this is 4KiB of block index table lookups per
-invocation. We don't have to statically allocate this, but the system is more
-predictable if we just reserve what we need. Again, numbers are given to give a
-sense of scale only: real workloads will vary.
-"""
-
-from dataclasses import dataclass
-
-from iree.runtime import ( # type: ignore
- HalElementType,
-)
-
-import json
-
-
-@dataclass
-class ModelParams:
- """Parameters for a specific compiled model, sufficient to do cache planning and
- invocations."""
-
- # The element type of the attention caches.
- attn_dtype: HalElementType
-
- # Maximum length of a sequence including prompt and output.
- max_seq_len: int
-
- # Number of transformer blocks.
- transformer_block_count: int
-
- # Number of attention heads per block.
- attn_head_count: int
-
- # Dimensionality of each attention head
- attn_head_dim: int
-
- # Position stride per attention block
- block_seq_stride: int
-
- # Batch sizes that the prefill stage is compiled for. These are expected to be
- # functions exported from the model with suffixes of "_bs{batch_size}". Must
- # be in ascending order.
- prefill_batch_sizes: list[int]
-
- # Similarly, batch sizes that the decode stage is compiled for.
- decode_batch_sizes: list[int]
-
- # Name of the IREE module implementing the model.
- module_name: str = "module"
-
- # ABI of the module.
- module_abi_version: int = 1
-
- # Size in bytes of the KV cache dtype.
- @property
- def attn_dtype_size(self) -> int:
- assert HalElementType.is_byte_aligned(self.attn_dtype)
- return HalElementType.dense_byte_count(self.attn_dtype)
-
- @property
- def max_prefill_batch_size(self) -> int:
- return self.prefill_batch_sizes[-1]
-
- @property
- def max_decode_batch_size(self) -> int:
- return self.decode_batch_sizes[-1]
-
- @property
- def max_batch_size(self):
- return max(self.max_prefill_batch_size, self.max_decode_batch_size)
-
- @staticmethod
- def load_json(path):
- f = open(path)
- j = json.load(f)
- return ModelParams(attn_dtype=HalElementType.FLOAT_16, **j)
-
-
-@dataclass
-class CacheParams:
- """Parameters for management of the block cache.
-
- This is paired with a ModelParams.
-
- We presently use a static block cache configuration and hand-wave either a tuning
- run or pen/paper analysis to derive the parameters.
- """
-
- model: ModelParams
-
- # The size of the static block cache on the device.
- device_block_count: int
-
- # The stride of each block in sequence positions.
- block_pos_stride: int
-
- @property
- def attn_unit_size_elements(self) -> int:
- """Size in bytes of each cache line in the attention cache.
-
- Each cache line can store a unit position stride.
- """
- size = 1
- size *= self.model.transformer_block_count
- size *= 2 # K and V cache line
- size *= self.model.attn_head_count
- size *= self.model.attn_head_dim
- return size
-
- @property
- def attn_block_size_elements(self) -> int:
- """Size in bytes of each attention block of {block_position_stride} positions."""
- return self.attn_unit_size_elements * self.block_pos_stride
-
-
-@dataclass
-class ServiceParams:
- """Parameters for the serving service."""
-
- cache: CacheParams
- model: ModelParams
-
-
-# From: https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
-def human_size(num, suffix="B"):
- for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
- if abs(num) < 1024.0:
- return f"{num:3.1f}{unit}{suffix}"
- num /= 1024.0
- return f"{num:.1f}Yi{suffix}"
diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py
deleted file mode 100644
index 8ae0be637..000000000
--- a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py
+++ /dev/null
@@ -1,495 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-"""Implements the BatchGenerateService for V1 compiled models.
-
-This is far from where we want to land but is intended for first round bootstrapping.
-Perhaps the biggest issue is that it wouldn't mate well as-is with samplers.
-"""
-
-import asyncio
-from dataclasses import dataclass
-
-import numpy as np
-
-from iree.runtime import ( # type: ignore
- HalBufferView,
- HalCommandBuffer,
- HalElementType,
- HalFence,
- VmFunction,
- VmVariantList,
-)
-
-from ...framework.logging import get_logger, NDEBUG
-from ...framework.session import (
- AsyncResources,
- DeviceSession,
- TimelineGuarded,
- TransferBufferPool,
- WorkQueue,
-)
-
-from ..attn_block_cache import AttnBlockCacheEntry, AttnBlockCache
-from ..config import ServiceParams
-from ..service import (
- BatchGenerateService,
- BatchGenerateState,
- GenerateRequest,
-)
-
-
-logger = get_logger("sharktank.serving_poc.llm.impl.service_v1")
-
-EXPECTED_CONCURRENCY = 10
-
-
-class GenerateServiceV1(BatchGenerateService):
- def __init__(
- self, *, session: DeviceSession, params: ServiceParams, cache: AttnBlockCache
- ):
- self.params = params
- self.block_pos_stride = params.cache.block_pos_stride
- self.batch_sizes = params.model.prefill_batch_sizes
- # TODO: Remove distinction between prefill and decode batch sizes.
- assert params.model.decode_batch_sizes == self.batch_sizes
- self.session = session
- self.cache = cache
- module_name = params.model.module_name
- logger.info("Configuring serving for module set %s", module_name)
- self.module_set = session.module_set(params.model.module_name)
-
- # Initialize prefill entry-points (1 per batch size).
- self.prefill_functions: dict[int, VmFunction] = {}
- for bs in self.batch_sizes:
- assert bs not in self.prefill_functions
- symbol_name = f"prefill_bs{bs}"
- logger.info("Looking up symbol '%s'", symbol_name)
- self.prefill_functions[bs] = self.module_set.function(
- module_name, symbol_name
- )
-
- # Initialize decode entry-points (1 per batch size).
- self.decode_functions: dict[int, VmFunction] = {}
- for bs in self.batch_sizes:
- assert bs not in self.decode_functions
- symbol_name = f"decode_bs{bs}"
- logger.info("Looking up symbol '%s'", symbol_name)
- self.decode_functions[bs] = self.module_set.function(
- module_name, symbol_name
- )
-
- self._initialize_transfer_pools()
-
- def _initialize_transfer_pools(self):
- params = self.params
- max_bs = params.model.max_batch_size
- max_sl = params.model.max_seq_len
- initial_inflight = EXPECTED_CONCURRENCY
-
- # block_indices_pool: array([max_batch_size, max_attn_blocks], np.int64)
- # Suitable to handle the sequence->block mapping for all steps.
- self.block_indices_pool = TransferBufferPool.shaped(
- self.session,
- [
- max_bs,
- max_sl // self.block_pos_stride,
- ],
- HalElementType.SINT_64,
- initial_capacity=initial_inflight,
- growable=True,
- name="block_cache_indices",
- )
-
- # Prefill tokens: array([max_batch_size, max_seq_len], np.int64)
- # Tokens inputs to prefill.
- self.prefill_tokens_pool = TransferBufferPool.shaped(
- self.session,
- [
- max_bs,
- max_sl,
- ],
- HalElementType.SINT_64,
- initial_capacity=initial_inflight,
- growable=True,
- name="prefill_tokens",
- )
-
- # Prefill sequence lengths: array([max_batch_size], np.int64)
- # Sequence lengths of input tokens.
- self.prefill_seq_lens_pool = TransferBufferPool.shaped(
- self.session,
- [max_bs],
- HalElementType.SINT_64,
- initial_capacity=initial_inflight,
- growable=True,
- name="prefill_seq_lens",
- )
-
- # Decode tokens: array([max_batch_size], np.int64)
- # Tokens to perform a decode step with.
- self.decode_tokens_pool = TransferBufferPool.shaped(
- self.session,
- [max_bs, 1],
- HalElementType.SINT_64,
- initial_capacity=initial_inflight,
- growable=True,
- name="decode_tokens",
- )
-
- # Decode seq lengths: array([max_batch_size], np.int64)
- # Decoder seq length for this step
- self.decode_seq_lens_pool = TransferBufferPool.shaped(
- self.session,
- [max_bs],
- HalElementType.SINT_64,
- initial_capacity=initial_inflight,
- growable=True,
- name="decode_seq_len",
- )
-
- # Decode start positions: array([max_batch_size], np.int64)
- # Tokens to perform a decode step with.
- self.decode_start_pos_pool = TransferBufferPool.shaped(
- self.session,
- [max_bs],
- HalElementType.SINT_64,
- initial_capacity=initial_inflight,
- growable=True,
- name="decode_start_pos",
- )
-
- def start(self) -> "GenerateState":
- return GenerateState(self)
-
- def shutdown(self):
- self.session.shutdown()
-
-
-class _Sequence:
- __slots__ = [
- "attn_blocks",
- "attn_blocks_needed",
- "current_token_ids",
- "decode_token_ids",
- "request",
- "seq_length",
- ]
-
- current_token_ids: list[int]
- decode_token_ids: list[int]
-
- def __init__(self, request: GenerateRequest):
- self.request = request
- self.seq_length: int = 0
- self.attn_blocks: list[AttnBlockCacheEntry] = []
- self.attn_blocks_needed: int = 0
- self.decode_token_ids = []
- self.current_token_ids = []
-
- def attn_blocks_available(self):
- return len(self.attn_blocks)
-
- def resize_attention(self, new_size):
- old_size = self.attn_blocks_needed
- self.attn_blocks_needed = new_size
- return new_size - old_size
-
-
-class GenerateState(BatchGenerateState):
- __slots__ = [
- "_bs",
- "_decode_function",
- "_prefill_function",
- "_max_attn_blocks_length",
- "_max_seq_length",
- "_resources",
- "_service",
- "_sequences",
- "_batch_queue",
- ]
-
- def __init__(self, service: GenerateServiceV1):
- super().__init__(service.module_set.host_context)
- self._resources = AsyncResources()
- self._service = service
- self._sequences: list[_Sequence] = []
- self._batch_queue = WorkQueue(service.session)
-
- async def recycle(self):
- """Recycles or releases all resources consumed by this instance."""
- cache = self._service.cache
- self._batch_queue.sync(self.host_context)
- self._resources.recycle()
- all_blocks = []
- for seq in self._sequences:
- all_blocks.extend(seq.attn_blocks)
- seq.attn_blocks.clear()
- self._sequences = []
- await cache.release_attn_blocks(all_blocks)
-
- async def set_sequences(self, requests: list[GenerateRequest]):
- """Initiates processing of a list of sequences that make up a batch.
-
- This is async because it acquires resources which may not be available.
- """
- service = self._service
- block_pos_stride = service.block_pos_stride
-
- # Loop through each request and reserve initial attention blocks.
- bs = 0
- sequences = self._sequences
- assert not sequences, "set_sequences already called"
- max_attn_blocks_length = 0
- max_seq_length = 0
- attn_blocks_required = 0
-
- for req in requests:
- bs += 1
- seq = _Sequence(req)
- sequences.append(seq)
- seq.current_token_ids = req.required_prompt_token_ids
- seq_length = len(seq.current_token_ids)
- seq.seq_length = seq_length
- max_seq_length = max(max_seq_length, seq_length)
- initial_block_count = seq_length // block_pos_stride + 1
- attn_blocks_required += initial_block_count
- seq.attn_blocks_needed = initial_block_count
- max_attn_blocks_length = max(max_attn_blocks_length, initial_block_count)
-
- # Determine the appropriate batched entrypoints.
- assert bs > 0
- for allowed_bs in service.batch_sizes:
- if allowed_bs >= bs:
- self._prefill_function = service.prefill_functions[allowed_bs]
- self._decode_function = service.decode_functions[allowed_bs]
- break
- else:
- raise AssertionError(f"Unsupported batch size: {bs}")
-
- # Acquire the needed attention blocks in one batch so as to give the scheduler
- # the most visibility into the need.
- logger.debug("Acquire prefill attn blocks: %s", attn_blocks_required)
- all_attn_blocks: list[AttnBlockCacheEntry] = []
- await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks)
- block_index = 0
- for seq in sequences:
- next_block_count = seq.attn_blocks_needed
- seq.attn_blocks.extend(
- all_attn_blocks[block_index : block_index + seq.attn_blocks_needed]
- )
- block_index += next_block_count
-
- # Save state.
- self._bs = allowed_bs
- self._max_attn_blocks_length = max_attn_blocks_length
- self._max_seq_length = max_seq_length
-
- async def prefill(self) -> TimelineGuarded[HalBufferView]:
- hc = self.host_context
- service = self._service
- resources = self._resources
- bs = self._bs
- service = self._service
- block_pos_stride = service.block_pos_stride
- max_attn_blocks_length = self._max_attn_blocks_length
- max_seq_length = max_attn_blocks_length * block_pos_stride
- sequences = self._sequences
- work_queue = self._batch_queue
-
- # Record a command buffer for performing h2d transfers.
- cb = HalCommandBuffer(hc.session.device)
-
- # Prepare input tokens, sequence lengths and block indices.
- # We acquire a transfer buffer of each from the respective pool, populate its
- # host side and enqueue.
- # prefill_tokens: array([bs, max_seq_length], np.int32)
- prefill_tokens_host, prefill_tokens_device = resources.acquire_transfer_buffer(
- service.prefill_tokens_pool
- ).h2d_array(cb, [bs, max_seq_length], HalElementType.SINT_64, fill_value=0)
-
- # prefill_seq_lens: array([bs], np.int32)
- (
- prefill_seq_lens_host,
- prefill_seq_lens_device,
- ) = resources.acquire_transfer_buffer(service.prefill_seq_lens_pool).h2d_array(
- cb, [bs], HalElementType.SINT_64, fill_value=0
- )
-
- # attn_block_indices: array([bs, max_attn_blocks], np.in16)
- (
- prefill_attn_block_indices_host,
- prefill_attn_block_indices_device,
- ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array(
- cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0
- )
-
- # Populate host buffers for each sequence.
- for i in range(len(sequences)):
- seq = sequences[i]
- attn_blocks = seq.attn_blocks
- current_token_ids = seq.current_token_ids
- row_seq_len = len(current_token_ids)
- prefill_tokens_host[i, 0:row_seq_len] = current_token_ids
- prefill_seq_lens_host[i] = row_seq_len
- for j in range(len(seq.attn_blocks)):
- prefill_attn_block_indices_host[i, j] = attn_blocks[j].index
-
- # Perform h2d transfers.
- cb.end()
- work_queue.execute_sequential(cb)
-
- # Inputs:
- # token_ids
- # seq_lens
- # attn_block_indices
- # attn_block_buffer_view (the entire slab passed as input)
- # wait, signal semaphores
- # tied attn_block_buffer (for input[2])
- # tied attn_block_buffer (for result[0])
- inputs = VmVariantList(3)
- inputs.push_ref(prefill_tokens_device)
- inputs.push_ref(prefill_seq_lens_device)
- inputs.push_ref(prefill_attn_block_indices_device)
- inputs.push_ref(service.cache.attn_block_buffer_view)
-
- # Outputs:
- # attn_block_buffer_view (tied output)
- # decode_tokens
- outputs = VmVariantList(1)
- # TODO: Async invoke.
- hc.vm_context.invoke(self._prefill_function, inputs, outputs)
- return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView))
-
- async def set_decode_step(self, tokens):
- """Initiates processing of a list of tokens to decode across each batch
-
- This is async because it acquires resources which may not be available.
- """
- service = self._service
- block_pos_stride = service.block_pos_stride
-
- sequences = self._sequences
- assert sequences, "set_sequences was not called yet"
- assert len(sequences) == len(tokens), "expected token for each sequence"
-
- max_attn_blocks_length = 0
- max_seq_length = 0
- attn_blocks_required = 0
-
- for tok, seq in zip(tokens, self._sequences):
- seq.decode_token_ids.append(tok)
- seq.seq_length = seq.seq_length + 1
-
- max_seq_length = max(max_seq_length, seq.seq_length)
- block_count = seq.seq_length // block_pos_stride + 1
-
- seq.attn_blocks_needed = block_count
- attn_blocks_required += block_count - seq.attn_blocks_available()
- max_attn_blocks_length = max(max_attn_blocks_length, block_count)
-
- # Acquire the needed attention blocks in one batch so as to give the scheduler
- # the most visibility into the need.
- logger.debug("Acquire decode attn blocks: %s", attn_blocks_required)
- all_attn_blocks: list[AttnBlockCacheEntry] = []
- await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks)
- block_index = 0
- for seq in sequences:
- next_block_count = seq.attn_blocks_needed - seq.attn_blocks_available()
- seq.attn_blocks.extend(
- all_attn_blocks[block_index : block_index + next_block_count]
- )
- block_index += next_block_count
-
- # Save state.
- self._max_attn_blocks_length = max_attn_blocks_length
- self._max_seq_length = max_seq_length
-
- async def decode(self) -> TimelineGuarded[HalBufferView]:
- hc = self.host_context
- service = self._service
- resources = self._resources
- bs = self._bs
- max_attn_blocks_length = self._max_attn_blocks_length
- sequences = self._sequences
- work_queue = self._batch_queue
-
- # Record a command buffer for performing h2d transfers.
- cb = HalCommandBuffer(hc.session.device)
-
- # decode_tokens: array([bs, 1], np.int32)
- (decode_tokens_host, decode_tokens_device,) = resources.acquire_transfer_buffer(
- service.decode_tokens_pool
- ).h2d_array(cb, [bs, 1], HalElementType.SINT_64, fill_value=0)
-
- # decode_seq_lens: array([bs], np.int32)
- (
- decode_seq_lens_host,
- decode_seq_lens_device,
- ) = resources.acquire_transfer_buffer(service.decode_seq_lens_pool).h2d_array(
- cb, [bs], HalElementType.SINT_64, fill_value=0
- )
-
- # decode_start_pos: array([bs], np.int32)
- (
- decode_start_pos_host,
- decode_start_pos_device,
- ) = resources.acquire_transfer_buffer(service.decode_start_pos_pool).h2d_array(
- cb, [bs], HalElementType.SINT_64, fill_value=0
- )
-
- # attn_block_indices: array([bs, max_attn_blocks], np.in16)
- (
- decode_attn_block_indices_host,
- decode_attn_block_indices_device,
- ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array(
- cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0
- )
-
- # Populate host buffers for each sequence.
- for i in range(len(sequences)):
- seq = sequences[i]
- attn_blocks = seq.attn_blocks
-
- tok = seq.decode_token_ids[0]
- seq_len = len(seq.current_token_ids)
- print(f"seq.current_token_ids: {seq.current_token_ids}")
- seq.current_token_ids.append(tok)
- seq.decode_token_ids = seq.decode_token_ids[1:]
-
- decode_tokens_host[i, 0] = tok
- decode_start_pos_host[i] = seq_len
- decode_seq_lens_host[i] = seq_len
- for j in range(len(seq.attn_blocks)):
- decode_attn_block_indices_host[i, j] = attn_blocks[j].index
-
- # Perform h2d transfers.
- cb.end()
- work_queue.execute_sequential(cb)
-
- # Inputs:
- # token_ids
- # seq_lens
- # start_pos
- # attn_block_indices
- # attn_block_buffer_view (the entire slab passed as input)
- # wait, signal semaphores
- # tied attn_block_buffer (for input[4])
- # tied attn_block_buffer (for result[0])
- inputs = VmVariantList(5)
- inputs.push_ref(decode_tokens_device)
- inputs.push_ref(decode_seq_lens_device)
- inputs.push_ref(decode_start_pos_device)
- inputs.push_ref(decode_attn_block_indices_device)
- inputs.push_ref(service.cache.attn_block_buffer_view)
-
- # Outputs:
- # attn_block_buffer_view (tied output)
- # decode_tokens
- outputs = VmVariantList(1)
- # TODO: Async invoke.
- hc.vm_context.invoke(self._decode_function, inputs, outputs)
- return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView))
diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py
deleted file mode 100644
index 7895341c9..000000000
--- a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import asyncio
-import argparse
-import numpy
-import sys
-
-from transformers import LlamaTokenizer # type: ignore
-
-from iree.runtime import ( # type: ignore
- HalElementType,
-)
-
-from sharktank.serving_poc.framework.session import DeviceSession
-
-from sharktank.serving_poc.llm.attn_block_cache import (
- create_attn_block_cache_module,
- AttnBlockCache,
-)
-
-from sharktank.serving_poc.llm.config import (
- CacheParams,
- ModelParams,
- ServiceParams,
-)
-
-from sharktank.serving_poc.llm.impl.service_v1 import GenerateServiceV1
-from sharktank.serving_poc.llm.service import GenerateRequest
-
-
-def setup(vmfb_path, config_path, gguf_path):
- from iree.runtime._binding import disable_leak_checker # type: ignore
-
- model_params = ModelParams.load_json(config_path)
-
- device_block_count = model_params.max_seq_len // model_params.block_seq_stride
- cache_params = CacheParams(
- model=model_params,
- device_block_count=device_block_count,
- block_pos_stride=model_params.block_seq_stride,
- )
-
- disable_leak_checker()
- session = DeviceSession(uri="local-sync", queue_count=2)
- attn_block_cache = AttnBlockCache(session, cache_params)
-
- lms = session.create_module_set(model_params.module_name, context_count=1)
- lms.load_io_module(gguf_path)
- lms.load_vmfb(vmfb_path)
- lms.add(create_attn_block_cache_module(attn_block_cache))
- lms.initialize()
-
- params = ServiceParams(cache=cache_params, model=model_params)
- service = GenerateServiceV1(session=session, params=params, cache=attn_block_cache)
- return service
-
-
-def map_buffer(value):
- mapped = value.map()
- return mapped.asarray(value.shape, HalElementType.map_to_dtype(value.element_type))
-
-
-async def main(argv):
- parser = argparse.ArgumentParser()
- parser.add_argument("--tokenizer", help="name of hugginface tokenizer to use")
- parser.add_argument("--config", help="json config file with hyperparameters")
- parser.add_argument("--vmfb", help="vmfb with compiler LLM kernels")
- parser.add_argument("--gguf", help="gguf file containing modle coefficients")
- parsed = parser.parse_args(argv)
-
- hf_path = parsed.tokenizer
- config_path = parsed.config
- vmfb_path = parsed.vmfb
- gguf_path = parsed.gguf
-
- service = setup(vmfb_path, config_path, gguf_path)
- tokenizer = LlamaTokenizer.from_pretrained(hf_path)
- state = service.start()
-
- for line in ["one two three four five six seven eight"]:
- prompt = line.strip()
- if not prompt:
- break
-
- input_ids = tokenizer.encode(prompt, return_tensors="pt")[0].tolist()
- print(input_ids)
- request = GenerateRequest("request_id", prompt, input_ids)
- await state.set_sequences([request])
- logits = await state.prefill()
-
- seq_len = len(input_ids)
- mapped_logits = map_buffer(logits.value)
- predicted_tokens = numpy.argmax(mapped_logits[0, :seq_len], axis=-1)
- predicted_token = predicted_tokens[-1]
- decoded_token = tokenizer.decode(predicted_token)
- print(f"Prefill predicted token: {predicted_token}, decoded: '{decoded_token}'")
-
- # TODO(scotttodd): sanity check tokenizer use, document inputs/outputs
- # 'prefill' is for initialization with multiple steps at once
- # 'decode' is for hypothesis exploration, one step at a time
- await state.set_decode_step([predicted_token])
- logits = await state.decode()
- mapped_logits = map_buffer(logits.value)
- predicted_tokens = numpy.argmax(mapped_logits, axis=-1)
- predicted_token = predicted_tokens[0]
- decoded_token = tokenizer.decode(predicted_token)
- print(f"Decode predicted token: {predicted_token}, decoded: '{decoded_token}'")
- await state.recycle()
-
- service.shutdown()
-
-
-if __name__ == "__main__":
- asyncio.run(main(sys.argv[1:]))
diff --git a/sharktank/sharktank/serving_poc/llm/service.py b/sharktank/sharktank/serving_poc/llm/service.py
deleted file mode 100644
index c5d4ffb44..000000000
--- a/sharktank/sharktank/serving_poc/llm/service.py
+++ /dev/null
@@ -1,189 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from typing import AsyncIterator, Callable, Optional
-
-from abc import abstractmethod, ABC
-import asyncio
-from dataclasses import dataclass
-
-from ..framework.session import (
- HostContext,
-)
-
-########################################################################################
-# User-level single request service
-########################################################################################
-
-
-@dataclass
-class GenerateRequest:
- """Encapsulates a request to perform LLM generation.
-
- Requests are bootstrapped from user values and then pumped through the pipeline,
- receiving additional elaboration needed to actually begin generation.
- """
-
- # Client set fields
- request_id: str
- prompt: str
-
- # Fields that are set as the request is processed.
- prompt_token_ids: Optional[list[int]] = None
-
- @property
- def required_prompt_token_ids(self) -> list[int]:
- ids = self.prompt_token_ids
- assert ids is not None
- return ids
-
-
-@dataclass
-class GenerateResponsePart:
- """A response part from an LLM generation request."""
-
- request: GenerateRequest
- index: int
- token_ids: list[int]
-
- # Fields that can be set as the response is post-processed.
- text: Optional[str] = None
- finished: bool = False
-
-
-class GenerateService(ABC):
- """Asynchronous generator service which processes requests into response parts."""
-
- @abstractmethod
- def handle_request(
- self,
- request: GenerateRequest,
- ) -> AsyncIterator[GenerateResponsePart]:
- """Generates response parts for a request."""
- ...
-
- @abstractmethod
- async def abort(self, request_id: str) -> None:
- """Aborts a submitted request."""
- ...
-
-
-########################################################################################
-# Batch generation service
-# This service is completely asynchronous and operates on a BatchGenerateRequest as
-# a state machine. It is expected to have an external actor stepping it through
-# states.
-########################################################################################
-
-
-class BatchGenerateService(ABC):
- """Handles generation of a batch of requests."""
-
- __slots__ = [] # type: ignore
-
- # def start_prefill(self, request: BatchGenerateRequest):
- # ...
- @abstractmethod
- def start(self) -> "BatchGenerateState":
- ...
-
-
-class BatchGenerateState(ABC):
- """In-progress batch generation state."""
-
- __slots__ = [
- "host_context",
- ]
-
- def __init__(self, host_context: HostContext):
- self.host_context = host_context
-
-
-########################################################################################
-# Utilities
-########################################################################################
-
-
-class SyncGenerateFilter(GenerateService):
- """GenerateService filter which can synchronously pre/post process."""
-
- __slots__ = ["_next"]
-
- def __init__(self, next: GenerateService):
- self._next = next
-
- def filter_request(self, request: GenerateRequest):
- ...
-
- def filter_response(self, part: GenerateResponsePart):
- ...
-
- async def handle_request(
- self,
- request: GenerateRequest,
- ) -> AsyncIterator[GenerateResponsePart]:
- self.filter_request(request)
- async for part in self._next.handle_request(request):
- self.filter_response(part)
- yield part
-
- async def abort(self, request_id: str) -> None:
- """Aborts a submitted request."""
- await self._next.abort(request_id)
-
-
-########################################################################################
-# Testing and mock types
-########################################################################################
-
-
-def create_mock_generate_service() -> GenerateService:
- return DummyTokenizerService(EchoGenerateService())
-
-
-class DummyTokenizerService(SyncGenerateFilter):
- """Tokenizer service which will map to code points.
-
- Useful for testing.
- """
-
- def filter_request(self, request: GenerateRequest):
- if request.prompt_token_ids is None:
- request.prompt_token_ids = [ord(c) for c in request.prompt]
-
- def filter_response(self, part: GenerateResponsePart):
- if part.text is None:
- part.text = "".join([chr(x) for x in part.token_ids])
-
-
-class EchoGenerateService(GenerateService):
- """Dummy implementation of a generate service.
-
- It just echoes back the request five times after a delay.
- """
-
- def __init__(self, delay: float = 0.1):
- self._delay = delay
-
- async def handle_request(
- self,
- request: GenerateRequest,
- ) -> AsyncIterator[GenerateResponsePart]:
- next = None
- for i in range(5):
- if next:
- yield next
- assert request.prompt_token_ids, "Request lacks prompt tokens"
- next = GenerateResponsePart(
- request, i, request.prompt_token_ids, finished=False
- )
- await asyncio.sleep(self._delay)
- if next:
- next.finished = True
- yield next
-
- async def abort(self, request_id: str) -> None:
- pass
diff --git a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py b/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py
deleted file mode 100644
index a36ebe667..000000000
--- a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-"""Implements a service_v1 compliant module in Python for testing.
-
-This uses a PyModuleInterface to define a fake VmModule that exposes 'prefill_bs{n}'
-and 'decode_bs{n}' such that the call sequence and args/results can be manipulated.
-"""
-
-import numpy as np
-import textwrap
-import threading
-
-from iree.runtime import ( # type: ignore
- BufferUsage,
- HalBuffer,
- HalBufferView,
- HalDevice,
- HalElementType,
- HalFence,
- MemoryType,
- PyModuleInterface,
- VmModule,
- VmRef,
-)
-
-from ..config import ModelParams
-
-
-def create_fake_module(
- device: HalDevice, module_name: str, model_params: ModelParams
-) -> VmModule:
- class ServiceV1Module:
- def __init__(self, iface):
- ...
- print("IFACE:", iface, dir(iface))
-
- def prefill(
- self,
- bs: int,
- token_ids_ref: VmRef,
- seq_lens_ref: VmRef,
- attn_block_indices_ref: VmRef,
- attn_block_buffer_view: VmRef,
- ):
- result_array: np.ndarray = np.ndarray([bs, 1], dtype=np.int32)
-
- def run():
- print(f"FAKE_V1_MODULE: PREFILL bs={bs} : WAIT")
- print(" - READY")
- _format_device_buffer_view(
- lambda s: print(" token_ids =", s), token_ids_ref
- )
- _format_device_buffer_view(
- lambda s: print(" seq_lens =", s), seq_lens_ref
- )
- _format_device_buffer_view(
- lambda s: print(" attn_block_indices =", s),
- attn_block_indices_ref,
- )
- _format_device_buffer_view(
- lambda s: print(" attn_block_buffer_view =", s),
- attn_block_buffer_view,
- )
-
- # Async populate.
- device_array = result_bv.map().asarray(
- result_array.shape, result_array.dtype
- )
- for i in range(bs):
- device_array[i, 0] = i + 1
-
- threading.Thread(target=run).start()
-
- result_buffer = device.allocator.allocate_buffer(
- memory_type=MemoryType.DEVICE_LOCAL | MemoryType.HOST_VISIBLE,
- allowed_usage=BufferUsage.DEFAULT,
- allocation_size=result_array.size * result_array.itemsize,
- )
- result_bv = HalBufferView(
- result_buffer, result_array.shape, HalElementType.INT_32
- )
- return result_bv.ref
-
- def decode(self, bs: int):
- print(f"FAKE_V1_MODULE: DECODE bs={bs}")
-
- iface = PyModuleInterface(module_name=module_name, ctor=ServiceV1Module)
-
- # Dynamically define prefill functions.
- def add_prefill_bs(bs: int):
- def trampoline(self, *args):
- return self.prefill(bs, *args)
-
- iface.export(f"prefill_bs{bs}", "0rrrr_r", trampoline)
-
- [add_prefill_bs(bs) for bs in model_params.prefill_batch_sizes]
-
- # Dynamically define decode functions.
- def add_decode_bs(bs: int):
- def trampoline(self, *args):
- return self.decode(bs, *args)
-
- iface.export(f"decode_bs{bs}", "0v_v", trampoline)
-
- [add_decode_bs(bs) for bs in model_params.decode_batch_sizes]
-
- return iface.create()
-
-
-def _format_device_buffer_view(callback, bv_ref: VmRef):
- bv = bv_ref.deref(HalBufferView) # type: HalBufferView
- value = bv.map().asarray(bv.shape, HalElementType.map_to_dtype(bv.element_type))
- value_indented = textwrap.indent(repr(value), " ")
- callback(f"{bv!r}\n{value_indented}")
diff --git a/sharktank/sharktank/serving_poc/py.typed b/sharktank/sharktank/serving_poc/py.typed
deleted file mode 100644
index 5e43cc13b..000000000
--- a/sharktank/sharktank/serving_poc/py.typed
+++ /dev/null
@@ -1 +0,0 @@
-# Marker file for PEP 561 inline type checking.
diff --git a/sharktank/sharktank/transforms/dataset/__init__.py b/sharktank/sharktank/transforms/dataset/__init__.py
index b6a2a400a..e2a58ea5d 100644
--- a/sharktank/sharktank/transforms/dataset/__init__.py
+++ b/sharktank/sharktank/transforms/dataset/__init__.py
@@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .sharding import *
+from .dataset import *
diff --git a/sharktank/sharktank/transforms/dataset/dataset.py b/sharktank/sharktank/transforms/dataset/dataset.py
new file mode 100644
index 000000000..c600865e4
--- /dev/null
+++ b/sharktank/sharktank/transforms/dataset/dataset.py
@@ -0,0 +1,19 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import torch
+
+from ...types.tensors import InferenceTensor, PrimitiveTensor, DefaultPrimitiveTensor
+from ... import ops
+
+
+def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor:
+ if isinstance(tensor, PrimitiveTensor) and tensor.dtype.is_floating_point:
+ return DefaultPrimitiveTensor(
+ name=tensor.name, data=ops.to(tensor, dtype=dtype)
+ )
+
+ return tensor
diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py
index 9a7dcf1ee..494607f97 100644
--- a/sharktank/sharktank/types/gguf_interop/base.py
+++ b/sharktank/sharktank/types/gguf_interop/base.py
@@ -118,6 +118,15 @@ def _wrap_tensor(
name=name, data=_externalize_tensor(name, data, logical_shape)
)
+ if type_name == "BF16":
+ assert data.dtype == np.uint8
+ return DefaultPrimitiveTensor(
+ name=name,
+ data=_externalize_tensor(name, data.view(np.int16), logical_shape).view(
+ dtype=torch.bfloat16
+ ),
+ )
+
quantized_type = _quantized_types.get(type_name)
if quantized_type is not None:
return quantized_type(
diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py
index f870aa101..2c267ac49 100644
--- a/sharktank/sharktank/types/tensors.py
+++ b/sharktank/sharktank/types/tensors.py
@@ -41,6 +41,7 @@
"AnyTensor",
"DefaultPrimitiveTensor",
"dtype_to_serialized_name",
+ "dtype_to_serialized_short_name",
"flatten_tensor_tree",
"InferenceTensor",
"MetaDataValueType",
@@ -51,6 +52,7 @@
"register_quantized_layout",
"ReplicatedTensor",
"serialized_name_to_dtype",
+ "serialized_short_name_to_dtype",
"ShardedTensor",
"SplitPrimitiveTensor",
"torch_tree_flatten",
@@ -1286,6 +1288,15 @@ def dtype_to_serialized_name(dtype: torch.dtype) -> str:
) from e
+def dtype_to_serialized_short_name(dtype: torch.dtype) -> str:
+ try:
+ return _DTYPE_TO_SHORT_NAME[dtype]
+ except KeyError as e:
+ raise KeyError(
+ f"Missing mapping for dtype {dtype}. Please add to the _SHORT_NAME_TO_DTYPE dict"
+ ) from e
+
+
def serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
try:
return _NAME_TO_DTYPE[dtype_name]
@@ -1295,6 +1306,15 @@ def serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
) from e
+def serialized_short_name_to_dtype(dtype_name: str) -> torch.dtype:
+ try:
+ return _SHORT_NAME_TO_DTYPE[dtype_name]
+ except KeyError as e:
+ raise KeyError(
+ f"Missing mapping for dtype '{dtype_name}'. Please add to the _SHORT_NAME_TO_DTYPE dict"
+ ) from e
+
+
_NAME_TO_DTYPE: dict[str, torch.dtype] = {
"float32": torch.float32,
"float64": torch.float64,
@@ -1338,6 +1358,26 @@ def _maybe_dtype(*names: str):
_DTYPE_TO_NAME: dict[torch.dtype, str] = {v: k for k, v in _NAME_TO_DTYPE.items()}
+_SHORT_NAME_TO_DTYPE: dict[str, torch.dtype] = {
+ "f32": torch.float32,
+ "f64": torch.float64,
+ "c64": torch.complex64,
+ "c128": torch.complex128,
+ "f16": torch.float16,
+ "bf16": torch.bfloat16,
+ "ui8": torch.uint8,
+ "i8": torch.int8,
+ "i16": torch.int16,
+ "i32": torch.int32,
+ "i64": torch.int64,
+ "b": torch.bool,
+ "f8_e4m3fnuz": torch.float8_e4m3fnuz,
+}
+
+_DTYPE_TO_SHORT_NAME: dict[torch.dtype, str] = {
+ v: k for k, v in _SHORT_NAME_TO_DTYPE.items()
+}
+
AnyTensor = Union[torch.Tensor, InferenceTensor]
########################################################################################
diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py
index bd33e1a62..c950a875a 100644
--- a/sharktank/sharktank/utils/export_artifacts.py
+++ b/sharktank/sharktank/utils/export_artifacts.py
@@ -9,6 +9,7 @@
import subprocess
import logging
import time
+import re
from pathlib import Path
from datetime import timedelta
from typing import List, Optional
@@ -107,11 +108,18 @@ def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
- seconds = end - start
- time_taken = abs(timedelta(seconds=round(seconds)))
-
- if seconds < 1:
- time_taken = f" {seconds * 1000} ms"
+ total_seconds = end - start
+ time_taken = abs(timedelta(seconds=total_seconds))
+ hours, minutes, seconds = re.split(":", str(time_taken))
+
+ if total_seconds < 1:
+ time_taken = f" {round(total_seconds * 1000, 3)} ms"
+ elif total_seconds < 60:
+ time_taken = "{:.2f} secs".format(round(float(total_seconds), 2))
+ else:
+ time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format(
+ int(hours), int(minutes), round(float(seconds), 2)
+ )
func_name = func.__name__
logger.info(f" {func_name}: {time_taken}")
@@ -180,13 +188,13 @@ def export_to_mlir(
cwd = self.sharktank_dir
cmd = subprocess.list2cmdline(export_args)
- logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}")
+ logger.info(f" Exporting mlir:\n" f"cd {cwd} && {cmd}")
proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True)
if proc.returncode != 0:
raise ExportMlirException(proc, cwd)
else:
- logger.info(f"Exported to mlir successfully:\n" f"{proc.stdout}")
+ logger.info(f" Exported to mlir successfully:\n" f"{proc.stdout}")
return proc.returncode
@@ -223,7 +231,7 @@ def compile_to_vmfb(
compile_args += args
cmd = subprocess.list2cmdline(compile_args)
- logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}")
+ logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}")
proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd)
return_code = proc.returncode
if return_code != 0:
@@ -277,7 +285,7 @@ def iree_benchmark_vmfb(
benchmark_args += devices
benchmark_args += args
cmd = subprocess.list2cmdline(benchmark_args)
- logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}")
+ logger.info(f" Launching run command:\n" f"cd {cwd} && {cmd}")
proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd)
return_code = proc.returncode
if return_code != 0:
diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py
index d5976ec48..a9097cf06 100644
--- a/sharktank/sharktank/utils/iree.py
+++ b/sharktank/sharktank/utils/iree.py
@@ -71,6 +71,54 @@ def load_iree_module(
return vm_module, vm_context, vm_instance
+def promote_bfloat16_to_float32(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.dtype == torch.bfloat16:
+ return tensor.to(dtype=torch.float32)
+ else:
+ return tensor
+
+
+def device_array_to_host(device_array: iree.runtime.DeviceArray) -> torch.Tensor:
+ def reinterpret_hal_buffer_view_element_type(
+ buffer_view: iree.runtime.HalBufferView,
+ element_type: iree.runtime.HalElementType,
+ ) -> iree.runtime.HalBufferView:
+ return iree.runtime.HalBufferView(
+ buffer=buffer_view.get_buffer(),
+ shape=buffer_view.shape,
+ element_type=element_type,
+ )
+
+ def reinterpret_device_array_dtype(
+ device_array: iree.runtime.DeviceArray, dtype: np.dtype
+ ) -> iree.runtime.DeviceArray:
+ return iree.runtime.DeviceArray(
+ device=device_array._device,
+ buffer_view=reinterpret_hal_buffer_view_element_type(
+ device_array._buffer_view,
+ iree.runtime.array_interop.map_dtype_to_element_type(dtype),
+ ),
+ )
+
+ # Circumvent the lack of bfloat16 in numpy.
+ # TODO: This uses private fields _device and _buffer_view in iree.runtime.DeviceArray.
+ # Improve DeviceArray to provide a hatchet to allow for reinterpretation of
+ # element type of the underlying buffer.
+ def bfloat16_device_array_to_torch(
+ device_array: iree.runtime.DeviceArray,
+ ) -> torch.Tensor:
+ device_array_as_int16 = reinterpret_device_array_dtype(device_array, np.int16)
+ torch_tensor_as_int16 = torch.tensor(device_array_as_int16.to_host())
+ return torch_tensor_as_int16.view(dtype=torch.bfloat16)
+
+ if device_array._buffer_view.element_type == int(
+ iree.runtime.HalElementType.BFLOAT_16
+ ):
+ return bfloat16_device_array_to_torch(device_array)
+ else:
+ return torch.tensor(device_array.to_host())
+
+
def run_iree_module_function(
module: iree.runtime.VmModule,
vm_context: iree.runtime.VmContext,
@@ -88,9 +136,13 @@ def run_iree_module_function(
device=iree.runtime.get_device(driver, cache=False),
vm_function=vm_function,
)
+
if trace_path_prefix is not None:
for i, arg in enumerate(args):
- np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host())
+ np.save(
+ f"{trace_path_prefix}{function_name}_arg{i}.npy",
+ promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(),
+ )
results = invoker(*args)
if isinstance(results, iree.runtime.DeviceArray):
results = (results,)
@@ -99,10 +151,13 @@ def run_iree_module_function(
for i, arg in enumerate(args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy",
- arg.to_host(),
+ device_array_to_host(arg).numpy(),
)
for i, arg in enumerate(results):
- np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host())
+ np.save(
+ f"{trace_path_prefix}{function_name}_result{i}.npy",
+ promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(),
+ )
return results
@@ -158,7 +213,7 @@ def call_torch_module_function(
for i, arg in enumerate(flat_args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}.npy",
- arg.to("cpu").numpy(),
+ promote_bfloat16_to_float32(arg.to("cpu")).numpy(),
)
res = getattr(module, function_name)(**kwargs)
if trace_path_prefix is not None:
@@ -166,7 +221,7 @@ def call_torch_module_function(
for i, arg in enumerate(flat_args):
np.save(
f"{trace_path_prefix}{function_name}_arg{i}.npy",
- arg.to("cpu").numpy(),
+ promote_bfloat16_to_float32(arg.to("cpu")).numpy(),
)
results = (
(res,)
@@ -189,4 +244,4 @@ def call_torch_module_function(
def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
- return [torch.tensor(tensor.to_host()) for tensor in tensors]
+ return [device_array_to_host(tensor) for tensor in tensors]
diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py
index acf56eb1b..47d9f0244 100644
--- a/sharktank/sharktank/utils/load_llm.py
+++ b/sharktank/sharktank/utils/load_llm.py
@@ -23,24 +23,20 @@ def __init__(
self,
model: PagedLlamaModelV1,
tokenizer: InferenceTokenizer,
- page_cache_size: int = 8192,
# Need to look at the model more for this.
end_token: int = 2,
):
self.model = model
self.tokenizer = tokenizer
- if model.cache.is_paged:
- self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
- self.free_pages = list(range(1, page_cache_size))
- else:
- self.shared_cache_state = None
self.end_token = end_token
@property
def block_seq_stride(self) -> int:
return self.model.cache.block_seq_stride
- def begin_batch(self, prompts: list[str], add_start_token: bool):
+ def begin_batch(
+ self, prompts: list[str], add_start_token: bool, page_cache_size: int = 128
+ ):
token_ids, seq_lens = self.tokenizer.encode(
prompts,
pad_to_multiple_of=self.model.cache.pad_sequence_stride,
@@ -48,8 +44,10 @@ def begin_batch(self, prompts: list[str], add_start_token: bool):
)
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
- if self.shared_cache_state is not None:
- cache_state = self.shared_cache_state
+
+ if self.model.cache.is_paged:
+ cache_state = self.model.cache.paged.allocate(page_cache_size)
+ self.free_pages = list(range(1, page_cache_size))
else:
cache_state = self.model.cache.direct.allocate(bs=len(prompts))
return Batch(self, token_ids, seq_lens, cache_state)
@@ -59,10 +57,11 @@ def begin_eval_batch(
token_batch: torch.tensor,
seq_lens_batch: torch.tensor,
bs: int,
+ page_cache_size: int = 128,
):
-
- if self.shared_cache_state is not None:
- cache_state = self.shared_cache_state
+ if self.model.cache.is_paged:
+ cache_state = self.model.cache.paged.allocate(page_cache_size)
+ self.free_pages = list(range(1, page_cache_size))
else:
cache_state = self.model.cache.direct.allocate(bs=bs)
return Batch(self, token_batch, seq_lens_batch, cache_state)
diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py
index 933bfd2b6..32acec8ac 100644
--- a/sharktank/sharktank/utils/testing.py
+++ b/sharktank/sharktank/utils/testing.py
@@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
import contextlib
from pathlib import Path
import os
@@ -20,7 +21,7 @@
# Range of torch.rand() is [0,1)
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
-def make_rand_torch(shape, dtype=torch.float32):
+def make_rand_torch(shape: list[int], dtype: Optional[torch.dtype] = torch.float32):
return torch.rand(shape, dtype=dtype) * 2 - 1
diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json
index ac2cd7b83..24511b05f 100644
--- a/sharktank/tests/evaluate/baseline_perplexity_scores.json
+++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json
@@ -210,7 +210,7 @@
],
"mean_perplexity": 6.060831
},
- "llama3_8B_f16_decomposed_vmfb": {
+ "llama3_8B_f16_decomposed_iree": {
"perplexities": [
6.651368,
22.059452,
diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py
similarity index 58%
rename from sharktank/tests/evaluate/perplexity_vmfb_test.py
rename to sharktank/tests/evaluate/perplexity_iree_test.py
index 93ffbe61c..d10d9f5db 100644
--- a/sharktank/tests/evaluate/perplexity_vmfb_test.py
+++ b/sharktank/tests/evaluate/perplexity_iree_test.py
@@ -7,10 +7,15 @@
import unittest
import pytest
import json
+import numpy as np
-from sharktank.evaluate import perplexity_vmfb
+from sharktank.evaluate import perplexity_iree
-longrun = pytest.mark.skipif("not config.getoption('longrun')")
+is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
+skipif_run_quick_llama_test = pytest.mark.skipif(
+ 'not config.getoption("run-nightly-llama-tests")',
+ reason="Run large tests if --run-nightly-llama-tests is passed",
+)
@pytest.mark.usefixtures(
@@ -18,7 +23,9 @@
"get_iree_flags",
"tensor_parallelism_size",
"baseline_perplexity_scores",
+ "batch_size",
)
+@is_mi300x
class PerplexityTest(unittest.TestCase):
def setUp(self):
self.current_perplexity_all = {}
@@ -27,15 +34,14 @@ def setUp(self):
with open(self.baseline_perplexity_scores, "r") as f:
self.baseline_perplexity = json.load(f)
- @longrun
def test_llama3_8B_f16_decomposed(self):
# Llama 3.1 8B decomposed
- model_name = "llama3_8B_f16_decomposed_vmfb"
+ model_name = "llama3_8B_f16_decomposed_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_8b_f16_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
@@ -44,33 +50,34 @@ def test_llama3_8B_f16_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
- @pytest.mark.xfail(
- reason="Non-decomposed attention is not supported yet",
- )
- @longrun
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
def test_llama3_8B_f16(self):
# Llama 3.1 8B non-decomposed
- model_name = "llama3_8B_f16_vmfb"
+ model_name = "llama3_8B_f16_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_8b_f16_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
@@ -79,33 +86,34 @@ def test_llama3_8B_f16(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
- @pytest.mark.xfail(
- reason="FP8 model is unsupported",
- )
- @longrun
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
def test_llama3_8B_fp8_decomposed(self):
# Llama 3.1 8B decomposed
- model_name = "llama3_8B_fp8_decomposed_vmfb"
+ model_name = "llama3_8B_fp8_decomposed_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_8b_fp8_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
@@ -114,33 +122,34 @@ def test_llama3_8B_fp8_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
- @pytest.mark.xfail(
- reason="FP8 model is unsupported",
- )
- @longrun
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
def test_llama3_8B_fp8(self):
# Llama 3.1 8B non-decomposed
- model_name = "llama3_8B_fp8_vmfb"
+ model_name = "llama3_8B_fp8_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_8b_fp8_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
@@ -149,33 +158,36 @@ def test_llama3_8B_fp8(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
+ @skipif_run_quick_llama_test
@pytest.mark.xfail(
reason="Sharding is unsupported",
)
- @longrun
def test_llama3_405B_f16_decomposed(self):
# Llama 3.1 405B decomposed
- model_name = "llama3_405B_f16_decomposed_vmfb"
+ model_name = "llama3_405B_f16_decomposed_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_405b_f16_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
@@ -184,33 +196,34 @@ def test_llama3_405B_f16_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
- @pytest.mark.xfail(
- reason="Non-decomposed attention is not supported yet",
- )
- @longrun
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
def test_llama3_405B_f16(self):
# Llama 3.1 405B non-decomposed
- model_name = "llama3_405B_f16_vmfb"
+ model_name = "llama3_405B_f16_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_405b_f16_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
@@ -219,33 +232,34 @@ def test_llama3_405B_f16(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
- @pytest.mark.xfail(
- reason="FP8 model is unsupported",
- )
- @longrun
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
def test_llama3_405B_fp8_decomposed(self):
# Llama 3.1 405B decomposed
- model_name = "llama3_405B_fp8_decomposed_vmfb"
+ model_name = "llama3_405B_fp8_decomposed_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_405b_fp8_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
@@ -254,33 +268,34 @@ def test_llama3_405B_fp8_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
- @pytest.mark.xfail(
- reason="FP8 model is unsupported",
- )
- @longrun
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
def test_llama3_405B_fp8(self):
# Llama 3.1 405B non-decomposed
- model_name = "llama3_405B_fp8_vmfb"
+ model_name = "llama3_405B_fp8_iree"
baseline_perplexity = self.baseline_perplexity[model_name]
- current_perplexity = perplexity_vmfb.main(
+ current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_405b_fp8_model}",
f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
@@ -289,17 +304,20 @@ def test_llama3_405B_fp8(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
]
)
- perplexity_difference = (
- current_perplexity["mean_perplexity"]
- - baseline_perplexity["mean_perplexity"]
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
self.assertAlmostEqual(
- baseline_perplexity["mean_perplexity"],
- current_perplexity["mean_perplexity"],
+ baseline_mean_perplexity,
+ current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)
diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py
index 125a0cfdc..751615a85 100644
--- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py
+++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py
@@ -197,7 +197,6 @@ def testBenchmark8B_f16_Decomposed(self):
)
@skipif_run_quick_llama_test
- @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def testBenchmark8B_f16_Non_Decomposed_Prefill(self):
output_file_name = self.dir_path_8b / "f16_torch_prefill"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
@@ -780,7 +779,9 @@ def testBenchmark405B_f16_TP8_Decomposed(self):
cwd=self.repo_root,
)
- @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ @pytest.mark.xfail(
+ reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
+ )
def testBenchmark405B_f16_TP8_Non_Decomposed(self):
output_file_name = self.dir_path_405b / "f16_torch"
output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file(
@@ -828,7 +829,9 @@ def testBenchmark405B_f16_TP8_Non_Decomposed(self):
cwd=self.repo_root,
)
- @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ @pytest.mark.xfail(
+ reason="KeyError in theta.py", strict=True, raises=ExportMlirException
+ )
def testBenchmark405B_fp8_TP8_Decomposed(self):
output_file_name = self.dir_path_405b / "fp8_decomposed"
output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file(
@@ -874,7 +877,9 @@ def testBenchmark405B_fp8_TP8_Decomposed(self):
cwd=self.repo_root,
)
- @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ @pytest.mark.xfail(
+ reason="KeyError in theta.py", strict=True, raises=ExportMlirException
+ )
def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
output_file_name = self.dir_path_405b / "fp8_torch"
output_mlir = self.llama405b_fp8_torch_sdpa_artifacts.create_file(
diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py
index 076404e5d..1a696ba57 100644
--- a/sharktank/tests/models/t5/t5_test.py
+++ b/sharktank/tests/models/t5/t5_test.py
@@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+import functools
from transformers.models.t5.modeling_t5 import (
T5Attention as ReferenceT5Attention,
T5LayerSelfAttention as ReferenceT5LayerSelfAttention,
@@ -14,13 +15,21 @@
T5EncoderModel as ReferenceT5EncoderModel,
T5Config as ReferenceT5Config,
)
+from typing import Optional
import os
from collections import OrderedDict
import pytest
import torch
+from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path
from unittest import TestCase
from parameterized import parameterized
-from sharktank.types import Theta, DefaultPrimitiveTensor, unbox_tensor, Dataset
+from sharktank.types import (
+ Theta,
+ DefaultPrimitiveTensor,
+ unbox_tensor,
+ Dataset,
+ dtype_to_serialized_short_name,
+)
from sharktank.models.t5 import (
T5Attention,
T5SelfAttention,
@@ -41,6 +50,8 @@
flatten_for_iree_signature,
iree_to_torch,
)
+from sharktank.transforms.dataset import set_float_dtype
+from sharktank import ops
import iree.compiler
with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')")
@@ -67,45 +78,210 @@ def setUp(self):
torch.random.manual_seed(12345)
torch.no_grad()
- def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace(
- self, huggingface_repo_id: str
+ @with_t5_data
+ def testXxlBf16AgainstFluxGolden(self):
+ """The ground-truth values were acquired from the Flux pipeline."""
+ target_model_name = (
+ f"{'google/t5-v1_1-xxl'.replace('/', '__').replace('-', '_')}_f32_model"
+ )
+ target_model_path = getattr(self, target_model_name)
+ dataset = Dataset.load(target_model_path)
+ dataset.root_theta = dataset.root_theta.transform(
+ functools.partial(set_float_dtype, dtype=torch.bfloat16)
+ )
+ config = T5Config.from_gguf_properties(
+ dataset.properties,
+ feed_forward_proj="gated-gelu",
+ )
+ model = T5Encoder(theta=dataset.root_theta, config=config)
+ model.eval()
+
+ with open(
+ "/data/t5/xxl/flux_schnell_t5_v1_1_xxl_encoder_bf16_input_ids.pt", "rb"
+ ) as f:
+ reference_input_ids = torch.load(f)
+
+ outputs = model(
+ input_ids=reference_input_ids,
+ attention_mask=None,
+ output_hidden_states=False,
+ )
+
+ with open(
+ "/data/t5/xxl/flux_schnell_t5_v1_1_xxl_encoder_bf16_output_last_hidden_state.pt",
+ "rb",
+ ) as f:
+ reference_last_hidden_state = torch.load(f)
+
+ torch.testing.assert_close(
+ outputs["last_hidden_state"], reference_last_hidden_state
+ )
+
+ def runTestV1_1CompareTorchEagerHuggingFace(
+ self,
+ huggingface_repo_id: str,
+ reference_dtype: torch.dtype,
+ target_dtype: torch.dtype,
+ atol: Optional[float] = None,
+ rtol: Optional[float] = None,
):
get_dataset(
huggingface_repo_id,
).download()
tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id)
- reference_model = ReferenceT5EncoderModel.from_pretrained(huggingface_repo_id)
+ reference_model = ReferenceT5EncoderModel.from_pretrained(
+ huggingface_repo_id, torch_dtype=reference_dtype
+ )
reference_model.eval()
+ model = ReferenceT5EncoderModel.from_pretrained(
+ huggingface_repo_id, torch_dtype=target_dtype
+ )
+ model.eval()
+
input_ids = tokenizer(
test_prompts,
return_tensors="pt",
padding=True,
+ pad_to_multiple_of=16,
).input_ids
+ expected_outputs = dict(reference_model(input_ids=input_ids))
+ actual_outputs = dict(model(input_ids=input_ids))
+ actual_outputs = tree_map(
+ lambda t: ops.to(t, dtype=reference_dtype), actual_outputs
+ )
+
+ torch.testing.assert_close(
+ actual_outputs, expected_outputs, atol=atol, rtol=rtol
+ )
+
+ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
+ self,
+ huggingface_repo_id: str,
+ reference_dtype: torch.dtype,
+ target_dtype: torch.dtype,
+ atol: Optional[float] = None,
+ rtol: Optional[float] = None,
+ ):
+ get_dataset(
+ huggingface_repo_id,
+ ).download()
+ tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id)
+ reference_model = ReferenceT5EncoderModel.from_pretrained(
+ huggingface_repo_id, torch_dtype=reference_dtype
+ )
+ reference_model.eval()
+
target_model_name = (
- f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_fp32_model"
+ f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_f32_model"
)
target_model_path = getattr(self, target_model_name)
dataset = Dataset.load(target_model_path)
+ dataset.root_theta = dataset.root_theta.transform(
+ functools.partial(set_float_dtype, dtype=target_dtype)
+ )
config = T5Config.from_gguf_properties(
dataset.properties,
feed_forward_proj="gated-gelu",
)
+
+ input_ids = tokenizer(
+ test_prompts,
+ return_tensors="pt",
+ padding=True,
+ pad_to_multiple_of=config.context_length_padding_block_size,
+ ).input_ids
+
model = T5Encoder(theta=dataset.root_theta, config=config)
model.eval()
expected_outputs = reference_model(input_ids=input_ids)
actual_outputs = model(input_ids=input_ids)
- torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0)
+ actual_outputs = tree_map(
+ lambda t: ops.to(t, dtype=reference_dtype), actual_outputs
+ )
+
+ torch.testing.assert_close(
+ actual_outputs, expected_outputs, atol=atol, rtol=rtol
+ )
+
+ @pytest.mark.xfail(
+ raises=AssertionError,
+ reason=(
+ "The accuracy is bad, "
+ "but for XXL we get the same result as the Flux pipeline. "
+ "This need further investigation how Flux works at all like that."
+ ),
+ )
+ @with_t5_data
+ def testV1_1SmallCompareTorchEagerHuggingFaceBf16AgainstF32(self):
+ self.runTestV1_1CompareTorchEagerHuggingFace(
+ "google/t5-v1_1-small",
+ reference_dtype=torch.float32,
+ target_dtype=torch.bfloat16,
+ atol=1e-2,
+ rtol=1.6e-2,
+ )
+
+ @with_t5_data
+ def testV1_1SmallF32CompareTorchEagerAgainstHuggingFace(self):
+ self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
+ "google/t5-v1_1-small",
+ reference_dtype=torch.float32,
+ target_dtype=torch.float32,
+ )
+
+ @pytest.mark.xfail(
+ raises=AssertionError,
+ reason=(
+ "The accuracy is bad, "
+ "but for XXL we get the same result as the Flux pipeline. "
+ "This need further investigation how Flux works at all like that."
+ ),
+ )
+ @with_t5_data
+ def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32(self):
+ self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
+ "google/t5-v1_1-small",
+ reference_dtype=torch.float32,
+ target_dtype=torch.bfloat16,
+ atol=1e-2,
+ rtol=1.6e-2,
+ )
@with_t5_data
- def testV1_1SmallFp32CompareTorchEagerAgainstHuggingFace(self):
- self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-small")
+ def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFace(self):
+ self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
+ "google/t5-v1_1-small",
+ reference_dtype=torch.bfloat16,
+ target_dtype=torch.bfloat16,
+ )
@with_t5_data
- def testV1_1XxlFp32CompareTorchEagerAgainstHuggingFace(self):
- self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-xxl")
+ def testV1_1XxlF32CompareTorchEagerAgainstHuggingFace(self):
+ self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
+ "google/t5-v1_1-xxl",
+ reference_dtype=torch.float32,
+ target_dtype=torch.float32,
+ )
+
+ @pytest.mark.xfail(
+ raises=AssertionError,
+ reason=(
+ "The accuracy is bad, but we get the same result as the Flux pipeline. "
+ "This need further investigation how Flux works at all like that."
+ ),
+ )
+ @with_t5_data
+ def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32(self):
+ self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
+ "google/t5-v1_1-xxl",
+ reference_dtype=torch.float32,
+ target_dtype=torch.bfloat16,
+ atol=1e-2,
+ rtol=1.6e-2,
+ )
@pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix")
@@ -115,14 +291,14 @@ def setUp(self):
if self.path_prefix is None:
self.path_prefix = f"{self._temp_dir}/"
- @parameterized.expand(
- [
- "google/t5-v1_1-small",
- "google/t5-v1_1-xxl",
- ]
- )
- @with_t5_data
- def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str):
+ def runTestV1_1CompareIreeAgainstTorchEager(
+ self,
+ huggingface_repo_id: str,
+ reference_dtype: torch.dtype,
+ target_dtype: torch.dtype,
+ atol: Optional[float] = None,
+ rtol: Optional[float] = None,
+ ):
get_dataset(
huggingface_repo_id,
).download()
@@ -131,12 +307,15 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str):
huggingface_repo_id_as_path = (
f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}"
)
- source_model_name = f"{huggingface_repo_id_as_path}_fp32_model"
+ source_model_name = f"{huggingface_repo_id_as_path}_f32_model"
source_model_path = getattr(self, source_model_name)
- dataset = Dataset.load(source_model_path)
+ reference_dataset = Dataset.load(source_model_path)
+ reference_dataset.root_theta = reference_dataset.root_theta.transform(
+ functools.partial(set_float_dtype, dtype=reference_dtype)
+ )
config = T5Config.from_gguf_properties(
- dataset.properties,
+ reference_dataset.properties,
feed_forward_proj="gated-gelu",
)
@@ -149,24 +328,31 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str):
input_args = OrderedDict([("input_ids", input_ids)])
batch_size = input_ids.shape[0]
- reference_model = T5Encoder(theta=dataset.root_theta, config=config)
- reference_result = flatten_for_iree_signature(
- call_torch_module_function(
- module=reference_model,
- function_name="forward",
- kwargs=input_args,
- trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_torch_",
- )
+ reference_dtype_name = dtype_to_serialized_short_name(reference_dtype)
+ target_dtype_name = dtype_to_serialized_short_name(target_dtype)
+ target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{target_dtype_name}"
+
+ reference_model = T5Encoder(theta=reference_dataset.root_theta, config=config)
+ reference_result_dict = call_torch_module_function(
+ module=reference_model,
+ function_name="forward",
+ kwargs=input_args,
+ trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{reference_dtype_name}_torch_",
)
+ reference_result = flatten_for_iree_signature(reference_result_dict)
- mlir_path = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.mlir"
+ parameters_path = f"{target_model_path_prefix}.irpa"
+ if not self.caching or not os.path.exists(parameters_path):
+ export_encoder_iree_parameters(
+ source_model_path, parameters_path, dtype=target_dtype
+ )
+
+ mlir_path = f"{target_model_path_prefix}.mlir"
if not self.caching or not os.path.exists(mlir_path):
export_encoder_mlir(
- source_model_path, batch_sizes=[batch_size], mlir_output_path=mlir_path
+ parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path
)
- iree_module_path = (
- f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.vmfb"
- )
+ iree_module_path = f"{target_model_path_prefix}.vmfb"
if not self.caching or not os.path.exists(iree_module_path):
iree.compiler.compile_file(
mlir_path,
@@ -174,12 +360,6 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str):
extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
)
- parameters_path = (
- f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.irpa"
- )
- if not self.caching or not os.path.exists(parameters_path):
- export_encoder_iree_parameters(source_model_path, parameters_path)
-
iree_devices = get_iree_devices(driver="hip", device_count=1)
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path=iree_module_path,
@@ -196,12 +376,70 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str):
args=iree_args,
driver="hip",
function_name=f"forward_bs{batch_size}",
- trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_iree_",
+ trace_path_prefix=f"{target_model_path_prefix}_iree_",
)
)
+ iree_result = [
+ ops.to(iree_result[i], dtype=reference_result[i].dtype)
+ for i in range(len(reference_result))
+ ]
- torch.testing.assert_close(
- reference_result, iree_result, atol=1e-4, rtol=2.0e-3
+ torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol)
+
+ @with_t5_data
+ def testV1_1CompareSmallIreeF32AgainstTorchEagerF32(self):
+ self.runTestV1_1CompareIreeAgainstTorchEager(
+ "google/t5-v1_1-small",
+ reference_dtype=torch.float32,
+ target_dtype=torch.float32,
+ atol=1e-4,
+ rtol=2.0e-3,
+ )
+
+ @pytest.mark.xfail(
+ raises=AssertionError,
+ reason=(
+ "The accuracy is bad, "
+ "but but it is no worse than the accuracy for of eager bfloat16. "
+ "This need further investigation how Flux works at all like that."
+ ),
+ )
+ @with_t5_data
+ def testV1_1CompareSmallIreeBf16AgainstTorchEagerF32(self):
+ self.runTestV1_1CompareIreeAgainstTorchEager(
+ "google/t5-v1_1-small",
+ reference_dtype=torch.float32,
+ target_dtype=torch.bfloat16,
+ atol=1e-2,
+ rtol=1.6e-2,
+ )
+
+ @with_t5_data
+ def testV1_1CompareXxlIreeF32AgainstTorchEagerF32(self):
+ self.runTestV1_1CompareIreeAgainstTorchEager(
+ "google/t5-v1_1-xxl",
+ reference_dtype=torch.float32,
+ target_dtype=torch.float32,
+ atol=1e-4,
+ rtol=2.0e-3,
+ )
+
+ @pytest.mark.xfail(
+ raises=AssertionError,
+ reason=(
+ "The accuracy is bad, "
+ "but but it is no worse than the accuracy for of eager bfloat16. "
+ "This need further investigation how Flux works at all like that."
+ ),
+ )
+ @with_t5_data
+ def testV1_1CompareXxlIreeBf16AgainstTorchEagerF32(self):
+ self.runTestV1_1CompareIreeAgainstTorchEager(
+ "google/t5-v1_1-xxl",
+ reference_dtype=torch.float32,
+ target_dtype=torch.bfloat16,
+ atol=1e-2,
+ rtol=1.6e-2,
)
@@ -211,8 +449,21 @@ def setUp(self):
torch.random.manual_seed(12345)
torch.no_grad()
- def testCompareAgainstTransformersFp32(self):
- dtype = torch.float32
+ @parameterized.expand(
+ [
+ [torch.float32, torch.float32],
+ [torch.bfloat16, torch.bfloat16],
+ [torch.float32, torch.bfloat16, 1e-2, 1.6e-2],
+ ]
+ )
+ def testCompareAgainstTransformers(
+ self,
+ reference_dtype: torch.dtype,
+ target_dtype: torch.dtype,
+ atol: Optional[float] = None,
+ rtol: Optional[float] = None,
+ ):
+ torch.set_default_dtype(reference_dtype)
batch_size = 19
batch_seq_len = 23
reference_config = ReferenceT5Config(
@@ -233,19 +484,21 @@ def testCompareAgainstTransformersFp32(self):
theta = Theta(
{
"attn_q.weight": DefaultPrimitiveTensor(
- data=reference_model.q.weight.data
+ data=reference_model.q.weight.to(dtype=target_dtype)
),
"attn_k.weight": DefaultPrimitiveTensor(
- data=reference_model.k.weight.data
+ data=reference_model.k.weight.to(dtype=target_dtype)
),
"attn_v.weight": DefaultPrimitiveTensor(
- data=reference_model.v.weight.data
+ data=reference_model.v.weight.to(dtype=target_dtype)
),
"attn_o.weight": DefaultPrimitiveTensor(
- data=reference_model.o.weight.data
+ data=reference_model.o.weight.to(dtype=target_dtype)
),
"attn_rel_b.weight": DefaultPrimitiveTensor(
- data=reference_model.relative_attention_bias.weight.data
+ data=reference_model.relative_attention_bias.weight.to(
+ dtype=target_dtype
+ )
),
}
)
@@ -257,24 +510,52 @@ def testCompareAgainstTransformersFp32(self):
d_model=reference_config.d_model,
d_kv=reference_config.d_kv,
num_heads=reference_config.num_heads,
- activation_dtype=dtype,
+ activation_dtype=target_dtype,
has_relative_attention_bias=True,
)
model.eval()
- hidden_states = make_rand_torch(
- shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype
+ reference_hidden_states = make_rand_torch(
+ shape=[batch_size, batch_seq_len, reference_config.d_model],
+ dtype=reference_dtype,
+ )
+ reference_mask = make_random_mask(
+ shape=[batch_size, 1, 1, batch_seq_len], dtype=reference_dtype
)
- mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype)
- expected_outputs = reference_model(hidden_states=hidden_states, mask=mask)
+ expected_outputs = reference_model(
+ hidden_states=reference_hidden_states, mask=reference_mask
+ )
+
+ hidden_states = ops.to(reference_hidden_states, dtype=target_dtype)
+ mask = ops.to(reference_mask, dtype=target_dtype)
actual_outputs = model(
hidden_states=DefaultPrimitiveTensor(data=hidden_states),
mask=DefaultPrimitiveTensor(data=mask),
)
- torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0)
+ actual_outputs = tree_map(
+ lambda t: None if t is None else ops.to(t, dtype=reference_dtype),
+ actual_outputs,
+ )
- def testCompareSelfAttentionAgainstTransformersFp32(self):
- dtype = torch.float32
+ torch.testing.assert_close(
+ actual_outputs, expected_outputs, atol=atol, rtol=rtol
+ )
+
+ @parameterized.expand(
+ [
+ [torch.float32, torch.float32],
+ [torch.bfloat16, torch.bfloat16],
+ [torch.float32, torch.bfloat16, 1e-2, 1.6e-2],
+ ]
+ )
+ def testCompareSelfAttentionAgainstTransformers(
+ self,
+ reference_dtype: torch.dtype,
+ target_dtype: torch.dtype,
+ atol: Optional[float] = None,
+ rtol: Optional[float] = None,
+ ):
+ torch.set_default_dtype(reference_dtype)
batch_size = 19
batch_seq_len = 23
reference_config = ReferenceT5Config(
@@ -296,22 +577,24 @@ def testCompareSelfAttentionAgainstTransformersFp32(self):
theta = Theta(
{
"attn_q.weight": DefaultPrimitiveTensor(
- data=reference_model.SelfAttention.q.weight.data
+ data=reference_model.SelfAttention.q.weight.to(dtype=target_dtype)
),
"attn_k.weight": DefaultPrimitiveTensor(
- data=reference_model.SelfAttention.k.weight.data
+ data=reference_model.SelfAttention.k.weight.to(dtype=target_dtype)
),
"attn_v.weight": DefaultPrimitiveTensor(
- data=reference_model.SelfAttention.v.weight.data
+ data=reference_model.SelfAttention.v.weight.to(dtype=target_dtype)
),
"attn_o.weight": DefaultPrimitiveTensor(
- data=reference_model.SelfAttention.o.weight.data
+ data=reference_model.SelfAttention.o.weight.to(dtype=target_dtype)
),
"attn_rel_b.weight": DefaultPrimitiveTensor(
- data=reference_model.SelfAttention.relative_attention_bias.weight.data
+ data=reference_model.SelfAttention.relative_attention_bias.weight.to(
+ dtype=target_dtype
+ )
),
"attn_norm.weight": DefaultPrimitiveTensor(
- data=reference_model.layer_norm.weight.data
+ data=reference_model.layer_norm.weight.to(dtype=target_dtype)
),
}
)
@@ -323,24 +606,37 @@ def testCompareSelfAttentionAgainstTransformersFp32(self):
d_model=reference_config.d_model,
d_kv=reference_config.d_kv,
num_heads=reference_config.num_heads,
- activation_dtype=dtype,
+ activation_dtype=torch.float32,
layer_norm_epsilon=reference_config.layer_norm_epsilon,
has_relative_attention_bias=True,
)
model.eval()
- hidden_states = make_rand_torch(
- shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype
+ reference_hidden_states = make_rand_torch(
+ shape=[batch_size, batch_seq_len, reference_config.d_model],
+ dtype=reference_dtype,
)
- mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype)
- position_bias = make_rand_torch(
- shape=[batch_size, reference_config.num_heads, batch_seq_len, batch_seq_len]
+ reference_mask = make_random_mask(
+ shape=[batch_size, 1, 1, batch_seq_len], dtype=reference_dtype
+ )
+ reference_position_bias = make_rand_torch(
+ shape=[
+ batch_size,
+ reference_config.num_heads,
+ batch_seq_len,
+ batch_seq_len,
+ ],
+ dtype=reference_dtype,
)
expected_outputs = reference_model(
- hidden_states=hidden_states,
- attention_mask=mask,
- position_bias=position_bias,
+ hidden_states=reference_hidden_states,
+ attention_mask=reference_mask,
+ position_bias=reference_position_bias,
)
+
+ hidden_states = ops.to(reference_hidden_states, dtype=target_dtype)
+ mask = ops.to(reference_mask, dtype=target_dtype)
+ position_bias = ops.to(reference_position_bias, dtype=target_dtype)
actual_outputs = model(
hidden_states=DefaultPrimitiveTensor(data=hidden_states),
attention_mask=DefaultPrimitiveTensor(data=mask),
@@ -349,7 +645,14 @@ def testCompareSelfAttentionAgainstTransformersFp32(self):
actual_outputs = [
unbox_tensor(t) if t is not None else t for t in actual_outputs
]
- torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0)
+ actual_outputs = tree_map(
+ lambda t: None if t is None else ops.to(t, dtype=reference_dtype),
+ actual_outputs,
+ )
+
+ torch.testing.assert_close(
+ actual_outputs, expected_outputs, atol=atol, rtol=rtol
+ )
class T5LayerFFTest(TestCase):
@@ -358,8 +661,21 @@ def setUp(self):
torch.random.manual_seed(12345)
torch.no_grad()
- def testCompareAgainstTransformersFp32(self):
- dtype = torch.float32
+ @parameterized.expand(
+ [
+ [torch.float32, torch.float32],
+ [torch.bfloat16, torch.bfloat16],
+ [torch.float32, torch.bfloat16, 1e-2, 1.6e-2],
+ ]
+ )
+ def testCompareAgainstTransformers(
+ self,
+ reference_dtype: torch.dtype,
+ target_dtype: torch.dtype,
+ atol: Optional[float] = None,
+ rtol: Optional[float] = None,
+ ):
+ torch.set_default_dtype(reference_dtype)
batch_size = 19
batch_seq_len = 23
reference_config = ReferenceT5Config(
@@ -376,16 +692,20 @@ def testCompareAgainstTransformersFp32(self):
theta = Theta(
{
"ffn_gate.weight": DefaultPrimitiveTensor(
- data=reference_model.DenseReluDense.wi_0.weight
+ data=reference_model.DenseReluDense.wi_0.weight.to(
+ dtype=target_dtype
+ )
),
"ffn_up.weight": DefaultPrimitiveTensor(
- data=reference_model.DenseReluDense.wi_1.weight
+ data=reference_model.DenseReluDense.wi_1.weight.to(
+ dtype=target_dtype
+ )
),
"ffn_down.weight": DefaultPrimitiveTensor(
- data=reference_model.DenseReluDense.wo.weight
+ data=reference_model.DenseReluDense.wo.weight.to(dtype=target_dtype)
),
"ffn_norm.weight": DefaultPrimitiveTensor(
- data=reference_model.layer_norm.weight
+ data=reference_model.layer_norm.weight.to(dtype=target_dtype)
),
}
)
@@ -394,17 +714,24 @@ def testCompareAgainstTransformersFp32(self):
is_gated_act=reference_config.is_gated_act,
dense_act_fn=reference_config.dense_act_fn,
layer_norm_epsilon=reference_config.layer_norm_epsilon,
- activation_dtype=dtype,
+ activation_dtype=torch.float32,
)
- hidden_states = make_rand_torch(
- shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype
+ reference_hidden_states = make_rand_torch(
+ shape=[batch_size, batch_seq_len, reference_config.d_model],
+ dtype=reference_dtype,
)
-
expected_output = reference_model(
- hidden_states=hidden_states,
+ hidden_states=reference_hidden_states,
)
+
+ hidden_states = ops.to(reference_hidden_states, dtype=target_dtype)
actual_output = model(
hidden_states=DefaultPrimitiveTensor(data=hidden_states),
)
- torch.testing.assert_close(actual_output, expected_output, atol=1e-5, rtol=0)
+ actual_output = tree_map(
+ lambda t: None if t is None else ops.to(t, dtype=reference_dtype),
+ actual_output,
+ )
+
+ torch.testing.assert_close(actual_output, expected_output, atol=atol, rtol=rtol)
diff --git a/sharktank/tests/serving_poc/framework/device_session_test.py b/sharktank/tests/serving_poc/framework/device_session_test.py
deleted file mode 100644
index 5dfdd5f46..000000000
--- a/sharktank/tests/serving_poc/framework/device_session_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import pytest
-
-from sharktank.serving_poc.framework.session import (
- DeviceSession,
-)
-
-
-@pytest.fixture
-def local_device_session():
- session = DeviceSession(uri="local-task")
- yield session
- session.shutdown()
-
-
-def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession):
- ms = local_device_session.create_module_set("default")
- ms.initialize()
-
-
-def test_host_context_start_stop(local_device_session: DeviceSession):
- ms = local_device_session.create_module_set("default")
- ms.initialize()
- hc = ms.host_context
-
-
-def test_host_context_scheduling(local_device_session: DeviceSession):
- device = local_device_session.device
- ms = local_device_session.create_module_set("default")
- ms.initialize()
- hc = ms.host_context
-
- sem = device.create_semaphore(0)
-
- async def task1():
- print("[coro1] test_host_context_scheduling.task")
- await hc.on_semaphore(sem, 1, True)
- print("[coro1] await completed")
- sem.signal(2)
-
- async def task2():
- print("[coro2] waiting for 2")
- await hc.on_semaphore(sem, 2, True)
- sem.fail("Fail from task2")
-
- f1 = hc.run_concurrent(task1())
- f2 = hc.run_concurrent(task2())
- sem.signal(1)
- print("[main] Waiting for semaphore")
-
- # Ensure task completion. Important to consume to ensure that exceptions
- # propagate.
- f1.result()
- f2.result()
-
- print("[main] Waiting on semaphore payload 3")
- with pytest.raises(Exception, match="Fail from task2"):
- sem.wait(3)
diff --git a/sharktank/tests/serving_poc/llm/api_server_test.py b/sharktank/tests/serving_poc/llm/api_server_test.py
deleted file mode 100644
index c2d2cc36a..000000000
--- a/sharktank/tests/serving_poc/llm/api_server_test.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import os
-from contextlib import closing
-from pathlib import Path
-import pytest
-import requests
-import socket
-import subprocess
-import sys
-import time
-
-
-def find_free_port():
- """This tries to find a free port to run a server on for the test.
-
- Race conditions are possible - the port can be acquired between when this
- runs and when the server starts.
-
- https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
- """
- with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
- s.bind(("localhost", 0))
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- return s.getsockname()[1]
-
-
-class ServerRunner:
- def __init__(self, args):
- port = str(find_free_port())
- self.url = "http://localhost:" + port
- env = os.environ.copy()
- env["PYTHONUNBUFFERED"] = "1"
- self.process = subprocess.Popen(
- [
- sys.executable,
- "-m",
- "sharktank.serving_poc.llm.api.rest_server",
- "--testing-mock-service",
- "--port=" + port,
- ]
- + args,
- env=env,
- # TODO: Have a more robust way of forking a subprocess.
- cwd=str(Path(__file__).resolve().parent.parent.parent),
- stdout=sys.stdout,
- stderr=sys.stderr,
- )
- self._wait_for_ready()
-
- def _wait_for_ready(self):
- start = time.time()
- while True:
- try:
- if requests.get(f"{self.url}/health").status_code == 200:
- return
- except Exception as e:
- if self.process.poll() is not None:
- raise RuntimeError("API server processs terminated") from e
- time.sleep(1.0)
- if (time.time() - start) > 30:
- raise RuntimeError("Timeout waiting for server start")
-
- def __del__(self):
- try:
- process = self.process
- except AttributeError:
- pass
- else:
- process.terminate()
- process.wait()
-
-
-@pytest.fixture(scope="session")
-def server():
- try:
- import fastapi
- import uvicorn
- except ModuleNotFoundError as e:
- pytest.skip(f"Skipping server test because deps are missing: {e}")
- runner = ServerRunner([])
- yield runner
-
-
-def test_health(server: ServerRunner):
- # Health check is part of getting the fixture.
- ...
-
-
-def test_generate_non_streaming(server: ServerRunner):
- resp = requests.post(
- f"{server.url}/generate",
- json={
- "prompt": "Hi Bob",
- },
- )
- resp.raise_for_status()
- d = resp.json()
- assert d["text"] == "Hi Bob", repr(d)
-
-
-def test_generate_streaming(server: ServerRunner):
- resp = requests.post(
- f"{server.url}/generate", json={"prompt": "Hi Bob!", "stream": True}
- )
- resp.raise_for_status()
- full_contents = resp.content
- expected_contents = b'{"text": "Hi Bob!"}\x00' * 5
- assert (
- full_contents == expected_contents
- ), f"Expected {expected_contents!r} vs {full_contents!r}"
diff --git a/sharktank/tests/serving_poc/llm/service_v1_test.py b/sharktank/tests/serving_poc/llm/service_v1_test.py
deleted file mode 100644
index c010e2034..000000000
--- a/sharktank/tests/serving_poc/llm/service_v1_test.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import pytest
-
-from iree.runtime import ( # type: ignore
- HalElementType,
-)
-
-from sharktank.serving_poc.framework.session import DeviceSession
-from sharktank.serving_poc.llm.config import (
- CacheParams,
- ModelParams,
- ServiceParams,
-)
-
-from sharktank.serving_poc.llm.service import (
- GenerateRequest,
- GenerateResponsePart,
-)
-
-from sharktank.serving_poc.llm.attn_block_cache import (
- create_attn_block_cache_module,
- AttnBlockCache,
-)
-
-from sharktank.serving_poc.llm.impl.service_v1 import (
- GenerateServiceV1,
-)
-
-from sharktank.serving_poc.llm.testing.fake_v1_module import (
- create_fake_module,
-)
-
-
-@pytest.fixture
-def cache_params(model_params: ModelParams) -> CacheParams:
- return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16)
-
-
-@pytest.fixture
-def model_params() -> ModelParams:
- return ModelParams(
- module_name="AwesomeLLM",
- module_abi_version=1,
- attn_dtype=HalElementType.FLOAT_16,
- max_seq_len=128,
- transformer_block_count=32,
- attn_head_count=32,
- attn_head_dim=128,
- block_seq_stride=16,
- prefill_batch_sizes=[1, 4, 16],
- decode_batch_sizes=[1, 4, 16],
- )
-
-
-@pytest.fixture
-def uninitialized_session(model_params: ModelParams):
- from iree.runtime._binding import disable_leak_checker # type: ignore
-
- disable_leak_checker()
- session = DeviceSession(uri="local-task", queue_count=2)
- yield session
- session.shutdown()
- del session
-
-
-@pytest.fixture
-def attn_block_cache(
- uninitialized_session: DeviceSession, cache_params: CacheParams
-) -> AttnBlockCache:
- return AttnBlockCache(uninitialized_session, cache_params)
-
-
-@pytest.fixture
-def session(
- model_params: ModelParams,
- uninitialized_session: DeviceSession,
- attn_block_cache: AttnBlockCache,
-):
- session = uninitialized_session
- lms = session.create_module_set("AwesomeLLM", context_count=1)
- lms.add(
- create_attn_block_cache_module(attn_block_cache),
- create_fake_module(session.device, "AwesomeLLM", model_params=model_params),
- )
- lms.initialize()
- return session
-
-
-@pytest.fixture
-def service(
- session: DeviceSession,
- cache_params: CacheParams,
- model_params: ModelParams,
- attn_block_cache: AttnBlockCache,
-):
- params = ServiceParams(cache=cache_params, model=model_params)
- return GenerateServiceV1(session=session, params=params, cache=attn_block_cache)
-
-
-def test_single(service: GenerateServiceV1):
- state = service.start()
-
- async def task():
- await state.set_sequences(
- requests=[
- GenerateRequest(
- "1",
- "hello, tell me a story",
- [3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73],
- ),
- GenerateRequest("2", "goodbye", [9, 10]),
- ]
- )
- guarded_outputs = await state.prefill()
- prefill_ids = await guarded_outputs.resolve(state.host_context)
- print(
- "PREFILL IDS:",
- prefill_ids,
- ":\n",
- prefill_ids.map().asarray(
- prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type)
- ),
- )
- await state.recycle()
-
- state.host_context.run_sync(task())
diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt
index f025eccfe..16baa1675 100644
--- a/shortfin/CMakeLists.txt
+++ b/shortfin/CMakeLists.txt
@@ -48,6 +48,7 @@ option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON)
option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON)
option(SHORTFIN_ENABLE_TRACING "Enable runtime tracing for iree and shortfin" OFF)
option(SHORTFIN_ENABLE_LTO "Enables LTO if supported" ON)
+option(SHORTFIN_ENABLE_TOKENIZERS "Enables integration of native tokenizers library" OFF)
set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source")
@@ -80,6 +81,7 @@ list(APPEND CMAKE_MODULE_PATH
${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/
)
include(shortfin_library)
+include(shortfin_testing)
include(CheckCXXCompilerFlag)
include(FetchContent)
@@ -90,7 +92,9 @@ include(FetchContent)
if(SHORTFIN_ENABLE_LTO)
include(CheckIPOSupported)
check_ipo_supported(RESULT SHORTFIN_LTO_SUPPORTED OUTPUT SHORTFIN_LTO_ERROR)
- if(SHORTFIN_LTO_SUPPORTED)
+ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
+ message(STATUS "Not enabling LTO for debug build")
+ elseif(SHORTFIN_LTO_SUPPORTED)
message(STATUS "Shortfin LTO Enabled")
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
else()
@@ -126,7 +130,9 @@ endif()
message(STATUS " - Host")
################################################################################
-# Dependencies
+# Bundled Dependencies
+# These dependencies are either bundled or used via installed packages based
+# on the SHORTFIN_BUNDLE_DEPS option.
################################################################################
if(SHORTFIN_BUNDLE_DEPS)
@@ -164,15 +170,19 @@ if(SHORTFIN_BUNDLE_DEPS)
shortfin_push_bundled_lib_options()
# Enable spdlog shared library options so we can export it.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPDLOG_SHARED_LIB -Dspdlog_EXPORTS")
+ message(STATUS "Fetching bundled projects")
+ list(APPEND CMAKE_MESSAGE_INDENT " ")
FetchContent_MakeAvailable(fmt spdlog xtl xtensor)
shortfin_pop_bundled_lib_options()
+ list(POP_BACK CMAKE_MESSAGE_INDENT)
else()
find_package(spdlog)
find_package(xtensor)
endif()
################################################################################
-# IREE
+# IREE Dependency
+# This is always a source dependency on the IREE runtime.
################################################################################
# Set IREE build flags.
@@ -237,6 +247,65 @@ else()
endif()
shortfin_pop_bundled_lib_options()
+################################################################################
+# Tokenizer Library
+################################################################################
+
+function(shortfin_check_tokenizers)
+ # Make sure that rust/cargo is installed and usable.
+ # Consider switching this to a cached variable once the tokenizers_cpp project
+ # will accept an override vs running whatever is on the path. For now, just
+ # verify the path is sane as that is what will get used.
+ find_program(SHORTFIN_CARGO_PATH NAMES cargo NO_CACHE)
+ if(NOT SHORTFIN_CARGO_PATH)
+ message(SEND_ERROR
+ "Building with -DSHORTFIN_ENABLE_TOKENIZERS=ON requires cargo (Rust's build tool). "
+ "Please follow Rust documentation to install. On Ubuntu, this can typically be accomplished with:\n"
+ " sudo apt install rustup && rustup default stable\n"
+ "See https://www.rust-lang.org/tools/install"
+ )
+ endif()
+
+ # Make sure cargo is functional.
+ execute_process(
+ COMMAND ${SHORTFIN_CARGO_PATH}
+ RESULT_VARIABLE _CARGO_RESULT
+ OUTPUT_VARIABLE _CARGO_OUT
+ ERROR_VARIABLE _CARGO_ERR
+ )
+ if(NOT "${_CARGO_RESULT}" STREQUAL "0")
+ message(SEND_ERROR
+ "Building with -DSHORTFIN_ENABLE_TOKENIZERS=ON requires cargo (Rust's build tool) "
+ "to be configured properly. It was found (${SHORTFIN_CARGO_PATH}) but returned an "
+ "error. Output below:\n"
+ "${_CARGO_OUT}\n"
+ "${_CARGO_ERR}"
+ )
+ endif()
+endfunction()
+
+if(SHORTFIN_ENABLE_TOKENIZERS)
+ # TODO: submit a patch to tokenizers_cpp to allow explicit configuration of the
+ # cargo location and pass that vs relying on environmental alignment.
+ shortfin_check_tokenizers()
+
+ shortfin_push_bundled_lib_options()
+ set(CMAKE_C_VISIBILITY_PRESET "hidden")
+ set(CMAKE_CXX_VISIBILITY_PRESET "hidden")
+ set(CMAKE_VISIBILITY_INLINES_HIDDEN ON)
+ set(MLC_ENABLE_SENTENCEPIECE_TOKENIZER OFF)
+
+ FetchContent_Declare(
+ tokenizers_cpp # From CMake project() declaration
+ GIT_REPOSITORY https://github.com/mlc-ai/tokenizers-cpp.git
+ GIT_TAG 4bb753377680e249345b54c6b10e6d0674c8af03 # 2024 Nov 15
+ EXCLUDE_FROM_ALL
+ )
+ message(STATUS "Fetching tokenizers_cpp")
+ FetchContent_MakeAvailable(tokenizers_cpp)
+ shortfin_pop_bundled_lib_options()
+endif()
+
################################################################################
# Tests
################################################################################
@@ -254,9 +323,9 @@ if(SHORTFIN_BUILD_TESTS)
endif()
include(GoogleTest)
enable_testing()
+ add_custom_target(shortfin_testdata_deps)
endif()
-
add_subdirectory(src)
if(SHORTFIN_BUILD_PYTHON_BINDINGS)
diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake
index aaa97a6c1..103fdf1c5 100644
--- a/shortfin/build_tools/cmake/shortfin_library.cmake
+++ b/shortfin/build_tools/cmake/shortfin_library.cmake
@@ -182,7 +182,10 @@ function(shortfin_gtest_test)
GTest::gmock
GTest::gtest_main
)
- gtest_discover_tests(${_RULE_NAME})
+ gtest_discover_tests(
+ ${_RULE_NAME}
+ WORKING_DIRECTORY "${libshortfin_BINARY_DIR}"
+ )
endfunction()
diff --git a/shortfin/build_tools/cmake/shortfin_testing.cmake b/shortfin/build_tools/cmake/shortfin_testing.cmake
new file mode 100644
index 000000000..e462b7023
--- /dev/null
+++ b/shortfin/build_tools/cmake/shortfin_testing.cmake
@@ -0,0 +1,50 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Downloads some test data file as part of configure.
+# This does a download->rename in an attempt to be robust to partial downloads.
+# It should not be used to manage large test data files or anything sensitive
+# enough to require a hash check.
+# The output file is added as an additional clean file on the global
+# shortfin_testdata_deps target, meaning the "ninja clean" will remove it.
+# It is also added to the current directories list of configure depends, which
+# means that if ninja is run and it is not present, cmake will be re-invoked.
+function(shortfin_download_test_data)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "URL;OUTPUT_FILE"
+ ""
+ ${ARGN}
+ )
+ if(NOT SHORTFIN_BUILD_TESTS)
+ return()
+ endif()
+ if(NOT EXISTS "${_RULE_OUTPUT_FILE}")
+ set(_stage_file "${_RULE_OUTPUT_FILE}.stage")
+ message(STATUS "Downloading test data ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}")
+ file(DOWNLOAD "${_RULE_URL}" "${_stage_file}" STATUS _status)
+ list(POP_FRONT _status _status_code)
+ if(_status_code EQUAL "0")
+ file(RENAME "${_stage_file}" "${_RULE_OUTPUT_FILE}")
+ else()
+ message(SEND_ERROR "Error downloading file ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}")
+ endif()
+ endif()
+
+ # Make clean remove it.
+ set_property(
+ TARGET shortfin_testdata_deps
+ APPEND PROPERTY ADDITIONAL_CLEAN_FILES
+ "${_RULE_OUTPUT_FILE}"
+ )
+
+ # And make us reconfigure if it isn't there.
+ set_property(
+ DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+ APPEND PROPERTY
+ CMAKE_CONFIGURE_DEPENDS "${_RULE_OUTPUT_FILE}")
+endfunction()
diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc
index a05232674..08a4071a8 100644
--- a/shortfin/python/array_binding.cc
+++ b/shortfin/python/array_binding.cc
@@ -531,22 +531,26 @@ void BindArray(py::module_ &m) {
->AddAsInvocationArgument(
inv, static_cast(barrier));
})
- .def_static("for_device",
- [](local::ScopedDevice &device, std::span shape,
- DType dtype) {
- return custom_new_keep_alive(
- py::type(),
- /*keep_alive=*/device.fiber(),
- device_array::for_device(device, shape, dtype));
- })
- .def_static("for_host",
- [](local::ScopedDevice &device, std::span shape,
- DType dtype) {
- return custom_new_keep_alive(
- py::type(),
- /*keep_alive=*/device.fiber(),
- device_array::for_host(device, shape, dtype));
- })
+ .def_static(
+ "for_device",
+ [](local::ScopedDevice &device, std::span shape,
+ DType dtype) {
+ return custom_new_keep_alive(
+ py::type(),
+ /*keep_alive=*/device.fiber(),
+ device_array::for_device(device, shape, dtype));
+ },
+ py::arg("device"), py::arg("shape"), py::arg("dtype"))
+ .def_static(
+ "for_host",
+ [](local::ScopedDevice &device, std::span shape,
+ DType dtype) {
+ return custom_new_keep_alive(
+ py::type(),
+ /*keep_alive=*/device.fiber(),
+ device_array::for_host(device, shape, dtype));
+ },
+ py::arg("device"), py::arg("shape"), py::arg("dtype"))
.def("for_transfer",
[](device_array &self) {
return custom_new_keep_alive(
diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc
index 86385cfee..3e2a8ebe3 100644
--- a/shortfin/python/array_host_ops.cc
+++ b/shortfin/python/array_host_ops.cc
@@ -91,6 +91,18 @@ static const char DOCSTRING_RANDOM_GENERATOR[] =
fixed number.
)";
+static const char DOCSTRING_TRANSPOSE[] =
+ R"(Transposes axes of an array according to a permutation vector.
+
+Args:
+ input: Array to transpose.
+ permutation: New sequence of axes. Must have same number of elements as the
+ rank of input.
+ out: If given, then the results are written to this array.
+ device_visible: Whether to make the result array visible to devices. Defaults
+ to False.
+)";
+
#define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \
case DType::dtype_name(): \
return compute.template operator()()
@@ -100,6 +112,25 @@ static const char DOCSTRING_RANDOM_GENERATOR[] =
compute.template operator()(); \
break
+#define SF_MOVEMENT_OP_SWITCH(dtype) \
+ if (!dtype.is_byte_aligned()) \
+ throw std::invalid_argument( \
+ "data movement ops are only defined for byte aligned dtypes"); \
+ switch (dtype.dense_byte_count()) { \
+ case 1: \
+ return compute.template operator()(); \
+ case 2: \
+ return compute.template operator()(); \
+ case 4: \
+ return compute.template operator()(); \
+ case 8: \
+ return compute.template operator()(); \
+ default: \
+ throw std::invalid_argument( \
+ "data movement ops are only defined for dtypes of size 1, 2, " \
+ "4, 8"); \
+ }
+
struct PyRandomGenerator {
public:
using SeedType = xt::random::default_engine_type::result_type;
@@ -374,6 +405,227 @@ struct ConvertTruncFunctor {
}
};
+void OptionalArrayCast(py::handle handle,
+ std::optional &maybe_array) {
+ if (py::isinstance(handle)) {
+ maybe_array.emplace(py::cast(handle));
+ }
+}
+
+int DTypePromotionRank(DType dtype) {
+ int rank = 1;
+ if (dtype.is_boolean())
+ rank *= 1000;
+ else if (dtype.is_integer())
+ rank *= 2000;
+ else if (dtype.is_float())
+ rank *= 4000;
+ else if (dtype.is_complex())
+ rank *= 8000;
+ return rank + dtype.bit_count();
+}
+
+DType PromoteArithmeticTypes(std::optional lhs_dtype,
+ std::optional rhs_dtype) {
+ if (!lhs_dtype && !rhs_dtype) {
+ throw std::invalid_argument(
+ "Elementwise operators require at least one argument to be a "
+ "device_array");
+ }
+
+ // One not an array: promote to the array type.
+ if (!lhs_dtype)
+ return *rhs_dtype;
+ else if (!rhs_dtype)
+ return *lhs_dtype;
+
+ int lhs_rank = DTypePromotionRank(*lhs_dtype);
+ int rhs_rank = DTypePromotionRank(*rhs_dtype);
+ DType promoted_dtype = lhs_rank < rhs_rank ? *rhs_dtype : *lhs_dtype;
+
+ // If mismatched signed/unsigned, then need to promote to the next signed
+ // dtype.
+ if (promoted_dtype.is_integer()) {
+ bool lhs_unsigned = iree_all_bits_set(
+ lhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED);
+ bool rhs_unsigned = iree_all_bits_set(
+ rhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED);
+ if ((lhs_unsigned || rhs_unsigned) && !(lhs_unsigned && rhs_unsigned)) {
+ // Signed/unsigned mismatch. Promote to next.
+ switch (promoted_dtype) {
+ case DType::uint8():
+ case DType::int8():
+ return DType::int16();
+ case DType::uint16():
+ case DType::int16():
+ return DType::int32();
+ case DType::uint32():
+ case DType::int32():
+ return DType::int64();
+ default:
+ // Jax's type promotion chart says this goes to a weak FP type, but
+ // we don't implement such a construct and I don't really see how
+ // that makes sense in a system setting like this, so we just saturate
+ // to 64bit.
+ return DType::int64();
+ }
+ }
+ }
+
+ return promoted_dtype;
+}
+
+// ---------------------------------------------------------------------------//
+// Elementwise support
+// ---------------------------------------------------------------------------//
+
+// Python element type scalar conversion functions.
+uint8_t ConvertPyToEltTy(py::handle py_value, uint8_t zero) {
+ return py::cast(py_value);
+}
+
+int8_t ConvertPyToEltTy(py::handle py_value, int8_t zero) {
+ return py::cast(py_value);
+}
+
+uint16_t ConvertPyToEltTy(py::handle py_value, uint16_t zero) {
+ return py::cast(py_value);
+}
+
+int16_t ConvertPyToEltTy(py::handle py_value, int16_t zero) {
+ return py::cast(py_value);
+}
+
+uint32_t ConvertPyToEltTy(py::handle py_value, uint32_t zero) {
+ return py::cast(py_value);
+}
+
+int32_t ConvertPyToEltTy(py::handle py_value, int32_t zero) {
+ return py::cast(py_value);
+}
+
+uint64_t ConvertPyToEltTy(py::handle py_value, uint64_t zero) {
+ return py::cast(py_value);
+}
+
+int64_t ConvertPyToEltTy(py::handle py_value, int64_t zero) {
+ return py::cast(py_value);
+}
+
+float ConvertPyToEltTy(py::handle py_value, float zero) {
+ return py::cast(py_value);
+}
+
+double ConvertPyToEltTy(py::handle py_value, double zero) {
+ return py::cast(py_value);
+}
+
+half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) {
+ // Python can't cast directly to half so first go to double.
+ return static_cast(py::cast(py_value));
+}
+
+struct AddFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs + rhs;
+ }
+};
+
+struct DivideFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs / rhs;
+ }
+};
+
+struct MultiplyFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs * rhs;
+ }
+};
+
+struct SubtractFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs - rhs;
+ }
+};
+
+template
+device_array ElementwiseOperation(py::handle lhs, py::handle rhs,
+ std::optional out,
+ bool device_visible) {
+ std::optional lhs_array;
+ OptionalArrayCast(lhs, lhs_array);
+ std::optional rhs_array;
+ OptionalArrayCast(rhs, rhs_array);
+ auto dtype = PromoteArithmeticTypes(
+ lhs_array ? std::optional(lhs_array->dtype()) : std::nullopt,
+ rhs_array ? std::optional(rhs_array->dtype()) : std::nullopt);
+ if (lhs_array && lhs_array->dtype() != dtype) {
+ auto converted = GenericElementwiseConvert(
+ *lhs_array, dtype, /*out=*/std::nullopt,
+ /*device_visible=*/false);
+ lhs_array.reset();
+ lhs_array.emplace(std::move(converted));
+ }
+ if (rhs_array && rhs_array->dtype() != dtype) {
+ auto converted = GenericElementwiseConvert(
+ *rhs_array, dtype, /*out=*/std::nullopt,
+ /*device_visible=*/false);
+ rhs_array.reset();
+ rhs_array.emplace(std::move(converted));
+ }
+
+ auto compute = [&]() -> device_array {
+ auto handle_result = [&](
+ D &&device, A &&result) -> device_array {
+ if (!out) {
+ out.emplace(device_array::for_host(device, result.shape(), dtype,
+ device_visible));
+ }
+ auto out_t = out->map_xtensor_w();
+ *out_t = result;
+ return *out;
+ };
+ if (!rhs_array) {
+ auto lhs_t = lhs_array->map_xtensor();
+ xt::xarray rhs_scalar = ConvertPyToEltTy(rhs, EltTy());
+ return handle_result(lhs_array->device(),
+ ElementwiseFunctor::Invoke(*lhs_t, rhs_scalar));
+ } else if (!lhs_array) {
+ xt::xarray lhs_scalar = ConvertPyToEltTy(lhs, EltTy());
+ auto rhs_t = rhs_array->map_xtensor();
+ return handle_result(rhs_array->device(),
+ ElementwiseFunctor::Invoke(lhs_scalar, *rhs_t));
+ } else {
+ auto lhs_t = lhs_array->map_xtensor();
+ auto rhs_t = rhs_array->map_xtensor();
+ return handle_result(lhs_array->device(),
+ ElementwiseFunctor::Invoke(*lhs_t, *rhs_t));
+ }
+ };
+
+ switch (dtype) {
+ SF_UNARY_FUNCTION_CASE(float16, half_float::half);
+ SF_UNARY_FUNCTION_CASE(float32, float);
+ SF_UNARY_FUNCTION_CASE(float64, double);
+ SF_UNARY_FUNCTION_CASE(uint8, uint8_t);
+ SF_UNARY_FUNCTION_CASE(int8, int8_t);
+ SF_UNARY_FUNCTION_CASE(uint16, uint16_t);
+ SF_UNARY_FUNCTION_CASE(int16, int16_t);
+ SF_UNARY_FUNCTION_CASE(uint32, uint32_t);
+ SF_UNARY_FUNCTION_CASE(int32, uint32_t);
+ SF_UNARY_FUNCTION_CASE(uint64, uint64_t);
+ SF_UNARY_FUNCTION_CASE(int64, int64_t);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for in elementwise op", dtype.name()));
+ }
+}
+
} // namespace
void BindArrayHostOps(py::module_ &m) {
@@ -457,6 +709,39 @@ void BindArrayHostOps(py::module_ &m) {
SF_DEF_CONVERT("floor", GenericElementwiseConvert);
SF_DEF_CONVERT("round", GenericElementwiseConvert);
SF_DEF_CONVERT("trunc", GenericElementwiseConvert);
-}
+
+ // Transpose.
+ m.def(
+ "transpose",
+ [](device_array input, std::vector permutation,
+ std::optional out, bool device_visible) {
+ auto compute = [&]() -> device_array {
+ auto input_t = input.map_xtensor();
+ auto permuted_t =
+ xt::transpose(*input_t, permutation, xt::check_policy::full());
+ if (!out) {
+ out.emplace(device_array::for_host(input.device(),
+ permuted_t.shape(),
+ input.dtype(), device_visible));
+ }
+ auto out_t = out->map_xtensor_w();
+ *out_t = permuted_t;
+ return *out;
+ };
+ SF_MOVEMENT_OP_SWITCH(input.dtype());
+ },
+ py::arg("input"), py::arg("permutation"), py::arg("out") = py::none(),
+ py::arg("device_visible") = false, DOCSTRING_TRANSPOSE);
+
+// Elementwise.
+#define SF_DEF_ELEMENTWISE(py_name, target) \
+ m.def(py_name, target, py::arg("lhs"), py::arg("rhs"), py::kw_only(), \
+ py::arg("out") = py::none(), py::arg("device_visible") = false)
+ SF_DEF_ELEMENTWISE("add", ElementwiseOperation);
+ SF_DEF_ELEMENTWISE("divide", ElementwiseOperation);
+ SF_DEF_ELEMENTWISE("multiply", ElementwiseOperation);
+ SF_DEF_ELEMENTWISE("subtract", ElementwiseOperation);
+
+} // namespace shortfin::python
} // namespace shortfin::python
diff --git a/shortfin/python/shortfin/array/__init__.py b/shortfin/python/shortfin/array/__init__.py
index 6079541c8..670102dfe 100644
--- a/shortfin/python/shortfin/array/__init__.py
+++ b/shortfin/python/shortfin/array/__init__.py
@@ -44,11 +44,16 @@
# Ops.
argmax = _sfl.array.argmax
+add = _sfl.array.add
ceil = _sfl.array.ceil
convert = _sfl.array.convert
+divide = _sfl.array.divide
fill_randn = _sfl.array.fill_randn
floor = _sfl.array.floor
+multiply = _sfl.array.multiply
round = _sfl.array.round
+subtract = _sfl.array.subtract
+transpose = _sfl.array.transpose
trunc = _sfl.array.trunc
RandomGenerator = _sfl.array.RandomGenerator
@@ -86,12 +91,17 @@
"storage",
"DType",
# Ops.
+ "add",
"argmax",
"ceil",
"convert",
+ "divide",
"fill_randn",
"floor",
+ "multiply",
"round",
+ "subtract",
+ "transpose",
"trunc",
"RandomGenerator",
]
diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py
index 7123d011e..fb8ca8176 100644
--- a/shortfin/python/shortfin_apps/llm/_deps.py
+++ b/shortfin/python/shortfin_apps/llm/_deps.py
@@ -5,13 +5,23 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from shortfin.support.deps import ShortfinDepNotFoundError
+import sys
-try:
- import tokenizers
-except ModuleNotFoundError as e:
- raise ShortfinDepNotFoundError(__name__, "tokenizers") from e
+deps = [
+ "tokenizers",
+ "dataclasses_json",
+]
-try:
- import dataclasses_json
-except ModuleNotFoundError as e:
- raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e
+for dep in deps:
+ try:
+ __import__(dep)
+ except ModuleNotFoundError as e:
+ if "pytest" in sys.modules:
+ import pytest
+
+ pytest.skip(
+ f"A test imports shortfin_apps.llm; skipping due to unavailable Shortfin LLM dependency: {dep}",
+ allow_module_level=True,
+ )
+ else:
+ raise ShortfinDepNotFoundError(__name__, dep) from e
diff --git a/shortfin/python/shortfin_apps/llm/components/cache.py b/shortfin/python/shortfin_apps/llm/components/cache.py
deleted file mode 100644
index 12794498f..000000000
--- a/shortfin/python/shortfin_apps/llm/components/cache.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from typing import Sequence
-
-import logging
-import math
-import threading
-
-import shortfin as sf
-
-from .config_struct import ModelParams, human_size
-
-logger = logging.getLogger(__name__)
-
-
-class AttnPageEntry:
- __slots__ = [
- "cache",
- "index",
- "in_use",
- ]
-
- def __init__(self, cache: "AttnPageCache", index: int):
- self.cache = cache
- self.index = index
- self.in_use = False
-
- def __repr__(self):
- return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})"
-
-
-class AttnPageCache:
- """Page table based attention cache.
-
- While internal to a model, the cache is organized with additional structure
- per page, outside of the model, it is just a list of pages of a certain
- element type and number of elements (all inner dims are flattened).
-
- One page table is allocated per device in a fiber. Currently, this is a
- dense allocation with committed memory but in the future, we may just
- allocate the address space and lazily populate it with committed memory.
-
- The cache is unique because usage of it can span fibers and concurrency
- is implicitly managed at the block level (i.e. freshly acquired blocks
- are assumed to be uninitialized and available immediately for use).
-
- It is initialized with a discrete list of fiberd devices from a fiber but
- cache usage can be done from any fiber which includes those devices.
- """
-
- def __init__(
- self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams
- ):
- self._lock = threading.Lock()
- self.devices = list(devices)
- self.model_params = model_params
- self.page_tables: list[sf.array.device_array] = []
- cache_params = model_params.paged_kv_cache
- alloc_page_count = cache_params.device_block_count
-
- # Setup accounting structs.
- self.attn_page_entries = [
- AttnPageEntry(self, i) for i in range(alloc_page_count)
- ]
- self.attn_page_free = list(self.attn_page_entries)
-
- # Initialize a page table on each device.
- assert cache_params is not None, "Model does not have a paged kv cache"
- page_table_shape = [
- alloc_page_count,
- model_params.paged_kv_block_size_elements,
- ]
- for device in devices:
- logging.info(
- "Allocating page table (shape=%r, dtype=%r, size=%s) on %r",
- page_table_shape,
- model_params.attn_dtype,
- human_size(
- math.prod(page_table_shape)
- * model_params.attn_dtype.dense_byte_count
- ),
- device,
- )
- page_table = sf.array.device_array.for_device(
- device, page_table_shape, model_params.attn_dtype
- )
- self.page_tables.append(page_table)
-
- def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None:
- with self._lock:
- available = len(self.attn_page_free)
- if count > available:
- return None
- return [self.attn_page_free.pop() for _ in range(count)]
-
- def release_pages(self, pages: list[AttnPageEntry]):
- with self._lock:
- self.attn_page_free.extend(pages)
-
- def __repr__(self):
- # No need to lock for repr (list is internally synchronized).
- free_pages = len(self.attn_page_free)
- total_pages = len(self.attn_page_entries)
- return (
- f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: "
- f"{100.0 * free_pages / total_pages}% free)"
- )
diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py
new file mode 100644
index 000000000..0007000bc
--- /dev/null
+++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py
@@ -0,0 +1,80 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""
+Base class for kv caches.
+"""
+
+from typing import List
+from .page_pool import PageInfo
+import math
+
+
+class BasePagedAttentionCache:
+ """
+ Manages lifecycle of pages (using PageInfo as handles).
+
+
+ Page States:
+ Caching - Page can be read by multiple threads
+ - Also maintains a reference count
+ Writing - Page is being modified by a single owner thread
+
+ Transitions:
+ Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing
+ Writing -> Caching: When writing is complete and page is released
+
+ Thread Safety:
+ - Multiple readers allowed in ReadableCaching state
+ - Single writer exclusive access in Writing state
+ - Reference counting prevents eviction of in-use pages
+ """
+
+ def __init__(self, page_pool, tokens_per_page):
+ self.page_pool = page_pool
+ self.tokens_per_page = tokens_per_page
+
+ def acquire_pages_for_tokens(
+ self, tokens: List[int], extra_token_slots: int = 1
+ ) -> tuple[list[PageInfo], int]:
+ """
+ Given a list of tokens, return a list of pages and a start position to continue generation from.
+
+ Parameters:
+ - tokens: all the known tokens for this generation request
+ - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens.
+
+ In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable.
+
+ The pages are returned in order.
+
+ No token at idx < n_cached_token should be written to. TODO: consider enforcing this.
+ """
+ token_count = len(tokens)
+ pages_needed = math.ceil(token_count / self.tokens_per_page)
+ pages = self.page_pool.acquire_free_pages(pages_needed)
+
+ n_cached_tokens = 0
+
+ return pages, n_cached_tokens
+
+ def publish_pages(self, tokens, pages) -> None:
+ """
+ Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.
+
+ Associates the tokens with the pages, and mark them as done writing.
+
+ It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
+ """
+
+ pass # the base implementation doesn't cache unfinished requests.
+
+ def release_pages(self, tokens, pages):
+ """
+ Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
+ """
+ # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
+ self.page_pool.release_pages(pages)
diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py
new file mode 100644
index 000000000..1686370c0
--- /dev/null
+++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py
@@ -0,0 +1,159 @@
+from __future__ import annotations
+from typing import List, Tuple, Optional, Sequence
+import threading
+import logging
+import shortfin as sf
+import shortfin.array as sfnp
+from dataclasses import dataclass
+
+from ..config_struct import human_size
+import math
+
+import time
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class PageInfo:
+ """
+ Page index with some metadata about its contents.
+ """
+
+ index: int
+ pool: PagePool
+ token_offset: int # Offset within the page
+ token_count: int # Number of tokens stored in this page
+ writing: bool = False
+ read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release
+
+
+@dataclass
+class PagePoolConfig:
+ """
+ Hyperparameters for the page pool.
+ """
+
+ dtype: sf.dtype
+ alloc_page_count: int
+
+ paged_kv_block_size_elements: int # size of a single page as # of elements
+ # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where:
+ # 32: number of transformer blocks
+ # 2: one for k + one for v
+ # 16: tokens per page
+ # 8: head count (32 heads, but every 4 heads share the same kv buffer)
+ # 128: hidden dimension
+
+
+class PagePool:
+ """Page table based attention cache.
+
+ While internal to a model, the cache is organized with additional structure
+ per page, outside of the model, it is just a list of pages of a certain
+ element type and number of elements (all inner dims are flattened).
+
+ One page table is allocated per device in a fiber. Currently, this is a
+ dense allocation with committed memory but in the future, we may just
+ allocate the address space and lazily populate it with committed memory.
+
+ The cache is unique because usage of it can span fibers and concurrency
+ is implicitly managed at the block level (i.e. freshly acquired blocks
+ are assumed to be uninitialized and available immediately for use).
+
+ It is initialized with a discrete list of fiberd devices from a fiber but
+ cache usage can be done from any fiber which includes those devices.
+
+ In addition to supporting paged attention standalone, this also serves
+ as the array / buffer allocation layer for radix attention described in
+ `radix_tree.py`.
+ """
+
+ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig):
+ self._lock = threading.Lock()
+ self.devices = list(devices)
+ self.config = config
+ self.page_tables: list[sf.array.device_array] = []
+
+ # Setup accounting structs.
+ self.attn_page_entries = [
+ PageInfo(
+ index=i,
+ pool=self,
+ token_offset=0,
+ token_count=0,
+ )
+ for i in range(self.config.alloc_page_count)
+ ]
+
+ self.attn_page_free = list(self.attn_page_entries)
+
+ # Initialize a page table on each device.
+ page_table_shape = [
+ self.config.alloc_page_count,
+ self.config.paged_kv_block_size_elements,
+ ]
+ for device in devices:
+ logging.info(
+ "Allocating page table (shape=%r, dtype=%r, size=%s) on %r",
+ page_table_shape,
+ self.config.dtype,
+ human_size(config.dtype.compute_dense_nd_size(page_table_shape)),
+ device,
+ )
+ page_table = sf.array.device_array.for_device(
+ device, page_table_shape, self.config.dtype
+ )
+ self.page_tables.append(page_table)
+
+ def acquire_free_pages(self, count: int) -> list[PageInfo] | None:
+ with self._lock:
+ available = len(self.attn_page_free)
+ if count > available:
+ return None
+ return [self.attn_page_free.pop() for _ in range(count)]
+
+ def release_pages(self, pages: list[PageInfo]):
+ with self._lock:
+ self.attn_page_free.extend(pages)
+
+ def copy_page(self, src_page: PageInfo) -> PageInfo:
+ """
+ Copy a page's contents to a new page.
+
+ Args:
+ src_page: Source page to copy from
+ token_count: Optional number of tokens to copy. If None, copies all tokens.
+
+ Returns:
+ New PageInfo containing the copied data
+ """
+ # Allocate new page
+ (dst_page,) = self.acquire_free_pages(1)
+
+ # fill src page with data
+
+ # Copy the data on each device
+ for page_table in self.page_tables:
+ # View of source and destination pages
+ src_view = page_table.view(src_page.index)
+ dst_view = page_table.view(dst_page.index)
+ # Copy the data
+ dst_view.copy_from(src_view)
+
+ # Setup destination page metadata
+ dst_page.token_offset = 0 # Always start at beginning of new page
+
+ return dst_page
+
+ def __repr__(self):
+ # No need to lock for repr (list is internally synchronized).
+ free_pages = len(self.attn_page_free)
+ total_pages = len(self.attn_page_entries)
+ return (
+ f"PagePool({total_pages - free_pages}/{total_pages} pages in use: "
+ f"{100.0 * free_pages / total_pages}% free)"
+ )
+
+
+############################## begin radix attention
diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py
index fdcbeefc1..c3e6fe34b 100644
--- a/shortfin/python/shortfin_apps/llm/components/messages.py
+++ b/shortfin/python/shortfin_apps/llm/components/messages.py
@@ -9,7 +9,8 @@
import shortfin as sf
import shortfin.array as sfnp
-from .cache import AttnPageCache, AttnPageEntry
+from .kvcache.base_attention_cache import BasePagedAttentionCache
+from .kvcache.page_pool import PageInfo
class InferencePhase(Enum):
@@ -41,8 +42,8 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]):
self.result_logits: sfnp.device_array | None = None
# Cache pages that have been locked for this request.
- self._cache: AttnPageCache | None = None
- self.locked_pages: list[AttnPageEntry] | None = None
+ self._cache: BasePagedAttentionCache | None = None
+ self.locked_pages: list[PageInfo] | None = None
def reset(self, phase: InferencePhase):
"""Resets all per request state in preparation for an subsequent execution."""
@@ -66,16 +67,18 @@ def free_cache_pages(self):
pages = self.locked_pages
self._cache = None
self.locked_pages = None
- cache.release_pages(pages)
+ cache.release_pages(self.input_token_ids, pages)
def lock_initial_cache_pages(
- self, cache: AttnPageCache, pages: list[AttnPageEntry]
+ self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert not self._cache
self._cache = cache
self.locked_pages = pages
- def lock_new_cache_pages(self, cache: AttnPageCache, pages: list[AttnPageEntry]):
+ def lock_new_cache_pages(
+ self, cache: BasePagedAttentionCache, pages: list[PageInfo]
+ ):
assert self._cache is cache
self.locked_pages.extend(pages)
diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py
index bcd08b756..8d3cc1424 100644
--- a/shortfin/python/shortfin_apps/llm/components/service.py
+++ b/shortfin/python/shortfin_apps/llm/components/service.py
@@ -11,7 +11,8 @@
import shortfin as sf
import shortfin.array as sfnp
-from .cache import AttnPageCache
+from .kvcache.base_attention_cache import BasePagedAttentionCache
+from .kvcache.page_pool import PagePoolConfig, PagePool
from .config_struct import ModelParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
@@ -54,8 +55,17 @@ def __init__(
# Scope dependent objects.
self.batcher = BatcherProcess(self)
- self.page_cache = AttnPageCache(
- devices=self.main_fiber.devices_dict.values(), model_params=model_params
+ page_pool_config = PagePoolConfig(
+ dtype=model_params.attn_dtype,
+ alloc_page_count=model_params.paged_kv_cache.device_block_count,
+ paged_kv_block_size_elements=model_params.paged_kv_block_size_elements,
+ )
+ page_pool = PagePool(
+ devices=self.main_fiber.devices_dict.values(), config=page_pool_config
+ )
+ self.page_cache = BasePagedAttentionCache(
+ page_pool=page_pool,
+ tokens_per_page=model_params.paged_kv_cache.block_seq_stride,
)
self.program_isolation = PROG_ISOLATIONS[program_isolation]
@@ -200,7 +210,7 @@ def board_flights(self):
self.pending_prefills.clear()
logger.debug("Post boarding cache state: %r", cache)
- def board_prefills(self, cache: AttnPageCache):
+ def board_prefills(self, cache: BasePagedAttentionCache):
# Fill prefill flights.
pending_prefills = self.pending_prefills
if len(pending_prefills) == 0:
@@ -209,7 +219,7 @@ def board_prefills(self, cache: AttnPageCache):
self.service,
InferencePhase.PREFILL,
self.page_seq_stride,
- cache.page_tables,
+ cache.page_pool.page_tables,
)
for prefill_request in pending_prefills:
assert prefill_request.phase == InferencePhase.PREFILL
@@ -218,7 +228,11 @@ def board_prefills(self, cache: AttnPageCache):
needed_pages = math.ceil(
len(prefill_request.input_token_ids) / self.page_seq_stride
)
- pages = cache.acquire_free_pages(needed_pages)
+ # allocate kv cache pages
+ pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
+ prefill_request.input_token_ids,
+ extra_token_slots=0, # prefill needs no extra kvcache slots to write to
+ )
if pages is None:
logger.debug("Cannot fulfill request for %d pages", needed_pages)
continue
@@ -236,13 +250,16 @@ def board_prefills(self, cache: AttnPageCache):
# And takeoff.
exec_process.launch()
- def board_decodes(self, cache: AttnPageCache):
+ def board_decodes(self, cache: BasePagedAttentionCache):
# Fill decode flights.
pending_decodes = self.pending_decodes
if len(pending_decodes) == 0:
return
exec_process = InferenceExecutorProcess(
- self.service, InferencePhase.DECODE, self.page_seq_stride, cache.page_tables
+ self.service,
+ InferencePhase.DECODE,
+ self.page_seq_stride,
+ cache.page_pool.page_tables,
)
for decode_request in pending_decodes:
assert decode_request.phase == InferencePhase.DECODE
@@ -254,7 +271,11 @@ def board_decodes(self, cache: AttnPageCache):
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
- pages = cache.acquire_free_pages(needed_pages)
+ # allocate kv cache pages
+ pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
+ decode_request.input_token_ids,
+ extra_token_slots=1, # need 1 extra slot to write result.
+ )
if pages is None:
logger.debug(
"Cannot fulfill decode request for %d pages", needed_pages
diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py
index 2ab7a1b96..1561803dd 100644
--- a/shortfin/python/shortfin_apps/llm/server.py
+++ b/shortfin/python/shortfin_apps/llm/server.py
@@ -33,6 +33,33 @@
logger = logging.getLogger(__name__)
+UVICORN_LOG_CONFIG = {
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "default": {
+ "()": "uvicorn.logging.DefaultFormatter",
+ "format": "[{asctime}] {message}",
+ "datefmt": "%Y-%m-%d %H:%M:%S",
+ "style": "{",
+ "use_colors": True,
+ },
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "formatter": "default",
+ },
+ },
+ "loggers": {
+ "uvicorn": {
+ "handlers": ["console"],
+ "level": "INFO",
+ "propagate": False,
+ },
+ },
+}
+
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -211,11 +238,5 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
main(
sys.argv[1:],
# Make logging defer to the default shortfin logging config.
- log_config={
- "version": 1,
- "disable_existing_loggers": False,
- "formatters": {},
- "handlers": {},
- "loggers": {},
- },
+ log_config=UVICORN_LOG_CONFIG,
)
diff --git a/shortfin/requirements-tests-nogil.txt b/shortfin/requirements-tests-nogil.txt
index 1049b0412..1769467ab 100644
--- a/shortfin/requirements-tests-nogil.txt
+++ b/shortfin/requirements-tests-nogil.txt
@@ -1,4 +1,3 @@
pytest
requests
-fastapi
uvicorn
diff --git a/shortfin/setup.py b/shortfin/setup.py
index cf3762950..e15b38d89 100644
--- a/shortfin/setup.py
+++ b/shortfin/setup.py
@@ -225,6 +225,7 @@ def build_cmake_configuration(CMAKE_BUILD_DIR: Path, extra_cmake_args=[]):
add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_LTO", default_value="ON")
add_env_cmake_setting(cmake_args, "SHORTFIN_IREE_SOURCE_DIR")
add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_ASAN")
+ add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_TOKENIZERS", default_value="OFF")
# Only do a from-scratch configure if not already configured.
cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt")
diff --git a/shortfin/src/shortfin/CMakeLists.txt b/shortfin/src/shortfin/CMakeLists.txt
index 058e0e336..73df08e7c 100644
--- a/shortfin/src/shortfin/CMakeLists.txt
+++ b/shortfin/src/shortfin/CMakeLists.txt
@@ -5,5 +5,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
add_subdirectory(array)
+add_subdirectory(components/tokenizers)
add_subdirectory(local)
add_subdirectory(support)
diff --git a/shortfin/src/shortfin/array/dtype.h b/shortfin/src/shortfin/array/dtype.h
index d746d69bf..de1763698 100644
--- a/shortfin/src/shortfin/array/dtype.h
+++ b/shortfin/src/shortfin/array/dtype.h
@@ -49,6 +49,9 @@ class SHORTFIN_API DType {
bool is_integer_bitwidth(size_t bitwidth) const {
return iree_hal_element_type_is_integer(et_, bitwidth);
}
+ uint32_t numerical_type() const {
+ return iree_hal_element_numerical_type(et_);
+ }
// Computes the size in bytes required to store densely packed nd-dims.
// This presently only supports byte aligned dtypes. In the future, when
diff --git a/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt b/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt
new file mode 100644
index 000000000..6b9f794b1
--- /dev/null
+++ b/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt
@@ -0,0 +1,41 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(NOT SHORTFIN_ENABLE_TOKENIZERS)
+ return()
+endif()
+
+shortfin_cc_component(
+ NAME
+ shortfin_tokenizers
+ HDRS
+ tokenizers.h
+ SRCS
+ tokenizers.cc
+ DEFINES
+ SHORTFIN_HAVE_TOKENIZERS
+ COMPONENTS
+ shortfin_support
+ DEPS
+ tokenizers_cpp
+)
+set_property(GLOBAL APPEND
+ PROPERTY SHORTFIN_LIB_OPTIONAL_COMPONENTS
+ shortfin_tokenizers)
+target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_TOKENIZERS)
+
+# Download test data.
+shortfin_download_test_data(
+ URL "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json"
+ OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/tokenizer.json"
+)
+
+# Note that tests run from the binary dir of the project.
+shortfin_gtest_test(
+ NAME shortfin_tokenizers_test
+ SRCS
+ tokenizers_test.cc
+)
diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers.cc b/shortfin/src/shortfin/components/tokenizers/tokenizers.cc
new file mode 100644
index 000000000..118bc0c1b
--- /dev/null
+++ b/shortfin/src/shortfin/components/tokenizers/tokenizers.cc
@@ -0,0 +1,63 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "shortfin/components/tokenizers/tokenizers.h"
+
+#include
+
+#include "shortfin/support/logging.h"
+#include "tokenizers_cpp.h"
+
+namespace shortfin::tokenizers {
+
+namespace {
+
+class AccessibleTokenizer : public Tokenizer {
+ public:
+ using Tokenizer::vendor_tokenizer_;
+};
+
+::tokenizers::Tokenizer *Get(Tokenizer *self) {
+ void *ptr = static_cast(self)->vendor_tokenizer_;
+ if (!ptr) {
+ throw std::logic_error("Tokenizer is null");
+ }
+ return static_cast<::tokenizers::Tokenizer *>(ptr);
+}
+
+} // namespace
+
+Tokenizer::~Tokenizer() { delete Get(this); }
+
+Tokenizer Tokenizer::FromBlobJSON(const std::string &json_blob) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::FromBlobJSON");
+ return Tokenizer(::tokenizers::Tokenizer::FromBlobJSON(json_blob).release());
+}
+
+std::vector Tokenizer::Encode(const std::string &text) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Encode");
+ return Get(this)->Encode(text);
+}
+
+std::vector> Tokenizer::EncodeBatch(
+ const std::vector &texts) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::EncodeBatch");
+ return Get(this)->EncodeBatch(texts);
+}
+
+std::string Tokenizer::Decode(const std::vector &ids) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Decode");
+ return Get(this)->Decode(ids);
+}
+size_t Tokenizer::GetVocabSize() { return Get(this)->GetVocabSize(); }
+std::string Tokenizer::IdToToken(int32_t token_id) {
+ return Get(this)->IdToToken(token_id);
+}
+int32_t Tokenizer::TokenToId(const std::string &token) {
+ return Get(this)->TokenToId(token);
+}
+
+} // namespace shortfin::tokenizers
diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers.h b/shortfin/src/shortfin/components/tokenizers/tokenizers.h
new file mode 100644
index 000000000..d263eace6
--- /dev/null
+++ b/shortfin/src/shortfin/components/tokenizers/tokenizers.h
@@ -0,0 +1,52 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H
+#define SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H
+
+#include
+#include
+
+#include "shortfin/support/api.h"
+
+namespace shortfin::tokenizers {
+
+// A vendored Tokenizer class that does not export the details of the backing
+// implementation. While a little bit gross, this keeps us from needing to
+// re-export a vendor'ed API as part of our public API.
+// The current vendor tokenizer is based on mlc-ai/tokenizers-cpp. The API
+// is fairly close to that implementation.
+// See: https://github.com/mlc-ai/tokenizers-cpp
+class SHORTFIN_API Tokenizer {
+ public:
+ Tokenizer(const Tokenizer &) = delete;
+ Tokenizer &operator=(const Tokenizer &) = delete;
+ Tokenizer(Tokenizer &&other) : vendor_tokenizer_(other.vendor_tokenizer_) {
+ vendor_tokenizer_ = nullptr;
+ }
+ ~Tokenizer();
+
+ // Factory functions.
+ static Tokenizer FromBlobJSON(const std::string &json_blob);
+
+ std::vector Encode(const std::string &text);
+ std::vector> EncodeBatch(
+ const std::vector &texts);
+ std::string Decode(const std::vector &ids);
+ size_t GetVocabSize();
+ std::string IdToToken(int32_t token_id);
+ int32_t TokenToId(const std::string &token);
+
+ private:
+ Tokenizer(void *vendor_tokenizer) : vendor_tokenizer_(vendor_tokenizer) {}
+
+ protected:
+ void *vendor_tokenizer_;
+};
+
+} // namespace shortfin::tokenizers
+
+#endif // SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H
diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc b/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc
new file mode 100644
index 000000000..674721653
--- /dev/null
+++ b/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc
@@ -0,0 +1,56 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "shortfin/components/tokenizers/tokenizers.h"
+
+#include
+#include
+
+#include
+#include
+
+using namespace shortfin::tokenizers;
+
+namespace {
+
+std::string ReadFile(std::filesystem::path path) {
+ std::ifstream in(path);
+ std::ostringstream out;
+ out << in.rdbuf();
+ return out.str();
+}
+
+} // namespace
+
+// TODO: Enable once upstream changes with error handling have landed.
+// Currently aborts.
+// See: https://github.com/mlc-ai/tokenizers-cpp/issues/50
+// TEST(TokenizersTest, FromIllegalBlobJson) {
+// auto tok = Tokenizer::FromBlobJSON("foobar");
+// }
+
+TEST(TokenizersTest, BasicTokenizerJson) {
+ std::filesystem::path tokenizer_path(
+ "src/shortfin/components/tokenizers/tokenizer.json");
+ auto tokenizer_json = ReadFile(tokenizer_path);
+ ASSERT_GT(tokenizer_json.size(), 0)
+ << "reading " << tokenizer_path
+ << " (cwd: " << std::filesystem::current_path() << ")";
+ auto tok = Tokenizer::FromBlobJSON(tokenizer_json);
+ EXPECT_GT(tok.GetVocabSize(), 100); // Sanity check
+ auto encoded = tok.Encode("hello world");
+ EXPECT_THAT(encoded,
+ ::testing::ContainerEq(std::vector{19082, 1362}));
+ auto batch_encoded = tok.EncodeBatch({"hello", "world"});
+ ASSERT_EQ(batch_encoded.size(), 2);
+ EXPECT_THAT(batch_encoded[0],
+ ::testing::ContainerEq(std::vector{19082}));
+ EXPECT_THAT(batch_encoded[1],
+ ::testing::ContainerEq(std::vector{1362}));
+ EXPECT_EQ(tok.TokenToId("hello"), 19082);
+ auto decoded = tok.Decode(encoded);
+ EXPECT_EQ(decoded, "hello world");
+}
diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py
index 7c792d92b..164dfb479 100644
--- a/shortfin/tests/api/array_ops_test.py
+++ b/shortfin/tests/api/array_ops_test.py
@@ -268,3 +268,171 @@ def test_nearest_int_conversion(device, dtype, out_dtype, sfnp_func, ref_round_f
assert output.dtype == out_dtype
for ref, actual in zip(ref_rounded, output.items):
assert ref == int(actual)
+
+
+def test_elementwise_forms(device):
+ # All elementwise ops use the same template expansion which enforces
+ # certain common invariants. Here we test these on the multiply op,
+ # relying on a parametric test for actual behavior.
+ with pytest.raises(
+ ValueError,
+ match="Elementwise operators require at least one argument to be a device_array",
+ ):
+ sfnp.multiply(2, 2)
+
+ ary = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32)
+ with ary.map(discard=True) as m:
+ m.fill(42.0)
+
+ # Rhs scalar int accepted.
+ result = sfnp.multiply(ary, 2)
+ assert list(result.items) == [84.0] * 6
+
+ # Rhs scalar float accepted.
+ result = sfnp.multiply(ary, 2.0)
+ assert list(result.items) == [84.0] * 6
+
+ # Lhs scalar int accepted.
+ result = sfnp.multiply(2, ary)
+ assert list(result.items) == [84.0] * 6
+
+ # Lhs scalar float accepted.
+ result = sfnp.multiply(2.0, ary)
+ assert list(result.items) == [84.0] * 6
+
+ # Out.
+ out = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32)
+ sfnp.multiply(2.0, ary, out=out)
+ assert list(out.items) == [84.0] * 6
+
+
+@pytest.mark.parametrize(
+ "lhs_dtype,rhs_dtype,promoted_dtype",
+ [
+ (sfnp.float32, sfnp.float16, sfnp.float32),
+ (sfnp.float16, sfnp.float32, sfnp.float32),
+ (sfnp.float32, sfnp.float64, sfnp.float64),
+ (sfnp.float64, sfnp.float32, sfnp.float64),
+ # Integer promotion.
+ (sfnp.uint8, sfnp.uint16, sfnp.uint16),
+ (sfnp.uint16, sfnp.uint32, sfnp.uint32),
+ (sfnp.uint32, sfnp.uint64, sfnp.uint64),
+ (sfnp.int8, sfnp.int16, sfnp.int16),
+ (sfnp.int16, sfnp.int32, sfnp.int32),
+ (sfnp.int32, sfnp.int64, sfnp.int64),
+ # Signed/unsigned promotion.
+ (sfnp.int8, sfnp.uint8, sfnp.int16),
+ (sfnp.int16, sfnp.uint16, sfnp.int32),
+ (sfnp.int32, sfnp.uint32, sfnp.int64),
+ (sfnp.int8, sfnp.uint32, sfnp.int64),
+ ],
+)
+def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
+ # Tests that promotion infers an appropriate result type.
+ lhs = sfnp.device_array.for_host(device, [2, 3], lhs_dtype)
+ rhs = sfnp.device_array.for_host(device, [2, 3], rhs_dtype)
+ result = sfnp.multiply(lhs, rhs)
+ assert result.dtype == promoted_dtype
+
+
+@pytest.mark.parametrize(
+ "dtype,op,check_value",
+ [
+ # Add.
+ (sfnp.int8, sfnp.add, 44.0),
+ (sfnp.int16, sfnp.add, 44.0),
+ (sfnp.int32, sfnp.add, 44.0),
+ (sfnp.int64, sfnp.add, 44.0),
+ (sfnp.uint8, sfnp.add, 44.0),
+ (sfnp.uint16, sfnp.add, 44.0),
+ (sfnp.uint32, sfnp.add, 44.0),
+ (sfnp.uint64, sfnp.add, 44.0),
+ (sfnp.float16, sfnp.add, 44.0),
+ (sfnp.float32, sfnp.add, 44.0),
+ (sfnp.float64, sfnp.add, 44.0),
+ # Divide.
+ (sfnp.int8, sfnp.divide, 21.0),
+ (sfnp.int16, sfnp.divide, 21.0),
+ (sfnp.int32, sfnp.divide, 21.0),
+ (sfnp.int64, sfnp.divide, 21.0),
+ (sfnp.uint8, sfnp.divide, 21.0),
+ (sfnp.uint16, sfnp.divide, 21.0),
+ (sfnp.uint32, sfnp.divide, 21.0),
+ (sfnp.uint64, sfnp.divide, 21.0),
+ (sfnp.float16, sfnp.divide, 21.0),
+ (sfnp.float32, sfnp.divide, 21.0),
+ (sfnp.float64, sfnp.divide, 21.0),
+ # Multiply.
+ (sfnp.int8, sfnp.multiply, 84.0),
+ (sfnp.int16, sfnp.multiply, 84.0),
+ (sfnp.int32, sfnp.multiply, 84.0),
+ (sfnp.int64, sfnp.multiply, 84.0),
+ (sfnp.uint8, sfnp.multiply, 84.0),
+ (sfnp.uint16, sfnp.multiply, 84.0),
+ (sfnp.uint32, sfnp.multiply, 84.0),
+ (sfnp.uint64, sfnp.multiply, 84.0),
+ (sfnp.float16, sfnp.multiply, 84.0),
+ (sfnp.float32, sfnp.multiply, 84.0),
+ (sfnp.float64, sfnp.multiply, 84.0),
+ # Subtract.
+ (sfnp.int8, sfnp.subtract, 40.0),
+ (sfnp.int16, sfnp.subtract, 40.0),
+ (sfnp.int32, sfnp.subtract, 40.0),
+ (sfnp.int64, sfnp.subtract, 40.0),
+ (sfnp.uint8, sfnp.subtract, 40.0),
+ (sfnp.uint16, sfnp.subtract, 40.0),
+ (sfnp.uint32, sfnp.subtract, 40.0),
+ (sfnp.uint64, sfnp.subtract, 40.0),
+ (sfnp.float16, sfnp.subtract, 40.0),
+ (sfnp.float32, sfnp.subtract, 40.0),
+ (sfnp.float64, sfnp.subtract, 40.0),
+ ],
+)
+def test_elementwise_array_correctness(device, dtype, op, check_value):
+ lhs = sfnp.device_array.for_host(device, [2, 2], sfnp.int32)
+ with lhs.map(discard=True) as m:
+ m.fill(42)
+
+ rhs = sfnp.device_array.for_host(device, [2], sfnp.int32)
+ with rhs.map(discard=True) as m:
+ m.fill(2)
+
+ lhs = sfnp.convert(lhs, dtype=dtype)
+ rhs = sfnp.convert(rhs, dtype=dtype)
+ result = op(lhs, rhs)
+ assert result.shape == [2, 2]
+ result = sfnp.convert(result, dtype=sfnp.float32)
+ items = list(result.items)
+ assert items == [check_value] * 4
+
+
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ sfnp.int8,
+ sfnp.int16,
+ sfnp.int32,
+ sfnp.int64,
+ sfnp.uint8,
+ sfnp.uint16,
+ sfnp.uint32,
+ sfnp.uint64,
+ sfnp.float32,
+ sfnp.float16,
+ sfnp.float32,
+ sfnp.float64,
+ ],
+)
+def test_transpose(device, dtype):
+ input = sfnp.device_array.for_host(device, [3, 2], sfnp.int32)
+ input.items = [0, 1, 2, 3, 4, 5]
+ input = sfnp.convert(input, dtype=dtype)
+ permuted = sfnp.transpose(input, [1, 0])
+ assert permuted.shape == [2, 3]
+ items = list(sfnp.convert(permuted, dtype=sfnp.int32).items)
+ assert items == [0, 2, 4, 1, 3, 5]
+
+ out = sfnp.device_array.for_host(device, [2, 3], dtype)
+ sfnp.transpose(input, [1, 0], out=out)
+ items = list(sfnp.convert(permuted, dtype=sfnp.int32).items)
+ assert items == [0, 2, 4, 1, 3, 5]
diff --git a/shortfin/tests/api/array_use_case_test.py b/shortfin/tests/api/array_use_case_test.py
new file mode 100644
index 000000000..d4a030d45
--- /dev/null
+++ b/shortfin/tests/api/array_use_case_test.py
@@ -0,0 +1,64 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import array
+import math
+import pytest
+
+import shortfin as sf
+import shortfin.array as sfnp
+
+
+@pytest.fixture
+def lsys():
+ # TODO: Port this test to use memory type independent access. It currently
+ # presumes unified memory.
+ # sc = sf.SystemBuilder()
+ sc = sf.host.CPUSystemBuilder()
+ lsys = sc.create_system()
+ yield lsys
+ lsys.shutdown()
+
+
+@pytest.fixture
+def fiber(lsys):
+ return lsys.create_fiber()
+
+
+@pytest.fixture
+def device(fiber):
+ return fiber.device(0)
+
+
+# Tests a typical image conversion from a model oriented layout to an array
+# of contained images.
+def test_image_to_bytes(device):
+ bs = 2
+ height = 16
+ width = 12
+ images_shape = [bs, 3, height, width]
+ images_planar = sfnp.device_array.for_host(device, images_shape, sfnp.float32)
+ # Band the data so that each channel increases by 0.1 across images.
+ for i in range(bs):
+ for j in range(3):
+ data = [i * 0.3 + j * 0.1 for _ in range(height * width)]
+ images_planar.view(i, j).items = data
+ images_planar = sfnp.convert(images_planar, dtype=sfnp.float16)
+
+ # Extract and convert each image to interleaved RGB bytes.
+ images = []
+ for idx in range(images_planar.shape[0]):
+ image_planar = images_planar.view(idx)
+ assert image_planar.shape == [1, 3, 16, 12]
+ image_interleaved = sfnp.transpose(image_planar, (0, 2, 3, 1))
+ assert image_interleaved.shape == [1, 16, 12, 3]
+ image_scaled = sfnp.multiply(image_interleaved, 255)
+ image = sfnp.round(image_scaled, dtype=sfnp.uint8)
+ image_bytes = bytes(image.map(read=True))
+ images.append(image_bytes)
+
+ assert images[0] == b"\x00\x1a3" * 192
+ assert images[1] == b"Mf\x80" * 192
diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py
deleted file mode 100644
index 169d082b1..000000000
--- a/shortfin/tests/apps/llm/components/cache_test.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-"""
-Tests for llm kvcache component.
-"""
-
-import pytest
-import time
-import tempfile
-import shortfin as sf
-from _shortfin import lib as sfl
-from shortfin_apps.llm.components import cache
-from shortfin_apps.llm.components import config_struct
-import json
-from pathlib import Path
-
-
-@pytest.fixture
-def lsys():
- sc = sfl.local.host.CPUSystemBuilder()
- ls = sc.create_system()
- yield ls
- ls.shutdown()
-
-
-@pytest.fixture
-def fiber(lsys):
- # TODO: Should adopt the main thread.
- worker = lsys.create_worker("main")
- return lsys.create_fiber(worker)
-
-
-@pytest.fixture
-def device(fiber):
- return fiber.device(0)
-
-
-@pytest.fixture
-def model_params():
- model_params = {
- "module_name": "module",
- "module_abi_version": 1,
- "max_seq_len": 2048,
- "attn_head_count": 32,
- "attn_head_dim": 100,
- "prefill_batch_sizes": [4],
- "decode_batch_sizes": [4],
- "transformer_block_count": 26,
- "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
- }
-
- # Create a temporary file to store the JSON
- with tempfile.NamedTemporaryFile(
- mode="w", suffix=".json", delete=False
- ) as tmp_file:
- json.dump(model_params, tmp_file, indent=4)
- tmp_path = Path(tmp_file.name)
-
- try:
- # Load the JSON using config_struct
- model_params = config_struct.ModelParams.load_json(tmp_path)
- yield model_params
- finally:
- tmp_path.unlink
-
-
-@pytest.fixture
-def cache_fixture(fiber, model_params) -> cache.AttnPageCache:
- # Create and return the cache object
- return cache.AttnPageCache(
- devices=fiber.devices_dict.values(), model_params=model_params
- )
-
-
-@pytest.mark.parametrize("n_allocated", [1, 16, 255])
-def test_alloc(
- cache_fixture: cache.AttnPageCache,
- n_allocated,
- model_params: config_struct.ModelParams,
-):
- alloc_page_count = cache_fixture.page_tables[0].shape[0]
-
- assert alloc_page_count == model_params.paged_kv_cache.device_block_count
-
- pages = cache_fixture.acquire_free_pages(n_allocated)
- last_page = alloc_page_count - 1
- expected_indices = range(last_page, last_page - n_allocated, -1)
- for p, expected_ix in zip(pages, expected_indices):
- assert p.index == expected_ix
- assert p.index > 0
diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py
new file mode 100644
index 000000000..a1ec00c07
--- /dev/null
+++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py
@@ -0,0 +1,57 @@
+import pytest
+import logging
+from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PagePoolConfig
+import shortfin as sf
+import shortfin.host
+import shortfin.array as sfnp
+import shortfin.amdgpu
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture
+def setup_pool(generic_device):
+ pool = PagePool(
+ devices=[generic_device],
+ config=PagePoolConfig(
+ alloc_page_count=256,
+ dtype=sfnp.float16,
+ paged_kv_block_size_elements=393216,
+ ),
+ )
+ return pool
+
+
+def test_page_acquisition(setup_pool):
+ pool = setup_pool
+ logger.info(f"=== Running page acquisition test on system ===")
+ page0 = pool.acquire_free_pages(1)
+ assert page0 is not None, f"Failed to acquire a free page on system"
+ logger.info(f"Successfully acquired page on system")
+
+
+def test_page_copy(setup_pool):
+ pool = setup_pool
+ logger.info(f"=== Running page copy test on system ===")
+ (page0,) = pool.acquire_free_pages(1)
+ page1 = pool.copy_page(page0)
+ assert page1 is not None, f"Failed to copy a page on system"
+ assert page0 != page1, f"Copied page should be different from original on system"
+ logger.info(f"Successfully copied page on system")
+
+
+@pytest.fixture(autouse=True)
+def setup_logging():
+ """Set up logging format to include timestamp and level"""
+ logging.basicConfig(
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO,
+ force=True,
+ )
+
+
+# Add more tests as needed
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py
index 083698968..b16d5a3c9 100644
--- a/shortfin/tests/conftest.py
+++ b/shortfin/tests/conftest.py
@@ -50,6 +50,17 @@ def pytest_runtest_setup(item):
sf.SystemBuilder.default_system_type = system_type
+# Dynamic Parameterization for lsys Fixture
+def pytest_generate_tests(metafunc):
+ if "generic_lsys" in metafunc.fixturenames:
+ system = metafunc.config.getoption("--system")
+ if system == "amdgpu":
+ params = ["cpu", "amdgpu"]
+ else:
+ params = [system]
+ metafunc.parametrize("generic_lsys", params, indirect=True)
+
+
# Keys that will be cleaned project wide prior to and after each test run.
# Test code can freely modify these.
CLEAN_ENV_KEYS = [
@@ -96,6 +107,28 @@ def kill():
kill()
+@pytest.fixture(scope="session")
+def generic_lsys(request):
+ system_type = request.param
+ if system_type == "cpu" or system_type == "hostcpu":
+ sc = sf.host.CPUSystemBuilder()
+ elif system_type == "amdgpu":
+ sc = sf.amdgpu.SystemBuilder()
+ lsys = sc.create_system()
+ yield lsys
+ lsys.shutdown()
+
+
+@pytest.fixture
+def generic_fiber(generic_lsys):
+ return generic_lsys.create_fiber()
+
+
+@pytest.fixture
+def generic_device(generic_fiber):
+ return generic_fiber.device(0)
+
+
@pytest.fixture
def cpu_lsys():
sc = sf.host.CPUSystemBuilder()
diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py
index 5786a9fff..f09e08888 100644
--- a/tuner/tuner/candidate_gen.py
+++ b/tuner/tuner/candidate_gen.py
@@ -30,6 +30,8 @@
from iree.compiler import ir # type: ignore
+from iree.compiler.dialects import iree_codegen # type: ignore
+
from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *
@@ -50,7 +52,7 @@ def apply_configuration(
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
- repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
+ repl0 = f""
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
@@ -117,7 +119,6 @@ def get_transform_function_mmt(
wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
-
return f"""
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
@@ -130,7 +131,7 @@ def get_transform_function_mmt(
translation_info = #iree_codegen.translation_info,
+ intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
@@ -203,7 +204,7 @@ def get_transform_function_conv(
translation_info = #iree_codegen.translation_info,
+ intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
@@ -264,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt(
translation_info = #iree_codegen.translation_info,
+ intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
@@ -344,7 +345,7 @@ def get_transform_function_batch_mmt(
translation_info = #iree_codegen.translation_info,
+ intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
@@ -412,7 +413,7 @@ def get_transform_function_batch_matmul(
translation_info = #iree_codegen.translation_info,
+ intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
@@ -535,13 +536,19 @@ def tune(
walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)
+ variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module)
+ assert len(variant_op_list) == 1, "Expect one executable variant op"
+ variant_op = variant_op_list[0]
+ # Get the MMA intrinisic intructions supported by the target.
+ mma_list = iree_codegen.query_mma_intrinsics(variant_op)
+
dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
- generate_solutions(tune_logger, problem_size, num_subgroups)
+ generate_solutions(tune_logger, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py
index 36fb87cbb..d81278e8c 100644
--- a/tuner/tuner/candidate_gen_test.py
+++ b/tuner/tuner/candidate_gen_test.py
@@ -13,6 +13,7 @@
from typing import Generator
from iree.compiler import ir # type: ignore
+from iree.compiler.dialects import iree_gpu # type: ignore
from . import candidate_gen
from . import common
@@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:
M, N, K = 2048, 1280, 1280
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=16,
workgroup_size=[16, 16, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[8, 8, 8],
subgroup_m_count=16,
subgroup_n_count=16,
@@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
@@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.contraction,
)
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ intrinsic=mma_attr,
tile_sizes=[480, 384, 32],
subgroup_m_count=1,
subgroup_n_count=4,
@@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_matmul,
)
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ intrinsic=mma_attr,
tile_sizes=[416, 320, 128],
subgroup_m_count=2,
subgroup_n_count=2,
@@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
@@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
- intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
+ intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
@@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.broadcast_rhs_mmt,
)
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
- intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
+ intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py
index a34f172eb..45ae48c22 100644
--- a/tuner/tuner/common.py
+++ b/tuner/tuner/common.py
@@ -12,6 +12,8 @@
from iree.compiler import ir # type: ignore
+from iree.compiler.dialects import iree_gpu # type: ignore
+
class CommonTypes:
def __init__(self, ctx: ir.Context):
@@ -83,65 +85,24 @@ def MNK(self) -> tuple[int, int, int]:
return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)
-@dataclass
-class MfmaIntrinsic:
- output_type: ir.IntegerType | ir.FloatType
- m: int
- n: int
- k: int
- input_type: ir.IntegerType | ir.FloatType
-
- def __str__(self) -> str:
- input = str(self.input_type).upper()
- output = str(self.output_type).upper()
- return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}"
-
- @staticmethod
- def mfma_f32_16x16x16_f16():
- f16 = ir.F16Type.get()
- f32 = ir.F32Type.get()
- return MfmaIntrinsic(f32, 16, 16, 16, f16)
-
- @staticmethod
- def mfma_f32_32x32x8_f16():
- f16 = ir.F16Type.get()
- f32 = ir.F32Type.get()
- return MfmaIntrinsic(f32, 32, 32, 8, f16)
-
- @staticmethod
- def mfma_i32_16x16x32_i8():
- i32 = ir.IntegerType.get_signless(32)
- i8 = ir.IntegerType.get_signless(8)
- return MfmaIntrinsic(i32, 16, 16, 32, i8)
-
- @staticmethod
- def mfma_i32_32x32x16_i8():
- i32 = ir.IntegerType.get_signless(32)
- i8 = ir.IntegerType.get_signless(8)
- return MfmaIntrinsic(i32, 32, 32, 16, i8)
-
- @staticmethod
- def all():
- return [
- MfmaIntrinsic.mfma_f32_16x16x16_f16(),
- MfmaIntrinsic.mfma_f32_32x32x8_f16(),
- MfmaIntrinsic.mfma_i32_16x16x32_i8(),
- MfmaIntrinsic.mfma_i32_32x32x16_i8(),
- ]
-
-
-def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]:
- def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
- if problem_size.res_type.element_type != intrinsic.output_type:
+def get_compatible_mfma_intrinsics(
+ problem_size: ProblemSize,
+ mma_intrinsics: list[iree_gpu.MMAIntrinsic],
+) -> list[iree_gpu.MMAIntrinsic]:
+ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
+ mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma
+ a_type, b_type, c_type = mma_attr.abc_element_types
+ if problem_size.res_type.element_type != c_type:
return False
if problem_size.dispatch_kind != DispatchKind.batch_matmul:
- if problem_size.lhs_type.element_type != intrinsic.input_type:
- return False
- if problem_size.rhs_type.element_type != intrinsic.input_type:
+ if (
+ problem_size.lhs_type.element_type != a_type
+ or problem_size.rhs_type.element_type != b_type
+ ):
return False
return True
- return list(filter(is_compatible, MfmaIntrinsic.all()))
+ return list(filter(is_comptible, mma_intrinsics))
class ReorderWorkgroupsStrategy(Enum):
@@ -186,7 +147,7 @@ def __str__(self) -> str:
class Configuration:
subgroup_size: int
workgroup_size: list[int]
- intrinsic: MfmaIntrinsic
+ intrinsic: iree_gpu.MMAAttr
tile_sizes: list[int]
subgroup_m_count: int
subgroup_n_count: int
diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py
index 891d703e2..ea0a4573d 100644
--- a/tuner/tuner/common_test.py
+++ b/tuner/tuner/common_test.py
@@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
-Usage: python -m pytest candidate_gen_test.py
+Usage: python -m pytest common_test.py
"""
import pytest
@@ -14,6 +14,7 @@
from typing import Generator
from iree.compiler import ir # type: ignore
+from iree.compiler.dialects import iree_gpu # type: ignore
@pytest.fixture
@@ -71,10 +72,12 @@ def test_gpu_pipeline_options() -> None:
def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=32,
workgroup_size=[16, 16, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
@@ -96,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
)
-def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None:
- assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16"
- assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8"
-
-
def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
@@ -109,10 +107,14 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.f16),
common.ShapedType([2048, 1280], tuner_ctx.type.f32),
common.DispatchKind.mmt,
- )
+ ),
+ [
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ ],
) == [
- common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
- common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]
assert common.get_compatible_mfma_intrinsics(
@@ -122,10 +124,14 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.i8),
common.ShapedType([2048, 1280], tuner_ctx.type.i32),
common.DispatchKind.mmt,
- )
+ ),
+ [
+ iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
+ iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
+ ],
) == [
- common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
- common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
+ iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
+ iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
]
assert common.get_compatible_mfma_intrinsics(
@@ -135,8 +141,44 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
- )
+ ),
+ [
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ ],
) == [
- common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
- common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]
+
+ assert common.get_compatible_mfma_intrinsics(
+ common.ProblemSize(
+ common.MatmulSize(968, 320, 640, 64),
+ common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
+ common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
+ common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
+ common.DispatchKind.batch_matmul,
+ ),
+ [
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ ],
+ ) == [
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ ]
+
+ assert (
+ common.get_compatible_mfma_intrinsics(
+ common.ProblemSize(
+ common.MatmulSize(968, 320, 640, 64),
+ common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
+ common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
+ common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
+ common.DispatchKind.batch_matmul,
+ ),
+ [
+ iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
+ iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
+ ],
+ )
+ == []
+ )
diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py
index edd7ccc38..f16b4a241 100644
--- a/tuner/tuner/dispatch_constraints.py
+++ b/tuner/tuner/dispatch_constraints.py
@@ -10,6 +10,9 @@
import z3 # type: ignore
from typing import Iterator
+
+from iree.compiler.dialects import iree_gpu # type: ignore
+
from .common import *
@@ -18,13 +21,22 @@ def get_mfma_intrinsic_constraints(
intrinsic_m: z3.ArithRef,
intrinsic_n: z3.ArithRef,
intrinsic_k: z3.ArithRef,
+ mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> z3.BoolRef:
- compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size)
+ compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics)
assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
+
+ mma_attrs = [iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics]
+ mnk_shapes = [mma_attr.mnk_shape for mma_attr in mma_attrs]
+
return z3.Or(
*(
- z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k)
- for mfma in compatible_intrinsics
+ z3.And(
+ intrinsic_m == m,
+ intrinsic_n == n,
+ intrinsic_k == k,
+ )
+ for m, n, k in mnk_shapes
)
)
@@ -68,6 +80,7 @@ def generate_constraints(
subgroup_m_count,
subgroup_n_count,
waves_per_eu,
+ mma_intrinsics: list[iree_gpu.MMAIntrinsic],
):
M, N, K = (
problem_size.matmul_size.M,
@@ -82,7 +95,7 @@ def generate_constraints(
constraints += [subgroup_size == 64, wg_threads <= 1024]
constraints += [
get_mfma_intrinsic_constraints(
- problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k
+ problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics
)
]
subgroup_k_count = 1
@@ -129,8 +142,40 @@ def generate_constraints(
return constraints
+def getMMAAttr(
+ output_type: ir.IntegerType | ir.FloatType,
+ m: int,
+ n: int,
+ k: int,
+ lhs_type: ir.IntegerType | ir.FloatType,
+ rhs_type: ir.IntegerType | ir.FloatType,
+) -> iree_gpu.MMAAttr:
+ for mma_intrinsic in iree_gpu.MMAIntrinsic:
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
+ a_type, b_type, c_type = mma_attr.abc_element_types
+ mnk = mma_attr.mnk_shape
+ if (
+ a_type == lhs_type
+ and b_type == rhs_type
+ and c_type == output_type
+ and m == mnk[0]
+ and n == mnk[1]
+ and k == mnk[2]
+ ):
+ return mma_attr
+ # If no matching intrinsic is found, raise an exception
+ raise ValueError(
+ f"No matching MMA intrinsic found for "
+ f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, "
+ f"m={m}, n={n}, k={k}."
+ )
+
+
def generate_solutions(
- logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int
+ logger: logging.Logger,
+ problem_size: ProblemSize,
+ num_subgrups: int,
+ mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> Iterator[Configuration]:
M, N, K = problem_size.MNK
logger.info(f"{M},{N},{K}")
@@ -168,6 +213,7 @@ def generate_solutions(
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
+ mma_intrinsics,
)
solver.add(z3.simplify(z3.And(constraints)))
logger.debug(f"Initial constraints: {solver}")
@@ -179,12 +225,13 @@ def generate_solutions(
config = Configuration(
lookup(subgroup_size),
[lookup(wg_x), lookup(wg_y), lookup(wg_z)],
- MfmaIntrinsic(
+ getMMAAttr(
problem_size.res_type.element_type,
lookup(intrinsic_mn),
lookup(intrinsic_mn),
lookup(intrinsic_k),
problem_size.lhs_type.element_type,
+ problem_size.rhs_type.element_type,
),
[lookup(m), lookup(n), lookup(k)],
lookup(sg_m_cnt),
diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py
index 7e1a5c55d..9de4beeee 100644
--- a/tuner/tuner/dispatch_constraints_test.py
+++ b/tuner/tuner/dispatch_constraints_test.py
@@ -14,6 +14,7 @@
from typing import Generator
from iree.compiler import ir # type: ignore
+from iree.compiler.dialects import iree_gpu # type: ignore
from . import common
from . import dispatch_constraints
@@ -37,7 +38,18 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
- configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4)
+ configs = dispatch_constraints.generate_solutions(
+ tuner_ctx.logger,
+ problem_size,
+ 4,
+ [
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
+ iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
+ ],
+ )
+
assert configs is not None
@@ -115,6 +127,12 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
+ [
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
+ iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
+ ],
)
solver = z3.Solver()
@@ -160,6 +178,12 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
+ [
+ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
+ iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
+ iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
+ iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
+ ],
)
constraints.append(m > 1000) # Adding an additional unsatisfiable constraint
diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py
index d3a99806f..fb10b04bc 100644
--- a/tuner/tuner/dispatch_parser_test.py
+++ b/tuner/tuner/dispatch_parser_test.py
@@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
-Usage: python -m pytest candidate_gen_test.py
+Usage: python -m pytest dispatch_parser_test.py
"""
import pytest
@@ -14,6 +14,7 @@
from iree.compiler import ir # type: ignore
from iree.compiler.dialects import func # type: ignore
+from iree.compiler.dialects import iree_gpu # type: ignore
from . import common
from . import dispatch_parser
@@ -39,10 +40,12 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None:
def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = dispatch_parser.Configuration(
subgroup_size=0,
workgroup_size=[],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[128, 320, 32],
subgroup_m_count=0,
subgroup_n_count=0,
@@ -53,10 +56,12 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:
def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None:
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = dispatch_parser.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
@@ -75,10 +80,12 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None:
def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None:
+ mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
+ mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = dispatch_parser.Configuration(
subgroup_size=32,
workgroup_size=[16, 16, 1],
- intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ intrinsic=mma_attr,
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,