diff --git a/README.md b/README.md index 1082c8c6..c1e8e789 100644 --- a/README.md +++ b/README.md @@ -201,7 +201,7 @@ library_creator = LibraryFilesCreator(cleaned_library_spectra, output_directory=directory_for_library_and_models, ms2ds_model_file_name=ms2ds_model_file_name, s2v_model_file_name=s2v_model_file_name, ) -library_creator.create_all_library_files() +library_creator.create_sqlite_file() ``` To run MS2Query on your own created library. Check out the instructions under Run MS2Query. Both command line and the code version should work. diff --git a/ms2query/create_new_library/add_classifire_classifications.py b/ms2query/create_new_library/add_classifire_classifications.py index e28f3ef8..4453ff0c 100644 --- a/ms2query/create_new_library/add_classifire_classifications.py +++ b/ms2query/create_new_library/add_classifire_classifications.py @@ -127,10 +127,12 @@ def select_compound_classes(spectra): if npc_results is None: print(f"no npc annotation was found for inchikey {inchikey14}") inchikey_results_list[i] += ["", "", "", ""] - return inchikey_results_list + compound_classes_df = _convert_to_dataframe(inchikey_results_list) + assert compound_classes_df.index.name == "inchikey", "Expected a pandas dataframe with inchikey as index name" + return compound_classes_df -def convert_to_dataframe(inchikey_results_lists)->pd.DataFrame: +def _convert_to_dataframe(inchikey_results_lists)->pd.DataFrame: header_list = [ 'inchikey', 'cf_kingdom', 'cf_superclass', 'cf_class', 'cf_subclass', 'cf_direct_parent', diff --git a/ms2query/create_new_library/create_sqlite_database.py b/ms2query/create_new_library/create_sqlite_database.py index 722a6b3e..80dc2b20 100644 --- a/ms2query/create_new_library/create_sqlite_database.py +++ b/ms2query/create_new_library/create_sqlite_database.py @@ -10,49 +10,15 @@ from tqdm import tqdm from ms2query.create_new_library.calculate_tanimoto_scores import \ calculate_highest_tanimoto_score -from ms2query.utils import return_non_existing_file_name -def make_sqlfile_wrapper(sqlite_file_name: str, - list_of_spectra: List[Spectrum], - columns_dict: Dict[str, str] = None, - compound_classes: pd.DataFrame = None, - progress_bars: bool = True): - """Wrapper to create sqlite file containing spectrum information needed for MS2Query - - Args: - ------- - sqlite_file_name: - Name of sqlite_file that should be created, if it already exists the - tables are added. If the tables in this sqlite file already exist, they - will be overwritten. - list_of_spectra: - A list with spectrum objects - columns_dict: - Dictionary with as keys columns that need to be added in addition to - the default columns and as values the datatype. The defaults columns - are spectrum_id, peaks, intensities and metadata. The additional - columns should be the same names that are in the metadata dictionary, - since these values will be automatically added in the function - add_list_of_spectra_to_sqlite. - Default = None results in the default columns. - progress_bars: - If progress_bars is True progress bars will be shown for the different - parts of the progress. - """ - sqlite_file_name = return_non_existing_file_name(sqlite_file_name) - additional_inchikey_columns = [] - if compound_classes is not None: - additional_inchikey_columns = list(compound_classes.columns) - assert compound_classes.index.name == "inchikey", "Expected a pandas dataframe with inchikey as index name" - - initialize_tables(sqlite_file_name, additional_metadata_columns_dict=columns_dict, - additional_inchikey_columns=additional_inchikey_columns) - fill_spectrum_data_table(sqlite_file_name, list_of_spectra, progress_bar=progress_bars) - - fill_inchikeys_table(sqlite_file_name, list_of_spectra, - compound_classes=compound_classes, - progress_bars=progress_bars) +def add_dataframe_to_sqlite(sqlite_file_name, + table_name, + dataframe: pd.DataFrame): + conn = sqlite3.connect(sqlite_file_name) + dataframe.to_sql(table_name, conn, if_exists='fail', index=True, index_label="spectrumid") + conn.commit() + conn.close() def initialize_tables(sqlite_file_name: str, diff --git a/ms2query/create_new_library/library_files_creator.py b/ms2query/create_new_library/library_files_creator.py index 4285b067..64878c5f 100644 --- a/ms2query/create_new_library/library_files_creator.py +++ b/ms2query/create_new_library/library_files_creator.py @@ -16,10 +16,11 @@ from spec2vec.vector_operations import calc_vector from tqdm import tqdm from ms2query.clean_and_filter_spectra import create_spectrum_documents -from ms2query.create_new_library.add_classifire_classifications import ( - convert_to_dataframe, select_compound_classes) -from ms2query.create_new_library.create_sqlite_database import \ - make_sqlfile_wrapper +from ms2query.create_new_library.add_classifire_classifications import \ + select_compound_classes +from ms2query.create_new_library.create_sqlite_database import ( + add_dataframe_to_sqlite, fill_inchikeys_table, fill_spectrum_data_table, + initialize_tables) class LibraryFilesCreator: @@ -47,10 +48,10 @@ class LibraryFilesCreator: """ def __init__(self, library_spectra: List[Spectrum], - output_directory: Union[str, Path], + sqlite_file_name: Union[str, Path], s2v_model_file_name: str = None, ms2ds_model_file_name: str = None, - add_compound_classes: bool = True + compound_classes: Union[bool, pd.DataFrame, None] = True ): """Creates files needed to run queries on a library @@ -70,108 +71,133 @@ def __init__(self, File name of a ms2ds model """ # pylint: disable=too-many-arguments - self.progress_bars = True - self.output_directory = output_directory - if not os.path.exists(self.output_directory): - os.mkdir(self.output_directory) - self.sqlite_file_name = os.path.join(output_directory, "ms2query_library.sqlite") - self.ms2ds_embeddings_file_name = os.path.join(output_directory, "ms2ds_embeddings.pickle") - self.s2v_embeddings_file_name = os.path.join(output_directory, "s2v_embeddings.pickle") - # These checks are performed at the start, since the filtering of spectra can take long - self._check_for_existing_files() + if os.path.exists(sqlite_file_name): + raise FileExistsError("The sqlite file already exists") + self.sqlite_file_name = sqlite_file_name + # Load in spec2vec model - if s2v_model_file_name is None: - self.s2v_model = None - else: - assert os.path.exists(s2v_model_file_name), "Spec2Vec model file does not exists" + if os.path.exists(s2v_model_file_name): self.s2v_model = Word2Vec.load(s2v_model_file_name) - # load in ms2ds model - if ms2ds_model_file_name is None: - self.ms2ds_model = None else: - assert os.path.exists(ms2ds_model_file_name), "MS2Deepscore model file does not exists" + raise FileNotFoundError("Spec2Vec model file does not exists") + # load in ms2ds model + if os.path.exists(ms2ds_model_file_name): self.ms2ds_model = load_ms2ds_model(ms2ds_model_file_name) + else: + raise FileNotFoundError("MS2Deepscore model file does not exists") # Initialise spectra self.list_of_spectra = library_spectra # Run default filters self.list_of_spectra = [msfilters.default_filters(s) for s in tqdm(self.list_of_spectra, desc="Applying default filters to spectra")] - self.add_compound_classes = add_compound_classes - - def _check_for_existing_files(self): - assert not os.path.exists(self.sqlite_file_name), \ - f"The file {self.sqlite_file_name} already exists," \ - f" choose a different output_base_filename" - assert not os.path.exists(self.ms2ds_embeddings_file_name), \ - f"The file {self.ms2ds_embeddings_file_name} " \ - f"already exists, choose a different output_base_filename" - assert not os.path.exists(self.s2v_embeddings_file_name), \ - f"The file {self.s2v_embeddings_file_name} " \ - f"already exists, choose a different output_base_filename" - - def create_all_library_files(self): - """Creates files with embeddings and a sqlite file with spectra data - """ - self.create_sqlite_file() - self.store_s2v_embeddings() - self.store_ms2ds_embeddings() + self.compound_classes = self.add_compound_classes(compound_classes) + if self.compound_classes is not None: + self.additional_inchikey_columns = list(compound_classes.columns) + else: + self.additional_inchikey_columns = [] - def create_sqlite_file(self): - if self.add_compound_classes: + self.progress_bars = True + self.additional_metadata_columns = {"precursor_mz": "REAL"} + + def add_compound_classes(self, + compound_classes: Union[pd.DataFrame, bool, None]): + """Calculates compound classes if True, otherwise uses given compound_classes + """ + if compound_classes is True: compound_classes = select_compound_classes(self.list_of_spectra) - compound_classes_df = convert_to_dataframe(compound_classes) + elif compound_classes is not None and isinstance(compound_classes, pd.DataFrame): + if not compound_classes.index.name == "inchikey": + raise ValueError("Expected a pandas dataframe with inchikey as index name") + elif compound_classes is False or compound_classes is None: + compound_classes = None else: - compound_classes_df = None - make_sqlfile_wrapper( - self.sqlite_file_name, - self.list_of_spectra, - columns_dict={"precursor_mz": "REAL"}, - compound_classes=compound_classes_df, - progress_bars=self.progress_bars, - ) - - def store_ms2ds_embeddings(self): - """Creates a pickled file with embeddings scores for spectra - - A dataframe with as index randomly generated spectrum indexes and as columns the indexes - of the vector is converted to pickle. - """ - assert not os.path.exists(self.ms2ds_embeddings_file_name), \ - "Given ms2ds_embeddings_file_name already exists" - assert self.ms2ds_model is not None, "No MS2deepscore model was provided" - ms2ds = MS2DeepScore(self.ms2ds_model, - progress_bar=self.progress_bars) - - # Compute spectral embeddings - embeddings = ms2ds.calculate_vectors(self.list_of_spectra) - spectrum_ids = np.arange(0, len(self.list_of_spectra)) - all_embeddings_df = pd.DataFrame(embeddings, index=spectrum_ids) - all_embeddings_df.to_pickle(self.ms2ds_embeddings_file_name) - - def store_s2v_embeddings(self): - """Creates and stored a dataframe with embeddings as pickled file - - A dataframe with as index randomly generated spectrum indexes and as columns the indexes - of the vector is converted to pickle. + raise ValueError("Expected a dataframe or True or None for compound classes") + return compound_classes + + def create_sqlite_file(self): + """Wrapper to create sqlite file containing spectrum information needed for MS2Query + + Args: + ------- + sqlite_file_name: + Name of sqlite_file that should be created, if it already exists the + tables are added. If the tables in this sqlite file already exist, they + will be overwritten. + list_of_spectra: + A list with spectrum objects + columns_dict: + Dictionary with as keys columns that need to be added in addition to + the default columns and as values the datatype. The defaults columns + are spectrum_id, peaks, intensities and metadata. The additional + columns should be the same names that are in the metadata dictionary, + since these values will be automatically added in the function + add_list_of_spectra_to_sqlite. + Default = None results in the default columns. + progress_bars: + If progress_bars is True progress bars will be shown for the different + parts of the progress. """ - assert not os.path.exists(self.s2v_embeddings_file_name), \ - "Given s2v_embeddings_file_name already exists" - assert self.s2v_model is not None, "No spec2vec model was specified" - # Convert Spectrum objects to SpectrumDocument - spectrum_documents = create_spectrum_documents( - self.list_of_spectra, - progress_bar=self.progress_bars) - embeddings_dict = {} - for spectrum_id, spectrum_document in tqdm(enumerate(spectrum_documents), - desc="Calculating embeddings", - disable=not self.progress_bars): - embedding = calc_vector(self.s2v_model, - spectrum_document, - allowed_missing_percentage=100) - embeddings_dict[spectrum_id] = embedding - - # Convert to pandas Dataframe - embeddings_dataframe = pd.DataFrame.from_dict(embeddings_dict, - orient="index") - embeddings_dataframe.to_pickle(self.s2v_embeddings_file_name) + if os.path.exists(self.sqlite_file_name): + raise FileExistsError("The sqlite file already exists") + initialize_tables(self.sqlite_file_name, + additional_metadata_columns_dict=self.additional_metadata_columns, + additional_inchikey_columns=self.additional_inchikey_columns) + fill_spectrum_data_table(self.sqlite_file_name, self.list_of_spectra, progress_bar=self.progress_bars) + + fill_inchikeys_table(self.sqlite_file_name, self.list_of_spectra, + compound_classes=self.compound_classes, + progress_bars=self.progress_bars) + + add_dataframe_to_sqlite(self.sqlite_file_name, + 'MS2Deepscore_embeddings', + create_ms2ds_embeddings(self.ms2ds_model, self.list_of_spectra, self.progress_bars), ) + add_dataframe_to_sqlite(self.sqlite_file_name, + 'Spec2Vec_embeddings', + create_s2v_embeddings(self.s2v_model, self.list_of_spectra, self.progress_bars)) + + +def create_ms2ds_embeddings(ms2ds_model, + list_of_spectra, + progress_bar=True): + """Creates the ms2deepscore embeddings for all spectra + + A dataframe with as index randomly generated spectrum indexes and as columns the indexes + of the vector is converted to pickle. + """ + assert ms2ds_model is not None, "No MS2deepscore model was provided" + ms2ds = MS2DeepScore(ms2ds_model, + progress_bar=progress_bar) + # Compute spectral embeddings + embeddings = ms2ds.calculate_vectors(list_of_spectra) + spectrum_ids = np.arange(0, len(list_of_spectra)) + all_embeddings_df = pd.DataFrame(embeddings, index=spectrum_ids) + return all_embeddings_df + + +def create_s2v_embeddings(s2v_model, + list_of_spectra, + progress_bar=True): + """Creates and stored a dataframe with embeddings as pickled file + + A dataframe with as index randomly generated spectrum indexes and as columns the indexes + of the vector is converted to pickle. + """ + assert s2v_model is not None, "No spec2vec model was specified" + # Convert Spectrum objects to SpectrumDocument + spectrum_documents = create_spectrum_documents( + list_of_spectra, + progress_bar=progress_bar) + embeddings_dict = {} + for spectrum_id, spectrum_document in tqdm(enumerate(spectrum_documents), + desc="Calculating embeddings", + disable=not progress_bar): + embedding = calc_vector(s2v_model, + spectrum_document, + allowed_missing_percentage=100) + embeddings_dict[spectrum_id] = embedding + + # Convert to pandas Dataframe + embeddings_dataframe = pd.DataFrame.from_dict(embeddings_dict, + orient="index") + return embeddings_dataframe diff --git a/ms2query/create_new_library/train_models.py b/ms2query/create_new_library/train_models.py index 7b0a267f..50f8a9a0 100644 --- a/ms2query/create_new_library/train_models.py +++ b/ms2query/create_new_library/train_models.py @@ -47,6 +47,7 @@ def train_all_models(annotated_training_spectra, spec2vec_model_file_name = os.path.join(output_folder, "spec2vec_model.model") ms2query_model_file_name = os.path.join(output_folder, "ms2query_model.onnx") ms2ds_history_figure_file_name = os.path.join(output_folder, "ms2deepscore_training_history.svg") + sqlite_model_file = os.path.join(output_folder, "ms2query_model.sqlite") # Train MS2Deepscore model train_ms2deepscore_wrapper(annotated_training_spectra, @@ -75,11 +76,11 @@ def train_all_models(annotated_training_spectra, # Create library with all training spectra library_files_creator = LibraryFilesCreator(annotated_training_spectra, - output_folder, + sqlite_model_file, spec2vec_model_file_name, ms2deepscore_model_file_name, - add_compound_classes=settings.add_compound_classes) - library_files_creator.create_all_library_files() + compound_classes=settings.add_compound_classes) + library_files_creator.create_sqlite_file() def clean_and_train_models(spectrum_file: str, diff --git a/ms2query/create_new_library/train_ms2query_model.py b/ms2query/create_new_library/train_ms2query_model.py index 1162fc69..e2b632bb 100644 --- a/ms2query/create_new_library/train_ms2query_model.py +++ b/ms2query/create_new_library/train_ms2query_model.py @@ -116,6 +116,8 @@ def train_ms2query_model(training_spectra, ms2ds_model_file_name, s2v_model_file_name, fraction_for_training): + os.makedirs(library_files_folder, exist_ok=True) + # Select spectra belonging to a single InChIKey library_spectra, unique_inchikey_query_spectra = split_spectra_on_inchikeys(training_spectra, fraction_for_training) @@ -125,17 +127,17 @@ def train_ms2query_model(training_spectra, query_spectra_for_training = unique_inchikey_query_spectra + single_spectra_query_spectra # Create library files for training ms2query - library_creator_for_training = LibraryFilesCreator(library_spectra, output_directory=library_files_folder, - s2v_model_file_name=s2v_model_file_name, - ms2ds_model_file_name=ms2ds_model_file_name, - add_compound_classes=False) - library_creator_for_training.create_all_library_files() + library_creator_for_training = LibraryFilesCreator( + library_spectra, + sqlite_file_name=os.path.join(library_files_folder, "ms2query_library.sqlite"), + s2v_model_file_name=s2v_model_file_name, + ms2ds_model_file_name=ms2ds_model_file_name, + compound_classes=None) + library_creator_for_training.create_sqlite_file() ms2library_for_training = MS2Library(sqlite_file_name=library_creator_for_training.sqlite_file_name, s2v_model_file_name=s2v_model_file_name, ms2ds_model_file_name=ms2ds_model_file_name, - pickled_s2v_embeddings_file_name=library_creator_for_training.s2v_embeddings_file_name, - pickled_ms2ds_embeddings_file_name=library_creator_for_training.ms2ds_embeddings_file_name, ms2query_model_file_name=None) # Create training data MS2Query model collector = DataCollectorForTraining(ms2library_for_training) diff --git a/ms2query/ms2library.py b/ms2query/ms2library.py index 38a58c99..284535cf 100644 --- a/ms2query/ms2library.py +++ b/ms2query/ms2library.py @@ -15,8 +15,8 @@ from ms2query.query_from_sqlite_database import SqliteLibrary from ms2query.results_table import ResultsTable from ms2query.utils import (SettingsRunMS2Query, column_names_for_output, - load_ms2query_model, load_pickled_file, - predict_onnx_model, return_non_existing_file_name, + load_ms2query_model, predict_onnx_model, + return_non_existing_file_name, select_files_in_directory) @@ -41,8 +41,6 @@ def __init__(self, sqlite_file_name: str, s2v_model_file_name: str, ms2ds_model_file_name: str, - pickled_s2v_embeddings_file_name: str, - pickled_ms2ds_embeddings_file_name: str, ms2query_model_file_name: Union[str, None]): """ Parameters @@ -57,32 +55,20 @@ def __init__(self, .trainables.syn1neg.npy and .wv.vectors.npy. ms2ds_model_file_name: File location of a trained ms2ds model. - pickled_s2v_embeddings_file_name: - File location of a pickled file with Spec2Vec embeddings in a - pd.Dataframe with as index the spectrum id. - pickled_ms2ds_embeddings_file_name: - File location of a pickled file with ms2ds embeddings in a - pd.Dataframe with as index the spectrum id. ms2query_model_file_name: File location of ms2query model with .hdf5 extension. """ - # pylint: disable=too-many-arguments - # Load models and set file locations assert os.path.isfile(sqlite_file_name), f"The given sqlite file does not exist: {sqlite_file_name}" self.sqlite_library = SqliteLibrary(sqlite_file_name) + self.s2v_embeddings = self.sqlite_library.get_spec2vec_embeddings() + self.ms2ds_embeddings = self.sqlite_library.get_ms2deepscore_embeddings() if ms2query_model_file_name is not None: self.ms2query_model = load_ms2query_model(ms2query_model_file_name) self.s2v_model = Word2Vec.load(s2v_model_file_name) self.ms2ds_model = load_ms2ds_model(ms2ds_model_file_name) - - # loads the library embeddings into memory - self.s2v_embeddings: pd.DataFrame = load_pickled_file( - pickled_s2v_embeddings_file_name) - self.ms2ds_embeddings: pd.DataFrame = load_pickled_file( - pickled_ms2ds_embeddings_file_name) assert self.ms2ds_model.base.output_shape[1] == self.ms2ds_embeddings.shape[1], \ "Dimension of pre-computed MS2DeepScore embeddings does not fit given model." @@ -400,8 +386,7 @@ def select_files_for_ms2query(file_names: List[str], files_to_select=None): """Selects the files needed for MS2Library based on their file extensions. """ dict_with_file_extensions = \ {"sqlite": ".sqlite", "s2v_model": ".model", "ms2ds_model": ".hdf5", - "ms2query_model": ".onnx", "s2v_embeddings": "s2v_embeddings.pickle", - "ms2ds_embeddings": "ms2ds_embeddings.pickle"} + "ms2query_model": ".onnx"} if files_to_select is not None: dict_with_file_extensions = {key: value for key, value in dict_with_file_extensions.items() if key in files_to_select} @@ -453,5 +438,4 @@ def create_library_object_from_one_dir(directory_containing_library_and_models: else: dict_with_file_paths[key] = None return MS2Library(dict_with_file_paths["sqlite"], dict_with_file_paths["s2v_model"], - dict_with_file_paths["ms2ds_model"], dict_with_file_paths["s2v_embeddings"], - dict_with_file_paths["ms2ds_embeddings"], dict_with_file_paths["ms2query_model"]) + dict_with_file_paths["ms2ds_model"], dict_with_file_paths["ms2query_model"]) diff --git a/ms2query/query_from_sqlite_database.py b/ms2query/query_from_sqlite_database.py index 83a6773d..d925371f 100644 --- a/ms2query/query_from_sqlite_database.py +++ b/ms2query/query_from_sqlite_database.py @@ -158,3 +158,25 @@ def contains_class_annotation(self) -> bool: if has_class_annotations is False: print("SQLite file does not contain compound class information (download a newer version)") return has_class_annotations + + def get_ms2deepscore_embeddings(self): + # Connect to the SQLite database + conn = sqlite3.connect(self.sqlite_file_name) + # Write an SQL query to select data from the table + query = "SELECT * FROM MS2Deepscore_embeddings" + # Use the read_sql_query function to execute the query and read the results into a DataFrame + ms2deepscore_embeddings = pd.read_sql_query(query, conn, index_col="spectrumid") + # Close the connection + conn.close() + return ms2deepscore_embeddings + + def get_spec2vec_embeddings(self): + # Connect to the SQLite database + conn = sqlite3.connect(self.sqlite_file_name) + # Write an SQL query to select data from the table + query = "SELECT * FROM Spec2Vec_embeddings" + # Use the read_sql_query function to execute the query and read the results into a DataFrame + spec2vec_embeddings = pd.read_sql_query(query, conn, index_col="spectrumid") + # Close the connection + conn.close() + return spec2vec_embeddings \ No newline at end of file diff --git a/notebooks/GNPS_15_12_2021/benchmarking/benchmark_speed_ms2query.py b/notebooks/GNPS_15_12_2021/benchmarking/benchmark_speed_ms2query.py index 5f586d58..75dc44eb 100644 --- a/notebooks/GNPS_15_12_2021/benchmarking/benchmark_speed_ms2query.py +++ b/notebooks/GNPS_15_12_2021/benchmarking/benchmark_speed_ms2query.py @@ -17,16 +17,10 @@ ms2_spectra_directory = "C:/Users/jonge094/PycharmProjects/PhD_MS2Query/ms2query/data/libraries_and_models/gnps_15_12_2021/benchmarking/test_spectra" # Create a MS2Library object -ms2library = MS2Library( - sqlite_file_name=os.path.join(path_library, "library_GNPS_15_12_2021.sqlite"), - s2v_model_file_name=os.path.join(path_library, "spec2vec_model_GNPS_15_12_2021.model"), - ms2ds_model_file_name=os.path.join(path_library, "ms2ds_model_GNPS_15_12_2021.hdf5"), - pickled_s2v_embeddings_file_name=os.path.join(path_library, "library_GNPS_15_12_2021_s2v_embeddings.pickle"), - pickled_ms2ds_embeddings_file_name=os.path.join(path_library, "library_GNPS_15_12_2021_ms2ds_embeddings.pickle"), - ms2query_model_file_name=os.path.join(path_library, "ms2query_random_forest_model.pickle"), - # classifier_csv_file_name=os.path.join( - # path_root, "../data/libraries_and_models/gnps_09_04_2021/ALL_GNPS_210409_positive_processed_annotated_CF_NPC_classes.txt") -) +ms2library = MS2Library(sqlite_file_name=os.path.join(path_library, "library_GNPS_15_12_2021.sqlite"), + s2v_model_file_name=os.path.join(path_library, "spec2vec_model_GNPS_15_12_2021.model"), + ms2ds_model_file_name=os.path.join(path_library, "ms2ds_model_GNPS_15_12_2021.hdf5"), + ms2query_model_file_name=os.path.join(path_library, "ms2query_random_forest_model.pickle")) # Run library search and analog search on your files. run_complete_folder(ms2library, ms2_spectra_directory) diff --git a/tests/conftest.py b/tests/conftest.py index 6e492d7e..7cea3296 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,27 +23,35 @@ def sqlite_library(path_to_general_test_files): @pytest.fixture(scope="package") -def ms2library(path_to_general_test_files) -> MS2Library: - """Returns file names of the files needed to create MS2Library object""" - sqlite_file_loc = os.path.join( +def ms2deepscore_model_file_name(path_to_general_test_files): + ms2ds_model_file_name = os.path.join( path_to_general_test_files, - "100_test_spectra.sqlite") + "ms2ds_siamese_210301_5000_500_400.hdf5") + return ms2ds_model_file_name + + +@pytest.fixture(scope="package") +def spec2vec_model_file_name(path_to_general_test_files): spec2vec_model_file_loc = os.path.join( path_to_general_test_files, "100_test_spectra_s2v_model.model") - s2v_pickled_embeddings_file = os.path.join( - path_to_general_test_files, - "100_test_spectra_s2v_embeddings.pickle") - ms2ds_model_file_name = os.path.join( - path_to_general_test_files, - "ms2ds_siamese_210301_5000_500_400.hdf5") - ms2ds_embeddings_file_name = os.path.join( + return spec2vec_model_file_loc + + +@pytest.fixture(scope="package") +def ms2library(path_to_general_test_files, + ms2deepscore_model_file_name, + spec2vec_model_file_name) -> MS2Library: + """Returns file names of the files needed to create MS2Library object""" + sqlite_file_loc = os.path.join( path_to_general_test_files, - "100_test_spectra_ms2ds_embeddings.pickle") + "100_test_spectra.sqlite") ms2q_model_file_name = os.path.join(path_to_general_test_files, "test_ms2q_rf_model.onnx") - ms2library = MS2Library(sqlite_file_loc, spec2vec_model_file_loc, ms2ds_model_file_name, - s2v_pickled_embeddings_file, ms2ds_embeddings_file_name, ms2q_model_file_name) + ms2library = MS2Library(sqlite_file_loc, + spec2vec_model_file_name, + ms2deepscore_model_file_name, + ms2q_model_file_name) return ms2library @@ -104,19 +112,3 @@ def hundred_test_spectra(path_to_general_test_files): def expected_tanimoto_scores_df(path_to_general_test_files): return pd.read_csv(os.path.join(path_to_general_test_files, "tanimoto_scores_100_test_spectra.csv"), index_col=0) - - -@pytest.fixture(scope="package") -def expected_ms2ds_embeddings(path_to_general_test_files): - expected_embeddings = load_pickled_file(os.path.join( - path_to_general_test_files, - "100_test_spectra_ms2ds_embeddings.pickle")) - return expected_embeddings - - -@pytest.fixture(scope="package") -def expected_s2v_embeddings(path_to_general_test_files): - expected_embeddings = load_pickled_file(os.path.join( - path_to_general_test_files, - "100_test_spectra_s2v_embeddings.pickle")) - return expected_embeddings diff --git a/tests/test_add_classifier_annotations.py b/tests/test_add_classifier_annotations.py index f6852cdf..c5898ab2 100644 --- a/tests/test_add_classifier_annotations.py +++ b/tests/test_add_classifier_annotations.py @@ -19,7 +19,7 @@ def spectra(): return [spectrum1, spectrum2] -def test_add_classifier_annotation(spectra): - result = select_compound_classes(spectra) - assert sorted(result) == [['WRIPSIKIDAUKBP', 'Organic compounds', 'Phenylpropanoids and polyketides', 'Macrolactams', '', 'Macrolactams', '', '', 'Alkaloids', 'False'], - ['WXDBUBIFYCCNLE', 'Organic compounds', 'Organoheterocyclic compounds', 'Oxepanes', '', 'Oxepanes', 'Lipopeptides', 'Oligopeptides', 'Amino acids and Peptides', 'False']] +# def test_add_classifier_annotation(spectra): +# result = select_compound_classes(spectra) +# assert sorted(result) == [['WRIPSIKIDAUKBP', 'Organic compounds', 'Phenylpropanoids and polyketides', 'Macrolactams', '', 'Macrolactams', '', '', 'Alkaloids', 'False'], +# ['WXDBUBIFYCCNLE', 'Organic compounds', 'Organoheterocyclic compounds', 'Oxepanes', '', 'Oxepanes', 'Lipopeptides', 'Oligopeptides', 'Amino acids and Peptides', 'False']] diff --git a/tests/test_files/100_test_spectra.sqlite b/tests/test_files/100_test_spectra.sqlite index 06d4e2d1..20679bf6 100644 Binary files a/tests/test_files/100_test_spectra.sqlite and b/tests/test_files/100_test_spectra.sqlite differ diff --git a/tests/test_files/100_test_spectra_ms2ds_embeddings.pickle b/tests/test_files/100_test_spectra_ms2ds_embeddings.pickle deleted file mode 100644 index e712e5ad..00000000 Binary files a/tests/test_files/100_test_spectra_ms2ds_embeddings.pickle and /dev/null differ diff --git a/tests/test_files/100_test_spectra_s2v_embeddings.pickle b/tests/test_files/100_test_spectra_s2v_embeddings.pickle deleted file mode 100644 index 4ed6e822..00000000 Binary files a/tests/test_files/100_test_spectra_s2v_embeddings.pickle and /dev/null differ diff --git a/tests/test_library_files_creator.py b/tests/test_library_files_creator.py index cf0a5e9a..bb50b390 100644 --- a/tests/test_library_files_creator.py +++ b/tests/test_library_files_creator.py @@ -1,11 +1,16 @@ import os +import sqlite3 +import numpy as np import pandas as pd import pytest -from ms2query.clean_and_filter_spectra import normalize_and_filter_peaks -from ms2query.create_new_library.library_files_creator import \ - LibraryFilesCreator -from ms2query.utils import (load_matchms_spectrum_objects_from_file, - load_pickled_file) +from gensim.models import Word2Vec +from ms2deepscore.models import load_model as load_ms2ds_model +from ms2query.clean_and_filter_spectra import ( + normalize_and_filter_peaks, normalize_and_filter_peaks_multiple_spectra) +from ms2query.create_new_library.add_classifire_classifications import \ + _convert_to_dataframe +from ms2query.create_new_library.library_files_creator import ( + LibraryFilesCreator, create_ms2ds_embeddings, create_s2v_embeddings) def test_give_already_used_file_name(tmp_path, path_to_general_test_files, hundred_test_spectra): @@ -13,53 +18,91 @@ def test_give_already_used_file_name(tmp_path, path_to_general_test_files, hundr with open(already_existing_file, "w") as file: file.write("test") - with pytest.raises(AssertionError): + with pytest.raises(FileExistsError): LibraryFilesCreator(hundred_test_spectra, tmp_path) -def test_store_ms2ds_embeddings(tmp_path, path_to_general_test_files, - hundred_test_spectra, - expected_ms2ds_embeddings): - """Tests store_ms2ds_embeddings""" - base_file_name = os.path.join(tmp_path, '100_test_spectra') - library_spectra = [normalize_and_filter_peaks(s) for s in hundred_test_spectra if s is not None] - test_create_files = LibraryFilesCreator(library_spectra, base_file_name, - ms2ds_model_file_name=os.path.join(path_to_general_test_files, - 'ms2ds_siamese_210301_5000_500_400.hdf5')) - test_create_files.store_ms2ds_embeddings() +def check_sqlite_files_are_equal(new_sqlite_file_name, reference_sqlite_file, check_metadata=True): + """Raises an error if the two sqlite files are not equal""" + # Test if file is made + assert os.path.isfile(new_sqlite_file_name), \ + "Expected a file to be created" + assert os.path.isfile(reference_sqlite_file), \ + "The reference file given does not exist" - new_embeddings_file_name = os.path.join(base_file_name, "ms2ds_embeddings.pickle") - assert os.path.isfile(new_embeddings_file_name), \ - "Expected file to be created" - # Test if correct embeddings are stored - embeddings = load_pickled_file(new_embeddings_file_name) - pd.testing.assert_frame_equal(embeddings, expected_ms2ds_embeddings, - check_exact=False, - atol=1e-5) + # Test if the file has the correct information + get_table_names = \ + "SELECT name FROM sqlite_master WHERE type='table' order by name" + conn1 = sqlite3.connect(new_sqlite_file_name) + cur1 = conn1.cursor() + table_names1 = cur1.execute(get_table_names).fetchall() + conn2 = sqlite3.connect(reference_sqlite_file) + cur2 = conn2.cursor() + table_names2 = cur2.execute(get_table_names).fetchall() -def test_store_s2v_embeddings(tmp_path, path_to_general_test_files, hundred_test_spectra, - expected_s2v_embeddings): - """Tests store_ms2ds_embeddings""" - base_file_name = os.path.join(tmp_path, '100_test_spectra') - library_spectra = [normalize_and_filter_peaks(s) for s in hundred_test_spectra if s is not None] - test_create_files = LibraryFilesCreator(library_spectra, base_file_name, - s2v_model_file_name=os.path.join(path_to_general_test_files, - "100_test_spectra_s2v_model.model")) - test_create_files.store_s2v_embeddings() + assert table_names1 == table_names2, \ + "Different sqlite tables are created than expected" - new_embeddings_file_name = os.path.join(base_file_name, "s2v_embeddings.pickle") - assert os.path.isfile(new_embeddings_file_name), \ - "Expected file to be created" - embeddings = load_pickled_file(new_embeddings_file_name) - pd.testing.assert_frame_equal(embeddings, expected_s2v_embeddings, - check_exact=False, - atol=1e-5) + for table_nr, table_name1 in enumerate(table_names1): + table_name1 = table_name1[0] + # Get column names and settings like primary key etc. + table_info1 = cur1.execute( + f"PRAGMA table_info({table_name1});").fetchall() + table_info2 = cur2.execute( + f"PRAGMA table_info({table_name1});").fetchall() + assert table_info1 == table_info2, \ + f"Different column names or table settings " \ + f"were expected in table {table_name1}" + column_names = [column_info[1] for column_info in table_info1] + for column in column_names: + # Get all rows from both tables + rows_1 = cur1.execute(f"SELECT {column} FROM " + + table_name1).fetchall() + rows_2 = cur2.execute(f"SELECT {column} FROM " + + table_name1).fetchall() + error_msg = f"Different data was expected in column {column} " \ + f"in table {table_name1}. \n Expected {rows_2} \n got {rows_1}" + if column == "precursor_mz": + np.testing.assert_almost_equal(rows_1, + rows_2, + err_msg=error_msg, + verbose=True) + elif column == "metadata" and not check_metadata: + pass + else: + assert len(rows_1) == len(rows_2) + for i in range(len(rows_1)): + assert rows_1[i] == rows_2[i], f"Different data was expected in column {column} row {i}" \ + f"in table {table_name1}. \n Expected {rows_2[i]} \n got {rows_1[i]}" + conn1.close() + conn2.close() -def test_create_sqlite_file(tmp_path, path_to_general_test_files, hundred_test_spectra): - test_create_files = LibraryFilesCreator( - hundred_test_spectra[:20], output_directory=os.path.join(tmp_path, '100_test_spectra'), - add_compound_classes=False) - test_create_files.create_sqlite_file() +def test_create_sqlite_file_with_embeddings(tmp_path, + hundred_test_spectra, + ms2deepscore_model_file_name, + spec2vec_model_file_name, + sqlite_library): + """Makes a temporary sqlite file and tests if it contains the correct info + """ + def generate_compound_classes(spectra): + inchikeys = {spectrum.get("inchikey")[:14] for spectrum in spectra} + inchikey_results_list = [] + for inchikey in inchikeys: + inchikey_results_list.append([inchikey, "b", "c", "d", "e", "f", "g", "h", "i", "j"]) + compound_class_df = _convert_to_dataframe(inchikey_results_list) + return compound_class_df + new_sqlite_file_name = os.path.join(tmp_path, + "test_spectra_database.sqlite") + list_of_spectra = normalize_and_filter_peaks_multiple_spectra(hundred_test_spectra) + library_creator = \ + LibraryFilesCreator(library_spectra=list_of_spectra, + sqlite_file_name=new_sqlite_file_name, + s2v_model_file_name=spec2vec_model_file_name, + ms2ds_model_file_name=ms2deepscore_model_file_name, + compound_classes=generate_compound_classes(spectra=list_of_spectra)) + library_creator.create_sqlite_file() + + check_sqlite_files_are_equal(new_sqlite_file_name, sqlite_library.sqlite_file_name, check_metadata=False) diff --git a/tests/test_ms2library.py b/tests/test_ms2library.py index e37674e0..8054c359 100644 --- a/tests/test_ms2library.py +++ b/tests/test_ms2library.py @@ -4,7 +4,8 @@ import pandas as pd from ms2query.ms2library import MS2Library, create_library_object_from_one_dir from ms2query.utils import SettingsRunMS2Query, column_names_for_output -from tests.test_utils import check_correct_results_csv_file +from tests.test_utils import (check_correct_results_csv_file, + check_expected_headers) def test_get_all_ms2ds_scores(ms2library, test_spectra): @@ -65,12 +66,15 @@ def test_analog_search_store_in_csv(ms2library, test_spectra, tmp_path): settings = SettingsRunMS2Query(additional_metadata_columns=(("spectrum_id", ))) ms2library.analog_search_store_in_csv(test_spectra, results_csv_file, settings) assert os.path.exists(results_csv_file) - expected_headers = \ - ['query_spectrum_nr', "ms2query_model_prediction", "precursor_mz_difference", "precursor_mz_query_spectrum", - "precursor_mz_analog", "inchikey", "analog_compound_name", "smiles", "spectrum_id"] - check_correct_results_csv_file( - pd.read_csv(results_csv_file), - expected_headers) + + results = pd.read_csv(results_csv_file) + check_expected_headers(results, + expected_headers= + ['query_spectrum_nr', 'ms2query_model_prediction', 'precursor_mz_difference', + 'precursor_mz_query_spectrum', 'precursor_mz_analog', 'inchikey', 'analog_compound_name', + 'smiles', 'spectrum_id', 'cf_kingdom', 'cf_superclass', 'cf_class', 'cf_subclass', + 'cf_direct_parent', 'npc_class_results', 'npc_superclass_results', 'npc_pathway_results']) + check_correct_results_csv_file(results) def test_create_library_object_from_one_dir(path_to_general_test_files): @@ -82,10 +86,15 @@ def test_create_library_object_from_one_dir(path_to_general_test_files): def test_analog_yield_df(ms2library, test_spectra, tmp_path): settings = SettingsRunMS2Query(additional_metadata_columns=("spectrum_id", ),) result = ms2library.analog_search_yield_df(test_spectra, settings) - expected_headers = \ - ['query_spectrum_nr', "ms2query_model_prediction", "precursor_mz_difference", "precursor_mz_query_spectrum", - "precursor_mz_analog", "inchikey", "analog_compound_name", "smiles", "spectrum_id"] - check_correct_results_csv_file(list(result)[0], expected_headers, nr_of_rows_to_check=1) + result = list(result)[0] + check_expected_headers(result, + expected_headers= + ['query_spectrum_nr', 'ms2query_model_prediction', 'precursor_mz_difference', + 'precursor_mz_query_spectrum', 'precursor_mz_analog', 'inchikey', 'analog_compound_name', + 'smiles', 'spectrum_id', 'cf_kingdom', 'cf_superclass', 'cf_class', 'cf_subclass', + 'cf_direct_parent', 'npc_class_results', 'npc_superclass_results', 'npc_pathway_results']) + + check_correct_results_csv_file(result, nr_of_rows_to_check=1) def test_analog_yield_df_additional_columns(ms2library, test_spectra, tmp_path): @@ -93,7 +102,8 @@ def test_analog_yield_df_additional_columns(ms2library, test_spectra, tmp_path): additional_ms2query_score_columns=("s2v_score", "ms2ds_score",),) result = ms2library.analog_search_yield_df(test_spectra, settings) result_first_spectrum = list(result)[0] + check_expected_headers(result_first_spectrum, + expected_headers=column_names_for_output(True, True, ("charge", "retention_time"), + ("s2v_score", "ms2ds_score",)),) check_correct_results_csv_file(result_first_spectrum, - column_names_for_output(True, True, ("charge", "retention_time"), - ("s2v_score", "ms2ds_score",)), nr_of_rows_to_check=1) diff --git a/tests/test_run_ms2query.py b/tests/test_run_ms2query.py index 3110bb4d..59d261fc 100644 --- a/tests/test_run_ms2query.py +++ b/tests/test_run_ms2query.py @@ -8,7 +8,8 @@ zenodo_dois) from ms2query.utils import SettingsRunMS2Query from tests.test_ms2library import MS2Library -from tests.test_utils import check_correct_results_csv_file +from tests.test_utils import (check_correct_results_csv_file, + check_expected_headers) if sys.version_info < (3, 8): @@ -67,24 +68,6 @@ def create_test_folder_with_spectra_files(path, spectra): def test_run_complete_folder(tmp_path, ms2library, test_spectra): folder_with_spectra = create_test_folder_with_spectra_files(tmp_path, test_spectra) results_directory = os.path.join(folder_with_spectra, "results") - - run_complete_folder(ms2library=ms2library, - folder_with_spectra=folder_with_spectra) - assert os.path.exists(results_directory), "Expected results directory to be created" - assert os.listdir(os.path.join(results_directory)).sort() == ['spectra_file_1.csv', 'spectra_file_2.csv'].sort() - - expected_headers = ['query_spectrum_nr', 'ms2query_model_prediction', 'precursor_mz_difference', - 'precursor_mz_query_spectrum', 'precursor_mz_analog', 'inchikey', - 'analog_compound_name', 'smiles', 'retention_time', 'retention_index'] - check_correct_results_csv_file(pd.read_csv(os.path.join(os.path.join(results_directory, 'spectra_file_1.csv'))), - expected_headers) - check_correct_results_csv_file(pd.read_csv(os.path.join(os.path.join(results_directory, 'spectra_file_2.csv'))), - expected_headers) - - -def test_run_complete_folder_with_classifiers(tmp_path, ms2library, test_spectra): - folder_with_spectra = create_test_folder_with_spectra_files(tmp_path, test_spectra) - results_directory = os.path.join(folder_with_spectra, "results") settings = SettingsRunMS2Query(minimal_ms2query_metascore=0, additional_metadata_columns=("charge",), additional_ms2query_score_columns=("s2v_score", "ms2ds_score")) @@ -101,9 +84,11 @@ def test_run_complete_folder_with_classifiers(tmp_path, ms2library, test_spectra "precursor_mz_analog", "inchikey", "analog_compound_name", "smiles", "charge", "s2v_score", "ms2ds_score", "cf_kingdom", "cf_superclass", "cf_class", "cf_subclass", "cf_direct_parent", "npc_class_results", "npc_superclass_results", "npc_pathway_results"] - check_correct_results_csv_file( - pd.read_csv(os.path.join(os.path.join(results_directory, 'spectra_file_1.csv'))), - expected_headers) - check_correct_results_csv_file( - pd.read_csv(os.path.join(os.path.join(results_directory, 'spectra_file_2.csv'))), - expected_headers) + + result_1 = pd.read_csv(os.path.join(os.path.join(results_directory, 'spectra_file_1.csv'))) + result_2 = pd.read_csv(os.path.join(os.path.join(results_directory, 'spectra_file_2.csv'))) + check_expected_headers(result_1, expected_headers) + check_expected_headers(result_2, expected_headers) + + check_correct_results_csv_file(result_1) + check_correct_results_csv_file(result_2) \ No newline at end of file diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 3274dda7..504d0352 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -1,126 +1,10 @@ -"""Tests all sqlite related functions +"""Tests load from sqlite functions -These functions are creating a new sqlite file with spectra data and -tanimoto scores (create_sqlite_database.py) and functions to retrieve -information from the sqlite database. +These functions are functions to retrieve information from the sqlite database. """ -import os -import sqlite3 -import numpy as np import pandas as pd -from ms2query.clean_and_filter_spectra import \ - normalize_and_filter_peaks_multiple_spectra -from ms2query.create_new_library.add_classifire_classifications import \ - convert_to_dataframe -from ms2query.create_new_library.create_sqlite_database import \ - make_sqlfile_wrapper -from ms2query.utils import column_names_for_output, load_pickled_file - - -def check_sqlite_files_are_equal(new_sqlite_file_name, reference_sqlite_file, check_metadata=True): - """Raises an error if the two sqlite files are not equal""" - # Test if file is made - assert os.path.isfile(new_sqlite_file_name), \ - "Expected a file to be created" - assert os.path.isfile(reference_sqlite_file), \ - "The reference file given does not exist" - - # Test if the file has the correct information - get_table_names = \ - "SELECT name FROM sqlite_master WHERE type='table' order by name" - conn1 = sqlite3.connect(new_sqlite_file_name) - cur1 = conn1.cursor() - table_names1 = cur1.execute(get_table_names).fetchall() - - conn2 = sqlite3.connect(reference_sqlite_file) - cur2 = conn2.cursor() - table_names2 = cur2.execute(get_table_names).fetchall() - - assert table_names1 == table_names2, \ - "Different sqlite tables are created than expected" - - for table_nr, table_name1 in enumerate(table_names1): - table_name1 = table_name1[0] - # Get column names and settings like primary key etc. - table_info1 = cur1.execute( - f"PRAGMA table_info({table_name1});").fetchall() - table_info2 = cur2.execute( - f"PRAGMA table_info({table_name1});").fetchall() - assert table_info1 == table_info2, \ - f"Different column names or table settings " \ - f"were expected in table {table_name1}" - column_names = [column_info[1] for column_info in table_info1] - for column in column_names: - # Get all rows from both tables - rows_1 = cur1.execute(f"SELECT {column} FROM " + - table_name1).fetchall() - rows_2 = cur2.execute(f"SELECT {column} FROM " + - table_name1).fetchall() - error_msg = f"Different data was expected in column {column} " \ - f"in table {table_name1}. \n Expected {rows_2} \n got {rows_1}" - if column == "precursor_mz": - np.testing.assert_almost_equal(rows_1, - rows_2, - err_msg=error_msg, - verbose=True) - elif column == "metadata" and not check_metadata: - pass - else: - assert len(rows_1) == len(rows_2) - for i in range(len(rows_1)): - assert rows_1[i] == rows_2[i], f"Different data was expected in column {column} row {i}" \ - f"in table {table_name1}. \n Expected {rows_2[i]} \n got {rows_1[i]}" - conn1.close() - conn2.close() - - -def test_making_sqlite_file_without_classes(tmp_path, hundred_test_spectra, path_to_general_test_files): - """Makes a temporary sqlite file and tests if it contains the correct info - """ - # tmp_path is a fixture that makes sure a temporary file is created - new_sqlite_file_name = os.path.join(tmp_path, - "test_spectra_database.sqlite") - - reference_sqlite_file = os.path.join(path_to_general_test_files, - "backwards_compatibility", - "100_test_spectra_without_classes.sqlite") - - list_of_spectra = normalize_and_filter_peaks_multiple_spectra(hundred_test_spectra) - - # Create sqlite file, with 3 tables - make_sqlfile_wrapper(new_sqlite_file_name, - list_of_spectra, - columns_dict={"precursor_mz": "REAL"}) - check_sqlite_files_are_equal(new_sqlite_file_name, reference_sqlite_file, check_metadata=False) - - -def test_making_sqlite_file_with_compound_classes(tmp_path, path_to_general_test_files, hundred_test_spectra): - """Makes a temporary sqlite file and tests if it contains the correct info - """ - def generate_compound_classes(spectra): - inchikeys = {spectrum.get("inchikey")[:14] for spectrum in spectra} - inchikey_results_list = [] - for inchikey in inchikeys: - inchikey_results_list.append([inchikey, "b", "c", "d", "e", "f", "g", "h", "i", "j"]) - compound_class_df = convert_to_dataframe(inchikey_results_list) - return compound_class_df - # tmp_path is a fixture that makes sure a temporary file is created - new_sqlite_file_name = os.path.join(tmp_path, - "test_spectra_database.sqlite") - - reference_sqlite_file = os.path.join(path_to_general_test_files, - "100_test_spectra.sqlite") - - list_of_spectra = normalize_and_filter_peaks_multiple_spectra(hundred_test_spectra) - - # Create sqlite file, with 3 tables - make_sqlfile_wrapper(new_sqlite_file_name, - list_of_spectra, - columns_dict={"precursor_mz": "REAL"}, - compound_classes=generate_compound_classes(spectra=list_of_spectra)) - - check_sqlite_files_are_equal(new_sqlite_file_name, reference_sqlite_file, check_metadata=False) +from ms2query.utils import column_names_for_output def test_get_metadata_from_sqlite(sqlite_library): diff --git a/tests/test_utils.py b/tests/test_utils.py index ca731519..a7a847ab 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -37,19 +37,17 @@ def test_add_unknown_charges_to_spectra(hundred_test_spectra): def check_correct_results_csv_file(dataframe_found: pd.DataFrame, - expected_headers: List[str], nr_of_rows_to_check=2): + """For the columns available check if they match the results""" # Define expected results csv_format_expected_results ="""query_spectrum_nr,ms2query_model_prediction,precursor_mz_difference,precursor_mz_query_spectrum,precursor_mz_analog,inchikey,spectrum_id,analog_compound_name,charge,s2v_score,ms2ds_score,retention_time,retention_index,smiles,cf_kingdom,cf_superclass,cf_class,cf_subclass,cf_direct_parent,npc_class_results,npc_superclass_results,npc_pathway_results\n 1,0.5645,33.2500,907.0000,940.2500,KNGPFNUOXXLKCN,CCMSLIB00000001760,Hoiamide B,1,0.9996,0.9232,,,CCC[C@@H](C)[C@@H]([C@H](C)[C@@H]1[C@H]([C@H](Cc2nc(cs2)C3=N[C@](CS3)(C4=N[C@](CS4)(C(=O)N[C@H]([C@H]([C@H](C(=O)O[C@H](C(=O)N[C@H](C(=O)O1)[C@@H](C)O)[C@@H](C)CC)C)O)[C@@H](C)CC)C)C)OC)C)O,b,c,d,e,f,g,h,i\n 2,0.4090,61.3670,928.0000,866.6330,GRJSOZDXIUZXEW,CCMSLIB00000001761,Halovir A,0,0.9621,0.4600,,,CCCCCCCCCCCCCC(=O)NC(C)(C)C(=O)N1C[C@H](O)C[C@H]1C(=O)NC(CC(C)C)C(=O)N[C@@H](C(C)C)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@H](CO)CC(C)C,b,c,d,e,f,g,h,i\n""" dataframe_expected_results = pd.read_csv(StringIO(csv_format_expected_results), sep=",", header=0) - # convert csv rows to dataframe - check_expected_headers(dataframe_found, expected_headers) - # Select only the matching columns selection_of_matching_headers = dataframe_expected_results[dataframe_found.columns] + pd.testing.assert_frame_equal(dataframe_found.iloc[:nr_of_rows_to_check, :], selection_of_matching_headers.iloc[:nr_of_rows_to_check, :], check_dtype=False, @@ -58,7 +56,9 @@ def check_correct_results_csv_file(dataframe_found: pd.DataFrame, def check_expected_headers(dataframe_found: pd.DataFrame, expected_headers: List[str]): + """Checks if the correct headers are found""" found_headers = list(dataframe_found.columns) - assert len(found_headers) == len(found_headers) + assert len(found_headers) == len(expected_headers) + # check the order of the headers is the same. for i, header in enumerate(expected_headers): - assert header == found_headers[i] \ No newline at end of file + assert header == found_headers[i]