@@ -12,12 +12,13 @@ class DeepConvexNet(DeepConvexFlow):
1212 r"""
1313 Class that takes a partially input convex neural network (picnn)
1414 as input and equips it with functions of logdet
15- computation (both estimation and exact computation)
15+ computation (both estimation and exact computation).
1616 This class is based on DeepConvexFlow of the CP-Flow
1717 repo (https://github.com/CW-Huang/CP-Flow)
1818 For details of the logdet estimator, see
1919 ``Convex potential flows: Universal probability distributions
2020 with optimal transport and convex optimization``
21+
2122 Parameters
2223 ----------
2324 picnn
@@ -94,6 +95,7 @@ class SequentialNet(SequentialFlow):
9495 layers and provides energy score computation
9596 This class is based on SequentialFlow of the CP-Flow repo
9697 (https://github.com/CW-Huang/CP-Flow)
98+
9799 Parameters
98100 ----------
99101 networks
@@ -116,6 +118,7 @@ def es_sample(self, hidden_state: torch.Tensor, dimension: int) -> torch.Tensor:
116118 """
117119 Auxiliary function for energy score computation
118120 Drawing samples conditioned on the hidden state
121+
119122 Parameters
120123 ----------
121124 hidden_state
@@ -159,6 +162,7 @@ def energy_score(
159162 h_i is the hidden state associated with z_i,
160163 and es_num_samples is the number of samples drawn
161164 for each of w, w', w'' in energy score approximation
165+
162166 Parameters
163167 ----------
164168 z
@@ -224,6 +228,7 @@ class MQF2Distribution(Distribution):
224228 Distribution class for the model MQF2 proposed in the paper
225229 ``Multivariate Quantile Function Forecaster``
226230 by Kan, Aubet, Januschowski, Park, Benidis, Ruthotto, Gasthaus
231+
227232 Parameters
228233 ----------
229234 picnn
@@ -290,6 +295,7 @@ def stack_sliding_view(self, z: torch.Tensor) -> torch.Tensor:
290295 over the observations z
291296 Then, reshapes the observations into a 2-dimensional tensor for
292297 further computation
298+
293299 Parameters
294300 ----------
295301 z
@@ -317,6 +323,7 @@ def log_prob(self, z: torch.Tensor) -> torch.Tensor:
317323 """
318324 Computes the log likelihood log(g(z)) + logdet(dg(z)/dz),
319325 where g is the gradient of the picnn
326+
320327 Parameters
321328 ----------
322329 z
@@ -346,6 +353,7 @@ def energy_score(self, z: torch.Tensor) -> torch.Tensor:
346353 h_i is the hidden state associated with z_i,
347354 and es_num_samples is the number of samples drawn
348355 for each of w, w', w'' in energy score approximation
356+
349357 Parameters
350358 ----------
351359 z
@@ -370,14 +378,15 @@ def energy_score(self, z: torch.Tensor) -> torch.Tensor:
370378 def rsample (self , sample_shape : torch .Size = torch .Size ()) -> torch .Tensor :
371379 """
372380 Generates the sample paths
381+
373382 Parameters
374383 ----------
375384 sample_shape
376385 Shape of the samples
377386 Returns
378387 -------
379388 sample_paths
380- Tesnor of shape (batch_size, *sample_shape, prediction_length)
389+ Tesnor of shape (batch_size, * sample_shape, prediction_length)
381390 """
382391
383392 numel_batch = self .numel_batch
@@ -407,6 +416,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
407416 def quantile (self , alpha : torch .Tensor , hidden_state : Optional [torch .Tensor ] = None ) -> torch .Tensor :
408417 """
409418 Generates the predicted paths associated with the quantile levels alpha
419+
410420 Parameters
411421 ----------
412422 alpha
0 commit comments