8
8
import orjson
9
9
from asyncpg .exceptions import InvalidDatetimeFormatError
10
10
from buildpg import render
11
- from fastapi import HTTPException
11
+ from fastapi import HTTPException , Request
12
12
from pydantic import ValidationError
13
13
from pygeofilter .backends .cql2_json import to_cql2
14
14
from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
15
15
from pypgstac .hydration import hydrate
16
16
from stac_pydantic .links import Relations
17
17
from stac_pydantic .shared import MimeTypes
18
- from starlette .requests import Request
19
18
20
19
from stac_fastapi .pgstac .config import Settings
21
20
from stac_fastapi .pgstac .models .links import (
38
37
class CoreCrudClient (AsyncBaseCoreClient ):
39
38
"""Client for core endpoints defined by stac."""
40
39
41
- async def all_collections (self , ** kwargs ) -> Collections :
40
+ async def all_collections (self , request : Request , ** kwargs ) -> Collections :
42
41
"""Read all collections from the database."""
43
- request : Request = kwargs ["request" ]
44
42
base_url = get_base_url (request )
45
- pool = request .app .state .readpool
46
43
47
- async with pool . acquire ( ) as conn :
44
+ async with request . app . state . get_connection ( request , "r" ) as conn :
48
45
collections = await conn .fetchval (
49
46
"""
50
47
SELECT * FROM all_collections();
@@ -80,7 +77,9 @@ async def all_collections(self, **kwargs) -> Collections:
80
77
collection_list = Collections (collections = linked_collections or [], links = links )
81
78
return collection_list
82
79
83
- async def get_collection (self , collection_id : str , ** kwargs ) -> Collection :
80
+ async def get_collection (
81
+ self , collection_id : str , request : Request , ** kwargs
82
+ ) -> Collection :
84
83
"""Get collection by id.
85
84
86
85
Called with `GET /collections/{collection_id}`.
@@ -93,9 +92,7 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
93
92
"""
94
93
collection : Optional [Dict [str , Any ]]
95
94
96
- request : Request = kwargs ["request" ]
97
- pool = request .app .state .readpool
98
- async with pool .acquire () as conn :
95
+ async with request .app .state .get_connection (request , "r" ) as conn :
99
96
q , p = render (
100
97
"""
101
98
SELECT * FROM get_collection(:id::text);
@@ -125,8 +122,7 @@ async def _get_base_item(
125
122
"""
126
123
item : Optional [Dict [str , Any ]]
127
124
128
- pool = request .app .state .readpool
129
- async with pool .acquire () as conn :
125
+ async with request .app .state .get_connection (request , "r" ) as conn :
130
126
q , p = render (
131
127
"""
132
128
SELECT * FROM collection_base_item(:collection_id::text);
@@ -143,7 +139,7 @@ async def _get_base_item(
143
139
async def _search_base (
144
140
self ,
145
141
search_request : PgstacSearch ,
146
- ** kwargs : Any ,
142
+ request : Request ,
147
143
) -> ItemCollection :
148
144
"""Cross catalog search (POST).
149
145
@@ -157,21 +153,19 @@ async def _search_base(
157
153
"""
158
154
items : Dict [str , Any ]
159
155
160
- request : Request = kwargs ["request" ]
161
156
settings : Settings = request .app .state .settings
162
- pool = request .app .state .readpool
163
157
164
158
search_request .conf = search_request .conf or {}
165
159
search_request .conf ["nohydrate" ] = settings .use_api_hydrate
166
- req = search_request .json (exclude_none = True , by_alias = True )
160
+ search_request_json = search_request .json (exclude_none = True , by_alias = True )
167
161
168
162
try :
169
- async with pool . acquire ( ) as conn :
163
+ async with request . app . state . get_connection ( request , "r" ) as conn :
170
164
q , p = render (
171
165
"""
172
166
SELECT * FROM search(:req::text::jsonb);
173
167
""" ,
174
- req = req ,
168
+ req = search_request_json ,
175
169
)
176
170
items = await conn .fetchval (q , * p )
177
171
except InvalidDatetimeFormatError :
@@ -253,6 +247,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
253
247
async def item_collection (
254
248
self ,
255
249
collection_id : str ,
250
+ request : Request ,
256
251
bbox : Optional [List [NumType ]] = None ,
257
252
datetime : Optional [Union [str , datetime ]] = None ,
258
253
limit : Optional [int ] = None ,
@@ -272,7 +267,7 @@ async def item_collection(
272
267
An ItemCollection.
273
268
"""
274
269
# If collection does not exist, NotFoundError wil be raised
275
- await self .get_collection (collection_id , ** kwargs )
270
+ await self .get_collection (collection_id , request )
276
271
277
272
base_args = {
278
273
"collections" : [collection_id ],
@@ -287,17 +282,19 @@ async def item_collection(
287
282
if v is not None and v != []:
288
283
clean [k ] = v
289
284
290
- req = self .post_request_model (
285
+ search_request = self .post_request_model (
291
286
** clean ,
292
287
)
293
- item_collection = await self ._search_base (req , ** kwargs )
288
+ item_collection = await self ._search_base (search_request , request )
294
289
links = await ItemCollectionLinks (
295
- collection_id = collection_id , request = kwargs [ " request" ]
290
+ collection_id = collection_id , request = request
296
291
).get_links (extra_links = item_collection ["links" ])
297
292
item_collection ["links" ] = links
298
293
return item_collection
299
294
300
- async def get_item (self , item_id : str , collection_id : str , ** kwargs ) -> Item :
295
+ async def get_item (
296
+ self , item_id : str , collection_id : str , request : Request , ** kwargs
297
+ ) -> Item :
301
298
"""Get item by id.
302
299
303
300
Called with `GET /collections/{collection_id}/items/{item_id}`.
@@ -310,12 +307,12 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
310
307
Item.
311
308
"""
312
309
# If collection does not exist, NotFoundError wil be raised
313
- await self .get_collection (collection_id , ** kwargs )
310
+ await self .get_collection (collection_id , request )
314
311
315
- req = self .post_request_model (
312
+ search_request = self .post_request_model (
316
313
ids = [item_id ], collections = [collection_id ], limit = 1
317
314
)
318
- item_collection = await self ._search_base (req , ** kwargs )
315
+ item_collection = await self ._search_base (search_request , request )
319
316
if not item_collection ["features" ]:
320
317
raise NotFoundError (
321
318
f"Item { item_id } in Collection { collection_id } does not exist."
@@ -324,7 +321,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
324
321
return Item (** item_collection ["features" ][0 ])
325
322
326
323
async def post_search (
327
- self , search_request : PgstacSearch , ** kwargs
324
+ self , search_request : PgstacSearch , request : Request , ** kwargs
328
325
) -> ItemCollection :
329
326
"""Cross catalog search (POST).
330
327
@@ -336,11 +333,12 @@ async def post_search(
336
333
Returns:
337
334
ItemCollection containing items which match the search criteria.
338
335
"""
339
- item_collection = await self ._search_base (search_request , ** kwargs )
336
+ item_collection = await self ._search_base (search_request , request )
340
337
return ItemCollection (** item_collection )
341
338
342
339
async def get_search (
343
340
self ,
341
+ request : Request ,
344
342
collections : Optional [List [str ]] = None ,
345
343
ids : Optional [List [str ]] = None ,
346
344
bbox : Optional [List [NumType ]] = None ,
@@ -362,7 +360,6 @@ async def get_search(
362
360
Returns:
363
361
ItemCollection containing items which match the search criteria.
364
362
"""
365
- request = kwargs ["request" ]
366
363
query_params = str (request .query_params )
367
364
368
365
# Kludgy fix because using factory does not allow alias for filter-lang
@@ -432,4 +429,4 @@ async def get_search(
432
429
raise HTTPException (
433
430
status_code = 400 , detail = f"Invalid parameters provided { e } "
434
431
)
435
- return await self .post_search (search_request , request = kwargs [ " request" ] )
432
+ return await self .post_search (search_request , request = request )
0 commit comments