1
1
import os
2
- from typing import Any , Callable , Dict , Optional , Sequence , Tuple , Union
2
+ from typing import Any , Dict , Optional , Tuple , Union
3
3
4
4
import numpy as np
5
5
import torch
6
6
import torch .nn as nn
7
7
from ignite .engine import Engine
8
- from ignite .metrics import Metric
9
8
from monai .data import decollate_batch
10
9
from monai .data .nifti_writer import write_nifti
11
10
from monai .engines import SupervisedEvaluator
12
- from monai .engines .utils import IterationEvents , default_prepare_batch
11
+ from monai .engines .utils import IterationEvents
13
12
from monai .inferers import Inferer
14
13
from monai .networks .utils import eval_mode
15
- from monai .transforms import AsDiscrete , Transform
14
+ from monai .transforms import AsDiscrete
16
15
from torch .utils .data import DataLoader
17
16
18
17
from transforms import recovery_prediction
21
20
class DynUNetInferrer (SupervisedEvaluator ):
22
21
"""
23
22
This class inherits from SupervisedEvaluator in MONAI, and is used with DynUNet
24
- on Decathlon datasets.
23
+ on Decathlon datasets. As a customized inferrer, some of the arguments from
24
+ SupervisedEvaluator are not supported. For example, the actual
25
+ post processing method used is hard coded in the `_iteration` function, thus the
26
+ argument `postprocessing` from SupervisedEvaluator is not exist. If you need
27
+ to change the post processing way, please modify the `_iteration` function directly.
25
28
26
29
Args:
27
30
device: an object representing the device on which to run.
@@ -30,22 +33,7 @@ class DynUNetInferrer(SupervisedEvaluator):
30
33
network: use the network to run model forward.
31
34
output_dir: the path to save inferred outputs.
32
35
num_classes: the number of classes (output channels) for the task.
33
- epoch_length: number of iterations for one epoch, default to
34
- `len(val_data_loader)`.
35
- non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
36
- with respect to the host. For other cases, this argument has no effect.
37
- prepare_batch: function to parse image and label for current iteration.
38
- iteration_update: the callable function for every iteration, expect to accept `engine`
39
- and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
40
36
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
41
- postprocessing: execute additional transformation for the model output data.
42
- Typically, several Tensor based transforms composed by `Compose`.
43
- key_val_metric: compute metric when every iteration completed, and save average value to
44
- engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
45
- checkpoint into files.
46
- additional_metrics: more Ignite metrics that also attach to Ignite Engine.
47
- val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
48
- CheckpointHandler, StatsHandler, SegmentationSaver, etc.
49
37
amp: whether to enable auto-mixed-precision evaluation, default is False.
50
38
tta_val: whether to do the 8 flips (8 = 2 ** 3, where 3 represents the three dimensions)
51
39
test time augmentation, default is False.
@@ -59,31 +47,15 @@ def __init__(
59
47
network : torch .nn .Module ,
60
48
output_dir : str ,
61
49
num_classes : Union [str , int ],
62
- epoch_length : Optional [int ] = None ,
63
- non_blocking : bool = False ,
64
- prepare_batch : Callable = default_prepare_batch ,
65
- iteration_update : Optional [Callable ] = None ,
66
50
inferer : Optional [Inferer ] = None ,
67
- postprocessing : Optional [Transform ] = None ,
68
- key_val_metric : Optional [Dict [str , Metric ]] = None ,
69
- additional_metrics : Optional [Dict [str , Metric ]] = None ,
70
- val_handlers : Optional [Sequence ] = None ,
71
51
amp : bool = False ,
72
52
tta_val : bool = False ,
73
53
) -> None :
74
54
super ().__init__ (
75
55
device = device ,
76
56
val_data_loader = val_data_loader ,
77
57
network = network ,
78
- epoch_length = epoch_length ,
79
- non_blocking = non_blocking ,
80
- prepare_batch = prepare_batch ,
81
- iteration_update = iteration_update ,
82
58
inferer = inferer ,
83
- postprocessing = postprocessing ,
84
- key_val_metric = key_val_metric ,
85
- additional_metrics = additional_metrics ,
86
- val_handlers = val_handlers ,
87
59
amp = amp ,
88
60
)
89
61
0 commit comments