diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index f88bb0190..e43e1d6a7 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -20,6 +20,7 @@ import typing as t from elastic_transport import ( + ApiResponse, AsyncTransport, BaseNode, BinaryApiResponse, @@ -97,7 +98,7 @@ SelfType = t.TypeVar("SelfType", bound="AsyncElasticsearch") -class AsyncElasticsearch(BaseClient): +class AsyncElasticsearch: """ Elasticsearch low-level client. Provides a straightforward mapping from Python to Elasticsearch REST APIs. @@ -295,72 +296,103 @@ def __init__( **transport_kwargs, ) - super().__init__(_transport) + self._base_client = BaseClient(_transport) # These are set per-request so are stored separately. - self._request_timeout = request_timeout - self._max_retries = max_retries - self._retry_on_timeout = retry_on_timeout + self._base_client._request_timeout = request_timeout + self._base_client._max_retries = max_retries + self._base_client._retry_on_timeout = retry_on_timeout if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - self._retry_on_status = retry_on_status + self._base_client._retry_on_status = retry_on_status else: - super().__init__(_transport) + self._base_client = BaseClient(_transport) if headers is not DEFAULT and headers is not None: - self._headers.update(headers) + self._base_client._headers.update(headers) if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] - self._headers["x-opaque-id"] = opaque_id - self._headers = resolve_auth_headers( - self._headers, + self._base_client._headers["x-opaque-id"] = opaque_id + self._base_client._headers = resolve_auth_headers( + self._base_client._headers, api_key=api_key, basic_auth=basic_auth, bearer_auth=bearer_auth, ) # namespaced clients for compatibility with API names - self.async_search = AsyncSearchClient(self) - self.autoscaling = AutoscalingClient(self) - self.cat = CatClient(self) - self.cluster = ClusterClient(self) - self.connector = ConnectorClient(self) - self.fleet = FleetClient(self) - self.features = FeaturesClient(self) - self.indices = IndicesClient(self) - self.inference = InferenceClient(self) - self.ingest = IngestClient(self) - self.nodes = NodesClient(self) - self.snapshot = SnapshotClient(self) - self.tasks = TasksClient(self) - - self.xpack = XPackClient(self) - self.ccr = CcrClient(self) - self.dangling_indices = DanglingIndicesClient(self) - self.enrich = EnrichClient(self) - self.eql = EqlClient(self) - self.esql = EsqlClient(self) - self.graph = GraphClient(self) - self.ilm = IlmClient(self) - self.license = LicenseClient(self) - self.logstash = LogstashClient(self) - self.migration = MigrationClient(self) - self.ml = MlClient(self) - self.monitoring = MonitoringClient(self) - self.query_rules = QueryRulesClient(self) - self.rollup = RollupClient(self) - self.search_application = SearchApplicationClient(self) - self.searchable_snapshots = SearchableSnapshotsClient(self) - self.security = SecurityClient(self) - self.slm = SlmClient(self) - self.simulate = SimulateClient(self) - self.shutdown = ShutdownClient(self) - self.sql = SqlClient(self) - self.ssl = SslClient(self) - self.synonyms = SynonymsClient(self) - self.text_structure = TextStructureClient(self) - self.transform = TransformClient(self) - self.watcher = WatcherClient(self) + self.async_search = AsyncSearchClient(self._base_client) + self.autoscaling = AutoscalingClient(self._base_client) + self.cat = CatClient(self._base_client) + self.cluster = ClusterClient(self._base_client) + self.connector = ConnectorClient(self._base_client) + self.fleet = FleetClient(self._base_client) + self.features = FeaturesClient(self._base_client) + self.indices = IndicesClient(self._base_client) + self.inference = InferenceClient(self._base_client) + self.ingest = IngestClient(self._base_client) + self.nodes = NodesClient(self._base_client) + self.snapshot = SnapshotClient(self._base_client) + self.tasks = TasksClient(self._base_client) + + self.xpack = XPackClient(self._base_client) + self.ccr = CcrClient(self._base_client) + self.dangling_indices = DanglingIndicesClient(self._base_client) + self.enrich = EnrichClient(self._base_client) + self.eql = EqlClient(self._base_client) + self.esql = EsqlClient(self._base_client) + self.graph = GraphClient(self._base_client) + self.ilm = IlmClient(self._base_client) + self.license = LicenseClient(self._base_client) + self.logstash = LogstashClient(self._base_client) + self.migration = MigrationClient(self._base_client) + self.ml = MlClient(self._base_client) + self.monitoring = MonitoringClient(self._base_client) + self.query_rules = QueryRulesClient(self._base_client) + self.rollup = RollupClient(self._base_client) + self.search_application = SearchApplicationClient(self._base_client) + self.searchable_snapshots = SearchableSnapshotsClient(self._base_client) + self.security = SecurityClient(self._base_client) + self.slm = SlmClient(self._base_client) + self.simulate = SimulateClient(self._base_client) + self.shutdown = ShutdownClient(self._base_client) + self.sql = SqlClient(self._base_client) + self.ssl = SslClient(self._base_client) + self.synonyms = SynonymsClient(self._base_client) + self.text_structure = TextStructureClient(self._base_client) + self.transform = TransformClient(self._base_client) + self.watcher = WatcherClient(self._base_client) + + @property + def transport(self) -> AsyncTransport: + return self._base_client._transport + + async def perform_request( + self, + method: str, + path: str, + *, + params: t.Optional[t.Mapping[str, t.Any]] = None, + headers: t.Optional[t.Mapping[str, str]] = None, + body: t.Optional[t.Any] = None, + endpoint_id: t.Optional[str] = None, + path_parts: t.Optional[t.Mapping[str, t.Any]] = None, + ) -> ApiResponse[t.Any]: + with self._base_client._otel.span( + method, + endpoint_id=endpoint_id, + path_parts=path_parts or {}, + ) as otel_span: + response = await self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response def __repr__(self) -> str: try: @@ -413,44 +445,44 @@ def options( resolved_headers["x-opaque-id"] = resolved_opaque_id if resolved_headers: - new_headers = self._headers.copy() + new_headers = self._base_client._headers.copy() new_headers.update(resolved_headers) - client._headers = new_headers + client._base_client._headers = new_headers else: - client._headers = self._headers.copy() + client._base_client._headers = self._base_client._headers.copy() if request_timeout is not DEFAULT: - client._request_timeout = request_timeout + client._base_client._request_timeout = request_timeout else: - client._request_timeout = self._request_timeout + client._base_client._request_timeout = self._base_client._request_timeout if ignore_status is not DEFAULT: if isinstance(ignore_status, int): ignore_status = (ignore_status,) - client._ignore_status = ignore_status + client._base_client._ignore_status = ignore_status else: - client._ignore_status = self._ignore_status + client._base_client._ignore_status = self._base_client._ignore_status if max_retries is not DEFAULT: if not isinstance(max_retries, int): raise TypeError("'max_retries' must be of type 'int'") - client._max_retries = max_retries + client._base_client._max_retries = max_retries else: - client._max_retries = self._max_retries + client._base_client._max_retries = self._base_client._max_retries if retry_on_status is not DEFAULT: if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - client._retry_on_status = retry_on_status + client._base_client._retry_on_status = retry_on_status else: - client._retry_on_status = self._retry_on_status + client._base_client._retry_on_status = self._base_client._retry_on_status if retry_on_timeout is not DEFAULT: if not isinstance(retry_on_timeout, bool): raise TypeError("'retry_on_timeout' must be of type 'bool'") - client._retry_on_timeout = retry_on_timeout + client._base_client._retry_on_timeout = retry_on_timeout else: - client._retry_on_timeout = self._retry_on_timeout + client._base_client._retry_on_timeout = self._base_client._retry_on_timeout return client diff --git a/elasticsearch/_async/client/_base.py b/elasticsearch/_async/client/_base.py index ed61f7bc4..8e6e5fbd9 100644 --- a/elasticsearch/_async/client/_base.py +++ b/elasticsearch/_async/client/_base.py @@ -217,42 +217,11 @@ def __init__(self, _transport: AsyncTransport) -> None: self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT - self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT + self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() - @property - def transport(self) -> AsyncTransport: - return self._transport - - async def perform_request( - self, - method: str, - path: str, - *, - params: Optional[Mapping[str, Any]] = None, - headers: Optional[Mapping[str, str]] = None, - body: Optional[Any] = None, - endpoint_id: Optional[str] = None, - path_parts: Optional[Mapping[str, Any]] = None, - ) -> ApiResponse[Any]: - with self._otel.span( - method, - endpoint_id=endpoint_id, - path_parts=path_parts or {}, - ) as otel_span: - response = await self._perform_request( - method, - path, - params=params, - headers=headers, - body=body, - otel_span=otel_span, - ) - otel_span.set_elastic_cloud_metadata(response.meta.headers) - return response - async def _perform_request( self, method: str, @@ -287,7 +256,7 @@ def mimetype_header_to_compat(header: str) -> None: else: target = path - meta, resp_body = await self.transport.perform_request( + meta, resp_body = await self._transport.perform_request( method, target, headers=request_headers, @@ -376,10 +345,9 @@ def mimetype_header_to_compat(header: str) -> None: return response -class NamespacedClient(BaseClient): - def __init__(self, client: "BaseClient") -> None: - self._client = client - super().__init__(self._client.transport) +class NamespacedClient: + def __init__(self, client: BaseClient) -> None: + self._base_client = client async def perform_request( self, @@ -392,14 +360,18 @@ async def perform_request( endpoint_id: Optional[str] = None, path_parts: Optional[Mapping[str, Any]] = None, ) -> ApiResponse[Any]: - # Use the internal clients .perform_request() implementation - # so we take advantage of their transport options. - return await self._client.perform_request( + with self._base_client._otel.span( method, - path, - params=params, - headers=headers, - body=body, endpoint_id=endpoint_id, - path_parts=path_parts, - ) + path_parts=path_parts or {}, + ) as otel_span: + response = await self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response diff --git a/elasticsearch/_async/helpers.py b/elasticsearch/_async/helpers.py index 7acc41ecd..51b15e0ba 100644 --- a/elasticsearch/_async/helpers.py +++ b/elasticsearch/_async/helpers.py @@ -216,7 +216,7 @@ async def async_streaming_bulk( """ client = client.options() - client._client_meta = (("h", "bp"),) + client._base_client._client_meta = (("h", "bp"),) if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) @@ -429,7 +429,7 @@ def pop_transport_kwargs(kw: MutableMapping[str, Any]) -> MutableMapping[str, An client = client.options( request_timeout=request_timeout, **pop_transport_kwargs(kwargs) ) - client._client_meta = (("h", "s"),) + client._base_client._client_meta = (("h", "s"),) # Setting query={"from": ...} would make 'from' be used # as a keyword argument instead of 'from_'. We handle that here. diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index b39cbae26..0c6db746e 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -20,6 +20,7 @@ import typing as t from elastic_transport import ( + ApiResponse, BaseNode, BinaryApiResponse, HeadApiResponse, @@ -97,7 +98,7 @@ SelfType = t.TypeVar("SelfType", bound="Elasticsearch") -class Elasticsearch(BaseClient): +class Elasticsearch: """ Elasticsearch low-level client. Provides a straightforward mapping from Python to Elasticsearch REST APIs. @@ -295,72 +296,103 @@ def __init__( **transport_kwargs, ) - super().__init__(_transport) + self._base_client = BaseClient(_transport) # These are set per-request so are stored separately. - self._request_timeout = request_timeout - self._max_retries = max_retries - self._retry_on_timeout = retry_on_timeout + self._base_client._request_timeout = request_timeout + self._base_client._max_retries = max_retries + self._base_client._retry_on_timeout = retry_on_timeout if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - self._retry_on_status = retry_on_status + self._base_client._retry_on_status = retry_on_status else: - super().__init__(_transport) + self._base_client = BaseClient(_transport) if headers is not DEFAULT and headers is not None: - self._headers.update(headers) + self._base_client._headers.update(headers) if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap] - self._headers["x-opaque-id"] = opaque_id - self._headers = resolve_auth_headers( - self._headers, + self._base_client._headers["x-opaque-id"] = opaque_id + self._base_client._headers = resolve_auth_headers( + self._base_client._headers, api_key=api_key, basic_auth=basic_auth, bearer_auth=bearer_auth, ) # namespaced clients for compatibility with API names - self.async_search = AsyncSearchClient(self) - self.autoscaling = AutoscalingClient(self) - self.cat = CatClient(self) - self.cluster = ClusterClient(self) - self.connector = ConnectorClient(self) - self.fleet = FleetClient(self) - self.features = FeaturesClient(self) - self.indices = IndicesClient(self) - self.inference = InferenceClient(self) - self.ingest = IngestClient(self) - self.nodes = NodesClient(self) - self.snapshot = SnapshotClient(self) - self.tasks = TasksClient(self) - - self.xpack = XPackClient(self) - self.ccr = CcrClient(self) - self.dangling_indices = DanglingIndicesClient(self) - self.enrich = EnrichClient(self) - self.eql = EqlClient(self) - self.esql = EsqlClient(self) - self.graph = GraphClient(self) - self.ilm = IlmClient(self) - self.license = LicenseClient(self) - self.logstash = LogstashClient(self) - self.migration = MigrationClient(self) - self.ml = MlClient(self) - self.monitoring = MonitoringClient(self) - self.query_rules = QueryRulesClient(self) - self.rollup = RollupClient(self) - self.search_application = SearchApplicationClient(self) - self.searchable_snapshots = SearchableSnapshotsClient(self) - self.security = SecurityClient(self) - self.slm = SlmClient(self) - self.simulate = SimulateClient(self) - self.shutdown = ShutdownClient(self) - self.sql = SqlClient(self) - self.ssl = SslClient(self) - self.synonyms = SynonymsClient(self) - self.text_structure = TextStructureClient(self) - self.transform = TransformClient(self) - self.watcher = WatcherClient(self) + self.async_search = AsyncSearchClient(self._base_client) + self.autoscaling = AutoscalingClient(self._base_client) + self.cat = CatClient(self._base_client) + self.cluster = ClusterClient(self._base_client) + self.connector = ConnectorClient(self._base_client) + self.fleet = FleetClient(self._base_client) + self.features = FeaturesClient(self._base_client) + self.indices = IndicesClient(self._base_client) + self.inference = InferenceClient(self._base_client) + self.ingest = IngestClient(self._base_client) + self.nodes = NodesClient(self._base_client) + self.snapshot = SnapshotClient(self._base_client) + self.tasks = TasksClient(self._base_client) + + self.xpack = XPackClient(self._base_client) + self.ccr = CcrClient(self._base_client) + self.dangling_indices = DanglingIndicesClient(self._base_client) + self.enrich = EnrichClient(self._base_client) + self.eql = EqlClient(self._base_client) + self.esql = EsqlClient(self._base_client) + self.graph = GraphClient(self._base_client) + self.ilm = IlmClient(self._base_client) + self.license = LicenseClient(self._base_client) + self.logstash = LogstashClient(self._base_client) + self.migration = MigrationClient(self._base_client) + self.ml = MlClient(self._base_client) + self.monitoring = MonitoringClient(self._base_client) + self.query_rules = QueryRulesClient(self._base_client) + self.rollup = RollupClient(self._base_client) + self.search_application = SearchApplicationClient(self._base_client) + self.searchable_snapshots = SearchableSnapshotsClient(self._base_client) + self.security = SecurityClient(self._base_client) + self.slm = SlmClient(self._base_client) + self.simulate = SimulateClient(self._base_client) + self.shutdown = ShutdownClient(self._base_client) + self.sql = SqlClient(self._base_client) + self.ssl = SslClient(self._base_client) + self.synonyms = SynonymsClient(self._base_client) + self.text_structure = TextStructureClient(self._base_client) + self.transform = TransformClient(self._base_client) + self.watcher = WatcherClient(self._base_client) + + @property + def transport(self) -> Transport: + return self._base_client._transport + + def perform_request( + self, + method: str, + path: str, + *, + params: t.Optional[t.Mapping[str, t.Any]] = None, + headers: t.Optional[t.Mapping[str, str]] = None, + body: t.Optional[t.Any] = None, + endpoint_id: t.Optional[str] = None, + path_parts: t.Optional[t.Mapping[str, t.Any]] = None, + ) -> ApiResponse[t.Any]: + with self._base_client._otel.span( + method, + endpoint_id=endpoint_id, + path_parts=path_parts or {}, + ) as otel_span: + response = self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response def __repr__(self) -> str: try: @@ -413,44 +445,44 @@ def options( resolved_headers["x-opaque-id"] = resolved_opaque_id if resolved_headers: - new_headers = self._headers.copy() + new_headers = self._base_client._headers.copy() new_headers.update(resolved_headers) - client._headers = new_headers + client._base_client._headers = new_headers else: - client._headers = self._headers.copy() + client._base_client._headers = self._base_client._headers.copy() if request_timeout is not DEFAULT: - client._request_timeout = request_timeout + client._base_client._request_timeout = request_timeout else: - client._request_timeout = self._request_timeout + client._base_client._request_timeout = self._base_client._request_timeout if ignore_status is not DEFAULT: if isinstance(ignore_status, int): ignore_status = (ignore_status,) - client._ignore_status = ignore_status + client._base_client._ignore_status = ignore_status else: - client._ignore_status = self._ignore_status + client._base_client._ignore_status = self._base_client._ignore_status if max_retries is not DEFAULT: if not isinstance(max_retries, int): raise TypeError("'max_retries' must be of type 'int'") - client._max_retries = max_retries + client._base_client._max_retries = max_retries else: - client._max_retries = self._max_retries + client._base_client._max_retries = self._base_client._max_retries if retry_on_status is not DEFAULT: if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) - client._retry_on_status = retry_on_status + client._base_client._retry_on_status = retry_on_status else: - client._retry_on_status = self._retry_on_status + client._base_client._retry_on_status = self._base_client._retry_on_status if retry_on_timeout is not DEFAULT: if not isinstance(retry_on_timeout, bool): raise TypeError("'retry_on_timeout' must be of type 'bool'") - client._retry_on_timeout = retry_on_timeout + client._base_client._retry_on_timeout = retry_on_timeout else: - client._retry_on_timeout = self._retry_on_timeout + client._base_client._retry_on_timeout = self._base_client._retry_on_timeout return client diff --git a/elasticsearch/_sync/client/_base.py b/elasticsearch/_sync/client/_base.py index 7d4617f74..02542eb43 100644 --- a/elasticsearch/_sync/client/_base.py +++ b/elasticsearch/_sync/client/_base.py @@ -217,42 +217,11 @@ def __init__(self, _transport: Transport) -> None: self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT self._max_retries: Union[DefaultType, int] = DEFAULT - self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT + self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT self._verified_elasticsearch = False self._otel = OpenTelemetry() - @property - def transport(self) -> Transport: - return self._transport - - def perform_request( - self, - method: str, - path: str, - *, - params: Optional[Mapping[str, Any]] = None, - headers: Optional[Mapping[str, str]] = None, - body: Optional[Any] = None, - endpoint_id: Optional[str] = None, - path_parts: Optional[Mapping[str, Any]] = None, - ) -> ApiResponse[Any]: - with self._otel.span( - method, - endpoint_id=endpoint_id, - path_parts=path_parts or {}, - ) as otel_span: - response = self._perform_request( - method, - path, - params=params, - headers=headers, - body=body, - otel_span=otel_span, - ) - otel_span.set_elastic_cloud_metadata(response.meta.headers) - return response - def _perform_request( self, method: str, @@ -287,7 +256,7 @@ def mimetype_header_to_compat(header: str) -> None: else: target = path - meta, resp_body = self.transport.perform_request( + meta, resp_body = self._transport.perform_request( method, target, headers=request_headers, @@ -376,10 +345,9 @@ def mimetype_header_to_compat(header: str) -> None: return response -class NamespacedClient(BaseClient): - def __init__(self, client: "BaseClient") -> None: - self._client = client - super().__init__(self._client.transport) +class NamespacedClient: + def __init__(self, client: BaseClient) -> None: + self._base_client = client def perform_request( self, @@ -392,14 +360,18 @@ def perform_request( endpoint_id: Optional[str] = None, path_parts: Optional[Mapping[str, Any]] = None, ) -> ApiResponse[Any]: - # Use the internal clients .perform_request() implementation - # so we take advantage of their transport options. - return self._client.perform_request( + with self._base_client._otel.span( method, - path, - params=params, - headers=headers, - body=body, endpoint_id=endpoint_id, - path_parts=path_parts, - ) + path_parts=path_parts or {}, + ) as otel_span: + response = self._base_client._perform_request( + method, + path, + params=params, + headers=headers, + body=body, + otel_span=otel_span, + ) + otel_span.set_elastic_cloud_metadata(response.meta.headers) + return response diff --git a/elasticsearch/dsl/connections.py b/elasticsearch/dsl/connections.py index 8acd80c6e..350a250fd 100644 --- a/elasticsearch/dsl/connections.py +++ b/elasticsearch/dsl/connections.py @@ -117,15 +117,15 @@ def get_connection(self, alias: Union[str, _T] = "default") -> _T: def _with_user_agent(self, conn: _T) -> _T: # try to inject our user agent - if hasattr(conn, "_headers"): - is_frozen = conn._headers.frozen + if hasattr(conn, "_base_client") and hasattr(conn._base_client, "_headers"): + is_frozen = conn._base_client._headers.frozen if is_frozen: - conn._headers = conn._headers.copy() - conn._headers.update( + conn._base_client._headers = conn._base_client._headers.copy() + conn._base_client._headers.update( {"user-agent": f"elasticsearch-dsl-py/{__versionstr__}"} ) if is_frozen: - conn._headers.freeze() + conn._base_client._headers.freeze() return conn diff --git a/elasticsearch/helpers/actions.py b/elasticsearch/helpers/actions.py index 25c21cdd4..f1a095c9f 100644 --- a/elasticsearch/helpers/actions.py +++ b/elasticsearch/helpers/actions.py @@ -334,7 +334,7 @@ def _process_bulk_chunk( """ Send a bulk request to elasticsearch and process the output. """ - with client._otel.use_span(otel_span): + with client._base_client._otel.use_span(otel_span): if isinstance(ignore_status, int): ignore_status = (ignore_status,) @@ -416,9 +416,9 @@ def streaming_bulk( :arg yield_ok: if set to False will skip successful documents in the output :arg ignore_status: list of HTTP status code that you want to ignore """ - with client._otel.helpers_span(span_name) as otel_span: + with client._base_client._otel.helpers_span(span_name) as otel_span: client = client.options() - client._client_meta = (("h", "bp"),) + client._base_client._client_meta = (("h", "bp"),) if isinstance(retry_on_status, int): retry_on_status = (retry_on_status,) @@ -608,7 +608,7 @@ def _setup_queues(self) -> None: ] = Queue(max(queue_size, thread_count)) self._quick_put = self._inqueue.put - with client._otel.helpers_span("helpers.parallel_bulk") as otel_span: + with client._base_client._otel.helpers_span("helpers.parallel_bulk") as otel_span: pool = BlockingPool(thread_count) try: @@ -711,7 +711,7 @@ def pop_transport_kwargs(kw: MutableMapping[str, Any]) -> Dict[str, Any]: client = client.options( request_timeout=request_timeout, **pop_transport_kwargs(kwargs) ) - client._client_meta = (("h", "s"),) + client._base_client._client_meta = (("h", "s"),) # Setting query={"from": ...} would make 'from' be used # as a keyword argument instead of 'from_'. We handle that here. diff --git a/test_elasticsearch/test_client/test_options.py b/test_elasticsearch/test_client/test_options.py index c2050d186..05486cab4 100644 --- a/test_elasticsearch/test_client/test_options.py +++ b/test_elasticsearch/test_client/test_options.py @@ -290,7 +290,7 @@ def test_default_node_configs(self): headers={"key": "val"}, basic_auth=("username", "password"), ) - assert client._headers == { + assert client._base_client._headers == { "key": "val", "authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", } @@ -347,7 +347,7 @@ def test_http_headers_overrides(self): "authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", "user-agent": USER_AGENT, } - assert client._headers == {"key": "val"} + assert client._base_client._headers == {"key": "val"} def test_user_agent_override(self): client = Elasticsearch( diff --git a/test_elasticsearch/test_dsl/test_connections.py b/test_elasticsearch/test_dsl/test_connections.py index dcaa59a98..f7825f9b0 100644 --- a/test_elasticsearch/test_dsl/test_connections.py +++ b/test_elasticsearch/test_dsl/test_connections.py @@ -123,19 +123,23 @@ def test_connection_has_correct_user_agent() -> None: c.create_connection("testing", hosts=["https://es.com:9200"]) assert ( c.get_connection("testing") - ._headers["user-agent"] + ._base_client._headers["user-agent"] .startswith("elasticsearch-dsl-py/") ) my_client = Elasticsearch(hosts=["http://localhost:9200"]) my_client = my_client.options(headers={"user-agent": "my-user-agent/1.0"}) c.add_connection("default", my_client) - assert c.get_connection()._headers["user-agent"].startswith("elasticsearch-dsl-py/") + assert ( + c.get_connection() + ._base_client._headers["user-agent"] + .startswith("elasticsearch-dsl-py/") + ) my_client = Elasticsearch(hosts=["http://localhost:9200"]) assert ( c.get_connection(my_client) - ._headers["user-agent"] + ._base_client._headers["user-agent"] .startswith("elasticsearch-dsl-py/") ) diff --git a/test_elasticsearch/test_otel.py b/test_elasticsearch/test_otel.py index 48eb9ea58..682ee1f86 100644 --- a/test_elasticsearch/test_otel.py +++ b/test_elasticsearch/test_otel.py @@ -112,11 +112,11 @@ def test_forward_otel_context_to_subthreads( ): tracer, memory_exporter = setup_tracing() es_client = Elasticsearch("http://localhost:9200") - es_client._otel = OpenTelemetry(enabled=True, tracer=tracer) + es_client._base_client._otel = OpenTelemetry(enabled=True, tracer=tracer) _call_bulk_mock.return_value = mock.Mock() actions = ({"x": i} for i in range(100)) list(helpers.parallel_bulk(es_client, actions, chunk_size=4)) # Ensures that the OTEL context has been forwarded to all chunks - assert es_client._otel.helpers_span.call_count == 1 - assert es_client._otel.use_span.call_count == 25 + assert es_client._base_client._otel.helpers_span.call_count == 1 + assert es_client._base_client._otel.use_span.call_count == 25 diff --git a/test_elasticsearch/test_server/test_otel.py b/test_elasticsearch/test_server/test_otel.py index 3f8033d7b..11b7642eb 100644 --- a/test_elasticsearch/test_server/test_otel.py +++ b/test_elasticsearch/test_server/test_otel.py @@ -34,7 +34,7 @@ def test_otel_end_to_end(sync_client): tracer, memory_exporter = setup_tracing() - sync_client._otel.tracer = tracer + sync_client._base_client._otel.tracer = tracer resp = sync_client.search(index="logs-*", query={"match_all": {}}) assert resp.meta.status == 200 @@ -61,7 +61,7 @@ def test_otel_bulk(sync_client, elasticsearch_url, bulk_helper_name): # Create a new client with our tracer sync_client = sync_client.options() - sync_client._otel.tracer = tracer + sync_client._base_client._otel.tracer = tracer # "Disable" options to keep our custom tracer sync_client.options = lambda: sync_client