Skip to content

Commit 8be109c

Browse files
Project-MONAI#413 add a function to calculate task params for DynUNet pipeline (Project-MONAI#414)
* add a function to calculate task params Signed-off-by: Yiheng Wang <[email protected]> * update readme Signed-off-by: Yiheng Wang <[email protected]> * update readme Signed-off-by: Yiheng Wang <[email protected]>
1 parent 34dbe15 commit 8be109c

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

modules/dynunet_pipeline/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ My running environment:
1414

1515
To prevent the inconsistency, all json files are included in `config/` already.
1616

17+
## Task parameters
18+
All task specific parameters refer from `nnUNet` pipeline (as shown in `task_params.py`).
19+
During these parameters, `patch_size`, `batch_size` and `deep_supr_num` are achieved from some heuristic rules, and this part is not covered in this pipeline. `spacing`, `clip_values` and `normalize_values` are based on the statistics of the dataset (only on training set). Here we provide `calculate_task_params.py` and you can try to calculate them. The command is like:
20+
21+
```
22+
python calculate_task_params.py -task_id 09
23+
```
24+
1725
## Training
1826
Please run `train.py` for training. Please modify the command line arguments according
1927
to the actual situation, such as `determinism_flag` for deterministic training, `amp` for automatic mixed precision.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
3+
4+
from monai.data import (
5+
Dataset,
6+
DatasetSummary,
7+
load_decathlon_datalist,
8+
load_decathlon_properties,
9+
)
10+
from monai.transforms import LoadImaged
11+
12+
from task_params import task_name
13+
14+
15+
def get_task_params(args):
16+
"""
17+
This function is used to achieve the spacings of decathlon dataset.
18+
In addition, for CT images (task 03, 06, 07, 08, 09 and 10), this function
19+
also prints the mean and std values (used for normalization), and the min (0.5 percentile)
20+
and max(99.5 percentile) values (used for clip).
21+
22+
"""
23+
task_id = args.task_id
24+
root_dir = args.root_dir
25+
datalist_path = args.datalist_path
26+
dataset_path = os.path.join(root_dir, task_name[task_id])
27+
datalist_name = "dataset_task{}.json".format(task_id)
28+
29+
# get all training data
30+
datalist = load_decathlon_datalist(
31+
os.path.join(datalist_path, datalist_name), True, "training", dataset_path
32+
)
33+
34+
# get modality info.
35+
properties = load_decathlon_properties(
36+
os.path.join(datalist_path, datalist_name), "modality"
37+
)
38+
39+
dataset = Dataset(
40+
data=datalist,
41+
transform=LoadImaged(keys=["image", "label"]),
42+
)
43+
44+
calculator = DatasetSummary(dataset, num_workers=4)
45+
target_spacing = calculator.get_target_spacing()
46+
print("spacing: ", target_spacing)
47+
if properties["modality"]["0"] == "CT":
48+
print("CT input, calculate statistics:")
49+
calculator.calculate_statistics()
50+
print("mean: ", calculator.data_mean, " std: ", calculator.data_std)
51+
calculator.calculate_percentiles(
52+
sampling_flag=True, interval=10, min_percentile=0.5, max_percentile=99.5
53+
)
54+
print(
55+
"min: ",
56+
calculator.data_min_percentile,
57+
" max: ",
58+
calculator.data_max_percentile,
59+
)
60+
else:
61+
print("non CT input, skip calculating.")
62+
63+
64+
if __name__ == "__main__":
65+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
66+
parser.add_argument(
67+
"-task_id", "--task_id", type=str, default="04", help="task 01 to 10"
68+
)
69+
parser.add_argument(
70+
"-root_dir",
71+
"--root_dir",
72+
type=str,
73+
default="/workspace/data/medical/",
74+
help="dataset path",
75+
)
76+
parser.add_argument(
77+
"-datalist_path",
78+
"--datalist_path",
79+
type=str,
80+
default="config/",
81+
)
82+
83+
args = parser.parse_args()
84+
get_task_params(args)

0 commit comments

Comments
 (0)