@@ -563,7 +563,9 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
563
563
return true;
564
564
}
565
565
566
- static id<MTLBuffer> ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t size, size_t alignment) {
566
+ static id<MTLBuffer> ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t size) {
567
+ const size_t alignment = 1024*1024;
568
+
567
569
const size_t size_aligned = GGML_PAD(size, alignment);
568
570
569
571
heap->need += size_aligned;
@@ -1583,7 +1585,8 @@ static bool ggml_metal_encode_node(
1583
1585
ggml_backend_t backend,
1584
1586
int idx,
1585
1587
id<MTLComputeCommandEncoder> encoder,
1586
- struct ggml_metal_heap * heap) {
1588
+ struct ggml_metal_heap * heap,
1589
+ bool no_compute) {
1587
1590
struct ggml_backend_metal_context * ctx = backend->context;
1588
1591
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1589
1592
@@ -1621,6 +1624,28 @@ static bool ggml_metal_encode_node(
1621
1624
GGML_ABORT("unsupported op");
1622
1625
}
1623
1626
1627
+ id<MTLBuffer> h_src0 = nil;
1628
+ switch (dst->op) {
1629
+ case GGML_OP_SOFT_MAX:
1630
+ {
1631
+ h_src0 = ggml_metal_heap_alloc(heap, ggml_nbytes(src0));
1632
+ if (!h_src0) {
1633
+ //GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
1634
+ // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
1635
+ return false;
1636
+ } else {
1637
+ //GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
1638
+ }
1639
+ } break;
1640
+ default:
1641
+ {
1642
+ } break;
1643
+ }
1644
+
1645
+ if (no_compute) {
1646
+ return true;
1647
+ }
1648
+
1624
1649
const int64_t ne00 = src0 ? src0->ne[0] : 0;
1625
1650
const int64_t ne01 = src0 ? src0->ne[1] : 0;
1626
1651
const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -2278,23 +2303,14 @@ static bool ggml_metal_encode_node(
2278
2303
/*.nb3 =*/ nb03,
2279
2304
};
2280
2305
2281
- id<MTLBuffer> id_src0h = ggml_metal_heap_alloc(heap, ggml_nbytes(src0), 64*1024);
2282
- if (!id_src0h) {
2283
- //GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
2284
- // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
2285
- return true;
2286
- } else {
2287
- //GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
2288
- }
2289
-
2290
2306
if (src0->type == GGML_TYPE_F16) {
2291
2307
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2292
2308
} else {
2293
2309
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2294
2310
}
2295
2311
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2296
2312
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2297
- [encoder setBuffer:id_src0h offset:0 atIndex:2];
2313
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
2298
2314
2299
2315
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2300
2316
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
@@ -2315,11 +2331,11 @@ static bool ggml_metal_encode_node(
2315
2331
};
2316
2332
2317
2333
[encoder setComputePipelineState:pipeline];
2318
- [encoder setBuffer:id_src0h offset:0 atIndex:0];
2334
+ [encoder setBuffer:h_src0 offset:0 atIndex:0];
2319
2335
if (id_src1) {
2320
2336
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2321
2337
} else {
2322
- [encoder setBuffer:id_src0h offset:0 atIndex:1];
2338
+ [encoder setBuffer:h_src0 offset:0 atIndex:1];
2323
2339
}
2324
2340
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2325
2341
[encoder setBytes:&args length:sizeof(args) atIndex:3];
@@ -4732,6 +4748,12 @@ static enum ggml_status ggml_metal_graph_compute(
4732
4748
}
4733
4749
}
4734
4750
4751
+ for (int i = 0; i <= n_cb; ++i) {
4752
+ struct ggml_metal_heap * heap = ctx->cmd_bufs[i].heap;
4753
+
4754
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
4755
+ }
4756
+
4735
4757
// the main thread commits the first few commands immediately
4736
4758
// cmd_buf[n_cb]
4737
4759
{
@@ -4824,6 +4846,7 @@ static enum ggml_status ggml_metal_graph_compute(
4824
4846
4825
4847
if (heap->fail == 0) {
4826
4848
ggml_metal_heap_reset(ctx->cmd_bufs[i].heap);
4849
+ [heap->obj setPurgeableState:MTLPurgeableStateEmpty];
4827
4850
4828
4851
continue;
4829
4852
}
@@ -5234,19 +5257,21 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5234
5257
5235
5258
const bool should_capture = ctx->capture_next_compute;
5236
5259
5260
+ bool no_compute = false;
5261
+
5237
5262
for (int idx = node_start; idx < node_end; ++idx) {
5238
5263
if (should_capture) {
5239
5264
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5240
5265
}
5241
5266
5242
- const bool res = ggml_metal_encode_node(backend, idx, encoder, heap);
5267
+ const bool res = ggml_metal_encode_node(backend, idx, encoder, heap, no_compute );
5243
5268
5244
5269
if (should_capture) {
5245
5270
[encoder popDebugGroup];
5246
5271
}
5247
5272
5248
5273
if (!res) {
5249
- break ;
5274
+ no_compute = true ;
5250
5275
}
5251
5276
}
5252
5277
0 commit comments