diff --git a/morph_net/tools/json_tensor_exporter_op.cc b/morph_net/tools/json_tensor_exporter_op.cc new file mode 100644 index 0000000..6902127 --- /dev/null +++ b/morph_net/tools/json_tensor_exporter_op.cc @@ -0,0 +1,23 @@ +#include "third_party/tensorflow/core/framework/op.h" + + +REGISTER_OP("JsonTensorExporter") + .Attr("filename: string") + .Attr("T: {float, double, int32, bool}") + .Attr("N: int") + .Attr("keys: list(string)") + .Input("values: N * T") + .Input("save: bool") + .Doc(R"doc( +Saves the content of tensors on file as JSON dictionary. + +filename: Filename to which the JSON is to be saved. +N: Number of tensors expected. +keys: The list of keys of the dictionary. Must be of length N. +values: A list of tensors, will be the respective values. The order of the + values is expected to match that of the keys. Must be of length N. Currently + only vectors and scalars (rank 1 and 0) are supported. +save: If false, the op would be a no-op. This mechanism is introduced because + tf.cond can execute both the if and the else, and we don't want to write files + unnecessarily. +)doc"); diff --git a/morph_net/tools/json_tensor_exporter_op_kernel.cc b/morph_net/tools/json_tensor_exporter_op_kernel.cc new file mode 100644 index 0000000..d7c24f2 --- /dev/null +++ b/morph_net/tools/json_tensor_exporter_op_kernel.cc @@ -0,0 +1,108 @@ +#include + +#include "file/base/file.h" +#include "file/base/helpers.h" + +#include "file/base/options.h" +#include "third_party/jsoncpp/json.h" +#include "third_party/tensorflow/core/framework/op_kernel.h" +#include "third_party/tensorflow/core/framework/tensor.h" +#include "third_party/tensorflow/core/framework/tensor_shape.h" +#include "third_party/tensorflow/core/lib/core/errors.h" + + +namespace morph_net { + +using ::tensorflow::errors::InvalidArgument; +using ::tensorflow::OpInputList; +using ::tensorflow::OpKernelConstruction; +using ::tensorflow::OpKernel; +using ::tensorflow::OpKernelContext; +using ::tensorflow::Tensor; + + +template +class JsonTensorExporterOpKernel : public OpKernel { + public: + explicit JsonTensorExporterOpKernel(OpKernelConstruction* context) + : OpKernel(context) { + int number_of_keys; + OP_REQUIRES_OK(context, context->GetAttr("N", &number_of_keys)); + OP_REQUIRES_OK(context, context->GetAttr("keys", &keys_)); + OP_REQUIRES_OK(context, context->GetAttr("filename", &filename_)); + + OP_REQUIRES(context, keys_.size() == number_of_keys, + InvalidArgument("Number of keys (", keys_.size(), ") must match" + " N (", number_of_keys, ").")); + + OP_REQUIRES_OK(context, WriteFile("")); + } + + void Compute(OpKernelContext* context) override { + OpInputList values; + const Tensor* save; + OP_REQUIRES_OK(context, context->input_list("values", &values)); + OP_REQUIRES_OK(context, context->input("save", &save)); + if (!save->scalar()()) return; + + CHECK_EQ(values.size(), keys_.size()); // Enforced by REGISTER_OP + + Json::Value json; + int ikey = 0; + for (const Tensor& tensor : values) { + OP_REQUIRES(context, tensor.dims() <= 1, InvalidArgument( + "Only scalars and vectors are currnetly supported, but a tensor " + "with rank ", tensor.dims(), "was found.")); + + const string& key = keys_[ikey++]; + if (tensor.dims() == 0) { // Scalar + json[key] = tensor.scalar()(); + continue; + } + + // Vector + for (int ielement = 0; ielement < tensor.NumElements(); ++ielement) { + json[key][ielement] = tensor.vec()(ielement); + } + } + + Json::StyledWriter writer; + OP_REQUIRES_OK(context, WriteFile(writer.write(json))); + } + + private: + ::tensorflow::Status WriteFile(const string& content) const { + ::util::Status status = + ::file::SetContents(filename_, content, ::file::Defaults()); + if (status.ok()){ + return ::tensorflow::Status::OK(); + } + return InvalidArgument("Unable to write to file ", filename_, + ". Error message: ", status.error_message()); + } + + std::vector keys_; + string filename_; +}; + +REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter") + .Device(::tensorflow::DEVICE_CPU) + .TypeConstraint("T"), + JsonTensorExporterOpKernel); + +REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter") + .Device(::tensorflow::DEVICE_CPU) + .TypeConstraint("T"), + JsonTensorExporterOpKernel); + +REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter") + .Device(::tensorflow::DEVICE_CPU) + .TypeConstraint("T"), + JsonTensorExporterOpKernel); + +REGISTER_KERNEL_BUILDER(Name("JsonTensorExporter") + .Device(::tensorflow::DEVICE_CPU) + .TypeConstraint("T"), + JsonTensorExporterOpKernel); + +} // namespace morph_net diff --git a/morph_net/tools/json_tensor_exporter_op_test.cc b/morph_net/tools/json_tensor_exporter_op_test.cc new file mode 100644 index 0000000..bbd1598 --- /dev/null +++ b/morph_net/tools/json_tensor_exporter_op_test.cc @@ -0,0 +1,118 @@ +#include + +#include "file/base/file.h" +#include "file/base/helpers.h" +#include "file/base/path.h" + +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" +#include "third_party/jsoncpp/json.h" +#include "third_party/tensorflow/core/framework/fake_input.h" +#include "third_party/tensorflow/core/framework/node_def_builder.h" +#include "third_party/tensorflow/core/framework/tensor.h" +#include "third_party/tensorflow/core/framework/tensor_testutil.h" +#include "third_party/tensorflow/core/kernels/ops_testutil.h" +#include "third_party/tensorflow/core/lib/core/status_test_util.h" + +namespace morph_net { + +using ::tensorflow::DT_INT32; +using ::tensorflow::FakeInput; +using ::tensorflow::NodeDefBuilder; +using ::tensorflow::OpsTestBase; +using ::tensorflow::Status; +using ::tensorflow::TensorShape; +using ::testing::ElementsAre; + + +std::vector ToVector(const Json::Value& json) { + std::vector v; + for (const Json::Value& item : json) { + v.push_back(item.asInt()); + } + return v; +} + + +class JsonTensorExporterTest : public OpsTestBase {}; + +TEST_F(JsonTensorExporterTest, Success) { + const int kLength = 3; + const string filename = ::file::JoinPath(FLAGS_test_tmpdir, "success.json"); + TF_ASSERT_OK( + NodeDefBuilder("exporter", "JsonTensorExporter") + .Attr("T", DT_INT32) + .Attr("N", kLength) + .Attr("keys", {"k1", "k2", "k3"}) + .Attr("filename", filename) + .Input(FakeInput(kLength, ::tensorflow::DT_INT32)) + .Input(FakeInput(::tensorflow::DT_BOOL)) + .Finalize(node_def())); + + TF_ASSERT_OK(InitOp()); + // The initialization of the op creates an empty file at `filename`. We delete + // both to verify it was created, and to clean it up for the next steps of the + // test. + ASSERT_OK(::file::Delete(filename, ::file::Defaults())); + + AddInputFromArray(TensorShape({3}), {3, 5, 7}); + AddInputFromArray(TensorShape({2}), {6, 4}); + AddInputFromArray(TensorShape({}), {9}); + + // Set the `save` flag initially to false - so the op should be a no-op. + AddInputFromArray(TensorShape({}), {false}); + TF_ASSERT_OK(RunOpKernel()); + // Verify that indeed no file was created. + EXPECT_EQ(absl::StatusCode::kNotFound, + ::file::Exists(filename, ::file::Defaults()).code()); + + // Flip the `save` flag to true and test the content of the savef file. + tensors_[3]->scalar()() = true; + TF_ASSERT_OK(RunOpKernel()); + + string contents; + ASSERT_OK(::file::GetContents(filename, &contents, ::file::Defaults())); + Json::Reader reader; + Json::Value json; + reader.parse(contents, json); + EXPECT_THAT(json.getMemberNames(), ElementsAre("k1", "k2", "k3")); + EXPECT_TRUE(json["k1"].isArray()); + EXPECT_THAT(ToVector(json["k1"]), ElementsAre(3, 5, 7)); + EXPECT_TRUE(json["k2"].isArray()); + EXPECT_THAT(ToVector(json["k2"]), ElementsAre(6, 4)); + EXPECT_EQ(9, json["k3"].asInt()); +} + +TEST_F(JsonTensorExporterTest, WrongNumberOfKeys) { + const int kLength = 3; + const string filename = ::file::JoinPath(FLAGS_test_tmpdir, "failure.json"); + TF_ASSERT_OK( + NodeDefBuilder("exporter", "JsonTensorExporter") + .Attr("T", DT_INT32) + .Attr("N", kLength) + .Attr("keys", {"k1", "k2"}) // Two keys only, even though kLength = 3. + .Attr("filename", filename) + .Input(FakeInput(kLength, ::tensorflow::DT_INT32)) + .Input(FakeInput(::tensorflow::DT_BOOL)) + .Finalize(node_def())); + + EXPECT_FALSE(InitOp().ok()); +} + +TEST_F(JsonTensorExporterTest, BadFileName) { + const int kLength = 3; + const string filename = "**bad"; + TF_ASSERT_OK( + NodeDefBuilder("exporter", "JsonTensorExporter") + .Attr("T", DT_INT32) + .Attr("N", kLength) + .Attr("keys", {"k1", "k2", "k3"}) + .Attr("filename", filename) + .Input(FakeInput(kLength, ::tensorflow::DT_INT32)) + .Input(FakeInput(::tensorflow::DT_BOOL)) + .Finalize(node_def())); + + EXPECT_FALSE(InitOp().ok()); +} + +} // namespace morph_net diff --git a/morph_net/tools/structure_exporter.py b/morph_net/tools/structure_exporter.py index d17b0bf..e046647 100644 --- a/morph_net/tools/structure_exporter.py +++ b/morph_net/tools/structure_exporter.py @@ -1,4 +1,14 @@ -"""Helper module for calculating and saving learned structures.""" +"""Helper module for calculating and saving learned structures. + +TODO(e1) +Ops for exporting OpRegularizer values to json module. + +When training with a network regularizer, the emerging structure of the +network is encoded in the `alive_vector`s and `regularization_vector`s of the +`OpRegularizers` of the ops in the graph. This module offers a way to create ops +that save those vectors as json files during the training. +""" + from __future__ import absolute_import from __future__ import division # [internal] enable type annotations @@ -6,34 +16,17 @@ import json import os +from enum import Enum from morph_net.framework import op_regularizer_manager as orm +from morph_net.tools.ops import gen_json_tensor_exporter_op_py import numpy as np import tensorflow as tf from typing import Text, Sequence, Dict, Optional, IO, Iterable, Callable + _SUPPORTED_OPS = ['Conv2D', 'Conv2DBackpropInput'] _ALIVE_FILENAME = 'alive' - - -def format_structure(structure: Dict[Text, int]) -> Text: - return json.dumps(structure, indent=2, sort_keys=True, default=str) - - -def compute_alive_counts( - alive_vectors: Dict[Text, Sequence[bool]]) -> Dict[Text, int]: - """Computes alive counts. - - Args: - alive_vectors: A mapping from op_name to a vector where each element - indicates whether the corresponding output activation is alive. - - Returns: - Mapping from op_name to the number of its alive output activations. - """ - return { - op_name: int(np.sum(alive_vector)) - for op_name, alive_vector in alive_vectors.items() - } +_REG_FILENAME = 'regularizer' class StructureExporter(object): @@ -130,7 +123,7 @@ def get_alive_counts(self) -> Dict[Text, int]: if self._alive_vectors is None: raise RuntimeError('Tensor values not populated.') - return compute_alive_counts(self._alive_vectors) + return _compute_alive_counts(self._alive_vectors) def save_alive_counts(self, f: IO[bytes]) -> None: """Saves live counts to a file. @@ -194,3 +187,136 @@ def get_remove_common_prefix_fn(iterable: Iterable[Text] if not all(k.startswith(prefix) for k in iterable): return lambda x: x return lambda item: item[len(prefix):] + + +class ExportInfo(Enum): + """ExportInfo for selecting file to be exported.""" + # Export alive count of op. + alive = 'alive' + # Export regularization vector of op. + regularization = 'regularization' + # Export both alive count and regularization vector. + both = 'both' + + +class StructureExporterOp(object): + """Manages the export of the alive and regularization json files.""" + + def __init__(self, + directory, + save, + opreg_manager, + alive_file=_ALIVE_FILENAME, + regularization_file=_REG_FILENAME): + """Init an object with all the exporter vars. + + Args: + directory: A string, directory to write the json files to. + save: A scalar `tf.Tensor` of type boolean. If `False`, the exporting is a + no-op. + opreg_manager: An OpRegularizerManager that manages the OpRegularizers. + alive_file: A string with a file name that will contain the alive counts + (in json format). + regularization_file: A string with a file name that will contain the + regularization vectors (in json format). + """ + self._save = save + self._opreg_manager = opreg_manager + self._alive_file = os.path.join(directory, alive_file) + self._regularization_file = os.path.join(directory, regularization_file) + + def export(self, info=ExportInfo.both): + """Returns an `tf.Operation` that saves the ExportInfo in json files. + + Args: + info: An 'ExportInfo' enum that defines the data to be exported. + Returns: + A `tf.Operation` that executes the exporting. + Raises: + ValueError: If info is not 'ExportInfo' enum. + """ + if not isinstance(info, ExportInfo): + raise ValueError('`info` must be an ExportInfo enum value.') + + op = None + if info == ExportInfo.regularization or info == ExportInfo.both: + op = self._export_helper( + self._regularization_file, lambda x: x.regularization_vector, + 'ExportRegularization') + if info == ExportInfo.alive or info == ExportInfo.both: + alive_op = self._export_helper(self._alive_file, _alive_count, + 'ExportAlive') + op = alive_op if op is None else tf.group(op, alive_op) + return op + + def export_state_every_n(self, n, info=ExportInfo.both): + """Returns an `tf.Operation` that saves the ExportInfo every `n` steps. + + Args: + n: An integer. Actual export will happen once in every `n` calls, all + other calls will be no-ops. + info: an 'ExportInfo' enum that defined what data to export. + + Returns: + A `tf.Operation` that executes the export. + """ + with tf.name_scope('ExportEveryN'): + counter = tf.get_variable( + 'counter', dtype=tf.int32, initializer=-1, trainable=False) + counter = counter.assign_add(1) + + return tf.cond( + tf.equal(counter % n, 0), + lambda: self.export(info), + lambda: tf.no_op(name='DontSave')).op + + def _export_helper(self, filename, getter, name): + """Helper function for OpRegularizer state vectors as JSON. + + Args: + filename: A string with the filename to save. + getter: A single-argument function, which receives an OpRegularizer object + and returns the value that needs to be exported (alive count or + regularization vector). + name: Name for the StructureExporterOp op. + + Returns: + An op that exports the state if `save` evaluates to `True`. + """ + op_to_reg = { + op: getter(self._opreg_manager.get_regularizer(op)) + for op in self._opreg_manager.ops + if op.type in _SUPPORTED_OPS and self._opreg_manager.get_regularizer(op) + } + + # Sort by name, to make the exported file more easily human-readable. + sorted_ops = sorted(op_to_reg.keys(), key=lambda x: x.name) + keys = [op.name for op in sorted_ops] + values = [op_to_reg[op] for op in sorted_ops] + return gen_json_tensor_exporter_op_py.json_tensor_exporter( + filename=filename, keys=keys, values=values, save=self._save, name=name) + + +def format_structure(structure: Dict[Text, int]) -> Text: + return json.dumps(structure, indent=2, sort_keys=True, default=str) + + +def _compute_alive_counts( + alive_vectors: Dict[Text, Sequence[bool]]) -> Dict[Text, int]: + """Computes alive counts. + + Args: + alive_vectors: A mapping from op_name to a vector where each element + indicates whether the corresponding output activation is alive. + + Returns: + Mapping from op_name to the number of its alive output activations. + """ + return { + op_name: int(np.sum(alive_vector)) + for op_name, alive_vector in alive_vectors.items() + } + + +def _alive_count(op_regularizer): + return tf.reduce_sum(tf.cast(op_regularizer.alive_vector, tf.int32)) diff --git a/morph_net/tools/structure_exporter_test.py b/morph_net/tools/structure_exporter_test.py index 2e62586..629d602 100644 --- a/morph_net/tools/structure_exporter_test.py +++ b/morph_net/tools/structure_exporter_test.py @@ -4,15 +4,20 @@ from __future__ import division from __future__ import print_function +import collections import json import os from absl import flags from absl.testing import parameterized + +from morph_net.framework import batch_norm_source_op_handler +from morph_net.framework import concat_op_handler from morph_net.framework import generic_regularizers +from morph_net.framework import grouping_op_handler from morph_net.framework import op_regularizer_manager as orm +from morph_net.framework import output_non_passthrough_op_handler from morph_net.tools import structure_exporter as se - import tensorflow as tf @@ -110,13 +115,13 @@ def test_populate_tensor_values(self): def test_compute_alive_count(self): self.assertAllEqual( - se.compute_alive_counts({'a': [True, False, False]}), {'a': 1}) + se._compute_alive_counts({'a': [True, False, False]}), {'a': 1}) self.assertAllEqual( - se.compute_alive_counts({'b': [False, False]}), {'b': 0}) + se._compute_alive_counts({'b': [False, False]}), {'b': 0}) self.assertAllEqual( - se.compute_alive_counts(self.tensor_value_1), self.expected_alive_1) + se._compute_alive_counts(self.tensor_value_1), self.expected_alive_1) self.assertAllEqual( - se.compute_alive_counts(self.tensor_value_2), self.expected_alive_2) + se._compute_alive_counts(self.tensor_value_2), self.expected_alive_2) def test_save_alive_counts(self): filename = 'alive007' @@ -165,5 +170,124 @@ def test_find_common_prefix_size(self, iterable, expected_result): self.assertEqual(expected_result, list(map(rename_op, iterable))) +arg_scope = tf.contrib.framework.arg_scope +conv2d_transpose = tf.contrib.layers.conv2d_transpose +conv2d = tf.contrib.layers.conv2d +FLAGS = flags.FLAGS + + +def assign_to_gamma(scope, value): + name_to_var = {v.op.name: v for v in tf.global_variables()} + gamma = name_to_var[scope + '/BatchNorm/gamma'] + gamma.assign(value).eval() + + +def jsons_exist_in_tempdir(): + for f in tf.gfile.ListDirectory(FLAGS.test_tmpdir): + if f.startswith('alive') or f.startswith('reg'): + return True + return False + + +class StructureExporterOpTest(tf.test.TestCase): + + def empty_test_dir(self): + for f in tf.gfile.ListDirectory(FLAGS.test_tmpdir): + if f.startswith('alive') or f.startswith('reg'): + print('found f', f) + tf.gfile.Remove(os.path.join(FLAGS.test_tmpdir, f)) + + def setUp(self): + super(StructureExporterOpTest, self).setUp() + self.empty_test_dir() + params = { + 'trainable': True, + 'normalizer_fn': tf.contrib.layers.batch_norm, + 'normalizer_params': { + 'scale': True + }, + 'padding': 'SAME' + } + + image = tf.zeros([3, 10, 10, 3]) + with arg_scope([conv2d, conv2d_transpose], **params): + conv1 = conv2d(image, 5, 3, scope='conv1') + conv2 = conv2d(image, 5, 3, scope='conv2') + add = conv1 + conv2 + conv3 = conv2d(add, 4, 1, scope='conv3') + convt = conv2d_transpose(conv3, 3, 2, scope='convt') + # Create OpHandler dict for test. + op_handler_dict = collections.defaultdict( + grouping_op_handler.GroupingOpHandler) + op_handler_dict.update({ + 'FusedBatchNorm': + batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1), + 'Conv2D': + output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), + 'Conv2DBackpropInput': + output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(), + 'ConcatV2': + concat_op_handler.ConcatOpHandler(), + }) + + # Create OpRegularizerManager and NetworkRegularizer for test. + opreg_manager = orm.OpRegularizerManager( + [convt.op], op_handler_dict) + self.exporter = se.StructureExporterOp( + directory=FLAGS.test_tmpdir, + save=True, + opreg_manager=opreg_manager) + + def test_simple_export(self): + export_op = self.exporter.export() + with self.cached_session(): + tf.global_variables_initializer().run() + assign_to_gamma('conv1', [0, 1.5, 1, 0, 1]) + assign_to_gamma('conv2', [1, 1, .1, 0, 1]) + assign_to_gamma('conv3', [0, .8, 1, .25]) + assign_to_gamma('convt', [3, .3, .03]) + export_op.run() + regularizers = self._read_file(reg=True) + grouped_conv1_conv2_reg = [1, 1.5, 1, 0, 1] + self.assertAllClose(grouped_conv1_conv2_reg, regularizers['conv1/Conv2D']) + self.assertAllClose(grouped_conv1_conv2_reg, regularizers['conv2/Conv2D']) + self.assertAllClose([0, .8, 1, .25], regularizers['conv3/Conv2D']) + self.assertAllClose([3, .3, .03], regularizers['convt/conv2d_transpose']) + + alive = self._read_file(reg=False) + self.assertAllEqual(4, alive['conv1/Conv2D']) + self.assertAllEqual(4, alive['conv2/Conv2D']) + self.assertAllEqual(3, alive['conv3/Conv2D']) + self.assertAllEqual(2, alive['convt/conv2d_transpose']) + + def test_export_every_n(self): + export_op = self.exporter.export_state_every_n( + 4, se.ExportInfo.alive) + with self.cached_session(): + tf.initialize_all_variables().run() + # Initially no jsons. + self.assertFalse(jsons_exist_in_tempdir()) + export_op.run() + # 0th iteration: jsons are saved, verified and deleted. + self.assertTrue(jsons_exist_in_tempdir()) + self.empty_test_dir() + for _ in range(3): + # Itertion 1, 2, 3: jsons are not saved. + export_op.run() + self.assertFalse(jsons_exist_in_tempdir()) + # 4th iteration: saved again. + export_op.run() + self.assertTrue(jsons_exist_in_tempdir()) + self.empty_test_dir() + # 5th: not saved. + export_op.run() + self.assertFalse(jsons_exist_in_tempdir()) + + def _read_file(self, reg): + filename = se._REG_FILENAME if reg else se._ALIVE_FILENAME + with tf.gfile.Open(os.path.join(FLAGS.test_tmpdir, filename)) as f: + data = f.read() + return json.loads(data) + if __name__ == '__main__': tf.test.main()