chore: improve ktransformers maintenance path#2033
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces type annotations to functions in test_mla.py and test_mlp.py, and enhances assertion error messages with formatted difference values. The review feedback points out that formatting 0-dimensional PyTorch tensors directly with float specifiers can cause issues, and suggests calling .item() to retrieve the underlying Python float first.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| f"Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()}, relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}" | ||
| ) | ||
| assert diff_relative_mean < 2e-1, "CPU and Torch outputs are not close enough!" | ||
| assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean:.4e} >= 2e-1" |
There was a problem hiding this comment.
The variable diff_relative_mean is a PyTorch Tensor (0-dimensional). Directly formatting a PyTorch tensor with float format specifiers like :.4e can raise a TypeError or produce unexpected formatting in older PyTorch versions. It is safer and more idiomatic to call .item() to retrieve the underlying Python float before formatting.
| assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean:.4e} >= 2e-1" | |
| assert diff_relative_mean < 2e-1, f"CPU and Torch outputs diverge: relative_mean={diff_relative_mean.item():.4e} >= 2e-1" |
| diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) | ||
| print("diff = ", diff) | ||
| assert diff < 0.001 | ||
| assert diff < 0.001, f"MLP output mismatch: diff={diff:.4e} exceeds threshold 0.001" |
There was a problem hiding this comment.
The variable diff is a PyTorch Tensor (0-dimensional). Directly formatting a PyTorch tensor with float format specifiers like :.4e can raise a TypeError or produce unexpected formatting in older PyTorch versions. It is safer and more idiomatic to call .item() to retrieve the underlying Python float before formatting.
| assert diff < 0.001, f"MLP output mismatch: diff={diff:.4e} exceeds threshold 0.001" | |
| assert diff < 0.001, f"MLP output mismatch: diff={diff.item():.4e} exceeds threshold 0.001" |
Summary:
Notes: