-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[MM][Core] Decouple ViT backend from LM backend #27061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Roger Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Lines 316 to 341 in 6a588c0
def compute_hash(self) -> str: | |
""" | |
WARNING: Whenever a new field is added to this config, | |
ensure that it is included in the factors list if | |
it affects the computation graph. | |
Provide a hash that uniquely identifies all the configs | |
that affect the structure of the computation | |
graph from input ids/embeddings to the final hidden states, | |
excluding anything before input ids/embeddings and after | |
the final hidden states. | |
""" | |
factors: list[Any] = [] | |
factors.append(self.model) | |
factors.append(self.dtype) | |
factors.append(self.quantization) | |
factors.append(self.revision) | |
factors.append(self.code_revision) | |
factors.append(self.max_model_len) | |
factors.append(self.max_logprobs) | |
factors.append(self.disable_sliding_window) | |
factors.append(self.trust_remote_code) | |
factors.append(self.generation_config) | |
factors.append(self.model_impl) | |
factors.append(self.override_generation_config) | |
factors.append(self.rope_scaling) |
The new mm_encoder_attn_backend
affects which attention kernel the vision encoder uses, but ModelConfig.compute_hash()
still ignores the entire multimodal configuration. VllmConfig.compute_hash()
relies on this value when keying compilation caches. As a result, running once with the default backend and later with --mm-encoder-attn-backend=FLASH_ATTN
will produce identical hashes and may reuse compiled artifacts for the wrong backend. Please incorporate self.multimodal_config.compute_hash()
(when present) into the model hash so that changing the ViT backend invalidates cached compilations.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
self.attn_backend = get_vit_attn_backend( | ||
head_size=head_dim, dtype=torch.get_default_dtype() | ||
head_size=head_dim, | ||
dtype=torch.get_default_dtype(), | ||
attn_backend_override=attn_backend_override, | ||
) | ||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( | ||
torch.get_default_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Override still forced to upstream flash attention
The newly added mm_encoder_attn_backend
override is returned by get_vit_attn_backend
, but the following block still promotes any non‑FLASH backend to _Backend.FLASH_ATTN
whenever upstream flash attention is available. This means a user who explicitly passes --mm-encoder-attn-backend=TRITON_ATTN
(or any other backend) will still run with flash attention on systems where it is present, defeating the purpose of the override and keeping ViT and LM backends coupled. Consider skipping this auto‑upgrade when an explicit override was provided.
Useful? React with 👍 / 👎.
|
||
# Determine the attention backend | ||
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) | ||
attn_backend_override = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename this layer to VisionAttention
btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This layer will be renamed to MMEncoderAttention
in #27147. But we can also rename it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do that in the other PR then
if candidate is not None: | ||
return candidate | ||
|
||
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can add a supported_vit_backend
for Platform interface in a following PR to detect invalid backend for specific platform before initializing model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea right now it'll just show all possible _Backend
with someone of them get auto resolved inside their correspending platform.get_vit_attn_backend. For example
Lines 203 to 211 in 9fce7be
@classmethod | |
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": | |
from vllm.attention.backends.registry import _Backend | |
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): | |
return _Backend.ROCM_AITER_FA | |
if on_gfx9(): | |
return _Backend.FLASH_ATTN | |
return _Backend.TORCH_SDPA |
I think we can shrink this selection by just having a specific _MHA_Backend
enum
|
||
# Determine the attention backend | ||
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) | ||
attn_backend_override = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This layer will be renamed to MMEncoderAttention
in #27147. But we can also rename it here.
Purpose
Multimdal encoder attention backend for a selection of models have been adopting the same backend as the language model backbone (via env var). This has exposed challenges and inflexibilities since some backends may work for the LM but not for the multimodal encoder.
This PR refactors the ViT backend selection for models that use
get_vit_attn_backend
so that users can specify--mm-encoder-attn-backend
as an override.Model-specific attention post-selection changes will stay as is (for example, using upstream FA whenever it's available).
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.