Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cyclic dependency between clients and sub clients #2858

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 97 additions & 65 deletions elasticsearch/_async/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import typing as t

from elastic_transport import (
ApiResponse,
AsyncTransport,
BaseNode,
BinaryApiResponse,
Expand Down Expand Up @@ -97,7 +98,7 @@
SelfType = t.TypeVar("SelfType", bound="AsyncElasticsearch")


class AsyncElasticsearch(BaseClient):
class AsyncElasticsearch:
"""
Elasticsearch low-level client. Provides a straightforward mapping from
Python to Elasticsearch REST APIs.
Expand Down Expand Up @@ -295,72 +296,103 @@ def __init__(
**transport_kwargs,
)

super().__init__(_transport)
self._base_client = BaseClient(_transport)

# These are set per-request so are stored separately.
self._request_timeout = request_timeout
self._max_retries = max_retries
self._retry_on_timeout = retry_on_timeout
self._base_client._request_timeout = request_timeout
self._base_client._max_retries = max_retries
self._base_client._retry_on_timeout = retry_on_timeout
if isinstance(retry_on_status, int):
retry_on_status = (retry_on_status,)
self._retry_on_status = retry_on_status
self._base_client._retry_on_status = retry_on_status

else:
super().__init__(_transport)
self._base_client = BaseClient(_transport)

if headers is not DEFAULT and headers is not None:
self._headers.update(headers)
self._base_client._headers.update(headers)
if opaque_id is not DEFAULT and opaque_id is not None: # type: ignore[comparison-overlap]
self._headers["x-opaque-id"] = opaque_id
self._headers = resolve_auth_headers(
self._headers,
self._base_client._headers["x-opaque-id"] = opaque_id
self._base_client._headers = resolve_auth_headers(
self._base_client._headers,
api_key=api_key,
basic_auth=basic_auth,
bearer_auth=bearer_auth,
)

# namespaced clients for compatibility with API names
self.async_search = AsyncSearchClient(self)
self.autoscaling = AutoscalingClient(self)
self.cat = CatClient(self)
self.cluster = ClusterClient(self)
self.connector = ConnectorClient(self)
self.fleet = FleetClient(self)
self.features = FeaturesClient(self)
self.indices = IndicesClient(self)
self.inference = InferenceClient(self)
self.ingest = IngestClient(self)
self.nodes = NodesClient(self)
self.snapshot = SnapshotClient(self)
self.tasks = TasksClient(self)

self.xpack = XPackClient(self)
self.ccr = CcrClient(self)
self.dangling_indices = DanglingIndicesClient(self)
self.enrich = EnrichClient(self)
self.eql = EqlClient(self)
self.esql = EsqlClient(self)
self.graph = GraphClient(self)
self.ilm = IlmClient(self)
self.license = LicenseClient(self)
self.logstash = LogstashClient(self)
self.migration = MigrationClient(self)
self.ml = MlClient(self)
self.monitoring = MonitoringClient(self)
self.query_rules = QueryRulesClient(self)
self.rollup = RollupClient(self)
self.search_application = SearchApplicationClient(self)
self.searchable_snapshots = SearchableSnapshotsClient(self)
self.security = SecurityClient(self)
self.slm = SlmClient(self)
self.simulate = SimulateClient(self)
self.shutdown = ShutdownClient(self)
self.sql = SqlClient(self)
self.ssl = SslClient(self)
self.synonyms = SynonymsClient(self)
self.text_structure = TextStructureClient(self)
self.transform = TransformClient(self)
self.watcher = WatcherClient(self)
self.async_search = AsyncSearchClient(self._base_client)
self.autoscaling = AutoscalingClient(self._base_client)
self.cat = CatClient(self._base_client)
self.cluster = ClusterClient(self._base_client)
self.connector = ConnectorClient(self._base_client)
self.fleet = FleetClient(self._base_client)
self.features = FeaturesClient(self._base_client)
self.indices = IndicesClient(self._base_client)
self.inference = InferenceClient(self._base_client)
self.ingest = IngestClient(self._base_client)
self.nodes = NodesClient(self._base_client)
self.snapshot = SnapshotClient(self._base_client)
self.tasks = TasksClient(self._base_client)

self.xpack = XPackClient(self._base_client)
self.ccr = CcrClient(self._base_client)
self.dangling_indices = DanglingIndicesClient(self._base_client)
self.enrich = EnrichClient(self._base_client)
self.eql = EqlClient(self._base_client)
self.esql = EsqlClient(self._base_client)
self.graph = GraphClient(self._base_client)
self.ilm = IlmClient(self._base_client)
self.license = LicenseClient(self._base_client)
self.logstash = LogstashClient(self._base_client)
self.migration = MigrationClient(self._base_client)
self.ml = MlClient(self._base_client)
self.monitoring = MonitoringClient(self._base_client)
self.query_rules = QueryRulesClient(self._base_client)
self.rollup = RollupClient(self._base_client)
self.search_application = SearchApplicationClient(self._base_client)
self.searchable_snapshots = SearchableSnapshotsClient(self._base_client)
self.security = SecurityClient(self._base_client)
self.slm = SlmClient(self._base_client)
self.simulate = SimulateClient(self._base_client)
self.shutdown = ShutdownClient(self._base_client)
self.sql = SqlClient(self._base_client)
self.ssl = SslClient(self._base_client)
self.synonyms = SynonymsClient(self._base_client)
self.text_structure = TextStructureClient(self._base_client)
self.transform = TransformClient(self._base_client)
self.watcher = WatcherClient(self._base_client)

@property
def transport(self) -> AsyncTransport:
return self._base_client._transport

async def perform_request(
self,
method: str,
path: str,
*,
params: t.Optional[t.Mapping[str, t.Any]] = None,
headers: t.Optional[t.Mapping[str, str]] = None,
body: t.Optional[t.Any] = None,
endpoint_id: t.Optional[str] = None,
path_parts: t.Optional[t.Mapping[str, t.Any]] = None,
) -> ApiResponse[t.Any]:
with self._base_client._otel.span(
method,
endpoint_id=endpoint_id,
path_parts=path_parts or {},
) as otel_span:
response = await self._base_client._perform_request(
method,
path,
params=params,
headers=headers,
body=body,
otel_span=otel_span,
)
otel_span.set_elastic_cloud_metadata(response.meta.headers)
return response

def __repr__(self) -> str:
try:
Expand Down Expand Up @@ -413,44 +445,44 @@ def options(
resolved_headers["x-opaque-id"] = resolved_opaque_id

if resolved_headers:
new_headers = self._headers.copy()
new_headers = self._base_client._headers.copy()
new_headers.update(resolved_headers)
client._headers = new_headers
client._base_client._headers = new_headers
else:
client._headers = self._headers.copy()
client._base_client._headers = self._base_client._headers.copy()

if request_timeout is not DEFAULT:
client._request_timeout = request_timeout
client._base_client._request_timeout = request_timeout
else:
client._request_timeout = self._request_timeout
client._base_client._request_timeout = self._base_client._request_timeout

if ignore_status is not DEFAULT:
if isinstance(ignore_status, int):
ignore_status = (ignore_status,)
client._ignore_status = ignore_status
client._base_client._ignore_status = ignore_status
else:
client._ignore_status = self._ignore_status
client._base_client._ignore_status = self._base_client._ignore_status

if max_retries is not DEFAULT:
if not isinstance(max_retries, int):
raise TypeError("'max_retries' must be of type 'int'")
client._max_retries = max_retries
client._base_client._max_retries = max_retries
else:
client._max_retries = self._max_retries
client._base_client._max_retries = self._base_client._max_retries

if retry_on_status is not DEFAULT:
if isinstance(retry_on_status, int):
retry_on_status = (retry_on_status,)
client._retry_on_status = retry_on_status
client._base_client._retry_on_status = retry_on_status
else:
client._retry_on_status = self._retry_on_status
client._base_client._retry_on_status = self._base_client._retry_on_status

if retry_on_timeout is not DEFAULT:
if not isinstance(retry_on_timeout, bool):
raise TypeError("'retry_on_timeout' must be of type 'bool'")
client._retry_on_timeout = retry_on_timeout
client._base_client._retry_on_timeout = retry_on_timeout
else:
client._retry_on_timeout = self._retry_on_timeout
client._base_client._retry_on_timeout = self._base_client._retry_on_timeout

return client

Expand Down
64 changes: 18 additions & 46 deletions elasticsearch/_async/client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,42 +217,11 @@ def __init__(self, _transport: AsyncTransport) -> None:
self._request_timeout: Union[DefaultType, Optional[float]] = DEFAULT
self._ignore_status: Union[DefaultType, Collection[int]] = DEFAULT
self._max_retries: Union[DefaultType, int] = DEFAULT
self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT
self._retry_on_status: Union[DefaultType, Collection[int]] = DEFAULT
self._retry_on_timeout: Union[DefaultType, bool] = DEFAULT
self._verified_elasticsearch = False
self._otel = OpenTelemetry()

@property
def transport(self) -> AsyncTransport:
return self._transport

async def perform_request(
self,
method: str,
path: str,
*,
params: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Any] = None,
endpoint_id: Optional[str] = None,
path_parts: Optional[Mapping[str, Any]] = None,
) -> ApiResponse[Any]:
with self._otel.span(
method,
endpoint_id=endpoint_id,
path_parts=path_parts or {},
) as otel_span:
response = await self._perform_request(
method,
path,
params=params,
headers=headers,
body=body,
otel_span=otel_span,
)
otel_span.set_elastic_cloud_metadata(response.meta.headers)
return response

async def _perform_request(
self,
method: str,
Expand Down Expand Up @@ -287,7 +256,7 @@ def mimetype_header_to_compat(header: str) -> None:
else:
target = path

meta, resp_body = await self.transport.perform_request(
meta, resp_body = await self._transport.perform_request(
method,
target,
headers=request_headers,
Expand Down Expand Up @@ -376,10 +345,9 @@ def mimetype_header_to_compat(header: str) -> None:
return response


class NamespacedClient(BaseClient):
def __init__(self, client: "BaseClient") -> None:
self._client = client
super().__init__(self._client.transport)
class NamespacedClient:
def __init__(self, client: BaseClient) -> None:
self._base_client = client

async def perform_request(
self,
Expand All @@ -392,14 +360,18 @@ async def perform_request(
endpoint_id: Optional[str] = None,
path_parts: Optional[Mapping[str, Any]] = None,
) -> ApiResponse[Any]:
# Use the internal clients .perform_request() implementation
# so we take advantage of their transport options.
return await self._client.perform_request(
with self._base_client._otel.span(
method,
path,
params=params,
headers=headers,
body=body,
endpoint_id=endpoint_id,
path_parts=path_parts,
)
path_parts=path_parts or {},
) as otel_span:
response = await self._base_client._perform_request(
method,
path,
params=params,
headers=headers,
body=body,
otel_span=otel_span,
)
otel_span.set_elastic_cloud_metadata(response.meta.headers)
return response
4 changes: 2 additions & 2 deletions elasticsearch/_async/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def async_streaming_bulk(
"""

client = client.options()
client._client_meta = (("h", "bp"),)
client._base_client._client_meta = (("h", "bp"),)

if isinstance(retry_on_status, int):
retry_on_status = (retry_on_status,)
Expand Down Expand Up @@ -429,7 +429,7 @@ def pop_transport_kwargs(kw: MutableMapping[str, Any]) -> MutableMapping[str, An
client = client.options(
request_timeout=request_timeout, **pop_transport_kwargs(kwargs)
)
client._client_meta = (("h", "s"),)
client._base_client._client_meta = (("h", "s"),)

# Setting query={"from": ...} would make 'from' be used
# as a keyword argument instead of 'from_'. We handle that here.
Expand Down
Loading
Loading