Skip to content

Commit baf7a3c

Browse files
authored
[Offload] Properly guard modifications to the RPC device array (#126790)
Summary: If the user deallocates an RPC device this can sometimes fail if the RPC server is still running. This will happen if the modification happens while the server is still checking it. This patch adds a mutex to guard modifications to it.
1 parent e4016bf commit baf7a3c

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

offload/plugins-nextgen/common/include/RPC.h

+9-3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ struct RPCServerTy {
7272
/// Array of associated devices. These must be alive as long as the server is.
7373
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
7474

75+
/// Mutex that guards accesses to the buffers and device array.
76+
std::mutex BufferMutex{};
77+
7578
/// A helper class for running the user thread that handles the RPC interface.
7679
/// Because we only need to check the RPC server while any kernels are
7780
/// working, we track submission / completion events to allow the thread to
@@ -90,6 +93,9 @@ struct RPCServerTy {
9093
std::condition_variable CV;
9194
std::mutex Mutex;
9295

96+
/// A reference to the main server's mutex.
97+
std::mutex &BufferMutex;
98+
9399
/// A reference to all the RPC interfaces that the server is handling.
94100
llvm::ArrayRef<void *> Buffers;
95101

@@ -98,9 +104,9 @@ struct RPCServerTy {
98104

99105
/// Initialize the worker thread to run in the background.
100106
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
101-
size_t Length)
102-
: Running(false), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
103-
Devices(Devices, Length) {}
107+
size_t Length, std::mutex &BufferMutex)
108+
: Running(false), NumUsers(0), CV(), Mutex(), BufferMutex(BufferMutex),
109+
Buffers(Buffers, Length), Devices(Devices, Length) {}
104110

105111
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
106112

offload/plugins-nextgen/common/src/RPC.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ void RPCServerTy::ServerThread::run() {
128128
Lock.unlock();
129129
while (NumUsers.load(std::memory_order_relaxed) > 0 &&
130130
Running.load(std::memory_order_relaxed)) {
131+
std::lock_guard<decltype(Mutex)> Lock(BufferMutex);
131132
for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
132133
if (!Buffer || !Device)
133134
continue;
@@ -146,7 +147,7 @@ RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
146147
Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
147148
Plugin.getNumDevices())),
148149
Thread(new ServerThread(Buffers.get(), Devices.get(),
149-
Plugin.getNumDevices())) {}
150+
Plugin.getNumDevices(), BufferMutex)) {}
150151

151152
llvm::Error RPCServerTy::startThread() {
152153
Thread->startThread();
@@ -187,13 +188,15 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
187188
if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
188189
sizeof(rpc::Client), nullptr))
189190
return Err;
191+
std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
190192
Buffers[Device.getDeviceId()] = RPCBuffer;
191193
Devices[Device.getDeviceId()] = &Device;
192194

193195
return Error::success();
194196
}
195197

196198
Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
199+
std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
197200
Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
198201
Buffers[Device.getDeviceId()] = nullptr;
199202
Devices[Device.getDeviceId()] = nullptr;

0 commit comments

Comments
 (0)