Skip to content

Commit d7e11e3

Browse files
ezyangfacebook-github-bot
authored andcommitted
Revert "Move CreateContext to global registry (pytorch#11688)" (pytorch#12049)
Summary: This reverts commit 3ae6ee4. Pull Request resolved: pytorch#12049 Differential Revision: D10030954 Pulled By: ezyang fbshipit-source-id: 6ca9de65b707c5b4c68280fc6f1b8e5ad7251efc
1 parent 3deb479 commit d7e11e3

21 files changed

+94
-122
lines changed

aten/src/ATen/core/context_base.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
#include <ATen/core/context_base.h>
22

3-
namespace at {
4-
5-
AT_DEFINE_TYPED_REGISTRY(
6-
ContextRegistry,
7-
DeviceType,
8-
BaseContext,
9-
std::unique_ptr,
10-
at::Device);
11-
12-
} // namespace at
13-
143
namespace caffe2 {
154

165
// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h

aten/src/ATen/core/context_base.h

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
#include <memory>
77
#include <unordered_map>
88

9-
#include <ATen/core/ATenGeneral.h>
10-
#include <ATen/core/Device.h>
9+
#include <ATen/core/DeviceType.h>
1110
#include <ATen/core/Error.h>
12-
#include <ATen/core/Registry.h>
1311
#include <ATen/core/UniqueVoidPtr.h>
1412
#include <ATen/core/typeid.h>
13+
#include <ATen/core/ATenGeneral.h>
1514

1615
namespace caffe2 {
1716
class Event;
@@ -32,6 +31,11 @@ class AT_CORE_API BaseStaticContext {
3231

3332
virtual std::pair<void*, DeleterFnPtr> New(size_t nbytes) const = 0;
3433

34+
virtual std::unique_ptr<BaseContext> CreateContext() = 0;
35+
36+
virtual std::unique_ptr<BaseContext> CreateContext(
37+
const caffe2::DeviceOption&) = 0;
38+
3539
virtual DeviceType GetDeviceType() = 0;
3640

3741
/*
@@ -180,22 +184,6 @@ class AT_CORE_API BaseContext {
180184
}
181185
};
182186

183-
// Context constructor registry
184-
AT_DECLARE_TYPED_REGISTRY(
185-
ContextRegistry,
186-
at::DeviceType,
187-
BaseContext,
188-
std::unique_ptr,
189-
at::Device);
190-
191-
#define REGISTER_CONTEXT(type, ...) \
192-
AT_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)
193-
194-
inline std::unique_ptr<at::BaseContext> CreateContext(
195-
const at::Device& device) {
196-
return ContextRegistry()->Create(device.type(), device);
197-
}
198-
199187
} // namespace at
200188

201189
namespace caffe2 {

caffe2/core/blob_serialization.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void TensorSerializer::Serialize(
196196
const TensorProto::DataType data_type = TypeMetaToDataType(input.meta());
197197
proto.set_data_type(data_type);
198198
StoreDeviceDetail(input, &proto);
199-
auto uniq_ptr = CreateContext(input.GetDevice());
199+
auto uniq_ptr = input.GetStaticContext()->CreateContext();
200200
// A lot of copypaste is error prone. Should we create a macro for this?
201201
switch (data_type) {
202202
case TensorProto_DataType_FLOAT:
@@ -370,7 +370,8 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
370370
void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
371371
// We create a local context for deserializing. Since Caffe2 contexts are
372372
// usually lightweight, this should not involve too much overhead.
373-
auto uniq_ptr = CreateContext(OptionToDevice(proto.device_detail()));
373+
auto uniq_ptr =
374+
tensor->GetStaticContext()->CreateContext(proto.device_detail());
374375
auto context = uniq_ptr.get();
375376
context->SwitchToDevice(0);
376377
vector<int64_t> dims;

caffe2/core/context.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
#include <process.h>
66
#endif
77

8-
namespace at {
9-
10-
REGISTER_CONTEXT(DeviceType::CPU, caffe2::CPUContext);
11-
} // namespace at
128
namespace caffe2 {
139

1410
uint32_t RandomNumberSeed() {

caffe2/core/context.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ class CAFFE2_API CPUContext final : public BaseContext {
5050
: RandomNumberSeed()) {
5151
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU);
5252
}
53-
explicit CPUContext(const at::Device& device)
54-
: CPUContext(DeviceToOption(device)) {}
5553

5654
~CPUContext() noexcept override {}
5755

@@ -194,6 +192,15 @@ class CAFFE2_API CPUStaticContext : public BaseStaticContext {
194192
return data_and_deleter;
195193
}
196194

195+
std::unique_ptr<BaseContext> CreateContext() override {
196+
return caffe2::make_unique<CPUContext>();
197+
}
198+
199+
std::unique_ptr<BaseContext> CreateContext(
200+
const DeviceOption& option) override {
201+
return caffe2::make_unique<CPUContext>(option);
202+
}
203+
197204
DeviceType GetDeviceType() override {
198205
return CPU;
199206
}

caffe2/core/context_base.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "context_base.h"
22

33
namespace caffe2 {
4-
54
} // namespace caffe2

caffe2/core/context_base.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@
55
#include "caffe2/core/common.h"
66
#include "caffe2/core/logging.h"
77
#include "caffe2/proto/caffe2_pb.h"
8-
9-
namespace caffe2 {} // namespace caffe2

caffe2/core/context_gpu.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,6 @@ CAFFE2_DEFINE_int(
5757
128,
5858
"The threshold in MB on how frequently to report memory changes");
5959

60-
namespace at {
61-
62-
REGISTER_CONTEXT(DeviceType::CUDA, caffe2::CUDAContext);
63-
} // namespace at
64-
6560
namespace caffe2 {
6661

6762
ThreadLocalCUDAObjects& CUDAContext::getCudaObjects() {

caffe2/core/context_gpu.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,6 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
142142
// The default cuda context constructor.
143143
explicit CUDAContext(const int gpu_id = -1);
144144
explicit CUDAContext(const DeviceOption& option);
145-
explicit CUDAContext(const at::Device& device)
146-
: CUDAContext(DeviceToOption(device)) {}
147145

148146
~CUDAContext() override {
149147
if (curand_generator_) {
@@ -387,6 +385,19 @@ class CAFFE2_CUDA_API CUDAStaticContext final : public BaseStaticContext {
387385
public:
388386
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;
389387

388+
std::unique_ptr<BaseContext> CreateContext() override {
389+
return caffe2::make_unique<CUDAContext>();
390+
}
391+
392+
std::unique_ptr<BaseContext> CreateContext(
393+
const DeviceOption& option) override {
394+
return caffe2::make_unique<CUDAContext>(option);
395+
}
396+
397+
std::unique_ptr<BaseContext> CreateContext(int gpu_id = -1) {
398+
return caffe2::make_unique<CUDAContext>(gpu_id);
399+
}
400+
390401
DeviceType GetDeviceType() override {
391402
return CUDA;
392403
}

caffe2/core/hip/context_hip.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ CAFFE2_DEFINE_int(caffe2_gpu_memory_report_interval_mb,
5050
128,
5151
"The threshold in MB on how frequently to report memory changes");
5252

53-
namespace at {
54-
55-
REGISTER_CONTEXT(DeviceType::HIP, caffe2::HIPContext);
56-
} // namespace at
57-
5853
namespace caffe2 {
5954

6055
thread_local ThreadLocalHIPObjects HIPContext::hip_objects_;
@@ -413,12 +408,13 @@ void HIPStaticContext::Delete(void* ptr) {
413408
g_hip_device_affiliation.erase(it);
414409
break;
415410
}
416-
case HipMemoryPoolType::THC: {
417-
HIP_ENFORCE(g_thc_allocator->Free(ptr));
418-
if (FLAGS_caffe2_gpu_memory_tracking) {
419-
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
420-
}
421-
break;
411+
case HipMemoryPoolType::THC:
412+
{
413+
HIP_ENFORCE(g_thc_allocator->Free(ptr));
414+
if (FLAGS_caffe2_gpu_memory_tracking) {
415+
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
416+
}
417+
break;
422418
}
423419
}
424420
}

0 commit comments

Comments
 (0)