Skip to content

Commit a06aa04

Browse files
committed
Implemented FIRE positional encoding module, added
test, updated docs.
1 parent 7f3e70e commit a06aa04

File tree

4 files changed

+179
-0
lines changed

4 files changed

+179
-0
lines changed

docs/source/api_ref_modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Modeling Components and Building Blocks
2525
VisionTransformer
2626
LayerDropout
2727
prepare_layer_dropout
28+
FireSelfAttention
2829

2930
Losses
3031
------

tests/torchtune/modules/test_position_embeddings.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import tensor
1212

1313
from torchtune.modules.position_embeddings import (
14+
FireSelfAttention,
1415
RotaryPositionalEmbeddings,
1516
VisionRotaryPositionalEmbeddings,
1617
)
@@ -198,3 +199,32 @@ def test_rope_init_meta_device(self, input_params):
198199
meta_rope.rope_init()
199200
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
200201
torch.testing.assert_close(p1, p2)
202+
203+
204+
class TestFireSelfAttention:
205+
"""
206+
Class for testing FIRE positional embeddings module. As far as I am aware,
207+
there is not an open-source reference implementation available,
208+
besides the mathematical description of the algorithm in the paper (https://arxiv.org/abs/2310.04418),
209+
and since the module contains learnable weights, one would have to actually train it
210+
to evaluate performance (as others have done; e.g. see https://arxiv.org/abs/2402.09371).
211+
For now, I am just testing the format and content of the output to ensure it can be used
212+
safely within a transformer model without breaking anything.
213+
"""
214+
215+
# @mps_ignored_test()
216+
def test_format(self):
217+
# instantiate module
218+
test_layer = FireSelfAttention(dim_model=512, num_heads=8, hidden_size=32)
219+
# input tensor; FireSelfAttention expects a format of (batch_size, seq_len, dim_model)
220+
x = torch.randn(64, 20, 512)
221+
# get output
222+
y = test_layer(x)
223+
224+
# validate output
225+
# make sure it has the right shape
226+
assert y.shape == x.shape
227+
228+
"""The input tensor was not all zeros, so if the module is working properly (even
229+
though it hasn't been trained yet) the output should be different from the input."""
230+
assert not torch.equal(y, x)

torchtune/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .layer_norm import Fp32LayerNorm # noqa
1919
from .low_precision import FrozenNF4Linear # noqa
2020
from .position_embeddings import ( # noqa
21+
FireSelfAttention,
2122
RotaryPositionalEmbeddings,
2223
VisionRotaryPositionalEmbeddings,
2324
)
@@ -57,4 +58,5 @@
5758
"disable_kv_cache",
5859
"LayerDropout",
5960
"prepare_layer_dropout",
61+
"FireSelfAttention",
6062
]

torchtune/modules/position_embeddings.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,149 @@ def forward(
268268
# Squash tile dimension back into sequence dimension - tensor has shape [b, s, n_h, h_d]
269269
x_out = x_out.reshape(bsz, self.max_num_tiles * seq_len, n_h, h_d)
270270
return x_out.type_as(x)
271+
272+
273+
class FireSelfAttention(nn.Module):
274+
"""
275+
This class implements FIRE (Functional Interpolation for Relative Positional Encodings)
276+
as described in https://arxiv.org/abs/2310.04418 for causal language modeling tasks. The
277+
only modification from the paper is that this implementation uses the GELU activation function instead
278+
of ReLU in order to avoid possible problems with "dying" neurons.
279+
280+
Args:
281+
dim_model (int): The embedding dimension of the input vectors.
282+
num_heads (int): The number of self-attention heads, set to 1 by default. The dimension of each individual head
283+
is usually computed as ``dim_model // num_heads``.
284+
hidden_size (int): The dimension of the MLP layers in each attention head used to compute the bias matrix.
285+
286+
Note: This module is fundamentally a positional encoding scheme; however, due to the nature of FIRE relative
287+
positional encodings, it takes the form of an attention layer.
288+
"""
289+
290+
def __init__(
291+
self, dim_model: int, num_heads: int = 1, hidden_size: int = 32
292+
) -> None:
293+
super().__init__()
294+
295+
# make sure num_heads divides dim_model:
296+
assert (
297+
dim_model % num_heads == 0
298+
), "Number of heads must divide dimension of model"
299+
300+
# compute kdim = vdim
301+
kdim = dim_model // num_heads
302+
303+
# initialize attention heads
304+
self.attention_heads = nn.ModuleList(
305+
[
306+
self.FireAttentionHead(dim_model, kdim, hidden_size)
307+
for _ in range(num_heads)
308+
]
309+
)
310+
311+
# final linear layer
312+
self.W_o = nn.Linear(dim_model, dim_model, bias=False)
313+
314+
class FireAttentionHead(nn.Module):
315+
"""
316+
An inner class to implement a single attention head using the FIRE positional encoding scheme.
317+
**Do not** use this class directly; instead use FireSelfAttention with ``num_heads = 1`` if you need it.
318+
319+
Args:
320+
dim_model (int): The embedding dimension of the input vectors, as above.
321+
kdim (int): The dimension of the query, key, and value vectors, computed as ``kdim = dim_model // num_heads``.
322+
hidden_size (int): The dimension of the MLP layers in each attention head used to compute the bias matrix.
323+
"""
324+
325+
def __init__(self, dim_model: int, kdim: int, hidden_size: int) -> None:
326+
super().__init__()
327+
self.kdim = kdim
328+
329+
# initialize parameter matrices
330+
self.W_q = nn.Linear(dim_model, kdim, bias=False)
331+
self.W_k = nn.Linear(dim_model, kdim, bias=False)
332+
self.W_v = nn.Linear(dim_model, kdim, bias=False)
333+
334+
# initialize learnable scalars to "reasonable" values (these are arbitary and can be adjusted later on.)
335+
# c is used to modify the input of the logarithm in the phi function.
336+
self.c = nn.Parameter(torch.tensor(1.0))
337+
# L is used in the adaptive thresholding mechanism to activate progressive interpolation only for long contexts.
338+
self.L = nn.Parameter(torch.tensor(2.0))
339+
340+
# initialize learnable continuous function
341+
self.f_theta = nn.Sequential(
342+
nn.Linear(1, hidden_size),
343+
nn.GELU(),
344+
nn.Linear(hidden_size, hidden_size),
345+
nn.GELU(),
346+
nn.Linear(hidden_size, 1),
347+
)
348+
349+
# concave function to amplify differences among local positions
350+
def phi(self, c: nn.Parameter, x: int | torch.Tensor) -> torch.Tensor:
351+
return torch.log1p(c * x)
352+
353+
def forward(self, src: torch.Tensor) -> torch.Tensor:
354+
"""
355+
Args:
356+
src (torch.Tensor): Input tensor with shape ``[batch_size, seq_length, dim_model]``
357+
358+
Returns:
359+
torch.Tensor: Output tensor of shape ``[batch_size, seq_length, kdim]``
360+
"""
361+
# Assuming src has shape (batch_size, seq_length, dim_model)
362+
batch_size, seq_length = src.shape[0:2]
363+
364+
# constrain c to be > 0
365+
c = torch.nn.functional.softplus(self.c)
366+
367+
# compute bias matrix
368+
# below, i is the query position and j is the key position, 0 <= i - j < i
369+
bias = torch.zeros(seq_length, seq_length)
370+
for i in range(1, seq_length):
371+
for j in range(0, i):
372+
# we have to use i + 1 in the denominator to compensate for 0-based indexing
373+
bias[i, j] = self.phi(c, i - j) / self.phi(
374+
c, torch.maximum(self.L, torch.tensor(i + 1))
375+
)
376+
# apply MLP to bias matrix
377+
bias = self.f_theta(bias.unsqueeze(2)).squeeze(2)
378+
# add causal mask
379+
lookahead_mask = torch.ones(seq_length, seq_length, dtype=torch.bool).triu(
380+
diagonal=1
381+
)
382+
bias.masked_fill_(lookahead_mask, float("-inf"))
383+
# repeat bias matrix for batch_size
384+
bias = bias.repeat(batch_size, 1, 1)
385+
386+
# get Query, Key, and Value matrices for each sequence
387+
q = self.W_q(src)
388+
k = self.W_k(src)
389+
v = self.W_v(src)
390+
391+
# calculate attention scores
392+
k_t = torch.transpose(k, 1, 2)
393+
attn_logits = torch.bmm(q, k_t) / (self.kdim**0.5)
394+
attn_logits = attn_logits + bias
395+
attn_weights = torch.nn.functional.softmax(attn_logits, dim=-1)
396+
attn_outputs = torch.bmm(attn_weights, v)
397+
return attn_outputs
398+
399+
# End of the inner class for a single attention head
400+
401+
def forward(self, src: torch.Tensor) -> torch.Tensor:
402+
"""
403+
Args:
404+
src (torch.Tensor): Input tensor with shape ``[batch_size, seq_length, dim_model]``
405+
406+
Returns:
407+
torch.Tensor: Output tensor of shape ``[batch_size, seq_length, dim_model]`` with multi-head attention
408+
and FIRE relative positional encoding applied.
409+
"""
410+
# src should have shape (batch_size, seq_length, dim_model)
411+
# Pass src through the attention heads
412+
attn_results = [attn_head(src) for attn_head in self.attention_heads]
413+
# concatenate results
414+
attn_results = torch.cat(attn_results, dim=-1)
415+
# pass through final linear layer
416+
return self.W_o(attn_results)

0 commit comments

Comments
 (0)