fix: guard TorchTruncatedSVD against non-finite input and SVD non-convergence#1076
fix: guard TorchTruncatedSVD against non-finite input and SVD non-convergence#1076devangpratap wants to merge 3 commits into
Conversation
…vergence TorchTruncatedSVD.fit masked only NaN before the SVD, so +/-inf values flowed into torch.linalg.svd / torch.svd_lowrank. Depending on the LAPACK backend this either raised a non-convergence LinAlgError (crashing preprocessing) or silently produced non-finite components, on datasets with degenerate row-subsamples. Mask all non-finite values (NaN and +/-inf) to 0 before the decomposition, matching how TorchSafeStandardScaler already treats inf. Also add a float64 retry when the exact SVD fails to converge; the randomized (lowrank) path falls back to the exact decomposition on the same error. Fixes PriorLabs#1044
There was a problem hiding this comment.
Code Review
This pull request improves the robustness of SVD computations by replacing non-finite values (NaN and infinities) with zeros and introducing a fallback mechanism (_exact_svd) that retries in double precision (float64) upon encountering a convergence error (LinAlgError). It also adds corresponding unit tests. The feedback suggests optimizing the fallback mechanism to avoid a redundant and expensive SVD computation if the input tensor is already in double precision.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
@devangpratap thank you for your contribution! Seems useful! We will review this and get back to you! |
|
Hello @devangpratap, we iterated on this internally, and we believe that the current code covers this problem already. Were you able to replicate the issue reported? Could you provide a working script that fails and motivated your fix? Thanks! |
|
Went back through this properly. Script that shows the actual behavior on main: import torch
from tabpfn.preprocessing.torch.torch_svd import TorchTruncatedSVD
# small matrix, exact svd path: no crash, but silently returns NaN
x = torch.randn(8, 5)
x[0, 0] = float("inf")
out = TorchTruncatedSVD(n_components=4).fit(x)
print(torch.isnan(out["singular_values"]).any()) # True
# large matrix, svd_lowrank path: raises
x2 = torch.randn(5000, 300)
x2[0, 0] = float("inf")
TorchTruncatedSVD(n_components=128).fit(x2)
# _LinAlgError: input matrix is ill-conditioned or has too many repeated singular valuesBut you're right that this doesn't matter in practice. Through the actual pipeline it's dead code: TorchSafeStandardScaler already sanitizes inf before it reaches the SVD (steps.py:414-423, torch_svd.py:340-341). Checked it by running the pre-fix torch_svd.py with an inf value through TorchAddSVDFeaturesStep end to end, no crash, no NaNs. So the isfinite() guard isn't fixing anything reachable today, only relevant if TorchTruncatedSVD is called directly, which nothing in the codebase does. The other piece, the float64 retry when SVD fails to converge, is still worth keeping. Tried reproducing the SBDSDC error from #1044 on finite degenerate data (near-constant columns, rank-1, all-zeros, extreme scale) and couldn't trigger it either, seems LAPACK/hardware dependent rather than something a script can force. I'll cut the isfinite() masking and keep just the float64 retry, since that one's a no-op on the normal path and covers the actual failure mode in #1044. |
|
Closing this. Neither half of the fix protects against something reachable. isfinite() guard: dead code. TorchSafeStandardScaler always sanitizes input before it reaches the SVD, verified end to end. Even tried overflowing the scaler's own mean/std to inf, output stayed finite. float64 retry: couldn't reproduce the SBDSDC error from #1044. ~31k trials across degenerate matrix shapes, macOS and Linux, single and multi-threaded. Zero failures. Reopen if the reporter comes back with a concrete traceback. |
Issue
Fixes #1044
Motivation and Context
TorchTruncatedSVD.fitreplaced only NaN values with 0 before computing the SVD, leaving any +/-inf in place. Those infinities flow straight intotorch.linalg.svd/torch.svd_lowrank, which then either raise a non-convergenceLinAlgError(crashing the preprocessing pipeline) or silently return non-finite components, depending on the LAPACK backend. This shows up on datasets with degenerate row-subsamples (for example near-constant columns at small row caps), as reported in #1044.This masks all non-finite values (NaN and +/-inf) to 0 before the decomposition, matching how
TorchSafeStandardScaleralready treats inf. It also adds a numerical-stability fallback: if the single-precision SVD fails to converge, it retries the exact decomposition in float64; the randomized (lowrank) path falls back to the exact decomposition on the same error.Public API Changes
How Has This Been Tested?
Added two regression tests:
test__fit__non_finite_input_does_not_crashfeeds +inf, -inf, and NaN intofitand asserts the returned components and singular values are finite.test__exact_svd__retries_in_float64_on_non_convergenceforces the single-precision SVD to raiseLinAlgErrorand asserts the float64 fallback runs and round-trips the dtype.Ran the
test_torch_svd.pysuite locally (passes), plus pre-commit (ruff, ruff-format, mypy) clean.Checklist
changelog/README.md), or "no changelog needed" label requested.