Skip to content

attention cache #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lean_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .ffn import LeanFFN
from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore
from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore, AttentionCache, CreateAttentionCache
from .rotary import RotaryEmbeddings, rotate
from .sequence import SequentialWithKwargs, ReversibleWithKwargs, ActiveKwargs
from .config import LeanTransformerConfig
Expand Down
60 changes: 46 additions & 14 deletions lean_transformer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@

from lean_transformer.rotary import RotaryEmbeddings

from dataclasses import dataclass


@dataclass
class AttentionCache:
key: torch.tensor
value: torch.tensor

def CreateAttentionCache(batch_size, num_heads, seq_length, head_size, device):
cache = AttentionCache(
key = torch.zeros((batch_size, num_heads, head_size, seq_length), device=device),
value = torch.zeros((batch_size, num_heads, seq_length, head_size), device=device)
)
return cache

class LeanSelfAttention(nn.Module):
def __init__(
Expand Down Expand Up @@ -63,12 +77,16 @@ def __init__(
self.output_dropout = nn.Dropout(dropout, inplace=False)
self.residual, self.checkpoint_attention_core = residual, checkpoint_attention_core

def forward(self, hidden_states, attention_mask=None, output_attentions=False):
hidden_states_ln = self.pre_layer_norm(hidden_states)
qkv_output = self.qkv_proj(hidden_states_ln)
def forward(self, hidden_states, attention_mask=None, output_attentions=False, attn_cache=None, seq_index=0):
if attn_cache:
attn_local_cache = attn_cache[self]
else:
attn_local_cache = None
hidden_states_ln = self.layer_norm(hidden_states)
qkv_output = self.dense_qkv(hidden_states_ln)
query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1)
attention_output, attention_probs = self._maybe_checkpoint(
self.attention_core, query, key, value, attention_mask
self.attention_core, query, key, value, attention_mask, attn_local_cache, seq_index
)
outputs = self.out_proj(attention_output)
if self.post_layer_norm:
Expand All @@ -90,7 +108,7 @@ def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_d
self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads

def forward(self, query, key, value, attention_mask):
def forward(self, query, key, value, attention_mask, attn_cache=None, seq_index=0):
"""
:param query: [batch_size, query_seq_len, hidden_size]
:param key: [batch_size, kv_seq_len, hidden_size]
Expand All @@ -105,7 +123,7 @@ def forward(self, query, key, value, attention_mask):
assert torch.is_floating_point(attention_mask), "expected float mask with negative values for masked items"
return self._attention_core_forward(
query, key, value, attention_mask, self.num_attention_heads, self.attention_dropout.p,
self.training, scale_inplace=False,
self.training, scale_inplace=False, attn_cache=attn_cache, seq_index=seq_index
)

@staticmethod
Expand All @@ -114,7 +132,12 @@ def _attention_core_forward(
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
num_attention_heads: int, attention_dropout: float, training: bool, scale_inplace: bool
num_attention_heads: int,
attention_dropout: float,
training: bool,
scale_inplace: bool,
attn_cache: AttentionCache = None,
seq_index: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
# transpose from [batch, seq_length, full_hid_size] to [batch, num_heads, seq_length, head_size]
new_query_shape = query.shape[:-1] + (num_attention_heads, -1)
Expand All @@ -128,8 +151,17 @@ def _attention_core_forward(
value = value.view(new_kv_shape).permute(0, 2, 1, 3)
del key # not to confuse with key_transposed

if attn_cache:
attn_cache.key[:, :, :, seq_index] = key_transposed_scaled[:, :, :, 0]
attn_cache.value[:, :, seq_index, :] = value[:, :, 0, :]
key_ref = attn_cache.key[:, :, :, :seq_index + 1]
value_ref = attn_cache.value[:, :, :seq_index + 1, :]
else:
key_ref = key_transposed_scaled
value_ref = value

# Take the dot product between "query" and "key" to get the raw attention scores
attention_scores = torch.matmul(query, key_transposed_scaled)
attention_scores = torch.matmul(query, key_ref)

if attention_mask is not None:
attention_scores += attention_mask
Expand All @@ -143,7 +175,7 @@ def _attention_core_forward(
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = torch.dropout_(attention_probs, attention_dropout, training)

attention_output = torch.matmul(attention_probs, value)
attention_output = torch.matmul(attention_probs, value_ref)
attention_output = attention_output.transpose(2, 1).flatten(2)

return attention_output, attention_probs
Expand All @@ -160,12 +192,12 @@ def __init__(
rotary_emb = RotaryEmbeddings(self.attention_head_size)
self.rotary_emb = rotary_emb

def rotate(self, tensor: torch.Tensor):
def rotate(self, tensor: torch.Tensor, seq_index=None):
""":param tensor: query or key, shape: [batch_size, query_seq_len, hidden_size]"""
tensor_split_heads = tensor.view(*(tensor.shape[:-1] + (self.num_attention_heads, self.attention_head_size)))
return self.rotary_emb(tensor_split_heads).view(*tensor.shape)
return self.rotary_emb(tensor_split_heads, offset=seq_index).view(*tensor.shape)

def forward(self, query, key, value, attention_mask):
def forward(self, query, key, value, attention_mask, attn_cache=None, seq_index=0):
return self._attention_core_forward(
self.rotate(query), self.rotate(key), value, attention_mask, self.num_attention_heads,
self.attention_dropout.p, self.training, scale_inplace=True)
self.rotate(query, seq_index=seq_index), self.rotate(key, seq_index=seq_index), value, attention_mask, self.num_attention_heads,
self.attention_dropout.p, self.training, scale_inplace=True, attn_cache=attn_cache, seq_index=seq_index)
74 changes: 72 additions & 2 deletions lean_transformer/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch GPT modules that do not hog your GPU memory """
from typing import Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -228,6 +228,75 @@ def set_output_embeddings(self, new_lm_head: nn.Linear):

def _init_weights(self, module: nn.Module):
return self.config.init_weights(module)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, attn_cache=None, **kwargs):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"attn_cache": attn_cache,
}

