@@ -268,3 +268,149 @@ def forward(
268
268
# Squash tile dimension back into sequence dimension - tensor has shape [b, s, n_h, h_d]
269
269
x_out = x_out .reshape (bsz , self .max_num_tiles * seq_len , n_h , h_d )
270
270
return x_out .type_as (x )
271
+
272
+
273
+ class FireSelfAttention (nn .Module ):
274
+ """
275
+ This class implements FIRE (Functional Interpolation for Relative Positional Encodings)
276
+ as described in https://arxiv.org/abs/2310.04418 for causal language modeling tasks. The
277
+ only modification from the paper is that this implementation uses the GELU activation function instead
278
+ of ReLU in order to avoid possible problems with "dying" neurons.
279
+
280
+ Args:
281
+ dim_model (int): The embedding dimension of the input vectors.
282
+ num_heads (int): The number of self-attention heads, set to 1 by default. The dimension of each individual head
283
+ is usually computed as ``dim_model // num_heads``.
284
+ hidden_size (int): The dimension of the MLP layers in each attention head used to compute the bias matrix.
285
+
286
+ Note: This module is fundamentally a positional encoding scheme; however, due to the nature of FIRE relative
287
+ positional encodings, it takes the form of an attention layer.
288
+ """
289
+
290
+ def __init__ (
291
+ self , dim_model : int , num_heads : int = 1 , hidden_size : int = 32
292
+ ) -> None :
293
+ super ().__init__ ()
294
+
295
+ # make sure num_heads divides dim_model:
296
+ assert (
297
+ dim_model % num_heads == 0
298
+ ), "Number of heads must divide dimension of model"
299
+
300
+ # compute kdim = vdim
301
+ kdim = dim_model // num_heads
302
+
303
+ # initialize attention heads
304
+ self .attention_heads = nn .ModuleList (
305
+ [
306
+ self .FireAttentionHead (dim_model , kdim , hidden_size )
307
+ for _ in range (num_heads )
308
+ ]
309
+ )
310
+
311
+ # final linear layer
312
+ self .W_o = nn .Linear (dim_model , dim_model , bias = False )
313
+
314
+ class FireAttentionHead (nn .Module ):
315
+ """
316
+ An inner class to implement a single attention head using the FIRE positional encoding scheme.
317
+ **Do not** use this class directly; instead use FireSelfAttention with ``num_heads = 1`` if you need it.
318
+
319
+ Args:
320
+ dim_model (int): The embedding dimension of the input vectors, as above.
321
+ kdim (int): The dimension of the query, key, and value vectors, computed as ``kdim = dim_model // num_heads``.
322
+ hidden_size (int): The dimension of the MLP layers in each attention head used to compute the bias matrix.
323
+ """
324
+
325
+ def __init__ (self , dim_model : int , kdim : int , hidden_size : int ) -> None :
326
+ super ().__init__ ()
327
+ self .kdim = kdim
328
+
329
+ # initialize parameter matrices
330
+ self .W_q = nn .Linear (dim_model , kdim , bias = False )
331
+ self .W_k = nn .Linear (dim_model , kdim , bias = False )
332
+ self .W_v = nn .Linear (dim_model , kdim , bias = False )
333
+
334
+ # initialize learnable scalars to "reasonable" values (these are arbitary and can be adjusted later on.)
335
+ # c is used to modify the input of the logarithm in the phi function.
336
+ self .c = nn .Parameter (torch .tensor (1.0 ))
337
+ # L is used in the adaptive thresholding mechanism to activate progressive interpolation only for long contexts.
338
+ self .L = nn .Parameter (torch .tensor (2.0 ))
339
+
340
+ # initialize learnable continuous function
341
+ self .f_theta = nn .Sequential (
342
+ nn .Linear (1 , hidden_size ),
343
+ nn .GELU (),
344
+ nn .Linear (hidden_size , hidden_size ),
345
+ nn .GELU (),
346
+ nn .Linear (hidden_size , 1 ),
347
+ )
348
+
349
+ # concave function to amplify differences among local positions
350
+ def phi (self , c : nn .Parameter , x : int | torch .Tensor ) -> torch .Tensor :
351
+ return torch .log1p (c * x )
352
+
353
+ def forward (self , src : torch .Tensor ) -> torch .Tensor :
354
+ """
355
+ Args:
356
+ src (torch.Tensor): Input tensor with shape ``[batch_size, seq_length, dim_model]``
357
+
358
+ Returns:
359
+ torch.Tensor: Output tensor of shape ``[batch_size, seq_length, kdim]``
360
+ """
361
+ # Assuming src has shape (batch_size, seq_length, dim_model)
362
+ batch_size , seq_length = src .shape [0 :2 ]
363
+
364
+ # constrain c to be > 0
365
+ c = torch .nn .functional .softplus (self .c )
366
+
367
+ # compute bias matrix
368
+ # below, i is the query position and j is the key position, 0 <= i - j < i
369
+ bias = torch .zeros (seq_length , seq_length )
370
+ for i in range (1 , seq_length ):
371
+ for j in range (0 , i ):
372
+ # we have to use i + 1 in the denominator to compensate for 0-based indexing
373
+ bias [i , j ] = self .phi (c , i - j ) / self .phi (
374
+ c , torch .maximum (self .L , torch .tensor (i + 1 ))
375
+ )
376
+ # apply MLP to bias matrix
377
+ bias = self .f_theta (bias .unsqueeze (2 )).squeeze (2 )
378
+ # add causal mask
379
+ lookahead_mask = torch .ones (seq_length , seq_length , dtype = torch .bool ).triu (
380
+ diagonal = 1
381
+ )
382
+ bias .masked_fill_ (lookahead_mask , float ("-inf" ))
383
+ # repeat bias matrix for batch_size
384
+ bias = bias .repeat (batch_size , 1 , 1 )
385
+
386
+ # get Query, Key, and Value matrices for each sequence
387
+ q = self .W_q (src )
388
+ k = self .W_k (src )
389
+ v = self .W_v (src )
390
+
391
+ # calculate attention scores
392
+ k_t = torch .transpose (k , 1 , 2 )
393
+ attn_logits = torch .bmm (q , k_t ) / (self .kdim ** 0.5 )
394
+ attn_logits = attn_logits + bias
395
+ attn_weights = torch .nn .functional .softmax (attn_logits , dim = - 1 )
396
+ attn_outputs = torch .bmm (attn_weights , v )
397
+ return attn_outputs
398
+
399
+ # End of the inner class for a single attention head
400
+
401
+ def forward (self , src : torch .Tensor ) -> torch .Tensor :
402
+ """
403
+ Args:
404
+ src (torch.Tensor): Input tensor with shape ``[batch_size, seq_length, dim_model]``
405
+
406
+ Returns:
407
+ torch.Tensor: Output tensor of shape ``[batch_size, seq_length, dim_model]`` with multi-head attention
408
+ and FIRE relative positional encoding applied.
409
+ """
410
+ # src should have shape (batch_size, seq_length, dim_model)
411
+ # Pass src through the attention heads
412
+ attn_results = [attn_head (src ) for attn_head in self .attention_heads ]
413
+ # concatenate results
414
+ attn_results = torch .cat (attn_results , dim = - 1 )
415
+ # pass through final linear layer
416
+ return self .W_o (attn_results )
0 commit comments