Skip to content

Commit b75afa2

Browse files
Isalia20pytorchmergebot
authored andcommitted
[MPS] cholesky implementation (pytorch#145701)
Requested in pytorch#77764 Closed pytorch#144193 due to a lot of conflicts when rebasing Pull Request resolved: pytorch#145701 Approved by: https://github.com/malfet
1 parent c6ad083 commit b75afa2

File tree

5 files changed

+401
-2
lines changed

5 files changed

+401
-2
lines changed

aten/src/ATen/native/mps/kernels/LinearAlgebra.metal

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <metal_array>
2+
#include <metal_stdlib>
23

34
using namespace metal;
45
template <typename T>
@@ -31,6 +32,271 @@ kernel void naive_matmul(
3132
outputData[x * strides[2].x + y * strides[2].y] = rc;
3233
}
3334

35+
inline float blockReduceSum(
36+
threadgroup float* sharedScratch,
37+
float val,
38+
uint tid,
39+
uint tpg) {
40+
sharedScratch[tid] = val;
41+
threadgroup_barrier(mem_flags::mem_threadgroup);
42+
43+
for (uint offset = tpg >> 1; offset > 0; offset >>= 1) {
44+
if (tid < offset) {
45+
sharedScratch[tid] += sharedScratch[tid + offset];
46+
}
47+
threadgroup_barrier(mem_flags::mem_threadgroup);
48+
}
49+
50+
return sharedScratch[0];
51+
}
52+
53+
kernel void factorDiagonalBlock(
54+
device float* A [[buffer(0)]],
55+
device int* success [[buffer(1)]],
56+
constant uint& N [[buffer(2)]],
57+
constant uint& NB [[buffer(3)]],
58+
constant uint& k [[buffer(4)]],
59+
uint tid [[thread_position_in_threadgroup]],
60+
uint bid [[threadgroup_position_in_grid]],
61+
uint tpg [[threads_per_threadgroup]]) {
62+
const uint actSize = min(N - k * NB, NB); // uint64 before NB
63+
const uint batch_offset = bid * N * N;
64+
65+
const uint row0 = k * NB;
66+
const uint col0 = k * NB;
67+
68+
threadgroup float tile[32][33];
69+
threadgroup float reduceScratch[256];
70+
const uint tileSize = actSize * actSize;
71+
72+
for (uint i = tid; i < tileSize; i += tpg) {
73+
uint r = i / actSize;
74+
uint c = i % actSize;
75+
tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
76+
}
77+
threadgroup_barrier(mem_flags::mem_threadgroup);
78+
79+
for (uint kk = 0; kk < actSize; kk++) {
80+
float diagElt = 0.0f;
81+
if (kk > 0) {
82+
float partialSum = 0.0f;
83+
for (uint i = tid; i < kk; i += tpg) {
84+
float val = tile[kk][i];
85+
partialSum = fma(val, val, partialSum);
86+
}
87+
diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg);
88+
}
89+
90+
if (tid == 0) {
91+
float diagVal = tile[kk][kk] - diagElt;
92+
// Check for positive definiteness
93+
if (diagVal <= 0.0f) {
94+
success[bid] = 0; // matrix is not positive definite
95+
return;
96+
}
97+
tile[kk][kk] = sqrt(diagVal);
98+
}
99+
threadgroup_barrier(mem_flags::mem_threadgroup);
100+
101+
float pivot = tile[kk][kk];
102+
103+
for (uint j = kk + 1 + tid; j < actSize; j += tpg) {
104+
float partialSum = 0.0f;
105+
for (uint i = 0; i < kk; i++) {
106+
partialSum = fma(tile[j][i], tile[kk][i], partialSum);
107+
}
108+
109+
float val = tile[j][kk];
110+
val -= partialSum;
111+
val /= pivot;
112+
tile[j][kk] = val;
113+
}
114+
threadgroup_barrier(mem_flags::mem_threadgroup);
115+
}
116+
117+
for (uint i = tid; i < tileSize; i += tpg) {
118+
uint r = i / actSize;
119+
uint c = i % actSize;
120+
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c];
121+
}
122+
}
123+
124+
kernel void applyTRSM(
125+
device float* A [[buffer(0)]],
126+
constant uint& N [[buffer(2)]],
127+
constant uint& NB [[buffer(3)]],
128+
constant uint& k [[buffer(4)]],
129+
uint3 tid [[thread_position_in_threadgroup]],
130+
uint3 tgid [[threadgroup_position_in_grid]],
131+
uint3 tpg [[threads_per_threadgroup]]) {
132+
uint b = tgid.x;
133+
uint idxJ = tgid.y;
134+
135+
const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
136+
const uint batch_offset = b * N * N;
137+
const uint j = (k + 1) + idxJ;
138+
139+
uint row0 = j * NB;
140+
uint col0 = k * NB;
141+
142+
uint actSize_j = (uint)min((int)(N - row0), (int)NB);
143+
if (actSize_k == 0 || actSize_j == 0) {
144+
return;
145+
}
146+
if (j == k) {
147+
return;
148+
}
149+
150+
threadgroup float diag[32 * 32];
151+
threadgroup float target[32 * 32];
152+
153+
for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) {
154+
uint r = i / actSize_k;
155+
uint c = i % actSize_k;
156+
diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)];
157+
}
158+
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
159+
uint r = i / actSize_k;
160+
uint c = i % actSize_k;
161+
target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)];
162+
}
163+
threadgroup_barrier(mem_flags::mem_threadgroup);
164+
165+
for (uint col = 0; col < actSize_k; col++) {
166+
float diag_val = diag[col * actSize_k + col];
167+
if (abs(diag_val) < 1e-6f) {
168+
diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f;
169+
}
170+
171+
for (uint row = tid.x; row < actSize_j; row += tpg.x) {
172+
float sum = target[row * actSize_k + col];
173+
174+
// kahan sum
175+
float c = 0.0f;
176+
for (uint p = 0; p < col; p++) {
177+
float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c;
178+
float t = sum + y;
179+
c = (t - sum) - y;
180+
sum = t;
181+
}
182+
183+
target[row * actSize_k + col] = sum / diag_val;
184+
}
185+
threadgroup_barrier(mem_flags::mem_threadgroup);
186+
}
187+
188+
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
189+
uint r = i / actSize_k;
190+
uint c = i % actSize_k;
191+
A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i];
192+
}
193+
}
194+
195+
kernel void applySYRK(
196+
device float* A [[buffer(0)]],
197+
constant uint& N [[buffer(2)]],
198+
constant uint& NB [[buffer(3)]],
199+
constant uint& k [[buffer(4)]],
200+
uint3 tid [[thread_position_in_threadgroup]],
201+
uint3 tgid [[threadgroup_position_in_grid]],
202+
uint3 tpg [[threads_per_threadgroup]]) {
203+
uint b = tgid.x;
204+
uint pairID = tgid.y;
205+
206+
uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2;
207+
uint hRel = pairID - (jRel * (jRel + 1) >> 1);
208+
209+
const uint startJ = (k + 1);
210+
uint j = startJ + jRel;
211+
uint h = startJ + hRel;
212+
uint row0 = j * NB;
213+
uint col0 = h * NB;
214+
215+
const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
216+
const uint actSize_j = min((uint)(N - row0), NB);
217+
const uint actSize_h = min((uint)(N - col0), NB);
218+
const uint batch_offset = b * N * N;
219+
220+
if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0)
221+
return;
222+
223+
threadgroup float left[32 * 33];
224+
threadgroup float right_t[32 * 33];
225+
threadgroup float tile[32 * 33];
226+
227+
const uint threads = min(tpg.x, actSize_j * actSize_k);
228+
229+
for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) {
230+
uint r = i / actSize_k;
231+
uint c = i % actSize_k;
232+
left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)];
233+
}
234+
235+
for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) {
236+
uint r = i / actSize_k;
237+
uint c = i % actSize_k;
238+
right_t[c * actSize_h + r] =
239+
A[batch_offset + (h * NB + r) * N + (k * NB + c)];
240+
}
241+
242+
for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
243+
uint r = i / actSize_h;
244+
uint c = i % actSize_h;
245+
tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
246+
}
247+
248+
threadgroup_barrier(mem_flags::mem_threadgroup);
249+
250+
for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) {
251+
uint r = idx / actSize_h;
252+
uint c = idx % actSize_h;
253+
254+
if ((j == h) && (r < c))
255+
continue;
256+
257+
uint tile_idx = r * actSize_h + c;
258+
float sum = tile[tile_idx];
259+
260+
uint left_row = r * actSize_k;
261+
uint right_col = c;
262+
263+
uint k = 0;
264+
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};
265+
266+
for (; k + 4 <= actSize_k; k += 4) {
267+
float4 left4 = {
268+
left[left_row + k],
269+
left[left_row + k + 1],
270+
left[left_row + k + 2],
271+
left[left_row + k + 3]};
272+
273+
float4 right4 = {
274+
right_t[(k + 0) * actSize_h + right_col],
275+
right_t[(k + 1) * actSize_h + right_col],
276+
right_t[(k + 2) * actSize_h + right_col],
277+
right_t[(k + 3) * actSize_h + right_col]};
278+
279+
sum4 = fma(left4, right4, sum4);
280+
}
281+
282+
sum -= dot(sum4, 1.0);
283+
284+
for (; k < actSize_k; k++) {
285+
sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum);
286+
}
287+
288+
tile[tile_idx] = sum;
289+
}
290+
291+
threadgroup_barrier(mem_flags::mem_threadgroup);
292+
293+
for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
294+
uint r = i / actSize_h;
295+
uint c = i % actSize_h;
296+
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c];
297+
}
298+
}
299+
34300
#define INSTANTIATE_NAIVE_MM(DTYPE) \
35301
template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
36302
naive_matmul<DTYPE>( \

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <ATen/ops/addr_native.h>
1919
#include <ATen/ops/baddbmm_native.h>
2020
#include <ATen/ops/bmm_native.h>
21+
#include <ATen/ops/cholesky_native.h>
22+
#include <ATen/ops/linalg_cholesky_native.h>
2123
#include <ATen/ops/linalg_lu_factor_native.h>
2224
#include <ATen/ops/linalg_solve_triangular_native.h>
2325
#include <ATen/ops/mm_native.h>
@@ -780,6 +782,83 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
780782
return out;
781783
}
782784

