Skip to content

Commit 9cc85d9

Browse files
Improve EmbeddingBagCollection documentation (#2569)
Summary: Pull Request resolved: #2569 Make `EmbeddingBagCollection` documentation more explicit and precise by defining its output by math rather than example. Improve example by showing how the data is organized, using spatial formatting. Reviewed By: colin2328 Differential Revision: D66149087 fbshipit-source-id: 1686037b553ba8ccaa4c6be0495ab8f2f1969cea
1 parent 4b5a8a3 commit 9cc85d9

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

torchrec/modules/embedding_modules.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,25 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
102102
For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
103103
104104
105-
It processes sparse data in the form of `KeyedJaggedTensor` with values of the form
106-
[F X B X L] where:
105+
It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape
106+
`(F, B, L_{f,i})` where:
107107
108-
* F: features (keys)
109-
* B: batch size
110-
* L: length of sparse features (jagged)
108+
* `F`: number of features (keys)
109+
* `B`: batch size
110+
* `L_{f,i}`: length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged)
111111
112-
and outputs a `KeyedTensor` with values of the form [B * (F * D)] where:
112+
and outputs a `KeyedTensor` with values with shape `(B, D)` where:
113113
114-
* F: features (keys)
115-
* D: each feature's (key's) embedding dimension
116-
* B: batch size
114+
* `B`: batch size
115+
* `D`: sum of embedding dimensions of all embedding tables, that is, `sum([config.embedding_dim for config in tables])`
116+
117+
Assuming the argument is a `KeyedJaggedTensor` `J` with `F` features, batch size `B` and `L_{f,i}` sparse lengths
118+
such that `J[f][i]` is the bag for feature `f` and batch index `i`, the output `KeyedTensor` `KT` is defined as follows:
119+
`KT[i]` = `torch.cat([emb[f](J[f][i]) for f in J.keys()])` where `emb[f]` is the `EmbeddingBag` corresponding to the feature `f`.
120+
121+
Note that `J[f][i]` is a variable-length list of integer values (a bag), and `emb[f](J[f][i])` is pooled embedding
122+
produced by reducing the embeddings of each of the values in `J[f][i]`
123+
using the `EmbeddingBag` `emb[f]`'s mode (default is the mean).
117124
118125
Args:
119126
tables (List[EmbeddingBagConfig]): list of embedding tables.
@@ -131,28 +138,34 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
131138
132139
ebc = EmbeddingBagCollection(tables=[table_0, table_1])
133140
134-
# 0 1 2 <-- batch
135-
# "f1" [0,1] None [2]
136-
# "f2" [3] [4] [5,6,7]
141+
# i = 0 i = 1 i = 2 <-- batch indices
142+
# "f1" [0,1] None [2]
143+
# "f2" [3] [4] [5,6,7]
137144
# ^
138-
# feature
145+
# features
139146
140147
features = KeyedJaggedTensor(
141148
keys=["f1", "f2"],
142-
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
143-
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
149+
values=torch.tensor([0, 1, 2, # feature 'f1'
150+
3, 4, 5, 6, 7]), # feature 'f2'
151+
# i = 1 i = 2 i = 3 <--- batch indices
152+
offsets=torch.tensor([
153+
0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3]
154+
3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8]
144155
)
145156
146157
pooled_embeddings = ebc(features)
147158
print(pooled_embeddings.values())
148-
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
149-
[ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011],
150-
[-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]],
159+
tensor([
160+
# f1 pooled embeddings from bags (dim 3) f2 pooled embeddings from bags (dim 4)
161+
[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # batch index 0
162+
[ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # batch index 1
163+
[-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # batch index 2
151164
grad_fn=<CatBackward0>)
152165
print(pooled_embeddings.keys())
153166
['f1', 'f2']
154167
print(pooled_embeddings.offset_per_key())
155-
tensor([0, 3, 7])
168+
tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7).
156169
"""
157170

158171
def __init__(

0 commit comments

Comments
 (0)