@@ -645,72 +645,18 @@ breadth of relevant operator coverage isn't clear; however, the normalizations
645
645
that it does provide are worth using, so we'll make use of them throughout.
646
646
:END:
647
647
648
- [[grappler-normalize-function]] provides a simple means of
648
+ src_python[:eval never]{symbolic_pymc.tensorflow.graph.normalize_tf_graph} provides a simple means of
649
649
applying src_python[:eval never]{grappler}.
650
650
651
- #+NAME: grappler-normalize-function
652
- #+BEGIN_SRC python :exports code :results silent
653
- from tensorflow.core.protobuf import config_pb2
654
-
655
- from tensorflow.python.framework import ops
656
- from tensorflow.python.framework import importer
657
- from tensorflow.python.framework import meta_graph
658
-
659
- from tensorflow.python.grappler import cluster
660
- from tensorflow.python.grappler import tf_optimizer
661
-
662
-
663
- try:
664
- gcluster = cluster.Cluster()
665
- except tf.errors.UnavailableError:
666
- pass
667
-
668
- config = config_pb2.ConfigProto()
669
-
670
-
671
- def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
672
- """Use grappler to normalize a graph.
673
-
674
- Arguments
675
- =========
676
- graph_output: Tensor
677
- A tensor we want to consider as "output" of a FuncGraph.
678
-
679
- Returns
680
- =======
681
- The simplified graph.
682
- """
683
- train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
684
- train_op.clear()
685
- train_op.extend([graph_output])
686
-
687
- metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)
688
-
689
- optimized_graphdef = tf_optimizer.OptimizeGraph(
690
- config, metagraph, verbose=verbose, cluster=gcluster)
691
-
692
- output_name = graph_output.name
693
-
694
- if new_graph:
695
- optimized_graph = ops.Graph()
696
- else:
697
- optimized_graph = ops.get_default_graph()
698
- del graph_output
699
-
700
- with optimized_graph.as_default():
701
- importer.import_graph_def(optimized_graphdef, name="")
702
-
703
- opt_graph_output = optimized_graph.get_tensor_by_name(output_name)
704
-
705
- return opt_graph_output
706
- #+END_SRC
707
-
708
- In [[grappler-normalize-function]] we
651
+ In [[grappler-normalize-test-graph]] we
709
652
run src_python[:eval never]{grappler} on the log-likelihood graph for a normal
710
653
random variable from [[tfp-normal-log-lik-graph]].
711
654
712
655
#+NAME: grappler-normalize-test-graph
713
656
#+BEGIN_SRC python :exports code :results silent :wrap
657
+ from symbolic_pymc.tensorflow.graph import normalize_tf_graph
658
+
659
+
714
660
normal_log_lik_opt = normalize_tf_graph(normal_log_lik)
715
661
#+END_SRC
716
662
0 commit comments