Skip to content

Conversation

@zhangbilang
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @zhangbilang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the llmc library's visualization capabilities and refines configuration parameters. It updates numerical precision in sparsification method configurations and introduces a suite of new functions for visualizing attention maps and token embedding similarities, providing more granular insights into model behavior and token reduction strategies.

Highlights

  • Configuration Precision Update: Adjusted reduction_ratio and prune_ratio values in DART and VisPruner configuration files to a higher precision (e.g., 0.778 to 0.7778).
  • Visualizer Hook Adjustment: Modified visualizer_hook to target a different attention head/layer (index 31 instead of 4) for visualization.
  • Improved Image Saving Utility: Enhanced the save_image function to intelligently handle save paths, including creating directories and generating unique filenames when saving to a folder.
  • Refined Patch Visualization Logic: Updated visualize_kept_patches to apply a "white" masking effect to pruned areas, improving visual clarity, and changed the default darken_ratio.
  • Introduction of Advanced Visualizations: Added four new functions (visualize_attention, visualize_attention_v2, visualize_cosin_token, visualize_cosin_token_32p) to enable detailed visualization of attention patterns and token embedding cosine similarities.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the visualizer components, including configuration changes, modifications to existing visualization functions, and the addition of several new complex visualization utilities. The changes improve functionality, but there are several areas for improvement regarding code clarity, robustness, and maintainability. Key issues include the use of magic numbers, potential division-by-zero errors, modification of global state in plotting functions, and brittle logic for handling file paths. I've provided specific comments and suggestions to address these points.

Comment on lines +33 to +42
if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
os.makedirs(save_path, exist_ok=True)
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
idx = 0
while os.path.exists(base_path.format(idx)):
idx += 1
save_path = base_path.format(idx)

else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to determine if save_path is a directory or a file is based on a hardcoded list of extensions. This is brittle; for example, it would fail for a file without an extension or a file with an unlisted extension (e.g., .svg). A more robust approach is to check for the presence of a file extension using os.path.splitext.

Suggested change
if not save_path.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
os.makedirs(save_path, exist_ok=True)
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
idx = 0
while os.path.exists(base_path.format(idx)):
idx += 1
save_path = base_path.format(idx)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if not os.path.splitext(save_path)[1]:
os.makedirs(save_path, exist_ok=True)
base_path = os.path.join(save_path, '{:04d}_visprunerP.png')
idx = 0
while os.path.exists(base_path.format(idx)):
idx += 1
save_path = base_path.format(idx)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)

Comment on lines +251 to +302
def visualize_attention_v2(attention, grid_size=24, save_path=None):

if hasattr(attention, 'detach'):
attention = attention.detach().cpu().numpy()

# 分区
block_ranges = []

# SYS: 2 blocks

sys_splits = [0, 17, 35]
for i in range(len(sys_splits) - 1):
block_ranges.append((sys_splits[i], sys_splits[i + 1]))
# IMG: 24 blocks of size 24
for i in range(24):
start = 35 + i * 24
end = start + 24
block_ranges.append((start, end))

# INS: 6 blocks
ins_splits = [611 + i * 91 for i in range(7)] # 611 + 6 * 91 = 1157 → crop to 1155
ins_splits[-1] = 1155
for i in range(len(ins_splits) - 1):
block_ranges.append((ins_splits[i], ins_splits[i + 1]))

# 对每个 block pair 求平均
num_blocks = len(block_ranges)
block_attention = np.zeros((num_blocks, num_blocks))
for i in range(num_blocks):
i_start, i_end = block_ranges[i]
for j in range(num_blocks):
j_start, j_end = block_ranges[j]
block = attention[i_start:i_end, j_start:j_end]
block_attention[31 - i, j] = block.mean()

mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1)
plt.figure(figsize=(10, 10))
block_attention = block_attention / block_attention.max(axis=1, keepdims=True)
sns.heatmap(block_attention, mask=mask, cmap='viridis', square=True, cbar=True)
# sns.heatmap(block_attention, cmap='viridis', square=True, cbar=True)

section_labels = ['SYS', 'IMG', 'INS']
section_boundaries = [2, 26, 32] # block_ranges 分别为2个SYS,24个IMG,6个INS
ticks = np.arange(0, num_blocks)
plt.xticks(ticks=ticks, labels=[''] * num_blocks)
plt.yticks(ticks=ticks, labels=[''] * num_blocks)
plt.xticks(ticks=section_boundaries, labels=section_labels, fontsize=12)
plt.yticks(ticks=section_boundaries, labels=section_labels, fontsize=12)
plt.title('Attention Map')
plt.tight_layout()
plt.savefig(save_path, bbox_inches='tight')
plt.close()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function contains many hardcoded "magic numbers" (e.g., [0, 17, 35], 35 + i * 24, 31 - i, [2, 26, 32]) that seem specific to a particular model architecture. This makes the function rigid and hard to maintain or reuse for other models. These values should be refactored to be passed in as parameters or loaded from a configuration.


mask = np.triu(np.ones_like(block_attention, dtype=bool), k=1)
plt.figure(figsize=(10, 10))
block_attention = block_attention / block_attention.max(axis=1, keepdims=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The normalization block_attention / block_attention.max(...) is unsafe. If block_attention.max(...) is zero for any row, it will result in a division by zero, producing NaN or inf values, which will likely cause sns.heatmap to fail. You should handle this case, for example by adding a small epsilon to the denominator or using np.divide with a where clause.

Suggested change
block_attention = block_attention / block_attention.max(axis=1, keepdims=True)
max_vals = block_attention.max(axis=1, keepdims=True)
block_attention = np.divide(block_attention, max_vals, out=np.zeros_like(block_attention), where=max_vals!=0)

)
visualize_grid_to_grid(
visual_attention_maps[0, 4, :, :],
visual_attention_maps[0, 31, :, :],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The index 31 is a "magic number," making it unclear what it represents. To improve readability and maintainability, it should be replaced with a named constant or derived from a configuration value (e.g., num_attention_heads - 1).

functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras),
)
if idx == (len(self.blocks) - 1):
# self.model.language_model.layers[-1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out line appears to be a development note. It should be removed to keep the codebase clean.

Comment on lines 47 to 55
def visualize_kept_patches(
image,
keep_idx,
keep_idx=None,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
patch_size=14,
darken_ratio=0.3,
darken_ratio=0.8,
save_path=None,
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert image.ndim == 3 and image.shape[0] == 3 was removed. This check is valuable for ensuring the input image has the expected shape and provides a clear error message if it doesn't. Without it, the code might fail later with a less specific ValueError. Please consider adding it back at the beginning of the function body.

Comment on lines +307 to +332
plt.rcParams['font.size'] = 15

x = token_embedding[0, 14: 14 + 196 * 4, :]
x_norm = F.normalize(x, p=2, dim=1)
similarity_matrix = x_norm @ x_norm.T

sim_np = similarity_matrix.cpu().numpy()
sim_np = np.triu(sim_np, k=1)
valid_sim = sim_np[sim_np > 0]
vmin = np.percentile(valid_sim, 90) # 10% min

plt.subplots(figsize=(10, 10))
sns.heatmap(similarity_matrix.cpu().numpy(), cmap='Reds', vmin=vmin, vmax=1)

start = 0
step = 196
ticks = np.arange(start, 196 * 5, step)
plt.xticks(ticks, ticks)
plt.yticks(ticks, ticks)

plt.title('')
plt.xlabel('')
plt.ylabel('')
plt.tight_layout()
plt.savefig(save_path, format='pdf')
plt.rcdefaults()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying matplotlib's global rcParams can have unintended side effects on other plotting functions in the application. It's better practice to use a context manager with plt.rc_context(...) to apply these settings only for this function. This also makes the call to plt.rcdefaults() at the end unnecessary.

plt.xlabel('')
plt.ylabel('')
plt.tight_layout()
plt.savefig(save_path, format='pdf')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file format is hardcoded to 'pdf'. matplotlib.pyplot.savefig can infer the format from the file extension in save_path. Hardcoding the format prevents saving in other formats (e.g., PNG, SVG) and may be unexpected if save_path has a different extension. It's better to remove the format argument and let savefig infer it.

Suggested change
plt.savefig(save_path, format='pdf')
plt.savefig(save_path)

Comment on lines +338 to +381
plt.rcParams['font.size'] = 20

all_tokens = token_embedding[0, 14:14 + 196 * 32, :]
x_norm = F.normalize(all_tokens, p=2, dim=1)
similarity_matrix = x_norm @ x_norm.T
sim_np = similarity_matrix.cpu().numpy()
sim_np = np.triu(sim_np, k=1)
valid_sim = sim_np[sim_np > 0]
vmin = np.percentile(valid_sim, 90) # 10% min

group_size = 4
num_groups = 8
tokens_per_group = 196 * group_size
step = 196

fig, axs = plt.subplots(2, 4, figsize=(22, 10)) # 2x4排布
axs = axs.flatten()

for i in range(num_groups):
x = all_tokens[i * tokens_per_group: (i + 1) * tokens_per_group, :]
x_norm = F.normalize(x, p=2, dim=1)
similarity_matrix = x_norm @ x_norm.T

ax = axs[i]
sns.heatmap(
similarity_matrix.cpu().numpy(), cmap='Reds',
vmin=vmin, vmax=1, ax=ax, cbar=False
)

ticks = np.arange(0, tokens_per_group, step)
labels = np.arange(i * tokens_per_group, (i + 1) * tokens_per_group, step)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_xticklabels(labels, rotation=0)
ax.set_yticklabels(labels)
start_frame = i * group_size
end_frame = (i + 1) * group_size - 1
ax.set_xlabel(f'Frame {start_frame}-{end_frame}', fontsize=17, labelpad=10)

plt.tight_layout()
# plt.savefig(save_path, format='pdf')
# plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight')
plt.savefig(save_path, dpi=300)
plt.rcdefaults()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to visualize_cosin_token, modifying global rcParams is risky. Please use with plt.rc_context(...) to scope the font size change to this function, which will also make plt.rcdefaults() unnecessary.

Comment on lines +378 to +379
# plt.savefig(save_path, format='pdf')
# plt.savefig(save_path.replace('.pdf', '.svg'), format='svg', bbox_inches='tight')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These commented-out lines should be removed to keep the code clean.

@llmc-reviewer llmc-reviewer merged commit e219a71 into ModelTC:main Aug 12, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants