Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not passing epsilon with kernel matrix causes recursion error #349

Open
michalk8 opened this issue Mar 30, 2023 · 2 comments
Open

Not passing epsilon with kernel matrix causes recursion error #349

michalk8 opened this issue Mar 30, 2023 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@michalk8
Copy link
Collaborator

michalk8 commented Mar 30, 2023

Code to reproduce; most likely introduce in #310 :

import jax.numpy as jnp
import ott
x = jnp.ones((10, 12))
ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix

Traceback:

RecursionError                            Traceback (most recent call last)
Cell In [1], line 4
      2 import ott
      3 x = jnp.ones((10, 12))
----> 4 ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix

File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
    109   cost = -jnp.log(self._kernel_matrix + eps)
    110   cost *= self.inv_scale_cost
--> 111   return cost if self._epsilon_init is None else self.epsilon * cost
    112 return self._cost_matrix * self.inv_scale_cost

File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
    152 @property
    153 def epsilon(self) -> float:
    154   """Epsilon regularization value."""
--> 155   return self._epsilon.target

File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
    141 use_mean_scale = rel is True or (rel is None and target is None)
    142 if scale_eps is None and use_mean_scale:
--> 143   scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
    145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
    146   return self._epsilon_init.set(scale_epsilon=scale_eps)

File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
    120 @property
    121 def mean_cost_matrix(self) -> float:
    122   """Mean of the :attr:`cost_matrix`."""
--> 123   tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
    124   return jnp.sum(tmp * self._m_normed_ones)

File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
    573   arr = arr.reshape(-1, 1)
    575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)

    [... skipping hidden 3 frame]

File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
    578 def _apply_cost_to_vec(
    579     self,
    580     vec: jnp.ndarray,
   (...)
    583     **_: Any,
    584 ) -> jnp.ndarray:
    585   """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
    586 
    587   Args:
   (...)
    594     A jnp.ndarray corresponding to cost x vector
    595   """
--> 596   matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
    597   matrix = fn(matrix) if fn is not None else matrix
    598   return jnp.dot(matrix, vec)

File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
    109   cost = -jnp.log(self._kernel_matrix + eps)
    110   cost *= self.inv_scale_cost
--> 111   return cost if self._epsilon_init is None else self.epsilon * cost
    112 return self._cost_matrix * self.inv_scale_cost

File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
    152 @property
    153 def epsilon(self) -> float:
    154   """Epsilon regularization value."""
--> 155   return self._epsilon.target

File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
    141 use_mean_scale = rel is True or (rel is None and target is None)
    142 if scale_eps is None and use_mean_scale:
--> 143   scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
    145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
    146   return self._epsilon_init.set(scale_epsilon=scale_eps)

File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
    120 @property
    121 def mean_cost_matrix(self) -> float:
    122   """Mean of the :attr:`cost_matrix`."""
--> 123   tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
    124   return jnp.sum(tmp * self._m_normed_ones)

File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
    573   arr = arr.reshape(-1, 1)
    575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)

    [... skipping hidden 3 frame]

File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
    578 def _apply_cost_to_vec(
    579     self,
    580     vec: jnp.ndarray,
   (...)
    583     **_: Any,
    584 ) -> jnp.ndarray:
    585   """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
    586 
    587   Args:
   (...)
    594     A jnp.ndarray corresponding to cost x vector
    595   """
--> 596   matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
    597   matrix = fn(matrix) if fn is not None else matrix
    598   return jnp.dot(matrix, vec)

    [... skipping similar frames: Geometry._epsilon at line 143 (294 times), Geometry.cost_matrix at line 111 (294 times), Geometry.epsilon at line 155 (294 times), Geometry.mean_cost_matrix at line 123 (294 times), Geometry._apply_cost_to_vec at line 596 (293 times), Geometry.apply_cost at line 576 (293 times), WrappedFun.call_wrapped at line 165 (293 times), api_boundary.<locals>.reraise_with_filtered_traceback at line 166 (293 times), vmap.<locals>.vmap_f at line 1773 (293 times)]

File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
    573   arr = arr.reshape(-1, 1)
    575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)

    [... skipping hidden 3 frame]

File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
    578 def _apply_cost_to_vec(
    579     self,
    580     vec: jnp.ndarray,
   (...)
    583     **_: Any,
    584 ) -> jnp.ndarray:
    585   """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
    586 
    587   Args:
   (...)
    594     A jnp.ndarray corresponding to cost x vector
    595   """
--> 596   matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
    597   matrix = fn(matrix) if fn is not None else matrix
    598   return jnp.dot(matrix, vec)

File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
    109   cost = -jnp.log(self._kernel_matrix + eps)
    110   cost *= self.inv_scale_cost
--> 111   return cost if self._epsilon_init is None else self.epsilon * cost
    112 return self._cost_matrix * self.inv_scale_cost

File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
    152 @property
    153 def epsilon(self) -> float:
    154   """Epsilon regularization value."""
--> 155   return self._epsilon.target

File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
    141 use_mean_scale = rel is True or (rel is None and target is None)
    142 if scale_eps is None and use_mean_scale:
--> 143   scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
    145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
    146   return self._epsilon_init.set(scale_epsilon=scale_eps)

File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
    120 @property
    121 def mean_cost_matrix(self) -> float:
    122   """Mean of the :attr:`cost_matrix`."""
--> 123   tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
    124   return jnp.sum(tmp * self._m_normed_ones)

File ~/Projects/ott/src/ott/geometry/geometry.py:862, in Geometry._n_normed_ones(self)
    860 """Normalized array of shape ``[num_a,]``."""
    861 mask = self.src_mask
--> 862 arr = jnp.ones(self.shape[0]) if mask is None else mask
    863 return arr / jnp.sum(arr)

File ~/.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2150, in ones(shape, dtype)
   2148 shape = canonicalize_shape(shape)
   2149 dtypes.check_user_dtype_supported(dtype, "ones")
-> 2150 return lax.full(shape, 1, _jnp_dtype(dtype))

    [... skipping hidden 17 frame]

File ~/.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/config.py:241, in Config.define_bool_state.<locals>.get_state(self)
    240 def get_state(self):
--> 241   val = _thread_local_state.__dict__.get(name, unset)
    242   return val if val is not unset else self._read(name)
@michalk8 michalk8 added the bug Something isn't working label Mar 30, 2023
@michalk8 michalk8 self-assigned this Mar 30, 2023
@marcocuturi
Copy link
Contributor

Maybe epsilon should be set by default to 1.0 in that case?

@michalk8
Copy link
Collaborator Author

Geoms will be refactored in a future release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants