@@ -12,12 +12,13 @@ class DeepConvexNet(DeepConvexFlow):
12
12
r"""
13
13
Class that takes a partially input convex neural network (picnn)
14
14
as input and equips it with functions of logdet
15
- computation (both estimation and exact computation)
15
+ computation (both estimation and exact computation).
16
16
This class is based on DeepConvexFlow of the CP-Flow
17
17
repo (https://github.com/CW-Huang/CP-Flow)
18
18
For details of the logdet estimator, see
19
19
``Convex potential flows: Universal probability distributions
20
20
with optimal transport and convex optimization``
21
+
21
22
Parameters
22
23
----------
23
24
picnn
@@ -94,6 +95,7 @@ class SequentialNet(SequentialFlow):
94
95
layers and provides energy score computation
95
96
This class is based on SequentialFlow of the CP-Flow repo
96
97
(https://github.com/CW-Huang/CP-Flow)
98
+
97
99
Parameters
98
100
----------
99
101
networks
@@ -116,6 +118,7 @@ def es_sample(self, hidden_state: torch.Tensor, dimension: int) -> torch.Tensor:
116
118
"""
117
119
Auxiliary function for energy score computation
118
120
Drawing samples conditioned on the hidden state
121
+
119
122
Parameters
120
123
----------
121
124
hidden_state
@@ -159,6 +162,7 @@ def energy_score(
159
162
h_i is the hidden state associated with z_i,
160
163
and es_num_samples is the number of samples drawn
161
164
for each of w, w', w'' in energy score approximation
165
+
162
166
Parameters
163
167
----------
164
168
z
@@ -224,6 +228,7 @@ class MQF2Distribution(Distribution):
224
228
Distribution class for the model MQF2 proposed in the paper
225
229
``Multivariate Quantile Function Forecaster``
226
230
by Kan, Aubet, Januschowski, Park, Benidis, Ruthotto, Gasthaus
231
+
227
232
Parameters
228
233
----------
229
234
picnn
@@ -290,6 +295,7 @@ def stack_sliding_view(self, z: torch.Tensor) -> torch.Tensor:
290
295
over the observations z
291
296
Then, reshapes the observations into a 2-dimensional tensor for
292
297
further computation
298
+
293
299
Parameters
294
300
----------
295
301
z
@@ -317,6 +323,7 @@ def log_prob(self, z: torch.Tensor) -> torch.Tensor:
317
323
"""
318
324
Computes the log likelihood log(g(z)) + logdet(dg(z)/dz),
319
325
where g is the gradient of the picnn
326
+
320
327
Parameters
321
328
----------
322
329
z
@@ -346,6 +353,7 @@ def energy_score(self, z: torch.Tensor) -> torch.Tensor:
346
353
h_i is the hidden state associated with z_i,
347
354
and es_num_samples is the number of samples drawn
348
355
for each of w, w', w'' in energy score approximation
356
+
349
357
Parameters
350
358
----------
351
359
z
@@ -370,14 +378,15 @@ def energy_score(self, z: torch.Tensor) -> torch.Tensor:
370
378
def rsample (self , sample_shape : torch .Size = torch .Size ()) -> torch .Tensor :
371
379
"""
372
380
Generates the sample paths
381
+
373
382
Parameters
374
383
----------
375
384
sample_shape
376
385
Shape of the samples
377
386
Returns
378
387
-------
379
388
sample_paths
380
- Tesnor of shape (batch_size, *sample_shape, prediction_length)
389
+ Tesnor of shape (batch_size, * sample_shape, prediction_length)
381
390
"""
382
391
383
392
numel_batch = self .numel_batch
@@ -407,6 +416,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
407
416
def quantile (self , alpha : torch .Tensor , hidden_state : Optional [torch .Tensor ] = None ) -> torch .Tensor :
408
417
"""
409
418
Generates the predicted paths associated with the quantile levels alpha
419
+
410
420
Parameters
411
421
----------
412
422
alpha
0 commit comments