Skip to content

Commit f23b817

Browse files
[docs] Use the normalize_tf_graph now in symbolic_pymc
1 parent f0d4198 commit f23b817

4 files changed

+54
-310
lines changed

docs/source/tensorflow-radon-example.org

+5-59
Original file line numberDiff line numberDiff line change
@@ -645,72 +645,18 @@ breadth of relevant operator coverage isn't clear; however, the normalizations
645645
that it does provide are worth using, so we'll make use of them throughout.
646646
:END:
647647

648-
[[grappler-normalize-function]] provides a simple means of
648+
src_python[:eval never]{symbolic_pymc.tensorflow.graph.normalize_tf_graph} provides a simple means of
649649
applying src_python[:eval never]{grappler}.
650650

651-
#+NAME: grappler-normalize-function
652-
#+BEGIN_SRC python :exports code :results silent
653-
from tensorflow.core.protobuf import config_pb2
654-
655-
from tensorflow.python.framework import ops
656-
from tensorflow.python.framework import importer
657-
from tensorflow.python.framework import meta_graph
658-
659-
from tensorflow.python.grappler import cluster
660-
from tensorflow.python.grappler import tf_optimizer
661-
662-
663-
try:
664-
gcluster = cluster.Cluster()
665-
except tf.errors.UnavailableError:
666-
pass
667-
668-
config = config_pb2.ConfigProto()
669-
670-
671-
def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
672-
"""Use grappler to normalize a graph.
673-
674-
Arguments
675-
=========
676-
graph_output: Tensor
677-
A tensor we want to consider as "output" of a FuncGraph.
678-
679-
Returns
680-
=======
681-
The simplified graph.
682-
"""
683-
train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
684-
train_op.clear()
685-
train_op.extend([graph_output])
686-
687-
metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)
688-
689-
optimized_graphdef = tf_optimizer.OptimizeGraph(
690-
config, metagraph, verbose=verbose, cluster=gcluster)
691-
692-
output_name = graph_output.name
693-
694-
if new_graph:
695-
optimized_graph = ops.Graph()
696-
else:
697-
optimized_graph = ops.get_default_graph()
698-
del graph_output
699-
700-
with optimized_graph.as_default():
701-
importer.import_graph_def(optimized_graphdef, name="")
702-
703-
opt_graph_output = optimized_graph.get_tensor_by_name(output_name)
704-
705-
return opt_graph_output
706-
#+END_SRC
707-
708-
In [[grappler-normalize-function]] we
651+
In [[grappler-normalize-test-graph]] we
709652
run src_python[:eval never]{grappler} on the log-likelihood graph for a normal
710653
random variable from [[tfp-normal-log-lik-graph]].
711654

712655
#+NAME: grappler-normalize-test-graph
713656
#+BEGIN_SRC python :exports code :results silent :wrap
657+
from symbolic_pymc.tensorflow.graph import normalize_tf_graph
658+
659+
714660
normal_log_lik_opt = normalize_tf_graph(normal_log_lik)
715661
#+END_SRC
716662

0 commit comments

Comments
 (0)