Skip to content

Commit

Permalink
jdekraker edits: untar nnunet model
Browse files Browse the repository at this point in the history
moved from jdekraker PR #353
  • Loading branch information
akhanf committed Jan 22, 2025
1 parent ccf2b76 commit db0de78
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions hippunfold/workflow/rules/nnunet.smk
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,6 @@ def get_nnunet_input(wildcards):
raise ValueError("modality not supported for nnunet!")


def get_model_tar():
if config["force_nnunet_model"]:
model_name = config["force_nnunet_model"]
else:
model_name = config["modality"]

local_tar = config["resource_urls"]["nnunet_model"].get(model_name, None)
if local_tar == None:
print(f"ERROR: {model_name} does not exist in nnunet_model in the config file")

return (Path(download_dir) / "model" / Path(local_tar).name).absolute()


rule download_nnunet_model:
params:
url=config["resource_urls"]["nnunet_model"][config["force_nnunet_model"]]
if config["force_nnunet_model"]
else config["resource_urls"]["nnunet_model"][config["modality"]],
model_dir=Path(download_dir) / "model",
output:
model_tar=get_model_tar(),
container:
config["singularity"]["autotop"]
shell:
"mkdir -p {params.model_dir} && wget https://{params.url} -O {output}"


def parse_task_from_tar(wildcards, input):
match = re.search("Task[0-9]{3}_[\w]+", input.model_tar)
if match:
Expand Down Expand Up @@ -109,6 +82,34 @@ def get_cmd_copy_inputs(wildcards, input):
return " && ".join(cmd)


def get_model_tar():
if config["force_nnunet_model"]:
model_name = config["force_nnunet_model"]
else:
model_name = config["modality"]

local_tar = config["resource_urls"]["nnunet_model"].get(model_name, None)
if local_tar == None:
print(f"ERROR: {model_name} does not exist in nnunet_model in the config file")

return (Path(download_dir) / "model" / Path(local_tar).name).absolute()


rule download_extract_nnunet_model:
params:
url=config["resource_urls"]["nnunet_model"][config["force_nnunet_model"]]
if config["force_nnunet_model"]
else config["resource_urls"]["nnunet_model"][config["modality"]],
model_dir=Path(download_dir) / "model",
output:
model_tar=get_model_tar(),
container:
config["singularity"]["autotop"]
shell:
"mkdir -p {params.model_dir} && wget https://{params.url} -O {output} && "
"tar -xf {output} -C {params.model_dir}"


rule run_inference:
""" This rule uses either GPU or CPU .
It also runs in an isolated folder (shadow), with symlinks to inputs in that folder, copying over outputs once complete, so temp files are not retained"""
Expand All @@ -118,7 +119,8 @@ rule run_inference:
params:
cmd_copy_inputs=get_cmd_copy_inputs,
temp_lbl="templbl/temp.nii.gz",
model_dir="tempmodel",
model_dir=Path(download_dir) / "model" / "nnUNet",
tmp_model_dir="tempmodel",
in_folder="tempimg",
out_folder="templbl",
task=parse_task_from_tar,
Expand Down Expand Up @@ -162,10 +164,10 @@ rule run_inference:
#set threads
# run inference
#copy from temp output folder to final output
"mkdir -p {params.model_dir} {params.in_folder} {params.out_folder} && "
"mkdir -p {params.tmp_model_dir} {params.in_folder} {params.out_folder} && "
"{params.cmd_copy_inputs} && "
"tar -xf {input.model_tar} -C {params.model_dir} && "
"export RESULTS_FOLDER={params.model_dir} && "
"ln -s {params.model_dir} {params.tmp_model_dir} && "
"export RESULTS_FOLDER={params.tmp_model_dir} && "
"export nnUNet_n_proc_DA={threads} && "
"nnUNet_predict -i {params.in_folder} -o {params.out_folder} -t {params.task} -chk {params.chkpnt} -tr {params.trainer} {params.tta} &> {log} && "
"cp {params.temp_lbl} {output.nnunet_seg}"
Expand Down

0 comments on commit db0de78

Please sign in to comment.