forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_transformers_privateuse1.py
124 lines (111 loc) · 4.47 KB
/
test_transformers_privateuse1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Owner(s): ["module: sdpa"]
import os
import unittest
from collections import namedtuple
from functools import partial
import pytorch_openreg # noqa: F401
import torch
import torch.utils.cpp_extension
from torch.nn.attention import SDPBackend
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
IS_FBCODE,
run_tests,
skipIfTorchDynamo,
TEST_XPU,
)
SdpaShape = namedtuple("Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"])
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
@unittest.skipIf(
IS_FBCODE,
"Ninja is required to load C++ extensions and it's not compatible with Buck ",
)
class TestSDPAPrivateUse1Only(NNTestCase):
@classmethod
def setUpClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
cls.module = torch.utils.cpp_extension.load(
name="custom_device_extension",
sources=[
f"{'test/' if not os.getcwd().endswith('test') else ''}cpp_extensions/open_registration_extension.cpp",
],
extra_include_paths=["cpp_extensions"],
extra_cflags=["-g"],
verbose=True,
)
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
assert (
torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1)
== SDPBackend.OVERRIDEABLE.value
)
def test_scaled_dot_product_fused_attention_overrideable(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0
)
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
)
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
attn_mask_privateuse1 = attn_mask.to("openreg")
(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
)
rand_upward = torch.rand(
shape, device="cpu", dtype=torch.float16, requires_grad=False
)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = (
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,
k_privateuse1,
v_privateuse1,
attn_mask_privateuse1,
grad_input_mask,
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p=0.0,
is_causal=False,
philox_seed=philox_seed,
philox_offset=philox_offset,
)
)
if __name__ == "__main__":
run_tests()