Skip to content

Commit fb25044

Browse files
shraman-rcmn-robot
authored andcommitted
Make TPU summaries work with non-variable BatchNorm gammas (and other non-variable tensors).
PiperOrigin-RevId: 262224145
1 parent 41c1724 commit fb25044

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

morph_net/framework/tpu_util.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def maybe_convert_to_variable(tensor):
5858
return the original input tensor.
5959
"""
6060
op = tensor.op
61+
if is_on_cpu() and tensor in var_store:
62+
return var_store[tensor]
6163
if op.type != 'ReadVariableOp':
6264
# No need to convert.
6365
return tensor
@@ -102,7 +104,9 @@ def write_to_variable(tensor):
102104
use_resource=True)
103105
var_store[tensor] = variable
104106
with tf.control_dependencies([variable.assign(tensor)]):
105-
return tf.identity(tensor)
107+
tensor_copy = tf.identity(tensor)
108+
var_store[tensor_copy] = variable
109+
return tensor_copy
106110

107111

108112
def read_from_variable(tensor):
@@ -113,3 +117,8 @@ def read_from_variable(tensor):
113117
else:
114118
# Current read, but only works on TPU.
115119
return tensor
120+
121+
122+
def is_intermediate_var(v):
123+
"""Returns True if `v` was created by `write_to_variable` above."""
124+
return v in var_store.values()

0 commit comments

Comments
 (0)