Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1076.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix crash in `TorchTruncatedSVD` on non-finite input by masking NaN and +/-inf before the SVD, and retry in float64 when the decomposition fails to converge on degenerate data.
53 changes: 40 additions & 13 deletions src/tabpfn/preprocessing/torch/torch_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@
import torch


def _exact_svd(
x: torch.Tensor,
n_components: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Exact truncated SVD, retrying in float64 if it fails to converge.

On degenerate (near rank-deficient) data the default LAPACK/cuSOLVER
``gesdd`` driver can raise a non-convergence ``LinAlgError``. Recomputing
in double precision is more numerically stable and succeeds where the
single-precision decomposition does not.

Returns the components truncated to the top ``n_components``.
"""
try:
u, s, vh = torch.linalg.svd(x, full_matrices=False)
except torch.linalg.LinAlgError:
if x.dtype == torch.float64:
raise
u, s, vh = torch.linalg.svd(x.double(), full_matrices=False)
u, s, vh = u.to(x.dtype), s.to(x.dtype), vh.to(x.dtype)
return u[:, :n_components], s[:n_components], vh[:n_components, :]
Comment thread
devangpratap marked this conversation as resolved.


def _svd_flip_stable(
u: torch.Tensor,
v: torch.Tensor,
Expand Down Expand Up @@ -84,9 +107,10 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:

n_samples, n_features = x.shape

# Handle NaN values by replacing with 0 for SVD computation
nan_mask = torch.isnan(x)
x_filled = torch.where(nan_mask, torch.zeros_like(x), x)
# Replace non-finite values (NaN and +/-inf) with 0 for the SVD
# computation. Leaving infinities in deterministically crashes
# torch.linalg.svd / torch.svd_lowrank with a non-convergence error.
x_filled = torch.where(torch.isfinite(x), x, torch.zeros_like(x))

# Clamp n_components to valid range
n_components = min(self.n_components, n_samples, n_features)
Expand All @@ -111,17 +135,20 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
)

if use_lowrank:
# torch.svd_lowrank returns (U, S, V) with A ≈ U diag(S) V^T
u, s, v = torch.svd_lowrank(x_filled, q=q, niter=2)
# Truncate oversampling dimensions
u = u[:, :n_components]
s = s[:n_components]
vh = v[:, :n_components].T # V [n_features, q] → V^T [n_comp, n_features]
try:
# torch.svd_lowrank returns (U, S, V) with A ≈ U diag(S) V^T
u, s, v = torch.svd_lowrank(x_filled, q=q, niter=2)
# Truncate oversampling dimensions
u = u[:, :n_components]
s = s[:n_components]
vh = v[:, :n_components].T # V [n_feat, q] → V^T [n_comp, n_feat]
except torch.linalg.LinAlgError:
# Randomized SVD can fail to converge on degenerate (near
# rank-deficient) data; fall back to the exact, more robust
# decomposition.
u, s, vh = _exact_svd(x_filled, n_components)
else:
u, s, vh = torch.linalg.svd(x_filled, full_matrices=False)
u = u[:, :n_components]
s = s[:n_components]
vh = vh[:n_components, :]
u, s, vh = _exact_svd(x_filled, n_components)

# Apply sign flip for deterministic output.
# We use the same convention as sklearn (u_based_decision=False:
Expand Down
49 changes: 49 additions & 0 deletions tests/test_torch_preprocessing/test_torch_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tabpfn.preprocessing.torch.torch_svd import (
TorchSafeStandardScaler,
TorchTruncatedSVD,
_exact_svd,
_svd_flip_stable,
)

Expand Down Expand Up @@ -215,6 +216,54 @@ def test__transform__nan_handling(self):
# Second row (with NaN input) should be all NaN
assert torch.isnan(x_transformed[1]).all()

def test__fit__non_finite_input_does_not_crash(self):
"""Fit must not crash on +/-inf input and must return finite components.

Regression for #1044: fit previously masked only NaN, so infinities
flowed into torch.linalg.svd / torch.svd_lowrank and raised a
non-convergence LinAlgError.
"""
svd = TorchTruncatedSVD(n_components=4)

x = torch.randn(20, 8)
x[0, 0] = float("inf")
x[1, 3] = float("-inf")
x[2, 5] = float("nan")

cache = svd.fit(x)

assert torch.isfinite(cache["components"]).all()
assert torch.isfinite(cache["singular_values"]).all()
assert cache["components"].shape[0] == 4

def test__exact_svd__retries_in_float64_on_non_convergence(self, monkeypatch):
"""_exact_svd falls back to a float64 recompute when the SVD fails.

Regression for #1044: degenerate data can make the default driver
raise a non-convergence LinAlgError. Forcing the float32 call to
raise must trigger the double-precision retry, which succeeds and
casts results back to the input dtype.
"""
real_svd = torch.linalg.svd
calls = {"n": 0}

def flaky_svd(x: torch.Tensor, *args: object, **kwargs: object) -> object:
calls["n"] += 1
if x.dtype == torch.float32:
raise torch.linalg.LinAlgError("forced non-convergence")
return real_svd(x, *args, **kwargs)

monkeypatch.setattr(torch.linalg, "svd", flaky_svd)

x = torch.randn(12, 6, dtype=torch.float32)
u, s, vh = _exact_svd(x, n_components=3)

assert calls["n"] == 2 # float32 raised, float64 retried
assert u.dtype == torch.float32
assert s.dtype == torch.float32
assert vh.shape == (3, 6)
assert torch.isfinite(s).all()

def test__transform__missing_cache_raises(self):
"""Test that transform raises ValueError with invalid cache."""
svd = TorchTruncatedSVD(n_components=5)
Expand Down
Loading