Skip to content

Commit

Permalink
Create Estimator hook for StructureExporter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 277371977
  • Loading branch information
pkch authored and mn-robot committed Oct 29, 2019
1 parent 8e00552 commit 35282f6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ with tf.Session() as sess:
_, structure_exporter_tensors = sess.run([train_op, exporter.tensors])
if (step % 1000 == 0):
exporter.populate_tensor_values(structure_exporter_tensors)
exporter.create_file_and_save_alive_counts(train_dir, step)
exporter.create_file_and_save_alive_counts(
os.path.join(train_dir, 'learned_structure'), step)
```

## Misc
Expand Down
36 changes: 30 additions & 6 deletions morph_net/tools/structure_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,20 @@ def create_file_and_save_alive_counts(self, base_dir: Text,
global_step: int) -> None:
"""Creates and updates files with alive counts.
Creates the directory `{base_dir}/learned_structure/` and saves the current
alive counts to:
`{base_dir}/learned_structure/{ALIVE_FILENAME}_{global_step}`.
Creates the directory `{base_dir}` and saves the current alive counts to:
`{base_dir}/{ALIVE_FILENAME}_{global_step}`.
Args:
base_dir: where to export the alive counts.
global_step: current value of global step, used as a suffix in filename.
"""
current_filename = '%s_%s' % (ALIVE_FILENAME, global_step)
directory = os.path.join(base_dir, 'learned_structure')
try:
tf.gfile.MakeDirs(directory)
tf.gfile.MakeDirs(base_dir)
except tf.errors.OpError:
# Probably already exists. If not, we'll see the error in the next line.
pass
with tf.gfile.Open(os.path.join(directory, current_filename), 'w') as f:
with tf.gfile.Open(os.path.join(base_dir, current_filename), 'w') as f:
self.save_alive_counts(f) # pytype: disable=wrong-arg-types


Expand Down Expand Up @@ -196,3 +194,29 @@ def _compute_alive_counts(

def format_structure(structure: Dict[Text, int]) -> Text:
return json.dumps(structure, indent=2, sort_keys=True, default=str)


class StructureExporterHook(tf.train.SessionRunHook):
"""Estimator hook for StructureExporter.
Usage:
exporter = structure_exporter.StructureExporter(
network_regularizer.op_regularizer_manager)
structure_export_hook = structure_exporter.StructureExporterHook(
exporter, '/path/to/cns')
estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
...,
training_hooks=[structure_export_hook])
"""

def __init__(self, exporter: StructureExporter, export_dir: Text):
self._export_dir = export_dir
self._exporter = exporter

def end(self, session: tf.Session):
global_step = session.run(tf.train.get_global_step())
tf.logging.info('Exporting structure at step %d', global_step)
tensor_to_eval_dict = session.run(self._exporter.tensors)
self._exporter.populate_tensor_values(session.run(tensor_to_eval_dict))
self._exporter.create_file_and_save_alive_counts(self._export_dir,
global_step)
6 changes: 4 additions & 2 deletions morph_net/tools/structure_exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,15 @@ def test_create_file_and_save_alive_counts(self):
base_dir = os.path.join(FLAGS.test_tmpdir, 'ee')

self.exporter.populate_tensor_values(self.tensor_value_1)
self.exporter.create_file_and_save_alive_counts(base_dir, 19)
self.exporter.create_file_and_save_alive_counts(
os.path.join(base_dir, 'learned_structure'), 19)
self.assertAllEqual(
_alive_from_file('ee/learned_structure/alive_19'),
self.expected_alive_1)

self.exporter.populate_tensor_values(self.tensor_value_2)
self.exporter.create_file_and_save_alive_counts(base_dir, 1009)
self.exporter.create_file_and_save_alive_counts(
os.path.join(base_dir, 'learned_structure'), 1009)
self.assertAllEqual(
_alive_from_file('ee/learned_structure/alive_1009'),
self.expected_alive_2)
Expand Down

0 comments on commit 35282f6

Please sign in to comment.