Skip to content

Commit ab2c4e8

Browse files
Merge pull request IBM#54 from Xaenalt/rhoai-2.8.4-granite-attention
RHOAI 2.8.4 granite attention
2 parents 3c432fb + 135402d commit ab2c4e8

File tree

13 files changed

+2355
-298
lines changed

13 files changed

+2355
-298
lines changed

Dockerfile

+80-91
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
## Global Args #################################################################
2-
ARG BASE_UBI_IMAGE_TAG=9.4-1181
3-
ARG PROTOC_VERSION=25.2
2+
ARG BASE_UBI_IMAGE_TAG=latest
3+
ARG PROTOC_VERSION=25.3
44
ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
55
# ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
6+
ARG AUTO_GPTQ_VERSION=0.7.1
67

78
# match PyTorch version that was used to compile flash-attention v2 pre-built wheels
89
# e.g. flash-attn v2.5.2 => torch ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126']
910
# https://github.com/Dao-AILab/flash-attention/blob/v2.5.2/.github/workflows/publish.yml#L47
1011
# use nightly build index for torch .dev pre-release versions
11-
ARG PYTORCH_VERSION=2.2.0
12+
ARG PYTORCH_VERSION=2.2.1
1213

1314
ARG PYTHON_VERSION=3.11
1415

@@ -35,18 +36,19 @@ ENV LANG=C.UTF-8 \
3536
## CUDA Base ###################################################################
3637
FROM base as cuda-base
3738

38-
ENV CUDA_VERSION=11.8.0 \
39-
NV_CUDA_LIB_VERSION=11.8.0-1 \
39+
# Ref: https://docs.nvidia.com/cuda/archive/12.1.0/cuda-toolkit-release-notes/
40+
ENV CUDA_VERSION=12.1.0 \
41+
NV_CUDA_LIB_VERSION=12.1.0-1 \
4042
NVIDIA_VISIBLE_DEVICES=all \
4143
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
42-
NV_CUDA_CUDART_VERSION=11.8.89-1 \
43-
NV_CUDA_COMPAT_VERSION=520.61.05-1
44+
NV_CUDA_CUDART_VERSION=12.1.55-1 \
45+
NV_CUDA_COMPAT_VERSION=530.30.02-1
4446

4547
RUN dnf config-manager \
4648
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
4749
&& dnf install -y \
48-
cuda-cudart-11-8-${NV_CUDA_CUDART_VERSION} \
49-
cuda-compat-11-8-${NV_CUDA_COMPAT_VERSION} \
50+
cuda-cudart-12-1-${NV_CUDA_CUDART_VERSION} \
51+
cuda-compat-12-1-${NV_CUDA_COMPAT_VERSION} \
5052
&& echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf \
5153
&& echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf \
5254
&& dnf clean all
@@ -56,53 +58,35 @@ ENV CUDA_HOME="/usr/local/cuda" \
5658
LD_LIBRARY_PATH="/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64:${LD_LIBRARY_PATH}"
5759

5860

59-
## CUDA Runtime ################################################################
60-
FROM cuda-base as cuda-runtime
61-
62-
ENV NV_NVTX_VERSION=11.8.86-1 \
63-
NV_LIBNPP_VERSION=11.8.0.86-1 \
64-
NV_LIBCUBLAS_VERSION=11.11.3.6-1 \
65-
NV_LIBNCCL_PACKAGE_VERSION=2.15.5-1+cuda11.8
66-
67-
RUN dnf config-manager \
68-
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
69-
&& dnf install -y \
70-
cuda-libraries-11-8-${NV_CUDA_LIB_VERSION} \
71-
cuda-nvtx-11-8-${NV_NVTX_VERSION} \
72-
libnpp-11-8-${NV_LIBNPP_VERSION} \
73-
libcublas-11-8-${NV_LIBCUBLAS_VERSION} \
74-
libnccl-${NV_LIBNCCL_PACKAGE_VERSION} \
75-
&& dnf clean all
76-
77-
7861
## CUDA Development ############################################################
7962
FROM cuda-base as cuda-devel
8063

81-
ENV NV_CUDA_CUDART_DEV_VERSION=11.8.89-1 \
82-
NV_NVML_DEV_VERSION=11.8.86-1 \
83-
NV_LIBCUBLAS_DEV_VERSION=11.11.3.6-1 \
84-
NV_LIBNPP_DEV_VERSION=11.8.0.86-1 \
85-
NV_LIBNCCL_DEV_PACKAGE_VERSION=2.15.5-1+cuda11.8
64+
# Ref: https://developer.nvidia.com/nccl/nccl-legacy-downloads
65+
ENV NV_CUDA_CUDART_DEV_VERSION=12.1.55-1 \
66+
NV_NVML_DEV_VERSION=12.1.55-1 \
67+
NV_LIBCUBLAS_DEV_VERSION=12.1.0.26-1 \
68+
NV_LIBNPP_DEV_VERSION=12.0.2.50-1 \
69+
NV_LIBNCCL_DEV_PACKAGE_VERSION=2.18.3-1+cuda12.1
8670

