Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implicitly-pipelined and modality-agnostic Estimator #62

Merged
merged 14 commits into from
Jan 24, 2025

Conversation

oesteban
Copy link
Member

Changes our current implementation of the estimator with a new architecture that allows stacking
(#12 (comment)):

estimator_level1 = Estimator(model="b0", ...)    # e.g., 6 dof registration
estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration

estimator_level2.fit(dataset_object)  # Checks "input" and it's another estimator, so runs that first

This allow for adding filters (to be implemented):

downsample = Filter(...)  # e.g., downsampling filter
estimator_level1 = Estimator(model="b0", input=downsample, ...)    # e.g., 6 dof registration
estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration

estimator_level2.fit(dataset_object)  # Checks "input" and it's another estimator, so runs that first

In this case, the filtered dataset only feeds estimator_level1. The second level will work on a full dataset.
If you want to interleave another downsampling filter:

downsample1 = Filter(...)  # e.g., downsampling filter 1
estimator_level1 = Estimator(model="b0", input=downsample, ...)    # e.g., 6 dof registration
downsample2 = Filter(input=estimator_level1, ...)  # e.g., downsampling filter 2
estimator_level2 = Estimator(model="b0", input=downsaple2, ...) # e.g., 9 dof registration

estimator_level2.fit(dataset_object)  # Checks "input" and it's another estimator, so runs that first

Resolves: #12.
Resolves: #21.

Changes our current implementation of the estimator with a new
architecture that allows stacking
(#12 (comment)):

```Python
estimator_level1 = Estimator(model="b0", ...)    # e.g., 6 dof registration
estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration

estimator_level2.fit(dataset_object)  # Checks "input" and it's another estimator, so runs that first
```

This allow for adding *filters* (to be implemented):

``` Python
downsample = Filter(...)  # e.g., downsampling filter
estimator_level1 = Estimator(model="b0", input=downsample, ...)    # e.g., 6 dof registration
estimator_level2 = Estimator(model="b0", input=estimator_level1, ...) # e.g., 9 dof registration

estimator_level2.fit(dataset_object)  # Checks "input" and it's another estimator, so runs that first
```

In this case, the filtered dataset only feeds ``estimator_level1``.
The second level will work on a full dataset.
If you want to interleave another downsampling filter:

``` Python
downsample1 = Filter(...)  # e.g., downsampling filter 1
estimator_level1 = Estimator(model="b0", input=downsample, ...)    # e.g., 6 dof registration
downsample2 = Filter(input=estimator_level1, ...)  # e.g., downsampling filter 2
estimator_level2 = Estimator(model="b0", input=downsaple2, ...) # e.g., 9 dof registration

estimator_level2.fit(dataset_object)  # Checks "input" and it's another estimator, so runs that first
```

Resolves: #12.
Resolves: #21.
@oesteban oesteban marked this pull request as draft January 21, 2025 08:55
Copy link

codecov bot commented Jan 21, 2025

Codecov Report

Attention: Patch coverage is 64.35185% with 77 lines in your changes missing coverage. Please review.

Project coverage is 68.75%. Comparing base (d853cab) to head (6697c78).
Report is 15 commits behind head on main.

Files with missing lines Patch % Lines
src/nifreeze/model/dmri.py 45.28% 28 Missing and 1 partial ⚠️
src/nifreeze/data/dmri.py 65.30% 16 Missing and 1 partial ⚠️
src/nifreeze/model/base.py 61.29% 9 Missing and 3 partials ⚠️
src/nifreeze/cli/run.py 12.50% 7 Missing ⚠️
src/nifreeze/estimator.py 87.75% 3 Missing and 3 partials ⚠️
src/nifreeze/registration/ants.py 76.92% 5 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #62      +/-   ##
==========================================
+ Coverage   68.24%   68.75%   +0.51%     
==========================================
  Files          20       20              
  Lines         995      957      -38     
  Branches      130      121       -9     
==========================================
- Hits          679      658      -21     
+ Misses        263      254       -9     
+ Partials       53       45       -8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@oesteban oesteban force-pushed the enh/estimator-refactor branch from 7a489e3 to 7322e93 Compare January 23, 2025 11:33
@oesteban oesteban force-pushed the enh/estimator-refactor branch from 0f32b1d to 97e61c3 Compare January 23, 2025 13:19
@oesteban oesteban marked this pull request as ready for review January 23, 2025 13:38
@@ -337,7 +337,7 @@
"metadata": {},
"outputs": [],
"source": [
"from nifreeze.model.base import AverageModel\n",
"from nifreeze.model.base import ExpectationModel\n",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Average is too specific about the statistic used. If Expectation meets opposition, I'm happy to settle whatever others think is best (or even revert back to Average).

@@ -134,7 +134,7 @@ async def train_coro(
moving_path,
fixedmask_path=brainmask_path,
**_kwargs,
)
).cmdline
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generate_command function has been modified to return the nipype interface and that way unify the antsRegistration command line generation in a single function.

@@ -40,14 +40,19 @@ def main(argv=None) -> None:
args = parse_args(argv)

# Open the data with the given file path
dwi_dataset: DWI = DWI.from_filename(args.input_file)
dataset: BaseDataset = BaseDataset.from_filename(args.input_file)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generalize beyond DWI. In the end, we most likely will not expose the dataset HDF5 and accept NIfTIs + metadata (e.g., bvecs/bvals) and build the appropriate object. For a future PR.

Comment on lines +42 to +64
DEFAULT_CLIP_PERCENTILE = 75
"""Upper percentile threshold for intensity clipping."""

DEFAULT_MIN_S0 = 1e-5
"""Minimum value when considering the :math:`S_{0}` DWI signal."""

DEFAULT_MAX_S0 = 1.0
"""Maximum value when considering the :math:`S_{0}` DWI signal."""

DEFAULT_LOWB_THRESHOLD = 50
"""The lower bound for the b-value so that the orientation is considered a DW volume."""

DEFAULT_HIGHB_THRESHOLD = 8000
"""A b-value cap for DWI data."""

DEFAULT_NUM_BINS = 15
"""Number of bins to classify b-values."""

DEFAULT_MULTISHELL_BIN_COUNT_THR = 7
"""Default bin count to consider a multishell scheme."""

DTI_MIN_ORIENTATIONS = 6
"""Minimum number of nonzero b-values in a DWI dataset."""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables better belong in the data representation object rather than models (perhaps the two first could be returned back to model).

@@ -342,3 +366,87 @@ def load(
dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)

return dwi_obj


def find_shelling_scheme(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This felt like a dwi-specific manipulation routine, I moved it here for that reason.

def predict(self, *args, **kwargs):
"""Abstract member signature of predict()."""
raise NotImplementedError("Cannot call predict() on a BaseModel instance.")
def fit_predict(self, *_, **kwargs):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to expose a fit/predict interface here. We'll fit&predict all the time. For stationary models (meaning, those trained on the full dataset, e.g., to fine tune hyperparameters such as GPR, will need to implement this logic)

predicted
if predicted is not None
# Infer from dataset if not provided at initialization
else getattr(dataset, "reference", getattr(dataset, "bzero", None))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference comes from thinking of, e.g., fMRI's SBRefs.


# Regress out global signal differences
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If detrending is necessary, the right way to do this is a filter, in this case.


return retval


class AverageDWIModel(BaseDWIModel):
class AverageDWIModel(ExpectationModel):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I held up changing Average here until I got a sense of how you all receive "Expectation". The final merge should have consistent naming.

@oesteban oesteban force-pushed the enh/estimator-refactor branch from e5d6244 to ce14358 Compare January 23, 2025 14:26
@oesteban
Copy link
Member Author

This refactor opens the door to easily address #16 and refrain from having all models derive from a parallelized base.

Copy link
Member

@effigies effigies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments. As noted previously, I haven't used this, so I don't have a stake in the API.

@oesteban oesteban force-pushed the enh/estimator-refactor branch 2 times, most recently from ce6f3b2 to 668e240 Compare January 24, 2025 08:48
@oesteban oesteban force-pushed the enh/estimator-refactor branch from 89f31c7 to 24efee4 Compare January 24, 2025 11:08
@oesteban oesteban force-pushed the enh/estimator-refactor branch 4 times, most recently from 90f4a57 to 6697c78 Compare January 24, 2025 15:43
@oesteban
Copy link
Member Author

oesteban commented Jan 24, 2025

Thanks for the useful feedback @effigies :)

I think this is getting over the finish line. This can work as the initial implementation of this approach. We will need to iterate over this PR in the future.

I also think we have removed the dMRI bias for the most part. By passing the dataset structure to models, they can query the data for the adequate metadata without exposing it into the estimator object, which should be general.

There's lots to polish, but this PR will facilitate future work, I hope.

@oesteban oesteban merged commit 7237ecf into main Jan 24, 2025
11 of 12 checks passed
@oesteban oesteban deleted the enh/estimator-refactor branch January 24, 2025 16:03
@jhlegarreta
Copy link
Contributor

Sorry for my late review/reply here. Thanks for doing this.

As for #62 (comment), does StatisticalMoment or Moment fit better than Expectation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants