@@ -51,94 +51,95 @@ uint32_t _getAlignment(uintptr_t address) {
51
51
inline void CreateCublasHandle (cublasLtHandle_t *handle) {
52
52
NVTE_CHECK_CUBLAS (cublasLtCreate (handle));
53
53
54
- struct GemmParam {
55
- void *A;
56
- void *B;
57
- cublasOperation_t transA;
58
- cublasOperation_t transB;
59
- transformer_engine::DType Atype;
60
- transformer_engine::DType Btype;
61
- void *A_scale_inv;
62
- void *B_scale_inv;
63
- int lda;
64
- int ldb;
65
-
66
- GemmParam (cublasOperation_t transA, cublasOperation_t transB)
67
- : A(nullptr ),
68
- B (nullptr ),
69
- transA(transA),
70
- transB(transB),
71
- Atype(transformer_engine::DType::kNumTypes ),
72
- Btype(transformer_engine::DType::kNumTypes ),
73
- A_scale_inv(nullptr ),
74
- B_scale_inv(nullptr ),
75
- lda(0 ),
76
- ldb(0 ) {}
77
- };
78
-
79
- GemmParam CanonicalizeGemmInput (const transformer_engine::Tensor &A, const cublasOperation_t transA,
80
- const transformer_engine::Tensor &B, const cublasOperation_t transB,
81
- const int k, const int lda, const int ldb) {
82
- using namespace transformer_engine ;
83
- NVTE_CHECK (A.scaling_mode == B.scaling_mode ,
84
- " Inputs A and B to GEMM need to have the same scaling mode!" );
85
- NVTE_CHECK (A.has_data () || A.has_columnwise_data (), " Input A does not hold any data!" );
86
- NVTE_CHECK (B.has_data () || B.has_columnwise_data (), " Input B does not hold any data!" );
87
- GemmParam ret (transA, transB);
88
-
89
- ret.lda = lda;
90
- ret.ldb = ldb;
91
-
92
- if (is_tensor_scaling (A.scaling_mode )) {
93
- ret.A = A.data .dptr ;
94
- ret.A_scale_inv = A.scale_inv .dptr ;
95
- if (transA == CUBLAS_OP_T) {
96
- ret.Atype = A.data .dtype ;
97
- } else {
98
- ret.Atype = A.has_columnwise_data () ? A.columnwise_data .dtype : A.data .dtype ;
99
- if (is_fp8_dtype (ret.Atype )) {
100
- int arch = cuda::sm_arch (cuda::current_device ());
101
- if (arch < 100 ) {
102
- // Hopper and Ada - we need to use columnwise_data and change transA
103
- NVTE_CHECK (A.has_columnwise_data (), " Input A is not suitable for columnwise usage!" );
104
- ret.A = A.columnwise_data .dptr ;
105
- ret.transA = CUBLAS_OP_T;
106
- ret.A_scale_inv = A.columnwise_scale_inv .dptr ;
107
- ret.lda = k;
54
+ struct GemmParam {
55
+ void *A;
56
+ void *B;
57
+ cublasOperation_t transA;
58
+ cublasOperation_t transB;
59
+ transformer_engine::DType Atype;
60
+ transformer_engine::DType Btype;
61
+ void *A_scale_inv;
62
+ void *B_scale_inv;
63
+ int lda;
64
+ int ldb;
65
+
66
+ GemmParam (cublasOperation_t transA, cublasOperation_t transB)
67
+ : A(nullptr ),
68
+ B (nullptr ),
69
+ transA(transA),
70
+ transB(transB),
71
+ Atype(transformer_engine::DType::kNumTypes ),
72
+ Btype(transformer_engine::DType::kNumTypes ),
73
+ A_scale_inv(nullptr ),
74
+ B_scale_inv(nullptr ),
75
+ lda(0 ),
76
+ ldb(0 ) {}
77
+ };
78
+
79
+ GemmParam CanonicalizeGemmInput (
80
+ const transformer_engine::Tensor &A, const cublasOperation_t transA,
81
+ const transformer_engine::Tensor &B, const cublasOperation_t transB, const int k,
82
+ const int lda, const int ldb) {
83
+ using namespace transformer_engine ;
84
+ NVTE_CHECK (A.scaling_mode == B.scaling_mode ,
85
+ " Inputs A and B to GEMM need to have the same scaling mode!" );
86
+ NVTE_CHECK (A.has_data () || A.has_columnwise_data (), " Input A does not hold any data!" );
87
+ NVTE_CHECK (B.has_data () || B.has_columnwise_data (), " Input B does not hold any data!" );
88
+ GemmParam ret (transA, transB);
89
+
90
+ ret.lda = lda;
91
+ ret.ldb = ldb;
92
+
93
+ if (is_tensor_scaling (A.scaling_mode )) {
94
+ ret.A = A.data .dptr ;
95
+ ret.A_scale_inv = A.scale_inv .dptr ;
96
+ if (transA == CUBLAS_OP_T) {
97
+ ret.Atype = A.data .dtype ;
98
+ } else {
99
+ ret.Atype = A.has_columnwise_data () ? A.columnwise_data .dtype : A.data .dtype ;
100
+ if (is_fp8_dtype (ret.Atype )) {
101
+ int arch = cuda::sm_arch (cuda::current_device ());
102
+ if (arch < 100 ) {
103
+ // Hopper and Ada - we need to use columnwise_data and change transA
104
+ NVTE_CHECK (A.has_columnwise_data (), " Input A is not suitable for columnwise usage!" );
105
+ ret.A = A.columnwise_data .dptr ;
106
+ ret.transA = CUBLAS_OP_T;
107
+ ret.A_scale_inv = A.columnwise_scale_inv .dptr ;
108
+ ret.lda = k;
109
+ }
108
110
}
109
111
}
110
- }
111
- ret.B = B.data .dptr ;
112
- ret. B_scale_inv = B. scale_inv . dptr ;
113
- if (transB == CUBLAS_OP_T) {
114
- ret. Btype = B. has_columnwise_data () ? B. columnwise_data . dtype : B. data . dtype ;
115
- if ( is_fp8_dtype (ret. Btype )) {
116
- int arch = cuda::sm_arch ( cuda::current_device ());
117
- if (arch < 100 ) {
118
- // Hopper and Ada - we need to use columnwise_data and change transA
119
- NVTE_CHECK (B. has_columnwise_data (), " Input B is not suitable for columnwise usage! " ) ;
120
- ret.B = B. columnwise_data . dptr ;
121
- ret.transB = CUBLAS_OP_N ;
122
- ret.B_scale_inv = B. columnwise_scale_inv . dptr ;
123
- ret. ldb = k;
112
+ ret. B = B. data . dptr ;
113
+ ret.B_scale_inv = B.scale_inv .dptr ;
114
+ if (transB == CUBLAS_OP_T) {
115
+ ret. Btype = B. has_columnwise_data () ? B. columnwise_data . dtype : B. data . dtype ;
116
+ if ( is_fp8_dtype (ret. Btype )) {
117
+ int arch = cuda::sm_arch ( cuda::current_device ());
118
+ if (arch < 100 ) {
119
+ // Hopper and Ada - we need to use columnwise_data and change transA
120
+ NVTE_CHECK (B. has_columnwise_data (), " Input B is not suitable for columnwise usage! " );
121
+ ret. B = B. columnwise_data . dptr ;
122
+ ret.transB = CUBLAS_OP_N ;
123
+ ret.B_scale_inv = B. columnwise_scale_inv . dptr ;
124
+ ret.ldb = k ;
125
+ }
124
126
}
127
+ } else {
128
+ ret.Btype = B.data .dtype ;
125
129
}
126
130
} else {
127
- ret.Btype = B.data .dtype ;
131
+ // If not tensor scaling (which includes also high precision types), we need to
132
+ // use the proper version of data
133
+ // We leave the transA/B values as is, since Blackwell supports transposes
134
+ ret.A = transA ? A.data .dptr : A.columnwise_data .dptr ;
135
+ ret.Atype = transA ? A.data .dtype : A.columnwise_data .dtype ;
136
+ ret.A_scale_inv = transA ? A.scale_inv .dptr : A.columnwise_scale_inv .dptr ;
137
+ ret.B = transB ? B.columnwise_data .dptr : B.data .dptr ;
138
+ ret.Btype = transB ? B.columnwise_data .dtype : B.data .dtype ;
139
+ ret.B_scale_inv = transB ? B.columnwise_scale_inv .dptr : B.scale_inv .dptr ;
128
140
}
129
- } else {
130
- // If not tensor scaling (which includes also high precision types), we need to
131
- // use the proper version of data
132
- // We leave the transA/B values as is, since Blackwell supports transposes
133
- ret.A = transA ? A.data .dptr : A.columnwise_data .dptr ;
134
- ret.Atype = transA ? A.data .dtype : A.columnwise_data .dtype ;
135
- ret.A_scale_inv = transA ? A.scale_inv .dptr : A.columnwise_scale_inv .dptr ;
136
- ret.B = transB ? B.columnwise_data .dptr : B.data .dptr ;
137
- ret.Btype = transB ? B.columnwise_data .dtype : B.data .dtype ;
138
- ret.B_scale_inv = transB ? B.columnwise_scale_inv .dptr : B.scale_inv .dptr ;
141
+ return ret;
139
142
}
140
- return ret;
141
- }
142
143
143
144
} // namespace
144
145
0 commit comments