forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadjust_batch_op.h
76 lines (67 loc) · 2.46 KB
/
adjust_batch_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#pragma once
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class AdjustBatchOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit AdjustBatchOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
max_batch_size_(
this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
bool RunOnDevice() override {
auto& input = Input(0);
vector<int64_t> output_dims(input.sizes().vec());
CAFFE_ENFORCE(!output_dims.empty());
if (InputSize() > 1) {
// TODO: if we have a second input and we have max_batch_size set, check
// the batch size of the two inputs for consistency
auto& batch_size = Input(1);
int64_t real_batch_size = *batch_size.template data<int64_t>();
int64_t max_batch_size = output_dims[0];
CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
output_dims[0] = real_batch_size;
auto* output = Output(0, output_dims, input.dtype());
this->context_.template CopyItems<Context, Context>(
input.dtype(),
input.numel() * real_batch_size / max_batch_size,
input.raw_data(),
output->raw_mutable_data(input.dtype()));
} else {
// Pad to max batch size
CAFFE_ENFORCE_GT(
max_batch_size_,
0,
"max_batch_size should be larger than 0. Got ",
max_batch_size_);
// TODO: ideally we can support the case when input batch is larger than
// the max_batch_size, as we can just pad to the multiple of
// max_batch_size.
CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
int64_t real_batch_size = output_dims[0];
output_dims[0] = max_batch_size_;
auto* output = Output(0, output_dims, input.dtype());
math::Set(
output->nbytes(),
static_cast<char>(0),
static_cast<char*>(output->raw_data()),
&context_);
this->context_.template CopyItems<Context, Context>(
input.dtype(),
input.numel(),
input.raw_data(),
output->raw_mutable_data(input.dtype()));
if (OutputSize() > 1) {
auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
real_batch_tensor->template mutable_data<int64_t>()[0] =
real_batch_size;
}
}
return true;
}
private:
int64_t max_batch_size_;
};
} // namespace caffe2