diff --git a/changelog/1076.fixed.md b/changelog/1076.fixed.md new file mode 100644 index 000000000..328b59fc5 --- /dev/null +++ b/changelog/1076.fixed.md @@ -0,0 +1 @@ +Fix crash in `TorchTruncatedSVD` on non-finite input by masking NaN and +/-inf before the SVD, and retry in float64 when the decomposition fails to converge on degenerate data. diff --git a/src/tabpfn/preprocessing/torch/torch_svd.py b/src/tabpfn/preprocessing/torch/torch_svd.py index 58ba9c384..0b7958131 100644 --- a/src/tabpfn/preprocessing/torch/torch_svd.py +++ b/src/tabpfn/preprocessing/torch/torch_svd.py @@ -7,6 +7,29 @@ import torch +def _exact_svd( + x: torch.Tensor, + n_components: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Exact truncated SVD, retrying in float64 if it fails to converge. + + On degenerate (near rank-deficient) data the default LAPACK/cuSOLVER + ``gesdd`` driver can raise a non-convergence ``LinAlgError``. Recomputing + in double precision is more numerically stable and succeeds where the + single-precision decomposition does not. + + Returns the components truncated to the top ``n_components``. + """ + try: + u, s, vh = torch.linalg.svd(x, full_matrices=False) + except torch.linalg.LinAlgError: + if x.dtype == torch.float64: + raise + u, s, vh = torch.linalg.svd(x.double(), full_matrices=False) + u, s, vh = u.to(x.dtype), s.to(x.dtype), vh.to(x.dtype) + return u[:, :n_components], s[:n_components], vh[:n_components, :] + + def _svd_flip_stable( u: torch.Tensor, v: torch.Tensor, @@ -84,9 +107,10 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]: n_samples, n_features = x.shape - # Handle NaN values by replacing with 0 for SVD computation - nan_mask = torch.isnan(x) - x_filled = torch.where(nan_mask, torch.zeros_like(x), x) + # Replace non-finite values (NaN and +/-inf) with 0 for the SVD + # computation. Leaving infinities in deterministically crashes + # torch.linalg.svd / torch.svd_lowrank with a non-convergence error. + x_filled = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) # Clamp n_components to valid range n_components = min(self.n_components, n_samples, n_features) @@ -111,17 +135,20 @@ def fit(self, x: torch.Tensor) -> dict[str, torch.Tensor]: ) if use_lowrank: - # torch.svd_lowrank returns (U, S, V) with A ≈ U diag(S) V^T - u, s, v = torch.svd_lowrank(x_filled, q=q, niter=2) - # Truncate oversampling dimensions - u = u[:, :n_components] - s = s[:n_components] - vh = v[:, :n_components].T # V [n_features, q] → V^T [n_comp, n_features] + try: + # torch.svd_lowrank returns (U, S, V) with A ≈ U diag(S) V^T + u, s, v = torch.svd_lowrank(x_filled, q=q, niter=2) + # Truncate oversampling dimensions + u = u[:, :n_components] + s = s[:n_components] + vh = v[:, :n_components].T # V [n_feat, q] → V^T [n_comp, n_feat] + except torch.linalg.LinAlgError: + # Randomized SVD can fail to converge on degenerate (near + # rank-deficient) data; fall back to the exact, more robust + # decomposition. + u, s, vh = _exact_svd(x_filled, n_components) else: - u, s, vh = torch.linalg.svd(x_filled, full_matrices=False) - u = u[:, :n_components] - s = s[:n_components] - vh = vh[:n_components, :] + u, s, vh = _exact_svd(x_filled, n_components) # Apply sign flip for deterministic output. # We use the same convention as sklearn (u_based_decision=False: diff --git a/tests/test_torch_preprocessing/test_torch_svd.py b/tests/test_torch_preprocessing/test_torch_svd.py index a5b85b7ba..243655044 100644 --- a/tests/test_torch_preprocessing/test_torch_svd.py +++ b/tests/test_torch_preprocessing/test_torch_svd.py @@ -23,6 +23,7 @@ from tabpfn.preprocessing.torch.torch_svd import ( TorchSafeStandardScaler, TorchTruncatedSVD, + _exact_svd, _svd_flip_stable, ) @@ -215,6 +216,54 @@ def test__transform__nan_handling(self): # Second row (with NaN input) should be all NaN assert torch.isnan(x_transformed[1]).all() + def test__fit__non_finite_input_does_not_crash(self): + """Fit must not crash on +/-inf input and must return finite components. + + Regression for #1044: fit previously masked only NaN, so infinities + flowed into torch.linalg.svd / torch.svd_lowrank and raised a + non-convergence LinAlgError. + """ + svd = TorchTruncatedSVD(n_components=4) + + x = torch.randn(20, 8) + x[0, 0] = float("inf") + x[1, 3] = float("-inf") + x[2, 5] = float("nan") + + cache = svd.fit(x) + + assert torch.isfinite(cache["components"]).all() + assert torch.isfinite(cache["singular_values"]).all() + assert cache["components"].shape[0] == 4 + + def test__exact_svd__retries_in_float64_on_non_convergence(self, monkeypatch): + """_exact_svd falls back to a float64 recompute when the SVD fails. + + Regression for #1044: degenerate data can make the default driver + raise a non-convergence LinAlgError. Forcing the float32 call to + raise must trigger the double-precision retry, which succeeds and + casts results back to the input dtype. + """ + real_svd = torch.linalg.svd + calls = {"n": 0} + + def flaky_svd(x: torch.Tensor, *args: object, **kwargs: object) -> object: + calls["n"] += 1 + if x.dtype == torch.float32: + raise torch.linalg.LinAlgError("forced non-convergence") + return real_svd(x, *args, **kwargs) + + monkeypatch.setattr(torch.linalg, "svd", flaky_svd) + + x = torch.randn(12, 6, dtype=torch.float32) + u, s, vh = _exact_svd(x, n_components=3) + + assert calls["n"] == 2 # float32 raised, float64 retried + assert u.dtype == torch.float32 + assert s.dtype == torch.float32 + assert vh.shape == (3, 6) + assert torch.isfinite(s).all() + def test__transform__missing_cache_raises(self): """Test that transform raises ValueError with invalid cache.""" svd = TorchTruncatedSVD(n_components=5)