8771
RUN dnf config-manager \
8872
--add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
8973
&& dnf install -y \
90-
cuda-command-line-tools-11-8-${NV_CUDA_LIB_VERSION} \
91-
cuda-libraries-devel-11-8-${NV_CUDA_LIB_VERSION} \
92-
cuda-minimal-build-11-8-${NV_CUDA_LIB_VERSION} \
93-
cuda-cudart-devel-11-8-${NV_CUDA_CUDART_DEV_VERSION} \
94-
cuda-nvml-devel-11-8-${NV_NVML_DEV_VERSION} \
95-
libcublas-devel-11-8-${NV_LIBCUBLAS_DEV_VERSION} \
96-
libnpp-devel-11-8-${NV_LIBNPP_DEV_VERSION} \
74+
cuda-command-line-tools-12-1-${NV_CUDA_LIB_VERSION} \
75+
cuda-libraries-devel-12-1-${NV_CUDA_LIB_VERSION} \
76+
cuda-minimal-build-12-1-${NV_CUDA_LIB_VERSION} \
77+
cuda-cudart-devel-12-1-${NV_CUDA_CUDART_DEV_VERSION} \
78+
cuda-nvml-devel-12-1-${NV_NVML_DEV_VERSION} \
79+
libcublas-devel-12-1-${NV_LIBCUBLAS_DEV_VERSION} \
80+
libnpp-devel-12-1-${NV_LIBNPP_DEV_VERSION} \
9781
libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \
9882
&& dnf clean all
9983

10084
ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"
10185

10286

10387
## Rust builder ################################################################
104-
# Specific debian version so that compatible glibc version is used
105-
FROM rust:1.76-bullseye as rust-builder
88+
# Using bookworm for compilation so the rust binaries get linked against libssl.so.3
89+
FROM rust:1.78-bookworm as rust-builder
10690
ARG PROTOC_VERSION
10791

10892
ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -180,6 +164,9 @@ RUN cd server && \
180164
make gen-server && \
181165
pip install ".[accelerate]" --no-cache-dir
182166

167+
# temp: install newer transformers lib that optimum clashes with
168+
RUN pip install transformers==4.40.0 tokenizers==0.19.1 --no-cache-dir
169+
183170
# Patch codegen model changes into transformers
184171
RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
185172

@@ -218,12 +205,12 @@ ENV PATH=/opt/tgis/bin/:$PATH
218205
# Install specific version of torch
219206
RUN pip install ninja==1.11.1.1 --no-cache-dir
220207
RUN pip install packaging --no-cache-dir
221-
RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu118" --no-cache-dir
208+
RUN pip install torch==$PYTORCH_VERSION+cu121 --index-url "${PYTORCH_INDEX}/cu121" --no-cache-dir
222209

223210

224211
## Build flash attention v2 ####################################################
225212
FROM python-builder as flash-att-v2-builder
226-
ARG FLASH_ATT_VERSION=v2.5.2
213+
ARG FLASH_ATT_VERSION=v2.5.6
227214

228215
WORKDIR /usr/src/flash-attention-v2
229216

@@ -237,14 +224,15 @@ RUN MAX_JOBS=2 pip --verbose wheel --no-deps flash-attn==${FLASH_ATT_VERSION} \
237224

238225

239226
## Install auto-gptq ###########################################################
240-
FROM python-builder as auto-gptq-installer
241-
ARG AUTO_GPTQ_REF=ccb6386ebfde63c17c45807d38779a93cd25846f
242-
243-
WORKDIR /usr/src/auto-gptq-wheel
244-
245-
# numpy is required to run auto-gptq's setup.py
246-
RUN pip install numpy
247-
RUN DISABLE_QIGEN=1 pip wheel git+https://github.com/AutoGPTQ/AutoGPTQ@${AUTO_GPTQ_REF} --no-cache-dir --no-deps --verbose
227+
## Uncomment if a custom autogptq build is required
228+
#FROM python-builder as auto-gptq-installer
229+
#ARG AUTO_GPTQ_REF=896d8204bc89a7cfbda42bf3314e13cf4ce20b02
230+
#
231+
#WORKDIR /usr/src/auto-gptq-wheel
232+
#
233+
## numpy is required to run auto-gptq's setup.py
234+
#RUN pip install numpy
235+
#RUN DISABLE_QIGEN=1 pip wheel git+https://github.com/AutoGPTQ/AutoGPTQ@${AUTO_GPTQ_REF} --no-cache-dir --no-deps --verbose
248236

