-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a padding operator. #95
Changes from 4 commits
03fdff1
4064629
9e54e3a
1090710
1b17e2a
02ef71e
b2a51ba
c89fa42
fe44780
d818a5f
7689d04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
tags | ||
build/ | ||
experiments/ | ||
smaug/operators/padding_op_test.h | ||
*.pyc | ||
__pycache__ | ||
*.swp | ||
*.swo |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,14 @@ | |
#include "smaug/core/backend.h" | ||
#include "smaug/core/operator.h" | ||
#include "smaug/core/tensor.h" | ||
// #include "smaug/core/tensor_utils.h" | ||
#include "smaug/core/workspace.h" | ||
#include <google/protobuf/repeated_field.h> | ||
using namespace google::protobuf; | ||
|
||
namespace smaug { | ||
|
||
/** \ingroup Operators | ||
* \brief Pad a given tensor in different dimension. | ||
* \brief Pad a given tensor in any number of dimensions with arbitrary size. | ||
* | ||
* This has a software-based implementation. | ||
* | ||
|
@@ -19,95 +20,77 @@ namespace smaug { | |
template <typename Backend> | ||
class PaddingOp : public Operator { | ||
public: | ||
PaddingOp(const std::string& name, | ||
Workspace* workspace) | ||
: Operator(name, OpType::Repeat, workspace){ | ||
PaddingOp(const std::string& name, Workspace* workspace) | ||
: Operator(name, OpType::Padding, workspace) { | ||
inputs.resize(kNumInputs, nullptr); | ||
outputs.resize(kNumOutputs, nullptr); | ||
} | ||
|
||
PaddingOp(const std::string& name, | ||
Workspace* workspace, | ||
int val) | ||
: Operator(name, OpType::Repeat, workspace), padder(val){ | ||
inputs.resize(kNumInputs, nullptr); | ||
outputs.resize(kNumOutputs, nullptr); | ||
/** | ||
* Set the paddingSize of the Tensor along each dimension. | ||
* The paddingSize is orgainized as <dim1_forward, dim1_backward, ... | ||
* ,dimk_backward> | ||
*/ | ||
void setPaddingSize(RepeatedField<google::protobuf::int32> const& val) { | ||
paddingSize.assign(val.begin(), val.end()); | ||
} | ||
|
||
/** Set the number of padders of the Tensor along each dimension. */ | ||
void setPadder(const int& val) { | ||
padder = val; | ||
// set output size? | ||
} | ||
void setPaddingSize(std::vector<int> const& val) { paddingSize = val; } | ||
|
||
int getPadder() { return padder; } | ||
std::vector<int> getPaddingSize() const { return paddingSize; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return const std::vector& to avoid making a copy of paddingSize. |
||
|
||
void run() override { | ||
Tensor* input = getInput(0); | ||
Tensor* output = getOutput(0); | ||
int ndims = input->ndims(); | ||
std::vector<int> inputDims = input->getShape().dims(); | ||
std::vector<int> outputDims = output->getShape().dims(); | ||
int total_dim = 1; | ||
for (int i: outputDims){ | ||
total_dim *= i; | ||
} | ||
std::vector<float> vf(total_dim, 0); | ||
output->fillData(vf.data(), vf.size()); | ||
/* | ||
copyTensorRegion(Tensor* dest, | ||
Tensor* src, | ||
const std::vector<int>& destOrigin, | ||
const std::vector<int>& srcOrigin, | ||
const std::vector<int>& regionSize | ||
*/ | ||
std::vector<int> destOrigin; | ||
if (input->getShape().getLayout() == DataLayout::NCHW){ | ||
destOrigin = std::vector<int>({0, 0, padder, padder}); | ||
} | ||
else if(input->getShape().getLayout() == DataLayout::NHWC){ | ||
destOrigin = std::vector<int>({0, padder, padder, 0}); | ||
} | ||
else{ | ||
assert(false && "Invalid padding data type!"); | ||
} | ||
std::vector<int> srcOrigin = std::vector<int>({0, 0, 0, 0}); | ||
std::vector<int> regionSize = inputDims; | ||
copyTensorRegion(output, input, destOrigin, srcOrigin, regionSize); | ||
Tensor* input = getInput(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the enums defined below ( |
||
Tensor* output = getOutput(0); | ||
int ndims = input->ndims(); | ||
const std::vector<int> inputDims = input->getShape().dims(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both |
||
const std::vector<int> outputDims = output->getShape().dims(); | ||
int total_dim = 1; | ||
for (int i : outputDims) { | ||
total_dim *= i; | ||
} | ||
std::vector<float> vf(total_dim, 0); | ||
output->fillData(vf.data(), vf.size()); | ||
std::vector<int> destOrigin, paddingBegin, srcOrigin; | ||
for (int i = 0; i < ndims; i++) { | ||
paddingBegin.push_back(paddingSize[2 * i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
srcOrigin.push_back(0); | ||
} | ||
destOrigin = std::vector<int>(paddingBegin); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no need to declare destOrigin earlier and then re-initialize it here - that causes the vector to be constructed twice. Just declare it here directly: |
||
std::vector<int> regionSize = inputDims; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here - no need to make a copy of inputDims, just use it directly. I think you're trying to make it more clear what each vector is representing in the copyTensorRegion call but there's no need since the API documents it very clearly already. |
||
copyTensorRegion(output, input, destOrigin, srcOrigin, regionSize); | ||
} | ||
|
||
// Optional override for testing purposes. | ||
void createAllTensors() override { | ||
Tensor* input = getInput(0); | ||
int ndims = input->ndims(); | ||
std::vector<int> dims = input->getShape().dims(); | ||
if (input->getShape().getLayout() == DataLayout::NCHW){ | ||
dims[2] += 2*padder; | ||
dims[3] += 2*padder; | ||
} | ||
else if (input->getShape().getLayout() == DataLayout::NHWC){ | ||
dims[1] += 2*padder; | ||
dims[2] += 2*padder; | ||
for (int i = 0; i < ndims; i++) { | ||
dims[i] += (paddingSize[2 * i] + paddingSize[2 * i + 1]); | ||
} | ||
TensorShape shape( | ||
dims, input->getShape().getLayout(), Backend::Alignment); | ||
Tensor* output = new Tensor(name, shape); | ||
workspace->addTensor(output); | ||
outputs.at(0) = output; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here - use enums instead of hardcoded constants. |
||
} | ||
} | ||
|
||
// Optional but recommended function to verify operator parameters. | ||
bool validate() override { | ||
if (padder < 0){ | ||
return false; | ||
} | ||
Tensor* input = getInput(0); | ||
int ndims = input->ndims(); | ||
if (paddingSize.size() != 2 * ndims) { | ||
return false; | ||
} | ||
return Operator::validate(); | ||
} | ||
|
||
enum { kInputs, kNumInputs }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
enum { kOutputs, kNumOutputs }; | ||
|
||
private: | ||
int padder = 0; | ||
private: | ||
std::vector<int> paddingSize = {}; | ||
}; | ||
|
||
} // namespace smaug | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
const RepeatedField&
, notRepeatedField const&
.Also, what does "forward" and "backward" mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update this.