Skip to content

Commit cd635b9

Browse files
committed
Added example for FIRE positional encodings.
1 parent a06aa04 commit cd635b9

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

torchtune/modules/position_embeddings.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,17 @@ class FireSelfAttention(nn.Module):
277277
only modification from the paper is that this implementation uses the GELU activation function instead
278278
of ReLU in order to avoid possible problems with "dying" neurons.
279279
280+
This module is fundamentally a positional encoding scheme; however, due to the nature of FIRE relative
281+
positional encodings, it takes the form of an attention layer.
282+
280283
Args:
281284
dim_model (int): The embedding dimension of the input vectors.
282285
num_heads (int): The number of self-attention heads, set to 1 by default. The dimension of each individual head
283286
is usually computed as ``dim_model // num_heads``.
284287
hidden_size (int): The dimension of the MLP layers in each attention head used to compute the bias matrix.
285288
286-
Note: This module is fundamentally a positional encoding scheme; however, due to the nature of FIRE relative
287-
positional encodings, it takes the form of an attention layer.
289+
Raises:
290+
ValueError: If num_heads does not divide dim_model
288291
"""
289292

290293
def __init__(
@@ -293,9 +296,8 @@ def __init__(
293296
super().__init__()
294297

295298
# make sure num_heads divides dim_model:
296-
assert (
297-
dim_model % num_heads == 0
298-
), "Number of heads must divide dimension of model"
299+
if dim_model % num_heads != 0:
300+
raise ValueError("Number of heads must divide dimension of model")
299301

300302
# compute kdim = vdim
301303
kdim = dim_model // num_heads
@@ -406,6 +408,22 @@ def forward(self, src: torch.Tensor) -> torch.Tensor:
406408
Returns:
407409
torch.Tensor: Output tensor of shape ``[batch_size, seq_length, dim_model]`` with multi-head attention
408410
and FIRE relative positional encoding applied.
411+
412+
Example:
413+
414+
>>> import torch
415+
>>> from torchtune.modules import FireSelfAttention
416+
>>>
417+
>>> # instantiate module
418+
>>> test_layer = FireSelfAttention(dim_model=512, num_heads=8, hidden_size=32)
419+
>>>
420+
>>> # input tensor; FireSelfAttention expects a format of (batch_size, seq_len, dim_model)
421+
>>> x = torch.randn(64, 20, 512)
422+
>>>
423+
>>> # get output of attention layer with FIRE positional encoding
424+
>>> y = test_layer(x)
425+
>>> print(y.shape)
426+
torch.Size([64, 20, 512])
409427
"""
410428
# src should have shape (batch_size, seq_length, dim_model)
411429
# Pass src through the attention heads

0 commit comments

Comments
 (0)