249237
## Build libraries #############################################################
250238
FROM python-builder as build
@@ -254,75 +242,76 @@ COPY server/custom_kernels/ /usr/src/.
254242
RUN cd /usr/src && python setup.py build_ext && python setup.py install
255243

256244

257-
## Build transformers exllama kernels ##########################################
258-
FROM python-builder as exllama-kernels-builder
259-
260-
WORKDIR /usr/src
261-
262-
COPY server/exllama_kernels/ .
263-
RUN python setup.py build
264-
265-
266-
## Build transformers exllamav2 kernels ########################################
267-
FROM python-builder as exllamav2-kernels-builder
268-
269-
WORKDIR /usr/src
270-
271-
COPY server/exllamav2_kernels/ .
272-
RUN python setup.py build
273-
274-
275245
## Flash attention v2 cached build image #######################################
276246
FROM base as flash-att-v2-cache
277247

278248
# Copy just the wheels we built for flash-attention
279249
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2 /usr/src/flash-attention-v2
280250

281251

282-
## Auto gptq cached build image
283-
FROM base as auto-gptq-cache
252+
## Auto gptq cached build image ################################################
253+
## Uncomment if a custom autogptq build is required
254+
#FROM base as auto-gptq-cache
255+
#
256+
## Copy just the wheel we built for auto-gptq
257+
#COPY --from=auto-gptq-installer /usr/src/auto-gptq-wheel /usr/src/auto-gptq-wheel
284258

285-
# Copy just the wheel we built for auto-gptq
286-
COPY --from=auto-gptq-installer /usr/src/auto-gptq-wheel /usr/src/auto-gptq-wheel
287259

260+
## Full set of python installations for server release #########################
261+
262+
FROM python-builder as python-installations
288263

289-
## Final Inference Server image ################################################
290-
FROM cuda-runtime as server-release
291264
ARG PYTHON_VERSION
265+
ARG AUTO_GPTQ_VERSION
292266
ARG SITE_PACKAGES=/opt/tgis/lib/python${PYTHON_VERSION}/site-packages
293267

294-
# Install C++ compiler (required at runtime when PT2_COMPILE is enabled)
295-
RUN dnf install -y gcc-c++ git && dnf clean all \
296-
&& useradd -u 2000 tgis -m -g 0
297-
298-
SHELL ["/bin/bash", "-c"]
299-
300268
COPY --from=build /opt/tgis /opt/tgis
301269

270+
# `pip` is installed in the venv here
302271
ENV PATH=/opt/tgis/bin:$PATH
303272

304273
# Install flash attention v2 from the cache build
305274
RUN --mount=type=bind,from=flash-att-v2-cache,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
306275
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
307276

308-
# Copy build artifacts from exllama kernels builder
309-
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
310-
311-
# Copy build artifacts from exllamav2 kernels builder
312-
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
313-
314277
# Copy over the auto-gptq wheel and install it
315-
RUN --mount=type=bind,from=auto-gptq-cache,src=/usr/src/auto-gptq-wheel,target=/usr/src/auto-gptq-wheel \
316-
pip install /usr/src/auto-gptq-wheel/*.whl --no-cache-dir
278+
#RUN --mount=type=bind,from=auto-gptq-cache,src=/usr/src/auto-gptq-wheel,target=/usr/src/auto-gptq-wheel \
279+
# pip install /usr/src/auto-gptq-wheel/*.whl --no-cache-dir
280+
281+
# We only need to install a custom-built auto-gptq version if we need a pre-release
282+
# or are using a PyTorch nightly version
283+
RUN pip install auto-gptq=="${AUTO_GPTQ_VERSION}" --no-cache-dir
317284

318285
# Install server
286+
# git is required to pull the fms-extras dependency
287+
RUN dnf install -y git && dnf clean all
319288
COPY proto proto
320289
COPY server server
321-
RUN cd server && make gen-server && pip install ".[accelerate, ibm-fms, onnx-gpu, quantize]" --no-cache-dir
290+
# Extra url is required to install cuda-12 version of onnxruntime-gpu
291+
# Ref: https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x
292+
RUN cd server && make gen-server && pip install ".[accelerate, ibm-fms, onnx-gpu, quantize]" --no-cache-dir --extra-index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
293+
294+
# temp: install newer transformers lib that optimum clashes with
295+
RUN pip install transformers==4.40.0 tokenizers==0.19.1 --no-cache-dir
322296

323297
# Patch codegen model changes into transformers 4.35
324298
RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
325299

300+
301+
## Final Inference Server image ################################################
302+
FROM base as server-release
303+
ARG PYTHON_VERSION
304+
ARG SITE_PACKAGES=/opt/tgis/lib/python${PYTHON_VERSION}/site-packages
305+
306+
# Install C++ compiler (required at runtime when PT2_COMPILE is enabled)
307+
RUN dnf install -y gcc-c++ && dnf clean all \
308+
&& useradd -u 2000 tgis -m -g 0
309+
310+
# Copy in the full python environment
311+
COPY --from=python-installations /opt/tgis /opt/tgis
312+
313+
ENV PATH=/opt/tgis/bin:$PATH
314+
326315
# Print a list of all installed packages and versions
327316
RUN pip list -v --disable-pip-version-check --no-python-version-warning
328317

integration_tests/sample_client.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import time
2+
import grpc
3+
from google.protobuf import json_format
4+
from text_generation_tests.pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2
5+
6+
7+
def get_streaming_response_tgis(response):
8+
stop = False
9+
generated_tokens = 0
10+
while not stop:
11+
try:
12+
x = next(response)
13+
timestamp = time.time_ns()
14+
data = json_format.MessageToDict(x)
15+
# skip first response (tokenizer output only)
16+
if "inputTokenCount" not in data:
17+
n_tokens = data["generatedTokenCount"] - generated_tokens
18+
generated_tokens = data["generatedTokenCount"]
19+
yield data, n_tokens, timestamp, True, None
20+
except Exception as e:
21+
timestamp = time.time_ns()
22+
yield None, 0, timestamp, False, e
23+
24+
25+
channel = grpc.insecure_channel("localhost:8033")
26+
stub = gpb2.GenerationServiceStub(channel)
27+
max_new_tokens = 100
28+
29+
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
30+
num_req = 0
31+
while True:
32+
prompt_input = input(f"\n{num_req}) Enter a prompt:\n")
33+
34+
print("-" * 40)
35+
print("Output:")
36+
prompt = template.format(prompt_input)
37+
sample_request = {
38+
"model_id": "dummy-model-name",
39+
"request": {"text": prompt},
40+
"params": {
41+
"method": "GREEDY",
42+
"stopping": {
43+
"max_new_tokens": max_new_tokens,
44+
"min_new_tokens": max_new_tokens,
45+
},
46+
},
47+
}
48+
message = json_format.ParseDict(sample_request, pb2.SingleGenerationRequest())
49+
output = []
50+
total_time = 0
51+
response = stub.GenerateStream(message)
52+
response_generator = get_streaming_response_tgis(response)
53+
t0 = time.time_ns()
54+
response = ""
55+
stop = False
56+
while not stop:
57+
r, n_tokens, t, ok, err = next(response_generator)
58+
59+
if not ok:
60+
stop = True
61+
# check if we have reached end of stream
62+
if type(err) is StopIteration:
63+
continue
64+
duration = (t - t0) / 1000.0 / 1000.0
65+
record = {
66+
"response": r,
67+
"ok": ok,
68+
"error": str(err),
69+
"timestamp": t,
70+
"duration_ms": duration,
71+
"n_tokens": n_tokens,
72+
}
73+
total_time += duration
74+
response += r["text"]
75+
output.append(record)
76+
t0 = t
77+
78+
# print(json.dumps(output, indent=4))
79+
print("-" * 40)
80+
print(response)
81+
print("-" * 40)
82+
print(f"Total_time : {total_time}ms")
83+
print(f"Time_per_token : {total_time/max_new_tokens}ms")
84+
print("-" * 40)
85+
num_req += 1

router/src/batcher.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,19 @@ impl<'a> TokenProcessor<'a> {
723723
let request_id = output.request_id;
724724
let next_token_id = output.token_id;
725725

726-
let e = self.entries.get_mut(&request_id)
727-
.expect("ID not found. This is a bug.");
726+
let e = self.entries.get_mut(&request_id);
727+
728+
// if a client cancelled a request and speculative decoding is
729+
// enabled, it's possible that the request will get removed
730+
// from entries table, but there can still be tokens in outputs stream
731+
// corresponding to that request. ideally we could defer removing
732+
// the request_id from the entries table until all tokens have been
733+
// processed...but for now let's just ignore them.
734+
if e.is_none() {
735+
continue;
736+
}
737+
738+
let e = e.unwrap();
728739

729740
let is_stream = e.stream_tx.is_some();
730741
let stop_seqs = &e.request.parameters.stop_seqs;

0 commit comments

Comments
 (0)