|
1 | 1 | #include <metal_array>
|
| 2 | +#include <metal_stdlib> |
2 | 3 |
|
3 | 4 | using namespace metal;
|
4 | 5 | template <typename T>
|
@@ -31,6 +32,271 @@ kernel void naive_matmul(
|
31 | 32 | outputData[x * strides[2].x + y * strides[2].y] = rc;
|
32 | 33 | }
|
33 | 34 |
|
| 35 | +inline float blockReduceSum( |
| 36 | + threadgroup float* sharedScratch, |
| 37 | + float val, |
| 38 | + uint tid, |
| 39 | + uint tpg) { |
| 40 | + sharedScratch[tid] = val; |
| 41 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 42 | + |
| 43 | + for (uint offset = tpg >> 1; offset > 0; offset >>= 1) { |
| 44 | + if (tid < offset) { |
| 45 | + sharedScratch[tid] += sharedScratch[tid + offset]; |
| 46 | + } |
| 47 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 48 | + } |
| 49 | + |
| 50 | + return sharedScratch[0]; |
| 51 | +} |
| 52 | + |
| 53 | +kernel void factorDiagonalBlock( |
| 54 | + device float* A [[buffer(0)]], |
| 55 | + device int* success [[buffer(1)]], |
| 56 | + constant uint& N [[buffer(2)]], |
| 57 | + constant uint& NB [[buffer(3)]], |
| 58 | + constant uint& k [[buffer(4)]], |
| 59 | + uint tid [[thread_position_in_threadgroup]], |
| 60 | + uint bid [[threadgroup_position_in_grid]], |
| 61 | + uint tpg [[threads_per_threadgroup]]) { |
| 62 | + const uint actSize = min(N - k * NB, NB); // uint64 before NB |
| 63 | + const uint batch_offset = bid * N * N; |
| 64 | + |
| 65 | + const uint row0 = k * NB; |
| 66 | + const uint col0 = k * NB; |
| 67 | + |
| 68 | + threadgroup float tile[32][33]; |
| 69 | + threadgroup float reduceScratch[256]; |
| 70 | + const uint tileSize = actSize * actSize; |
| 71 | + |
| 72 | + for (uint i = tid; i < tileSize; i += tpg) { |
| 73 | + uint r = i / actSize; |
| 74 | + uint c = i % actSize; |
| 75 | + tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; |
| 76 | + } |
| 77 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 78 | + |
| 79 | + for (uint kk = 0; kk < actSize; kk++) { |
| 80 | + float diagElt = 0.0f; |
| 81 | + if (kk > 0) { |
| 82 | + float partialSum = 0.0f; |
| 83 | + for (uint i = tid; i < kk; i += tpg) { |
| 84 | + float val = tile[kk][i]; |
| 85 | + partialSum = fma(val, val, partialSum); |
| 86 | + } |
| 87 | + diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg); |
| 88 | + } |
| 89 | + |
| 90 | + if (tid == 0) { |
| 91 | + float diagVal = tile[kk][kk] - diagElt; |
| 92 | + // Check for positive definiteness |
| 93 | + if (diagVal <= 0.0f) { |
| 94 | + success[bid] = 0; // matrix is not positive definite |
| 95 | + return; |
| 96 | + } |
| 97 | + tile[kk][kk] = sqrt(diagVal); |
| 98 | + } |
| 99 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 100 | + |
| 101 | + float pivot = tile[kk][kk]; |
| 102 | + |
| 103 | + for (uint j = kk + 1 + tid; j < actSize; j += tpg) { |
| 104 | + float partialSum = 0.0f; |
| 105 | + for (uint i = 0; i < kk; i++) { |
| 106 | + partialSum = fma(tile[j][i], tile[kk][i], partialSum); |
| 107 | + } |
| 108 | + |
| 109 | + float val = tile[j][kk]; |
| 110 | + val -= partialSum; |
| 111 | + val /= pivot; |
| 112 | + tile[j][kk] = val; |
| 113 | + } |
| 114 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 115 | + } |
| 116 | + |
| 117 | + for (uint i = tid; i < tileSize; i += tpg) { |
| 118 | + uint r = i / actSize; |
| 119 | + uint c = i % actSize; |
| 120 | + A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c]; |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +kernel void applyTRSM( |
| 125 | + device float* A [[buffer(0)]], |
| 126 | + constant uint& N [[buffer(2)]], |
| 127 | + constant uint& NB [[buffer(3)]], |
| 128 | + constant uint& k [[buffer(4)]], |
| 129 | + uint3 tid [[thread_position_in_threadgroup]], |
| 130 | + uint3 tgid [[threadgroup_position_in_grid]], |
| 131 | + uint3 tpg [[threads_per_threadgroup]]) { |
| 132 | + uint b = tgid.x; |
| 133 | + uint idxJ = tgid.y; |
| 134 | + |
| 135 | + const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB))); |
| 136 | + const uint batch_offset = b * N * N; |
| 137 | + const uint j = (k + 1) + idxJ; |
| 138 | + |
| 139 | + uint row0 = j * NB; |
| 140 | + uint col0 = k * NB; |
| 141 | + |
| 142 | + uint actSize_j = (uint)min((int)(N - row0), (int)NB); |
| 143 | + if (actSize_k == 0 || actSize_j == 0) { |
| 144 | + return; |
| 145 | + } |
| 146 | + if (j == k) { |
| 147 | + return; |
| 148 | + } |
| 149 | + |
| 150 | + threadgroup float diag[32 * 32]; |
| 151 | + threadgroup float target[32 * 32]; |
| 152 | + |
| 153 | + for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) { |
| 154 | + uint r = i / actSize_k; |
| 155 | + uint c = i % actSize_k; |
| 156 | + diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)]; |
| 157 | + } |
| 158 | + for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) { |
| 159 | + uint r = i / actSize_k; |
| 160 | + uint c = i % actSize_k; |
| 161 | + target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)]; |
| 162 | + } |
| 163 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 164 | + |
| 165 | + for (uint col = 0; col < actSize_k; col++) { |
| 166 | + float diag_val = diag[col * actSize_k + col]; |
| 167 | + if (abs(diag_val) < 1e-6f) { |
| 168 | + diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f; |
| 169 | + } |
| 170 | + |
| 171 | + for (uint row = tid.x; row < actSize_j; row += tpg.x) { |
| 172 | + float sum = target[row * actSize_k + col]; |
| 173 | + |
| 174 | + // kahan sum |
| 175 | + float c = 0.0f; |
| 176 | + for (uint p = 0; p < col; p++) { |
| 177 | + float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c; |
| 178 | + float t = sum + y; |
| 179 | + c = (t - sum) - y; |
| 180 | + sum = t; |
| 181 | + } |
| 182 | + |
| 183 | + target[row * actSize_k + col] = sum / diag_val; |
| 184 | + } |
| 185 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 186 | + } |
| 187 | + |
| 188 | + for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) { |
| 189 | + uint r = i / actSize_k; |
| 190 | + uint c = i % actSize_k; |
| 191 | + A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i]; |
| 192 | + } |
| 193 | +} |
| 194 | + |
| 195 | +kernel void applySYRK( |
| 196 | + device float* A [[buffer(0)]], |
| 197 | + constant uint& N [[buffer(2)]], |
| 198 | + constant uint& NB [[buffer(3)]], |
| 199 | + constant uint& k [[buffer(4)]], |
| 200 | + uint3 tid [[thread_position_in_threadgroup]], |
| 201 | + uint3 tgid [[threadgroup_position_in_grid]], |
| 202 | + uint3 tpg [[threads_per_threadgroup]]) { |
| 203 | + uint b = tgid.x; |
| 204 | + uint pairID = tgid.y; |
| 205 | + |
| 206 | + uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2; |
| 207 | + uint hRel = pairID - (jRel * (jRel + 1) >> 1); |
| 208 | + |
| 209 | + const uint startJ = (k + 1); |
| 210 | + uint j = startJ + jRel; |
| 211 | + uint h = startJ + hRel; |
| 212 | + uint row0 = j * NB; |
| 213 | + uint col0 = h * NB; |
| 214 | + |
| 215 | + const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB))); |
| 216 | + const uint actSize_j = min((uint)(N - row0), NB); |
| 217 | + const uint actSize_h = min((uint)(N - col0), NB); |
| 218 | + const uint batch_offset = b * N * N; |
| 219 | + |
| 220 | + if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0) |
| 221 | + return; |
| 222 | + |
| 223 | + threadgroup float left[32 * 33]; |
| 224 | + threadgroup float right_t[32 * 33]; |
| 225 | + threadgroup float tile[32 * 33]; |
| 226 | + |
| 227 | + const uint threads = min(tpg.x, actSize_j * actSize_k); |
| 228 | + |
| 229 | + for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) { |
| 230 | + uint r = i / actSize_k; |
| 231 | + uint c = i % actSize_k; |
| 232 | + left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)]; |
| 233 | + } |
| 234 | + |
| 235 | + for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) { |
| 236 | + uint r = i / actSize_k; |
| 237 | + uint c = i % actSize_k; |
| 238 | + right_t[c * actSize_h + r] = |
| 239 | + A[batch_offset + (h * NB + r) * N + (k * NB + c)]; |
| 240 | + } |
| 241 | + |
| 242 | + for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) { |
| 243 | + uint r = i / actSize_h; |
| 244 | + uint c = i % actSize_h; |
| 245 | + tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; |
| 246 | + } |
| 247 | + |
| 248 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 249 | + |
| 250 | + for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) { |
| 251 | + uint r = idx / actSize_h; |
| 252 | + uint c = idx % actSize_h; |
| 253 | + |
| 254 | + if ((j == h) && (r < c)) |
| 255 | + continue; |
| 256 | + |
| 257 | + uint tile_idx = r * actSize_h + c; |
| 258 | + float sum = tile[tile_idx]; |
| 259 | + |
| 260 | + uint left_row = r * actSize_k; |
| 261 | + uint right_col = c; |
| 262 | + |
| 263 | + uint k = 0; |
| 264 | + float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 265 | + |
| 266 | + for (; k + 4 <= actSize_k; k += 4) { |
| 267 | + float4 left4 = { |
| 268 | + left[left_row + k], |
| 269 | + left[left_row + k + 1], |
| 270 | + left[left_row + k + 2], |
| 271 | + left[left_row + k + 3]}; |
| 272 | + |
| 273 | + float4 right4 = { |
| 274 | + right_t[(k + 0) * actSize_h + right_col], |
| 275 | + right_t[(k + 1) * actSize_h + right_col], |
| 276 | + right_t[(k + 2) * actSize_h + right_col], |
| 277 | + right_t[(k + 3) * actSize_h + right_col]}; |
| 278 | + |
| 279 | + sum4 = fma(left4, right4, sum4); |
| 280 | + } |
| 281 | + |
| 282 | + sum -= dot(sum4, 1.0); |
| 283 | + |
| 284 | + for (; k < actSize_k; k++) { |
| 285 | + sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum); |
| 286 | + } |
| 287 | + |
| 288 | + tile[tile_idx] = sum; |
| 289 | + } |
| 290 | + |
| 291 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 292 | + |
| 293 | + for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) { |
| 294 | + uint r = i / actSize_h; |
| 295 | + uint c = i % actSize_h; |
| 296 | + A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c]; |
| 297 | + } |
| 298 | +} |
| 299 | + |
34 | 300 | #define INSTANTIATE_NAIVE_MM(DTYPE) \
|
35 | 301 | template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
|
36 | 302 | naive_matmul<DTYPE>( \
|
|
0 commit comments