Skip to content

Commit

Permalink
add integration test for nan losses cases
Browse files Browse the repository at this point in the history
  • Loading branch information
gcroci2 committed Dec 20, 2023
1 parent 50bbe82 commit b6b6be7
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import glob
import os
import warnings
from shutil import rmtree
from tempfile import mkdtemp

import h5py
import pandas as pd
import pytest
import torch

from deeprank2.dataset import GraphDataset, GridDataset
from deeprank2.domain import edgestorage as Efeat
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as targets
from deeprank2.neuralnets.cnn.model3d import CnnClassification
from deeprank2.neuralnets.gnn.ginet import GINet
from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork
from deeprank2.query import (ProteinProteinInterfaceResidueQuery,
QueryCollection)
from deeprank2.tools.target import compute_ppi_scores
Expand Down Expand Up @@ -196,3 +201,77 @@ def test_gnn(): # pylint: disable=too-many-locals
finally:
rmtree(hdf5_directory)
rmtree(output_directory)

@pytest.fixture(scope='session')
def hdf5_files_for_nan(tmpdir_factory):
# For testing cases in which the loss function is nan for the validation and/or for
# the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset,
# since it is a functionality of the trainer module, which does not depend on the dataset type.
# The settings and the parameters have been carefully chosen to result in nan losses.
pdb_paths = [
"tests/data/pdb/3C8P/3C8P.pdb",
"tests/data/pdb/1A0Z/1A0Z.pdb",
"tests/data/pdb/1ATN/1ATN_1w.pdb"
]
chain_id1 = "A"
chain_id2 = "B"
targets_values = [0, 1, 1]
prefix = os.path.join(tmpdir_factory.mktemp("data"), "test-queries-process")

queries = QueryCollection()
for idx, pdb_path in enumerate(pdb_paths):
query = ProteinProteinInterfaceResidueQuery(
pdb_path,
chain_id1,
chain_id2,
# A very low cutoff distance helps for not making the network to learn
distance_cutoff=3,
targets = {targets.BINARY: targets_values[idx]}
)
queries.add(query)

hdf5_paths = queries.process(prefix = prefix)
return hdf5_paths

@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)])
def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan):
mols = []
for fname in hdf5_files_for_nan:
with h5py.File(fname, 'r') as hdf5:
for mol in hdf5.keys():
mols.append(mol)

dataset_train = GraphDataset(
hdf5_path = hdf5_files_for_nan,
subset = mols[1:],
target = targets.BINARY,
task = targets.CLASSIF
)
dataset_valid = GraphDataset(
hdf5_path = hdf5_files_for_nan,
subset = [mols[0]],
dataset_train=dataset_train,
train=False
)

trainer = Trainer(
NaiveNetwork,
dataset_train,
dataset_valid)

optimizer = torch.optim.SGD
lr = 10000
weight_decay = 10000

trainer.configure_optimizers(optimizer, lr, weight_decay=weight_decay)
w_msg = "A model has been saved but the validation and/or the training losses were NaN;" + \
"\n\ttry to increase the cutoff distance during the data processing or the number of data points " + \
"during the training."
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
trainer.train(
nepoch=5, batch_size=1, validate=validate,
best_model=best_model, filename='test.pth.tar')
assert len(w) == 1
assert issubclass(w[-1].category, UserWarning)
assert w_msg in str(w[-1].message)

0 comments on commit b6b6be7

Please sign in to comment.