-
Notifications
You must be signed in to change notification settings - Fork 63
update visualizer #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update visualizer #434
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zhangbilang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the llmc library's visualization capabilities and refines configuration parameters. It updates numerical precision in sparsification method configurations and introduces a suite of new functions for visualizing attention maps and token embedding similarities, providing more granular insights into model behavior and token reduction strategies.
Highlights
- Configuration Precision Update: Adjusted reduction_ratio and prune_ratio values in DART and VisPruner configuration files to a higher precision (e.g., 0.778 to 0.7778).
- Visualizer Hook Adjustment: Modified visualizer_hook to target a different attention head/layer (index 31 instead of 4) for visualization.
- Improved Image Saving Utility: Enhanced the save_image function to intelligently handle save paths, including creating directories and generating unique filenames when saving to a folder.
- Refined Patch Visualization Logic: Updated visualize_kept_patches to apply a "white" masking effect to pruned areas, improving visual clarity, and changed the default darken_ratio.
- Introduction of Advanced Visualizations: Added four new functions (visualize_attention, visualize_attention_v2, visualize_cosin_token, visualize_cosin_token_32p) to enable detailed visualization of attention patterns and token embedding cosine similarities.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the visualizer components, including configuration changes, modifications to existing visualization functions, and the addition of several new complex visualization utilities. The changes improve functionality, but there are several areas for improvement regarding code clarity, robustness, and maintainability. Key issues include the use of magic numbers, potential division-by-zero errors, modification of global state in plotting functions, and brittle logic for handling file paths. I've provided specific comments and suggestions to address these points.
| if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')): | ||
| os.makedirs(save_path, exist_ok=True) | ||
| base_path = os.path.join(save_path, '{:04d}_visprunerP.png') | ||
| idx = 0 | ||
| while os.path.exists(base_path.format(idx)): | ||
| idx += 1 | ||
| save_path = base_path.format(idx) | ||
|
|
||
| else: | ||
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to determine if save_path is a directory or a file is based on a hardcoded list of extensions. This is brittle; for example, it would fail for a file without an extension or a file with an unlisted extension (e.g., .svg). A more robust approach is to check for the presence of a file extension using os.path.splitext.
| if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')): | |
| os.makedirs(save_path, exist_ok=True) | |
| base_path = os.path.join(save_path, '{:04d}_visprunerP.png') | |
| idx = 0 | |
| while os.path.exists(base_path.format(idx)): | |
| idx += 1 | |
| save_path = base_path.format(idx) | |
| else: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| if not os.path.splitext(save_path)[1]: | |
| os.makedirs(save_path, exist_ok=True) | |
| base_path = os.path.join(save_path, '{:04d}_visprunerP.png') | |
| idx = 0 | |
| while os.path.exists(base_path.format(idx)): | |
| idx += 1 | |
| save_path = base_path.format(idx) | |
| else: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| def visualize_attention_v2(attention, grid_size=24, save_path=None): | ||
|
|
||
| if hasattr(attention, 'detach'): | ||
| attention = attention.detach().cpu().numpy() | ||
|
|
||
| # 分区 | ||
| block_ranges = [] | ||
|
|
||
| # SYS: 2 blocks | ||
|
|
||
| sys_splits = [0, 17, 35] | ||
| for i in range(len(sys_splits) - 1): | ||
| block_ranges.append((sys_splits[i], sys_splits[i + 1])) | ||
| # IMG: 24 blocks of size 24 | ||
| for i in range(24): | ||
| start = 35 + i * 24 | ||
| end = start + 24 | ||
| block_ranges.append((start, end)) | ||
|
|
||
| # INS: 6 blocks | ||
| ins_splits = [611 + i * 91 for i in range(7)] # 611 + 6 * 91 = 1157 → crop to 1155 | ||
| ins_splits[-1] = 1155 | ||
| for i in range(len(ins_splits) - 1): | ||
| block_ranges.append((ins_splits[i], ins_splits[i + 1])) | ||
|
|
||
| # 对每个 block pair 求平均 | ||
| num_blocks = len(block_ranges) | ||
| block_attention = np.zeros((num_blocks, num_blocks)) | ||
| for i in range(num_blocks): | ||
| i_start, i_end = block_ranges[i] | ||
| for j in range(num_blocks): | ||
| j_start, j_end = block_ranges[j] | ||
| block = attention[i_start:i_end, j_start:j_end] | ||
| block_attention[31 - i, j] = block.mean() | ||
|
|
||
| mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1) | ||
| plt.figure(figsize=(10, 10)) | ||
| block_attention = block_attention / block_attention.max(axis=1, keepdims=True) | ||
| sns.heatmap(block_attention, mask=mask, cmap='viridis', square=True, cbar=True) | ||
| # sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True) | ||
|
|
||
| section_labels = ['SYS', 'IMG', 'INS'] | ||
| section_boundaries = [2, 26, 32] # block_ranges 分别为2个SYS,24个IMG,6个INS | ||
| ticks = np.arange(0, num_blocks) | ||
| plt.xticks(ticks=ticks, labels=[''] * num_blocks) | ||
| plt.yticks(ticks=ticks, labels=[''] * num_blocks) | ||
| plt.xticks(ticks=section_boundaries, labels=section_labels, fontsize=12) | ||
| plt.yticks(ticks=section_boundaries, labels=section_labels, fontsize=12) | ||
| plt.title('Attention Map') | ||
| plt.tight_layout() | ||
| plt.savefig(save_path, bbox_inches='tight') | ||
| plt.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function contains many hardcoded "magic numbers" (e.g., [0, 17, 35], 35 + i * 24, 31 - i, [2, 26, 32]) that seem specific to a particular model architecture. This makes the function rigid and hard to maintain or reuse for other models. These values should be refactored to be passed in as parameters or loaded from a configuration.
|
|
||
| mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1) | ||
| plt.figure(figsize=(10, 10)) | ||
| block_attention = block_attention / block_attention.max(axis=1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The normalization block_attention / block_attention.max(...) is unsafe. If block_attention.max(...) is zero for any row, it will result in a division by zero, producing NaN or inf values, which will likely cause sns.heatmap to fail. You should handle this case, for example by adding a small epsilon to the denominator or using np.divide with a where clause.
| block_attention = block_attention / block_attention.max(axis=1, keepdims=True) | |
| max_vals = block_attention.max(axis=1, keepdims=True) | |
| block_attention = np.divide(block_attention, max_vals, out=np.zeros_like(block_attention), where=max_vals!=0) |
| ) | ||
| visualize_grid_to_grid( | ||
| visual_attention_maps[0, 4, :, :], | ||
| visual_attention_maps[0, 31, :, :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras), | ||
| ) | ||
| if idx == (len(self.blocks) - 1): | ||
| # self.model.language_model.layers[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def visualize_kept_patches( | ||
| image, | ||
| keep_idx, | ||
| keep_idx=None, | ||
| mean=[0.48145466, 0.4578275, 0.40821073], | ||
| std=[0.26862954, 0.26130258, 0.27577711], | ||
| patch_size=14, | ||
| darken_ratio=0.3, | ||
| darken_ratio=0.8, | ||
| save_path=None, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert image.ndim == 3 and image.shape[0] == 3 was removed. This check is valuable for ensuring the input image has the expected shape and provides a clear error message if it doesn't. Without it, the code might fail later with a less specific ValueError. Please consider adding it back at the beginning of the function body.
| plt.rcParams['font.size'] = 15 | ||
|
|
||
| x = token_embedding[0, 14: 14 + 196 * 4, :] | ||
| x_norm = F.normalize(x, p=2, dim=1) | ||
| similarity_matrix = x_norm @ x_norm.T | ||
|
|
||
| sim_np = similarity_matrix.cpu().numpy() | ||
| sim_np = np.triu(sim_np, k=1) | ||
| valid_sim = sim_np[sim_np > 0] | ||
| vmin = np.percentile(valid_sim, 90) # 10% min | ||
|
|
||
| plt.subplots(figsize=(10, 10)) | ||
| sns.heatmap(similarity_matrix.cpu().numpy(), cmap='Reds', vmin=vmin, vmax=1) | ||
|
|
||
| start = 0 | ||
| step = 196 | ||
| ticks = np.arange(start, 196 * 5, step) | ||
| plt.xticks(ticks, ticks) | ||
| plt.yticks(ticks, ticks) | ||
|
|
||
| plt.title('') | ||
| plt.xlabel('') | ||
| plt.ylabel('') | ||
| plt.tight_layout() | ||
| plt.savefig(save_path, format='pdf') | ||
| plt.rcdefaults() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modifying matplotlib's global rcParams can have unintended side effects on other plotting functions in the application. It's better practice to use a context manager with plt.rc_context(...) to apply these settings only for this function. This also makes the call to plt.rcdefaults() at the end unnecessary.
| plt.xlabel('') | ||
| plt.ylabel('') | ||
| plt.tight_layout() | ||
| plt.savefig(save_path, format='pdf') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file format is hardcoded to 'pdf'. matplotlib.pyplot.savefig can infer the format from the file extension in save_path. Hardcoding the format prevents saving in other formats (e.g., PNG, SVG) and may be unexpected if save_path has a different extension. It's better to remove the format argument and let savefig infer it.
| plt.savefig(save_path, format='pdf') | |
| plt.savefig(save_path) |
| plt.rcParams['font.size'] = 20 | ||
|
|
||
| all_tokens = token_embedding[0, 14:14 + 196 * 32, :] | ||
| x_norm = F.normalize(all_tokens, p=2, dim=1) | ||
| similarity_matrix = x_norm @ x_norm.T | ||
| sim_np = similarity_matrix.cpu().numpy() | ||
| sim_np = np.triu(sim_np, k=1) | ||
| valid_sim = sim_np[sim_np > 0] | ||
| vmin = np.percentile(valid_sim, 90) # 10% min | ||
|
|
||
| group_size = 4 | ||
| num_groups = 8 | ||
| tokens_per_group = 196 * group_size | ||
| step = 196 | ||
|
|
||
| fig, axs = plt.subplots(2, 4, figsize=(22, 10)) # 2x4排布 | ||
| axs = axs.flatten() | ||
|
|
||
| for i in range(num_groups): | ||
| x = all_tokens[i * tokens_per_group: (i + 1) * tokens_per_group, :] | ||
| x_norm = F.normalize(x, p=2, dim=1) | ||
| similarity_matrix = x_norm @ x_norm.T | ||
|
|
||
| ax = axs[i] | ||
| sns.heatmap( | ||
| similarity_matrix.cpu().numpy(), cmap='Reds', | ||
| vmin=vmin, vmax=1, ax=ax, cbar=False | ||
| ) | ||
|
|
||
| ticks = np.arange(0, tokens_per_group, step) | ||
| labels = np.arange(i * tokens_per_group, (i + 1) * tokens_per_group, step) | ||
| ax.set_xticks(ticks) | ||
| ax.set_yticks(ticks) | ||
| ax.set_xticklabels(labels, rotation=0) | ||
| ax.set_yticklabels(labels) | ||
| start_frame = i * group_size | ||
| end_frame = (i + 1) * group_size - 1 | ||
| ax.set_xlabel(f'Frame {start_frame}-{end_frame}', fontsize=17, labelpad=10) | ||
|
|
||
| plt.tight_layout() | ||
| # plt.savefig(save_path, format='pdf') | ||
| # plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight') | ||
| plt.savefig(save_path, dpi=300) | ||
| plt.rcdefaults() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # plt.savefig(save_path, format='pdf') | ||
| # plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.