Skip to content

Commit d06d569

Browse files
drisspgpytorchmergebot
authored andcommitted
Update the sdp benchmark to work with nested tensors (pytorch#87215)
# Summary Update the sdp benchmark to work with nested tensors Pull Request resolved: pytorch#87215 Approved by: https://github.com/cpuhrsch
1 parent e8c4adf commit d06d569

File tree

2 files changed

+57
-35
lines changed

2 files changed

+57
-35
lines changed

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
399399

400400
// If the physical layout of the NestedTensor's storage
401401
// is not: batch, {seq_len}, num_heads, head_dim then we need
402-
// to call contiguous
402+
// to call contiguous
403403
if (!is_safe_to_get_storage_as_tensor(query_impl, key_impl, value_impl)) {
404404
q_t = q_t.contiguous();
405405
k_t = k_t.contiguous();

benchmarks/transformer/sdp.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import numpy as np
44
import sys
55
import csv
6+
import random
67

7-
8+
import warnings
9+
warnings.filterwarnings("ignore")
810
class CompositeMHA(torch.nn.Module):
911
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
1012
super().__init__()
@@ -25,14 +27,15 @@ def forward(self, query, key, value, mask):
2527
query, self.in_proj_weight, self.in_proj_bias
2628
)
2729

28-
batch_size, seq_len, embed_dim = query_projected.size()
30+
batch_size = query_projected.size(0)
31+
embed_dim = query_projected.size(2)
2932
head_dim = embed_dim // (self.num_heads * 3)
3033

31-
# Transpose seq_len and num_heads dim
32-
query_projected = query_projected.view(
33-
batch_size, seq_len, 3 * self.num_heads, head_dim
34-
).transpose(1, 2)
35-
query, key, value = query_projected.chunk(3, 1)
34+
query, key, value = query_projected.chunk(3, -1)
35+
36+
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
37+
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
38+
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
3639

3740
# the output of sdp = (batch, num_heads, seq_len, head_dim)
3841
attn, _ = torch.nn.functional._scaled_dot_product_attention(
@@ -46,7 +49,7 @@ def forward(self, query, key, value, mask):
4649
)
4750

4851
attn = attn.transpose(1, 2).reshape(
49-
batch_size, seq_len, self.num_heads * head_dim
52+
batch_size, -1, self.num_heads * head_dim
5053
)
5154
# Match return signature of nn.MHA
5255
return self.out_proj(attn), None
@@ -60,6 +63,18 @@ def build_composite_mha_from_nn_mha(pt):
6063
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
6164

6265

66+
def generate_rand_batch(batch_size, max_sequence_len, embed_dimension, pad_percentage=None, dtype=torch.float16, device="cuda"):
67+
if not pad_percentage:
68+
return torch.randn(batch_size, max_sequence_len, embed_dimension, dtype=dtype, device=device), None
69+
# Really slow but should work
70+
seq_len_list = [int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) for _ in range(batch_size)]
71+
# Make random ele max length
72+
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
73+
# print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}")
74+
return torch.nested.nested_tensor([
75+
torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) for seq_len in seq_len_list]), seq_len_list
76+
77+
6378
def benchmark_torch_function(iters, f, *args, **kwargs):
6479
if f is None:
6580
return None
@@ -75,50 +90,57 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
7590
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
7691

7792

78-
def run_timing(batch_size, D, H, L, writer):
79-
dropout_p = 0.0
80-
mask = None
81-
82-
pt = torch.nn.MultiheadAttention(
83-
embed_dim=D, num_heads=H, batch_first=True, dropout=dropout_p
84-
)
85-
npt = pt.eval().half().cuda()
86-
cpt = build_composite_mha_from_nn_mha(npt)
87-
88-
x = torch.randn(batch_size, L, D)
89-
x = x.half().cuda()
90-
91-
pt_output, _ = pt(x, x, x, mask)
92-
cp_output, _ = cpt(x, x, x, mask)
93+
def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, writer):
94+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True):
95+
with torch.inference_mode():
96+
dropout_p = 0.0
97+
mask = None
9398

94-
# First order sanity check. Not a replacement for rigorous tests.
95-
assert torch.allclose(pt_output, cp_output, atol=1e-3, rtol=1e-3)
99+
pt = torch.nn.MultiheadAttention(
100+
embed_dim=embed_dimension, num_heads=num_heads, batch_first=True, dropout=dropout_p
101+
)
102+
npt = pt.eval().half().cuda()
103+
cpt = build_composite_mha_from_nn_mha(npt)
104+
x, lengths = generate_rand_batch(batch_size, max_sequence_len, embed_dimension, pad_percentage)
105+
pt_output, _ = pt(x, x, x, mask)
106+
cpt_output, _ = cpt(x, x, x, mask)
107+
108+
# First order sanity check. Not a replacement for rigorous tests.
109+
if pt_output.is_nested and cpt_output.is_nested:
110+
for a, b in zip(pt_output.unbind(), cpt_output.unbind()):
111+
assert torch.allclose(a, b, atol=1e-3, rtol=1e-3)
112+
else:
113+
assert torch.allclose(pt_output, cpt_output, atol=1e-3, rtol=1e-3)
96114

97-
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True):
98-
with torch.inference_mode():
99115
pt_time = benchmark_torch_function(iters, npt, x, x, x, mask) * 1e3
100116
cp_time = benchmark_torch_function(iters, cpt, x, x, x, mask) * 1e3
101117
results = {}
102-
results["L"] = L
103-
results["H"] = H
104-
results["D"] = D
118+
results["max_sequence_len"] = max_sequence_len
119+
results["num_heads"] = num_heads
120+
results["embed_dimension"] = embed_dimension
105121
results["pt_time"] = pt_time
106122
results["cp_time"] = cp_time
107123
results["speedup"] = pt_time / cp_time
108124
results["dtype"] = str(x.dtype)
109125
writer.writerow(results)
110126

111127

112-
if __name__ == "__main__":
128+
def main():
113129
iters = 100
114130
seed = 123
115131
np.random.seed(seed)
116132
torch.manual_seed(seed)
117133

118-
headers = ["L", "H", "D", "pt_time", "cp_time", "speedup", "dtype"]
134+
headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", "cp_time", "speedup", "dtype"]
119135
writer = csv.DictWriter(sys.stdout, headers)
120136
writer.writeheader()
121137

122138
batch_size = 64
123-
for H, L in itertools.product([1, 2, 4, 8, 16, 32], [64, 128, 256]):
124-
run_timing(batch_size, 1024, H, L, writer)
139+
pad_percentage = 0.5
140+
141+
for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]):
142+
run_timing(iters, batch_size, 1024, num_heads, max_seq_len, pad_percentage, writer)
143+
144+
145+
if __name__ == "__main__":
146+
main()

0 commit comments

Comments
 (0)