You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
gather(): Address indices validation and other algorithm nits
* webmachinelearning#486 points out that indices can't be validated at build-time,
and clamping behavior with an implementation note is given
instead.
* Fix a typo in the steps.
* Replace several map-like iterations over lists with list iteration.
Fixeswebmachinelearning#486
- *input*: an {{MLOperand}}. The input N-D tensor from which the values are gathered.
2717
-
- *indices*: an {{MLOperand}}. The indices N-D tensor of the input values to gather. The values must be of type {{MLOperandDataType/"uint32"}} or {{MLOperandDataType/"int64"}} in the range [0, N-1]where N is the size of the input dimension indexed by *options.axis*.
2717
+
- *indices*: an {{MLOperand}}. The indices N-D tensor of the input values to gather. The values must be of type {{MLOperandDataType/"uint32"}} or {{MLOperandDataType/"int64"}}, and are clamped to the range -N (inclusive) to N (exclusive) where N is the size of the input dimension indexed by *options.axis*, and a negative index means indexing from the end of the dimension.
2718
2718
- *options*: an optional {{MLGatherOptions}}. The optional parameters of the operation.
2719
2719
2720
2720
**Returns:** an {{MLOperand}}. The output N-D tensor of [=MLOperand/rank=] equal to the [=MLOperand/rank=] of *input* + the [=MLOperand/rank=] of *indices* - 1.
2721
2721
</div>
2722
2722
2723
+
<div class="note">
2724
+
The {{MLGraphBuilder/gather(input, indices, options)/indices}} parameter to {{MLGraphBuilder/gather()}} can not be clamped to the allowed range when the graph is built. Implementations can introduce clamping operands in the compiled graph if the required clamping behavior is not provided by the underlying platform.
2725
+
</div>
2726
+
2723
2727
<details open algorithm>
2724
2728
<summary>
2725
2729
The <dfn method for=MLGraphBuilder>gather(|input|, |indices|, |options|)</dfn> method steps are:
2726
2730
</summary>
2727
-
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input| abd |indices| returns false, then [=exception/throw=] a {{TypeError}}.
2731
+
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input| and |indices| returns false, then [=exception/throw=] a {{TypeError}}.
2728
2732
1. If |indices|'s [=MLOperand/dataType=] is neither {{MLOperandDataType/"uint32"}} nor {{MLOperandDataType/"int64"}}, then [=exception/throw=] a {{TypeError}}.
2729
2733
1. Let |shapeInput| be |input|'s [=MLOperand/shape=] and |rankInput| be |shapeInput|'s [=MLOperand/rank=].
2730
2734
1. Let |shapeIndices| be |indices|'s [=MLOperand/shape=].
2731
2735
1. Let |axis| be |options|.{{MLGatherOptions/axis}}.
2732
-
1. Let |axisSize| be |input|'s [=MLOperand/shape=][|axis|]
2733
2736
1. If |axis| is greater than or equal to |rankInput|, then [=exception/throw=] a {{TypeError}}.
2734
-
1. [=map/For each=] |index| → |value| of |indices|:
2735
-
1. If |index| is greater than or equal to |axisSize|, then [=exception/throw=] a {{TypeError}}.
2736
2737
1. Let |dimCount| be zero.
2737
2738
1. Let |rankOutput| be zero.
2738
2739
1. Let |shapeOutput| be an empty list.
2739
-
1. [=map/For each=] |size| → |value| of |shapeInput|:
2740
+
1. [=list/For each=] |size| of |shapeInput|:
2740
2741
1. If |dimCount| is equal to |axis| then [=iteration/break=].
2741
2742
1. Set |shapeOutput|[|dimCount|] to |size|.
2742
2743
1. Increment |dimCount| by one.
2743
2744
1. Set |rankOutput| to |dimCount|.
2744
2745
1. Let |dimCount| be zero.
2745
-
1. [=map/For each=] |size| → |value| of |shapeIndices|:
2746
+
1. [=list/For each=] |size| of |shapeIndices|:
2746
2747
1. Set |shapeOutput|[|rankOutput| + |dimCount|] to |size|.
2747
2748
1. Increment |dimCount| by one.
2748
2749
1. Set |rankOutput| to |rankOutput| + |dimCount|.
2749
2750
1. Let |dimCount| be zero.
2750
-
1. [=map/For each=] |size| → |value| of |shapeInput|:
2751
+
1. [=list/For each=] |size| of |shapeInput|:
2751
2752
1. If |dimCount| is less than or equal to |axis| then [=iteration/continue=].
2752
2753
1. Set |shapeOutput|[|rankOutput| + |dimCount| - |axis| - 1] to |size|.
0 commit comments