3
3
import numpy as np
4
4
import sys
5
5
import csv
6
+ import random
6
7
7
-
8
+ import warnings
9
+ warnings .filterwarnings ("ignore" )
8
10
class CompositeMHA (torch .nn .Module ):
9
11
def __init__ (self , num_heads , in_proj_weight , in_proj_bias , out_proj ):
10
12
super ().__init__ ()
@@ -25,14 +27,15 @@ def forward(self, query, key, value, mask):
25
27
query , self .in_proj_weight , self .in_proj_bias
26
28
)
27
29
28
- batch_size , seq_len , embed_dim = query_projected .size ()
30
+ batch_size = query_projected .size (0 )
31
+ embed_dim = query_projected .size (2 )
29
32
head_dim = embed_dim // (self .num_heads * 3 )
30
33
31
- # Transpose seq_len and num_heads dim
32
- query_projected = query_projected . view (
33
- batch_size , seq_len , 3 * self .num_heads , head_dim
34
- ).transpose (1 , 2 )
35
- query , key , value = query_projected . chunk ( 3 , 1 )
34
+ query , key , value = query_projected . chunk ( 3 , - 1 )
35
+
36
+ query = query . view ( batch_size , - 1 , self .num_heads , head_dim ). transpose ( 1 , 2 )
37
+ key = key . view ( batch_size , - 1 , self . num_heads , head_dim ).transpose (1 , 2 )
38
+ value = value . view ( batch_size , - 1 , self . num_heads , head_dim ). transpose ( 1 , 2 )
36
39
37
40
# the output of sdp = (batch, num_heads, seq_len, head_dim)
38
41
attn , _ = torch .nn .functional ._scaled_dot_product_attention (
@@ -46,7 +49,7 @@ def forward(self, query, key, value, mask):
46
49
)
47
50
48
51
attn = attn .transpose (1 , 2 ).reshape (
49
- batch_size , seq_len , self .num_heads * head_dim
52
+ batch_size , - 1 , self .num_heads * head_dim
50
53
)
51
54
# Match return signature of nn.MHA
52
55
return self .out_proj (attn ), None
@@ -60,6 +63,18 @@ def build_composite_mha_from_nn_mha(pt):
60
63
return CompositeMHA (pt .num_heads , pt .in_proj_weight , pt .in_proj_bias , pt .out_proj )
61
64
62
65
66
+ def generate_rand_batch (batch_size , max_sequence_len , embed_dimension , pad_percentage = None , dtype = torch .float16 , device = "cuda" ):
67
+ if not pad_percentage :
68
+ return torch .randn (batch_size , max_sequence_len , embed_dimension , dtype = dtype , device = device ), None
69
+ # Really slow but should work
70
+ seq_len_list = [int (max_sequence_len * (1 - random .gauss (pad_percentage , 0.01 ))) for _ in range (batch_size )]
71
+ # Make random ele max length
72
+ seq_len_list [random .randint (0 , batch_size - 1 )] = max_sequence_len
73
+ # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}")
74
+ return torch .nested .nested_tensor ([
75
+ torch .randn (seq_len , embed_dimension , dtype = dtype , device = device ) for seq_len in seq_len_list ]), seq_len_list
76
+
77
+
63
78
def benchmark_torch_function (iters , f , * args , ** kwargs ):
64
79
if f is None :
65
80
return None
@@ -75,50 +90,57 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
75
90
return (start_event .elapsed_time (end_event ) * 1.0e-3 ) / iters
76
91
77
92
78
- def run_timing (batch_size , D , H , L , writer ):
79
- dropout_p = 0.0
80
- mask = None
81
-
82
- pt = torch .nn .MultiheadAttention (
83
- embed_dim = D , num_heads = H , batch_first = True , dropout = dropout_p
84
- )
85
- npt = pt .eval ().half ().cuda ()
86
- cpt = build_composite_mha_from_nn_mha (npt )
87
-
88
- x = torch .randn (batch_size , L , D )
89
- x = x .half ().cuda ()
90
-
91
- pt_output , _ = pt (x , x , x , mask )
92
- cp_output , _ = cpt (x , x , x , mask )
93
+ def run_timing (iters , batch_size , embed_dimension , num_heads , max_sequence_len , pad_percentage , writer ):
94
+ with torch .backends .cuda .sdp_kernel (enable_math = False , enable_flash = True ):
95
+ with torch .inference_mode ():
96
+ dropout_p = 0.0
97
+ mask = None
93
98
94
- # First order sanity check. Not a replacement for rigorous tests.
95
- assert torch .allclose (pt_output , cp_output , atol = 1e-3 , rtol = 1e-3 )
99
+ pt = torch .nn .MultiheadAttention (
100
+ embed_dim = embed_dimension , num_heads = num_heads , batch_first = True , dropout = dropout_p
101
+ )
102
+ npt = pt .eval ().half ().cuda ()
103
+ cpt = build_composite_mha_from_nn_mha (npt )
104
+ x , lengths = generate_rand_batch (batch_size , max_sequence_len , embed_dimension , pad_percentage )
105
+ pt_output , _ = pt (x , x , x , mask )
106
+ cpt_output , _ = cpt (x , x , x , mask )
107
+
108
+ # First order sanity check. Not a replacement for rigorous tests.
109
+ if pt_output .is_nested and cpt_output .is_nested :
110
+ for a , b in zip (pt_output .unbind (), cpt_output .unbind ()):
111
+ assert torch .allclose (a , b , atol = 1e-3 , rtol = 1e-3 )
112
+ else :
113
+ assert torch .allclose (pt_output , cpt_output , atol = 1e-3 , rtol = 1e-3 )
96
114
97
- with torch .backends .cuda .sdp_kernel (enable_math = True , enable_flash = True ):
98
- with torch .inference_mode ():
99
115
pt_time = benchmark_torch_function (iters , npt , x , x , x , mask ) * 1e3
100
116
cp_time = benchmark_torch_function (iters , cpt , x , x , x , mask ) * 1e3
101
117
results = {}
102
- results ["L " ] = L
103
- results ["H " ] = H
104
- results ["D " ] = D
118
+ results ["max_sequence_len " ] = max_sequence_len
119
+ results ["num_heads " ] = num_heads
120
+ results ["embed_dimension " ] = embed_dimension
105
121
results ["pt_time" ] = pt_time
106
122
results ["cp_time" ] = cp_time
107
123
results ["speedup" ] = pt_time / cp_time
108
124
results ["dtype" ] = str (x .dtype )
109
125
writer .writerow (results )
110
126
111
127
112
- if __name__ == "__main__" :
128
+ def main () :
113
129
iters = 100
114
130
seed = 123
115
131
np .random .seed (seed )
116
132
torch .manual_seed (seed )
117
133
118
- headers = ["L " , "H " , "D " , "pt_time" , "cp_time" , "speedup" , "dtype" ]
134
+ headers = ["max_sequence_len " , "num_heads " , "embed_dimension " , "pt_time" , "cp_time" , "speedup" , "dtype" ]
119
135
writer = csv .DictWriter (sys .stdout , headers )
120
136
writer .writeheader ()
121
137
122
138
batch_size = 64
123
- for H , L in itertools .product ([1 , 2 , 4 , 8 , 16 , 32 ], [64 , 128 , 256 ]):
124
- run_timing (batch_size , 1024 , H , L , writer )
139
+ pad_percentage = 0.5
140
+
141
+ for num_heads , max_seq_len in itertools .product ([2 , 4 , 8 , 16 , 32 ], [64 , 128 , 256 ]):
142
+ run_timing (iters , batch_size , 1024 , num_heads , max_seq_len , pad_percentage , writer )
143
+
144
+
145
+ if __name__ == "__main__" :
146
+ main ()
0 commit comments