Skip to content

Commit ac1fe16

Browse files
authored
[Allocator] Fix OOM issues in non-DPS mode (#498)
Summary: - Restore querying of `rankOffset` from the output descriptor, which was inadvertently removed in a previous merge. Update the related test `tensorrt-runtime-to-executor.mlir`. - Enhance debugging logs to provide clearer error messages. - Revert `Allocator::track` to its previous implementation for cases where the incoming pointer is not managed internally by `Allocator`. - Modify `Allocator::track` to correctly handle cases where an incoming pointer is already managed internally. This adjustment is necessary as `Allocator` now tracks pointers internally when they are returned as function results. - Refine `Allocator::safeDeallocate` to ensure it only releases pointers that are managed internally. - Correct a typo in `Allocator::safeDeallocate` for `PointerType::pinned_host`, which previously caused an error. Adjust to log a message deferring memory deallocation to `PinnedMemoryAllocator`. - Address an issue in external memref creation by avoiding retracking of pointers already managed by `Allocator`. This prevents redundant tracking as externally managed. - Ensure that when populating arguments for an Enqueue function, the session allocator tracks pointers as internally managed if they are managed by the client allocator. Ensure the session tracker does not assume ownership for deallocation. - Prevent `~Allocator()` from releasing pointers that have already been released internally. Signed-off-by: Jhalak Patel <[email protected]>
1 parent ac4ad9a commit ac1fe16

File tree

8 files changed

+201
-91
lines changed

8 files changed

+201
-91
lines changed

mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ struct ConvertEnqueueAllocToCall
375375
for (auto [idx, result] : llvm::enumerate(op.getResults())) {
376376
MemRefType memrefType = cast<MemRefType>(result.getType());
377377
unsigned rank = memrefType.getRank();
378+
379+
// Skip the rank offset that is populated by the callee.
380+
outputDescOffset++;
381+
378382
Value devicePtrOffset = b.create<executor::GetOffsetOp>(
379383
i64Type, structType,
380384
ArrayRef<OpFoldResult>{

mlir-tensorrt/compiler/test/Conversion/TensorRTRuntimeToExecutor/tensorrt-runtime-to-executor.mlir

+17-15
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,26 @@ func.func @convert_enqueue_alloc(%arg0: memref<?xf32, #device>,
7979
// CHECK: %[[v12:.+]] = executor.table.get %[[v0]][3] : <!executor.ptr<device>, !executor.ptr<device>, i64, i64, i64>
8080
// CHECK: %[[v13:.+]] = executor.table.create(%[[v8]], %[[c0_i64]], %[[c2_i64]], %[[v9]], %[[v10]], %[[v11]], %[[c0_i64]], %[[c1_i64]], %[[v12]] : !executor.ptr<device>, i64, i64, i64, i64, !executor.ptr<device>, i64, i64, i64) : <!executor.ptr<device>, i64, i64, i64, i64, !executor.ptr<device>, i64, i64, i64>
8181
// CHECK: executor.call @_trtrt_enqueue_alloc(%[[v3]], %[[v2]], %[[v4]], %[[v13]]) : (!executor.ptr<host>, !executor.ptr<host>, !executor.ptr<host>, !executor.table<!executor.ptr<device>, i64, i64, i64, i64, !executor.ptr<device>, i64, i64, i64>) -> ()
82-
// CHECK: %[[v14:.+]] = executor.load %[[v4]] + %[[v6]] : (!executor.ptr<host>, i64) -> i64
83-
// CHECK: %[[v15:.+]] = executor.inttoptr %[[v14]] : (i64) -> !executor.ptr<device>
84-
// CHECK: %[[v16:.+]] = executor.getoffset[0, 2] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
85-
// CHECK: %[[v17:.+]] = executor.load %[[v4]] + %[[v16]] : (!executor.ptr<host>, i64) -> i64
82+
// CHECK: %[[v14:.+]] = executor.getoffset[0, 2] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
83+
// CHECK: %[[v15:.+]] = executor.load %[[v4]] + %[[v14]] : (!executor.ptr<host>, i64) -> i64
84+
// CHECK: %[[v16:.+]] = executor.inttoptr %[[v15]] : (i64) -> !executor.ptr<device>
8685
// CHECK: %[[v18:.+]] = executor.getoffset[0, 3] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
8786
// CHECK: %[[v19:.+]] = executor.load %[[v4]] + %[[v18]] : (!executor.ptr<host>, i64) -> i64
88-
// CHECK: %[[v20:.+]] = executor.table.create(%[[v15]], %[[v15]], %[[c0_i64]], %[[v17]], %[[v19]] : !executor.ptr<device>, !executor.ptr<device>, i64, i64, i64) : <!executor.ptr<device>, !executor.ptr<device>, i64, i64, i64>
89-
// CHECK: %[[v21:.+]] = builtin.unrealized_conversion_cast %[[v20]] : !executor.table<!executor.ptr<device>, !executor.ptr<device>, i64, i64, i64> to memref<?xf32, #executor.memory_type<device>>
90-
// CHECK: %[[v22:.+]] = executor.getoffset[0, 4] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
91-
// CHECK: %[[v23:.+]] = executor.load %[[v4]] + %[[v22]] : (!executor.ptr<host>, i64) -> i64
92-
// CHECK: %[[v24:.+]] = executor.inttoptr %[[v23]] : (i64) -> !executor.ptr<host>
93-
// CHECK: %[[v25:.+]] = executor.load %[[v4]] + %[[v7]] : (!executor.ptr<host>, i64) -> i64
94-
// CHECK: %[[v26:.+]] = executor.getoffset[0, 6] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
95-
// CHECK: %[[v27:.+]] = executor.load %[[v4]] + %[[v26]] : (!executor.ptr<host>, i64) -> i64
87+
// CHECK: %[[v20:.+]] = executor.getoffset[0, 4] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
88+
// CHECK: %[[v21:.+]] = executor.load %[[v4]] + %[[v20]] : (!executor.ptr<host>, i64) -> i64
89+
// CHECK: %[[v22:.+]] = executor.table.create(%[[v16]], %[[v16]], %[[c0_i64]], %[[v19]], %[[v21]] : !executor.ptr<device>, !executor.ptr<device>, i64, i64, i64) : <!executor.ptr<device>, !executor.ptr<device>, i64, i64, i64>
90+
// CHECK: %[[v23:.+]] = builtin.unrealized_conversion_cast %[[v22]] : !executor.table<!executor.ptr<device>, !executor.ptr<device>, i64, i64, i64> to memref<?xf32, #executor.memory_type<device>>
91+
// CHECK: %[[v24:.+]] = executor.getoffset[0, 6] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
92+
// CHECK: %[[v25:.+]] = executor.load %[[v4]] + %[[v24]] : (!executor.ptr<host>, i64) -> i64
93+
// CHECK: %[[v26:.+]] = executor.inttoptr %[[v25]] : (i64) -> !executor.ptr<host>
9694
// CHECK: %[[v28:.+]] = executor.getoffset[0, 7] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
9795
// CHECK: %[[v29:.+]] = executor.load %[[v4]] + %[[v28]] : (!executor.ptr<host>, i64) -> i64
9896
// CHECK: %[[v30:.+]] = executor.getoffset[0, 8] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
9997
// CHECK: %[[v31:.+]] = executor.load %[[v4]] + %[[v30]] : (!executor.ptr<host>, i64) -> i64
100-
// CHECK: %[[v32:.+]] = executor.table.create(%[[v24]], %[[v24]], %[[c0_i64]], %[[v25]], %[[v27]], %[[v29]], %[[v31]] : !executor.ptr<host>, !executor.ptr<host>, i64, i64, i64, i64, i64) : <!executor.ptr<host>, !executor.ptr<host>, i64, i64, i64, i64, i64>
101-
// CHECK: %[[v33:.+]] = builtin.unrealized_conversion_cast %[[v32]] : !executor.table<!executor.ptr<host>, !executor.ptr<host>, i64, i64, i64, i64, i64> to memref<?x?xf32, #executor.memory_type<host>>
102-
// CHECK: return %[[v21]], %[[v33]] : memref<?xf32, #executor.memory_type<device>>, memref<?x?xf32, #executor.memory_type<host>>
98+
// CHECK: %[[v32:.+]] = executor.getoffset[0, 9] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
99+
// CHECK: %[[v33:.+]] = executor.load %[[v4]] + %[[v32]] : (!executor.ptr<host>, i64) -> i64
100+
// CHECK: %[[v34:.+]] = executor.getoffset[0, 10] : () -> i64, !executor.table<i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64>
101+
// CHECK: %[[v35:.+]] = executor.load %[[v4]] + %[[v34]] : (!executor.ptr<host>, i64) -> i64
102+
// CHECK: %[[v36:.+]] = executor.table.create(%[[v26]], %[[v26]], %[[c0_i64]], %[[v29]], %[[v31]], %[[v33]], %[[v35]] : !executor.ptr<host>, !executor.ptr<host>, i64, i64, i64, i64, i64) : <!executor.ptr<host>, !executor.ptr<host>, i64, i64, i64, i64, i64>
103+
// CHECK: %[[v37:.+]] = builtin.unrealized_conversion_cast %[[v36]] : !executor.table<!executor.ptr<host>, !executor.ptr<host>, i64, i64, i64, i64, i64> to memref<?x?xf32, #executor.memory_type<host>>
104+
// CHECK: return %[[v23]], %[[v37]] : memref<?xf32, #executor.memory_type<device>>, memref<?x?xf32, #executor.memory_type<host>>

mlir-tensorrt/compiler/test/python/mlir_tensorrt_runtime/test_create_memref.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -397,24 +397,25 @@ def test_released_internally():
397397
arr = np.array([5.0, 4.0, 2.0])
398398

399399
def memref_alloc():
400+
# The data is externally owned, so the memref will not be released internally.
400401
memref = client.create_host_memref_view(
401402
int(arr.ctypes.data), shape=[3], dtype=runtime.ScalarTypeCode.f64
402403
)
403404
return np.from_dlpack(
404405
memref
405-
) # Ensure we have an externally reference to the pointer.
406+
) # Ensure we have an external reference to the pointer.
406407

407408
_ = memref_alloc()
408409
print(
409410
"Memref released internally: ", client.is_released_internally(arr.ctypes.data)
410411
)
411412

412413

413-
print("Test memref is released internally with an external reference")
414+
print("Test memref is not released internally with an external reference")
414415
test_released_internally()
415416

416-
# CHECK-LABEL: Test memref is released internally with an external reference
417-
# CHECK-NEXT: Memref released internally: True
417+
# CHECK-LABEL: Test memref is not released internally with an external reference
418+
# CHECK-NEXT: Memref released internally: False
418419

419420

420421
def test_memref_lifetime():

mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ class MemRefValue : public RuntimeValue {
689689

690690
const std::optional<ScalarType> &getScalarType() const { return scalarType; }
691691

692-
RuntimeClient *getClient() { return client; }
692+
RuntimeClient *getClient() const { return client; }
693693

694694
private:
695695
MemRefValue(RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,

mlir-tensorrt/executor/include/mlir-executor/Support/Status.h

+11
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,17 @@ class StatusOr {
239239
} \
240240
} while (false);
241241

242+
#ifndef NDEBUG
243+
#define MTRT_ERROR_IF(errexpr, msg) \
244+
do { \
245+
if (errexpr) { \
246+
llvm::report_fatal_error(msg); \
247+
} \
248+
} while (false);
249+
#else // In Release mode, compiles to a no-op.
250+
#define MTRT_ERROR_IF(errexpr, msg)
251+
#endif
252+
242253
} // namespace mlirtrt
243254

244255
#endif // MLIR_TENSORRT_SUPPORT_STATUS_H

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,8 @@ MTRT_Status mtrtGetScalarTypeCodeFromDLDataType(DLDataType dtype,
535535

536536
static void dlpackManagedTensorDeleter(DLManagedTensor *tensor) {
537537
if (tensor) {
538+
MTRT_DBGF("Deleting DLManagedTensor. Data pointer: %p",
539+
tensor->dl_tensor.data);
538540
delete[] tensor->dl_tensor.shape;
539541
delete[] tensor->dl_tensor.strides;
540542
if (tensor->manager_ctx) {

0 commit comments

Comments
 (0)