Skip to content

[WIP][ENH] TTM data pipeline: _dataset_class hook + TTMDataModule#2206

Draft
StrikerEureka34 wants to merge 2 commits intosktime:mainfrom
StrikerEureka34:Experiments/FM
Draft

[WIP][ENH] TTM data pipeline: _dataset_class hook + TTMDataModule#2206
StrikerEureka34 wants to merge 2 commits intosktime:mainfrom
StrikerEureka34:Experiments/FM

Conversation

@StrikerEureka34
Copy link

Reference Issues/PRs

WIP towards #2184. Related: #1959 #2051

What does this implement/fix? Explain your changes.

WIP draft for #2184, that is integrating IBM's TinyTimeMixer (TTM) into pytorch-forecasting's v2 pipeline.
TTM is a pretrained foundation model for zero-shot and fine-tuned time series forecasting, and it expects a specific input format that our current v2 pipeline doesn't produce.

This draft/WIP PR covers the data layer only. No model weights, no tsfm_public dependency yet, as the main aim of it was to just getting the pipeline to speak TTM's language.

Issues I ran into

TslibDataModule.setup() had four hardcoded _TslibDataset(...) calls.
So there was no clean way to swap in a different dataset class without duplicating the entire method.

Changes made

Added _dataset_class = _TslibDataset to TslibDataModule and replaced those four calls with self._dataset_class(...).
Which gives a zero behavioral change for existing code.

On top of that hook:

  • _TTMDataset now overrides __getitem__ to produce TTM-format batches:
    past_values, past_observed_mask, future_values, prediction_channel_indices
  • TTMDataModule now sets _dataset_class = _TTMDataset, derives channel indices
    from metadata, and provides a collate_fn for the mixed Tensor/non-Tensor batch

So the channel ordering follows TTM's spec:
[targets | past-only covariates | known-future covariates]

Does it work?

Yes, with a MockTTM standing in for the real model:

test_ttm_data_module_uses_ttm_dataset  PASSED
test_ttm_dataset_captures_stage        PASSED
test_channel_indices_derived           PASSED
test_batch_shapes                      PASSED
test_no_future_values_in_predict       PASSED
test_mock_ttm_forward                  PASSED

One pre-existing failure in test_multivariate_target ('list' object has no attribute 'shape') which was reproducible on main, so unrelated to this PR and the changes made by it, worth a separate issue if not already tracked?

What's left

  • Caching: _preprocess_data runs per __getitem__ so it needs a per-series cache before this is usable at scale (I mentioned it as a TODO in the code and using a short term fix for current scope)
  • Model wrapper: thin wrapper around tsfm_public.models.tinytimemixer, as an optional dependency.
  • Training glue: loss, forward, predict steps wired through TslibBaseModel

The _dataset_class hook is intentionally generic, with TTMDataModule as the first user, but the same pattern should work for Chronos, Moirai, and others with their own __getitem__ overrides.
More discussion and testing is needed for that I believe, but initial findings seem promising.

Looking for early feedback on this hook design and TTM channel format assumptions before going further, thanks!
@phoeenniixx @PranavBhatP @agrob

@codecov
Copy link

codecov bot commented Mar 17, 2026

Codecov Report

❌ Patch coverage is 9.03955% with 161 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@edbdeb4). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...orch_forecasting/data/tests/test_fm_data_module.py 0.00% 85 Missing ⚠️
pytorch_forecasting/data/_fm_data_module.py 15.71% 59 Missing ⚠️
...h_forecasting/data/tests/test_tslib_data_module.py 0.00% 17 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2206   +/-   ##
=======================================
  Coverage        ?   85.23%           
=======================================
  Files           ?      167           
  Lines           ?     9909           
  Branches        ?        0           
=======================================
  Hits            ?     8446           
  Misses          ?     1463           
  Partials        ?        0           
Flag Coverage Δ
cpu 85.23% <9.03%> (?)
pytest 85.23% <9.03%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant