From 060eedc12cc7048c554d477dfb5f62161e2a8e3b Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 14 Jan 2025 02:24:12 -0800 Subject: [PATCH 1/3] vmap False by default --- src/scvi/external/mrvi/_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index e7e7b7c2fa..d6662adf57 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -309,7 +309,7 @@ def compute_local_statistics( adata: AnnData | None = None, indices: npt.ArrayLike | None = None, batch_size: int | None = None, - use_vmap: bool = True, + use_vmap: bool = False, norm: str = "l2", mc_samples: int = 10, ) -> xr.Dataset: @@ -619,7 +619,7 @@ def get_local_sample_representation( indices: npt.ArrayLike | None = None, batch_size: int = 256, use_mean: bool = True, - use_vmap: bool = True, + use_vmap: bool = False, ) -> xr.DataArray: """Compute the local sample representation of the cells in the ``adata`` object. @@ -660,7 +660,7 @@ def get_local_sample_distances( batch_size: int = 256, use_mean: bool = True, normalize_distances: bool = False, - use_vmap: bool = True, + use_vmap: bool = False, groupby: list[str] | str | None = None, keep_cell: bool = True, norm: str = "l2", @@ -1053,7 +1053,7 @@ def differential_expression( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, batch_size: int = 128, - use_vmap: bool = True, + use_vmap: bool = False, normalize_design_matrix: bool = True, add_batch_specific_offsets: bool = False, mc_samples: int = 100, From d685dbb8336f91d9629346c37b9f7acc84f4aef6 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 14 Jan 2025 10:13:01 -0800 Subject: [PATCH 2/3] more informative tracebacks + auto vmap --- docs/tutorials/notebooks | 2 +- src/scvi/external/mrvi/_model.py | 73 +++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index b5890651da..c2fc6d100e 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit b5890651da3ad734cc12e7d54b39395aa6e9137d +Subproject commit c2fc6d100ecc28e716f9ffc96bc68af48a7733b4 diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index d6662adf57..68f4ffcdf3 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -309,7 +309,7 @@ def compute_local_statistics( adata: AnnData | None = None, indices: npt.ArrayLike | None = None, batch_size: int | None = None, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", norm: str = "l2", mc_samples: int = 10, ) -> xr.Dataset: @@ -330,7 +330,8 @@ def compute_local_statistics( batch_size Batch size to use for computing the local statistics. use_vmap - Whether to use vmap to compute the local statistics. + Whether to use vmap to compute the local statistics. If "auto", vmap will be used if + the number of samples is less than 500. norm Norm to use for computing the distances. mc_samples @@ -341,6 +342,8 @@ def compute_local_statistics( from scvi.external.mrvi._utils import _parse_local_statistics_requirements + use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 + if not reductions or len(reductions) == 0: raise ValueError("At least one reduction must be provided.") @@ -418,13 +421,22 @@ def per_sample_inference_fn(pair): # OK to use stacked rngs here since there is no stochasticity for mean rep. if reqs.needs_mean_representations: - mean_zs_ = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - use_mean=True, - ) + try: + mean_zs_ = mapped_inference_fn( + stacked_rngs=stacked_rngs, + x=jnp.array(inf_inputs["x"]), + sample_index=jnp.array(inf_inputs["sample_index"]), + cf_sample=jnp.array(cf_sample), + use_mean=True, + ) + except jax.errors.JaxRuntimeError as e: + if use_vmap: + raise RuntimeError( + "JAX ran out of memory. Try setting use_vmap=False." + ) from e + else: + raise e + mean_zs = xr.DataArray( mean_zs_, dims=["cell_name", "sample", "latent_dim"], @@ -619,7 +631,7 @@ def get_local_sample_representation( indices: npt.ArrayLike | None = None, batch_size: int = 256, use_mean: bool = True, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", ) -> xr.DataArray: """Compute the local sample representation of the cells in the ``adata`` object. @@ -660,7 +672,7 @@ def get_local_sample_distances( batch_size: int = 256, use_mean: bool = True, normalize_distances: bool = False, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", groupby: list[str] | str | None = None, keep_cell: bool = True, norm: str = "l2", @@ -698,6 +710,8 @@ def get_local_sample_distances( Number of Monte Carlo samples to use for computing the local sample distances. Only relevant if ``use_mean=False``. """ + use_vmap = "auto" if use_vmap == "auto" else use_vmap + input = "mean_distances" if use_mean else "sampled_distances" if normalize_distances: if use_mean: @@ -1053,7 +1067,7 @@ def differential_expression( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, batch_size: int = 128, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", normalize_design_matrix: bool = True, add_batch_specific_offsets: bool = False, mc_samples: int = 100, @@ -1142,6 +1156,8 @@ def differential_expression( from scipy.stats import false_discovery_control + use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 + if sample_cov_keys is None: # Hack: kept as kwarg to maintain order of arguments. raise ValueError("Must assign `sample_cov_keys`") @@ -1371,19 +1387,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): Amat = jax.device_put(Amat, self.device) prefactor = jax.device_put(prefactor, self.device) - res = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - Amat=Amat, - prefactor=prefactor, - n_samples_per_cell=n_samples_per_cell, - admissible_samples_mat=admissible_samples_mat, - use_mean=False, - rngs_de=rngs_de, - mc_samples=mc_samples, - ) + try: + res = mapped_inference_fn( + stacked_rngs=stacked_rngs, + x=jnp.array(inf_inputs["x"]), + sample_index=jnp.array(inf_inputs["sample_index"]), + cf_sample=jnp.array(cf_sample), + Amat=Amat, + prefactor=prefactor, + n_samples_per_cell=n_samples_per_cell, + admissible_samples_mat=admissible_samples_mat, + use_mean=False, + rngs_de=rngs_de, + mc_samples=mc_samples, + ) + except jax.errors.JaxRuntimeError as e: + if use_vmap: + raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e + else: + raise e + beta.append(np.array(res["beta"])) effect_size.append(np.array(res["effect_size"])) pvalue.append(np.array(res["pvalue"])) From e3e5ed61a203d5f2a10389816524097a2df61f12 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 27 Jan 2025 15:52:25 -0800 Subject: [PATCH 3/3] store dmats as numpy arrays instead of jax. --- src/scvi/external/mrvi/_model.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 68f4ffcdf3..d8333ceb32 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -438,7 +438,7 @@ def per_sample_inference_fn(pair): raise e mean_zs = xr.DataArray( - mean_zs_, + np.array(mean_zs_), dims=["cell_name", "sample", "latent_dim"], coords={ "cell_name": self.adata.obs_names[indices].values, @@ -457,7 +457,7 @@ def per_sample_inference_fn(pair): ) # (n_mc_samples, n_cells, n_samples, n_latent) sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3)) sampled_zs = xr.DataArray( - sampled_zs_, + np.array(sampled_zs_), dims=["cell_name", "mc_sample", "sample", "latent_dim"], coords={ "cell_name": self.adata.obs_names[indices].values, @@ -468,12 +468,12 @@ def per_sample_inference_fn(pair): if reqs.needs_mean_distances: mean_dists = self._compute_distances_from_representations( - mean_zs_, indices, norm=norm + mean_zs_, indices, norm=norm, return_numpy=True ) if reqs.needs_sampled_distances or reqs.needs_normalized_distances: sampled_dists = self._compute_distances_from_representations( - sampled_zs_, indices, norm=norm + sampled_zs_, indices, norm=norm, return_numpy=True ) if reqs.needs_normalized_distances: @@ -582,6 +582,7 @@ def _compute_distances_from_representations( reps: jax.typing.ArrayLike, indices: jax.typing.ArrayLike, norm: Literal["l2", "l1", "linf"] = "l2", + return_numpy: bool = True, ) -> xr.DataArray: if norm not in ("l2", "l1", "linf"): raise ValueError(f"`norm` {norm} not supported") @@ -600,6 +601,8 @@ def _compute_distance(rep: jax.typing.ArrayLike): if reps.ndim == 3: dists = jax.vmap(_compute_distance)(reps) + if return_numpy: + dists = np.array(dists) return xr.DataArray( dists, dims=["cell_name", "sample_x", "sample_y"], @@ -613,6 +616,8 @@ def _compute_distance(rep: jax.typing.ArrayLike): else: # Case with sampled representations dists = jax.vmap(jax.vmap(_compute_distance))(reps) + if return_numpy: + dists = np.array(dists) return xr.DataArray( dists, dims=["cell_name", "mc_sample", "sample_x", "sample_y"],