Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap False by default #3146

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
86 changes: 57 additions & 29 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def compute_local_statistics(
adata: AnnData | None = None,
indices: npt.ArrayLike | None = None,
batch_size: int | None = None,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
norm: str = "l2",
mc_samples: int = 10,
) -> xr.Dataset:
Expand All @@ -330,7 +330,8 @@ def compute_local_statistics(
batch_size
Batch size to use for computing the local statistics.
use_vmap
Whether to use vmap to compute the local statistics.
Whether to use vmap to compute the local statistics. If "auto", vmap will be used if
the number of samples is less than 500.
norm
Norm to use for computing the distances.
mc_samples
Expand All @@ -341,6 +342,8 @@ def compute_local_statistics(

from scvi.external.mrvi._utils import _parse_local_statistics_requirements

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if not reductions or len(reductions) == 0:
raise ValueError("At least one reduction must be provided.")

Expand Down Expand Up @@ -418,15 +421,24 @@ def per_sample_inference_fn(pair):

# OK to use stacked rngs here since there is no stochasticity for mean rep.
if reqs.needs_mean_representations:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
try:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError(
"JAX ran out of memory. Try setting use_vmap=False."
) from e
else:
raise e

mean_zs = xr.DataArray(
mean_zs_,
np.array(mean_zs_),
dims=["cell_name", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices].values,
Expand All @@ -445,7 +457,7 @@ def per_sample_inference_fn(pair):
) # (n_mc_samples, n_cells, n_samples, n_latent)
sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3))
sampled_zs = xr.DataArray(
sampled_zs_,
np.array(sampled_zs_),
dims=["cell_name", "mc_sample", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices].values,
Expand All @@ -456,12 +468,12 @@ def per_sample_inference_fn(pair):

if reqs.needs_mean_distances:
mean_dists = self._compute_distances_from_representations(
mean_zs_, indices, norm=norm
mean_zs_, indices, norm=norm, return_numpy=True
)

if reqs.needs_sampled_distances or reqs.needs_normalized_distances:
sampled_dists = self._compute_distances_from_representations(
sampled_zs_, indices, norm=norm
sampled_zs_, indices, norm=norm, return_numpy=True
)

if reqs.needs_normalized_distances:
Expand Down Expand Up @@ -570,6 +582,7 @@ def _compute_distances_from_representations(
reps: jax.typing.ArrayLike,
indices: jax.typing.ArrayLike,
norm: Literal["l2", "l1", "linf"] = "l2",
return_numpy: bool = True,
) -> xr.DataArray:
if norm not in ("l2", "l1", "linf"):
raise ValueError(f"`norm` {norm} not supported")
Expand All @@ -588,6 +601,8 @@ def _compute_distance(rep: jax.typing.ArrayLike):

if reps.ndim == 3:
dists = jax.vmap(_compute_distance)(reps)
if return_numpy:
dists = np.array(dists)
return xr.DataArray(
dists,
dims=["cell_name", "sample_x", "sample_y"],
Expand All @@ -601,6 +616,8 @@ def _compute_distance(rep: jax.typing.ArrayLike):
else:
# Case with sampled representations
dists = jax.vmap(jax.vmap(_compute_distance))(reps)
if return_numpy:
dists = np.array(dists)
return xr.DataArray(
dists,
dims=["cell_name", "mc_sample", "sample_x", "sample_y"],
Expand All @@ -619,7 +636,7 @@ def get_local_sample_representation(
indices: npt.ArrayLike | None = None,
batch_size: int = 256,
use_mean: bool = True,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
) -> xr.DataArray:
"""Compute the local sample representation of the cells in the ``adata`` object.

Expand Down Expand Up @@ -660,7 +677,7 @@ def get_local_sample_distances(
batch_size: int = 256,
use_mean: bool = True,
normalize_distances: bool = False,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
groupby: list[str] | str | None = None,
keep_cell: bool = True,
norm: str = "l2",
Expand Down Expand Up @@ -698,6 +715,8 @@ def get_local_sample_distances(
Number of Monte Carlo samples to use for computing the local sample distances. Only
relevant if ``use_mean=False``.
"""
use_vmap = "auto" if use_vmap == "auto" else use_vmap

input = "mean_distances" if use_mean else "sampled_distances"
if normalize_distances:
if use_mean:
Expand Down Expand Up @@ -1053,7 +1072,7 @@ def differential_expression(
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
batch_size: int = 128,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
normalize_design_matrix: bool = True,
add_batch_specific_offsets: bool = False,
mc_samples: int = 100,
Expand Down Expand Up @@ -1142,6 +1161,8 @@ def differential_expression(

from scipy.stats import false_discovery_control

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if sample_cov_keys is None:
# Hack: kept as kwarg to maintain order of arguments.
raise ValueError("Must assign `sample_cov_keys`")
Expand Down Expand Up @@ -1371,19 +1392,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
Amat = jax.device_put(Amat, self.device)
prefactor = jax.device_put(prefactor, self.device)

res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
try:
res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e
else:
raise e

beta.append(np.array(res["beta"]))
effect_size.append(np.array(res["effect_size"]))
pvalue.append(np.array(res["pvalue"]))
Expand Down