Skip to content

Commit f665bd6

Browse files
committed
Merge remote-tracking branch 'upstream/main' into opt
2 parents 85932bf + 1331d57 commit f665bd6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+6481
-3568
lines changed

LICENSE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
THE MIT License
22

3-
Copyright 2020 Jan Beitner
3+
Copyright (c) 2020 - present, the pytorch-forecasting developers
4+
Copyright (c) 2020 Jan Beitner
45

56
Permission is hereby granted, free of charge, to any person obtaining a copy
67
of this software and associated documentation files (the "Software"), to deal

build_tools/changelog.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def fetch_latest_release(): # noqa: D103
6262
"""
6363
import httpx
6464

65-
response = httpx.get(f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS)
65+
response = httpx.get(
66+
f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS
67+
)
6668

6769
if response.status_code == 200:
6870
return response.json()
@@ -91,7 +93,9 @@ def fetch_pull_requests_since_last_release() -> list[dict]:
9193
all_pulls = []
9294
while not is_exhausted:
9395
pulls = fetch_merged_pull_requests(page=page)
94-
all_pulls.extend([p for p in pulls if parser.parse(p["merged_at"]) > published_at])
96+
all_pulls.extend(
97+
[p for p in pulls if parser.parse(p["merged_at"]) > published_at]
98+
)
9599
is_exhausted = any(parser.parse(p["updated_at"]) < published_at for p in pulls)
96100
page += 1
97101
return all_pulls
@@ -101,7 +105,9 @@ def github_compare_tags(tag_left: str, tag_right: str = "HEAD"):
101105
"""Compare commit between two tags."""
102106
import httpx
103107

104-
response = httpx.get(f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}")
108+
response = httpx.get(
109+
f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}"
110+
)
105111
if response.status_code == 200:
106112
return response.json()
107113
else:
@@ -135,7 +141,9 @@ def assign_prs(prs, categs: list[dict[str, list[str]]]):
135141
# if any(l.startswith("module") for l in pr_labels):
136142
# print(i, pr_labels)
137143

138-
assigned["Other"] = list(set(range(len(prs))) - {i for _, j in assigned.items() for i in j})
144+
assigned["Other"] = list(
145+
set(range(len(prs))) - {i for _, j in assigned.items() for i in j}
146+
)
139147

140148
return assigned
141149

docs/source/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def setup(app: Sphinx):
145145
"navbar_end": ["navbar-icon-links.html", "search-field.html"],
146146
"show_nav_level": 2,
147147
"header_links_before_dropdown": 10,
148-
"external_links": [{"name": "GitHub", "url": "https://github.com/sktime/pytorch-forecasting"}],
148+
"external_links": [
149+
{"name": "GitHub", "url": "https://github.com/sktime/pytorch-forecasting"}
150+
],
149151
}
150152

151153
html_sidebars = {

examples/ar.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,20 @@
5151
stop_randomization=True,
5252
)
5353
batch_size = 64
54-
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
55-
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
54+
train_dataloader = training.to_dataloader(
55+
train=True, batch_size=batch_size, num_workers=0
56+
)
57+
val_dataloader = validation.to_dataloader(
58+
train=False, batch_size=batch_size, num_workers=0
59+
)
5660

5761
# save datasets
5862
training.save("training.pkl")
5963
validation.save("validation.pkl")
6064

61-
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min")
65+
early_stop_callback = EarlyStopping(
66+
monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min"
67+
)
6268
lr_logger = LearningRateMonitor()
6369

6470
trainer = pl.Trainer(

examples/nbeats.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,21 @@
4242
add_target_scales=False,
4343
)
4444

45-
validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff)
45+
validation = TimeSeriesDataSet.from_dataset(
46+
training, data, min_prediction_idx=training_cutoff
47+
)
4648
batch_size = 128
47-
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
48-
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)
49+
train_dataloader = training.to_dataloader(
50+
train=True, batch_size=batch_size, num_workers=2
51+
)
52+
val_dataloader = validation.to_dataloader(
53+
train=False, batch_size=batch_size, num_workers=2
54+
)
4955

5056

51-
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
57+
early_stop_callback = EarlyStopping(
58+
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
59+
)
5260
trainer = pl.Trainer(
5361
max_epochs=100,
5462
accelerator="auto",
@@ -63,7 +71,12 @@
6371

6472

6573
net = NBeats.from_dataset(
66-
training, learning_rate=3e-2, log_interval=10, log_val_interval=1, log_gradient_flow=False, weight_decay=1e-2
74+
training,
75+
learning_rate=3e-2,
76+
log_interval=10,
77+
log_val_interval=1,
78+
log_gradient_flow=False,
79+
weight_decay=1e-2,
6780
)
6881
print(f"Number of parameters in network: {net.size() / 1e3:.1f}k")
6982

examples/stallion.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@
77
import numpy as np
88
from pandas.core.common import SettingWithCopyWarning
99

10-
from pytorch_forecasting import GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet
10+
from pytorch_forecasting import (
11+
GroupNormalizer,
12+
TemporalFusionTransformer,
13+
TimeSeriesDataSet,
14+
)
1115
from pytorch_forecasting.data.examples import get_stallion_data
1216
from pytorch_forecasting.metrics import QuantileLoss
13-
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
17+
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import (
18+
optimize_hyperparameters,
19+
)
1420

1521
warnings.simplefilter("error", category=SettingWithCopyWarning)
1622

@@ -22,8 +28,12 @@
2228

2329
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
2430
data["time_idx"] -= data["time_idx"].min()
25-
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
26-
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")
31+
data["avg_volume_by_sku"] = data.groupby(
32+
["time_idx", "sku"], observed=True
33+
).volume.transform("mean")
34+
data["avg_volume_by_agency"] = data.groupby(
35+
["time_idx", "agency"], observed=True
36+
).volume.transform("mean")
2737
# data = data[lambda x: (x.sku == data.iloc[0]["sku"]) & (x.agency == data.iloc[0]["agency"])]
2838
special_days = [
2939
"easter_day",
@@ -39,7 +49,9 @@
3949
"beer_capital",
4050
"music_fest",
4151
]
42-
data[special_days] = data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
52+
data[special_days] = (
53+
data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
54+
)
4355

4456
training_cutoff = data["time_idx"].max() - 6
4557
max_encoder_length = 36
@@ -50,14 +62,17 @@
5062
time_idx="time_idx",
5163
target="volume",
5264
group_ids=["agency", "sku"],
53-
min_encoder_length=max_encoder_length // 2, # allow encoder lengths from 0 to max_prediction_length
65+
min_encoder_length=max_encoder_length
66+
// 2, # allow encoder lengths from 0 to max_prediction_length
5467
max_encoder_length=max_encoder_length,
5568
min_prediction_length=1,
5669
max_prediction_length=max_prediction_length,
5770
static_categoricals=["agency", "sku"],
5871
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
5972
time_varying_known_categoricals=["special_days", "month"],
60-
variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable
73+
variable_groups={
74+
"special_days": special_days
75+
}, # group of categorical variables can be treated as one variable
6176
time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
6277
time_varying_unknown_categoricals=[],
6378
time_varying_unknown_reals=[
@@ -78,17 +93,25 @@
7893
)
7994

8095

81-
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
96+
validation = TimeSeriesDataSet.from_dataset(
97+
training, data, predict=True, stop_randomization=True
98+
)
8299
batch_size = 64
83-
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
84-
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
100+
train_dataloader = training.to_dataloader(
101+
train=True, batch_size=batch_size, num_workers=0
102+
)
103+
val_dataloader = validation.to_dataloader(
104+
train=False, batch_size=batch_size, num_workers=0
105+
)
85106

86107

87108
# save datasets
88109
training.save("t raining.pkl")
89110
validation.save("validation.pkl")
90111

91-
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
112+
early_stop_callback = EarlyStopping(
113+
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
114+
)
92115
lr_logger = LearningRateMonitor()
93116
logger = TensorBoardLogger(log_graph=True)
94117

pyproject.toml

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,3 @@
1-
[tool.ruff]
2-
line-length = 120
3-
exclude = [
4-
"docs/build/",
5-
"node_modules/",
6-
".eggs/",
7-
"versioneer.py",
8-
"venv/",
9-
".venv/",
10-
".git/",
11-
".history/",
12-
]
13-
14-
[tool.ruff.lint]
15-
select = ["E", "F", "W", "C4", "S"]
16-
extend-ignore = [
17-
"E203", # space before : (needed for how black formats slicing)
18-
"E402", # module level import not at top of file
19-
"E731", # do not assign a lambda expression, use a def
20-
"E741", # ignore not easy to read variables like i l I etc.
21-
"C406", # Unnecessary list literal - rewrite as a dict literal.
22-
"C408", # Unnecessary dict call - rewrite as a literal.
23-
"C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal.
24-
"F401", # unused imports
25-
"S101", # use of assert
26-
]
27-
28-
[tool.ruff.lint.isort]
29-
known-first-party = ["pytorch_forecasting"]
30-
combine-as-imports = true
31-
force-sort-within-sections = true
32-
33-
[tool.black]
34-
line-length = 120
35-
include = '\.pyi?$'
36-
exclude = '''
37-
(
38-
/(
39-
\.eggs # exclude a few common directories in the
40-
| \.git # root of the project
41-
| \.hg
42-
| \.mypy_cache
43-
| \.tox
44-
| \.venv
45-
| _build
46-
| buck-out
47-
| build
48-
| dist
49-
)/
50-
| docs/build/
51-
| node_modules/
52-
| venve/
53-
| .venv/
54-
)
55-
'''
56-
57-
[tool.nbqa.mutate]
58-
ruff = 1
59-
black = 1
60-
611
[project]
622
name = "pytorch-forecasting"
633
readme = "README.md" # Markdown files are supported
@@ -184,3 +124,67 @@ build-backend = "setuptools.build_meta"
184124
requires = [
185125
"setuptools>=70.0.0",
186126
]
127+
128+
[tool.ruff]
129+
line-length = 88
130+
exclude = [
131+
"docs/build/",
132+
"node_modules/",
133+
".eggs/",
134+
"versioneer.py",
135+
"venv/",
136+
".venv/",
137+
".git/",
138+
".history/",
139+
]
140+
141+
[tool.ruff.lint]
142+
select = ["E", "F", "W", "C4", "S"]
143+
extend-select = [
144+
"I", # isort
145+
"C4", # https://pypi.org/project/flake8-comprehensions
146+
]
147+
extend-ignore = [
148+
"E203", # space before : (needed for how black formats slicing)
149+
"E402", # module level import not at top of file
150+
"E731", # do not assign a lambda expression, use a def
151+
"E741", # ignore not easy to read variables like i l I etc.
152+
"C406", # Unnecessary list literal - rewrite as a dict literal.
153+
"C408", # Unnecessary dict call - rewrite as a literal.
154+
"C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal.
155+
"F401", # unused imports
156+
"S101", # use of assert
157+
]
158+
159+
[tool.ruff.lint.isort]
160+
known-first-party = ["pytorch_forecasting"]
161+
combine-as-imports = true
162+
force-sort-within-sections = true
163+
164+
[tool.black]
165+
line-length = 88
166+
include = '\.pyi?$'
167+
exclude = '''
168+
(
169+
/(
170+
\.eggs # exclude a few common directories in the
171+
| \.git # root of the project
172+
| \.hg
173+
| \.mypy_cache
174+
| \.tox
175+
| \.venv
176+
| _build
177+
| buck-out
178+
| build
179+
| dist
180+
)/
181+
| docs/build/
182+
| node_modules/
183+
| venve/
184+
| .venv/
185+
)
186+
'''
187+
188+
[tool.nbqa.mutate]
189+
ruff = 1
190+
black = 1

0 commit comments

Comments
 (0)