Skip to content

Commit b6b6be7

Browse files
committed
add integration test for nan losses cases
1 parent 50bbe82 commit b6b6be7

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tests/test_integration.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1+
import glob
12
import os
23
import warnings
34
from shutil import rmtree
45
from tempfile import mkdtemp
56

67
import h5py
8+
import pandas as pd
9+
import pytest
10+
import torch
711

812
from deeprank2.dataset import GraphDataset, GridDataset
913
from deeprank2.domain import edgestorage as Efeat
1014
from deeprank2.domain import nodestorage as Nfeat
1115
from deeprank2.domain import targetstorage as targets
1216
from deeprank2.neuralnets.cnn.model3d import CnnClassification
1317
from deeprank2.neuralnets.gnn.ginet import GINet
18+
from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork
1419
from deeprank2.query import (ProteinProteinInterfaceResidueQuery,
1520
QueryCollection)
1621
from deeprank2.tools.target import compute_ppi_scores
@@ -196,3 +201,77 @@ def test_gnn(): # pylint: disable=too-many-locals
196201
finally:
197202
rmtree(hdf5_directory)
198203
rmtree(output_directory)
204+
205+
@pytest.fixture(scope='session')
206+
def hdf5_files_for_nan(tmpdir_factory):
207+
# For testing cases in which the loss function is nan for the validation and/or for
208+
# the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset,
209+
# since it is a functionality of the trainer module, which does not depend on the dataset type.
210+
# The settings and the parameters have been carefully chosen to result in nan losses.
211+
pdb_paths = [
212+
"tests/data/pdb/3C8P/3C8P.pdb",
213+
"tests/data/pdb/1A0Z/1A0Z.pdb",
214+
"tests/data/pdb/1ATN/1ATN_1w.pdb"
215+
]
216+
chain_id1 = "A"
217+
chain_id2 = "B"
218+
targets_values = [0, 1, 1]
219+
prefix = os.path.join(tmpdir_factory.mktemp("data"), "test-queries-process")
220+
221+
queries = QueryCollection()
222+
for idx, pdb_path in enumerate(pdb_paths):
223+
query = ProteinProteinInterfaceResidueQuery(
224+
pdb_path,
225+
chain_id1,
226+
chain_id2,
227+
# A very low cutoff distance helps for not making the network to learn
228+
distance_cutoff=3,
229+
targets = {targets.BINARY: targets_values[idx]}
230+
)
231+
queries.add(query)
232+
233+
hdf5_paths = queries.process(prefix = prefix)
234+
return hdf5_paths
235+
236+
@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)])
237+
def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan):
238+
mols = []
239+
for fname in hdf5_files_for_nan:
240+
with h5py.File(fname, 'r') as hdf5:
241+
for mol in hdf5.keys():
242+
mols.append(mol)
243+
244+
dataset_train = GraphDataset(
245+
hdf5_path = hdf5_files_for_nan,
246+
subset = mols[1:],
247+
target = targets.BINARY,
248+
task = targets.CLASSIF
249+
)
250+
dataset_valid = GraphDataset(
251+
hdf5_path = hdf5_files_for_nan,
252+
subset = [mols[0]],
253+
dataset_train=dataset_train,
254+
train=False
255+
)
256+
257+
trainer = Trainer(
258+
NaiveNetwork,
259+
dataset_train,
260+
dataset_valid)
261+
262+
optimizer = torch.optim.SGD
263+
lr = 10000
264+
weight_decay = 10000
265+
266+
trainer.configure_optimizers(optimizer, lr, weight_decay=weight_decay)
267+
w_msg = "A model has been saved but the validation and/or the training losses were NaN;" + \
268+
"\n\ttry to increase the cutoff distance during the data processing or the number of data points " + \
269+
"during the training."
270+
with warnings.catch_warnings(record=True) as w:
271+
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
272+
trainer.train(
273+
nepoch=5, batch_size=1, validate=validate,
274+
best_model=best_model, filename='test.pth.tar')
275+
assert len(w) == 1
276+
assert issubclass(w[-1].category, UserWarning)
277+
assert w_msg in str(w[-1].message)

0 commit comments

Comments
 (0)