Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding conditional monge gap #605

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions src/ott/neural/methods/conditional_monge_gap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import collections
import functools

from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.core import frozen_dict
from flax.training import train_state
from flax.training.orbax_utils import save_args_from_target
from jax.tree_util import tree_map
from orbax.checkpoint import PyTreeCheckpointer
from ott.neural.methods.monge_gap import monge_gap_from_samples
from ott.solvers.linear import sinkhorn
from ott.neural.networks.conditional_perturbation_network import (
ConditionalPerturbationNetwork,
)

T = TypeVar("T", bound="ConditionalMongeGapEstimator")


def cmonge_gap_from_samples(
source: jnp.ndarray,
target: jnp.ndarray,
condition: jnp.ndarray,
return_output: bool = False,
**kwargs: Any,
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
r"""Monge gap, instantiated in terms of samples before / after applying map.

.. math::
\sum_{i=1}{K} \frac{1}{n} \sum_{i=1}^n c(x_i, y_i)) -
W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i},
\frac{1}{n}\sum_i \delta_{y_i})

where :math:`W_{c, \varepsilon}` is an
:term:`entropy-regularized optimal transport`
cost, the :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`.

Args:
source: samples from first measure, array of shape ``[n, d]``.
target: samples from second measure, array of shape ``[n, d]``.
condition: array indicating condition for each source-target sample
`integer array of shape ``[n]``.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to the function
:function:`~ott.neural.methods.monge_gap.monge_gap_from_samples`

Returns:
The average Monge gap value over all conditions and optionally the
list of Monge gap per condition and :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`
"""
all_gaps = []
all_outs = []
for c in jnp.unique(condition):
c_target = target[condition == c]
c_source = source[condition == c]

if return_output:
monge_gap, out = monge_gap_from_samples(
target=c_target, source=c_source, return_output=True, **kwargs
)
all_outs.append(out)
else:
monge_gap = monge_gap_from_samples(
target=c_target, source=c_source, return_output=False, **kwargs
)
all_gaps.append(monge_gap)

condition_monge_gap = sum(all_gaps) / len(all_gaps) # average

return (
(condition_monge_gap, all_outs, all_gaps)
if return_output
else condition_monge_gap
)
117 changes: 117 additions & 0 deletions src/ott/neural/networks/conditional_perturbation_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, Callable, Iterable, Sequence, Tuple

import flax.linen as nn
import jax.numpy as jnp
import optax
from ott.neural.networks.potentials import (
BasePotential,
PotentialTrainState,
)


class ConditionalPerturbationNetwork(BasePotential):
dim_hidden: Sequence[int] = None
dim_data: int = None
dim_cond: int = None # Full dimension of all context variables concatenated
# Same length as context_entity_bonds if embed_cond_equal is False
# (if True, first item is size of deep set layer, rest is ignored)
dim_cond_map: Iterable[int] = (50,)
act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu
is_potential: bool = False
layer_norm: bool = False
embed_cond_equal: bool = (
False # Whether all context variables should be treated as set or not
)
context_entity_bonds: Iterable[Tuple[int, int]] = (
(0, 10),
(0, 11),
) # Start/stop index per modality
num_contexts: int = 2

@nn.compact
def __call__(
self, x: jnp.ndarray, c: jnp.ndarray
) -> jnp.ndarray: # noqa: D102
"""
Args:
x (jnp.ndarray): The input data of shape bs x dim_data
c (jnp.ndarray): The context of shape bs x dim_cond with
possibly different modalities
concatenated, as can be specified via context_entity_bonds.

Returns:
jnp.ndarray: _description_
"""
n_input = x.shape[-1]

# Chunk the inputs
contexts = [
c[:, e[0] : e[1]]
for i, e in enumerate(self.context_entity_bonds)
if i < self.num_contexts
]

if not self.embed_cond_equal:
# Each context is processed by a different layer,
# good for combining modalities
assert len(self.context_entity_bonds) == len(self.dim_cond_map), (
"Length of context entity bonds and context map sizes have to "
f"match: {self.context_entity_bonds} != {self.dim_cond_map}"
)

layers = [
nn.Dense(self.dim_cond_map[i], use_bias=True)
for i in range(len(contexts))
]
embeddings = [
self.act_fn(layers[i](context))
for i, context in enumerate(contexts)
]
cond_embedding = jnp.concatenate(embeddings, axis=1)
else:
# We can use any number of contexts from the same modality,
# via a permutation-invariant deep set layer.
sizes = [c.shape[-1] for c in contexts]
if not len(set(sizes)) == 1:
raise ValueError(
"For embedding a set, all contexts need same length ,"
f"not {sizes}"
)
layer = nn.Dense(self.dim_cond_map[0], use_bias=True)
embeddings = [self.act_fn(layer(context)) for context in contexts]
# Average along stacked dimension
# (alternatives like summing are possible)
cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0)

z = jnp.concatenate((x, cond_embedding), axis=1)
if self.layer_norm:
n = nn.LayerNorm()
z = n(z)

for n_hidden in self.dim_hidden:
wx = nn.Dense(n_hidden, use_bias=True)
z = self.act_fn(wx(z))
wx = nn.Dense(n_input, use_bias=True)

return x + wx(z)

def create_train_state(
self,
rng: jnp.ndarray,
optimizer: optax.OptState,
dim_data: int,
dim_cond: int,
**kwargs: Any,
) -> PotentialTrainState:
"""Create initial `TrainState`."""
c = jnp.ones((1, dim_cond)) # (n_batch, embed_dim)
x = jnp.ones((1, dim_data)) # (n_batch, data_dim)
params = self.init(rng, x=x, c=c)["params"]
return PotentialTrainState.create(
apply_fn=self.apply,
params=params,
tx=optimizer,
potential_value_fn=self.potential_value_fn,
potential_gradient_fn=self.potential_gradient_fn,
**kwargs,
)
Loading