forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflashpy_xformers-0.0.23.rocm.patch
152 lines (145 loc) · 5.56 KB
/
flashpy_xformers-0.0.23.rocm.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
@@ -36,44 +36,44 @@
FLASH_VERSION = "0.0.0"
try:
- try:
- from ... import _C_flashattention # type: ignore[attr-defined]
- from ..._cpp_lib import _build_metadata
-
- if _build_metadata is not None:
- FLASH_VERSION = _build_metadata.flash_version
- except ImportError:
- import flash_attn
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- if (
- flash_ver_parsed != (2, 3, 6)
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- ):
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
+
+ # if _build_metadata is not None:
+ # FLASH_VERSION = _build_metadata.flash_version
+ #except ImportError:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+ # if (
+ # flash_ver_parsed != (2, 3, 6)
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+ # ):
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
-
- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_left, "
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, "
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_left, "
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, "
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)
def _flash_fwd(
query,
@@ -111,8 +111,8 @@
p,
softmax_scale,
is_causal,
- window_left, # window_size_left
- window_right, # window_size_right
+ # window_left, # window_size_left
+ # window_right, # window_size_right
return_softmax,
None, # rng
)
@@ -134,15 +134,15 @@
out,
cu_seq_lens_q,
cu_seq_lens_k,
- seqused_k,
+ # seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
- window_left,
- window_right,
+ # window_left,
+ # window_right,
return_softmax,
None,
)
@@ -184,8 +184,8 @@
p,
softmax_scale,
is_causal,
- window_left,
- window_right,
+ # window_left,
+ # window_right,
None,
rng_state,
)
@@ -208,15 +208,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_left,
- window_right,
+ # window_left,
+ # window_right,
None,
rng_state,
)
return dq, dk, dv
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
@@ -400,7 +400,7 @@
implementation.
"""
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}