diff --git a/peptdeep/model/ms2.py b/peptdeep/model/ms2.py index fa7b1035..d808ab8d 100644 --- a/peptdeep/model/ms2.py +++ b/peptdeep/model/ms2.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np import warnings - +from peptdeep.utils import logging from typing import List, Tuple, IO from tqdm import tqdm @@ -425,6 +425,53 @@ def _prepare_train_data_df( # if np.all(precursor_df['nce'].values > 1): # precursor_df['nce'] = precursor_df['nce']*self.NCE_factor + def _load_model_from_stream(self, stream: IO): + """ + Overriding this function to allow for partial loading of pretrained models. + Helpful incase of just changing the prediction head (fragment types) of the model, + while keeping the rest of the model weights same. + + Parameters + ---------- + stream : IO + A file stream to load the model from + """ + current_model_dict = self.model.state_dict() + to_be_loaded_dict = torch.load(stream, map_location=self.device) + # load same size and key tensors + filtered_params = {} + size_mismatches = [] + unexpected_keys = [] + for source_key, source_value in to_be_loaded_dict.items(): + if source_key in current_model_dict: + if source_value.size() == current_model_dict[source_key].size(): + filtered_params[source_key] = source_value + else: + size_mismatches.append(source_key) + else: + unexpected_keys.append(source_key) + missing_keys = set(current_model_dict.keys()) - set(filtered_params.keys()) + + self.model.load_state_dict(filtered_params, strict=False) + if size_mismatches or unexpected_keys or missing_keys: + warning_msg = "Some layers might be randomly initialized due to a mismatch between the loaded weights and the model architecture. Make sure to train the model or load different weights before prediction." + warning_msg += ( + f" The following keys had size mismatches: {size_mismatches}" + if size_mismatches + else "" + ) + warning_msg += ( + f" The following keys were unexpected: {unexpected_keys}" + if unexpected_keys + else "" + ) + warning_msg += ( + f" The following keys were missing: {missing_keys}" + if missing_keys + else "" + ) + logging.warning(warning_msg) + def _check_predict_in_order(self, precursor_df: pd.DataFrame): pass diff --git a/peptdeep/pretrained_models.py b/peptdeep/pretrained_models.py index 3d6f8ef2..f6d63a7b 100644 --- a/peptdeep/pretrained_models.py +++ b/peptdeep/pretrained_models.py @@ -270,6 +270,7 @@ def __init__( self, mask_modloss: bool = False, device: str = "gpu", + charged_frag_types: list[str] = None, ): """ Parameters @@ -283,11 +284,20 @@ def __init__( Device for DL models, could be 'gpu' ('cuda') or 'cpu'. if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'. Defaults to 'gpu' + charge_frag_types : list[str], optional + Charge fragment types for MS2 model to override the default configuration in the yaml file. + If set to None, it will use the default configuration in the yaml file. """ self._train_psm_logging = True - self.ms2_model: pDeepModel = pDeepModel( - mask_modloss=mask_modloss, device=device + self.ms2_model: pDeepModel = ( + pDeepModel(mask_modloss=mask_modloss, device=device) + if charged_frag_types is None + else pDeepModel( + mask_modloss=mask_modloss, + device=device, + charged_frag_types=charged_frag_types, + ) ) self.rt_model: AlphaRTModel = AlphaRTModel(device=device) self.ccs_model: AlphaCCSModel = AlphaCCSModel(device=device)