Skip to content

Update the api_version check to allow 2022.12 and warn on 2021.12 #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import math
import inspect
import warnings

def is_numpy_array(x):
# Avoid importing NumPy if it isn't already
Expand Down Expand Up @@ -77,8 +78,10 @@ def is_array_api_obj(x):
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
if api_version is not None and api_version != '2021.12':
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
elif api_version is not None and api_version != '2022.12':
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The text of the error isn't right, no? Because 2022.12 and 2021.12 are supported


Can we somehow re-do the if statements? I think something like the below would be easier to read and implement the behaviour Ralf suggested (no warning, just return subset)

if api_version is not in ('2021.12', '2022.12', None):
    raise ValueError(f"The specified Array API version '{api_version}' is not supported.")

WDYT?


def array_namespace(*xs, api_version=None, _use_compat=True):
"""
Expand Down