785+
static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
786+
using namespace mps;
787+
788+
TORCH_CHECK(out.is_mps());
789+
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32");
790+
TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D");
791+
TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square");
792+
793+
if (input.numel() == 0 || out.numel() == 0) {
794+
out.zero_();
795+
return out;
796+
}
797+
resize_output(out, input.sizes());
798+
out.copy_(input);
799+
800+
int64_t ndim = out.dim();
801+
int64_t N = out.size(-1);
802+
int64_t B = 1;
803+
for (int64_t i = 0; i < ndim - 2; i++) {
804+
B *= out.size(i);
805+
}
806+
807+
auto stream = getCurrentMPSStream();
808+
auto device = MPSDevice::getInstance()->device();
809+
810+
auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock");
811+
auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM");
812+
auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK");
813+
814+
int64_t NB = std::min<int64_t>(32, N);
815+
int64_t numBlocks = (N + NB - 1) / NB;
816+
817+
Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1);
818+
id<MTLBuffer> successBuffer = getMTLBufferStorage(success);
819+
820+
MTLSize threadGroupSize = MTLSizeMake(256, 1, 1);
821+
id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
822+
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
823+
[computeEncoder setBuffer:outBuffer offset:0 atIndex:0];
824+
[computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2];
825+
[computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3];
826+
827+
@autoreleasepool {
828+
dispatch_sync_with_rethrow(stream->queue(), ^() {
829+
for (int64_t k = 0; k < numBlocks; k++) {
830+
[computeEncoder setComputePipelineState:factorDiagonalPSO];
831+
[computeEncoder setBuffer:successBuffer offset:0 atIndex:1];
832+
[computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4];
833+
MTLSize gridSize = MTLSizeMake(B, 1, 1);
834+
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
835+
836+
// process all remaining blocks in this row/column in parallel
837+
if (k < numBlocks - 1) {
838+
int64_t startJ = k + 1;
839+
int64_t nBlocksJ = (numBlocks - startJ);
840+
841+
if (nBlocksJ > 0) {
842+
// TRSM for all blocks in parallel
843+
MTLSize trsmGridSize = MTLSizeMake(B, nBlocksJ, 1);
844+
[computeEncoder setComputePipelineState:applyTRSMPSO];
845+
[computeEncoder dispatchThreadgroups:trsmGridSize threadsPerThreadgroup:threadGroupSize];
846+
847+
// SYRK for all independent block pairs in parallel
848+
uint32_t nPairs = nBlocksJ * (nBlocksJ + 1) / 2;
849+
MTLSize syrkGridSize = MTLSizeMake(B, nPairs, 1);
850+
[computeEncoder setComputePipelineState:applySYRKPSO];
851+
[computeEncoder dispatchThreadgroups:syrkGridSize threadsPerThreadgroup:threadGroupSize];
852+
}
853+
}
854+
}
855+
});
856+
}
857+
858+
TORCH_CHECK(success.all().item<bool>(), "linalg.cholesky: Input matrix is not positive definite");
859+
out.tril_(); //
860+
return upper ? out.transpose_(ndim - 2, ndim - 1) : out;
861+
}
783862
} // namespace mps
784863

785864
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
@@ -940,6 +1019,25 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons
9401019
return result;
9411020
}
9421021

1022+
Tensor cholesky_mps(const Tensor& self, bool upper) {
1023+
auto out = at::empty_like(self, MemoryFormat::Contiguous);
1024+
mps::linalg_cholesky_mps_impl(self, upper, out);
1025+
return out;
1026+
}
1027+
1028+
Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) {
1029+
return mps::linalg_cholesky_mps_impl(self, upper, out);
1030+
}
1031+
1032+
Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) {
1033+
return mps::linalg_cholesky_mps_impl(self, upper, out);
1034+
}
1035+
1036+
Tensor linalg_cholesky_mps(const Tensor& self, bool upper) {
1037+
auto out = at::empty_like(self, MemoryFormat::Contiguous);
1038+
return mps::linalg_cholesky_mps_impl(self, upper, out);
1039+
}
1040+
9431041
Tensor addbmm_mps(const Tensor& self,
9441042
const Tensor& batch1,
9451043
const Tensor& batch2,

0 commit comments

Comments
 (0)