Skip to content

Commit aebf3b4

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Remove template parameter from Tensor (pytorch#9939)
Summary: Pull Request resolved: pytorch#9939 Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13 Pull Request resolved: pytorch/translate#166 Pull Request resolved: pytorch#9125 Closes pytorch#9125 Use inheritance for polymorphism, and remove template parameter This is to change the templating in call sites, the core implementations will change later Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are: 1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)), 2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided. 3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type 4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s. Reviewed By: ezyang, houseroad Differential Revision: D9024330 fbshipit-source-id: e0b8295d2dc6ebe2963383ded5af799ad17164ba
1 parent 94439d7 commit aebf3b4

File tree

365 files changed

+4187
-3515
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

365 files changed

+4187
-3515
lines changed

binaries/benchmark_helper.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void loadInput(
160160
CAFFE_THROW("Not support GPU on mobile.");
161161
#endif
162162
} else {
163-
caffe2::TensorCPU* tensor = blob->GetMutable<caffe2::TensorCPU>();
163+
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
164164
CHECK_NOTNULL(tensor);
165165
tensor->Resize(input_dims);
166166
if (input_type_list[i] == "uint8_t") {
@@ -197,7 +197,7 @@ void fillInputBlob(
197197
int protos_size = tensor_kv.second.protos_size();
198198
caffe2::TensorProto* tensor_proto =
199199
tensor_kv.second.mutable_protos(iteration % protos_size);
200-
caffe2::TensorCPU* tensor = blob->GetMutable<caffe2::TensorCPU>();
200+
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
201201
tensor->Resize(std::vector<caffe2::TIndex>());
202202
if (tensor_proto->data_type() == caffe2::TensorProto::STRING) {
203203
(tensor->mutable_data<std::string>())[0] = tensor_proto->string_data(0);
@@ -290,7 +290,7 @@ void writeOutput(
290290
#endif
291291
} else {
292292
writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
293-
workspace->GetBlob(name)->GetMutable<caffe2::TensorCPU>(),
293+
workspace->GetBlob(name)->GetMutableTensor(caffe2::CPU),
294294
output_prefix,
295295
name);
296296
}

binaries/benchmark_helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void writeTextOutput(
3535
const string& output_prefix,
3636
const string& name) {
3737
string output_name = output_prefix + "/" + name + ".txt";
38-
caffe2::TensorSerializer<ContextType> ser;
38+
caffe2::TensorSerializer ser;
3939
caffe2::BlobProto blob_proto;
4040
ser.Serialize(
4141
*tensor, output_name, blob_proto.mutable_tensor(), 0, tensor->size());

binaries/core_overhead_benchmark.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ BENCHMARK(BM_cudaStreamWaitEventThenStreamSynchronize);
139139

140140
static void BM_CudaPointerAffinity(benchmark::State& state) {
141141
CAFFE2_SKIP_IF_NO_GPU;
142-
TensorCUDA tensor(vector<TIndex>{1, 2, 3, 4});
142+
Tensor tensor(vector<TIndex>{1, 2, 3, 4}, CUDA);
143143
float* ptr = tensor.mutable_data<float>();
144144
while (state.KeepRunning()) {
145145
volatile int id = GetGPUIDForPointer(ptr);
@@ -198,7 +198,7 @@ static void BM_RawAllocDeallocCPU(benchmark::State& state) {
198198
BENCHMARK(BM_RawAllocDeallocCPU);
199199

200200
static void BM_TensorAllocDeallocCPU(benchmark::State& state) {
201-
Tensor<CPUContext> tensor;
201+
Tensor tensor(CPU);
202202
// small allocation
203203
tensor.Resize(32, 32);
204204
while (state.KeepRunning()) {
@@ -210,7 +210,7 @@ BENCHMARK(BM_TensorAllocDeallocCPU);
210210

211211
static void BM_TensorAllocDeallocCUDA(benchmark::State& state) {
212212
CAFFE2_SKIP_IF_NO_GPU;
213-
Tensor<CUDAContext> tensor;
213+
Tensor tensor(CUDA);
214214
// small allocation
215215
tensor.Resize(32, 32);
216216
while (state.KeepRunning()) {

binaries/print_core_object_sizes.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828

2929
int main(int /* unused */, char** /* unused */) {
3030
PRINT_SIZE(caffe2::Blob);
31-
PRINT_SIZE(caffe2::Tensor<caffe2::CPUContext>);
32-
PRINT_SIZE(caffe2::Tensor<caffe2::CUDAContext>);
31+
PRINT_SIZE(caffe2::Tensor);
3332
PRINT_SIZE(caffe2::CPUContext);
3433
PRINT_SIZE(caffe2::CUDAContext);
3534
PRINT_SIZE(caffe2::OperatorBase);

binaries/speed_benchmark.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ int main(int argc, char** argv) {
136136
if (blob == nullptr) {
137137
blob = workspace->CreateBlob(input_names[i]);
138138
}
139-
caffe2::TensorCPU* tensor = blob->GetMutable<caffe2::TensorCPU>();
139+
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
140140
CHECK_NOTNULL(tensor);
141141
tensor->Resize(input_dims);
142142
if (input_type_list[i] == "uint8_t") {

caffe2/contrib/aten/aten_op_template.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ class ATenOp : public Operator<Context> {
5454
#undef DEFINE_CASE
5555
}
5656

57-
at::Type & typeFor(const Tensor<Context> & ten) {
57+
at::Type& typeFor(const Tensor& ten) {
5858
return at::getType(backend(), atScalarTypeFor(ten.meta()));
5959
}
60-
at::Tensor tensorWrapping(const Tensor<Context>& ten_) {
61-
auto& ten = const_cast<Tensor<Context>&>(ten_);
60+
at::Tensor tensorWrapping(const Tensor& ten_) {
61+
auto& ten = const_cast<Tensor&>(ten_);
6262
return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.dims());
6363
}
6464

@@ -88,7 +88,7 @@ class ATenOp : public Operator<Context> {
8888
}
8989
CAFFE_THROW("Unknown type meta"); // TODO: improve error message...
9090
}
91-
void assignTo(Tensor<Context> * dst, const at::Tensor & src_) {
91+
void assignTo(Tensor* dst, const at::Tensor& src_) {
9292
at::Tensor src = src_.contiguous();
9393
auto at_sizes = src.sizes();
9494
std::vector<int64_t> dims(at_sizes.begin(),at_sizes.end());
@@ -121,7 +121,7 @@ class ATenOp : public Operator<Context> {
121121
return s.toLong();
122122
}
123123

124-
void assignTo(Tensor<Context> * dst, at::Type & inferred_type, at::Scalar scalar) {
124+
void assignTo(Tensor* dst, at::Type& inferred_type, at::Scalar scalar) {
125125
switch(inferred_type.scalarType()) {
126126
#define DEFINE_CASE(ctype,aten_name,native) \
127127
case at::k##aten_name: { \
@@ -134,8 +134,8 @@ class ATenOp : public Operator<Context> {
134134
CAFFE_THROW("Unknown ATen Type");
135135
}
136136
}
137-
template<typename T>
138-
void assignToValue(Tensor<Context> * dst, T v) {
137+
template <typename T>
138+
void assignToValue(Tensor* dst, T v) {
139139
dst->Resize(std::vector<TIndex>());
140140
math::Set(1, v, dst->template mutable_data<T>(), &context_);
141141
}

caffe2/contrib/gloo/common.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace caffe2 {
1212
namespace gloo {
1313

1414
void signalFailure(Blob* status_blob, std::exception& /* unused */) {
15-
auto* res = status_blob->GetMutable<TensorCPU>();
15+
auto* res = status_blob->GetMutableTensor(CPU);
1616
res->Resize(1);
1717
res->template mutable_data<int32_t>()[0] = 1;
1818
}

caffe2/contrib/nccl/cuda_nccl_op_gpu.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ nccl::NCCLExecution getNCCLElements(
1717
ex.elements.resize(op->InputSize());
1818
for (auto i = 0; i < op->InputSize(); ++i) {
1919
auto& el = ex.elements[i];
20-
el.src = &(op->Input<TensorCUDA>(i));
20+
el.src = &(op->Input<Tensor>(i, CUDA));
2121
if (op->OutputSize() == 1) {
2222
// Reduce op
2323
if (i == ex.root) {
24-
el.dst = op->Output<TensorCUDA>(0);
24+
el.dst = op->Output<Tensor>(0, CUDA);
2525
}
2626
} else if (i < op->OutputSize()) {
27-
el.dst = op->Output<TensorCUDA>(i);
27+
el.dst = op->Output<Tensor>(i, CUDA);
2828
}
2929
// TODO - expensive (>1ms) - cache these.
30-
el.device = GetGPUIDForPointer(op->Input<TensorCUDA>(i).raw_data());
30+
el.device = GetGPUIDForPointer(op->Input<Tensor>(i, CUDA).raw_data());
3131
}
3232

3333
return ex;
@@ -38,7 +38,7 @@ namespace {
3838
template <typename T>
3939
bool AllInputsAre(OperatorBase* op) {
4040
for (auto i = 0; i < op->InputSize(); ++i) {
41-
if (op->Input<TensorCUDA>(i).IsType<T>()) {
41+
if (op->Input<Tensor>(i, CUDA).IsType<T>()) {
4242
continue;
4343
} else {
4444
return false;

caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ static void AddConstInput(const std::vector<int>& shape, const float value,
2222
option.set_device_type(CUDA);
2323
CUDAContext context(option);
2424
Blob* blob = ws->CreateBlob(name);
25-
auto* tensor = blob->GetMutable<Tensor<CUDAContext>>();
25+
auto* tensor = blob->GetMutableTensor(CUDA);
2626
tensor->Resize(shape);
2727
math::Set<float, CUDAContext>(tensor->size(), value,
2828
tensor->mutable_data<float>(),
@@ -54,8 +54,8 @@ TEST(NervanaFullyConnectedTest, Test) {
5454
EXPECT_TRUE(op->Run());
5555
Blob* Yblob = ws.GetBlob("Y");
5656
EXPECT_NE(nullptr, Yblob);
57-
auto& Y = Yblob->Get<Tensor<CUDAContext>>();
58-
TensorCPU Y_cpu(Y);
57+
auto& Y = Yblob->Get<Tensor>();
58+
Tensor Y_cpu(Y, CPU);
5959
EXPECT_EQ(Y.size(), 5 * 6);
6060
for (int i = 0; i < Y.size(); ++i) {
6161
CHECK_LT(Y_cpu.data<float>()[i], 10.11);

caffe2/contrib/warpctc/ctc_op.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,26 @@ class CTCOp final : public Operator<Context> {
4747
const auto& inputs = Input(INPUTS);
4848
const auto minibatchSize = inputs.dim(1);
4949
const auto alphabetSize = inputs.dim(2);
50-
const auto& labels = OperatorBase::template Input<TensorCPU>(LABELS);
50+
const auto& labels = OperatorBase::template Input<Tensor>(LABELS, CPU);
5151
const auto& labelLengths =
52-
OperatorBase::template Input<TensorCPU>(LABEL_LENGTHS);
52+
OperatorBase::template Input<Tensor>(LABEL_LENGTHS, CPU);
5353
const auto& inputLengths =
54-
OperatorBase::template Input<TensorCPU>(INPUT_LENGTHS);
54+
OperatorBase::template Input<Tensor>(INPUT_LENGTHS, CPU);
5555

5656
// outputs
57-
Tensor<Context>* gradients = nullptr;
57+
Tensor* gradients = nullptr;
5858
TensorCPU* costs;
59-
Tensor<Context>* workspace;
59+
Tensor* workspace;
6060
if (!is_test_) {
6161
// [grads, costs, workspace] to maintain backward compatibility
6262
gradients = Output(0);
6363
gradients->ResizeLike(inputs);
64-
costs = OperatorBase::template Output<TensorCPU>(1);
64+
costs = OperatorBase::template Output<Tensor>(1, CPU);
6565
costs->ResizeLike(labelLengths);
6666
workspace = Output(2);
6767
} else {
6868
// [costs, workspace]
69-
costs = OperatorBase::template Output<TensorCPU>(0);
69+
costs = OperatorBase::template Output<Tensor>(0, CPU);
7070
costs->ResizeLike(labelLengths);
7171
workspace = Output(1);
7272
}

caffe2/core/allocator.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void SetCPUAllocator(CPUAllocator* alloc) {
2626
g_cpu_allocator.reset(alloc);
2727
}
2828

29-
MemoryAllocationReporter CPUContext::reporter_;
29+
MemoryAllocationReporter CPUStaticContext::reporter_;
3030

3131
void MemoryAllocationReporter::New(void* ptr, size_t nbytes) {
3232
std::lock_guard<std::mutex> guard(mutex_);

caffe2/core/blob.h

+28-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
#include "caffe2/core/blob_serializer_base.h"
1111
#include "caffe2/core/common.h"
12-
#include "caffe2/core/typeid.h"
1312
#include "caffe2/core/logging.h"
13+
#include "caffe2/core/tensor.h"
14+
#include "caffe2/core/typeid.h"
1415
#include "caffe2/proto/caffe2.pb.h"
1516

1617
namespace caffe2 {
@@ -60,6 +61,20 @@ class Blob {
6061
template <class T>
6162
bool IsType() const { return meta_.Match<T>(); }
6263

64+
// TODO(jerryzh): Remove template
65+
template <class T>
66+
bool IsType(DeviceType device_type) const {
67+
static_assert(
68+
std::is_same<T, Tensor>::value,
69+
"IsType(DeviceType) only available on "
70+
"Tensor types.");
71+
auto* tensor = static_cast<Tensor*>(pointer_);
72+
if (tensor && tensor->GetDeviceType() == device_type) {
73+
return true;
74+
}
75+
return false;
76+
}
77+
6378
/**
6479
* Returns the meta info of the blob.
6580
*/
@@ -74,6 +89,7 @@ class Blob {
7489
* @brief Gets the const reference of the stored object. The code checks if
7590
* the stored object is of the desired type.
7691
*/
92+
// TODO(jerryzh): add a Get(DeviceType) function?
7793
template <class T>
7894
const T& Get() const {
7995
CAFFE_ENFORCE(
@@ -123,6 +139,17 @@ class Blob {
123139
}
124140
}
125141

142+
inline Tensor* GetMutableTensor(DeviceType device_type) {
143+
if (IsType<Tensor>() &&
144+
static_cast<Tensor*>(pointer_)->GetDeviceType() == device_type) {
145+
return static_cast<Tensor*>(pointer_);
146+
} else {
147+
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
148+
<< " DeviceType:" << device_type;
149+
return Reset<Tensor>(new Tensor(device_type));
150+
}
151+
}
152+
126153
/**
127154
* Sets the underlying object to the allocated one. The Blob then takes over
128155
* the ownership of the passed in pointer. If there is already an object in

0 commit comments

Comments
 (0)