diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 28410960..eb708984 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,13 +10,16 @@ on: branches: [master] jobs: - build: + pytest: + name: Run pytest runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest, macOS-latest] # add windows-2019 when poetry allows installation with `-f` flag python-version: ["3.8", "3.9", "3.10"] - + defaults: + run: + shell: bash steps: - uses: actions/checkout@v2 @@ -30,18 +33,11 @@ jobs: run: | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 - - name: Setup macOS - if: runner.os == 'windows' - run: | - brew install libomp # https://github.com/pytorch/pytorch/issues/20030 - - name: Get full Python version id: full-python-version - shell: bash run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") - name: Install poetry - shell: bash run: | curl -sSL https://install.python-poetry.org | python3 - @@ -49,7 +45,6 @@ jobs: run: echo "/Users/runner/.local/bin" >> $GITHUB_PATH - name: Configure poetry - shell: bash run: poetry config virtualenvs.in-project true - name: Set up cache @@ -61,15 +56,12 @@ jobs: - name: Ensure cache is healthy if: steps.cache.outputs.cache-hit == 'true' - shell: bash run: poetry run pip --version >/dev/null 2>&1 || rm -rf .venv - name: Upgrade pip - shell: bash run: poetry run python -m pip install pip -U - name: Install dependencies - shell: bash run: poetry install -E "github-actions graph mqf2" # - name: Install pytorch geometric dependencies @@ -77,11 +69,9 @@ jobs: # run: poetry run pip install pyg_lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.0+cpu.html - name: Run pytest - shell: bash run: poetry run pytest tests - name: Statistics - if: success() run: | pip install coverage coverage report @@ -99,9 +89,11 @@ jobs: fail_ci_if_error: false docs: - name: Test docs build + name: Docs build runs-on: ubuntu-latest - + defaults: + run: + shell: bash steps: - name: Check out Git repository uses: actions/checkout@v2 @@ -116,21 +108,17 @@ jobs: with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('docs/requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- + restore-keys: ${{ runner.os }}-pip- - name: Install dependencies run: | - sudo apt-get update && sudo apt-get install -y pandoc - python -m pip install --upgrade pip + sudo apt-get update --fix-missing + sudo apt-get install -y pandoc pip install -r docs/requirements.txt - shell: bash - name: Build sphinx documentation - run: | - cd docs - make clean - make html --debug --jobs 2 SPHINXOPTS="-W" + working-directory: docs/ + run: make html --debug --jobs 2 SPHINXOPTS="-W" - name: Upload built docs uses: actions/upload-artifact@v2 @@ -138,4 +126,4 @@ jobs: name: docs-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: docs/build/html/ # Use always() to always run this step to publish test results when there are test failures - if: success() + #if: success() diff --git a/README.md b/README.md index 0c14e2c8..5b6c42a0 100755 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [conda-url]: https://anaconda.org/conda-forge/pytorch-forecasting [build-image]: https://github.com/jdb78/pytorch-forecasting/actions/workflows/test.yml/badge.svg?branch=master [build-url]: https://github.com/jdb78/pytorch-forecasting/actions/workflows/test.yml?query=branch%3Amaster -[linter-image]: https://github.com/jdb78/pytorch-forecasting/actions/workflows/code_quality.yml/badge.svg?branch=master +[linter-image]: https://github.com/jdb78/pytorch-forecasting/actions/workflows/lint.yml/badge.svg?event=push [linter-url]: https://github.com/jdb78/pytorch-forecasting/actions/workflows/code_quality.yml?query=branch%3Amaster [docs-image]: https://readthedocs.org/projects/pytorch-forecasting/badge/?version=latest [docs-url]: https://pytorch-forecasting.readthedocs.io diff --git a/docs/requirements.txt b/docs/requirements.txt index 17a82e77..09435203 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,19 +1,19 @@ -sphinx>3.2 +sphinx >3.2 nbsphinx pandoc docutils pydata-sphinx-theme -lightning>=2.0.0 +lightning >=2.0.0 cloudpickle -torch>=2.0,!=2.0.1 -optuna>=3.1.0 +torch >=2.0,!=2.0.1 +optuna >=3.1.0 scipy -pandas>=1.3 -scikit-learn>1.2 +pandas >=1.3 +scikit-learn >1.2 matplotlib statsmodels ipython -nbconvert>=6.3.0 -recommonmark>=0.7.1 -pytorch-optimizer>=2.5.1 -fastapi>0.80 +nbconvert >=6.3.0 +recommonmark >=0.7.1 +pytorch-optimizer >=2.5.1 +fastapi >0.80 diff --git a/docs/source/tutorials/stallion.ipynb b/docs/source/tutorials/stallion.ipynb index fb600176..a8da6d87 100644 --- a/docs/source/tutorials/stallion.ipynb +++ b/docs/source/tutorials/stallion.ipynb @@ -978,7 +978,7 @@ " dropout=0.1, # between 0.1 and 0.3 are good values\n", " hidden_continuous_size=8, # set to <= hidden_size\n", " loss=QuantileLoss(),\n", - " optimizer=\"Ranger\"\n", + " optimizer=\"Ranger\",\n", " # reduce learning rate if no improvement in validation loss after x epochs\n", " # reduce_on_plateau_patience=1000,\n", ")\n", diff --git a/poetry.lock b/poetry.lock index c8270929..7899ab19 100644 --- a/poetry.lock +++ b/poetry.lock @@ -965,23 +965,22 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] [[package]] name = "fastapi" -version = "0.104.1" +version = "0.109.2" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.104.1-py3-none-any.whl", hash = "sha256:752dc31160cdbd0436bb93bad51560b57e525cbb1d4bbf6f4904ceee75548241"}, - {file = "fastapi-0.104.1.tar.gz", hash = "sha256:e5e4540a7c5e1dcfbbcf5b903c234feddcdcd881f191977a1c5dfd917487e7ae"}, + {file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"}, + {file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"}, ] [package.dependencies] -anyio = ">=3.7.1,<4.0.0" pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.27.0,<0.28.0" +starlette = ">=0.36.3,<0.37.0" typing-extensions = ">=4.8.0" [package.extras] -all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "fastjsonschema" @@ -2623,12 +2622,12 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-nccl-cu12" -version = "2.18.1" +version = "2.19.3" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, + {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, ] [[package]] @@ -3691,6 +3690,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4409,13 +4409,13 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "starlette" -version = "0.27.0" +version = "0.36.3" description = "The little ASGI library that shines." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91"}, - {file = "starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75"}, + {file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"}, + {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"}, ] [package.dependencies] @@ -4423,7 +4423,7 @@ anyio = ">=3.4.0,<5" typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] -full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] [[package]] name = "statsmodels" @@ -4618,31 +4618,36 @@ files = [ [[package]] name = "torch" -version = "2.1.1" +version = "2.2.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5ebc43f5355a9b7be813392b3fb0133991f0380f6f0fcc8218d5468dc45d1071"}, - {file = "torch-2.1.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:84fefd63356416c0cd20578637ccdbb82164993400ed17b57c951dd6376dcee8"}, - {file = "torch-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:0a7a9da0c324409bcb5a7bdad1b4e94e936d21c2590aaa7ac2f63968da8c62f7"}, - {file = "torch-2.1.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:1e1e5faddd43a8f2c0e0e22beacd1e235a2e447794d807483c94a9e31b54a758"}, - {file = "torch-2.1.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:e76bf3c5c354874f1da465c852a2fb60ee6cbce306e935337885760f080f9baa"}, - {file = "torch-2.1.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:98fea993639b0bb432dfceb7b538f07c0f1c33386d63f635219f49254968c80f"}, - {file = "torch-2.1.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:61b51b33c61737c287058b0c3061e6a9d3c363863e4a094f804bc486888a188a"}, - {file = "torch-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:1d70920da827e2276bf07f7ec46958621cad18d228c97da8f9c19638474dbd52"}, - {file = "torch-2.1.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:a70593806f1d7e6b53657d96810518da0f88ef2608c98a402955765b8c79d52c"}, - {file = "torch-2.1.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e312f7e82e49565f7667b0bbf9559ab0c597063d93044740781c02acd5a87978"}, - {file = "torch-2.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1e3cbecfa5a7314d828f4a37b0c286714dc9aa2e69beb7a22f7aca76567ed9f4"}, - {file = "torch-2.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9ca0fcbf3d5ba644d6a8572c83a9abbdf5f7ff575bc38529ef6c185a3a71bde9"}, - {file = "torch-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:2dc9f312fc1fa0d61a565a0292ad73119d4b74c9f8b5031b55f8b4722abca079"}, - {file = "torch-2.1.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:d56b032176458e2af4709627bbd2c20fe2917eff8cd087a7fe313acccf5ce2f1"}, - {file = "torch-2.1.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:29e3b90a8c281f6660804a939d1f4218604c80162e521e1e6d8c8557325902a0"}, - {file = "torch-2.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:bd95cee8511584b67ddc0ba465c3f1edeb5708d833ee02af1206b4486f1d9096"}, - {file = "torch-2.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b31230bd058424e56dba7f899280dbc6ac8b9948e43902e0c84a44666b1ec151"}, - {file = "torch-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:403f1095e665e4f35971b43797a920725b8b205723aa68254a4050c6beca29b6"}, - {file = "torch-2.1.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:715b50d8c1de5da5524a68287eb000f73e026e74d5f6b12bc450ef6995fcf5f9"}, - {file = "torch-2.1.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:db67e8725c76f4c7f4f02e7551bb16e81ba1a1912867bc35d7bb96d2be8c78b4"}, + {file = "torch-2.2.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d366158d6503a3447e67f8c0ad1328d54e6c181d88572d688a625fac61b13a97"}, + {file = "torch-2.2.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:707f2f80402981e9f90d0038d7d481678586251e6642a7a6ef67fc93511cb446"}, + {file = "torch-2.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:15c8f0a105c66b28496092fca1520346082e734095f8eaf47b5786bac24b8a31"}, + {file = "torch-2.2.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:0ca4df4b728515ad009b79f5107b00bcb2c63dc202d991412b9eb3b6a4f24349"}, + {file = "torch-2.2.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:3d3eea2d5969b9a1c9401429ca79efc668120314d443d3463edc3289d7f003c7"}, + {file = "torch-2.2.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:0d1c580e379c0d48f0f0a08ea28d8e373295aa254de4f9ad0631f9ed8bc04c24"}, + {file = "torch-2.2.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:9328e3c1ce628a281d2707526b4d1080eae7c4afab4f81cea75bde1f9441dc78"}, + {file = "torch-2.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:03c8e660907ac1b8ee07f6d929c4e15cd95be2fb764368799cca02c725a212b8"}, + {file = "torch-2.2.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:da0cefe7f84ece3e3b56c11c773b59d1cb2c0fd83ddf6b5f7f1fd1a987b15c3e"}, + {file = "torch-2.2.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f81d23227034221a4a4ff8ef24cc6cec7901edd98d9e64e32822778ff01be85e"}, + {file = "torch-2.2.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:dcbfb2192ac41ca93c756ebe9e2af29df0a4c14ee0e7a0dd78f82c67a63d91d4"}, + {file = "torch-2.2.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:9eeb42971619e24392c9088b5b6d387d896e267889d41d267b1fec334f5227c5"}, + {file = "torch-2.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:c718b2ca69a6cac28baa36d86d8c0ec708b102cebd1ceb1b6488e404cd9be1d1"}, + {file = "torch-2.2.0-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:f11d18fceb4f9ecb1ac680dde7c463c120ed29056225d75469c19637e9f98d12"}, + {file = "torch-2.2.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:ee1da852bfd4a7e674135a446d6074c2da7194c1b08549e31eae0b3138c6b4d2"}, + {file = "torch-2.2.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0d819399819d0862268ac531cf12a501c253007df4f9e6709ede8a0148f1a7b8"}, + {file = "torch-2.2.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:08f53ccc38c49d839bc703ea1b20769cc8a429e0c4b20b56921a9f64949bf325"}, + {file = "torch-2.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:93bffe3779965a71dab25fc29787538c37c5d54298fd2f2369e372b6fb137d41"}, + {file = "torch-2.2.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:c17ec323da778efe8dad49d8fb534381479ca37af1bfc58efdbb8607a9d263a3"}, + {file = "torch-2.2.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c02685118008834e878f676f81eab3a952b7936fa31f474ef8a5ff4b5c78b36d"}, + {file = "torch-2.2.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d9f39d6f53cec240a0e3baa82cb697593340f9d4554cee6d3d6ca07925c2fac0"}, + {file = "torch-2.2.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:51770c065206250dc1222ea7c0eff3f88ab317d3e931cca2aee461b85fbc2472"}, + {file = "torch-2.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:008e4c6ad703de55af760c73bf937ecdd61a109f9b08f2bbb9c17e7c7017f194"}, + {file = "torch-2.2.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:de8680472dd14e316f42ceef2a18a301461a9058cd6e99a1f1b20f78f11412f1"}, + {file = "torch-2.2.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:99e1dcecb488e3fd25bcaac56e48cdb3539842904bdc8588b0b255fde03a254c"}, ] [package.dependencies] @@ -4659,15 +4664,15 @@ nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linu nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -typing-extensions = "*" +triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = ">=4.8.0" [package.extras] -dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.9.1)"] [[package]] name = "torchmetrics" @@ -4700,38 +4705,43 @@ visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.2.0)"] [[package]] name = "torchvision" -version = "0.16.1" +version = "0.17.0" description = "image and video datasets and models for torch deep learning" optional = true python-versions = ">=3.8" files = [ - {file = "torchvision-0.16.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:987132795e5c037cb74e7be35a693999fdb2f603152266ee15b80206e83a5b0c"}, - {file = "torchvision-0.16.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:25da6a7b22ea0348f62c45ec0daf157731096babcae65d222404081af96e085c"}, - {file = "torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:c82e291e674a18b67f92ddb476ae18498fb46d7032ae914f3fda90c955e7d86f"}, - {file = "torchvision-0.16.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:632887b22e67ce32a3ede806b868bba4057601e46d680de14b32a391eac1b483"}, - {file = "torchvision-0.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:92c76a5092b4033efdb183b11fa4854a7630e23c46f4a1c3ffd70c30cb5be4fc"}, - {file = "torchvision-0.16.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:4aea5cf491c6c21b1cbdbb1bf2a3838a59d4db93ad5f49019a6564d3ca7127c7"}, - {file = "torchvision-0.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3391757167637ace3ef33a67c9d5ef86b1f8cbd93eaa5bad45eebcf266ea6089"}, - {file = "torchvision-0.16.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:4f9d5b192b336982e6dbe32c070b05606f0b53e87d722ae332a02909fbf988ed"}, - {file = "torchvision-0.16.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:3d34601614958c4e30f53ec0eb7bf3f282ee72bb747734be2d75422831a43384"}, - {file = "torchvision-0.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:e11af530585574eb5ca837b8f151bcdd57c10e35c3af56c76a10f3281d2a2f2c"}, - {file = "torchvision-0.16.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:4f2cad621fb96cf10e29af93e16c98b3226bdd53ae712b57e873c3deaf061617"}, - {file = "torchvision-0.16.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d614b3c9e2de9cd75cc0e4e1923fcfbbcd9fdb9f08a0bbbbf7e135e4a0a1cfa"}, - {file = "torchvision-0.16.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:73271e930501a008fe24ba38945b2a75b25a6098f4c2f4402e39a9d0dd305ca6"}, - {file = "torchvision-0.16.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fab67ddc4809fcc2a04610b13cac5193b9d3be2896b77538bfdff401b13022e5"}, - {file = "torchvision-0.16.1-cp38-cp38-win_amd64.whl", hash = "sha256:13782d574033efec6646d1a2f5d85f4c59fcf3f403367bb407b15df07adc87e0"}, - {file = "torchvision-0.16.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:f14d201c37176dc4106eec76b229d6585a1505266b8cea99d3366fd38897b7c0"}, - {file = "torchvision-0.16.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a15e88a93a7501cc75b761a2dcd07aaedaaf9cbfaf48c8affa8c98989ecbb19d"}, - {file = "torchvision-0.16.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:72fde5fdb462e66ebe25ae42d2ee11434cbc395f74cad0d3b22cf60524345cc5"}, - {file = "torchvision-0.16.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:153f753f14eba58969cdc86360893a57f8bf63f8136c7d1cd4388108560b5446"}, - {file = "torchvision-0.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:75e33b198b1265f61d822aa66d646ec3df67a712470ffec1e0c37ff46d4103c1"}, + {file = "torchvision-0.17.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:153882cd8ff8e3dbef5c5054fdd15df64e85420546805a90c0b2221f2f119c4a"}, + {file = "torchvision-0.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c55c2f86e3f3a21ddd92739a972366244e9b17916e836ec47167b0a0c083c65f"}, + {file = "torchvision-0.17.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605950cdcefe6c5aef85709ade17b1525bcf171e122cce1df09e666d96525b90"}, + {file = "torchvision-0.17.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:3d86c212fc6379e9bec3ac647d062e34c2cf36c26b98840b66573eb9fbe1f1d9"}, + {file = "torchvision-0.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:71b314813faf13cecb09a4a635b5e4b274e8df0b1921681038d491c529555bb6"}, + {file = "torchvision-0.17.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:10d276821f115fb369e6cf1f1b77b2cca60cda12cbb39a41513a9d3d0f2a93ae"}, + {file = "torchvision-0.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3eef2daddadb5c21e802e0550dd7e3ee3d98c430f4aed212ae3ba0358558be1"}, + {file = "torchvision-0.17.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:acc0d098ab8c295a750f0218bf5bf7bfc2f2c21f9c2fe3fc30b695cd94f4c759"}, + {file = "torchvision-0.17.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:3d2e9552d72e4037f2db6f7d97989a2e2f95763aa1861963a3faf521bb1610c4"}, + {file = "torchvision-0.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:f8e542cf71e1294fcb5635038eae6702df543dc90706f0836ec80e75efc511fc"}, + {file = "torchvision-0.17.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:816ae1a4506b1cb0f638e1827cae7ab768c731369ab23e86839f177926197143"}, + {file = "torchvision-0.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be39874c239215a39b3c431c7016501f1a45bfbbebf2fe8e11d8339b5ea23bca"}, + {file = "torchvision-0.17.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:8fe14d580557aef2c45dd462c069ff936b6507b215c4b496f30973ae8cff917d"}, + {file = "torchvision-0.17.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:4608ba3246c45c968ede40e7640e4eed64556210faa154cf1ffccb1cadabe445"}, + {file = "torchvision-0.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:b755d6d3e021239d2408bf3794d0d3dcffbc629f1fd808c43d8b346045a098c4"}, + {file = "torchvision-0.17.0-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:870d7cda57420e44d20eb07bfe37bf5344a06434a7a6195b4c7f3dd55838587d"}, + {file = "torchvision-0.17.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:477f6e64a9d798c0f5adefc300acc220da6f17ef5c1e110d20108f66554fee4d"}, + {file = "torchvision-0.17.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:a54a15bd6f3dbb04ebd36c5a87530b2e090ee4b9b15eb89eda558ab3e50396a0"}, + {file = "torchvision-0.17.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e041ce3336364413bab051a3966d884bab25c200f98ca8a065f0abe758c3005e"}, + {file = "torchvision-0.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:7887f767670c72aa20f5237042d0ca1462da18f66a3ea8c36b6ba67ce26b82fc"}, + {file = "torchvision-0.17.0-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:b1ced438b81ef662a71c8c81debaf0c80455b35b811ca55a4c3c593d721b560a"}, + {file = "torchvision-0.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b53569c52bd4bd1176a1e49d8ea55883bcf57e1614cb97e2e8ce372768299b70"}, + {file = "torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7f373507afcd9022ebd9f50b31da8dbac1ea6783ffb77d1f1ab8806425c0a83b"}, + {file = "torchvision-0.17.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:085251ab36340206dc7e1be59a15fa5e307d45ccd66889f5d7bf1ba5e7ecdc57"}, + {file = "torchvision-0.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:4c0d4c0af58af2752aad235150bd794d0f324e6eeac5cd13c440bda5dce622d3"}, ] [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" requests = "*" -torch = "2.1.1" +torch = "2.2.0" [package.extras] scipy = ["scipy"] @@ -4793,28 +4803,26 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] [[package]] name = "triton" -version = "2.1.0" +version = "2.2.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, - {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, - {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, - {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, - {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, - {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, - {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, - {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, + {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"}, + {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"}, + {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"}, + {file = "triton-2.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8fe46d3ab94a8103e291bd44c741cc294b91d1d81c1a2888254cbf7ff846dab"}, + {file = "triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ce26093e539d727e7cf6f6f0d932b1ab0574dc02567e684377630d86723ace"}, + {file = "triton-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:227cc6f357c5efcb357f3867ac2a8e7ecea2298cd4606a8ba1e931d1d5a947df"}, ] [package.dependencies] filelock = "*" [package.extras] -build = ["cmake (>=3.18)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] -tutorials = ["matplotlib", "pandas", "tabulate"] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "typing-extensions" diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index f2edefd9..1344c0b7 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -26,7 +26,7 @@ # need to inherit from callback for this to work -class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback): +class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): pass