-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allow partial loading for pre trained ms2 models #226
base: development
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import pandas as pd | ||
import numpy as np | ||
import warnings | ||
|
||
from peptdeep.utils import logging | ||
from typing import List, Tuple, IO | ||
|
||
from tqdm import tqdm | ||
|
@@ -425,6 +425,53 @@ def _prepare_train_data_df( | |
# if np.all(precursor_df['nce'].values > 1): | ||
# precursor_df['nce'] = precursor_df['nce']*self.NCE_factor | ||
|
||
def _load_model_from_stream(self, stream: IO): | ||
""" | ||
Overriding this function to allow for partial loading of pretrained models. | ||
Helpful incase of just changing the prediction head (fragment types) of the model, | ||
while keeping the rest of the model weights same. | ||
|
||
Parameters | ||
---------- | ||
stream : IO | ||
A file stream to load the model from | ||
""" | ||
current_model_dict = self.model.state_dict() | ||
to_be_loaded_dict = torch.load(stream, map_location=self.device) | ||
# load same size and key tensors | ||
filtered_params = {} | ||
size_mismatches = [] | ||
unexpected_keys = [] | ||
for source_key, source_value in to_be_loaded_dict.items(): | ||
if source_key in current_model_dict: | ||
if source_value.size() == current_model_dict[source_key].size(): | ||
filtered_params[source_key] = source_value | ||
else: | ||
size_mismatches.append(source_key) | ||
else: | ||
unexpected_keys.append(source_key) | ||
missing_keys = set(current_model_dict.keys()) - set(filtered_params.keys()) | ||
|
||
self.model.load_state_dict(filtered_params, strict=False) | ||
if size_mismatches or unexpected_keys or missing_keys: | ||
warning_msg = "Some layers might be randomly initialized due to a mismatch between the loaded weights and the model architecture. Make sure to train the model or load different weights before prediction." | ||
warning_msg += ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) This more concise format might be easier on the eye:
|
||
f" The following keys had size mismatches: {size_mismatches}" | ||
if size_mismatches | ||
else "" | ||
) | ||
warning_msg += ( | ||
f" The following keys were unexpected: {unexpected_keys}" | ||
if unexpected_keys | ||
else "" | ||
) | ||
warning_msg += ( | ||
f" The following keys were missing: {missing_keys}" | ||
if missing_keys | ||
else "" | ||
) | ||
logging.warning(warning_msg) | ||
|
||
def _check_predict_in_order(self, precursor_df: pd.DataFrame): | ||
pass | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -270,6 +270,7 @@ def __init__( | |
self, | ||
mask_modloss: bool = False, | ||
device: str = "gpu", | ||
charged_frag_types: list[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we still have @jalew188 should we drop support for 3.8? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can, in the next release. |
||
): | ||
""" | ||
Parameters | ||
|
@@ -283,11 +284,20 @@ def __init__( | |
Device for DL models, could be 'gpu' ('cuda') or 'cpu'. | ||
if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'. | ||
Defaults to 'gpu' | ||
charge_frag_types : list[str], optional | ||
Charge fragment types for MS2 model to override the default configuration in the yaml file. | ||
If set to None, it will use the default configuration in the yaml file. | ||
""" | ||
self._train_psm_logging = True | ||
|
||
self.ms2_model: pDeepModel = pDeepModel( | ||
mask_modloss=mask_modloss, device=device | ||
self.ms2_model: pDeepModel = ( | ||
pDeepModel(mask_modloss=mask_modloss, device=device) | ||
if charged_frag_types is None | ||
else pDeepModel( | ||
mask_modloss=mask_modloss, | ||
device=device, | ||
charged_frag_types=charged_frag_types, | ||
) | ||
) | ||
self.rt_model: AlphaRTModel = AlphaRTModel(device=device) | ||
self.ccs_model: AlphaCCSModel = AlphaCCSModel(device=device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may be two cases:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I’m not sure I fully understand your concerns. The current interface saves the complete model weights, so I’m unclear about why users would want to save only partial models. Are you suggesting we add this functionality?
As for loading partial weights, the current implementation(in this PR) should already handle this automatically. It matches parameter keys and sizes, loading the matching weights while initializing the remaining parameters from scratch. Are you suggesting we modify the interface to let users explicitly specify which layers to load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I mean the use cases.