|
14 | 14 | #include <ze_api.h> |
15 | 15 |
|
16 | 16 | #include "../common.hpp" |
| 17 | +#include "../ur_interface_loader.hpp" |
17 | 18 | #include "logger/ur_logger.hpp" |
18 | 19 |
|
19 | 20 | namespace v2 { |
@@ -54,8 +55,7 @@ struct ze_handle_wrapper { |
54 | 55 | try { |
55 | 56 | reset(); |
56 | 57 | } catch (...) { |
57 | | - // TODO: add appropriate logging or pass the error |
58 | | - // to the caller (make the dtor noexcept(false) or use tls?) |
| 58 | + // logging already done in reset |
59 | 59 | } |
60 | 60 | } |
61 | 61 |
|
@@ -104,5 +104,95 @@ using ze_context_handle_t = |
104 | 104 | using ze_command_list_handle_t = |
105 | 105 | ze_handle_wrapper<::ze_command_list_handle_t, zeCommandListDestroy>; |
106 | 106 |
|
| 107 | +template <typename URHandle, ur_result_t (*retain)(URHandle), |
| 108 | + ur_result_t (*release)(URHandle)> |
| 109 | +struct ref_counted { |
| 110 | + ref_counted(URHandle handle) : handle(handle) { |
| 111 | + if (handle) { |
| 112 | + retain(handle); |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + ~ref_counted() { |
| 117 | + if (handle) { |
| 118 | + release(handle); |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + operator URHandle() const { return handle; } |
| 123 | + URHandle operator->() const { return handle; } |
| 124 | + |
| 125 | + ref_counted(const ref_counted &) = delete; |
| 126 | + ref_counted &operator=(const ref_counted &) = delete; |
| 127 | + |
| 128 | + ref_counted(ref_counted &&other) { |
| 129 | + handle = other.handle; |
| 130 | + other.handle = nullptr; |
| 131 | + } |
| 132 | + |
| 133 | + ref_counted &operator=(ref_counted &&other) { |
| 134 | + if (this == &other) { |
| 135 | + return *this; |
| 136 | + } |
| 137 | + |
| 138 | + if (handle) { |
| 139 | + release(handle); |
| 140 | + } |
| 141 | + |
| 142 | + handle = other.handle; |
| 143 | + other.handle = nullptr; |
| 144 | + return *this; |
| 145 | + } |
| 146 | + |
| 147 | + URHandle get() const { return handle; } |
| 148 | + |
| 149 | +private: |
| 150 | + URHandle handle; |
| 151 | +}; |
| 152 | + |
| 153 | +template <typename URHandle> struct ref_counted_traits; |
| 154 | + |
| 155 | +#define DECLARE_REF_COUNTER_TRAITS(URHandle, retainFn, releaseFn) \ |
| 156 | + template <> struct ref_counted_traits<URHandle> { \ |
| 157 | + static ur_result_t retain(URHandle handle) { return retainFn(handle); } \ |
| 158 | + static ur_result_t release(URHandle handle) { return releaseFn(handle); } \ |
| 159 | + static ur_result_t nop(URHandle) { return UR_RESULT_SUCCESS; } \ |
| 160 | + static ur_result_t validate([[maybe_unused]] URHandle handle) { \ |
| 161 | + assert(reinterpret_cast<_ur_object *>(handle)->RefCount.load() != 0); \ |
| 162 | + return UR_RESULT_SUCCESS; \ |
| 163 | + } \ |
| 164 | + }; |
| 165 | + |
| 166 | +// This version of ref_counted calls retain/release functions. |
| 167 | +template <typename URHandle> |
| 168 | +using rc = ref_counted<URHandle, ref_counted_traits<URHandle>::retain, |
| 169 | + ref_counted_traits<URHandle>::release>; |
| 170 | + |
| 171 | +// This version of ref_counted does not call retain/release functions. |
| 172 | +// It is used to avoid circular references, most notably to ur_context_handle_t. |
| 173 | +// This is equivalent to just using URHandle but makes it clear that no ref |
| 174 | +// counting is expected. |
| 175 | +template <typename URHandle> |
| 176 | +using weak = ref_counted<URHandle, ref_counted_traits<URHandle>::nop, |
| 177 | + ref_counted_traits<URHandle>::nop>; |
| 178 | + |
| 179 | +// This version of ref_counted validates that the ref count is not zero on every |
| 180 | +// release and retain in debug mode, and does nothing in the release mode. |
| 181 | +// Used for types that should always be alibe during the adapter lifetime (e.g. |
| 182 | +// devices). |
| 183 | +template <typename URHandle> |
| 184 | +using rc_val_only = |
| 185 | + ref_counted<URHandle, ref_counted_traits<URHandle>::validate, |
| 186 | + ref_counted_traits<URHandle>::validate>; |
| 187 | + |
| 188 | +DECLARE_REF_COUNTER_TRAITS(::ur_device_handle_t, urDeviceRetain, |
| 189 | + urDeviceRelease); |
| 190 | +DECLARE_REF_COUNTER_TRAITS(::ur_context_handle_t, urContextRetain, |
| 191 | + urContextRelease); |
| 192 | +DECLARE_REF_COUNTER_TRAITS(::ur_mem_handle_t, urMemRetain, urMemRelease); |
| 193 | +DECLARE_REF_COUNTER_TRAITS(::ur_program_handle_t, urProgramRetain, |
| 194 | + urProgramRelease); |
| 195 | + |
107 | 196 | } // namespace raii |
| 197 | + |
108 | 198 | } // namespace v2 |
0 commit comments