From ae408cc8a75ed4bee649d319ebf89b8038ee1ced Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 5 Jul 2023 11:51:23 -0700 Subject: [PATCH 1/2] Fix optional kwarg in function signature This commit fixes the function signature for `take`. Namely, when an input array is one-dimensional, the `axis` kwarg is optional; when the array has more than one dimension, the `axis` kwarg is required. Unfortunately, the type signature cannot encode this duality, and we must rely on the specification text to clarify that the `axis` kwarg is required for arrays having ranks greater than unity. Ref: https://github.com/data-apis/array-api-compat/issues/34 --- src/array_api_stubs/_2022_12/indexing_functions.py | 2 +- src/array_api_stubs/_draft/indexing_functions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_2022_12/indexing_functions.py b/src/array_api_stubs/_2022_12/indexing_functions.py index e76ce7484..117daf7cf 100644 --- a/src/array_api_stubs/_2022_12/indexing_functions.py +++ b/src/array_api_stubs/_2022_12/indexing_functions.py @@ -1,7 +1,7 @@ from ._types import Union, array -def take(x: array, indices: array, /, *, axis: int) -> array: +def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array: """ Returns elements of an array along an axis. diff --git a/src/array_api_stubs/_draft/indexing_functions.py b/src/array_api_stubs/_draft/indexing_functions.py index d57dc91e5..f46167018 100644 --- a/src/array_api_stubs/_draft/indexing_functions.py +++ b/src/array_api_stubs/_draft/indexing_functions.py @@ -1,7 +1,7 @@ from ._types import Union, array -def take(x: array, indices: array, /, *, axis: int) -> array: +def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array: """ Returns elements of an array along an axis. From 4d93cdcb0a2a797bda678677c671adf7f3ade8e0 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Wed, 5 Jul 2023 12:01:27 -0700 Subject: [PATCH 2/2] Fix imports --- src/array_api_stubs/_2022_12/indexing_functions.py | 2 +- src/array_api_stubs/_draft/indexing_functions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_2022_12/indexing_functions.py b/src/array_api_stubs/_2022_12/indexing_functions.py index 117daf7cf..0cbad55ab 100644 --- a/src/array_api_stubs/_2022_12/indexing_functions.py +++ b/src/array_api_stubs/_2022_12/indexing_functions.py @@ -1,4 +1,4 @@ -from ._types import Union, array +from ._types import Union, Optional, array def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array: diff --git a/src/array_api_stubs/_draft/indexing_functions.py b/src/array_api_stubs/_draft/indexing_functions.py index f46167018..3b218fdac 100644 --- a/src/array_api_stubs/_draft/indexing_functions.py +++ b/src/array_api_stubs/_draft/indexing_functions.py @@ -1,4 +1,4 @@ -from ._types import Union, array +from ._types import Union, Optional, array def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array: