@@ -277,14 +277,17 @@ class FireSelfAttention(nn.Module):
277
277
only modification from the paper is that this implementation uses the GELU activation function instead
278
278
of ReLU in order to avoid possible problems with "dying" neurons.
279
279
280
+ This module is fundamentally a positional encoding scheme; however, due to the nature of FIRE relative
281
+ positional encodings, it takes the form of an attention layer.
282
+
280
283
Args:
281
284
dim_model (int): The embedding dimension of the input vectors.
282
285
num_heads (int): The number of self-attention heads, set to 1 by default. The dimension of each individual head
283
286
is usually computed as ``dim_model // num_heads``.
284
287
hidden_size (int): The dimension of the MLP layers in each attention head used to compute the bias matrix.
285
288
286
- Note: This module is fundamentally a positional encoding scheme; however, due to the nature of FIRE relative
287
- positional encodings, it takes the form of an attention layer.
289
+ Raises:
290
+ ValueError: If num_heads does not divide dim_model
288
291
"""
289
292
290
293
def __init__ (
@@ -293,9 +296,8 @@ def __init__(
293
296
super ().__init__ ()
294
297
295
298
# make sure num_heads divides dim_model:
296
- assert (
297
- dim_model % num_heads == 0
298
- ), "Number of heads must divide dimension of model"
299
+ if dim_model % num_heads != 0 :
300
+ raise ValueError ("Number of heads must divide dimension of model" )
299
301
300
302
# compute kdim = vdim
301
303
kdim = dim_model // num_heads
@@ -406,6 +408,22 @@ def forward(self, src: torch.Tensor) -> torch.Tensor:
406
408
Returns:
407
409
torch.Tensor: Output tensor of shape ``[batch_size, seq_length, dim_model]`` with multi-head attention
408
410
and FIRE relative positional encoding applied.
411
+
412
+ Example:
413
+
414
+ >>> import torch
415
+ >>> from torchtune.modules import FireSelfAttention
416
+ >>>
417
+ >>> # instantiate module
418
+ >>> test_layer = FireSelfAttention(dim_model=512, num_heads=8, hidden_size=32)
419
+ >>>
420
+ >>> # input tensor; FireSelfAttention expects a format of (batch_size, seq_len, dim_model)
421
+ >>> x = torch.randn(64, 20, 512)
422
+ >>>
423
+ >>> # get output of attention layer with FIRE positional encoding
424
+ >>> y = test_layer(x)
425
+ >>> print(y.shape)
426
+ torch.Size([64, 20, 512])
409
427
"""
410
428
# src should have shape (batch_size, seq_length, dim_model)
411
429
# Pass src through the attention heads
0 commit comments