Skip to content

Commit 5653a91

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Implement reference counting for shared IPC CUDA tensors (pytorch#16854)
Summary: This is to fix pytorch#16141 and similar issues. The idea is to track a reference to every shared CUDA Storage and deallocate memory only after a consumer process deallocates received Storage. ezyang Done with cleanup. Same (insignificantly better) performance as in file-per-share solution, but handles millions of shared tensors easily. Note [ ] documentation in progress. Pull Request resolved: pytorch#16854 Differential Revision: D13994490 Pulled By: VitalyFedyunin fbshipit-source-id: 565148ec3ac4fafb32d37fde0486b325bed6fbd1
1 parent f5ea528 commit 5653a91

15 files changed

+841
-88
lines changed

c10/core/StorageImpl.h

+14
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
1919
data_ptr_(std::move(data_ptr)),
2020
numel_(numel),
2121
resizable_(resizable),
22+
received_cuda_(false),
2223
allocator_(allocator) {
2324
if (resizable) {
2425
AT_ASSERTM(
@@ -210,11 +211,24 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
210211
resizable_ = false;
211212
}
212213

214+
// This method can be used only after storage construction and cannot be used
215+
// to modify storage status
216+
void set_received_cuda(bool received_cuda) {
217+
received_cuda_ = received_cuda;
218+
}
219+
220+
bool received_cuda() {
221+
return received_cuda_;
222+
}
223+
213224
private:
214225
caffe2::TypeMeta data_type_;
215226
DataPtr data_ptr_;
216227
int64_t numel_;
217228
bool resizable_;
229+
// Identifies that Storage was received from another process and doesn't have
230+
// local to process cuda memory allocation
231+
bool received_cuda_;
218232
Allocator* allocator_;
219233
};
220234
} // namespace c10

c10/cuda/CUDACachingAllocator.cpp

+37-16
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
#include <vector>
1717

1818
namespace c10 {
19-
namespace cuda {
2019

20+
C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
21+
22+
namespace cuda {
2123
namespace CUDACachingAllocator {
2224

2325
//
@@ -47,6 +49,8 @@ namespace CUDACachingAllocator {
4749
// work.
4850
//
4951

52+
53+
5054
namespace {
5155

5256
using stream_set = std::unordered_set<cuda::CUDAStream>;
@@ -154,7 +158,7 @@ struct THCCachingAllocator
154158
std::vector<DeviceStats> device_stats;
155159

156160
// lock around all operations
157-
std::mutex mutex;
161+
std::recursive_mutex mutex;
158162

159163
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
160164
std::mutex cuda_free_mutex;
@@ -186,7 +190,7 @@ struct THCCachingAllocator
186190
/** allocates a block which is safe to use from the provided stream */
187191
void malloc(void** devPtr, size_t size, cudaStream_t stream)
188192
{
189-
std::lock_guard<std::mutex> lock(mutex);
193+
std::lock_guard<std::recursive_mutex> lock(mutex);
190194

191195
int device;
192196
C10_CUDA_CHECK(cudaGetDevice(&device));
@@ -201,14 +205,29 @@ struct THCCachingAllocator
201205
Block search_key(device, stream, size);
202206
auto& pool = get_pool(size);
203207

204-
Block* block = nullptr;
205-
Block* remaining = nullptr;
206-
207-
auto it = pool.lower_bound(&search_key);
208-
if (it != pool.end() && (*it)->device == device && (*it)->stream == stream) {
209-
block = *it;
210-
pool.erase(it);
211-
} else {
208+
auto find_free_block = [&]()->Block*{
209+
auto it = pool.lower_bound(&search_key);
210+
if (it != pool.end() && (*it)->device == device &&
211+
(*it)->stream == stream) {
212+
Block* block = *it;
213+
pool.erase(it);
214+
return block;
215+
}
216+
return nullptr;
217+
};
218+
219+
Block* block = find_free_block();
220+
if (block == nullptr) {
221+
bool freed_memory = false;
222+
for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
223+
freed_memory |=
224+
FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
225+
}
226+
if (freed_memory) {
227+
block = find_free_block();
228+
}
229+
}
230+
if (block == nullptr) {
212231
void* ptr;
213232
size_t alloc_size = get_allocation_size(size);
214233
cudaError_t err = cuda_malloc_retry(device, &ptr, alloc_size);
@@ -253,8 +272,10 @@ struct THCCachingAllocator
253272
block = new Block(device, stream, alloc_size, &pool, ptr);
254273
}
255274

275+
Block* remaining = nullptr;
256276
AT_ASSERT(block);
257277
if (should_split(block, size)) {
278+
258279
remaining = block;
259280

260281
block = new Block(device, stream, size, &pool, block->ptr);
@@ -280,7 +301,7 @@ struct THCCachingAllocator
280301

281302
void free(void* ptr)
282303
{
283-
std::lock_guard<std::mutex> lock(mutex);
304+
std::lock_guard<std::recursive_mutex> lock(mutex);
284305
if (!ptr) {
285306
return;
286307
}
@@ -305,14 +326,14 @@ struct THCCachingAllocator
305326
/** returns cached blocks to the system allocator */
306327
void emptyCache()
307328
{
308-
std::lock_guard<std::mutex> lock(mutex);
329+
std::lock_guard<std::recursive_mutex> lock(mutex);
309330
free_blocks(large_blocks, large_blocks.begin(), large_blocks.end());
310331
free_blocks(small_blocks, small_blocks.begin(), small_blocks.end());
311332
}
312333

313334
void* getBaseAllocation(void* ptr, size_t* outSize)
314335
{
315-
std::lock_guard<std::mutex> lock(mutex);
336+
std::lock_guard<std::recursive_mutex> lock(mutex);
316337
Block* block = find_allocated_block(ptr);
317338
if (!block) {
318339
AT_ERROR("invalid device pointer: %p", ptr);
@@ -348,14 +369,14 @@ struct THCCachingAllocator
348369

349370
void cacheInfo(int dev_id, size_t* total, size_t* largest)
350371
{
351-
std::lock_guard<std::mutex> lock(mutex);
372+
std::lock_guard<std::recursive_mutex> lock(mutex);
352373
cacheInfoAux(large_blocks, dev_id, total, largest);
353374
cacheInfoAux(small_blocks, dev_id, total, largest);
354375
}
355376

356377
void recordStream(void* ptr, cuda::CUDAStream stream)
357378
{
358-
std::lock_guard<std::mutex> lock(mutex);
379+
std::lock_guard<std::recursive_mutex> lock(mutex);
359380
Block* block = find_allocated_block(ptr);
360381
if (!block) {
361382
AT_ERROR("invalid device pointer: %p", ptr);

c10/cuda/CUDACachingAllocator.h

+14
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,24 @@
44
#include <c10/cuda/CUDAStream.h>
55
#include <c10/core/Allocator.h>
66
#include <c10/cuda/CUDAMacros.h>
7+
#include <c10/util/Registry.h>
78

89
#include <mutex>
910

1011
namespace c10 {
12+
13+
// Caching allocator will execute every registered callback if it unable to find
14+
// block inside of already allocated area.
15+
class C10_CUDA_API FreeMemoryCallback {
16+
public:
17+
virtual ~FreeMemoryCallback() {};
18+
virtual bool Execute() = 0;
19+
};
20+
21+
C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
22+
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
23+
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
24+
1125
namespace cuda {
1226

1327
// TODO: Turn this into an honest to goodness class. I briefly attempted to do

docs/source/multiprocessing.rst

+55-47
Original file line numberDiff line numberDiff line change
@@ -28,57 +28,65 @@ Python 2 can only create subprocesses using ``fork``, and it's not supported
2828
by the CUDA runtime.
2929

3030
Unlike CPU tensors, the sending process is required to keep the original tensor
31-
as long as the receiving process retains a copy of the tensor.
32-
This shouldn't be a problem for sharing model parameters (which stay live
33-
for the entire execution of the model), but passing other
34-
kinds of data should be done with care.
31+
as long as the receiving process retains a copy of the tensor. It is implemented
32+
under the hood but requires users to follow the next best practices.
3533

36-
Here is an example program which handles these requirements correctly:
34+
1. Release memory ASAP in the consumer.
3735

3836
::
3937

40-
import torch
41-
import torch.multiprocessing as mp
42-
43-
torch.set_default_tensor_type(torch.cuda.FloatTensor)
44-
45-
def sender(q, e):
46-
for i in range(10):
47-
s_sample = [torch.zeros(1), torch.ones(1)]
48-
q.put(s_sample)
49-
e.wait()
50-
del s_sample
51-
e.clear()
52-
53-
if __name__ == "__main__":
54-
ctx = mp.get_context("spawn")
55-
q = ctx.Queue()
56-
e = ctx.Event()
57-
p = ctx.Process(target=sender, args=(q, e))
58-
p.start()
59-
60-
for i in range(10):
61-
print('=== ITER {} ===".format(i))
62-
r_sample = q.get()
63-
del r_sample
64-
e.set()
65-
66-
p.join()
67-
68-
In the example above, calling `e.wait()`
69-
on sender side ensures tensor `s_sample` doesn't get deleted while
70-
receiver is working on it. The receiver signals when it is done
71-
with the tensor using `e.set()`, being careful to `del` its reference
72-
to the received tensor first. It is INSUFFICIENT to promise never to call
73-
`r_sample` again; while `r_sample` is live, it may be confused with
74-
any subsequent tensors allocated by the source process at the same address.
75-
76-
If a receiver wants to save the data of `r_sample` for future use while
77-
letting the source process deallocate the original, it must
78-
`clone()` it.
79-
80-
This behavior is very confusing, and we are tracking a fix for it
81-
at https://github.com/pytorch/pytorch/issues/16141
38+
## Good
39+
x = queue.get()
40+
# do somethings with x
41+
del x
42+
43+
::
44+
45+
## Bad
46+
x = queue.get()
47+
# do somethings with x
48+
# do everything else (producer have to keep x in memory)
49+
50+
2. Keep producer process running until all consumers exits. This will prevent
51+
the situation when the producer process releasing memory which is still in use
52+
by the consumer.
53+
54+
::
55+
56+
## producer
57+
# send tensors, do something
58+
event.wait()
59+
60+
61+
::
62+
63+
## consumer
64+
# receive tensors and use them
65+
event.set()
66+
67+
3. Don't pass received tensors.
68+
69+
::
70+
71+
# not going to work
72+
x = queue.get()
73+
queue_2.put(x)
74+
75+
76+
::
77+
78+
# you need to create a process-local copy
79+
x = queue.get()
80+
x_clone = x.clone()
81+
queue_2.put(x_clone)
82+
83+
84+
::
85+
86+
# putting and getting from the same queue in the same process will likely end up with segfault
87+
queue.put(tensor)
88+
x = queue.get()
89+
8290

8391
Sharing strategies
8492
------------------

0 commit comments

Comments
 (0)