Skip to content

Commit d46dbea

Browse files
authored
Expose knobs to create and share (CPU) allocators across sessions in C# and Python (microsoft#5634)
1 parent 26e6ced commit d46dbea

20 files changed

+425
-92
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs

+40-8
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ public struct OrtApi
182182
public IntPtr SetGlobalInterOpNumThreads;
183183
public IntPtr SetGlobalSpinControl;
184184
public IntPtr AddInitializer;
185+
public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools;
186+
public IntPtr SessionOptionsAppendExecutionProvider_CUDA;
187+
public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO;
188+
public IntPtr SetGlobalDenormalAsZero;
189+
public IntPtr CreateArenaCfg;
190+
public IntPtr ReleaseArenaCfg;
185191
}
186192

187193
internal static class NativeMethods
@@ -260,6 +266,9 @@ static NativeMethods()
260266
OrtRunOptionsSetTerminate = (DOrtRunOptionsSetTerminate)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsSetTerminate, typeof(DOrtRunOptionsSetTerminate));
261267
OrtRunOptionsUnsetTerminate = (DOrtRunOptionsUnsetTerminate)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsUnsetTerminate, typeof(DOrtRunOptionsUnsetTerminate));
262268

269+
OrtCreateArenaCfg = (DOrtCreateArenaCfg)Marshal.GetDelegateForFunctionPointer(api_.CreateArenaCfg, typeof(DOrtCreateArenaCfg));
270+
OrtReleaseArenaCfg = (DOrtReleaseArenaCfg)Marshal.GetDelegateForFunctionPointer(api_.ReleaseArenaCfg, typeof(DOrtReleaseArenaCfg));
271+
OrtReleaseAllocator = (DOrtReleaseAllocator)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAllocator, typeof(DOrtReleaseAllocator));
263272
OrtCreateMemoryInfo = (DOrtCreateMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.CreateMemoryInfo, typeof(DOrtCreateMemoryInfo));
264273
OrtCreateCpuMemoryInfo = (DOrtCreateCpuMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.CreateCpuMemoryInfo, typeof(DOrtCreateCpuMemoryInfo));
265274
OrtReleaseMemoryInfo = (DOrtReleaseMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseMemoryInfo, typeof(DOrtReleaseMemoryInfo));
@@ -311,7 +320,6 @@ static NativeMethods()
311320
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
312321
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));
313322

314-
315323
OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata));
316324
OrtModelMetadataGetProducerName = (DOrtModelMetadataGetProducerName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetProducerName, typeof(DOrtModelMetadataGetProducerName));
317325
OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName));
@@ -708,6 +716,27 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
708716
public delegate IntPtr /*(OrtStatus*)*/DOrtAllocatorGetInfo(IntPtr /*(const OrtAllocator*)*/ ptr, out IntPtr /*(const struct OrtMemoryInfo**)*/info);
709717
public static DOrtAllocatorGetInfo OrtAllocatorGetInfo;
710718

719+
/// <summary>
720+
/// Create an instance of arena configuration which will be used to create an arena based allocator
721+
/// See docs/C_API.md for details on what the following parameters mean and how to choose these values
722+
/// </summary>
723+
/// <param name="maxMemory">Maximum amount of memory the arena allocates</param>
724+
/// <param name="arenaExtendStrategy">Strategy for arena expansion</param>
725+
/// <param name="initialChunkSizeBytes">Size of the region that the arena allocates first</param>
726+
/// <param name="maxDeadBytesPerChunk">Maximum amount of fragmentation allowed per chunk</param>
727+
/// <returns>Pointer to a native OrtStatus instance indicating success/failure of config creation</returns>
728+
public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateArenaCfg(UIntPtr /*(size_t)*/ maxMemory, int /*(int)*/ arenaExtendStrategy,
729+
int /*(int)*/ initialChunkSizeBytes, int /*(int)*/ maxDeadBytesPerChunk,
730+
out IntPtr /*(OrtArenaCfg**)*/ arenaCfg);
731+
public static DOrtCreateArenaCfg OrtCreateArenaCfg;
732+
733+
/// <summary>
734+
/// Destroy an instance of an arena configuration instance
735+
/// </summary>
736+
/// <param name="arenaCfg">arena configuration instance to be destroyed</param>
737+
public delegate void DOrtReleaseArenaCfg(IntPtr /*(OrtArenaCfg*)*/ arenaCfg);
738+
public static DOrtReleaseArenaCfg OrtReleaseArenaCfg;
739+
711740
/// <summary>
712741
/// Create an instance of allocator according to mem_info
713742
/// </summary>
@@ -861,13 +890,16 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
861890

