Skip to content

Commit

Permalink
(fix): pass pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Feb 26, 2024
1 parent 37c3757 commit a0f94a4
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 137 deletions.
9 changes: 9 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,12 @@ @article{srivatsan:20
year={2020},
publisher={American Association for the Advancement of Science}
}

@misc{klein2023generative,
title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces},
author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi},
year={2023},
eprint={2310.09254},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
21 changes: 3 additions & 18 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import (
OTTOutput,
OTTNeuralOutput,
GraphOTTOutput
)
from moscot.backends.ott.solver import (
GWSolver,
SinkhornSolver,
GENOTLinSolver,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = [
"OTTOutput",
"GWSolver",
"SinkhornSolver",
"OTTNeuralOutput",
"sinkhorn_divergence",
"GENOTLinSolver"
]
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]


register_cost("euclidean", backend="ott")(costs.Euclidean)
Expand Down
15 changes: 2 additions & 13 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
from functools import partial
from typing import (
Any,
Dict,
Mapping,
Optional,
Tuple,
Type,
Union,
Literal
)

import optax
from typing import Any, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geometry, pointcloud
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div

from moscot._logging import logger
Expand Down
60 changes: 36 additions & 24 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import jaxlib.xla_extension as xla_ext

import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.problems.linear import potentials
from ott.neural.flow_models.genot import (
GENOTBase, # TODO(ilan-gold): will neeed to update for ICNN's
)
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
from ott.neural.flow_models.genot import GENOTBase # TODO(ilan-gold): will neeed to update for ICNN's

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from moscot._types import ArrayLike, Device_t
from moscot.backends.ott._utils import get_nearest_neighbors
Expand Down Expand Up @@ -245,9 +244,20 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102


class OTTNeuralOutput(BaseNeuralOutput):
def __init__(self, model: GENOTBase):
"""Output wrapper for GENOTBase."""

def __init__(
self, model: GENOTBase
): # TODO(ilan-gold): Swap out once a more general implemenetation fo the base model is available.
"""Initialize `OTTNeuralOutput`.
Parameters
----------
model : GENOTBase
The OTT-Jax GENOTBase model
"""
self._model = model

def _project_transport_matrix(
self,
src_dist: ArrayLike,
Expand Down Expand Up @@ -289,7 +299,7 @@ def _project_transport_matrix(
if save_transport_matrix:
self._inverse_transport_matrix = tm
return tm

def project_transport_matrix( # type:ignore[override]
self,
src_cells: ArrayLike,
Expand Down Expand Up @@ -341,13 +351,9 @@ def project_transport_matrix( # type:ignore[override]
The projected transport matrix.
"""
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
push = self.push if condition is None else lambda x : self.push(x, condition)
pull = self.pull if condition is None else lambda x : self.pull(x, condition)
func, src_dist, tgt_dist = (
(push, src_cells, tgt_cells)
if forward
else (pull, tgt_cells, src_cells)
)
push = self.push if condition is None else lambda x: self.push(x, condition)
pull = self.pull if condition is None else lambda x: self.pull(x, condition)
func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells)
return self._project_transport_matrix(
src_dist=src_dist,
tgt_dist=tgt_dist,
Expand All @@ -359,8 +365,8 @@ def project_transport_matrix( # type:ignore[override]
length_scale=length_scale,
seed=seed,
)
def push(self, x: ArrayLike, cond: ArrayLike | None = None) -> ArrayLike: # type: ignore[override]

def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
"""Push distribution `x` conditioned on condition `cond`.
Parameters
Expand All @@ -378,7 +384,7 @@ def push(self, x: ArrayLike, cond: ArrayLike | None = None) -> ArrayLike: # typ
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
return self._apply(x, cond=cond, forward=True)

def pull(self, x: ArrayLike, cond: ArrayLike | None = None) -> ArrayLike: # type: ignore[override]
def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
"""Pull distribution `x` conditioned on condition `cond`.
Parameters
Expand All @@ -395,24 +401,30 @@ def pull(self, x: ArrayLike, cond: ArrayLike | None = None) -> ArrayLike: # typ
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
return self._apply(x, cond=cond, forward=False)
def _apply(self, x: ArrayLike, forward: bool, cond: ArrayLike | None = None) -> ArrayLike:

def _apply(self, x: ArrayLike, forward: bool, cond: Optional[ArrayLike] = None) -> ArrayLike:
return self._model.transport(x, condition=cond, forward=forward)


@property
def is_linear(self) -> bool: # noqa: D102
return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOTBase

@property
def shape(self) -> Tuple[int, int]:
"""%(shape)s."""
raise NotImplementedError()

def to(
self,
device: Optional[Device_t] = None,
) -> "OTTNeuralOutput":
"""Transfer the output to another device or change its data type.
Parameters
----------
device
If not `None`, the output will be transferred to `device`.
Returns
-------
The output on a saved on `device`.
Expand All @@ -432,8 +444,8 @@ def to(

# out = jax.device_put(self._model, device)
# return OTTNeuralOutput(out)
return self #TODO(ilan-gold) move model to device
return self # TODO(ilan-gold) move model to device

@property
def converged(self) -> bool:
"""%(converged)s."""
Expand Down
Loading

0 comments on commit a0f94a4

Please sign in to comment.