Skip to content

Commit 7f6147f

Browse files
Use a single httpx client for each worker thread (#862)
1 parent 7992deb commit 7f6147f

File tree

13 files changed

+183
-90
lines changed

13 files changed

+183
-90
lines changed

backend_py/primary/poetry.lock

+9-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
import logging
3+
4+
import httpx
5+
6+
7+
LOGGER = logging.getLogger(__name__)
8+
9+
10+
class HTTPXAsyncClientWrapper:
11+
"""Global async client wrapper for HTTPX."""
12+
13+
_instance: Optional["HTTPXAsyncClientWrapper"] = None
14+
_async_client: Optional[httpx.AsyncClient] = None
15+
16+
def __new__(cls) -> "HTTPXAsyncClientWrapper":
17+
if cls._instance is None:
18+
cls._instance = super().__new__(cls)
19+
return cls._instance
20+
21+
@property
22+
def client(self) -> httpx.AsyncClient:
23+
"""Get the async client instance."""
24+
if self._async_client is None:
25+
raise RuntimeError("HTTPXAsyncClientWrapper not initialized. Call start() first.")
26+
return self._async_client
27+
28+
def start(self) -> None:
29+
"""Instantiate the client. Call from the FastAPI startup hook."""
30+
if self._async_client is None:
31+
self._async_client = httpx.AsyncClient()
32+
LOGGER.info(f"httpx AsyncClient instantiated. Id {id(self._async_client)}")
33+
34+
async def stop(self) -> None:
35+
"""Gracefully shutdown. Call from FastAPI shutdown hook."""
36+
if self._async_client is not None:
37+
LOGGER.info(
38+
f"httpx async_client.is_closed: {self._async_client.is_closed}. " f"Id: {id(self._async_client)}"
39+
)
40+
await self._async_client.aclose()
41+
LOGGER.info(
42+
f"httpx async_client.is_closed: {self._async_client.is_closed}. " f"Id: {id(self._async_client)}"
43+
)
44+
self._async_client = None
45+
LOGGER.info("httpx AsyncClient closed")
46+
47+
48+
# Create a singleton instance of the async client
49+
httpx_async_client = HTTPXAsyncClientWrapper()

backend_py/primary/primary/main.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,21 @@
3838
from primary.utils.logging_setup import ensure_console_log_handler_is_configured, setup_normal_log_levels
3939

4040
from . import config
41-
41+
from .httpx_client import httpx_async_client
4242

4343
ensure_console_log_handler_is_configured()
4444
setup_normal_log_levels()
4545

4646
# temporarily set some loggers to DEBUG
4747
# logging.getLogger().setLevel(logging.DEBUG)
4848
logging.getLogger("primary.services.sumo_access").setLevel(logging.DEBUG)
49-
logging.getLogger("primary.services.user_session_manager").setLevel(logging.DEBUG)
49+
logging.getLogger("primary.services.smda_access").setLevel(logging.DEBUG)
50+
logging.getLogger("primary.services.ssdl_access").setLevel(logging.DEBUG)
5051
logging.getLogger("primary.services.user_grid3d_service").setLevel(logging.DEBUG)
5152
logging.getLogger("primary.routers.grid3d").setLevel(logging.DEBUG)
5253
logging.getLogger("primary.routers.dev").setLevel(logging.DEBUG)
54+
# logging.getLogger("uvicorn.error").setLevel(logging.DEBUG)
55+
# logging.getLogger("uvicorn.access").setLevel(logging.DEBUG)
5356

5457
LOGGER = logging.getLogger(__name__)
5558

@@ -71,6 +74,17 @@ def custom_generate_unique_id(route: APIRoute) -> str:
7174
LOGGER.warning("Skipping telemetry configuration, APPLICATIONINSIGHTS_CONNECTION_STRING env variable not set.")
7275

7376

77+
# Start the httpx client on startup and stop it on shutdown of the app
78+
@app.on_event("startup")
79+
async def startup_event() -> None:
80+
httpx_async_client.start()
81+
82+
83+
@app.on_event("shutdown")
84+
async def shutdown_event() -> None:
85+
await httpx_async_client.stop()
86+
87+
7488
# The tags we add here will determine the name of the frontend api service for our endpoints as well as
7589
# providing some grouping when viewing the openapi documentation.
7690
app.include_router(explore_router, tags=["explore"])

backend_py/primary/primary/services/graph_access/graph_access.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from typing import Mapping
33
from urllib.parse import urljoin
44

5-
# Using the same http client as sumo
65
import httpx
76

7+
from primary.httpx_client import httpx_async_client
8+
89

910
class GraphApiAccess:
1011
def __init__(self, access_token: str):
@@ -15,12 +16,12 @@ def _make_headers(self) -> Mapping[str, str]:
1516
return {"Authorization": f"Bearer {self._access_token}"}
1617

1718
async def _request(self, url: str) -> httpx.Response:
18-
async with httpx.AsyncClient() as client:
19-
response = await client.get(
20-
url,
21-
headers=self._make_headers(),
22-
)
23-
return response
19+
20+
response = await httpx_async_client.client.get(
21+
url,
22+
headers=self._make_headers(),
23+
)
24+
return response
2425

2526
async def get_user_profile_photo(self, user_id: str) -> str | None:
2627
request_url = urljoin(self.base_url, "me/photo/$value" if user_id == "me" else f"users/{user_id}/photo/$value")
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22
from typing import List
33

4-
import httpx
5-
64
from webviz_pkg.core_utils.perf_timer import PerfTimer
75

86
from primary import config
7+
from primary.httpx_client import httpx_async_client
98
from primary.services.service_exceptions import ServiceRequestError, Service
109

1110
LOGGER = logging.getLogger(__name__)
@@ -28,35 +27,34 @@ async def smda_get_request(access_token: str, endpoint: str, params: dict) -> Li
2827
timer = PerfTimer()
2928
single_lap_timer = PerfTimer()
3029

31-
async with httpx.AsyncClient() as client:
32-
results: List[dict] = []
33-
page: int = 1
34-
while True:
35-
response = await client.get(urlstring, params=params, headers=headers, timeout=60)
36-
LOGGER.info(f"TIME SMDA fetch '{endpoint}', page {page}, took {single_lap_timer.lap_s():.2f} seconds")
37-
page += 1
38-
if response.status_code == 200:
39-
result = response.json()["data"]["results"]
40-
if not result:
41-
raise ServiceRequestError(f"No data found for endpoint: '{endpoint}'", Service.SMDA)
42-
43-
results.extend(result)
44-
45-
next_request = response.json()["data"]["next"]
46-
if next_request is None:
47-
break
48-
params["_next"] = next_request
49-
elif response.status_code == 404:
50-
LOGGER.error(f"{str(response.status_code) } {endpoint} either does not exists or can not be found")
51-
raise ServiceRequestError(
52-
f"[{str(response.status_code)}] '{endpoint}' either does not exists or can not be found",
53-
Service.SMDA,
54-
)
55-
else:
56-
raise ServiceRequestError(
57-
f"[{str(response.status_code)}] Cannot fetch data from endpoint: '{endpoint}'", Service.SMDA
58-
)
59-
60-
LOGGER.info(f"TIME SMDA fetch '{endpoint}' took {timer.lap_s():.2f} seconds")
30+
results: List[dict] = []
31+
page: int = 1
32+
while True:
33+
response = await httpx_async_client.client.get(urlstring, params=params, headers=headers, timeout=60)
34+
LOGGER.debug(f"TIME SMDA fetch '{endpoint}', page {page}, took {single_lap_timer.lap_s():.2f} seconds")
35+
page += 1
36+
if response.status_code == 200:
37+
result = response.json()["data"]["results"]
38+
if not result:
39+
raise ServiceRequestError(f"No data found for endpoint: '{endpoint}'", Service.SMDA)
40+
41+
results.extend(result)
42+
43+
next_request = response.json()["data"]["next"]
44+
if next_request is None:
45+
break
46+
params["_next"] = next_request
47+
elif response.status_code == 404:
48+
LOGGER.error(f"{str(response.status_code) } {endpoint} either does not exists or can not be found")
49+
raise ServiceRequestError(
50+
f"[{str(response.status_code)}] '{endpoint}' either does not exists or can not be found",
51+
Service.SMDA,
52+
)
53+
else:
54+
raise ServiceRequestError(
55+
f"[{str(response.status_code)}] Cannot fetch data from endpoint: '{endpoint}'", Service.SMDA
56+
)
57+
58+
LOGGER.debug(f"TIME SMDA fetch '{endpoint}' took {timer.lap_s():.2f} seconds")
6159

6260
return results
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22
from typing import List, Optional
33

4-
import httpx
5-
64
from webviz_pkg.core_utils.perf_timer import PerfTimer
75

86
from primary import config
7+
from primary.httpx_client import httpx_async_client
98
from primary.services.service_exceptions import (
109
Service,
1110
InvalidDataError,
@@ -30,23 +29,22 @@ async def ssdl_get_request(access_token: str, endpoint: str, params: Optional[di
3029
}
3130
timer = PerfTimer()
3231

33-
async with httpx.AsyncClient() as client:
34-
response = await client.get(urlstring, params=params, headers=headers, timeout=60)
35-
results = []
36-
if response.status_code == 200:
37-
results = response.json()
38-
39-
elif response.status_code == 401:
40-
raise AuthorizationError("Unauthorized access to SSDL", Service.SSDL)
41-
elif response.status_code == 403:
42-
raise AuthorizationError("Forbidden access to SSDL", Service.SSDL)
43-
elif response.status_code == 404:
44-
raise InvalidDataError(f"Endpoint {endpoint} either does not exists or can not be found", Service.SSDL)
45-
46-
# Capture other errors
47-
else:
48-
raise InvalidParameterError(f"Can not fetch data from endpoint {endpoint}", Service.SSDL)
49-
50-
print(f"TIME SSDL fetch {endpoint} took {timer.lap_s():.2f} seconds")
51-
LOGGER.debug(f"TIME SSDL fetch {endpoint} took {timer.lap_s():.2f} seconds")
32+
response = await httpx_async_client.client.get(urlstring, params=params, headers=headers, timeout=60)
33+
results = []
34+
if response.status_code == 200:
35+
results = response.json()
36+
37+
elif response.status_code == 401:
38+
raise AuthorizationError("Unauthorized access to SSDL", Service.SSDL)
39+
elif response.status_code == 403:
40+
raise AuthorizationError("Forbidden access to SSDL", Service.SSDL)
41+
elif response.status_code == 404:
42+
raise InvalidDataError(f"Endpoint {endpoint} either does not exists or can not be found", Service.SSDL)
43+
44+
# Capture other errors
45+
else:
46+
raise InvalidParameterError(f"Can not fetch data from endpoint {endpoint}", Service.SSDL)
47+
48+
print(f"TIME SSDL fetch {endpoint} took {timer.lap_s():.2f} seconds")
49+
LOGGER.debug(f"TIME SSDL fetch {endpoint} took {timer.lap_s():.2f} seconds")
5250
return results

backend_py/primary/primary/services/sumo_access/_helpers.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,51 @@
11
import logging
2+
from typing import Any
23

34
from fmu.sumo.explorer.explorer import SumoClient, Pit
45
from fmu.sumo.explorer.objects import CaseCollection, Case
56
from webviz_pkg.core_utils.perf_timer import PerfTimer
7+
from webviz_pkg.core_utils.perf_metrics import PerfMetrics
68

79
from primary import config
10+
from primary.httpx_client import httpx_async_client
811
from primary.services.service_exceptions import Service, NoDataError, MultipleDataMatchesError
912

1013
LOGGER = logging.getLogger(__name__)
1114

1215

16+
class SynchronousMethodCallError(Exception):
17+
"""Custom error for when synchronous methods are called instead of async."""
18+
19+
20+
class FakeHTTPXClient:
21+
"""A fake HTTPX client to ensure we use async methods instead of sync ones.
22+
This is needed as we do not want to allow any synchronous HTTP calls in the primary service.
23+
Ideally this should be handled by the SumoClient. https://github.com/equinor/fmu-sumo/issues/369"""
24+
25+
def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument
26+
self._error_msg = "🚫 Do not use a synchronous http class!. Use the async http class instead. "
27+
28+
def __getattr__(self, name: str) -> None:
29+
"""Catch any synchronous method calls and raise a helpful error."""
30+
async_methods = {"get", "post", "put", "patch", "delete", "head", "options"}
31+
if name in async_methods:
32+
raise SynchronousMethodCallError(self._error_msg.format(method=name))
33+
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
34+
35+
1336
def create_sumo_client(access_token: str) -> SumoClient:
37+
timer = PerfMetrics()
1438
if access_token == "DUMMY_TOKEN_FOR_TESTING": # nosec bandit B105
1539
sumo_client = SumoClient(env=config.SUMO_ENV, interactive=False)
1640
else:
17-
sumo_client = SumoClient(env=config.SUMO_ENV, token=access_token, interactive=False)
41+
sumo_client = SumoClient(
42+
env=config.SUMO_ENV,
43+
token=access_token,
44+
http_client=FakeHTTPXClient(),
45+
async_http_client=httpx_async_client.client,
46+
)
47+
timer.record_lap("create_sumo_client()")
48+
LOGGER.debug(f"{timer.to_string()}ms")
1849
return sumo_client
1950

2051

backend_py/primary/primary/services/sumo_access/case_inspector.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ async def get_case_name_async(self) -> str:
4040
async def get_iterations_async(self) -> list[IterationInfo]:
4141
case: Case = await self._get_or_create_case_obj()
4242

43-
# Stick with the sync version for now, since there is a bug in the async version of SumoExplorer
44-
# See: https://github.com/equinor/fmu-sumo/issues/326
45-
iterations = case.iterations
43+
iterations = await case.iterations_async
4644

4745
iter_info_arr: list[IterationInfo] = []
4846
for iteration in iterations:

backend_py/primary/primary/services/sumo_access/group_tree_access.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def get_group_tree_table(self, realization: Optional[int]) -> Optional[pd.
4343
if await table_collection.length_async() > 1:
4444
raise ValueError("Multiple tables found.")
4545

46-
group_tree_df = table_collection[0].to_pandas()
46+
group_tree_df = await table_collection[0].to_pandas_async()
4747

4848
_validate_group_tree_df(group_tree_df)
4949

backend_py/primary/primary/services/sumo_access/table_access.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ async def get_realization_table_async(
5353
iteration=self._iteration_name,
5454
realization=realization,
5555
)
56-
if not table_collection:
56+
table_length = await table_collection.length_async()
57+
if table_length == 0:
5758
raise ValueError(f"No table found for {table_schema=}")
58-
if len(table_collection) > 1:
59+
if table_length > 1:
5960
raise ValueError(f"Multiple tables found for {table_schema=}")
6061

6162
sumo_table = await table_collection.getitem_async(0)

0 commit comments

Comments
 (0)