55
66import torch
77
8- from executorch .backends .nxp .backend .data_format import DataFormat
9- from executorch .backends .nxp .backend .ir .converter .conversion import translator
108from executorch .backends .nxp .backend .ir .converter .conversion .common import OpsList
11- from executorch .backends .nxp .backend .ir .converter .conversion .translator import (
12- create_channels_last_to_channels_first_permutation ,
13- )
149from executorch .backends .nxp .backend .ir .converter .node_converter import (
1510 CustomDelegationOptions ,
1611 NodeConverter ,
1712)
1813from executorch .backends .nxp .backend .ir .converter .node_converters .shared .reduce_utils import (
1914 convert_axes_from_attribute ,
15+ get_dim_and_handle_io_formats ,
16+ get_reduce_node_attrs ,
2017)
2118from executorch .backends .nxp .backend .ir .tflite_generator .builtin_options import (
2219 mean_options ,
@@ -38,7 +35,7 @@ def supports_partitioning_result(
3835 neutron_target_spec : NeutronTargetSpec ,
3936 parameters_mapping : dict [str , Parameter ],
4037 ) -> bool :
41- dim , keepdim = MeanDimConverter . _get_attrs (node )
38+ dim , keepdim = get_reduce_node_attrs (node )
4239 input_shape = node .args [0 ].meta ["val" ].shape
4340
4441 is_alone_in_partition = cls .is_node_alone_in_partition (node , partition_list )
@@ -85,140 +82,6 @@ def _is_supported_in_IR(
8582
8683 return True
8784
88- @staticmethod
89- def _to_pos_dim (d : int , rank : int ):
90- return d + rank if d < 0 else d
91-
92- @staticmethod
93- def _normalize_dim (dim : list [int ], rank : int ) -> list [int ]:
94- # convert negative index to positive
95- return [MeanDimConverter ._to_pos_dim (d , rank ) for d in dim ]
96-
97- @staticmethod
98- def _normalize_and_to_channel_last_dim (dim : list [int ], rank : int ) -> list [int ]:
99- # convert negative index to positive
100- dim = MeanDimConverter ._normalize_dim (dim , rank )
101-
102- perm = create_channels_last_to_channels_first_permutation (rank , True )
103- dim = [perm [d ] for d in dim ]
104-
105- # noinspection PyTypeChecker
106- return dim
107-
108- @staticmethod
109- def _get_attrs (node : Node ) -> tuple [list [int ], bool ]:
110- dim = node .args [1 ]
111- keepdim = node .args [2 ] if len (node .args ) >= 3 else False
112- return dim , keepdim
113-
114- def _get_dim_and_handle_io_formats (
115- self , ops : OpsList , dim : list [int ], keep_dim : bool
116- ):
117- t_op = ops .middle_op
118- x = t_op .tmp_inputs [0 ]
119- y = t_op .tmp_outputs [0 ]
120-
121- channels_last_input = x .tensor_format .is_channels_last ()
122- channels_last_output = y .tensor_format .is_channels_last ()
123- formatless_input = not channels_last_input
124- formatless_output = not channels_last_output
125-
126- dim = self ._normalize_dim (dim , x .rank )
127-
128- if keep_dim :
129- # The rank is preserved and the io formats should always be equal.
130- assert (
131- x .tensor_format == y .tensor_format
132- ), "NXP backend: There is a bug in `mean.dim` format inference."
133-
134- # Just adjust the dim to match the input format.
135- if channels_last_input :
136- dim = self ._normalize_and_to_channel_last_dim (dim , x .rank )
137-
138- else :
139- # `keep_dim = False`, so the output rank != input rank, and the operator changes the tensor format.
140-
141- if channels_last_input and formatless_output :
142- if 1 in dim :
143- # If we are reducing over the channels, the channels dimension gets removed and the output ends up
144- # exactly equal in channels last and channels first, regardless of which other dimensions are
145- # removed. Therefore, we can just adjust the `dim` and we don't need to insert any `Transpose` ops.
146- dim = self ._normalize_and_to_channel_last_dim (dim , x .rank )
147- elif all (spatial_dim in dim for spatial_dim in range (2 , x .rank )):
148- # All spatial dims are reduced, leaving only batch and channels (both optionally). So the result is
149- # equal in channels first and channels last as long as we adjust the `dim` to match a channels last
150- # input (similarly to the case above).
151- dim = self ._normalize_and_to_channel_last_dim (dim , x .rank )
152- else :
153- # If the channels dimension is preserved, we must transpose the input to channels first (to match
154- # the edge model) and we must keep the `dim` unchanged (referencing channels first dimensions).
155- # Otherwise, the output would not match the input.
156- to_channels_first_perm = (
157- translator .create_channels_last_to_channels_first_permutation (
158- x .rank
159- )
160- )
161- ops .add_pre (
162- self .builder .create_transpose_operator_before (
163- t_op , 0 , to_channels_first_perm
164- )
165- )
166- t_op .tmp_inputs [0 ].tensor_format = DataFormat .CHANNELS_FIRST
167-
168- elif formatless_input and channels_last_output :
169- # We need apply the `mean` with the original `dim`, which will produce a channels first output. Then,
170- # we need to append a `Transpose` operator to make the output channels last.
171- to_channels_last_perm = (
172- translator .create_channels_first_to_channels_last_permutation (
173- y .rank , True
174- )
175- )
176- ops .add_post (
177- self .builder .create_transpose_operator_after (
178- t_op , 0 , to_channels_last_perm
179- )
180- )
181- t_op .tmp_outputs [0 ].tensor_format = DataFormat .CHANNELS_FIRST
182-
183- elif formatless_input and formatless_output :
184- # No action needed.
185- pass
186-
187- else : # channels_last_input and channels_last_output
188- # This case cannot currently occur, as it would require the case:
189- # channels last 4D -> mean -> channels_last 3D
190- # which cannot currently happen as the 3D conv/pooling/... is supported by adding `view_copy` nodes in
191- # the edge dialect and converting the node to 4D, and the `view_copy` nodes prevent the propagation of
192- # the format to the `mean.dim` output.
193- # Therefore, the implementation cannot be tested. But from experience with other operators, it should
194- # work correctly. We just need to add 2 `Transpose` ops to make the IO channels first, and keep the
195- # `dim` unchanged.
196- to_channels_first_perm = (
197- translator .create_channels_last_to_channels_first_permutation (
198- x .rank
199- )
200- )
201- ops .add_pre (
202- self .builder .create_transpose_operator_before (
203- t_op , 0 , to_channels_first_perm
204- )
205- )
206- t_op .tmp_inputs [0 ].tensor_format = DataFormat .CHANNELS_FIRST
207-
208- to_channels_last_perm = (
209- translator .create_channels_first_to_channels_last_permutation (
210- y .rank , True
211- )
212- )
213- ops .add_post (
214- self .builder .create_transpose_operator_after (
215- t_op , 0 , to_channels_last_perm
216- )
217- )
218- t_op .tmp_outputs [0 ].tensor_format = DataFormat .CHANNELS_FIRST
219-
220- return dim
221-
22285 def convert (self , node : Node ):
22386 """Convert the 'mean.dim' operator to NeutronIR 'Mean'.
22487 The ExecuTorch schema is:
@@ -232,13 +95,13 @@ def convert(self, node: Node):
23295 """
23396 self .assert_convertible (node )
23497
235- dim , keepdim = self . _get_attrs (node )
98+ dim , keepdim = get_reduce_node_attrs (node )
23699
237100 t_op = self ._create_tflite_op_with_io_tensors (node )
238101 t_op .builtin_options = mean_options .Mean (keepdim )
239102
240103 ops = OpsList (middle_op = t_op )
241- dim = self ._get_dim_and_handle_io_formats ( ops , dim , keepdim )
104+ dim = get_dim_and_handle_io_formats ( self .builder , ops , dim , keepdim )
242105
243106 convert_axes_from_attribute (t_op , self .builder , dim )
244107 self .builder .append_operators (ops .flatten ())
0 commit comments