Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about the IRLS #26

Open
zhanghua7099 opened this issue Jan 26, 2025 · 2 comments
Open

Questions about the IRLS #26

zhanghua7099 opened this issue Jan 26, 2025 · 2 comments

Comments

@zhanghua7099
Copy link

Hi!

I found that this project may support IRLS.

https://github.com/brentyi/jaxls?tab=readme-ov-file#jaxls

Does this mean we can use a robust kernel, such as the Huber and Cauchy function? I searched for this project but couldn't find it. Could you please give me some advice?

Thank you for your help!

@brentyi
Copy link
Owner

brentyi commented Jan 26, 2025

Hello!

For IRLS we just need to apply a loss-dependent weight term, right?

You could do this by modifying the compute_residual= function that factors take as input. Here's a utility for Huber, I generated from Claude and haven't tested exhaustively but it looks correct to me:

from typing import Callable, Literal, cast

import jax
import jax.numpy as jnp
from jax import numpy as jnp

def huber_loss_wrapper[T: Callable[..., jax.Array]](
    residual_fn: T, k: float = 1.345, mode: Literal["per_term", "norm"] = "per_term"
) -> T:
    """Wraps a residual function with Huber loss for IRLS.

    Args:
        residual_fn: Original residual function
        k: Huber loss parameter (default=1.345 for 95% efficiency on normal distribution)
        mode: Whether to apply weighting per term or based on residual norm

    Returns:
        A new residual function that applies Huber weighting
    """

    def huber_residual(*args):
        residual = residual_fn(*args)
        eps = 1e-8  # epsilon to prevent NaNs

        if mode == "per_term":
            abs_r = jnp.abs(residual) + eps
            w = jax.lax.stop_gradient(jnp.where(abs_r > k, k / abs_r, 1.0))
        else:  # norm
            r_norm = jnp.linalg.norm(residual) + eps
            w = jax.lax.stop_gradient(jnp.where(r_norm > k, k / r_norm, 1.0))
        return residual * jnp.sqrt(w)  # type: ignore

    return cast(T, huber_residual)

Usage: jaxls.Factor(lambda ..., ...) can be replaced by jaxls.Factor(huber_loss_wrapper(lambda ...), ...).

@zhanghua7099
Copy link
Author

Hi! Thank you very much for your quick reply.

I will test this function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants