From ab2a914301e2ff3fe2d231153e4f8b785fc9996a Mon Sep 17 00:00:00 2001 From: PatReis Date: Sun, 5 Jan 2025 14:20:53 +0100 Subject: [PATCH] Simplify casting base layer. Inputs are ignored for ragged inputs. --- kgcnn/layers/casting.py | 52 +++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 33 deletions(-) diff --git a/kgcnn/layers/casting.py b/kgcnn/layers/casting.py index 6502fd50..92ce9397 100644 --- a/kgcnn/layers/casting.py +++ b/kgcnn/layers/casting.py @@ -16,8 +16,12 @@ def _cat_one(t): class _CastBatchedDisjointBase(Layer): - def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dtype_index=None, - padded_disjoint: bool = False, uses_mask: bool = False, + def __init__(self, + reverse_indices: bool = False, + dtype_batch: str = "int64", + dtype_index=None, + padded_disjoint: bool = False, + uses_mask: bool = False, static_batched_node_output_shape: tuple = None, static_batched_edge_output_shape: tuple = None, remove_padded_disjoint_from_batched_output: bool = True, @@ -29,12 +33,17 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'. dtype_index (str): Dtype for index tensor. Default is None. padded_disjoint (bool): Whether to keep padding in disjoint representation. Default is False. + Not used for ragged arguments. uses_mask (bool): Whether the padding is marked by a boolean mask or by a length tensor, counting the non-padded nodes from index 0. Default is False. + Not used for ragged arguments. static_batched_node_output_shape (tuple): Statical output shape of nodes. Default is None. + Not used for ragged arguments. static_batched_edge_output_shape (tuple): Statical output shape of edges. Default is None. + Not used for ragged arguments. remove_padded_disjoint_from_batched_output (bool): Whether to remove the first element on batched output in case of padding. + Not used for ragged arguments. """ super(_CastBatchedDisjointBase, self).__init__(**kwargs) self.reverse_indices = reverse_indices @@ -42,7 +51,8 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt self.dtype_batch = dtype_batch self.uses_mask = uses_mask self.padded_disjoint = padded_disjoint - self.supports_jit = padded_disjoint + if padded_disjoint: + self.supports_jit = True self.static_batched_node_output_shape = static_batched_node_output_shape self.static_batched_edge_output_shape = static_batched_edge_output_shape self.remove_padded_disjoint_from_batched_output = remove_padded_disjoint_from_batched_output @@ -536,31 +546,7 @@ def call(self, inputs: list, **kwargs): CastBatchedGraphStateToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__ -class _CastRaggedToDisjointBase(Layer): - - def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dtype_index=None, **kwargs): - r"""Initialize layer. - - Args: - reverse_indices (bool): Whether to reverse index order. Default is False. - dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'. - dtype_index (str): Dtype for index tensor. Default is None. - """ - super(_CastRaggedToDisjointBase, self).__init__(**kwargs) - self.reverse_indices = reverse_indices - self.dtype_index = dtype_index - self.dtype_batch = dtype_batch - # self.supports_jit = False - - def get_config(self): - """Get config dictionary for this layer.""" - config = super(_CastRaggedToDisjointBase, self).get_config() - config.update({"reverse_indices": self.reverse_indices, "dtype_batch": self.dtype_batch, - "dtype_index": self.dtype_index}) - return config - - -class CastRaggedAttributesToDisjoint(_CastRaggedToDisjointBase): +class CastRaggedAttributesToDisjoint(_CastBatchedDisjointBase): def __init__(self, **kwargs): super(CastRaggedAttributesToDisjoint, self).__init__(**kwargs) @@ -598,10 +584,10 @@ def call(self, inputs, **kwargs): return decompose_ragged_tensor(inputs, batch_dtype=self.dtype_batch) -CastRaggedAttributesToDisjoint.__init__.__doc__ = _CastRaggedToDisjointBase.__init__.__doc__ +CastRaggedAttributesToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__ -class CastRaggedIndicesToDisjoint(_CastRaggedToDisjointBase): +class CastRaggedIndicesToDisjoint(_CastBatchedDisjointBase): def __init__(self, **kwargs): super(CastRaggedIndicesToDisjoint, self).__init__(**kwargs) @@ -685,10 +671,10 @@ def call(self, inputs, **kwargs): return [nodes_flatten, disjoint_indices, graph_id_node, graph_id_edge, node_id, edge_id, node_len, edge_len] -CastRaggedIndicesToDisjoint.__init__.__doc__ = _CastRaggedToDisjointBase.__init__.__doc__ +CastRaggedIndicesToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__ -class CastDisjointToRaggedAttributes(_CastRaggedToDisjointBase): +class CastDisjointToRaggedAttributes(_CastBatchedDisjointBase): def __init__(self, **kwargs): super(CastDisjointToRaggedAttributes, self).__init__(**kwargs) @@ -713,4 +699,4 @@ def call(self, inputs, **kwargs): raise NotImplementedError() -CastDisjointToRaggedAttributes.__init__.__doc__ = CastDisjointToRaggedAttributes.__init__.__doc__ \ No newline at end of file +CastDisjointToRaggedAttributes.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__ \ No newline at end of file