Skip to content

Commit c125f26

Browse files
committed
wip
1 parent 500b8b7 commit c125f26

File tree

1 file changed

+78
-53
lines changed

1 file changed

+78
-53
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
15791579
}
15801580
}
15811581

1582-
static void ggml_metal_encode_node(
1582+
static bool ggml_metal_encode_node(
15831583
ggml_backend_t backend,
15841584
int idx,
15851585
id<MTLComputeCommandEncoder> encoder,
@@ -1599,7 +1599,7 @@ static void ggml_metal_encode_node(
15991599
struct ggml_tensor * dst = node;
16001600

16011601
if (ggml_is_empty(dst)) {
1602-
return;
1602+
return true;
16031603
}
16041604

16051605
switch (dst->op) {
@@ -1610,7 +1610,7 @@ static void ggml_metal_encode_node(
16101610
case GGML_OP_PERMUTE:
16111611
{
16121612
// noop -> next node
1613-
} return;
1613+
} return true;
16141614
default:
16151615
{
16161616
} break;
@@ -2214,6 +2214,8 @@ static void ggml_metal_encode_node(
22142214
{
22152215
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
22162216

2217+
GGML_ASSERT(ggml_is_contiguous(src0));
2218+
22172219
int nth = 32; // SIMD width
22182220

22192221
id<MTLComputePipelineState> pipeline = nil;
@@ -2278,7 +2280,9 @@ static void ggml_metal_encode_node(
22782280

22792281
id<MTLBuffer> id_src0h = ggml_metal_heap_alloc(heap, ggml_nbytes(src0), 32);
22802282
if (!id_src0h) {
2281-
break;
2283+
//GGML_LOG_ERROR("%s: failed to allocate buffer for cpy, size = %zu, need = %zu, max available = %zu\n",
2284+
// __func__, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:32]);
2285+
return false;
22822286
}
22832287

22842288
if (src0->type == GGML_TYPE_F16) {
@@ -4669,6 +4673,8 @@ static void ggml_metal_encode_node(
46694673
GGML_ABORT("fatal error");
46704674
}
46714675
}
4676+
4677+
return true;
46724678
}
46734679

46744680
static enum ggml_status ggml_metal_graph_compute(
@@ -4683,13 +4689,16 @@ static enum ggml_status ggml_metal_graph_compute(
46834689
// number of threads in addition to the main thread
46844690
const int n_cb = ctx->n_cb;
46854691

4692+
int n_try = 64;
4693+
46864694
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
46874695
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
46884696
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
46894697
// each thread creates it's own command buffer and enqueues the ops in parallel
46904698
//
46914699
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
46924700

4701+
while (n_try-- > 0) {
46934702
@autoreleasepool {
46944703
ctx->gf = gf;
46954704

@@ -4752,8 +4761,6 @@ static enum ggml_status ggml_metal_graph_compute(
47524761
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
47534762
[cmd_buf waitUntilCompleted];
47544763

4755-
ggml_metal_heap_reset(ctx->cmd_bufs[n_cb].heap);
4756-
47574764
MTLCommandBufferStatus status = [cmd_buf status];
47584765
if (status != MTLCommandBufferStatusCompleted) {
47594766
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
@@ -4769,8 +4776,6 @@ static enum ggml_status ggml_metal_graph_compute(
47694776
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
47704777
[cmd_buf waitUntilCompleted];
47714778

4772-
ggml_metal_heap_reset(ctx->cmd_bufs[i].heap);
4773-
47744779
MTLCommandBufferStatus status = [cmd_buf status];
47754780
if (status != MTLCommandBufferStatusCompleted) {
47764781
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
@@ -4805,6 +4810,54 @@ static enum ggml_status ggml_metal_graph_compute(
48054810
}
48064811
}
48074812

4813+
bool retry = false;
4814+
4815+
// check heap statuses
4816+
for (int i = 0; i <= n_cb; ++i) {
4817+
struct ggml_metal_heap * heap = ctx->cmd_bufs[i].heap;
4818+
4819+
const size_t need = 4*heap->need;
4820+
4821+
//printf("\nXXXXXXXXXXXXXXXXX cb %d, need = %zu, fail = %d, size = %zu\n", i, need, heap->fail, [heap->obj currentAllocatedSize]);
4822+
4823+
if (heap->fail == 0) {
4824+
ggml_metal_heap_reset(ctx->cmd_bufs[i].heap);
4825+
4826+
continue;
4827+
}
4828+
4829+
if (heap->fail == 2) {
4830+
GGML_LOG_ERROR("%s: command buffer %d, MTLHeap ran out of buffers, max = %d\n", __func__, i, heap->n);
4831+
return GGML_STATUS_ALLOC_FAILED;
4832+
}
4833+
4834+
if (heap->fail == 3) {
4835+
GGML_LOG_ERROR("%s: command buffer %d, MTLHeap failed to allocate buffer, max = %d\n", __func__, i, heap->n);
4836+
return GGML_STATUS_ALLOC_FAILED;
4837+
}
4838+
4839+
//GGML_LOG_INFO("%s: command buffer %d, MTLHeap need = %zu\n", __func__, i, need);
4840+
4841+
if (!ggml_metal_heap_resize(heap, need)) {
4842+
GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need);
4843+
return GGML_STATUS_ALLOC_FAILED;
4844+
}
4845+
4846+
retry = true;
4847+
}
4848+
4849+
if (!retry) {
4850+
break;
4851+
}
4852+
4853+
//printf("XXXXXXXXXXXXXXXXXXXXXXX retry\n");
4854+
4855+
if (n_try == 0) {
4856+
GGML_LOG_ERROR("%s: failed to allocate heap memory\n", __func__);
4857+
return GGML_STATUS_ALLOC_FAILED;
4858+
}
4859+
}
4860+
48084861
return GGML_STATUS_SUCCESS;
48094862
}
48104863

@@ -5167,64 +5220,36 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51675220
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
51685221
struct ggml_metal_heap * heap = ctx->cmd_bufs[cb_idx].heap;
51695222

5170-
int n_try = 2;
5171-
5172-
while (n_try-- > 0) {
5173-
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
5174-
5175-
int node_start = 0;
5176-
int node_end = n_nodes_0;
5177-
5178-
if (cb_idx < n_cb_l) {
5179-
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
5180-
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
5181-
}
5182-
5183-
const bool should_capture = ctx->capture_next_compute;
5184-
5185-
for (int idx = node_start; idx < node_end; ++idx) {
5186-
if (should_capture) {
5187-
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5188-
}
5189-
5190-
ggml_metal_encode_node(backend, idx, encoder, heap);
5223+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
51915224

5192-
if (should_capture) {
5193-
[encoder popDebugGroup];
5194-
}
5195-
}
5225+
int node_start = 0;
5226+
int node_end = n_nodes_0;
51965227

5197-
[encoder endEncoding];
5228+
if (cb_idx < n_cb_l) {
5229+
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
5230+
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
5231+
}
51985232

5199-
if (heap->fail == 0) {
5200-
break;
5201-
}
5233+
const bool should_capture = ctx->capture_next_compute;
52025234

5203-
if (heap->fail == 2) {
5204-
GGML_LOG_ERROR("%s: MTLHeap ran out of buffers, max = %d\n", __func__, heap->n);
5205-
break;
5235+
for (int idx = node_start; idx < node_end; ++idx) {
5236+
if (should_capture) {
5237+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
52065238
}
52075239

5208-
if (heap->fail == 3) {
5209-
GGML_LOG_ERROR("%s: MTLHeap failed to allocate buffer\n", __func__);
5210-
break;
5211-
}
5240+
const bool res = ggml_metal_encode_node(backend, idx, encoder, heap);
52125241

5213-
if (n_try == 0) {
5214-
GGML_LOG_ERROR("%s: failed to allocate heap memory\n", __func__);
5215-
break;
5242+
if (should_capture) {
5243+
[encoder popDebugGroup];
52165244
}
52175245

5218-
const size_t need = heap->need;
5219-
5220-
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need);
5221-
5222-
if (!ggml_metal_heap_resize(heap, need)) {
5223-
GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need);
5246+
if (!res) {
52245247
break;
52255248
}
52265249
}
52275250

5251+
[encoder endEncoding];
5252+
52285253
if (cb_idx < 2 || ctx->abort_callback == NULL) {
52295254
[cmd_buf commit];
52305255
}

0 commit comments

Comments
 (0)