From fdf4a93e52bb00c87cf74ebe5d95d8d4ccc29366 Mon Sep 17 00:00:00 2001 From: kta-intel Date: Wed, 22 Nov 2023 16:47:31 -0600 Subject: [PATCH 1/3] enable IPEX optimization Signed-off-by: kta-intel --- Dockerfile | 14 +++++ README.md | 50 ++++++++++++++++++ pb/client.py | 33 ++++++++++++ .../inference_engine/hf_transformers_ipex.py | 51 +++++++++++++++++++ .../models/causal_lm.py | 3 ++ 5 files changed, 151 insertions(+) create mode 100644 pb/client.py create mode 100644 server/text_generation_server/inference_engine/hf_transformers_ipex.py diff --git a/Dockerfile b/Dockerfile index d1134085..1ab5658c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,8 +2,10 @@ ARG BASE_UBI_IMAGE_TAG=9.3-1361.1699548029 ARG PROTOC_VERSION=25.0 ARG PYTORCH_INDEX="https://download.pytorch.org/whl" +ARG IPEX_INDEX="https://pytorch-extension.intel.com/release-whl/stable/cpu/us/" #ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly" ARG PYTORCH_VERSION=2.1.0 +ARG IPEX_VERSION=2.1.0 ## Base Layer ################################################################## FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} as base @@ -148,6 +150,7 @@ WORKDIR /usr/src # Install specific version of torch RUN pip install torch=="$PYTORCH_VERSION+cpu" --index-url "${PYTORCH_INDEX}/cpu" --no-cache-dir +RUN pip install intel-extension-for-pytorch=="$IPEX_VERSION" --extra-index-url "${IPEX_INDEX}" --no-cache-dir COPY server/Makefile server/Makefile @@ -174,6 +177,8 @@ RUN cd integration_tests && make install FROM cuda-devel as python-builder ARG PYTORCH_INDEX ARG PYTORCH_VERSION +ARG IPEX_INDEX +ARG IPEX_VERSION RUN dnf install -y unzip git ninja-build && dnf clean all @@ -187,6 +192,7 @@ ENV PATH=/opt/miniconda/bin:$PATH # Install specific version of torch RUN pip install ninja==1.11.1.1 --no-cache-dir RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu118" --no-cache-dir +RUN pip install intel-extension-for-pytorch~="$IPEX_VERSION" --extra-index-url "${IPEX_INDEX}" --no-cache-dir ## Build flash attention v2 #################################################### @@ -241,6 +247,14 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build /usr/sr FROM base as flash-att-v2-cache COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build /usr/src/flash-attention-v2/build +## Setup environment variables for performance on Xeon +ENV KMP_BLOCKTIME=INF +ENV KMP_TPAUSE=0 +ENV KMP_SETTINGS=1 +ENV KMP_AFFINITY=granularity=fine,compact,1,0 +ENV KMP_FORJOIN_BARRIER_PATTERN=dist,dist +ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist +ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist ## Final Inference Server image ################################################ FROM cuda-runtime as server-release diff --git a/README.md b/README.md index 41626221..d9ab354c 100644 --- a/README.md +++ b/README.md @@ -158,3 +158,53 @@ They are all prefixed with `tgi_`. Descriptions will be added to the table below | `tgi_tokenize_request_input_count` | `counter` | | | | `tgi_tokenize_request_tokens` | `histogram` | | | | `tgi_tokenize_request_duration` | `histogram` | | | + +### Run Inference Locally with Intel(R) Extension for PyTorch* + +#### 0. Build the image + +``` +make build +``` + +This command will print the Docker image id for `text-gen-server`. Set `IMAGE_ID` in the commands below to this. + +#### 1. Run the server + +``` +export IMAGE_ID= +export MODEL= +export volume=$PWD/data +mkdir $volume +chmod 777 volume +``` + +It's possible to use `text-generation-server download-weights`, but in this example we use a model that we download locally with `transformers-cli`. + +``` +transformers-cli download $MODEL +``` + +Move model from `~/.cache/huggingface/hub/` to `$volume` You can then run the inference server with: + +``` +docker run -p 8033:8033 -p 3000:3000 -e TRANSFORMERS_CACHE=/data -e HUGGINGFACE_HUB_CACHE=/data -e DEPLOYMENT_FRAMEWORK=hf_transformers_ipex -e MODEL_NAME=$MODEL -v $volume:/data $IMAGE_ID text-generation-launcher --dtype-str bfloat16 +``` + +#### 2. Prepare the client + +Install GRPC in a Python environment: `pip install grpcio grpcio-tools` + +In the repository root, run: +``` +python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generate.proto +python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generation.proto +``` +This generates the necessary files in the pb directory. + +Then to run inference: +``` +python pb/client.py +``` + +Edit `pb/client.py` to change the prompts. \ No newline at end of file diff --git a/pb/client.py b/pb/client.py new file mode 100644 index 00000000..157fa94e --- /dev/null +++ b/pb/client.py @@ -0,0 +1,33 @@ +import json +import time + +import grpc +import requests +from google.protobuf import json_format + +import generation_pb2 as pb2 +import generation_pb2_grpc as gpb2 + +port = 8033 +channel = grpc.insecure_channel(f"localhost:{port}") +stub = gpb2.GenerationServiceStub(channel) + +# warmup inference +for i in range (5): + text = "hello world" + message = json_format.ParseDict( + {"requests": [{"text": text}]}, pb2.BatchedGenerationRequest() + ) + response = stub.Generate(message) + +# time inference +for prompt in ["The weather is", "The cat is walking on", "I would like to"]: +# for prompt in ["def hello_world():"]: + message = json_format.ParseDict( + {"requests": [{"text": prompt}]}, pb2.BatchedGenerationRequest() + ) + start = time.perf_counter() + response = stub.Generate(message) + end = time.perf_counter() + print(prompt, response) + print(f"Duration: {end-start:.2f}") \ No newline at end of file diff --git a/server/text_generation_server/inference_engine/hf_transformers_ipex.py b/server/text_generation_server/inference_engine/hf_transformers_ipex.py new file mode 100644 index 00000000..ee002e9a --- /dev/null +++ b/server/text_generation_server/inference_engine/hf_transformers_ipex.py @@ -0,0 +1,51 @@ +import os +import torch +import intel_extension_for_pytorch as ipex +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from text_generation_server.inference_engine.engine import BaseInferenceEngine +from text_generation_server.utils.hub import TRUST_REMOTE_CODE +from typing import Any, Optional + + +class InferenceEngine(BaseInferenceEngine): + def __init__( + self, + model_path: str, + model_class: type[_BaseAutoModelClass], + dtype: torch.dtype, + model_config: Optional[Any] + ) -> None: + super().__init__(model_path, model_config) + + kwargs = { + "pretrained_model_name_or_path": model_path, + "local_files_only": True, + "trust_remote_code": TRUST_REMOTE_CODE, + "torchscript": 'jit', + "torch_dtype": dtype + } + + if model_config.model_type == "mpt": + model_config.init_device = str(self.device) + kwargs["config"] = model_config + + try: + ipex._C.disable_jit_linear_repack() + except Exception: + pass + + torch._C._jit_set_texpr_fuser_enabled(False) + + slow_but_exact = os.getenv('BLOOM_SLOW_BUT_EXACT', 'false').lower() == 'true' + if slow_but_exact: + kwargs["slow_but_exact"] = True + + with self.device: + self.model = model_class.from_pretrained(**kwargs).requires_grad_(False).eval() + + self.model = self.model.to(memory_format=torch.channels_last) + self.model = ipex.optimize_transformers(self.model, dtype=dtype, inplace=True) + print('Intel(R) Extension for PyTorch* enabled') + + self.model.to(self.device) \ No newline at end of file diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7d2e856e..a0406e77 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -575,6 +575,9 @@ def __init__( _, past_key_values, _ = self.forward(input_ids=one_token, attention_mask=one_token) if torch.is_tensor(past_key_values[0]): self.batch_type = CombinedKVCausalLMBatch + elif 'ipex' in deployment_framework: + print(deployment_framework) + self.batch_type = CausalLMBatch else: # check the ordering of the key tensor dimensions key_past, value_past = past_key_values[0] From f04ae08b35635af69d6e93428eafb7d6648f9150 Mon Sep 17 00:00:00 2001 From: kta-intel Date: Wed, 17 Jan 2024 09:51:31 -0800 Subject: [PATCH 2/3] upgrade to latest IPEX 2.1.100 Signed-off-by: kta-intel --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c6a4e290..e16b3c6c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG PYTORCH_VERSION=2.1.0 # ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly" # ARG PYTORCH_VERSION=2.3.0.dev20231221 ARG IPEX_INDEX="https://pytorch-extension.intel.com/release-whl/stable/cpu/us/" -ARG IPEX_VERSION=2.1.0 +ARG IPEX_VERSION=2.1.100 ## Base Layer ################################################################## From 3ccf76bfc44d64e2c8dd011010cabe734a29f4fb Mon Sep 17 00:00:00 2001 From: kta-intel Date: Wed, 31 Jan 2024 10:19:48 -0800 Subject: [PATCH 3/3] fix input to deployment framework Signed-off-by: kta-intel --- .../inference_engine/hf_transformers_ipex.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/inference_engine/hf_transformers_ipex.py b/server/text_generation_server/inference_engine/hf_transformers_ipex.py index ee002e9a..79153ae6 100644 --- a/server/text_generation_server/inference_engine/hf_transformers_ipex.py +++ b/server/text_generation_server/inference_engine/hf_transformers_ipex.py @@ -14,6 +14,7 @@ def __init__( model_path: str, model_class: type[_BaseAutoModelClass], dtype: torch.dtype, + quantize: Optional[str], model_config: Optional[Any] ) -> None: super().__init__(model_path, model_config)