-
Notifications
You must be signed in to change notification settings - Fork 63
fix vispruner bugs and update holitom_merge #431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zhangbilang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly refines the token reduction capabilities within the llmc framework, particularly for models handling visual inputs. It addresses several critical bugs in the vispruner module, introduces an advanced attention-based token merging strategy (holitom_merge), and cleans up various codebases by standardizing import error handling and improving parameter consistency. These changes collectively aim to enhance the efficiency, robustness, and configurability of visual token processing.
Highlights
- Bug Fixes in Visual Pruning (VisPruner): Corrected the calculation of the number of visual tokens to prune, ensuring it's based on the actual input image features rather than a fixed configuration. Additionally, a redundant image feature transformation step was removed, and handling of
mm_patch_merge_typewas adjusted. - Enhanced Token Merging (Holitom Merge): Introduced a sophisticated token merging strategy that intelligently prunes visual tokens based on their attention scores. This involves capturing attention from a preceding layer, identifying the most important tokens, and merging less important ones into their most similar counterparts, thereby reducing sequence length while aiming to preserve critical information.
- Improved Import Error Handling: Standardized the handling of optional LLaVA module imports across multiple token reduction files (dycoke, fastvid, holitom, prunevid, utils) by changing
ModuleNotFoundErrorto the more generalImportErrorand suppressing verbose logging on failure, indicating these are optional dependencies. - Code Cleanup and Parameter Consistency: Removed unused imports (
wraps,logger,F) and simplified configuration access intome.pyandrandom.py. Parameter names related to visual token pruning were also updated for consistency (e.g.,ratetoprune_ratio,image_token_lengthtovision_token_length).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several fixes and updates, primarily within the token reduction modules. Key changes include:
- Fixing bugs in
visprunerby using dynamic token counts and correcting tensor operations. - Updating
holitom_mergelogic in therandompruning strategy, which adds a new merging method based on attention scores. - Standardizing exception handling for optional
llavadependencies across multiple files. - Cleaning up unused code and logging.
My review focuses on performance, correctness, and maintainability. The new holitom_merge implementation contains loops that can be vectorized for significant performance gains. I've also identified a potential correctness issue in visionzip.py where a safety check for contiguous visual tokens was removed, and a maintainability concern in tome.py regarding the removal of backups for patched methods. I've provided detailed feedback and suggestions for these points.
| for b in range(batch_size): | ||
| for i in range(num_non_topk): | ||
| topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引 | ||
| topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引 | ||
| non_topk_abs_idx = non_topk_indices[i] | ||
|
|
||
| # 累加non-topk到topk token上(就地) | ||
| hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :] | ||
| # 增加计数 | ||
| topk_counter[b, topk_rel_idx] += 1 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested Python loops over batch_size and num_non_topk to merge token states can be very inefficient, especially for larger batch sizes, as it prevents vectorized execution on the GPU. This part of the code can be a significant performance bottleneck.
Consider replacing the loop with vectorized PyTorch operations. You can use torch.scatter_add_ to perform the summation of non-topk states into topk states in a single, efficient operation. This will also require vectorizing the topk_counter update.
| for b in range(batch_size): | ||
| for i in range(num_topk): | ||
| topk_abs_idx = top_attention_rank_index[i] | ||
| hidden_states[b, topk_abs_idx, :] /= topk_counter[b, i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| first, last = st_idx[0].item(), st_idx[-1].item() | ||
| img_mask[first: last + 1] = ~select_mask | ||
| img_mask = ~img_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to detect and handle discontinuous visual tokens has been removed. The previous code explicitly checked for this and would raise a ValueError, which is a safe way to handle unexpected input. The new implementation assumes that the visual tokens are always in a contiguous block by just using the first and last indices.
If it's possible for visual tokens to be non-contiguous, this change could lead to incorrect behavior and silent errors by applying the select_mask to a range that includes non-visual tokens. The raise statement in the original code was followed by unreachable code, which was a bug, but removing the check entirely might be unsafe.
Could you please confirm that visual tokens are guaranteed to be contiguous for qwenv25vl? If not, it would be safer to restore a check for contiguity.
| last_layer_attention_avg_last_tok_image = \ | ||
| last_layer_attention_avg_last_tok[ | ||
| image_token_start_index: | ||
| image_token_start_index + image_token_length | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backslash \ for line continuation here is unnecessary. Python implicitly continues lines inside parentheses, brackets, and braces, which is the case here with the square brackets for slicing. Removing the backslash and adjusting the indentation will make the code cleaner and more compliant with PEP 8.
last_layer_attention_avg_last_tok_image = last_layer_attention_avg_last_tok[
image_token_start_index:
image_token_start_index + image_token_length
]| topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引 | ||
| topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引 | ||
| non_topk_abs_idx = non_topk_indices[i] | ||
|
|
||
| # 累加non-topk到topk token上(就地) | ||
| hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :] | ||
| # 增加计数 | ||
| topk_counter[b, topk_rel_idx] += 1 | ||
|
|
||
| # 平均化所有topk token(包含自己和所有被合并的) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code contains comments in Chinese. To improve code clarity and maintainability for a broader audience, please translate these comments into English.
For example:
# 这是 topk 中的相对索引-># This is the relative index within topk# 得到绝对索引-># Get the absolute index# 累加non-topk到topk token上(就地)-># Accumulate non-topk token to topk token (in-place)
| block.self_attn.forward = types.MethodType( | ||
| tome_CLIPSdpaAttention_forward, | ||
| block.self_attn | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've removed the lines that back up the original forward methods before patching them. While this simplifies the code, it's a risky change as it removes the ability to revert the patch or to call the original methods from the patched versions. This can make debugging more difficult and reduces the modularity of the patch.
It is generally a good practice to store the original method before overwriting it. If there's no strong reason to remove them, I'd recommend restoring these backups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zhangbilang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly refines and stabilizes the token reduction functionalities within the llmc framework. It primarily focuses on addressing and resolving several identified bugs within the vispruner module to ensure more accurate and reliable visual token pruning. Additionally, it introduces substantial updates to the holitom_merge strategy, enabling a more advanced and attention-driven approach to token merging. The changes also encompass general code hygiene improvements, such as streamlining import error handling and removing redundant or unused code, leading to a cleaner and more efficient codebase.
Highlights
- Bug Fixes in VisPruner: Corrected the calculation of
visual_token_numto dynamically use the actual number of image features (N) instead of a fixedvision_token_length. Also, a redundant flattening operation and a potentially problematic string replacement formm_patch_merge_typewere removed, enhancing the robustness of visual token pruning. - Enhanced Token Reduction Strategies: Introduced a new
holitom_mergestrategy that intelligently merges less important visual tokens based on attention scores, providing a more sophisticated and attention-aware pruning mechanism. This strategy is now dynamically selectable. - Code Refinement and Consistency: Standardized exception handling for optional
llavaimports across multiple files by switching from the broaderModuleNotFoundErrorto the more generalImportErrorand removing verboselogger.infomessages. This improves code cleanliness and error handling. - Cleanup of Unused Code: Removed obsolete imports (
wraps,logger,torch.nn.functional) and deprecated LLaVA-specific input handling hooks, contributing to a leaner and more maintainable codebase.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several changes, including bug fixes in VisPruner, an update to holitom_merge, and general code cleanup.
The changes in VisPruner and the refactoring to use ImportError for optional dependencies are good improvements.
I've identified a few areas for improvement:
- In
llmc/compression/token_reduction/random.py, the newholitom_merge_hookcontains Chinese comments and loops that could be vectorized for better performance. - In
llmc/compression/token_reduction/visionzip.py, a check for contiguous visual tokens has been removed, which could lead to incorrect behavior if the tokens are not contiguous. This is a critical issue that needs to be addressed.
Please see my detailed comments.
| first, last = st_idx[0].item(), st_idx[-1].item() | ||
| img_mask[first: last + 1] = ~select_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation checked for discontinuous visual tokens and raised an error. This check has been removed. The new logic first, last = st_idx[0].item(), st_idx[-1].item() assumes that the visual tokens are always in a contiguous block. If st_idx can contain indices of multiple, separate blocks of visual tokens, this assumption is incorrect. For example, if st_idx is [10, 11, 100, 101], first will be 10 and last will be 101. The slice img_mask[first: last + 1] would incorrectly include non-visual tokens from index 12 to 99. This could lead to silent and incorrect behavior.
The original check was safer. If discontinuous tokens are not supported, it's better to fail loudly with a ValueError. If they are to be supported, the logic should handle them correctly. The original code had unreachable code after raise, but the check itself was valuable.
Please consider restoring the check for contiguity or implementing logic to handle multiple segments of visual tokens.
| for b in range(batch_size): | ||
| for i in range(num_non_topk): | ||
| topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引 | ||
| topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引 | ||
| non_topk_abs_idx = non_topk_indices[i] | ||
|
|
||
| # 累加non-topk到topk token上(就地) | ||
| hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :] | ||
| # 增加计数 | ||
| topk_counter[b, topk_rel_idx] += 1 | ||
|
|
||
| # 平均化所有topk token(包含自己和所有被合并的) | ||
| for b in range(batch_size): | ||
| for i in range(num_topk): | ||
| topk_abs_idx = top_attention_rank_index[i] | ||
| hidden_states[b, topk_abs_idx, :] /= topk_counter[b, i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of the code has a couple of areas for improvement:
- Performance: The nested Python loops iterating over
batch_sizeand token numbers can be a performance bottleneck. These operations can be vectorized usingtorch.scatter_add_for significant speedup. - Maintainability: There are comments in Chinese. For better code maintainability and to make it accessible to a wider range of contributors, they should be translated to English.
Here is a suggested vectorized implementation with English comments:
# Vectorized token merging for performance.
# Determine which top-k token each non-top-k token will merge into.
topk_abs_indices_to_add_to = top_attention_rank_index[sim_max_index]
# Add non-top-k hidden states to their corresponding top-k hidden states.
source_states = hidden_states[:, non_topk_indices, :]
index_for_scatter = topk_abs_indices_to_add_to.unsqueeze(-1).expand_as(source_states)
hidden_states.scatter_add_(1, index_for_scatter, source_states)
# Update counters for averaging.
topk_counter.scatter_add_(1, sim_max_index.unsqueeze(-1), 1)
# Average the hidden states of the merged tokens.
hidden_states[:, top_attention_rank_index, :] /= topk_counter
No description provided.