diff --git a/lean_transformer/__init__.py b/lean_transformer/__init__.py index d3d8d80..726f5b7 100644 --- a/lean_transformer/__init__.py +++ b/lean_transformer/__init__.py @@ -1,5 +1,5 @@ from .ffn import LeanFFN -from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore +from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore, AttentionCache, CreateAttentionCache from .rotary import RotaryEmbeddings, rotate from .sequence import SequentialWithKwargs, ReversibleWithKwargs, ActiveKwargs from .config import LeanTransformerConfig diff --git a/lean_transformer/attn.py b/lean_transformer/attn.py index b36a614..01e4efc 100644 --- a/lean_transformer/attn.py +++ b/lean_transformer/attn.py @@ -7,6 +7,20 @@ from lean_transformer.rotary import RotaryEmbeddings +from dataclasses import dataclass + + +@dataclass +class AttentionCache: + key: torch.tensor + value: torch.tensor + +def CreateAttentionCache(batch_size, num_heads, seq_length, head_size, device): + cache = AttentionCache( + key = torch.zeros((batch_size, num_heads, head_size, seq_length), device=device), + value = torch.zeros((batch_size, num_heads, seq_length, head_size), device=device) + ) + return cache class LeanSelfAttention(nn.Module): def __init__( @@ -63,12 +77,16 @@ def __init__( self.output_dropout = nn.Dropout(dropout, inplace=False) self.residual, self.checkpoint_attention_core = residual, checkpoint_attention_core - def forward(self, hidden_states, attention_mask=None, output_attentions=False): - hidden_states_ln = self.pre_layer_norm(hidden_states) - qkv_output = self.qkv_proj(hidden_states_ln) + def forward(self, hidden_states, attention_mask=None, output_attentions=False, attn_cache=None, seq_index=0): + if attn_cache: + attn_local_cache = attn_cache[self] + else: + attn_local_cache = None + hidden_states_ln = self.layer_norm(hidden_states) + qkv_output = self.dense_qkv(hidden_states_ln) query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1) attention_output, attention_probs = self._maybe_checkpoint( - self.attention_core, query, key, value, attention_mask + self.attention_core, query, key, value, attention_mask, attn_local_cache, seq_index ) outputs = self.out_proj(attention_output) if self.post_layer_norm: @@ -90,7 +108,7 @@ def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_d self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads self.attention_head_size = hidden_size // num_attention_heads - def forward(self, query, key, value, attention_mask): + def forward(self, query, key, value, attention_mask, attn_cache=None, seq_index=0): """ :param query: [batch_size, query_seq_len, hidden_size] :param key: [batch_size, kv_seq_len, hidden_size] @@ -105,7 +123,7 @@ def forward(self, query, key, value, attention_mask): assert torch.is_floating_point(attention_mask), "expected float mask with negative values for masked items" return self._attention_core_forward( query, key, value, attention_mask, self.num_attention_heads, self.attention_dropout.p, - self.training, scale_inplace=False, + self.training, scale_inplace=False, attn_cache=attn_cache, seq_index=seq_index ) @staticmethod @@ -114,7 +132,12 @@ def _attention_core_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - num_attention_heads: int, attention_dropout: float, training: bool, scale_inplace: bool + num_attention_heads: int, + attention_dropout: float, + training: bool, + scale_inplace: bool, + attn_cache: AttentionCache = None, + seq_index: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: # transpose from [batch, seq_length, full_hid_size] to [batch, num_heads, seq_length, head_size] new_query_shape = query.shape[:-1] + (num_attention_heads, -1) @@ -128,8 +151,17 @@ def _attention_core_forward( value = value.view(new_kv_shape).permute(0, 2, 1, 3) del key # not to confuse with key_transposed + if attn_cache: + attn_cache.key[:, :, :, seq_index] = key_transposed_scaled[:, :, :, 0] + attn_cache.value[:, :, seq_index, :] = value[:, :, 0, :] + key_ref = attn_cache.key[:, :, :, :seq_index + 1] + value_ref = attn_cache.value[:, :, :seq_index + 1, :] + else: + key_ref = key_transposed_scaled + value_ref = value + # Take the dot product between "query" and "key" to get the raw attention scores - attention_scores = torch.matmul(query, key_transposed_scaled) + attention_scores = torch.matmul(query, key_ref) if attention_mask is not None: attention_scores += attention_mask @@ -143,7 +175,7 @@ def _attention_core_forward( # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = torch.dropout_(attention_probs, attention_dropout, training) - attention_output = torch.matmul(attention_probs, value) + attention_output = torch.matmul(attention_probs, value_ref) attention_output = attention_output.transpose(2, 1).flatten(2) return attention_output, attention_probs @@ -160,12 +192,12 @@ def __init__( rotary_emb = RotaryEmbeddings(self.attention_head_size) self.rotary_emb = rotary_emb - def rotate(self, tensor: torch.Tensor): + def rotate(self, tensor: torch.Tensor, seq_index=None): """:param tensor: query or key, shape: [batch_size, query_seq_len, hidden_size]""" tensor_split_heads = tensor.view(*(tensor.shape[:-1] + (self.num_attention_heads, self.attention_head_size))) - return self.rotary_emb(tensor_split_heads).view(*tensor.shape) + return self.rotary_emb(tensor_split_heads, offset=seq_index).view(*tensor.shape) - def forward(self, query, key, value, attention_mask): + def forward(self, query, key, value, attention_mask, attn_cache=None, seq_index=0): return self._attention_core_forward( - self.rotate(query), self.rotate(key), value, attention_mask, self.num_attention_heads, - self.attention_dropout.p, self.training, scale_inplace=True) + self.rotate(query, seq_index=seq_index), self.rotate(key, seq_index=seq_index), value, attention_mask, self.num_attention_heads, + self.attention_dropout.p, self.training, scale_inplace=True, attn_cache=attn_cache, seq_index=seq_index) diff --git a/lean_transformer/models/gpt.py b/lean_transformer/models/gpt.py index 643e4e1..2556eef 100644 --- a/lean_transformer/models/gpt.py +++ b/lean_transformer/models/gpt.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch GPT modules that do not hog your GPU memory """ -from typing import Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -228,6 +228,75 @@ def set_output_embeddings(self, new_lm_head: nn.Linear): def _init_weights(self, module: nn.Module): return self.config.init_weights(module) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, attn_cache=None, **kwargs): + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "attn_cache": attn_cache, + } + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if ( + self.config.is_encoder_decoder + and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name + ): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + model_kwargs["attn_cache"] = self.transformer._init_attn_cache( + inputs.shape[0], + self.config.num_attention_heads, + model_kwargs["max_len"], + self.config.hidden_size // self.config.num_attention_heads, + self.device + ) + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside " + f"{input_name} which is not allowed." + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. models with `input_ids` can also make use of `inputs_embeds` + if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. Only encoder-decoder models can have non `input_ids` input format + if not self.config.is_encoder_decoder and input_name != "input_ids": + raise ValueError( + f"If {input_name} is passed as model-specific keyword " + "input then model has to be an encoder-decoder and not a " + f"{self.__class__.__name__}." + ) + + # 5. if `inputs` is still None, try to create `input_ids` from BOS token + if inputs is None: + inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) + + return inputs, input_name, model_kwargs def forward( self, @@ -241,6 +310,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + attn_cache=None, ): assert head_mask is None and output_attentions is None and output_hidden_states is None, "not implemented" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -282,7 +352,7 @@ def forward( embedding_output = self.embeddings( input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) - transformer_outputs = self.transformer(embedding_output, extended_attention_mask + causal_attention_mask) + transformer_outputs = self.transformer(embedding_output, extended_attention_mask + causal_attention_mask, attn_cache=attn_cache) lm_logits = self.lm_head(transformer_outputs.last_hidden_state) loss = None diff --git a/lean_transformer/transformer.py b/lean_transformer/transformer.py index fd66302..682a048 100644 --- a/lean_transformer/transformer.py +++ b/lean_transformer/transformer.py @@ -1,15 +1,16 @@ from functools import partial from typing import Tuple, Optional, Union +import torch from torch import nn as nn from torch.autograd.graph import saved_tensors_hooks from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput -from lean_transformer import LeanFFN, LeanSelfAttention from lean_transformer.config import LeanTransformerConfig from lean_transformer.sequence import ActiveKwargs, ReversibleWithKwargs, SequentialWithKwargs from lean_transformer.blocksparse import GeneralizedMatrix +from lean_transformer import AttentionCache, CreateAttentionCache class LeanTransformer(nn.Module): @@ -38,7 +39,7 @@ def _get_sequential(self): for i in range(self.config.num_hidden_layers): group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) for layer in self.layer_groups[group_idx].layers: - sequence.append(ActiveKwargs(layer.attention, ("attention_mask",), use_first_output=True)) + sequence.append(ActiveKwargs(layer.attention, ("attention_mask", "attn_cache", "seq_index"), use_first_output=True)) sequence.append(ActiveKwargs(layer.ffn, active_keys=())) sequential_cls = ReversibleWithKwargs if self.config.reversible else SequentialWithKwargs self._sequential = (sequential_cls(*sequence),) @@ -75,14 +76,28 @@ def _make_ffn(self, index: int, config: LeanTransformerConfig): residual=not config.reversible, ) - def forward(self, hidden_states, attention_mask=None): + def _init_attn_cache(self, batch_size, num_heads, seq_length, head_size, device): + cache = {m:CreateAttentionCache(batch_size, num_heads, seq_length, head_size, device) for m in self.modules() if isinstance(m, LeanSelfAttention)} + return cache + + def forward(self, hidden_states, attention_mask=None, attn_cache=None): """ :param hidden_states: input embeddings, batch-first (e.g. [batch_size, seq_length, hidden-size]) :param attention_mask: an additive mask with zeros for active elements and large negative values for masked + :param attn_cache: attention cache with prevous sequence tokens for K and V """ - hidden_states = self._get_sequential()(hidden_states, attention_mask=attention_mask) + if attn_cache: + tokens_shape = hidden_states.size(1) + indices = torch.tensor([tokens_shape - 1]).to(hidden_states.device) + hidden_state_last = torch.index_select(hidden_states, 1, indices).to(hidden_states.device) + attention_mask_last = torch.index_select(attention_mask, 2, indices).to(hidden_states.device) + hidden_states = self._get_sequential()( + hidden_state_last, attention_mask=attention_mask_last, + attn_cache=attn_cache, seq_index=tokens_shape - 1) + else: + hidden_states = self._get_sequential()(hidden_states, attention_mask=attention_mask, attn_cache=attn_cache) return BaseModelOutput(last_hidden_state=self.post_layer_norm(hidden_states)) - + def init_weights(self): self.apply(self.config.init_weights)