-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[feat]: add mistral moe loader compatibility #1873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader): | |
| Supported formats: | ||
| - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight | ||
| - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight | ||
| - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight | ||
|
|
||
| Supported scale formats (auto-detected): | ||
| - Block-wise: weight_scale_inv (DeepSeek FP8) | ||
|
Comment on lines
248
to
249
|
||
|
|
@@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader): | |
| MOE_FORMATS = { | ||
| "deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"), | ||
| "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), | ||
| "mistral": ("{base}.experts", "w1", "w3", "w2"), | ||
| } | ||
|
|
||
| def __init__(self, file_path: str, scale_suffix: str = None): | ||
|
|
@@ -297,6 +299,10 @@ def _detect_format(self): | |
| self._detected_format = fmt_name | ||
| print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") | ||
| break | ||
| elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key: | ||
| self._detected_format = fmt_name | ||
| print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") | ||
| break | ||
| if self._detected_format: | ||
| break | ||
|
|
||
|
|
@@ -321,8 +327,21 @@ def _detect_format(self): | |
| return | ||
| elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key: | ||
| self._scale_suffix = "weight_scale" | ||
| self._is_per_channel = True | ||
| print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)") | ||
| # Some models (e.g., Mistral) use block-wise FP8 scales but keep | ||
| # the key suffix as `weight_scale` (without `_inv`). Infer format | ||
| # from scale tensor shape instead of suffix alone: | ||
| # - per-channel: [N] or [N, 1] | ||
| # - block-wise: [N_block, K_block] (both dims > 1) | ||
| scale_tensor = self.load_tensor(key, device="cpu") | ||
| if scale_tensor.dim() == 1: | ||
| self._is_per_channel = True | ||
| elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1: | ||
| self._is_per_channel = True | ||
| else: | ||
| self._is_per_channel = False | ||
|
|
||
| scale_kind = "per-channel" if self._is_per_channel else "block-wise" | ||
| print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)") | ||
| return | ||
| # Default to weight_scale_inv | ||
| self._scale_suffix = "weight_scale_inv" | ||
|
|
@@ -333,12 +352,20 @@ def _detect_format(self): | |
| scale_type = "per-channel" if self._is_per_channel else "block-wise" | ||
| print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})") | ||
|
|
||
| def _get_experts_prefix(self, base_key: str) -> str: | ||
| """Get the experts prefix based on detected format.""" | ||
| def _get_experts_prefix_candidates(self, base_key: str) -> list[str]: | ||
| """Get candidate experts prefixes based on detected format and base key variants.""" | ||
| path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] | ||
| candidates = [] | ||
| if self._is_vl_model: | ||
| base_key = base_key.replace("model.layers", "model.language_model.layers") | ||
| return path_tpl.format(base=base_key) | ||
| candidates.append(path_tpl.format(base=base_key)) | ||
|
|
||
| # Some model weights (e.g., Mistral native format) do not have "model." prefix. | ||
| if base_key.startswith("model."): | ||
| candidates.append(path_tpl.format(base=base_key[len("model.") :])) | ||
|
|
||
| # Deduplicate while preserving order. | ||
| return list(dict.fromkeys(candidates)) | ||
|
|
||
| def _get_proj_names(self): | ||
| """Get projection names (gate, up, down) based on detected format.""" | ||
|
|
@@ -363,15 +390,21 @@ def load_experts(self, base_key: str, device: str = "cpu"): | |
| Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats. | ||
| Per-channel scales are squeezed from [N, 1] to [N] if needed. | ||
| """ | ||
| experts_prefix = self._get_experts_prefix(base_key) | ||
| experts_prefix_candidates = self._get_experts_prefix_candidates(base_key) | ||
| gate_name, up_name, down_name = self._get_proj_names() | ||
|
|
||
| expert_count = 0 | ||
| while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"): | ||
| expert_count += 1 | ||
| experts_prefix = None | ||
| for prefix in experts_prefix_candidates: | ||
| expert_count = 0 | ||
| while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"): | ||
| expert_count += 1 | ||
| if expert_count > 0: | ||
| experts_prefix = prefix | ||
| break | ||
|
|
||
| if expert_count == 0: | ||
| raise ValueError(f"No experts found for key {experts_prefix}") | ||
| if expert_count == 0 or experts_prefix is None: | ||
| raise ValueError(f"No experts found for keys: {experts_prefix_candidates}") | ||
|
|
||
| gate_weights = [None] * expert_count | ||
| up_weights = [None] * expert_count | ||
|
|
@@ -423,20 +456,21 @@ def is_per_channel(self) -> bool: | |
| return self._is_per_channel | ||
|
|
||
|
|
||
|
|
||
| class BF16SafeTensorLoader(SafeTensorLoader): | ||
| """Loader for native BF16 expert weights (no quantization, no scales). | ||
|
|
||
| Supported formats: | ||
| - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight | ||
| - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight | ||
| - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight | ||
|
|
||
| The format is auto-detected during initialization. | ||
| """ | ||
|
|
||
| MOE_FORMATS = { | ||
| "deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"), | ||
| "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), | ||
| "mistral": ("{base}.experts", "w1", "w3", "w2"), | ||
| } | ||
|
|
||
| def __init__(self, file_path: str): | ||
|
|
@@ -466,14 +500,24 @@ def _detect_format(self): | |
| self._detected_format = fmt_name | ||
| print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}") | ||
| return | ||
| elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key: | ||
| self._detected_format = fmt_name | ||
| print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}") | ||
| return | ||
|
|
||
| self._detected_format = "deepseek" | ||
| print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek") | ||
|
|
||
| def _get_experts_prefix(self, base_key: str) -> str: | ||
| """Get the experts prefix based on detected format.""" | ||
| def _get_experts_prefix_candidates(self, base_key: str) -> list[str]: | ||
| """Get candidate experts prefixes based on detected format and base key variants.""" | ||
| path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] | ||
| return path_tpl.format(base=base_key) | ||
| candidates = [path_tpl.format(base=base_key)] | ||
|
|
||
| # Some model weights (e.g., Mistral native format) do not have "model." prefix. | ||
| if base_key.startswith("model."): | ||
| candidates.append(path_tpl.format(base=base_key[len("model.") :])) | ||
|
|
||
| return list(dict.fromkeys(candidates)) | ||
|
Comment on lines
+511
to
+520
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function A potential approach would be to move this logic to the |
||
|
|
||
| def _get_proj_names(self): | ||
| """Get projection names (gate, up, down) based on detected format.""" | ||
|
|
@@ -497,15 +541,21 @@ def load_experts(self, base_key: str, device: str = "cpu"): | |
| if self._detected_format == "packed": | ||
| return self._load_experts_packed(base_key, device) | ||
|
|
||
| experts_prefix = self._get_experts_prefix(base_key) | ||
| experts_prefix_candidates = self._get_experts_prefix_candidates(base_key) | ||
| gate_name, up_name, down_name = self._get_proj_names() | ||
|
|
||
| expert_count = 0 | ||
| while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"): | ||
| expert_count += 1 | ||
| experts_prefix = None | ||
| for prefix in experts_prefix_candidates: | ||
| expert_count = 0 | ||
| while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"): | ||
| expert_count += 1 | ||
| if expert_count > 0: | ||
| experts_prefix = prefix | ||
| break | ||
|
|
||
| if expert_count == 0: | ||
| raise ValueError(f"No experts found for key {experts_prefix}") | ||
| if expert_count == 0 or experts_prefix is None: | ||
| raise ValueError(f"No experts found for keys: {experts_prefix_candidates}") | ||
|
|
||
| gate_weights = [None] * expert_count | ||
| up_weights = [None] * expert_count | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block correctly ensures that scales are
torch.float32. However, the code for castinggate_scales,up_scales, anddown_scalesis repetitive. This logic is also duplicated from theFP8method handling on lines 446-448.To improve maintainability and reduce code duplication, you could use a loop to handle the casting for all three projections.