Skip to content

Commit 48df618

Browse files
committed
feat(gaudi): add all the changes from tgi-gaudi fork up to PR #289
1 parent f91434e commit 48df618

File tree

10 files changed

+907
-1236
lines changed

10 files changed

+907
-1236
lines changed

Dockerfile_gaudi

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Those arguments are required to build the image
2-
ARG HABANA_VERSION=1.19.0
3-
ARG PYTORCH_VERSION=2.5.1
2+
ARG HABANA_VERSION=1.20.0
3+
ARG PYTORCH_VERSION=2.6.0
44

55
# Rust builder
66
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
@@ -92,7 +92,6 @@ RUN cd server && \
9292
make gen-server && \
9393
pip install --no-deps -r requirements.txt && \
9494
bash ./dill-0.3.8-patch.sh && \
95-
pip install outlines~=0.0.34 && \
9695
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
9796
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
9897
pip install . --no-cache-dir

backends/gaudi/Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
22
mkfile_dir := $(dir $(mkfile_path))
33
root_dir := "${mkfile_dir}/../.."
44

5-
HABANA_VERSION := 1.19.0
6-
PYTORCH_VERSION := 2.5.1
5+
HABANA_VERSION := 1.20.0
6+
PYTORCH_VERSION := 2.6.0
77

88
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
99

backends/gaudi/server/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
2222
hf-transfer = "^0.1.2"
2323
sentencepiece = "^0.1.97"
2424
peft = "^0.10"
25-
optimum-habana = "1.15.0"
25+
optimum-habana = "1.16.0"
2626
transformers = "4.45.2"
2727
numpy = "1.26.4"
2828
accelerate = "0.33.0"

backends/gaudi/server/requirements.txt

+16-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
4646
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
4747
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
4848
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
49-
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
49+
optimum-habana==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
5050
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
5151
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
5252
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
@@ -87,3 +87,18 @@ wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
8787
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
8888
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
8989
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
90+
outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13"
91+
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
92+
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
93+
cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13"
94+
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
95+
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
96+
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
97+
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
98+
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
99+
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
100+
nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13"
101+
pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13"
102+
pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13"
103+
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
104+
rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13"

backends/gaudi/server/text_generation_server/models/__init__.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717
from text_generation_server.models.bloom import BLOOM
1818
from text_generation_server.models.starcoder import StarCoder
1919
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
20-
21-
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
20+
from text_generation_server.models.custom_modeling.mllama import (
21+
MllamaForConditionalGeneration,
22+
)
2223
from text_generation_server.models.custom_modeling.llava_next import (
2324
LlavaNextForConditionalGeneration,
2425
)
2526

