Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 26, 2024
1 parent 85ac491 commit 25844d7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
1 change: 0 additions & 1 deletion src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial
from typing import Any, Literal, Optional, Tuple, Union


import jax
import jax.numpy as jnp
import scipy.sparse as sp
Expand Down
59 changes: 31 additions & 28 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
problem_kwargs |= {"alpha"}
return geom_kwargs | problem_kwargs, {"epsilon"}


class GENOTLinSolver(OTSolver[OTTOutput]):

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -526,18 +527,11 @@ def _prepare( # type: ignore[override]
conditions=distributions[source_key].conditions,
)
source_loader = DataLoader(
source_ds,
batch_size=batch_size,
sampler=RandomSampler(source_ds, replacement=True)
)
target_ds = OTDataset(
lin=distributions[target_key].xy,
conditions=distributions[target_key].conditions
source_ds, batch_size=batch_size, sampler=RandomSampler(source_ds, replacement=True)
)
target_ds = OTDataset(lin=distributions[target_key].xy, conditions=distributions[target_key].conditions)
target_loader = DataLoader(
target_ds,
batch_size=batch_size,
sampler=RandomSampler(target_ds, replacement=True)
target_ds, batch_size=batch_size, sampler=RandomSampler(target_ds, replacement=True)
)
train_loaders.append((source_loader, target_loader))
validate_loaders.append((source_loader, target_loader))
Expand Down Expand Up @@ -575,8 +569,7 @@ def _prepare( # type: ignore[override]
sampler=RandomSampler(source_ds_train, replacement=True),
)
target_ds_train = OTDataset(
lin=target_split_data.data_train,
conditions=target_split_data.conditions_train
lin=target_split_data.data_train, conditions=target_split_data.conditions_train
)
target_train_loader = DataLoader(
target_ds_train,
Expand All @@ -590,16 +583,15 @@ def _prepare( # type: ignore[override]
source_validate_loader = DataLoader(
source_ds_validate,
batch_size=batch_size,
sampler=RandomSampler(source_ds_validate, replacement=True)
sampler=RandomSampler(source_ds_validate, replacement=True),
)
target_ds_validate = OTDataset(
lin=target_split_data.data_valid,
conditions=target_split_data.conditions_valid
lin=target_split_data.data_valid, conditions=target_split_data.conditions_valid
)
target_validate_loader = DataLoader(
target_ds_validate,
batch_size=batch_size,
sampler=RandomSampler(target_ds_validate, replacement=True)
sampler=RandomSampler(target_ds_validate, replacement=True),
)
train_loaders.append((source_train_loader, target_train_loader))
validate_loaders.append((source_validate_loader, target_validate_loader))
Expand All @@ -612,19 +604,27 @@ def _prepare( # type: ignore[override]
latent_embed_dim=self._neural_kwargs.pop("latent_embed_dim", 5),
)
ot_solver = sinkhorn.Sinkhorn(**self._neural_kwargs.pop("valid_sinkhorn_kwargs", {}))
tau_a=self._neural_kwargs.pop("tau_a", 1)
tau_b=self._neural_kwargs.pop("tau_b", 1)
tau_a = self._neural_kwargs.pop("tau_a", 1)
tau_b = self._neural_kwargs.pop("tau_b", 1)
rescaling_a = self._neural_kwargs.pop("rescaling_a", RescalingMLP(hidden_dim=4, condition_dim=condition_dim))
rescaling_b = self._neural_kwargs.pop("rescaling_b", RescalingMLP(hidden_dim=4, condition_dim=condition_dim))
seed = self._neural_kwargs.pop("seed", 0)
rng = jax.random.PRNGKey(seed)
ot_matcher = self._neural_kwargs.pop("ot_matcher", OTMatcherLinear(
ot_solver, tau_a=tau_a, tau_b=tau_b
))
ot_matcher = self._neural_kwargs.pop("ot_matcher", OTMatcherLinear(ot_solver, tau_a=tau_a, tau_b=tau_b))
time_sampler = self._neural_kwargs.pop("time_sampler", uniform_sampler)
unbalancedness_handler = self._neural_kwargs.pop("unbalancedness_handler", UnbalancednessHandler(
rng=rng, source_dim=source_dim, target_dim=target_dim, cond_dim=condition_dim, tau_a=tau_a, tau_b=tau_b, rescaling_a=rescaling_a, rescaling_b=rescaling_b,
))
unbalancedness_handler = self._neural_kwargs.pop(
"unbalancedness_handler",
UnbalancednessHandler(
rng=rng,
source_dim=source_dim,
target_dim=target_dim,
cond_dim=condition_dim,
tau_a=tau_a,
tau_b=tau_b,
rescaling_a=rescaling_a,
rescaling_b=rescaling_b,
),
)
optimizer = self._neural_kwargs.pop("optimizer", optax.adam(learning_rate=1e-3))
self._solver = GENOTLin(
velocity_field=neural_vf,
Expand All @@ -636,12 +636,15 @@ def _prepare( # type: ignore[override]
optimizer=optimizer,
time_sampler=time_sampler,
rng=rng,
matcher_latent_to_data=OTMatcherLinear(sinkhorn.Sinkhorn()) if self._neural_kwargs.pop("solver_latent_to_data", True) else None,
matcher_latent_to_data=(
OTMatcherLinear(sinkhorn.Sinkhorn()) if self._neural_kwargs.pop("solver_latent_to_data", True) else None
),
k_samples_per_x=self._neural_kwargs.pop("k_samples_per_x", 1),
**self._neural_kwargs
**self._neural_kwargs,
)
return ConditionalOTDataset(datasets=train_loaders, seed=seed), ConditionalOTDataset(
datasets=validate_loaders, seed=seed
)
return ConditionalOTDataset(datasets=train_loaders, seed=seed), ConditionalOTDataset(datasets=validate_loaders, seed=seed)


@staticmethod
def _assert2d(arr: ArrayLike, *, allow_reshape: bool = True) -> jnp.ndarray:
Expand Down
1 change: 0 additions & 1 deletion src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Type, Union


from anndata import AnnData

from moscot import _constants
Expand Down

0 comments on commit 25844d7

Please sign in to comment.