23
23
def _cross_squared_distance_matrix (x : TensorLike , y : TensorLike ) -> tf .Tensor :
24
24
"""Pairwise squared distance between two (batch) matrices' rows (2nd dim).
25
25
26
- Computes the pairwise distances between rows of x and rows of y
26
+ Computes the pairwise distances between rows of x and rows of y.
27
+
27
28
Args:
28
- x: [batch_size, n, d] float `Tensor`
29
- y: [batch_size, m, d] float `Tensor`
29
+ x: ` [batch_size, n, d]` float `Tensor`.
30
+ y: ` [batch_size, m, d]` float `Tensor`.
30
31
31
32
Returns:
32
- squared_dists: [batch_size, n, m] float `Tensor`, where
33
- squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
33
+ squared_dists: ` [batch_size, n, m]` float `Tensor`, where
34
+ ` squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2`.
34
35
"""
35
36
x_norm_squared = tf .reduce_sum (tf .square (x ), 2 )
36
37
y_norm_squared = tf .reduce_sum (tf .square (y ), 2 )
@@ -52,14 +53,14 @@ def _pairwise_squared_distance_matrix(x: TensorLike) -> tf.Tensor:
52
53
"""Pairwise squared distance among a (batch) matrix's rows (2nd dim).
53
54
54
55
This saves a bit of computation vs. using
55
- _cross_squared_distance_matrix(x,x)
56
+ ` _cross_squared_distance_matrix(x, x)`
56
57
57
58
Args:
58
- x: `[batch_size, n, d]` float `Tensor`
59
+ x: `[batch_size, n, d]` float `Tensor`.
59
60
60
61
Returns:
61
62
squared_dists: `[batch_size, n, n]` float `Tensor`, where
62
- squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
63
+ ` squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2`.
63
64
"""
64
65
65
66
x_x_transpose = tf .matmul (x , x , adjoint_b = True )
@@ -83,17 +84,17 @@ def _solve_interpolation(
83
84
order : int ,
84
85
regularization_weight : FloatTensorLike ,
85
86
) -> TensorLike :
86
- """Solve for interpolation coefficients.
87
+ r """Solve for interpolation coefficients.
87
88
88
89
Computes the coefficients of the polyharmonic interpolant for the
89
- 'training' data defined by (train_points, train_values) using the kernel
90
- phi.
90
+ 'training' data defined by ` (train_points, train_values)` using the kernel
91
+ $\ phi$ .
91
92
92
93
Args:
93
- train_points: `[b, n, d]` interpolation centers
94
- train_values: `[b, n, k]` function values
95
- order: order of the interpolation
96
- regularization_weight: weight to place on smoothness regularization term
94
+ train_points: `[b, n, d]` interpolation centers.
95
+ train_values: `[b, n, k]` function values.
96
+ order: order of the interpolation.
97
+ regularization_weight: weight to place on smoothness regularization term.
97
98
98
99
Returns:
99
100
w: `[b, n, k]` weights on each interpolation center
@@ -173,15 +174,15 @@ def _apply_interpolation(
173
174
interpolated function values at query_points.
174
175
175
176
Args:
176
- query_points: `[b, m, d]` x values to evaluate the interpolation at
177
+ query_points: `[b, m, d]` x values to evaluate the interpolation at.
177
178
train_points: `[b, n, d]` x values that act as the interpolation centers
178
- ( the c variables in the wikipedia article)
179
- w: `[b, n, k]` weights on each interpolation center
180
- v: `[b, d, k]` weights on each input dimension
181
- order: order of the interpolation
179
+ ( the c variables in the wikipedia article).
180
+ w: `[b, n, k]` weights on each interpolation center.
181
+ v: `[b, d, k]` weights on each input dimension.
182
+ order: order of the interpolation.
182
183
183
184
Returns:
184
- Polyharmonic interpolation evaluated at points defined in query_points.
185
+ Polyharmonic interpolation evaluated at points defined in ` query_points` .
185
186
"""
186
187
187
188
# First, compute the contribution from the rbf term.
@@ -207,11 +208,11 @@ def _phi(r: FloatTensorLike, order: int) -> FloatTensorLike:
207
208
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
208
209
209
210
Args:
210
- r: input op
211
- order: interpolation order
211
+ r: input op.
212
+ order: interpolation order.
212
213
213
214
Returns:
214
- phi_k evaluated coordinate-wise on r , for k = r
215
+ ` phi_k` evaluated coordinate-wise on `r` , for ` k = r`.
215
216
"""
216
217
217
218
# using EPSILON prevents log(0), sqrt0), etc.
0 commit comments