Skip to content

Commit e6c68bb

Browse files
committed
Fix cyclic dependency between clients and sub clients
While the Python GC can handle those dependencies, it can cause latency spikes when .option() is used for each query, which causes many clients to be garbage collected, which can be slower in the presence of cycles.
1 parent 7ae3235 commit e6c68bb

File tree

4 files changed

+250
-240
lines changed

4 files changed

+250
-240
lines changed

elasticsearch/_async/client/__init__.py

+106-72
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import typing as t
2121

2222
from elastic_transport import (
23+
ApiResponse,
2324
AsyncTransport,
2425
BaseNode,
2526
BinaryApiResponse,
2627
HeadApiResponse,
28+
HttpHeaders,
2729
NodeConfig,
2830
NodePool,
2931
NodeSelector,
@@ -97,7 +99,7 @@
9799
SelfType = t.TypeVar("SelfType", bound="AsyncElasticsearch")
98100

99101

100-
class AsyncElasticsearch(BaseClient):
102+
class AsyncElasticsearch:
101103
"""
102104
Elasticsearch low-level client. Provides a straightforward mapping from
103105
Python to Elasticsearch REST APIs.
@@ -224,6 +226,18 @@ def __init__(
224226
):
225227
sniff_callback = default_sniff_callback
226228

229+
headers = HttpHeaders()
230+
if headers is not DEFAULT and headers is not None:
231+
headers.update(headers)
232+
if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap]
233+
headers["x-opaque-id"] = opaque_id
234+
headers = resolve_auth_headers(
235+
headers,
236+
api_key=api_key,
237+
basic_auth=basic_auth,
238+
bearer_auth=bearer_auth,
239+
)
240+
227241
if _transport is None:
228242
node_configs = client_node_configs(
229243
hosts,
@@ -295,72 +309,92 @@ def __init__(
295309
**transport_kwargs,
296310
)
297311

298-
super().__init__(_transport)
312+
self._base_client = BaseClient(_transport, headers=headers)
299313

300314
# These are set per-request so are stored separately.
301-
self._request_timeout = request_timeout
302-
self._max_retries = max_retries
303-
self._retry_on_timeout = retry_on_timeout
315+
self._base_client._request_timeout = request_timeout
316+
self._base_client._max_retries = max_retries
317+
self._base_client._retry_on_timeout = retry_on_timeout
304318
if isinstance(retry_on_status, int):
305319
retry_on_status = (retry_on_status,)
306-
self._retry_on_status = retry_on_status
320+
self._base_client._retry_on_status = retry_on_status
307321

308322
else:
309-
super().__init__(_transport)
310-
311-
if headers is not DEFAULT and headers is not None:
312-
self._headers.update(headers)
313-
if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap]
314-
self._headers["x-opaque-id"] = opaque_id
315-
self._headers = resolve_auth_headers(
316-
self._headers,
317-
api_key=api_key,
318-
basic_auth=basic_auth,
319-
bearer_auth=bearer_auth,
320-
)
323+
self._base_client = BaseClient(_transport, headers=headers)
321324

322325
# namespaced clients for compatibility with API names
323-
self.async_search = AsyncSearchClient(self)
324-
self.autoscaling = AutoscalingClient(self)
325-
self.cat = CatClient(self)
326-
self.cluster = ClusterClient(self)
327-
self.connector = ConnectorClient(self)
328-
self.fleet = FleetClient(self)
329-
self.features = FeaturesClient(self)
330-
self.indices = IndicesClient(self)
331-
self.inference = InferenceClient(self)
332-
self.ingest = IngestClient(self)
333-
self.nodes = NodesClient(self)
334-
self.snapshot = SnapshotClient(self)
335-
self.tasks = TasksClient(self)
336-
337-
self.xpack = XPackClient(self)
338-
self.ccr = CcrClient(self)
339-
self.dangling_indices = DanglingIndicesClient(self)
340-
self.enrich = EnrichClient(self)
341-
self.eql = EqlClient(self)
342-
self.esql = EsqlClient(self)
343-
self.graph = GraphClient(self)
344-
self.ilm = IlmClient(self)
345-
self.license = LicenseClient(self)
346-
self.logstash = LogstashClient(self)
347-
self.migration = MigrationClient(self)
348-
self.ml = MlClient(self)
349-
self.monitoring = MonitoringClient(self)
350-
self.query_rules = QueryRulesClient(self)
351-
self.rollup = RollupClient(self)
352-
self.search_application = SearchApplicationClient(self)
353-
self.searchable_snapshots = SearchableSnapshotsClient(self)
354-
self.security = SecurityClient(self)
355-
self.slm = SlmClient(self)
356-
self.simulate = SimulateClient(self)
357-
self.shutdown = ShutdownClient(self)
358-
self.sql = SqlClient(self)
359-
self.ssl = SslClient(self)
360-
self.synonyms = SynonymsClient(self)
361-
self.text_structure = TextStructureClient(self)
362-
self.transform = TransformClient(self)
363-
self.watcher = WatcherClient(self)
326+
self.async_search = AsyncSearchClient(self._base_client)
327+
self.autoscaling = AutoscalingClient(self._base_client)
328+
self.cat = CatClient(self._base_client)
329+
self.cluster = ClusterClient(self._base_client)
330+
self.connector = ConnectorClient(self._base_client)
331+
self.fleet = FleetClient(self._base_client)
332+
self.features = FeaturesClient(self._base_client)
333+
self.indices = IndicesClient(self._base_client)
334+
self.inference = InferenceClient(self._base_client)
335+
self.ingest = IngestClient(self._base_client)
336+
self.nodes = NodesClient(self._base_client)
337+
self.snapshot = SnapshotClient(self._base_client)
338+
self.tasks = TasksClient(self._base_client)
339+
340+
self.xpack = XPackClient(self._base_client)
341+
self.ccr = CcrClient(self._base_client)
342+
self.dangling_indices = DanglingIndicesClient(self._base_client)
343+
self.enrich = EnrichClient(self._base_client)
344+
self.eql = EqlClient(self._base_client)
345+
self.esql = EsqlClient(self._base_client)
346+
self.graph = GraphClient(self._base_client)
347+
self.ilm = IlmClient(self._base_client)
348+
self.license = LicenseClient(self._base_client)
349+
self.logstash = LogstashClient(self._base_client)
350+
self.migration = MigrationClient(self._base_client)
351+
self.ml = MlClient(self._base_client)
352+
self.monitoring = MonitoringClient(self._base_client)
353+
self.query_rules = QueryRulesClient(self._base_client)
354+
self.rollup = RollupClient(self._base_client)
355+
self.search_application = SearchApplicationClient(self._base_client)
356+
self.searchable_snapshots = SearchableSnapshotsClient(self._base_client)
357+
self.security = SecurityClient(self._base_client)
358+
self.slm = SlmClient(self._base_client)
359+
self.simulate = SimulateClient(self._base_client)
360+
self.shutdown = ShutdownClient(self._base_client)
361+
self.sql = SqlClient(self._base_client)
362+
self.ssl = SslClient(self._base_client)
363+
self.synonyms = SynonymsClient(self._base_client)
364+
self.text_structure = TextStructureClient(self._base_client)
365+
self.transform = TransformClient(self._base_client)
366+
self.watcher = WatcherClient(self._base_client)
367+
368+
@property
369+
def transport(self) -> AsyncTransport:
370+
return self._base_client._transport
371+
372+
async def perform_request(
373+
self,
374+
method: str,
375+
path: str,
376+
*,
377+
params: t.Optional[t.Mapping[str, t.Any]] = None,
378+
headers: t.Optional[t.Mapping[str, str]] = None,
379+
body: t.Optional[t.Any] = None,
380+
endpoint_id: t.Optional[str] = None,
381+
path_parts: t.Optional[t.Mapping[str, t.Any]] = None,
382+
) -> ApiResponse[t.Any]:
383+
with self._base_client._otel.span(
384+
method,
385+
endpoint_id=endpoint_id,
386+
path_parts=path_parts or {},
387+
) as otel_span:
388+
response = await self._base_client._perform_request(
389+
method,
390+
path,
391+
params=params,
392+
headers=headers,
393+
body=body,
394+
otel_span=otel_span,
395+
)
396+
otel_span.set_elastic_cloud_metadata(response.meta.headers)
397+
return response
364398

365399
def __repr__(self) -> str:
366400
try:
@@ -413,44 +447,44 @@ def options(
413447
resolved_headers["x-opaque-id"] = resolved_opaque_id
414448

415449
if resolved_headers:
416-
new_headers = self._headers.copy()
450+
new_headers = self._base_client._headers.copy()
417451
new_headers.update(resolved_headers)
418-
client._headers = new_headers
452+
client._base_client._headers = new_headers
419453
else:
420-
client._headers = self._headers.copy()
454+
client._base_client._headers = self._headers.copy()
421455

422456
if request_timeout is not DEFAULT:
423-
client._request_timeout = request_timeout
457+
client._base_client._request_timeout = request_timeout
424458
else:
425-
client._request_timeout = self._request_timeout
459+
client._base_client._request_timeout = self._base_client._request_timeout
426460

427461
if ignore_status is not DEFAULT:
428462
if isinstance(ignore_status, int):
429463
ignore_status = (ignore_status,)
430-
client._ignore_status = ignore_status
464+
client._base_client._ignore_status = ignore_status
431465
else:
432-
client._ignore_status = self._ignore_status
466+
client._base_client._ignore_status = self._base_client._ignore_status
433467

434468
if max_retries is not DEFAULT:
435469
if not isinstance(max_retries, int):
436470
raise TypeError("'max_retries' must be of type 'int'")
437-
client._max_retries = max_retries
471+
client._base_client._max_retries = max_retries
438472
else:
439-
client._max_retries = self._max_retries
473+
client._base_client._max_retries = self._base_client._max_retries
440474

441475
if retry_on_status is not DEFAULT:
442476
if isinstance(retry_on_status, int):
443477
retry_on_status = (retry_on_status,)
444-
client._retry_on_status = retry_on_status
478+
client._base_client._retry_on_status = retry_on_status
445479
else:
446-
client._retry_on_status = self._retry_on_status
480+
client._base_client._retry_on_status = self._base_client._retry_on_status
447481

448482
if retry_on_timeout is not DEFAULT:
449483
if not isinstance(retry_on_timeout, bool):
450484
raise TypeError("'retry_on_timeout' must be of type 'bool'")
451-
client._retry_on_timeout = retry_on_timeout
485+
client._base_client._retry_on_timeout = retry_on_timeout
452486
else:
453-
client._retry_on_timeout = self._retry_on_timeout
487+
client._base_client._retry_on_timeout = self._base_client._retry_on_timeout
454488

455489
return client
456490

elasticsearch/_async/client/_base.py

+19-48
Original file line numberDiff line numberDiff line change
@@ -210,49 +210,17 @@ def _default_sniffed_node_callback(
210210

211211

212212
class BaseClient:
213-
def __init__(self, _transport: AsyncTransport) -> None:
213+
def __init__(self, _transport: AsyncTransport, headers: HttpHeaders) -> None:
214214
self._transport = _transport
215215
self._client_meta: Union[DefaultType, Tuple[Tuple[str, str], ...]] = DEFAULT
216-
self._headers = HttpHeaders()
216+
self._headers = headers
217217
self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT
218218
self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT
219219
self._max_retries: Union[DefaultType, int] = DEFAULT
220-
self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT
221220
self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT
222221
self._verified_elasticsearch = False
223222
self._otel = OpenTelemetry()
224223

225-
@property
226-
def transport(self) -> AsyncTransport:
227-
return self._transport
228-
229-
async def perform_request(
230-
self,
231-
method: str,
232-
path: str,
233-
*,
234-
params: Optional[Mapping[str, Any]] = None,
235-
headers: Optional[Mapping[str, str]] = None,
236-
body: Optional[Any] = None,
237-
endpoint_id: Optional[str] = None,
238-
path_parts: Optional[Mapping[str, Any]] = None,
239-
) -> ApiResponse[Any]:
240-
with self._otel.span(
241-
method,
242-
endpoint_id=endpoint_id,
243-
path_parts=path_parts or {},
244-
) as otel_span:
245-
response = await self._perform_request(
246-
method,
247-
path,
248-
params=params,
249-
headers=headers,
250-
body=body,
251-
otel_span=otel_span,
252-
)
253-
otel_span.set_elastic_cloud_metadata(response.meta.headers)
254-
return response
255-
256224
async def _perform_request(
257225
self,
258226
method: str,
@@ -287,7 +255,7 @@ def mimetype_header_to_compat(header: str) -> None:
287255
else:
288256
target = path
289257

290-
meta, resp_body = await self.transport.perform_request(
258+
meta, resp_body = await self._transport.perform_request(
291259
method,
292260
target,
293261
headers=request_headers,
@@ -376,10 +344,9 @@ def mimetype_header_to_compat(header: str) -> None:
376344
return response
377345

378346

379-
class NamespacedClient(BaseClient):
380-
def __init__(self, client: "BaseClient") -> None:
381-
self._client = client
382-
super().__init__(self._client.transport)
347+
class NamespacedClient:
348+
def __init__(self, client: BaseClient) -> None:
349+
self._base_client = client
383350

384351
async def perform_request(
385352
self,
@@ -392,14 +359,18 @@ async def perform_request(
392359
endpoint_id: Optional[str] = None,
393360
path_parts: Optional[Mapping[str, Any]] = None,
394361
) -> ApiResponse[Any]:
395-
# Use the internal clients .perform_request() implementation
396-
# so we take advantage of their transport options.
397-
return await self._client.perform_request(
362+
with self._base_client._otel.span(
398363
method,
399-
path,
400-
params=params,
401-
headers=headers,
402-
body=body,
403364
endpoint_id=endpoint_id,
404-
path_parts=path_parts,
405-
)
365+
path_parts=path_parts or {},
366+
) as otel_span:
367+
response = await self._base_client._perform_request(
368+
method,
369+
path,
370+
params=params,
371+
headers=headers,
372+
body=body,
373+
otel_span=otel_span,
374+
)
375+
otel_span.set_elastic_cloud_metadata(response.meta.headers)
376+
return response

0 commit comments

Comments
 (0)