Skip to content

Commit 2f2faa2

Browse files
alukachbitner
andauthored
Add tooling to make customization of DB Connection possible (#538)
* Add functionality to customize db connection retrieval * Add test for customizing the connection_getter * Cleanup * isort fixes * flake8 fix * Update typing --------- Co-authored-by: David Bitner <[email protected]>
1 parent 42b5588 commit 2f2faa2

File tree

4 files changed

+142
-93
lines changed

4 files changed

+142
-93
lines changed

Diff for: stac_fastapi/pgstac/stac_fastapi/pgstac/core.py

+27-30
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import orjson
99
from asyncpg.exceptions import InvalidDatetimeFormatError
1010
from buildpg import render
11-
from fastapi import HTTPException
11+
from fastapi import HTTPException, Request
1212
from pydantic import ValidationError
1313
from pygeofilter.backends.cql2_json import to_cql2
1414
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
1515
from pypgstac.hydration import hydrate
1616
from stac_pydantic.links import Relations
1717
from stac_pydantic.shared import MimeTypes
18-
from starlette.requests import Request
1918

2019
from stac_fastapi.pgstac.config import Settings
2120
from stac_fastapi.pgstac.models.links import (
@@ -38,13 +37,11 @@
3837
class CoreCrudClient(AsyncBaseCoreClient):
3938
"""Client for core endpoints defined by stac."""
4039

41-
async def all_collections(self, **kwargs) -> Collections:
40+
async def all_collections(self, request: Request, **kwargs) -> Collections:
4241
"""Read all collections from the database."""
43-
request: Request = kwargs["request"]
4442
base_url = get_base_url(request)
45-
pool = request.app.state.readpool
4643

47-
async with pool.acquire() as conn:
44+
async with request.app.state.get_connection(request, "r") as conn:
4845
collections = await conn.fetchval(
4946
"""
5047
SELECT * FROM all_collections();
@@ -80,7 +77,9 @@ async def all_collections(self, **kwargs) -> Collections:
8077
collection_list = Collections(collections=linked_collections or [], links=links)
8178
return collection_list
8279

83-
async def get_collection(self, collection_id: str, **kwargs) -> Collection:
80+
async def get_collection(
81+
self, collection_id: str, request: Request, **kwargs
82+
) -> Collection:
8483
"""Get collection by id.
8584
8685
Called with `GET /collections/{collection_id}`.
@@ -93,9 +92,7 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
9392
"""
9493
collection: Optional[Dict[str, Any]]
9594

96-
request: Request = kwargs["request"]
97-
pool = request.app.state.readpool
98-
async with pool.acquire() as conn:
95+
async with request.app.state.get_connection(request, "r") as conn:
9996
q, p = render(
10097
"""
10198
SELECT * FROM get_collection(:id::text);
@@ -125,8 +122,7 @@ async def _get_base_item(
125122
"""
126123
item: Optional[Dict[str, Any]]
127124

128-
pool = request.app.state.readpool
129-
async with pool.acquire() as conn:
125+
async with request.app.state.get_connection(request, "r") as conn:
130126
q, p = render(
131127
"""
132128
SELECT * FROM collection_base_item(:collection_id::text);
@@ -143,7 +139,7 @@ async def _get_base_item(
143139
async def _search_base(
144140
self,
145141
search_request: PgstacSearch,
146-
**kwargs: Any,
142+
request: Request,
147143
) -> ItemCollection:
148144
"""Cross catalog search (POST).
149145
@@ -157,21 +153,19 @@ async def _search_base(
157153
"""
158154
items: Dict[str, Any]
159155

160-
request: Request = kwargs["request"]
161156
settings: Settings = request.app.state.settings
162-
pool = request.app.state.readpool
163157

164158
search_request.conf = search_request.conf or {}
165159
search_request.conf["nohydrate"] = settings.use_api_hydrate
166-
req = search_request.json(exclude_none=True, by_alias=True)
160+
search_request_json = search_request.json(exclude_none=True, by_alias=True)
167161

168162
try:
169-
async with pool.acquire() as conn:
163+
async with request.app.state.get_connection(request, "r") as conn:
170164
q, p = render(
171165
"""
172166
SELECT * FROM search(:req::text::jsonb);
173167
""",
174-
req=req,
168+
req=search_request_json,
175169
)
176170
items = await conn.fetchval(q, *p)
177171
except InvalidDatetimeFormatError:
@@ -253,6 +247,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
253247
async def item_collection(
254248
self,
255249
collection_id: str,
250+
request: Request,
256251
bbox: Optional[List[NumType]] = None,
257252
datetime: Optional[Union[str, datetime]] = None,
258253
limit: Optional[int] = None,
@@ -272,7 +267,7 @@ async def item_collection(
272267
An ItemCollection.
273268
"""
274269
# If collection does not exist, NotFoundError wil be raised
275-
await self.get_collection(collection_id, **kwargs)
270+
await self.get_collection(collection_id, request)
276271

277272
base_args = {
278273
"collections": [collection_id],
@@ -287,17 +282,19 @@ async def item_collection(
287282
if v is not None and v != []:
288283
clean[k] = v
289284

290-
req = self.post_request_model(
285+
search_request = self.post_request_model(
291286
**clean,
292287
)
293-
item_collection = await self._search_base(req, **kwargs)
288+
item_collection = await self._search_base(search_request, request)
294289
links = await ItemCollectionLinks(
295-
collection_id=collection_id, request=kwargs["request"]
290+
collection_id=collection_id, request=request
296291
).get_links(extra_links=item_collection["links"])
297292
item_collection["links"] = links
298293
return item_collection
299294

300-
async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
295+
async def get_item(
296+
self, item_id: str, collection_id: str, request: Request, **kwargs
297+
) -> Item:
301298
"""Get item by id.
302299
303300
Called with `GET /collections/{collection_id}/items/{item_id}`.
@@ -310,12 +307,12 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
310307
Item.
311308
"""
312309
# If collection does not exist, NotFoundError wil be raised
313-
await self.get_collection(collection_id, **kwargs)
310+
await self.get_collection(collection_id, request)
314311

315-
req = self.post_request_model(
312+
search_request = self.post_request_model(
316313
ids=[item_id], collections=[collection_id], limit=1
317314
)
318-
item_collection = await self._search_base(req, **kwargs)
315+
item_collection = await self._search_base(search_request, request)
319316
if not item_collection["features"]:
320317
raise NotFoundError(
321318
f"Item {item_id} in Collection {collection_id} does not exist."
@@ -324,7 +321,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
324321
return Item(**item_collection["features"][0])
325322

326323
async def post_search(
327-
self, search_request: PgstacSearch, **kwargs
324+
self, search_request: PgstacSearch, request: Request, **kwargs
328325
) -> ItemCollection:
329326
"""Cross catalog search (POST).
330327
@@ -336,11 +333,12 @@ async def post_search(
336333
Returns:
337334
ItemCollection containing items which match the search criteria.
338335
"""
339-
item_collection = await self._search_base(search_request, **kwargs)
336+
item_collection = await self._search_base(search_request, request)
340337
return ItemCollection(**item_collection)
341338

342339
async def get_search(
343340
self,
341+
request: Request,
344342
collections: Optional[List[str]] = None,
345343
ids: Optional[List[str]] = None,
346344
bbox: Optional[List[NumType]] = None,
@@ -362,7 +360,6 @@ async def get_search(
362360
Returns:
363361
ItemCollection containing items which match the search criteria.
364362
"""
365-
request = kwargs["request"]
366363
query_params = str(request.query_params)
367364

368365
# Kludgy fix because using factory does not allow alias for filter-lang
@@ -432,4 +429,4 @@ async def get_search(
432429
raise HTTPException(
433430
status_code=400, detail=f"Invalid parameters provided {e}"
434431
)
435-
return await self.post_search(search_request, request=kwargs["request"])
432+
return await self.post_search(search_request, request=request)

Diff for: stac_fastapi/pgstac/stac_fastapi/pgstac/db.py

+41-25
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Database connection handling."""
22

33
import json
4-
from contextlib import contextmanager
5-
from typing import Dict, Generator, Union
4+
from contextlib import asynccontextmanager, contextmanager
5+
from typing import AsyncIterator, Callable, Dict, Generator, Literal, Union
66

77
import attr
88
import orjson
9-
from asyncpg import exceptions, pool
9+
from asyncpg import Connection, exceptions
1010
from buildpg import V, asyncpg, render
11-
from fastapi import FastAPI
11+
from fastapi import FastAPI, Request
1212

1313
from stac_fastapi.types.errors import (
1414
ConflictError,
@@ -34,8 +34,11 @@ async def con_init(conn):
3434
)
3535

3636

37-
async def connect_to_db(app: FastAPI) -> None:
38-
"""Connect to Database."""
37+
ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]]
38+
39+
40+
async def connect_to_db(app: FastAPI, get_conn: ConnectionGetter = None) -> None:
41+
"""Create connection pools & connection retriever on application."""
3942
settings = app.state.settings
4043
if app.state.settings.testing:
4144
readpool = writepool = settings.testing_connection_string
@@ -45,6 +48,7 @@ async def connect_to_db(app: FastAPI) -> None:
4548
db = DB()
4649
app.state.readpool = await db.create_pool(readpool, settings)
4750
app.state.writepool = await db.create_pool(writepool, settings)
51+
app.state.get_connection = get_conn if get_conn else get_connection
4852

4953

5054
async def close_db_connection(app: FastAPI) -> None:
@@ -53,7 +57,21 @@ async def close_db_connection(app: FastAPI) -> None:
5357
await app.state.writepool.close()
5458

5559

56-
async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]):
60+
@asynccontextmanager
61+
async def get_connection(
62+
request: Request,
63+
readwrite: Literal["r", "w"] = "r",
64+
) -> AsyncIterator[Connection]:
65+
"""Retrieve connection from database conection pool."""
66+
pool = (
67+
request.app.state.writepool if readwrite == "w" else request.app.state.readpool
68+
)
69+
with translate_pgstac_errors():
70+
async with pool.acquire() as conn:
71+
yield conn
72+
73+
74+
async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict]):
5775
"""Wrap PLPGSQL Functions.
5876
5977
Keyword arguments:
@@ -64,25 +82,23 @@ async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]):
6482
"""
6583
with translate_pgstac_errors():
6684
if isinstance(arg, str):
67-
async with pool.acquire() as conn:
68-
q, p = render(
69-
"""
70-
SELECT * FROM :func(:item::text);
71-
""",
72-
func=V(func),
73-
item=arg,
74-
)
75-
return await conn.fetchval(q, *p)
85+
q, p = render(
86+
"""
87+
SELECT * FROM :func(:item::text);
88+
""",
89+
func=V(func),
90+
item=arg,
91+
)
92+
return await conn.fetchval(q, *p)
7693
else:
77-
async with pool.acquire() as conn:
78-
q, p = render(
79-
"""
80-
SELECT * FROM :func(:item::text::jsonb);
81-
""",
82-
func=V(func),
83-
item=json.dumps(arg),
84-
)
85-
return await conn.fetchval(q, *p)
94+
q, p = render(
95+
"""
96+
SELECT * FROM :func(:item::text::jsonb);
97+
""",
98+
func=V(func),
99+
item=json.dumps(arg),
100+
)
101+
return await conn.fetchval(q, *p)
86102

87103

88104
@contextmanager

0 commit comments

Comments
 (0)