Skip to content

Commit

Permalink
Merge branch 'main' into fix/mixscape_perturb_signature
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilly-May authored Feb 21, 2025
2 parents 4fafc2e + 28bc4e3 commit c80729f
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 44 deletions.
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ Consider citing [scanpy Genome Biology (2018)] along with the original {doc}`ref

# Indices and tables

- {ref}`genindex`
- {ref}`modindex`
- {ref}`search`
- {ref}`genindex`
- {ref}`modindex`
- {ref}`search`

[scanpy genome biology (2018)]: https://doi.org/10.1186/s13059-017-1382-0
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
104 changes: 97 additions & 7 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from __future__ import annotations

import uuid
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy
from rich.progress import track
from scipy.sparse import issparse

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture

if TYPE_CHECKING:
from anndata import AnnData
from matplotlib.pyplot import Figure


class GuideAssignment:
"""Offers simple guide assigment based on count thresholds."""
"""Assign cells to guide RNAs."""

def assign_by_threshold(
self,
Expand All @@ -33,12 +37,12 @@ def assign_by_threshold(
This function expects unnormalized data as input.
Args:
adata: Annotated data matrix containing gRNA values
adata: AnnData object containing gRNA values.
assignment_threshold: The count threshold that is required for an assignment to be viable.
layer: Key to the layer containing raw count values of the gRNAs.
adata.X is used if layer is None. Expects count data.
output_layer: Assigned guide will be saved on adata.layers[output_key].
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
Examples:
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
Expand Down Expand Up @@ -67,7 +71,7 @@ def assign_to_max_guide(
assignment_threshold: float,
layer: str | None = None,
output_key: str = "assigned_guide",
no_grna_assigned_key: str = "NT",
no_grna_assigned_key: str = "Negative",
only_return_results: bool = False,
) -> np.ndarray | None:
"""Simple threshold based max gRNA assignment function.
Expand All @@ -76,13 +80,13 @@ def assign_to_max_guide(
This function expects unnormalized data as input.
Args:
adata: Annotated data matrix containing gRNA values
adata: AnnData object containing gRNA values.
assignment_threshold: The count threshold that is required for an assignment to be viable.
layer: Key to the layer containing raw count values of the gRNAs.
adata.X is used if layer is None. Expects count data.
output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
Examples:
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
Expand All @@ -109,6 +113,92 @@ def assign_to_max_guide(

return None

def assign_mixture_model(
self,
adata: AnnData,
model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
assigned_guides_key: str = "assigned_guide",
no_grna_assigned_key: str = "negative",
max_assignments_per_cell: int = 5,
multiple_grna_assigned_key: str = "multiple",
multiple_grna_assignment_string: str = "+",
only_return_results: bool = False,
uns_key: str = "guide_assignment_params",
show_progress: bool = False,
**mixture_model_kwargs,
) -> np.ndarray | None:
"""Assigns gRNAs to cells using a mixture model.
Args:
adata: AnnData object containing gRNA values.
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
output_key: Assigned guide will be saved on adata.obs[output_key].
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
show_progress: Whether to shows progress bar.
mixture_model_kwargs: Are passed to the mixture model.
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> gdo = mdata.mod["gdo"]
>>> ga = pt.pp.GuideAssignment()
>>> ga.assign_mixture_model(gdo)
"""
if model == "poisson_gauss_mixture":
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
else:
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")

if uns_key not in adata.uns:
adata.uns[uns_key] = {}
elif type(adata.uns[uns_key]) is not dict:
raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")

res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
fct = track if show_progress else lambda iterable: iterable
for gene in fct(adata.var_names):
is_nonzero = (
np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
)
if sum(is_nonzero) < 2:
warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
continue
# We are only fitting the model to the non-zero values, the rest is
# automatically assigned to the negative class
data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
data = np.ravel(data)

if np.any(data < 0):
raise ValueError(
"Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
)

# Log2 transform the data so positive population is approximately normal
data = np.log2(data)
assignments = mixture_model.run_model(data)
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
adata.uns[uns_key][gene] = mixture_model.params

# Assign guides to cells
# Some cells might have multiple guides assigned
series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
num_guides_assigned = res.sum(1)
series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
lambda row: row.index[row == 1].tolist(), axis=1
).str.join(multiple_grna_assignment_string)
series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key

if only_return_results:
return series.values

adata.obs[assigned_guides_key] = series.values

return None

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_heatmap(
self,
Expand Down
179 changes: 179 additions & 0 deletions pertpy/preprocessing/_guide_rna_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS

ParamsDict = Mapping[str, jnp.ndarray]


class MixtureModel(ABC):
"""Abstract base class for 2-component mixture models.
Args:
num_warmup: Number of warmup steps for MCMC sampling.
num_samples: Number of samples to draw after warmup.
fraction_positive_expected: Prior belief about fraction of positive components.
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
"""

def __init__(
self,
num_warmup: int = 50,
num_samples: int = 100,
fraction_positive_expected: float = 0.15,
poisson_rate_prior: float = 0.2,
gaussian_mean_prior: tuple[float, float] = (3, 2),
gaussian_std_prior: float = 1,
) -> None:
self.num_warmup = num_warmup
self.num_samples = num_samples
self.fraction_positive_expected = fraction_positive_expected
self.poisson_rate_prior = poisson_rate_prior
self.gaussian_mean_prior = gaussian_mean_prior
self.gaussian_std_prior = gaussian_std_prior

@abstractmethod
def initialize_params(self) -> ParamsDict:
"""Initialize model parameters via sampling from priors.
Returns:
Dictionary of sampled parameter values.
"""
pass

@abstractmethod
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
"""Calculate log likelihood of data under current parameters.
Args:
data: Input data array.
params: Current parameter values.
Returns:
Log likelihood values for each datapoint.
"""
pass

def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
"""Fit the mixture model using MCMC.
Args:
data: Input data to fit.
seed: Random seed for reproducibility.
Returns:
Fitted MCMC object containing samples.
"""
nuts_kernel = NUTS(self.mixture_model)
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(seed), data=data)
return mcmc

def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
"""Run model fitting and assign components.
Args:
data: Input data array.
seed: Random seed.
Returns:
Array of "Positive"/"Negative" assignments for each datapoint.
"""
self.mcmc = self.fit_model(data, seed)
self.samples = self.mcmc.get_samples()
self.assignments = self.assignment(self.samples, data)
return self.assignments

def mixture_model(self, data: jnp.ndarray) -> None:
"""Define mixture model structure for NumPyro.
Args:
data: Input data array.
"""
params = self.initialize_params()

with numpyro.plate("data", data.shape[0]):
log_likelihoods = self.log_likelihood(data, params)
log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)

def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
"""Assign data points to mixture components.
Args:
samples: MCMC samples of parameters.
data: Input data array.
Returns:
Array of component assignments.
"""
params = {key: samples[key].mean(axis=0) for key in samples.keys()}
self.params = params

log_likelihoods = self.log_likelihood(data, params)
guide_assignments = jnp.argmax(log_likelihoods, axis=-1)

assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
return np.array(assignments)


class PoissonGaussMixture(MixtureModel):
"""Mixture model combining Poisson and Gaussian distributions."""

def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
"""Calculate component-wise log likelihoods.
Args:
data: Input data array.
params: Current parameter values.
Returns:
Log likelihood values for each component.
"""
poisson_rate = params["poisson_rate"]
gaussian_mean = params["gaussian_mean"]
gaussian_std = params["gaussian_std"]
mix_probs = params["mix_probs"]

# We penalize the model for positioning the Poisson component to the right of the Gaussian component
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
# Heuristic regularization term to prevent flipping of the components
numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))

log_likelihoods = jnp.stack(
[
# Poisson component
jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
# Gaussian component
jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
],
axis=-1,
)

return log_likelihoods

def initialize_params(self) -> ParamsDict:
"""Initialize model parameters via prior sampling.
Returns:
Dictionary of sampled parameter values.
"""
params = {}
params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
params["mix_probs"] = numpyro.sample(
"mix_probs",
dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
)
return params
Loading

0 comments on commit c80729f

Please sign in to comment.