def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name

model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

model_kwargs["attn_cache"] = self.transformer._init_attn_cache(
inputs.shape[0],
self.config.num_attention_heads,
model_kwargs["max_len"],
self.config.hidden_size // self.config.num_attention_heads,
self.device
)

# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg

# 3. models with `input_ids` can also make use of `inputs_embeds`
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

# 4. Only encoder-decoder models can have non `input_ids` input format
if not self.config.is_encoder_decoder and input_name != "input_ids":
raise ValueError(
f"If {input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}."
)

# 5. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))

return inputs, input_name, model_kwargs

def forward(
self,
Expand All @@ -241,6 +310,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
attn_cache=None,
):
assert head_mask is None and output_attentions is None and output_hidden_states is None, "not implemented"
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -282,7 +352,7 @@ def forward(
embedding_output = self.embeddings(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
transformer_outputs = self.transformer(embedding_output, extended_attention_mask + causal_attention_mask)
transformer_outputs = self.transformer(embedding_output, extended_attention_mask + causal_attention_mask, attn_cache=attn_cache)
lm_logits = self.lm_head(transformer_outputs.last_hidden_state)

loss = None
Expand Down
25 changes: 20 additions & 5 deletions lean_transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from functools import partial
from typing import Tuple, Optional, Union

import torch
from torch import nn as nn
from torch.autograd.graph import saved_tensors_hooks
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput

from lean_transformer import LeanFFN, LeanSelfAttention
from lean_transformer.config import LeanTransformerConfig
from lean_transformer.sequence import ActiveKwargs, ReversibleWithKwargs, SequentialWithKwargs
from lean_transformer.blocksparse import GeneralizedMatrix
from lean_transformer import AttentionCache, CreateAttentionCache


class LeanTransformer(nn.Module):
Expand Down Expand Up @@ -38,7 +39,7 @@ def _get_sequential(self):
for i in range(self.config.num_hidden_layers):
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
for layer in self.layer_groups[group_idx].layers:
sequence.append(ActiveKwargs(layer.attention, ("attention_mask",), use_first_output=True))
sequence.append(ActiveKwargs(layer.attention, ("attention_mask", "attn_cache", "seq_index"), use_first_output=True))
sequence.append(ActiveKwargs(layer.ffn, active_keys=()))
sequential_cls = ReversibleWithKwargs if self.config.reversible else SequentialWithKwargs
self._sequential = (sequential_cls(*sequence),)
Expand Down Expand Up @@ -75,14 +76,28 @@ def _make_ffn(self, index: int, config: LeanTransformerConfig):
residual=not config.reversible,
)

def forward(self, hidden_states, attention_mask=None):
def _init_attn_cache(self, batch_size, num_heads, seq_length, head_size, device):
cache = {m:CreateAttentionCache(batch_size, num_heads, seq_length, head_size, device) for m in self.modules() if isinstance(m, LeanSelfAttention)}
return cache

def forward(self, hidden_states, attention_mask=None, attn_cache=None):
"""
:param hidden_states: input embeddings, batch-first (e.g. [batch_size, seq_length, hidden-size])
:param attention_mask: an additive mask with zeros for active elements and large negative values for masked
:param attn_cache: attention cache with prevous sequence tokens for K and V
"""
hidden_states = self._get_sequential()(hidden_states, attention_mask=attention_mask)
if attn_cache:
tokens_shape = hidden_states.size(1)
indices = torch.tensor([tokens_shape - 1]).to(hidden_states.device)
hidden_state_last = torch.index_select(hidden_states, 1, indices).to(hidden_states.device)
attention_mask_last = torch.index_select(attention_mask, 2, indices).to(hidden_states.device)
hidden_states = self._get_sequential()(
hidden_state_last, attention_mask=attention_mask_last,
attn_cache=attn_cache, seq_index=tokens_shape - 1)
else:
hidden_states = self._get_sequential()(hidden_states, attention_mask=attention_mask, attn_cache=attn_cache)
return BaseModelOutput(last_hidden_state=self.post_layer_norm(hidden_states))

def init_weights(self):
self.apply(self.config.init_weights)

Expand Down