Skip to content

Commit 2973404

Browse files
eladebanmn-robot
authored andcommitted
Add OpStructureExporter.
Allows to create an op which exports learned structure to file. PiperOrigin-RevId: 249396125
1 parent 77d6ac4 commit 2973404

5 files changed

+528
-29
lines changed
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "third_party/tensorflow/core/framework/op.h"
2+
3+
4+
REGISTER_OP("JsonTensorExporter")
5+
.Attr("filename: string")
6+
.Attr("T: {float, double, int32, bool}")
7+
.Attr("N: int")
8+
.Attr("keys: list(string)")
9+
.Input("values: N * T")
10+
.Input("save: bool")
11+
.Doc(R"doc(
12+
Saves the content of tensors on file as JSON dictionary.
13+
14+
filename: Filename to which the JSON is to be saved.
15+
N: Number of tensors expected.
16+
keys: The list of keys of the dictionary. Must be of length N.
17+
values: A list of tensors, will be the respective values. The order of the
18+
values is expected to match that of the keys. Must be of length N. Currently
19+
only vectors and scalars (rank 1 and 0) are supported.
20+
save: If false, the op would be a no-op. This mechanism is introduced because
21+
tf.cond can execute both the if and the else, and we don't want to write files
22+
unnecessarily.
23+
)doc");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include <string>
2+
3+
#include "file/base/file.h"
4+
#include "file/base/helpers.h"
5+
6+
#include "file/base/options.h"
7+
#include "third_party/jsoncpp/json.h"
8+
#include "third_party/tensorflow/core/framework/op_kernel.h"
9+
#include "third_party/tensorflow/core/framework/tensor.h"
10+
#include "third_party/tensorflow/core/framework/tensor_shape.h"
11+
#include "third_party/tensorflow/core/lib/core/errors.h"
12+
13+
14+
namespace morph_net {
15+
16+
using ::tensorflow::errors::InvalidArgument;
17+
using ::tensorflow::OpInputList;
18+
using ::tensorflow::OpKernelConstruction;
19+
using ::tensorflow::OpKernel;
20+
using ::tensorflow::OpKernelContext;
21+
using ::tensorflow::Tensor;
22+
23+
24+
template <typename T>
25+
class JsonTensorExporterOpKernel : public OpKernel {
26+
public:
27+
explicit JsonTensorExporterOpKernel(OpKernelConstruction* context)
28+
: OpKernel(context) {
29+
int number_of_keys;
30+
OP_REQUIRES_OK(context, context->GetAttr("N", &number_of_keys));
31+
OP_REQUIRES_OK(context, context->GetAttr("keys", &keys_));
32+
OP_REQUIRES_OK(context, context->GetAttr("filename", &filename_));
33+
34+
OP_REQUIRES(context, keys_.size() == number_of_keys,
35+
InvalidArgument("Number of keys (", keys_.size(), ") must match"
36+
" N (", number_of_keys, ")."));
37+
38+
OP_REQUIRES_OK(context, WriteFile(""));
39+
}
40+
41+
void Compute(OpKernelContext* context) override {
42+
OpInputList values;
43+
const Tensor* save;
44+
OP_REQUIRES_OK(context, context->input_list("values", &values));
45+
OP_REQUIRES_OK(context, context->input("save", &save));
46+
if (!save->scalar<bool>()()) return;
47+
48+
CHECK_EQ(values.size(), keys_.size()); // Enforced by REGISTER_OP
49+
50+
Json::Value json;
51+
int ikey = 0;
52+
for (const Tensor& tensor : values) {
53+
OP_REQUIRES(context, tensor.dims() <= 1, InvalidArgument(
54+
"Only scalars and vectors are currnetly supported, but a tensor "
55+
"with rank ", tensor.dims(), "was found."));
56+
57+
const string& key = keys_[ikey++];
58+
if (tensor.dims() == 0) { // Scalar
59+
json[key] = tensor.scalar<T>()();
60+
continue;
61+
}
62+
63+
// Vector
64+
for (int ielement = 0; ielement < tensor.NumElements(); ++ielement) {
65+
json[key][ielement] = tensor.vec<T>()(ielement);
66+
}
67+
}
68+
69+
Json::StyledWriter writer;
70+
OP_REQUIRES_OK(context, WriteFile(writer.write(json)));
71+
}
72+
73+
private:
74+
::tensorflow::Status WriteFile(const string& content) const {
75+
::util::Status status =
76+
::file::SetContents(filename_, content, ::file::Defaults());
77+
if (status.ok()){
78+
return ::tensorflow::Status::OK();
79+
}
80+
return InvalidArgument("Unable to write to file ", filename_,
81+
". Error message: ", status.error_message());
82+
}
83+
84+
std::vector<string> keys_;
85+
string filename_;
86+
};
87+
88+
REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter")
89+
.Device(::tensorflow::DEVICE_CPU)
90+
.TypeConstraint<int32>("T"),
91+
JsonTensorExporterOpKernel<int32>);
92+
93+
REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter")
94+
.Device(::tensorflow::DEVICE_CPU)
95+
.TypeConstraint<float>("T"),
96+
JsonTensorExporterOpKernel<float>);
97+
98+
REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter")
99+
.Device(::tensorflow::DEVICE_CPU)
100+
.TypeConstraint<double>("T"),
101+
JsonTensorExporterOpKernel<double>);
102+
103+
REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter")
104+
.Device(::tensorflow::DEVICE_CPU)
105+
.TypeConstraint<bool>("T"),
106+
JsonTensorExporterOpKernel<bool>);
107+
108+
} // namespace morph_net
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include <string>
2+
3+
#include "file/base/file.h"
4+
#include "file/base/helpers.h"
5+
#include "file/base/path.h"
6+
7+
#include "testing/base/public/gmock.h"
8+
#include "testing/base/public/gunit.h"
9+
#include "third_party/jsoncpp/json.h"
10+
#include "third_party/tensorflow/core/framework/fake_input.h"
11+
#include "third_party/tensorflow/core/framework/node_def_builder.h"
12+
#include "third_party/tensorflow/core/framework/tensor.h"
13+
#include "third_party/tensorflow/core/framework/tensor_testutil.h"
14+
#include "third_party/tensorflow/core/kernels/ops_testutil.h"
15+
#include "third_party/tensorflow/core/lib/core/status_test_util.h"
16+
17+
namespace morph_net {
18+
19+
using ::tensorflow::DT_INT32;
20+
using ::tensorflow::FakeInput;
21+
using ::tensorflow::NodeDefBuilder;
22+
using ::tensorflow::OpsTestBase;
23+
using ::tensorflow::Status;
24+
using ::tensorflow::TensorShape;
25+
using ::testing::ElementsAre;
26+
27+
28+
std::vector<int> ToVector(const Json::Value& json) {
29+
std::vector<int> v;
30+
for (const Json::Value& item : json) {
31+
v.push_back(item.asInt());
32+
}
33+
return v;
34+
}
35+
36+
37+
class JsonTensorExporterTest : public OpsTestBase {};
38+
39+
TEST_F(JsonTensorExporterTest, Success) {
40+
const int kLength = 3;
41+
const string filename = ::file::JoinPath(FLAGS_test_tmpdir, "success.json");
42+
TF_ASSERT_OK(
43+
NodeDefBuilder("exporter", "JsonTensorExporter")
44+
.Attr("T", DT_INT32)
45+
.Attr("N", kLength)
46+
.Attr("keys", {"k1", "k2", "k3"})
47+
.Attr("filename", filename)
48+
.Input(FakeInput(kLength, ::tensorflow::DT_INT32))
49+
.Input(FakeInput(::tensorflow::DT_BOOL))
50+
.Finalize(node_def()));
51+
52+
TF_ASSERT_OK(InitOp());
53+
// The initialization of the op creates an empty file at `filename`. We delete
54+
// both to verify it was created, and to clean it up for the next steps of the
55+
// test.
56+
ASSERT_OK(::file::Delete(filename, ::file::Defaults()));
57+
58+
AddInputFromArray<int>(TensorShape({3}), {3, 5, 7});
59+
AddInputFromArray<int>(TensorShape({2}), {6, 4});
60+
AddInputFromArray<int>(TensorShape({}), {9});
61+
62+
// Set the `save` flag initially to false - so the op should be a no-op.
63+
AddInputFromArray<bool>(TensorShape({}), {false});
64+
TF_ASSERT_OK(RunOpKernel());
65+
// Verify that indeed no file was created.
66+
EXPECT_EQ(absl::StatusCode::kNotFound,
67+
::file::Exists(filename, ::file::Defaults()).code());
68+
69+
// Flip the `save` flag to true and test the content of the savef file.
70+
tensors_[3]->scalar<bool>()() = true;
71+
TF_ASSERT_OK(RunOpKernel());
72+
73+
string contents;
74+
ASSERT_OK(::file::GetContents(filename, &contents, ::file::Defaults()));
75+
Json::Reader reader;
76+
Json::Value json;
77+
reader.parse(contents, json);
78+
EXPECT_THAT(json.getMemberNames(), ElementsAre("k1", "k2", "k3"));
79+
EXPECT_TRUE(json["k1"].isArray());
80+
EXPECT_THAT(ToVector(json["k1"]), ElementsAre(3, 5, 7));
81+
EXPECT_TRUE(json["k2"].isArray());
82+
EXPECT_THAT(ToVector(json["k2"]), ElementsAre(6, 4));
83+
EXPECT_EQ(9, json["k3"].asInt());
84+
}
85+
86+
TEST_F(JsonTensorExporterTest, WrongNumberOfKeys) {
87+
const int kLength = 3;
88+
const string filename = ::file::JoinPath(FLAGS_test_tmpdir, "failure.json");
89+
TF_ASSERT_OK(
90+
NodeDefBuilder("exporter", "JsonTensorExporter")
91+
.Attr("T", DT_INT32)
92+
.Attr("N", kLength)
93+
.Attr("keys", {"k1", "k2"}) // Two keys only, even though kLength = 3.
94+
.Attr("filename", filename)
95+
.Input(FakeInput(kLength, ::tensorflow::DT_INT32))
96+
.Input(FakeInput(::tensorflow::DT_BOOL))
97+
.Finalize(node_def()));
98+
99+
EXPECT_FALSE(InitOp().ok());
100+
}
101+
102+
TEST_F(JsonTensorExporterTest, BadFileName) {
103+
const int kLength = 3;
104+
const string filename = "**bad";
105+
TF_ASSERT_OK(
106+
NodeDefBuilder("exporter", "JsonTensorExporter")
107+
.Attr("T", DT_INT32)
108+
.Attr("N", kLength)
109+
.Attr("keys", {"k1", "k2", "k3"})
110+
.Attr("filename", filename)
111+
.Input(FakeInput(kLength, ::tensorflow::DT_INT32))
112+
.Input(FakeInput(::tensorflow::DT_BOOL))
113+
.Finalize(node_def()));
114+
115+
EXPECT_FALSE(InitOp().ok());
116+
}
117+
118+
} // namespace morph_net

0 commit comments

Comments
 (0)