Skip to content

Commit 3f6d9b1

Browse files
committed
add tensorflow.gather_nd
1 parent 4e678f2 commit 3f6d9b1

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

Diff for: stubs/tensorflow/tensorflow/__init__.pyi

+12-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@ from tensorflow import (
2020
math as math,
2121
types as types,
2222
)
23-
from tensorflow._aliases import AnyArray, DTypeLike, ScalarTensorCompatible, ShapeLike, Slice, TensorCompatible
23+
from tensorflow._aliases import (
24+
AnyArray,
25+
DTypeLike,
26+
ScalarTensorCompatible,
27+
ShapeLike,
28+
Slice,
29+
TensorCompatible,
30+
UIntTensorCompatible,
31+
)
2432
from tensorflow.autodiff import GradientTape as GradientTape
2533
from tensorflow.core.protobuf import struct_pb2
2634
from tensorflow.dtypes import *
@@ -415,4 +423,7 @@ def shape(input: TensorCompatible, out_type: DTypeLike | None = None, name: str
415423
def where(
416424
condition: TensorCompatible, x: TensorCompatible | None = None, y: TensorCompatible | None = None, name: str | None = None
417425
) -> Tensor: ...
426+
def gather_nd(
427+
params: TensorCompatible, indices: UIntTensorCompatible, batch_dims: UIntTensorCompatible = 0, name: str | None = None
428+
) -> Tensor: ...
418429
def __getattr__(name: str) -> Incomplete: ...

0 commit comments

Comments
 (0)