-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions about the IRLS #26
Comments
Hello! For IRLS we just need to apply a loss-dependent weight term, right? You could do this by modifying the from typing import Callable, Literal, cast
import jax
import jax.numpy as jnp
from jax import numpy as jnp
def huber_loss_wrapper[T: Callable[..., jax.Array]](
residual_fn: T, k: float = 1.345, mode: Literal["per_term", "norm"] = "per_term"
) -> T:
"""Wraps a residual function with Huber loss for IRLS.
Args:
residual_fn: Original residual function
k: Huber loss parameter (default=1.345 for 95% efficiency on normal distribution)
mode: Whether to apply weighting per term or based on residual norm
Returns:
A new residual function that applies Huber weighting
"""
def huber_residual(*args):
residual = residual_fn(*args)
eps = 1e-8 # epsilon to prevent NaNs
if mode == "per_term":
abs_r = jnp.abs(residual) + eps
w = jax.lax.stop_gradient(jnp.where(abs_r > k, k / abs_r, 1.0))
else: # norm
r_norm = jnp.linalg.norm(residual) + eps
w = jax.lax.stop_gradient(jnp.where(r_norm > k, k / r_norm, 1.0))
return residual * jnp.sqrt(w) # type: ignore
return cast(T, huber_residual) Usage: |
Hi! Thank you very much for your quick reply. I will test this function. |
Hi!
I found that this project may support IRLS.
https://github.com/brentyi/jaxls?tab=readme-ov-file#jaxls
Does this mean we can use a robust kernel, such as the Huber and Cauchy function? I searched for this project but couldn't find it. Could you please give me some advice?
Thank you for your help!
The text was updated successfully, but these errors were encountered: