Skip to content

Commit 814b71b

Browse files
authored
[Feat] Add WhisperFlashAttention2 (#2018)
1 parent b771d0a commit 814b71b

File tree

8 files changed

+769
-38
lines changed

8 files changed

+769
-38
lines changed

mindnlp/core/nn/modules/conv.py

+12-35
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import math
44
from typing import Optional, Tuple, Union, List
55
from mindspore import Tensor, ops as mops
6+
from mindspore.ops.auto_generate.gen_ops_prim import conv1d_ext_op, conv1d_padding_op
7+
from mindspore.ops.function.nn_func import pad_ext
68
from ..parameter import Parameter
79
from .module import Module
810
from ..common_types import _size_2_t, _size_1_t
@@ -182,43 +184,18 @@ def __init__(
182184
super().__init__(
183185
in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
184186
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
185-
186-
pad_mode = 'valid'
187-
pad = padding
188-
if isinstance(padding, tuple):
189-
if padding[0] != 0:
190-
pad_mode = 'pad'
191-
pad = (0, 0, padding[0], padding[0])
192-
elif isinstance(padding, int):
193-
if padding != 0:
194-
pad_mode = 'pad'
195-
pad = (0, 0) + (padding,) * 2
196-
if not isinstance(padding, (int, tuple)):
197-
pad_mode = padding
198-
pad = (0,) * 4
199-
200-
if self.padding_mode != 'zeros':
201-
pad_mode = 'valid'
202-
pad = (0,) * 4
203-
self.conv2d = mops.Conv2D(out_channel=self.out_channels,
204-
kernel_size=(1,) + self.kernel_size,
205-
mode=1,
206-
pad_mode=pad_mode,
207-
pad=pad,
208-
stride=(1,) + self.stride,
209-
dilation=(1,) + self.dilation,
210-
group=self.groups)
187+
188+
if isinstance(padding, str) and padding_mode == "zeros":
189+
self.conv1d = conv1d_padding_op
190+
else:
191+
self.conv1d = conv1d_ext_op
211192

212193
def forward(self, input):
213-
if self.padding_mode != 'zeros':
214-
input = F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode)
215-
input = input.expand_dims(2)
216-
output = self.conv2d(input, self.weight.expand_dims(2))
217-
218-
if self.bias is not None:
219-
output = mops.bias_add(output, self.bias)
220-
221-
output = output.squeeze(2)
194+
if self.padding_mode != "zeros":
195+
output = self.conv1d(pad_ext(input, self._reversed_padding, mode=self.padding_mode), self.weight,
196+
self.bias, self.stride, (0,), self.dilation, self.groups)
197+
else:
198+
output = self.conv1d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
222199
return output
223200

224201

mindnlp/core/ops/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def narrow(input, dim, start, length):
130130
has_nonzero = hasattr(mindspore.mint, 'nonzero')
131131
def nonzero(input, *, as_tuple=False):
132132
if use_pyboost() and has_nonzero:
133-
return mindspore.mint.nonzero(input, as_tuple)
133+
return mindspore.mint.nonzero(input, as_tuple=as_tuple)
134134
_nonzero = _get_cache_prim(ops.NonZero)()
135135
out = _nonzero(input)
136136
if as_tuple:

mindnlp/core/ops/other.py

+7
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,13 @@ def einsum(equation, *operands):
556556
return result
557557

558558

559+
# expand_dims
560+
has_expand_dims = hasattr(mindspore.mint, 'expand_dims')
561+
def expand_dims(input, axis):
562+
if use_pyboost() and has_expand_dims:
563+
return mindspore.mint.expand_dims(input, axis)
564+
return ops.expand_dims(input, axis)
565+
559566

560567
# flatten
561568
has_flatten = hasattr(mindspore.mint, 'flatten')

mindnlp/transformers/configuration_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def __init__(self, **kwargs):
342342

343343
# Attention implementation to use, if relevant.
344344
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
345+
self._attn_implementation_autoset = False
345346

346347
# Drop the transformers version info
347348
self.transformers_version = kwargs.pop("transformers_version", None)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
17+
Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
18+
"""
19+
20+
21+
import os
22+
23+
import math
24+
from typing import Optional, Tuple
25+
import mindspore
26+
from mindspore.ops import flash_attention_score
27+
from mindspore import nn
28+
from mindnlp.core import ops
29+
30+
31+
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
32+
# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
33+
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
34+
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
35+
36+
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=str(DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE)))
37+
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
38+
raise ValueError(
39+
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
40+
"or 3 (down-right aligned causal mask)."
41+
)
42+
43+
44+
def is_npu_fa2_top_left_aligned_causal_mask():
45+
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE
46+
47+
48+
class IndexFirstAxis(nn.Cell):
49+
def __init__(self):
50+
super(IndexFirstAxis, self).__init__()
51+
52+
def construct(self, input: mindspore.Tensor, indices: mindspore.Tensor):
53+
assert input.ndim >= 2
54+
first_axis_dim, other_shape = input.shape[0], input.shape[1:]
55+
input_flat = input.reshape(first_axis_dim, -1)
56+
indices_expanded = ops.expand_dims(indices, -1)
57+
indices_expanded = ops.broadcast_to(indices_expanded, (-1, input_flat.shape[1]))
58+
output_flat = ops.gather(input_flat, 0, indices_expanded)
59+
output = output_flat.reshape(-1, *other_shape)
60+
return output
61+
62+
def bprop(self, input, indices, out, dout):
63+
assert dout.ndim >= 2
64+
other_shape = dout.shape[1:]
65+
grad_output = dout
66+
67+
grad_flat = grad_output.reshape(grad_output.shape[0], -1)
68+
grad_shape = (input.shape[0], grad_flat.shape[1])
69+
grad_input = ops.zeros(grad_shape, grad_flat.dtype)
70+
71+
indices_expanded = ops.expand_dims(indices, -1)
72+
indices_expanded = ops.broadcast_to(indices_expanded, (-1, grad_flat.shape[1]))
73+
grad_input.scatter_(0, indices_expanded, grad_flat)
74+
75+
return grad_input.reshape(input.shape[0], *other_shape), None
76+
77+
78+
index_first_axis = IndexFirstAxis()
79+
80+
81+
class IndexPutFirstAxis(nn.Cell):
82+
def __init__(self):
83+
super(IndexPutFirstAxis, self).__init__()
84+
85+
def construct(self, values: mindspore.Tensor, indices: mindspore.Tensor, first_axis_dim: int):
86+
assert indices.ndim == 1
87+
assert values.ndim >= 2
88+
output = ops.zeros(
89+
(first_axis_dim, *values.shape[1:]),
90+
values.dtype
91+
)
92+
output[indices] = values
93+
return output
94+
95+
def bprop(self, values, indices, first_axis_dim, out, dout):
96+
grad_values = dout[indices]
97+
return grad_values, None, None
98+
99+
100+
index_put_first_axis = IndexPutFirstAxis()
101+
102+
103+
def pad_input(
104+
hidden_states: mindspore.Tensor,
105+
indices: mindspore.Tensor,
106+
batch: int,
107+
seqlen: int
108+
):
109+
"""
110+
Arguments:
111+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
112+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
113+
batch: int, batch size for the padded sequence.
114+
seqlen: int, maximum sequence length for the padded sequence.
115+
Return:
116+
hidden_states: (batch, seqlen, ...)
117+
"""
118+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
119+
return output.reshape(batch, seqlen, *hidden_states.shape[1:])
120+
121+
122+
def unpad_input(
123+
hidden_states: mindspore.Tensor,
124+
attention_mask: mindspore.Tensor,
125+
unused_mask: Optional[mindspore.Tensor] = None,
126+
):
127+
"""
128+
Arguments:
129+
hidden_states: (batch, seqlen, ...)
130+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
131+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
132+
Return:
133+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
134+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
135+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
136+
max_seqlen_in_batch: int
137+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
138+
"""
139+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
140+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=mindspore.int32)
141+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32)
142+
indices = ops.nonzero(all_masks.flatten(), as_tuple=False).flatten()
143+
max_seqlen_in_batch = seqlens_in_batch.max().item()
144+
cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0))
145+
146+
hidden_states_flat = hidden_states.reshape(-1, *hidden_states.shape[2:])
147+
hidden_states = index_first_axis(hidden_states_flat, indices)
148+
return (
149+
hidden_states,
150+
indices,
151+
cu_seqlens,
152+
max_seqlen_in_batch,
153+
used_seqlens_in_batch,
154+
)
155+
156+
157+
def create_attn_mask(causal: bool, sparse_mode: int) -> Tuple[int, mindspore.Tensor]:
158+
"""
159+
Create a causal mask for the attention scores.
160+
161+
Args:
162+
causal (`bool`):
163+
If `True`, the mask will be causal.
164+
sparse_mode (`bool`):
165+
If `True`, the mask will be top-left
166+
aligned, otherwise it will be bottom-right aligned.
167+
Returns:
168+
`Tuple[bool, mindspore.Tensor]`:
169+
A tuple containing sparse_mode and the mask tensor.
170+
"""
171+
if not causal:
172+
sparse_mode = 0
173+
attn_mask = None
174+
else:
175+
if sparse_mode == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE:
176+
attn_mask = ops.tril(ops.ones((2048, 2048)), diagonal=-1).bool()
177+
else:
178+
attn_mask = ops.triu(ops.ones((2048, 2048)), diagonal=1).bool()
179+
return sparse_mode, attn_mask
180+
181+
182+
def npu_flash_attn_func(
183+
q: mindspore.Tensor,
184+
k: mindspore.Tensor,
185+
v: mindspore.Tensor,
186+
dropout_p: float = 0.0,
187+
softmax_scale: Optional[float] = None,
188+
causal: bool = False,
189+
**kwargs,
190+
):
191+
head_num = q.shape[2]
192+
sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE)
193+
if softmax_scale is None:
194+
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
195+
output = flash_attention_score(
196+
q,
197+
k,
198+
v,
199+
head_num,
200+
keep_prob=1.0 - dropout_p,
201+
scalar_value=softmax_scale,
202+
attn_mask=attn_mask,
203+
input_layout="BSND",
204+
sparse_mode=sparse_mode,
205+
prefix=None,
206+
)
207+
208+
return output
209+
210+
211+
def npu_flash_attn_varlen_func(
212+
q: mindspore.Tensor,
213+
k: mindspore.Tensor,
214+
v: mindspore.Tensor,
215+
cu_seqlens_q: Optional[mindspore.Tensor] = None,
216+
cu_seqlens_k: Optional[mindspore.Tensor] = None,
217+
dropout_p: float = 0.0,
218+
softmax_scale: Optional[float] = None,
219+
causal: bool = False,
220+
**kwargs,
221+
):
222+
head_num = q.shape[1]
223+
sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE)
224+
if softmax_scale is None:
225+
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
226+
227+
output = flash_attention_score(
228+
q,
229+
k,
230+
v,
231+
head_num,
232+
keep_prob=1.0 - dropout_p,
233+
scalar_value=softmax_scale,
234+
attn_mask=attn_mask,
235+
input_layout="TND",
236+
actual_seq_qlen=cu_seqlens_q[1:].asnumpy().tolist(),
237+
actual_seq_kvlen=cu_seqlens_k[1:].asnumpy().tolist(),
238+
sparse_mode=sparse_mode,
239+
prefix=None,
240+
)
241+
242+
return output

0 commit comments

Comments
 (0)