862891
/// <summary>
863892
/// Creates an allocator instance and registers it with the env to enable
864-
///sharing between multiple sessions that use the same env instance.
865-
///Lifetime of the created allocator will be valid for the duration of the environment.
866-
///Returns an error if an allocator with the same OrtMemoryInfo is already registered.
867-
/// </summary>
868-
/// <param name="mem_info">must be non-null</param>
869-
/// <param name="arena_cfg">if nullptr defaults will be used</param>
870-
public delegate void DOrtCreateAndRegisterAllocator(IntPtr /*(OrtIoBinding)*/ io_binding);
893+
/// sharing between multiple sessions that use the same env instance.
894+
/// Lifetime of the created allocator will be valid for the duration of the environment.
895+
/// Returns an error if an allocator with the same OrtMemoryInfo is already registered.
896+
/// <param name="env">Native OrtEnv instance</param>
897+
/// <param name="memInfo">Native OrtMemoryInfo instance</param>
898+
/// <param name="arenaCfg">Native OrtArenaCfg instance</param>
899+
/// <retruns>A pointer to native ortStatus indicating success/failure</retruns>
900+
public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateAndRegisterAllocator(IntPtr /*(OrtEnv*)*/ env,
901+
IntPtr /*(const OrtMemoryInfo*)*/ memInfo,
902+
IntPtr/*(const OrtArenaCfg*)*/ arenaCfg);
871903
public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator;
872904

873905
/// <summary>

csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs

+11
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ public void DisableTelemetryEvents()
108108
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableTelemetryEvents(Handle));
109109
}
110110

111+
/// <summary>
112+
/// Create and register an allocator to the OrtEnv instance
113+
/// so as to enable sharing across all sessions using the OrtEnv instance
114+
/// <param name="memInfo">OrtMemoryInfo instance to be used for allocator creation</param>
115+
/// <param name="arenaCfg">OrtArenaCfg instance that will be used to define the behavior of the arena based allocator</param>
116+
/// </summary>
117+
public void CreateAndRegisterAllocator(OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg)
118+
{
119+
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer));
120+
}
121+
111122
/// <summary>
112123
/// Queries all the execution providers supported in the native onnxruntime shared library
113124
/// </summary>

csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.cs

+57
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,63 @@ public enum OrtMemType
2727
Default = 0, // the default allocator for execution provider
2828
}
2929

30+
/// <summary>
31+
/// This class encapsulates arena configuration information that will be used to define the behavior
32+
/// of an arena based allocator
33+
/// See docs/C_API.md for more details
34+
/// </summary>
35+
public class OrtArenaCfg : SafeHandle
36+
{
37+
/// <summary>
38+
/// Create an instance of arena configuration which will be used to create an arena based allocator
39+
/// See docs/C_API.md for details on what the following parameters mean and how to choose these values
40+
/// </summary>
41+
/// <param name="maxMemory">Maximum amount of memory the arena allocates</param>
42+
/// <param name="arenaExtendStrategy">Strategy for arena expansion</param>
43+
/// <param name="initialChunkSizeBytes">Size of the region that the arena allocates first</param>
44+
/// <param name="maxDeadBytesPerChunk">Maximum amount of fragmentation allowed per chunk</param>
45+
public OrtArenaCfg(uint maxMemory, int arenaExtendStrategy, int initialChunkSizeBytes, int maxDeadBytesPerChunk)
46+
: base(IntPtr.Zero, true)
47+
{
48+
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateArenaCfg((UIntPtr)maxMemory,
49+
arenaExtendStrategy,
50+
initialChunkSizeBytes,
51+
maxDeadBytesPerChunk,
52+
out handle));
53+
}
54+
55+
internal IntPtr Pointer
56+
{
57+
get
58+
{
59+
return handle;
60+
}
61+
}
62+
63+
#region SafeHandle
64+
65+
/// <summary>
66+
/// Overrides SafeHandle.IsInvalid
67+
/// </summary>
68+
/// <value>returns true if handle is equal to Zero</value>
69+
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
70+
71+
/// <summary>
72+
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
73+
/// the native instance of OrtEnv
74+
/// </summary>
75+
/// <returns>always returns true</returns>
76+
protected override bool ReleaseHandle()
77+
{
78+
NativeMethods.OrtReleaseArenaCfg(handle);
79+
handle = IntPtr.Zero;
80+
return true;
81+
}
82+
83+
#endregion
84+
85+
}
86+
3087
/// <summary>
3188
/// This class encapsulates and most of the time owns the underlying native OrtMemoryInfo instance.
3289
/// Instance returned from OrtAllocator will not own OrtMemoryInfo, the class must be disposed

csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs

+75
Original file line numberDiff line numberDiff line change
@@ -2091,6 +2091,81 @@ private void TestWeightSharingBetweenSessions()
20912091
}
20922092
}
20932093

2094+
[Fact]
2095+
private void TestSharedAllocatorUsingCreateAndRegisterAllocator()
2096+
{
2097+
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "mul_1.onnx");
2098+
2099+
using (var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU,
2100+
OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default))
2101+
using (var arenaCfg = new OrtArenaCfg(0, -1, -1, -1))
2102+
{
2103+
var env = OrtEnv.Instance();
2104+
// Create and register the arena based allocator
2105+
env.CreateAndRegisterAllocator(memInfo, arenaCfg);
2106+
2107+
using (var sessionOptions = new SessionOptions())
2108+
{
2109+
// Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h
2110+
sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1");
2111+
2112+
// Create two sessions to share the allocator
2113+
// Create a thrid session that DOES NOT use the allocator in the environment
2114+
using (var session1 = new InferenceSession(modelPath, sessionOptions))
2115+
using (var session2 = new InferenceSession(modelPath, sessionOptions))
2116+
using (var session3 = new InferenceSession(modelPath)) // Use the default SessionOptions instance
2117+
{
2118+
// Input data
2119+
var inputDims = new long[] { 3, 2 };
2120+
var input = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F };
2121+
2122+
// Output data
2123+
int[] outputDims = { 3, 2 };
2124+
float[] output = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F };
2125+
2126+
// Run inference on all three models
2127+
var inputMeta = session1.InputMetadata;
2128+
var container = new List<NamedOnnxValue>();
2129+
2130+
foreach (var name in inputMeta.Keys)
2131+
{
2132+
Assert.Equal(typeof(float), inputMeta[name].ElementType);
2133+
Assert.True(inputMeta[name].IsTensor);
2134+
var tensor = new DenseTensor<float>(input, inputMeta[name].Dimensions);
2135+
container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
2136+
}
2137+
2138+
// Run inference with named inputs and outputs created with in Run()
2139+
using (var results = session1.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
2140+
{
2141+
foreach (var r in results)
2142+
{
2143+
validateRunResultData(r.AsTensor<float>(), output, outputDims);
2144+
}
2145+
}
2146+
2147+
// Run inference with named inputs and outputs created with in Run()
2148+
using (var results = session2.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
2149+
{
2150+
foreach (var r in results)
2151+
{
2152+
validateRunResultData(r.AsTensor<float>(), output, outputDims);
2153+
}
2154+
}
2155+
2156+
// Run inference with named inputs and outputs created with in Run()
2157+
using (var results = session3.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
2158+
{
2159+
foreach (var r in results)
2160+
{
2161+
validateRunResultData(r.AsTensor<float>(), output, outputDims);
2162+
}
2163+
}
2164+
}
2165+
}
2166+
}
2167+
}
2168+
20942169
[DllImport("kernel32", SetLastError = true)]
20952170
static extern IntPtr LoadLibrary(string lpFileName);
20962171

include/onnxruntime/core/framework/allocator.h

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99
#include "ortdevice.h"
1010
#include "ortmemoryinfo.h"
1111

