Skip to content

Commit 1331d57

Browse files
d-schmittfnhirwa
andauthored
[MNT] update nbeats/sub_modules.py to remove overhead in tensor creation (#1580)
This commit removes a warning raised by nbeats/sub_module.py: `UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor.` This commit replaces the relevant sections with the appropriate code to remove the warning in the nbeats submodule. (similar to #754) Co-authored-by: Felix Hirwa Nshuti <[email protected]>
1 parent 18153aa commit 1331d57

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

pytorch_forecasting/models/nbeats/sub_modules.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,33 +111,21 @@ def __init__(
111111
else (thetas_dim // 2, thetas_dim // 2 + 1)
112112
)
113113
s1_b = torch.tensor(
114-
[
115-
np.cos(2 * np.pi * i * backcast_linspace)
116-
for i in self.get_frequencies(p1)
117-
],
114+
np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * backcast_linspace),
118115
dtype=torch.float32,
119116
) # H/2-1
120117
s2_b = torch.tensor(
121-
[
122-
np.sin(2 * np.pi * i * backcast_linspace)
123-
for i in self.get_frequencies(p2)
124-
],
118+
np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * backcast_linspace),
125119
dtype=torch.float32,
126120
)
127121
self.register_buffer("S_backcast", torch.cat([s1_b, s2_b]))
128122

129123
s1_f = torch.tensor(
130-
[
131-
np.cos(2 * np.pi * i * forecast_linspace)
132-
for i in self.get_frequencies(p1)
133-
],
124+
np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * forecast_linspace),
134125
dtype=torch.float32,
135126
) # H/2-1
136127
s2_f = torch.tensor(
137-
[
138-
np.sin(2 * np.pi * i * forecast_linspace)
139-
for i in self.get_frequencies(p2)
140-
],
128+
np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * forecast_linspace),
141129
dtype=torch.float32,
142130
)
143131
self.register_buffer("S_forecast", torch.cat([s1_f, s2_f]))
@@ -183,14 +171,15 @@ def __init__(
183171
norm = np.sqrt(
184172
forecast_length / thetas_dim
185173
) # ensure range of predictions is comparable to input
186-
174+
thetas_dims_range = np.array(range(thetas_dim))
187175
coefficients = torch.tensor(
188-
[backcast_linspace**i for i in range(thetas_dim)], dtype=torch.float32
176+
backcast_linspace ** thetas_dims_range[:, None],
177+
dtype=torch.float32,
189178
)
190179
self.register_buffer("T_backcast", coefficients * norm)
191-
192180
coefficients = torch.tensor(
193-
[forecast_linspace**i for i in range(thetas_dim)], dtype=torch.float32
181+
forecast_linspace ** thetas_dims_range[:, None],
182+
dtype=torch.float32,
194183
)
195184
self.register_buffer("T_forecast", coefficients * norm)
196185

0 commit comments

Comments
 (0)