|
12 | 12 | // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
13 | 13 | // All matrix tiles have ne physical 32 bit elements per warp.
|
14 | 14 | //
|
15 |
| -// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. |
| 15 | +// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. |
| 16 | +// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior. |
16 | 17 |
|
17 | 18 | #include "common.cuh"
|
18 | 19 |
|
@@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
|
66 | 67 | struct tile {
|
67 | 68 | static constexpr int I = I_;
|
68 | 69 | static constexpr int J = J_;
|
69 |
| - static constexpr int ne = I * J / WARP_SIZE; |
| 70 | + |
| 71 | +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) |
| 72 | + static constexpr int ne = I * J / 64; |
| 73 | + T x[ne] = {0}; |
| 74 | + |
| 75 | + static __device__ __forceinline__ int get_i(const int l) { |
| 76 | + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> |
| 77 | + return threadIdx.x % 16; |
| 78 | + } else if constexpr (I == 16 && J == 8) { |
| 79 | + return threadIdx.x % 16; |
| 80 | + } else if constexpr (I == 32 && J == 4) { |
| 81 | + return threadIdx.x % 32; |
| 82 | + } else if constexpr (I == 16 && J == 16) { |
| 83 | + return 4 * (threadIdx.x / 16) + l; |
| 84 | + } else if constexpr (I == 32 && J == 32) { |
| 85 | + return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); |
| 86 | + } else { |
| 87 | + static_assert(I == -1 && J == -1, "template specialization not implemented"); |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + static __device__ __forceinline__ int get_j(const int l) { |
| 92 | + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> |
| 93 | + return (2 * ((threadIdx.x / 16) % 2) + l); |
| 94 | + } else if constexpr (I == 16 && J == 8) { |
| 95 | + return 2 * (threadIdx.x / 16) + l; |
| 96 | + } else if constexpr (I == 32 && J == 4) { |
| 97 | + return 2 * (threadIdx.x / 32) + l; |
| 98 | + } else if constexpr (I == 16 && J == 16) { |
| 99 | + return threadIdx.x % 16; |
| 100 | + } else if constexpr (I == 32 && J == 32) { |
| 101 | + return threadIdx.x % 32; |
| 102 | + } else { |
| 103 | + static_assert(I == -1 && J == -1, "template specialization not implemented"); |
| 104 | + } |
| 105 | + } |
| 106 | +#else |
| 107 | + static constexpr int ne = I * J / 32; |
70 | 108 | T x[ne] = {0};
|
71 | 109 |
|
72 | 110 | static __device__ __forceinline__ int get_i(const int l) {
|
@@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
|
94 | 132 | static_assert(I == -1 && J == -1, "template specialization not implemented");
|
95 | 133 | }
|
96 | 134 | }
|
| 135 | +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) |
97 | 136 | };
|
98 | 137 |
|
99 | 138 | template <int I_, int J_>
|
@@ -148,10 +187,23 @@ namespace ggml_cuda_mma {
|
148 | 187 |
|
149 | 188 | template <int I, int J, typename T>
|
150 | 189 | static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 190 | +#if defined(AMD_MFMA_AVAILABLE) |
| 191 | + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> |
| 192 | +#pragma unroll |
| 193 | + for (int l = 0; l < t.ne; ++l) { |
| 194 | + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; |
| 195 | + } |
| 196 | + } else { |
| 197 | + int64_t * xi = (int64_t *) t.x; |
| 198 | + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); |
| 199 | + xi[0] = xs[0]; |
| 200 | + } |
| 201 | +#else |
151 | 202 | #pragma unroll
|
152 | 203 | for (int l = 0; l < t.ne; ++l) {
|
153 | 204 | t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
154 | 205 | }
|
| 206 | +#endif // defined(AMD_MFMA_AVAILABLE) |
155 | 207 | }
|
156 | 208 |
|
157 | 209 | template <typename T>
|
@@ -186,7 +238,7 @@ namespace ggml_cuda_mma {
|
186 | 238 | template <typename T>
|
187 | 239 | static __device__ __forceinline__ void load_ldmatrix(
|
188 | 240 | tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
189 |
| -#ifdef NEW_MMA_AVAILABLE |
| 241 | +#if defined(NEW_MMA_AVAILABLE) |
190 | 242 | int * xi = (int * ) t.x;
|
191 | 243 | const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
192 | 244 | asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
@@ -393,4 +445,60 @@ namespace ggml_cuda_mma {
|
393 | 445 | NO_DEVICE_CODE;
|
394 | 446 | #endif // NEW_MMA_AVAILABLE
|
395 | 447 | }
|
| 448 | + |
| 449 | + static __device__ __forceinline__ void mma( |
| 450 | + tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { |
| 451 | +#if defined(AMD_MFMA_AVAILABLE) |
| 452 | + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; |
| 453 | + int32x4_t * acc = (int32x4_t *) D.x; |
| 454 | +#if defined(CDNA3) |
| 455 | + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], |
| 456 | + ((int64_t *) B.x)[0], |
| 457 | + acc[0], |
| 458 | + 0, 0, 0); |
| 459 | +#elif defined(CDNA2) || defined(CDNA) |
| 460 | + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], |
| 461 | + B.x[0], |
| 462 | + acc[0], |
| 463 | + 0, 0, 0); |
| 464 | + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], |
| 465 | + B.x[1], |
| 466 | + acc[0], |
| 467 | + 0, 0, 0); |
| 468 | +#endif // defined(CDNA3) |
| 469 | +#else |
| 470 | + GGML_UNUSED(D); |
| 471 | + GGML_UNUSED(A); |
| 472 | + GGML_UNUSED(B); |
| 473 | + NO_DEVICE_CODE; |
| 474 | +#endif // AMD_MFMA_AVAILABLE |
| 475 | + } |
| 476 | + |
| 477 | + static __device__ __forceinline__ void mma( |
| 478 | + tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { |
| 479 | +#if defined(AMD_MFMA_AVAILABLE) |
| 480 | + using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; |
| 481 | + int32x16_t * acc = (int32x16_t *) D.x; |
| 482 | +#if defined(CDNA3) |
| 483 | + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], |
| 484 | + ((int64_t *) B.x)[0], |
| 485 | + acc[0], |
| 486 | + 0, 0, 0); |
| 487 | +#elif defined(CDNA2) || defined(CDNA) |
| 488 | + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], |
| 489 | + B.x[0], |
| 490 | + acc[0], |
| 491 | + 0, 0, 0); |
| 492 | + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], |
| 493 | + B.x[1], |
| 494 | + acc[0], |
| 495 | + 0, 0, 0); |
| 496 | +#endif // defined(CDNA3) |
| 497 | +#else |
| 498 | + GGML_UNUSED(D); |
| 499 | + GGML_UNUSED(A); |
| 500 | + GGML_UNUSED(B); |
| 501 | + NO_DEVICE_CODE; |
| 502 | +#endif // AMD_MFMA_AVAILABLE |
| 503 | + } |
396 | 504 | }
|
0 commit comments