10
10
import triton
11
11
import triton .language as tl
12
12
13
+ from torchao .prototype .moe_training .utils import (
14
+ _is_column_major ,
15
+ _is_row_major ,
16
+ )
17
+
13
18
fp8_gemm_configs_max_autotune = [
14
- # Small
15
- triton .Config ({"BLOCK_SIZE_M" : 32 , "BLOCK_SIZE_N" : 64 }, num_warps = 2 ),
16
- # Medium
17
- triton .Config ({"BLOCK_SIZE_M" : 64 , "BLOCK_SIZE_N" : 128 }, num_warps = 4 ),
18
- triton .Config ({"BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 64 }, num_warps = 4 ),
19
- triton .Config ({"BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 128 }, num_warps = 4 ),
20
- triton .Config ({"BLOCK_SIZE_M" : 64 , "BLOCK_SIZE_N" : 256 }, num_warps = 8 ),
21
- # Large
22
- triton .Config ({"BLOCK_SIZE_M" : 256 , "BLOCK_SIZE_N" : 64 }, num_warps = 8 ),
23
- triton .Config ({"BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 128 }, num_warps = 8 ),
24
- triton .Config ({"BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 256 }, num_warps = 4 ),
25
- triton .Config ({"BLOCK_SIZE_M" : 256 , "BLOCK_SIZE_N" : 128 }, num_warps = 4 ),
26
- triton .Config ({"BLOCK_SIZE_M" : 256 , "BLOCK_SIZE_N" : 128 }, num_warps = 8 ),
19
+ triton .Config (
20
+ {"BLOCK_SIZE_M" : block_size , "BLOCK_SIZE_N" : block_size },
21
+ num_warps = num_warps ,
22
+ num_stages = num_stages ,
23
+ )
24
+ for block_size in [64 , 128 , 256 ]
25
+ for num_warps in [4 , 8 ]
26
+ for num_stages in [2 , 4 ]
27
27
]
28
28
29
29
# For fast compile times during development.
@@ -57,6 +57,7 @@ def blockwise_fp8_gemm_1x128_128x128_kernel(
57
57
M ,
58
58
N : tl .constexpr ,
59
59
K : tl .constexpr ,
60
+ out_dtype : tl .constexpr ,
60
61
BLOCK_SIZE_M : tl .constexpr ,
61
62
BLOCK_SIZE_N : tl .constexpr ,
62
63
BLOCK_SIZE_K : tl .constexpr ,
@@ -81,18 +82,16 @@ def blockwise_fp8_gemm_1x128_128x128_kernel(
81
82
a_s_base_ptr = a_s_ptr + offs_m * a_s_stride_dim_0
82
83
b_s_base_ptr = b_s_ptr + (offs_n // BLOCK_SIZE_K ) * b_s_stride_dim_1
83
84
accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
85
+ a_mask = (offs_m [:, None ] < M ) & (offs_k [None , :] < K )
86
+ b_mask = (offs_k [:, None ] < K ) & (offs_n [None , :] < N )
84
87
for k in range (0 , k_num_blocks ):
85
- a_mask = (offs_m [:, None ] < M ) & (offs_k [None , :] < K )
86
88
a = tl .load (a_ptrs , mask = a_mask , other = 0.0 )
87
-
88
- b_mask = (offs_k [:, None ] < K ) & (offs_n [None , :] < N )
89
89
b = tl .load (b_ptrs , mask = b_mask , other = 0.0 )
90
90
91
91
# Reciprocal scales to scale back to dynamic range of output dtype
92
92
a_s = tl .load (a_s_base_ptr + k * a_s_stride_dim_1 )
93
93
b_s = tl .load (b_s_base_ptr + k * b_s_stride_dim_0 )
94
-
95
- accumulator += tl .dot (a , b ) * a_s [:, None ] * b_s [None , :]
94
+ accumulator += tl .dot (a , b ) * a_s [:, None ] * b_s
96
95
97
96
a_ptrs += BLOCK_SIZE_K * a_stride_dim_1
98
97
b_ptrs += BLOCK_SIZE_K * b_stride_dim_0
@@ -109,14 +108,22 @@ def blockwise_fp8_gemm_1x128_128x128(
109
108
b : torch .Tensor , # (K, N)
110
109
b_s : torch .Tensor , # (K // block_size, N // block_size)
111
110
block_size : int = 128 ,
111
+ out_dtype : torch .dtype = torch .float32 ,
112
112
):
113
113
# 'a' must be in row-major layout, 'b' must be in column-major layout
114
- assert a .is_contiguous () and not b .is_contiguous ()
115
- assert a_s .is_contiguous () and b_s .is_contiguous ()
114
+ assert _is_row_major (a ) and _is_column_major (b ), (
115
+ "a must be row-major, b must be column-major"
116
+ )
117
+
118
+ # a_scales must be row-major, b_scales must be column-major
119
+ assert _is_row_major (a_s ) and _is_column_major (b_s ), (
120
+ "a_s must be row-major, b_s must be column-major"
121
+ )
122
+
116
123
M = a .size (0 )
117
124
K = a .size (1 )
118
125
N = b .size (1 )
119
- c = a .new_empty (M , N , dtype = torch . bfloat16 )
126
+ c = a .new_empty (M , N , dtype = out_dtype )
120
127
grid = lambda META : (
121
128
triton .cdiv (M , META ["BLOCK_SIZE_M" ]),
122
129
triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
@@ -140,6 +147,7 @@ def blockwise_fp8_gemm_1x128_128x128(
140
147
M ,
141
148
N ,
142
149
K ,
150
+ out_dtype = out_dtype ,
143
151
BLOCK_SIZE_K = block_size ,
144
152
)
145
153
return c
@@ -217,14 +225,15 @@ def blockwise_fp8_gemm_1x128_128x1(
217
225
b : torch .Tensor , # (K, N)
218
226
b_s : torch .Tensor , # (K // block_size, N) reciprocals of scales
219
227
block_size : int = 128 ,
228
+ out_dtype : torch .dtype = torch .float32 ,
220
229
):
221
230
# 'a' must be in row-major layout, 'b' must be in column-major layout
222
231
assert a .is_contiguous () and not b .is_contiguous ()
223
232
assert a_s .is_contiguous () and b_s .is_contiguous ()
224
233
M = a .size (0 )
225
234
K = a .size (1 )
226
235
N = b .size (1 )
227
- c = a .new_empty (M , N , dtype = torch . bfloat16 )
236
+ c = a .new_empty (M , N , dtype = out_dtype )
228
237
grid = lambda META : (
229
238
triton .cdiv (M , META ["BLOCK_SIZE_M" ]),
230
239
triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
@@ -674,8 +683,10 @@ def fp8_blockwise_weight_quant_transposed_rhs(
674
683
M , N = x .size ()
675
684
y = torch .empty (N , M , dtype = dtype , device = x .device )
676
685
y = y .as_strided (y .size (), (1 , y .size (0 ))) # Column major
677
- s = x .new_empty (
678
- triton .cdiv (N , block_size ), triton .cdiv (M , block_size ), dtype = torch .float32
686
+ n_blocks , m_blocks = triton .cdiv (N , block_size ), triton .cdiv (M , block_size )
687
+ s = x .new_empty (n_blocks , m_blocks , dtype = torch .float32 ).as_strided (
688
+ (n_blocks , m_blocks ), # shape
689
+ (1 , n_blocks ), # stride
679
690
)
680
691
grid = lambda meta : (
681
692
triton .cdiv (M , meta ["BLOCK_SIZE" ]),
0 commit comments