From a8ac75ff36809d986b4648b6773ffc34d30227b1 Mon Sep 17 00:00:00 2001 From: Noa Levi <275430404+lphuc2250gma@users.noreply.github.com> Date: Fri, 5 Jun 2026 11:36:48 +0000 Subject: [PATCH] chore: improve ktransformers maintenance path --- kt-kernel/examples/test_mla.py | 12 ++++++------ kt-kernel/examples/test_mlp.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/kt-kernel/examples/test_mla.py b/kt-kernel/examples/test_mla.py index bb9b512d8..d341b2338 100644 --- a/kt-kernel/examples/test_mla.py +++ b/kt-kernel/examples/test_mla.py @@ -58,15 +58,15 @@ def get_torch_tensor_and_type_from_gguf(gguf_weights, name): return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name -def type_to_ggml_type(type): - if type == "F32": +def type_to_ggml_type(dtype_str: str) -> ggml_type: + if dtype_str == "F32": return ggml_type.FP32 - elif type == "F16": + elif dtype_str == "F16": return ggml_type.FP16 - elif type == "BF16": + elif dtype_str == "BF16": return ggml_type.BF16 else: - raise ValueError(f"Unsupported data type: {type}") + raise ValueError(f"Unsupported data type: {dtype_str}") use_real_weights = True @@ -721,4 +721,4 @@ def torch_attn( print( f"Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()}, relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}" ) -assert diff_relative_mean < 2e-1, "CPU and Torch outputs are not close enough!" +assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean:.4e} >= 2e-1" diff --git a/kt-kernel/examples/test_mlp.py b/kt-kernel/examples/test_mlp.py index 832053bf3..6525df85b 100644 --- a/kt-kernel/examples/test_mlp.py +++ b/kt-kernel/examples/test_mlp.py @@ -30,11 +30,11 @@ validation_iter = 100 -def act_fn(x): +def act_fn(x: torch.Tensor) -> torch.Tensor: return x / (1.0 + torch.exp(-x)) -def mlp_torch(input, gate_proj, up_proj, down_proj): +def mlp_torch(input: torch.Tensor, gate_proj: torch.Tensor, up_proj: torch.Tensor, down_proj: torch.Tensor) -> torch.Tensor: gate_buf = torch.mm(input, gate_proj.t()) up_buf = torch.mm(input, up_proj.t()) intermediate = act_fn(gate_buf) * up_buf @@ -95,4 +95,4 @@ def mlp_torch(input, gate_proj, up_proj, down_proj): diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print("diff = ", diff) - assert diff < 0.001 + assert diff < 0.001, f"MLP output mismatch: diff={diff:.4e} exceeds threshold 0.001"