Skip to content

[BUG]: Fix TimeXer autograd graph detachment and deepcopy crash#2179

Open
Siddhazntx wants to merge 4 commits intosktime:mainfrom
Siddhazntx:fix-timexer-grad-bug
Open

[BUG]: Fix TimeXer autograd graph detachment and deepcopy crash#2179
Siddhazntx wants to merge 4 commits intosktime:mainfrom
Siddhazntx:fix-timexer-grad-bug

Conversation

@Siddhazntx
Copy link
Contributor

Reference Issues/PRs

Fixes #2110

What does this implement/fix? Explain your changes.

This PR resolves the RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn that caused the automated CI to fail during TimeXer integration tests (specifically [TimeXer-base_params-3-RMSE]).

Root Cause & Fixes:

  1. Autograd Graph Detachment (Shared State Bug): When initializing TimeXer for multivariate forecasting (features="M"), MultiLoss was being initialized using list multiplication ([MAE()] * len(...)). This created multiple pointers to the exact same metric instance. During the forward pass for multiple targets, the internal state was overwritten, breaking the autograd graph and detaching the final loss tensor.
    Change: Updated the initialization to use a list comprehension ([MAE() for _ in range(...)]), ensuring every target receives an independent metric instance.

  2. Deepcopy Crash on Hyperparameter Save: self.save_hyperparameters() was attempting to deepcopy initialized, complex metric objects and live tensors, which PyTorch natively blocks, causing local test crashes.
    Change: Added ignore=["loss", "logging_metrics"] to save_hyperparameters() and ensured it executes before super().init() wraps them.

What should a reviewer concentrate their feedback on?

  • The updated initialization order in TimeXer.init.

  • The list comprehension logic for MultiLoss to ensure it aligns with the expected metric tracking structure in PyTorch Lightning.

Did you add any tests for the change?

No new tests were added, as this PR fixes the existing, failing CI integration tests. Validated the changes locally by running the full test suite in headless mode (pytest pytorch_forecasting/tests/test_all_estimators.py -k "TimeXer"), resulting in a clean pass for all 42 variants.

Click to expand: Local Test Run Output (All 42 TimeXer tests passing)
ptf-dev) PS D:\pytorch-forecasting> $env:MPLBACKEND="Agg"; pytest pytorch_forecasting/tests/test_all_estimators.py -k "TimeXer"
Test session starts (platform: win32, Python 3.10.19, pytest 9.0.2, pytest-sugar 1.1.1)
cachedir: .cache
rootdir: D:\pytorch-forecasting
configfile: pyproject.toml
plugins: cov-7.0.0, dotenv-0.5.2, sugar-1.1.1, xdist-3.8.0
collected 317 items / 275 deselected / 42 selected                                                                                                                                                                

 pytorch_forecasting\tests\test_all_estimators.py::TestAllPtForecasters.test_doctest_examples[TimeXer] ✓                                     2% ▎         
 pytorch_forecasting\tests\test_all_estimators.py::TestAllPtForecasters.test_integration[TimeXer-base_params-0-MAE] ✓                        5% ▌         
...
 pytorch_forecasting\tests\test_all_estimators.py::TestAllPtForecasters.test_integration[TimeXer-base_params-3-RMSE] ✓                      24% ██▍       
...
 pytorch_forecasting\tests\test_all_estimators.py::TestAllPtForecasters.test_integration[TimeXer-base_params-4-QuantileLoss] ✓              98% █████████▊
 pytorch_forecasting\tests\test_all_estimators.py::TestAllPtForecasters.test_pkg_linkage[TimeXer-TimeXer] ✓                                100% ██████████
=================================================================== warnings summary =================================================================== 
pytorch_forecasting/tests/test_all_estimators.py: 48 warnings
  C:\Users\hp\anaconda3\envs\ptf-dev\lib\site-packages\lightning\pytorch\utilities\_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

Results (119.35s (0:01:59)):
      42 passed
     275 deselected

PR checklist

  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

@codecov
Copy link

codecov bot commented Mar 12, 2026

Codecov Report

❌ Patch coverage is 80.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@1a19279). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pytorch_forecasting/models/timexer/_timexer.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2179   +/-   ##
=======================================
  Coverage        ?   86.62%           
=======================================
  Files           ?      165           
  Lines           ?     9737           
  Branches        ?        0           
=======================================
  Hits            ?     8435           
  Misses          ?     1302           
  Partials        ?        0           
Flag Coverage Δ
cpu 86.62% <80.00%> (?)
pytest 86.62% <80.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

loss = MultiLoss([MAE()] * len(self.target_positions))
else:
loss = MAE()
loss = MAE()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you removing the MultiLoss from here?

Copy link
Contributor Author

@Siddhazntx Siddhazntx Mar 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally took it out because the old way ([MAE()] * len(...)) shared the exact same metric object in memory, which was actually what caused the autograd graph to detach. But you're totally right. I'll add it back, but I'll use a list comprehension ([MAE() for _ in range(...)]) so each target gets its own independent metric. That should fix the bug and keep the graph stable!

if len(target_positions) == 1:
prediction = prediction[..., 0, :]
if len(target_positions) > 1:
prediction = torch.stack(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you using torch.stack? I think for multi-target, it should be a list of tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad on that one! I was trying to format the 3D tensor output and just defaulted to torch.stack without thinking. A list is definitely the right way to go here. I'll update it!

@Siddhazntx
Copy link
Contributor Author

Hey @phoeenniixx, Just pushed a quick follow-up commit. The last update tripped a unit test because I tried to check self.target_positions to initialize MultiLoss before super().__init__() actually created it (which threw an AttributeError).

I swapped it to safely check kwargs.get("output_size") instead. All the local unit and integration tests are completely green on my end now! Let me know if everything looks good to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] TimeXer integration test failing due to requires_grad

2 participants