Skip to content

Commit e5e0bf4

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
Add AdjustBatch Op (pytorch#16676)
Summary: Pull Request resolved: pytorch#16676 This op is used for changing batch size (first dimension) of the tensor. Reviewed By: bertmaher, ipiszy Differential Revision: D13929200 fbshipit-source-id: 4f2c3faec072d468be8301bf00c80d33adb3b5b3
1 parent 100aa07 commit e5e0bf4

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

caffe2/operators/adjust_batch_op.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "caffe2/operators/adjust_batch_op.h"
2+
3+
namespace caffe2 {
4+
REGISTER_CPU_OPERATOR(AdjustBatch, AdjustBatchOp<CPUContext>);
5+
OPERATOR_SCHEMA(AdjustBatch)
6+
.NumInputs(1, 2)
7+
.NumOutputs(1, 2)
8+
.Input(0, "Input", "Input data")
9+
.Input(1, "RealBatchSizeIn", "[Optional] Real batch size")
10+
.Output(0, "Output", "Data with Adjusted batch size")
11+
.Output(1, "RealBatchSizeOut", "[Optional] Real batah size")
12+
.Arg("max_batch_size", "(*int*): max batch size")
13+
.SetDoc(R"DOC(
14+
Adjust the batch size of `input` tensor. When we only have 1 input, it will adjust the batch size according to `max_batch_size` argument. In this case, in addition, if it has two outputs, it will record the input batch size and record it to the second output. When we have 2 inputs, it expects the seocnd input contains the batch size to adjust to, and will truncate the input data accordingly.
15+
16+
Github Links:
17+
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/adjust_batch_op.cc
18+
19+
)DOC");
20+
} // namespace caffe2

caffe2/operators/adjust_batch_op.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#pragma once
2+
3+
#include "caffe2/core/context.h"
4+
#include "caffe2/core/operator.h"
5+
6+
namespace caffe2 {
7+
8+
template <class Context>
9+
class AdjustBatchOp final : public Operator<Context> {
10+
public:
11+
USE_OPERATOR_CONTEXT_FUNCTIONS;
12+
AdjustBatchOp(const OperatorDef& operator_def, Workspace* ws)
13+
: Operator<Context>(operator_def, ws),
14+
max_batch_size_(
15+
this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
16+
17+
bool RunOnDevice() override {
18+
auto& input = Input(0);
19+
vector<int64_t> output_dims(input.sizes().vec());
20+
CAFFE_ENFORCE(!output_dims.empty());
21+
if (InputSize() > 1) {
22+
// TODO: if we have a second input and we have max_batch_size set, check
23+
// the batch size of the two inputs for consistency
24+
auto& batch_size = Input(1);
25+
int64_t real_batch_size = *batch_size.template data<int64_t>();
26+
int64_t max_batch_size = output_dims[0];
27+
CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
28+
output_dims[0] = real_batch_size;
29+
auto* output = Output(0, output_dims, input.dtype());
30+
this->context_.template CopyItems<Context, Context>(
31+
input.dtype(),
32+
input.numel() * real_batch_size / max_batch_size,
33+
input.raw_data(),
34+
output->raw_mutable_data(input.dtype()));
35+
} else {
36+
// Pad to max batch size
37+
CAFFE_ENFORCE_GT(
38+
max_batch_size_,
39+
0,
40+
"max_batch_size should be larger than 0. Got ",
41+
max_batch_size_);
42+
43+
// TODO: ideally we can support the case when input batch is larger than
44+
// the max_batch_size, as we can just pad to the multiple of
45+
// max_batch_size.
46+
CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
47+
48+
int64_t real_batch_size = output_dims[0];
49+
output_dims[0] = max_batch_size_;
50+
auto* output = Output(0, output_dims, input.dtype());
51+
math::Set(
52+
output->nbytes(),
53+
static_cast<char>(0),
54+
static_cast<char*>(output->raw_data()),
55+
&context_);
56+
this->context_.template CopyItems<Context, Context>(
57+
input.dtype(),
58+
input.numel(),
59+
input.raw_data(),
60+
output->raw_mutable_data(input.dtype()));
61+
62+
if (OutputSize() > 1) {
63+
auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
64+
real_batch_tensor->template mutable_data<int64_t>()[0] =
65+
real_batch_size;
66+
}
67+
}
68+
69+
return true;
70+
}
71+
72+
private:
73+
int64_t max_batch_size_;
74+
};
75+
} // namespace caffe2
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
from caffe2.python import core, workspace
7+
from hypothesis import given, assume
8+
import caffe2.python.hypothesis_test_util as hu
9+
import hypothesis.strategies as st
10+
import numpy as np
11+
12+
import unittest
13+
import os
14+
15+
16+
class TestAdjustBatchOp(hu.HypothesisTestCase):
17+
@given(d=st.integers(1, 4), n=st.integers(1, 20),
18+
seed=st.integers(0, 1000), **hu.gcs_cpu_only)
19+
def test_pad(self, d, n, gc, dc, seed):
20+
for dtype in [np.float32, np.int8, np.int64]:
21+
np.random.seed(seed)
22+
dims = [n] * d
23+
X = np.random.rand(*dims).astype(dtype)
24+
max_batch_size = n + 8
25+
26+
def ref_op(X):
27+
shape = list(X.shape)
28+
out = np.zeros((1), dtype=np.int64)
29+
out[0] = shape[0]
30+
shape[0] = max_batch_size
31+
Y = np.zeros(shape, dtype=dtype)
32+
Y[:n] = X
33+
return [Y, out]
34+
35+
op = core.CreateOperator(
36+
"AdjustBatch",
37+
["X"],
38+
["Y", "RealBatch"],
39+
max_batch_size=max_batch_size,
40+
)
41+
42+
self.assertReferenceChecks(
43+
device_option=gc,
44+
op=op,
45+
inputs=[X],
46+
reference=ref_op,
47+
)
48+
49+
@given(d=st.integers(1, 4), n=st.integers(8, 20),
50+
seed=st.integers(0, 1000), **hu.gcs_cpu_only)
51+
def test_truncate(self, d, n, gc, dc, seed):
52+
for dtype in [np.float32, np.int8, np.int64]:
53+
np.random.seed(seed)
54+
dims = [n] * d
55+
X = np.random.rand(*dims).astype(dtype)
56+
real_batch_size = n - 8
57+
R = np.zeros((1), dtype=np.int64)
58+
R[0] = real_batch_size
59+
60+
def ref_op(X, R):
61+
r = R[0]
62+
return [X[:r]]
63+
64+
op = core.CreateOperator(
65+
"AdjustBatch",
66+
["X", "RealBatch"],
67+
["Y"],
68+
)
69+
70+
self.assertReferenceChecks(
71+
device_option=gc,
72+
op=op,
73+
inputs=[X, R],
74+
reference=ref_op,
75+
)

0 commit comments

Comments
 (0)