|
| 1 | +import glob |
1 | 2 | import os
|
2 | 3 | import warnings
|
3 | 4 | from shutil import rmtree
|
4 | 5 | from tempfile import mkdtemp
|
5 | 6 |
|
6 | 7 | import h5py
|
| 8 | +import pandas as pd |
| 9 | +import pytest |
| 10 | +import torch |
7 | 11 |
|
8 | 12 | from deeprank2.dataset import GraphDataset, GridDataset
|
9 | 13 | from deeprank2.domain import edgestorage as Efeat
|
10 | 14 | from deeprank2.domain import nodestorage as Nfeat
|
11 | 15 | from deeprank2.domain import targetstorage as targets
|
12 | 16 | from deeprank2.neuralnets.cnn.model3d import CnnClassification
|
13 | 17 | from deeprank2.neuralnets.gnn.ginet import GINet
|
| 18 | +from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork |
14 | 19 | from deeprank2.query import (ProteinProteinInterfaceResidueQuery,
|
15 | 20 | QueryCollection)
|
16 | 21 | from deeprank2.tools.target import compute_ppi_scores
|
@@ -196,3 +201,77 @@ def test_gnn(): # pylint: disable=too-many-locals
|
196 | 201 | finally:
|
197 | 202 | rmtree(hdf5_directory)
|
198 | 203 | rmtree(output_directory)
|
| 204 | + |
| 205 | +@pytest.fixture(scope='session') |
| 206 | +def hdf5_files_for_nan(tmpdir_factory): |
| 207 | + # For testing cases in which the loss function is nan for the validation and/or for |
| 208 | + # the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset, |
| 209 | + # since it is a functionality of the trainer module, which does not depend on the dataset type. |
| 210 | + # The settings and the parameters have been carefully chosen to result in nan losses. |
| 211 | + pdb_paths = [ |
| 212 | + "tests/data/pdb/3C8P/3C8P.pdb", |
| 213 | + "tests/data/pdb/1A0Z/1A0Z.pdb", |
| 214 | + "tests/data/pdb/1ATN/1ATN_1w.pdb" |
| 215 | + ] |
| 216 | + chain_id1 = "A" |
| 217 | + chain_id2 = "B" |
| 218 | + targets_values = [0, 1, 1] |
| 219 | + prefix = os.path.join(tmpdir_factory.mktemp("data"), "test-queries-process") |
| 220 | + |
| 221 | + queries = QueryCollection() |
| 222 | + for idx, pdb_path in enumerate(pdb_paths): |
| 223 | + query = ProteinProteinInterfaceResidueQuery( |
| 224 | + pdb_path, |
| 225 | + chain_id1, |
| 226 | + chain_id2, |
| 227 | + # A very low cutoff distance helps for not making the network to learn |
| 228 | + distance_cutoff=3, |
| 229 | + targets = {targets.BINARY: targets_values[idx]} |
| 230 | + ) |
| 231 | + queries.add(query) |
| 232 | + |
| 233 | + hdf5_paths = queries.process(prefix = prefix) |
| 234 | + return hdf5_paths |
| 235 | + |
| 236 | +@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)]) |
| 237 | +def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan): |
| 238 | + mols = [] |
| 239 | + for fname in hdf5_files_for_nan: |
| 240 | + with h5py.File(fname, 'r') as hdf5: |
| 241 | + for mol in hdf5.keys(): |
| 242 | + mols.append(mol) |
| 243 | + |
| 244 | + dataset_train = GraphDataset( |
| 245 | + hdf5_path = hdf5_files_for_nan, |
| 246 | + subset = mols[1:], |
| 247 | + target = targets.BINARY, |
| 248 | + task = targets.CLASSIF |
| 249 | + ) |
| 250 | + dataset_valid = GraphDataset( |
| 251 | + hdf5_path = hdf5_files_for_nan, |
| 252 | + subset = [mols[0]], |
| 253 | + dataset_train=dataset_train, |
| 254 | + train=False |
| 255 | + ) |
| 256 | + |
| 257 | + trainer = Trainer( |
| 258 | + NaiveNetwork, |
| 259 | + dataset_train, |
| 260 | + dataset_valid) |
| 261 | + |
| 262 | + optimizer = torch.optim.SGD |
| 263 | + lr = 10000 |
| 264 | + weight_decay = 10000 |
| 265 | + |
| 266 | + trainer.configure_optimizers(optimizer, lr, weight_decay=weight_decay) |
| 267 | + w_msg = "A model has been saved but the validation and/or the training losses were NaN;" + \ |
| 268 | + "\n\ttry to increase the cutoff distance during the data processing or the number of data points " + \ |
| 269 | + "during the training." |
| 270 | + with warnings.catch_warnings(record=True) as w: |
| 271 | + warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) |
| 272 | + trainer.train( |
| 273 | + nepoch=5, batch_size=1, validate=validate, |
| 274 | + best_model=best_model, filename='test.pth.tar') |
| 275 | + assert len(w) == 1 |
| 276 | + assert issubclass(w[-1].category, UserWarning) |
| 277 | + assert w_msg in str(w[-1].message) |
0 commit comments