Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions kt-kernel/python/utils/amx.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
elif self.method == "FP8_PERCHANNEL":
if self.gate_scales[0].dtype != torch.float32:
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
Comment on lines +452 to +454
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block correctly ensures that scales are torch.float32. However, the code for casting gate_scales, up_scales, and down_scales is repetitive. This logic is also duplicated from the FP8 method handling on lines 446-448.

To improve maintainability and reduce code duplication, you could use a loop to handle the casting for all three projections.

Suggested change
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
for proj in ["gate", "up", "down"]:
setattr(self, f"{proj}_scales", [t.to(torch.float32).contiguous() for t in weights[f"{proj}_scale"]])

assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"

t2 = time.time()
Expand Down
88 changes: 69 additions & 19 deletions kt-kernel/python/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight

Supported scale formats (auto-detected):
- Block-wise: weight_scale_inv (DeepSeek FP8)
Comment on lines 248 to 249
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FP8SafeTensorLoader class docstring section "Supported scale formats (auto-detected)" at lines 248-249 describes weight_scale as only "Per-channel (GLM-4.7-FP8)". However, this PR introduces logic (lines 330-344) that infers weight_scale can also be block-wise (as used by Mistral FP8). The docstring should be updated to reflect that weight_scale can be either per-channel or block-wise, with the granularity inferred from the scale tensor shape.

Copilot uses AI. Check for mistakes.
Expand All @@ -255,6 +256,7 @@ class FP8SafeTensorLoader(SafeTensorLoader):
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
"mistral": ("{base}.experts", "w1", "w3", "w2"),
}

def __init__(self, file_path: str, scale_suffix: str = None):
Expand Down Expand Up @@ -297,6 +299,10 @@ def _detect_format(self):
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
break
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
break
if self._detected_format:
break

Expand All @@ -321,8 +327,21 @@ def _detect_format(self):
return
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
self._scale_suffix = "weight_scale"
self._is_per_channel = True
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
# Some models (e.g., Mistral) use block-wise FP8 scales but keep
# the key suffix as `weight_scale` (without `_inv`). Infer format
# from scale tensor shape instead of suffix alone:
# - per-channel: [N] or [N, 1]
# - block-wise: [N_block, K_block] (both dims > 1)
scale_tensor = self.load_tensor(key, device="cpu")
if scale_tensor.dim() == 1:
self._is_per_channel = True
elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1:
self._is_per_channel = True
else:
self._is_per_channel = False

scale_kind = "per-channel" if self._is_per_channel else "block-wise"
print(f"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)")
return
# Default to weight_scale_inv
self._scale_suffix = "weight_scale_inv"
Expand All @@ -333,12 +352,20 @@ def _detect_format(self):
scale_type = "per-channel" if self._is_per_channel else "block-wise"
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")

def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
"""Get candidate experts prefixes based on detected format and base key variants."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
candidates = []
if self._is_vl_model:
base_key = base_key.replace("model.layers", "model.language_model.layers")
return path_tpl.format(base=base_key)
candidates.append(path_tpl.format(base=base_key))

# Some model weights (e.g., Mistral native format) do not have "model." prefix.
if base_key.startswith("model."):
candidates.append(path_tpl.format(base=base_key[len("model.") :]))

# Deduplicate while preserving order.
return list(dict.fromkeys(candidates))

def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
Expand All @@ -363,15 +390,21 @@ def load_experts(self, base_key: str, device: str = "cpu"):
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
Per-channel scales are squeezed from [N, 1] to [N] if needed.
"""
experts_prefix = self._get_experts_prefix(base_key)
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
gate_name, up_name, down_name = self._get_proj_names()

expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
experts_prefix = None
for prefix in experts_prefix_candidates:
expert_count = 0
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count > 0:
experts_prefix = prefix
break

if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
if expert_count == 0 or experts_prefix is None:
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")

gate_weights = [None] * expert_count
up_weights = [None] * expert_count
Expand Down Expand Up @@ -423,20 +456,21 @@ def is_per_channel(self) -> bool:
return self._is_per_channel



class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).

Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
- Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight

The format is auto-detected during initialization.
"""

MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
"mistral": ("{base}.experts", "w1", "w3", "w2"),
}

def __init__(self, file_path: str):
Expand Down Expand Up @@ -466,14 +500,24 @@ def _detect_format(self):
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
elif fmt_name == "mistral" and ".mlp.experts" not in key and ".block_sparse_moe.experts" not in key:
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return

self._detected_format = "deepseek"
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")

def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:
"""Get candidate experts prefixes based on detected format and base key variants."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
candidates = [path_tpl.format(base=base_key)]

# Some model weights (e.g., Mistral native format) do not have "model." prefix.
if base_key.startswith("model."):
candidates.append(path_tpl.format(base=base_key[len("model.") :]))

return list(dict.fromkeys(candidates))
Comment on lines +511 to +520
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _get_experts_prefix_candidates is very similar to its counterpart in the FP8SafeTensorLoader class (lines 355-368). To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, consider refactoring the common logic into a shared method.

A potential approach would be to move this logic to the SafeTensorLoader base class or a new intermediate base class for MoE loaders, allowing for format-specific customizations where needed (like the _is_vl_model check in FP8SafeTensorLoader).


def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
Expand All @@ -497,15 +541,21 @@ def load_experts(self, base_key: str, device: str = "cpu"):
if self._detected_format == "packed":
return self._load_experts_packed(base_key, device)

experts_prefix = self._get_experts_prefix(base_key)
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
gate_name, up_name, down_name = self._get_proj_names()

expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
experts_prefix = None
for prefix in experts_prefix_candidates:
expert_count = 0
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count > 0:
experts_prefix = prefix
break

if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
if expert_count == 0 or experts_prefix is None:
raise ValueError(f"No experts found for keys: {experts_prefix_candidates}")

gate_weights = [None] * expert_count
up_weights = [None] * expert_count
Expand Down
Loading