diff --git a/q2_amr/card/reads.py b/q2_amr/card/reads.py index 4b5db7f..5a97bca 100644 --- a/q2_amr/card/reads.py +++ b/q2_amr/card/reads.py @@ -74,13 +74,24 @@ def annotate_reads_card( path=path_allele, samp_bin_name=samp, data_type="reads" ) allele_frequency_list.append(allele_frequency) + path_gene = os.path.join(samp_input_dir, "output.gene_mapping_data.txt") gene_frequency = read_in_txt( path=path_gene, samp_bin_name=samp, data_type="reads" ) gene_frequency_list.append(gene_frequency) - move_files(samp_input_dir, samp_allele_dir, "allele") - move_files(samp_input_dir, samp_gene_dir, "gene") + + # Move mapping and stats files to the sample allele and gene directories + for map_type, des_dir in zip( + ["allele", "gene"], [samp_allele_dir, samp_gene_dir] + ): + files = [f"{map_type}_mapping_data.txt", "overall_mapping_stats.txt"] + + for file in files: + shutil.copy( + os.path.join(samp_input_dir, "output." + file), + os.path.join(des_dir, file), + ) allele_feature_table = create_count_table(allele_frequency_list) gene_feature_table = create_count_table(gene_frequency_list) @@ -92,17 +103,6 @@ def annotate_reads_card( ) -def move_files(source_dir: str, des_dir: str, map_type: str): - shutil.move( - os.path.join(source_dir, f"output.{map_type}_mapping_data.txt"), - os.path.join(des_dir, f"{map_type}_mapping_data.txt"), - ) - shutil.copy( - os.path.join(source_dir, "output.overall_mapping_stats.txt"), - os.path.join(des_dir, "overall_mapping_stats.txt"), - ) - - def run_rgi_bwt( cwd: str, samp: str, diff --git a/q2_amr/tests/card/test_reads.py b/q2_amr/tests/card/test_reads.py index 5ef00a4..4487edf 100644 --- a/q2_amr/tests/card/test_reads.py +++ b/q2_amr/tests/card/test_reads.py @@ -13,7 +13,6 @@ from q2_amr.card.reads import ( annotate_reads_card, extract_sample_stats, - move_files, plot_sample_stats, run_rgi_bwt, visualize_annotation_stats, @@ -147,11 +146,9 @@ def annotate_reads_card_test_body(self, read_type): # resulting CARD annotation objects for num in [0, 1]: map_type = "allele" if num == 0 else "gene" + files = [f"{map_type}_mapping_data.txt", "overall_mapping_stats.txt"] for samp in ["sample1", "sample2"]: - for file in [ - f"{map_type}_mapping_data.txt", - "overall_mapping_stats.txt", - ]: + for file in files: self.assertTrue( os.path.exists(os.path.join(str(result[num]), samp, file)) ) @@ -204,43 +201,6 @@ def test_exception_raised(self): run_rgi_bwt() self.assertEqual(str(cm.exception), expected_message) - def test_move_files_allele(self): - self.move_files_test_body("allele") - - def test_move_files_gene(self): - self.move_files_test_body("gene") - - def move_files_test_body(self, map_type): - with tempfile.TemporaryDirectory() as tmp: - source_dir = os.path.join(tmp, "source_dir") - des_dir = os.path.join( - tmp, - "des_dir", - ) - os.makedirs(os.path.join(source_dir)) - os.makedirs(os.path.join(des_dir)) - mapping_data = self.get_data_path(f"output.{map_type}_mapping_data.txt") - mapping_stats = self.get_data_path("output.overall_mapping_stats.txt") - shutil.copy(mapping_data, source_dir) - shutil.copy(mapping_stats, source_dir) - move_files(source_dir, des_dir, map_type) - self.assertTrue( - os.path.exists( - os.path.join( - des_dir, - f"{map_type}_mapping_data.txt", - ) - ) - ) - self.assertTrue( - os.path.exists( - os.path.join( - des_dir, - "overall_mapping_stats.txt", - ) - ) - ) - def test_extract_sample_stats(self): with tempfile.TemporaryDirectory() as tmp: mapping_stats_path = self.get_data_path("output.overall_mapping_stats.txt")