Skip to content

Commit 48705dd

Browse files
authored
Merge pull request #618 from DeepRank/617_clarify_output_col_gcroci2
docs: clarify usage of `output` column generated by the exporter
2 parents c7513c6 + c822e1b commit 48705dd

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

deeprank2/dataset.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,32 +151,37 @@ def _check_hdf5_files(self) -> None:
151151
self.hdf5_paths.remove(hdf5_path)
152152

153153
def _check_task_and_classes(self, task: str, classes: str | None = None) -> None:
154-
if self.target in [targets.IRMSD, targets.LRMSD, targets.FNAT, targets.DOCKQ]:
155-
self.task = targets.REGRESS
156-
157-
elif self.target in [targets.BINARY, targets.CAPRI]:
158-
self.task = targets.CLASSIF
159-
154+
# Determine the task based on the target or use the provided task
155+
if task is None:
156+
target_to_task_map = {
157+
targets.IRMSD: targets.REGRESS,
158+
targets.LRMSD: targets.REGRESS,
159+
targets.FNAT: targets.REGRESS,
160+
targets.DOCKQ: targets.REGRESS,
161+
targets.BINARY: targets.CLASSIF,
162+
targets.CAPRI: targets.CLASSIF,
163+
}
164+
self.task = target_to_task_map.get(self.target)
160165
else:
161166
self.task = task
162167

168+
# Validate the task
163169
if self.task not in [targets.CLASSIF, targets.REGRESS] and self.target is not None:
164170
msg = f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}"
165171
raise ValueError(msg)
166172

167-
if task != self.task and task is not None:
173+
# Warn if the user-set task does not match the determined task
174+
if task and task != self.task:
168175
warnings.warn(
169-
f"Target {self.target} expects {self.task}, but was set to task {task} by user.\nUser set task is ignored and {self.task} will be used.",
176+
f"Target {self.target} expects {self.task}, but was set to task {task} by user. User set task is ignored and {self.task} will be used.",
170177
)
171178

179+
# Handle classification task
172180
if self.task == targets.CLASSIF:
173181
if classes is None:
174-
self.classes = [0, 1]
175-
_log.info(f"Target classes set to: {self.classes}")
176-
else:
177-
self.classes = classes
178-
182+
self.classes = [0, 1, 2, 3, 4, 5] if self.target == targets.CAPRI else [0, 1]
179183
self.classes_to_index = {class_: index for index, class_ in enumerate(self.classes)}
184+
_log.info(f"Target classes set to: {self.classes}")
180185
else:
181186
self.classes = None
182187
self.classes_to_index = None

docs/getstarted.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ output_test = pd.read_hdf(os.path.join("<output_folder_path>", "output_exporter.
391391

392392
The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results.
393393

394+
For classification tasks, the `output` column contains a list of probabilities that each class occurs, and each list sums to 1 (for more details, please see documentation on the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)). Note that the order of the classes in the list depends on the `classes` attribute of the DeeprankDataset instances. For classification tasks, if `classes` is not specified (as in this example case), it is defaulted to [0, 1].
395+
394396
Example for plotting training loss curves using [Plotly Express](https://plotly.com/python/plotly-express/):
395397

396398
```python

tutorials/training.ipynb

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,8 @@
420420
"metadata": {},
421421
"outputs": [],
422422
"source": [
423-
"output_train = pd.read_hdf(\n",
424-
" os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\"\n",
425-
")\n",
426-
"output_test = pd.read_hdf(\n",
427-
" os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\"\n",
428-
")\n",
423+
"output_train = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n",
424+
"output_test = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n",
429425
"output_train.head()"
430426
]
431427
},
@@ -436,7 +432,9 @@
436432
"source": [
437433
"The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results.\n",
438434
"\n",
439-
"For example, the loss across the epochs can be plotted for the training and the validation sets:\n"
435+
"For classification tasks, the `output` column contains a list of probabilities that each class occurs, and each list sums to 1 (for more details, please see documentation on the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)). Note that the order of the classes in the list depends on the `classes` attribute of the DeeprankDataset instances. For classification tasks, if `classes` is not specified (as in this example case), it is defaulted to [0, 1].\n",
436+
"\n",
437+
"The loss across the epochs can be plotted for the training and the validation sets:\n"
440438
]
441439
},
442440
{
@@ -671,12 +669,8 @@
671669
"metadata": {},
672670
"outputs": [],
673671
"source": [
674-
"output_train = pd.read_hdf(\n",
675-
" os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\"\n",
676-
")\n",
677-
"output_test = pd.read_hdf(\n",
678-
" os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\"\n",
679-
")\n",
672+
"output_train = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n",
673+
"output_test = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n",
680674
"output_train.head()"
681675
]
682676
},
@@ -767,7 +761,7 @@
767761
"name": "python",
768762
"nbconvert_exporter": "python",
769763
"pygments_lexer": "ipython3",
770-
"version": "3.10.13"
764+
"version": "3.10.12"
771765
},
772766
"orig_nbformat": 4
773767
},

0 commit comments

Comments
 (0)