Skip to content

Commit

Permalink
Polish
Browse files Browse the repository at this point in the history
Signed-off-by: Lukas Heumos <[email protected]>
  • Loading branch information
Zethson committed Feb 20, 2025
1 parent 09af7d9 commit bde9a92
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 50 deletions.
19 changes: 10 additions & 9 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.sparse import issparse

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy.preprocessing._guide_rna_mixture import Poisson_Gauss_Mixture
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture

if TYPE_CHECKING:
from anndata import AnnData
Expand All @@ -37,12 +37,12 @@ def assign_by_threshold(
This function expects unnormalized data as input.
Args:
adata: Annotated data matrix containing gRNA values
adata: AnnData object containing gRNA values.
assignment_threshold: The count threshold that is required for an assignment to be viable.
layer: Key to the layer containing raw count values of the gRNAs.
adata.X is used if layer is None. Expects count data.
output_layer: Assigned guide will be saved on adata.layers[output_key].
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
Examples:
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
Expand Down Expand Up @@ -80,13 +80,13 @@ def assign_to_max_guide(
This function expects unnormalized data as input.
Args:
adata: Annotated data matrix containing gRNA values
adata: AnnData object containing gRNA values.
assignment_threshold: The count threshold that is required for an assignment to be viable.
layer: Key to the layer containing raw count values of the gRNAs.
adata.X is used if layer is None. Expects count data.
output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
Examples:
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
Expand Down Expand Up @@ -130,15 +130,15 @@ def assign_mixture_model(
"""Assigns gRNAs to cells using a mixture model.
Args:
adata: Annotated data matrix containing gRNA values
adata: AnnData object containing gRNA values.
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
output_key: Assigned guide will be saved on adata.obs[output_key].
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
show_progress: If True, shows progress bar.
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
show_progress: Whether to shows progress bar.
mixture_model_kwargs: Are passed to the mixture model.
Examples:
Expand All @@ -149,7 +149,7 @@ def assign_mixture_model(
>>> ga.assign_mixture_model(gdo)
"""
if model == "poisson_gauss_mixture":
mixture_model = Poisson_Gauss_Mixture(**mixture_model_kwargs)
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
else:
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")

Expand Down Expand Up @@ -196,6 +196,7 @@ def assign_mixture_model(
return series.values

adata.obs[assigned_guides_key] = series.values

return None

@_doc_params(common_plot_args=doc_common_plot_args)
Expand Down
122 changes: 81 additions & 41 deletions pertpy/preprocessing/_guide_rna_mixture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping

import jax
import jax.numpy as jnp
Expand All @@ -10,25 +11,19 @@
from jax import random
from numpyro.infer import MCMC, NUTS

ParamsDict = Mapping[str, jnp.ndarray]


class MixtureModel(ABC):
"""Template class for 2-component mixture models for guide assignment.
It handles:
- Fitting the model to the data
- Running the model on the data and assigning each data point to a component
The user needs to implement the following methods:
- initialize: Initialize the model parameters
- log_likelihood: Calculate the log-likelihood of the data under the model
This class has the following parameters:
- num_warmup: Number of warmup or "burn-in" steps in MCMC
- num_samples: Number of samples in MCMC. Recommended to be at least 100
- fraction_positive_expected: Expected fraction of gRNA positive data points
- poisson_rate_prior: Prior for the Poisson rate of the negative component
- gaussian_mean_prior: Prior for the Gaussian mean of the positive component
- gaussian_std_prior: Prior for the Gaussian standard deviation of the positive component
"""Abstract base class for 2-component mixture models.
Args:
num_warmup: Number of warmup steps for MCMC sampling.
num_samples: Number of samples to draw after warmup.
fraction_positive_expected: Prior belief about fraction of positive components.
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
"""

def __init__(
Expand All @@ -37,9 +32,9 @@ def __init__(
num_samples: int = 100,
fraction_positive_expected: float = 0.15,
poisson_rate_prior: float = 0.2,
gaussian_mean_prior: tuple[int, int] = (3, 2),
gaussian_mean_prior: tuple[float, float] = (3, 2),
gaussian_std_prior: float = 1,
):
) -> None:
self.num_warmup = num_warmup
self.num_samples = num_samples
self.fraction_positive_expected = fraction_positive_expected
Expand All @@ -48,44 +43,80 @@ def __init__(
self.gaussian_std_prior = gaussian_std_prior

@abstractmethod
def initialize_params(self) -> dict:
def initialize_params(self) -> ParamsDict:
"""Initialize model parameters via sampling from priors.
Returns:
Dictionary of sampled parameter values.
"""
pass

@abstractmethod
def log_likelihood(self, data: jnp.ndarray, params: dict) -> jnp.ndarray:
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
"""Calculate log likelihood of data under current parameters.
Args:
data: Input data array.
params: Current parameter values.
Returns:
Log likelihood values for each datapoint.
"""
pass

def fit_model(self, data: jnp.ndarray, seed: int = 0):
def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
"""Fit the mixture model using MCMC.
Args:
data: Input data to fit.
seed: Random seed for reproducibility.
Returns:
Fitted MCMC object containing samples.
"""
nuts_kernel = NUTS(self.mixture_model)
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(seed), data=data)
return mcmc

def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
# Runs MCMS on the model and returns the assignments of the data points
"""Run model fitting and assign components.
Args:
data: Input data array.
seed: Random seed.
Returns:
Array of "Positive"/"Negative" assignments for each datapoint.
"""
self.mcmc = self.fit_model(data, seed)
self.samples = self.mcmc.get_samples()
self.assignments = self.assignment(self.samples, data)
return self.assignments

def mixture_model(self, data: jnp.ndarray) -> None:
# Note: numpyro does not natively support discrete latent variables.
# Hence here we manually marginalize out the discrete latent variable,
# which requires us to use a log-likelihood formulation.
"""Define mixture model structure for NumPyro.
Args:
data: Input data array.
"""
params = self.initialize_params()

with numpyro.plate("data", data.shape[0]):
log_likelihoods = self.log_likelihood(data, params)

# Use logsumexp for numerical stability
log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)

# Sample the data from the mixture distribution
numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)

def assignment(self, samples: dict, data: jnp.ndarray) -> np.ndarray:
# Assigns each data point to a component based on the highest log-likelihood
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
"""Assign data points to mixture components.
Args:
samples: MCMC samples of parameters.
data: Input data array.
Returns:
Array of component assignments.
"""
params = {key: samples[key].mean(axis=0) for key in samples.keys()}
self.params = params

Expand All @@ -96,33 +127,42 @@ def assignment(self, samples: dict, data: jnp.ndarray) -> np.ndarray:
return np.array(assignments)


class Poisson_Gauss_Mixture(MixtureModel):
def log_likelihood(self, data: np.ndarray, params: dict) -> jnp.ndarray:
# Defines how to calculate the log-likelihood of the data under the model
class PoissonGaussMixture(MixtureModel):
"""Mixture model combining Poisson and Gaussian distributions."""

def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
"""Calculate component-wise log likelihoods.
Args:
data: Input data array.
params: Current parameter values.
Returns:
Log likelihood values for each component.
"""
poisson_rate = params["poisson_rate"]
gaussian_mean = params["gaussian_mean"]
gaussian_std = params["gaussian_std"]
mix_probs = params["mix_probs"]

# We penalize the model for positioning the Poisson component to the right of the Gaussian component
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
# Heuristic regularization term to prevent flipping of the components
numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))

log_likelihoods = jnp.stack(
[
# Poisson component
jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
# Gaussian component
jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
],
axis=-1,
)

return log_likelihoods

def initialize_params(self) -> dict:
def initialize_params(self) -> ParamsDict:
"""Initialize model parameters via prior sampling.
Returns:
Dictionary of sampled parameter values.
"""
params = {}
params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
Expand Down

0 comments on commit bde9a92

Please sign in to comment.