Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions kt-kernel/examples/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def get_torch_tensor_and_type_from_gguf(gguf_weights, name):
return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name


def type_to_ggml_type(type):
if type == "F32":
def type_to_ggml_type(dtype_str: str) -> ggml_type:
if dtype_str == "F32":
return ggml_type.FP32
elif type == "F16":
elif dtype_str == "F16":
return ggml_type.FP16
elif type == "BF16":
elif dtype_str == "BF16":
return ggml_type.BF16
else:
raise ValueError(f"Unsupported data type: {type}")
raise ValueError(f"Unsupported data type: {dtype_str}")


use_real_weights = True
Expand Down Expand Up @@ -721,4 +721,4 @@ def torch_attn(
print(
f"Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()}, relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}"
)
assert diff_relative_mean < 2e-1, "CPU and Torch outputs are not close enough!"
assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean:.4e} >= 2e-1"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable diff_relative_mean is a PyTorch Tensor (0-dimensional). Directly formatting a PyTorch tensor with float format specifiers like :.4e can raise a TypeError or produce unexpected formatting in older PyTorch versions. It is safer and more idiomatic to call .item() to retrieve the underlying Python float before formatting.

Suggested change
assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean:.4e} >= 2e-1"
assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean.item():.4e} >= 2e-1"

6 changes: 3 additions & 3 deletions kt-kernel/examples/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
validation_iter = 100


def act_fn(x):
def act_fn(x: torch.Tensor) -> torch.Tensor:
return x / (1.0 + torch.exp(-x))


def mlp_torch(input, gate_proj, up_proj, down_proj):
def mlp_torch(input: torch.Tensor, gate_proj: torch.Tensor, up_proj: torch.Tensor, down_proj: torch.Tensor) -> torch.Tensor:
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
Expand Down Expand Up @@ -95,4 +95,4 @@ def mlp_torch(input, gate_proj, up_proj, down_proj):

diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print("diff = ", diff)
assert diff < 0.001
assert diff < 0.001, f"MLP output mismatch: diff={diff:.4e} exceeds threshold 0.001"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable diff is a PyTorch Tensor (0-dimensional). Directly formatting a PyTorch tensor with float format specifiers like :.4e can raise a TypeError or produce unexpected formatting in older PyTorch versions. It is safer and more idiomatic to call .item() to retrieve the underlying Python float before formatting.

Suggested change
assert diff < 0.001, f"MLP output mismatch: diff={diff:.4e} exceeds threshold 0.001"
assert diff < 0.001, f"MLP output mismatch: diff={diff.item():.4e} exceeds threshold 0.001"