Skip to content

Commit 5320ba4

Browse files
remove unused args for inferrer (Project-MONAI#605)
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 8f7939c commit 5320ba4

File tree

1 file changed

+8
-36
lines changed

1 file changed

+8
-36
lines changed

modules/dynunet_pipeline/inferrer.py

+8-36
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import os
2-
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
2+
from typing import Any, Dict, Optional, Tuple, Union
33

44
import numpy as np
55
import torch
66
import torch.nn as nn
77
from ignite.engine import Engine
8-
from ignite.metrics import Metric
98
from monai.data import decollate_batch
109
from monai.data.nifti_writer import write_nifti
1110
from monai.engines import SupervisedEvaluator
12-
from monai.engines.utils import IterationEvents, default_prepare_batch
11+
from monai.engines.utils import IterationEvents
1312
from monai.inferers import Inferer
1413
from monai.networks.utils import eval_mode
15-
from monai.transforms import AsDiscrete, Transform
14+
from monai.transforms import AsDiscrete
1615
from torch.utils.data import DataLoader
1716

1817
from transforms import recovery_prediction
@@ -21,7 +20,11 @@
2120
class DynUNetInferrer(SupervisedEvaluator):
2221
"""
2322
This class inherits from SupervisedEvaluator in MONAI, and is used with DynUNet
24-
on Decathlon datasets.
23+
on Decathlon datasets. As a customized inferrer, some of the arguments from
24+
SupervisedEvaluator are not supported. For example, the actual
25+
post processing method used is hard coded in the `_iteration` function, thus the
26+
argument `postprocessing` from SupervisedEvaluator is not exist. If you need
27+
to change the post processing way, please modify the `_iteration` function directly.
2528
2629
Args:
2730
device: an object representing the device on which to run.
@@ -30,22 +33,7 @@ class DynUNetInferrer(SupervisedEvaluator):
3033
network: use the network to run model forward.
3134
output_dir: the path to save inferred outputs.
3235
num_classes: the number of classes (output channels) for the task.
33-
epoch_length: number of iterations for one epoch, default to
34-
`len(val_data_loader)`.
35-
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
36-
with respect to the host. For other cases, this argument has no effect.
37-
prepare_batch: function to parse image and label for current iteration.
38-
iteration_update: the callable function for every iteration, expect to accept `engine`
39-
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
4036
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
41-
postprocessing: execute additional transformation for the model output data.
42-
Typically, several Tensor based transforms composed by `Compose`.
43-
key_val_metric: compute metric when every iteration completed, and save average value to
44-
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
45-
checkpoint into files.
46-
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
47-
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
48-
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
4937
amp: whether to enable auto-mixed-precision evaluation, default is False.
5038
tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
5139
test time augmentation, default is False.
@@ -59,31 +47,15 @@ def __init__(
5947
network: torch.nn.Module,
6048
output_dir: str,
6149
num_classes: Union[str, int],
62-
epoch_length: Optional[int] = None,
63-
non_blocking: bool = False,
64-
prepare_batch: Callable = default_prepare_batch,
65-
iteration_update: Optional[Callable] = None,
6650
inferer: Optional[Inferer] = None,
67-
postprocessing: Optional[Transform] = None,
68-
key_val_metric: Optional[Dict[str, Metric]] = None,
69-
additional_metrics: Optional[Dict[str, Metric]] = None,
70-
val_handlers: Optional[Sequence] = None,
7151
amp: bool = False,
7252
tta_val: bool = False,
7353
) -> None:
7454
super().__init__(
7555
device=device,
7656
val_data_loader=val_data_loader,
7757
network=network,
78-
epoch_length=epoch_length,
79-
non_blocking=non_blocking,
80-
prepare_batch=prepare_batch,
81-
iteration_update=iteration_update,
8258
inferer=inferer,
83-
postprocessing=postprocessing,
84-
key_val_metric=key_val_metric,
85-
additional_metrics=additional_metrics,
86-
val_handlers=val_handlers,
8759
amp=amp,
8860
)
8961

0 commit comments

Comments
 (0)