Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow partial loading for pre trained ms2 models #226

Open
wants to merge 2 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion peptdeep/model/ms2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import numpy as np
import warnings

from peptdeep.utils import logging
from typing import List, Tuple, IO

from tqdm import tqdm
Expand Down Expand Up @@ -425,6 +425,53 @@ def _prepare_train_data_df(
# if np.all(precursor_df['nce'].values > 1):
# precursor_df['nce'] = precursor_df['nce']*self.NCE_factor

def _load_model_from_stream(self, stream: IO):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be two cases:

  1. we save partial model params into the file and this method allows us to load this partial model.
  2. we save the full model but we only need to load partial params by specifying param names or first Kth layers.

Copy link
Collaborator Author

@mo-sameh mo-sameh Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I’m not sure I fully understand your concerns. The current interface saves the complete model weights, so I’m unclear about why users would want to save only partial models. Are you suggesting we add this functionality?

As for loading partial weights, the current implementation(in this PR) should already handle this automatically. It matches parameter keys and sizes, loading the matching weights while initializing the remaining parameters from scratch. Are you suggesting we modify the interface to let users explicitly specify which layers to load?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean the use cases.

"""
Overriding this function to allow for partial loading of pretrained models.
Helpful incase of just changing the prediction head (fragment types) of the model,
while keeping the rest of the model weights same.

Parameters
----------
stream : IO
A file stream to load the model from
"""
current_model_dict = self.model.state_dict()
to_be_loaded_dict = torch.load(stream, map_location=self.device)
# load same size and key tensors
filtered_params = {}
size_mismatches = []
unexpected_keys = []
for source_key, source_value in to_be_loaded_dict.items():
if source_key in current_model_dict:
if source_value.size() == current_model_dict[source_key].size():
filtered_params[source_key] = source_value
else:
size_mismatches.append(source_key)
else:
unexpected_keys.append(source_key)
missing_keys = set(current_model_dict.keys()) - set(filtered_params.keys())

self.model.load_state_dict(filtered_params, strict=False)
if size_mismatches or unexpected_keys or missing_keys:
warning_msg = "Some layers might be randomly initialized due to a mismatch between the loaded weights and the model architecture. Make sure to train the model or load different weights before prediction."
warning_msg += (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This more concise format might be easier on the eye:

    warning_msg += "".join(
        [
            f"\nKeys with size mismatches: {size_mismatches}" if size_mismatches else "",
            f"\nUnexpected keys: {unexpected_keys}" if unexpected_keys else "",
            f"\nMissing keys: {missing_keys}" if missing_keys else "",
        ]
    )

f" The following keys had size mismatches: {size_mismatches}"
if size_mismatches
else ""
)
warning_msg += (
f" The following keys were unexpected: {unexpected_keys}"
if unexpected_keys
else ""
)
warning_msg += (
f" The following keys were missing: {missing_keys}"
if missing_keys
else ""
)
logging.warning(warning_msg)

def _check_predict_in_order(self, precursor_df: pd.DataFrame):
pass

Expand Down
14 changes: 12 additions & 2 deletions peptdeep/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def __init__(
self,
mask_modloss: bool = False,
device: str = "gpu",
charged_frag_types: list[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still have requires-python = ">=3.8.0" .. so Optional[List[str]] it is ..

@jalew188 should we drop support for 3.8?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, in the next release.

):
"""
Parameters
Expand All @@ -283,11 +284,20 @@ def __init__(
Device for DL models, could be 'gpu' ('cuda') or 'cpu'.
if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'.
Defaults to 'gpu'
charge_frag_types : list[str], optional
Charge fragment types for MS2 model to override the default configuration in the yaml file.
If set to None, it will use the default configuration in the yaml file.
"""
self._train_psm_logging = True

self.ms2_model: pDeepModel = pDeepModel(
mask_modloss=mask_modloss, device=device
self.ms2_model: pDeepModel = (
pDeepModel(mask_modloss=mask_modloss, device=device)
if charged_frag_types is None
else pDeepModel(
mask_modloss=mask_modloss,
device=device,
charged_frag_types=charged_frag_types,
)
)
self.rt_model: AlphaRTModel = AlphaRTModel(device=device)
self.ccs_model: AlphaCCSModel = AlphaCCSModel(device=device)
Expand Down
Loading