8
8
#include < vector>
9
9
10
10
#ifdef USE_COMPOSABLE_KERNEL
11
+ #include " core/providers/rocm/composable_kernel_common.h"
12
+
11
13
#include " ck/ck.hpp"
12
14
#include " ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp"
13
15
#include " ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp"
@@ -28,20 +30,7 @@ namespace internal {
28
30
29
31
#ifdef USE_COMPOSABLE_KERNEL
30
32
31
- template <typename T>
32
- struct DataTypeAdaptor {
33
- using type = T;
34
- };
35
-
36
- template <>
37
- struct DataTypeAdaptor <half> {
38
- using type = ck::half_t ;
39
- };
40
-
41
- template <>
42
- struct DataTypeAdaptor <BFloat16> {
43
- using type = ck::bhalf16_t ;
44
- };
33
+ using onnxruntime::rocm::CKDataTypeAdaptor;
45
34
46
35
using Row = ck::tensor_layout::gemm::RowMajor;
47
36
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -52,7 +41,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
52
41
53
42
template <typename T, typename ALayout, typename BLayout>
54
43
auto GetCKGemmAddFastGeluTypeStringAndOps () {
55
- using CKDataType = typename DataTypeAdaptor <T>::type;
44
+ using CKDataType = typename CKDataTypeAdaptor <T>::type;
56
45
using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD<
57
46
ALayout, BLayout, ck::Tuple<Row>, Row,
58
47
CKDataType, CKDataType, ck::Tuple<CKDataType>, CKDataType,
@@ -89,7 +78,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() {
89
78
90
79
template <typename T, typename ALayout, typename BLayout>
91
80
auto GetCKGemmFastGeluTypeStringAndOps () {
92
- using CKDataType = typename DataTypeAdaptor <T>::type;
81
+ using CKDataType = typename CKDataTypeAdaptor <T>::type;
93
82
using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD<
94
83
ALayout, BLayout, ck::Tuple<>, Row,
95
84
CKDataType, CKDataType, ck::Tuple<>, CKDataType,
0 commit comments