Skip to content

Commit 1e6acc6

Browse files
ezyangfacebook-github-bot
authored andcommitted
Replace caffe2::DeviceGuard with c10::cuda::CUDAGuard (pytorch#17623)
Summary: Pull Request resolved: pytorch#17623 Despite it's generic sounding name, caffe2::DeviceGuard actually only worked on CUDA devices. Rename it to something that more clearly spells out its applicability. I'm not sure if it's the right call, but in this patch I added 'using CUDAGuard = c10::cuda::CUDAGuard', as this seems to be more in-line with how the Caffe2 codebase is currently written. More idiomatic c10 namespace style would be to say cuda::CUDAGuard. Willing to change this if people shout. This is a respin of D13156470 (pytorch#14284) Reviewed By: dzhulgakov Differential Revision: D14285504 fbshipit-source-id: 93b8ab938b064572b3b010c307e1261fde0fff3d
1 parent e9eb18a commit 1e6acc6

File tree

11 files changed

+31
-42
lines changed

11 files changed

+31
-42
lines changed

caffe2/contrib/nccl/cuda_nccl_gpu.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class NCCLContext {
2626
streams_.resize(devices_.size());
2727
events_.resize(devices_.size());
2828
for (auto i = 0; i < devices_.size(); ++i) {
29-
DeviceGuard g(devices_[i]);
29+
CUDAGuard g(devices_[i]);
3030
// get stream priorities
3131
int lo_pri, hi_pri;
3232
CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
@@ -35,18 +35,18 @@ class NCCLContext {
3535
CUDA_ENFORCE(cudaEventCreateWithFlags(
3636
&events_[i], cudaEventDefault | cudaEventDisableTiming));
3737
}
38-
DeviceGuard g(master_gpu_id_);
38+
CUDAGuard g(master_gpu_id_);
3939
CUDA_ENFORCE(cudaEventCreateWithFlags(
4040
&master_event_, cudaEventDefault | cudaEventDisableTiming));
4141
}
4242

4343
~NCCLContext() {
4444
for (auto i = 0; i < devices_.size(); ++i) {
45-
DeviceGuard g(devices_[i]);
45+
CUDAGuard g(devices_[i]);
4646
CUDA_ENFORCE(cudaStreamDestroy(streams_[i]));
4747
CUDA_ENFORCE(cudaEventDestroy(events_[i]));
4848
}
49-
DeviceGuard g(master_gpu_id_);
49+
CUDAGuard g(master_gpu_id_);
5050
CUDA_ENFORCE(cudaEventDestroy(master_event_));
5151

5252
/*
@@ -137,7 +137,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
137137
// do initialization
138138
for (auto i = 0; i < ex.elements.size(); ++i) {
139139
auto& ctx = ex.elements[i];
140-
DeviceGuard g(ctx.device);
140+
CUDAGuard g(ctx.device);
141141
init_f(ex.elements[i]);
142142
}
143143

@@ -150,7 +150,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
150150
// children streams, so the children streams are synchronized WRT
151151
// the original stream.
152152
{
153-
DeviceGuard g(ex.stream_gpu_id);
153+
CUDAGuard g(ex.stream_gpu_id);
154154
CUDA_ENFORCE(cudaEventRecord(context->master_event_, ex.stream));
155155
}
156156

@@ -164,7 +164,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
164164

165165
for (auto i = 0; i < ex.elements.size(); ++i) {
166166
auto& ctx = ex.elements[i];
167-
DeviceGuard g(ctx.device);
167+
CUDAGuard g(ctx.device);
168168
auto& comm = comms[i];
169169
auto& stream = streams[i];
170170
auto& event = events[i];
@@ -180,7 +180,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
180180

181181
for (auto i = 0; i < ex.elements.size(); ++i) {
182182
auto& ctx = ex.elements[i];
183-
DeviceGuard g(ctx.device);
183+
CUDAGuard g(ctx.device);
184184
auto& comm = comms[i];
185185
auto& stream = streams[i];
186186
auto& event = events[i];
@@ -192,7 +192,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
192192
}
193193

194194
// Now, wait on all the events in the original stream.
195-
DeviceGuard dg(ex.stream_gpu_id);
195+
CUDAGuard dg(ex.stream_gpu_id);
196196
for (auto& event : events) {
197197
CUDA_ENFORCE(cudaStreamWaitEvent(CHECK_NOTNULL(ex.stream), event, 0));
198198
}

caffe2/core/blob_gpu_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
190190
tensor.mutable_data<float>()[i] = i;
191191
}
192192
for (int gpu_id = 0; gpu_id < NumCudaDevices(); ++gpu_id) {
193-
DeviceGuard guard(gpu_id);
193+
CUDAGuard guard(gpu_id);
194194
CUDAContext context(gpu_id); // switch to the current gpu
195195
blob.Reset(new Tensor(tensor, CUDA));
196196
string serialized = SerializeBlob(blob, "test");

caffe2/core/blob_serialization.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ void TensorSerializer::Serialize(
223223
const TensorProto::DataType data_type = TypeMetaToDataType(input.dtype());
224224
proto.set_data_type(data_type);
225225
StoreDeviceDetail(input, &proto);
226-
// TODO: use DeviceGuard here instead of context and employ explicit sync
226+
// TODO: use CUDAGuard here instead of context and employ explicit sync
227227
// copy
228228
auto uniq_ptr = CreateContext(input.GetDevice());
229229
// A lot of copypaste is error prone. Should we create a macro for this?

caffe2/core/common_gpu.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "c10/cuda/CUDAMacros.h"
2929
#include "c10/cuda/CUDAMathCompat.h"
30+
#include <c10/cuda/CUDAGuard.h>
3031

3132
// Defines CAFFE2_CUDA_EXPORT and CAFFE2_CUDA_IMPORT. On Windows, this
3233
// corresponds to different declarations (dllexport and dllimport). On
@@ -371,21 +372,7 @@ inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) {
371372
return grid;
372373
}
373374

374-
class DeviceGuard {
375-
public:
376-
explicit DeviceGuard(int newDevice) : previous_(CaffeCudaGetDevice()) {
377-
if (previous_ != newDevice) {
378-
CaffeCudaSetDevice(newDevice);
379-
}
380-
}
381-
382-
~DeviceGuard() noexcept {
383-
CaffeCudaSetDevice(previous_);
384-
}
385-
386-
private:
387-
int previous_;
388-
};
375+
using CUDAGuard = c10::cuda::CUDAGuard;
389376

390377
template <typename T, int N>
391378
struct SimpleArray {

caffe2/core/context_gpu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void CUDAContext::CopyBytesSync(
112112
// This emulates Caffe2 original behavior where sync copy doesn't change the
113113
// device. It's probably better for clarity to switch to the target device
114114
// explicitly here, but in the worst case CUDA would sync for us.
115-
// TODO: change it to DeviceGuard
115+
// TODO: change it to CUDAGuard
116116
CUDAContext context(-1); // take current device
117117
CUDA_ENFORCE(cudaMemcpyAsync(
118118
dst, src, nbytes, cudaMemcpyDefault, context.cuda_stream()));
@@ -212,7 +212,7 @@ static void Caffe2InitializeCuda() {
212212
"). Increase that and recompile.");
213213

214214
for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {
215-
DeviceGuard g(i);
215+
CUDAGuard g(i);
216216
// Enable peer access.
217217
const int peer_group = i / CAFFE2_CUDA_MAX_PEER_SIZE;
218218
const int peer_start = peer_group * CAFFE2_CUDA_MAX_PEER_SIZE;

caffe2/core/context_gpu.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
103103
}
104104

105105
cublasHandle_t GetHandle(c10::cuda::CUDAStream cuda_stream) {
106-
DeviceGuard guard(cuda_stream.device_index());
106+
CUDAGuard guard(cuda_stream.device_index());
107107
// Default construct in the map if it doesn't exist, and return a mutable
108108
// refernce to it.
109109
auto& r = cublas_handles_[cuda_stream];
@@ -127,7 +127,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
127127
}
128128

129129
cudnnHandle_t GetCudnnHandle(c10::cuda::CUDAStream cuda_stream) {
130-
DeviceGuard guard(cuda_stream.device_index());
130+
CUDAGuard guard(cuda_stream.device_index());
131131
auto& r = cudnn_handles_[cuda_stream];
132132
if (r == nullptr) {
133133
CUDNN_ENFORCE(cudnnCreate(&r));
@@ -234,7 +234,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
234234

235235
curandGenerator_t& curand_generator() {
236236
if (!curand_generator_) {
237-
DeviceGuard guard(gpu_id_);
237+
CUDAGuard guard(gpu_id_);
238238
CURAND_ENFORCE(
239239
curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
240240
CURAND_ENFORCE(

caffe2/core/context_gpu_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ TEST(CUDAContextTest, MemoryPoolAllocateDealloc) {
4343
const int nbytes = 1048576;
4444
for (int i = 0; i < NumCudaDevices(); ++i) {
4545
LOG(INFO) << "Device " << i << " of " << NumCudaDevices();
46-
DeviceGuard guard(i);
46+
CUDAGuard guard(i);
4747
auto allocated = CUDAContext::New(nbytes);
4848
EXPECT_NE(allocated, nullptr);
4949
cudaPointerAttributes attr;

caffe2/core/cudnn_wrappers.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct CuDNNWorkspace {
8787
class CuDNNState {
8888
public:
8989
explicit CuDNNState(size_t gpu_id) : gpu_id_(gpu_id) {
90-
DeviceGuard g(gpu_id_);
90+
CUDAGuard g(gpu_id_);
9191
CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_));
9292
CUDA_ENFORCE(cudaEventCreate(&before_));
9393
CUDA_ENFORCE(cudaEventCreate(&after_));
@@ -96,7 +96,7 @@ class CuDNNState {
9696
}
9797

9898
~CuDNNState() noexcept {
99-
DeviceGuard g(gpu_id_);
99+
CUDAGuard g(gpu_id_);
100100
CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
101101
CUDA_CHECK(cudaStreamDestroy(stream_));
102102
CUDA_CHECK(cudaEventDestroy(after_));
@@ -162,7 +162,7 @@ class CuDNNWrapper {
162162
state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES, "Invalid state_idx");
163163
auto& sync_state = cudnn_states()[context_->device_id()][state_idx];
164164

165-
DeviceGuard dg(context_->device_id());
165+
CUDAGuard dg(context_->device_id());
166166

167167
// We need to serialize execution on the CuDNNState as we can't
168168
// allow multiple threads to race through the cudaEventRecord

caffe2/core/event_gpu.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ struct CudaEventWrapper {
1212
device_id_(option.device_id()),
1313
status_(EventStatus::EVENT_INITIALIZED) {
1414
CAFFE_ENFORCE(option.device_type(), PROTO_CUDA);
15-
DeviceGuard g(device_id_);
15+
CUDAGuard g(device_id_);
1616
CUDA_ENFORCE(cudaEventCreateWithFlags(
1717
&cuda_event_, cudaEventDefault | cudaEventDisableTiming));
1818
}
1919
~CudaEventWrapper() {
20-
DeviceGuard g(device_id_);
20+
CUDAGuard g(device_id_);
2121
CUDA_CHECK(cudaEventDestroy(cuda_event_));
2222
}
2323

@@ -96,7 +96,7 @@ void EventFinishCUDA(const Event* event) {
9696

9797
if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) {
9898
// ok, even if event is already completed and status was not yet updated
99-
DeviceGuard g(wrapper->device_id_);
99+
CUDAGuard g(wrapper->device_id_);
100100
auto cudaResult = cudaEventSynchronize(wrapper->cuda_event_);
101101
if (cudaResult == cudaSuccess) {
102102
wrapper->status_ = EventStatus::EVENT_SUCCESS;

caffe2/core/hip/miopen_wrapper.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "caffe2/core/hip/common_miopen.h"
66
#include "caffe2/core/hip/context_gpu.h"
77

8+
#include <c10/hip/HIPGuard.h>
9+
810
namespace caffe2 {
911

1012
class MIOPENWrapper;
@@ -53,7 +55,7 @@ class MIOPENState
5355
public:
5456
explicit MIOPENState(size_t gpu_id) : gpu_id_(gpu_id)
5557
{
56-
DeviceGuard g(gpu_id_);
58+
HIPGuard g(gpu_id_);
5759
MIOPEN_ENFORCE(miopenCreate(&miopen_handle_));
5860
HIP_ENFORCE(hipEventCreate(&before_));
5961
HIP_ENFORCE(hipEventCreate(&after_));
@@ -63,7 +65,7 @@ class MIOPENState
6365

6466
~MIOPENState() noexcept
6567
{
66-
DeviceGuard g(gpu_id_);
68+
HIPGuard g(gpu_id_);
6769
MIOPEN_CHECK(miopenDestroy(miopen_handle_));
6870
HIP_CHECK(hipStreamDestroy(stream_));
6971
HIP_CHECK(hipEventDestroy(after_));
@@ -125,7 +127,7 @@ class MIOPENWrapper
125127
CAFFE_ENFORCE(state_idx < CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES, "Invalid state_idx");
126128
auto& sync_state = miopen_states()[context_->device_id()][state_idx];
127129

128-
DeviceGuard dg(context_->device_id());
130+
HIPGuard dg(context_->device_id());
129131

130132
// We need to serialize execution on the MIOPENState as we can't
131133
// allow multiple threads to race through the cudaEventRecord

0 commit comments

Comments
 (0)