Skip to content

Commit a02013c

Browse files
authored
Merge pull request #574 from DeepRank/570_new_patch_release_gcroci2
release: patch 3.0.1
2 parents c45f828 + 39cb127 commit a02013c

File tree

9 files changed

+42
-24
lines changed

9 files changed

+42
-24
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 3.0.0
2+
current_version = 3.0.1
33

44
[comment]
55
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved

CITATION.cff

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,5 @@ keywords:
6868
- DeepRank
6969
license: Apache-2.0
7070
commit: 4e8823758ba03f824b4281f5689cb6a335ab2f6c
71-
version: "3.0.0"
72-
date-released: "2024-01-25"
71+
version: "3.0.1"
72+
date-released: "2024-02-22"

CONTRIBUTING.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ You want to make some kind of change to the code base
3737
#. if needed, fork the repository to your own Github profile and create your own feature branch off of the latest main commit. While working on your feature branch, make sure to stay up to date with the main branch by pulling in changes, possibly from the 'upstream' repository (follow the instructions `here <https://help.github.com/articles/configuring-a-remote-for-a-fork/>`__ and `here <https://help.github.com/articles/syncing-a-fork/>`__);
3838
#. make sure the existing tests still work by running ``python setup.py test``;
3939
#. add your own tests (if necessary);
40-
#. ensure the code is correctly linted (`ruff .`) and formatted (`ruff format .`);
41-
#. see our `developer's readme <README.dev.md>`` for detailed information on our style conventions, etc.;
40+
#. ensure the code is correctly linted (``ruff .``) and formatted (``ruff format .``);
41+
#. see our `developer's readme <README.dev.md>`_ for detailed information on our style conventions, etc.;
4242
#. update or expand the documentation;
4343
#. `push <http://rogerdudler.github.io/git-guide/>`_ your feature branch to (your fork of) the DeepRank2 repository on GitHub;
4444
#. create the pull request, e.g. following the instructions `here <https://help.github.com/articles/creating-a-pull-request/>`__.

README.dev.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ During the development cycle, three main supporting branches are used:
7979

8080
## Making a release
8181

82-
1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files such as the current one, fix minor bugs if necessary).
82+
1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files, fix minor bugs if necessary).
8383
2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning).
8484
3. Verify that the information in `CITATION.cff` is correct (update the release date), and that `.zenodo.json` contains equivalent data.
8585
4. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests).

deeprank2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.0.0"
1+
__version__ = "3.0.1"

deeprank2/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def _save_model(self) -> dict[str, Any]:
947947
if key["transform"] is None:
948948
continue
949949
str_expr = inspect.getsource(key["transform"])
950-
match = re.search(r"\'transform\':.*(lambda.*).*,.*\'standardize\'.*", str_expr).group(1)
950+
match = re.search(r"[\"|\']transform[\"|\']:.*(lambda.*).*,.*[\"|\']standardize[\"|\'].*", str_expr).group(1)
951951
key["transform"] = match
952952

953953
state = {

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "deeprank2"
7-
version = "3.0.0"
7+
version = "3.0.1"
88
description = "DeepRank2 is an open-source deep learning framework for data mining of protein-protein interfaces or single-residue missense variants."
99
readme = "README.md"
1010
requires-python = ">=3.10"

tests/features/test_irc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_irc_atom() -> None:
3232
pdb_path = "tests/data/pdb/1A0Z/1A0Z.pdb"
3333
graph, _ = build_testgraph(
3434
pdb_path=pdb_path,
35-
detail="residue",
35+
detail="atom",
3636
influence_radius=4.5,
3737
max_edge_length=4.5,
3838
)

tests/test_trainer.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _model_base_test(
8686
dataset_train,
8787
dataset_val,
8888
dataset_test,
89+
cuda=use_cuda,
8990
output_exporters=output_exporters,
9091
)
9192

@@ -94,20 +95,6 @@ def _model_base_test(
9495
for parameter in trainer.model.parameters():
9596
assert parameter.is_cuda, f"{parameter} is not cuda"
9697

97-
data = dataset_train.get(0)
98-
99-
for name, data_tensor in (
100-
("x", data.x),
101-
("y", data.y),
102-
(Efeat.INDEX, data.edge_index),
103-
("edge_attr", data.edge_attr),
104-
(Nfeat.POSITION, data.pos),
105-
("cluster0", data.cluster0),
106-
("cluster1", data.cluster1),
107-
):
108-
if data_tensor is not None:
109-
assert data_tensor.is_cuda, f"data.{name} is not cuda"
110-
11198
with warnings.catch_warnings(record=UserWarning):
11299
trainer.train(
113100
nepoch=3,
@@ -774,6 +761,37 @@ def test_test_method_pretrained_model_on_dataset_without_target(self) -> None:
774761
assert output.target.unique().tolist()[0] is None
775762
assert output.loss.unique().tolist()[0] is None
776763

764+
def test_graph_save_and_load_model(self) -> None:
765+
test_data_graph = "tests/data/hdf5/test.hdf5"
766+
n = 10
767+
features_transform = {
768+
Nfeat.RESTYPE: {"transform": lambda x: x / 2, "standardize": True},
769+
Nfeat.BSA: {"transform": None, "standardize": False},
770+
}
771+
772+
dataset = GraphDataset(
773+
hdf5_path=test_data_graph,
774+
node_features=[Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA],
775+
target=targets.BINARY,
776+
task=targets.CLASSIF,
777+
features_transform=features_transform,
778+
)
779+
trainer = Trainer(NaiveNetwork, dataset)
780+
# during the training the model is saved
781+
trainer.train(nepoch=2, batch_size=2, filename=self.save_path)
782+
assert trainer.features_transform == features_transform
783+
784+
# load the model into a new GraphDataset instance
785+
dataset_test = GraphDataset(
786+
hdf5_path="tests/data/hdf5/test.hdf5",
787+
train_source=self.save_path,
788+
)
789+
790+
# Check if the features_transform is correctly loaded from the saved model
791+
assert dataset_test.features_transform[Nfeat.RESTYPE]["transform"](n) == n / 2 # the only way to test the transform in this case is to apply it
792+
assert dataset_test.features_transform[Nfeat.RESTYPE]["standardize"] == features_transform[Nfeat.RESTYPE]["standardize"]
793+
assert dataset_test.features_transform[Nfeat.BSA] == features_transform[Nfeat.BSA]
794+
777795

778796
if __name__ == "__main__":
779797
unittest.main()

0 commit comments

Comments
 (0)