diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml new file mode 100644 index 00000000..8a9fde49 --- /dev/null +++ b/.github/workflows/build_rocm.yaml @@ -0,0 +1,134 @@ + name: Build and push AMD ROCm docker image to registry + + on: + workflow_dispatch: + push: + branches: + - 'main' + tags: + - 'v*' + pull_request: + paths: + - ".github/workflows/build.yaml" +# - "integration-tests/**" + - "backends/**" + - "core/**" + - "router/**" + - "Cargo.lock" + - "rust-toolchain.toml" + - "Dockerfile" + branches: + - 'main' + + jobs: + build-and-push-image: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + - name: Configure sccache + uses: actions/github-script@v6 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Tailscale + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} + password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} + registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}} + type=semver,pattern=rocm-{{major}}.{{minor}} + type=raw,value=rocm-latest + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }} + + - name: Build and push Docker image + id: build-and-push-rocm + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm.outputs.tags }} + labels: ${{ steps.meta-rocm.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}}-grpc + type=semver,pattern=rocm-{{major}}.{{minor}}-grpc + type=raw,value=rocm-latest-grpc + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + + - name: Build and push Docker image + id: build-and-push-rocm-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm-grpc.outputs.tags }} + labels: ${{ steps.meta-rocm-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max diff --git a/.gitignore b/.gitignore index ee44a963..6862c2f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea target +__pycache__/ diff --git a/Dockerfile-rocm b/Dockerfile-rocm new file mode 100644 index 00000000..152fa0a0 --- /dev/null +++ b/Dockerfile-rocm @@ -0,0 +1,135 @@ +FROM rocm/dev-ubuntu-22.04:6.0.2 AS base-builder + +ENV SCCACHE=0.5.4 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Donwload and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --locked + +FROM base-builder AS planner + +WORKDIR /usr/src + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM base-builder AS builder + +ARG CUDA_COMPUTE_CAP=80 +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG ACTIONS_CACHE_URL +ARG ACTIONS_RUNTIME_TOKEN +ARG SCCACHE_GHA_ENABLED + +WORKDIR /usr/src + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + unzip \ + && rm -rf /var/lib/apt/lists/* + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +FROM builder as http-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder as grpc-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F grpc --no-default-features && sccache -s + +FROM rocm/dev-ubuntu-22.04:6.0.2 as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + git \ + python3-dev \ + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev \ + hipblaslt-dev \ + rocblas-dev \ + hiprand-dev \ + rocrand-dev \ + && rm -rf /var/lib/apt/lists/* + + +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTORCH_VERSION='2.3.0' +ARG ROCM_VERSION='6.0.2' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +RUN curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh + +# Install flash-attention, torch dependencies +RUN pip install numpy einops ninja --no-cache-dir + +RUN pip install torch --index-url https://download.pytorch.org/whl/rocm6.0 + +ARG DEFAULT_USE_FLASH_ATTENTION=True +COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 +RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm + +# Install python backend +COPY backends/python/server /tei_backends/python/server +COPY backends/proto tei_backends/proto +RUN make -C /tei_backends/python/server install + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 \ + USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION + +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/python/Makefile-flash-att-v2 b/backends/python/Makefile-flash-att-v2 new file mode 100644 index 00000000..ba90a74d --- /dev/null +++ b/backends/python/Makefile-flash-att-v2 @@ -0,0 +1,21 @@ +flash_att_v2_commit_cuda := v2.5.9.post1 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 + +build-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) + +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + echo "Flash v2 installed" + +build-flash-attention-v2-rocm: + if [ ! -d 'flash-attention-v2' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ + git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi + +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml index 96fcaf9e..8fbc0008 100644 --- a/backends/python/server/pyproject.toml +++ b/backends/python/server/pyproject.toml @@ -15,12 +15,13 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -safetensors = "^0.3.2" +safetensors = "^0.4.0" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" -torch = { version = "^2.0.1" } +torch = { version = "==2.3.1" } +transformers = { version = "^4.39.0"} [tool.poetry.extras] @@ -33,6 +34,11 @@ name = "pytorch-gpu-src" url = "https://download.pytorch.org/whl/cu118" priority = "explicit" +[[tool.poetry.source]] +name = "pytorch-gpu-src-rocm" +url = "https://download.pytorch.org/whl/rocm6.0" +priority = "explicit" + [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index 89ca314d..2d089e41 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -4,20 +4,13 @@ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" -jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" -networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" @@ -27,15 +20,10 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" -torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 9497dc20..4f423afe 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -24,6 +24,7 @@ def serve( json_output: bool = False, otlp_endpoint: Optional[str] = None, otlp_service_name: str = "text-embeddings-inference.server", + pooling_mode: Optional[str] = None, ): # Remove default handler logger.remove() @@ -48,7 +49,7 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path) + server.serve(model_path, dtype, uds_path, pooling_mode) if __name__ == "__main__": diff --git a/backends/python/server/text_embeddings_server/layers/__init__.py b/backends/python/server/text_embeddings_server/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backends/python/server/text_embeddings_server/layers/attention/__init__.py b/backends/python/server/text_embeddings_server/layers/attention/__init__.py new file mode 100644 index 00000000..42aac2bd --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/__init__.py @@ -0,0 +1,14 @@ +from text_embeddings_server.utils.import_utils import SYSTEM +import os + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + class Attention: + def __getattr__(self, name): + raise RuntimeError(f"TEI is used with USE_FLASH_ATTENTION=false, accessing `attention` is prohibited") + attention = Attention() +if SYSTEM == "cuda": + from .cuda import attention +elif SYSTEM == "rocm": + from .rocm import attention +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/backends/python/server/text_embeddings_server/utils/flash_attn.py b/backends/python/server/text_embeddings_server/layers/attention/cuda.py similarity index 100% rename from backends/python/server/text_embeddings_server/utils/flash_attn.py rename to backends/python/server/text_embeddings_server/layers/attention/cuda.py diff --git a/backends/python/server/text_embeddings_server/layers/attention/rocm.py b/backends/python/server/text_embeddings_server/layers/attention/rocm.py new file mode 100644 index 00000000..365e5451 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/rocm.py @@ -0,0 +1,45 @@ +import os +import torch +from text_embeddings_server.utils.import_utils import SYSTEM +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 + +if SYSTEM == "rocm": + try: + import flash_attn_2_cuda + + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError as e: + if major >= 8 or is_sm75: + architecture_suffix = f"-{SYSTEM}" + raise ImportError(f"Flash Attention V2 is not installed. {e}") + else: + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name and "MI300" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + +def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + False, + None, + ) \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/layers/layernorm.py b/backends/python/server/text_embeddings_server/layers/layernorm.py new file mode 100644 index 00000000..abd9e676 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/layernorm.py @@ -0,0 +1,54 @@ +import torch +from text_embeddings_server.utils.import_utils import SYSTEM + +from transformers.models.bert import BertConfig + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + normed_hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = torch.nn.functional.layer_norm(hidden_states, self.weight.shape, self.weight, self.bias, eps=self.variance_epsilon) + + return hidden_states, residual +else: + raise ValueError("System not recognized") \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/layers/pooling.py b/backends/python/server/text_embeddings_server/layers/pooling.py new file mode 100644 index 00000000..1bccbc57 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/pooling.py @@ -0,0 +1,22 @@ +import torch +from flash_attn.bert_padding import pad_input + +from loguru import logger + +def mean_pooling(embedding, cu_seqlens, max_s): + # Ideally, rust would pass `indices` to the FlashBatch. + seqlens = cu_seqlens[1:].clone() + seqlens[0] = cu_seqlens[1] + seqlens[1:] -= cu_seqlens[1:-1] + batch_size = len(seqlens) + + # Example: indices = [0, 1, 2, 3, 7, 8, 9, 10, 11, 12, 13] + mask = torch.zeros(batch_size, max_s, dtype=torch.int32, device=cu_seqlens.device) + mask[torch.arange(max_s) < seqlens[:, None].cpu()] = 1 + indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + embedding_padded = pad_input(embedding, indices, batch_size, max_s) + + sum_embeddings = torch.sum(embedding_padded, 1) + + return sum_embeddings / seqlens[:, None] \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 47867187..c606efc9 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -15,17 +15,17 @@ torch.set_grad_enabled(False) FLASH_ATTENTION = True -try: - from text_embeddings_server.models.flash_bert import FlashBert -except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") - FLASH_ATTENTION = False +# try: +from text_embeddings_server.models.flash_bert import FlashBert +# except ImportError as e: +# logger.warning(f"Could not import Flash Attention enabled models: {e}") +# FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) -def get_model(model_path: Path, dtype: Optional[str]): +def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": @@ -52,8 +52,8 @@ def get_model(model_path: Path, dtype: Optional[str]): and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): - return FlashBert(model_path, device, dtype) + return FlashBert(model_path, device, dtype, pooling_mode) else: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, pooling_mode) raise NotImplementedError diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index dc39fdc8..17ad4589 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -8,14 +8,16 @@ from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding +from typing import Optional tracer = trace.get_tracer(__name__) class DefaultModel(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): model = AutoModel.from_pretrained(model_path).to(dtype).to(device) self.hidden_size = model.config.hidden_size + self.pooling_mode = pooling_mode self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 50b8d70d..6ebb70d4 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -8,46 +8,16 @@ from transformers.models.bert import BertConfig from opentelemetry import trace -# Flash attention imports -import dropout_layer_norm - from text_embeddings_server.models import Model from text_embeddings_server.models.types import FlashBatch, Embedding -from text_embeddings_server.utils.flash_attn import attention +from text_embeddings_server.layers.attention import attention +from text_embeddings_server.layers.layernorm import FastLayerNorm +from text_embeddings_server.layers.pooling import mean_pooling +from typing import Optional tracer = trace.get_tracer(__name__) -class FastLayerNorm: - def __init__(self, prefix, handle, device, dtype, config: BertConfig): - self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) - self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) - self.variance_epsilon = config.layer_norm_eps - - def forward(self, hidden_states, residual=None): - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - False, - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - - class BertEmbeddings: def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.word_embeddings_weight = ( @@ -217,16 +187,17 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) - return encoder_outputs[cu_seqlens[:-1]] + return encoder_outputs class FlashBert(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): config = BertConfig.from_pretrained(model_path) with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashBertModel(f, device, dtype, config) self.hidden_size = config.hidden_size + self.pooling_mode = pooling_mode super(FlashBert, self).__init__(model=model, dtype=dtype, device=device) @@ -243,11 +214,24 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: cu_seqlens=batch.cu_seqlens, max_s=batch.max_s, ) - cpu_results = embedding.view(-1).tolist() - return [ - Embedding( - values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] - ) - for i in range(len(batch)) - ] + if self.pooling_mode == "cls": + embedding = embedding[batch.cu_seqlens[:-1]] + cpu_results = embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] + elif self.pooling_mode == "mean": + res = mean_pooling(embedding, batch.cu_seqlens, batch.max_s) + return [ + Embedding( + values=res[i] + ) + for i in range(len(batch)) + ] + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index d0a43ace..2c99cf79 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -37,6 +37,7 @@ def serve( model_path: Path, dtype: Optional[str], uds_path: Path, + pooling_mode: Optional[str], ): async def serve_inner( model_path: Path, @@ -45,7 +46,7 @@ async def serve_inner( unix_socket = f"unix://{uds_path}" try: - model = get_model(model_path, dtype) + model = get_model(model_path, dtype, pooling_mode) except Exception: logger.exception("Error when initializing model") raise diff --git a/backends/python/server/text_embeddings_server/utils/import_utils.py b/backends/python/server/text_embeddings_server/utils/import_utils.py new file mode 100644 index 00000000..83394eaa --- /dev/null +++ b/backends/python/server/text_embeddings_server/utils/import_utils.py @@ -0,0 +1,12 @@ +import torch +from loguru import logger + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" +else: + SYSTEM = "cpu" + +logger.info(f"Python backend: detected system {SYSTEM}") diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 195f1d37..ef33b7d2 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -23,6 +23,7 @@ impl PythonBackend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { match model_type { ModelType::Classifier => { @@ -31,8 +32,8 @@ impl PythonBackend { )) } ModelType::Embedding(pool) => { - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); + if pool != Pool::Cls && pool != Pool::Mean { + return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); } pool } @@ -44,6 +45,7 @@ impl PythonBackend { &uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, )?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 911c6984..2044a3e0 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -22,6 +22,7 @@ impl BackendProcess { uds_path: &str, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { // Get UDS path let uds = Path::new(uds_path); @@ -52,6 +53,9 @@ impl BackendProcess { python_server_args.push("--otlp-service-name".to_owned()); python_server_args.push(otlp_service_name); + python_server_args.push("--pooling-mode".to_owned()); + python_server_args.push(pooling_mode); + // Copy current process env let envs: Vec<(OsString, OsString)> = env::vars_os().collect(); diff --git a/backends/src/lib.rs b/backends/src/lib.rs index d332b4a7..db27cddc 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -39,6 +39,7 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { let (backend_sender, backend_receiver) = mpsc::unbounded_channel(); @@ -49,6 +50,7 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, )?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); @@ -138,6 +140,7 @@ fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] @@ -158,6 +161,7 @@ fn init_backend( uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, ) }) .join() diff --git a/router/src/lib.rs b/router/src/lib.rs index d2023515..14f1dfb3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; + let (backend_model_type, inferred_pooling) = get_backend_model_type(&config, &model_root, &pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,6 +191,11 @@ pub async fn run( } }); + let pooling_str = match inferred_pooling { + Some(pool) => pool.to_string(), + None => "none".to_string(), + }; + // Create backend tracing::info!("Starting model backend"); let backend = text_embeddings_backend::Backend::new( @@ -200,6 +205,7 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), + pooling_str, ) .context("Could not create backend")?; backend @@ -306,24 +312,24 @@ pub async fn run( fn get_backend_model_type( config: &ModelConfig, model_root: &Path, - pooling: Option, -) -> Result { + pooling: &Option, +) -> Result<(text_embeddings_backend::ModelType, Option)> { for arch in &config.architectures { - if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") { - return Ok(text_embeddings_backend::ModelType::Embedding( + if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { + return Ok((text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, - )); + ), Some(text_embeddings_backend::Pool::Splade))); } else if arch.ends_with("Classification") { if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." ); } - return Ok(text_embeddings_backend::ModelType::Classifier); + return Ok((text_embeddings_backend::ModelType::Classifier, None)); } } - if Some(text_embeddings_backend::Pool::Splade) == pooling { + if Some(text_embeddings_backend::Pool::Splade) == *pooling { return Err(anyhow!( "Splade pooling is not supported: model is not a ForMaskedLM model" )); @@ -331,7 +337,7 @@ fn get_backend_model_type( // Set pooling let pool = match pooling { - Some(pool) => pool, + Some(pool) => pool.clone(), None => { // Load pooling config let config_path = model_root.join("1_Pooling/config.json"); @@ -347,7 +353,7 @@ fn get_backend_model_type( } } }; - Ok(text_embeddings_backend::ModelType::Embedding(pool)) + Ok((text_embeddings_backend::ModelType::Embedding(pool.clone()), Some(pool))) } #[derive(Debug, Deserialize)] diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..c4ff5d0b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,34 @@ +## Testing + +To run the tests, install from within docker with `--entrypoint "/bin/bash"` the requirements +``` +pip install -r requirements.txt +``` + +and mounting a volume for the tests, they can be run from within the container with +``` +pytest tests/ -s -vvvvv +``` + +## Reference outputs + +For example, collecting the reference on an RTX 4090 on Candle backend: +``` +docker run --rm -it --gpus all --net host --entrypoint "/bin/bash" -v $(pwd):/tei ghcr.io/huggingface/text-embeddings-inference:89-1.2.3 +``` +and +``` +text-embeddings-router --model-id sentence-transformers/all-MiniLM-L6-v2 +``` + +and then +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 --flash +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 --flash +``` + +Restart server with `USE_FLASH_ATTENTION=0`, and +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 +``` \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt new file mode 100644 index 00000000..aaf95a92 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt differ diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt new file mode 100644 index 00000000..d986e332 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt differ diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt new file mode 100644 index 00000000..bea6dca1 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt differ diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt new file mode 100644 index 00000000..7bf51879 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt differ diff --git a/tests/collect.py b/tests/collect.py new file mode 100644 index 00000000..313c0871 --- /dev/null +++ b/tests/collect.py @@ -0,0 +1,37 @@ + +import requests +import torch +import argparse +import json +import os + +parser = argparse.ArgumentParser(description='Assets collection') +parser.add_argument('--model-id', help='Model id', required=True) +parser.add_argument('--n_inp', help='Number of inputs', required=True, type=int) +parser.add_argument('--flash', action='store_true') + +args = parser.parse_args() + +url = f"http://0.0.0.0:80/embed" + +INPUTS = [ + "What is Deep Learning?", + "Today I am in Paris and I would like to", + "Paris weather is", + "Great job" +] + +data = {"inputs": INPUTS[:args.n_inp]} +headers = {"Content-Type": "application/json"} + +response = requests.post(url, json=data, headers=headers) + +embedding = torch.Tensor(json.loads(response.text)) + +postfix = "" +if not args.flash: + postfix = "_no_flash" + +save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt" +print(f"Saving embedding of shape {embedding.shape} to {save_path}") +torch.save(embedding, save_path) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..6d8ed997 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,113 @@ +import pytest +import asyncio +import contextlib +import random +import os +import tempfile +import subprocess +import shutil +import sys +from typing import Optional +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError +import requests +import time +from requests.exceptions import ConnectionError as RequestsConnectionError + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + +class ProcessLauncherHandle: + def __init__(self, process, port: int): + self.port = port + self.process = process + + def _inner_health(self) -> bool: + return self.process.poll() is None + + def health(self, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + if not self._inner_health(): + raise RuntimeError("Launcher crashed") + + try: + url = f"http://0.0.0.0:{self.port}/health" + headers = {"Content-Type": "application/json"} + + response = requests.post(url, headers=headers) + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError, RequestsConnectionError) as e: + print("Connecting") + time.sleep(1) + raise RuntimeError("Health check failed") + +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.contextmanager + def local_launcher( + model_id: str, + trust_remote_code: bool = False, + use_flash_attention: bool = True, + dtype: Optional[str] = None, + revision: Optional[str] = None, + pooling: Optional[str] = None, + ): + port = random.randint(8000, 10_000) + shard_uds_path = ( + f"/tmp/tei-tests-{model_id.split('/')[-1]}-server" + ) + + args = [ + "text-embeddings-router", + "--model-id", + model_id, + "--port", + str(port), + "--uds-path", + shard_uds_path, + ] + + env = os.environ + + if dtype is not None: + args.append("--dtype") + args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) + if trust_remote_code: + args.append("--trust-remote-code") + if pooling: + args.append("--pooling") + args.append(str(max_input_length)) + + env["LOG_LEVEL"] = "debug" + + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + + with tempfile.TemporaryFile("w+") as tmp: + # We'll output stdout/stderr to a temporary file. Using a pipe + # cause the process to block until stdout is read. + print("call subprocess.Popen, with args", args) + with subprocess.Popen( + args, + stdout=tmp, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port) + + process.terminate() + process.wait(60) + + tmp.seek(0) + shutil.copyfileobj(tmp, sys.stderr) + + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + + return local_launcher \ No newline at end of file diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..2f4c80e3 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..b1ee0f58 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest +pytest-asyncio +aiohttp \ No newline at end of file diff --git a/tests/test_default_model.py b/tests/test_default_model.py new file mode 100644 index 00000000..595fe6bf --- /dev/null +++ b/tests/test_default_model.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=False) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt") + + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py new file mode 100644 index 00000000..3c3fde1c --- /dev/null +++ b/tests/test_flash_bert.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=True) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt") + + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file