From c8738b728f4bb9bf19f105d743b16c9e852d609b Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 21 Nov 2024 18:44:58 -0500 Subject: [PATCH 01/25] Add Linux prep instructions to developer guide (#575) [skip ci] There are some apt packages needed before it would work. --- docs/developer_guide.md | 51 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/docs/developer_guide.md b/docs/developer_guide.md index 832466688..6cd5f83a8 100644 --- a/docs/developer_guide.md +++ b/docs/developer_guide.md @@ -3,6 +3,55 @@ Each sub-project has its own developer guide. If you would like to work across projects, these instructions should help you get started: + +### Install Dependencies + +Install shortfin dependencies +```bash +sudo apt update && sudo apt install -y clang lld +``` + +### Prepare your python environment + +Install: + +``` +python-is-python3 python3-venv python3-dev +``` + +
+ + Or, alternatively, use `pyenv` to manage a separate python installation for more control over its version: + + +First, install pyenv and its dependencies. + +```bash +sudo apt update; sudo apt install build-essential libssl-dev zlib1g-dev \ +libbz2-dev libreadline-dev libsqlite3-dev curl git \ +libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev +curl https://pyenv.run | bash +``` + +Then, make pyenv available by adding the below to your `~/.bashrc`: + +```bash +export PYENV_ROOT="$HOME/.pyenv" +command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" +eval "$(pyenv init -)" +``` + +Finally, install a pyenv-managed version of python + +```bash +pyenv install 3.12 # or whichever python version you'd like +pyenv local 3.12 +``` + +Now, your python, pip, and venv should be managed by pyenv instead. + +
+ ### Setup a venv We recommend setting up a Python @@ -54,8 +103,10 @@ See also: [nightly_releases.md](nightly_releases.md). ### Running tests ```bash +pip install -r shortfin/requirements-tests.txt pytest sharktank pytest shortfin +pytest app_tests/integration_tests ``` ### Optional: pre-commits and developer settings From 395f2f13e0cda7aeddc6fe8c39558537d5b5333f Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Nov 2024 19:02:02 -0500 Subject: [PATCH 02/25] [shortfin] Implement transpose and elementwise host ops. (#578) * Adds `sfnp.transpose`, `sfnp.add`, `sfnp.divide`, `sfnp.multiply`, `sfnp.subtract` * Adds a use case test showing how to convert planar fp images to interleaved RGB pixel values Fixes #314 --- shortfin/python/array_binding.cc | 36 +-- shortfin/python/array_host_ops.cc | 287 ++++++++++++++++++++- shortfin/python/shortfin/array/__init__.py | 10 + shortfin/src/shortfin/array/dtype.h | 3 + shortfin/tests/api/array_ops_test.py | 168 ++++++++++++ shortfin/tests/api/array_use_case_test.py | 64 +++++ 6 files changed, 551 insertions(+), 17 deletions(-) create mode 100644 shortfin/tests/api/array_use_case_test.py diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc index a05232674..08a4071a8 100644 --- a/shortfin/python/array_binding.cc +++ b/shortfin/python/array_binding.cc @@ -531,22 +531,26 @@ void BindArray(py::module_ &m) { ->AddAsInvocationArgument( inv, static_cast(barrier)); }) - .def_static("for_device", - [](local::ScopedDevice &device, std::span shape, - DType dtype) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/device.fiber(), - device_array::for_device(device, shape, dtype)); - }) - .def_static("for_host", - [](local::ScopedDevice &device, std::span shape, - DType dtype) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/device.fiber(), - device_array::for_host(device, shape, dtype)); - }) + .def_static( + "for_device", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/device.fiber(), + device_array::for_device(device, shape, dtype)); + }, + py::arg("device"), py::arg("shape"), py::arg("dtype")) + .def_static( + "for_host", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/device.fiber(), + device_array::for_host(device, shape, dtype)); + }, + py::arg("device"), py::arg("shape"), py::arg("dtype")) .def("for_transfer", [](device_array &self) { return custom_new_keep_alive( diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index 86385cfee..3e2a8ebe3 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -91,6 +91,18 @@ static const char DOCSTRING_RANDOM_GENERATOR[] = fixed number. )"; +static const char DOCSTRING_TRANSPOSE[] = + R"(Transposes axes of an array according to a permutation vector. + +Args: + input: Array to transpose. + permutation: New sequence of axes. Must have same number of elements as the + rank of input. + out: If given, then the results are written to this array. + device_visible: Whether to make the result array visible to devices. Defaults + to False. +)"; + #define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \ case DType::dtype_name(): \ return compute.template operator()() @@ -100,6 +112,25 @@ static const char DOCSTRING_RANDOM_GENERATOR[] = compute.template operator()(); \ break +#define SF_MOVEMENT_OP_SWITCH(dtype) \ + if (!dtype.is_byte_aligned()) \ + throw std::invalid_argument( \ + "data movement ops are only defined for byte aligned dtypes"); \ + switch (dtype.dense_byte_count()) { \ + case 1: \ + return compute.template operator()(); \ + case 2: \ + return compute.template operator()(); \ + case 4: \ + return compute.template operator()(); \ + case 8: \ + return compute.template operator()(); \ + default: \ + throw std::invalid_argument( \ + "data movement ops are only defined for dtypes of size 1, 2, " \ + "4, 8"); \ + } + struct PyRandomGenerator { public: using SeedType = xt::random::default_engine_type::result_type; @@ -374,6 +405,227 @@ struct ConvertTruncFunctor { } }; +void OptionalArrayCast(py::handle handle, + std::optional &maybe_array) { + if (py::isinstance(handle)) { + maybe_array.emplace(py::cast(handle)); + } +} + +int DTypePromotionRank(DType dtype) { + int rank = 1; + if (dtype.is_boolean()) + rank *= 1000; + else if (dtype.is_integer()) + rank *= 2000; + else if (dtype.is_float()) + rank *= 4000; + else if (dtype.is_complex()) + rank *= 8000; + return rank + dtype.bit_count(); +} + +DType PromoteArithmeticTypes(std::optional lhs_dtype, + std::optional rhs_dtype) { + if (!lhs_dtype && !rhs_dtype) { + throw std::invalid_argument( + "Elementwise operators require at least one argument to be a " + "device_array"); + } + + // One not an array: promote to the array type. + if (!lhs_dtype) + return *rhs_dtype; + else if (!rhs_dtype) + return *lhs_dtype; + + int lhs_rank = DTypePromotionRank(*lhs_dtype); + int rhs_rank = DTypePromotionRank(*rhs_dtype); + DType promoted_dtype = lhs_rank < rhs_rank ? *rhs_dtype : *lhs_dtype; + + // If mismatched signed/unsigned, then need to promote to the next signed + // dtype. + if (promoted_dtype.is_integer()) { + bool lhs_unsigned = iree_all_bits_set( + lhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED); + bool rhs_unsigned = iree_all_bits_set( + rhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED); + if ((lhs_unsigned || rhs_unsigned) && !(lhs_unsigned && rhs_unsigned)) { + // Signed/unsigned mismatch. Promote to next. + switch (promoted_dtype) { + case DType::uint8(): + case DType::int8(): + return DType::int16(); + case DType::uint16(): + case DType::int16(): + return DType::int32(); + case DType::uint32(): + case DType::int32(): + return DType::int64(); + default: + // Jax's type promotion chart says this goes to a weak FP type, but + // we don't implement such a construct and I don't really see how + // that makes sense in a system setting like this, so we just saturate + // to 64bit. + return DType::int64(); + } + } + } + + return promoted_dtype; +} + +// ---------------------------------------------------------------------------// +// Elementwise support +// ---------------------------------------------------------------------------// + +// Python element type scalar conversion functions. +uint8_t ConvertPyToEltTy(py::handle py_value, uint8_t zero) { + return py::cast(py_value); +} + +int8_t ConvertPyToEltTy(py::handle py_value, int8_t zero) { + return py::cast(py_value); +} + +uint16_t ConvertPyToEltTy(py::handle py_value, uint16_t zero) { + return py::cast(py_value); +} + +int16_t ConvertPyToEltTy(py::handle py_value, int16_t zero) { + return py::cast(py_value); +} + +uint32_t ConvertPyToEltTy(py::handle py_value, uint32_t zero) { + return py::cast(py_value); +} + +int32_t ConvertPyToEltTy(py::handle py_value, int32_t zero) { + return py::cast(py_value); +} + +uint64_t ConvertPyToEltTy(py::handle py_value, uint64_t zero) { + return py::cast(py_value); +} + +int64_t ConvertPyToEltTy(py::handle py_value, int64_t zero) { + return py::cast(py_value); +} + +float ConvertPyToEltTy(py::handle py_value, float zero) { + return py::cast(py_value); +} + +double ConvertPyToEltTy(py::handle py_value, double zero) { + return py::cast(py_value); +} + +half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) { + // Python can't cast directly to half so first go to double. + return static_cast(py::cast(py_value)); +} + +struct AddFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs + rhs; + } +}; + +struct DivideFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs / rhs; + } +}; + +struct MultiplyFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs * rhs; + } +}; + +struct SubtractFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs - rhs; + } +}; + +template +device_array ElementwiseOperation(py::handle lhs, py::handle rhs, + std::optional out, + bool device_visible) { + std::optional lhs_array; + OptionalArrayCast(lhs, lhs_array); + std::optional rhs_array; + OptionalArrayCast(rhs, rhs_array); + auto dtype = PromoteArithmeticTypes( + lhs_array ? std::optional(lhs_array->dtype()) : std::nullopt, + rhs_array ? std::optional(rhs_array->dtype()) : std::nullopt); + if (lhs_array && lhs_array->dtype() != dtype) { + auto converted = GenericElementwiseConvert( + *lhs_array, dtype, /*out=*/std::nullopt, + /*device_visible=*/false); + lhs_array.reset(); + lhs_array.emplace(std::move(converted)); + } + if (rhs_array && rhs_array->dtype() != dtype) { + auto converted = GenericElementwiseConvert( + *rhs_array, dtype, /*out=*/std::nullopt, + /*device_visible=*/false); + rhs_array.reset(); + rhs_array.emplace(std::move(converted)); + } + + auto compute = [&]() -> device_array { + auto handle_result = [&]( + D &&device, A &&result) -> device_array { + if (!out) { + out.emplace(device_array::for_host(device, result.shape(), dtype, + device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = result; + return *out; + }; + if (!rhs_array) { + auto lhs_t = lhs_array->map_xtensor(); + xt::xarray rhs_scalar = ConvertPyToEltTy(rhs, EltTy()); + return handle_result(lhs_array->device(), + ElementwiseFunctor::Invoke(*lhs_t, rhs_scalar)); + } else if (!lhs_array) { + xt::xarray lhs_scalar = ConvertPyToEltTy(lhs, EltTy()); + auto rhs_t = rhs_array->map_xtensor(); + return handle_result(rhs_array->device(), + ElementwiseFunctor::Invoke(lhs_scalar, *rhs_t)); + } else { + auto lhs_t = lhs_array->map_xtensor(); + auto rhs_t = rhs_array->map_xtensor(); + return handle_result(lhs_array->device(), + ElementwiseFunctor::Invoke(*lhs_t, *rhs_t)); + } + }; + + switch (dtype) { + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); + SF_UNARY_FUNCTION_CASE(float64, double); + SF_UNARY_FUNCTION_CASE(uint8, uint8_t); + SF_UNARY_FUNCTION_CASE(int8, int8_t); + SF_UNARY_FUNCTION_CASE(uint16, uint16_t); + SF_UNARY_FUNCTION_CASE(int16, int16_t); + SF_UNARY_FUNCTION_CASE(uint32, uint32_t); + SF_UNARY_FUNCTION_CASE(int32, uint32_t); + SF_UNARY_FUNCTION_CASE(uint64, uint64_t); + SF_UNARY_FUNCTION_CASE(int64, int64_t); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for in elementwise op", dtype.name())); + } +} + } // namespace void BindArrayHostOps(py::module_ &m) { @@ -457,6 +709,39 @@ void BindArrayHostOps(py::module_ &m) { SF_DEF_CONVERT("floor", GenericElementwiseConvert); SF_DEF_CONVERT("round", GenericElementwiseConvert); SF_DEF_CONVERT("trunc", GenericElementwiseConvert); -} + + // Transpose. + m.def( + "transpose", + [](device_array input, std::vector permutation, + std::optional out, bool device_visible) { + auto compute = [&]() -> device_array { + auto input_t = input.map_xtensor(); + auto permuted_t = + xt::transpose(*input_t, permutation, xt::check_policy::full()); + if (!out) { + out.emplace(device_array::for_host(input.device(), + permuted_t.shape(), + input.dtype(), device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = permuted_t; + return *out; + }; + SF_MOVEMENT_OP_SWITCH(input.dtype()); + }, + py::arg("input"), py::arg("permutation"), py::arg("out") = py::none(), + py::arg("device_visible") = false, DOCSTRING_TRANSPOSE); + +// Elementwise. +#define SF_DEF_ELEMENTWISE(py_name, target) \ + m.def(py_name, target, py::arg("lhs"), py::arg("rhs"), py::kw_only(), \ + py::arg("out") = py::none(), py::arg("device_visible") = false) + SF_DEF_ELEMENTWISE("add", ElementwiseOperation); + SF_DEF_ELEMENTWISE("divide", ElementwiseOperation); + SF_DEF_ELEMENTWISE("multiply", ElementwiseOperation); + SF_DEF_ELEMENTWISE("subtract", ElementwiseOperation); + +} // namespace shortfin::python } // namespace shortfin::python diff --git a/shortfin/python/shortfin/array/__init__.py b/shortfin/python/shortfin/array/__init__.py index 6079541c8..670102dfe 100644 --- a/shortfin/python/shortfin/array/__init__.py +++ b/shortfin/python/shortfin/array/__init__.py @@ -44,11 +44,16 @@ # Ops. argmax = _sfl.array.argmax +add = _sfl.array.add ceil = _sfl.array.ceil convert = _sfl.array.convert +divide = _sfl.array.divide fill_randn = _sfl.array.fill_randn floor = _sfl.array.floor +multiply = _sfl.array.multiply round = _sfl.array.round +subtract = _sfl.array.subtract +transpose = _sfl.array.transpose trunc = _sfl.array.trunc RandomGenerator = _sfl.array.RandomGenerator @@ -86,12 +91,17 @@ "storage", "DType", # Ops. + "add", "argmax", "ceil", "convert", + "divide", "fill_randn", "floor", + "multiply", "round", + "subtract", + "transpose", "trunc", "RandomGenerator", ] diff --git a/shortfin/src/shortfin/array/dtype.h b/shortfin/src/shortfin/array/dtype.h index d746d69bf..de1763698 100644 --- a/shortfin/src/shortfin/array/dtype.h +++ b/shortfin/src/shortfin/array/dtype.h @@ -49,6 +49,9 @@ class SHORTFIN_API DType { bool is_integer_bitwidth(size_t bitwidth) const { return iree_hal_element_type_is_integer(et_, bitwidth); } + uint32_t numerical_type() const { + return iree_hal_element_numerical_type(et_); + } // Computes the size in bytes required to store densely packed nd-dims. // This presently only supports byte aligned dtypes. In the future, when diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py index 7c792d92b..164dfb479 100644 --- a/shortfin/tests/api/array_ops_test.py +++ b/shortfin/tests/api/array_ops_test.py @@ -268,3 +268,171 @@ def test_nearest_int_conversion(device, dtype, out_dtype, sfnp_func, ref_round_f assert output.dtype == out_dtype for ref, actual in zip(ref_rounded, output.items): assert ref == int(actual) + + +def test_elementwise_forms(device): + # All elementwise ops use the same template expansion which enforces + # certain common invariants. Here we test these on the multiply op, + # relying on a parametric test for actual behavior. + with pytest.raises( + ValueError, + match="Elementwise operators require at least one argument to be a device_array", + ): + sfnp.multiply(2, 2) + + ary = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32) + with ary.map(discard=True) as m: + m.fill(42.0) + + # Rhs scalar int accepted. + result = sfnp.multiply(ary, 2) + assert list(result.items) == [84.0] * 6 + + # Rhs scalar float accepted. + result = sfnp.multiply(ary, 2.0) + assert list(result.items) == [84.0] * 6 + + # Lhs scalar int accepted. + result = sfnp.multiply(2, ary) + assert list(result.items) == [84.0] * 6 + + # Lhs scalar float accepted. + result = sfnp.multiply(2.0, ary) + assert list(result.items) == [84.0] * 6 + + # Out. + out = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32) + sfnp.multiply(2.0, ary, out=out) + assert list(out.items) == [84.0] * 6 + + +@pytest.mark.parametrize( + "lhs_dtype,rhs_dtype,promoted_dtype", + [ + (sfnp.float32, sfnp.float16, sfnp.float32), + (sfnp.float16, sfnp.float32, sfnp.float32), + (sfnp.float32, sfnp.float64, sfnp.float64), + (sfnp.float64, sfnp.float32, sfnp.float64), + # Integer promotion. + (sfnp.uint8, sfnp.uint16, sfnp.uint16), + (sfnp.uint16, sfnp.uint32, sfnp.uint32), + (sfnp.uint32, sfnp.uint64, sfnp.uint64), + (sfnp.int8, sfnp.int16, sfnp.int16), + (sfnp.int16, sfnp.int32, sfnp.int32), + (sfnp.int32, sfnp.int64, sfnp.int64), + # Signed/unsigned promotion. + (sfnp.int8, sfnp.uint8, sfnp.int16), + (sfnp.int16, sfnp.uint16, sfnp.int32), + (sfnp.int32, sfnp.uint32, sfnp.int64), + (sfnp.int8, sfnp.uint32, sfnp.int64), + ], +) +def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): + # Tests that promotion infers an appropriate result type. + lhs = sfnp.device_array.for_host(device, [2, 3], lhs_dtype) + rhs = sfnp.device_array.for_host(device, [2, 3], rhs_dtype) + result = sfnp.multiply(lhs, rhs) + assert result.dtype == promoted_dtype + + +@pytest.mark.parametrize( + "dtype,op,check_value", + [ + # Add. + (sfnp.int8, sfnp.add, 44.0), + (sfnp.int16, sfnp.add, 44.0), + (sfnp.int32, sfnp.add, 44.0), + (sfnp.int64, sfnp.add, 44.0), + (sfnp.uint8, sfnp.add, 44.0), + (sfnp.uint16, sfnp.add, 44.0), + (sfnp.uint32, sfnp.add, 44.0), + (sfnp.uint64, sfnp.add, 44.0), + (sfnp.float16, sfnp.add, 44.0), + (sfnp.float32, sfnp.add, 44.0), + (sfnp.float64, sfnp.add, 44.0), + # Divide. + (sfnp.int8, sfnp.divide, 21.0), + (sfnp.int16, sfnp.divide, 21.0), + (sfnp.int32, sfnp.divide, 21.0), + (sfnp.int64, sfnp.divide, 21.0), + (sfnp.uint8, sfnp.divide, 21.0), + (sfnp.uint16, sfnp.divide, 21.0), + (sfnp.uint32, sfnp.divide, 21.0), + (sfnp.uint64, sfnp.divide, 21.0), + (sfnp.float16, sfnp.divide, 21.0), + (sfnp.float32, sfnp.divide, 21.0), + (sfnp.float64, sfnp.divide, 21.0), + # Multiply. + (sfnp.int8, sfnp.multiply, 84.0), + (sfnp.int16, sfnp.multiply, 84.0), + (sfnp.int32, sfnp.multiply, 84.0), + (sfnp.int64, sfnp.multiply, 84.0), + (sfnp.uint8, sfnp.multiply, 84.0), + (sfnp.uint16, sfnp.multiply, 84.0), + (sfnp.uint32, sfnp.multiply, 84.0), + (sfnp.uint64, sfnp.multiply, 84.0), + (sfnp.float16, sfnp.multiply, 84.0), + (sfnp.float32, sfnp.multiply, 84.0), + (sfnp.float64, sfnp.multiply, 84.0), + # Subtract. + (sfnp.int8, sfnp.subtract, 40.0), + (sfnp.int16, sfnp.subtract, 40.0), + (sfnp.int32, sfnp.subtract, 40.0), + (sfnp.int64, sfnp.subtract, 40.0), + (sfnp.uint8, sfnp.subtract, 40.0), + (sfnp.uint16, sfnp.subtract, 40.0), + (sfnp.uint32, sfnp.subtract, 40.0), + (sfnp.uint64, sfnp.subtract, 40.0), + (sfnp.float16, sfnp.subtract, 40.0), + (sfnp.float32, sfnp.subtract, 40.0), + (sfnp.float64, sfnp.subtract, 40.0), + ], +) +def test_elementwise_array_correctness(device, dtype, op, check_value): + lhs = sfnp.device_array.for_host(device, [2, 2], sfnp.int32) + with lhs.map(discard=True) as m: + m.fill(42) + + rhs = sfnp.device_array.for_host(device, [2], sfnp.int32) + with rhs.map(discard=True) as m: + m.fill(2) + + lhs = sfnp.convert(lhs, dtype=dtype) + rhs = sfnp.convert(rhs, dtype=dtype) + result = op(lhs, rhs) + assert result.shape == [2, 2] + result = sfnp.convert(result, dtype=sfnp.float32) + items = list(result.items) + assert items == [check_value] * 4 + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.int8, + sfnp.int16, + sfnp.int32, + sfnp.int64, + sfnp.uint8, + sfnp.uint16, + sfnp.uint32, + sfnp.uint64, + sfnp.float32, + sfnp.float16, + sfnp.float32, + sfnp.float64, + ], +) +def test_transpose(device, dtype): + input = sfnp.device_array.for_host(device, [3, 2], sfnp.int32) + input.items = [0, 1, 2, 3, 4, 5] + input = sfnp.convert(input, dtype=dtype) + permuted = sfnp.transpose(input, [1, 0]) + assert permuted.shape == [2, 3] + items = list(sfnp.convert(permuted, dtype=sfnp.int32).items) + assert items == [0, 2, 4, 1, 3, 5] + + out = sfnp.device_array.for_host(device, [2, 3], dtype) + sfnp.transpose(input, [1, 0], out=out) + items = list(sfnp.convert(permuted, dtype=sfnp.int32).items) + assert items == [0, 2, 4, 1, 3, 5] diff --git a/shortfin/tests/api/array_use_case_test.py b/shortfin/tests/api/array_use_case_test.py new file mode 100644 index 000000000..d4a030d45 --- /dev/null +++ b/shortfin/tests/api/array_use_case_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import array +import math +import pytest + +import shortfin as sf +import shortfin.array as sfnp + + +@pytest.fixture +def lsys(): + # TODO: Port this test to use memory type independent access. It currently + # presumes unified memory. + # sc = sf.SystemBuilder() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +# Tests a typical image conversion from a model oriented layout to an array +# of contained images. +def test_image_to_bytes(device): + bs = 2 + height = 16 + width = 12 + images_shape = [bs, 3, height, width] + images_planar = sfnp.device_array.for_host(device, images_shape, sfnp.float32) + # Band the data so that each channel increases by 0.1 across images. + for i in range(bs): + for j in range(3): + data = [i * 0.3 + j * 0.1 for _ in range(height * width)] + images_planar.view(i, j).items = data + images_planar = sfnp.convert(images_planar, dtype=sfnp.float16) + + # Extract and convert each image to interleaved RGB bytes. + images = [] + for idx in range(images_planar.shape[0]): + image_planar = images_planar.view(idx) + assert image_planar.shape == [1, 3, 16, 12] + image_interleaved = sfnp.transpose(image_planar, (0, 2, 3, 1)) + assert image_interleaved.shape == [1, 16, 12, 3] + image_scaled = sfnp.multiply(image_interleaved, 255) + image = sfnp.round(image_scaled, dtype=sfnp.uint8) + image_bytes = bytes(image.map(read=True)) + images.append(image_bytes) + + assert images[0] == b"\x00\x1a3" * 192 + assert images[1] == b"Mf\x80" * 192 From a6cb4423bb0a2a221e82019e5db6adf8b215684f Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 21 Nov 2024 19:21:14 -0500 Subject: [PATCH 03/25] Fix some linux developer_guide.md comments that didn't show up at time of merge (#588) [skip ci] Clicked merge before some of #575 's comments showed up. --- docs/developer_guide.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/developer_guide.md b/docs/developer_guide.md index 6cd5f83a8..73aee61f7 100644 --- a/docs/developer_guide.md +++ b/docs/developer_guide.md @@ -15,8 +15,8 @@ sudo apt update && sudo apt install -y clang lld Install: -``` -python-is-python3 python3-venv python3-dev +```bash +sudo apt install python-is-python3 python3-venv python3-dev ```
@@ -24,6 +24,8 @@ python-is-python3 python3-venv python3-dev Or, alternatively, use `pyenv` to manage a separate python installation for more control over its version: +The following instructions are taken from pyenv's guide here: https://github.com/pyenv/pyenv?tab=readme-ov-file#a-getting-pyenv + First, install pyenv and its dependencies. ```bash From fd15aa5a30f842472faf819e60285c38e8adfa7b Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Thu, 21 Nov 2024 17:26:09 -0800 Subject: [PATCH 04/25] Fix publish dir of llama tests, update xpass 8b test and 405b fp8 test failures (#580) Fixes publish dir of llama tests to `out/llm/llama/benchmarks`, update xpass 8b test (`testBenchmark8B_f16_Non_Decomposed_Prefill`) and 405b fp8 test failures (`testBenchmark405B_fp8_TP8_Decomposed` and `testBenchmark405B_fp8_TP8_Non_Decomposed`). --------- Signed-off-by: aviator19941 Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> --- .github/workflows/ci-llama-large-tests.yaml | 7 ++++--- .../tests/models/llama/benchmark_amdgpu_test.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 34e91cebb..ae53d3f38 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -8,6 +8,7 @@ name: Llama Benchmarking Tests on: workflow_dispatch: + pull_request: schedule: # Weekdays at 4:00 AM UTC = 9:00 PM PST. - cron: "0 4 * * 1-5" @@ -76,14 +77,14 @@ jobs: iree-base-runtime - name: Run llama tests - run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/index.html + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/llm/llama/benchmark/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} - publish_dir: ./out/llm/llama/benchmarks - destination_dir: ./llm/llama/benchmarks + publish_dir: ./out/llm/llama/benchmark + destination_dir: ./llm/llama/benchmark keep_files: true - name: Upload llama executable files diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 125a0cfdc..751615a85 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -197,7 +197,6 @@ def testBenchmark8B_f16_Decomposed(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark8B_f16_Non_Decomposed_Prefill(self): output_file_name = self.dir_path_8b / "f16_torch_prefill" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -780,7 +779,9 @@ def testBenchmark405B_f16_TP8_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException + ) def testBenchmark405B_f16_TP8_Non_Decomposed(self): output_file_name = self.dir_path_405b / "f16_torch" output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file( @@ -828,7 +829,9 @@ def testBenchmark405B_f16_TP8_Non_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="KeyError in theta.py", strict=True, raises=ExportMlirException + ) def testBenchmark405B_fp8_TP8_Decomposed(self): output_file_name = self.dir_path_405b / "fp8_decomposed" output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file( @@ -874,7 +877,9 @@ def testBenchmark405B_fp8_TP8_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="KeyError in theta.py", strict=True, raises=ExportMlirException + ) def testBenchmark405B_fp8_TP8_Non_Decomposed(self): output_file_name = self.dir_path_405b / "fp8_torch" output_mlir = self.llama405b_fp8_torch_sdpa_artifacts.create_file( From 83f2d1b2b9c8e4abf703c1fd774fd9d1558c4af9 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:07:46 -0800 Subject: [PATCH 05/25] remove presubmit for large llama tests (#591) Accidentally added as part of https://github.com/nod-ai/shark-ai/pull/580 --- .github/workflows/ci-llama-large-tests.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index ae53d3f38..644066094 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -8,7 +8,6 @@ name: Llama Benchmarking Tests on: workflow_dispatch: - pull_request: schedule: # Weekdays at 4:00 AM UTC = 9:00 PM PST. - cron: "0 4 * * 1-5" From 779adc3d5add624cb80d5bcb579705dd42956f4b Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:16:21 -0800 Subject: [PATCH 06/25] [sharktank] Add perplexity CI to sharktank dashboard (#466) Add perplexity CI to sharktank dashboard Rename _vmfb to _iree Add Perplexity scoreboard & description Add descriptive errors for better logging --------- Co-authored-by: Marius Brehler --- .github/workflows/ci_eval.yaml | 30 ++++++++++++---- sharktank/sharktank/evaluate/README.md | 20 +++++++++-- ...{perplexity_vmfb.py => perplexity_iree.py} | 18 +++++++--- .../sharktank/evaluate/perplexity_torch.py | 18 +++++++--- sharktank/sharktank/utils/export_artifacts.py | 18 +++++++--- .../evaluate/baseline_perplexity_scores.json | 2 +- ...y_vmfb_test.py => perplexity_iree_test.py} | 34 +++++++++---------- 7 files changed, 98 insertions(+), 42 deletions(-) rename sharktank/sharktank/evaluate/{perplexity_vmfb.py => perplexity_iree.py} (96%) rename sharktank/tests/evaluate/{perplexity_vmfb_test.py => perplexity_iree_test.py} (92%) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 54aa3c763..e6794bff9 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -21,10 +21,10 @@ concurrency: cancel-in-progress: true jobs: - test_perplexity_vmfb: + test_perplexity_iree: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "IREE/vmfb" + name: "Perplexity-IREE" strategy: matrix: version: [3.11] @@ -74,13 +74,21 @@ jobs: iree-base-compiler \ iree-base-runtime - - name: Run perplexity test with vmfb - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + - name: Run perplexity test with IREE + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/perplexity/iree_perplexity + destination_dir: ./llm/llama/perplexity/iree_perplexity + keep_files: true test_perplexity_torch: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "Torch/eager mode" + name: "Perplexity-Torch" strategy: matrix: version: [3.11] @@ -123,5 +131,13 @@ jobs: pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - - name: Run perplexity test in eager mode - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json + - name: Run perplexity test with Torch + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/perplexity/torch_perplexity + destination_dir: ./llm/llama/perplexity/torch_perplexity + keep_files: true diff --git a/sharktank/sharktank/evaluate/README.md b/sharktank/sharktank/evaluate/README.md index 784bb24fd..beb0281cd 100644 --- a/sharktank/sharktank/evaluate/README.md +++ b/sharktank/sharktank/evaluate/README.md @@ -9,16 +9,32 @@ pip install -r sharktank/requirements-tests.txt ### Perplexity -Test perplexity for Llama3.1 8B & 405B (FP16 & FP8) models: +Perplexity score measures the ability of a language model to predict the next token in a sequence. A lower score indicates that a model has higher certainty in it's predictions. Perplexity acts as an intrinsic evaluation metric that measures the model quality, independent of any downstream task. + +In SHARK-Platform, we use perplexity to track code regressions and quality loss across quantized models (with FP16 as baseline). We use 100 prompts randomly selected from the Wikitext-2 test set and calculate the mean perplexities shown below. These numbers are neither comparable between models with different tokenizers nor with other projects due to varying implementations. + +* Test perplexity for Llama3.1 8B (FP16) model: ```bash pytest sharktank/tests/evaluate/perplexity_test.py --longrun ``` -Get perplexity for a new model: +* Calculate perplexity for a new model: ```bash python -m sharktank.evaluate.perplexity \ --gguf-file=llama3_70b_f16.gguf \ --tokenizer-config-json=tokenizer_config.json ``` + +### Perplexity Scoreboard + +| CPU | GPU | +|:-------------: |:----------:| +| AMD EPYC 9554 | MI300X | + +#### LLaMA 3.1 + +|Models |Model size (GB) |Torch score |IREE score | +|:----------------------|:---------------|:-------------|:-------------| +|8B FP16 TP1 decomposed |16.07 |14.930181 |14.991893 | diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_iree.py similarity index 96% rename from sharktank/sharktank/evaluate/perplexity_vmfb.py rename to sharktank/sharktank/evaluate/perplexity_iree.py index 4f95ae1bd..9701bed34 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -9,6 +9,7 @@ import json import time import random +import re from datetime import timedelta from tqdm import tqdm @@ -83,11 +84,18 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) func_name = func.__name__ if func_name == "get_perplexity": diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..da5fc104a 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -8,6 +8,7 @@ import logging import time import random +import re from datetime import timedelta import json import numpy as np @@ -69,11 +70,18 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) func_name = func.__name__ if func_name == "get_perplexity": diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index bd33e1a62..e7851ac37 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -9,6 +9,7 @@ import subprocess import logging import time +import re from pathlib import Path from datetime import timedelta from typing import List, Optional @@ -107,11 +108,18 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) func_name = func.__name__ logger.info(f" {func_name}: {time_taken}") diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json index ac2cd7b83..24511b05f 100644 --- a/sharktank/tests/evaluate/baseline_perplexity_scores.json +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -210,7 +210,7 @@ ], "mean_perplexity": 6.060831 }, - "llama3_8B_f16_decomposed_vmfb": { + "llama3_8B_f16_decomposed_iree": { "perplexities": [ 6.651368, 22.059452, diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py similarity index 92% rename from sharktank/tests/evaluate/perplexity_vmfb_test.py rename to sharktank/tests/evaluate/perplexity_iree_test.py index 93ffbe61c..8cf2055c9 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -8,7 +8,7 @@ import pytest import json -from sharktank.evaluate import perplexity_vmfb +from sharktank.evaluate import perplexity_iree longrun = pytest.mark.skipif("not config.getoption('longrun')") @@ -32,10 +32,10 @@ def test_llama3_8B_f16_decomposed(self): # Llama 3.1 8B decomposed - model_name = "llama3_8B_f16_decomposed_vmfb" + model_name = "llama3_8B_f16_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -67,10 +67,10 @@ def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_f16_vmfb" + model_name = "llama3_8B_f16_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -102,10 +102,10 @@ def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed - model_name = "llama3_8B_fp8_decomposed_vmfb" + model_name = "llama3_8B_fp8_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -137,10 +137,10 @@ def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_fp8_vmfb" + model_name = "llama3_8B_fp8_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", @@ -172,10 +172,10 @@ def test_llama3_405B_f16_decomposed(self): # Llama 3.1 405B decomposed - model_name = "llama3_405B_f16_decomposed_vmfb" + model_name = "llama3_405B_f16_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -207,10 +207,10 @@ def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_f16_vmfb" + model_name = "llama3_405B_f16_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -242,10 +242,10 @@ def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed - model_name = "llama3_405B_fp8_decomposed_vmfb" + model_name = "llama3_405B_fp8_decomposed_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", @@ -277,10 +277,10 @@ def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_fp8_vmfb" + model_name = "llama3_405B_fp8_iree" baseline_perplexity = self.baseline_perplexity[model_name] - current_perplexity = perplexity_vmfb.main( + current_perplexity = perplexity_iree.main( [ f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", From 530f4bdb9ce4e2b886ea34265441ff1ea3f9ee85 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Fri, 22 Nov 2024 11:51:23 -0500 Subject: [PATCH 07/25] [tuner]: use python binding to select mma intrinsics (#586) This PR is relevant to the task in https://github.com/nod-ai/shark-ai/issues/453: " Use IREE attributes for MFMA intrinsics in the tuner". --------- Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 10 ++++- tuner/tuner/common.py | 13 +++++- tuner/tuner/common_test.py | 51 ++++++++++++++++++++++-- tuner/tuner/dispatch_constraints.py | 15 +++++-- tuner/tuner/dispatch_constraints_test.py | 26 +++++++++++- 5 files changed, 106 insertions(+), 9 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 5786a9fff..38696e6db 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -30,6 +30,8 @@ from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore + from .common import * from .dispatch_constraints import * from .dispatch_parser import * @@ -535,13 +537,19 @@ def tune( walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + # Get the MMA intrinisic intructions supported by the target. + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + dispatch_tuner = walk_result.dispatch_tuner assert dispatch_tuner, "No suitable dispatch tuner found" problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) configs = [] for i, config in enumerate( - generate_solutions(tune_logger, problem_size, num_subgroups) + generate_solutions(tune_logger, problem_size, num_subgroups, mma_list) ): if i >= limit: break diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a34f172eb..b6e31768e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -12,6 +12,8 @@ from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore + class CommonTypes: def __init__(self, ctx: ir.Context): @@ -130,7 +132,12 @@ def all(): ] -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: +def get_compatible_mfma_intrinsics( + problem_size: ProblemSize, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +) -> list[MfmaIntrinsic]: + available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.res_type.element_type != intrinsic.output_type: return False @@ -139,6 +146,10 @@ def is_compatible(intrinsic: MfmaIntrinsic) -> bool: return False if problem_size.rhs_type.element_type != intrinsic.input_type: return False + + if str(intrinsic) not in available_mma_intrinsics: + return False + return True return list(filter(is_compatible, MfmaIntrinsic.all())) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 891d703e2..297ac95a2 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -14,6 +14,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore @pytest.fixture @@ -109,7 +110,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.f16), common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.mmt, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], ) == [ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), @@ -122,7 +127,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.i8), common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.mmt, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) == [ common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), @@ -135,8 +144,44 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([64, 640, 320], tuner_ctx.type.f32), common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], ) == [ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], + ) == [ + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert ( + common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + == [] + ) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index edd7ccc38..85039a1e8 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -10,6 +10,9 @@ import z3 # type: ignore from typing import Iterator + +from iree.compiler.dialects import iree_gpu # type: ignore + from .common import * @@ -18,8 +21,9 @@ def get_mfma_intrinsic_constraints( intrinsic_m: z3.ArithRef, intrinsic_n: z3.ArithRef, intrinsic_k: z3.ArithRef, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" return z3.Or( *( @@ -68,6 +72,7 @@ def generate_constraints( subgroup_m_count, subgroup_n_count, waves_per_eu, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ): M, N, K = ( problem_size.matmul_size.M, @@ -82,7 +87,7 @@ def generate_constraints( constraints += [subgroup_size == 64, wg_threads <= 1024] constraints += [ get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics ) ] subgroup_k_count = 1 @@ -130,7 +135,10 @@ def generate_constraints( def generate_solutions( - logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int + logger: logging.Logger, + problem_size: ProblemSize, + num_subgrups: int, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> Iterator[Configuration]: M, N, K = problem_size.MNK logger.info(f"{M},{N},{K}") @@ -168,6 +176,7 @@ def generate_solutions( sg_m_cnt, sg_n_cnt, waves_per_eu, + mma_intrinsics, ) solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 7e1a5c55d..9de4beeee 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -14,6 +14,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_constraints @@ -37,7 +38,18 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) - configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4) + configs = dispatch_constraints.generate_solutions( + tuner_ctx.logger, + problem_size, + 4, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + assert configs is not None @@ -115,6 +127,12 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non sg_m_cnt, sg_n_cnt, waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) solver = z3.Solver() @@ -160,6 +178,12 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N sg_m_cnt, sg_n_cnt, waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) constraints.append(m > 1000) # Adding an additional unsatisfiable constraint From e2c2f013199f3cf7de57df291e837cec2373f90d Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:35:35 -0600 Subject: [PATCH 08/25] Sglang User Doc (#498) # Description Adds documentation for running SGLang with Shortfin LLM Server. Currently, only focus on the sglang docs. I created this PR from the same repo as the other shortfin llm server docs. Those diffs should go away once that is merged. It links to the existing `Shortfin LLM Server User Doc` to setup and run shortfin. It then shows how to install SGLang inside of the same virtual environment. From there it has instructions for running a `Multi-Turn Q&A Flow`, `Fork Flow`, and how to run the `Benchmark` script against the shortfin server. --- docs/shortfin/llm/user/e2e_llama8b_mi300x.md | 2 +- .../shortfin_with_sglang_frontend_language.md | 254 ++++++++++++++++++ 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md index 5e0749546..4a8423bc8 100644 --- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -64,7 +64,7 @@ We will use the `hf_datasets` module in `sharktank` to download a LLama3.1 8b f16 model. ```bash -python -m sharktank.utils.hf_datasets amd-shark/llama3.1-8B --local-dir $EXPORT_DIR +python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR ``` ### Define environment variables diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md new file mode 100644 index 000000000..b63861a56 --- /dev/null +++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md @@ -0,0 +1,254 @@ +# Using `shortfin` with `sglang` + +This doc includes basic steps for hooking up sglang with a running Shortfin server. + +## Current Support Status + +| Feature | Description | Enabled | Reference | +| ----------- | ----------- | ---------- | ------------ | +| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) | +| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) | +| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) | +| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) | +| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) | +| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) | +| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) | + +## Prerequisites + +For this tutorial, you will need to meet the following prerequisites: + +### Software + +- Python >= 3.11 + - You can check out [pyenv](https://github.com/pyenv/pyenv) + as a good tool to be able to manage multiple versions of python + on the same system. +- A running `shortfin` LLM server as described [below](#installstart-shortfin-llm-server) + - We will use the shortfin server as the `backend` to generate completions + from SGLang's `frontend language`. In this tutorial, you can think of + `sglang` as the client and `shortfin` as the server. + +### Hardware + +- This tutorial is designed to run on an [AMD MI300X GPU](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html) + +## Install/Start `shortfin` LLM server + +Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/e2e_llama8b_mi300x.md) +to export a model with `sharktank` and start a `shortfin` LLM server +with that model. + +## Install sglang + +### Install sglang inside of virtual environment + +Currently, we have our SGLang integration located at this [forked repo](https://github.com/nod-ai/sglang). +We can use pip to install it in the same virtual environment that we used +to start our Shortfin LLM Server. + +```bash +pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" +``` + +## Getting started + +You can verify the installation/setup through the following examples: + +- [Multi-Turn Q&A Example](#multi-turn-qa-example) +- [Fork Example](#fork-example) +- [Benchmark Shortfin](#bench-mark-shortfin-w-sglang-bench_serving-script) + +## Multi-Turn Q&A example + +Now that we have sglang installed, we can run an example to show a multi-turn +Q&A flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): + +### Open python interpreter + +```bash +python +``` + +### Run example + +You can copy and paste the following example into your interpreter: + +```python +import sglang as sgl + +from sglang.lang.chat_template import get_chat_template + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000", ) # Change base_url if running at different address + +sgl.set_default_backend(backend) + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + +state = multi_turn_question.run(question_1="Name the capital city of the USA.", question_2="The Smithsonian is in this location.") + +for m in state.messages(): + print(m["role"], m["content"]) +``` + +### Shortfin example output + +You should see an output similar to this: + +```text +========== single ========== + +user : Name the capital city of the USA +assistant : The capital city of the United States of America is Washington, D.C. (short for District of Columbia). +user : The Smithsonian is in this location. +assistant : The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. +``` + +## Fork example + +Now that we have sglang installed, we can run an example to show a `fork` +flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): + +### Open python interpreter + +```bash +python +``` + +### Run example + +You can copy and paste the following example into your interpreter: + +```python +import sglang as sgl + +from sglang.lang.chat_template import get_chat_template + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000") # Change base_url if running at different address + +sgl.set_default_backend(backend) + +@sgl.function +def tip_suggestion(s): + s += ( + "Here are two tips for staying healthy: " + "1. Balanced Diet. 2. Regular Exercise.\n\n" + ) + forks = s.fork(2) + for i, f in enumerate(forks): + f += f"Now, expand tip {i+1} into a paragraph:\n" + f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") + +state = tip_suggestion.run() + +print(state.text()) +``` + +### Shortfin example output + +You should see an output similar to this: + +```text +Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise. + +Tip 1:A balanced diet is important for maintaining good health. It should +include a variety of foods from all the major food groups, such as fruits, +vegetables, grains, proteins, and dairy. Eating a balanced diet can help +prevent chronic diseases such as heart disease, diabetes, and obesity. + +Now, expand tip 2 into a paragraph: +Regular exercise is also important for maintaining good health. It can help +improve cardiovascular health, strengthen muscles and bones, and reduce the +risk of chronic diseases. Exercise can also help improve mental health by +reducing stress and anxiety. It is recommended that adults get at least 150 +minutes of moderate-intensity exercise or 75 minutes of vigorous-intensity +exercise per week. + +Now, combine the two paragraphs into a single paragraph: +A balanced diet and regular exercise are both important for maintaining good +health. A balanced diet should include a variety of foods from all the major +food groups, such as fruits, vegetables, grains, proteins, and dairy. +Eating a balanced diet can help prevent chronic diseases such as heart disease, +diabetes, and obesity. Regular exercise is also important for maintaining good +health. It can help improve cardiovascular health, strengthen muscles and bones, +and reduce the risk of chronic diseases. Exercise can also help improve mental +health by reducing stress and anxiety. It is recommended that + +Tip 2:Regular exercise is important for maintaining a healthy body and mind. +It can help improve cardiovascular health, strengthen muscles and bones, +and reduce the risk of chronic diseases such as diabetes and heart disease. +Additionally, exercise has been shown to improve mood, reduce stress, +and increase overall well-being. It is recommended that adults engage in +at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of +vigorous-intensity aerobic activity per week, as well as strength training +exercises at least two days per week. + +In summary, a balanced diet and regular exercise are both essential for +maintaining good health. A balanced diet should include a variety of foods from +all the major food groups, while regular exercise can help improve +cardiovascular health, strengthen muscles and bones, reduce the risk of +chronic diseases, and improve mental health. It is recommended that adults +engage in at least 150 minutes of moderate-intensity aerobic activity or +75 minutes of vigorous-intensity aerobic activity per week, +as well as strength training exercises at least two days per week. +``` + +## Benchmark shortfin w/ sglang `bench_serving` script + +We can obtain benchmarking metrics using the `bench_serving` script +provided by SGLang: + +**NOTE: Change `--base-url` if running at a different address** + +```bash +python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer /path/to/tokenizer/dir --request-rate 1 +``` + +There are some more metrics captured, but the most relevant are the following: + +- E2E Latency +- TTFT (Time to First Token) +- TPOT (Time per Output Token) +- ITL (Inter-Token Latency) +- Request Throughput +- Benchmark Duration + +When complete, you should see an output similar to this: + +```text +============ Serving Benchmark Result ============ +Backend: shortfin +Traffic request rate: 1.0 +Successful requests: 10 +Benchmark duration (s): 427.91 +Total input tokens: 1960 +Total generated tokens: 2774 +Total generated tokens (retokenized): 63 +Request throughput (req/s): 0.02 +Input token throughput (tok/s): 4.58 +Output token throughput (tok/s): 6.48 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 416268.77 +Median E2E Latency (ms): 417159.14 +---------------Time to First Token---------------- +Mean TTFT (ms): 292404.29 +Median TTFT (ms): 365989.01 +P99 TTFT (ms): 367325.63 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 1359.41 +Median TPOT (ms): 163.96 +P99 TPOT (ms): 6316.12 +---------------Inter-token Latency---------------- +Mean ITL (ms): 2238.99 +Median ITL (ms): 958.75 +P99 ITL (ms): 2719.50 +================================================== +``` From e37b934a7eb3841f2a609cf94fa6b684838729ba Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 22 Nov 2024 20:34:00 +0100 Subject: [PATCH 09/25] [shortfin] Remove fastapi from 3.13t requirements (#596) This removes fastapi as a requirements as it pulls in pydantic which depends on pydantic-core. The latter cannot be build for a free-threaded Python and as part of the 0.27.1 release bumped PyO3 to 0.22.6 to prevent accidental installs on free-threaded Python. Supersedes and closes #595. --- shortfin/requirements-tests-nogil.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/shortfin/requirements-tests-nogil.txt b/shortfin/requirements-tests-nogil.txt index 1049b0412..1769467ab 100644 --- a/shortfin/requirements-tests-nogil.txt +++ b/shortfin/requirements-tests-nogil.txt @@ -1,4 +1,3 @@ pytest requests -fastapi uvicorn From eacbd9bc2319392cbf6cf5ed1bc086168860361c Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:22:36 -0800 Subject: [PATCH 10/25] [sharktank] Add Perplexity pre-submit test (#579) Add a perplexity pre-submit test for llama3.1 8b fp16 with 5 prompts --- .github/workflows/ci_eval.yaml | 4 +- .github/workflows/ci_eval_short.yaml | 77 +++++++++ .../sharktank/evaluate/perplexity_iree.py | 102 +++++++----- .../sharktank/evaluate/perplexity_torch.py | 101 +++++++----- sharktank/sharktank/utils/export_artifacts.py | 8 +- sharktank/sharktank/utils/load_llm.py | 23 ++- .../tests/evaluate/perplexity_iree_test.py | 152 ++++++++++-------- 7 files changed, 302 insertions(+), 165 deletions(-) create mode 100644 .github/workflows/ci_eval_short.yaml diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index e6794bff9..0164b6cdc 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - Perplexity +name: CI - sharktank perplexity on: workflow_dispatch: @@ -75,7 +75,7 @@ jobs: iree-base-runtime - name: Run perplexity test with IREE - run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml new file mode 100644 index 000000000..4622f5c57 --- /dev/null +++ b/.github/workflows/ci_eval_short.yaml @@ -0,0 +1,77 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - sharktank perplexity short + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_perplexity_iree: + name: "Llama3.1 8B FP16" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run perplexity test with vmfb + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index 9701bed34..6060eb91b 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -99,9 +99,9 @@ def wrapper(*args, **kwargs): func_name = func.__name__ if func_name == "get_perplexity": - func_name = f"Total time to calculate perplexity" + func_name = f"Calculate perplexity" elif func_name == "compile_model": - func_name = f"Total time to export and compile" + func_name = f"Export & compile" logger.info(f" {func_name}: {time_taken}") return result @@ -127,7 +127,7 @@ def print_token_comparison(self, i): def compile_model(self, weight_path_str): self.weight_path_str = weight_path_str - logger.info(f"Compiling: {self.weight_path_str}") + logger.info(f" Compiling: {self.weight_path_str}") export_artifacts = ExportArtifacts( irpa_path=self.weight_path_str, @@ -143,7 +143,7 @@ def compile_model(self, weight_path_str): @timeit def load_model(self, weight_path, tokenizer, vmfb_path): - config = LlamaModelConfig( + self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), block_seq_stride=16, kv_cache_type=self.kv_cache_type, @@ -153,18 +153,18 @@ def load_model(self, weight_path, tokenizer, vmfb_path): tensor_parallelism_size=self.tensor_parallelism_size, ) - if config.tensor_parallelism_size > 1: - weight_path.root_theta = shard_theta(weight_path.root_theta, config) + if self.config.tensor_parallelism_size > 1: + weight_path.root_theta = shard_theta(weight_path.root_theta, self.config) theta = weight_path.root_theta - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) else: - model = PagedMixtralModelV1(theta, config) + model = PagedMixtralModelV1(theta, self.config) else: - model = PagedLlamaModelV1(theta, config) + model = PagedLlamaModelV1(theta, self.config) self.generator = TorchGenerator(model, tokenizer) @@ -177,7 +177,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path): self.haldevice = self.runner.config.device @timeit - def get_prompts(self): + def get_prompts(self, num_prompts): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ "text" ] @@ -191,12 +191,15 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ] + ][0:num_prompts] + + self.test_prompts = test_prompts self.bs = len(test_prompts) - return test_prompts + logger.info(f" Batch size: {self.bs}") + @timeit def prefill_vmfb(self, token_batch, i): seq_block_ids = self.batch.pad_block_ids() @@ -252,25 +255,7 @@ def decode_vmfb(self, token_batch, i): return decode_logits @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.test_prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.test_prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) - - self.max_prompt_length = max(seq_lens) - - self.token_ids = torch.tensor(token_ids, device=self.torch_device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.torch_device) - ) + def get_logits(self, page_cache_size): is_first_token = True start = 0 @@ -306,6 +291,7 @@ def get_logits(self): token_batch=token_batch, seq_lens_batch=self.seq_lens_batch, bs=self.bs, + page_cache_size=page_cache_size, ) self.cache_state = ireert.asdevicearray( @@ -355,11 +341,31 @@ def compute_perplexity(self): } @timeit - def get_perplexity(self, test_prompts): + def get_perplexity(self): - self.test_prompts = test_prompts + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.torch_device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.torch_device) + ) - self.get_logits() + self.get_logits(page_cache_size=self.page_cache_size) self.out_logits = self.out_logits[..., :-1, :].contiguous() self.token_ids = self.token_ids[..., 1:].contiguous() @@ -387,7 +393,9 @@ def run_perplexity( kv_cache_type, tensor_parallelism_size, attention_kernel, + num_prompts, ): + start = time.time() perplexity = Perplexity( torch_device=torch_device, iree_device=iree_device, @@ -398,12 +406,19 @@ def run_perplexity( attention_kernel=attention_kernel, ) - test_prompts = perplexity.get_prompts() - logger.info(f" Total test prompts: {len(test_prompts)}") + perplexity.get_prompts(num_prompts=num_prompts) vmfb_path = perplexity.compile_model(weight_path_str) perplexity.load_model(weight_path, tokenizer, vmfb_path) - ppl = perplexity.get_perplexity(test_prompts) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") return ppl @@ -412,7 +427,7 @@ def main(argv): parser = cli.create_parser() parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument("--torch-device", help="Torch device (or default)") - parser.add_argument("--iree-device", help="List an IREE device, eg: 'hip://0'") + parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") parser.add_argument( "--iree-hip-target", action="store", @@ -437,6 +452,12 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding", ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) @@ -460,6 +481,7 @@ def main(argv): kv_cache_type=kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index da5fc104a..258e8c9a0 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -85,7 +85,7 @@ def wrapper(*args, **kwargs): func_name = func.__name__ if func_name == "get_perplexity": - func_name = "Total time" + func_name = "Calculate perplexity" logger.info(f" {func_name}: {time_taken}") return result @@ -110,7 +110,7 @@ def print_token_comparison(self, i): @timeit def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel): - config = LlamaModelConfig( + self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, kv_cache_type=self.kv_cache_type, @@ -120,23 +120,23 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern tensor_parallelism_size=tensor_parallelism_size, ) - if config.tensor_parallelism_size > 1: - dataset.root_theta = shard_theta(dataset.root_theta, config) + if self.config.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, self.config) theta = dataset.root_theta - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) else: - model = PagedMixtralModelV1(theta, config) + model = PagedMixtralModelV1(theta, self.config) else: - model = PagedLlamaModelV1(theta, config) + model = PagedLlamaModelV1(theta, self.config) self.generator = TorchGenerator(model, tokenizer) @timeit - def get_prompts(self): + def get_prompts(self, num_prompts): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ "text" @@ -152,34 +152,16 @@ def get_prompts(self): s.replace("\n", "").rstrip() for s in test_prompts if s != "" and len(s.split()) >= 20 and s.count("=") < 2 - ] - - logger.info(f" num_test_prompts: {len(test_prompts)}") - - return test_prompts + ][0:num_prompts] - @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.test_prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.test_prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) + self.test_prompts = test_prompts - self.max_prompt_length = max(seq_lens) + self.bs = len(test_prompts) - self.token_ids = torch.tensor(token_ids, device=self.device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.device) - ) + logger.info(f" Batch size: {self.bs}") - self.bs = len(self.test_prompts) + @timeit + def get_logits(self, page_cache_size): is_first_token = True start = 0 @@ -212,6 +194,7 @@ def get_logits(self): token_batch=token_batch, seq_lens_batch=seq_lens_batch, bs=self.bs, + page_cache_size=page_cache_size, ) self.batch.prefill() @@ -268,10 +251,31 @@ def compute_perplexity(self): } @timeit - def get_perplexity(self, test_prompts): + def get_perplexity(self): - self.test_prompts = test_prompts - self.get_logits() + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.device) + ) + + self.get_logits(page_cache_size=self.page_cache_size) self.out_logits = self.out_logits[..., :-1, :].contiguous() self.token_ids = self.token_ids[..., 1:].contiguous() @@ -295,12 +299,22 @@ def run_perplexity_torch( kv_cache_type, tensor_parallelism_size, attention_kernel, + num_prompts, ): - perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + start = time.time() + perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + perplexity.get_prompts(num_prompts=num_prompts) perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) - test_prompts = perplexity.get_prompts() - ppl = perplexity.get_perplexity(test_prompts=test_prompts) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") return ppl @@ -322,6 +336,12 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding.", ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) @@ -339,6 +359,7 @@ def main(argv): kv_cache_type=kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index e7851ac37..c950a875a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -188,13 +188,13 @@ def export_to_mlir( cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) - logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}") + logger.info(f" Exporting mlir:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True) if proc.returncode != 0: raise ExportMlirException(proc, cwd) else: - logger.info(f"Exported to mlir successfully:\n" f"{proc.stdout}") + logger.info(f" Exported to mlir successfully:\n" f"{proc.stdout}") return proc.returncode @@ -231,7 +231,7 @@ def compile_to_vmfb( compile_args += args cmd = subprocess.list2cmdline(compile_args) - logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") + logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) return_code = proc.returncode if return_code != 0: @@ -285,7 +285,7 @@ def iree_benchmark_vmfb( benchmark_args += devices benchmark_args += args cmd = subprocess.list2cmdline(benchmark_args) - logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") + logger.info(f" Launching run command:\n" f"cd {cwd} && {cmd}") proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) return_code = proc.returncode if return_code != 0: diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index acf56eb1b..47d9f0244 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -23,24 +23,20 @@ def __init__( self, model: PagedLlamaModelV1, tokenizer: InferenceTokenizer, - page_cache_size: int = 8192, # Need to look at the model more for this. end_token: int = 2, ): self.model = model self.tokenizer = tokenizer - if model.cache.is_paged: - self.shared_cache_state = model.cache.paged.allocate(page_cache_size) - self.free_pages = list(range(1, page_cache_size)) - else: - self.shared_cache_state = None self.end_token = end_token @property def block_seq_stride(self) -> int: return self.model.cache.block_seq_stride - def begin_batch(self, prompts: list[str], add_start_token: bool): + def begin_batch( + self, prompts: list[str], add_start_token: bool, page_cache_size: int = 128 + ): token_ids, seq_lens = self.tokenizer.encode( prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride, @@ -48,8 +44,10 @@ def begin_batch(self, prompts: list[str], add_start_token: bool): ) token_ids = torch.tensor(token_ids, device=self.model.device) seq_lens = torch.tensor(seq_lens, device=self.model.device) - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state + + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: cache_state = self.model.cache.direct.allocate(bs=len(prompts)) return Batch(self, token_ids, seq_lens, cache_state) @@ -59,10 +57,11 @@ def begin_eval_batch( token_batch: torch.tensor, seq_lens_batch: torch.tensor, bs: int, + page_cache_size: int = 128, ): - - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: cache_state = self.model.cache.direct.allocate(bs=bs) return Batch(self, token_batch, seq_lens_batch, cache_state) diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index 8cf2055c9..d10d9f5db 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -7,10 +7,15 @@ import unittest import pytest import json +import numpy as np from sharktank.evaluate import perplexity_iree -longrun = pytest.mark.skipif("not config.getoption('longrun')") +is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") +skipif_run_quick_llama_test = pytest.mark.skipif( + 'not config.getoption("run-nightly-llama-tests")', + reason="Run large tests if --run-nightly-llama-tests is passed", +) @pytest.mark.usefixtures( @@ -18,7 +23,9 @@ "get_iree_flags", "tensor_parallelism_size", "baseline_perplexity_scores", + "batch_size", ) +@is_mi300x class PerplexityTest(unittest.TestCase): def setUp(self): self.current_perplexity_all = {} @@ -27,7 +34,6 @@ def setUp(self): with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) - @longrun def test_llama3_8B_f16_decomposed(self): # Llama 3.1 8B decomposed @@ -44,25 +50,26 @@ def test_llama3_8B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -79,25 +86,26 @@ def test_llama3_8B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed @@ -114,25 +122,26 @@ def test_llama3_8B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed @@ -149,25 +158,28 @@ def test_llama3_8B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) + @skipif_run_quick_llama_test @pytest.mark.xfail( reason="Sharding is unsupported", ) - @longrun def test_llama3_405B_f16_decomposed(self): # Llama 3.1 405B decomposed @@ -184,25 +196,26 @@ def test_llama3_405B_f16_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed @@ -219,25 +232,26 @@ def test_llama3_405B_f16(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed @@ -254,25 +268,26 @@ def test_llama3_405B_fp8_decomposed(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @longrun + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed @@ -289,17 +304,20 @@ def test_llama3_405B_fp8(self): f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) From 9d9f0d309edd03fa3e5ddb7f7f6a6dd4ee6b87ef Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Mon, 25 Nov 2024 09:01:55 -0600 Subject: [PATCH 11/25] Move SGLang related tests (#601) Split from this PR: https://github.com/nod-ai/shark-ai/pull/590 We have too many tests running on `mi300x-3` and need to move the SGLang related ones to `mi300x-4`. This PR moves the workflows for `sglang_integration_tests` and `sglang_benchmark_tests` to mi300x-4, along with removing the assumption of static MODEL_PATH and TOKENIZER_PATH, downloading them on demand instead. --- .github/workflows/ci-sglang-benchmark.yml | 4 +- .../workflows/ci-sglang-integration-tests.yml | 2 +- .../llm/sglang_benchmarks/__init__.py | 5 +++ .../llm/{ => sglang_benchmarks}/conftest.py | 16 ++++++-- .../sglang_benchmark_test.py | 41 ++++++++----------- .../llm/{ => sglang_benchmarks}/utils.py | 13 ++++++ app_tests/integration_tests/llm/utils.py | 2 +- 7 files changed, 52 insertions(+), 31 deletions(-) create mode 100644 app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py rename app_tests/benchmark_tests/llm/{ => sglang_benchmarks}/conftest.py (74%) rename app_tests/benchmark_tests/llm/{ => sglang_benchmarks}/sglang_benchmark_test.py (76%) rename app_tests/benchmark_tests/llm/{ => sglang_benchmarks}/utils.py (84%) diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index 504e7e5e3..f44e2772b 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -28,7 +28,7 @@ jobs: matrix: version: [3.11] fail-fast: false - runs-on: llama-mi300x-3 + runs-on: mi300x-4 defaults: run: shell: bash @@ -78,7 +78,7 @@ jobs: run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" - name: Launch Shortfin Server - run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html + run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml index 1c382617d..c61756d78 100644 --- a/.github/workflows/ci-sglang-integration-tests.yml +++ b/.github/workflows/ci-sglang-integration-tests.yml @@ -29,7 +29,7 @@ jobs: matrix: version: [3.11] fail-fast: false - runs-on: llama-mi300x-3 + runs-on: mi300x-4 defaults: run: shell: bash diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py similarity index 74% rename from app_tests/benchmark_tests/llm/conftest.py rename to app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py index cc354b7eb..1e1c64b24 100644 --- a/app_tests/benchmark_tests/llm/conftest.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py @@ -9,15 +9,22 @@ import pytest import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) -from integration_tests.llm.utils import compile_model, export_paged_llm_v1 +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +) +from integration_tests.llm.utils import ( + compile_model, + export_paged_llm_v1, + download_with_hf_datasets, +) @pytest.fixture(scope="module") def pre_process_model(request, tmp_path_factory): tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") - model_path = request.param["model_path"] + model_name = request.param["model_name"] + model_param_file_name = request.param["model_param_file_name"] settings = request.param["settings"] batch_sizes = request.param["batch_sizes"] @@ -25,6 +32,9 @@ def pre_process_model(request, tmp_path_factory): config_path = tmp_dir / "config.json" vmfb_path = tmp_dir / "model.vmfb" + model_path = tmp_dir / model_param_file_name + download_with_hf_datasets(tmp_dir, model_name) + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) config = { diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py similarity index 76% rename from app_tests/benchmark_tests/llm/sglang_benchmark_test.py rename to app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py index 0de775795..b66904570 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py @@ -4,7 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import json import logging import multiprocessing import os @@ -16,14 +15,14 @@ pytest.importorskip("sglang") from sglang import bench_serving -from utils import SGLangBenchmarkArgs +from .utils import SGLangBenchmarkArgs, log_jsonl_result from integration_tests.llm.utils import ( find_available_port, start_llm_server, ) -logger = logging.getLogger("__name__") +logger = logging.getLogger(__name__) device_settings = { "device_flags": [ @@ -33,30 +32,21 @@ "device": "hip", } -# TODO: Download on demand instead of assuming files exist at this path -MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa") -TOKENIZER_DIR = Path("/data/llama3.1/8b/") - - -def log_jsonl_result(file_path): - with open(file_path, "r") as file: - json_string = file.readline().strip() - - json_data = json.loads(json_string) - for key, val in json_data.items(): - logger.info(f"{key.upper()}: {val}") - @pytest.mark.parametrize( - "request_rate", - [1, 2, 4, 8, 16, 32], + "request_rate,model_param_file_name", + [ + (req_rate, "meta-llama-3.1-8b-instruct.f16.gguf") + for req_rate in [1, 2, 4, 8, 16, 32] + ], ) @pytest.mark.parametrize( "pre_process_model", [ ( { - "model_path": MODEL_PATH, + "model_name": "llama3_8B_fp16", + "model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf", "settings": device_settings, "batch_sizes": [1, 4], } @@ -64,7 +54,9 @@ def log_jsonl_result(file_path): ], indirect=True, ) -def test_sglang_benchmark_server(request_rate, pre_process_model): +def test_sglang_benchmark_server( + request_rate, model_param_file_name, pre_process_model +): # TODO: Remove when multi-device is fixed os.environ["ROCR_VISIBLE_DEVICES"] = "1" @@ -72,7 +64,8 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): config_path = tmp_dir / "config.json" vmfb_path = tmp_dir / "model.vmfb" - tokenizer_path = TOKENIZER_DIR / "tokenizer.json" + tokenizer_path = tmp_dir / "tokenizer.json" + model_path = tmp_dir / model_param_file_name # Start shortfin llm server port = find_available_port() @@ -81,7 +74,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): tokenizer_path, config_path, vmfb_path, - MODEL_PATH, + model_path, device_settings, timeout=30, ) @@ -91,7 +84,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): backend="shortfin", num_prompt=10, base_url=f"http://localhost:{port}", - tokenizer=TOKENIZER_DIR, + tokenizer=tmp_dir, request_rate=request_rate, ) output_file = ( @@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model): logger.info("======== RESULTS ========") log_jsonl_result(benchmark_args.output_file) except Exception as e: - logger.info(e) + logger.error(e) server_process.terminate() server_process.wait() diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py similarity index 84% rename from app_tests/benchmark_tests/llm/utils.py rename to app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py index 55b01da04..47cea4d76 100644 --- a/app_tests/benchmark_tests/llm/utils.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py @@ -6,8 +6,12 @@ from argparse import Namespace from dataclasses import dataclass +import json +import logging from pathlib import Path +logger = logging.getLogger(__name__) + @dataclass class SGLangBenchmarkArgs: @@ -54,3 +58,12 @@ def __repr__(self): f"Tokenizer: {self.tokenizer}\n" f"Request Rate: {self.request_rate}" ) + + +def log_jsonl_result(file_path): + with open(file_path, "r") as file: + json_string = file.readline().strip() + + json_data = json.loads(json_string) + for key, val in json_data.items(): + logger.info(f"{key.upper()}: {val}") diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py index 05712039e..80b5b3c09 100644 --- a/app_tests/integration_tests/llm/utils.py +++ b/app_tests/integration_tests/llm/utils.py @@ -15,7 +15,7 @@ import requests from transformers import AutoTokenizer -logger = logging.getLogger("__name__") +logger = logging.getLogger(__name__) class AccuracyValidationException(RuntimeError): From bf8540f482f453f4fa04ac750381135e99815bb1 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Mon, 25 Nov 2024 23:14:27 +0800 Subject: [PATCH 12/25] Add a DPP intro in amdgpu_kernel_optimization_guide.md (#598) --- docs/amdgpu_kernel_optimization_guide.md | 123 ++++++++++++++++++++++- 1 file changed, 122 insertions(+), 1 deletion(-) diff --git a/docs/amdgpu_kernel_optimization_guide.md b/docs/amdgpu_kernel_optimization_guide.md index 09c5b59f9..91b7f1385 100644 --- a/docs/amdgpu_kernel_optimization_guide.md +++ b/docs/amdgpu_kernel_optimization_guide.md @@ -4,7 +4,7 @@ Author: Jakub Kuderski @kuhar Date: 2024-06-24 -Last Update: 2024-08-22 +Last Update: 2024-11-22 ## Introduction @@ -293,3 +293,124 @@ forms a *clause* that translates to a single data fabric transaction. > [!TIP] > For allocations of 4 GB or less, you can implement predicated loads using the > `buffer` instructions. + +## Data-Parallel Primitives and Warp-level Reduction + +For cross-lane data sharing, the most straightforward way is LDS. Some lanes +write data to some locations on LDS and other lanes read data from LDS. Besides, +there are several instructions can be used to share data cross lanes within a +wavefront/warp. + +Here's a brief introduction of these instructions. Please check out [this +blog](https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/) for +details. + +### ds_permute/ds_bpermute + +`ds_permute`/`ds_bpermute` instructions use LDS hardware for data sharing but +don't actually write to an LDS location. But it still needs `s_waitcnt` +instruction to determine when data is returned to `dest` VGPR. + +Example: +```nasm +ds_bpermute_b32 dest, addr, src [offset:addr_offset] +``` + +### ds_swizzle + +Compared to `ds_bpermute`, the `ds_swizzle` instruction doesn't require an +additional VGPR for offset since it's encoded in the instruction. + +`ds_swizzle` is likely to have less address generation instructions required +than `ds_bpermute`. + +The cons are: +1. It only supports limited patterns. +2. Similar to `ds_bpermute`, `s_waitcnt` is required to wait for the `dest` VGPR. + +Example: +```nasm +ds_swizzle_b32 dest, src offset:ds_pattern +``` + +### Data-Parallel Primitives, DPP + +DPP is a 32-bit instruction modifier appended to the normal VALU instructions. +It allows VALU instructions to access data in neighboring lanes directly, which +means it doesn't need LDS hardware anymore, hence `s_waitcnt` instructions are +**not required**. + +Unfortunately, it also supported limited patterns like `ds_swizzle`. And there +are some instructions that can't be modified by DPP. + +Example: +```nasm +; Normal VALU instruction. +v_add_f32 + +; Instruction modified by DPP. +v_add_f32_dpp +``` + +It's worth mentioning that DPP has different names and syntaxes on different +architectures: +* CDNA: DPP +* RDNA: DPP8/DPP16 + +For details, please check the [MI300 ISA Reference +Guide](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf) +and the [RDNA3 ISA Reference +Guide](https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf). + +### How to use them in MLIR + +Each instruction has a corresponding Op in MLIR (except for `ds_permute`, this +one is not implemented at the time of writing): +* `ds_bpermute`: `rocdl.ds_bpermute` +* `ds_swizzle`: `rocdl.ds_swizzle` +* DPP: `rocdl.update.dpp`, `amdgpu.dpp` (a thin wrapper around + `rocdl.update.dpp` with more comprehensive user interface, e.g., replace magic + numbers with enums) + +The first 2 are straightforward, while DPP follows a different fashion. + +Since DPP is an instruction modifier instead of an instruction itself, there are +tremendous number of combinations of VALU instructions and DPP. To solve that, +`rocdl.update.dpp` and `amdgpu.dpp` are designed to be a wrapper of +`v_mov_b32_dpp` instruction. And it depends on LLVM compiler to fuse it with the +subsequent VALU instruction **with best efforts**. + +For example, `v_mov_b32_dpp` + `v_add_f32_e32` might be fused into `v_add_f32_dpp`. + +There are plenty of constraints stopping an instruction from being merged. For +example, if either the `bank_mask` or the `row_mask` is not `0xf`, it can't be +fused. You can check the +[GCNDPPCombine::combineDPPMov](https://github.com/llvm/llvm-project/blob/ab51eccf88f5321e7c60591c5546b254b6afab99/llvm/lib/Target/AMDGPU/GCNDPPCombine.cpp#L522) +function to see how it works. + +### Comparison + +To summarize, there's no free lunch: instruction's expressivity comes at the +expense of performance. + +The relative performance of cross-lane instructions is as follows: + +DPP > `ds_swizzle` >= `ds_permute` > `ds_bpermute` + +while the generality ranking is the reverse: + +DPP < `ds_swizzle` < `ds_permute` < `ds_bpermute` + +This table presents the approximate instruction latency, collected +experimentally on Fused Softmax kernel with +[rocprofv2](https://github.com/ROCm/rocprofiler?tab=readme-ov-file#plugin-support) +on the MI300 GPU: + +| Instructions | MLIR Op | Hardware | latency/#cycles | +| ---------------------- | ---------------------------- | ------------ | --------------- | +| ds_permute/ds_bpermute | rocdl.ds_bpermute | LDS hardware | ~50* | +| ds_swizzle | rocdl.ds_swizzle | LDS hardware | ~50* | +| DPP | rocdl.update.dpp, amdgpu.dpp | VALU | 4~12 | + +*: For `ds_permute`/`ds_bpermute` and `ds_swizzle`, the latency includes the +instruction itself and its corresponding `s_waitcnt` instruction. From e906b669b071d1804704b619910b6774f0604070 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:53:50 -0600 Subject: [PATCH 13/25] See if stale dependencies are causing shortfin server to fail to start (#604) # Description We started seeing a failure in `Shortfin CPU LLM Integration Test` after merging #601. However, the only aspect of the integration test that that PR touches is a fix in the logger: Old ```python logger = logging.getLogger("__name__") ``` New ```python logger = logging.getLogger(__name__) ``` That shouldn't have an impact on the test, and while reading the output of the workflow, it didn't seem to be the line that caused the server to not start. When testing locally in a fresh environment, the test ran fine, which made me think that it may be related to stale dependencies. I updated the hash of cached pip to take into account requirement changes in `sharktank` and `shortfin`, which appears to fix the test. --- .github/workflows/ci-shark-ai.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-shark-ai.yml b/.github/workflows/ci-shark-ai.yml index bf8007e65..fc85a76a7 100644 --- a/.github/workflows/ci-shark-ai.yml +++ b/.github/workflows/ci-shark-ai.yml @@ -49,7 +49,7 @@ jobs: id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} - name: Install pip deps run: | From 0e74c394037784e46a1f898078d3142b04d91662 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 25 Nov 2024 19:26:30 +0100 Subject: [PATCH 14/25] [shortfin] Merge Windows and Linux workflows (#603) Merges the Windows and Linux workflows by implementing a matrix strategy. In addition, this adds compiling the shortfin code with GCC to the CI and makes the pip cache more explicit as it is now defined which requirements file to include in the hash. --- ...x64-libshortfin.yml => ci-libshortfin.yml} | 92 ++++++++++------- .../workflows/ci_windows_x64-libshortfin.yml | 98 ------------------- 2 files changed, 58 insertions(+), 132 deletions(-) rename .github/workflows/{ci_linux_x64-libshortfin.yml => ci-libshortfin.yml} (54%) delete mode 100644 .github/workflows/ci_windows_x64-libshortfin.yml diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci-libshortfin.yml similarity index 54% rename from .github/workflows/ci_linux_x64-libshortfin.yml rename to .github/workflows/ci-libshortfin.yml index afeca11a6..33a6df72b 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci-libshortfin.yml @@ -10,13 +10,13 @@ on: workflow_dispatch: pull_request: paths: - - '.github/workflows/ci_linux_x64-libshortfin.yml' + - '.github/workflows/ci-libshortfin.yml' - 'shortfin/**' push: branches: - main paths: - - '.github/workflows/ci_linux_x64-libshortfin.yml' + - '.github/workflows/ci-libshortfin.yml' - 'shortfin/**' permissions: @@ -36,17 +36,55 @@ env: jobs: build-and-test: - name: Build and test - runs-on: ubuntu-24.04 + name: "Unit tests :: ${{ matrix.name }} :: ${{ matrix.python-version }}" + runs-on: ${{ matrix.runs-on }} + defaults: + run: + shell: bash strategy: + fail-fast: false matrix: + name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Ubuntu (GCC)", "Windows (MSVC)"] python-version: ["3.10", "3.11", "3.12"] + include: + - name: Ubuntu (Clang)(full) + runs-on: ubuntu-24.04 + cmake-options: + -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD + additional-packages: clang lld + - name: Ubuntu (Clang)(host-only) + runs-on: ubuntu-24.04 + # In this configuration, also build static+dynamic in order to verify + # that path structurally works. + cmake-options: + -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD -DSHORTFIN_HAVE_AMDGPU=OFF -DSHORTFIN_BUILD_STATIC=ON -DSHORTFIN_BUILD_DYNAMIC=ON + additional-packages: clang lld + - name: Ubuntu (GCC) + runs-on: ubuntu-24.04 + - name: Windows (MSVC) + runs-on: windows-2022 + exclude: + # Only test Python 3.12 with GCC + - name: Ubuntu (GCC) + python-version: "3.10" + - name: Ubuntu (GCC) + python-version: "3.11" + # TODO: Include additional Python versions for Windows after build got fixed + - name: Windows (MSVC) + python-version: "3.10" + - name: Windows (MSVC) + python-version: "3.11" steps: - - name: Install dependencies + - name: (Linux) Install dependencies + if: "runner.os == 'Linux'" run: | sudo apt update - sudo apt install clang lld cmake ninja-build + sudo apt install cmake ninja-build ${{matrix.additional-packages}} + + - name: (Windows) Configure MSVC + if: "runner.os == 'Windows'" + uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0 - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -70,56 +108,42 @@ jobs: git submodule update --init --depth 1 -- third_party/googletest git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - name: Setup Python ${{ matrix.python-version }} + - name: "Setup Python ${{ matrix.python-version }}" uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} cache: "pip" + cache-dependency-path: | + 'shortfin/requirements-tests.txt' + 'shortfin/requirements-iree-compiler.txt' - name: Install Python packages - # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | pip install -r requirements-tests.txt pip install -r requirements-iree-compiler.txt pip freeze - - name: Build shortfin (full) + - name: Build shortfin working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | mkdir build cmake -GNinja \ -S. \ -Bbuild \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_BUNDLE_DEPS=ON \ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ + ${{matrix.cmake-options}} cmake --build build --target all - pip install -v -e build/ - - name: Test shortfin (full) + - name: pip install shortfin + if: ${{ matrix.name != 'Ubuntu (Clang)(host-only)'}} working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - ctest --timeout 30 --output-on-failure --test-dir build - pytest -s + pip install -v -e build/ - - name: Build shortfin (host-only) + - name: Test shortfin + if: ${{ matrix.name != 'Ubuntu (Clang)(host-only)'}} working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - mkdir build-host-only - # In this configuration, also build static+dynamic in order to verify - # that path structurally works. - cmake -GNinja \ - -S. \ - -Bbuild-host-only \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - -DSHORTFIN_HAVE_AMDGPU=OFF \ - -DSHORTFIN_BUILD_STATIC=ON \ - -DSHORTFIN_BUILD_DYNAMIC=ON - cmake --build build-host-only --target all + ctest --timeout 30 --output-on-failure --test-dir build + pytest -s diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml deleted file mode 100644 index 00873c432..000000000 --- a/.github/workflows/ci_windows_x64-libshortfin.yml +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -name: CI - shortfin - Windows - -on: - workflow_dispatch: - pull_request: - paths: - - '.github/workflows/ci_windows_x64-libshortfin.yml' - - 'shortfin/**' - push: - branches: - - main - paths: - - '.github/workflows/ci_windows_x64-libshortfin.yml' - - 'shortfin/**' - -permissions: - contents: read - -concurrency: - # A PR number if a pull request and otherwise the commit hash. This cancels - # queued and in-progress runs for the same PR (presubmit) or commit - # (postsubmit). The workflow name is prepended to avoid conflicts between - # different workflows. - group: ${{ github.workflow }}-${{ github.event.number || github.sha }} - cancel-in-progress: true - -env: - IREE_REPO_DIR: ${{ github.workspace }}/iree - LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ - -jobs: - build-and-test: - name: Build and test - runs-on: windows-2022 - - steps: - - name: Configure MSVC - uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0 - - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - submodules: false - - - name: Checkout IREE repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: iree-org/iree - path: ${{ env.IREE_REPO_DIR }} - submodules: false - ref: iree-3.0.0rc20241118 - - - name: Initalize IREE submodules - working-directory: ${{ env.IREE_REPO_DIR }} - run : | - git submodule update --init --depth 1 -- third_party/benchmark - git submodule update --init --depth 1 -- third_party/cpuinfo/ - git submodule update --init --depth 1 -- third_party/flatcc - git submodule update --init --depth 1 -- third_party/googletest - git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: "3.12" - cache: "pip" - - name: Install Python packages - working-directory: ${{ env.LIBSHORTFIN_DIR }} - run: | - pip install -r requirements-tests.txt - pip install -r requirements-iree-compiler.txt - pip freeze - - - name: Build shortfin (full) - working-directory: ${{ env.LIBSHORTFIN_DIR }} - shell: bash - run: | - mkdir build - cmake -GNinja \ - -S. \ - -Bbuild \ - -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON - cmake --build build --target all - pip install -v -e build/ - - - name: Test shortfin (full) - working-directory: ${{ env.LIBSHORTFIN_DIR }} - run: | - ctest --timeout 30 --output-on-failure --test-dir build - pytest -s From ddc3091f0be1b962e7fcaa8da2056728febf761d Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 25 Nov 2024 19:37:13 -0500 Subject: [PATCH 15/25] Replace AttnPagedCache with BasePagedAttentionCache (#565) Creates space for #593 (prefix-sharing) Coming next: #607 , which should be the last thing I do before I can check in my blocktrie implementation. Summary of changes: - copied over stella's cache.py and renamed it to page_pool.py - each inference request now notifies the cache when its pages are done written to --- .../ci_linux_x64_nogil-libshortfin.yml | 2 +- shortfin/python/shortfin_apps/llm/_deps.py | 26 ++- .../shortfin_apps/llm/components/cache.py | 111 ------------ .../kvcache/base_attention_cache.py | 80 +++++++++ .../llm/components/kvcache/page_pool.py | 159 ++++++++++++++++++ .../shortfin_apps/llm/components/messages.py | 15 +- .../shortfin_apps/llm/components/service.py | 39 ++++- .../tests/apps/llm/components/cache_test.py | 94 ----------- .../llm/components/kvcache/page_pool_test.py | 57 +++++++ shortfin/tests/conftest.py | 33 ++++ 10 files changed, 387 insertions(+), 229 deletions(-) delete mode 100644 shortfin/python/shortfin_apps/llm/components/cache.py create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py delete mode 100644 shortfin/tests/apps/llm/components/cache_test.py create mode 100644 shortfin/tests/apps/llm/components/kvcache/page_pool_test.py diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 0e0e1db2a..550366e1b 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/sd # TODO: Enable further tests and switch to # pytest -s diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py index 7123d011e..fb8ca8176 100644 --- a/shortfin/python/shortfin_apps/llm/_deps.py +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -5,13 +5,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from shortfin.support.deps import ShortfinDepNotFoundError +import sys -try: - import tokenizers -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "tokenizers") from e +deps = [ + "tokenizers", + "dataclasses_json", +] -try: - import dataclasses_json -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e +for dep in deps: + try: + __import__(dep) + except ModuleNotFoundError as e: + if "pytest" in sys.modules: + import pytest + + pytest.skip( + f"A test imports shortfin_apps.llm; skipping due to unavailable Shortfin LLM dependency: {dep}", + allow_module_level=True, + ) + else: + raise ShortfinDepNotFoundError(__name__, dep) from e diff --git a/shortfin/python/shortfin_apps/llm/components/cache.py b/shortfin/python/shortfin_apps/llm/components/cache.py deleted file mode 100644 index 12794498f..000000000 --- a/shortfin/python/shortfin_apps/llm/components/cache.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Sequence - -import logging -import math -import threading - -import shortfin as sf - -from .config_struct import ModelParams, human_size - -logger = logging.getLogger(__name__) - - -class AttnPageEntry: - __slots__ = [ - "cache", - "index", - "in_use", - ] - - def __init__(self, cache: "AttnPageCache", index: int): - self.cache = cache - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnPageCache: - """Page table based attention cache. - - While internal to a model, the cache is organized with additional structure - per page, outside of the model, it is just a list of pages of a certain - element type and number of elements (all inner dims are flattened). - - One page table is allocated per device in a fiber. Currently, this is a - dense allocation with committed memory but in the future, we may just - allocate the address space and lazily populate it with committed memory. - - The cache is unique because usage of it can span fibers and concurrency - is implicitly managed at the block level (i.e. freshly acquired blocks - are assumed to be uninitialized and available immediately for use). - - It is initialized with a discrete list of fiberd devices from a fiber but - cache usage can be done from any fiber which includes those devices. - """ - - def __init__( - self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams - ): - self._lock = threading.Lock() - self.devices = list(devices) - self.model_params = model_params - self.page_tables: list[sf.array.device_array] = [] - cache_params = model_params.paged_kv_cache - alloc_page_count = cache_params.device_block_count - - # Setup accounting structs. - self.attn_page_entries = [ - AttnPageEntry(self, i) for i in range(alloc_page_count) - ] - self.attn_page_free = list(self.attn_page_entries) - - # Initialize a page table on each device. - assert cache_params is not None, "Model does not have a paged kv cache" - page_table_shape = [ - alloc_page_count, - model_params.paged_kv_block_size_elements, - ] - for device in devices: - logging.info( - "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", - page_table_shape, - model_params.attn_dtype, - human_size( - math.prod(page_table_shape) - * model_params.attn_dtype.dense_byte_count - ), - device, - ) - page_table = sf.array.device_array.for_device( - device, page_table_shape, model_params.attn_dtype - ) - self.page_tables.append(page_table) - - def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None: - with self._lock: - available = len(self.attn_page_free) - if count > available: - return None - return [self.attn_page_free.pop() for _ in range(count)] - - def release_pages(self, pages: list[AttnPageEntry]): - with self._lock: - self.attn_page_free.extend(pages) - - def __repr__(self): - # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) - total_pages = len(self.attn_page_entries) - return ( - f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " - f"{100.0 * free_pages / total_pages}% free)" - ) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py new file mode 100644 index 000000000..0007000bc --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -0,0 +1,80 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Base class for kv caches. +""" + +from typing import List +from .page_pool import PageInfo +import math + + +class BasePagedAttentionCache: + """ + Manages lifecycle of pages (using PageInfo as handles). + + + Page States: + Caching - Page can be read by multiple threads + - Also maintains a reference count + Writing - Page is being modified by a single owner thread + + Transitions: + Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing + Writing -> Caching: When writing is complete and page is released + + Thread Safety: + - Multiple readers allowed in ReadableCaching state + - Single writer exclusive access in Writing state + - Reference counting prevents eviction of in-use pages + """ + + def __init__(self, page_pool, tokens_per_page): + self.page_pool = page_pool + self.tokens_per_page = tokens_per_page + + def acquire_pages_for_tokens( + self, tokens: List[int], extra_token_slots: int = 1 + ) -> tuple[list[PageInfo], int]: + """ + Given a list of tokens, return a list of pages and a start position to continue generation from. + + Parameters: + - tokens: all the known tokens for this generation request + - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. + + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. + + The pages are returned in order. + + No token at idx < n_cached_token should be written to. TODO: consider enforcing this. + """ + token_count = len(tokens) + pages_needed = math.ceil(token_count / self.tokens_per_page) + pages = self.page_pool.acquire_free_pages(pages_needed) + + n_cached_tokens = 0 + + return pages, n_cached_tokens + + def publish_pages(self, tokens, pages) -> None: + """ + Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + + Associates the tokens with the pages, and mark them as done writing. + + It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. + """ + + pass # the base implementation doesn't cache unfinished requests. + + def release_pages(self, tokens, pages): + """ + Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. + """ + # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release + self.page_pool.release_pages(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py new file mode 100644 index 000000000..1686370c0 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -0,0 +1,159 @@ +from __future__ import annotations +from typing import List, Tuple, Optional, Sequence +import threading +import logging +import shortfin as sf +import shortfin.array as sfnp +from dataclasses import dataclass + +from ..config_struct import human_size +import math + +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class PageInfo: + """ + Page index with some metadata about its contents. + """ + + index: int + pool: PagePool + token_offset: int # Offset within the page + token_count: int # Number of tokens stored in this page + writing: bool = False + read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release + + +@dataclass +class PagePoolConfig: + """ + Hyperparameters for the page pool. + """ + + dtype: sf.dtype + alloc_page_count: int + + paged_kv_block_size_elements: int # size of a single page as # of elements + # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: + # 32: number of transformer blocks + # 2: one for k + one for v + # 16: tokens per page + # 8: head count (32 heads, but every 4 heads share the same kv buffer) + # 128: hidden dimension + + +class PagePool: + """Page table based attention cache. + + While internal to a model, the cache is organized with additional structure + per page, outside of the model, it is just a list of pages of a certain + element type and number of elements (all inner dims are flattened). + + One page table is allocated per device in a fiber. Currently, this is a + dense allocation with committed memory but in the future, we may just + allocate the address space and lazily populate it with committed memory. + + The cache is unique because usage of it can span fibers and concurrency + is implicitly managed at the block level (i.e. freshly acquired blocks + are assumed to be uninitialized and available immediately for use). + + It is initialized with a discrete list of fiberd devices from a fiber but + cache usage can be done from any fiber which includes those devices. + + In addition to supporting paged attention standalone, this also serves + as the array / buffer allocation layer for radix attention described in + `radix_tree.py`. + """ + + def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): + self._lock = threading.Lock() + self.devices = list(devices) + self.config = config + self.page_tables: list[sf.array.device_array] = [] + + # Setup accounting structs. + self.attn_page_entries = [ + PageInfo( + index=i, + pool=self, + token_offset=0, + token_count=0, + ) + for i in range(self.config.alloc_page_count) + ] + + self.attn_page_free = list(self.attn_page_entries) + + # Initialize a page table on each device. + page_table_shape = [ + self.config.alloc_page_count, + self.config.paged_kv_block_size_elements, + ] + for device in devices: + logging.info( + "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", + page_table_shape, + self.config.dtype, + human_size(config.dtype.compute_dense_nd_size(page_table_shape)), + device, + ) + page_table = sf.array.device_array.for_device( + device, page_table_shape, self.config.dtype + ) + self.page_tables.append(page_table) + + def acquire_free_pages(self, count: int) -> list[PageInfo] | None: + with self._lock: + available = len(self.attn_page_free) + if count > available: + return None + return [self.attn_page_free.pop() for _ in range(count)] + + def release_pages(self, pages: list[PageInfo]): + with self._lock: + self.attn_page_free.extend(pages) + + def copy_page(self, src_page: PageInfo) -> PageInfo: + """ + Copy a page's contents to a new page. + + Args: + src_page: Source page to copy from + token_count: Optional number of tokens to copy. If None, copies all tokens. + + Returns: + New PageInfo containing the copied data + """ + # Allocate new page + (dst_page,) = self.acquire_free_pages(1) + + # fill src page with data + + # Copy the data on each device + for page_table in self.page_tables: + # View of source and destination pages + src_view = page_table.view(src_page.index) + dst_view = page_table.view(dst_page.index) + # Copy the data + dst_view.copy_from(src_view) + + # Setup destination page metadata + dst_page.token_offset = 0 # Always start at beginning of new page + + return dst_page + + def __repr__(self): + # No need to lock for repr (list is internally synchronized). + free_pages = len(self.attn_page_free) + total_pages = len(self.attn_page_entries) + return ( + f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " + f"{100.0 * free_pages / total_pages}% free)" + ) + + +############################## begin radix attention diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index fdcbeefc1..c3e6fe34b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache, AttnPageEntry +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PageInfo class InferencePhase(Enum): @@ -41,8 +42,8 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): self.result_logits: sfnp.device_array | None = None # Cache pages that have been locked for this request. - self._cache: AttnPageCache | None = None - self.locked_pages: list[AttnPageEntry] | None = None + self._cache: BasePagedAttentionCache | None = None + self.locked_pages: list[PageInfo] | None = None def reset(self, phase: InferencePhase): """Resets all per request state in preparation for an subsequent execution.""" @@ -66,16 +67,18 @@ def free_cache_pages(self): pages = self.locked_pages self._cache = None self.locked_pages = None - cache.release_pages(pages) + cache.release_pages(self.input_token_ids, pages) def lock_initial_cache_pages( - self, cache: AttnPageCache, pages: list[AttnPageEntry] + self, cache: BasePagedAttentionCache, pages: list[PageInfo] ): assert not self._cache self._cache = cache self.locked_pages = pages - def lock_new_cache_pages(self, cache: AttnPageCache, pages: list[AttnPageEntry]): + def lock_new_cache_pages( + self, cache: BasePagedAttentionCache, pages: list[PageInfo] + ): assert self._cache is cache self.locked_pages.extend(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index bcd08b756..8d3cc1424 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,7 +11,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PagePoolConfig, PagePool from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -54,8 +55,17 @@ def __init__( # Scope dependent objects. self.batcher = BatcherProcess(self) - self.page_cache = AttnPageCache( - devices=self.main_fiber.devices_dict.values(), model_params=model_params + page_pool_config = PagePoolConfig( + dtype=model_params.attn_dtype, + alloc_page_count=model_params.paged_kv_cache.device_block_count, + paged_kv_block_size_elements=model_params.paged_kv_block_size_elements, + ) + page_pool = PagePool( + devices=self.main_fiber.devices_dict.values(), config=page_pool_config + ) + self.page_cache = BasePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) self.program_isolation = PROG_ISOLATIONS[program_isolation] @@ -200,7 +210,7 @@ def board_flights(self): self.pending_prefills.clear() logger.debug("Post boarding cache state: %r", cache) - def board_prefills(self, cache: AttnPageCache): + def board_prefills(self, cache: BasePagedAttentionCache): # Fill prefill flights. pending_prefills = self.pending_prefills if len(pending_prefills) == 0: @@ -209,7 +219,7 @@ def board_prefills(self, cache: AttnPageCache): self.service, InferencePhase.PREFILL, self.page_seq_stride, - cache.page_tables, + cache.page_pool.page_tables, ) for prefill_request in pending_prefills: assert prefill_request.phase == InferencePhase.PREFILL @@ -218,7 +228,11 @@ def board_prefills(self, cache: AttnPageCache): needed_pages = math.ceil( len(prefill_request.input_token_ids) / self.page_seq_stride ) - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) if pages is None: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue @@ -236,13 +250,16 @@ def board_prefills(self, cache: AttnPageCache): # And takeoff. exec_process.launch() - def board_decodes(self, cache: AttnPageCache): + def board_decodes(self, cache: BasePagedAttentionCache): # Fill decode flights. pending_decodes = self.pending_decodes if len(pending_decodes) == 0: return exec_process = InferenceExecutorProcess( - self.service, InferencePhase.DECODE, self.page_seq_stride, cache.page_tables + self.service, + InferencePhase.DECODE, + self.page_seq_stride, + cache.page_pool.page_tables, ) for decode_request in pending_decodes: assert decode_request.phase == InferencePhase.DECODE @@ -254,7 +271,11 @@ def board_decodes(self, cache: AttnPageCache): / self.page_seq_stride ) if needed_pages > len(decode_request.locked_pages): - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + decode_request.input_token_ids, + extra_token_slots=1, # need 1 extra slot to write result. + ) if pages is None: logger.debug( "Cannot fulfill decode request for %d pages", needed_pages diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py deleted file mode 100644 index 169d082b1..000000000 --- a/shortfin/tests/apps/llm/components/cache_test.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Tests for llm kvcache component. -""" - -import pytest -import time -import tempfile -import shortfin as sf -from _shortfin import lib as sfl -from shortfin_apps.llm.components import cache -from shortfin_apps.llm.components import config_struct -import json -from pathlib import Path - - -@pytest.fixture -def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - ls = sc.create_system() - yield ls - ls.shutdown() - - -@pytest.fixture -def fiber(lsys): - # TODO: Should adopt the main thread. - worker = lsys.create_worker("main") - return lsys.create_fiber(worker) - - -@pytest.fixture -def device(fiber): - return fiber.device(0) - - -@pytest.fixture -def model_params(): - model_params = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": [4], - "decode_batch_sizes": [4], - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - - # Create a temporary file to store the JSON - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as tmp_file: - json.dump(model_params, tmp_file, indent=4) - tmp_path = Path(tmp_file.name) - - try: - # Load the JSON using config_struct - model_params = config_struct.ModelParams.load_json(tmp_path) - yield model_params - finally: - tmp_path.unlink - - -@pytest.fixture -def cache_fixture(fiber, model_params) -> cache.AttnPageCache: - # Create and return the cache object - return cache.AttnPageCache( - devices=fiber.devices_dict.values(), model_params=model_params - ) - - -@pytest.mark.parametrize("n_allocated", [1, 16, 255]) -def test_alloc( - cache_fixture: cache.AttnPageCache, - n_allocated, - model_params: config_struct.ModelParams, -): - alloc_page_count = cache_fixture.page_tables[0].shape[0] - - assert alloc_page_count == model_params.paged_kv_cache.device_block_count - - pages = cache_fixture.acquire_free_pages(n_allocated) - last_page = alloc_page_count - 1 - expected_indices = range(last_page, last_page - n_allocated, -1) - for p, expected_ix in zip(pages, expected_indices): - assert p.index == expected_ix - assert p.index > 0 diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py new file mode 100644 index 000000000..a1ec00c07 --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -0,0 +1,57 @@ +import pytest +import logging +from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PagePoolConfig +import shortfin as sf +import shortfin.host +import shortfin.array as sfnp +import shortfin.amdgpu + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def setup_pool(generic_device): + pool = PagePool( + devices=[generic_device], + config=PagePoolConfig( + alloc_page_count=256, + dtype=sfnp.float16, + paged_kv_block_size_elements=393216, + ), + ) + return pool + + +def test_page_acquisition(setup_pool): + pool = setup_pool + logger.info(f"=== Running page acquisition test on system ===") + page0 = pool.acquire_free_pages(1) + assert page0 is not None, f"Failed to acquire a free page on system" + logger.info(f"Successfully acquired page on system") + + +def test_page_copy(setup_pool): + pool = setup_pool + logger.info(f"=== Running page copy test on system ===") + (page0,) = pool.acquire_free_pages(1) + page1 = pool.copy_page(page0) + assert page1 is not None, f"Failed to copy a page on system" + assert page0 != page1, f"Copied page should be different from original on system" + logger.info(f"Successfully copied page on system") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Set up logging format to include timestamp and level""" + logging.basicConfig( + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + force=True, + ) + + +# Add more tests as needed + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 083698968..b16d5a3c9 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -50,6 +50,17 @@ def pytest_runtest_setup(item): sf.SystemBuilder.default_system_type = system_type +# Dynamic Parameterization for lsys Fixture +def pytest_generate_tests(metafunc): + if "generic_lsys" in metafunc.fixturenames: + system = metafunc.config.getoption("--system") + if system == "amdgpu": + params = ["cpu", "amdgpu"] + else: + params = [system] + metafunc.parametrize("generic_lsys", params, indirect=True) + + # Keys that will be cleaned project wide prior to and after each test run. # Test code can freely modify these. CLEAN_ENV_KEYS = [ @@ -96,6 +107,28 @@ def kill(): kill() +@pytest.fixture(scope="session") +def generic_lsys(request): + system_type = request.param + if system_type == "cpu" or system_type == "hostcpu": + sc = sf.host.CPUSystemBuilder() + elif system_type == "amdgpu": + sc = sf.amdgpu.SystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def generic_fiber(generic_lsys): + return generic_lsys.create_fiber() + + +@pytest.fixture +def generic_device(generic_fiber): + return generic_fiber.device(0) + + @pytest.fixture def cpu_lsys(): sc = sf.host.CPUSystemBuilder() From 6bb24a3e475793edcb58d51f2bc5692ba8dd2259 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 26 Nov 2024 10:34:36 -0500 Subject: [PATCH 16/25] [tuner]: use iree_gpu.MMAIntrinsic and iree_gpu.MMAAttr (#605) Remove the data class `MfmaIntrinsic` from the codebase, and use IREE attributes (` iree_gpu.MMAIntrinsic` and `iree_gpu.MMAAttr` ) for MFMA intrinsics in the tuner. **Motivation for this PR**: The original MLIR processing relies heavily on string-based operations, making it fragile and prone to breaking with updates to the IREE Compiler. To address this, we aim to leverage key attributes directly through IREE Python bindings, enabled by exposing these attributes. For more details, refer to [this issue](https://github.com/nod-ai/shark-ai/issues/453). --- tuner/tuner/candidate_gen.py | 13 +++--- tuner/tuner/candidate_gen_test.py | 29 +++++++++--- tuner/tuner/common.py | 72 +++++------------------------ tuner/tuner/common_test.py | 25 +++++----- tuner/tuner/dispatch_constraints.py | 44 ++++++++++++++++-- tuner/tuner/dispatch_parser_test.py | 15 ++++-- 6 files changed, 102 insertions(+), 96 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 38696e6db..f09e08888 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -52,7 +52,7 @@ def apply_configuration( expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" @@ -119,7 +119,6 @@ def get_transform_function_mmt( wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op @@ -132,7 +131,7 @@ def get_transform_function_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -205,7 +204,7 @@ def get_transform_function_conv( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -266,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -346,7 +345,7 @@ def get_transform_function_batch_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -414,7 +413,7 @@ def get_transform_function_batch_matmul( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 36fb87cbb..d81278e8c 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -13,6 +13,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import candidate_gen from . import common @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: M, N, K = 2048, 1280, 1280 + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.contraction, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_matmul, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.broadcast_rhs_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index b6e31768e..45ae48c22 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -85,74 +85,24 @@ def MNK(self) -> tuple[int, int, int]: return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) -@dataclass -class MfmaIntrinsic: - output_type: ir.IntegerType | ir.FloatType - m: int - n: int - k: int - input_type: ir.IntegerType | ir.FloatType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" - - @staticmethod - def mfma_f32_16x16x16_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 16, 16, 16, f16) - - @staticmethod - def mfma_f32_32x32x8_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 32, 32, 8, f16) - - @staticmethod - def mfma_i32_16x16x32_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 16, 16, 32, i8) - - @staticmethod - def mfma_i32_32x32x16_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 32, 32, 16, i8) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f32_16x16x16_f16(), - MfmaIntrinsic.mfma_f32_32x32x8_f16(), - MfmaIntrinsic.mfma_i32_16x16x32_i8(), - MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - def get_compatible_mfma_intrinsics( problem_size: ProblemSize, mma_intrinsics: list[iree_gpu.MMAIntrinsic], -) -> list[MfmaIntrinsic]: - available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] - - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: +) -> list[iree_gpu.MMAIntrinsic]: + def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma + a_type, b_type, c_type = mma_attr.abc_element_types + if problem_size.res_type.element_type != c_type: return False if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: + if ( + problem_size.lhs_type.element_type != a_type + or problem_size.rhs_type.element_type != b_type + ): return False - - if str(intrinsic) not in available_mma_intrinsics: - return False - return True - return list(filter(is_compatible, MfmaIntrinsic.all())) + return list(filter(is_comptible, mma_intrinsics)) class ReorderWorkgroupsStrategy(Enum): @@ -197,7 +147,7 @@ def __str__(self) -> str: class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: MfmaIntrinsic + intrinsic: iree_gpu.MMAAttr tile_sizes: list[int] subgroup_m_count: int subgroup_n_count: int diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 297ac95a2..ea0a4573d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest common_test.py """ import pytest @@ -72,10 +72,12 @@ def test_gpu_pipeline_options() -> None: def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, @@ -97,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: ) -def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: - assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" - assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" - - def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( @@ -116,8 +113,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -133,8 +130,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], ) == [ - common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), - common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ] assert common.get_compatible_mfma_intrinsics( @@ -150,8 +147,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -166,7 +163,7 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert ( diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 85039a1e8..f16b4a241 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -25,10 +25,18 @@ def get_mfma_intrinsic_constraints( ) -> z3.BoolRef: compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + + mma_attrs = [iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics] + mnk_shapes = [mma_attr.mnk_shape for mma_attr in mma_attrs] + return z3.Or( *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics + z3.And( + intrinsic_m == m, + intrinsic_n == n, + intrinsic_k == k, + ) + for m, n, k in mnk_shapes ) ) @@ -134,6 +142,35 @@ def generate_constraints( return constraints +def getMMAAttr( + output_type: ir.IntegerType | ir.FloatType, + m: int, + n: int, + k: int, + lhs_type: ir.IntegerType | ir.FloatType, + rhs_type: ir.IntegerType | ir.FloatType, +) -> iree_gpu.MMAAttr: + for mma_intrinsic in iree_gpu.MMAIntrinsic: + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + a_type, b_type, c_type = mma_attr.abc_element_types + mnk = mma_attr.mnk_shape + if ( + a_type == lhs_type + and b_type == rhs_type + and c_type == output_type + and m == mnk[0] + and n == mnk[1] + and k == mnk[2] + ): + return mma_attr + # If no matching intrinsic is found, raise an exception + raise ValueError( + f"No matching MMA intrinsic found for " + f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, " + f"m={m}, n={n}, k={k}." + ) + + def generate_solutions( logger: logging.Logger, problem_size: ProblemSize, @@ -188,12 +225,13 @@ def generate_solutions( config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( + getMMAAttr( problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), lookup(intrinsic_k), problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, ), [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index d3a99806f..fb10b04bc 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest dispatch_parser_test.py """ import pytest @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_parser @@ -39,10 +40,12 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, @@ -53,10 +56,12 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -75,10 +80,12 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, From 26bf8ce55c9fc3c03597a12ec177e647aceab318 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:52:58 -0600 Subject: [PATCH 17/25] Fix logging for shortfin LLM server (#612) # Description In troubleshooting tests, I realized that we weren't getting any logging output from the server, which used to happen. Realized that shortfin LLM server wasn't logging output. Copied over the UVICORN_CONFIG from sdxl, which seems to fix the problem. --- shortfin/python/shortfin_apps/llm/server.py | 35 ++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 2ab7a1b96..1561803dd 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -33,6 +33,33 @@ logger = logging.getLogger(__name__) +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + @asynccontextmanager async def lifespan(app: FastAPI): @@ -211,11 +238,5 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): main( sys.argv[1:], # Make logging defer to the default shortfin logging config. - log_config={ - "version": 1, - "disable_existing_loggers": False, - "formatters": {}, - "handlers": {}, - "loggers": {}, - }, + log_config=UVICORN_LOG_CONFIG, ) From 10cd58f25171998f59943c465e129d6afff0733c Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 26 Nov 2024 16:22:23 -0800 Subject: [PATCH 18/25] Add T5 encoder bfloat16 support (#614) Adds various tests for verifying bfloat16 execution. The bfloat16 eager execution matches Flux's output. One test verifies this against golden values from the Flux pipeline. I would say that the eager numerical error is still quite high. My initial idea was to compare the numerical error in IREE bfloat16 vs eager float32. It should match the error profile as the one we get between eager bfloat16 compared against eager float32. The problem is that the bfloat16 eager has a high error that needs further investigation. Due to this some tests are marked as xfail as we don't have a good metric to evaluate the IREE results. --- sharktank/conftest.py | 41 +- sharktank/requirements.txt | 2 +- .../sharktank/layers/configs/llm_configs.py | 17 +- sharktank/sharktank/layers/token_embedding.py | 3 +- sharktank/sharktank/models/t5/export.py | 20 +- sharktank/sharktank/models/t5/t5.py | 8 +- .../sharktank/transforms/dataset/__init__.py | 1 + .../sharktank/transforms/dataset/dataset.py | 19 + .../sharktank/types/gguf_interop/base.py | 9 + sharktank/sharktank/types/tensors.py | 40 ++ sharktank/sharktank/utils/iree.py | 67 ++- sharktank/sharktank/utils/testing.py | 3 +- sharktank/tests/models/t5/t5_test.py | 499 +++++++++++++++--- 13 files changed, 615 insertions(+), 114 deletions(-) create mode 100644 sharktank/sharktank/transforms/dataset/dataset.py diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 475f386be..ddd371198 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -153,16 +153,28 @@ def pytest_addoption(parser): # --outtype=f32 \ # t5-v1_1-small parser.addoption( - "--google-t5-v1-1-small-fp32-model-path", + "--google-t5-v1-1-small-f32-model-path", type=Path, - default="/data/t5/small/google__t5-v1_1-small_fp32.gguf", - help="Google T5 v1.1 small fp32 model path", + default="/data/t5/small/google__t5-v1_1-small_f32.gguf", + help="Google T5 v1.1 small float32 model path", ) parser.addoption( - "--google-t5-v1-1-xxl-fp32-model-path", + "--google-t5-v1-1-small-bf16-model-path", type=Path, - default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf", - help="Google T5 v1.1 XXL fp32 model path", + default="/data/t5/small/google__t5-v1_1-small_bf16.gguf", + help="Google T5 v1.1 small bfloat16 model path", + ) + parser.addoption( + "--google-t5-v1-1-xxl-f32-model-path", + type=Path, + default="/data/t5/xxl/google__t5-v1_1-xxl_f32.gguf", + help="Google T5 v1.1 XXL float32 model path", + ) + parser.addoption( + "--google-t5-v1-1-xxl-bf16-model-path", + type=Path, + default="/data/t5/xxl/google__t5-v1_1-xxl_bf16.gguf", + help="Google T5 v1.1 XXL bfloat16 model path", ) parser.addoption( @@ -288,15 +300,20 @@ def get_model_artifacts(request: FixtureRequest): model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option( request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model" ) - model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option( + model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option( + request, + "--google-t5-v1-1-small-f32-model-path", + "google__t5_v1_1_small_f32_model", + ) + model_path["google__t5_v1_1_small_bf16_model_path"] = set_fixture_from_cli_option( request, - "--google-t5-v1-1-small-fp32-model-path", - "google__t5_v1_1_small_fp32_model", + "--google-t5-v1-1-small-bf16-model-path", + "google__t5_v1_1_small_bf16_model", ) - model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option( + model_path["google__t5_v1_1_xxl_f32_model_path"] = set_fixture_from_cli_option( request, - "--google-t5-v1-1-xxl-fp32-model-path", - "google__t5_v1_1_xxl_fp32_model", + "--google-t5-v1-1-xxl-f32-model-path", + "google__t5_v1_1_xxl_f32_model", ) return model_path diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 6b533d977..26d89c59d 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -1,7 +1,7 @@ iree-turbine # Runtime deps. -gguf==0.6.0 +gguf==0.10.0 numpy<2.0 # Needed for newer gguf versions (TODO: remove when gguf package includes this) diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 35a2ee570..996a92152 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -227,6 +227,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): == properties["t5.attention.layer_norm_rms_epsilon"] ) + all_kwargs = {"vocab_size": None, "feed_forward_proj": None} + gguf_to_config_names_map = { "t5.context_length": ["context_length"], "t5.embedding_length": ["d_model"], @@ -236,11 +238,9 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): "t5.attention.key_length": ["d_kv"], "t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"], "t5.attention.relative_buckets_count": ["relative_attention_num_buckets"], - "t5.decoder_start_token_id": ["decoder_start_token_id"], "tokenizer.ggml.eos_token_id": ["eos_token_id"], "tokenizer.ggml.padding_token_id": ["pad_token_id"], } - all_kwargs = {"vocab_size": None, "feed_forward_proj": None} all_kwargs.update( { config_name: properties[gguf_name] @@ -248,6 +248,19 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): for config_name in config_names } ) + + gguf_to_optional_config_names_map = { + "t5.decoder_start_token_id": ["decoder_start_token_id"], + } + all_kwargs.update( + { + config_name: properties[gguf_name] + for gguf_name, config_names in gguf_to_optional_config_names_map.items() + for config_name in config_names + if gguf_name in properties + } + ) + if "tokenizer.ggml.tokens" in properties: all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"]) all_kwargs.update(kwargs) diff --git a/sharktank/sharktank/layers/token_embedding.py b/sharktank/sharktank/layers/token_embedding.py index 32e7fec8f..e5e06c0ef 100644 --- a/sharktank/sharktank/layers/token_embedding.py +++ b/sharktank/sharktank/layers/token_embedding.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +from typing import Optional from .. import ops from .base import Theta, ThetaLayer @@ -16,7 +17,7 @@ def __init__( theta: Theta, *, weight_name: str = "weight", - dtype: torch.dtype = torch.float32, + dtype: Optional[torch.dtype] = torch.float32, ): super().__init__(theta) self.weight = self.theta_tensor(weight_name) diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py index 7bd5eef3d..8d5f75db2 100644 --- a/sharktank/sharktank/models/t5/export.py +++ b/sharktank/sharktank/models/t5/export.py @@ -4,12 +4,15 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Union +import functools +from typing import Optional, Union from pathlib import Path import torch +from copy import copy from .t5 import T5Config, T5Encoder from ...types import Dataset +from ...transforms.dataset import set_float_dtype from iree.turbine.aot import FxProgramsBuilder, export __all__ = [ @@ -91,7 +94,18 @@ def prune_decoder_parameters(dataset: Dataset): pass -def export_encoder_iree_parameters(model_path: str, output_path: str): - dataset = Dataset.load(model_path) +def export_encoder_iree_parameters( + model_path_or_dataset: str | Dataset, + output_path: str, + dtype: Optional[torch.dtype] = None, +): + if isinstance(model_path_or_dataset, Dataset): + dataset = copy(model_path_or_dataset) + else: + dataset = Dataset.load(model_path_or_dataset) + if dtype: + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) prune_decoder_parameters(dataset) dataset.save(output_path) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py index 4ae9108d5..88472db1d 100644 --- a/sharktank/sharktank/models/t5/t5.py +++ b/sharktank/sharktank/models/t5/t5.py @@ -684,7 +684,9 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): self.add_module( "final_layer_norm", RMSNormLayer( - theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon + theta(f"{theta_prefix}.output_norm"), + epsilon=config.layer_norm_epsilon, + dtype=config.activation_dtype, ), ) @@ -1046,7 +1048,9 @@ def __init__(self, theta: Theta, config: T5Config): super().__init__() self.add_module( "token_embedding", - TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + TokenEmbeddingLayer( + theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype + ), ) encoder_config = copy.deepcopy(config) diff --git a/sharktank/sharktank/transforms/dataset/__init__.py b/sharktank/sharktank/transforms/dataset/__init__.py index b6a2a400a..e2a58ea5d 100644 --- a/sharktank/sharktank/transforms/dataset/__init__.py +++ b/sharktank/sharktank/transforms/dataset/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .sharding import * +from .dataset import * diff --git a/sharktank/sharktank/transforms/dataset/dataset.py b/sharktank/sharktank/transforms/dataset/dataset.py new file mode 100644 index 000000000..c600865e4 --- /dev/null +++ b/sharktank/sharktank/transforms/dataset/dataset.py @@ -0,0 +1,19 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from ...types.tensors import InferenceTensor, PrimitiveTensor, DefaultPrimitiveTensor +from ... import ops + + +def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor: + if isinstance(tensor, PrimitiveTensor) and tensor.dtype.is_floating_point: + return DefaultPrimitiveTensor( + name=tensor.name, data=ops.to(tensor, dtype=dtype) + ) + + return tensor diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index 9a7dcf1ee..494607f97 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -118,6 +118,15 @@ def _wrap_tensor( name=name, data=_externalize_tensor(name, data, logical_shape) ) + if type_name == "BF16": + assert data.dtype == np.uint8 + return DefaultPrimitiveTensor( + name=name, + data=_externalize_tensor(name, data.view(np.int16), logical_shape).view( + dtype=torch.bfloat16 + ), + ) + quantized_type = _quantized_types.get(type_name) if quantized_type is not None: return quantized_type( diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index f870aa101..2c267ac49 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -41,6 +41,7 @@ "AnyTensor", "DefaultPrimitiveTensor", "dtype_to_serialized_name", + "dtype_to_serialized_short_name", "flatten_tensor_tree", "InferenceTensor", "MetaDataValueType", @@ -51,6 +52,7 @@ "register_quantized_layout", "ReplicatedTensor", "serialized_name_to_dtype", + "serialized_short_name_to_dtype", "ShardedTensor", "SplitPrimitiveTensor", "torch_tree_flatten", @@ -1286,6 +1288,15 @@ def dtype_to_serialized_name(dtype: torch.dtype) -> str: ) from e +def dtype_to_serialized_short_name(dtype: torch.dtype) -> str: + try: + return _DTYPE_TO_SHORT_NAME[dtype] + except KeyError as e: + raise KeyError( + f"Missing mapping for dtype {dtype}. Please add to the _SHORT_NAME_TO_DTYPE dict" + ) from e + + def serialized_name_to_dtype(dtype_name: str) -> torch.dtype: try: return _NAME_TO_DTYPE[dtype_name] @@ -1295,6 +1306,15 @@ def serialized_name_to_dtype(dtype_name: str) -> torch.dtype: ) from e +def serialized_short_name_to_dtype(dtype_name: str) -> torch.dtype: + try: + return _SHORT_NAME_TO_DTYPE[dtype_name] + except KeyError as e: + raise KeyError( + f"Missing mapping for dtype '{dtype_name}'. Please add to the _SHORT_NAME_TO_DTYPE dict" + ) from e + + _NAME_TO_DTYPE: dict[str, torch.dtype] = { "float32": torch.float32, "float64": torch.float64, @@ -1338,6 +1358,26 @@ def _maybe_dtype(*names: str): _DTYPE_TO_NAME: dict[torch.dtype, str] = {v: k for k, v in _NAME_TO_DTYPE.items()} +_SHORT_NAME_TO_DTYPE: dict[str, torch.dtype] = { + "f32": torch.float32, + "f64": torch.float64, + "c64": torch.complex64, + "c128": torch.complex128, + "f16": torch.float16, + "bf16": torch.bfloat16, + "ui8": torch.uint8, + "i8": torch.int8, + "i16": torch.int16, + "i32": torch.int32, + "i64": torch.int64, + "b": torch.bool, + "f8_e4m3fnuz": torch.float8_e4m3fnuz, +} + +_DTYPE_TO_SHORT_NAME: dict[torch.dtype, str] = { + v: k for k, v in _SHORT_NAME_TO_DTYPE.items() +} + AnyTensor = Union[torch.Tensor, InferenceTensor] ######################################################################################## diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index d5976ec48..a9097cf06 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -71,6 +71,54 @@ def load_iree_module( return vm_module, vm_context, vm_instance +def promote_bfloat16_to_float32(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.bfloat16: + return tensor.to(dtype=torch.float32) + else: + return tensor + + +def device_array_to_host(device_array: iree.runtime.DeviceArray) -> torch.Tensor: + def reinterpret_hal_buffer_view_element_type( + buffer_view: iree.runtime.HalBufferView, + element_type: iree.runtime.HalElementType, + ) -> iree.runtime.HalBufferView: + return iree.runtime.HalBufferView( + buffer=buffer_view.get_buffer(), + shape=buffer_view.shape, + element_type=element_type, + ) + + def reinterpret_device_array_dtype( + device_array: iree.runtime.DeviceArray, dtype: np.dtype + ) -> iree.runtime.DeviceArray: + return iree.runtime.DeviceArray( + device=device_array._device, + buffer_view=reinterpret_hal_buffer_view_element_type( + device_array._buffer_view, + iree.runtime.array_interop.map_dtype_to_element_type(dtype), + ), + ) + + # Circumvent the lack of bfloat16 in numpy. + # TODO: This uses private fields _device and _buffer_view in iree.runtime.DeviceArray. + # Improve DeviceArray to provide a hatchet to allow for reinterpretation of + # element type of the underlying buffer. + def bfloat16_device_array_to_torch( + device_array: iree.runtime.DeviceArray, + ) -> torch.Tensor: + device_array_as_int16 = reinterpret_device_array_dtype(device_array, np.int16) + torch_tensor_as_int16 = torch.tensor(device_array_as_int16.to_host()) + return torch_tensor_as_int16.view(dtype=torch.bfloat16) + + if device_array._buffer_view.element_type == int( + iree.runtime.HalElementType.BFLOAT_16 + ): + return bfloat16_device_array_to_torch(device_array) + else: + return torch.tensor(device_array.to_host()) + + def run_iree_module_function( module: iree.runtime.VmModule, vm_context: iree.runtime.VmContext, @@ -88,9 +136,13 @@ def run_iree_module_function( device=iree.runtime.get_device(driver, cache=False), vm_function=vm_function, ) + if trace_path_prefix is not None: for i, arg in enumerate(args): - np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host()) + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + ) results = invoker(*args) if isinstance(results, iree.runtime.DeviceArray): results = (results,) @@ -99,10 +151,13 @@ def run_iree_module_function( for i, arg in enumerate(args): np.save( f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", - arg.to_host(), + device_array_to_host(arg).numpy(), ) for i, arg in enumerate(results): - np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host()) + np.save( + f"{trace_path_prefix}{function_name}_result{i}.npy", + promote_bfloat16_to_float32(device_array_to_host(arg)).numpy(), + ) return results @@ -158,7 +213,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - arg.to("cpu").numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).numpy(), ) res = getattr(module, function_name)(**kwargs) if trace_path_prefix is not None: @@ -166,7 +221,7 @@ def call_torch_module_function( for i, arg in enumerate(flat_args): np.save( f"{trace_path_prefix}{function_name}_arg{i}.npy", - arg.to("cpu").numpy(), + promote_bfloat16_to_float32(arg.to("cpu")).numpy(), ) results = ( (res,) @@ -189,4 +244,4 @@ def call_torch_module_function( def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: - return [torch.tensor(tensor.to_host()) for tensor in tensors] + return [device_array_to_host(tensor) for tensor in tensors] diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 933bfd2b6..32acec8ac 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional import contextlib from pathlib import Path import os @@ -20,7 +21,7 @@ # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values -def make_rand_torch(shape, dtype=torch.float32): +def make_rand_torch(shape: list[int], dtype: Optional[torch.dtype] = torch.float32): return torch.rand(shape, dtype=dtype) * 2 - 1 diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py index 076404e5d..1a696ba57 100644 --- a/sharktank/tests/models/t5/t5_test.py +++ b/sharktank/tests/models/t5/t5_test.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from transformers.models.t5.modeling_t5 import ( T5Attention as ReferenceT5Attention, T5LayerSelfAttention as ReferenceT5LayerSelfAttention, @@ -14,13 +15,21 @@ T5EncoderModel as ReferenceT5EncoderModel, T5Config as ReferenceT5Config, ) +from typing import Optional import os from collections import OrderedDict import pytest import torch +from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path from unittest import TestCase from parameterized import parameterized -from sharktank.types import Theta, DefaultPrimitiveTensor, unbox_tensor, Dataset +from sharktank.types import ( + Theta, + DefaultPrimitiveTensor, + unbox_tensor, + Dataset, + dtype_to_serialized_short_name, +) from sharktank.models.t5 import ( T5Attention, T5SelfAttention, @@ -41,6 +50,8 @@ flatten_for_iree_signature, iree_to_torch, ) +from sharktank.transforms.dataset import set_float_dtype +from sharktank import ops import iree.compiler with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')") @@ -67,45 +78,210 @@ def setUp(self): torch.random.manual_seed(12345) torch.no_grad() - def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace( - self, huggingface_repo_id: str + @with_t5_data + def testXxlBf16AgainstFluxGolden(self): + """The ground-truth values were acquired from the Flux pipeline.""" + target_model_name = ( + f"{'google/t5-v1_1-xxl'.replace('/', '__').replace('-', '_')}_f32_model" + ) + target_model_path = getattr(self, target_model_name) + dataset = Dataset.load(target_model_path) + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=torch.bfloat16) + ) + config = T5Config.from_gguf_properties( + dataset.properties, + feed_forward_proj="gated-gelu", + ) + model = T5Encoder(theta=dataset.root_theta, config=config) + model.eval() + + with open( + "/data/t5/xxl/flux_schnell_t5_v1_1_xxl_encoder_bf16_input_ids.pt", "rb" + ) as f: + reference_input_ids = torch.load(f) + + outputs = model( + input_ids=reference_input_ids, + attention_mask=None, + output_hidden_states=False, + ) + + with open( + "/data/t5/xxl/flux_schnell_t5_v1_1_xxl_encoder_bf16_output_last_hidden_state.pt", + "rb", + ) as f: + reference_last_hidden_state = torch.load(f) + + torch.testing.assert_close( + outputs["last_hidden_state"], reference_last_hidden_state + ) + + def runTestV1_1CompareTorchEagerHuggingFace( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, ): get_dataset( huggingface_repo_id, ).download() tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) - reference_model = ReferenceT5EncoderModel.from_pretrained(huggingface_repo_id) + reference_model = ReferenceT5EncoderModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) reference_model.eval() + model = ReferenceT5EncoderModel.from_pretrained( + huggingface_repo_id, torch_dtype=target_dtype + ) + model.eval() + input_ids = tokenizer( test_prompts, return_tensors="pt", padding=True, + pad_to_multiple_of=16, ).input_ids + expected_outputs = dict(reference_model(input_ids=input_ids)) + actual_outputs = dict(model(input_ids=input_ids)) + actual_outputs = tree_map( + lambda t: ops.to(t, dtype=reference_dtype), actual_outputs + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + def runTestV1_1CompareTorchEagerAgainstHuggingFace( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + get_dataset( + huggingface_repo_id, + ).download() + tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) + reference_model = ReferenceT5EncoderModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) + reference_model.eval() + target_model_name = ( - f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_fp32_model" + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_f32_model" ) target_model_path = getattr(self, target_model_name) dataset = Dataset.load(target_model_path) + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_dtype) + ) config = T5Config.from_gguf_properties( dataset.properties, feed_forward_proj="gated-gelu", ) + + input_ids = tokenizer( + test_prompts, + return_tensors="pt", + padding=True, + pad_to_multiple_of=config.context_length_padding_block_size, + ).input_ids + model = T5Encoder(theta=dataset.root_theta, config=config) model.eval() expected_outputs = reference_model(input_ids=input_ids) actual_outputs = model(input_ids=input_ids) - torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + actual_outputs = tree_map( + lambda t: ops.to(t, dtype=reference_dtype), actual_outputs + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but for XXL we get the same result as the Flux pipeline. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1SmallCompareTorchEagerHuggingFaceBf16AgainstF32(self): + self.runTestV1_1CompareTorchEagerHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) + + @with_t5_data + def testV1_1SmallF32CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.float32, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but for XXL we get the same result as the Flux pipeline. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) @with_t5_data - def testV1_1SmallFp32CompareTorchEagerAgainstHuggingFace(self): - self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-small") + def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-small", + reference_dtype=torch.bfloat16, + target_dtype=torch.bfloat16, + ) @with_t5_data - def testV1_1XxlFp32CompareTorchEagerAgainstHuggingFace(self): - self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-xxl") + def testV1_1XxlF32CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.float32, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, but we get the same result as the Flux pipeline. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32(self): + self.runTestV1_1CompareTorchEagerAgainstHuggingFace( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) @pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix") @@ -115,14 +291,14 @@ def setUp(self): if self.path_prefix is None: self.path_prefix = f"{self._temp_dir}/" - @parameterized.expand( - [ - "google/t5-v1_1-small", - "google/t5-v1_1-xxl", - ] - ) - @with_t5_data - def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): + def runTestV1_1CompareIreeAgainstTorchEager( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): get_dataset( huggingface_repo_id, ).download() @@ -131,12 +307,15 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): huggingface_repo_id_as_path = ( f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" ) - source_model_name = f"{huggingface_repo_id_as_path}_fp32_model" + source_model_name = f"{huggingface_repo_id_as_path}_f32_model" source_model_path = getattr(self, source_model_name) - dataset = Dataset.load(source_model_path) + reference_dataset = Dataset.load(source_model_path) + reference_dataset.root_theta = reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=reference_dtype) + ) config = T5Config.from_gguf_properties( - dataset.properties, + reference_dataset.properties, feed_forward_proj="gated-gelu", ) @@ -149,24 +328,31 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - reference_model = T5Encoder(theta=dataset.root_theta, config=config) - reference_result = flatten_for_iree_signature( - call_torch_module_function( - module=reference_model, - function_name="forward", - kwargs=input_args, - trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_torch_", - ) + reference_dtype_name = dtype_to_serialized_short_name(reference_dtype) + target_dtype_name = dtype_to_serialized_short_name(target_dtype) + target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{target_dtype_name}" + + reference_model = T5Encoder(theta=reference_dataset.root_theta, config=config) + reference_result_dict = call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_{reference_dtype_name}_torch_", ) + reference_result = flatten_for_iree_signature(reference_result_dict) - mlir_path = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.mlir" + parameters_path = f"{target_model_path_prefix}.irpa" + if not self.caching or not os.path.exists(parameters_path): + export_encoder_iree_parameters( + source_model_path, parameters_path, dtype=target_dtype + ) + + mlir_path = f"{target_model_path_prefix}.mlir" if not self.caching or not os.path.exists(mlir_path): export_encoder_mlir( - source_model_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path ) - iree_module_path = ( - f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.vmfb" - ) + iree_module_path = f"{target_model_path_prefix}.vmfb" if not self.caching or not os.path.exists(iree_module_path): iree.compiler.compile_file( mlir_path, @@ -174,12 +360,6 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], ) - parameters_path = ( - f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.irpa" - ) - if not self.caching or not os.path.exists(parameters_path): - export_encoder_iree_parameters(source_model_path, parameters_path) - iree_devices = get_iree_devices(driver="hip", device_count=1) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( module_path=iree_module_path, @@ -196,12 +376,70 @@ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): args=iree_args, driver="hip", function_name=f"forward_bs{batch_size}", - trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_iree_", + trace_path_prefix=f"{target_model_path_prefix}_iree_", ) ) + iree_result = [ + ops.to(iree_result[i], dtype=reference_result[i].dtype) + for i in range(len(reference_result)) + ] - torch.testing.assert_close( - reference_result, iree_result, atol=1e-4, rtol=2.0e-3 + torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol) + + @with_t5_data + def testV1_1CompareSmallIreeF32AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-4, + rtol=2.0e-3, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but but it is no worse than the accuracy for of eager bfloat16. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1CompareSmallIreeBf16AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-small", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) + + @with_t5_data + def testV1_1CompareXxlIreeF32AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-4, + rtol=2.0e-3, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason=( + "The accuracy is bad, " + "but but it is no worse than the accuracy for of eager bfloat16. " + "This need further investigation how Flux works at all like that." + ), + ) + @with_t5_data + def testV1_1CompareXxlIreeBf16AgainstTorchEagerF32(self): + self.runTestV1_1CompareIreeAgainstTorchEager( + "google/t5-v1_1-xxl", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, ) @@ -211,8 +449,21 @@ def setUp(self): torch.random.manual_seed(12345) torch.no_grad() - def testCompareAgainstTransformersFp32(self): - dtype = torch.float32 + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) batch_size = 19 batch_seq_len = 23 reference_config = ReferenceT5Config( @@ -233,19 +484,21 @@ def testCompareAgainstTransformersFp32(self): theta = Theta( { "attn_q.weight": DefaultPrimitiveTensor( - data=reference_model.q.weight.data + data=reference_model.q.weight.to(dtype=target_dtype) ), "attn_k.weight": DefaultPrimitiveTensor( - data=reference_model.k.weight.data + data=reference_model.k.weight.to(dtype=target_dtype) ), "attn_v.weight": DefaultPrimitiveTensor( - data=reference_model.v.weight.data + data=reference_model.v.weight.to(dtype=target_dtype) ), "attn_o.weight": DefaultPrimitiveTensor( - data=reference_model.o.weight.data + data=reference_model.o.weight.to(dtype=target_dtype) ), "attn_rel_b.weight": DefaultPrimitiveTensor( - data=reference_model.relative_attention_bias.weight.data + data=reference_model.relative_attention_bias.weight.to( + dtype=target_dtype + ) ), } ) @@ -257,24 +510,52 @@ def testCompareAgainstTransformersFp32(self): d_model=reference_config.d_model, d_kv=reference_config.d_kv, num_heads=reference_config.num_heads, - activation_dtype=dtype, + activation_dtype=target_dtype, has_relative_attention_bias=True, ) model.eval() - hidden_states = make_rand_torch( - shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + reference_hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], + dtype=reference_dtype, + ) + reference_mask = make_random_mask( + shape=[batch_size, 1, 1, batch_seq_len], dtype=reference_dtype ) - mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype) - expected_outputs = reference_model(hidden_states=hidden_states, mask=mask) + expected_outputs = reference_model( + hidden_states=reference_hidden_states, mask=reference_mask + ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) + mask = ops.to(reference_mask, dtype=target_dtype) actual_outputs = model( hidden_states=DefaultPrimitiveTensor(data=hidden_states), mask=DefaultPrimitiveTensor(data=mask), ) - torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) - def testCompareSelfAttentionAgainstTransformersFp32(self): - dtype = torch.float32 + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) + + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareSelfAttentionAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) batch_size = 19 batch_seq_len = 23 reference_config = ReferenceT5Config( @@ -296,22 +577,24 @@ def testCompareSelfAttentionAgainstTransformersFp32(self): theta = Theta( { "attn_q.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.q.weight.data + data=reference_model.SelfAttention.q.weight.to(dtype=target_dtype) ), "attn_k.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.k.weight.data + data=reference_model.SelfAttention.k.weight.to(dtype=target_dtype) ), "attn_v.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.v.weight.data + data=reference_model.SelfAttention.v.weight.to(dtype=target_dtype) ), "attn_o.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.o.weight.data + data=reference_model.SelfAttention.o.weight.to(dtype=target_dtype) ), "attn_rel_b.weight": DefaultPrimitiveTensor( - data=reference_model.SelfAttention.relative_attention_bias.weight.data + data=reference_model.SelfAttention.relative_attention_bias.weight.to( + dtype=target_dtype + ) ), "attn_norm.weight": DefaultPrimitiveTensor( - data=reference_model.layer_norm.weight.data + data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } ) @@ -323,24 +606,37 @@ def testCompareSelfAttentionAgainstTransformersFp32(self): d_model=reference_config.d_model, d_kv=reference_config.d_kv, num_heads=reference_config.num_heads, - activation_dtype=dtype, + activation_dtype=torch.float32, layer_norm_epsilon=reference_config.layer_norm_epsilon, has_relative_attention_bias=True, ) model.eval() - hidden_states = make_rand_torch( - shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + reference_hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], + dtype=reference_dtype, ) - mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype) - position_bias = make_rand_torch( - shape=[batch_size, reference_config.num_heads, batch_seq_len, batch_seq_len] + reference_mask = make_random_mask( + shape=[batch_size, 1, 1, batch_seq_len], dtype=reference_dtype + ) + reference_position_bias = make_rand_torch( + shape=[ + batch_size, + reference_config.num_heads, + batch_seq_len, + batch_seq_len, + ], + dtype=reference_dtype, ) expected_outputs = reference_model( - hidden_states=hidden_states, - attention_mask=mask, - position_bias=position_bias, + hidden_states=reference_hidden_states, + attention_mask=reference_mask, + position_bias=reference_position_bias, ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) + mask = ops.to(reference_mask, dtype=target_dtype) + position_bias = ops.to(reference_position_bias, dtype=target_dtype) actual_outputs = model( hidden_states=DefaultPrimitiveTensor(data=hidden_states), attention_mask=DefaultPrimitiveTensor(data=mask), @@ -349,7 +645,14 @@ def testCompareSelfAttentionAgainstTransformersFp32(self): actual_outputs = [ unbox_tensor(t) if t is not None else t for t in actual_outputs ] - torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + actual_outputs = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_outputs, + ) + + torch.testing.assert_close( + actual_outputs, expected_outputs, atol=atol, rtol=rtol + ) class T5LayerFFTest(TestCase): @@ -358,8 +661,21 @@ def setUp(self): torch.random.manual_seed(12345) torch.no_grad() - def testCompareAgainstTransformersFp32(self): - dtype = torch.float32 + @parameterized.expand( + [ + [torch.float32, torch.float32], + [torch.bfloat16, torch.bfloat16], + [torch.float32, torch.bfloat16, 1e-2, 1.6e-2], + ] + ) + def testCompareAgainstTransformers( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + torch.set_default_dtype(reference_dtype) batch_size = 19 batch_seq_len = 23 reference_config = ReferenceT5Config( @@ -376,16 +692,20 @@ def testCompareAgainstTransformersFp32(self): theta = Theta( { "ffn_gate.weight": DefaultPrimitiveTensor( - data=reference_model.DenseReluDense.wi_0.weight + data=reference_model.DenseReluDense.wi_0.weight.to( + dtype=target_dtype + ) ), "ffn_up.weight": DefaultPrimitiveTensor( - data=reference_model.DenseReluDense.wi_1.weight + data=reference_model.DenseReluDense.wi_1.weight.to( + dtype=target_dtype + ) ), "ffn_down.weight": DefaultPrimitiveTensor( - data=reference_model.DenseReluDense.wo.weight + data=reference_model.DenseReluDense.wo.weight.to(dtype=target_dtype) ), "ffn_norm.weight": DefaultPrimitiveTensor( - data=reference_model.layer_norm.weight + data=reference_model.layer_norm.weight.to(dtype=target_dtype) ), } ) @@ -394,17 +714,24 @@ def testCompareAgainstTransformersFp32(self): is_gated_act=reference_config.is_gated_act, dense_act_fn=reference_config.dense_act_fn, layer_norm_epsilon=reference_config.layer_norm_epsilon, - activation_dtype=dtype, + activation_dtype=torch.float32, ) - hidden_states = make_rand_torch( - shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + reference_hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], + dtype=reference_dtype, ) - expected_output = reference_model( - hidden_states=hidden_states, + hidden_states=reference_hidden_states, ) + + hidden_states = ops.to(reference_hidden_states, dtype=target_dtype) actual_output = model( hidden_states=DefaultPrimitiveTensor(data=hidden_states), ) - torch.testing.assert_close(actual_output, expected_output, atol=1e-5, rtol=0) + actual_output = tree_map( + lambda t: None if t is None else ops.to(t, dtype=reference_dtype), + actual_output, + ) + + torch.testing.assert_close(actual_output, expected_output, atol=atol, rtol=rtol) From 9e1e1217898da3353b4477783d423a112fdcf1dc Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Tue, 26 Nov 2024 18:29:23 -0600 Subject: [PATCH 19/25] Fix linear (#613) Fixes punet regression from linear.py I went a little crazy with the fake-quant arg and the logic was ugly. Now the fake_quant arg only affects q_input. The purpose in having it at all is to allow exporting and fake eager using the same irpa file --------- Co-authored-by: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> --- .github/workflows/ci-sharktank.yml | 39 +++++++++++++++++++ .../models/punet/integration_test.py | 7 ++-- sharktank/sharktank/layers/linear.py | 21 +++++----- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 1d3960b43..6e8cee3bb 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -22,6 +22,45 @@ concurrency: cancel-in-progress: true jobs: + test_punet: + name: "Integration Tests - punet" + runs-on: nodai-amdgpu-mi250-x86-64 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.11 + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Update to the latest iree packages. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + - name: Run punet tests + run: | + pytest -v sharktank/ -m model_punet + test: name: "Unit Tests and Type Checking" strategy: diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py index 182b37a50..45af24004 100644 --- a/sharktank/integration/models/punet/integration_test.py +++ b/sharktank/integration/models/punet/integration_test.py @@ -89,12 +89,13 @@ def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir): def sdxl_int8_base_files(): from huggingface_hub import hf_hub_download - REPO_ID = "amd-shark/sdxl-quant-models" - REVISION = "942e771bf0c2657a8b33380103d04747a75dfa4a" + REPO_ID = "amd-shark/sdxl-quant-int8" + SUBFOLDER = "mi300_all_sym_8_step14_fp32" + REVISION = "efda8afb35fd72c1769e02370b320b1011622958" def download(filename): return hf_hub_download( - repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION + repo_id=REPO_ID, subfolder=SUBFOLDER, filename=filename, revision=REVISION ) return { diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index b679dccde..acd9b8a37 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -31,9 +31,8 @@ class LinearLayer(ThetaLayer): x = x * premul_input matmul(x, weight.T) + bias - fake_quant exists to allow export without adding dequant ops. - when fake_quant is True, the op will in quant dequant fashion. - When false, it will keep quantized types. + fake quant only exists in order to allow for q_input to act as qdq. + when fake quant is false, q_input will quantize normally. ``` """ @@ -43,7 +42,7 @@ def __init__( *, weight_name: str = "weight", bias_name: str = "bias", - fake_quant: bool = True, + fake_quant: bool = False, ): super().__init__(theta) self._simulate_native_quant = True @@ -74,21 +73,23 @@ def forward(self, x): x = q_input.quantize(x) if self.fake_quant: x = x.unpack().dequant() - elif qdq_input is not None and self.fake_quant: + + elif qdq_input is not None: x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) # Unconditionally dequantize. - if isinstance(y, QuantizedTensor) and not self.fake_quant: + if isinstance(y, QuantizedTensor): y = y.unpack().dequant() # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.float16) - return y - if qdq_output is not None and self.fake_quant: + if not isinstance(y, QuantizedTensor): + if y.dtype == torch.float8_e4m3fnuz: + y = ops.to(y, torch.float16) + return y + if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y From cdb4ccd2e6682223d8321892d3ee9707528f68a2 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Nov 2024 19:01:29 -0800 Subject: [PATCH 20/25] [shortfin] Add C++ tokenizer wrapper library. (#610) * This is gated by SHORTFIN_ENABLE_TOKENIZERS (presently off). * I'd like to either take over the wrapper or get https://github.com/mlc-ai/tokenizers-cpp/issues/50 before putting much weight on this. * There is no great C++ option for this component, so we go to the trouble of integrating a Rust component. We will need to do a bit of prep on our CI systems to enable this by default. * Python API will be added in a subsequent commit. This should be more efficient than the tokenizers Python API since we will allow direct access to the tokens vs doing a lot of conversions. * Size analysis: Prior to this patch, libshortfin was 1.8MB, which gave us an entire GPU and CPU runtime stack. After this patch (stripped) it is 8.4MB. Given how important the use case is, I'm willing to tolerate this for the moment. It seems like there is room for something better here, which is why I did not expose the underlying vendor'd API directly (edit: by switching to a nightly rust and activating a bunch of options from https://github.com/johnthagen/min-sized-rust, I was able to produce a binary that is 4.2MB, which is more reasonable). --- shortfin/CMakeLists.txt | 77 ++++++++++++++++++- .../build_tools/cmake/shortfin_library.cmake | 5 +- .../build_tools/cmake/shortfin_testing.cmake | 50 ++++++++++++ shortfin/setup.py | 1 + shortfin/src/shortfin/CMakeLists.txt | 1 + .../components/tokenizers/CMakeLists.txt | 41 ++++++++++ .../components/tokenizers/tokenizers.cc | 63 +++++++++++++++ .../components/tokenizers/tokenizers.h | 52 +++++++++++++ .../components/tokenizers/tokenizers_test.cc | 56 ++++++++++++++ 9 files changed, 341 insertions(+), 5 deletions(-) create mode 100644 shortfin/build_tools/cmake/shortfin_testing.cmake create mode 100644 shortfin/src/shortfin/components/tokenizers/CMakeLists.txt create mode 100644 shortfin/src/shortfin/components/tokenizers/tokenizers.cc create mode 100644 shortfin/src/shortfin/components/tokenizers/tokenizers.h create mode 100644 shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index f025eccfe..16baa1675 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -48,6 +48,7 @@ option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON) option(SHORTFIN_ENABLE_TRACING "Enable runtime tracing for iree and shortfin" OFF) option(SHORTFIN_ENABLE_LTO "Enables LTO if supported" ON) +option(SHORTFIN_ENABLE_TOKENIZERS "Enables integration of native tokenizers library" OFF) set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") @@ -80,6 +81,7 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/ ) include(shortfin_library) +include(shortfin_testing) include(CheckCXXCompilerFlag) include(FetchContent) @@ -90,7 +92,9 @@ include(FetchContent) if(SHORTFIN_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT SHORTFIN_LTO_SUPPORTED OUTPUT SHORTFIN_LTO_ERROR) - if(SHORTFIN_LTO_SUPPORTED) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + message(STATUS "Not enabling LTO for debug build") + elseif(SHORTFIN_LTO_SUPPORTED) message(STATUS "Shortfin LTO Enabled") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) else() @@ -126,7 +130,9 @@ endif() message(STATUS " - Host") ################################################################################ -# Dependencies +# Bundled Dependencies +# These dependencies are either bundled or used via installed packages based +# on the SHORTFIN_BUNDLE_DEPS option. ################################################################################ if(SHORTFIN_BUNDLE_DEPS) @@ -164,15 +170,19 @@ if(SHORTFIN_BUNDLE_DEPS) shortfin_push_bundled_lib_options() # Enable spdlog shared library options so we can export it. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPDLOG_SHARED_LIB -Dspdlog_EXPORTS") + message(STATUS "Fetching bundled projects") + list(APPEND CMAKE_MESSAGE_INDENT " ") FetchContent_MakeAvailable(fmt spdlog xtl xtensor) shortfin_pop_bundled_lib_options() + list(POP_BACK CMAKE_MESSAGE_INDENT) else() find_package(spdlog) find_package(xtensor) endif() ################################################################################ -# IREE +# IREE Dependency +# This is always a source dependency on the IREE runtime. ################################################################################ # Set IREE build flags. @@ -237,6 +247,65 @@ else() endif() shortfin_pop_bundled_lib_options() +################################################################################ +# Tokenizer Library +################################################################################ + +function(shortfin_check_tokenizers) + # Make sure that rust/cargo is installed and usable. + # Consider switching this to a cached variable once the tokenizers_cpp project + # will accept an override vs running whatever is on the path. For now, just + # verify the path is sane as that is what will get used. + find_program(SHORTFIN_CARGO_PATH NAMES cargo NO_CACHE) + if(NOT SHORTFIN_CARGO_PATH) + message(SEND_ERROR + "Building with -DSHORTFIN_ENABLE_TOKENIZERS=ON requires cargo (Rust's build tool). " + "Please follow Rust documentation to install. On Ubuntu, this can typically be accomplished with:\n" + " sudo apt install rustup && rustup default stable\n" + "See https://www.rust-lang.org/tools/install" + ) + endif() + + # Make sure cargo is functional. + execute_process( + COMMAND ${SHORTFIN_CARGO_PATH} + RESULT_VARIABLE _CARGO_RESULT + OUTPUT_VARIABLE _CARGO_OUT + ERROR_VARIABLE _CARGO_ERR + ) + if(NOT "${_CARGO_RESULT}" STREQUAL "0") + message(SEND_ERROR + "Building with -DSHORTFIN_ENABLE_TOKENIZERS=ON requires cargo (Rust's build tool) " + "to be configured properly. It was found (${SHORTFIN_CARGO_PATH}) but returned an " + "error. Output below:\n" + "${_CARGO_OUT}\n" + "${_CARGO_ERR}" + ) + endif() +endfunction() + +if(SHORTFIN_ENABLE_TOKENIZERS) + # TODO: submit a patch to tokenizers_cpp to allow explicit configuration of the + # cargo location and pass that vs relying on environmental alignment. + shortfin_check_tokenizers() + + shortfin_push_bundled_lib_options() + set(CMAKE_C_VISIBILITY_PRESET "hidden") + set(CMAKE_CXX_VISIBILITY_PRESET "hidden") + set(CMAKE_VISIBILITY_INLINES_HIDDEN ON) + set(MLC_ENABLE_SENTENCEPIECE_TOKENIZER OFF) + + FetchContent_Declare( + tokenizers_cpp # From CMake project() declaration + GIT_REPOSITORY https://github.com/mlc-ai/tokenizers-cpp.git + GIT_TAG 4bb753377680e249345b54c6b10e6d0674c8af03 # 2024 Nov 15 + EXCLUDE_FROM_ALL + ) + message(STATUS "Fetching tokenizers_cpp") + FetchContent_MakeAvailable(tokenizers_cpp) + shortfin_pop_bundled_lib_options() +endif() + ################################################################################ # Tests ################################################################################ @@ -254,9 +323,9 @@ if(SHORTFIN_BUILD_TESTS) endif() include(GoogleTest) enable_testing() + add_custom_target(shortfin_testdata_deps) endif() - add_subdirectory(src) if(SHORTFIN_BUILD_PYTHON_BINDINGS) diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake index aaa97a6c1..103fdf1c5 100644 --- a/shortfin/build_tools/cmake/shortfin_library.cmake +++ b/shortfin/build_tools/cmake/shortfin_library.cmake @@ -182,7 +182,10 @@ function(shortfin_gtest_test) GTest::gmock GTest::gtest_main ) - gtest_discover_tests(${_RULE_NAME}) + gtest_discover_tests( + ${_RULE_NAME} + WORKING_DIRECTORY "${libshortfin_BINARY_DIR}" + ) endfunction() diff --git a/shortfin/build_tools/cmake/shortfin_testing.cmake b/shortfin/build_tools/cmake/shortfin_testing.cmake new file mode 100644 index 000000000..e462b7023 --- /dev/null +++ b/shortfin/build_tools/cmake/shortfin_testing.cmake @@ -0,0 +1,50 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Downloads some test data file as part of configure. +# This does a download->rename in an attempt to be robust to partial downloads. +# It should not be used to manage large test data files or anything sensitive +# enough to require a hash check. +# The output file is added as an additional clean file on the global +# shortfin_testdata_deps target, meaning the "ninja clean" will remove it. +# It is also added to the current directories list of configure depends, which +# means that if ninja is run and it is not present, cmake will be re-invoked. +function(shortfin_download_test_data) + cmake_parse_arguments( + _RULE + "" + "URL;OUTPUT_FILE" + "" + ${ARGN} + ) + if(NOT SHORTFIN_BUILD_TESTS) + return() + endif() + if(NOT EXISTS "${_RULE_OUTPUT_FILE}") + set(_stage_file "${_RULE_OUTPUT_FILE}.stage") + message(STATUS "Downloading test data ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}") + file(DOWNLOAD "${_RULE_URL}" "${_stage_file}" STATUS _status) + list(POP_FRONT _status _status_code) + if(_status_code EQUAL "0") + file(RENAME "${_stage_file}" "${_RULE_OUTPUT_FILE}") + else() + message(SEND_ERROR "Error downloading file ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}") + endif() + endif() + + # Make clean remove it. + set_property( + TARGET shortfin_testdata_deps + APPEND PROPERTY ADDITIONAL_CLEAN_FILES + "${_RULE_OUTPUT_FILE}" + ) + + # And make us reconfigure if it isn't there. + set_property( + DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + APPEND PROPERTY + CMAKE_CONFIGURE_DEPENDS "${_RULE_OUTPUT_FILE}") +endfunction() diff --git a/shortfin/setup.py b/shortfin/setup.py index cf3762950..e15b38d89 100644 --- a/shortfin/setup.py +++ b/shortfin/setup.py @@ -225,6 +225,7 @@ def build_cmake_configuration(CMAKE_BUILD_DIR: Path, extra_cmake_args=[]): add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_LTO", default_value="ON") add_env_cmake_setting(cmake_args, "SHORTFIN_IREE_SOURCE_DIR") add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_ASAN") + add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_TOKENIZERS", default_value="OFF") # Only do a from-scratch configure if not already configured. cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt") diff --git a/shortfin/src/shortfin/CMakeLists.txt b/shortfin/src/shortfin/CMakeLists.txt index 058e0e336..73df08e7c 100644 --- a/shortfin/src/shortfin/CMakeLists.txt +++ b/shortfin/src/shortfin/CMakeLists.txt @@ -5,5 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(array) +add_subdirectory(components/tokenizers) add_subdirectory(local) add_subdirectory(support) diff --git a/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt b/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt new file mode 100644 index 000000000..6b9f794b1 --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +if(NOT SHORTFIN_ENABLE_TOKENIZERS) + return() +endif() + +shortfin_cc_component( + NAME + shortfin_tokenizers + HDRS + tokenizers.h + SRCS + tokenizers.cc + DEFINES + SHORTFIN_HAVE_TOKENIZERS + COMPONENTS + shortfin_support + DEPS + tokenizers_cpp +) +set_property(GLOBAL APPEND + PROPERTY SHORTFIN_LIB_OPTIONAL_COMPONENTS + shortfin_tokenizers) +target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_TOKENIZERS) + +# Download test data. +shortfin_download_test_data( + URL "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json" + OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/tokenizer.json" +) + +# Note that tests run from the binary dir of the project. +shortfin_gtest_test( + NAME shortfin_tokenizers_test + SRCS + tokenizers_test.cc +) diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers.cc b/shortfin/src/shortfin/components/tokenizers/tokenizers.cc new file mode 100644 index 000000000..118bc0c1b --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/tokenizers.cc @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/components/tokenizers/tokenizers.h" + +#include + +#include "shortfin/support/logging.h" +#include "tokenizers_cpp.h" + +namespace shortfin::tokenizers { + +namespace { + +class AccessibleTokenizer : public Tokenizer { + public: + using Tokenizer::vendor_tokenizer_; +}; + +::tokenizers::Tokenizer *Get(Tokenizer *self) { + void *ptr = static_cast(self)->vendor_tokenizer_; + if (!ptr) { + throw std::logic_error("Tokenizer is null"); + } + return static_cast<::tokenizers::Tokenizer *>(ptr); +} + +} // namespace + +Tokenizer::~Tokenizer() { delete Get(this); } + +Tokenizer Tokenizer::FromBlobJSON(const std::string &json_blob) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::FromBlobJSON"); + return Tokenizer(::tokenizers::Tokenizer::FromBlobJSON(json_blob).release()); +} + +std::vector Tokenizer::Encode(const std::string &text) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Encode"); + return Get(this)->Encode(text); +} + +std::vector> Tokenizer::EncodeBatch( + const std::vector &texts) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::EncodeBatch"); + return Get(this)->EncodeBatch(texts); +} + +std::string Tokenizer::Decode(const std::vector &ids) { + SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Decode"); + return Get(this)->Decode(ids); +} +size_t Tokenizer::GetVocabSize() { return Get(this)->GetVocabSize(); } +std::string Tokenizer::IdToToken(int32_t token_id) { + return Get(this)->IdToToken(token_id); +} +int32_t Tokenizer::TokenToId(const std::string &token) { + return Get(this)->TokenToId(token); +} + +} // namespace shortfin::tokenizers diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers.h b/shortfin/src/shortfin/components/tokenizers/tokenizers.h new file mode 100644 index 000000000..d263eace6 --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/tokenizers.h @@ -0,0 +1,52 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H +#define SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H + +#include +#include + +#include "shortfin/support/api.h" + +namespace shortfin::tokenizers { + +// A vendored Tokenizer class that does not export the details of the backing +// implementation. While a little bit gross, this keeps us from needing to +// re-export a vendor'ed API as part of our public API. +// The current vendor tokenizer is based on mlc-ai/tokenizers-cpp. The API +// is fairly close to that implementation. +// See: https://github.com/mlc-ai/tokenizers-cpp +class SHORTFIN_API Tokenizer { + public: + Tokenizer(const Tokenizer &) = delete; + Tokenizer &operator=(const Tokenizer &) = delete; + Tokenizer(Tokenizer &&other) : vendor_tokenizer_(other.vendor_tokenizer_) { + vendor_tokenizer_ = nullptr; + } + ~Tokenizer(); + + // Factory functions. + static Tokenizer FromBlobJSON(const std::string &json_blob); + + std::vector Encode(const std::string &text); + std::vector> EncodeBatch( + const std::vector &texts); + std::string Decode(const std::vector &ids); + size_t GetVocabSize(); + std::string IdToToken(int32_t token_id); + int32_t TokenToId(const std::string &token); + + private: + Tokenizer(void *vendor_tokenizer) : vendor_tokenizer_(vendor_tokenizer) {} + + protected: + void *vendor_tokenizer_; +}; + +} // namespace shortfin::tokenizers + +#endif // SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H diff --git a/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc b/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc new file mode 100644 index 000000000..674721653 --- /dev/null +++ b/shortfin/src/shortfin/components/tokenizers/tokenizers_test.cc @@ -0,0 +1,56 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/components/tokenizers/tokenizers.h" + +#include +#include + +#include +#include + +using namespace shortfin::tokenizers; + +namespace { + +std::string ReadFile(std::filesystem::path path) { + std::ifstream in(path); + std::ostringstream out; + out << in.rdbuf(); + return out.str(); +} + +} // namespace + +// TODO: Enable once upstream changes with error handling have landed. +// Currently aborts. +// See: https://github.com/mlc-ai/tokenizers-cpp/issues/50 +// TEST(TokenizersTest, FromIllegalBlobJson) { +// auto tok = Tokenizer::FromBlobJSON("foobar"); +// } + +TEST(TokenizersTest, BasicTokenizerJson) { + std::filesystem::path tokenizer_path( + "src/shortfin/components/tokenizers/tokenizer.json"); + auto tokenizer_json = ReadFile(tokenizer_path); + ASSERT_GT(tokenizer_json.size(), 0) + << "reading " << tokenizer_path + << " (cwd: " << std::filesystem::current_path() << ")"; + auto tok = Tokenizer::FromBlobJSON(tokenizer_json); + EXPECT_GT(tok.GetVocabSize(), 100); // Sanity check + auto encoded = tok.Encode("hello world"); + EXPECT_THAT(encoded, + ::testing::ContainerEq(std::vector{19082, 1362})); + auto batch_encoded = tok.EncodeBatch({"hello", "world"}); + ASSERT_EQ(batch_encoded.size(), 2); + EXPECT_THAT(batch_encoded[0], + ::testing::ContainerEq(std::vector{19082})); + EXPECT_THAT(batch_encoded[1], + ::testing::ContainerEq(std::vector{1362})); + EXPECT_EQ(tok.TokenToId("hello"), 19082); + auto decoded = tok.Decode(encoded); + EXPECT_EQ(decoded, "hello world"); +} From d6be43f7276e7c552ff409739b6041942217ed13 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Tue, 26 Nov 2024 19:14:43 -0800 Subject: [PATCH 21/25] Update flash attention op (#616) This commit updates the flash attention op to adhere to the addition of is_causal and scale args added by this commit: https://github.com/nod-ai/SHARK-Platform/commit/2d46caaa932fe41f765315752665e87718aef5b3. Without this, we are seeing a fp8 attention export failure --- sharktank/sharktank/ops/attention_impls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index 8ffb1f51e..d1353daaa 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -47,7 +47,7 @@ def _extract_linear_scale(t): return unbox_tensor(t), None -def flash_attention(q, k, v, a): +def flash_attention(q, k, v, a, is_causal, scale): scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32) q, qscale = _extract_linear_scale(q) From b7201f3f4347d61b15c037d64b1b53be7f6cd472 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:26:51 -0800 Subject: [PATCH 22/25] [kernel] remove transposeV in favor of generic method in iree compile (#618) Now that we have landed https://github.com/iree-org/iree/commit/41115bba05960e563791ce6ed1af26093f4fab1e on top of main, we should not need explicit transpose as part of kernel anymore. Although this PR still maintains the order that reduction dim is the fastest dim per which is done in https://github.com/nod-ai/shark-ai/commit/a7feae8bce788f865548f7f5fda547050f3097a1 Signed-off-by: Stanley Winata --- sharktank/sharktank/kernels/templates/flash_attention.mlir | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sharktank/sharktank/kernels/templates/flash_attention.mlir b/sharktank/sharktank/kernels/templates/flash_attention.mlir index 15d75c372..4085fef9c 100644 --- a/sharktank/sharktank/kernels/templates/flash_attention.mlir +++ b/sharktank/sharktank/kernels/templates/flash_attention.mlir @@ -33,19 +33,16 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_ %scale = tensor.extract %s[] : !s_type - %init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type - %transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2] - %empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type %empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type %atten = iree_linalg_ext.attention {indexing_maps = [ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]} - ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) { + ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> !o_type From 4a14a66631507742c9f14e59d6aad06ea2f90176 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 27 Nov 2024 10:34:27 +0100 Subject: [PATCH 23/25] Update `CI - shortfin` link (#615) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 44f1e6113..ae3eac423 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ If you're looking to use SHARK check out our [User Guide](docs/user_guide.md). F -[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) +[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-libshortfin.yml?query=event%3Apush) The shortfin sub-project is SHARK's high performance inference library and serving engine. From 2ee7676b0b57002e1f815419178ca36d8adcac79 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Wed, 27 Nov 2024 08:54:53 -0800 Subject: [PATCH 24/25] Delete sharktank/serving_poc code that is replaced with shortfin. (#429) I believe this has been fully replaced by https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/llm (hard to tell, docs aren't up to date). If developers are still using it, we can keep it around for a bit longer. --- docs/model_cookbook.md | 2 + sharktank/sharktank/serving_poc/__init__.py | 0 .../serving_poc/framework/logging.py | 48 -- .../serving_poc/framework/session.py | 610 ------------------ .../sharktank/serving_poc/llm/__init__.py | 0 .../serving_poc/llm/api/rest_server.py | 123 ---- .../serving_poc/llm/attn_block_cache.py | 133 ---- sharktank/sharktank/serving_poc/llm/config.py | 181 ------ .../serving_poc/llm/impl/service_v1.py | 495 -------------- .../serving_poc/llm/impl/service_v1_cli.py | 118 ---- .../sharktank/serving_poc/llm/service.py | 189 ------ .../serving_poc/llm/testing/fake_v1_module.py | 118 ---- sharktank/sharktank/serving_poc/py.typed | 1 - .../framework/device_session_test.py | 63 -- .../tests/serving_poc/llm/api_server_test.py | 115 ---- .../tests/serving_poc/llm/service_v1_test.py | 131 ---- 16 files changed, 2 insertions(+), 2325 deletions(-) delete mode 100644 sharktank/sharktank/serving_poc/__init__.py delete mode 100644 sharktank/sharktank/serving_poc/framework/logging.py delete mode 100644 sharktank/sharktank/serving_poc/framework/session.py delete mode 100644 sharktank/sharktank/serving_poc/llm/__init__.py delete mode 100644 sharktank/sharktank/serving_poc/llm/api/rest_server.py delete mode 100644 sharktank/sharktank/serving_poc/llm/attn_block_cache.py delete mode 100644 sharktank/sharktank/serving_poc/llm/config.py delete mode 100644 sharktank/sharktank/serving_poc/llm/impl/service_v1.py delete mode 100644 sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py delete mode 100644 sharktank/sharktank/serving_poc/llm/service.py delete mode 100644 sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py delete mode 100644 sharktank/sharktank/serving_poc/py.typed delete mode 100644 sharktank/tests/serving_poc/framework/device_session_test.py delete mode 100644 sharktank/tests/serving_poc/llm/api_server_test.py delete mode 100644 sharktank/tests/serving_poc/llm/service_v1_test.py diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index ddc0cb3bb..03e625b96 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -210,6 +210,8 @@ iree-compile /tmp/open_llama_3b_v2/open-llama-3b-v2-f16.mlir \ -o /tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb ``` +TODO: replace these instructions with the newer shortfin code + Run via `service_v1_cli.py` (shortfin serving, with tokenizer): * TODO: script (via service CLI?) to dump inputs/outputs to .bin/.npy files diff --git a/sharktank/sharktank/serving_poc/__init__.py b/sharktank/sharktank/serving_poc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sharktank/sharktank/serving_poc/framework/logging.py b/sharktank/sharktank/serving_poc/framework/logging.py deleted file mode 100644 index fe5ffc069..000000000 --- a/sharktank/sharktank/serving_poc/framework/logging.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import os -import sys - - -# Whether debug assertions are disabled. -NDEBUG: bool = False - -_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG") - - -class DefaultFormatter(logging.Formatter): - def __init__(self): - super().__init__( - "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s", - "%m-%d %H:%M:%S", - ) - - -def _setup_logger(): - root_logger = logging.getLogger("sharktank.serving_poc") - root_logger.setLevel(logging.DEBUG) - default_handler = logging.StreamHandler(sys.stderr) - default_handler.flush = sys.stderr.flush - default_handler.setLevel(_default_log_level) - default_handler.setFormatter(DefaultFormatter()) - root_logger.addHandler(default_handler) - root_logger.propagate = False - return root_logger, default_handler - - -root_logger, default_handler = _setup_logger() - -logging.getLogger("asyncio").addHandler(default_handler) - - -def get_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(_default_log_level) - logger.addHandler(default_handler) - logger.propagate = False - return logger diff --git a/sharktank/sharktank/serving_poc/framework/session.py b/sharktank/sharktank/serving_poc/framework/session.py deleted file mode 100644 index 28af0fd44..000000000 --- a/sharktank/sharktank/serving_poc/framework/session.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Runtime session constructs. - -Key concepts: - - * DeviceSession: A single HAL device and other process-level globals. Shared global - memory and corresponding synchronization handles are accessible from here. - * WorkQueue: Logical stream of execution, nested under the DeviceSession. Each - queue holds a timeline semaphore which sequences invocations. For these models, - we route workloads of vastly different characteristics to distinct queues (i.e. - prefill vs decode step). - * LoadedModule: Modules that have been loaded but have not yet been instantiated into - a context. - * HostContext: At least one HostContext is created per LoadedModule. It encapsulates - a VMContext and performs invocations on a dedicated thread. Typically, there will - be more that one HostContext per LoadedModule as it helps us load balance the - host side work across multiple OS threads, ensuring faster feeding of the device. -""" - -from typing import Any, Callable, Coroutine, Generic, TypeVar, Optional, Union - -import asyncio -import concurrent.futures -import math -import queue -from threading import Lock, Thread -import warnings - -import numpy as np - -from iree.runtime import ( # type: ignore[import-untyped] - create_hal_module, - create_io_parameters_module, - get_driver, - BufferUsage, - HalBufferView, - HalCommandBuffer, - HalDevice, - HalDeviceLoopBridge, - HalDriver, - HalElementType, - HalFence, - HalSemaphore, - MemoryType, - ParameterIndex, - VmFunction, - VmInstance, - VmContext, - VmModule, -) - -from .logging import get_logger, NDEBUG - -T = TypeVar("T") - -logger = get_logger("shark_turbine.serving.session") -_CONFIG_LOCK = Lock() -_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None - - -def get_vm_instance() -> VmInstance: - global _GLOBAL_VM_INSTANCE - if not _GLOBAL_VM_INSTANCE: - with _CONFIG_LOCK: - if not _GLOBAL_VM_INSTANCE: - _GLOBAL_VM_INSTANCE = VmInstance() - return _GLOBAL_VM_INSTANCE - - -class DeviceSession: - """Top-level object associated with a single attached device.""" - - __slots__ = [ - "device", - "driver", - "_module_sets", - "queues", - "_queue_request_count", - "vm_instance", - ] - - def __init__( - self, - *, - uri: Optional[str] = None, - driver: Optional[Union[str, HalDriver]] = None, - device: Optional[HalDevice] = None, - vm_instance: Optional[VmInstance] = None, - queue_count: int = 1, - ): - self._queue_request_count = 0 - self.vm_instance = vm_instance or get_vm_instance() - if uri is not None: - assert ( - driver is None and device is None - ), "If 'uri' is given, 'driver' and 'device' cannot be set" - logger.info("Opening device by uri: %s", uri) - driver = uri_driver = get_driver(uri) - device = uri_driver.create_device_by_uri(uri) - assert driver is not None, "'driver' cannot be None" - self.driver = driver if not isinstance(driver, str) else get_driver(driver) - self.device = device if device else self.driver.create_default_device() - - # Dependent objects. - self._module_sets: dict[str, "ModuleSet"] = {} - self.queues = [WorkQueue(self, i) for i in range(queue_count)] - - def shutdown(self): - for ms in self._module_sets.values(): - ms.shutdown() - - def create_module_set(self, name: str, *, context_count: int = 1) -> "ModuleSet": - assert ( - name not in self._module_sets - ), f"Modules with name {name} already created" - lm = ModuleSet(self, name, context_count=context_count) - self._module_sets[name] = lm - return lm - - def module_set(self, name: str) -> "ModuleSet": - try: - return self._module_sets[name] - except KeyError: - raise KeyError( - f"ModuleSet '{name}' not found. Available: {self._module_sets.keys()}" - ) - - def queue(self, index: int = -1) -> "WorkQueue": - """Gets a queue either with an explicit index or in some rotating fashion.""" - if index >= 0: - return self.queues[index] - else: - self._queue_request_count += 1 - qc = self._queue_request_count - return self.queues[qc % len(self.queues)] - - -class ModuleSet: - __slots__ = [ - "contexts", - "modules", - "name", - "session", - "_context_counter", - ] - - def __init__(self, session: DeviceSession, name: str, *, context_count: int): - assert context_count > 0 - self.session = session - self.name = name - self.modules: list[VmModule] = [ - create_hal_module(session.vm_instance, session.device) - ] - self.contexts = [None] * context_count - self._context_counter = 0 - - @property - def initialized(self) -> bool: - return self.contexts[-1] is not None - - def add(self, *modules: VmModule): - for module in modules: - self.modules.append(module) - - def load_vmfb(self, vmfb_path: str): - logger.info("Loading VMFB %s", vmfb_path) - self.add(VmModule.mmap(self.session.vm_instance, vmfb_path)) - - def load_io_module(self, sources_path: str): - logger.info("Loading IO Module %s", sources_path) - index = ParameterIndex() - index.load(sources_path) - par_provider = index.create_provider(scope="model") - self.add(create_io_parameters_module(self.session.vm_instance, par_provider)) - - def initialize(self): - assert not self.initialized, "Already initialized" - count = len(self.contexts) - logger.info("Initializing %s contexts for %s", count, self.name) - for i in range(count): - self.contexts[i] = HostContext( - self.session, self.modules, name=f"HostContext-{self.name}-{i}" - ) - - def shutdown(self): - for hc in self.contexts: - if hc is not None: - hc.shutdown() - - def module(self, name: str) -> VmModule: - for m in self.modules: - if m.name == name: - return m - raise KeyError( - f"Module `{name}` not found. Available: {[m.name for m in self.modules]}" - ) - - def function(self, module_name: str, function_name: str) -> VmFunction: - m = self.module(module_name) - f = m.lookup_function(function_name) - if f is None: - raise KeyError( - f"Function '{function_name}' not found in '{module_name}'. " - f"Available: {m.function_names}" - ) - return f - - @property - def host_context(self) -> "HostContext": - """Gets a context, load balancing across available instances.""" - with _CONFIG_LOCK: - self._context_counter += 1 - counter = self._context_counter - contexts = self.contexts - context = contexts[counter % len(contexts)] - assert context is not None, "Module set not initialized" - return context - - -_ThunkQueueT = queue.SimpleQueue[Union[None, Callable[[], None]]] - - -class HostContext: - def __init__(self, session: DeviceSession, modules: list[VmModule], name: str): - self.session = session - self.vm_context = VmContext(session.vm_instance, modules=modules) - self.name = name - self.loop = asyncio.new_event_loop() - self.loop.set_debug(True) - - # def exc_handler(loop, context): - # print("[EXCEPTION]", loop, context) - # self.loop.set_exception_handler(exc_handler) - - self._device_bridge = HalDeviceLoopBridge(session.device, self.loop) - self._shutdown_future = self.loop.create_future() - logger.info(f"Starting asyncio loop thread %s", name) - self._loop_thread = Thread( - target=self.loop.run_until_complete, - args=[self._shutdown_future], - name=name, - daemon=False, - ) - self._loop_thread.start() - - def shutdown(self, join: bool = True): - if self._shutdown_future is None: - return - logger.info("Signalling shutdown of host context %s", self.name) - local_future = self._shutdown_future - del self._shutdown_future - - def _shutdown(): - local_future.set_result(True) - - self.loop.call_soon_threadsafe(_shutdown) - self._device_bridge.stop() - if join: - self._loop_thread.join() - self.loop.close() - - def __del__(self): - if hasattr(self, "_shutdown_future"): - warnings.warn(f"HostContext deallocated without shutdown(): {self}") - self.shutdown(join=False) - - def run_concurrent( - self, coro: Coroutine[Any, Any, T] - ) -> concurrent.futures.Future[T]: - """Runs a coroutine from another thread, returning a concurrent Future. - - This should be used for submitting initial work to the host context from - another thread or event loop. - - Note that the concurrent Future should have its result() retrieved to - ensure that any asynchronous exceptions are propagated. Otherwise, they will - be silently consumed. - """ - return asyncio.run_coroutine_threadsafe(coro, self.loop) - - def run_sync(self, coro: Coroutine[Any, Any, T]) -> T: - """Runs a coroutine on the host context loop from another thread. - - Waits on and returns the result. - This is primarily intended for testing. - """ - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() - - def on_semaphore( - self, sem: HalSemaphore, payload: int, value: Any - ) -> asyncio.Future: - """Returns an awaitable for when the semaphore attains a payload timepoint. - - The resulting Future will take the given `value` once complete. - """ - return self._device_bridge.on_semaphore(sem, payload, value) - - -class WorkQueue: - """Models a queue as a progression of steps against a timeline semaphore.""" - - __slots__ = [ - "_device", - "_lock", - "_semaphore", - "_step", - "index", - ] - - def __init__(self, session: DeviceSession, index: int = 0): - self.index = index - self._device = session.device - self._lock = Lock() - self._semaphore = session.device.create_semaphore(0) - self._step = 0 - - def execute_sequential(self, command_buffer: HalCommandBuffer): - """Executes a list of command buffers at the current step, advancing to the - next. - """ - with self._lock: - current_step = self._step - next_step = current_step + 1 - self._step = next_step - sem = self._semaphore - self._device.queue_execute( - command_buffer, [(sem, current_step)], [(sem, next_step)] - ) - - def current_fence(self) -> HalFence: - """Gets a fence representing the current step.""" - with self._lock: - return HalFence.create_at(self._semaphore, self._step) - - def step_fences(self) -> tuple[HalFence, HalFence]: - """Gets two fences, one at the current step and one at the next.""" - with self._lock: - current_step = self._step - next_step = current_step + 1 - self._step = next_step - sem = self._semaphore - return HalFence.create_at(sem, current_step), HalFence.create_at(sem, next_step) - - def sync(self, host_context: HostContext) -> asyncio.Future: - """Awaitable that completes when all work currently queued completed.""" - with self._lock: - current_step = self._step - return host_context.on_semaphore(self._semaphore, current_step, True) - - def guard(self, value: T) -> "TimelineGuarded[T]": - """Guards an arbitrary value as a timeline guard at the current queue - position. The value will become available when the queue is sync'd.""" - return TimelineGuarded(value, self._semaphore, self._step) - - def __repr__(self): - with self._lock: - return f"WorkQueue[{self.index}](semaphore={self._semaphore}, step={self._step}" - - -class TransferBuffer: - """Transfer buffers are pairs of host/device buffers of a specific size. - - They are used for streaming to/from the device. - """ - - __slots__ = [ - "host_buffer", - "device_buffer", - "host_buffer_map", - "_pool", - ] - - def __init__(self, session: DeviceSession, buffer_size_bytes: int): - self.host_buffer = session.device.allocator.allocate_buffer( - memory_type=MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=buffer_size_bytes, - ) - self.device_buffer = session.device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=buffer_size_bytes, - ) - self.host_buffer_map = self.host_buffer.map() - self._pool: Optional["TransferBufferPool"] = None - - @staticmethod - def allocate_shaped( - session: DeviceSession, shape: list[int], element_type: HalElementType - ) -> "TransferBuffer": - assert HalElementType.is_byte_aligned(element_type) - buffer_size_bytes = math.prod(shape) * HalElementType.dense_byte_count( - element_type - ) - return TransferBuffer(session, buffer_size_bytes) - - def recycle(self): - pool = self._pool - assert ( - pool is not None - ), f"Cannot recycle a TransferBuffer that was not acquired from a pool ({self})" - self._pool = None - pool.recycle(self) - - def h2d_array( - self, - cb: HalCommandBuffer, - shape: list[int], - element_type: HalElementType, - *, - fill_value: Any = None, - ) -> tuple[np.ndarray, HalBufferView]: - """Performs an h2d transfer on the given CommandBuffer of the given shape and - element type. - - Returns a host array and device buffer view. Because transfers do not start - until the command buffer is submitted, the host array should be populated - between the return from this call and submission. - """ - ary = self.host_buffer_map.asarray( - shape, HalElementType.map_to_dtype(element_type) - ) - if fill_value is not None: - ary.fill(fill_value) - bv = HalBufferView(self.device_buffer, shape, element_type) - cb.copy(self.host_buffer, self.device_buffer, length=bv.byte_length) - return ary, bv - - def __repr__(self): - if self._pool is None: - return f"TransferBuffer(FREE)" - else: - return f"TransferBuffer({self._pool})" - - if not NDEBUG: - - def __del__(self): - if self._pool is not None: - warnings.warn( - f"Deallocated TransferBuffer which needed to be recycled: {self}" - ) - - -class TransferBufferPool: - """Pool of transfer buffers of a fixed size.""" - - __slots__ = [ - "_allocator", - "_free_list", - "name", - ] - - def __init__( - self, - allocator: Callable[[], TransferBuffer], - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ): - self.name = name - if initial_capacity > 0: - self._free_list = [allocator() for _ in range(initial_capacity)] - self._allocator = None - if growable: - self._allocator = allocator - - @staticmethod - def shaped( - session: DeviceSession, - shape: list[int], - element_type: HalElementType, - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ) -> "TransferBufferPool": - """Allocates a pool of transfer buffers of the given shape.""" - if initial_capacity > 0: - logger.info( - "Allocating initial capacity %s of '%s' transfer buffers: %s x %r", - initial_capacity, - name, - shape, - element_type, - ) - return TransferBufferPool( - lambda: TransferBuffer.allocate_shaped(session, shape, element_type), - initial_capacity=initial_capacity, - growable=growable, - name=name, - ) - - @staticmethod - def sized( - session: DeviceSession, - buffer_byte_size: int, - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ) -> "TransferBufferPool": - """Allocates a pool of transfer buffers of a given size in bytes.""" - if initial_capacity > 0: - logger.info( - "Allocating initial capacity %s of '%s' transfer buffers: %s bytes", - initial_capacity, - name, - buffer_byte_size, - ) - return TransferBufferPool( - lambda: TransferBuffer(session, buffer_byte_size), - initial_capacity=initial_capacity, - growable=growable, - name=name, - ) - - def acquire(self) -> TransferBuffer: - """Acquires a transfer buffer from the pool. - - Must be returned via recycle() when done. - """ - free_list = self._free_list - if len(free_list) > 0: - tb = free_list.pop() - assert tb._pool is None - tb._pool = self - return tb - - allocator = self._allocator - if not allocator: - raise RuntimeError( - f"Transfer buffer pool '%s' exhausted and not growable", self.name - ) - logger.info("Grow transfer buffer pool '%s'", self.name) - tb = allocator() - assert tb._pool is None - tb._pool = self - return tb - - def recycle(self, tb: TransferBuffer): - """Recycles an acquired transfer buffer.""" - self._free_list.append(tb) - - def __repr__(self): - return f"TransferBufferPool({self.name})" - - -class AsyncResources: - """Resources held for some asynchronous scope.""" - - __slots__ = [ - "_resources", - ] - - def __init__(self): - self._resources: list[Union[TransferBuffer, "AsyncResources"]] = [] - - def acquire_transfer_buffer(self, pool: TransferBufferPool) -> TransferBuffer: - tb = pool.acquire() - self._resources.append(tb) - return tb - - def recycle(self): - for r in self._resources: - r.recycle() - self._resources.clear() - - if not NDEBUG: - - def __del__(self): - if len(self._resources) != 0: - warnings.warn( - f"Deallocated AsyncResources that was not recycled: {self}" - ) - self.recycle() - - -class TimelineGuarded(Generic[T]): - """Some form of results that are structurally available now but will not be - populated until some point in the future. - - This is used to encapsulate entities that are guarded by availability of - a timepoint. Note that we only allow a single timepoint guard in order to - simplify subsequent coordination. This will typically be the case when the - guard is derived from a queue of some form (as opposed to a gather). - """ - - __slots__ = [ - "value", - "sem", - "timeline", - ] - - def __init__(self, value: T, sem: HalSemaphore, timeline: int): - self.value = value - self.sem = sem - self.timeline = timeline - - def resolve(self, host_context: HostContext) -> asyncio.Future[T]: - """Produces an awaitable that resolves to the value once available.""" - return host_context.on_semaphore(self.sem, self.timeline, self.value) - - def __repr__(self): - return f"TimelineGuarded[{self.sem} @ {self.timeline}] = {self.value}" diff --git a/sharktank/sharktank/serving_poc/llm/__init__.py b/sharktank/sharktank/serving_poc/llm/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sharktank/sharktank/serving_poc/llm/api/rest_server.py b/sharktank/sharktank/serving_poc/llm/api/rest_server.py deleted file mode 100644 index 67536173f..000000000 --- a/sharktank/sharktank/serving_poc/llm/api/rest_server.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Heavily adapted from the vllm api_server.py. - -from typing import AsyncGenerator, Optional, Sequence - -import argparse -import json - -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -import sys -import uuid -import uvicorn - -from ...framework.logging import get_logger -from ...framework.session import DeviceSession - - -from ..service import ( - create_mock_generate_service, - GenerateService, - GenerateRequest, -) - -logger = get_logger("sharktank.serving_poc.llm.api_server") -app = FastAPI() -service: Optional[GenerateService] = None - - -def get_service() -> GenerateService: - assert service is not None, "Service was not initialized" - return service - - -@app.get("/health") -async def health() -> Response: - get_service() - return Response(status_code=200) - - -@app.post("/generate") -async def generate(request: Request) -> Response: - service = get_service() - r = await request.json() - prompt = r.pop("prompt") - stream = bool(r.pop("stream", False)) - request_id = uuid.uuid4().hex - - generate_request = GenerateRequest(request_id=request_id, prompt=prompt) - result_parts = service.handle_request(generate_request) - - if stream: - # TODO: This isn't entirely matching how others do it: we should be returning - # the full result on each update. - async def stream_contents() -> AsyncGenerator[bytes, None]: - async for part in result_parts: - response_record = json.dumps({"text": part.text}) - yield (response_record + "\0").encode() - - return StreamingResponse(stream_contents()) - - # Non-streaming just reads to the final. - async for result_part in result_parts: - if await request.is_disconnected(): - # Abort. - await service.abort(generate_request.request_id) - return Response(status_code=499) - - assert result_part is not None, "No results generated!" - return JSONResponse({"text": result_part.text}) - - -def main(clargs: Sequence[str]): - global service - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="Root path to use for installing behind path based proxy.", - ) - parser.add_argument( - "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" - ) - parser.add_argument( - "--testing-mock-service", - action="store_true", - help="Enable the mock testing service", - ) - parser.add_argument( - "--device-uri", type=str, default="local-task", help="Device URI to serve on" - ) - - args = parser.parse_args(clargs) - - # Spin up the device machinery. - # Note that in the future, for multi-device, we will need more scaffolding for - # configuration and bringup, obviously. - device_session = DeviceSession(uri=args.device_uri) - - if args.testing_mock_service: - logger.info("Enabling mock LLM generate service") - service = create_mock_generate_service() - - app.root_path = args.root_path - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=args.timeout_keep_alive, - ) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py deleted file mode 100644 index a2299c67e..000000000 --- a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Manages the block cache.""" - -from iree.runtime import ( # type: ignore - HalBufferView, - HalElementType, - BufferUsage, - MemoryType, - PyModuleInterface, - VmModule, -) - -from ..framework.logging import get_logger -from ..framework.session import DeviceSession - -from .config import human_size, CacheParams - - -logger = get_logger("sharktank.serving_poc.llm.cache") - - -class AttnBlockCacheEntry: - __slots__ = [ - "index", - "in_use", - ] - - def __init__(self, index: int): - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnBlockCache: - def __init__(self, session: DeviceSession, cache_params: CacheParams): - self.session = session - self.cache_params = cache_params - self._initialize_block_cache() - - def _initialize_block_cache(self): - model_params = self.cache_params.model - # Allocate the on-device cache slab. - attn_block_count = self.cache_params.device_block_count - attn_block_size_elements = self.cache_params.attn_block_size_elements - attn_block_size_bytes = attn_block_size_elements * model_params.attn_dtype_size - attn_cache_size_bytes = attn_block_count * attn_block_size_bytes - - logger.info("Setting up cache for\n %r", self.cache_params) - logger.info( - "Allocating attention static cache on device of %s " - "(blocks=%s, block_size=%s bytes)", - human_size(attn_cache_size_bytes), - attn_block_count, - attn_block_size_bytes, - ) - self.attn_block_buffer = self.session.device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=attn_cache_size_bytes, - ) - - # Attn block logical view. - self.attn_block_buffer_view = HalBufferView( - self.attn_block_buffer, - [ - attn_block_count, - attn_block_size_elements, - ], - model_params.attn_dtype, - ) - - # Accounting structs. - self.attn_block_entries = [ - AttnBlockCacheEntry(i) for i in range(attn_block_count) - ] - self.attn_block_free = list(self.attn_block_entries) - - async def acquire_attn_blocks( - self, count: int, into_list: list[AttnBlockCacheEntry] - ): - """Acquires 'count' attention blocks. - - If there are insufficient free blocks, raises an exception. - """ - free_list = self.attn_block_free - assert ( - len(free_list) >= count - ), f"Cache does not contain requested {count} free attn blocks" - for i in range(count): - into_list.append(free_list.pop()) - - async def release_attn_blocks(self, blocks: list[AttnBlockCacheEntry]): - """Releases a list of attention blocks. - - If at all possible, this should be batched to include all blocks that need to - be released at a given time since this will trigger heavy-weight scheduling - that will work better with a view of the new free list as a whole. - """ - free_list = self.attn_block_free - for block in blocks: - free_list.append(block) - - -def create_attn_block_cache_module(attn_block_cache: AttnBlockCache) -> VmModule: - """Creates a VM module that exports the attention block cache. - - For in-system use, we use a dynamic module that can provide the block cache - slab. In other uses, this may be provided by a statically compiled module - that does the same. - - Interface: - Module name: attn_block_cache - Exports: - func @attn_block_cache.get_shared_slab() -> (!hal.buffer_view) - """ - - class Module: - def __init__(self, iface): - ... - - def get_shared_slab(self): - return attn_block_cache.attn_block_buffer_view.ref - - iface = PyModuleInterface(module_name="attn_block_cache", ctor=Module) - iface.export("get_shared_slab", "0v_r", Module.get_shared_slab) - return iface.create() diff --git a/sharktank/sharktank/serving_poc/llm/config.py b/sharktank/sharktank/serving_poc/llm/config.py deleted file mode 100644 index df5db5f8f..000000000 --- a/sharktank/sharktank/serving_poc/llm/config.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Configuration objects. - -Parameters that are intrinsic to a specific model. - -In a typical transformer model, the KV cache is organized similar to (mapped to -our parameter names below): - k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) - v = ... - -For context, a popular model has parameters of: - attn_dtype_size = 2 # (fp16) - max_seq_len = 2048 - transformer_block_count = 32 - attn_head_count = 32 - attn_head_dim = 128 # (dim / head_count) - -If paging, then we primary care about the organization of a single block, where -a block represents a single position in the sequence for a single item in the batch. -Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) - -In this scenario, we declare that one block holds the KV cache for all transformer -block layers because it reduces the accounting. As such, for the above example, -a single position in the sequence will be 524,288 bytes, assuming a 2-byte element -type. If we choose to block by block_stride=16 positions, each block will be 8MiB. -Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 -blocks for a total number of sequence positions of 24,576. - -These are well-known numbers but are derived above to give a sense of scale. - -In order to indirect through to the block cache, we have to provide the index map -to specific invocations: - -* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will - need write indices of [batch_size, prompt_len // block_stride + 1]. -* Decode step: Decode is auto-regressive, and needs to first compute the new kv - row and then attend over all rows in the cache up to this point in the sequence. - -If wanting to avoid dynamic allocation of transients, we can also pool the index -tables based on the maximum batch size and maximum sequence length. Since all -block cache sizes are well within the range of an i16, we will use that for storage. -Therefore, each batch invocation would need a block lookup table of: - - byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) - -For a max_batch_size of 16, this is 4KiB of block index table lookups per -invocation. We don't have to statically allocate this, but the system is more -predictable if we just reserve what we need. Again, numbers are given to give a -sense of scale only: real workloads will vary. -""" - -from dataclasses import dataclass - -from iree.runtime import ( # type: ignore - HalElementType, -) - -import json - - -@dataclass -class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" - - # The element type of the attention caches. - attn_dtype: HalElementType - - # Maximum length of a sequence including prompt and output. - max_seq_len: int - - # Number of transformer blocks. - transformer_block_count: int - - # Number of attention heads per block. - attn_head_count: int - - # Dimensionality of each attention head - attn_head_dim: int - - # Position stride per attention block - block_seq_stride: int - - # Batch sizes that the prefill stage is compiled for. These are expected to be - # functions exported from the model with suffixes of "_bs{batch_size}". Must - # be in ascending order. - prefill_batch_sizes: list[int] - - # Similarly, batch sizes that the decode stage is compiled for. - decode_batch_sizes: list[int] - - # Name of the IREE module implementing the model. - module_name: str = "module" - - # ABI of the module. - module_abi_version: int = 1 - - # Size in bytes of the KV cache dtype. - @property - def attn_dtype_size(self) -> int: - assert HalElementType.is_byte_aligned(self.attn_dtype) - return HalElementType.dense_byte_count(self.attn_dtype) - - @property - def max_prefill_batch_size(self) -> int: - return self.prefill_batch_sizes[-1] - - @property - def max_decode_batch_size(self) -> int: - return self.decode_batch_sizes[-1] - - @property - def max_batch_size(self): - return max(self.max_prefill_batch_size, self.max_decode_batch_size) - - @staticmethod - def load_json(path): - f = open(path) - j = json.load(f) - return ModelParams(attn_dtype=HalElementType.FLOAT_16, **j) - - -@dataclass -class CacheParams: - """Parameters for management of the block cache. - - This is paired with a ModelParams. - - We presently use a static block cache configuration and hand-wave either a tuning - run or pen/paper analysis to derive the parameters. - """ - - model: ModelParams - - # The size of the static block cache on the device. - device_block_count: int - - # The stride of each block in sequence positions. - block_pos_stride: int - - @property - def attn_unit_size_elements(self) -> int: - """Size in bytes of each cache line in the attention cache. - - Each cache line can store a unit position stride. - """ - size = 1 - size *= self.model.transformer_block_count - size *= 2 # K and V cache line - size *= self.model.attn_head_count - size *= self.model.attn_head_dim - return size - - @property - def attn_block_size_elements(self) -> int: - """Size in bytes of each attention block of {block_position_stride} positions.""" - return self.attn_unit_size_elements * self.block_pos_stride - - -@dataclass -class ServiceParams: - """Parameters for the serving service.""" - - cache: CacheParams - model: ModelParams - - -# From: https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size -def human_size(num, suffix="B"): - for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): - if abs(num) < 1024.0: - return f"{num:3.1f}{unit}{suffix}" - num /= 1024.0 - return f"{num:.1f}Yi{suffix}" diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py deleted file mode 100644 index 8ae0be637..000000000 --- a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py +++ /dev/null @@ -1,495 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Implements the BatchGenerateService for V1 compiled models. - -This is far from where we want to land but is intended for first round bootstrapping. -Perhaps the biggest issue is that it wouldn't mate well as-is with samplers. -""" - -import asyncio -from dataclasses import dataclass - -import numpy as np - -from iree.runtime import ( # type: ignore - HalBufferView, - HalCommandBuffer, - HalElementType, - HalFence, - VmFunction, - VmVariantList, -) - -from ...framework.logging import get_logger, NDEBUG -from ...framework.session import ( - AsyncResources, - DeviceSession, - TimelineGuarded, - TransferBufferPool, - WorkQueue, -) - -from ..attn_block_cache import AttnBlockCacheEntry, AttnBlockCache -from ..config import ServiceParams -from ..service import ( - BatchGenerateService, - BatchGenerateState, - GenerateRequest, -) - - -logger = get_logger("sharktank.serving_poc.llm.impl.service_v1") - -EXPECTED_CONCURRENCY = 10 - - -class GenerateServiceV1(BatchGenerateService): - def __init__( - self, *, session: DeviceSession, params: ServiceParams, cache: AttnBlockCache - ): - self.params = params - self.block_pos_stride = params.cache.block_pos_stride - self.batch_sizes = params.model.prefill_batch_sizes - # TODO: Remove distinction between prefill and decode batch sizes. - assert params.model.decode_batch_sizes == self.batch_sizes - self.session = session - self.cache = cache - module_name = params.model.module_name - logger.info("Configuring serving for module set %s", module_name) - self.module_set = session.module_set(params.model.module_name) - - # Initialize prefill entry-points (1 per batch size). - self.prefill_functions: dict[int, VmFunction] = {} - for bs in self.batch_sizes: - assert bs not in self.prefill_functions - symbol_name = f"prefill_bs{bs}" - logger.info("Looking up symbol '%s'", symbol_name) - self.prefill_functions[bs] = self.module_set.function( - module_name, symbol_name - ) - - # Initialize decode entry-points (1 per batch size). - self.decode_functions: dict[int, VmFunction] = {} - for bs in self.batch_sizes: - assert bs not in self.decode_functions - symbol_name = f"decode_bs{bs}" - logger.info("Looking up symbol '%s'", symbol_name) - self.decode_functions[bs] = self.module_set.function( - module_name, symbol_name - ) - - self._initialize_transfer_pools() - - def _initialize_transfer_pools(self): - params = self.params - max_bs = params.model.max_batch_size - max_sl = params.model.max_seq_len - initial_inflight = EXPECTED_CONCURRENCY - - # block_indices_pool: array([max_batch_size, max_attn_blocks], np.int64) - # Suitable to handle the sequence->block mapping for all steps. - self.block_indices_pool = TransferBufferPool.shaped( - self.session, - [ - max_bs, - max_sl // self.block_pos_stride, - ], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="block_cache_indices", - ) - - # Prefill tokens: array([max_batch_size, max_seq_len], np.int64) - # Tokens inputs to prefill. - self.prefill_tokens_pool = TransferBufferPool.shaped( - self.session, - [ - max_bs, - max_sl, - ], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="prefill_tokens", - ) - - # Prefill sequence lengths: array([max_batch_size], np.int64) - # Sequence lengths of input tokens. - self.prefill_seq_lens_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="prefill_seq_lens", - ) - - # Decode tokens: array([max_batch_size], np.int64) - # Tokens to perform a decode step with. - self.decode_tokens_pool = TransferBufferPool.shaped( - self.session, - [max_bs, 1], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_tokens", - ) - - # Decode seq lengths: array([max_batch_size], np.int64) - # Decoder seq length for this step - self.decode_seq_lens_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_seq_len", - ) - - # Decode start positions: array([max_batch_size], np.int64) - # Tokens to perform a decode step with. - self.decode_start_pos_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_start_pos", - ) - - def start(self) -> "GenerateState": - return GenerateState(self) - - def shutdown(self): - self.session.shutdown() - - -class _Sequence: - __slots__ = [ - "attn_blocks", - "attn_blocks_needed", - "current_token_ids", - "decode_token_ids", - "request", - "seq_length", - ] - - current_token_ids: list[int] - decode_token_ids: list[int] - - def __init__(self, request: GenerateRequest): - self.request = request - self.seq_length: int = 0 - self.attn_blocks: list[AttnBlockCacheEntry] = [] - self.attn_blocks_needed: int = 0 - self.decode_token_ids = [] - self.current_token_ids = [] - - def attn_blocks_available(self): - return len(self.attn_blocks) - - def resize_attention(self, new_size): - old_size = self.attn_blocks_needed - self.attn_blocks_needed = new_size - return new_size - old_size - - -class GenerateState(BatchGenerateState): - __slots__ = [ - "_bs", - "_decode_function", - "_prefill_function", - "_max_attn_blocks_length", - "_max_seq_length", - "_resources", - "_service", - "_sequences", - "_batch_queue", - ] - - def __init__(self, service: GenerateServiceV1): - super().__init__(service.module_set.host_context) - self._resources = AsyncResources() - self._service = service - self._sequences: list[_Sequence] = [] - self._batch_queue = WorkQueue(service.session) - - async def recycle(self): - """Recycles or releases all resources consumed by this instance.""" - cache = self._service.cache - self._batch_queue.sync(self.host_context) - self._resources.recycle() - all_blocks = [] - for seq in self._sequences: - all_blocks.extend(seq.attn_blocks) - seq.attn_blocks.clear() - self._sequences = [] - await cache.release_attn_blocks(all_blocks) - - async def set_sequences(self, requests: list[GenerateRequest]): - """Initiates processing of a list of sequences that make up a batch. - - This is async because it acquires resources which may not be available. - """ - service = self._service - block_pos_stride = service.block_pos_stride - - # Loop through each request and reserve initial attention blocks. - bs = 0 - sequences = self._sequences - assert not sequences, "set_sequences already called" - max_attn_blocks_length = 0 - max_seq_length = 0 - attn_blocks_required = 0 - - for req in requests: - bs += 1 - seq = _Sequence(req) - sequences.append(seq) - seq.current_token_ids = req.required_prompt_token_ids - seq_length = len(seq.current_token_ids) - seq.seq_length = seq_length - max_seq_length = max(max_seq_length, seq_length) - initial_block_count = seq_length // block_pos_stride + 1 - attn_blocks_required += initial_block_count - seq.attn_blocks_needed = initial_block_count - max_attn_blocks_length = max(max_attn_blocks_length, initial_block_count) - - # Determine the appropriate batched entrypoints. - assert bs > 0 - for allowed_bs in service.batch_sizes: - if allowed_bs >= bs: - self._prefill_function = service.prefill_functions[allowed_bs] - self._decode_function = service.decode_functions[allowed_bs] - break - else: - raise AssertionError(f"Unsupported batch size: {bs}") - - # Acquire the needed attention blocks in one batch so as to give the scheduler - # the most visibility into the need. - logger.debug("Acquire prefill attn blocks: %s", attn_blocks_required) - all_attn_blocks: list[AttnBlockCacheEntry] = [] - await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks) - block_index = 0 - for seq in sequences: - next_block_count = seq.attn_blocks_needed - seq.attn_blocks.extend( - all_attn_blocks[block_index : block_index + seq.attn_blocks_needed] - ) - block_index += next_block_count - - # Save state. - self._bs = allowed_bs - self._max_attn_blocks_length = max_attn_blocks_length - self._max_seq_length = max_seq_length - - async def prefill(self) -> TimelineGuarded[HalBufferView]: - hc = self.host_context - service = self._service - resources = self._resources - bs = self._bs - service = self._service - block_pos_stride = service.block_pos_stride - max_attn_blocks_length = self._max_attn_blocks_length - max_seq_length = max_attn_blocks_length * block_pos_stride - sequences = self._sequences - work_queue = self._batch_queue - - # Record a command buffer for performing h2d transfers. - cb = HalCommandBuffer(hc.session.device) - - # Prepare input tokens, sequence lengths and block indices. - # We acquire a transfer buffer of each from the respective pool, populate its - # host side and enqueue. - # prefill_tokens: array([bs, max_seq_length], np.int32) - prefill_tokens_host, prefill_tokens_device = resources.acquire_transfer_buffer( - service.prefill_tokens_pool - ).h2d_array(cb, [bs, max_seq_length], HalElementType.SINT_64, fill_value=0) - - # prefill_seq_lens: array([bs], np.int32) - ( - prefill_seq_lens_host, - prefill_seq_lens_device, - ) = resources.acquire_transfer_buffer(service.prefill_seq_lens_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # attn_block_indices: array([bs, max_attn_blocks], np.in16) - ( - prefill_attn_block_indices_host, - prefill_attn_block_indices_device, - ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array( - cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0 - ) - - # Populate host buffers for each sequence. - for i in range(len(sequences)): - seq = sequences[i] - attn_blocks = seq.attn_blocks - current_token_ids = seq.current_token_ids - row_seq_len = len(current_token_ids) - prefill_tokens_host[i, 0:row_seq_len] = current_token_ids - prefill_seq_lens_host[i] = row_seq_len - for j in range(len(seq.attn_blocks)): - prefill_attn_block_indices_host[i, j] = attn_blocks[j].index - - # Perform h2d transfers. - cb.end() - work_queue.execute_sequential(cb) - - # Inputs: - # token_ids - # seq_lens - # attn_block_indices - # attn_block_buffer_view (the entire slab passed as input) - # wait, signal semaphores - # tied attn_block_buffer (for input[2]) - # tied attn_block_buffer (for result[0]) - inputs = VmVariantList(3) - inputs.push_ref(prefill_tokens_device) - inputs.push_ref(prefill_seq_lens_device) - inputs.push_ref(prefill_attn_block_indices_device) - inputs.push_ref(service.cache.attn_block_buffer_view) - - # Outputs: - # attn_block_buffer_view (tied output) - # decode_tokens - outputs = VmVariantList(1) - # TODO: Async invoke. - hc.vm_context.invoke(self._prefill_function, inputs, outputs) - return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView)) - - async def set_decode_step(self, tokens): - """Initiates processing of a list of tokens to decode across each batch - - This is async because it acquires resources which may not be available. - """ - service = self._service - block_pos_stride = service.block_pos_stride - - sequences = self._sequences - assert sequences, "set_sequences was not called yet" - assert len(sequences) == len(tokens), "expected token for each sequence" - - max_attn_blocks_length = 0 - max_seq_length = 0 - attn_blocks_required = 0 - - for tok, seq in zip(tokens, self._sequences): - seq.decode_token_ids.append(tok) - seq.seq_length = seq.seq_length + 1 - - max_seq_length = max(max_seq_length, seq.seq_length) - block_count = seq.seq_length // block_pos_stride + 1 - - seq.attn_blocks_needed = block_count - attn_blocks_required += block_count - seq.attn_blocks_available() - max_attn_blocks_length = max(max_attn_blocks_length, block_count) - - # Acquire the needed attention blocks in one batch so as to give the scheduler - # the most visibility into the need. - logger.debug("Acquire decode attn blocks: %s", attn_blocks_required) - all_attn_blocks: list[AttnBlockCacheEntry] = [] - await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks) - block_index = 0 - for seq in sequences: - next_block_count = seq.attn_blocks_needed - seq.attn_blocks_available() - seq.attn_blocks.extend( - all_attn_blocks[block_index : block_index + next_block_count] - ) - block_index += next_block_count - - # Save state. - self._max_attn_blocks_length = max_attn_blocks_length - self._max_seq_length = max_seq_length - - async def decode(self) -> TimelineGuarded[HalBufferView]: - hc = self.host_context - service = self._service - resources = self._resources - bs = self._bs - max_attn_blocks_length = self._max_attn_blocks_length - sequences = self._sequences - work_queue = self._batch_queue - - # Record a command buffer for performing h2d transfers. - cb = HalCommandBuffer(hc.session.device) - - # decode_tokens: array([bs, 1], np.int32) - (decode_tokens_host, decode_tokens_device,) = resources.acquire_transfer_buffer( - service.decode_tokens_pool - ).h2d_array(cb, [bs, 1], HalElementType.SINT_64, fill_value=0) - - # decode_seq_lens: array([bs], np.int32) - ( - decode_seq_lens_host, - decode_seq_lens_device, - ) = resources.acquire_transfer_buffer(service.decode_seq_lens_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # decode_start_pos: array([bs], np.int32) - ( - decode_start_pos_host, - decode_start_pos_device, - ) = resources.acquire_transfer_buffer(service.decode_start_pos_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # attn_block_indices: array([bs, max_attn_blocks], np.in16) - ( - decode_attn_block_indices_host, - decode_attn_block_indices_device, - ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array( - cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0 - ) - - # Populate host buffers for each sequence. - for i in range(len(sequences)): - seq = sequences[i] - attn_blocks = seq.attn_blocks - - tok = seq.decode_token_ids[0] - seq_len = len(seq.current_token_ids) - print(f"seq.current_token_ids: {seq.current_token_ids}") - seq.current_token_ids.append(tok) - seq.decode_token_ids = seq.decode_token_ids[1:] - - decode_tokens_host[i, 0] = tok - decode_start_pos_host[i] = seq_len - decode_seq_lens_host[i] = seq_len - for j in range(len(seq.attn_blocks)): - decode_attn_block_indices_host[i, j] = attn_blocks[j].index - - # Perform h2d transfers. - cb.end() - work_queue.execute_sequential(cb) - - # Inputs: - # token_ids - # seq_lens - # start_pos - # attn_block_indices - # attn_block_buffer_view (the entire slab passed as input) - # wait, signal semaphores - # tied attn_block_buffer (for input[4]) - # tied attn_block_buffer (for result[0]) - inputs = VmVariantList(5) - inputs.push_ref(decode_tokens_device) - inputs.push_ref(decode_seq_lens_device) - inputs.push_ref(decode_start_pos_device) - inputs.push_ref(decode_attn_block_indices_device) - inputs.push_ref(service.cache.attn_block_buffer_view) - - # Outputs: - # attn_block_buffer_view (tied output) - # decode_tokens - outputs = VmVariantList(1) - # TODO: Async invoke. - hc.vm_context.invoke(self._decode_function, inputs, outputs) - return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView)) diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py deleted file mode 100644 index 7895341c9..000000000 --- a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import asyncio -import argparse -import numpy -import sys - -from transformers import LlamaTokenizer # type: ignore - -from iree.runtime import ( # type: ignore - HalElementType, -) - -from sharktank.serving_poc.framework.session import DeviceSession - -from sharktank.serving_poc.llm.attn_block_cache import ( - create_attn_block_cache_module, - AttnBlockCache, -) - -from sharktank.serving_poc.llm.config import ( - CacheParams, - ModelParams, - ServiceParams, -) - -from sharktank.serving_poc.llm.impl.service_v1 import GenerateServiceV1 -from sharktank.serving_poc.llm.service import GenerateRequest - - -def setup(vmfb_path, config_path, gguf_path): - from iree.runtime._binding import disable_leak_checker # type: ignore - - model_params = ModelParams.load_json(config_path) - - device_block_count = model_params.max_seq_len // model_params.block_seq_stride - cache_params = CacheParams( - model=model_params, - device_block_count=device_block_count, - block_pos_stride=model_params.block_seq_stride, - ) - - disable_leak_checker() - session = DeviceSession(uri="local-sync", queue_count=2) - attn_block_cache = AttnBlockCache(session, cache_params) - - lms = session.create_module_set(model_params.module_name, context_count=1) - lms.load_io_module(gguf_path) - lms.load_vmfb(vmfb_path) - lms.add(create_attn_block_cache_module(attn_block_cache)) - lms.initialize() - - params = ServiceParams(cache=cache_params, model=model_params) - service = GenerateServiceV1(session=session, params=params, cache=attn_block_cache) - return service - - -def map_buffer(value): - mapped = value.map() - return mapped.asarray(value.shape, HalElementType.map_to_dtype(value.element_type)) - - -async def main(argv): - parser = argparse.ArgumentParser() - parser.add_argument("--tokenizer", help="name of hugginface tokenizer to use") - parser.add_argument("--config", help="json config file with hyperparameters") - parser.add_argument("--vmfb", help="vmfb with compiler LLM kernels") - parser.add_argument("--gguf", help="gguf file containing modle coefficients") - parsed = parser.parse_args(argv) - - hf_path = parsed.tokenizer - config_path = parsed.config - vmfb_path = parsed.vmfb - gguf_path = parsed.gguf - - service = setup(vmfb_path, config_path, gguf_path) - tokenizer = LlamaTokenizer.from_pretrained(hf_path) - state = service.start() - - for line in ["one two three four five six seven eight"]: - prompt = line.strip() - if not prompt: - break - - input_ids = tokenizer.encode(prompt, return_tensors="pt")[0].tolist() - print(input_ids) - request = GenerateRequest("request_id", prompt, input_ids) - await state.set_sequences([request]) - logits = await state.prefill() - - seq_len = len(input_ids) - mapped_logits = map_buffer(logits.value) - predicted_tokens = numpy.argmax(mapped_logits[0, :seq_len], axis=-1) - predicted_token = predicted_tokens[-1] - decoded_token = tokenizer.decode(predicted_token) - print(f"Prefill predicted token: {predicted_token}, decoded: '{decoded_token}'") - - # TODO(scotttodd): sanity check tokenizer use, document inputs/outputs - # 'prefill' is for initialization with multiple steps at once - # 'decode' is for hypothesis exploration, one step at a time - await state.set_decode_step([predicted_token]) - logits = await state.decode() - mapped_logits = map_buffer(logits.value) - predicted_tokens = numpy.argmax(mapped_logits, axis=-1) - predicted_token = predicted_tokens[0] - decoded_token = tokenizer.decode(predicted_token) - print(f"Decode predicted token: {predicted_token}, decoded: '{decoded_token}'") - await state.recycle() - - service.shutdown() - - -if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) diff --git a/sharktank/sharktank/serving_poc/llm/service.py b/sharktank/sharktank/serving_poc/llm/service.py deleted file mode 100644 index c5d4ffb44..000000000 --- a/sharktank/sharktank/serving_poc/llm/service.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import AsyncIterator, Callable, Optional - -from abc import abstractmethod, ABC -import asyncio -from dataclasses import dataclass - -from ..framework.session import ( - HostContext, -) - -######################################################################################## -# User-level single request service -######################################################################################## - - -@dataclass -class GenerateRequest: - """Encapsulates a request to perform LLM generation. - - Requests are bootstrapped from user values and then pumped through the pipeline, - receiving additional elaboration needed to actually begin generation. - """ - - # Client set fields - request_id: str - prompt: str - - # Fields that are set as the request is processed. - prompt_token_ids: Optional[list[int]] = None - - @property - def required_prompt_token_ids(self) -> list[int]: - ids = self.prompt_token_ids - assert ids is not None - return ids - - -@dataclass -class GenerateResponsePart: - """A response part from an LLM generation request.""" - - request: GenerateRequest - index: int - token_ids: list[int] - - # Fields that can be set as the response is post-processed. - text: Optional[str] = None - finished: bool = False - - -class GenerateService(ABC): - """Asynchronous generator service which processes requests into response parts.""" - - @abstractmethod - def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - """Generates response parts for a request.""" - ... - - @abstractmethod - async def abort(self, request_id: str) -> None: - """Aborts a submitted request.""" - ... - - -######################################################################################## -# Batch generation service -# This service is completely asynchronous and operates on a BatchGenerateRequest as -# a state machine. It is expected to have an external actor stepping it through -# states. -######################################################################################## - - -class BatchGenerateService(ABC): - """Handles generation of a batch of requests.""" - - __slots__ = [] # type: ignore - - # def start_prefill(self, request: BatchGenerateRequest): - # ... - @abstractmethod - def start(self) -> "BatchGenerateState": - ... - - -class BatchGenerateState(ABC): - """In-progress batch generation state.""" - - __slots__ = [ - "host_context", - ] - - def __init__(self, host_context: HostContext): - self.host_context = host_context - - -######################################################################################## -# Utilities -######################################################################################## - - -class SyncGenerateFilter(GenerateService): - """GenerateService filter which can synchronously pre/post process.""" - - __slots__ = ["_next"] - - def __init__(self, next: GenerateService): - self._next = next - - def filter_request(self, request: GenerateRequest): - ... - - def filter_response(self, part: GenerateResponsePart): - ... - - async def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - self.filter_request(request) - async for part in self._next.handle_request(request): - self.filter_response(part) - yield part - - async def abort(self, request_id: str) -> None: - """Aborts a submitted request.""" - await self._next.abort(request_id) - - -######################################################################################## -# Testing and mock types -######################################################################################## - - -def create_mock_generate_service() -> GenerateService: - return DummyTokenizerService(EchoGenerateService()) - - -class DummyTokenizerService(SyncGenerateFilter): - """Tokenizer service which will map to code points. - - Useful for testing. - """ - - def filter_request(self, request: GenerateRequest): - if request.prompt_token_ids is None: - request.prompt_token_ids = [ord(c) for c in request.prompt] - - def filter_response(self, part: GenerateResponsePart): - if part.text is None: - part.text = "".join([chr(x) for x in part.token_ids]) - - -class EchoGenerateService(GenerateService): - """Dummy implementation of a generate service. - - It just echoes back the request five times after a delay. - """ - - def __init__(self, delay: float = 0.1): - self._delay = delay - - async def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - next = None - for i in range(5): - if next: - yield next - assert request.prompt_token_ids, "Request lacks prompt tokens" - next = GenerateResponsePart( - request, i, request.prompt_token_ids, finished=False - ) - await asyncio.sleep(self._delay) - if next: - next.finished = True - yield next - - async def abort(self, request_id: str) -> None: - pass diff --git a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py b/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py deleted file mode 100644 index a36ebe667..000000000 --- a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Implements a service_v1 compliant module in Python for testing. - -This uses a PyModuleInterface to define a fake VmModule that exposes 'prefill_bs{n}' -and 'decode_bs{n}' such that the call sequence and args/results can be manipulated. -""" - -import numpy as np -import textwrap -import threading - -from iree.runtime import ( # type: ignore - BufferUsage, - HalBuffer, - HalBufferView, - HalDevice, - HalElementType, - HalFence, - MemoryType, - PyModuleInterface, - VmModule, - VmRef, -) - -from ..config import ModelParams - - -def create_fake_module( - device: HalDevice, module_name: str, model_params: ModelParams -) -> VmModule: - class ServiceV1Module: - def __init__(self, iface): - ... - print("IFACE:", iface, dir(iface)) - - def prefill( - self, - bs: int, - token_ids_ref: VmRef, - seq_lens_ref: VmRef, - attn_block_indices_ref: VmRef, - attn_block_buffer_view: VmRef, - ): - result_array: np.ndarray = np.ndarray([bs, 1], dtype=np.int32) - - def run(): - print(f"FAKE_V1_MODULE: PREFILL bs={bs} : WAIT") - print(" - READY") - _format_device_buffer_view( - lambda s: print(" token_ids =", s), token_ids_ref - ) - _format_device_buffer_view( - lambda s: print(" seq_lens =", s), seq_lens_ref - ) - _format_device_buffer_view( - lambda s: print(" attn_block_indices =", s), - attn_block_indices_ref, - ) - _format_device_buffer_view( - lambda s: print(" attn_block_buffer_view =", s), - attn_block_buffer_view, - ) - - # Async populate. - device_array = result_bv.map().asarray( - result_array.shape, result_array.dtype - ) - for i in range(bs): - device_array[i, 0] = i + 1 - - threading.Thread(target=run).start() - - result_buffer = device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL | MemoryType.HOST_VISIBLE, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=result_array.size * result_array.itemsize, - ) - result_bv = HalBufferView( - result_buffer, result_array.shape, HalElementType.INT_32 - ) - return result_bv.ref - - def decode(self, bs: int): - print(f"FAKE_V1_MODULE: DECODE bs={bs}") - - iface = PyModuleInterface(module_name=module_name, ctor=ServiceV1Module) - - # Dynamically define prefill functions. - def add_prefill_bs(bs: int): - def trampoline(self, *args): - return self.prefill(bs, *args) - - iface.export(f"prefill_bs{bs}", "0rrrr_r", trampoline) - - [add_prefill_bs(bs) for bs in model_params.prefill_batch_sizes] - - # Dynamically define decode functions. - def add_decode_bs(bs: int): - def trampoline(self, *args): - return self.decode(bs, *args) - - iface.export(f"decode_bs{bs}", "0v_v", trampoline) - - [add_decode_bs(bs) for bs in model_params.decode_batch_sizes] - - return iface.create() - - -def _format_device_buffer_view(callback, bv_ref: VmRef): - bv = bv_ref.deref(HalBufferView) # type: HalBufferView - value = bv.map().asarray(bv.shape, HalElementType.map_to_dtype(bv.element_type)) - value_indented = textwrap.indent(repr(value), " ") - callback(f"{bv!r}\n{value_indented}") diff --git a/sharktank/sharktank/serving_poc/py.typed b/sharktank/sharktank/serving_poc/py.typed deleted file mode 100644 index 5e43cc13b..000000000 --- a/sharktank/sharktank/serving_poc/py.typed +++ /dev/null @@ -1 +0,0 @@ -# Marker file for PEP 561 inline type checking. diff --git a/sharktank/tests/serving_poc/framework/device_session_test.py b/sharktank/tests/serving_poc/framework/device_session_test.py deleted file mode 100644 index 5dfdd5f46..000000000 --- a/sharktank/tests/serving_poc/framework/device_session_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from sharktank.serving_poc.framework.session import ( - DeviceSession, -) - - -@pytest.fixture -def local_device_session(): - session = DeviceSession(uri="local-task") - yield session - session.shutdown() - - -def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession): - ms = local_device_session.create_module_set("default") - ms.initialize() - - -def test_host_context_start_stop(local_device_session: DeviceSession): - ms = local_device_session.create_module_set("default") - ms.initialize() - hc = ms.host_context - - -def test_host_context_scheduling(local_device_session: DeviceSession): - device = local_device_session.device - ms = local_device_session.create_module_set("default") - ms.initialize() - hc = ms.host_context - - sem = device.create_semaphore(0) - - async def task1(): - print("[coro1] test_host_context_scheduling.task") - await hc.on_semaphore(sem, 1, True) - print("[coro1] await completed") - sem.signal(2) - - async def task2(): - print("[coro2] waiting for 2") - await hc.on_semaphore(sem, 2, True) - sem.fail("Fail from task2") - - f1 = hc.run_concurrent(task1()) - f2 = hc.run_concurrent(task2()) - sem.signal(1) - print("[main] Waiting for semaphore") - - # Ensure task completion. Important to consume to ensure that exceptions - # propagate. - f1.result() - f2.result() - - print("[main] Waiting on semaphore payload 3") - with pytest.raises(Exception, match="Fail from task2"): - sem.wait(3) diff --git a/sharktank/tests/serving_poc/llm/api_server_test.py b/sharktank/tests/serving_poc/llm/api_server_test.py deleted file mode 100644 index c2d2cc36a..000000000 --- a/sharktank/tests/serving_poc/llm/api_server_test.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import os -from contextlib import closing -from pathlib import Path -import pytest -import requests -import socket -import subprocess -import sys -import time - - -def find_free_port(): - """This tries to find a free port to run a server on for the test. - - Race conditions are possible - the port can be acquired between when this - runs and when the server starts. - - https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - """ - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("localhost", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -class ServerRunner: - def __init__(self, args): - port = str(find_free_port()) - self.url = "http://localhost:" + port - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.process = subprocess.Popen( - [ - sys.executable, - "-m", - "sharktank.serving_poc.llm.api.rest_server", - "--testing-mock-service", - "--port=" + port, - ] - + args, - env=env, - # TODO: Have a more robust way of forking a subprocess. - cwd=str(Path(__file__).resolve().parent.parent.parent), - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_ready() - - def _wait_for_ready(self): - start = time.time() - while True: - try: - if requests.get(f"{self.url}/health").status_code == 200: - return - except Exception as e: - if self.process.poll() is not None: - raise RuntimeError("API server processs terminated") from e - time.sleep(1.0) - if (time.time() - start) > 30: - raise RuntimeError("Timeout waiting for server start") - - def __del__(self): - try: - process = self.process - except AttributeError: - pass - else: - process.terminate() - process.wait() - - -@pytest.fixture(scope="session") -def server(): - try: - import fastapi - import uvicorn - except ModuleNotFoundError as e: - pytest.skip(f"Skipping server test because deps are missing: {e}") - runner = ServerRunner([]) - yield runner - - -def test_health(server: ServerRunner): - # Health check is part of getting the fixture. - ... - - -def test_generate_non_streaming(server: ServerRunner): - resp = requests.post( - f"{server.url}/generate", - json={ - "prompt": "Hi Bob", - }, - ) - resp.raise_for_status() - d = resp.json() - assert d["text"] == "Hi Bob", repr(d) - - -def test_generate_streaming(server: ServerRunner): - resp = requests.post( - f"{server.url}/generate", json={"prompt": "Hi Bob!", "stream": True} - ) - resp.raise_for_status() - full_contents = resp.content - expected_contents = b'{"text": "Hi Bob!"}\x00' * 5 - assert ( - full_contents == expected_contents - ), f"Expected {expected_contents!r} vs {full_contents!r}" diff --git a/sharktank/tests/serving_poc/llm/service_v1_test.py b/sharktank/tests/serving_poc/llm/service_v1_test.py deleted file mode 100644 index c010e2034..000000000 --- a/sharktank/tests/serving_poc/llm/service_v1_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from iree.runtime import ( # type: ignore - HalElementType, -) - -from sharktank.serving_poc.framework.session import DeviceSession -from sharktank.serving_poc.llm.config import ( - CacheParams, - ModelParams, - ServiceParams, -) - -from sharktank.serving_poc.llm.service import ( - GenerateRequest, - GenerateResponsePart, -) - -from sharktank.serving_poc.llm.attn_block_cache import ( - create_attn_block_cache_module, - AttnBlockCache, -) - -from sharktank.serving_poc.llm.impl.service_v1 import ( - GenerateServiceV1, -) - -from sharktank.serving_poc.llm.testing.fake_v1_module import ( - create_fake_module, -) - - -@pytest.fixture -def cache_params(model_params: ModelParams) -> CacheParams: - return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16) - - -@pytest.fixture -def model_params() -> ModelParams: - return ModelParams( - module_name="AwesomeLLM", - module_abi_version=1, - attn_dtype=HalElementType.FLOAT_16, - max_seq_len=128, - transformer_block_count=32, - attn_head_count=32, - attn_head_dim=128, - block_seq_stride=16, - prefill_batch_sizes=[1, 4, 16], - decode_batch_sizes=[1, 4, 16], - ) - - -@pytest.fixture -def uninitialized_session(model_params: ModelParams): - from iree.runtime._binding import disable_leak_checker # type: ignore - - disable_leak_checker() - session = DeviceSession(uri="local-task", queue_count=2) - yield session - session.shutdown() - del session - - -@pytest.fixture -def attn_block_cache( - uninitialized_session: DeviceSession, cache_params: CacheParams -) -> AttnBlockCache: - return AttnBlockCache(uninitialized_session, cache_params) - - -@pytest.fixture -def session( - model_params: ModelParams, - uninitialized_session: DeviceSession, - attn_block_cache: AttnBlockCache, -): - session = uninitialized_session - lms = session.create_module_set("AwesomeLLM", context_count=1) - lms.add( - create_attn_block_cache_module(attn_block_cache), - create_fake_module(session.device, "AwesomeLLM", model_params=model_params), - ) - lms.initialize() - return session - - -@pytest.fixture -def service( - session: DeviceSession, - cache_params: CacheParams, - model_params: ModelParams, - attn_block_cache: AttnBlockCache, -): - params = ServiceParams(cache=cache_params, model=model_params) - return GenerateServiceV1(session=session, params=params, cache=attn_block_cache) - - -def test_single(service: GenerateServiceV1): - state = service.start() - - async def task(): - await state.set_sequences( - requests=[ - GenerateRequest( - "1", - "hello, tell me a story", - [3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73], - ), - GenerateRequest("2", "goodbye", [9, 10]), - ] - ) - guarded_outputs = await state.prefill() - prefill_ids = await guarded_outputs.resolve(state.host_context) - print( - "PREFILL IDS:", - prefill_ids, - ":\n", - prefill_ids.map().asarray( - prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type) - ), - ) - await state.recycle() - - state.host_context.run_sync(task()) From 5aae18d1e48afcc06ce421aa7e14ab12424d9127 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Wed, 27 Nov 2024 09:37:00 -0800 Subject: [PATCH 25/25] Remove test_punet job from ci-sharktank.yml. (#621) Partial revert of https://github.com/nod-ai/shark-ai/pull/613. This job takes over an hour to run (specifically `test_punet_eager_fp16_validation `) and is not suitable for presubmit. This is likely also not the right file for such tests, and if it was then the specialized tests should go under the generic `test` job, not at the top of the file. --- .github/workflows/ci-sharktank.yml | 39 ------------------------------ 1 file changed, 39 deletions(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 6e8cee3bb..1d3960b43 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -22,45 +22,6 @@ concurrency: cancel-in-progress: true jobs: - test_punet: - name: "Integration Tests - punet" - runs-on: nodai-amdgpu-mi250-x86-64 - env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" - steps: - - name: "Setting up Python" - id: setup_python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: 3.11 - - - name: "Checkout Code" - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Cache Pip Packages - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} - - - name: Install pip deps - run: | - python -m pip install --no-compile --upgrade pip - # Note: We install in three steps in order to satisfy requirements - # from non default locations first. Installing the PyTorch CPU - # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ - - # Update to the latest iree packages. - pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ - iree-base-compiler iree-base-runtime --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - - name: Run punet tests - run: | - pytest -v sharktank/ -m model_punet - test: name: "Unit Tests and Type Checking" strategy: