diff --git a/src/ott/neural/methods/conditional_monge_gap.py b/src/ott/neural/methods/conditional_monge_gap.py new file mode 100644 index 000000000..3b52a366e --- /dev/null +++ b/src/ott/neural/methods/conditional_monge_gap.py @@ -0,0 +1,91 @@ +import collections +import functools + +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterator, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.core import frozen_dict +from flax.training import train_state +from flax.training.orbax_utils import save_args_from_target +from jax.tree_util import tree_map +from orbax.checkpoint import PyTreeCheckpointer +from ott.neural.methods.monge_gap import monge_gap_from_samples +from ott.solvers.linear import sinkhorn +from ott.neural.networks.conditional_perturbation_network import ( + ConditionalPerturbationNetwork, +) + +T = TypeVar("T", bound="ConditionalMongeGapEstimator") + + +def cmonge_gap_from_samples( + source: jnp.ndarray, + target: jnp.ndarray, + condition: jnp.ndarray, + return_output: bool = False, + **kwargs: Any, +) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: + r"""Monge gap, instantiated in terms of samples before / after applying map. + + .. math:: + \sum_{i=1}{K} \frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) - + W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i}, + \frac{1}{n}\sum_i \delta_{y_i}) + + where :math:`W_{c, \varepsilon}` is an + :term:`entropy-regularized optimal transport` + cost, the :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`. + + Args: + source: samples from first measure, array of shape ``[n, d]``. + target: samples from second measure, array of shape ``[n, d]``. + condition: array indicating condition for each source-target sample + `integer array of shape ``[n]``. + return_output: boolean to also return the + :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. + kwargs: holds the kwargs to the function + :function:`~ott.neural.methods.monge_gap.monge_gap_from_samples` + + Returns: + The average Monge gap value over all conditions and optionally the + list of Monge gap per condition and :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput` + """ + all_gaps = [] + all_outs = [] + for c in jnp.unique(condition): + c_target = target[condition == c] + c_source = source[condition == c] + + if return_output: + monge_gap, out = monge_gap_from_samples( + target=c_target, source=c_source, return_output=True, **kwargs + ) + all_outs.append(out) + else: + monge_gap = monge_gap_from_samples( + target=c_target, source=c_source, return_output=False, **kwargs + ) + all_gaps.append(monge_gap) + + condition_monge_gap = sum(all_gaps) / len(all_gaps) # average + + return ( + (condition_monge_gap, all_outs, all_gaps) + if return_output + else condition_monge_gap + ) diff --git a/src/ott/neural/networks/conditional_perturbation_network.py b/src/ott/neural/networks/conditional_perturbation_network.py new file mode 100644 index 000000000..81fefff76 --- /dev/null +++ b/src/ott/neural/networks/conditional_perturbation_network.py @@ -0,0 +1,117 @@ +from typing import Any, Callable, Iterable, Sequence, Tuple + +import flax.linen as nn +import jax.numpy as jnp +import optax +from ott.neural.networks.potentials import ( + BasePotential, + PotentialTrainState, +) + + +class ConditionalPerturbationNetwork(BasePotential): + dim_hidden: Sequence[int] = None + dim_data: int = None + dim_cond: int = None # Full dimension of all context variables concatenated + # Same length as context_entity_bonds if embed_cond_equal is False + # (if True, first item is size of deep set layer, rest is ignored) + dim_cond_map: Iterable[int] = (50,) + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu + is_potential: bool = False + layer_norm: bool = False + embed_cond_equal: bool = ( + False # Whether all context variables should be treated as set or not + ) + context_entity_bonds: Iterable[Tuple[int, int]] = ( + (0, 10), + (0, 11), + ) # Start/stop index per modality + num_contexts: int = 2 + + @nn.compact + def __call__( + self, x: jnp.ndarray, c: jnp.ndarray + ) -> jnp.ndarray: # noqa: D102 + """ + Args: + x (jnp.ndarray): The input data of shape bs x dim_data + c (jnp.ndarray): The context of shape bs x dim_cond with + possibly different modalities + concatenated, as can be specified via context_entity_bonds. + + Returns: + jnp.ndarray: _description_ + """ + n_input = x.shape[-1] + + # Chunk the inputs + contexts = [ + c[:, e[0] : e[1]] + for i, e in enumerate(self.context_entity_bonds) + if i < self.num_contexts + ] + + if not self.embed_cond_equal: + # Each context is processed by a different layer, + # good for combining modalities + assert len(self.context_entity_bonds) == len(self.dim_cond_map), ( + "Length of context entity bonds and context map sizes have to " + f"match: {self.context_entity_bonds} != {self.dim_cond_map}" + ) + + layers = [ + nn.Dense(self.dim_cond_map[i], use_bias=True) + for i in range(len(contexts)) + ] + embeddings = [ + self.act_fn(layers[i](context)) + for i, context in enumerate(contexts) + ] + cond_embedding = jnp.concatenate(embeddings, axis=1) + else: + # We can use any number of contexts from the same modality, + # via a permutation-invariant deep set layer. + sizes = [c.shape[-1] for c in contexts] + if not len(set(sizes)) == 1: + raise ValueError( + "For embedding a set, all contexts need same length ," + f"not {sizes}" + ) + layer = nn.Dense(self.dim_cond_map[0], use_bias=True) + embeddings = [self.act_fn(layer(context)) for context in contexts] + # Average along stacked dimension + # (alternatives like summing are possible) + cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) + + z = jnp.concatenate((x, cond_embedding), axis=1) + if self.layer_norm: + n = nn.LayerNorm() + z = n(z) + + for n_hidden in self.dim_hidden: + wx = nn.Dense(n_hidden, use_bias=True) + z = self.act_fn(wx(z)) + wx = nn.Dense(n_input, use_bias=True) + + return x + wx(z) + + def create_train_state( + self, + rng: jnp.ndarray, + optimizer: optax.OptState, + dim_data: int, + dim_cond: int, + **kwargs: Any, + ) -> PotentialTrainState: + """Create initial `TrainState`.""" + c = jnp.ones((1, dim_cond)) # (n_batch, embed_dim) + x = jnp.ones((1, dim_data)) # (n_batch, data_dim) + params = self.init(rng, x=x, c=c)["params"] + return PotentialTrainState.create( + apply_fn=self.apply, + params=params, + tx=optimizer, + potential_value_fn=self.potential_value_fn, + potential_gradient_fn=self.potential_gradient_fn, + **kwargs, + )