diff --git a/q2_amr/amrfinderplus/feature_data.py b/q2_amr/amrfinderplus/feature_data.py index 2ba921e..0c135f3 100644 --- a/q2_amr/amrfinderplus/feature_data.py +++ b/q2_amr/amrfinderplus/feature_data.py @@ -32,24 +32,34 @@ def _validate_inputs(mags, loci, proteins): def _get_file_paths(file, mags, proteins, loci): + # If mags is provided, mag_id is extracted from the file name. if mags: mag_id = os.path.splitext(os.path.basename(file))[0] + + # If proteins are provided, construct the expected protein file path. if proteins: protein_path = os.path.join(str(proteins), f"{mag_id}_proteins.fasta") + + # Raise an error if the expected protein file does not exist. if not os.path.exists(protein_path): raise ValueError( f"Proteins file for ID '{mag_id}' is missing in proteins input." ) else: protein_path = None - elif proteins: + + # If only proteins are provided (without mags), determine mag_id and protein path. + else: + # Extract mag_id from the file name, excluding the last 9 characters + # '_proteins'. mag_id = os.path.splitext(os.path.basename(file))[0][:-9] protein_path = file - else: - raise ValueError("Either mags or proteins must be provided.") + # If loci are provided, construct the expected GFF file path. if loci: gff_path = os.path.join(str(loci), f"{mag_id}_loci.gff") + + # Raise an error if the expected GFF file does not exist. if not os.path.exists(gff_path): raise ValueError(f"GFF file for ID '{mag_id}' is missing in loci input.") else: @@ -58,13 +68,17 @@ def _get_file_paths(file, mags, proteins, loci): return mag_id, protein_path, gff_path -def _move_or_create_files(src_dir: str, mag_id: str, file_operations: dict): +def _move_or_create_files(src_dir: str, mag_id: str, file_operations: list): + # Loop through all files. for file_name, target_dir in file_operations: + # If the file exists move it to the destination dir and attach mag_id. if os.path.exists(os.path.join(src_dir, file_name)): shutil.move( os.path.join(src_dir, file_name), os.path.join(str(target_dir), f"{mag_id}_{file_name}"), ) + # If the file does not exist, create empty placeholder file in the + # destination dir. else: with open(os.path.join(str(target_dir), f"{mag_id}_{file_name}"), "w"): pass @@ -92,22 +106,26 @@ def annotate_feature_data_amrfinderplus( GenesDirectoryFormat, ProteinsDirectoryFormat, ): - # Check for unallowed input combinations + # Check for unallowed input combinations. _validate_inputs(mags, loci, proteins) - # Create all output directory formats + # Create all output directories. amr_annotations = AMRFinderPlusAnnotationsDirFmt() amr_all_mutations = AMRFinderPlusAnnotationsDirFmt() amr_genes = GenesDirectoryFormat() amr_proteins = ProteinsDirectoryFormat() + # Create list of files to loop over, if mags is provided then files in mags will be + # used if only proteins is provided then files in proteins will be used if mags: files = glob.glob(os.path.join(str(mags), "*")) - elif proteins: + else: files = glob.glob(os.path.join(str(proteins), "*")) with tempfile.TemporaryDirectory() as tmp: + # Loop over all files for file in files: + # Get paths to protein and gff files, and get mag_id mag_id, protein_path, gff_path = _get_file_paths(file, mags, proteins, loci) # Run amrfinderplus @@ -130,9 +148,7 @@ def annotate_feature_data_amrfinderplus( threads=threads, ) - # Move mutations, genes and proteins files from tmp dir to the output - # directory format, if organism, dna_sequence and proteins parameters - # are specified. Else create empty placeholder files. + # Output filenames and output directories file_operations = [ ("amr_annotations.tsv", amr_annotations), ("amr_all_mutations.tsv", amr_all_mutations), @@ -140,7 +156,7 @@ def annotate_feature_data_amrfinderplus( ("amr_proteins.fasta", amr_proteins), ] - # Loop through each file operation + # Move the files or create placeholder files _move_or_create_files(tmp, mag_id, file_operations) return amr_annotations, amr_all_mutations, amr_genes, amr_proteins diff --git a/q2_amr/amrfinderplus/tests/test_feature_data.py b/q2_amr/amrfinderplus/tests/test_feature_data.py index 0825780..e077e52 100644 --- a/q2_amr/amrfinderplus/tests/test_feature_data.py +++ b/q2_amr/amrfinderplus/tests/test_feature_data.py @@ -1,93 +1,220 @@ import os +from pathlib import Path from unittest.mock import MagicMock, patch -from q2_types.genome_data import LociDirectoryFormat +from q2_types.feature_data_mag import MAGSequencesDirFmt +from q2_types.genome_data import ProteinsDirectoryFormat from qiime2.plugin.testing import TestPluginBase -from q2_amr.amrfinderplus.sequences import annotate_feature_data_amrfinderplus -from q2_amr.amrfinderplus.tests.test_sample_data import mock_run_amrfinderplus_n +from q2_amr.amrfinderplus.feature_data import ( + _get_file_paths, + _move_or_create_files, + _validate_inputs, + annotate_feature_data_amrfinderplus, +) -class TestAnnotateSequencesAMRFinderPlus(TestPluginBase): +class TestValidateInputs(TestPluginBase): package = "q2_amr.amrfinderplus.tests" - def test_annotate_sequences_amrfinderplus_dna(self): - # dna_sequences = DNASequencesDirectoryFormat() - # with open(os.path.join(str(dna_sequences), "dna-sequences.fasta"), "w"): - # pass - dna_sequences = MagicMock() - self._helper( - dna_sequences=dna_sequences, proteins=None, gff=None, organism=None + def test_loci_mags(self): + with self.assertRaisesRegex( + ValueError, + "Loci input can only be given in combination with proteins input.", + ): + _validate_inputs(mags="mags", loci="loci", proteins=None) + + def test_no_loci_protein_mags(self): + with self.assertRaisesRegex( + ValueError, + "MAGs and proteins inputs together can only be given in combination with " + "loci input.", + ): + _validate_inputs(mags="mags", loci=None, proteins="proteins") + + def test_no_protein_no_mags(self): + with self.assertRaisesRegex( + ValueError, "MAGs or proteins input has to be provided." + ): + _validate_inputs(mags=None, loci="loci_directory", proteins=None) + + +class TestMoveOrCreateFiles(TestPluginBase): + package = "q2_amr.amrfinderplus.tests" + + def setUp(self): + super().setUp() + + self.tmp = self.temp_dir.name + self.src_dir = os.path.join(self.tmp, "src_dir") + self.target_dir = os.path.join(self.tmp, "target_dir") + os.mkdir(self.src_dir) + os.mkdir(self.target_dir) + + def test_move_file(self): + # Create a dummy file in the source directory + with open(os.path.join(self.src_dir, "test_file.txt"), "w"): + pass + + # Define the file operations + file_operations = [("test_file.txt", self.target_dir)] + + # Run the function + _move_or_create_files( + src_dir=self.src_dir, + mag_id="mag", + file_operations=file_operations, ) - def test_annotate_sequences_amrfinderplus_prot_gff(self): - proteins = MagicMock() - gff = LociDirectoryFormat() - gff_content = "##gff-version 3\nchr1\t.\tgene\t1\t1000\t.\t+\t.\tID=gene1" - with open(os.path.join(str(gff), "loci.gff"), "w") as file: - file.write(gff_content) - self._helper( - dna_sequences=None, - proteins=proteins, - gff=gff, - organism="Escherichia", + # Assert the file was moved + self.assertTrue( + os.path.exists(os.path.join(self.target_dir, "mag_test_file.txt")) ) - def test_annotate_sequences_amrfinderplus_dna_gff(self): - dna_sequences = MagicMock() - gff = MagicMock() - amrfinderplus_db = MagicMock() + def test_file_missing_create_placeholder(self): + # Define the file operations + file_operations = [("test_file.txt", self.target_dir)] + + # Run the function + _move_or_create_files( + src_dir=self.src_dir, + mag_id="mag", + file_operations=file_operations, + ) + + # Assert the file was moved + self.assertTrue( + os.path.exists(os.path.join(self.target_dir, "mag_test_file.txt")) + ) + + def test_with_mags_and_proteins_file_missing(self): with self.assertRaisesRegex( - ValueError, - "GFF input can only be given in combination with proteis input.", + ValueError, "Proteins file for ID 'mag_id' is missing in proteins input." ): - annotate_feature_data_amrfinderplus( - mags=dna_sequences, - loci=gff, - amrfinderplus_db=amrfinderplus_db, - ) + _get_file_paths("path/mag_id.fasta", "path/mags", "path/proteins", None) + - def test_annotate_sequences_amrfinderplus_dna_prot(self): - dna_sequences = MagicMock() - proteins = MagicMock() - amrfinderplus_db = MagicMock() +class TestGetFilePaths(TestPluginBase): + package = "q2_amr.amrfinderplus.tests" + + def setUp(self): + super().setUp() + + self.test_dir = self.temp_dir + self.test_dir_path = Path(self.test_dir.name) + self.file_path = self.test_dir_path / "test_file.fasta" + self.file_path.touch() # Create an empty test file + + def test_with_mags_and_proteins_file_exists(self): + protein_file_path = self.test_dir_path / "test_file_proteins.fasta" + protein_file_path.touch() # Create an empty protein file + + mag_id, protein_path, gff_path = _get_file_paths( + file=self.file_path, + mags=self.test_dir_path, + proteins=self.test_dir_path, + loci=None, + ) + self.assertEqual(mag_id, "test_file") + self.assertEqual(protein_path, str(protein_file_path)) + self.assertIsNone(gff_path) + + def test_with_mags_and_proteins_file_missing(self): with self.assertRaisesRegex( ValueError, - "DNA-sequence and protein-sequence inputs together can only be given in " - "combination with GFF input.", + "Proteins file for ID 'test_file' is missing in proteins input.", ): - annotate_feature_data_amrfinderplus( - mags=dna_sequences, - proteins=proteins, - amrfinderplus_db=amrfinderplus_db, + _get_file_paths( + file=self.file_path, + mags=self.test_dir_path, + proteins=self.test_dir_path, + loci=None, ) - def _helper(self, dna_sequences, proteins, gff, organism): - amrfinderplus_db = MagicMock() - with patch( - "q2_amr.amrfinderplus.sequences.run_amrfinderplus_n", - side_effect=mock_run_amrfinderplus_n, + def test_with_proteins_only(self): + protein_file_path = self.test_dir_path / "test_file_proteins.fasta" + protein_file_path.touch() # Create an empty protein file + + mag_id, protein_path, gff_path = _get_file_paths( + file=protein_file_path, mags=None, proteins=self.test_dir_path, loci=None + ) + self.assertEqual(mag_id, "test_file") + self.assertEqual(protein_path, protein_file_path) + self.assertIsNone(gff_path) + + def test_with_loci_file_exists(self): + gff_file_path = self.test_dir_path / "test_file_loci.gff" + gff_file_path.touch() # Create an empty GFF file + + mag_id, protein_path, gff_path = _get_file_paths( + file=self.file_path, + mags=self.test_dir_path, + proteins=None, + loci=self.test_dir_path, + ) + self.assertEqual(mag_id, "test_file") + self.assertIsNone(protein_path) + self.assertEqual(gff_path, str(gff_file_path)) + + def test_with_loci_file_missing(self): + with self.assertRaisesRegex( + ValueError, "GFF file for ID 'test_file' is missing in loci input." ): - result = annotate_feature_data_amrfinderplus( - mags=dna_sequences, - proteins=proteins, - loci=gff, - amrfinderplus_db=amrfinderplus_db, - organism=organism, + _get_file_paths( + file=self.file_path, + mags=self.test_dir_path, + proteins=None, + loci=self.test_dir_path, ) - self.assertTrue( - os.path.exists(os.path.join(str(result[0]), "amr_annotations.tsv")) - ) - if organism: - self.assertTrue( - os.path.exists( - os.path.join(str(result[1]), "amr_all_mutations.tsv") - ) - ) - self.assertTrue( - os.path.exists(os.path.join(str(result[2]), "amr_genes.fasta")) - ) - self.assertTrue( - os.path.exists(os.path.join(str(result[3]), "amr_proteins.fasta")) - ) + def test_with_mags_proteins_and_loci_all_files_exist(self): + protein_file_path = self.test_dir_path / "test_file_proteins.fasta" + gff_file_path = self.test_dir_path / "test_file_loci.gff" + protein_file_path.touch() # Create an empty protein file + gff_file_path.touch() # Create an empty GFF file + + mag_id, protein_path, gff_path = _get_file_paths( + file=self.file_path, + mags=self.test_dir_path, + proteins=self.test_dir_path, + loci=self.test_dir_path, + ) + self.assertEqual(mag_id, "test_file") + self.assertEqual(protein_path, str(protein_file_path)) + self.assertEqual(gff_path, str(gff_file_path)) + + +class TestAnnotateFeatureDataAMRFinderPlus(TestPluginBase): + package = "q2_amr.amrfinderplus.tests" + + @patch("q2_amr.amrfinderplus.feature_data._validate_inputs") + @patch( + "q2_amr.amrfinderplus.feature_data._get_file_paths", + return_value=("mag_id", "protein_path", "gff_path"), + ) + @patch("q2_amr.amrfinderplus.feature_data.run_amrfinderplus_n") + @patch("q2_amr.amrfinderplus.feature_data._move_or_create_files") + def test_annotate_feature_data_amrfinderplus_mags( + self, mock_validate, mock_paths, mock_run, mock_move + ): + mags = MAGSequencesDirFmt() + with open(os.path.join(str(mags), "mag.fasta"), "w"): + pass + annotate_feature_data_amrfinderplus(amrfinderplus_db=MagicMock(), mags=mags) + + @patch("q2_amr.amrfinderplus.feature_data._validate_inputs") + @patch( + "q2_amr.amrfinderplus.feature_data._get_file_paths", + return_value=("mag_id", "protein_path", "gff_path"), + ) + @patch("q2_amr.amrfinderplus.feature_data.run_amrfinderplus_n") + @patch("q2_amr.amrfinderplus.feature_data._move_or_create_files") + def test_annotate_feature_data_amrfinderplus_proteins( + self, mock_validate, mock_paths, mock_run, mock_move + ): + proteins = ProteinsDirectoryFormat() + with open(os.path.join(str(proteins), "proteins.fasta"), "w"): + pass + annotate_feature_data_amrfinderplus( + amrfinderplus_db=MagicMock(), proteins=proteins + )