Skip to content

[feat]: add mistral moe loader compatibility#1873

Merged
chenht2022 merged 1 commit intomainfrom
feat/mistral-loader-support
Feb 28, 2026
Merged

[feat]: add mistral moe loader compatibility#1873
chenht2022 merged 1 commit intomainfrom
feat/mistral-loader-support

Conversation

@chenht2022
Copy link
Contributor

Summary

  • add Mistral MoE key format support in FP8/BF16 loaders
  • add base-key fallback for weights without model. prefix
  • infer FP8 scale type by tensor shape when suffix is weight_scale
  • cast FP8_PERCHANNEL scales to float32 before AMX load

Validation

  • py_compile passes
  • Mistral startup and minimal request validation passed in kt env

Copilot AI review requested due to automatic review settings February 28, 2026 05:43
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and compatibility of the weight loading utilities for Mixture-of-Experts (MoE) models. It introduces specific support for Mistral's MoE key structure and refines the logic for detecting scale types and handling variations in weight key prefixes. These changes collectively improve the robustness of the loading process, making it easier to integrate and run a wider range of quantized and BF16 MoE models.

Highlights

  • Mistral MoE Compatibility: Added support for Mistral-style Mixture-of-Experts (MoE) key formats in both FP8 and BF16 weight loaders, allowing for proper loading of Mistral models.
  • Flexible Weight Key Detection: Implemented a fallback mechanism for model weights that may or may not include the 'model.' prefix in their keys, enhancing compatibility with various model serialization formats.
  • Dynamic FP8 Scale Type Inference: Improved FP8 scale detection by inferring whether a scale is per-channel or block-wise based on the tensor's shape, rather than solely relying on the 'weight_scale' suffix, which addresses ambiguities in some model formats like Mistral.
  • FP8_PERCHANNEL Scale Casting: Ensured that FP8_PERCHANNEL scales are explicitly cast to float32 before being used in AMX loading operations, preventing potential dtype mismatches and ensuring correct processing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • kt-kernel/python/utils/amx.py
    • Ensured FP8_PERCHANNEL scales are cast to float32 if their dtype is not already float32.
  • kt-kernel/python/utils/loader.py
    • Added 'Mistral style' to the supported formats for FP8SafeTensorLoader and BF16SafeTensorLoader.
    • Included Mistral MoE format definition in MOE_FORMATS for both FP8 and BF16 loaders.
    • Modified format detection logic to specifically identify Mistral MoE format based on key patterns.
    • Updated FP8 scale detection to infer per-channel or block-wise nature from tensor shape when 'weight_scale' suffix is present.
    • Refactored _get_experts_prefix to _get_experts_prefix_candidates to provide multiple prefix options, including a fallback for keys without 'model.' prefix.
    • Adjusted load_experts in both loaders to iterate through candidate prefixes to find existing expert weights.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Mistral MoE weight formats to the FP8 and BF16 loaders. The changes include detecting the new format, handling key prefixes with and without the 'model.' prefix, and inferring FP8 scale types from tensor shapes. Additionally, it ensures FP8 per-channel scales are correctly cast to float32 before being passed to the AMX kernel.

My review focuses on improving code clarity and reducing duplication. I've suggested simplifying a boolean check and refactoring repetitive code into a loop. I also noted an opportunity to reduce code duplication between the FP8SafeTensorLoader and BF16SafeTensorLoader classes for better long-term maintainability.

Comment on lines +452 to +454
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block correctly ensures that scales are torch.float32. However, the code for casting gate_scales, up_scales, and down_scales is repetitive. This logic is also duplicated from the FP8 method handling on lines 446-448.

To improve maintainability and reduce code duplication, you could use a loop to handle the casting for all three projections.

Suggested change
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
for proj in ["gate", "up", "down"]:
setattr(self, f"{proj}_scales", [t.to(torch.float32).contiguous() for t in weights[f"{proj}_scale"]])

Comment on lines +511 to +520
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
"""Get candidate experts prefixes based on detected format and base key variants."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
candidates = [path_tpl.format(base=base_key)]

# Some model weights (e.g., Mistral native format) do not have "model." prefix.
if base_key.startswith("model."):
candidates.append(path_tpl.format(base=base_key[len("model.") :]))

return list(dict.fromkeys(candidates))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _get_experts_prefix_candidates is very similar to its counterpart in the FP8SafeTensorLoader class (lines 355-368). To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, consider refactoring the common logic into a shared method.

A potential approach would be to move this logic to the SafeTensorLoader base class or a new intermediate base class for MoE loaders, allowing for format-specific customizations where needed (like the _is_vl_model check in FP8SafeTensorLoader).

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for loading Mistral-style MoE (Mixture-of-Experts) models in the FP8 and BF16 safetensor weight loaders. Mistral uses a simpler key naming convention ({base}.experts.{id}.{w1,w3,w2}.weight) without the intermediate .mlp.experts or .block_sparse_moe.experts path components used by DeepSeek and Mixtral/MiniMax respectively. Additionally, Mistral's native weight files may omit the model. prefix, and their FP8 scales use weight_scale (without _inv) for what may be block-wise quantization rather than per-channel.

Changes:

  • Add "mistral" format entry to MOE_FORMATS in both FP8SafeTensorLoader and BF16SafeTensorLoader, with corresponding format detection logic
  • Rename _get_experts_prefix()_get_experts_prefix_candidates() in both loaders to return a list of prefix candidates (primary + model.-stripped fallback), and update load_experts() to try candidates in order
  • Infer FP8 scale granularity (per-channel vs. block-wise) from tensor shape when the suffix is weight_scale, and cast FP8_PERCHANNEL scales to float32 in amx.py before the dtype assertion

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
kt-kernel/python/utils/loader.py Adds Mistral MoE format detection, base-key prefix fallback for weights without model. prefix, and FP8 scale type inference from tensor shape
kt-kernel/python/utils/amx.py Adds float32 cast for FP8_PERCHANNEL scales before the dtype assertion, mirroring existing FP8 behavior

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 248 to 249
Supported scale formats (auto-detected):
- Block-wise: weight_scale_inv (DeepSeek FP8)
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FP8SafeTensorLoader class docstring section "Supported scale formats (auto-detected)" at lines 248-249 describes weight_scale as only "Per-channel (GLM-4.7-FP8)". However, this PR introduces logic (lines 330-344) that infers weight_scale can also be block-wise (as used by Mistral FP8). The docstring should be updated to reflect that weight_scale can be either per-channel or block-wise, with the granularity inferred from the scale tensor shape.

Copilot uses AI. Check for mistakes.
@chenht2022 chenht2022 force-pushed the feat/mistral-loader-support branch from b36226e to b49af1d Compare February 28, 2026 05:55
@chenht2022 chenht2022 merged commit 9e69fcc into main Feb 28, 2026
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants