diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 84d1f3d4d..587a4c3e4 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -89,6 +89,8 @@ client_node_configs, is_requests_http_auth, is_requests_node_class, + is_httpx_http_auth, + is_httpx_node_class, ) from .watcher import WatcherClient from .xpack import XPackClient @@ -244,6 +246,20 @@ def __init__( requests_session_auth = http_auth http_auth = DEFAULT + if is_httpx_http_auth(http_auth): + # If we're using custom httpx authentication + # then we need to alert the user that they also + # need to use 'node_class=httpxasync'. + if not is_httpx_node_class(node_class): + raise ValueError( + "Using a custom 'httpx.Auth' class for " + "'http_auth' must be used with node_class='httpxasync'" + ) + + # Reset 'http_auth' to DEFAULT so it's not consumed below. + requests_session_auth = http_auth + http_auth = DEFAULT + node_configs = client_node_configs( hosts, cloud_id=cloud_id, diff --git a/elasticsearch/_async/client/utils.py b/elasticsearch/_async/client/utils.py index 97918d9e4..aefc753de 100644 --- a/elasticsearch/_async/client/utils.py +++ b/elasticsearch/_async/client/utils.py @@ -15,6 +15,15 @@ # specific language governing permissions and limitations # under the License. +import inspect +from typing import ( + TYPE_CHECKING, + Any, +) + +from elastic_transport import HttpxAsyncHttpNode +from elastic_transport.client_utils import DEFAULT + from ..._sync.client.utils import ( _TYPE_ASYNC_SNIFF_CALLBACK, _TYPE_HOSTS, @@ -31,6 +40,33 @@ is_requests_node_class, ) + +def is_httpx_http_auth(http_auth: Any) -> bool: + """Detect if an http_auth value is a custom Httpx auth object""" + try: + from httpx import Auth + + return isinstance(http_auth, Auth) + except ImportError: + pass + return False + + +def is_httpx_node_class(node_class: Any) -> bool: + """Detect if 'HttpxAsyncHttpNode' would be used given the setting of 'node_class'""" + return ( + node_class is not None + and node_class is not DEFAULT + and ( + node_class == "httpxasync" + or ( + inspect.isclass(node_class) + and issubclass(node_class, HttpxAsyncHttpNode) + ) + ) + ) + + __all__ = [ "CLIENT_META_SERVICE", "_TYPE_ASYNC_SNIFF_CALLBACK", @@ -45,4 +81,6 @@ "_stability_warning", "is_requests_http_auth", "is_requests_node_class", + "is_httpx_http_auth", + "is_httpx_node_class", ]