Skip to content

Commit

Permalink
feat: Add flag to set TRT fallback behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Jan 29, 2025
1 parent 5cc8d56 commit a535ac2
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 2 deletions.
20 changes: 20 additions & 0 deletions fmriprep/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def _slice_time_ref(value, parser):
raise parser.error(f'Slice time reference must be in range 0-1. Received {value}.')
return value

def _fallback_trt(value, parser):
if value == 'estimated':
return value
try:
return float(value)
except ValueError:
raise parser.error(
f'Falling back to TRT must be a number or "estimated". Received {value}.'
) from None

verstr = f'fMRIPrep v{config.environment.version}'
currentv = Version(config.environment.version)
is_release = not any((currentv.is_devrelease, currentv.is_prerelease, currentv.is_postrelease))
Expand All @@ -165,6 +175,7 @@ def _slice_time_ref(value, parser):
PositiveInt = partial(_min_one, parser=parser)
BIDSFilter = partial(_bids_filter, parser=parser)
SliceTimeRef = partial(_slice_time_ref, parser=parser)
FallbackTRT = partial(_fallback_trt, parser=parser)

# Arguments as specified by BIDS-Apps
# required, positional arguments
Expand Down Expand Up @@ -423,6 +434,15 @@ def _slice_time_ref(value, parser):
type=int,
help='Number of nonsteady-state volumes. Overrides automatic detection.',
)
g_conf.add_argument(
'--fallback-total-readout-time',
required=False,
action='store',
default=None,
type=FallbackTRT,
help='Fallback value for Total Readout Time (TRT) calculation. '
'May be a number or "estimated".',
)
g_conf.add_argument(
'--random-seed',
dest='_random_seed',
Expand Down
3 changes: 3 additions & 0 deletions fmriprep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ class workflow(_Config):
"""Remove the mean from fieldmaps."""
force_syn = None
"""Run *fieldmap-less* susceptibility-derived distortions estimation."""
fallback_total_readout_time = None
"""Infer the total readout time if unavailable from authoritative metadata.
This may be a number or the string "estimated"."""
hires = None
"""Run FreeSurfer ``recon-all`` with the ``-hires`` flag."""
fs_no_resume = None
Expand Down
9 changes: 9 additions & 0 deletions fmriprep/interfaces/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def _run_interface(self, runtime):
class DistortionParametersInputSpec(TraitedSpec):
in_file = File(exists=True, desc='EPI image corresponding to the metadata')
metadata = traits.Dict(mandatory=True, desc='metadata corresponding to the inputs')
fallback = traits.Either(
None,
'estimated',
traits.Float,
usedefault=True,
desc='Fallback value for missing metadata',
)


class DistortionParametersOutputSpec(TraitedSpec):
Expand All @@ -208,6 +215,8 @@ def _run_interface(self, runtime):
self._results['readout_time'] = get_trt(
self.inputs.metadata,
self.inputs.in_file or None,
use_estimate=self.inputs.fallback == 'estimated',
fallback=self.inputs.fallback if isinstance(self.inputs.fallback, float) else None,
)
self._results['pe_direction'] = self.inputs.metadata['PhaseEncodingDirection']
except (KeyError, ValueError):
Expand Down
6 changes: 5 additions & 1 deletion fmriprep/workflows/bold/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def init_bold_volumetric_resample_wf(
metadata: dict,
mem_gb: dict[str, float],
jacobian: bool,
fallback_total_readout_time: str | float | None = None,
fieldmap_id: str | None = None,
omp_nthreads: int = 1,
name: str = 'bold_volumetric_resample_wf',
Expand Down Expand Up @@ -161,7 +162,10 @@ def init_bold_volumetric_resample_wf(
run_without_submitting=True,
)
distortion_params = pe.Node(
DistortionParameters(metadata=metadata),
DistortionParameters(
metadata=metadata,
fallback=fallback_total_readout_time,
),
name='distortion_params',
run_without_submitting=True,
)
Expand Down
1 change: 1 addition & 0 deletions fmriprep/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def init_bold_wf(
# Resample to anatomical space
bold_anat_wf = init_bold_volumetric_resample_wf(
metadata=all_metadata[0],
fallback_total_readout_time=config.workflow.fallback_total_readout_time,
fieldmap_id=fieldmap_id if not multiecho else None,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb,
Expand Down
6 changes: 5 additions & 1 deletion fmriprep/workflows/bold/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,11 @@ def init_bold_native_wf(
)

distortion_params = pe.Node(
DistortionParameters(metadata=metadata, in_file=bold_file),
DistortionParameters(
metadata=metadata,
in_file=bold_file,
fallback=config.worfklow.fallback_total_readout_time,
),
name='distortion_params',
run_without_submitting=True,
)
Expand Down

0 comments on commit a535ac2

Please sign in to comment.