@@ -1579,7 +1579,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1579
1579
}
1580
1580
}
1581
1581
1582
- static void ggml_metal_encode_node (
1582
+ static bool ggml_metal_encode_node (
1583
1583
ggml_backend_t backend,
1584
1584
int idx,
1585
1585
id <MTLComputeCommandEncoder > encoder,
@@ -1599,7 +1599,7 @@ static void ggml_metal_encode_node(
1599
1599
struct ggml_tensor * dst = node;
1600
1600
1601
1601
if (ggml_is_empty (dst)) {
1602
- return ;
1602
+ return true ;
1603
1603
}
1604
1604
1605
1605
switch (dst->op ) {
@@ -1610,7 +1610,7 @@ static void ggml_metal_encode_node(
1610
1610
case GGML_OP_PERMUTE:
1611
1611
{
1612
1612
// noop -> next node
1613
- } return ;
1613
+ } return true ;
1614
1614
default :
1615
1615
{
1616
1616
} break ;
@@ -2214,6 +2214,8 @@ static void ggml_metal_encode_node(
2214
2214
{
2215
2215
GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
2216
2216
2217
+ GGML_ASSERT (ggml_is_contiguous (src0));
2218
+
2217
2219
int nth = 32 ; // SIMD width
2218
2220
2219
2221
id <MTLComputePipelineState > pipeline = nil ;
@@ -2278,7 +2280,9 @@ static void ggml_metal_encode_node(
2278
2280
2279
2281
id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 32 );
2280
2282
if (!id_src0h) {
2281
- break ;
2283
+ // GGML_LOG_ERROR("%s: failed to allocate buffer for cpy, size = %zu, need = %zu, max available = %zu\n",
2284
+ // __func__, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:32]);
2285
+ return false ;
2282
2286
}
2283
2287
2284
2288
if (src0->type == GGML_TYPE_F16) {
@@ -4669,6 +4673,8 @@ static void ggml_metal_encode_node(
4669
4673
GGML_ABORT (" fatal error" );
4670
4674
}
4671
4675
}
4676
+
4677
+ return true ;
4672
4678
}
4673
4679
4674
4680
static enum ggml_status ggml_metal_graph_compute (
@@ -4683,13 +4689,16 @@ static enum ggml_status ggml_metal_graph_compute(
4683
4689
// number of threads in addition to the main thread
4684
4690
const int n_cb = ctx->n_cb ;
4685
4691
4692
+ int n_try = 64 ;
4693
+
4686
4694
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
4687
4695
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
4688
4696
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
4689
4697
// each thread creates it's own command buffer and enqueues the ops in parallel
4690
4698
//
4691
4699
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
4692
4700
4701
+ while (n_try-- > 0 ) {
4693
4702
@autoreleasepool {
4694
4703
ctx->gf = gf;
4695
4704
@@ -4752,8 +4761,6 @@ static enum ggml_status ggml_metal_graph_compute(
4752
4761
id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb].obj ;
4753
4762
[cmd_buf waitUntilCompleted ];
4754
4763
4755
- ggml_metal_heap_reset (ctx->cmd_bufs [n_cb].heap );
4756
-
4757
4764
MTLCommandBufferStatus status = [cmd_buf status ];
4758
4765
if (status != MTLCommandBufferStatusCompleted ) {
4759
4766
GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
@@ -4769,8 +4776,6 @@ static enum ggml_status ggml_metal_graph_compute(
4769
4776
id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i].obj ;
4770
4777
[cmd_buf waitUntilCompleted ];
4771
4778
4772
- ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4773
-
4774
4779
MTLCommandBufferStatus status = [cmd_buf status ];
4775
4780
if (status != MTLCommandBufferStatusCompleted ) {
4776
4781
GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
@@ -4805,6 +4810,54 @@ static enum ggml_status ggml_metal_graph_compute(
4805
4810
}
4806
4811
}
4807
4812
4813
+ bool retry = false ;
4814
+
4815
+ // check heap statuses
4816
+ for (int i = 0 ; i <= n_cb; ++i) {
4817
+ struct ggml_metal_heap * heap = ctx->cmd_bufs [i].heap ;
4818
+
4819
+ const size_t need = 4 *heap->need ;
4820
+
4821
+ // printf("\nXXXXXXXXXXXXXXXXX cb %d, need = %zu, fail = %d, size = %zu\n", i, need, heap->fail, [heap->obj currentAllocatedSize]);
4822
+
4823
+ if (heap->fail == 0 ) {
4824
+ ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4825
+
4826
+ continue ;
4827
+ }
4828
+
4829
+ if (heap->fail == 2 ) {
4830
+ GGML_LOG_ERROR (" %s : command buffer %d , MTLHeap ran out of buffers, max = %d \n " , __func__, i, heap->n );
4831
+ return GGML_STATUS_ALLOC_FAILED;
4832
+ }
4833
+
4834
+ if (heap->fail == 3 ) {
4835
+ GGML_LOG_ERROR (" %s : command buffer %d , MTLHeap failed to allocate buffer, max = %d \n " , __func__, i, heap->n );
4836
+ return GGML_STATUS_ALLOC_FAILED;
4837
+ }
4838
+
4839
+ // GGML_LOG_INFO("%s: command buffer %d, MTLHeap need = %zu\n", __func__, i, need);
4840
+
4841
+ if (!ggml_metal_heap_resize (heap, need)) {
4842
+ GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n " , __func__, need);
4843
+ return GGML_STATUS_ALLOC_FAILED;
4844
+ }
4845
+
4846
+ retry = true ;
4847
+ }
4848
+
4849
+ if (!retry) {
4850
+ break ;
4851
+ }
4852
+
4853
+ // printf("XXXXXXXXXXXXXXXXXXXXXXX retry\n");
4854
+
4855
+ if (n_try == 0 ) {
4856
+ GGML_LOG_ERROR (" %s : failed to allocate heap memory\n " , __func__);
4857
+ return GGML_STATUS_ALLOC_FAILED;
4858
+ }
4859
+ }
4860
+
4808
4861
return GGML_STATUS_SUCCESS;
4809
4862
}
4810
4863
@@ -5167,64 +5220,36 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5167
5220
id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
5168
5221
struct ggml_metal_heap * heap = ctx->cmd_bufs [cb_idx].heap ;
5169
5222
5170
- int n_try = 2 ;
5171
-
5172
- while (n_try-- > 0 ) {
5173
- id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
5174
-
5175
- int node_start = 0 ;
5176
- int node_end = n_nodes_0;
5177
-
5178
- if (cb_idx < n_cb_l) {
5179
- node_start = n_nodes_0 + ( (cb_idx + 0 ) * n_nodes_per_cb);
5180
- node_end = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
5181
- }
5182
-
5183
- const bool should_capture = ctx->capture_next_compute ;
5184
-
5185
- for (int idx = node_start; idx < node_end; ++idx) {
5186
- if (should_capture) {
5187
- [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5188
- }
5189
-
5190
- ggml_metal_encode_node (backend, idx, encoder, heap);
5223
+ id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
5191
5224
5192
- if (should_capture) {
5193
- [encoder popDebugGroup ];
5194
- }
5195
- }
5225
+ int node_start = 0 ;
5226
+ int node_end = n_nodes_0;
5196
5227
5197
- [encoder endEncoding ];
5228
+ if (cb_idx < n_cb_l) {
5229
+ node_start = n_nodes_0 + ( (cb_idx + 0 ) * n_nodes_per_cb);
5230
+ node_end = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
5231
+ }
5198
5232
5199
- if (heap->fail == 0 ) {
5200
- break ;
5201
- }
5233
+ const bool should_capture = ctx->capture_next_compute ;
5202
5234
5203
- if (heap-> fail == 2 ) {
5204
- GGML_LOG_ERROR ( " %s : MTLHeap ran out of buffers, max = %d \n " , __func__, heap-> n );
5205
- break ;
5235
+ for ( int idx = node_start; idx < node_end; ++idx ) {
5236
+ if (should_capture) {
5237
+ [encoder pushDebugGroup: [ NSString stringWithCString: ggml_op_desc ( ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]] ;
5206
5238
}
5207
5239
5208
- if (heap->fail == 3 ) {
5209
- GGML_LOG_ERROR (" %s : MTLHeap failed to allocate buffer\n " , __func__);
5210
- break ;
5211
- }
5240
+ const bool res = ggml_metal_encode_node (backend, idx, encoder, heap);
5212
5241
5213
- if (n_try == 0 ) {
5214
- GGML_LOG_ERROR (" %s : failed to allocate heap memory\n " , __func__);
5215
- break ;
5242
+ if (should_capture) {
5243
+ [encoder popDebugGroup ];
5216
5244
}
5217
5245
5218
- const size_t need = heap->need ;
5219
-
5220
- GGML_LOG_INFO (" %s : increasing heap size to %zu \n " , __func__, need);
5221
-
5222
- if (!ggml_metal_heap_resize (heap, need)) {
5223
- GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n " , __func__, need);
5246
+ if (!res) {
5224
5247
break ;
5225
5248
}
5226
5249
}
5227
5250
5251
+ [encoder endEncoding ];
5252
+
5228
5253
if (cb_idx < 2 || ctx->abort_callback == NULL ) {
5229
5254
[cmd_buf commit ];
5230
5255
}
0 commit comments