Skip to content

fix: guard TorchTruncatedSVD against non-finite input and SVD non-convergence#1076

Closed
devangpratap wants to merge 3 commits into
PriorLabs:mainfrom
devangpratap:fix/svd-non-finite-fallback
Closed

fix: guard TorchTruncatedSVD against non-finite input and SVD non-convergence#1076
devangpratap wants to merge 3 commits into
PriorLabs:mainfrom
devangpratap:fix/svd-non-finite-fallback

Conversation

@devangpratap

@devangpratap devangpratap commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Issue

Fixes #1044

Motivation and Context

TorchTruncatedSVD.fit replaced only NaN values with 0 before computing the SVD, leaving any +/-inf in place. Those infinities flow straight into torch.linalg.svd / torch.svd_lowrank, which then either raise a non-convergence LinAlgError (crashing the preprocessing pipeline) or silently return non-finite components, depending on the LAPACK backend. This shows up on datasets with degenerate row-subsamples (for example near-constant columns at small row caps), as reported in #1044.

This masks all non-finite values (NaN and +/-inf) to 0 before the decomposition, matching how TorchSafeStandardScaler already treats inf. It also adds a numerical-stability fallback: if the single-precision SVD fails to converge, it retries the exact decomposition in float64; the randomized (lowrank) path falls back to the exact decomposition on the same error.


Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

How Has This Been Tested?

Added two regression tests:

  • test__fit__non_finite_input_does_not_crash feeds +inf, -inf, and NaN into fit and asserts the returned components and singular values are finite.
  • test__exact_svd__retries_in_float64_on_non_convergence forces the single-precision SVD to raise LinAlgError and asserts the float64 fallback runs and round-trips the dtype.

Ran the test_torch_svd.py suite locally (passes), plus pre-commit (ruff, ruff-format, mypy) clean.


Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes). No public API or usage change.
  • A changelog entry has been added (see changelog/README.md), or "no changelog needed" label requested.
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

…vergence

TorchTruncatedSVD.fit masked only NaN before the SVD, so +/-inf values
flowed into torch.linalg.svd / torch.svd_lowrank. Depending on the LAPACK
backend this either raised a non-convergence LinAlgError (crashing
preprocessing) or silently produced non-finite components, on datasets with
degenerate row-subsamples.

Mask all non-finite values (NaN and +/-inf) to 0 before the decomposition,
matching how TorchSafeStandardScaler already treats inf. Also add a float64
retry when the exact SVD fails to converge; the randomized (lowrank) path
falls back to the exact decomposition on the same error.

Fixes PriorLabs#1044
@devangpratap devangpratap requested a review from a team as a code owner June 26, 2026 19:32
@devangpratap devangpratap requested review from alanprior and removed request for a team June 26, 2026 19:32

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the robustness of SVD computations by replacing non-finite values (NaN and infinities) with zeros and introducing a fallback mechanism (_exact_svd) that retries in double precision (float64) upon encountering a convergence error (LinAlgError). It also adds corresponding unit tests. The feedback suggests optimizing the fallback mechanism to avoid a redundant and expensive SVD computation if the input tensor is already in double precision.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/tabpfn/preprocessing/torch/torch_svd.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@alanprior

Copy link
Copy Markdown
Contributor

@devangpratap thank you for your contribution! Seems useful! We will review this and get back to you!

@alanprior

Copy link
Copy Markdown
Contributor

Hello @devangpratap, we iterated on this internally, and we believe that the current code covers this problem already. Were you able to replicate the issue reported? Could you provide a working script that fails and motivated your fix?

Thanks!

@devangpratap

Copy link
Copy Markdown
Contributor Author

Went back through this properly. Script that shows the actual behavior on main:

import torch
from tabpfn.preprocessing.torch.torch_svd import TorchTruncatedSVD

# small matrix, exact svd path: no crash, but silently returns NaN
x = torch.randn(8, 5)
x[0, 0] = float("inf")
out = TorchTruncatedSVD(n_components=4).fit(x)
print(torch.isnan(out["singular_values"]).any())  # True

# large matrix, svd_lowrank path: raises
x2 = torch.randn(5000, 300)
x2[0, 0] = float("inf")
TorchTruncatedSVD(n_components=128).fit(x2)
# _LinAlgError: input matrix is ill-conditioned or has too many repeated singular values

But you're right that this doesn't matter in practice. Through the actual pipeline it's dead code: TorchSafeStandardScaler already sanitizes inf before it reaches the SVD (steps.py:414-423, torch_svd.py:340-341). Checked it by running the pre-fix torch_svd.py with an inf value through TorchAddSVDFeaturesStep end to end, no crash, no NaNs. So the isfinite() guard isn't fixing anything reachable today, only relevant if TorchTruncatedSVD is called directly, which nothing in the codebase does.

The other piece, the float64 retry when SVD fails to converge, is still worth keeping. Tried reproducing the SBDSDC error from #1044 on finite degenerate data (near-constant columns, rank-1, all-zeros, extreme scale) and couldn't trigger it either, seems LAPACK/hardware dependent rather than something a script can force.

I'll cut the isfinite() masking and keep just the float64 retry, since that one's a no-op on the normal path and covers the actual failure mode in #1044.

@devangpratap

Copy link
Copy Markdown
Contributor Author

Closing this. Neither half of the fix protects against something reachable.

isfinite() guard: dead code. TorchSafeStandardScaler always sanitizes input before it reaches the SVD, verified end to end. Even tried overflowing the scaler's own mean/std to inf, output stayed finite.

float64 retry: couldn't reproduce the SBDSDC error from #1044. ~31k trials across degenerate matrix shapes, macOS and Linux, single and multi-threaded. Zero failures.

Reopen if the reporter comes back with a concrete traceback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] torch.linalg.svd fails to converge in TorchTruncatedSVD on degenerate data

2 participants