Skip to content

Commit 5a93798

Browse files
authored
chore: use xp.take_along_axis if Array API version >=2024.12 (#4406)
see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced the `xp_take_along_axis` function to utilize newer array API features when available, improving performance and functionality. - **Bug Fixes** - Maintained compatibility with older versions of the array API by providing a fallback implementation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 2ce3276 commit 5a93798

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

deepmd/dpmodel/array_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
"""Utilities for the array API."""
33

44
import array_api_compat
5+
from packaging.version import (
6+
Version,
7+
)
58

69

710
def support_array_api(version: str) -> callable:
@@ -45,6 +48,9 @@ def xp_swapaxes(a, axis1, axis2):
4548

4649
def xp_take_along_axis(arr, indices, axis):
4750
xp = array_api_compat.array_namespace(arr)
51+
if Version(xp.__array_api_version__) >= Version("2024.12"):
52+
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
53+
return xp.take_along_axis(arr, indices, axis=axis)
4854
arr = xp_swapaxes(arr, axis, -1)
4955
indices = xp_swapaxes(indices, axis, -1)
5056

0 commit comments

Comments
 (0)