Skip to content

Commit f1a6527

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ce754de commit f1a6527

File tree

1 file changed

+82
-81
lines changed

1 file changed

+82
-81
lines changed

transformer_engine/common/gemm/cublaslt_gemm.cu

+82-81
Original file line numberDiff line numberDiff line change
@@ -51,94 +51,95 @@ uint32_t _getAlignment(uintptr_t address) {
5151
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
5252
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
5353

54-
struct GemmParam {
55-
void *A;
56-
void *B;
57-
cublasOperation_t transA;
58-
cublasOperation_t transB;
59-
transformer_engine::DType Atype;
60-
transformer_engine::DType Btype;
61-
void *A_scale_inv;
62-
void *B_scale_inv;
63-
int lda;
64-
int ldb;
65-
66-
GemmParam(cublasOperation_t transA, cublasOperation_t transB)
67-
: A(nullptr),
68-
B(nullptr),
69-
transA(transA),
70-
transB(transB),
71-
Atype(transformer_engine::DType::kNumTypes),
72-
Btype(transformer_engine::DType::kNumTypes),
73-
A_scale_inv(nullptr),
74-
B_scale_inv(nullptr),
75-
lda(0),
76-
ldb(0) {}
77-
};
78-
79-
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
80-
const transformer_engine::Tensor &B, const cublasOperation_t transB,
81-
const int k, const int lda, const int ldb) {
82-
using namespace transformer_engine;
83-
NVTE_CHECK(A.scaling_mode == B.scaling_mode,
84-
"Inputs A and B to GEMM need to have the same scaling mode!");
85-
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
86-
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
87-
GemmParam ret(transA, transB);
88-
89-
ret.lda = lda;
90-
ret.ldb = ldb;
91-
92-
if (is_tensor_scaling(A.scaling_mode)) {
93-
ret.A = A.data.dptr;
94-
ret.A_scale_inv = A.scale_inv.dptr;
95-
if (transA == CUBLAS_OP_T) {
96-
ret.Atype = A.data.dtype;
97-
} else {
98-
ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
99-
if (is_fp8_dtype(ret.Atype)) {
100-
int arch = cuda::sm_arch(cuda::current_device());
101-
if (arch < 100) {
102-
// Hopper and Ada - we need to use columnwise_data and change transA
103-
NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
104-
ret.A = A.columnwise_data.dptr;
105-
ret.transA = CUBLAS_OP_T;
106-
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
107-
ret.lda = k;
54+
struct GemmParam {
55+
void *A;
56+
void *B;
57+
cublasOperation_t transA;
58+
cublasOperation_t transB;
59+
transformer_engine::DType Atype;
60+
transformer_engine::DType Btype;
61+
void *A_scale_inv;
62+
void *B_scale_inv;
63+
int lda;
64+
int ldb;
65+
66+
GemmParam(cublasOperation_t transA, cublasOperation_t transB)
67+
: A(nullptr),
68+
B(nullptr),
69+
transA(transA),
70+
transB(transB),
71+
Atype(transformer_engine::DType::kNumTypes),
72+
Btype(transformer_engine::DType::kNumTypes),
73+
A_scale_inv(nullptr),
74+
B_scale_inv(nullptr),
75+
lda(0),
76+
ldb(0) {}
77+
};
78+
79+
GemmParam CanonicalizeGemmInput(
80+
const transformer_engine::Tensor &A, const cublasOperation_t transA,
81+
const transformer_engine::Tensor &B, const cublasOperation_t transB, const int k,
82+
const int lda, const int ldb) {
83+
using namespace transformer_engine;
84+
NVTE_CHECK(A.scaling_mode == B.scaling_mode,
85+
"Inputs A and B to GEMM need to have the same scaling mode!");
86+
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
87+
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
88+
GemmParam ret(transA, transB);
89+
90+
ret.lda = lda;
91+
ret.ldb = ldb;
92+
93+
if (is_tensor_scaling(A.scaling_mode)) {
94+
ret.A = A.data.dptr;
95+
ret.A_scale_inv = A.scale_inv.dptr;
96+
if (transA == CUBLAS_OP_T) {
97+
ret.Atype = A.data.dtype;
98+
} else {
99+
ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
100+
if (is_fp8_dtype(ret.Atype)) {
101+
int arch = cuda::sm_arch(cuda::current_device());
102+
if (arch < 100) {
103+
// Hopper and Ada - we need to use columnwise_data and change transA
104+
NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
105+
ret.A = A.columnwise_data.dptr;
106+
ret.transA = CUBLAS_OP_T;
107+
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
108+
ret.lda = k;
109+
}
108110
}
109111
}
110-
}
111-
ret.B = B.data.dptr;
112-
ret.B_scale_inv = B.scale_inv.dptr;
113-
if (transB == CUBLAS_OP_T) {
114-
ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
115-
if (is_fp8_dtype(ret.Btype)) {
116-
int arch = cuda::sm_arch(cuda::current_device());
117-
if (arch < 100) {
118-
// Hopper and Ada - we need to use columnwise_data and change transA
119-
NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
120-
ret.B = B.columnwise_data.dptr;
121-
ret.transB = CUBLAS_OP_N;
122-
ret.B_scale_inv = B.columnwise_scale_inv.dptr;
123-
ret.ldb = k;
112+
ret.B = B.data.dptr;
113+
ret.B_scale_inv = B.scale_inv.dptr;
114+
if (transB == CUBLAS_OP_T) {
115+
ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
116+
if (is_fp8_dtype(ret.Btype)) {
117+
int arch = cuda::sm_arch(cuda::current_device());
118+
if (arch < 100) {
119+
// Hopper and Ada - we need to use columnwise_data and change transA
120+
NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
121+
ret.B = B.columnwise_data.dptr;
122+
ret.transB = CUBLAS_OP_N;
123+
ret.B_scale_inv = B.columnwise_scale_inv.dptr;
124+
ret.ldb = k;
125+
}
124126
}
127+
} else {
128+
ret.Btype = B.data.dtype;
125129
}
126130
} else {
127-
ret.Btype = B.data.dtype;
131+
// If not tensor scaling (which includes also high precision types), we need to
132+
// use the proper version of data
133+
// We leave the transA/B values as is, since Blackwell supports transposes
134+
ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
135+
ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
136+
ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
137+
ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
138+
ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
139+
ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
128140
}
129-
} else {
130-
// If not tensor scaling (which includes also high precision types), we need to
131-
// use the proper version of data
132-
// We leave the transA/B values as is, since Blackwell supports transposes
133-
ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
134-
ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
135-
ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
136-
ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
137-
ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
138-
ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
141+
return ret;
139142
}
140-
return ret;
141-
}
142143

143144
} // namespace
144145

0 commit comments

Comments
 (0)