Skip to content

Commit eea2ee6

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Renaming size() to numel() - 1/17
Summary: Codemod generated with clangr shard mode, 25 files per diff Reviewed By: li-roy Differential Revision: D10866237 fbshipit-source-id: 020fcfdf52083430c5b674eda8e07ad3adfcc838
1 parent 06392bd commit eea2ee6

21 files changed

+132
-107
lines changed

binaries/benchmark_helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void writeTextOutput(
4141
caffe2::BlobProto blob_proto;
4242

4343
ser.Serialize(
44-
*tensor, output_name, blob_proto.mutable_tensor(), 0, tensor->size());
44+
*tensor, output_name, blob_proto.mutable_tensor(), 0, tensor->numel());
4545
blob_proto.set_name(output_name);
4646
blob_proto.set_type("Tensor");
4747
CAFFE_ENFORCE(blob_proto.has_tensor());

caffe2/contrib/gloo/allgather_ops.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class AllgatherOp final : public Operator<Context> {
7474
CAFFE_ENFORCE_EQ(OutputSize(), 1);
7575
auto comm_size =
7676
OperatorBase::Input<std::shared_ptr<::gloo::Context>>(0)->size;
77-
const auto dims =
78-
std::vector<int64_t>(1, (InputSize() - 1) * Input(1).size() * comm_size);
77+
const auto dims = std::vector<int64_t>(
78+
1, (InputSize() - 1) * Input(1).numel() * comm_size);
7979
Output(0)->Resize(dims);
8080

8181
// Store which inputs/outputs this instance initialized with
@@ -84,9 +84,9 @@ class AllgatherOp final : public Operator<Context> {
8484
CAFFE_ENFORCE_EQ(init_.outputs.size(), 1);
8585

8686
// Verify tensors all have same size
87-
size_t size = Input(1).size();
87+
size_t size = Input(1).numel();
8888
for (auto i = 2; i < InputSize(); i++) {
89-
CAFFE_ENFORCE_EQ(Input(i).size(), size);
89+
CAFFE_ENFORCE_EQ(Input(i).numel(), size);
9090
}
9191

9292
// Verify tensors all have same type
@@ -111,7 +111,7 @@ class AllgatherOp final : public Operator<Context> {
111111
void update(GlooParameters& params) {
112112
params.context = OperatorBase::Input<std::shared_ptr<::gloo::Context>>(0);
113113
params.inputs.resize(InputSize() - 1);
114-
params.size = Input(1).size();
114+
params.size = Input(1).numel();
115115
params.meta = Input(1).meta();
116116
for (auto i = 0; i < params.inputs.size(); i++) {
117117
params.inputs[i] = Input(i + 1).raw_data();

caffe2/contrib/gloo/allreduce_ops.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ class AllreduceOp final : public Operator<Context> {
7171
}
7272

7373
// Verify tensors all have same size
74-
size_t size = Input(1).size();
74+
size_t size = Input(1).numel();
7575
for (auto i = 2; i < InputSize(); i++) {
76-
CAFFE_ENFORCE_EQ(Input(i).size(), size);
76+
CAFFE_ENFORCE_EQ(Input(i).numel(), size);
7777
}
7878

7979
// Verify tensors all have same type
@@ -120,7 +120,7 @@ class AllreduceOp final : public Operator<Context> {
120120
params.inputs[i] = Input(i + 1).raw_data();
121121
params.outputs[i] = Output(i)->raw_mutable_data();
122122
}
123-
params.size = Output(0)->size();
123+
params.size = Output(0)->numel();
124124
params.meta = Output(0)->meta();
125125
}
126126

caffe2/contrib/gloo/broadcast_ops.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ class BroadcastOp final : public Operator<Context> {
6565
}
6666

6767
// Verify tensors all have same size
68-
size_t size = Input(1).size();
68+
size_t size = Input(1).numel();
6969
for (auto i = 2; i < InputSize(); i++) {
70-
CAFFE_ENFORCE_EQ(Input(i).size(), size);
70+
CAFFE_ENFORCE_EQ(Input(i).numel(), size);
7171
}
7272

7373
// Verify tensors all have same size
@@ -98,7 +98,7 @@ class BroadcastOp final : public Operator<Context> {
9898
params.inputs[i] = Input(i + 1).raw_data();
9999
params.outputs[i] = Output(i)->raw_mutable_data();
100100
}
101-
params.size = Output(0)->size();
101+
params.size = Output(0)->numel();
102102
params.meta = Output(0)->meta();
103103
}
104104

caffe2/contrib/gloo/reduce_scatter_ops.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ class ReduceScatterOp final : public Operator<Context> {
8080
}
8181

8282
// Verify tensors all have same size
83-
size_t size = Input(1).size();
83+
size_t size = Input(1).numel();
8484
for (auto i = 2; i < InputSize() - 1; i++) {
85-
CAFFE_ENFORCE_EQ(Input(i).size(), size);
85+
CAFFE_ENFORCE_EQ(Input(i).numel(), size);
8686
}
8787

8888
// Verify tensors all have same type
@@ -111,13 +111,13 @@ class ReduceScatterOp final : public Operator<Context> {
111111
params.inputs[i] = Input(i + 1).raw_data();
112112
params.outputs[i] = Output(i)->raw_mutable_data();
113113
}
114-
params.size = Output(0)->size();
114+
params.size = Output(0)->numel();
115115
params.meta = Output(0)->meta();
116116

117117
// Verify recvCountsSize == comm_size
118-
CAFFE_ENFORCE_EQ(Input(InputSize() - 1).size(), params.context->size);
118+
CAFFE_ENFORCE_EQ(Input(InputSize() - 1).numel(), params.context->size);
119119
int* recvCounts = (int*)Input(InputSize() - 1).raw_data();
120-
recvCounts_.assign(recvCounts, recvCounts + Input(InputSize() - 1).size());
120+
recvCounts_.assign(recvCounts, recvCounts + Input(InputSize() - 1).numel());
121121
}
122122

123123
GlooParameters init_;

caffe2/contrib/nccl/cuda_nccl_gpu.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void NCCL<T>::AllReduce(const NCCLExecution& ex) {
212212
CAFFE_NCCL_CHECK(ncclAllReduce(
213213
ctx.src->raw_data(),
214214
ctx.dst->raw_mutable_data(),
215-
ctx.dst->size(),
215+
ctx.dst->numel(),
216216
ncclTypeWrapper<T>::type,
217217
ncclSum,
218218
comm,
@@ -231,7 +231,7 @@ void NCCL<T>::Broadcast(const NCCLExecution& ex) {
231231
[&ex](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
232232
CAFFE_NCCL_CHECK(ncclBcast(
233233
ctx.dst->raw_mutable_data(),
234-
ctx.dst->size(),
234+
ctx.dst->numel(),
235235
ncclTypeWrapper<T>::type,
236236
ex.root,
237237
comm,
@@ -315,7 +315,7 @@ void NCCL<T>::ReduceScatter(const NCCLExecution& ex) {
315315
CAFFE_NCCL_CHECK(ncclReduceScatter(
316316
ctx.src->raw_data(),
317317
ctx.dst->raw_mutable_data(),
318-
ctx.dst->size(),
318+
ctx.dst->numel(),
319319
ncclTypeWrapper<T>::type,
320320
ncclSum,
321321
comm,

caffe2/contrib/nnpack/nnpack_ops.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class NNPACKConvOp final : public ConvPoolOpBase<CPUContext> {
131131
CAFFE_ENFORCE(filter.dim32(1) == C / this->group_, "");
132132
CAFFE_ENFORCE(filter.dim32(2) == this->kernel_h(), "");
133133
CAFFE_ENFORCE(filter.dim32(3) == this->kernel_w(), "");
134-
CAFFE_ENFORCE(bias.size() == M, "");
134+
CAFFE_ENFORCE(bias.numel() == M, "");
135135

136136
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
137137
const int oH = Y->dim32(2), oW = Y->dim32(3);
@@ -180,8 +180,8 @@ class NNPACKConvOp final : public ConvPoolOpBase<CPUContext> {
180180
kernel_size,
181181
output_subsample,
182182
X.template data<float>() + g * H * W * (C / group_),
183-
filter.template data<float>() + filter.size() / group_ * g,
184-
bias.template data<float>() + bias.size() / group_ * g,
183+
filter.template data<float>() + filter.numel() / group_ * g,
184+
bias.template data<float>() + bias.numel() / group_ * g,
185185
Y->template mutable_data<float>() + g * oH * oW * (M / group_),
186186
nnpack_threadpool(),
187187
nullptr);
@@ -199,8 +199,8 @@ class NNPACKConvOp final : public ConvPoolOpBase<CPUContext> {
199199
padding,
200200
kernel_size,
201201
X.template data<float>() + g * H * W * (C / group_),
202-
filter.template data<float>() + filter.size() / group_ * g,
203-
bias.template data<float>() + bias.size() / group_ * g,
202+
filter.template data<float>() + filter.numel() / group_ * g,
203+
bias.template data<float>() + bias.numel() / group_ * g,
204204
Y->template mutable_data<float>() + g * oH * oW * (M / group_),
205205
nnpack_threadpool(),
206206
nullptr);
@@ -306,7 +306,7 @@ class NNPACKReluOp final : public Operator<CPUContext> {
306306
auto* Y = Output(0);
307307
const auto status = nnp_relu_output(
308308
1,
309-
X.size(),
309+
X.numel(),
310310
X.template data<float>(),
311311
Y->template mutable_data<float>(),
312312
0.0,
@@ -332,7 +332,7 @@ class NNPACKLeakyReluOp final : public LeakyReluOp<float, CPUContext> {
332332
auto* Y = Output(0);
333333
const auto status = nnp_relu_output(
334334
1,
335-
X.size(),
335+
X.numel(),
336336
X.template data<float>(),
337337
Y->template mutable_data<float>(),
338338
alpha_,

caffe2/contrib/warpctc/ctc_op.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class CTCOp final : public Operator<Context> {
9696
if (is_test_ && labels.dim(0) == 0) {
9797
// compute_ctc_loss doesn't handle empty labels well
9898
T* costsData = costs->template mutable_data<T>();
99-
for (int i = 0; i < costs->size(); ++i) {
99+
for (int i = 0; i < costs->numel(); ++i) {
100100
costsData[i] = 0;
101101
}
102102
return true;

caffe2/core/blob_gpu_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ TYPED_TEST_CASE(TensorGPUDeathTest, TensorTypes);
1919
TYPED_TEST(TensorGPUTest, TensorInitializedEmpty) {
2020
if (!caffe2::HasCudaGPU()) return;
2121
Tensor tensor(CUDA);
22-
EXPECT_EQ(tensor.size(), 0);
22+
EXPECT_EQ(tensor.numel(), 0);
2323
EXPECT_EQ(tensor.ndim(), 1);
2424
vector<int> dims(3);
2525
dims[0] = 2;
@@ -119,7 +119,7 @@ TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
119119
::testing::FLAGS_gtest_death_test_style = "threadsafe";
120120
Tensor tensor(CUDA);
121121
EXPECT_EQ(tensor.ndim(), 1);
122-
EXPECT_EQ(tensor.size(), 0);
122+
EXPECT_EQ(tensor.numel(), 0);
123123
EXPECT_THROW(tensor.data<TypeParam>(), EnforceNotMet);
124124
}
125125

caffe2/core/blob_serialization.cc

+11-11
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void TensorSerializer::SerializeWithChunkSize(
119119
CAFFE_ENFORCE(typeMeta.Match<Tensor>());
120120
const auto& tensor = *static_cast<const Tensor*>(pointer);
121121
if (chunk_size == kNoChunking) {
122-
chunk_size = tensor.size() + 1; // to account for empty tensors
122+
chunk_size = tensor.numel() + 1; // to account for empty tensors
123123
} else if (chunk_size == kDefaultChunkSize) {
124124
chunk_size = FLAGS_caffe2_tensor_chunk_size;
125125
}
@@ -147,7 +147,7 @@ void TensorSerializer::SerializeWithChunkSize(
147147
processChunk(chunkStart);
148148
}
149149
};
150-
if (tensor.size() > chunk_size) {
150+
if (tensor.numel() > chunk_size) {
151151
for (int i = 0; i < FLAGS_caffe2_max_tensor_serializer_threads; ++i) {
152152
futures.emplace_back(std::async(std::launch::async, task));
153153
}
@@ -158,11 +158,11 @@ void TensorSerializer::SerializeWithChunkSize(
158158
// Serialize whole vector. If vector is empty, it's shape still needs to be
159159
// serialized in empty proto
160160
for (size_t chunkBegin = 0;
161-
chunkBegin < std::max(tensor.size(), static_cast<int64_t>(1));
161+
chunkBegin < std::max(tensor.numel(), static_cast<int64_t>(1));
162162
chunkBegin += chunk_size) {
163163
VLOG(2) << "Starting a chunk at " << chunkBegin;
164164
#ifndef __ANDROID__
165-
if (tensor.size() > chunk_size) {
165+
if (tensor.numel() > chunk_size) {
166166
chunkQueue.Push(chunkBegin);
167167
} else {
168168
// Sync mode for small tensors
@@ -189,13 +189,13 @@ void TensorSerializer::Serialize(
189189
size_t chunkBegin,
190190
int32_t chunkSize) {
191191
CAFFE_ENFORCE(
192-
chunkBegin <= input.size(),
192+
chunkBegin <= input.numel(),
193193
"Chunk begin is out of tensor: ",
194194
chunkBegin,
195195
' ',
196-
input.size());
197-
if (chunkBegin + chunkSize > input.size()) {
198-
chunkSize = input.size() - chunkBegin;
196+
input.numel());
197+
if (chunkBegin + chunkSize > input.numel()) {
198+
chunkSize = input.numel() - chunkBegin;
199199
}
200200

201201
if (chunkSize != 0) {
@@ -408,19 +408,19 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
408408
tensor->Resize(dims);
409409

410410
int64_t chunkBegin = 0;
411-
auto chunkEnd = tensor->size();
411+
auto chunkEnd = tensor->numel();
412412
if (proto.has_segment()) {
413413
chunkBegin = proto.segment().begin();
414414
chunkEnd = proto.segment().end();
415415
}
416416
CAFFE_ENFORCE(
417-
0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->size(),
417+
0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->numel(),
418418
"Invalid chunk ",
419419
chunkBegin,
420420
' ',
421421
chunkEnd,
422422
" with total tensor size ",
423-
tensor->size());
423+
tensor->numel());
424424
auto chunkSize = chunkEnd - chunkBegin;
425425

426426
switch (proto.data_type()) {

0 commit comments

Comments
 (0)