Skip to content

Commit 68faeff

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
InteractionFeatures input transform (#2560)
Summary: Pull Request resolved: #2560 InteractionFeatures input transform to compute first-order interactions between inputs. Used for feature importance work in conjunction with (warped) linear models. Reviewed By: sdaulton Differential Revision: D63673008 fbshipit-source-id: 1e57431b92f55cf25b711d5a35b8606f77a58c69
1 parent 6d327b9 commit 68faeff

File tree

3 files changed

+80
-1
lines changed

3 files changed

+80
-1
lines changed

botorch/models/transforms/input.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525
from botorch.exceptions.errors import BotorchTensorDimensionError
2626
from botorch.exceptions.warnings import UserInputWarning
27-
from botorch.models.transforms.utils import subset_transform
27+
from botorch.models.transforms.utils import interaction_features, subset_transform
2828
from botorch.models.utils import fantasize
2929
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
3030
from gpytorch import Module as GPyTorchModule
@@ -1370,6 +1370,30 @@ def transform(self, X: Tensor) -> Tensor:
13701370
return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1])
13711371

13721372

1373+
class InteractionFeatures(AppendFeatures):
1374+
r"""A transform that appends the first-order interaction terms $x_i * x_j, i < j$,
1375+
for all or a subset of the input variables."""
1376+
1377+
def __init__(
1378+
self,
1379+
indices: Optional[list[int]] = None,
1380+
) -> None:
1381+
r"""Initializes the InteractionFeatures transform.
1382+
1383+
Args:
1384+
indices: Indices of the subset of dimensions to compute interaction
1385+
features on.
1386+
"""
1387+
1388+
super().__init__(
1389+
f=interaction_features,
1390+
indices=indices,
1391+
transform_on_train=True,
1392+
transform_on_eval=True,
1393+
transform_on_fantasize=True,
1394+
)
1395+
1396+
13731397
class FilterFeatures(InputTransform, Module):
13741398
r"""A transform that filters the input with a given set of features indices.
13751399

botorch/models/transforms/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,18 @@ def f(self, X: Tensor) -> Tensor:
126126
return Y
127127

128128
return f
129+
130+
131+
def interaction_features(X: Tensor) -> Tensor:
132+
"""Computes the interaction features between the inputs.
133+
134+
Args:
135+
X: A `batch_shape x q x d`-dim tensor of inputs.
136+
indices: The input dimensions to generate interaction features for.
137+
138+
Returns:
139+
A `n x q x 1 x (d * (d-1) / 2))`-dim tensor of interaction features.
140+
"""
141+
dim = X.shape[-1]
142+
row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1)
143+
return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2)

test/models/transforms/test_input.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
InputPerturbation,
2121
InputStandardize,
2222
InputTransform,
23+
InteractionFeatures,
2324
Log10,
2425
Normalize,
2526
OneHotToNumeric,
@@ -1629,6 +1630,45 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor:
16291630
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))
16301631

16311632

1633+
class TestInteractionFeatures(BotorchTestCase):
1634+
def test_interaction_features(self) -> None:
1635+
interaction = InteractionFeatures()
1636+
X = torch.arange(6, dtype=torch.float).reshape(2, 3)
1637+
X_tf = interaction(X)
1638+
self.assertTrue(X_tf.shape, torch.Size([2, 6]))
1639+
1640+
# test correct output values
1641+
self.assertTrue(
1642+
torch.equal(
1643+
X_tf,
1644+
torch.tensor(
1645+
[[0.0, 1.0, 2.0, 0.0, 0.0, 2.0], [3.0, 4.0, 5.0, 12.0, 15.0, 20.0]]
1646+
),
1647+
)
1648+
)
1649+
X = torch.arange(6, dtype=torch.float).reshape(2, 3)
1650+
interaction = InteractionFeatures(indices=[1, 2])
1651+
X_tf = interaction(X)
1652+
self.assertTrue(
1653+
torch.equal(
1654+
X_tf,
1655+
torch.tensor([[0.0, 1.0, 2.0, 2.0], [3.0, 4.0, 5.0, 20.0]]),
1656+
)
1657+
)
1658+
with self.assertRaisesRegex(
1659+
IndexError, "index 2 is out of bounds for dimension 0 with size 2"
1660+
):
1661+
interaction(torch.rand(4, 2))
1662+
1663+
# test batched evaluation
1664+
interaction = InteractionFeatures()
1665+
X_tf = interaction(torch.rand(4, 2, 4))
1666+
self.assertTrue(X_tf.shape, torch.Size([4, 2, 10]))
1667+
1668+
X_tf = interaction(torch.rand(5, 7, 3, 4))
1669+
self.assertTrue(X_tf.shape, torch.Size([5, 7, 3, 10]))
1670+
1671+
16321672
class TestFilterFeatures(BotorchTestCase):
16331673
def test_filter_features(self) -> None:
16341674
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)