From b8f7359b83f48ba0a7dafa91de3529294a105177 Mon Sep 17 00:00:00 2001 From: zethson Date: Tue, 6 Feb 2024 18:00:35 +0100 Subject: [PATCH] Fix ruff config Signed-off-by: zethson --- pertpy/metadata/_cell_line.py | 16 ++++--- pertpy/plot/_augur.py | 12 +++-- pertpy/plot/_coda.py | 4 +- pertpy/plot/_guide_rna.py | 2 +- pertpy/plot/_milopy.py | 2 +- pertpy/plot/_mixscape.py | 38 ++++++++------- pertpy/preprocessing/_guide_rna.py | 6 +-- pertpy/tools/_augur.py | 16 +++++-- pertpy/tools/_coda/_base_coda.py | 4 +- pertpy/tools/_coda/_tasccoda.py | 2 +- pertpy/tools/_dialogue.py | 4 +- pertpy/tools/_distances/_distance_tests.py | 16 +++---- pertpy/tools/_milo.py | 6 +-- pertpy/tools/_mixscape.py | 48 ++++++++++--------- .../tools/_perturbation_space/_clustering.py | 4 +- .../_discriminator_classifier.py | 6 +-- .../_perturbation_space.py | 8 ++-- pertpy/tools/_perturbation_space/_simple.py | 2 +- pertpy/tools/_scgen/_scgen.py | 2 +- pertpy/tools/_scgen/_utils.py | 4 +- pyproject.toml | 9 +++- 21 files changed, 126 insertions(+), 85 deletions(-) diff --git a/pertpy/metadata/_cell_line.py b/pertpy/metadata/_cell_line.py index 614a9363..f23edda3 100644 --- a/pertpy/metadata/_cell_line.py +++ b/pertpy/metadata/_cell_line.py @@ -218,7 +218,7 @@ def annotate( Examples: >>> import pertpy as pt >>> adata = pt.dt.dialogue_example() - >>> adata.obs['cell_line_name'] = 'MCF7' + >>> adata.obs["cell_line_name"] = "MCF7" >>> pt_metadata = pt.md.CellLine() >>> adata_annotated = pt_metadata.annotate(adata=adata, >>> reference_id='cell_line_name', @@ -332,9 +332,11 @@ def annotate_bulk_rna( Examples: >>> import pertpy as pt >>> adata = pt.dt.dialogue_example() - >>> adata.obs['cell_line_name'] = 'MCF7' + >>> adata.obs["cell_line_name"] = "MCF7" >>> pt_metadata = pt.md.CellLine() - >>> adata_annotated = pt_metadata.annotate(adata=adata, reference_id='cell_line_name', query_id='cell_line_name', copy=True) + >>> adata_annotated = pt_metadata.annotate( + ... adata=adata, reference_id="cell_line_name", query_id="cell_line_name", copy=True + ... ) >>> pt_metadata.annotate_bulk_rna(adata_annotated) """ if copy: @@ -433,9 +435,11 @@ def annotate_protein_expression( Examples: >>> import pertpy as pt >>> adata = pt.dt.dialogue_example() - >>> adata.obs['cell_line_name'] = 'MCF7' + >>> adata.obs["cell_line_name"] = "MCF7" >>> pt_metadata = pt.md.CellLine() - >>> adata_annotated = pt_metadata.annotate(adata=adata, reference_id='cell_line_name', query_id='cell_line_name', copy=True) + >>> adata_annotated = pt_metadata.annotate( + ... adata=adata, reference_id="cell_line_name", query_id="cell_line_name", copy=True + ... ) >>> pt_metadata.annotate_protein_expression(adata_annotated) """ if copy: @@ -520,7 +524,7 @@ def annotate_from_gdsc( >>> import pertpy as pt >>> adata = pt.dt.mcfarland_2020() >>> pt_metadata = pt.md.CellLine() - >>> pt_metadata.annotate_from_gdsc(adata, query_id='cell_line') + >>> pt_metadata.annotate_from_gdsc(adata, query_id="cell_line") """ if copy: adata = adata.copy() diff --git a/pertpy/plot/_augur.py b/pertpy/plot/_augur.py index f318bfe8..8b818647 100644 --- a/pertpy/plot/_augur.py +++ b/pertpy/plot/_augur.py @@ -79,7 +79,9 @@ def important_features( >>> adata = pt.dt.sc_sim_augur() >>> ag_rfc = pt.tl.Augur("random_forest_classifier") >>> loaded_data = ag_rfc.load(adata) - >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict( + ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4 + ... ) >>> ag_rfc.plot_important_features(v_results) """ warnings.warn( @@ -115,7 +117,9 @@ def lollipop( >>> adata = pt.dt.sc_sim_augur() >>> ag_rfc = pt.tl.Augur("random_forest_classifier") >>> loaded_data = ag_rfc.load(adata) - >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict( + ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4 + ... ) >>> ag_rfc.plot_lollipop(v_results) """ warnings.warn( @@ -152,7 +156,9 @@ def scatterplot( >>> ag_rfc = pt.tl.Augur("random_forest_classifier") >>> loaded_data = ag_rfc.load(adata) >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4) - >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict( + ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4 + ... ) >>> ag_rfc.plot_scatterplot(v_results, h_results) """ warnings.warn( diff --git a/pertpy/plot/_coda.py b/pertpy/plot/_coda.py index e2dbeb11..51059671 100644 --- a/pertpy/plot/_coda.py +++ b/pertpy/plot/_coda.py @@ -594,7 +594,9 @@ def effects_umap( # pragma: no cover >>> pen_args={"phi": 0, "lambda_1": 3.5}, >>> tree_key="tree" >>> ) - >>> tasccoda_model.run_nuts(tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000) + >>> tasccoda_model.run_nuts( + ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000 + ... ) >>> tasccoda_model.plot_effects_umap(tasccoda_data, >>> effect_name=["effect_df_condition[T.Salmonella]", >>> "effect_df_condition[T.Hpoly.Day3]", diff --git a/pertpy/plot/_guide_rna.py b/pertpy/plot/_guide_rna.py index 9763257c..ffda2bf7 100644 --- a/pertpy/plot/_guide_rna.py +++ b/pertpy/plot/_guide_rna.py @@ -46,7 +46,7 @@ def heatmap( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> gdo = mdata.mod['gdo'] + >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) >>> ga.plot_heatmap(gdo) diff --git a/pertpy/plot/_milopy.py b/pertpy/plot/_milopy.py index 3ecb3e8c..34c61bce 100644 --- a/pertpy/plot/_milopy.py +++ b/pertpy/plot/_milopy.py @@ -155,7 +155,7 @@ def da_beeswarm( >>> milo.make_nhoods(mdata["rna"]) >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") >>> milo.da_nhoods(mdata, design="~label") - >>> milo.annotate_nhoods(mdata, anno_col='cell_type') + >>> milo.annotate_nhoods(mdata, anno_col="cell_type") >>> milo.plot_da_beeswarm(mdata) """ warnings.warn( diff --git a/pertpy/plot/_mixscape.py b/pertpy/plot/_mixscape.py index 6156643e..d7203178 100644 --- a/pertpy/plot/_mixscape.py +++ b/pertpy/plot/_mixscape.py @@ -47,9 +47,9 @@ def barplot( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.plot_barplot(mdata['rna'], guide_rna_column='NT') + >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms.plot_barplot(mdata["rna"], guide_rna_column="NT") """ warnings.warn( "This function is deprecated and will be removed in pertpy 0.8.0!" @@ -109,9 +109,11 @@ def heatmap( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT') + >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms.plot_heatmap( + ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT" + ... ) """ warnings.warn( "This function is deprecated and will be removed in pertpy 0.8.0!" @@ -173,9 +175,11 @@ def perturbscore( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> mixscape_identifier = pt.tl.Mixscape() - >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> mixscape_identifier.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange') + >>> mixscape_identifier.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> mixscape_identifier.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> mixscape_identifier.perturbscore( + ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange" + ... ) """ warnings.warn( "This function is deprecated and will be removed in pertpy 0.8.0!" @@ -247,9 +251,11 @@ def violin( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class') + >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms.plot_violin( + ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class" + ... ) """ warnings.warn( "This function is deprecated and will be removed in pertpy 0.8.0!" @@ -319,10 +325,10 @@ def lda( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') - >>> ms.plot_lda(adata=mdata['rna'], control='NT') + >>> ms.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms.plot_lda(adata=mdata["rna"], control="NT") """ warnings.warn( "This function is deprecated and will be removed in pertpy 0.8.0!" diff --git a/pertpy/preprocessing/_guide_rna.py b/pertpy/preprocessing/_guide_rna.py index 87c6f96e..e16b7322 100644 --- a/pertpy/preprocessing/_guide_rna.py +++ b/pertpy/preprocessing/_guide_rna.py @@ -41,7 +41,7 @@ def assign_by_threshold( >>> import pertpy as pt >>> mdata = pt.data.papalexi_2021() - >>> gdo = mdata.mod['gdo'] + >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) """ @@ -86,7 +86,7 @@ def assign_to_max_guide( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> gdo = mdata.mod['gdo'] + >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_to_max_guide(gdo, assignment_threshold=5) """ @@ -143,7 +143,7 @@ def plot_heatmap( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> gdo = mdata.mod['gdo'] + >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) >>> ga.heatmap(gdo) diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index 38889f0e..89e86ec2 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -220,7 +220,9 @@ def sample(self, adata: AnnData, categorical: bool, subsample_size: int, random_ >>> loaded_data = ag_rfc.load(adata) >>> ag_rfc.select_highly_variable(loaded_data) >>> features = loaded_data.var_names - >>> subsample = ag_rfc.sample(loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names) + >>> subsample = ag_rfc.sample( + ... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names + ... ) """ # export subsampling. random.seed(random_state) @@ -1051,7 +1053,9 @@ def plot_important_features( >>> adata = pt.dt.sc_sim_augur() >>> ag_rfc = pt.tl.Augur("random_forest_classifier") >>> loaded_data = ag_rfc.load(adata) - >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict( + ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4 + ... ) >>> ag_rfc.plot_important_features(v_results) """ if isinstance(data, AnnData): @@ -1101,7 +1105,9 @@ def plot_lollipop( >>> adata = pt.dt.sc_sim_augur() >>> ag_rfc = pt.tl.Augur("random_forest_classifier") >>> loaded_data = ag_rfc.load(adata) - >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict( + ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4 + ... ) >>> ag_rfc.plot_lollipop(v_results) """ if isinstance(data, AnnData): @@ -1149,7 +1155,9 @@ def plot_scatterplot( >>> ag_rfc = pt.tl.Augur("random_forest_classifier") >>> loaded_data = ag_rfc.load(adata) >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4) - >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) + >>> v_adata, v_results = ag_rfc.predict( + ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4 + ... ) >>> ag_rfc.plot_scatterplot(v_results, h_results) """ cell_types = results1["summary_metrics"].columns diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index b431e931..b7a68e11 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -2141,7 +2141,9 @@ def plot_effects_umap( # pragma: no cover >>> pen_args={"phi": 0, "lambda_1": 3.5}, >>> tree_key="tree" >>> ) - >>> tasccoda_model.run_nuts(tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000) + >>> tasccoda_model.run_nuts( + ... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000 + ... ) >>> tasccoda_model.plot_effects_umap(tasccoda_data, >>> effect_name=["effect_df_condition[T.Salmonella]", >>> "effect_df_condition[T.Hpoly.Day3]", diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index faf2e56a..6e7aa393 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -334,7 +334,7 @@ def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adat >>> mdata = tasccoda.prepare( >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0} >>> ) - >>> adata = tasccoda.set_init_mcmc_states(rng_key=42, ref_index=[0,1], sample_adata=mdata['coda']) + >>> adata = tasccoda.set_init_mcmc_states(rng_key=42, ref_index=[0, 1], sample_adata=mdata["coda"]) """ N, D = sample_adata.obsm["covariate_matrix"].shape P = sample_adata.X.shape[1] diff --git a/pertpy/tools/_dialogue.py b/pertpy/tools/_dialogue.py index c6e9f9cb..4bf8417d 100644 --- a/pertpy/tools/_dialogue.py +++ b/pertpy/tools/_dialogue.py @@ -623,7 +623,9 @@ def calculate_multifactor_PMD( >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) - >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", n_counts_key = "nCount_RNA", n_mpcs = 3) + >>> dl = pt.tl.Dialogue( + ... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3 + ... ) >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True) """ # IMPORTANT NOTE: the order in which matrices are passed to multicca matters. diff --git a/pertpy/tools/_distances/_distance_tests.py b/pertpy/tools/_distances/_distance_tests.py index ba6f004c..3f4770d5 100644 --- a/pertpy/tools/_distances/_distance_tests.py +++ b/pertpy/tools/_distances/_distance_tests.py @@ -37,8 +37,8 @@ class DistanceTest: Examples: >>> import pertpy as pt >>> adata = pt.dt.distance_example_data() - >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) - >>> tab = distance_test(adata, groupby='perturbation', contrast='control') + >>> distance_test = pt.tl.DistanceTest("edistance", n_perms=1000) + >>> tab = distance_test(adata, groupby="perturbation", contrast="control") """ def __init__( @@ -100,8 +100,8 @@ def __call__( Examples: >>> import pertpy as pt >>> adata = pt.dt.distance_example_data() - >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) - >>> tab = distance_test(adata, groupby='perturbation', contrast='control') + >>> distance_test = pt.tl.DistanceTest("edistance", n_perms=1000) + >>> tab = distance_test(adata, groupby="perturbation", contrast="control") """ if self.distance.metric_fct.accepts_precomputed: # Much faster if the metric can be called on the precomputed @@ -134,8 +134,8 @@ def test_xy(self, adata: AnnData, groupby: str, contrast: str, show_progressbar: Examples: >>> import pertpy as pt >>> adata = pt.dt.distance_example_data() - >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) - >>> test_results = distance_test.test_xy(adata, groupby='perturbation', contrast='control') + >>> distance_test = pt.tl.DistanceTest("edistance", n_perms=1000) + >>> test_results = distance_test.test_xy(adata, groupby="perturbation", contrast="control") """ groups = adata.obs[groupby].unique() if contrast not in groups: @@ -226,8 +226,8 @@ def test_precomputed(self, adata: AnnData, groupby: str, contrast: str, verbose: Examples: >>> import pertpy as pt >>> adata = pt.dt.distance_example_data() - >>> distance_test = pt.tl.DistanceTest('edistance', n_perms=1000) - >>> test_results = distance_test.test_precomputed(adata, groupby='perturbation', contrast='control') + >>> distance_test = pt.tl.DistanceTest("edistance", n_perms=1000) + >>> test_results = distance_test.test_precomputed(adata, groupby="perturbation", contrast="control") """ if not self.distance.metric_fct.accepts_precomputed: raise ValueError(f"Metric {self.metric} does not accept precomputed distances.") diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 8a911005..60382378 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -429,7 +429,7 @@ def annotate_nhoods( >>> sc.pp.neighbors(mdata["rna"]) >>> milo.make_nhoods(mdata["rna"]) >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") - >>> milo.annotate_nhoods(mdata, anno_col='cell_type') + >>> milo.annotate_nhoods(mdata, anno_col="cell_type") """ try: sample_adata = mdata["milo"] @@ -480,7 +480,7 @@ def annotate_nhoods_continuous(self, mdata: MuData, anno_col: str, feature_key: >>> sc.pp.neighbors(mdata["rna"]) >>> milo.make_nhoods(mdata["rna"]) >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") - >>> milo.annotate_nhoods_continuous(mdata, anno_col='nUMI') + >>> milo.annotate_nhoods_continuous(mdata, anno_col="nUMI") """ if "milo" not in mdata.mod: raise ValueError( @@ -845,7 +845,7 @@ def plot_da_beeswarm( >>> milo.make_nhoods(mdata["rna"]) >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") >>> milo.da_nhoods(mdata, design="~label") - >>> milo.annotate_nhoods(mdata, anno_col='cell_type') + >>> milo.annotate_nhoods(mdata, anno_col="cell_type") >>> milo.plot_da_beeswarm(mdata) """ try: diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index f4d2b348..7f2d6a62 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -80,7 +80,7 @@ def perturbation_signature( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") """ if copy: adata = adata.copy() @@ -200,8 +200,8 @@ def mixscape( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") """ if copy: adata = adata.copy() @@ -367,9 +367,9 @@ def lda( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") """ if copy: adata = adata.copy() @@ -530,9 +530,9 @@ def plot_barplot( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms_pt.plot_barplot(mdata['rna'], guide_rna_column='NT') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT") Preview: .. image:: /_static/docstring_previews/mixscape_barplot.png @@ -637,9 +637,11 @@ def plot_heatmap( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms_pt.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.plot_heatmap( + ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT" + ... ) Preview: .. image:: /_static/docstring_previews/mixscape_heatmap.png @@ -701,9 +703,9 @@ def plot_perturbscore( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms_pt.plot_perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange") Preview: .. image:: /_static/docstring_previews/mixscape_perturbscore.png @@ -870,9 +872,11 @@ def plot_violin( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms_pt.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.plot_violin( + ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class" + ... ) Preview: .. image:: /_static/docstring_previews/mixscape_violin.png @@ -1051,10 +1055,10 @@ def plot_lda( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') - >>> ms_pt.plot_lda(adata=mdata['rna'], control='NT') + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT") Preview: .. image:: /_static/docstring_previews/mixscape_lda.png diff --git a/pertpy/tools/_perturbation_space/_clustering.py b/pertpy/tools/_perturbation_space/_clustering.py index b86acf14..9aa8d326 100644 --- a/pertpy/tools/_perturbation_space/_clustering.py +++ b/pertpy/tools/_perturbation_space/_clustering.py @@ -41,7 +41,9 @@ def evaluate_clustering( >>> mdata = pt.dt.papalexi_2021() >>> kmeans = pt.tl.KMeansSpace() >>> kmeans_adata = kmeans.compute(mdata["rna"], n_clusters=26) - >>> results = kmeans.evaluate_clustering(kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=['nmi']) + >>> results = kmeans.evaluate_clustering( + ... kmeans_adata, true_label_col="gene_target", cluster_col="k-means", metrics=["nmi"] + ... ) """ if metrics is None: metrics = ["nmi", "ari", "asw"] diff --git a/pertpy/tools/_perturbation_space/_discriminator_classifier.py b/pertpy/tools/_perturbation_space/_discriminator_classifier.py index 12d2569b..2deddb69 100644 --- a/pertpy/tools/_perturbation_space/_discriminator_classifier.py +++ b/pertpy/tools/_perturbation_space/_discriminator_classifier.py @@ -60,7 +60,7 @@ def load( # type: ignore Examples: >>> import pertpy as pt - >>> adata = pt.dt.papalexi_2021()['rna'] + >>> adata = pt.dt.papalexi_2021()["rna"] >>> dcs = pt.tl.DiscriminatorClassifierSpace() >>> dcs.load(adata, target_col="gene_target") """ @@ -135,7 +135,7 @@ def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = Examples: >>> import pertpy as pt - >>> adata = pt.dt.papalexi_2021()['rna'] + >>> adata = pt.dt.papalexi_2021()["rna"] >>> dcs = pt.tl.DiscriminatorClassifierSpace() >>> dcs.load(adata, target_col="gene_target") >>> dcs.train(max_epochs=5) @@ -164,7 +164,7 @@ def get_embeddings(self) -> AnnData: Examples: >>> import pertpy as pt - >>> adata = pt.dt.papalexi_2021()['rna'] + >>> adata = pt.dt.papalexi_2021()["rna"] >>> dcs = pt.tl.DiscriminatorClassifierSpace() >>> dcs.load(adata, target_col="gene_target") >>> dcs.train() diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index 8f832589..af6c0bd2 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -59,7 +59,7 @@ def compute_control_diff( # type: ignore >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() - >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key='NT') + >>> diff_adata = ps.compute_control_diff(mdata["rna"], target_col="gene_target", reference_key="NT") """ if reference_key not in adata.obs[target_col].unique(): raise ValueError( @@ -171,7 +171,7 @@ def add( >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") - >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key='NT') + >>> new_perturbation = ps.add(ps_adata, perturbations=["ATF2", "CD86"], reference_key="NT") """ new_pert_name = "" for perturbation in perturbations: @@ -383,7 +383,9 @@ def label_transfer( >>> import numpy as np >>> adata = sc.datasets.pbmc68k_reduced() >>> rng = np.random.default_rng() - >>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]) + >>> adata.obs["perturbation"] = rng.choice( + ... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01] + ... ) >>> sc.pp.neighbors(adata) >>> sc.tl.umap(adata) >>> ps = pt.tl.PseudobulkSpace() diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index ae4f467a..42649496 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -41,7 +41,7 @@ def compute( >>> import scanpy as sc >>> mdata = pt.dt.papalexi_2021() >>> sc.pp.pca(mdata["rna"]) - >>> sc.pp.neighbors(mdata['rna']) + >>> sc.pp.neighbors(mdata["rna"]) >>> sc.tl.umap(mdata["rna"]) >>> cs = pt.tl.CentroidSpace() >>> cs_adata = cs.compute(mdata["rna"], target_col="gene_target") diff --git a/pertpy/tools/_scgen/_scgen.py b/pertpy/tools/_scgen/_scgen.py index 278c2e66..0e81e39c 100644 --- a/pertpy/tools/_scgen/_scgen.py +++ b/pertpy/tools/_scgen/_scgen.py @@ -86,7 +86,7 @@ def predict( >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") >>> model = pt.tl.SCGEN(data) >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) - >>> pred, delta = model.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells') + >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells") """ # use keys registered from `setup_anndata()` cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key diff --git a/pertpy/tools/_scgen/_utils.py b/pertpy/tools/_scgen/_utils.py index 18887530..1ff6be9a 100644 --- a/pertpy/tools/_scgen/_utils.py +++ b/pertpy/tools/_scgen/_utils.py @@ -32,9 +32,7 @@ def extractor( train_data = anndata.read("./data/train.h5ad") test_data = anndata.read("./data/test.h5ad") - train_data_extracted_list = extractor( - train_data, "CD4T", "conditions", "cell_type", "control", "stimulated" - ) + train_data_extracted_list = extractor(train_data, "CD4T", "conditions", "cell_type", "control", "stimulated") """ cell_with_both_condition = data[data.obs[cell_type_key] == cell_type] condition_1 = data[(data.obs[cell_type_key] == cell_type) & (data.obs[condition_key] == ctrl_key)] diff --git a/pyproject.toml b/pyproject.toml index 55ddf088..7ea006d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,12 @@ norecursedirs = [ '.*', 'build', 'dist', '*.egg', 'data', '__pycache__'] [tool.ruff] src = ["src"] line-length = 120 -lint.select = [ + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = [ "F", # Errors detected by Pyflakes "E", # Error detected by Pycodestyle "W", # Warning detected by Pycodestyle @@ -146,7 +151,7 @@ lint.select = [ "NPY", # Numpy specific rules "PTH" # Use pathlib ] -lint.ignore = [ +ignore = [ # line too long -> we accept long comment lines; black gets rid of long code lines "E501", # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient