Skip to content

Commit

Permalink
docs: better docs for mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
joschif committed May 21, 2024
1 parent c2a2ae1 commit 9b49fa3
Showing 1 changed file with 43 additions and 38 deletions.
81 changes: 43 additions & 38 deletions hnoca/map/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Literal, Optional

import anndata as ad
import numpy as np
import pandas as pd

Expand All @@ -23,31 +24,42 @@ class AtlasMapper:
A class to map a query dataset to a reference dataset using scPoli, scVI or scANVI models.
"""

def __init__(self, ref_model):
def __init__(
self,
ref_model: Union[
scvi.model.SCANVI, scvi.model.SCVI, scarches.models.scpoli.scPoli
],
):
"""
Initialize the AtlasMapper object
Args:
ref_model: scvi.model
The reference model to map the query dataset to.
ref_model: The reference model to map the query dataset to.
"""
self.model_type = self._check_model_type(ref_model)
self.ref_model = ref_model
self.ref_adata = ref_model.adata
self.query_model = None
self.ref_trans_prob = None

def map_query(self, query_adata, retrain="partial", **kwargs):
def map_query(
self,
query_adata: ad.AnnData,
retrain: Literal["partial", "full", "none"],
**kwargs,
):
"""
Map a query dataset to the reference dataset
Args:
query_adata : AnnData
The query dataset to map to the reference dataset
query_model : str
The model to use for the query dataset
retrain : str
Whether to retrain the query model. Options are "partial", "full" or "none"
query_adata: The query dataset to map to the reference dataset
retrain: Whether to retrain the query model.
* `"partial"` will retrain the weights of the new batch key
* `"full"` will retrain the entire model
* `"none"` will use the reference model without retraining
kwargs: Additional keyword arguments to pass to the training function
"""
if retrain in ["partial", "full"]:
if self.model_type == "scanvi":
Expand Down Expand Up @@ -123,7 +135,7 @@ def _get_latent(self, model, adata, **kwargs):

def compute_wknn(
self,
ref_adata=None,
ref_adata: ad.AnnData = None,
k: int = 100,
query2ref: bool = True,
ref2query: bool = False,
Expand All @@ -137,16 +149,11 @@ def compute_wknn(
Args:
k : int
Number of neighbors per cell
query2ref : bool
Consider query-to-ref neighbors
ref2query : bool
Consider ref-to-query neighbors
weighting_scheme : str
How to weight edges in the ref-query neighbor graph
top_n : int
The number of top neighbors to consider
k: Number of neighbors per cell
query2ref: Consider query-to-ref neighbors
ref2query: Consider ref-to-query neighbors
weighting_scheme: How to weight edges in the ref-query neighbor graph
top_n: The number of top neighbors to consider
"""

self.ref_adata = ref_adata if ref_adata is not None else self.ref_adata
Expand All @@ -168,22 +175,22 @@ def compute_wknn(
self.query_adata.obsm["X_latent"] = query_latent

def get_presence_scores(
self, split_by=None, random_walk=True, alpha=0.1, n_rounds=100, log=True
self,
split_by: str = None,
random_walk: bool = True,
alpha: float = 0.1,
n_rounds: int = 100,
log: bool = True,
):
"""
Estimate the presence score of the query dataset
Args:
split_by: str
The column in the query dataset to split by
random_walk: bool
Whether to use random walk to estimate presence score
alpha: float
The heat diffusion parameter for the random walk
n_rounds: int
The number of rounds for the random walk
log: bool
Whether to log the presence score
split_by: The column in the query dataset to split by
random_walk: Whether to use random walk to estimate presence score
alpha: The heat diffusion parameter for the random walk
n_rounds: The number of rounds for the random walk
log: Whether to log the presence score
"""

scores = estimate_presence_score(
Expand All @@ -202,7 +209,7 @@ def get_presence_scores(
self.ref_trans_prob = scores["ref_trans_prop"]
return scores

def transfer_labels(self, label_key):
def transfer_labels(self, label_key: str):
"""
Transfer labels from the reference dataset to the query dataset
Expand All @@ -220,13 +227,11 @@ def transfer_labels(self, label_key):

return scores

def get_matched_expression(self, rescale_factor=1):
def get_matched_expression(self, rescale_factor: int = 1):
"""
Get the expression of reference cells matched to query cells. This can be used for quantitative comparisons like DE analysis.
Args:
layer: str
If not None, uses this as the key in adata.layers to return the reference transcriptome.
rescale_factor: str
Factor to rescale the log-normalized counts
"""
Expand All @@ -239,7 +244,7 @@ def get_matched_expression(self, rescale_factor=1):
self.matched_adata = matched_adata
return matched_adata

def save(self, output_dir):
def save(self, output_dir: str):
"""
Save the mapper object to disk
Expand All @@ -252,7 +257,7 @@ def save(self, output_dir):
cloudpickle.dump(self, f)

@classmethod
def load(cls, input_dir):
def load(cls, input_dir: str):
"""
Load the mapper object from disk
Expand Down

0 comments on commit 9b49fa3

Please sign in to comment.