12+
// This configures the arena based allocator used by ORT
13+
// See docs/C_API.md for details on what these mean and how to choose these values
14+
struct OrtArenaCfg {
15+
size_t max_mem; // use 0 to allow ORT to choose the default
16+
int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
17+
int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default
18+
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
19+
};
20+
1221
namespace onnxruntime {
1322
constexpr const char* CPU = "Cpu";
1423
constexpr const char* CUDA = "Cuda";

include/onnxruntime/core/session/environment.h

+6
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class Environment {
6161
*/
6262
Status RegisterAllocator(AllocatorPtr allocator);
6363

64+
/**
65+
* Creates and registers an allocator for sharing between multiple sessions.
66+
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
67+
*/
68+
Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr);
69+
6470
/**
6571
* Returns the list of registered allocators in this env.
6672
*/

include/onnxruntime/core/session/onnxruntime_c_api.h

+17-9
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,6 @@ typedef enum OrtErrorCode {
143143
ORT_EP_FAIL,
144144
} OrtErrorCode;
145145

146-
// This configures the arena based allocator used by ORT
147-
// See docs/C_API.md for details on what these mean and how to choose these values
148-
typedef struct OrtArenaCfg {
149-
size_t max_mem; // use 0 to allow ORT to choose the default
150-
int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
151-
int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default
152-
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
153-
} OrtArenaCfg;
154-
155146
#define ORT_RUNTIME_CLASS(X) \
156147
struct Ort##X; \
157148
typedef struct Ort##X Ort##X;
@@ -173,6 +164,7 @@ ORT_RUNTIME_CLASS(SequenceTypeInfo);
173164
ORT_RUNTIME_CLASS(ModelMetadata);
174165
ORT_RUNTIME_CLASS(ThreadPoolParams);
175166
ORT_RUNTIME_CLASS(ThreadingOptions);
167+
ORT_RUNTIME_CLASS(ArenaCfg);
176168

177169
#ifdef _WIN32
178170
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
@@ -1126,6 +1118,22 @@ struct OrtApi {
11261118
* and that's recommended because turning this option on may hurt model accuracy.
11271119
*/
11281120
ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options);
1121+
1122+
/**
1123+
* Use this API to create the configuration of an arena that can eventually be used to define
1124+
* an arena based allocator's behavior
1125+
* \param max_mem - use 0 to allow ORT to choose the default
1126+
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1127+
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1128+
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1129+
* \param out - a pointer to an OrtArenaCfg instance
1130+
* \return a nullptr in case of success or a pointer to an OrtStatus instance in case of failure
1131+
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
1132+
*/
1133+
ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
1134+
int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out);
1135+
1136+
ORT_CLASS_RELEASE(ArenaCfg);
11291137
};
11301138

11311139
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

+18
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ ORT_DEFINE_RELEASE(Value);
9898
ORT_DEFINE_RELEASE(ModelMetadata);
9999
ORT_DEFINE_RELEASE(ThreadingOptions);
100100
ORT_DEFINE_RELEASE(IoBinding);
101+
ORT_DEFINE_RELEASE(ArenaCfg);
101102

102103
/*! \class Ort::Float16_t
103104
* \brief it is a structure that represents float16 data.
@@ -551,6 +552,23 @@ struct IoBinding : public Base<OrtIoBinding> {
551552
void ClearBoundOutputs();
552553
};
553554

555+
/*! \struct Ort::ArenaCfg
556+
* \brief it is a structure that represents the configuration of an arena based allocator
557+
* \details Please see docs/C_API.md for details
558+
*/
559+
struct ArenaCfg : Base<OrtArenaCfg> {
560+
explicit ArenaCfg(std::nullptr_t) {}
561+
/**
562+
* \param max_mem - use 0 to allow ORT to choose the default
563+
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
564+
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
565+
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
566+
* \return an instance of ArenaCfg
567+
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
568+
*/
569+
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
570+
};
571+
554572
//
555573
// Custom OPs (only needed to implement custom OPs)
556574
//

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

+4
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ inline void IoBinding::ClearBoundOutputs() {
290290
GetApi().ClearBoundOutputs(p_);
291291
}
292292

293+
inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
294+
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
295+
}
296+
293297
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
294298
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
295299
if (strcmp(logid, "onnxruntime-node") == 0) {

0 commit comments

Comments
 (0)