diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci-libshortfin.yml similarity index 54% rename from .github/workflows/ci_linux_x64-libshortfin.yml rename to .github/workflows/ci-libshortfin.yml index afeca11a6..33a6df72b 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci-libshortfin.yml @@ -10,13 +10,13 @@ on: workflow_dispatch: pull_request: paths: - - '.github/workflows/ci_linux_x64-libshortfin.yml' + - '.github/workflows/ci-libshortfin.yml' - 'shortfin/**' push: branches: - main paths: - - '.github/workflows/ci_linux_x64-libshortfin.yml' + - '.github/workflows/ci-libshortfin.yml' - 'shortfin/**' permissions: @@ -36,17 +36,55 @@ env: jobs: build-and-test: - name: Build and test - runs-on: ubuntu-24.04 + name: "Unit tests :: ${{ matrix.name }} :: ${{ matrix.python-version }}" + runs-on: ${{ matrix.runs-on }} + defaults: + run: + shell: bash strategy: + fail-fast: false matrix: + name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Ubuntu (GCC)", "Windows (MSVC)"] python-version: ["3.10", "3.11", "3.12"] + include: + - name: Ubuntu (Clang)(full) + runs-on: ubuntu-24.04 + cmake-options: + -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD + additional-packages: clang lld + - name: Ubuntu (Clang)(host-only) + runs-on: ubuntu-24.04 + # In this configuration, also build static+dynamic in order to verify + # that path structurally works. + cmake-options: + -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD -DSHORTFIN_HAVE_AMDGPU=OFF -DSHORTFIN_BUILD_STATIC=ON -DSHORTFIN_BUILD_DYNAMIC=ON + additional-packages: clang lld + - name: Ubuntu (GCC) + runs-on: ubuntu-24.04 + - name: Windows (MSVC) + runs-on: windows-2022 + exclude: + # Only test Python 3.12 with GCC + - name: Ubuntu (GCC) + python-version: "3.10" + - name: Ubuntu (GCC) + python-version: "3.11" + # TODO: Include additional Python versions for Windows after build got fixed + - name: Windows (MSVC) + python-version: "3.10" + - name: Windows (MSVC) + python-version: "3.11" steps: - - name: Install dependencies + - name: (Linux) Install dependencies + if: "runner.os == 'Linux'" run: | sudo apt update - sudo apt install clang lld cmake ninja-build + sudo apt install cmake ninja-build ${{matrix.additional-packages}} + + - name: (Windows) Configure MSVC + if: "runner.os == 'Windows'" + uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0 - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -70,56 +108,42 @@ jobs: git submodule update --init --depth 1 -- third_party/googletest git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - name: Setup Python ${{ matrix.python-version }} + - name: "Setup Python ${{ matrix.python-version }}" uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} cache: "pip" + cache-dependency-path: | + 'shortfin/requirements-tests.txt' + 'shortfin/requirements-iree-compiler.txt' - name: Install Python packages - # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | pip install -r requirements-tests.txt pip install -r requirements-iree-compiler.txt pip freeze - - name: Build shortfin (full) + - name: Build shortfin working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | mkdir build cmake -GNinja \ -S. \ -Bbuild \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_BUNDLE_DEPS=ON \ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ + ${{matrix.cmake-options}} cmake --build build --target all - pip install -v -e build/ - - name: Test shortfin (full) + - name: pip install shortfin + if: ${{ matrix.name != 'Ubuntu (Clang)(host-only)'}} working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - ctest --timeout 30 --output-on-failure --test-dir build - pytest -s + pip install -v -e build/ - - name: Build shortfin (host-only) + - name: Test shortfin + if: ${{ matrix.name != 'Ubuntu (Clang)(host-only)'}} working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - mkdir build-host-only - # In this configuration, also build static+dynamic in order to verify - # that path structurally works. - cmake -GNinja \ - -S. \ - -Bbuild-host-only \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - -DSHORTFIN_HAVE_AMDGPU=OFF \ - -DSHORTFIN_BUILD_STATIC=ON \ - -DSHORTFIN_BUILD_DYNAMIC=ON - cmake --build build-host-only --target all + ctest --timeout 30 --output-on-failure --test-dir build + pytest -s diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 34e91cebb..644066094 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -76,14 +76,14 @@ jobs: iree-base-runtime - name: Run llama tests - run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/index.html + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/llm/llama/benchmark/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} - publish_dir: ./out/llm/llama/benchmarks - destination_dir: ./llm/llama/benchmarks + publish_dir: ./out/llm/llama/benchmark + destination_dir: ./llm/llama/benchmark keep_files: true - name: Upload llama executable files diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index 504e7e5e3..f44e2772b 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -28,7 +28,7 @@ jobs: matrix: version: [3.11] fail-fast: false - runs-on: llama-mi300x-3 + runs-on: mi300x-4 defaults: run: shell: bash @@ -78,7 +78,7 @@ jobs: run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" - name: Launch Shortfin Server - run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html + run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml index 1c382617d..c61756d78 100644 --- a/.github/workflows/ci-sglang-integration-tests.yml +++ b/.github/workflows/ci-sglang-integration-tests.yml @@ -29,7 +29,7 @@ jobs: matrix: version: [3.11] fail-fast: false - runs-on: llama-mi300x-3 + runs-on: mi300x-4 defaults: run: shell: bash diff --git a/.github/workflows/ci-shark-ai.yml b/.github/workflows/ci-shark-ai.yml index bf8007e65..fc85a76a7 100644 --- a/.github/workflows/ci-shark-ai.yml +++ b/.github/workflows/ci-shark-ai.yml @@ -49,7 +49,7 @@ jobs: id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} - name: Install pip deps run: | diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 54aa3c763..0164b6cdc 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - Perplexity +name: CI - sharktank perplexity on: workflow_dispatch: @@ -21,10 +21,10 @@ concurrency: cancel-in-progress: true jobs: - test_perplexity_vmfb: + test_perplexity_iree: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "IREE/vmfb" + name: "Perplexity-IREE" strategy: matrix: version: [3.11] @@ -74,13 +74,21 @@ jobs: iree-base-compiler \ iree-base-runtime - - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + - name: Run perplexity test with IREE + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/perplexity/iree_perplexity + destination_dir: ./llm/llama/perplexity/iree_perplexity + keep_files: true test_perplexity_torch: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "Torch/eager mode" + name: "Perplexity-Torch" strategy: matrix: version: [3.11] @@ -123,5 +131,13 @@ jobs: pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - - name: Run perplexity test in eager mode - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + - name: Run perplexity test with Torch + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/perplexity/torch_perplexity + destination_dir: ./llm/llama/perplexity/torch_perplexity + keep_files: true diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml new file mode 100644 index 000000000..4622f5c57 --- /dev/null +++ b/.github/workflows/ci_eval_short.yaml @@ -0,0 +1,77 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - sharktank perplexity short + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_perplexity_iree: + name: "Llama3.1 8B FP16" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run perplexity test with vmfb + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 0e0e1db2a..550366e1b 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/sd # TODO: Enable further tests and switch to # pytest -s diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml deleted file mode 100644 index 00873c432..000000000 --- a/.github/workflows/ci_windows_x64-libshortfin.yml +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -name: CI - shortfin - Windows - -on: - workflow_dispatch: - pull_request: - paths: - - '.github/workflows/ci_windows_x64-libshortfin.yml' - - 'shortfin/**' - push: - branches: - - main - paths: - - '.github/workflows/ci_windows_x64-libshortfin.yml' - - 'shortfin/**' - -permissions: - contents: read - -concurrency: - # A PR number if a pull request and otherwise the commit hash. This cancels - # queued and in-progress runs for the same PR (presubmit) or commit - # (postsubmit). The workflow name is prepended to avoid conflicts between - # different workflows. - group: ${{ github.workflow }}-${{ github.event.number || github.sha }} - cancel-in-progress: true - -env: - IREE_REPO_DIR: ${{ github.workspace }}/iree - LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ - -jobs: - build-and-test: - name: Build and test - runs-on: windows-2022 - - steps: - - name: Configure MSVC - uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0 - - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - submodules: false - - - name: Checkout IREE repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: iree-org/iree - path: ${{ env.IREE_REPO_DIR }} - submodules: false - ref: iree-3.0.0rc20241118 - - - name: Initalize IREE submodules - working-directory: ${{ env.IREE_REPO_DIR }} - run : | - git submodule update --init --depth 1 -- third_party/benchmark - git submodule update --init --depth 1 -- third_party/cpuinfo/ - git submodule update --init --depth 1 -- third_party/flatcc - git submodule update --init --depth 1 -- third_party/googletest - git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: "3.12" - cache: "pip" - - name: Install Python packages - working-directory: ${{ env.LIBSHORTFIN_DIR }} - run: | - pip install -r requirements-tests.txt - pip install -r requirements-iree-compiler.txt - pip freeze - - - name: Build shortfin (full) - working-directory: ${{ env.LIBSHORTFIN_DIR }} - shell: bash - run: | - mkdir build - cmake -GNinja \ - -S. \ - -Bbuild \ - -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON - cmake --build build --target all - pip install -v -e build/ - - - name: Test shortfin (full) - working-directory: ${{ env.LIBSHORTFIN_DIR }} - run: | - ctest --timeout 30 --output-on-failure --test-dir build - pytest -s diff --git a/README.md b/README.md index 44f1e6113..ae3eac423 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ If you're looking to use SHARK check out our [User Guide](docs/user_guide.md). F -[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) +[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-libshortfin.yml?query=event%3Apush) The shortfin sub-project is SHARK's high performance inference library and serving engine. diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py similarity index 74% rename from app_tests/benchmark_tests/llm/conftest.py rename to app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py index cc354b7eb..1e1c64b24 100644 --- a/app_tests/benchmark_tests/llm/conftest.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py @@ -9,15 +9,22 @@ import pytest import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) -from integration_tests.llm.utils import compile_model, export_paged_llm_v1 +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +) +from integration_tests.llm.utils import ( + compile_model, + export_paged_llm_v1, + download_with_hf_datasets, +) @pytest.fixture(scope="module") def pre_process_model(request, tmp_path_factory): tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") - model_path = request.param["model_path"] + model_name = request.param["model_name"] + model_param_file_name = request.param["model_param_file_name"] settings = request.param["settings"] batch_sizes = request.param["batch_sizes"] @@ -25,6 +32,9 @@ def pre_process_model(request, tmp_path_factory): config_path = tmp_dir / "config.json" vmfb_path = tmp_dir / "model.vmfb" + model_path = tmp_dir / model_param_file_name + download_with_hf_datasets(tmp_dir, model_name) + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) config = { diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py similarity index 76% rename from app_tests/benchmark_tests/llm/sglang_benchmark_test.py rename to app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py index 0de775795..b66904570 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py @@ -4,7 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import json import logging import multiprocessing import os @@ -16,14 +15,14 @@ pytest.importorskip("sglang") from sglang import bench_serving -from utils import SGLangBenchmarkArgs +from .utils import SGLangBenchmarkArgs, log_jsonl_result from integration_tests.llm.utils import ( find_available_port, start_llm_server, ) -logger = logging.getLogger("__name__") +logger = logging.getLogger(__name__) device_settings = { "device_flags": [ @@ -33,30 +32,21 @@ "device": "hip", } -# TODO: Download on demand instead of assuming files exist at this path -MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa") -TOKENIZER_DIR = Path("/data/llama3.1/8b/") - - -def log_jsonl_result(file_path): - with open(file_path, "r") as file: - json_string = file.readline().strip() - - json_data = json.loads(json_string) - for key, val in json_data.items(): - logger.info(f"{key.upper()}: {val}") - @pytest.mark.parametrize( - "request_rate", - [1, 2, 4, 8, 16, 32], + "request_rate,model_param_file_name", + [ + (req_rate, "meta-llama-3.1-8b-instruct.f16.gguf") + for req_rate in [1, 2, 4, 8, 16, 32] + ], ) @pytest.mark.parametrize( "pre_process_model", [ ( { - "model_path": MODEL_PATH, + "model_name": "llama3_8B_fp16", + "model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf", "settings": device_settings, "batch_sizes": [1, 4], } @@ -64,7 +54,9 @@ def log_jsonl_result(file_path): ], indirect=True, ) -def test_sglang_benchmark_server(request_rate, pre_process_model): +def test_sglang_benchmark_server( + request_rate, model_param_file_name, pre_process_model +): # TODO: Remove when multi-device is fixed os.environ["ROCR_VISIBLE_DEVICES"] = "1" @@ -72,7 +64,8 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): config_path = tmp_dir / "config.json" vmfb_path = tmp_dir / "model.vmfb" - tokenizer_path = TOKENIZER_DIR / "tokenizer.json" + tokenizer_path = tmp_dir / "tokenizer.json" + model_path = tmp_dir / model_param_file_name # Start shortfin llm server port = find_available_port() @@ -81,7 +74,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): tokenizer_path, config_path, vmfb_path, - MODEL_PATH, + model_path, device_settings, timeout=30, ) @@ -91,7 +84,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): backend="shortfin", num_prompt=10, base_url=f"http://localhost:{port}", - tokenizer=TOKENIZER_DIR, + tokenizer=tmp_dir, request_rate=request_rate, ) output_file = ( @@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): logger.info("======== RESULTS ========") log_jsonl_result(benchmark_args.output_file) except Exception as e: - logger.info(e) + logger.error(e) server_process.terminate() server_process.wait() diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py similarity index 84% rename from app_tests/benchmark_tests/llm/utils.py rename to app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py index 55b01da04..47cea4d76 100644 --- a/app_tests/benchmark_tests/llm/utils.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py @@ -6,8 +6,12 @@ from argparse import Namespace from dataclasses import dataclass +import json +import logging from pathlib import Path +logger = logging.getLogger(__name__) + @dataclass class SGLangBenchmarkArgs: @@ -54,3 +58,12 @@ def __repr__(self): f"Tokenizer: {self.tokenizer}\n" f"Request Rate: {self.request_rate}" ) + + +def log_jsonl_result(file_path): + with open(file_path, "r") as file: + json_string = file.readline().strip() + + json_data = json.loads(json_string) + for key, val in json_data.items(): + logger.info(f"{key.upper()}: {val}") diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py index 05712039e..80b5b3c09 100644 --- a/app_tests/integration_tests/llm/utils.py +++ b/app_tests/integration_tests/llm/utils.py @@ -15,7 +15,7 @@ import requests from transformers import AutoTokenizer -logger = logging.getLogger("__name__") +logger = logging.getLogger(__name__) class AccuracyValidationException(RuntimeError): diff --git a/docs/amdgpu_kernel_optimization_guide.md b/docs/amdgpu_kernel_optimization_guide.md index 09c5b59f9..91b7f1385 100644 --- a/docs/amdgpu_kernel_optimization_guide.md +++ b/docs/amdgpu_kernel_optimization_guide.md @@ -4,7 +4,7 @@ Author: Jakub Kuderski @kuhar Date: 2024-06-24 -Last Update: 2024-08-22 +Last Update: 2024-11-22 ## Introduction @@ -293,3 +293,124 @@ forms a *clause* that translates to a single data fabric transaction. > [!TIP] > For allocations of 4 GB or less, you can implement predicated loads using the > `buffer` instructions. + +## Data-Parallel Primitives and Warp-level Reduction + +For cross-lane data sharing, the most straightforward way is LDS. Some lanes +write data to some locations on LDS and other lanes read data from LDS. Besides, +there are several instructions can be used to share data cross lanes within a +wavefront/warp. + +Here's a brief introduction of these instructions. Please check out [this +blog](https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/) for +details. + +### ds_permute/ds_bpermute + +`ds_permute`/`ds_bpermute` instructions use LDS hardware for data sharing but +don't actually write to an LDS location. But it still needs `s_waitcnt` +instruction to determine when data is returned to `dest` VGPR. + +Example: +```nasm +ds_bpermute_b32 dest, addr, src [offset:addr_offset] +``` + +### ds_swizzle + +Compared to `ds_bpermute`, the `ds_swizzle` instruction doesn't require an +additional VGPR for offset since it's encoded in the instruction. + +`ds_swizzle` is likely to have less address generation instructions required +than `ds_bpermute`. + +The cons are: +1. It only supports limited patterns. +2. Similar to `ds_bpermute`, `s_waitcnt` is required to wait for the `dest` VGPR. + +Example: +```nasm +ds_swizzle_b32 dest, src offset:ds_pattern +``` + +### Data-Parallel Primitives, DPP + +DPP is a 32-bit instruction modifier appended to the normal VALU instructions. +It allows VALU instructions to access data in neighboring lanes directly, which +means it doesn't need LDS hardware anymore, hence `s_waitcnt` instructions are +**not required**. + +Unfortunately, it also supported limited patterns like `ds_swizzle`. And there +are some instructions that can't be modified by DPP. + +Example: +```nasm +; Normal VALU instruction. +v_add_f32 + +; Instruction modified by DPP. +v_add_f32_dpp +``` + +It's worth mentioning that DPP has different names and syntaxes on different +architectures: +* CDNA: DPP +* RDNA: DPP8/DPP16 + +For details, please check the [MI300 ISA Reference +Guide](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf) +and the [RDNA3 ISA Reference +Guide](https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf). + +### How to use them in MLIR + +Each instruction has a corresponding Op in MLIR (except for `ds_permute`, this +one is not implemented at the time of writing): +* `ds_bpermute`: `rocdl.ds_bpermute` +* `ds_swizzle`: `rocdl.ds_swizzle` +* DPP: `rocdl.update.dpp`, `amdgpu.dpp` (a thin wrapper around + `rocdl.update.dpp` with more comprehensive user interface, e.g., replace magic + numbers with enums) + +The first 2 are straightforward, while DPP follows a different fashion. + +Since DPP is an instruction modifier instead of an instruction itself, there are +tremendous number of combinations of VALU instructions and DPP. To solve that, +`rocdl.update.dpp` and `amdgpu.dpp` are designed to be a wrapper of +`v_mov_b32_dpp` instruction. And it depends on LLVM compiler to fuse it with the +subsequent VALU instruction **with best efforts**. + +For example, `v_mov_b32_dpp` + `v_add_f32_e32` might be fused into `v_add_f32_dpp`. + +There are plenty of constraints stopping an instruction from being merged. For +example, if either the `bank_mask` or the `row_mask` is not `0xf`, it can't be +fused. You can check the +[GCNDPPCombine::combineDPPMov](https://github.com/llvm/llvm-project/blob/ab51eccf88f5321e7c60591c5546b254b6afab99/llvm/lib/Target/AMDGPU/GCNDPPCombine.cpp#L522) +function to see how it works. + +### Comparison + +To summarize, there's no free lunch: instruction's expressivity comes at the +expense of performance. + +The relative performance of cross-lane instructions is as follows: + +DPP > `ds_swizzle` >= `ds_permute` > `ds_bpermute` + +while the generality ranking is the reverse: + +DPP < `ds_swizzle` < `ds_permute` < `ds_bpermute` + +This table presents the approximate instruction latency, collected +experimentally on Fused Softmax kernel with +[rocprofv2](https://github.com/ROCm/rocprofiler?tab=readme-ov-file#plugin-support) +on the MI300 GPU: + +| Instructions | MLIR Op | Hardware | latency/#cycles | +| ---------------------- | ---------------------------- | ------------ | --------------- | +| ds_permute/ds_bpermute | rocdl.ds_bpermute | LDS hardware | ~50* | +| ds_swizzle | rocdl.ds_swizzle | LDS hardware | ~50* | +| DPP | rocdl.update.dpp, amdgpu.dpp | VALU | 4~12 | + +*: For `ds_permute`/`ds_bpermute` and `ds_swizzle`, the latency includes the +instruction itself and its corresponding `s_waitcnt` instruction. diff --git a/docs/developer_guide.md b/docs/developer_guide.md index 832466688..73aee61f7 100644 --- a/docs/developer_guide.md +++ b/docs/developer_guide.md @@ -3,6 +3,57 @@ Each sub-project has its own developer guide. If you would like to work across projects, these instructions should help you get started: + +### Install Dependencies + +Install shortfin dependencies +```bash +sudo apt update && sudo apt install -y clang lld +``` + +### Prepare your python environment + +Install: + +```bash +sudo apt install python-is-python3 python3-venv python3-dev +``` + +
+ + Or, alternatively, use `pyenv` to manage a separate python installation for more control over its version: + + +The following instructions are taken from pyenv's guide here: https://github.com/pyenv/pyenv?tab=readme-ov-file#a-getting-pyenv + +First, install pyenv and its dependencies. + +```bash +sudo apt update; sudo apt install build-essential libssl-dev zlib1g-dev \ +libbz2-dev libreadline-dev libsqlite3-dev curl git \ +libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev +curl https://pyenv.run | bash +``` + +Then, make pyenv available by adding the below to your `~/.bashrc`: + +```bash +export PYENV_ROOT="$HOME/.pyenv" +command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" +eval "$(pyenv init -)" +``` + +Finally, install a pyenv-managed version of python + +```bash +pyenv install 3.12 # or whichever python version you'd like +pyenv local 3.12 +``` + +Now, your python, pip, and venv should be managed by pyenv instead. + +
+ ### Setup a venv We recommend setting up a Python @@ -54,8 +105,10 @@ See also: [nightly_releases.md](nightly_releases.md). ### Running tests ```bash +pip install -r shortfin/requirements-tests.txt pytest sharktank pytest shortfin +pytest app_tests/integration_tests ``` ### Optional: pre-commits and developer settings diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index ddc0cb3bb..03e625b96 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -210,6 +210,8 @@ iree-compile /tmp/open_llama_3b_v2/open-llama-3b-v2-f16.mlir \ -o /tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb ``` +TODO: replace these instructions with the newer shortfin code + Run via `service_v1_cli.py` (shortfin serving, with tokenizer): * TODO: script (via service CLI?) to dump inputs/outputs to .bin/.npy files diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md index 5e0749546..4a8423bc8 100644 --- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -64,7 +64,7 @@ We will use the `hf_datasets` module in `sharktank` to download a LLama3.1 8b f16 model. ```bash -python -m sharktank.utils.hf_datasets amd-shark/llama3.1-8B --local-dir $EXPORT_DIR +python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR ``` ### Define environment variables diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md new file mode 100644 index 000000000..b63861a56 --- /dev/null +++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md @@ -0,0 +1,254 @@ +# Using `shortfin` with `sglang` + +This doc includes basic steps for hooking up sglang with a running Shortfin server. + +## Current Support Status + +| Feature | Description | Enabled | Reference | +| ----------- | ----------- | ---------- | ------------ | +| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) | +| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) | +| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) | +| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) | +| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) | +| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) | +| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) | + +## Prerequisites + +For this tutorial, you will need to meet the following prerequisites: + +### Software + +- Python >= 3.11 + - You can check out [pyenv](https://github.com/pyenv/pyenv) + as a good tool to be able to manage multiple versions of python + on the same system. +- A running `shortfin` LLM server as described [below](#installstart-shortfin-llm-server) + - We will use the shortfin server as the `backend` to generate completions + from SGLang's `frontend language`. In this tutorial, you can think of + `sglang` as the client and `shortfin` as the server. + +### Hardware + +- This tutorial is designed to run on an [AMD MI300X GPU](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html) + +## Install/Start `shortfin` LLM server + +Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/e2e_llama8b_mi300x.md) +to export a model with `sharktank` and start a `shortfin` LLM server +with that model. + +## Install sglang + +### Install sglang inside of virtual environment + +Currently, we have our SGLang integration located at this [forked repo](https://github.com/nod-ai/sglang). +We can use pip to install it in the same virtual environment that we used +to start our Shortfin LLM Server. + +```bash +pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" +``` + +## Getting started + +You can verify the installation/setup through the following examples: + +- [Multi-Turn Q&A Example](#multi-turn-qa-example) +- [Fork Example](#fork-example) +- [Benchmark Shortfin](#bench-mark-shortfin-w-sglang-bench_serving-script) + +## Multi-Turn Q&A example + +Now that we have sglang installed, we can run an example to show a multi-turn +Q&A flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): + +### Open python interpreter + +```bash +python +``` + +### Run example + +You can copy and paste the following example into your interpreter: + +```python +import sglang as sgl + +from sglang.lang.chat_template import get_chat_template + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000", ) # Change base_url if running at different address + +sgl.set_default_backend(backend) + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + +state = multi_turn_question.run(question_1="Name the capital city of the USA.", question_2="The Smithsonian is in this location.") + +for m in state.messages(): + print(m["role"], m["content"]) +``` + +### Shortfin example output + +You should see an output similar to this: + +```text +========== single ========== + +user : Name the capital city of the USA +assistant : The capital city of the United States of America is Washington, D.C. (short for District of Columbia). +user : The Smithsonian is in this location. +assistant : The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. +``` + +## Fork example + +Now that we have sglang installed, we can run an example to show a `fork` +flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): + +### Open python interpreter + +```bash +python +``` + +### Run example + +You can copy and paste the following example into your interpreter: + +```python +import sglang as sgl + +from sglang.lang.chat_template import get_chat_template + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000") # Change base_url if running at different address + +sgl.set_default_backend(backend) + +@sgl.function +def tip_suggestion(s): + s += ( + "Here are two tips for staying healthy: " + "1. Balanced Diet. 2. Regular Exercise.\n\n" + ) + forks = s.fork(2) + for i, f in enumerate(forks): + f += f"Now, expand tip {i+1} into a paragraph:\n" + f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") + +state = tip_suggestion.run() + +print(state.text()) +``` + +### Shortfin example output + +You should see an output similar to this: + +```text +Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise. + +Tip 1:A balanced diet is important for maintaining good health. It should +include a variety of foods from all the major food groups, such as fruits, +vegetables, grains, proteins, and dairy. Eating a balanced diet can help +prevent chronic diseases such as heart disease, diabetes, and obesity. + +Now, expand tip 2 into a paragraph: +Regular exercise is also important for maintaining good health. It can help +improve cardiovascular health, strengthen muscles and bones, and reduce the +risk of chronic diseases. Exercise can also help improve mental health by +reducing stress and anxiety. It is recommended that adults get at least 150 +minutes of moderate-intensity exercise or 75 minutes of vigorous-intensity +exercise per week. + +Now, combine the two paragraphs into a single paragraph: +A balanced diet and regular exercise are both important for maintaining good +health. A balanced diet should include a variety of foods from all the major +food groups, such as fruits, vegetables, grains, proteins, and dairy. +Eating a balanced diet can help prevent chronic diseases such as heart disease, +diabetes, and obesity. Regular exercise is also important for maintaining good +health. It can help improve cardiovascular health, strengthen muscles and bones, +and reduce the risk of chronic diseases. Exercise can also help improve mental +health by reducing stress and anxiety. It is recommended that + +Tip 2:Regular exercise is important for maintaining a healthy body and mind. +It can help improve cardiovascular health, strengthen muscles and bones, +and reduce the risk of chronic diseases such as diabetes and heart disease. +Additionally, exercise has been shown to improve mood, reduce stress, +and increase overall well-being. It is recommended that adults engage in +at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of +vigorous-intensity aerobic activity per week, as well as strength training +exercises at least two days per week. + +In summary, a balanced diet and regular exercise are both essential for +maintaining good health. A balanced diet should include a variety of foods from +all the major food groups, while regular exercise can help improve +cardiovascular health, strengthen muscles and bones, reduce the risk of +chronic diseases, and improve mental health. It is recommended that adults +engage in at least 150 minutes of moderate-intensity aerobic activity or +75 minutes of vigorous-intensity aerobic activity per week, +as well as strength training exercises at least two days per week. +``` + +## Benchmark shortfin w/ sglang `bench_serving` script + +We can obtain benchmarking metrics using the `bench_serving` script +provided by SGLang: + +**NOTE: Change `--base-url` if running at a different address** + +```bash +python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer /path/to/tokenizer/dir --request-rate 1 +``` + +There are some more metrics captured, but the most relevant are the following: + +- E2E Latency +- TTFT (Time to First Token) +- TPOT (Time per Output Token) +- ITL (Inter-Token Latency) +- Request Throughput +- Benchmark Duration + +When complete, you should see an output similar to this: + +```text +============ Serving Benchmark Result ============ +Backend: shortfin +Traffic request rate: 1.0 +Successful requests: 10 +Benchmark duration (s): 427.91 +Total input tokens: 1960 +Total generated tokens: 2774 +Total generated tokens (retokenized): 63 +Request throughput (req/s): 0.02 +Input token throughput (tok/s): 4.58 +Output token throughput (tok/s): 6.48 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 416268.77 +Median E2E Latency (ms): 417159.14 +---------------Time to First Token---------------- +Mean TTFT (ms): 292404.29 +Median TTFT (ms): 365989.01 +P99 TTFT (ms): 367325.63 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 1359.41 +Median TPOT (ms): 163.96 +P99 TPOT (ms): 6316.12 +---------------Inter-token Latency---------------- +Mean ITL (ms): 2238.99 +Median ITL (ms): 958.75 +P99 ITL (ms): 2719.50 +================================================== +``` diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 475f386be..ddd371198 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -153,16 +153,28 @@ def pytest_addoption(parser): # --outtype=f32 \ # t5-v1_1-small parser.addoption( - "--google-t5-v1-1-small-fp32-model-path", + "--google-t5-v1-1-small-f32-model-path", type=Path, - default="/data/t5/small/google__t5-v1_1-small_fp32.gguf", - help="Google T5 v1.1 small fp32 model path", + default="/data/t5/small/google__t5-v1_1-small_f32.gguf", + help="Google T5 v1.1 small float32 model path", ) parser.addoption( - "--google-t5-v1-1-xxl-fp32-model-path", + "--google-t5-v1-1-small-bf16-model-path", type=Path, - default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf", - help="Google T5 v1.1 XXL fp32 model path", + default="/data/t5/small/google__t5-v1_1-small_bf16.gguf", + help="Google T5 v1.1 small bfloat16 model path", + ) + parser.addoption( + "--google-t5-v1-1-xxl-f32-model-path", + type=Path, + default="/data/t5/xxl/google__t5-v1_1-xxl_f32.gguf", + help="Google T5 v1.1 XXL float32 model path", + ) + parser.addoption( + "--google-t5-v1-1-xxl-bf16-model-path", + type=Path, + default="/data/t5/xxl/google__t5-v1_1-xxl_bf16.gguf", + help="Google T5 v1.1 XXL bfloat16 model path", ) parser.addoption( @@ -288,15 +300,20 @@ def get_model_artifacts(request: FixtureRequest): model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option( request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model" ) - model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option( + model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option( + request, + "--google-t5-v1-1-small-f32-model-path", + "google__t5_v1_1_small_f32_model", + ) + model_path["google__t5_v1_1_small_bf16_model_path"] = set_fixture_from_cli_option( request, - "--google-t5-v1-1-small-fp32-model-path", - "google__t5_v1_1_small_fp32_model", + "--google-t5-v1-1-small-bf16-model-path", + "google__t5_v1_1_small_bf16_model", ) - model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option( + model_path["google__t5_v1_1_xxl_f32_model_path"] = set_fixture_from_cli_option( request, - "--google-t5-v1-1-xxl-fp32-model-path", - "google__t5_v1_1_xxl_fp32_model", + "--google-t5-v1-1-xxl-f32-model-path", + "google__t5_v1_1_xxl_f32_model", ) return model_path diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py index 182b37a50..45af24004 100644 --- a/sharktank/integration/models/punet/integration_test.py +++ b/sharktank/integration/models/punet/integration_test.py @@ -89,12 +89,13 @@ def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir): def sdxl_int8_base_files(): from huggingface_hub import hf_hub_download - REPO_ID = "amd-shark/sdxl-quant-models" - REVISION = "942e771bf0c2657a8b33380103d04747a75dfa4a" + REPO_ID = "amd-shark/sdxl-quant-int8" + SUBFOLDER = "mi300_all_sym_8_step14_fp32" + REVISION = "efda8afb35fd72c1769e02370b320b1011622958" def download(filename): return hf_hub_download( - repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION + repo_id=REPO_ID, subfolder=SUBFOLDER, filename=filename, revision=REVISION ) return { diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 6b533d977..26d89c59d 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -1,7 +1,7 @@ iree-turbine # Runtime deps. -gguf==0.6.0 +gguf==0.10.0 numpy<2.0 # Needed for newer gguf versions (TODO: remove when gguf package includes this) diff --git a/sharktank/sharktank/evaluate/README.md b/sharktank/sharktank/evaluate/README.md index 784bb24fd..beb0281cd 100644 --- a/sharktank/sharktank/evaluate/README.md +++ b/sharktank/sharktank/evaluate/README.md @@ -9,16 +9,32 @@ pip install -r sharktank/requirements-tests.txt ### Perplexity -Test perplexity for Llama3.1 8B & 405B (FP16 & FP8) models: +Perplexity score measures the ability of a language model to predict the next token in a sequence. A lower score indicates that a model has higher certainty in it's predictions. Perplexity acts as an intrinsic evaluation metric that measures the model quality, independent of any downstream task. + +In SHARK-Platform, we use perplexity to track code regressions and quality loss across quantized models (with FP16 as baseline). We use 100 prompts randomly selected from the Wikitext-2 test set and calculate the mean perplexities shown below. These numbers are neither comparable between models with different tokenizers nor with other projects due to varying implementations. + +* Test perplexity for Llama3.1 8B (FP16) model: ```bash pytest sharktank/tests/evaluate/perplexity_test.py --longrun ``` -Get perplexity for a new model: +* Calculate perplexity for a new model: ```bash python -m sharktank.evaluate.perplexity \ --gguf-file=llama3_70b_f16.gguf \ --tokenizer-config-json=tokenizer_config.json ``` + +### Perplexity Scoreboard + +| CPU | GPU | +|:-------------: |:----------:| +| AMD EPYC 9554 | MI300X | + +#### LLaMA 3.1 + +|Models |Model size (GB) |Torch score |IREE score | +|:----------------------|:---------------|:-------------|:-------------| +|8B FP16 TP1 decomposed |16.07 |14.930181 |14.991893 | diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_iree.py similarity index 85% rename from sharktank/sharktank/evaluate/perplexity_vmfb.py rename to sharktank/sharktank/evaluate/perplexity_iree.py index 4f95ae1bd..6060eb91b 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -9,6 +9,7 @@ import json import time import random +import re from datetime import timedelta from tqdm import tqdm @@ -83,17 +84,24 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) func_name = func.__name__ if func_name == "get_perplexity": - func_name = f"Total time to calculate perplexity" + func_name = f"Calculate perplexity" elif func_name == "compile_model": - func_name = f"Total time to export and compile" + func_name = f"Export & compile" logger.info(f" {func_name}: {time_taken}") return result @@ -119,7 +127,7 @@ def print_token_comparison(self, i): def compile_model(self, weight_path_str): self.weight_path_str = weight_path_str - logger.info(f"Compiling: {self.weight_path_str}") + logger.info(f" Compiling: {self.weight_path_str}") export_artifacts = ExportArtifacts( irpa_path=self.weight_path_str, @@ -135,7 +143,7 @@ def compile_model(self, weight_path_str): @timeit def load_model(self, weight_path, tokenizer, vmfb_path): - config = LlamaModelConfig( + self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), block_seq_stride=16, kv_cache_type=self.kv_cache_type, @@ -145,18 +153,18 @@ def load_model(self, weight_path, tokenizer, vmfb_path): tensor_parallelism_size=self.tensor_parallelism_size, ) - if config.tensor_parallelism_size > 1: - weight_path.root_theta = shard_theta(weight_path.root_theta, config) + if self.config.tensor_parallelism_size > 1: + weight_path.root_theta = shard_theta(weight_path.root_theta, self.config) theta = weight_path.root_theta - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) else: - model = PagedMixtralModelV1(theta, config) + model = PagedMixtralModelV1(theta, self.config) else: - model = PagedLlamaModelV1(theta, config) + model = PagedLlamaModelV1(theta, self.config) self.generator = TorchGenerator(model, tokenizer) @@ -169,7 +177,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path): self.haldevice = self.runner.config.device @timeit - def get_prompts(self): + def get_prompts(self, num_prompts): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ "text" ] @@ -183,12 +191,15 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ] + ][0:num_prompts] + + self.test_prompts = test_prompts self.bs = len(test_prompts) - return test_prompts + logger.info(f" Batch size: {self.bs}") + @timeit def prefill_vmfb(self, token_batch, i): seq_block_ids = self.batch.pad_block_ids() @@ -244,25 +255,7 @@ def decode_vmfb(self, token_batch, i): return decode_logits @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.test_prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.test_prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) - - self.max_prompt_length = max(seq_lens) - - self.token_ids = torch.tensor(token_ids, device=self.torch_device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.torch_device) - ) + def get_logits(self, page_cache_size): is_first_token = True start = 0 @@ -298,6 +291,7 @@ def get_logits(self): token_batch=token_batch, seq_lens_batch=self.seq_lens_batch, bs=self.bs, + page_cache_size=page_cache_size, ) self.cache_state = ireert.asdevicearray( @@ -347,11 +341,31 @@ def compute_perplexity(self): } @timeit - def get_perplexity(self, test_prompts): + def get_perplexity(self): - self.test_prompts = test_prompts + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) - self.get_logits() + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.torch_device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.torch_device) + ) + + self.get_logits(page_cache_size=self.page_cache_size) self.out_logits = self.out_logits[..., :-1, :].contiguous() self.token_ids = self.token_ids[..., 1:].contiguous() @@ -379,7 +393,9 @@ def run_perplexity( kv_cache_type, tensor_parallelism_size, attention_kernel, + num_prompts, ): + start = time.time() perplexity = Perplexity( torch_device=torch_device, iree_device=iree_device, @@ -390,12 +406,19 @@ def run_perplexity( attention_kernel=attention_kernel, ) - test_prompts = perplexity.get_prompts() - logger.info(f" Total test prompts: {len(test_prompts)}") + perplexity.get_prompts(num_prompts=num_prompts) vmfb_path = perplexity.compile_model(weight_path_str) perplexity.load_model(weight_path, tokenizer, vmfb_path) - ppl = perplexity.get_perplexity(test_prompts) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") return ppl @@ -404,7 +427,7 @@ def main(argv): parser = cli.create_parser() parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument("--torch-device", help="Torch device (or default)") - parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'") + parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") parser.add_argument( "--iree-hip-target", action="store", @@ -429,6 +452,12 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding", ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) @@ -452,6 +481,7 @@ def main(argv): kv_cache_type=kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..258e8c9a0 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -8,6 +8,7 @@ import logging import time import random +import re from datetime import timedelta import json import numpy as np @@ -69,15 +70,22 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) func_name = func.__name__ if func_name == "get_perplexity": - func_name = "Total time" + func_name = "Calculate perplexity" logger.info(f" {func_name}: {time_taken}") return result @@ -102,7 +110,7 @@ def print_token_comparison(self, i): @timeit def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel): - config = LlamaModelConfig( + self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, kv_cache_type=self.kv_cache_type, @@ -112,23 +120,23 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern tensor_parallelism_size=tensor_parallelism_size, ) - if config.tensor_parallelism_size > 1: - dataset.root_theta = shard_theta(dataset.root_theta, config) + if self.config.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, self.config) theta = dataset.root_theta - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) else: - model = PagedMixtralModelV1(theta, config) + model = PagedMixtralModelV1(theta, self.config) else: - model = PagedLlamaModelV1(theta, config) + model = PagedLlamaModelV1(theta, self.config) self.generator = TorchGenerator(model, tokenizer) @timeit - def get_prompts(self): + def get_prompts(self, num_prompts): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ "text" @@ -144,34 +152,16 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ] - - logger.info(f" num_test_prompts: {len(test_prompts)}") - - return test_prompts - - @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.test_prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) + ][0:num_prompts] - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.test_prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) + self.test_prompts = test_prompts - self.max_prompt_length = max(seq_lens) + self.bs = len(test_prompts) - self.token_ids = torch.tensor(token_ids, device=self.device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.device) - ) + logger.info(f" Batch size: {self.bs}") - self.bs = len(self.test_prompts) + @timeit + def get_logits(self, page_cache_size): is_first_token = True start = 0 @@ -204,6 +194,7 @@ def get_logits(self): token_batch=token_batch, seq_lens_batch=seq_lens_batch, bs=self.bs, + page_cache_size=page_cache_size, ) self.batch.prefill() @@ -260,10 +251,31 @@ def compute_perplexity(self): } @timeit - def get_perplexity(self, test_prompts): + def get_perplexity(self): - self.test_prompts = test_prompts - self.get_logits() + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.device) + ) + + self.get_logits(page_cache_size=self.page_cache_size) self.out_logits = self.out_logits[..., :-1, :].contiguous() self.token_ids = self.token_ids[..., 1:].contiguous() @@ -287,12 +299,22 @@ def run_perplexity_torch( kv_cache_type, tensor_parallelism_size, attention_kernel, + num_prompts, ): - perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + start = time.time() + perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + perplexity.get_prompts(num_prompts=num_prompts) perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) - test_prompts = perplexity.get_prompts() - ppl = perplexity.get_perplexity(test_prompts=test_prompts) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") return ppl @@ -314,6 +336,12 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding.", ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) @@ -331,6 +359,7 @@ def main(argv): kv_cache_type=kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/kernels/templates/flash_attention.mlir b/sharktank/sharktank/kernels/templates/flash_attention.mlir index 15d75c372..4085fef9c 100644 --- a/sharktank/sharktank/kernels/templates/flash_attention.mlir +++ b/sharktank/sharktank/kernels/templates/flash_attention.mlir @@ -33,19 +33,16 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_ %scale = tensor.extract %s[] : !s_type - %init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type - %transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2] - %empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type %empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type %atten = iree_linalg_ext.attention {indexing_maps = [ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]} - ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) { + ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> !o_type diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 35a2ee570..996a92152 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -227,6 +227,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): == properties["t5.attention.layer_norm_rms_epsilon"] ) + all_kwargs = {"vocab_size": None, "feed_forward_proj": None} + gguf_to_config_names_map = { "t5.context_length": ["context_length"], "t5.embedding_length": ["d_model"], @@ -236,11 +238,9 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): "t5.attention.key_length": ["d_kv"], "t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"], "t5.attention.relative_buckets_count": ["relative_attention_num_buckets"], - "t5.decoder_start_token_id": ["decoder_start_token_id"], "tokenizer.ggml.eos_token_id": ["eos_token_id"], "tokenizer.ggml.padding_token_id": ["pad_token_id"], } - all_kwargs = {"vocab_size": None, "feed_forward_proj": None} all_kwargs.update( { config_name: properties[gguf_name] @@ -248,6 +248,19 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): for config_name in config_names } ) + + gguf_to_optional_config_names_map = { + "t5.decoder_start_token_id": ["decoder_start_token_id"], + } + all_kwargs.update( + { + config_name: properties[gguf_name] + for gguf_name, config_names in gguf_to_optional_config_names_map.items() + for config_name in config_names + if gguf_name in properties + } + ) + if "tokenizer.ggml.tokens" in properties: all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"]) all_kwargs.update(kwargs) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index b679dccde..acd9b8a37 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -31,9 +31,8 @@ class LinearLayer(ThetaLayer): x = x * premul_input matmul(x, weight.T) + bias - fake_quant exists to allow export without adding dequant ops. - when fake_quant is True, the op will in quant dequant fashion. - When false, it will keep quantized types. + fake quant only exists in order to allow for q_input to act as qdq. + when fake quant is false, q_input will quantize normally. ``` """ @@ -43,7 +42,7 @@ def __init__( *, weight_name: str = "weight", bias_name: str = "bias", - fake_quant: bool = True, + fake_quant: bool = False, ): super().__init__(theta) self._simulate_native_quant = True @@ -74,21 +73,23 @@ def forward(self, x): x = q_input.quantize(x) if self.fake_quant: x = x.unpack().dequant() - elif qdq_input is not None and self.fake_quant: + + elif qdq_input is not None: x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) # Unconditionally dequantize. - if isinstance(y, QuantizedTensor) and not self.fake_quant: + if isinstance(y, QuantizedTensor): y = y.unpack().dequant() # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.float16) - return y - if qdq_output is not None and self.fake_quant: + if not isinstance(y, QuantizedTensor): + if y.dtype == torch.float8_e4m3fnuz: + y = ops.to(y, torch.float16) + return y + if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/token_embedding.py b/sharktank/sharktank/layers/token_embedding.py index 32e7fec8f..e5e06c0ef 100644 --- a/sharktank/sharktank/layers/token_embedding.py +++ b/sharktank/sharktank/layers/token_embedding.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +from typing import Optional from .. import ops from .base import Theta, ThetaLayer @@ -16,7 +17,7 @@ def __init__( theta: Theta, *, weight_name: str = "weight", - dtype: torch.dtype = torch.float32, + dtype: Optional[torch.dtype] = torch.float32, ): super().__init__(theta) self.weight = self.theta_tensor(weight_name) diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py index 7bd5eef3d..8d5f75db2 100644 --- a/sharktank/sharktank/models/t5/export.py +++ b/sharktank/sharktank/models/t5/export.py @@ -4,12 +4,15 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Union +import functools +from typing import Optional, Union from pathlib import Path import torch +from copy import copy from .t5 import T5Config, T5Encoder from ...types import Dataset +from ...transforms.dataset import set_float_dtype from iree.turbine.aot import FxProgramsBuilder, export __all__ = [ @@ -91,7 +94,18 @@ def prune_decoder_parameters(dataset: Dataset): pass -def export_encoder_iree_parameters(model_path: str, output_path: str): - dataset = Dataset.load(model_path) +def export_encoder_iree_parameters( + model_path_or_dataset: str | Dataset, + output_path: str, + dtype: Optional[torch.dtype] = None, +): + if isinstance(model_path_or_dataset, Dataset): + dataset = copy(model_path_or_dataset) + else: + dataset = Dataset.load(model_path_or_dataset) + if dtype: + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) prune_decoder_parameters(dataset) dataset.save(output_path) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py index 4ae9108d5..88472db1d 100644 --- a/sharktank/sharktank/models/t5/t5.py +++ b/sharktank/sharktank/models/t5/t5.py @@ -684,7 +684,9 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): self.add_module( "final_layer_norm", RMSNormLayer( - theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon + theta(f"{theta_prefix}.output_norm"), + epsilon=config.layer_norm_epsilon, + dtype=config.activation_dtype, ), ) @@ -1046,7 +1048,9 @@ def __init__(self, theta: Theta, config: T5Config): super().__init__() self.add_module( "token_embedding", - TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + TokenEmbeddingLayer( + theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype + ), ) encoder_config = copy.deepcopy(config) diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index 8ffb1f51e..d1353daaa 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -47,7 +47,7 @@ def _extract_linear_scale(t): return unbox_tensor(t), None -def flash_attention(q, k, v, a): +def flash_attention(q, k, v, a, is_causal, scale): scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32) q, qscale = _extract_linear_scale(q) diff --git a/sharktank/sharktank/serving_poc/__init__.py b/sharktank/sharktank/serving_poc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sharktank/sharktank/serving_poc/framework/logging.py b/sharktank/sharktank/serving_poc/framework/logging.py deleted file mode 100644 index fe5ffc069..000000000 --- a/sharktank/sharktank/serving_poc/framework/logging.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import os -import sys - - -# Whether debug assertions are disabled. -NDEBUG: bool = False - -_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG") - - -class DefaultFormatter(logging.Formatter): - def __init__(self): - super().__init__( - "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s", - "%m-%d %H:%M:%S", - ) - - -def _setup_logger(): - root_logger = logging.getLogger("sharktank.serving_poc") - root_logger.setLevel(logging.DEBUG) - default_handler = logging.StreamHandler(sys.stderr) - default_handler.flush = sys.stderr.flush - default_handler.setLevel(_default_log_level) - default_handler.setFormatter(DefaultFormatter()) - root_logger.addHandler(default_handler) - root_logger.propagate = False - return root_logger, default_handler - - -root_logger, default_handler = _setup_logger() - -logging.getLogger("asyncio").addHandler(default_handler) - - -def get_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(_default_log_level) - logger.addHandler(default_handler) - logger.propagate = False - return logger diff --git a/sharktank/sharktank/serving_poc/framework/session.py b/sharktank/sharktank/serving_poc/framework/session.py deleted file mode 100644 index 28af0fd44..000000000 --- a/sharktank/sharktank/serving_poc/framework/session.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Runtime session constructs. - -Key concepts: - - * DeviceSession: A single HAL device and other process-level globals. Shared global - memory and corresponding synchronization handles are accessible from here. - * WorkQueue: Logical stream of execution, nested under the DeviceSession. Each - queue holds a timeline semaphore which sequences invocations. For these models, - we route workloads of vastly different characteristics to distinct queues (i.e. - prefill vs decode step). - * LoadedModule: Modules that have been loaded but have not yet been instantiated into - a context. - * HostContext: At least one HostContext is created per LoadedModule. It encapsulates - a VMContext and performs invocations on a dedicated thread. Typically, there will - be more that one HostContext per LoadedModule as it helps us load balance the - host side work across multiple OS threads, ensuring faster feeding of the device. -""" - -from typing import Any, Callable, Coroutine, Generic, TypeVar, Optional, Union - -import asyncio -import concurrent.futures -import math -import queue -from threading import Lock, Thread -import warnings - -import numpy as np - -from iree.runtime import ( # type: ignore[import-untyped] - create_hal_module, - create_io_parameters_module, - get_driver, - BufferUsage, - HalBufferView, - HalCommandBuffer, - HalDevice, - HalDeviceLoopBridge, - HalDriver, - HalElementType, - HalFence, - HalSemaphore, - MemoryType, - ParameterIndex, - VmFunction, - VmInstance, - VmContext, - VmModule, -) - -from .logging import get_logger, NDEBUG - -T = TypeVar("T") - -logger = get_logger("shark_turbine.serving.session") -_CONFIG_LOCK = Lock() -_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None - - -def get_vm_instance() -> VmInstance: - global _GLOBAL_VM_INSTANCE - if not _GLOBAL_VM_INSTANCE: - with _CONFIG_LOCK: - if not _GLOBAL_VM_INSTANCE: - _GLOBAL_VM_INSTANCE = VmInstance() - return _GLOBAL_VM_INSTANCE - - -class DeviceSession: - """Top-level object associated with a single attached device.""" - - __slots__ = [ - "device", - "driver", - "_module_sets", - "queues", - "_queue_request_count", - "vm_instance", - ] - - def __init__( - self, - *, - uri: Optional[str] = None, - driver: Optional[Union[str, HalDriver]] = None, - device: Optional[HalDevice] = None, - vm_instance: Optional[VmInstance] = None, - queue_count: int = 1, - ): - self._queue_request_count = 0 - self.vm_instance = vm_instance or get_vm_instance() - if uri is not None: - assert ( - driver is None and device is None - ), "If 'uri' is given, 'driver' and 'device' cannot be set" - logger.info("Opening device by uri: %s", uri) - driver = uri_driver = get_driver(uri) - device = uri_driver.create_device_by_uri(uri) - assert driver is not None, "'driver' cannot be None" - self.driver = driver if not isinstance(driver, str) else get_driver(driver) - self.device = device if device else self.driver.create_default_device() - - # Dependent objects. - self._module_sets: dict[str, "ModuleSet"] = {} - self.queues = [WorkQueue(self, i) for i in range(queue_count)] - - def shutdown(self): - for ms in self._module_sets.values(): - ms.shutdown() - - def create_module_set(self, name: str, *, context_count: int = 1) -> "ModuleSet": - assert ( - name not in self._module_sets - ), f"Modules with name {name} already created" - lm = ModuleSet(self, name, context_count=context_count) - self._module_sets[name] = lm - return lm - - def module_set(self, name: str) -> "ModuleSet": - try: - return self._module_sets[name] - except KeyError: - raise KeyError( - f"ModuleSet '{name}' not found. Available: {self._module_sets.keys()}" - ) - - def queue(self, index: int = -1) -> "WorkQueue": - """Gets a queue either with an explicit index or in some rotating fashion.""" - if index >= 0: - return self.queues[index] - else: - self._queue_request_count += 1 - qc = self._queue_request_count - return self.queues[qc % len(self.queues)] - - -class ModuleSet: - __slots__ = [ - "contexts", - "modules", - "name", - "session", - "_context_counter", - ] - - def __init__(self, session: DeviceSession, name: str, *, context_count: int): - assert context_count > 0 - self.session = session - self.name = name - self.modules: list[VmModule] = [ - create_hal_module(session.vm_instance, session.device) - ] - self.contexts = [None] * context_count - self._context_counter = 0 - - @property - def initialized(self) -> bool: - return self.contexts[-1] is not None - - def add(self, *modules: VmModule): - for module in modules: - self.modules.append(module) - - def load_vmfb(self, vmfb_path: str): - logger.info("Loading VMFB %s", vmfb_path) - self.add(VmModule.mmap(self.session.vm_instance, vmfb_path)) - - def load_io_module(self, sources_path: str): - logger.info("Loading IO Module %s", sources_path) - index = ParameterIndex() - index.load(sources_path) - par_provider = index.create_provider(scope="model") - self.add(create_io_parameters_module(self.session.vm_instance, par_provider)) - - def initialize(self): - assert not self.initialized, "Already initialized" - count = len(self.contexts) - logger.info("Initializing %s contexts for %s", count, self.name) - for i in range(count): - self.contexts[i] = HostContext( - self.session, self.modules, name=f"HostContext-{self.name}-{i}" - ) - - def shutdown(self): - for hc in self.contexts: - if hc is not None: - hc.shutdown() - - def module(self, name: str) -> VmModule: - for m in self.modules: - if m.name == name: - return m - raise KeyError( - f"Module `{name}` not found. Available: {[m.name for m in self.modules]}" - ) - - def function(self, module_name: str, function_name: str) -> VmFunction: - m = self.module(module_name) - f = m.lookup_function(function_name) - if f is None: - raise KeyError( - f"Function '{function_name}' not found in '{module_name}'. " - f"Available: {m.function_names}" - ) - return f - - @property - def host_context(self) -> "HostContext": - """Gets a context, load balancing across available instances.""" - with _CONFIG_LOCK: - self._context_counter += 1 - counter = self._context_counter - contexts = self.contexts - context = contexts[counter % len(contexts)] - assert context is not None, "Module set not initialized" - return context - - -_ThunkQueueT = queue.SimpleQueue[Union[None, Callable[[], None]]] - - -class HostContext: - def __init__(self, session: DeviceSession, modules: list[VmModule], name: str): - self.session = session - self.vm_context = VmContext(session.vm_instance, modules=modules) - self.name = name - self.loop = asyncio.new_event_loop() - self.loop.set_debug(True) - - # def exc_handler(loop, context): - # print("[EXCEPTION]", loop, context) - # self.loop.set_exception_handler(exc_handler) - - self._device_bridge = HalDeviceLoopBridge(session.device, self.loop) - self._shutdown_future = self.loop.create_future() - logger.info(f"Starting asyncio loop thread %s", name) - self._loop_thread = Thread( - target=self.loop.run_until_complete, - args=[self._shutdown_future], - name=name, - daemon=False, - ) - self._loop_thread.start() - - def shutdown(self, join: bool = True): - if self._shutdown_future is None: - return - logger.info("Signalling shutdown of host context %s", self.name) - local_future = self._shutdown_future - del self._shutdown_future - - def _shutdown(): - local_future.set_result(True) - - self.loop.call_soon_threadsafe(_shutdown) - self._device_bridge.stop() - if join: - self._loop_thread.join() - self.loop.close() - - def __del__(self): - if hasattr(self, "_shutdown_future"): - warnings.warn(f"HostContext deallocated without shutdown(): {self}") - self.shutdown(join=False) - - def run_concurrent( - self, coro: Coroutine[Any, Any, T] - ) -> concurrent.futures.Future[T]: - """Runs a coroutine from another thread, returning a concurrent Future. - - This should be used for submitting initial work to the host context from - another thread or event loop. - - Note that the concurrent Future should have its result() retrieved to - ensure that any asynchronous exceptions are propagated. Otherwise, they will - be silently consumed. - """ - return asyncio.run_coroutine_threadsafe(coro, self.loop) - - def run_sync(self, coro: Coroutine[Any, Any, T]) -> T: - """Runs a coroutine on the host context loop from another thread. - - Waits on and returns the result. - This is primarily intended for testing. - """ - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() - - def on_semaphore( - self, sem: HalSemaphore, payload: int, value: Any - ) -> asyncio.Future: - """Returns an awaitable for when the semaphore attains a payload timepoint. - - The resulting Future will take the given `value` once complete. - """ - return self._device_bridge.on_semaphore(sem, payload, value) - - -class WorkQueue: - """Models a queue as a progression of steps against a timeline semaphore.""" - - __slots__ = [ - "_device", - "_lock", - "_semaphore", - "_step", - "index", - ] - - def __init__(self, session: DeviceSession, index: int = 0): - self.index = index - self._device = session.device - self._lock = Lock() - self._semaphore = session.device.create_semaphore(0) - self._step = 0 - - def execute_sequential(self, command_buffer: HalCommandBuffer): - """Executes a list of command buffers at the current step, advancing to the - next. - """ - with self._lock: - current_step = self._step - next_step = current_step + 1 - self._step = next_step - sem = self._semaphore - self._device.queue_execute( - command_buffer, [(sem, current_step)], [(sem, next_step)] - ) - - def current_fence(self) -> HalFence: - """Gets a fence representing the current step.""" - with self._lock: - return HalFence.create_at(self._semaphore, self._step) - - def step_fences(self) -> tuple[HalFence, HalFence]: - """Gets two fences, one at the current step and one at the next.""" - with self._lock: - current_step = self._step - next_step = current_step + 1 - self._step = next_step - sem = self._semaphore - return HalFence.create_at(sem, current_step), HalFence.create_at(sem, next_step) - - def sync(self, host_context: HostContext) -> asyncio.Future: - """Awaitable that completes when all work currently queued completed.""" - with self._lock: - current_step = self._step - return host_context.on_semaphore(self._semaphore, current_step, True) - - def guard(self, value: T) -> "TimelineGuarded[T]": - """Guards an arbitrary value as a timeline guard at the current queue - position. The value will become available when the queue is sync'd.""" - return TimelineGuarded(value, self._semaphore, self._step) - - def __repr__(self): - with self._lock: - return f"WorkQueue[{self.index}](semaphore={self._semaphore}, step={self._step}" - - -class TransferBuffer: - """Transfer buffers are pairs of host/device buffers of a specific size. - - They are used for streaming to/from the device. - """ - - __slots__ = [ - "host_buffer", - "device_buffer", - "host_buffer_map", - "_pool", - ] - - def __init__(self, session: DeviceSession, buffer_size_bytes: int): - self.host_buffer = session.device.allocator.allocate_buffer( - memory_type=MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=buffer_size_bytes, - ) - self.device_buffer = session.device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=buffer_size_bytes, - ) - self.host_buffer_map = self.host_buffer.map() - self._pool: Optional["TransferBufferPool"] = None - - @staticmethod - def allocate_shaped( - session: DeviceSession, shape: list[int], element_type: HalElementType - ) -> "TransferBuffer": - assert HalElementType.is_byte_aligned(element_type) - buffer_size_bytes = math.prod(shape) * HalElementType.dense_byte_count( - element_type - ) - return TransferBuffer(session, buffer_size_bytes) - - def recycle(self): - pool = self._pool - assert ( - pool is not None - ), f"Cannot recycle a TransferBuffer that was not acquired from a pool ({self})" - self._pool = None - pool.recycle(self) - - def h2d_array( - self, - cb: HalCommandBuffer, - shape: list[int], - element_type: HalElementType, - *, - fill_value: Any = None, - ) -> tuple[np.ndarray, HalBufferView]: - """Performs an h2d transfer on the given CommandBuffer of the given shape and - element type. - - Returns a host array and device buffer view. Because transfers do not start - until the command buffer is submitted, the host array should be populated - between the return from this call and submission. - """ - ary = self.host_buffer_map.asarray( - shape, HalElementType.map_to_dtype(element_type) - ) - if fill_value is not None: - ary.fill(fill_value) - bv = HalBufferView(self.device_buffer, shape, element_type) - cb.copy(self.host_buffer, self.device_buffer, length=bv.byte_length) - return ary, bv - - def __repr__(self): - if self._pool is None: - return f"TransferBuffer(FREE)" - else: - return f"TransferBuffer({self._pool})" - - if not NDEBUG: - - def __del__(self): - if self._pool is not None: - warnings.warn( - f"Deallocated TransferBuffer which needed to be recycled: {self}" - ) - - -class TransferBufferPool: - """Pool of transfer buffers of a fixed size.""" - - __slots__ = [ - "_allocator", - "_free_list", - "name", - ] - - def __init__( - self, - allocator: Callable[[], TransferBuffer], - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ): - self.name = name - if initial_capacity > 0: - self._free_list = [allocator() for _ in range(initial_capacity)] - self._allocator = None - if growable: - self._allocator = allocator - - @staticmethod - def shaped( - session: DeviceSession, - shape: list[int], - element_type: HalElementType, - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ) -> "TransferBufferPool": - """Allocates a pool of transfer buffers of the given shape.""" - if initial_capacity > 0: - logger.info( - "Allocating initial capacity %s of '%s' transfer buffers: %s x %r", - initial_capacity, - name, - shape, - element_type, - ) - return TransferBufferPool( - lambda: TransferBuffer.allocate_shaped(session, shape, element_type), - initial_capacity=initial_capacity, - growable=growable, - name=name, - ) - - @staticmethod - def sized( - session: DeviceSession, - buffer_byte_size: int, - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ) -> "TransferBufferPool": - """Allocates a pool of transfer buffers of a given size in bytes.""" - if initial_capacity > 0: - logger.info( - "Allocating initial capacity %s of '%s' transfer buffers: %s bytes", - initial_capacity, - name, - buffer_byte_size, - ) - return TransferBufferPool( - lambda: TransferBuffer(session, buffer_byte_size), - initial_capacity=initial_capacity, - growable=growable, - name=name, - ) - - def acquire(self) -> TransferBuffer: - """Acquires a transfer buffer from the pool. - - Must be returned via recycle() when done. - """ - free_list = self._free_list - if len(free_list) > 0: - tb = free_list.pop() - assert tb._pool is None - tb._pool = self - return tb - - allocator = self._allocator - if not allocator: - raise RuntimeError( - f"Transfer buffer pool '%s' exhausted and not growable", self.name - ) - logger.info("Grow transfer buffer pool '%s'", self.name) - tb = allocator() - assert tb._pool is None - tb._pool = self - return tb - - def recycle(self, tb: TransferBuffer): - """Recycles an acquired transfer buffer.""" - self._free_list.append(tb) - - def __repr__(self): - return f"TransferBufferPool({self.name})" - - -class AsyncResources: - """Resources held for some asynchronous scope.""" - - __slots__ = [ - "_resources", - ] - - def __init__(self): - self._resources: list[Union[TransferBuffer, "AsyncResources"]] = [] - - def acquire_transfer_buffer(self, pool: TransferBufferPool) -> TransferBuffer: - tb = pool.acquire() - self._resources.append(tb) - return tb - - def recycle(self): - for r in self._resources: - r.recycle() - self._resources.clear() - - if not NDEBUG: - - def __del__(self): - if len(self._resources) != 0: - warnings.warn( - f"Deallocated AsyncResources that was not recycled: {self}" - ) - self.recycle() - - -class TimelineGuarded(Generic[T]): - """Some form of results that are structurally available now but will not be - populated until some point in the future. - - This is used to encapsulate entities that are guarded by availability of - a timepoint. Note that we only allow a single timepoint guard in order to - simplify subsequent coordination. This will typically be the case when the - guard is derived from a queue of some form (as opposed to a gather). - """ - - __slots__ = [ - "value", - "sem", - "timeline", - ] - - def __init__(self, value: T, sem: HalSemaphore, timeline: int): - self.value = value - self.sem = sem - self.timeline = timeline - - def resolve(self, host_context: HostContext) -> asyncio.Future[T]: - """Produces an awaitable that resolves to the value once available.""" - return host_context.on_semaphore(self.sem, self.timeline, self.value) - - def __repr__(self): - return f"TimelineGuarded[{self.sem} @ {self.timeline}] = {self.value}" diff --git a/sharktank/sharktank/serving_poc/llm/__init__.py b/sharktank/sharktank/serving_poc/llm/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sharktank/sharktank/serving_poc/llm/api/rest_server.py b/sharktank/sharktank/serving_poc/llm/api/rest_server.py deleted file mode 100644 index 67536173f..000000000 --- a/sharktank/sharktank/serving_poc/llm/api/rest_server.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Heavily adapted from the vllm api_server.py. - -from typing import AsyncGenerator, Optional, Sequence - -import argparse -import json - -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -import sys -import uuid -import uvicorn - -from ...framework.logging import get_logger -from ...framework.session import DeviceSession - - -from ..service import ( - create_mock_generate_service, - GenerateService, - GenerateRequest, -) - -logger = get_logger("sharktank.serving_poc.llm.api_server") -app = FastAPI() -service: Optional[GenerateService] = None - - -def get_service() -> GenerateService: - assert service is not None, "Service was not initialized" - return service - - -@app.get("/health") -async def health() -> Response: - get_service() - return Response(status_code=200) - - -@app.post("/generate") -async def generate(request: Request) -> Response: - service = get_service() - r = await request.json() - prompt = r.pop("prompt") - stream = bool(r.pop("stream", False)) - request_id = uuid.uuid4().hex - - generate_request = GenerateRequest(request_id=request_id, prompt=prompt) - result_parts = service.handle_request(generate_request) - - if stream: - # TODO: This isn't entirely matching how others do it: we should be returning - # the full result on each update. - async def stream_contents() -> AsyncGenerator[bytes, None]: - async for part in result_parts: - response_record = json.dumps({"text": part.text}) - yield (response_record + "\0").encode() - - return StreamingResponse(stream_contents()) - - # Non-streaming just reads to the final. - async for result_part in result_parts: - if await request.is_disconnected(): - # Abort. - await service.abort(generate_request.request_id) - return Response(status_code=499) - - assert result_part is not None, "No results generated!" - return JSONResponse({"text": result_part.text}) - - -def main(clargs: Sequence[str]): - global service - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="Root path to use for installing behind path based proxy.", - ) - parser.add_argument( - "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" - ) - parser.add_argument( - "--testing-mock-service", - action="store_true", - help="Enable the mock testing service", - ) - parser.add_argument( - "--device-uri", type=str, default="local-task", help="Device URI to serve on" - ) - - args = parser.parse_args(clargs) - - # Spin up the device machinery. - # Note that in the future, for multi-device, we will need more scaffolding for - # configuration and bringup, obviously. - device_session = DeviceSession(uri=args.device_uri) - - if args.testing_mock_service: - logger.info("Enabling mock LLM generate service") - service = create_mock_generate_service() - - app.root_path = args.root_path - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=args.timeout_keep_alive, - ) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py deleted file mode 100644 index a2299c67e..000000000 --- a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Manages the block cache.""" - -from iree.runtime import ( # type: ignore - HalBufferView, - HalElementType, - BufferUsage, - MemoryType, - PyModuleInterface, - VmModule, -) - -from ..framework.logging import get_logger -from ..framework.session import DeviceSession - -from .config import human_size, CacheParams - - -logger = get_logger("sharktank.serving_poc.llm.cache") - - -class AttnBlockCacheEntry: - __slots__ = [ - "index", - "in_use", - ] - - def __init__(self, index: int): - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnBlockCache: - def __init__(self, session: DeviceSession, cache_params: CacheParams): - self.session = session - self.cache_params = cache_params - self._initialize_block_cache() - - def _initialize_block_cache(self): - model_params = self.cache_params.model - # Allocate the on-device cache slab. - attn_block_count = self.cache_params.device_block_count - attn_block_size_elements = self.cache_params.attn_block_size_elements - attn_block_size_bytes = attn_block_size_elements * model_params.attn_dtype_size - attn_cache_size_bytes = attn_block_count * attn_block_size_bytes - - logger.info("Setting up cache for\n %r", self.cache_params) - logger.info( - "Allocating attention static cache on device of %s " - "(blocks=%s, block_size=%s bytes)", - human_size(attn_cache_size_bytes), - attn_block_count, - attn_block_size_bytes, - ) - self.attn_block_buffer = self.session.device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=attn_cache_size_bytes, - ) - - # Attn block logical view. - self.attn_block_buffer_view = HalBufferView( - self.attn_block_buffer, - [ - attn_block_count, - attn_block_size_elements, - ], - model_params.attn_dtype, - ) - - # Accounting structs. - self.attn_block_entries = [ - AttnBlockCacheEntry(i) for i in range(attn_block_count) - ] - self.attn_block_free = list(self.attn_block_entries) - - async def acquire_attn_blocks( - self, count: int, into_list: list[AttnBlockCacheEntry] - ): - """Acquires 'count' attention blocks. - - If there are insufficient free blocks, raises an exception. - """ - free_list = self.attn_block_free - assert ( - len(free_list) >= count - ), f"Cache does not contain requested {count} free attn blocks" - for i in range(count): - into_list.append(free_list.pop()) - - async def release_attn_blocks(self, blocks: list[AttnBlockCacheEntry]): - """Releases a list of attention blocks. - - If at all possible, this should be batched to include all blocks that need to - be released at a given time since this will trigger heavy-weight scheduling - that will work better with a view of the new free list as a whole. - """ - free_list = self.attn_block_free - for block in blocks: - free_list.append(block) - - -def create_attn_block_cache_module(attn_block_cache: AttnBlockCache) -> VmModule: - """Creates a VM module that exports the attention block cache. - - For in-system use, we use a dynamic module that can provide the block cache - slab. In other uses, this may be provided by a statically compiled module - that does the same. - - Interface: - Module name: attn_block_cache - Exports: - func @attn_block_cache.get_shared_slab() -> (!hal.buffer_view) - """ - - class Module: - def __init__(self, iface): - ... - - def get_shared_slab(self): - return attn_block_cache.attn_block_buffer_view.ref - - iface = PyModuleInterface(module_name="attn_block_cache", ctor=Module) - iface.export("get_shared_slab", "0v_r", Module.get_shared_slab) - return iface.create() diff --git a/sharktank/sharktank/serving_poc/llm/config.py b/sharktank/sharktank/serving_poc/llm/config.py deleted file mode 100644 index df5db5f8f..000000000 --- a/sharktank/sharktank/serving_poc/llm/config.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Configuration objects. - -Parameters that are intrinsic to a specific model. - -In a typical transformer model, the KV cache is organized similar to (mapped to -our parameter names below): - k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) - v = ... - -For context, a popular model has parameters of: - attn_dtype_size = 2 # (fp16) - max_seq_len = 2048 - transformer_block_count = 32 - attn_head_count = 32 - attn_head_dim = 128 # (dim / head_count) - -If paging, then we primary care about the organization of a single block, where -a block represents a single position in the sequence for a single item in the batch. -Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) - -In this scenario, we declare that one block holds the KV cache for all transformer -block layers because it reduces the accounting. As such, for the above example, -a single position in the sequence will be 524,288 bytes, assuming a 2-byte element -type. If we choose to block by block_stride=16 positions, each block will be 8MiB. -Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 -blocks for a total number of sequence positions of 24,576. - -These are well-known numbers but are derived above to give a sense of scale. - -In order to indirect through to the block cache, we have to provide the index map -to specific invocations: - -* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will - need write indices of [batch_size, prompt_len // block_stride + 1]. -* Decode step: Decode is auto-regressive, and needs to first compute the new kv - row and then attend over all rows in the cache up to this point in the sequence. - -If wanting to avoid dynamic allocation of transients, we can also pool the index -tables based on the maximum batch size and maximum sequence length. Since all -block cache sizes are well within the range of an i16, we will use that for storage. -Therefore, each batch invocation would need a block lookup table of: - - byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) - -For a max_batch_size of 16, this is 4KiB of block index table lookups per -invocation. We don't have to statically allocate this, but the system is more -predictable if we just reserve what we need. Again, numbers are given to give a -sense of scale only: real workloads will vary. -""" - -from dataclasses import dataclass - -from iree.runtime import ( # type: ignore - HalElementType, -) - -import json - - -@dataclass -class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" - - # The element type of the attention caches. - attn_dtype: HalElementType - - # Maximum length of a sequence including prompt and output. - max_seq_len: int - - # Number of transformer blocks. - transformer_block_count: int - - # Number of attention heads per block. - attn_head_count: int - - # Dimensionality of each attention head - attn_head_dim: int - - # Position stride per attention block - block_seq_stride: int - - # Batch sizes that the prefill stage is compiled for. These are expected to be - # functions exported from the model with suffixes of "_bs{batch_size}". Must - # be in ascending order. - prefill_batch_sizes: list[int] - - # Similarly, batch sizes that the decode stage is compiled for. - decode_batch_sizes: list[int] - - # Name of the IREE module implementing the model. - module_name: str = "module" - - # ABI of the module. - module_abi_version: int = 1 - - # Size in bytes of the KV cache dtype. - @property - def attn_dtype_size(self) -> int: - assert HalElementType.is_byte_aligned(self.attn_dtype) - return HalElementType.dense_byte_count(self.attn_dtype) - - @property - def max_prefill_batch_size(self) -> int: - return self.prefill_batch_sizes[-1] - - @property - def max_decode_batch_size(self) -> int: - return self.decode_batch_sizes[-1] - - @property - def max_batch_size(self): - return max(self.max_prefill_batch_size, self.max_decode_batch_size) - - @staticmethod - def load_json(path): - f = open(path) - j = json.load(f) - return ModelParams(attn_dtype=HalElementType.FLOAT_16, **j) - - -@dataclass -class CacheParams: - """Parameters for management of the block cache. - - This is paired with a ModelParams. - - We presently use a static block cache configuration and hand-wave either a tuning - run or pen/paper analysis to derive the parameters. - """ - - model: ModelParams - - # The size of the static block cache on the device. - device_block_count: int - - # The stride of each block in sequence positions. - block_pos_stride: int - - @property - def attn_unit_size_elements(self) -> int: - """Size in bytes of each cache line in the attention cache. - - Each cache line can store a unit position stride. - """ - size = 1 - size *= self.model.transformer_block_count - size *= 2 # K and V cache line - size *= self.model.attn_head_count - size *= self.model.attn_head_dim - return size - - @property - def attn_block_size_elements(self) -> int: - """Size in bytes of each attention block of {block_position_stride} positions.""" - return self.attn_unit_size_elements * self.block_pos_stride - - -@dataclass -class ServiceParams: - """Parameters for the serving service.""" - - cache: CacheParams - model: ModelParams - - -# From: https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size -def human_size(num, suffix="B"): - for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): - if abs(num) < 1024.0: - return f"{num:3.1f}{unit}{suffix}" - num /= 1024.0 - return f"{num:.1f}Yi{suffix}" diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py deleted file mode 100644 index 8ae0be637..000000000 --- a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py +++ /dev/null @@ -1,495 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Implements the BatchGenerateService for V1 compiled models. - -This is far from where we want to land but is intended for first round bootstrapping. -Perhaps the biggest issue is that it wouldn't mate well as-is with samplers. -""" - -import asyncio -from dataclasses import dataclass - -import numpy as np - -from iree.runtime import ( # type: ignore - HalBufferView, - HalCommandBuffer, - HalElementType, - HalFence, - VmFunction, - VmVariantList, -) - -from ...framework.logging import get_logger, NDEBUG -from ...framework.session import ( - AsyncResources, - DeviceSession, - TimelineGuarded, - TransferBufferPool, - WorkQueue, -) - -from ..attn_block_cache import AttnBlockCacheEntry, AttnBlockCache -from ..config import ServiceParams -from ..service import ( - BatchGenerateService, - BatchGenerateState, - GenerateRequest, -) - - -logger = get_logger("sharktank.serving_poc.llm.impl.service_v1") - -EXPECTED_CONCURRENCY = 10 - - -class GenerateServiceV1(BatchGenerateService): - def __init__( - self, *, session: DeviceSession, params: ServiceParams, cache: AttnBlockCache - ): - self.params = params - self.block_pos_stride = params.cache.block_pos_stride - self.batch_sizes = params.model.prefill_batch_sizes - # TODO: Remove distinction between prefill and decode batch sizes. - assert params.model.decode_batch_sizes == self.batch_sizes - self.session = session - self.cache = cache - module_name = params.model.module_name - logger.info("Configuring serving for module set %s", module_name) - self.module_set = session.module_set(params.model.module_name) - - # Initialize prefill entry-points (1 per batch size). - self.prefill_functions: dict[int, VmFunction] = {} - for bs in self.batch_sizes: - assert bs not in self.prefill_functions - symbol_name = f"prefill_bs{bs}" - logger.info("Looking up symbol '%s'", symbol_name) - self.prefill_functions[bs] = self.module_set.function( - module_name, symbol_name - ) - - # Initialize decode entry-points (1 per batch size). - self.decode_functions: dict[int, VmFunction] = {} - for bs in self.batch_sizes: - assert bs not in self.decode_functions - symbol_name = f"decode_bs{bs}" - logger.info("Looking up symbol '%s'", symbol_name) - self.decode_functions[bs] = self.module_set.function( - module_name, symbol_name - ) - - self._initialize_transfer_pools() - - def _initialize_transfer_pools(self): - params = self.params - max_bs = params.model.max_batch_size - max_sl = params.model.max_seq_len - initial_inflight = EXPECTED_CONCURRENCY - - # block_indices_pool: array([max_batch_size, max_attn_blocks], np.int64) - # Suitable to handle the sequence->block mapping for all steps. - self.block_indices_pool = TransferBufferPool.shaped( - self.session, - [ - max_bs, - max_sl // self.block_pos_stride, - ], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="block_cache_indices", - ) - - # Prefill tokens: array([max_batch_size, max_seq_len], np.int64) - # Tokens inputs to prefill. - self.prefill_tokens_pool = TransferBufferPool.shaped( - self.session, - [ - max_bs, - max_sl, - ], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="prefill_tokens", - ) - - # Prefill sequence lengths: array([max_batch_size], np.int64) - # Sequence lengths of input tokens. - self.prefill_seq_lens_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="prefill_seq_lens", - ) - - # Decode tokens: array([max_batch_size], np.int64) - # Tokens to perform a decode step with. - self.decode_tokens_pool = TransferBufferPool.shaped( - self.session, - [max_bs, 1], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_tokens", - ) - - # Decode seq lengths: array([max_batch_size], np.int64) - # Decoder seq length for this step - self.decode_seq_lens_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_seq_len", - ) - - # Decode start positions: array([max_batch_size], np.int64) - # Tokens to perform a decode step with. - self.decode_start_pos_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_start_pos", - ) - - def start(self) -> "GenerateState": - return GenerateState(self) - - def shutdown(self): - self.session.shutdown() - - -class _Sequence: - __slots__ = [ - "attn_blocks", - "attn_blocks_needed", - "current_token_ids", - "decode_token_ids", - "request", - "seq_length", - ] - - current_token_ids: list[int] - decode_token_ids: list[int] - - def __init__(self, request: GenerateRequest): - self.request = request - self.seq_length: int = 0 - self.attn_blocks: list[AttnBlockCacheEntry] = [] - self.attn_blocks_needed: int = 0 - self.decode_token_ids = [] - self.current_token_ids = [] - - def attn_blocks_available(self): - return len(self.attn_blocks) - - def resize_attention(self, new_size): - old_size = self.attn_blocks_needed - self.attn_blocks_needed = new_size - return new_size - old_size - - -class GenerateState(BatchGenerateState): - __slots__ = [ - "_bs", - "_decode_function", - "_prefill_function", - "_max_attn_blocks_length", - "_max_seq_length", - "_resources", - "_service", - "_sequences", - "_batch_queue", - ] - - def __init__(self, service: GenerateServiceV1): - super().__init__(service.module_set.host_context) - self._resources = AsyncResources() - self._service = service - self._sequences: list[_Sequence] = [] - self._batch_queue = WorkQueue(service.session) - - async def recycle(self): - """Recycles or releases all resources consumed by this instance.""" - cache = self._service.cache - self._batch_queue.sync(self.host_context) - self._resources.recycle() - all_blocks = [] - for seq in self._sequences: - all_blocks.extend(seq.attn_blocks) - seq.attn_blocks.clear() - self._sequences = [] - await cache.release_attn_blocks(all_blocks) - - async def set_sequences(self, requests: list[GenerateRequest]): - """Initiates processing of a list of sequences that make up a batch. - - This is async because it acquires resources which may not be available. - """ - service = self._service - block_pos_stride = service.block_pos_stride - - # Loop through each request and reserve initial attention blocks. - bs = 0 - sequences = self._sequences - assert not sequences, "set_sequences already called" - max_attn_blocks_length = 0 - max_seq_length = 0 - attn_blocks_required = 0 - - for req in requests: - bs += 1 - seq = _Sequence(req) - sequences.append(seq) - seq.current_token_ids = req.required_prompt_token_ids - seq_length = len(seq.current_token_ids) - seq.seq_length = seq_length - max_seq_length = max(max_seq_length, seq_length) - initial_block_count = seq_length // block_pos_stride + 1 - attn_blocks_required += initial_block_count - seq.attn_blocks_needed = initial_block_count - max_attn_blocks_length = max(max_attn_blocks_length, initial_block_count) - - # Determine the appropriate batched entrypoints. - assert bs > 0 - for allowed_bs in service.batch_sizes: - if allowed_bs >= bs: - self._prefill_function = service.prefill_functions[allowed_bs] - self._decode_function = service.decode_functions[allowed_bs] - break - else: - raise AssertionError(f"Unsupported batch size: {bs}") - - # Acquire the needed attention blocks in one batch so as to give the scheduler - # the most visibility into the need. - logger.debug("Acquire prefill attn blocks: %s", attn_blocks_required) - all_attn_blocks: list[AttnBlockCacheEntry] = [] - await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks) - block_index = 0 - for seq in sequences: - next_block_count = seq.attn_blocks_needed - seq.attn_blocks.extend( - all_attn_blocks[block_index : block_index + seq.attn_blocks_needed] - ) - block_index += next_block_count - - # Save state. - self._bs = allowed_bs - self._max_attn_blocks_length = max_attn_blocks_length - self._max_seq_length = max_seq_length - - async def prefill(self) -> TimelineGuarded[HalBufferView]: - hc = self.host_context - service = self._service - resources = self._resources - bs = self._bs - service = self._service - block_pos_stride = service.block_pos_stride - max_attn_blocks_length = self._max_attn_blocks_length - max_seq_length = max_attn_blocks_length * block_pos_stride - sequences = self._sequences - work_queue = self._batch_queue - - # Record a command buffer for performing h2d transfers. - cb = HalCommandBuffer(hc.session.device) - - # Prepare input tokens, sequence lengths and block indices. - # We acquire a transfer buffer of each from the respective pool, populate its - # host side and enqueue. - # prefill_tokens: array([bs, max_seq_length], np.int32) - prefill_tokens_host, prefill_tokens_device = resources.acquire_transfer_buffer( - service.prefill_tokens_pool - ).h2d_array(cb, [bs, max_seq_length], HalElementType.SINT_64, fill_value=0) - - # prefill_seq_lens: array([bs], np.int32) - ( - prefill_seq_lens_host, - prefill_seq_lens_device, - ) = resources.acquire_transfer_buffer(service.prefill_seq_lens_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # attn_block_indices: array([bs, max_attn_blocks], np.in16) - ( - prefill_attn_block_indices_host, - prefill_attn_block_indices_device, - ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array( - cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0 - ) - - # Populate host buffers for each sequence. - for i in range(len(sequences)): - seq = sequences[i] - attn_blocks = seq.attn_blocks - current_token_ids = seq.current_token_ids - row_seq_len = len(current_token_ids) - prefill_tokens_host[i, 0:row_seq_len] = current_token_ids - prefill_seq_lens_host[i] = row_seq_len - for j in range(len(seq.attn_blocks)): - prefill_attn_block_indices_host[i, j] = attn_blocks[j].index - - # Perform h2d transfers. - cb.end() - work_queue.execute_sequential(cb) - - # Inputs: - # token_ids - # seq_lens - # attn_block_indices - # attn_block_buffer_view (the entire slab passed as input) - # wait, signal semaphores - # tied attn_block_buffer (for input[2]) - # tied attn_block_buffer (for result[0]) - inputs = VmVariantList(3) - inputs.push_ref(prefill_tokens_device) - inputs.push_ref(prefill_seq_lens_device) - inputs.push_ref(prefill_attn_block_indices_device) - inputs.push_ref(service.cache.attn_block_buffer_view) - - # Outputs: - # attn_block_buffer_view (tied output) - # decode_tokens - outputs = VmVariantList(1) - # TODO: Async invoke. - hc.vm_context.invoke(self._prefill_function, inputs, outputs) - return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView)) - - async def set_decode_step(self, tokens): - """Initiates processing of a list of tokens to decode across each batch - - This is async because it acquires resources which may not be available. - """ - service = self._service - block_pos_stride = service.block_pos_stride - - sequences = self._sequences - assert sequences, "set_sequences was not called yet" - assert len(sequences) == len(tokens), "expected token for each sequence" - - max_attn_blocks_length = 0 - max_seq_length = 0 - attn_blocks_required = 0 - - for tok, seq in zip(tokens, self._sequences): - seq.decode_token_ids.append(tok) - seq.seq_length = seq.seq_length + 1 - - max_seq_length = max(max_seq_length, seq.seq_length) - block_count = seq.seq_length // block_pos_stride + 1 - - seq.attn_blocks_needed = block_count - attn_blocks_required += block_count - seq.attn_blocks_available() - max_attn_blocks_length = max(max_attn_blocks_length, block_count) - - # Acquire the needed attention blocks in one batch so as to give the scheduler - # the most visibility into the need. - logger.debug("Acquire decode attn blocks: %s", attn_blocks_required) - all_attn_blocks: list[AttnBlockCacheEntry] = [] - await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks) - block_index = 0 - for seq in sequences: - next_block_count = seq.attn_blocks_needed - seq.attn_blocks_available() - seq.attn_blocks.extend( - all_attn_blocks[block_index : block_index + next_block_count] - ) - block_index += next_block_count - - # Save state. - self._max_attn_blocks_length = max_attn_blocks_length - self._max_seq_length = max_seq_length - - async def decode(self) -> TimelineGuarded[HalBufferView]: - hc = self.host_context - service = self._service - resources = self._resources - bs = self._bs - max_attn_blocks_length = self._max_attn_blocks_length - sequences = self._sequences - work_queue = self._batch_queue - - # Record a command buffer for performing h2d transfers. - cb = HalCommandBuffer(hc.session.device) - - # decode_tokens: array([bs, 1], np.int32) - (decode_tokens_host, decode_tokens_device,) = resources.acquire_transfer_buffer( - service.decode_tokens_pool - ).h2d_array(cb, [bs, 1], HalElementType.SINT_64, fill_value=0) - - # decode_seq_lens: array([bs], np.int32) - ( - decode_seq_lens_host, - decode_seq_lens_device, - ) = resources.acquire_transfer_buffer(service.decode_seq_lens_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # decode_start_pos: array([bs], np.int32) - ( - decode_start_pos_host, - decode_start_pos_device, - ) = resources.acquire_transfer_buffer(service.decode_start_pos_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # attn_block_indices: array([bs, max_attn_blocks], np.in16) - ( - decode_attn_block_indices_host, - decode_attn_block_indices_device, - ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array( - cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0 - ) - - # Populate host buffers for each sequence. - for i in range(len(sequences)): - seq = sequences[i] - attn_blocks = seq.attn_blocks - - tok = seq.decode_token_ids[0] - seq_len = len(seq.current_token_ids) - print(f"seq.current_token_ids: {seq.current_token_ids}") - seq.current_token_ids.append(tok) - seq.decode_token_ids = seq.decode_token_ids[1:] - - decode_tokens_host[i, 0] = tok - decode_start_pos_host[i] = seq_len - decode_seq_lens_host[i] = seq_len - for j in range(len(seq.attn_blocks)): - decode_attn_block_indices_host[i, j] = attn_blocks[j].index - - # Perform h2d transfers. - cb.end() - work_queue.execute_sequential(cb) - - # Inputs: - # token_ids - # seq_lens - # start_pos - # attn_block_indices - # attn_block_buffer_view (the entire slab passed as input) - # wait, signal semaphores - # tied attn_block_buffer (for input[4]) - # tied attn_block_buffer (for result[0]) - inputs = VmVariantList(5) - inputs.push_ref(decode_tokens_device) - inputs.push_ref(decode_seq_lens_device) - inputs.push_ref(decode_start_pos_device) - inputs.push_ref(decode_attn_block_indices_device) - inputs.push_ref(service.cache.attn_block_buffer_view) - - # Outputs: - # attn_block_buffer_view (tied output) - # decode_tokens - outputs = VmVariantList(1) - # TODO: Async invoke. - hc.vm_context.invoke(self._decode_function, inputs, outputs) - return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView)) diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py deleted file mode 100644 index 7895341c9..000000000 --- a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import asyncio -import argparse -import numpy -import sys - -from transformers import LlamaTokenizer # type: ignore - -from iree.runtime import ( # type: ignore - HalElementType, -) - -from sharktank.serving_poc.framework.session import DeviceSession - -from sharktank.serving_poc.llm.attn_block_cache import ( - create_attn_block_cache_module, - AttnBlockCache, -) - -from sharktank.serving_poc.llm.config import ( - CacheParams, - ModelParams, - ServiceParams, -) - -from sharktank.serving_poc.llm.impl.service_v1 import GenerateServiceV1 -from sharktank.serving_poc.llm.service import GenerateRequest - - -def setup(vmfb_path, config_path, gguf_path): - from iree.runtime._binding import disable_leak_checker # type: ignore - - model_params = ModelParams.load_json(config_path) - - device_block_count = model_params.max_seq_len // model_params.block_seq_stride - cache_params = CacheParams( - model=model_params, - device_block_count=device_block_count, - block_pos_stride=model_params.block_seq_stride, - ) - - disable_leak_checker() - session = DeviceSession(uri="local-sync", queue_count=2) - attn_block_cache = AttnBlockCache(session, cache_params) - - lms = session.create_module_set(model_params.module_name, context_count=1) - lms.load_io_module(gguf_path) - lms.load_vmfb(vmfb_path) - lms.add(create_attn_block_cache_module(attn_block_cache)) - lms.initialize() - - params = ServiceParams(cache=cache_params, model=model_params) - service = GenerateServiceV1(session=session, params=params, cache=attn_block_cache) - return service - - -def map_buffer(value): - mapped = value.map() - return mapped.asarray(value.shape, HalElementType.map_to_dtype(value.element_type)) - - -async def main(argv): - parser = argparse.ArgumentParser() - parser.add_argument("--tokenizer", help="name of hugginface tokenizer to use") - parser.add_argument("--config", help="json config file with hyperparameters") - parser.add_argument("--vmfb", help="vmfb with compiler LLM kernels") - parser.add_argument("--gguf", help="gguf file containing modle coefficients") - parsed = parser.parse_args(argv) - - hf_path = parsed.tokenizer - config_path = parsed.config - vmfb_path = parsed.vmfb - gguf_path = parsed.gguf - - service = setup(vmfb_path, config_path, gguf_path) - tokenizer = LlamaTokenizer.from_pretrained(hf_path) - state = service.start() - - for line in ["one two three four five six seven eight"]: - prompt = line.strip() - if not prompt: - break - - input_ids = tokenizer.encode(prompt, return_tensors="pt")[0].tolist() - print(input_ids) - request = GenerateRequest("request_id", prompt, input_ids) - await state.set_sequences([request]) - logits = await state.prefill() - - seq_len = len(input_ids) - mapped_logits = map_buffer(logits.value) - predicted_tokens = numpy.argmax(mapped_logits[0, :seq_len], axis=-1) - predicted_token = predicted_tokens[-1] - decoded_token = tokenizer.decode(predicted_token) - print(f"Prefill predicted token: {predicted_token}, decoded: '{decoded_token}'") - - # TODO(scotttodd): sanity check tokenizer use, document inputs/outputs - # 'prefill' is for initialization with multiple steps at once - # 'decode' is for hypothesis exploration, one step at a time - await state.set_decode_step([predicted_token]) - logits = await state.decode() - mapped_logits = map_buffer(logits.value) - predicted_tokens = numpy.argmax(mapped_logits, axis=-1) - predicted_token = predicted_tokens[0] - decoded_token = tokenizer.decode(predicted_token) - print(f"Decode predicted token: {predicted_token}, decoded: '{decoded_token}'") - await state.recycle() - - service.shutdown() - - -if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) diff --git a/sharktank/sharktank/serving_poc/llm/service.py b/sharktank/sharktank/serving_poc/llm/service.py deleted file mode 100644 index c5d4ffb44..000000000 --- a/sharktank/sharktank/serving_poc/llm/service.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import AsyncIterator, Callable, Optional - -from abc import abstractmethod, ABC -import asyncio -from dataclasses import dataclass - -from ..framework.session import ( - HostContext, -) - -######################################################################################## -# User-level single request service -######################################################################################## - - -@dataclass -class GenerateRequest: - """Encapsulates a request to perform LLM generation. - - Requests are bootstrapped from user values and then pumped through the pipeline, - receiving additional elaboration needed to actually begin generation. - """ - - # Client set fields - request_id: str - prompt: str - - # Fields that are set as the request is processed. - prompt_token_ids: Optional[list[int]] = None - - @property - def required_prompt_token_ids(self) -> list[int]: - ids = self.prompt_token_ids - assert ids is not None - return ids - - -@dataclass -class GenerateResponsePart: - """A response part from an LLM generation request.""" - - request: GenerateRequest - index: int - token_ids: list[int] - - # Fields that can be set as the response is post-processed. - text: Optional[str] = None - finished: bool = False - - -class GenerateService(ABC): - """Asynchronous generator service which processes requests into response parts.""" - - @abstractmethod - def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - """Generates response parts for a request.""" - ... - - @abstractmethod - async def abort(self, request_id: str) -> None: - """Aborts a submitted request.""" - ... - - -######################################################################################## -# Batch generation service -# This service is completely asynchronous and operates on a BatchGenerateRequest as -# a state machine. It is expected to have an external actor stepping it through -# states. -######################################################################################## - - -class BatchGenerateService(ABC): - """Handles generation of a batch of requests.""" - - __slots__ = [] # type: ignore - - # def start_prefill(self, request: BatchGenerateRequest): - # ... - @abstractmethod - def start(self) -> "BatchGenerateState": - ... - - -class BatchGenerateState(ABC): - """In-progress batch generation state.""" - - __slots__ = [ - "host_context", - ] - - def __init__(self, host_context: HostContext): - self.host_context = host_context - - -######################################################################################## -# Utilities -######################################################################################## - - -class SyncGenerateFilter(GenerateService): - """GenerateService filter which can synchronously pre/post process.""" - - __slots__ = ["_next"] - - def __init__(self, next: GenerateService): - self._next = next - - def filter_request(self, request: GenerateRequest): - ... - - def filter_response(self, part: GenerateResponsePart): - ... - - async def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - self.filter_request(request) - async for part in self._next.handle_request(request): - self.filter_response(part) - yield part - - async def abort(self, request_id: str) -> None: - """Aborts a submitted request.""" - await self._next.abort(request_id) - - -######################################################################################## -# Testing and mock types -######################################################################################## - - -def create_mock_generate_service() -> GenerateService: - return DummyTokenizerService(EchoGenerateService()) - - -class DummyTokenizerService(SyncGenerateFilter): - """Tokenizer service which will map to code points. - - Useful for testing. - """ - - def filter_request(self, request: GenerateRequest): - if request.prompt_token_ids is None: - request.prompt_token_ids = [ord(c) for c in request.prompt] - - def filter_response(self, part: GenerateResponsePart): - if part.text is None: - part.text = "".join([chr(x) for x in part.token_ids]) - - -class EchoGenerateService(GenerateService): - """Dummy implementation of a generate service. - - It just echoes back the request five times after a delay. - """ - - def __init__(self, delay: float = 0.1): - self._delay = delay - - async def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - next = None - for i in range(5): - if next: - yield next - assert request.prompt_token_ids, "Request lacks prompt tokens" - next = GenerateResponsePart( - request, i, request.prompt_token_ids, finished=False - ) - await asyncio.sleep(self._delay) - if next: - next.finished = True - yield next - - async def abort(self, request_id: str) -> None: - pass diff --git a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py b/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py deleted file mode 100644 index a36ebe667..000000000 --- a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Implements a service_v1 compliant module in Python for testing. - -This uses a PyModuleInterface to define a fake VmModule that exposes 'prefill_bs{n}' -and 'decode_bs{n}' such that the call sequence and args/results can be manipulated. -""" - -import numpy as np -import textwrap -import threading - -from iree.runtime import ( # type: ignore - BufferUsage, - HalBuffer, - HalBufferView, - HalDevice, - HalElementType, - HalFence, - MemoryType, - PyModuleInterface, - VmModule, - VmRef, -) - -from ..config import ModelParams - - -def create_fake_module( - device: HalDevice, module_name: str, model_params: ModelParams -) -> VmModule: - class ServiceV1Module: - def __init__(self, iface): - ... - print("IFACE:", iface, dir(iface)) - - def prefill( - self, - bs: int, - token_ids_ref: VmRef, - seq_lens_ref: VmRef, - attn_block_indices_ref: VmRef, - attn_block_buffer_view: VmRef, - ): - result_array: np.ndarray = np.ndarray([bs, 1], dtype=np.int32) - - def run(): - print(f"FAKE_V1_MODULE: PREFILL bs={bs} : WAIT") - print(" - READY") - _format_device_buffer_view( - lambda s: print(" token_ids =", s), token_ids_ref - ) - _format_device_buffer_view( - lambda s: print(" seq_lens =", s), seq_lens_ref - ) - _format_device_buffer_view( - lambda s: print(" attn_block_indices =", s), - attn_block_indices_ref, - ) - _format_device_buffer_view( - lambda s: print(" attn_block_buffer_view =", s), - attn_block_buffer_view, - ) - - # Async populate. - device_array = result_bv.map().asarray( - result_array.shape, result_array.dtype - ) - for i in range(bs): - device_array[i, 0] = i + 1 - - threading.Thread(target=run).start() - - result_buffer = device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL | MemoryType.HOST_VISIBLE, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=result_array.size * result_array.itemsize, - ) - result_bv = HalBufferView( - result_buffer, result_array.shape, HalElementType.INT_32 - ) - return result_bv.ref - - def decode(self, bs: int): - print(f"FAKE_V1_MODULE: DECODE bs={bs}") - - iface = PyModuleInterface(module_name=module_name, ctor=ServiceV1Module) - - # Dynamically define prefill functions. - def add_prefill_bs(bs: int): - def trampoline(self, *args): - return self.prefill(bs, *args) - - iface.export(f"prefill_bs{bs}", "0rrrr_r", trampoline) - - [add_prefill_bs(bs) for bs in model_params.prefill_batch_sizes] - - # Dynamically define decode functions. - def add_decode_bs(bs: int): - def trampoline(self, *args): - return self.decode(bs, *args) - - iface.export(f"decode_bs{bs}", "0v_v", trampoline) - - [add_decode_bs(bs) for bs in model_params.decode_batch_sizes] - - return iface.create() - - -def _format_device_buffer_view(callback, bv_ref: VmRef): - bv = bv_ref.deref(HalBufferView) # type: HalBufferView - value = bv.map().asarray(bv.shape, HalElementType.map_to_dtype(bv.element_type)) - value_indented = textwrap.indent(repr(value), " ") - callback(f"{bv!r}\n{value_indented}") diff --git a/sharktank/sharktank/serving_poc/py.typed b/sharktank/sharktank/serving_poc/py.typed deleted file mode 100644 index 5e43cc13b..000000000 --- a/sharktank/sharktank/serving_poc/py.typed +++ /dev/null @@ -1 +0,0 @@ -# Marker file for PEP 561 inline type checking. diff --git a/sharktank/sharktank/transforms/dataset/__init__.py b/sharktank/sharktank/transforms/dataset/__init__.py index b6a2a400a..e2a58ea5d 100644 --- a/sharktank/sharktank/transforms/dataset/__init__.py +++ b/sharktank/sharktank/transforms/dataset/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .sharding import * +from .dataset import * diff --git a/sharktank/sharktank/transforms/dataset/dataset.py b/sharktank/sharktank/transforms/dataset/dataset.py new file mode 100644 index 000000000..c600865e4 --- /dev/null +++ b/sharktank/sharktank/transforms/dataset/dataset.py @@ -0,0 +1,19 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from ...types.tensors import InferenceTensor, PrimitiveTensor, DefaultPrimitiveTensor +from ... import ops + + +def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor: + if isinstance(tensor, PrimitiveTensor) and tensor.dtype.is_floating_point: + return DefaultPrimitiveTensor( + name=tensor.name, data=ops.to(tensor, dtype=dtype) + ) + + return tensor diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index 9a7dcf1ee..494607f97 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -118,6 +118,15 @@ def _wrap_tensor( name=name, data=_externalize_tensor(name, data, logical_shape) ) + if type_name == "BF16": + assert data.dtype == np.uint8 + return DefaultPrimitiveTensor( + name=name, + data=_externalize_tensor(name, data.view(np.int16), logical_shape).view( + dtype=torch.bfloat16 + ), + ) + quantized_type = _quantized_types.get(type_name) if quantized_type is not None: return quantized_type( diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index f870aa101..2c267ac49 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -41,6 +41,7 @@ "AnyTensor", "DefaultPrimitiveTensor", "dtype_to_serialized_name", + "dtype_to_serialized_short_name", "flatten_tensor_tree", "InferenceTensor", "MetaDataValueType", @@ -51,6 +52,7 @@ "register_quantized_layout", "ReplicatedTensor", "serialized_name_to_dtype", + "serialized_short_name_to_dtype", "ShardedTensor", "SplitPrimitiveTensor", "torch_tree_flatten", @@ -1286,6 +1288,15 @@ def dtype_to_serialized_name(dtype: torch.dtype) -> str: ) from e +def dtype_to_serialized_short_name(dtype: torch.dtype) -> str: + try: + return _DTYPE_TO_SHORT_NAME[dtype] + except KeyError as e: + raise KeyError( + f"Missing mapping for dtype {dtype}. Please add to the _SHORT_NAME_TO_DTYPE dict" + ) from e + + def serialized_name_to_dtype(dtype_name: str) -> torch.dtype: try: return _NAME_TO_DTYPE[dtype_name] @@ -1295,6 +1306,15 @@ def serialized_name_to_dtype(dtype_name: str) -> torch.dtype: ) from e +def serialized_short_name_to_dtype(dtype_name: str) -> torch.dtype: + try: + return _SHORT_NAME_TO_DTYPE[dtype_name] + except KeyError as e: + raise KeyError( + f"Missing mapping for dtype '{dtype_name}'. Please add to the _SHORT_NAME_TO_DTYPE dict" + ) from e + + _NAME_TO_DTYPE: dict[str, torch.dtype] = { "float32": torch.float32, "float64": torch.float64, @@ -1338,6 +1358,26 @@ def _maybe_dtype(*names: str): _DTYPE_TO_NAME: dict[torch.dtype, str] = {v: k for k, v in _NAME_TO_DTYPE.items()} +_SHORT_NAME_TO_DTYPE: dict[str, torch.dtype] = { + "f32": torch.float32, + "f64": torch.float64, + "c64": torch.complex64, + "c128": torch.complex128, + "f16": torch.float16, + "bf16": torch.bfloat16, + "ui8": torch.uint8, + "i8": torch.int8, + "i16": torch.int16, + "i32": torch.int32, + "i64": torch.int64, + "b": torch.bool, + "f8_e4m3fnuz": torch.float8_e4m3fnuz, +} + +_DTYPE_TO_SHORT_NAME: dict[torch.dtype, str] = { + v: k for k, v in _SHORT_NAME_TO_DTYPE.items() +} + AnyTensor = Union[torch.Tensor, InferenceTensor] ######################################################################################## diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index bd33e1a62..c950a875a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -9,6 +9,7 @@ import subprocess import logging import time +import re from pathlib import Path from datetime import timedelta from typing import List, Optional @@ -107,11 +108,18 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) func_name = func.__name__ logger.info(f" {func_name}: {time_taken}") @@ -180,13 +188,13 @@ def export_to_mlir( cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) - logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}") + logger.info(f" Exporting mlir:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True) if proc.returncode != 0: raise ExportMlirException(proc, cwd) else: - logger.info(f"Exported to mlir successfully:\n" f"{proc.stdout}") + logger.info(f" Exported to mlir successfully:\n" f"{proc.stdout}") return proc.returncode @@ -223,7 +231,7 @@ def compile_to_vmfb( compile_args += args cmd = subprocess.list2cmdline(compile_args) - logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") + logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) return_code = proc.returncode if return_code != 0: @@ -277,7 +285,7 @@ def iree_benchmark_vmfb( benchmark_args += devices benchmark_args += args cmd = subprocess.list2cmdline(benchmark_args) - logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") + logger.info(f" Launching run command:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) return_code = proc.returncode if return_code != 0: diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index d5976ec48..a9097cf06 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -71,6 +71,54 @@ def load_iree_module( return vm_module, vm_context, vm_instance +def promote_bfloat16_to_float32(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.bfloat16: + return tensor.to(dtype=torch.float32) + else: + return tensor + + +def device_array_to_host(device_array: iree.runtime.DeviceArray) -> torch.Tensor: + def reinterpret_hal_buffer_view_element_type( + buffer_view: iree.runtime.HalBufferView, + element_type: iree.runtime.HalElementType, + ) -> iree.runtime.HalBufferView: + return iree.runtime.HalBufferView( + buffer=buffer_view.get_buffer(), + shape=buffer_view.shape, + element_type=element_type, + ) + + def reinterpret_device_array_dtype( + device_array: iree.runtime.DeviceArray, dtype: np.dtype + ) -> iree.runtime.DeviceArray: + return iree.runtime.DeviceArray( + device=device_array._device, + buffer_view=reinterpret_hal_buffer_view_element_type( + device_array._buffer_view, + iree.runtime.array_interop.map_dtype_to_element_type(dtype), + ), + ) + + # Circumvent the lack of bfloat16 in numpy. + # TODO: This uses private fields _device and _buffer_view in iree.runtime.DeviceArray. + # Improve DeviceArray to provide a hatchet to allow for reinterpretation of + # element type of the underlying buffer. + def bfloat16_device_array_to_torch( + device_array: iree.runtime.DeviceArray, + ) -> torch.Tensor: + device_array_as_int16 = reinterpret_device_array_dtype(device_array, np.int16) + torch_tensor_as_int16 = torch.tensor(device_array_as_int16.to_host()) + return torch_tensor_as_int16.view(dtype=torch.bfloat16) + + if device_array._buffer_view.element_type == int( + iree.runtime.HalElementType.BFLOAT_16 + ): + return bfloat16_device_array_to_torch(device_array) + else: + return torch.tensor(device_array.to_host()) + + def run_iree_module_function( module: iree.runtime.VmModule, vm_context: iree.runtime.VmContext, @@ -88,9 +136,13 @@ def run_iree_module_function( device=iree.runtime.get_device(driver, cache=False), vm_function=vm_function, ) + if trace_path_prefix is not None: for i, arg in enumerate(args): - np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host()) + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + ) results = invoker(*args) if isinstance(results, iree.runtime.DeviceArray): results = (results,) @@ -99,10 +151,13 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", - arg.to_host(), + device_array_to_host(arg).numpy(), ) for i, arg in enumerate(results): - np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host()) + np.save( + f"{trace_path_prefix}{function_name}_result{i}.npy", + promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + ) return results @@ -158,7 +213,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - arg.to("cpu").numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).numpy(), ) res = getattr(module, function_name)(**kwargs) if trace_path_prefix is not None: @@ -166,7 +221,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - arg.to("cpu").numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).numpy(), ) results = ( (res,) @@ -189,4 +244,4 @@ def call_torch_module_function( def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: - return [torch.tensor(tensor.to_host()) for tensor in tensors] + return [device_array_to_host(tensor) for tensor in tensors] diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index acf56eb1b..47d9f0244 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -23,24 +23,20 @@ def __init__( self, model: PagedLlamaModelV1, tokenizer: InferenceTokenizer, - page_cache_size: int = 8192, # Need to look at the model more for this. end_token: int = 2, ): self.model = model self.tokenizer = tokenizer - if model.cache.is_paged: - self.shared_cache_state = model.cache.paged.allocate(page_cache_size) - self.free_pages = list(range(1, page_cache_size)) - else: - self.shared_cache_state = None self.end_token = end_token @property def block_seq_stride(self) -> int: return self.model.cache.block_seq_stride - def begin_batch(self, prompts: list[str], add_start_token: bool): + def begin_batch( + self, prompts: list[str], add_start_token: bool, page_cache_size: int = 128 + ): token_ids, seq_lens = self.tokenizer.encode( prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride, @@ -48,8 +44,10 @@ def begin_batch(self, prompts: list[str], add_start_token: bool): ) token_ids = torch.tensor(token_ids, device=self.model.device) seq_lens = torch.tensor(seq_lens, device=self.model.device) - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state + + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: cache_state = self.model.cache.direct.allocate(bs=len(prompts)) return Batch(self, token_ids, seq_lens, cache_state) @@ -59,10 +57,11 @@ def begin_eval_batch( token_batch: torch.tensor, seq_lens_batch: torch.tensor, bs: int, + page_cache_size: int = 128, ): - - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: cache_state = self.model.cache.direct.allocate(bs=bs) return Batch(self, token_batch, seq_lens_batch, cache_state) diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 933bfd2b6..32acec8ac 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional import contextlib from pathlib import Path import os @@ -20,7 +21,7 @@ # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values -def make_rand_torch(shape, dtype=torch.float32): +def make_rand_torch(shape: list[int], dtype: Optional[torch.dtype] = torch.float32): return torch.rand(shape, dtype=dtype) * 2 - 1 diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json index ac2cd7b83..24511b05f 100644 --- a/sharktank/tests/evaluate/baseline_perplexity_scores.json +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -210,7 +210,7 @@ ], "mean_perplexity": 6.060831 }, - "llama3_8B_f16_decomposed_vmfb": { + "llama3_8B_f16_decomposed_iree": { "perplexities": [ 6.651368, 22.059452, diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py similarity index 58% rename from sharktank/tests/evaluate/perplexity_vmfb_test.py rename to sharktank/tests/evaluate/perplexity_iree_test.py index 93ffbe61c..d10d9f5db 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -7,10 +7,15 @@ import unittest import pytest import json +import numpy as np -from sharktank.evaluate import perplexity_vmfb +from sharktank.evaluate import perplexity_iree -longrun = pytest.mark.skipif("not config.getoption('longrun')") +is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") +skipif_run_quick_llama_test = pytest.mark.skipif( + 'not config.getoption("run-nightly-llama-tests")', + reason="Run large tests if --run-nightly-llama-tests is passed", +) @pytest.mark.usefixtures( @@ -18,7 +23,9 @@ "get_iree_flags", "tensor_parallelism_size", "baseline_perplexity_scores", + "batch_size", ) +@is_mi300x class PerplexityTest(unittest.TestCase): def setUp(self): self.current_perplexity_all = {} @@ -27,15 +34,14 @@ def setUp(self): with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) - @longrun def test_llama3_8B_f16_decomposed(self): # Llama 3.1 8B decomposed - model_name = "llama3_8B_f16_decomposed_vmfb" + model_name = "llama3_8B_f16_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -44,33 +50,34 @@ def test_llama3_8B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_f16_vmfb" + model_name = "llama3_8B_f16_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -79,33 +86,34 @@ def test_llama3_8B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed - model_name = "llama3_8B_fp8_decomposed_vmfb" + model_name = "llama3_8B_fp8_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -114,33 +122,34 @@ def test_llama3_8B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_fp8_vmfb" + model_name = "llama3_8B_fp8_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -149,33 +158,36 @@ def test_llama3_8B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) + @skipif_run_quick_llama_test @pytest.mark.xfail( reason="Sharding is unsupported", ) - @longrun def test_llama3_405B_f16_decomposed(self): # Llama 3.1 405B decomposed - model_name = "llama3_405B_f16_decomposed_vmfb" + model_name = "llama3_405B_f16_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -184,33 +196,34 @@ def test_llama3_405B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_f16_vmfb" + model_name = "llama3_405B_f16_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -219,33 +232,34 @@ def test_llama3_405B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed - model_name = "llama3_405B_fp8_decomposed_vmfb" + model_name = "llama3_405B_fp8_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -254,33 +268,34 @@ def test_llama3_405B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_fp8_vmfb" + model_name = "llama3_405B_fp8_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -289,17 +304,20 @@ def test_llama3_405B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 125a0cfdc..751615a85 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -197,7 +197,6 @@ def testBenchmark8B_f16_Decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark8B_f16_Non_Decomposed_Prefill(self): output_file_name = self.dir_path_8b / "f16_torch_prefill" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -780,7 +779,9 @@ def testBenchmark405B_f16_TP8_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException + ) def testBenchmark405B_f16_TP8_Non_Decomposed(self): output_file_name = self.dir_path_405b / "f16_torch" output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file( @@ -828,7 +829,9 @@ def testBenchmark405B_f16_TP8_Non_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="KeyError in theta.py", strict=True, raises=ExportMlirException + ) def testBenchmark405B_fp8_TP8_Decomposed(self): output_file_name = self.dir_path_405b / "fp8_decomposed" output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file( @@ -874,7 +877,9 @@ def testBenchmark405B_fp8_TP8_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="KeyError in theta.py", strict=True, raises=ExportMlirException + ) def testBenchmark405B_fp8_TP8_Non_Decomposed(self): output_file_name = self.dir_path_405b / "fp8_torch" output_mlir = self.llama405b_fp8_torch_sdpa_artifacts.create_file( diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py index 076404e5d..1a696ba57 100644 --- a/sharktank/tests/models/t5/t5_test.py +++ b/sharktank/tests/models/t5/t5_test.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from transformers.models.t5.modeling_t5 import ( T5Attention as ReferenceT5Attention, T5LayerSelfAttention as ReferenceT5LayerSelfAttention, @@ -14,13 +15,21 @@ T5EncoderModel as ReferenceT5EncoderModel, T5Config as ReferenceT5Config, ) +from typing import Optional import os from collections import OrderedDict import pytest import torch +from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path from unittest import TestCase from parameterized import parameterized -from sharktank.types import Theta, DefaultPrimitiveTensor, unbox_tensor, Dataset +from sharktank.types import ( + Theta, + DefaultPrimitiveTensor, + unbox_tensor, + Dataset, + dtype_to_serialized_short_name, +) from sharktank.models.t5 import ( T5Attention, T5SelfAttention, @@ -41,6 +50,8 @@ flatten_for_iree_signature, iree_to_torch, ) +from sharktank.transforms.dataset import set_float_dtype +from sharktank import ops import iree.compiler with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')") @@ -67,45 +78,210 @@ def setUp(self): torch.random.manual_seed(12345) torch.no_grad() - def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace( - self, huggingface_repo_id: str + @with_t5_data + def testXxlBf16AgainstFluxGolden(self): + """The ground-truth values were acquired from the Flux pipeline.""" + target_model_name = ( + f"{'google/t5-v1_1-xxl'.replace('/', '__').replace('-', '_')}_f32_model" + ) + target_model_path = getattr(self, target_model_name) + dataset = Dataset.load(target_model_path) + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=torch.bfloat16) + ) + config = T5Config.from_gguf_properties( + dataset.properties, + feed_forward_proj="gated-gelu", + ) + model = T5Encoder(theta=dataset.root_theta, config=config) + model.eval() + + with open( + "/data/t5/xxl/flux_schnell_t5_v1_1_xxl_encoder_bf16_input_ids.pt", "rb" + ) as f: + reference_input_ids = torch.load(f) + + outputs = model( + input_ids=reference_input_ids, + attention_mask=None, + output_hidden_states=False, + ) + + with open( + "/data/t5/xxl/flux_schnell_t5_v1_1_xxl_encoder_bf16_output_last_hidden_state.pt", + "rb", + ) as f: + reference_last_hidden_state = torch.load(f) + + torch.testing.assert_close( + outputs["last_hidden_state"], reference_last_hidden_state + ) + + def runTestV1_1CompareTorchEagerHuggingFace( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, ): get_dataset( huggingface_repo_id, ).download() tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) - reference_model = ReferenceT5EncoderModel.from_pretrained(huggingface_repo_id) + reference_model = ReferenceT5EncoderModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) reference_model.eval() + model = ReferenceT5EncoderModel.from_pretrained( + huggingface_repo_id, torch_dtype=target_dtype + ) + model.eval() + input_ids = tokenizer( test_prompts, return_tensors="pt", padding=True, + pad_to_multiple_of=16, ).input_ids + expected_outputs = dict(reference_model(input_ids=input_ids)) + actual_outputs = dict(model(input_ids=input_ids)) + actual_outputs = tree_map( + lambda t: ops.to(t, dtype=reference_dtype), actual_outputs + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + def runTestV1_1CompareTorchEagerAgainstHuggingFace( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + get_dataset( + huggingface_repo_id, + ).download() + tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) + reference_model = ReferenceT5EncoderModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) + reference_model.eval() + target_model_name = ( - f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_fp32_model" + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_f32_model" ) target_model_path = getattr(self, target_model_name) dataset = Dataset.load(target_model_path) + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_dtype) + ) config = T5Config.from_gguf_properties( dataset.properties, feed_forward_proj="gated-gelu", ) + + input_ids = tokenizer( + test_prompts, + return_tensors="pt", + padding=True, + pad_to_multiple_of=config.context_length_padding_block_size, + ).input_ids + model = T5Encoder(theta=dataset.root_theta, config=config) model.eval() expected_outputs = reference_model(input_ids=input_ids) actual_outputs = model(input_ids=input_ids) - torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + actual_outputs = tree_map( + lambda t: ops.to(t, dtype=reference_dtype), actual_outputs + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but for XXL we get the same result as the Flux pipeline. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1SmallCompareTorchEagerHuggingFaceBf16AgainstF32(self): + self.runTestV1_1CompareTorchEagerHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) + + @with_t5_data + def testV1_1SmallF32CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.float32, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but for XXL we get the same result as the Flux pipeline. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) @with_t5_data - def testV1_1SmallFp32CompareTorchEagerAgainstHuggingFace(self): - self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-small") + def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.bfloat16, + target_dtype=torch.bfloat16, + ) @with_t5_data - def testV1_1XxlFp32CompareTorchEagerAgainstHuggingFace(self): - self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-xxl") + def testV1_1XxlF32CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.float32, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, but we get the same result as the Flux pipeline. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) @pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix") @@ -115,14 +291,14 @@ def setUp(self): if self.path_prefix is None: self.path_prefix = f"{self._temp_dir}/" - @parameterized.expand( - [ - "google/t5-v1_1-small", - "google/t5-v1_1-xxl", - ] - ) - @with_t5_data - def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): + def runTestV1_1CompareIreeAgainstTorchEager( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): get_dataset( huggingface_repo_id, ).download() @@ -131,12 +307,15 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): huggingface_repo_id_as_path = ( f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" ) - source_model_name = f"{huggingface_repo_id_as_path}_fp32_model" + source_model_name = f"{huggingface_repo_id_as_path}_f32_model" source_model_path = getattr(self, source_model_name) - dataset = Dataset.load(source_model_path) + reference_dataset = Dataset.load(source_model_path) + reference_dataset.root_theta = reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=reference_dtype) + ) config = T5Config.from_gguf_properties( - dataset.properties, + reference_dataset.properties, feed_forward_proj="gated-gelu", ) @@ -149,24 +328,31 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - reference_model = T5Encoder(theta=dataset.root_theta, config=config) - reference_result = flatten_for_iree_signature( - call_torch_module_function( - module=reference_model, - function_name="forward", - kwargs=input_args, - trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_torch_", - ) + reference_dtype_name = dtype_to_serialized_short_name(reference_dtype) + target_dtype_name = dtype_to_serialized_short_name(target_dtype) + target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{target_dtype_name}" + + reference_model = T5Encoder(theta=reference_dataset.root_theta, config=config) + reference_result_dict = call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{reference_dtype_name}_torch_", ) + reference_result = flatten_for_iree_signature(reference_result_dict) - mlir_path = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.mlir" + parameters_path = f"{target_model_path_prefix}.irpa" + if not self.caching or not os.path.exists(parameters_path): + export_encoder_iree_parameters( + source_model_path, parameters_path, dtype=target_dtype + ) + + mlir_path = f"{target_model_path_prefix}.mlir" if not self.caching or not os.path.exists(mlir_path): export_encoder_mlir( - source_model_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path ) - iree_module_path = ( - f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.vmfb" - ) + iree_module_path = f"{target_model_path_prefix}.vmfb" if not self.caching or not os.path.exists(iree_module_path): iree.compiler.compile_file( mlir_path, @@ -174,12 +360,6 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], ) - parameters_path = ( - f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.irpa" - ) - if not self.caching or not os.path.exists(parameters_path): - export_encoder_iree_parameters(source_model_path, parameters_path) - iree_devices = get_iree_devices(driver="hip", device_count=1) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( module_path=iree_module_path, @@ -196,12 +376,70 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): args=iree_args, driver="hip", function_name=f"forward_bs{batch_size}", - trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_iree_", + trace_path_prefix=f"{target_model_path_prefix}_iree_", ) ) + iree_result = [ + ops.to(iree_result[i], dtype=reference_result[i].dtype) + for i in range(len(reference_result)) + ] - torch.testing.assert_close( - reference_result, iree_result, atol=1e-4, rtol=2.0e-3 + torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol) + + @with_t5_data + def testV1_1CompareSmallIreeF32AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-4, + rtol=2.0e-3, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but but it is no worse than the accuracy for of eager bfloat16. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1CompareSmallIreeBf16AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) + + @with_t5_data + def testV1_1CompareXxlIreeF32AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-4, + rtol=2.0e-3, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but but it is no worse than the accuracy for of eager bfloat16. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1CompareXxlIreeBf16AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, ) @@ -211,8 +449,21 @@ def setUp(self): torch.random.manual_seed(12345) torch.no_grad() - def testCompareAgainstTransformersFp32(self): - dtype = torch.float32 + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) batch_size = 19 batch_seq_len = 23 reference_config = ReferenceT5Config( @@ -233,19 +484,21 @@ def testCompareAgainstTransformersFp32(self): theta = Theta( { "attn_q.weight": DefaultPrimitiveTensor( - data=reference_model.q.weight.data + data=reference_model.q.weight.to(dtype=target_dtype) ), "attn_k.weight": DefaultPrimitiveTensor( - data=reference_model.k.weight.data + data=reference_model.k.weight.to(dtype=target_dtype) ), "attn_v.weight": DefaultPrimitiveTensor( - data=reference_model.v.weight.data + data=reference_model.v.weight.to(dtype=target_dtype) ), "attn_o.weight": DefaultPrimitiveTensor( - data=reference_model.o.weight.data + data=reference_model.o.weight.to(dtype=target_dtype) ), "attn_rel_b.weight": DefaultPrimitiveTensor( - data=reference_model.relative_attention_bias.weight.data + data=reference_model.relative_attention_bias.weight.to( + dtype=target_dtype + ) ), } ) @@ -257,24 +510,52 @@ def testCompareAgainstTransformersFp32(self): d_model=reference_config.d_model, d_kv=reference_config.d_kv, num_heads=reference_config.num_heads, - activation_dtype=dtype, + activation_dtype=target_dtype, has_relative_attention_bias=True, ) model.eval() - hidden_states = make_rand_torch( - shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + reference_hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], + dtype=reference_dtype, + ) + reference_mask = make_random_mask( + shape=[batch_size, 1, 1, batch_seq_len], dtype=reference_dtype ) - mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype) - expected_outputs = reference_model(hidden_states=hidden_states, mask=mask) + expected_outputs = reference_model( + hidden_states=reference_hidden_states, mask=reference_mask + ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) + mask = ops.to(reference_mask, dtype=target_dtype) actual_outputs = model( hidden_states=DefaultPrimitiveTensor(data=hidden_states), mask=DefaultPrimitiveTensor(data=mask), ) - torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) - def testCompareSelfAttentionAgainstTransformersFp32(self): - dtype = torch.float32 + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareSelfAttentionAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) batch_size = 19 batch_seq_len = 23 reference_config = ReferenceT5Config( @@ -296,22 +577,24 @@ def testCompareSelfAttentionAgainstTransformersFp32(self): theta = Theta( { "attn_q.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.q.weight.data + data=reference_model.SelfAttention.q.weight.to(dtype=target_dtype) ), "attn_k.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.k.weight.data + data=reference_model.SelfAttention.k.weight.to(dtype=target_dtype) ), "attn_v.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.v.weight.data + data=reference_model.SelfAttention.v.weight.to(dtype=target_dtype) ), "attn_o.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.o.weight.data + data=reference_model.SelfAttention.o.weight.to(dtype=target_dtype) ), "attn_rel_b.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.relative_attention_bias.weight.data + data=reference_model.SelfAttention.relative_attention_bias.weight.to( + dtype=target_dtype + ) ), "attn_norm.weight": DefaultPrimitiveTensor( - data=reference_model.layer_norm.weight.data + data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } ) @@ -323,24 +606,37 @@ def testCompareSelfAttentionAgainstTransformersFp32(self): d_model=reference_config.d_model, d_kv=reference_config.d_kv, num_heads=reference_config.num_heads, - activation_dtype=dtype, + activation_dtype=torch.float32, layer_norm_epsilon=reference_config.layer_norm_epsilon, has_relative_attention_bias=True, ) model.eval() - hidden_states = make_rand_torch( - shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + reference_hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], + dtype=reference_dtype, ) - mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype) - position_bias = make_rand_torch( - shape=[batch_size, reference_config.num_heads, batch_seq_len, batch_seq_len] + reference_mask = make_random_mask( + shape=[batch_size, 1, 1, batch_seq_len], dtype=reference_dtype + ) + reference_position_bias = make_rand_torch( + shape=[ + batch_size, + reference_config.num_heads, + batch_seq_len, + batch_seq_len, + ], + dtype=reference_dtype, ) expected_outputs = reference_model( - hidden_states=hidden_states, - attention_mask=mask, - position_bias=position_bias, + hidden_states=reference_hidden_states, + attention_mask=reference_mask, + position_bias=reference_position_bias, ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) + mask = ops.to(reference_mask, dtype=target_dtype) + position_bias = ops.to(reference_position_bias, dtype=target_dtype) actual_outputs = model( hidden_states=DefaultPrimitiveTensor(data=hidden_states), attention_mask=DefaultPrimitiveTensor(data=mask), @@ -349,7 +645,14 @@ def testCompareSelfAttentionAgainstTransformersFp32(self): actual_outputs = [ unbox_tensor(t) if t is not None else t for t in actual_outputs ] - torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) class T5LayerFFTest(TestCase): @@ -358,8 +661,21 @@ def setUp(self): torch.random.manual_seed(12345) torch.no_grad() - def testCompareAgainstTransformersFp32(self): - dtype = torch.float32 + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) batch_size = 19 batch_seq_len = 23 reference_config = ReferenceT5Config( @@ -376,16 +692,20 @@ def testCompareAgainstTransformersFp32(self): theta = Theta( { "ffn_gate.weight": DefaultPrimitiveTensor( - data=reference_model.DenseReluDense.wi_0.weight + data=reference_model.DenseReluDense.wi_0.weight.to( + dtype=target_dtype + ) ), "ffn_up.weight": DefaultPrimitiveTensor( - data=reference_model.DenseReluDense.wi_1.weight + data=reference_model.DenseReluDense.wi_1.weight.to( + dtype=target_dtype + ) ), "ffn_down.weight": DefaultPrimitiveTensor( - data=reference_model.DenseReluDense.wo.weight + data=reference_model.DenseReluDense.wo.weight.to(dtype=target_dtype) ), "ffn_norm.weight": DefaultPrimitiveTensor( - data=reference_model.layer_norm.weight + data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } ) @@ -394,17 +714,24 @@ def testCompareAgainstTransformersFp32(self): is_gated_act=reference_config.is_gated_act, dense_act_fn=reference_config.dense_act_fn, layer_norm_epsilon=reference_config.layer_norm_epsilon, - activation_dtype=dtype, + activation_dtype=torch.float32, ) - hidden_states = make_rand_torch( - shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + reference_hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], + dtype=reference_dtype, ) - expected_output = reference_model( - hidden_states=hidden_states, + hidden_states=reference_hidden_states, ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) actual_output = model( hidden_states=DefaultPrimitiveTensor(data=hidden_states), ) - torch.testing.assert_close(actual_output, expected_output, atol=1e-5, rtol=0) + actual_output = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_output, + ) + + torch.testing.assert_close(actual_output, expected_output, atol=atol, rtol=rtol) diff --git a/sharktank/tests/serving_poc/framework/device_session_test.py b/sharktank/tests/serving_poc/framework/device_session_test.py deleted file mode 100644 index 5dfdd5f46..000000000 --- a/sharktank/tests/serving_poc/framework/device_session_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from sharktank.serving_poc.framework.session import ( - DeviceSession, -) - - -@pytest.fixture -def local_device_session(): - session = DeviceSession(uri="local-task") - yield session - session.shutdown() - - -def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession): - ms = local_device_session.create_module_set("default") - ms.initialize() - - -def test_host_context_start_stop(local_device_session: DeviceSession): - ms = local_device_session.create_module_set("default") - ms.initialize() - hc = ms.host_context - - -def test_host_context_scheduling(local_device_session: DeviceSession): - device = local_device_session.device - ms = local_device_session.create_module_set("default") - ms.initialize() - hc = ms.host_context - - sem = device.create_semaphore(0) - - async def task1(): - print("[coro1] test_host_context_scheduling.task") - await hc.on_semaphore(sem, 1, True) - print("[coro1] await completed") - sem.signal(2) - - async def task2(): - print("[coro2] waiting for 2") - await hc.on_semaphore(sem, 2, True) - sem.fail("Fail from task2") - - f1 = hc.run_concurrent(task1()) - f2 = hc.run_concurrent(task2()) - sem.signal(1) - print("[main] Waiting for semaphore") - - # Ensure task completion. Important to consume to ensure that exceptions - # propagate. - f1.result() - f2.result() - - print("[main] Waiting on semaphore payload 3") - with pytest.raises(Exception, match="Fail from task2"): - sem.wait(3) diff --git a/sharktank/tests/serving_poc/llm/api_server_test.py b/sharktank/tests/serving_poc/llm/api_server_test.py deleted file mode 100644 index c2d2cc36a..000000000 --- a/sharktank/tests/serving_poc/llm/api_server_test.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import os -from contextlib import closing -from pathlib import Path -import pytest -import requests -import socket -import subprocess -import sys -import time - - -def find_free_port(): - """This tries to find a free port to run a server on for the test. - - Race conditions are possible - the port can be acquired between when this - runs and when the server starts. - - https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - """ - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("localhost", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -class ServerRunner: - def __init__(self, args): - port = str(find_free_port()) - self.url = "http://localhost:" + port - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.process = subprocess.Popen( - [ - sys.executable, - "-m", - "sharktank.serving_poc.llm.api.rest_server", - "--testing-mock-service", - "--port=" + port, - ] - + args, - env=env, - # TODO: Have a more robust way of forking a subprocess. - cwd=str(Path(__file__).resolve().parent.parent.parent), - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_ready() - - def _wait_for_ready(self): - start = time.time() - while True: - try: - if requests.get(f"{self.url}/health").status_code == 200: - return - except Exception as e: - if self.process.poll() is not None: - raise RuntimeError("API server processs terminated") from e - time.sleep(1.0) - if (time.time() - start) > 30: - raise RuntimeError("Timeout waiting for server start") - - def __del__(self): - try: - process = self.process - except AttributeError: - pass - else: - process.terminate() - process.wait() - - -@pytest.fixture(scope="session") -def server(): - try: - import fastapi - import uvicorn - except ModuleNotFoundError as e: - pytest.skip(f"Skipping server test because deps are missing: {e}") - runner = ServerRunner([]) - yield runner - - -def test_health(server: ServerRunner): - # Health check is part of getting the fixture. - ... - - -def test_generate_non_streaming(server: ServerRunner): - resp = requests.post( - f"{server.url}/generate", - json={ - "prompt": "Hi Bob", - }, - ) - resp.raise_for_status() - d = resp.json() - assert d["text"] == "Hi Bob", repr(d) - - -def test_generate_streaming(server: ServerRunner): - resp = requests.post( - f"{server.url}/generate", json={"prompt": "Hi Bob!", "stream": True} - ) - resp.raise_for_status() - full_contents = resp.content - expected_contents = b'{"text": "Hi Bob!"}\x00' * 5 - assert ( - full_contents == expected_contents - ), f"Expected {expected_contents!r} vs {full_contents!r}" diff --git a/sharktank/tests/serving_poc/llm/service_v1_test.py b/sharktank/tests/serving_poc/llm/service_v1_test.py deleted file mode 100644 index c010e2034..000000000 --- a/sharktank/tests/serving_poc/llm/service_v1_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from iree.runtime import ( # type: ignore - HalElementType, -) - -from sharktank.serving_poc.framework.session import DeviceSession -from sharktank.serving_poc.llm.config import ( - CacheParams, - ModelParams, - ServiceParams, -) - -from sharktank.serving_poc.llm.service import ( - GenerateRequest, - GenerateResponsePart, -) - -from sharktank.serving_poc.llm.attn_block_cache import ( - create_attn_block_cache_module, - AttnBlockCache, -) - -from sharktank.serving_poc.llm.impl.service_v1 import ( - GenerateServiceV1, -) - -from sharktank.serving_poc.llm.testing.fake_v1_module import ( - create_fake_module, -) - - -@pytest.fixture -def cache_params(model_params: ModelParams) -> CacheParams: - return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16) - - -@pytest.fixture -def model_params() -> ModelParams: - return ModelParams( - module_name="AwesomeLLM", - module_abi_version=1, - attn_dtype=HalElementType.FLOAT_16, - max_seq_len=128, - transformer_block_count=32, - attn_head_count=32, - attn_head_dim=128, - block_seq_stride=16, - prefill_batch_sizes=[1, 4, 16], - decode_batch_sizes=[1, 4, 16], - ) - - -@pytest.fixture -def uninitialized_session(model_params: ModelParams): - from iree.runtime._binding import disable_leak_checker # type: ignore - - disable_leak_checker() - session = DeviceSession(uri="local-task", queue_count=2) - yield session - session.shutdown() - del session - - -@pytest.fixture -def attn_block_cache( - uninitialized_session: DeviceSession, cache_params: CacheParams -) -> AttnBlockCache: - return AttnBlockCache(uninitialized_session, cache_params) - - -@pytest.fixture -def session( - model_params: ModelParams, - uninitialized_session: DeviceSession, - attn_block_cache: AttnBlockCache, -): - session = uninitialized_session - lms = session.create_module_set("AwesomeLLM", context_count=1) - lms.add( - create_attn_block_cache_module(attn_block_cache), - create_fake_module(session.device, "AwesomeLLM", model_params=model_params), - ) - lms.initialize() - return session - - -@pytest.fixture -def service( - session: DeviceSession, - cache_params: CacheParams, - model_params: ModelParams, - attn_block_cache: AttnBlockCache, -): - params = ServiceParams(cache=cache_params, model=model_params) - return GenerateServiceV1(session=session, params=params, cache=attn_block_cache) - - -def test_single(service: GenerateServiceV1): - state = service.start() - - async def task(): - await state.set_sequences( - requests=[ - GenerateRequest( - "1", - "hello, tell me a story", - [3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73], - ), - GenerateRequest("2", "goodbye", [9, 10]), - ] - ) - guarded_outputs = await state.prefill() - prefill_ids = await guarded_outputs.resolve(state.host_context) - print( - "PREFILL IDS:", - prefill_ids, - ":\n", - prefill_ids.map().asarray( - prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type) - ), - ) - await state.recycle() - - state.host_context.run_sync(task()) diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index f025eccfe..16baa1675 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -48,6 +48,7 @@ option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON) option(SHORTFIN_ENABLE_TRACING "Enable runtime tracing for iree and shortfin" OFF) option(SHORTFIN_ENABLE_LTO "Enables LTO if supported" ON) +option(SHORTFIN_ENABLE_TOKENIZERS "Enables integration of native tokenizers library" OFF) set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") @@ -80,6 +81,7 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/ ) include(shortfin_library) +include(shortfin_testing) include(CheckCXXCompilerFlag) include(FetchContent) @@ -90,7 +92,9 @@ include(FetchContent) if(SHORTFIN_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT SHORTFIN_LTO_SUPPORTED OUTPUT SHORTFIN_LTO_ERROR) - if(SHORTFIN_LTO_SUPPORTED) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + message(STATUS "Not enabling LTO for debug build") + elseif(SHORTFIN_LTO_SUPPORTED) message(STATUS "Shortfin LTO Enabled") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) else() @@ -126,7 +130,9 @@ endif() message(STATUS " - Host") ################################################################################ -# Dependencies +# Bundled Dependencies +# These dependencies are either bundled or used via installed packages based +# on the SHORTFIN_BUNDLE_DEPS option. ################################################################################ if(SHORTFIN_BUNDLE_DEPS) @@ -164,15 +170,19 @@ if(SHORTFIN_BUNDLE_DEPS) shortfin_push_bundled_lib_options() # Enable spdlog shared library options so we can export it. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPDLOG_SHARED_LIB -Dspdlog_EXPORTS") + message(STATUS "Fetching bundled projects") + list(APPEND CMAKE_MESSAGE_INDENT " ") FetchContent_MakeAvailable(fmt spdlog xtl xtensor) shortfin_pop_bundled_lib_options() + list(POP_BACK CMAKE_MESSAGE_INDENT) else() find_package(spdlog) find_package(xtensor) endif() ################################################################################ -# IREE +# IREE Dependency +# This is always a source dependency on the IREE runtime. ################################################################################ # Set IREE build flags. @@ -237,6 +247,65 @@ else() endif() shortfin_pop_bundled_lib_options() +################################################################################ +# Tokenizer Library +################################################################################ + +function(shortfin_check_tokenizers) + # Make sure that rust/cargo is installed and usable. + # Consider switching this to a cached variable once the tokenizers_cpp project + # will accept an override vs running whatever is on the path. For now, just + # verify the path is sane as that is what will get used. + find_program(SHORTFIN_CARGO_PATH NAMES cargo NO_CACHE) + if(NOT SHORTFIN_CARGO_PATH) + message(SEND_ERROR + "Building with -DSHORTFIN_ENABLE_TOKENIZERS=ON requires cargo (Rust's build tool). " + "Please follow Rust documentation to install. On Ubuntu, this can typically be accomplished with:\n" + " sudo apt install rustup && rustup default stable\n" + "See https://www.rust-lang.org/tools/install" + ) + endif() + + # Make sure cargo is functional. + execute_process( + COMMAND ${SHORTFIN_CARGO_PATH} + RESULT_VARIABLE _CARGO_RESULT + OUTPUT_VARIABLE _CARGO_OUT + ERROR_VARIABLE _CARGO_ERR + ) + if(NOT "${_CARGO_RESULT}" STREQUAL "0") + message(SEND_ERROR + "Building with -DSHORTFIN_ENABLE_TOKENIZERS=ON requires cargo (Rust's build tool) " + "to be configured properly. It was found (${SHORTFIN_CARGO_PATH}) but returned an " + "error. Output below:\n" + "${_CARGO_OUT}\n" + "${_CARGO_ERR}" + ) + endif() +endfunction() + +if(SHORTFIN_ENABLE_TOKENIZERS) + # TODO: submit a patch to tokenizers_cpp to allow explicit configuration of the + # cargo location and pass that vs relying on environmental alignment. + shortfin_check_tokenizers() + + shortfin_push_bundled_lib_options() + set(CMAKE_C_VISIBILITY_PRESET "hidden") + set(CMAKE_CXX_VISIBILITY_PRESET "hidden") + set(CMAKE_VISIBILITY_INLINES_HIDDEN ON) + set(MLC_ENABLE_SENTENCEPIECE_TOKENIZER OFF) + + FetchContent_Declare( + tokenizers_cpp # From CMake project() declaration + GIT_REPOSITORY https://github.com/mlc-ai/tokenizers-cpp.git + GIT_TAG 4bb753377680e249345b54c6b10e6d0674c8af03 # 2024 Nov 15 + EXCLUDE_FROM_ALL + ) + message(STATUS "Fetching tokenizers_cpp") + FetchContent_MakeAvailable(tokenizers_cpp) + shortfin_pop_bundled_lib_options() +endif() + ################################################################################ # Tests ################################################################################ @@ -254,9 +323,9 @@ if(SHORTFIN_BUILD_TESTS) endif() include(GoogleTest) enable_testing() + add_custom_target(shortfin_testdata_deps) endif() - add_subdirectory(src) if(SHORTFIN_BUILD_PYTHON_BINDINGS) diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake index aaa97a6c1..103fdf1c5 100644 --- a/shortfin/build_tools/cmake/shortfin_library.cmake +++ b/shortfin/build_tools/cmake/shortfin_library.cmake @@ -182,7 +182,10 @@ function(shortfin_gtest_test) GTest::gmock GTest::gtest_main ) - gtest_discover_tests(${_RULE_NAME}) + gtest_discover_tests( + ${_RULE_NAME} + WORKING_DIRECTORY "${libshortfin_BINARY_DIR}" + ) endfunction() diff --git a/shortfin/build_tools/cmake/shortfin_testing.cmake b/shortfin/build_tools/cmake/shortfin_testing.cmake new file mode 100644 index 000000000..e462b7023 --- /dev/null +++ b/shortfin/build_tools/cmake/shortfin_testing.cmake @@ -0,0 +1,50 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Downloads some test data file as part of configure. +# This does a download->rename in an attempt to be robust to partial downloads. +# It should not be used to manage large test data files or anything sensitive +# enough to require a hash check. +# The output file is added as an additional clean file on the global +# shortfin_testdata_deps target, meaning the "ninja clean" will remove it. +# It is also added to the current directories list of configure depends, which +# means that if ninja is run and it is not present, cmake will be re-invoked. +function(shortfin_download_test_data) + cmake_parse_arguments( + _RULE + "" + "URL;OUTPUT_FILE" + "" + ${ARGN} + ) + if(NOT SHORTFIN_BUILD_TESTS) + return() + endif() + if(NOT EXISTS "${_RULE_OUTPUT_FILE}") + set(_stage_file "${_RULE_OUTPUT_FILE}.stage") + message(STATUS "Downloading test data ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}") + file(DOWNLOAD "${_RULE_URL}" "${_stage_file}" STATUS _status) + list(POP_FRONT _status _status_code) + if(_status_code EQUAL "0") + file(RENAME "${_stage_file}" "${_RULE_OUTPUT_FILE}") + else() + message(SEND_ERROR "Error downloading file ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}") + endif() + endif() + + # Make clean remove it. + set_property( + TARGET shortfin_testdata_deps + APPEND PROPERTY ADDITIONAL_CLEAN_FILES + "${_RULE_OUTPUT_FILE}" + ) + + # And make us reconfigure if it isn't there. + set_property( + DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + APPEND PROPERTY + CMAKE_CONFIGURE_DEPENDS "${_RULE_OUTPUT_FILE}") +endfunction() diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc index a05232674..08a4071a8 100644 --- a/shortfin/python/array_binding.cc +++ b/shortfin/python/array_binding.cc @@ -531,22 +531,26 @@ void BindArray(py::module_ &m) { ->AddAsInvocationArgument( inv, static_cast(barrier)); }) - .def_static("for_device", - [](local::ScopedDevice &device, std::span shape, - DType dtype) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/device.fiber(), - device_array::for_device(device, shape, dtype)); - }) - .def_static("for_host", - [](local::ScopedDevice &device, std::span shape, - DType dtype) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/device.fiber(), - device_array::for_host(device, shape, dtype)); - }) + .def_static( + "for_device", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/device.fiber(), + device_array::for_device(device, shape, dtype)); + }, + py::arg("device"), py::arg("shape"), py::arg("dtype")) + .def_static( + "for_host", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/device.fiber(), + device_array::for_host(device, shape, dtype)); + }, + py::arg("device"), py::arg("shape"), py::arg("dtype")) .def("for_transfer", [](device_array &self) { return custom_new_keep_alive( diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index 86385cfee..3e2a8ebe3 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -91,6 +91,18 @@ static const char DOCSTRING_RANDOM_GENERATOR[] = fixed number. )"; +static const char DOCSTRING_TRANSPOSE[] = + R"(Transposes axes of an array according to a permutation vector. + +Args: + input: Array to transpose. + permutation: New sequence of axes. Must have same number of elements as the + rank of input. + out: If given, then the results are written to this array. + device_visible: Whether to make the result array visible to devices. Defaults + to False. +)"; + #define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \ case DType::dtype_name(): \ return compute.template operator()() @@ -100,6 +112,25 @@ static const char DOCSTRING_RANDOM_GENERATOR[] = compute.template operator()(); \ break +#define SF_MOVEMENT_OP_SWITCH(dtype) \ + if (!dtype.is_byte_aligned()) \ + throw std::invalid_argument( \ + "data movement ops are only defined for byte aligned dtypes"); \ + switch (dtype.dense_byte_count()) { \ + case 1: \ + return compute.template operator()(); \ + case 2: \ + return compute.template operator()(); \ + case 4: \ + return compute.template operator()(); \ + case 8: \ + return compute.template operator()(); \ + default: \ + throw std::invalid_argument( \ + "data movement ops are only defined for dtypes of size 1, 2, " \ + "4, 8"); \ + } + struct PyRandomGenerator { public: using SeedType = xt::random::default_engine_type::result_type; @@ -374,6 +405,227 @@ struct ConvertTruncFunctor { } }; +void OptionalArrayCast(py::handle handle, + std::optional &maybe_array) { + if (py::isinstance(handle)) { + maybe_array.emplace(py::cast(handle)); + } +} + +int DTypePromotionRank(DType dtype) { + int rank = 1; + if (dtype.is_boolean()) + rank *= 1000; + else if (dtype.is_integer()) + rank *= 2000; + else if (dtype.is_float()) + rank *= 4000; + else if (dtype.is_complex()) + rank *= 8000; + return rank + dtype.bit_count(); +} + +DType PromoteArithmeticTypes(std::optional lhs_dtype, + std::optional rhs_dtype) { + if (!lhs_dtype && !rhs_dtype) { + throw std::invalid_argument( + "Elementwise operators require at least one argument to be a " + "device_array"); + } + + // One not an array: promote to the array type. + if (!lhs_dtype) + return *rhs_dtype; + else if (!rhs_dtype) + return *lhs_dtype; + + int lhs_rank = DTypePromotionRank(*lhs_dtype); + int rhs_rank = DTypePromotionRank(*rhs_dtype); + DType promoted_dtype = lhs_rank < rhs_rank ? *rhs_dtype : *lhs_dtype; + + // If mismatched signed/unsigned, then need to promote to the next signed + // dtype. + if (promoted_dtype.is_integer()) { + bool lhs_unsigned = iree_all_bits_set( + lhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED); + bool rhs_unsigned = iree_all_bits_set( + rhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED); + if ((lhs_unsigned || rhs_unsigned) && !(lhs_unsigned && rhs_unsigned)) { + // Signed/unsigned mismatch. Promote to next. + switch (promoted_dtype) { + case DType::uint8(): + case DType::int8(): + return DType::int16(); + case DType::uint16(): + case DType::int16(): + return DType::int32(); + case DType::uint32(): + case DType::int32(): + return DType::int64(); + default: + // Jax's type promotion chart says this goes to a weak FP type, but + // we don't implement such a construct and I don't really see how + // that makes sense in a system setting like this, so we just saturate + // to 64bit. + return DType::int64(); + } + } + } + + return promoted_dtype; +} + +// ---------------------------------------------------------------------------// +// Elementwise support +// ---------------------------------------------------------------------------// + +// Python element type scalar conversion functions. +uint8_t ConvertPyToEltTy(py::handle py_value, uint8_t zero) { + return py::cast(py_value); +} + +int8_t ConvertPyToEltTy(py::handle py_value, int8_t zero) { + return py::cast(py_value); +} + +uint16_t ConvertPyToEltTy(py::handle py_value, uint16_t zero) { + return py::cast(py_value); +} + +int16_t ConvertPyToEltTy(py::handle py_value, int16_t zero) { + return py::cast(py_value); +} + +uint32_t ConvertPyToEltTy(py::handle py_value, uint32_t zero) { + return py::cast(py_value); +} + +int32_t ConvertPyToEltTy(py::handle py_value, int32_t zero) { + return py::cast(py_value); +} + +uint64_t ConvertPyToEltTy(py::handle py_value, uint64_t zero) { + return py::cast(py_value); +} + +int64_t ConvertPyToEltTy(py::handle py_value, int64_t zero) { + return py::cast(py_value); +} + +float ConvertPyToEltTy(py::handle py_value, float zero) { + return py::cast(py_value); +} + +double ConvertPyToEltTy(py::handle py_value, double zero) { + return py::cast(py_value); +} + +half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) { + // Python can't cast directly to half so first go to double. + return static_cast(py::cast(py_value)); +} + +struct AddFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs + rhs; + } +}; + +struct DivideFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs / rhs; + } +}; + +struct MultiplyFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs * rhs; + } +}; + +struct SubtractFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs - rhs; + } +}; + +template +device_array ElementwiseOperation(py::handle lhs, py::handle rhs, + std::optional out, + bool device_visible) { + std::optional lhs_array; + OptionalArrayCast(lhs, lhs_array); + std::optional rhs_array; + OptionalArrayCast(rhs, rhs_array); + auto dtype = PromoteArithmeticTypes( + lhs_array ? std::optional(lhs_array->dtype()) : std::nullopt, + rhs_array ? std::optional(rhs_array->dtype()) : std::nullopt); + if (lhs_array && lhs_array->dtype() != dtype) { + auto converted = GenericElementwiseConvert( + *lhs_array, dtype, /*out=*/std::nullopt, + /*device_visible=*/false); + lhs_array.reset(); + lhs_array.emplace(std::move(converted)); + } + if (rhs_array && rhs_array->dtype() != dtype) { + auto converted = GenericElementwiseConvert( + *rhs_array, dtype, /*out=*/std::nullopt, + /*device_visible=*/false); + rhs_array.reset(); + rhs_array.emplace(std::move(converted)); + } + + auto compute = [&]() -> device_array { + auto handle_result = [&]( + D &&device, A &&result) -> device_array { + if (!out) { + out.emplace(device_array::for_host(device, result.shape(), dtype, + device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = result; + return *out; + }; + if (!rhs_array) { + auto lhs_t = lhs_array->map_xtensor(); + xt::xarray rhs_scalar = ConvertPyToEltTy(rhs, EltTy()); + return handle_result(lhs_array->device(), + ElementwiseFunctor::Invoke(*lhs_t, rhs_scalar)); + } else if (!lhs_array) { + xt::xarray lhs_scalar = ConvertPyToEltTy(lhs, EltTy()); + auto rhs_t = rhs_array->map_xtensor(); + return handle_result(rhs_array->device(), + ElementwiseFunctor::Invoke(lhs_scalar, *rhs_t)); + } else { + auto lhs_t = lhs_array->map_xtensor(); + auto rhs_t = rhs_array->map_xtensor(); + return handle_result(lhs_array->device(), + ElementwiseFunctor::Invoke(*lhs_t, *rhs_t)); + } + }; + + switch (dtype) { + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); + SF_UNARY_FUNCTION_CASE(float64, double); + SF_UNARY_FUNCTION_CASE(uint8, uint8_t); + SF_UNARY_FUNCTION_CASE(int8, int8_t); + SF_UNARY_FUNCTION_CASE(uint16, uint16_t); + SF_UNARY_FUNCTION_CASE(int16, int16_t); + SF_UNARY_FUNCTION_CASE(uint32, uint32_t); + SF_UNARY_FUNCTION_CASE(int32, uint32_t); + SF_UNARY_FUNCTION_CASE(uint64, uint64_t); + SF_UNARY_FUNCTION_CASE(int64, int64_t); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for in elementwise op", dtype.name())); + } +} + } // namespace void BindArrayHostOps(py::module_ &m) { @@ -457,6 +709,39 @@ void BindArrayHostOps(py::module_ &m) { SF_DEF_CONVERT("floor", GenericElementwiseConvert); SF_DEF_CONVERT("round", GenericElementwiseConvert); SF_DEF_CONVERT("trunc", GenericElementwiseConvert); -} + + // Transpose. + m.def( + "transpose", + [](device_array input, std::vector permutation, + std::optional out, bool device_visible) { + auto compute = [&]() -> device_array { + auto input_t = input.map_xtensor(); + auto permuted_t = + xt::transpose(*input_t, permutation, xt::check_policy::full()); + if (!out) { + out.emplace(device_array::for_host(input.device(), + permuted_t.shape(), + input.dtype(), device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = permuted_t; + return *out; + }; + SF_MOVEMENT_OP_SWITCH(input.dtype()); + }, + py::arg("input"), py::arg("permutation"), py::arg("out") = py::none(), + py::arg("device_visible") = false, DOCSTRING_TRANSPOSE); + +// Elementwise. +#define SF_DEF_ELEMENTWISE(py_name, target) \ + m.def(py_name, target, py::arg("lhs"), py::arg("rhs"), py::kw_only(), \ + py::arg("out") = py::none(), py::arg("device_visible") = false) + SF_DEF_ELEMENTWISE("add", ElementwiseOperation); + SF_DEF_ELEMENTWISE("divide", ElementwiseOperation); + SF_DEF_ELEMENTWISE("multiply", ElementwiseOperation); + SF_DEF_ELEMENTWISE("subtract", ElementwiseOperation); + +} // namespace shortfin::python } // namespace shortfin::python diff --git a/shortfin/python/shortfin/array/__init__.py b/shortfin/python/shortfin/array/__init__.py index 6079541c8..670102dfe 100644 --- a/shortfin/python/shortfin/array/__init__.py +++ b/shortfin/python/shortfin/array/__init__.py @@ -44,11 +44,16 @@ # Ops. argmax = _sfl.array.argmax +add = _sfl.array.add ceil = _sfl.array.ceil convert = _sfl.array.convert +divide = _sfl.array.divide fill_randn = _sfl.array.fill_randn floor = _sfl.array.floor +multiply = _sfl.array.multiply round = _sfl.array.round +subtract = _sfl.array.subtract +transpose = _sfl.array.transpose trunc = _sfl.array.trunc RandomGenerator = _sfl.array.RandomGenerator @@ -86,12 +91,17 @@ "storage", "DType", # Ops. + "add", "argmax", "ceil", "convert", + "divide", "fill_randn", "floor", + "multiply", "round", + "subtract", + "transpose", "trunc", "RandomGenerator", ] diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py index 7123d011e..fb8ca8176 100644 --- a/shortfin/python/shortfin_apps/llm/_deps.py +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -5,13 +5,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from shortfin.support.deps import ShortfinDepNotFoundError +import sys -try: - import tokenizers -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "tokenizers") from e +deps = [ + "tokenizers", + "dataclasses_json", +] -try: - import dataclasses_json -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e +for dep in deps: + try: + __import__(dep) + except ModuleNotFoundError as e: + if "pytest" in sys.modules: + import pytest + + pytest.skip( + f"A test imports shortfin_apps.llm; skipping due to unavailable Shortfin LLM dependency: {dep}", + allow_module_level=True, + ) + else: + raise ShortfinDepNotFoundError(__name__, dep) from e diff --git a/shortfin/python/shortfin_apps/llm/components/cache.py b/shortfin/python/shortfin_apps/llm/components/cache.py deleted file mode 100644 index 12794498f..000000000 --- a/shortfin/python/shortfin_apps/llm/components/cache.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Sequence - -import logging -import math -import threading - -import shortfin as sf - -from .config_struct import ModelParams, human_size - -logger = logging.getLogger(__name__) - - -class AttnPageEntry: - __slots__ = [ - "cache", - "index", - "in_use", - ] - - def __init__(self, cache: "AttnPageCache", index: int): - self.cache = cache - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnPageCache: - """Page table based attention cache. - - While internal to a model, the cache is organized with additional structure - per page, outside of the model, it is just a list of pages of a certain - element type and number of elements (all inner dims are flattened). - - One page table is allocated per device in a fiber. Currently, this is a - dense allocation with committed memory but in the future, we may just - allocate the address space and lazily populate it with committed memory. - - The cache is unique because usage of it can span fibers and concurrency - is implicitly managed at the block level (i.e. freshly acquired blocks - are assumed to be uninitialized and available immediately for use). - - It is initialized with a discrete list of fiberd devices from a fiber but - cache usage can be done from any fiber which includes those devices. - """ - - def __init__( - self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams - ): - self._lock = threading.Lock() - self.devices = list(devices) - self.model_params = model_params - self.page_tables: list[sf.array.device_array] = [] - cache_params = model_params.paged_kv_cache - alloc_page_count = cache_params.device_block_count - - # Setup accounting structs. - self.attn_page_entries = [ - AttnPageEntry(self, i) for i in range(alloc_page_count) - ] - self.attn_page_free = list(self.attn_page_entries) - - # Initialize a page table on each device. - assert cache_params is not None, "Model does not have a paged kv cache" - page_table_shape = [ - alloc_page_count, - model_params.paged_kv_block_size_elements, - ] - for device in devices: - logging.info( - "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", - page_table_shape, - model_params.attn_dtype, - human_size( - math.prod(page_table_shape) - * model_params.attn_dtype.dense_byte_count - ), - device, - ) - page_table = sf.array.device_array.for_device( - device, page_table_shape, model_params.attn_dtype - ) - self.page_tables.append(page_table) - - def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None: - with self._lock: - available = len(self.attn_page_free) - if count > available: - return None - return [self.attn_page_free.pop() for _ in range(count)] - - def release_pages(self, pages: list[AttnPageEntry]): - with self._lock: - self.attn_page_free.extend(pages) - - def __repr__(self): - # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) - total_pages = len(self.attn_page_entries) - return ( - f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " - f"{100.0 * free_pages / total_pages}% free)" - ) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py new file mode 100644 index 000000000..0007000bc --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -0,0 +1,80 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Base class for kv caches. +""" + +from typing import List +from .page_pool import PageInfo +import math + + +class BasePagedAttentionCache: + """ + Manages lifecycle of pages (using PageInfo as handles). + + + Page States: + Caching - Page can be read by multiple threads + - Also maintains a reference count + Writing - Page is being modified by a single owner thread + + Transitions: + Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing + Writing -> Caching: When writing is complete and page is released + + Thread Safety: + - Multiple readers allowed in ReadableCaching state + - Single writer exclusive access in Writing state + - Reference counting prevents eviction of in-use pages + """ + + def __init__(self, page_pool, tokens_per_page): + self.page_pool = page_pool + self.tokens_per_page = tokens_per_page + + def acquire_pages_for_tokens( + self, tokens: List[int], extra_token_slots: int = 1 + ) -> tuple[list[PageInfo], int]: + """ + Given a list of tokens, return a list of pages and a start position to continue generation from. + + Parameters: + - tokens: all the known tokens for this generation request + - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. + + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. + + The pages are returned in order. + + No token at idx < n_cached_token should be written to. TODO: consider enforcing this. + """ + token_count = len(tokens) + pages_needed = math.ceil(token_count / self.tokens_per_page) + pages = self.page_pool.acquire_free_pages(pages_needed) + + n_cached_tokens = 0 + + return pages, n_cached_tokens + + def publish_pages(self, tokens, pages) -> None: + """ + Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + + Associates the tokens with the pages, and mark them as done writing. + + It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. + """ + + pass # the base implementation doesn't cache unfinished requests. + + def release_pages(self, tokens, pages): + """ + Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. + """ + # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release + self.page_pool.release_pages(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py new file mode 100644 index 000000000..1686370c0 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -0,0 +1,159 @@ +from __future__ import annotations +from typing import List, Tuple, Optional, Sequence +import threading +import logging +import shortfin as sf +import shortfin.array as sfnp +from dataclasses import dataclass + +from ..config_struct import human_size +import math + +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class PageInfo: + """ + Page index with some metadata about its contents. + """ + + index: int + pool: PagePool + token_offset: int # Offset within the page + token_count: int # Number of tokens stored in this page + writing: bool = False + read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release + + +@dataclass +class PagePoolConfig: + """ + Hyperparameters for the page pool. + """ + + dtype: sf.dtype + alloc_page_count: int + + paged_kv_block_size_elements: int # size of a single page as # of elements + # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: + # 32: number of transformer blocks + # 2: one for k + one for v + # 16: tokens per page + # 8: head count (32 heads, but every 4 heads share the same kv buffer) + # 128: hidden dimension + + +class PagePool: + """Page table based attention cache. + + While internal to a model, the cache is organized with additional structure + per page, outside of the model, it is just a list of pages of a certain + element type and number of elements (all inner dims are flattened). + + One page table is allocated per device in a fiber. Currently, this is a + dense allocation with committed memory but in the future, we may just + allocate the address space and lazily populate it with committed memory. + + The cache is unique because usage of it can span fibers and concurrency + is implicitly managed at the block level (i.e. freshly acquired blocks + are assumed to be uninitialized and available immediately for use). + + It is initialized with a discrete list of fiberd devices from a fiber but + cache usage can be done from any fiber which includes those devices. + + In addition to supporting paged attention standalone, this also serves + as the array / buffer allocation layer for radix attention described in + `radix_tree.py`. + """ + + def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): + self._lock = threading.Lock() + self.devices = list(devices) + self.config = config + self.page_tables: list[sf.array.device_array] = [] + + # Setup accounting structs. + self.attn_page_entries = [ + PageInfo( + index=i, + pool=self, + token_offset=0, + token_count=0, + ) + for i in range(self.config.alloc_page_count) + ] + + self.attn_page_free = list(self.attn_page_entries) + + # Initialize a page table on each device. + page_table_shape = [ + self.config.alloc_page_count, + self.config.paged_kv_block_size_elements, + ] + for device in devices: + logging.info( + "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", + page_table_shape, + self.config.dtype, + human_size(config.dtype.compute_dense_nd_size(page_table_shape)), + device, + ) + page_table = sf.array.device_array.for_device( + device, page_table_shape, self.config.dtype + ) + self.page_tables.append(page_table) + + def acquire_free_pages(self, count: int) -> list[PageInfo] | None: + with self._lock: + available = len(self.attn_page_free) + if count > available: + return None + return [self.attn_page_free.pop() for _ in range(count)] + + def release_pages(self, pages: list[PageInfo]): + with self._lock: + self.attn_page_free.extend(pages) + + def copy_page(self, src_page: PageInfo) -> PageInfo: + """ + Copy a page's contents to a new page. + + Args: + src_page: Source page to copy from + token_count: Optional number of tokens to copy. If None, copies all tokens. + + Returns: + New PageInfo containing the copied data + """ + # Allocate new page + (dst_page,) = self.acquire_free_pages(1) + + # fill src page with data + + # Copy the data on each device + for page_table in self.page_tables: + # View of source and destination pages + src_view = page_table.view(src_page.index) + dst_view = page_table.view(dst_page.index) + # Copy the data + dst_view.copy_from(src_view) + + # Setup destination page metadata + dst_page.token_offset = 0 # Always start at beginning of new page + + return dst_page + + def __repr__(self): + # No need to lock for repr (list is internally synchronized). + free_pages = len(self.attn_page_free) + total_pages = len(self.attn_page_entries) + return ( + f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " + f"{100.0 * free_pages / total_pages}% free)" + ) + + +############################## begin radix attention diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index fdcbeefc1..c3e6fe34b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache, AttnPageEntry +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PageInfo class InferencePhase(Enum): @@ -41,8 +42,8 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): self.result_logits: sfnp.device_array | None = None # Cache pages that have been locked for this request. - self._cache: AttnPageCache | None = None - self.locked_pages: list[AttnPageEntry] | None = None + self._cache: BasePagedAttentionCache | None = None + self.locked_pages: list[PageInfo] | None = None def reset(self, phase: InferencePhase): """Resets all per request state in preparation for an subsequent execution.""" @@ -66,16 +67,18 @@ def free_cache_pages(self): pages = self.locked_pages self._cache = None self.locked_pages = None - cache.release_pages(pages) + cache.release_pages(self.input_token_ids, pages) def lock_initial_cache_pages( - self, cache: AttnPageCache, pages: list[AttnPageEntry] + self, cache: BasePagedAttentionCache, pages: list[PageInfo] ): assert not self._cache self._cache = cache self.locked_pages = pages - def lock_new_cache_pages(self, cache: AttnPageCache, pages: list[AttnPageEntry]): + def lock_new_cache_pages( + self, cache: BasePagedAttentionCache, pages: list[PageInfo] + ): assert self._cache is cache self.locked_pages.extend(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index bcd08b756..8d3cc1424 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,7 +11,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PagePoolConfig, PagePool from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -54,8 +55,17 @@ def __init__( # Scope dependent objects. self.batcher = BatcherProcess(self) - self.page_cache = AttnPageCache( - devices=self.main_fiber.devices_dict.values(), model_params=model_params + page_pool_config = PagePoolConfig( + dtype=model_params.attn_dtype, + alloc_page_count=model_params.paged_kv_cache.device_block_count, + paged_kv_block_size_elements=model_params.paged_kv_block_size_elements, + ) + page_pool = PagePool( + devices=self.main_fiber.devices_dict.values(), config=page_pool_config + ) + self.page_cache = BasePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) self.program_isolation = PROG_ISOLATIONS[program_isolation] @@ -200,7 +210,7 @@ def board_flights(self): self.pending_prefills.clear() logger.debug("Post boarding cache state: %r", cache) - def board_prefills(self, cache: AttnPageCache): + def board_prefills(self, cache: BasePagedAttentionCache): # Fill prefill flights. pending_prefills = self.pending_prefills if len(pending_prefills) == 0: @@ -209,7 +219,7 @@ def board_prefills(self, cache: AttnPageCache): self.service, InferencePhase.PREFILL, self.page_seq_stride, - cache.page_tables, + cache.page_pool.page_tables, ) for prefill_request in pending_prefills: assert prefill_request.phase == InferencePhase.PREFILL @@ -218,7 +228,11 @@ def board_prefills(self, cache: AttnPageCache): needed_pages = math.ceil( len(prefill_request.input_token_ids) / self.page_seq_stride ) - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) if pages is None: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue @@ -236,13 +250,16 @@ def board_prefills(self, cache: AttnPageCache): # And takeoff. exec_process.launch() - def board_decodes(self, cache: AttnPageCache): + def board_decodes(self, cache: BasePagedAttentionCache): # Fill decode flights. pending_decodes = self.pending_decodes if len(pending_decodes) == 0: return exec_process = InferenceExecutorProcess( - self.service, InferencePhase.DECODE, self.page_seq_stride, cache.page_tables + self.service, + InferencePhase.DECODE, + self.page_seq_stride, + cache.page_pool.page_tables, ) for decode_request in pending_decodes: assert decode_request.phase == InferencePhase.DECODE @@ -254,7 +271,11 @@ def board_decodes(self, cache: AttnPageCache): / self.page_seq_stride ) if needed_pages > len(decode_request.locked_pages): - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + decode_request.input_token_ids, + extra_token_slots=1, # need 1 extra slot to write result. + ) if pages is None: logger.debug( "Cannot fulfill decode request for %d pages", needed_pages diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 2ab7a1b96..1561803dd 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -33,6 +33,33 @@ logger = logging.getLogger(__name__) +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + @asynccontextmanager async def lifespan(app: FastAPI): @@ -211,11 +238,5 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): main( sys.argv[1:], # Make logging defer to the default shortfin logging config. - log_config={ - "version": 1, - "disable_existing_loggers": False, - "formatters": {}, - "handlers": {}, - "loggers": {}, - }, + log_config=UVICORN_LOG_CONFIG, ) diff --git a/shortfin/requirements-tests-nogil.txt b/shortfin/requirements-tests-nogil.txt index 1049b0412..1769467ab 100644 --- a/shortfin/requirements-tests-nogil.txt +++ b/shortfin/requirements-tests-nogil.txt @@ -1,4 +1,3 @@ pytest requests -fastapi uvicorn diff --git a/shortfin/setup.py b/shortfin/setup.py index cf3762950..e15b38d89 100644 --- a/shortfin/setup.py +++ b/shortfin/setup.py @@ -225,6 +225,7 @@ def build_cmake_configuration(CMAKE_BUILD_DIR: Path, extra_cmake_args=[]): add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_LTO", default_value="ON") add_env_cmake_setting(cmake_args, "SHORTFIN_IREE_SOURCE_DIR") add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_ASAN") + add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_TOKENIZERS", default_value="OFF") # Only do a from-scratch configure if not already configured. cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt") diff --git a/shortfin/src/shortfin/CMakeLists.txt b/shortfin/src/shortfin/CMakeLists.txt index 058e0e336..73df08e7c 100644 --- a/shortfin/src/shortfin/CMakeLists.txt +++ b/shortfin/src/shortfin/CMakeLists.txt @@ -5,5 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(array) +add_subdirectory(components/tokenizers) add_subdirectory(local) add_subdirectory(support) diff --git a/shortfin/src/shortfin/array/dtype.h b/shortfin/src/shortfin/array/dtype.h index d746d69bf..de1763698 100644 --- a/shortfin/src/shortfin/array/dtype.h +++ b/shortfin/src/shortfin/array/dtype.h @@ -49,6 +49,9 @@ class SHORTFIN_API DType { bool is_integer_bitwidth(size_t bitwidth) const { return iree_hal_element_type_is_integer(et_, bitwidth); } + uint32_t numerical_type() const { + return iree_hal_element_numerical_type(et_); + } // Computes the size in bytes required to store densely packed nd-dims. // This presently only supports byte aligned dtypes. In the future, when diff --git a/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt b/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt new file mode 100644 index 000000000..6b9f794b1 --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +if(NOT SHORTFIN_ENABLE_TOKENIZERS) + return() +endif() + +shortfin_cc_component( + NAME + shortfin_tokenizers + HDRS + tokenizers.h + SRCS + tokenizers.cc + DEFINES + SHORTFIN_HAVE_TOKENIZERS + COMPONENTS + shortfin_support + DEPS + tokenizers_cpp +) +set_property(GLOBAL APPEND + PROPERTY SHORTFIN_LIB_OPTIONAL_COMPONENTS + shortfin_tokenizers) +target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_TOKENIZERS) + +# Download test data. +shortfin_download_test_data( + URL "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json" + OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/tokenizer.json" +) + +# Note that tests run from the binary dir of the project. +shortfin_gtest_test( + NAME shortfin_tokenizers_test + SRCS + tokenizers_test.cc +) diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers.cc b/shortfin/src/shortfin/components/tokenizers/tokenizers.cc new file mode 100644 index 000000000..118bc0c1b --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/tokenizers.cc @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/components/tokenizers/tokenizers.h" + +#include + +#include "shortfin/support/logging.h" +#include "tokenizers_cpp.h" + +namespace shortfin::tokenizers { + +namespace { + +class AccessibleTokenizer : public Tokenizer { + public: + using Tokenizer::vendor_tokenizer_; +}; + +::tokenizers::Tokenizer *Get(Tokenizer *self) { + void *ptr = static_cast(self)->vendor_tokenizer_; + if (!ptr) { + throw std::logic_error("Tokenizer is null"); + } + return static_cast<::tokenizers::Tokenizer *>(ptr); +} + +} // namespace + +Tokenizer::~Tokenizer() { delete Get(this); } + +Tokenizer Tokenizer::FromBlobJSON(const std::string &json_blob) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::FromBlobJSON"); + return Tokenizer(::tokenizers::Tokenizer::FromBlobJSON(json_blob).release()); +} + +std::vector Tokenizer::Encode(const std::string &text) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Encode"); + return Get(this)->Encode(text); +} + +std::vector> Tokenizer::EncodeBatch( + const std::vector &texts) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::EncodeBatch"); + return Get(this)->EncodeBatch(texts); +} + +std::string Tokenizer::Decode(const std::vector &ids) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Decode"); + return Get(this)->Decode(ids); +} +size_t Tokenizer::GetVocabSize() { return Get(this)->GetVocabSize(); } +std::string Tokenizer::IdToToken(int32_t token_id) { + return Get(this)->IdToToken(token_id); +} +int32_t Tokenizer::TokenToId(const std::string &token) { + return Get(this)->TokenToId(token); +} + +} // namespace shortfin::tokenizers diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers.h b/shortfin/src/shortfin/components/tokenizers/tokenizers.h new file mode 100644 index 000000000..d263eace6 --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/tokenizers.h @@ -0,0 +1,52 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H +#define SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H + +#include +#include + +#include "shortfin/support/api.h" + +namespace shortfin::tokenizers { + +// A vendored Tokenizer class that does not export the details of the backing +// implementation. While a little bit gross, this keeps us from needing to +// re-export a vendor'ed API as part of our public API. +// The current vendor tokenizer is based on mlc-ai/tokenizers-cpp. The API +// is fairly close to that implementation. +// See: https://github.com/mlc-ai/tokenizers-cpp +class SHORTFIN_API Tokenizer { + public: + Tokenizer(const Tokenizer &) = delete; + Tokenizer &operator=(const Tokenizer &) = delete; + Tokenizer(Tokenizer &&other) : vendor_tokenizer_(other.vendor_tokenizer_) { + vendor_tokenizer_ = nullptr; + } + ~Tokenizer(); + + // Factory functions. + static Tokenizer FromBlobJSON(const std::string &json_blob); + + std::vector Encode(const std::string &text); + std::vector> EncodeBatch( + const std::vector &texts); + std::string Decode(const std::vector &ids); + size_t GetVocabSize(); + std::string IdToToken(int32_t token_id); + int32_t TokenToId(const std::string &token); + + private: + Tokenizer(void *vendor_tokenizer) : vendor_tokenizer_(vendor_tokenizer) {} + + protected: + void *vendor_tokenizer_; +}; + +} // namespace shortfin::tokenizers + +#endif // SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc b/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc new file mode 100644 index 000000000..674721653 --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc @@ -0,0 +1,56 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/components/tokenizers/tokenizers.h" + +#include +#include + +#include +#include + +using namespace shortfin::tokenizers; + +namespace { + +std::string ReadFile(std::filesystem::path path) { + std::ifstream in(path); + std::ostringstream out; + out << in.rdbuf(); + return out.str(); +} + +} // namespace + +// TODO: Enable once upstream changes with error handling have landed. +// Currently aborts. +// See: https://github.com/mlc-ai/tokenizers-cpp/issues/50 +// TEST(TokenizersTest, FromIllegalBlobJson) { +// auto tok = Tokenizer::FromBlobJSON("foobar"); +// } + +TEST(TokenizersTest, BasicTokenizerJson) { + std::filesystem::path tokenizer_path( + "src/shortfin/components/tokenizers/tokenizer.json"); + auto tokenizer_json = ReadFile(tokenizer_path); + ASSERT_GT(tokenizer_json.size(), 0) + << "reading " << tokenizer_path + << " (cwd: " << std::filesystem::current_path() << ")"; + auto tok = Tokenizer::FromBlobJSON(tokenizer_json); + EXPECT_GT(tok.GetVocabSize(), 100); // Sanity check + auto encoded = tok.Encode("hello world"); + EXPECT_THAT(encoded, + ::testing::ContainerEq(std::vector{19082, 1362})); + auto batch_encoded = tok.EncodeBatch({"hello", "world"}); + ASSERT_EQ(batch_encoded.size(), 2); + EXPECT_THAT(batch_encoded[0], + ::testing::ContainerEq(std::vector{19082})); + EXPECT_THAT(batch_encoded[1], + ::testing::ContainerEq(std::vector{1362})); + EXPECT_EQ(tok.TokenToId("hello"), 19082); + auto decoded = tok.Decode(encoded); + EXPECT_EQ(decoded, "hello world"); +} diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py index 7c792d92b..164dfb479 100644 --- a/shortfin/tests/api/array_ops_test.py +++ b/shortfin/tests/api/array_ops_test.py @@ -268,3 +268,171 @@ def test_nearest_int_conversion(device, dtype, out_dtype, sfnp_func, ref_round_f assert output.dtype == out_dtype for ref, actual in zip(ref_rounded, output.items): assert ref == int(actual) + + +def test_elementwise_forms(device): + # All elementwise ops use the same template expansion which enforces + # certain common invariants. Here we test these on the multiply op, + # relying on a parametric test for actual behavior. + with pytest.raises( + ValueError, + match="Elementwise operators require at least one argument to be a device_array", + ): + sfnp.multiply(2, 2) + + ary = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32) + with ary.map(discard=True) as m: + m.fill(42.0) + + # Rhs scalar int accepted. + result = sfnp.multiply(ary, 2) + assert list(result.items) == [84.0] * 6 + + # Rhs scalar float accepted. + result = sfnp.multiply(ary, 2.0) + assert list(result.items) == [84.0] * 6 + + # Lhs scalar int accepted. + result = sfnp.multiply(2, ary) + assert list(result.items) == [84.0] * 6 + + # Lhs scalar float accepted. + result = sfnp.multiply(2.0, ary) + assert list(result.items) == [84.0] * 6 + + # Out. + out = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32) + sfnp.multiply(2.0, ary, out=out) + assert list(out.items) == [84.0] * 6 + + +@pytest.mark.parametrize( + "lhs_dtype,rhs_dtype,promoted_dtype", + [ + (sfnp.float32, sfnp.float16, sfnp.float32), + (sfnp.float16, sfnp.float32, sfnp.float32), + (sfnp.float32, sfnp.float64, sfnp.float64), + (sfnp.float64, sfnp.float32, sfnp.float64), + # Integer promotion. + (sfnp.uint8, sfnp.uint16, sfnp.uint16), + (sfnp.uint16, sfnp.uint32, sfnp.uint32), + (sfnp.uint32, sfnp.uint64, sfnp.uint64), + (sfnp.int8, sfnp.int16, sfnp.int16), + (sfnp.int16, sfnp.int32, sfnp.int32), + (sfnp.int32, sfnp.int64, sfnp.int64), + # Signed/unsigned promotion. + (sfnp.int8, sfnp.uint8, sfnp.int16), + (sfnp.int16, sfnp.uint16, sfnp.int32), + (sfnp.int32, sfnp.uint32, sfnp.int64), + (sfnp.int8, sfnp.uint32, sfnp.int64), + ], +) +def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): + # Tests that promotion infers an appropriate result type. + lhs = sfnp.device_array.for_host(device, [2, 3], lhs_dtype) + rhs = sfnp.device_array.for_host(device, [2, 3], rhs_dtype) + result = sfnp.multiply(lhs, rhs) + assert result.dtype == promoted_dtype + + +@pytest.mark.parametrize( + "dtype,op,check_value", + [ + # Add. + (sfnp.int8, sfnp.add, 44.0), + (sfnp.int16, sfnp.add, 44.0), + (sfnp.int32, sfnp.add, 44.0), + (sfnp.int64, sfnp.add, 44.0), + (sfnp.uint8, sfnp.add, 44.0), + (sfnp.uint16, sfnp.add, 44.0), + (sfnp.uint32, sfnp.add, 44.0), + (sfnp.uint64, sfnp.add, 44.0), + (sfnp.float16, sfnp.add, 44.0), + (sfnp.float32, sfnp.add, 44.0), + (sfnp.float64, sfnp.add, 44.0), + # Divide. + (sfnp.int8, sfnp.divide, 21.0), + (sfnp.int16, sfnp.divide, 21.0), + (sfnp.int32, sfnp.divide, 21.0), + (sfnp.int64, sfnp.divide, 21.0), + (sfnp.uint8, sfnp.divide, 21.0), + (sfnp.uint16, sfnp.divide, 21.0), + (sfnp.uint32, sfnp.divide, 21.0), + (sfnp.uint64, sfnp.divide, 21.0), + (sfnp.float16, sfnp.divide, 21.0), + (sfnp.float32, sfnp.divide, 21.0), + (sfnp.float64, sfnp.divide, 21.0), + # Multiply. + (sfnp.int8, sfnp.multiply, 84.0), + (sfnp.int16, sfnp.multiply, 84.0), + (sfnp.int32, sfnp.multiply, 84.0), + (sfnp.int64, sfnp.multiply, 84.0), + (sfnp.uint8, sfnp.multiply, 84.0), + (sfnp.uint16, sfnp.multiply, 84.0), + (sfnp.uint32, sfnp.multiply, 84.0), + (sfnp.uint64, sfnp.multiply, 84.0), + (sfnp.float16, sfnp.multiply, 84.0), + (sfnp.float32, sfnp.multiply, 84.0), + (sfnp.float64, sfnp.multiply, 84.0), + # Subtract. + (sfnp.int8, sfnp.subtract, 40.0), + (sfnp.int16, sfnp.subtract, 40.0), + (sfnp.int32, sfnp.subtract, 40.0), + (sfnp.int64, sfnp.subtract, 40.0), + (sfnp.uint8, sfnp.subtract, 40.0), + (sfnp.uint16, sfnp.subtract, 40.0), + (sfnp.uint32, sfnp.subtract, 40.0), + (sfnp.uint64, sfnp.subtract, 40.0), + (sfnp.float16, sfnp.subtract, 40.0), + (sfnp.float32, sfnp.subtract, 40.0), + (sfnp.float64, sfnp.subtract, 40.0), + ], +) +def test_elementwise_array_correctness(device, dtype, op, check_value): + lhs = sfnp.device_array.for_host(device, [2, 2], sfnp.int32) + with lhs.map(discard=True) as m: + m.fill(42) + + rhs = sfnp.device_array.for_host(device, [2], sfnp.int32) + with rhs.map(discard=True) as m: + m.fill(2) + + lhs = sfnp.convert(lhs, dtype=dtype) + rhs = sfnp.convert(rhs, dtype=dtype) + result = op(lhs, rhs) + assert result.shape == [2, 2] + result = sfnp.convert(result, dtype=sfnp.float32) + items = list(result.items) + assert items == [check_value] * 4 + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.int8, + sfnp.int16, + sfnp.int32, + sfnp.int64, + sfnp.uint8, + sfnp.uint16, + sfnp.uint32, + sfnp.uint64, + sfnp.float32, + sfnp.float16, + sfnp.float32, + sfnp.float64, + ], +) +def test_transpose(device, dtype): + input = sfnp.device_array.for_host(device, [3, 2], sfnp.int32) + input.items = [0, 1, 2, 3, 4, 5] + input = sfnp.convert(input, dtype=dtype) + permuted = sfnp.transpose(input, [1, 0]) + assert permuted.shape == [2, 3] + items = list(sfnp.convert(permuted, dtype=sfnp.int32).items) + assert items == [0, 2, 4, 1, 3, 5] + + out = sfnp.device_array.for_host(device, [2, 3], dtype) + sfnp.transpose(input, [1, 0], out=out) + items = list(sfnp.convert(permuted, dtype=sfnp.int32).items) + assert items == [0, 2, 4, 1, 3, 5] diff --git a/shortfin/tests/api/array_use_case_test.py b/shortfin/tests/api/array_use_case_test.py new file mode 100644 index 000000000..d4a030d45 --- /dev/null +++ b/shortfin/tests/api/array_use_case_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import array +import math +import pytest + +import shortfin as sf +import shortfin.array as sfnp + + +@pytest.fixture +def lsys(): + # TODO: Port this test to use memory type independent access. It currently + # presumes unified memory. + # sc = sf.SystemBuilder() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +# Tests a typical image conversion from a model oriented layout to an array +# of contained images. +def test_image_to_bytes(device): + bs = 2 + height = 16 + width = 12 + images_shape = [bs, 3, height, width] + images_planar = sfnp.device_array.for_host(device, images_shape, sfnp.float32) + # Band the data so that each channel increases by 0.1 across images. + for i in range(bs): + for j in range(3): + data = [i * 0.3 + j * 0.1 for _ in range(height * width)] + images_planar.view(i, j).items = data + images_planar = sfnp.convert(images_planar, dtype=sfnp.float16) + + # Extract and convert each image to interleaved RGB bytes. + images = [] + for idx in range(images_planar.shape[0]): + image_planar = images_planar.view(idx) + assert image_planar.shape == [1, 3, 16, 12] + image_interleaved = sfnp.transpose(image_planar, (0, 2, 3, 1)) + assert image_interleaved.shape == [1, 16, 12, 3] + image_scaled = sfnp.multiply(image_interleaved, 255) + image = sfnp.round(image_scaled, dtype=sfnp.uint8) + image_bytes = bytes(image.map(read=True)) + images.append(image_bytes) + + assert images[0] == b"\x00\x1a3" * 192 + assert images[1] == b"Mf\x80" * 192 diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py deleted file mode 100644 index 169d082b1..000000000 --- a/shortfin/tests/apps/llm/components/cache_test.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Tests for llm kvcache component. -""" - -import pytest -import time -import tempfile -import shortfin as sf -from _shortfin import lib as sfl -from shortfin_apps.llm.components import cache -from shortfin_apps.llm.components import config_struct -import json -from pathlib import Path - - -@pytest.fixture -def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - ls = sc.create_system() - yield ls - ls.shutdown() - - -@pytest.fixture -def fiber(lsys): - # TODO: Should adopt the main thread. - worker = lsys.create_worker("main") - return lsys.create_fiber(worker) - - -@pytest.fixture -def device(fiber): - return fiber.device(0) - - -@pytest.fixture -def model_params(): - model_params = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": [4], - "decode_batch_sizes": [4], - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - - # Create a temporary file to store the JSON - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as tmp_file: - json.dump(model_params, tmp_file, indent=4) - tmp_path = Path(tmp_file.name) - - try: - # Load the JSON using config_struct - model_params = config_struct.ModelParams.load_json(tmp_path) - yield model_params - finally: - tmp_path.unlink - - -@pytest.fixture -def cache_fixture(fiber, model_params) -> cache.AttnPageCache: - # Create and return the cache object - return cache.AttnPageCache( - devices=fiber.devices_dict.values(), model_params=model_params - ) - - -@pytest.mark.parametrize("n_allocated", [1, 16, 255]) -def test_alloc( - cache_fixture: cache.AttnPageCache, - n_allocated, - model_params: config_struct.ModelParams, -): - alloc_page_count = cache_fixture.page_tables[0].shape[0] - - assert alloc_page_count == model_params.paged_kv_cache.device_block_count - - pages = cache_fixture.acquire_free_pages(n_allocated) - last_page = alloc_page_count - 1 - expected_indices = range(last_page, last_page - n_allocated, -1) - for p, expected_ix in zip(pages, expected_indices): - assert p.index == expected_ix - assert p.index > 0 diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py new file mode 100644 index 000000000..a1ec00c07 --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -0,0 +1,57 @@ +import pytest +import logging +from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PagePoolConfig +import shortfin as sf +import shortfin.host +import shortfin.array as sfnp +import shortfin.amdgpu + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def setup_pool(generic_device): + pool = PagePool( + devices=[generic_device], + config=PagePoolConfig( + alloc_page_count=256, + dtype=sfnp.float16, + paged_kv_block_size_elements=393216, + ), + ) + return pool + + +def test_page_acquisition(setup_pool): + pool = setup_pool + logger.info(f"=== Running page acquisition test on system ===") + page0 = pool.acquire_free_pages(1) + assert page0 is not None, f"Failed to acquire a free page on system" + logger.info(f"Successfully acquired page on system") + + +def test_page_copy(setup_pool): + pool = setup_pool + logger.info(f"=== Running page copy test on system ===") + (page0,) = pool.acquire_free_pages(1) + page1 = pool.copy_page(page0) + assert page1 is not None, f"Failed to copy a page on system" + assert page0 != page1, f"Copied page should be different from original on system" + logger.info(f"Successfully copied page on system") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Set up logging format to include timestamp and level""" + logging.basicConfig( + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + force=True, + ) + + +# Add more tests as needed + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 083698968..b16d5a3c9 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -50,6 +50,17 @@ def pytest_runtest_setup(item): sf.SystemBuilder.default_system_type = system_type +# Dynamic Parameterization for lsys Fixture +def pytest_generate_tests(metafunc): + if "generic_lsys" in metafunc.fixturenames: + system = metafunc.config.getoption("--system") + if system == "amdgpu": + params = ["cpu", "amdgpu"] + else: + params = [system] + metafunc.parametrize("generic_lsys", params, indirect=True) + + # Keys that will be cleaned project wide prior to and after each test run. # Test code can freely modify these. CLEAN_ENV_KEYS = [ @@ -96,6 +107,28 @@ def kill(): kill() +@pytest.fixture(scope="session") +def generic_lsys(request): + system_type = request.param + if system_type == "cpu" or system_type == "hostcpu": + sc = sf.host.CPUSystemBuilder() + elif system_type == "amdgpu": + sc = sf.amdgpu.SystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def generic_fiber(generic_lsys): + return generic_lsys.create_fiber() + + +@pytest.fixture +def generic_device(generic_fiber): + return generic_fiber.device(0) + + @pytest.fixture def cpu_lsys(): sc = sf.host.CPUSystemBuilder() diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 5786a9fff..f09e08888 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -30,6 +30,8 @@ from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore + from .common import * from .dispatch_constraints import * from .dispatch_parser import * @@ -50,7 +52,7 @@ def apply_configuration( expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" @@ -117,7 +119,6 @@ def get_transform_function_mmt( wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op @@ -130,7 +131,7 @@ def get_transform_function_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -203,7 +204,7 @@ def get_transform_function_conv( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -264,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -344,7 +345,7 @@ def get_transform_function_batch_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -412,7 +413,7 @@ def get_transform_function_batch_matmul( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -535,13 +536,19 @@ def tune( walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + # Get the MMA intrinisic intructions supported by the target. + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + dispatch_tuner = walk_result.dispatch_tuner assert dispatch_tuner, "No suitable dispatch tuner found" problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) configs = [] for i, config in enumerate( - generate_solutions(tune_logger, problem_size, num_subgroups) + generate_solutions(tune_logger, problem_size, num_subgroups, mma_list) ): if i >= limit: break diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 36fb87cbb..d81278e8c 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -13,6 +13,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import candidate_gen from . import common @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: M, N, K = 2048, 1280, 1280 + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.contraction, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_matmul, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.broadcast_rhs_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a34f172eb..45ae48c22 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -12,6 +12,8 @@ from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore + class CommonTypes: def __init__(self, ctx: ir.Context): @@ -83,65 +85,24 @@ def MNK(self) -> tuple[int, int, int]: return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) -@dataclass -class MfmaIntrinsic: - output_type: ir.IntegerType | ir.FloatType - m: int - n: int - k: int - input_type: ir.IntegerType | ir.FloatType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" - - @staticmethod - def mfma_f32_16x16x16_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 16, 16, 16, f16) - - @staticmethod - def mfma_f32_32x32x8_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 32, 32, 8, f16) - - @staticmethod - def mfma_i32_16x16x32_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 16, 16, 32, i8) - - @staticmethod - def mfma_i32_32x32x16_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 32, 32, 16, i8) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f32_16x16x16_f16(), - MfmaIntrinsic.mfma_f32_32x32x8_f16(), - MfmaIntrinsic.mfma_i32_16x16x32_i8(), - MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: +def get_compatible_mfma_intrinsics( + problem_size: ProblemSize, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +) -> list[iree_gpu.MMAIntrinsic]: + def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma + a_type, b_type, c_type = mma_attr.abc_element_types + if problem_size.res_type.element_type != c_type: return False if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: + if ( + problem_size.lhs_type.element_type != a_type + or problem_size.rhs_type.element_type != b_type + ): return False return True - return list(filter(is_compatible, MfmaIntrinsic.all())) + return list(filter(is_comptible, mma_intrinsics)) class ReorderWorkgroupsStrategy(Enum): @@ -186,7 +147,7 @@ def __str__(self) -> str: class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: MfmaIntrinsic + intrinsic: iree_gpu.MMAAttr tile_sizes: list[int] subgroup_m_count: int subgroup_n_count: int diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 891d703e2..ea0a4573d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest common_test.py """ import pytest @@ -14,6 +14,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore @pytest.fixture @@ -71,10 +72,12 @@ def test_gpu_pipeline_options() -> None: def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, @@ -96,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: ) -def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: - assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" - assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" - - def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( @@ -109,10 +107,14 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.f16), common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.mmt, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -122,10 +124,14 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.i8), common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.mmt, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) == [ - common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), - common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ] assert common.get_compatible_mfma_intrinsics( @@ -135,8 +141,44 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([64, 640, 320], tuner_ctx.type.f32), common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], + ) == [ + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ] + + assert ( + common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + == [] + ) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index edd7ccc38..f16b4a241 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -10,6 +10,9 @@ import z3 # type: ignore from typing import Iterator + +from iree.compiler.dialects import iree_gpu # type: ignore + from .common import * @@ -18,13 +21,22 @@ def get_mfma_intrinsic_constraints( intrinsic_m: z3.ArithRef, intrinsic_n: z3.ArithRef, intrinsic_k: z3.ArithRef, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + + mma_attrs = [iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics] + mnk_shapes = [mma_attr.mnk_shape for mma_attr in mma_attrs] + return z3.Or( *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics + z3.And( + intrinsic_m == m, + intrinsic_n == n, + intrinsic_k == k, + ) + for m, n, k in mnk_shapes ) ) @@ -68,6 +80,7 @@ def generate_constraints( subgroup_m_count, subgroup_n_count, waves_per_eu, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ): M, N, K = ( problem_size.matmul_size.M, @@ -82,7 +95,7 @@ def generate_constraints( constraints += [subgroup_size == 64, wg_threads <= 1024] constraints += [ get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics ) ] subgroup_k_count = 1 @@ -129,8 +142,40 @@ def generate_constraints( return constraints +def getMMAAttr( + output_type: ir.IntegerType | ir.FloatType, + m: int, + n: int, + k: int, + lhs_type: ir.IntegerType | ir.FloatType, + rhs_type: ir.IntegerType | ir.FloatType, +) -> iree_gpu.MMAAttr: + for mma_intrinsic in iree_gpu.MMAIntrinsic: + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + a_type, b_type, c_type = mma_attr.abc_element_types + mnk = mma_attr.mnk_shape + if ( + a_type == lhs_type + and b_type == rhs_type + and c_type == output_type + and m == mnk[0] + and n == mnk[1] + and k == mnk[2] + ): + return mma_attr + # If no matching intrinsic is found, raise an exception + raise ValueError( + f"No matching MMA intrinsic found for " + f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, " + f"m={m}, n={n}, k={k}." + ) + + def generate_solutions( - logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int + logger: logging.Logger, + problem_size: ProblemSize, + num_subgrups: int, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> Iterator[Configuration]: M, N, K = problem_size.MNK logger.info(f"{M},{N},{K}") @@ -168,6 +213,7 @@ def generate_solutions( sg_m_cnt, sg_n_cnt, waves_per_eu, + mma_intrinsics, ) solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") @@ -179,12 +225,13 @@ def generate_solutions( config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( + getMMAAttr( problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), lookup(intrinsic_k), problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, ), [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 7e1a5c55d..9de4beeee 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -14,6 +14,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_constraints @@ -37,7 +38,18 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) - configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4) + configs = dispatch_constraints.generate_solutions( + tuner_ctx.logger, + problem_size, + 4, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + assert configs is not None @@ -115,6 +127,12 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non sg_m_cnt, sg_n_cnt, waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) solver = z3.Solver() @@ -160,6 +178,12 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N sg_m_cnt, sg_n_cnt, waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) constraints.append(m > 1000) # Adding an additional unsatisfiable constraint diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index d3a99806f..fb10b04bc 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest dispatch_parser_test.py """ import pytest @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_parser @@ -39,10 +40,12 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, @@ -53,10 +56,12 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -75,10 +80,12 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1,