Skip to content

Commit c108e35

Browse files
q10facebook-github-bot
authored andcommitted
Re-organize SLL ops, pt 8 (#3663)
Summary: X-link: facebookresearch/FBGEMM#738 Pull Request resolved: #3663 - Re-organize the remaining SLL triton ops Differential Revision: D68970862
1 parent 2cef43a commit c108e35

18 files changed

+206
-197
lines changed

.github/scripts/fbgemm_gpu_test.bash

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,6 @@ __configure_fbgemm_gpu_test_cpu () {
8383
# These tests have non-CPU operators referenced in @given
8484
./uvm/copy_test.py
8585
./uvm/uvm_test.py
86-
./sll/triton_sll_test.py
87-
./sll/array_jagged_bmm_jagged_out_test.py
88-
./sll/jagged_dense_elementwise_add_test.py
89-
./sll/jagged_flash_attention_basic_test.py
90-
./sll/jagged_jagged_bmm_jagged_out_test.py
91-
./sll/jagged_dense_flash_attention_test.py
92-
./sll/jagged_dense_bmm_test.py
93-
./sll/jagged_dense_elementwise_mul_jagged_out_test.py
94-
./sll/jagged_jagged_bmm_test.py
95-
./sll/jagged_softmax_test.py
96-
./sll/jagged2_to_padded_dense_test.py
97-
./sll/multi_head_jagged_flash_attention_test.py
9886
)
9987
}
10088

fbgemm_gpu/fbgemm_gpu/sll/__init__.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333
meta_jagged_self_substraction_jagged_out,
3434
)
3535

36-
from fbgemm_gpu.sll.triton_sll import ( # noqa F401
37-
jagged_dense_elementwise_mul_jagged_out,
38-
triton_jagged_self_substraction_jagged_out,
39-
)
40-
4136
from fbgemm_gpu.utils import TorchLibraryFragment
4237

4338
lib = TorchLibraryFragment("fbgemm")
@@ -262,25 +257,11 @@
262257
},
263258
}
264259

265-
# pyre-ignore[5]
266-
sll_gpu_registrations = {
267-
"sll_jagged_self_substraction_jagged_out": {
268-
"CUDA": triton_jagged_self_substraction_jagged_out,
269-
},
270-
"sll_jagged_dense_elementwise_mul_jagged_out": {
271-
"CUDA": jagged_dense_elementwise_mul_jagged_out,
272-
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
273-
},
274-
}
275-
276260
for op_name, dispatches in sll_cpu_registrations.items():
277261
lib.register(op_name, dispatches)
278262

279263
if torch.cuda.is_available():
280-
from fbgemm_gpu.sll.triton import op_registrations
281-
282-
for op_name, dispatches in op_registrations.items():
283-
lib.register(op_name, dispatches)
264+
from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations
284265

285266
for op_name, dispatches in sll_gpu_registrations.items():
286267
lib.register(op_name, dispatches)

fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
JaggedDenseAdd, # noqa F401
3838
)
3939

40+
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
41+
jagged_dense_elementwise_mul_jagged_out,
42+
JaggedDenseElementwiseMul, # noqa F401
43+
)
44+
4045
from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
4146
jagged_dense_flash_attention,
4247
JaggedDenseFlashAttention, # noqa F401
@@ -47,6 +52,10 @@
4752
JaggedFlashAttentionBasic, # noqa F401
4853
)
4954

55+
from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
56+
triton_jagged_self_substraction_jagged_out,
57+
)
58+
5059
from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
5160
jagged2_softmax,
5261
Jagged2Softmax, # noqa F401
@@ -108,4 +117,11 @@
108117
"CUDA": multi_head_jagged_flash_attention,
109118
"AutogradCUDA": multi_head_jagged_flash_attention,
110119
},
120+
"sll_jagged_self_substraction_jagged_out": {
121+
"CUDA": triton_jagged_self_substraction_jagged_out,
122+
},
123+
"sll_jagged_dense_elementwise_mul_jagged_out": {
124+
"CUDA": jagged_dense_elementwise_mul_jagged_out,
125+
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
126+
},
111127
}

fbgemm_gpu/fbgemm_gpu/sll/triton/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@
99
import torch
1010

1111

12+
def next_power_of_two(N: int) -> int:
13+
if N > 4096:
14+
raise Exception(f"{N} is too large that is not supported yet")
15+
16+
if N > 2048:
17+
return 4096
18+
elif N > 1024:
19+
return 2048
20+
elif N > 512:
21+
return 1024
22+
elif N > 256:
23+
return 512
24+
elif N > 128:
25+
return 256
26+
elif N > 64:
27+
return 128
28+
elif N > 32:
29+
return 64
30+
else:
31+
return 32
32+
33+
1234
def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
1335
if not x.is_contiguous():
1436
return x.contiguous()

fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py renamed to fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -11,61 +11,6 @@
1111
import triton.language as tl
1212

1313

14-
def next_power_of_two(N: int) -> int:
15-
if N > 4096:
16-
raise Exception(f"{N} is too large that is not supported yet")
17-
18-
if N > 2048:
19-
return 4096
20-
elif N > 1024:
21-
return 2048
22-
elif N > 512:
23-
return 1024
24-
elif N > 256:
25-
return 512
26-
elif N > 128:
27-
return 256
28-
elif N > 64:
29-
return 128
30-
elif N > 32:
31-
return 64
32-
else:
33-
return 32
34-
35-
36-
@triton.jit
37-
def jagged_self_substraction_jagged_out_kernel(
38-
a_ptr, # jagged
39-
b_ptr, # jagged
40-
a_offsets_ptr,
41-
b_offsets_ptr,
42-
max_seq_len,
43-
BLOCK_SIZE: tl.constexpr,
44-
):
45-
pid_batch = tl.program_id(0)
46-
pid_index = tl.program_id(1)
47-
48-
a_offset = tl.load(a_offsets_ptr + pid_batch)
49-
a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
50-
a_length = tl.minimum(a_length, max_seq_len + 1)
51-
52-
if a_length <= 1:
53-
return
54-
55-
N = a_length - 1
56-
if pid_index >= N:
57-
return
58-
59-
a_cur = tl.load(a_ptr + a_offset + pid_index)
60-
offs = tl.arange(0, BLOCK_SIZE)
61-
mask = offs < N
62-
a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
63-
b = a_cur - a_row
64-
65-
b_offset = tl.load(b_offsets_ptr + pid_batch)
66-
tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)
67-
68-
6914
@triton.jit
7015
def jagged_dense_elementwise_mul_jagged_out_kernel(
7116
a_ptr, # 1d jagged
@@ -123,33 +68,6 @@ def jagged_dense_elementwise_mul_jagged_out_kernel(
12368
c_ptrs += BLOCK_N
12469

12570

126-
def triton_jagged_self_substraction_jagged_out(
127-
jagged_A: torch.Tensor,
128-
offsets_a: torch.Tensor,
129-
offsets_b: torch.Tensor,
130-
max_seq_len,
131-
) -> torch.Tensor:
132-
B = offsets_a.size(0) - 1
133-
134-
jagged_B = torch.empty(
135-
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
136-
)
137-
138-
BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
139-
grid = (B, max_seq_len)
140-
141-
jagged_self_substraction_jagged_out_kernel[grid](
142-
jagged_A,
143-
jagged_B,
144-
offsets_a,
145-
offsets_b,
146-
max_seq_len,
147-
BLOCK_SIZE, # pyre-fixme[6]: For 6th argument expected `constexpr` but got `int`.
148-
)
149-
150-
return jagged_B
151-
152-
15371
def triton_jagged_dense_elementwise_mul_jagged_out(
15472
jagged_A,
15573
dense_B,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import triton
11+
import triton.language as tl
12+
13+
from .common import next_power_of_two
14+
15+
16+
@triton.jit
17+
def jagged_self_substraction_jagged_out_kernel(
18+
a_ptr, # jagged
19+
b_ptr, # jagged
20+
a_offsets_ptr,
21+
b_offsets_ptr,
22+
max_seq_len,
23+
BLOCK_SIZE: tl.constexpr,
24+
):
25+
pid_batch = tl.program_id(0)
26+
pid_index = tl.program_id(1)
27+
28+
a_offset = tl.load(a_offsets_ptr + pid_batch)
29+
a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
30+
a_length = tl.minimum(a_length, max_seq_len + 1)
31+
32+
if a_length <= 1:
33+
return
34+
35+
N = a_length - 1
36+
if pid_index >= N:
37+
return
38+
39+
a_cur = tl.load(a_ptr + a_offset + pid_index)
40+
offs = tl.arange(0, BLOCK_SIZE)
41+
mask = offs < N
42+
a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
43+
b = a_cur - a_row
44+
45+
b_offset = tl.load(b_offsets_ptr + pid_batch)
46+
tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)
47+
48+
49+
def triton_jagged_self_substraction_jagged_out(
50+
jagged_A: torch.Tensor,
51+
offsets_a: torch.Tensor,
52+
offsets_b: torch.Tensor,
53+
max_seq_len,
54+
) -> torch.Tensor:
55+
B = offsets_a.size(0) - 1
56+
57+
jagged_B = torch.empty(
58+
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
59+
)
60+
61+
BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
62+
grid = (B, max_seq_len)
63+
64+
jagged_self_substraction_jagged_out_kernel[grid](
65+
jagged_A,
66+
jagged_B,
67+
offsets_a,
68+
offsets_b,
69+
max_seq_len,
70+
BLOCK_SIZE,
71+
)
72+
73+
return jagged_B

fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ArrayJaggedBmmJaggedTest(unittest.TestCase):
3131
)
3232
@unittest.skipIf(*gpu_unavailable)
3333
@unittest.skipIf(*running_on_rocm)
34-
@settings(deadline=20000)
34+
@settings(deadline=30000)
3535
def test_triton_array_jagged_bmm_jagged_out(
3636
self,
3737
B: int,
@@ -157,7 +157,7 @@ def ref_array_jagged_bmm_jagged_out(
157157
)
158158
@unittest.skipIf(*gpu_unavailable)
159159
@unittest.skipIf(*running_on_rocm)
160-
@settings(deadline=20000)
160+
@settings(deadline=30000)
161161
def test_triton_array_jagged_bmm_jagged_out_with_grad(
162162
self,
163163
B: int,
@@ -244,7 +244,7 @@ def test_triton_array_jagged_bmm_jagged_out_with_grad(
244244
)
245245
@unittest.skipIf(*gpu_unavailable)
246246
@unittest.skipIf(*running_on_rocm)
247-
@settings(deadline=20000)
247+
@settings(deadline=30000)
248248
def test_triton_array_jagged_bmm_jagged_out_meta_backend(
249249
self,
250250
B: int,

fbgemm_gpu/test/sll/common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
# pyre-ignore-all-errors[56]
99

1010
import fbgemm_gpu
11-
import fbgemm_gpu.sll.cpu_sll
12-
import fbgemm_gpu.sll.triton_sll
11+
import fbgemm_gpu.sll
1312
import torch
1413

1514
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.

0 commit comments

Comments
 (0)