From bde9a92a920d3824c713cab7020856004b8018e9 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Thu, 20 Feb 2025 20:38:36 +0100 Subject: [PATCH] Polish Signed-off-by: Lukas Heumos --- pertpy/preprocessing/_guide_rna.py | 19 ++-- pertpy/preprocessing/_guide_rna_mixture.py | 122 ++++++++++++++------- 2 files changed, 91 insertions(+), 50 deletions(-) diff --git a/pertpy/preprocessing/_guide_rna.py b/pertpy/preprocessing/_guide_rna.py index 38f9917f..f69f2d9c 100644 --- a/pertpy/preprocessing/_guide_rna.py +++ b/pertpy/preprocessing/_guide_rna.py @@ -13,7 +13,7 @@ from scipy.sparse import issparse from pertpy._doc import _doc_params, doc_common_plot_args -from pertpy.preprocessing._guide_rna_mixture import Poisson_Gauss_Mixture +from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture if TYPE_CHECKING: from anndata import AnnData @@ -37,12 +37,12 @@ def assign_by_threshold( This function expects unnormalized data as input. Args: - adata: Annotated data matrix containing gRNA values + adata: AnnData object containing gRNA values. assignment_threshold: The count threshold that is required for an assignment to be viable. layer: Key to the layer containing raw count values of the gRNAs. adata.X is used if layer is None. Expects count data. output_layer: Assigned guide will be saved on adata.layers[output_key]. - only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. + only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell. @@ -80,13 +80,13 @@ def assign_to_max_guide( This function expects unnormalized data as input. Args: - adata: Annotated data matrix containing gRNA values + adata: AnnData object containing gRNA values. assignment_threshold: The count threshold that is required for an assignment to be viable. layer: Key to the layer containing raw count values of the gRNAs. adata.X is used if layer is None. Expects count data. output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`. no_grna_assigned_key: The key to return if no gRNA is expressed enough. - only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. + only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray. Examples: Each cell is assigned to the most expressed gRNA if it has at least 5 counts. @@ -130,15 +130,15 @@ def assign_mixture_model( """Assigns gRNAs to cells using a mixture model. Args: - adata: Annotated data matrix containing gRNA values + adata: AnnData object containing gRNA values. model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported. output_key: Assigned guide will be saved on adata.obs[output_key]. no_grna_assigned_key: The key to return if a cell is negative for all gRNAs. max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell. multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell. multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell. - only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. - show_progress: If True, shows progress bar. + only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray. + show_progress: Whether to shows progress bar. mixture_model_kwargs: Are passed to the mixture model. Examples: @@ -149,7 +149,7 @@ def assign_mixture_model( >>> ga.assign_mixture_model(gdo) """ if model == "poisson_gauss_mixture": - mixture_model = Poisson_Gauss_Mixture(**mixture_model_kwargs) + mixture_model = PoissonGaussMixture(**mixture_model_kwargs) else: raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.") @@ -196,6 +196,7 @@ def assign_mixture_model( return series.values adata.obs[assigned_guides_key] = series.values + return None @_doc_params(common_plot_args=doc_common_plot_args) diff --git a/pertpy/preprocessing/_guide_rna_mixture.py b/pertpy/preprocessing/_guide_rna_mixture.py index ec9e7bf9..46df3c7e 100644 --- a/pertpy/preprocessing/_guide_rna_mixture.py +++ b/pertpy/preprocessing/_guide_rna_mixture.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Mapping import jax import jax.numpy as jnp @@ -10,25 +11,19 @@ from jax import random from numpyro.infer import MCMC, NUTS +ParamsDict = Mapping[str, jnp.ndarray] + class MixtureModel(ABC): - """Template class for 2-component mixture models for guide assignment. - - It handles: - - Fitting the model to the data - - Running the model on the data and assigning each data point to a component - - The user needs to implement the following methods: - - initialize: Initialize the model parameters - - log_likelihood: Calculate the log-likelihood of the data under the model - - This class has the following parameters: - - num_warmup: Number of warmup or "burn-in" steps in MCMC - - num_samples: Number of samples in MCMC. Recommended to be at least 100 - - fraction_positive_expected: Expected fraction of gRNA positive data points - - poisson_rate_prior: Prior for the Poisson rate of the negative component - - gaussian_mean_prior: Prior for the Gaussian mean of the positive component - - gaussian_std_prior: Prior for the Gaussian standard deviation of the positive component + """Abstract base class for 2-component mixture models. + + Args: + num_warmup: Number of warmup steps for MCMC sampling. + num_samples: Number of samples to draw after warmup. + fraction_positive_expected: Prior belief about fraction of positive components. + poisson_rate_prior: Rate parameter for exponential prior on Poisson component. + gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean. + gaussian_std_prior: Scale parameter for half-normal prior on positive component std. """ def __init__( @@ -37,9 +32,9 @@ def __init__( num_samples: int = 100, fraction_positive_expected: float = 0.15, poisson_rate_prior: float = 0.2, - gaussian_mean_prior: tuple[int, int] = (3, 2), + gaussian_mean_prior: tuple[float, float] = (3, 2), gaussian_std_prior: float = 1, - ): + ) -> None: self.num_warmup = num_warmup self.num_samples = num_samples self.fraction_positive_expected = fraction_positive_expected @@ -48,44 +43,80 @@ def __init__( self.gaussian_std_prior = gaussian_std_prior @abstractmethod - def initialize_params(self) -> dict: + def initialize_params(self) -> ParamsDict: + """Initialize model parameters via sampling from priors. + + Returns: + Dictionary of sampled parameter values. + """ pass @abstractmethod - def log_likelihood(self, data: jnp.ndarray, params: dict) -> jnp.ndarray: + def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray: + """Calculate log likelihood of data under current parameters. + + Args: + data: Input data array. + params: Current parameter values. + + Returns: + Log likelihood values for each datapoint. + """ pass - def fit_model(self, data: jnp.ndarray, seed: int = 0): + def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC: + """Fit the mixture model using MCMC. + + Args: + data: Input data to fit. + seed: Random seed for reproducibility. + + Returns: + Fitted MCMC object containing samples. + """ nuts_kernel = NUTS(self.mixture_model) mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False) mcmc.run(random.PRNGKey(seed), data=data) return mcmc def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray: - # Runs MCMS on the model and returns the assignments of the data points + """Run model fitting and assign components. + + Args: + data: Input data array. + seed: Random seed. + + Returns: + Array of "Positive"/"Negative" assignments for each datapoint. + """ self.mcmc = self.fit_model(data, seed) self.samples = self.mcmc.get_samples() self.assignments = self.assignment(self.samples, data) return self.assignments def mixture_model(self, data: jnp.ndarray) -> None: - # Note: numpyro does not natively support discrete latent variables. - # Hence here we manually marginalize out the discrete latent variable, - # which requires us to use a log-likelihood formulation. + """Define mixture model structure for NumPyro. + Args: + data: Input data array. + """ params = self.initialize_params() with numpyro.plate("data", data.shape[0]): log_likelihoods = self.log_likelihood(data, params) - - # Use logsumexp for numerical stability log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1) - - # Sample the data from the mixture distribution numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data) - def assignment(self, samples: dict, data: jnp.ndarray) -> np.ndarray: - # Assigns each data point to a component based on the highest log-likelihood + def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray: + """Assign data points to mixture components. + + Args: + samples: MCMC samples of parameters. + data: Input data array. + + Returns: + Array of component assignments. + """ params = {key: samples[key].mean(axis=0) for key in samples.keys()} self.params = params @@ -96,25 +127,29 @@ def assignment(self, samples: dict, data: jnp.ndarray) -> np.ndarray: return np.array(assignments) -class Poisson_Gauss_Mixture(MixtureModel): - def log_likelihood(self, data: np.ndarray, params: dict) -> jnp.ndarray: - # Defines how to calculate the log-likelihood of the data under the model +class PoissonGaussMixture(MixtureModel): + """Mixture model combining Poisson and Gaussian distributions.""" + def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray: + """Calculate component-wise log likelihoods. + + Args: + data: Input data array. + params: Current parameter values. + + Returns: + Log likelihood values for each component. + """ poisson_rate = params["poisson_rate"] gaussian_mean = params["gaussian_mean"] gaussian_std = params["gaussian_std"] mix_probs = params["mix_probs"] - # We penalize the model for positioning the Poisson component to the right of the Gaussian component - # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean - # Heuristic regularization term to prevent flipping of the components numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0)) log_likelihoods = jnp.stack( [ - # Poisson component jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data), - # Gaussian component jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data), ], axis=-1, @@ -122,7 +157,12 @@ def log_likelihood(self, data: np.ndarray, params: dict) -> jnp.ndarray: return log_likelihoods - def initialize_params(self) -> dict: + def initialize_params(self) -> ParamsDict: + """Initialize model parameters via prior sampling. + + Returns: + Dictionary of sampled parameter values. + """ params = {} params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior)) params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))