You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
gather(): Address indices validation and other algorithm nits (#642)
* gather(): Address indices validation and other algorithm nits
* #486 points out that indices can't be validated at build-time,
and clamping behavior with an implementation note is given
instead.
* Fix a typo in the steps.
* Replace several map-like iterations over lists with list iteration.
Fixes#486
* Update index.bs
Co-authored-by: Ningxin Hu <[email protected]>
* Add note about negative indices. fixes#484
* Update index.bs
Co-authored-by: Dwayne Robinson <[email protected]>
* Update index.bs
Co-authored-by: Dwayne Robinson <[email protected]>
* Fix grammar glitch
---------
Co-authored-by: Ningxin Hu <[email protected]>
Co-authored-by: Dwayne Robinson <[email protected]>
- *input*: an {{MLOperand}}. The input N-D tensor from which the values are gathered.
2718
-
- *indices*: an {{MLOperand}}. The indices N-D tensor of the input values to gather. The values must be of type {{MLOperandDataType/"uint32"}} or {{MLOperandDataType/"int64"}}in the range [0, N-1]where N is the size of the input dimension indexed by *options.axis*.
2718
+
- *indices*: an {{MLOperand}}. The indices N-D tensor of the input values to gather. The values must be of type {{MLOperandDataType/"uint32"}} or {{MLOperandDataType/"int64"}}, and must be in the range -N (inclusive) to N (exclusive) where N is the size of the input dimension indexed by *options.axis*, and a negative index means indexing from the end of the dimension.
2719
2719
- *options*: an optional {{MLGatherOptions}}. The optional parameters of the operation.
2720
2720
2721
2721
**Returns:** an {{MLOperand}}. The output N-D tensor of [=MLOperand/rank=] equal to the [=MLOperand/rank=] of *input* + the [=MLOperand/rank=] of *indices* - 1.
2722
2722
</div>
2723
2723
2724
+
<div class="note">
2725
+
The {{MLGraphBuilder/gather(input, indices, options)/indices}} parameter to {{MLGraphBuilder/gather()}} can not be clamped to the allowed range when the graph is built because the inputs are not known until execution. Implementations can introduce {{MLGraphBuilder/clamp()}} in the compiled graph if the required clamping behavior is not provided by the underlying platform. Similarly, if the underlying platform does not support negative indices, the implementation can introduce operations in the compiled graph to transform a negative index from the end of the dimension into a positive index.
2726
+
</div>
2727
+
2724
2728
<details open algorithm>
2725
2729
<summary>
2726
2730
The <dfn method for=MLGraphBuilder>gather(|input|, |indices|, |options|)</dfn> method steps are:
2727
2731
</summary>
2728
-
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input| abd |indices| returns false, then [=exception/throw=] a {{TypeError}}.
2732
+
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input| and |indices| returns false, then [=exception/throw=] a {{TypeError}}.
2729
2733
1. If |indices|'s [=MLOperand/dataType=] is neither {{MLOperandDataType/"uint32"}} nor {{MLOperandDataType/"int64"}}, then [=exception/throw=] a {{TypeError}}.
2730
2734
1. Let |shapeInput| be |input|'s [=MLOperand/shape=] and |rankInput| be |shapeInput|'s [=MLOperand/rank=].
2731
2735
1. Let |shapeIndices| be |indices|'s [=MLOperand/shape=].
2732
2736
1. Let |axis| be |options|.{{MLGatherOptions/axis}}.
2733
-
1. Let |axisSize| be |input|'s [=MLOperand/shape=][|axis|]
2734
2737
1. If |axis| is greater than or equal to |rankInput|, then [=exception/throw=] a {{TypeError}}.
2735
-
1. [=map/For each=] |index| → |value| of |indices|:
2736
-
1. If |index| is greater than or equal to |axisSize|, then [=exception/throw=] a {{TypeError}}.
2737
2738
1. Let |dimCount| be zero.
2738
2739
1. Let |rankOutput| be zero.
2739
2740
1. Let |shapeOutput| be an empty list.
2740
-
1. [=map/For each=] |size| → |value| of |shapeInput|:
2741
+
1. [=list/For each=] |size| of |shapeInput|:
2741
2742
1. If |dimCount| is equal to |axis| then [=iteration/break=].
2742
2743
1. Set |shapeOutput|[|dimCount|] to |size|.
2743
2744
1. Increment |dimCount| by one.
2744
2745
1. Set |rankOutput| to |dimCount|.
2745
2746
1. Let |dimCount| be zero.
2746
-
1. [=map/For each=] |size| → |value| of |shapeIndices|:
2747
+
1. [=list/For each=] |size| of |shapeIndices|:
2747
2748
1. Set |shapeOutput|[|rankOutput| + |dimCount|] to |size|.
2748
2749
1. Increment |dimCount| by one.
2749
2750
1. Set |rankOutput| to |rankOutput| + |dimCount|.
2750
2751
1. Let |dimCount| be zero.
2751
-
1. [=map/For each=] |size| → |value| of |shapeInput|:
2752
+
1. [=list/For each=] |size| of |shapeInput|:
2752
2753
1. If |dimCount| is less than or equal to |axis| then [=iteration/continue=].
2753
2754
1. Set |shapeOutput|[|rankOutput| + |dimCount| - |axis| - 1] to |size|.
0 commit comments