diff --git a/README.md b/README.md index d5a75ec..726bca0 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,8 @@ with tf.Session() as sess: _, structure_exporter_tensors = sess.run([train_op, exporter.tensors]) if (step % 1000 == 0): exporter.populate_tensor_values(structure_exporter_tensors) - exporter.create_file_and_save_alive_counts(train_dir, step) + exporter.create_file_and_save_alive_counts( + os.path.join(train_dir, 'learned_structure'), step) ``` ## Misc diff --git a/morph_net/tools/structure_exporter.py b/morph_net/tools/structure_exporter.py index 3d1422c..c8c00bf 100644 --- a/morph_net/tools/structure_exporter.py +++ b/morph_net/tools/structure_exporter.py @@ -129,22 +129,20 @@ def create_file_and_save_alive_counts(self, base_dir: Text, global_step: int) -> None: """Creates and updates files with alive counts. - Creates the directory `{base_dir}/learned_structure/` and saves the current - alive counts to: - `{base_dir}/learned_structure/{ALIVE_FILENAME}_{global_step}`. + Creates the directory `{base_dir}` and saves the current alive counts to: + `{base_dir}/{ALIVE_FILENAME}_{global_step}`. Args: base_dir: where to export the alive counts. global_step: current value of global step, used as a suffix in filename. """ current_filename = '%s_%s' % (ALIVE_FILENAME, global_step) - directory = os.path.join(base_dir, 'learned_structure') try: - tf.gfile.MakeDirs(directory) + tf.gfile.MakeDirs(base_dir) except tf.errors.OpError: # Probably already exists. If not, we'll see the error in the next line. pass - with tf.gfile.Open(os.path.join(directory, current_filename), 'w') as f: + with tf.gfile.Open(os.path.join(base_dir, current_filename), 'w') as f: self.save_alive_counts(f) # pytype: disable=wrong-arg-types @@ -196,3 +194,29 @@ def _compute_alive_counts( def format_structure(structure: Dict[Text, int]) -> Text: return json.dumps(structure, indent=2, sort_keys=True, default=str) + + +class StructureExporterHook(tf.train.SessionRunHook): + """Estimator hook for StructureExporter. + + Usage: + exporter = structure_exporter.StructureExporter( + network_regularizer.op_regularizer_manager) + structure_export_hook = structure_exporter.StructureExporterHook( + exporter, '/path/to/cns') + estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( + ..., + training_hooks=[structure_export_hook]) + """ + + def __init__(self, exporter: StructureExporter, export_dir: Text): + self._export_dir = export_dir + self._exporter = exporter + + def end(self, session: tf.Session): + global_step = session.run(tf.train.get_global_step()) + tf.logging.info('Exporting structure at step %d', global_step) + tensor_to_eval_dict = session.run(self._exporter.tensors) + self._exporter.populate_tensor_values(session.run(tensor_to_eval_dict)) + self._exporter.create_file_and_save_alive_counts(self._export_dir, + global_step) diff --git a/morph_net/tools/structure_exporter_test.py b/morph_net/tools/structure_exporter_test.py index 00dd262..75875f9 100644 --- a/morph_net/tools/structure_exporter_test.py +++ b/morph_net/tools/structure_exporter_test.py @@ -129,13 +129,15 @@ def test_create_file_and_save_alive_counts(self): base_dir = os.path.join(FLAGS.test_tmpdir, 'ee') self.exporter.populate_tensor_values(self.tensor_value_1) - self.exporter.create_file_and_save_alive_counts(base_dir, 19) + self.exporter.create_file_and_save_alive_counts( + os.path.join(base_dir, 'learned_structure'), 19) self.assertAllEqual( _alive_from_file('ee/learned_structure/alive_19'), self.expected_alive_1) self.exporter.populate_tensor_values(self.tensor_value_2) - self.exporter.create_file_and_save_alive_counts(base_dir, 1009) + self.exporter.create_file_and_save_alive_counts( + os.path.join(base_dir, 'learned_structure'), 1009) self.assertAllEqual( _alive_from_file('ee/learned_structure/alive_1009'), self.expected_alive_2)