@@ -77,8 +77,9 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
77
77
// /
78
78
79
79
static ur_result_t
80
- CreateHostMemoryProvider (ur_device_handle_t_ *DeviceHandle,
81
- umf_memory_provider_handle_t *MemoryProviderHost) {
80
+ CreateHostMemoryProviderPool (ur_device_handle_t_ *DeviceHandle,
81
+ umf_memory_provider_handle_t *MemoryProviderHost,
82
+ umf_memory_pool_handle_t *MemoryPoolHost) {
82
83
umf_cuda_memory_provider_params_handle_t CUMemoryProviderParams = nullptr ;
83
84
84
85
*MemoryProviderHost = nullptr ;
@@ -91,10 +92,20 @@ CreateHostMemoryProvider(ur_device_handle_t_ *DeviceHandle,
91
92
umf::cuda_params_unique_handle_t CUMemoryProviderParamsUnique (
92
93
CUMemoryProviderParams, umfCUDAMemoryProviderParamsDestroy);
93
94
94
- // create UMF CUDA memory provider for the host memory (UMF_MEMORY_TYPE_HOST)
95
- UmfResult = umf::createMemoryProvider (
96
- CUMemoryProviderParamsUnique.get (), 0 /* cuDevice */ , context,
97
- UMF_MEMORY_TYPE_HOST, MemoryProviderHost);
95
+ UmfResult = umf::setCUMemoryProviderParams (CUMemoryProviderParamsUnique.get (),
96
+ 0 /* cuDevice */ , context,
97
+ UMF_MEMORY_TYPE_HOST);
98
+ UMF_RETURN_UR_ERROR (UmfResult);
99
+
100
+ // create UMF CUDA memory provider and pool for the host memory
101
+ // (UMF_MEMORY_TYPE_HOST)
102
+ UmfResult = umfMemoryProviderCreate (umfCUDAMemoryProviderOps (),
103
+ CUMemoryProviderParamsUnique.get (),
104
+ MemoryProviderHost);
105
+ UMF_RETURN_UR_ERROR (UmfResult);
106
+
107
+ UmfResult = umfPoolCreate (umfProxyPoolOps (), *MemoryProviderHost, nullptr , 0 ,
108
+ MemoryPoolHost);
98
109
UMF_RETURN_UR_ERROR (UmfResult);
99
110
100
111
return UR_RESULT_SUCCESS;
@@ -112,8 +123,10 @@ struct ur_context_handle_t_ {
112
123
std::vector<ur_device_handle_t > Devices;
113
124
std::atomic_uint32_t RefCount;
114
125
115
- // UMF CUDA memory provider for the host memory (UMF_MEMORY_TYPE_HOST)
126
+ // UMF CUDA memory provider and pool for the host memory
127
+ // (UMF_MEMORY_TYPE_HOST)
116
128
umf_memory_provider_handle_t MemoryProviderHost = nullptr ;
129
+ umf_memory_pool_handle_t MemoryPoolHost = nullptr ;
117
130
118
131
ur_context_handle_t_ (const ur_device_handle_t *Devs, uint32_t NumDevices)
119
132
: Devices{Devs, Devs + NumDevices}, RefCount{1 } {
@@ -124,10 +137,14 @@ struct ur_context_handle_t_ {
124
137
// Create UMF CUDA memory provider for the host memory
125
138
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
126
139
// it is guaranteed to exist).
127
- UR_CHECK_ERROR (CreateHostMemoryProvider (Devices[0 ], &MemoryProviderHost));
140
+ UR_CHECK_ERROR (CreateHostMemoryProviderPool (Devices[0 ], &MemoryProviderHost,
141
+ &MemoryPoolHost));
128
142
};
129
143
130
144
~ur_context_handle_t_ () {
145
+ if (MemoryPoolHost) {
146
+ umfPoolDestroy (MemoryPoolHost);
147
+ }
131
148
if (MemoryProviderHost) {
132
149
umfMemoryProviderDestroy (MemoryProviderHost);
133
150
}
0 commit comments