2627
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
27-
# from text_generation_server.models.custom_modeling.mllama import (
28-
# MllamaForConditionalGeneration,
29-
# )
3028
from text_generation_server.utils.adapter import (
3129
AdapterParameters,
3230
build_layer_weight_lookup,
@@ -39,6 +37,7 @@
3937
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
4038

4139

40+
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
4241
# Disable gradients
4342
torch.set_grad_enabled(False)
4443

@@ -55,6 +54,8 @@ def get_model(
5554
max_input_tokens: int,
5655
) -> Model:
5756
adapt_transformers_to_gaudi()
57+
if SDP_ON_BF16 == 1:
58+
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
5859

5960
if speculate is not None:
6061
set_speculate(speculate)
@@ -199,6 +200,17 @@ def get_model(
199200
trust_remote_code=trust_remote_code,
200201
)
201202

203+
if model_type == "mllama":
204+
return VlmCausalLM(
205+
model_class=MllamaForConditionalGeneration,
206+
model_id=model_id,
207+
revision=revision,
208+
quantize=None,
209+
speculator=speculator,
210+
dtype=dtype,
211+
trust_remote_code=trust_remote_code,
212+
)
213+
202214
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
203215
return CausalLM(
204216
model_id,

backends/gaudi/server/text_generation_server/models/causal_lm.py

+3
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,9 @@ def __init__(
704704
htorch.core.hpu_set_env()
705705

706706
if world_size > 1:
707+
os.environ.setdefault(
708+
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
709+
)
707710
model = self.get_deepspeed_model(model_id, dtype, revision)
708711
model = hq_env.prepare_model_for_quantization(model)
709712
else:

backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py

+118-35
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
# limitations under the License.
1515
""" PyTorch Llava-NeXT model."""
1616

17-
from typing import List, Optional
17+
from typing import List, Optional, Union
1818

1919
import torch
2020
import torch.utils.checkpoint
21+
import numpy as np
2122

2223
from transformers.models.llava_next.modeling_llava_next import (
2324
unpad_image,
@@ -49,6 +50,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
4950
return height // patch_size, width // patch_size
5051

5152

53+
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
54+
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
55+
"""
56+
Calculate the number of patches after the preprocessing for images of any resolution.
57+
58+
Args:
59+
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
60+
The size of the input image in the format (height, width). ?
61+
grid_pinpoints (`List`):
62+
A list containing possible resolutions. Each item in the list should be a tuple or list
63+
of the form `(height, width)`.
64+
patch_size (`int`):
65+
The size of each image patch.
66+
67+
Returns:
68+
int: the number of patches
69+
"""
70+
if not isinstance(grid_pinpoints, list):
71+
raise TypeError("grid_pinpoints should be a list of tuples or lists")
72+
73+
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
74+
if not isinstance(image_size, (list, tuple)):
75+
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
76+
raise TypeError(
77+
f"image_size invalid type {type(image_size)} with value {image_size}"
78+
)
79+
image_size = image_size.tolist()
80+
81+
best_resolution = select_best_resolution(image_size, grid_pinpoints)
82+
height, width = best_resolution
83+
num_patches = 0
84+
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
85+
for i in range(0, height, patch_size):
86+
for j in range(0, width, patch_size):
87+
num_patches += 1
88+
# add the base patch
89+
num_patches += 1
90+
return num_patches
91+
92+
5293
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
5394

5495
def _merge_input_ids_with_image_features(
@@ -128,6 +169,76 @@ def forward(
128169

129170
return outputs
130171

172+
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
173+
def get_image_features(
174+
self,
175+
pixel_values: torch.FloatTensor,
176+
image_sizes: torch.Tensor,
177+
vision_feature_layer: Union[int, List[int]],
178+
vision_feature_select_strategy: str,
179+
):
180+
"""
181+
Obtains image last hidden states from the vision tower and apply multimodal projection.
182+
183+
Args:
184+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
185+
The tensors corresponding to the input images.
186+
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
187+
Actual image size of each images (H, W).
188+
vision_feature_layer (`Union[int, List[int]]`):
189+
The index of the layer to select the vision feature. If multiple indices are provided,
190+
the vision feature of the corresponding indices will be concatenated to form the
191+
vision features.
192+
vision_feature_select_strategy (`str`):
193+
The feature selection strategy used to select the vision feature from the vision backbone.
194+
Can be one of `"default"` or `"full"`
195+
Returns:
196+
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
197+
and are of shape `(num_patches, image_length, embed_dim)`).
198+
"""
199+
# ! infer image_num_patches from image_sizes
200+
image_num_patches = [
201+
image_size_to_num_patches(
202+
image_size=imsize,
203+
grid_pinpoints=self.config.image_grid_pinpoints,
204+
patch_size=self.config.vision_config.image_size,
205+
)
206+
for imsize in image_sizes
207+
]
208+
if pixel_values.dim() == 5:
209+
# stacked if input is (batch_size, num_patches, num_channels, height, width)
210+
_pixel_values_list = [
211+
pix_val[:num_patch]
212+
for pix_val, num_patch in zip(pixel_values, image_num_patches)
213+
]
214+
pixel_values = torch.cat(_pixel_values_list, dim=0)
215+
elif pixel_values.dim() != 4:
216+
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
217+
raise ValueError(
218+
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
219+
)
220+
221+
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
222+
# If we have one vision feature layer, return the corresponding hidden states,
223+
# otherwise, select the hidden states of each feature layer and concatenate them
224+
if isinstance(vision_feature_layer, int):
225+
selected_image_feature = image_features.hidden_states[vision_feature_layer]
226+
else:
227+
hs_pool = [
228+
image_features.hidden_states[layer_idx]
229+
for layer_idx in vision_feature_layer
230+
]
231+
selected_image_feature = torch.cat(hs_pool, dim=-1)
232+
233+
if vision_feature_select_strategy == "default":
234+
selected_image_feature = selected_image_feature[:, 1:]
235+
elif vision_feature_select_strategy == "full":
236+
selected_image_feature = selected_image_feature
237+
238+
image_features = self.multi_modal_projector(selected_image_feature)
239+
image_features = torch.split(image_features, image_num_patches, dim=0)
240+
return image_features
241+
131242
def prepare_inputs_for_generation(
132243
self,
133244
input_ids,
@@ -184,35 +295,12 @@ def prepare_inputs_for_generation(
184295
# 1. Extract the input embeddings
185296
inputs_embeds = self.get_input_embeddings()(input_ids)
186297
# 2. Merge text and images
187-
batch_size, num_patches, num_channels, height, width = (
188-
pixel_values.shape
298+
image_features = self.get_image_features(
299+
pixel_values,
300+
image_sizes,
301+
vision_feature_layer=vision_feature_layer,
302+
vision_feature_select_strategy=vision_feature_select_strategy,
189303
)
190-
reshaped_pixel_values = pixel_values.view(
191-
batch_size * num_patches, num_channels, height, width
192-
)
193-
image_features = self.vision_tower(
194-
reshaped_pixel_values,
195-
output_hidden_states=True,
196-
use_flash_attention=use_flash_attention,
197-
flash_attention_recompute=flash_attention_recompute,
198-
)
199-
200-
selected_image_feature = image_features.hidden_states[
201-
vision_feature_layer
202-
]
203-
204-
if vision_feature_select_strategy == "default":
205-
selected_image_feature = selected_image_feature[:, 1:]
206-
elif vision_feature_select_strategy == "full":
207-
selected_image_feature = selected_image_feature
208-
209-
image_features = self.multi_modal_projector(selected_image_feature)
210-
211-
# split up image_features for each of the individual images
212-
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
213-
# if we assume each image has 5 image features (base image + 4 patches)
214-
split_sizes = [image.shape[0] for image in pixel_values]
215-
image_features = torch.split(image_features, split_sizes, dim=0)
216304

217305
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
218306
height = width = (
@@ -266,13 +354,10 @@ def prepare_inputs_for_generation(
266354
(image_feature, self.image_newline[None]), dim=0
267355
)
268356
new_image_features.append(image_feature)
269-
image_features = torch.stack(new_image_features, dim=0)
357+
image_features = torch.cat(new_image_features, dim=0)
270358
inputs_embeds = self._merge_input_ids_with_image_features(
271359
inputs_embeds, image_features, input_ids
272360
)
273-
self.image_offset = (
274-
image_features.shape[1] - 1
275-
) # image_token has occupied 1 token position.
276361
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
277362
# generation with cache
278363
elif past_key_values is not None:
@@ -282,12 +367,10 @@ def prepare_inputs_for_generation(
282367
# Retrieve the first layer to inspect the logits and mask out the hidden states
283368
# that are set to 0
284369
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
285-
286370
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
287371
batch_index, non_attended_tokens = torch.where(
288372
first_layer_past_key_value.float().sum(-2) == 0
289373
)
290-
291374
# Get the target length
292375
past_length = first_layer_past_key_value.shape[-1]
293376
extended_attention_mask = torch.ones(

0 commit comments

Comments
 (0)