diff --git a/morph_net/framework/op_handler_util.py b/morph_net/framework/op_handler_util.py index f846dd2..45ccf92 100644 --- a/morph_net/framework/op_handler_util.py +++ b/morph_net/framework/op_handler_util.py @@ -31,6 +31,8 @@ def get_input_ops(op, op_reg_manager, whitelist_indices=None): Returns: List of tf.Operation that are the inputs to op. """ + if 'GumbelPrefix' in op.type: + return [] # Ignore scalar or 1-D constant inputs. def is_const(tensor): return tensor.op.type == 'Const' diff --git a/morph_net/tools/configurable_ops.py b/morph_net/tools/configurable_ops.py index 6534467..f4fe19a 100644 --- a/morph_net/tools/configurable_ops.py +++ b/morph_net/tools/configurable_ops.py @@ -159,6 +159,11 @@ def __init__(self, self._default_to_zero = fallback_rule == FallbackRule.zero self._strict = fallback_rule == FallbackRule.strict + @property + def parameterization(self): + """Returns the parameterization dict mapping op names to num_outputs.""" + return self._parameterization + @tf.contrib.framework.add_arg_scope def conv2d(self, *args, **kwargs): """Masks num_outputs from the function pointed to by 'conv2d'.