[feat]: add mistral moe loader compatibility#1873
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and compatibility of the weight loading utilities for Mixture-of-Experts (MoE) models. It introduces specific support for Mistral's MoE key structure and refines the logic for detecting scale types and handling variations in weight key prefixes. These changes collectively improve the robustness of the loading process, making it easier to integrate and run a wider range of quantized and BF16 MoE models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for Mistral MoE weight formats to the FP8 and BF16 loaders. The changes include detecting the new format, handling key prefixes with and without the 'model.' prefix, and inferring FP8 scale types from tensor shapes. Additionally, it ensures FP8 per-channel scales are correctly cast to float32 before being passed to the AMX kernel.
My review focuses on improving code clarity and reducing duplication. I've suggested simplifying a boolean check and refactoring repetitive code into a loop. I also noted an opportunity to reduce code duplication between the FP8SafeTensorLoader and BF16SafeTensorLoader classes for better long-term maintainability.
| self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]] | ||
| self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]] | ||
| self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]] |
There was a problem hiding this comment.
This block correctly ensures that scales are torch.float32. However, the code for casting gate_scales, up_scales, and down_scales is repetitive. This logic is also duplicated from the FP8 method handling on lines 446-448.
To improve maintainability and reduce code duplication, you could use a loop to handle the casting for all three projections.
| self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]] | |
| self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]] | |
| self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]] | |
| for proj in ["gate", "up", "down"]: | |
| setattr(self, f"{proj}_scales", [t.to(torch.float32).contiguous() for t in weights[f"{proj}_scale"]]) |
| def _get_experts_prefix_candidates(self, base_key: str) -> list[str]: | ||
| """Get candidate experts prefixes based on detected format and base key variants.""" | ||
| path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] | ||
| return path_tpl.format(base=base_key) | ||
| candidates = [path_tpl.format(base=base_key)] | ||
|
|
||
| # Some model weights (e.g., Mistral native format) do not have "model." prefix. | ||
| if base_key.startswith("model."): | ||
| candidates.append(path_tpl.format(base=base_key[len("model.") :])) | ||
|
|
||
| return list(dict.fromkeys(candidates)) |
There was a problem hiding this comment.
This function _get_experts_prefix_candidates is very similar to its counterpart in the FP8SafeTensorLoader class (lines 355-368). To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, consider refactoring the common logic into a shared method.
A potential approach would be to move this logic to the SafeTensorLoader base class or a new intermediate base class for MoE loaders, allowing for format-specific customizations where needed (like the _is_vl_model check in FP8SafeTensorLoader).
There was a problem hiding this comment.
Pull request overview
This PR adds support for loading Mistral-style MoE (Mixture-of-Experts) models in the FP8 and BF16 safetensor weight loaders. Mistral uses a simpler key naming convention ({base}.experts.{id}.{w1,w3,w2}.weight) without the intermediate .mlp.experts or .block_sparse_moe.experts path components used by DeepSeek and Mixtral/MiniMax respectively. Additionally, Mistral's native weight files may omit the model. prefix, and their FP8 scales use weight_scale (without _inv) for what may be block-wise quantization rather than per-channel.
Changes:
- Add
"mistral"format entry toMOE_FORMATSin bothFP8SafeTensorLoaderandBF16SafeTensorLoader, with corresponding format detection logic - Rename
_get_experts_prefix()→_get_experts_prefix_candidates()in both loaders to return a list of prefix candidates (primary +model.-stripped fallback), and updateload_experts()to try candidates in order - Infer FP8 scale granularity (per-channel vs. block-wise) from tensor shape when the suffix is
weight_scale, and castFP8_PERCHANNELscales to float32 inamx.pybefore the dtype assertion
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
kt-kernel/python/utils/loader.py |
Adds Mistral MoE format detection, base-key prefix fallback for weights without model. prefix, and FP8 scale type inference from tensor shape |
kt-kernel/python/utils/amx.py |
Adds float32 cast for FP8_PERCHANNEL scales before the dtype assertion, mirroring existing FP8 behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Supported scale formats (auto-detected): | ||
| - Block-wise: weight_scale_inv (DeepSeek FP8) |
There was a problem hiding this comment.
The FP8SafeTensorLoader class docstring section "Supported scale formats (auto-detected)" at lines 248-249 describes weight_scale as only "Per-channel (GLM-4.7-FP8)". However, this PR introduces logic (lines 330-344) that infers weight_scale can also be block-wise (as used by Mistral FP8). The docstring should be updated to reflect that weight_scale can be either per-channel or block-wise, with the granularity inferred from the scale tensor shape.
b36226e to
b49af1d
Compare
Summary
Validation