Skip to content

Commit 3dbcb13

Browse files
PeixuanZuoPrathik Rao
authored and
Prathik Rao
committed
[ROCm] simplify ck data type Adaptor (#15734)
DataTypeAdaptor is defined many times in every file that integrates CK. This PR refactor the code to put DataTypeAdaptor in a header file.
1 parent 3ba10e9 commit 3dbcb13

File tree

6 files changed

+55
-80
lines changed

6 files changed

+55
-80
lines changed

onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh

+2-17
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO
6767
#include "contrib_ops/rocm/bert/attention_softmax.h"
6868
#ifdef USE_COMPOSABLE_KERNEL
6969
#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh"
70+
#include "core/providers/rocm/composable_kernel_common.h"
7071

7172
#include "ck/ck.hpp"
7273
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
@@ -451,22 +452,6 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
451452
};
452453
453454
#ifdef USE_COMPOSABLE_KERNEL
454-
namespace {
455-
template <typename T>
456-
struct DataTypeAdaptor {
457-
using type = T;
458-
};
459-
460-
template <>
461-
struct DataTypeAdaptor<half> {
462-
using type = ck::half_t;
463-
};
464-
465-
template <>
466-
struct DataTypeAdaptor<BFloat16> {
467-
using type = ck::bhalf16_t;
468-
};
469-
} // namespace
470455
471456
template <typename T, bool USE_BIAS, bool USE_MASK>
472457
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
@@ -475,7 +460,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
475460
using Nop = ck::tensor_operation::element_wise::PassThrough;
476461
using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp;
477462
478-
using CKDataType = typename DataTypeAdaptor<T>::type;
463+
using CKDataType = typename CKDataTypeAdaptor<T>::type;
479464
using D0DataType = typename ck::detail::tuple_concat<
480465
std::conditional_t<USE_BIAS, ck::Tuple<CKDataType>, ck::Tuple<>>,
481466
std::conditional_t<USE_MASK, ck::Tuple<CKDataType>, ck::Tuple<>>>::type;

onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh

+5-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <vector>
99

1010
#ifdef USE_COMPOSABLE_KERNEL
11+
#include "core/providers/rocm/composable_kernel_common.h"
12+
1113
#include "ck/ck.hpp"
1214
#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp"
1315
#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp"
@@ -28,20 +30,7 @@ namespace internal {
2830

2931
#ifdef USE_COMPOSABLE_KERNEL
3032

31-
template <typename T>
32-
struct DataTypeAdaptor {
33-
using type = T;
34-
};
35-
36-
template <>
37-
struct DataTypeAdaptor<half> {
38-
using type = ck::half_t;
39-
};
40-
41-
template <>
42-
struct DataTypeAdaptor<BFloat16> {
43-
using type = ck::bhalf16_t;
44-
};
33+
using onnxruntime::rocm::CKDataTypeAdaptor;
4534

4635
using Row = ck::tensor_layout::gemm::RowMajor;
4736
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -52,7 +41,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
5241

5342
template <typename T, typename ALayout, typename BLayout>
5443
auto GetCKGemmAddFastGeluTypeStringAndOps() {
55-
using CKDataType = typename DataTypeAdaptor<T>::type;
44+
using CKDataType = typename CKDataTypeAdaptor<T>::type;
5645
using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD<
5746
ALayout, BLayout, ck::Tuple<Row>, Row,
5847
CKDataType, CKDataType, ck::Tuple<CKDataType>, CKDataType,
@@ -89,7 +78,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() {
8978

9079
template <typename T, typename ALayout, typename BLayout>
9180
auto GetCKGemmFastGeluTypeStringAndOps() {
92-
using CKDataType = typename DataTypeAdaptor<T>::type;
81+
using CKDataType = typename CKDataTypeAdaptor<T>::type;
9382
using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD<
9483
ALayout, BLayout, ck::Tuple<>, Row,
9584
CKDataType, CKDataType, ck::Tuple<>, CKDataType,

onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh

+6-12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <vector>
99

1010
#ifdef USE_COMPOSABLE_KERNEL
11+
#include "core/providers/rocm/composable_kernel_common.h"
12+
1113
#include "ck/ck.hpp"
1214
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
1315
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -22,15 +24,7 @@ namespace rocm {
2224

2325
#ifdef USE_COMPOSABLE_KERNEL
2426

25-
template <typename T>
26-
struct DataTypeAdaptor {
27-
using type = T;
28-
};
29-
30-
template <>
31-
struct DataTypeAdaptor<half> {
32-
using type = ck::half_t;
33-
};
27+
using onnxruntime::rocm::CKDataTypeAdaptor;
3428

3529
using Swish = ck::tensor_operation::element_wise::Swish;
3630
using Pass = ck::tensor_operation::element_wise::PassThrough;
@@ -40,9 +34,9 @@ constexpr int NumReduceDim = 3;
4034

4135
template <typename T, typename AccT, bool WithSwish>
4236
auto GetCKGroupNormNHWCTypeStringAndOps() {
43-
using InDataType = typename DataTypeAdaptor<T>::type;
44-
using OutDataType = typename DataTypeAdaptor<T>::type;
45-
using AccDataType = typename DataTypeAdaptor<AccT>::type;
37+
using InDataType = typename CKDataTypeAdaptor<T>::type;
38+
using OutDataType = typename CKDataTypeAdaptor<T>::type;
39+
using AccDataType = typename CKDataTypeAdaptor<AccT>::type;
4640
using GammaDataType = float;
4741
using BetaDataType = float;
4842

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#ifdef USE_COMPOSABLE_KERNEL
7+
#include "ck/utility/data_type.hpp"
8+
#endif
9+
10+
#include "core/providers/rocm/rocm_common.h"
11+
12+
namespace onnxruntime {
13+
namespace rocm {
14+
15+
#ifdef USE_COMPOSABLE_KERNEL
16+
template <typename T>
17+
struct CKDataTypeAdaptor {
18+
using type = T;
19+
};
20+
21+
template <>
22+
struct CKDataTypeAdaptor<half> {
23+
using type = ck::half_t;
24+
};
25+
26+
template <>
27+
struct CKDataTypeAdaptor<BFloat16> {
28+
using type = ck::bhalf16_t;
29+
};
30+
#endif
31+
32+
} // namespace rocm
33+
} // namespace onnxruntime

onnxruntime/core/providers/rocm/math/softmax_ck.cuh

+5-18
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <vector>
99

1010
#ifdef USE_COMPOSABLE_KERNEL
11+
#include "core/providers/rocm/composable_kernel_common.h"
12+
1113
#include "ck/ck.hpp"
1214
#include "ck/library/tensor_operation_instance/gpu/softmax.hpp"
1315
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
@@ -22,30 +24,15 @@ namespace rocm {
2224

2325
#ifdef USE_COMPOSABLE_KERNEL
2426

25-
template <typename T>
26-
struct DataTypeAdaptor {
27-
using type = T;
28-
};
29-
30-
template <>
31-
struct DataTypeAdaptor<half> {
32-
using type = ck::half_t;
33-
};
34-
35-
template <>
36-
struct DataTypeAdaptor<BFloat16> {
37-
using type = ck::bhalf16_t;
38-
};
39-
4027
using Nop = ck::tensor_operation::element_wise::PassThrough;
4128
constexpr int Rank = 4;
4229
constexpr int NumReduceDim = 1;
4330

4431
template <typename InputT, typename OutputT, typename AccT>
4532
auto GetCKSoftmaxTypeStringAndOps() {
46-
using InDataType = typename DataTypeAdaptor<InputT>::type;
47-
using OutDataType = typename DataTypeAdaptor<OutputT>::type;
48-
using AccDataType = typename DataTypeAdaptor<AccT>::type;
33+
using InDataType = typename CKDataTypeAdaptor<InputT>::type;
34+
using OutDataType = typename CKDataTypeAdaptor<OutputT>::type;
35+
using AccDataType = typename CKDataTypeAdaptor<AccT>::type;
4936
using DeviceSoftmax = ck::tensor_operation::device::
5037
DeviceSoftmax<InDataType, AccDataType, OutDataType, Nop, Nop, Rank>;
5138
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceSoftmax>;

onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh

+4-17
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <vector>
99

1010
#ifdef USE_COMPOSABLE_KERNEL
11+
#include "core/providers/rocm/composable_kernel_common.h"
12+
1113
#include "ck/ck.hpp"
1214
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
1315
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
@@ -27,29 +29,14 @@ namespace internal {
2729

2830
#ifdef USE_COMPOSABLE_KERNEL
2931

30-
template <typename T>
31-
struct DataTypeAdaptor {
32-
using type = T;
33-
};
34-
35-
template <>
36-
struct DataTypeAdaptor<half> {
37-
using type = ck::half_t;
38-
};
39-
40-
template <>
41-
struct DataTypeAdaptor<BFloat16> {
42-
using type = ck::bhalf16_t;
43-
};
44-
4532
using Row = ck::tensor_layout::gemm::RowMajor;
4633
using Col = ck::tensor_layout::gemm::ColumnMajor;
4734

4835
using Nop = ck::tensor_operation::element_wise::PassThrough;
4936

5037
template <typename T, typename ALayout, typename BLayout>
5138
auto GetCKGemmTypeStringAndOps() {
52-
using CKDataType = typename DataTypeAdaptor<T>::type;
39+
using CKDataType = typename CKDataTypeAdaptor<T>::type;
5340
using DeviceGemm = ck::tensor_operation::device::DeviceGemm<
5441
ALayout, BLayout, Row,
5542
CKDataType, CKDataType, CKDataType,
@@ -95,7 +82,7 @@ auto GetCKGemmTypeStringAndOps() {
9582

9683
template <typename T, typename ALayout, typename BLayout>
9784
auto GetCKStridedBatchedGemmTypeStringAndOps() {
98-
using CKDataType = typename DataTypeAdaptor<T>::type;
85+
using CKDataType = typename CKDataTypeAdaptor<T>::type;
9986
using DeviceStridedBatchedGemm = ck::tensor_operation::device::DeviceBatchedGemm<
10087
ALayout, BLayout, Row,
10188
CKDataType, CKDataType, CKDataType,

0 commit comments

Comments
 (0)