|
| 1 | +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. |
| 17 | +Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask. |
| 18 | +""" |
| 19 | + |
| 20 | + |
| 21 | +import os |
| 22 | + |
| 23 | +import math |
| 24 | +from typing import Optional, Tuple |
| 25 | +import mindspore |
| 26 | +from mindspore.ops import flash_attention_score |
| 27 | +from mindspore import nn |
| 28 | +from mindnlp.core import ops |
| 29 | + |
| 30 | + |
| 31 | +# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. |
| 32 | +# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask. |
| 33 | +TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 |
| 34 | +DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3 |
| 35 | + |
| 36 | +SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=str(DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))) |
| 37 | +if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]: |
| 38 | + raise ValueError( |
| 39 | + "Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) " |
| 40 | + "or 3 (down-right aligned causal mask)." |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +def is_npu_fa2_top_left_aligned_causal_mask(): |
| 45 | + return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE |
| 46 | + |
| 47 | + |
| 48 | +class IndexFirstAxis(nn.Cell): |
| 49 | + def __init__(self): |
| 50 | + super(IndexFirstAxis, self).__init__() |
| 51 | + |
| 52 | + def construct(self, input: mindspore.Tensor, indices: mindspore.Tensor): |
| 53 | + assert input.ndim >= 2 |
| 54 | + first_axis_dim, other_shape = input.shape[0], input.shape[1:] |
| 55 | + input_flat = input.reshape(first_axis_dim, -1) |
| 56 | + indices_expanded = ops.expand_dims(indices, -1) |
| 57 | + indices_expanded = ops.broadcast_to(indices_expanded, (-1, input_flat.shape[1])) |
| 58 | + output_flat = ops.gather(input_flat, 0, indices_expanded) |
| 59 | + output = output_flat.reshape(-1, *other_shape) |
| 60 | + return output |
| 61 | + |
| 62 | + def bprop(self, input, indices, out, dout): |
| 63 | + assert dout.ndim >= 2 |
| 64 | + other_shape = dout.shape[1:] |
| 65 | + grad_output = dout |
| 66 | + |
| 67 | + grad_flat = grad_output.reshape(grad_output.shape[0], -1) |
| 68 | + grad_shape = (input.shape[0], grad_flat.shape[1]) |
| 69 | + grad_input = ops.zeros(grad_shape, grad_flat.dtype) |
| 70 | + |
| 71 | + indices_expanded = ops.expand_dims(indices, -1) |
| 72 | + indices_expanded = ops.broadcast_to(indices_expanded, (-1, grad_flat.shape[1])) |
| 73 | + grad_input.scatter_(0, indices_expanded, grad_flat) |
| 74 | + |
| 75 | + return grad_input.reshape(input.shape[0], *other_shape), None |
| 76 | + |
| 77 | + |
| 78 | +index_first_axis = IndexFirstAxis() |
| 79 | + |
| 80 | + |
| 81 | +class IndexPutFirstAxis(nn.Cell): |
| 82 | + def __init__(self): |
| 83 | + super(IndexPutFirstAxis, self).__init__() |
| 84 | + |
| 85 | + def construct(self, values: mindspore.Tensor, indices: mindspore.Tensor, first_axis_dim: int): |
| 86 | + assert indices.ndim == 1 |
| 87 | + assert values.ndim >= 2 |
| 88 | + output = ops.zeros( |
| 89 | + (first_axis_dim, *values.shape[1:]), |
| 90 | + values.dtype |
| 91 | + ) |
| 92 | + output[indices] = values |
| 93 | + return output |
| 94 | + |
| 95 | + def bprop(self, values, indices, first_axis_dim, out, dout): |
| 96 | + grad_values = dout[indices] |
| 97 | + return grad_values, None, None |
| 98 | + |
| 99 | + |
| 100 | +index_put_first_axis = IndexPutFirstAxis() |
| 101 | + |
| 102 | + |
| 103 | +def pad_input( |
| 104 | + hidden_states: mindspore.Tensor, |
| 105 | + indices: mindspore.Tensor, |
| 106 | + batch: int, |
| 107 | + seqlen: int |
| 108 | +): |
| 109 | + """ |
| 110 | + Arguments: |
| 111 | + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| 112 | + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| 113 | + batch: int, batch size for the padded sequence. |
| 114 | + seqlen: int, maximum sequence length for the padded sequence. |
| 115 | + Return: |
| 116 | + hidden_states: (batch, seqlen, ...) |
| 117 | + """ |
| 118 | + output = index_put_first_axis(hidden_states, indices, batch * seqlen) |
| 119 | + return output.reshape(batch, seqlen, *hidden_states.shape[1:]) |
| 120 | + |
| 121 | + |
| 122 | +def unpad_input( |
| 123 | + hidden_states: mindspore.Tensor, |
| 124 | + attention_mask: mindspore.Tensor, |
| 125 | + unused_mask: Optional[mindspore.Tensor] = None, |
| 126 | +): |
| 127 | + """ |
| 128 | + Arguments: |
| 129 | + hidden_states: (batch, seqlen, ...) |
| 130 | + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| 131 | + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| 132 | + Return: |
| 133 | + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| 134 | + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| 135 | + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| 136 | + max_seqlen_in_batch: int |
| 137 | + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. |
| 138 | + """ |
| 139 | + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask |
| 140 | + seqlens_in_batch = all_masks.sum(dim=-1, dtype=mindspore.int32) |
| 141 | + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32) |
| 142 | + indices = ops.nonzero(all_masks.flatten(), as_tuple=False).flatten() |
| 143 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 144 | + cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0)) |
| 145 | + |
| 146 | + hidden_states_flat = hidden_states.reshape(-1, *hidden_states.shape[2:]) |
| 147 | + hidden_states = index_first_axis(hidden_states_flat, indices) |
| 148 | + return ( |
| 149 | + hidden_states, |
| 150 | + indices, |
| 151 | + cu_seqlens, |
| 152 | + max_seqlen_in_batch, |
| 153 | + used_seqlens_in_batch, |
| 154 | + ) |
| 155 | + |
| 156 | + |
| 157 | +def create_attn_mask(causal: bool, sparse_mode: int) -> Tuple[int, mindspore.Tensor]: |
| 158 | + """ |
| 159 | + Create a causal mask for the attention scores. |
| 160 | +
|
| 161 | + Args: |
| 162 | + causal (`bool`): |
| 163 | + If `True`, the mask will be causal. |
| 164 | + sparse_mode (`bool`): |
| 165 | + If `True`, the mask will be top-left |
| 166 | + aligned, otherwise it will be bottom-right aligned. |
| 167 | + Returns: |
| 168 | + `Tuple[bool, mindspore.Tensor]`: |
| 169 | + A tuple containing sparse_mode and the mask tensor. |
| 170 | + """ |
| 171 | + if not causal: |
| 172 | + sparse_mode = 0 |
| 173 | + attn_mask = None |
| 174 | + else: |
| 175 | + if sparse_mode == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE: |
| 176 | + attn_mask = ops.tril(ops.ones((2048, 2048)), diagonal=-1).bool() |
| 177 | + else: |
| 178 | + attn_mask = ops.triu(ops.ones((2048, 2048)), diagonal=1).bool() |
| 179 | + return sparse_mode, attn_mask |
| 180 | + |
| 181 | + |
| 182 | +def npu_flash_attn_func( |
| 183 | + q: mindspore.Tensor, |
| 184 | + k: mindspore.Tensor, |
| 185 | + v: mindspore.Tensor, |
| 186 | + dropout_p: float = 0.0, |
| 187 | + softmax_scale: Optional[float] = None, |
| 188 | + causal: bool = False, |
| 189 | + **kwargs, |
| 190 | +): |
| 191 | + head_num = q.shape[2] |
| 192 | + sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE) |
| 193 | + if softmax_scale is None: |
| 194 | + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) |
| 195 | + output = flash_attention_score( |
| 196 | + q, |
| 197 | + k, |
| 198 | + v, |
| 199 | + head_num, |
| 200 | + keep_prob=1.0 - dropout_p, |
| 201 | + scalar_value=softmax_scale, |
| 202 | + attn_mask=attn_mask, |
| 203 | + input_layout="BSND", |
| 204 | + sparse_mode=sparse_mode, |
| 205 | + prefix=None, |
| 206 | + ) |
| 207 | + |
| 208 | + return output |
| 209 | + |
| 210 | + |
| 211 | +def npu_flash_attn_varlen_func( |
| 212 | + q: mindspore.Tensor, |
| 213 | + k: mindspore.Tensor, |
| 214 | + v: mindspore.Tensor, |
| 215 | + cu_seqlens_q: Optional[mindspore.Tensor] = None, |
| 216 | + cu_seqlens_k: Optional[mindspore.Tensor] = None, |
| 217 | + dropout_p: float = 0.0, |
| 218 | + softmax_scale: Optional[float] = None, |
| 219 | + causal: bool = False, |
| 220 | + **kwargs, |
| 221 | +): |
| 222 | + head_num = q.shape[1] |
| 223 | + sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE) |
| 224 | + if softmax_scale is None: |
| 225 | + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) |
| 226 | + |
| 227 | + output = flash_attention_score( |
| 228 | + q, |
| 229 | + k, |
| 230 | + v, |
| 231 | + head_num, |
| 232 | + keep_prob=1.0 - dropout_p, |
| 233 | + scalar_value=softmax_scale, |
| 234 | + attn_mask=attn_mask, |
| 235 | + input_layout="TND", |
| 236 | + actual_seq_qlen=cu_seqlens_q[1:].asnumpy().tolist(), |
| 237 | + actual_seq_kvlen=cu_seqlens_k[1:].asnumpy().tolist(), |
| 238 | + sparse_mode=sparse_mode, |
| 239 | + prefix=None, |
| 240 | + ) |
| 241 | + |
| 242 | + return output |
0 commit comments