Skip to content

Commit a416ec0

Browse files
Queryables landing page and collection links (stac-utils#267)
**Related Issue(s):** - stac-utils#260 **Description:** - Code to add the `queryables` link included in the Filter extension to the landing page and collections - Only adds the `queryables` collection link when the Filter extension is enabled. It does this by passing an extensions list to the DatabaseLogic class. This could be used to have other conditions for when certain extensions are disabled/enabled in the app. please let me know if you have any suggestions for this approach - Some improvements to `data_loader.py` **PR Checklist:** - [x] Code is formatted and linted (run `pre-commit run --all-files`) - [x] Tests pass (run `make test`) - [ ] Documentation has been updated to reflect changes, if applicable - [x] Changes are added to the changelog --------- Co-authored-by: Jonathan Healy <[email protected]>
1 parent 009754e commit a416ec0

File tree

12 files changed

+135
-27
lines changed

12 files changed

+135
-27
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
77

88
## [Unreleased]
99

10+
### Added
11+
- Queryables landing page and collection links when the Filter Extension is enabled [#267](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/267)
12+
1013
### Changed
1114

1215
- Updated stac-fastapi libraries to v3.0.0a1 [#265](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/265)

data_loader.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,17 @@ def load_collection(base_url, collection_id, data_dir):
2222
collection["id"] = collection_id
2323
try:
2424
resp = requests.post(f"{base_url}/collections", json=collection)
25-
if resp.status_code == 200:
25+
if resp.status_code == 200 or resp.status_code == 201:
2626
click.echo(f"Status code: {resp.status_code}")
2727
click.echo(f"Added collection: {collection['id']}")
2828
elif resp.status_code == 409:
2929
click.echo(f"Status code: {resp.status_code}")
3030
click.echo(f"Collection: {collection['id']} already exists")
31+
else:
32+
click.echo(f"Status code: {resp.status_code}")
33+
click.echo(
34+
f"Error writing {collection['id']} collection. Message: {resp.text}"
35+
)
3136
except requests.ConnectionError:
3237
click.secho("Failed to connect", fg="red", err=True)
3338

sample_data/collection.json

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"id":"sentinel-s2-l2a-cogs-test",
33
"stac_version":"1.0.0",
4+
"type": "Collection",
45
"description":"Sentinel-2a and Sentinel-2b imagery, processed to Level 2A (Surface Reflectance) and converted to Cloud-Optimized GeoTIFFs",
56
"links":[
67
{"rel":"self","href":"https://earth-search.aws.element84.com/v0/collections/sentinel-s2-l2a-cogs"},

stac_fastapi/core/stac_fastapi/core/core.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
153153
conformance_classes=self.conformance_classes(),
154154
extension_schemas=[],
155155
)
156+
157+
if self.extension_is_enabled("FilterExtension"):
158+
landing_page["links"].append(
159+
{
160+
# TODO: replace this with Relations.queryables.value,
161+
"rel": "queryables",
162+
# TODO: replace this with MimeTypes.jsonschema,
163+
"type": "application/schema+json",
164+
"title": "Queryables",
165+
"href": urljoin(base_url, "queryables"),
166+
}
167+
)
168+
156169
collections = await self.all_collections(request=kwargs["request"])
157170
for collection in collections["collections"]:
158171
landing_page["links"].append(
@@ -205,7 +218,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
205218
token = request.query_params.get("token")
206219

207220
collections, next_token = await self.database.get_all_collections(
208-
token=token, limit=limit, base_url=base_url
221+
token=token, limit=limit, request=request
209222
)
210223

211224
links = [
@@ -239,10 +252,12 @@ async def get_collection(
239252
Raises:
240253
NotFoundError: If the collection with the given id cannot be found in the database.
241254
"""
242-
base_url = str(kwargs["request"].base_url)
255+
request = kwargs["request"]
243256
collection = await self.database.find_collection(collection_id=collection_id)
244257
return self.collection_serializer.db_to_stac(
245-
collection=collection, base_url=base_url
258+
collection=collection,
259+
request=request,
260+
extensions=[type(ext).__name__ for ext in self.extensions],
246261
)
247262

248263
async def item_collection(
@@ -748,12 +763,14 @@ async def create_collection(
748763
ConflictError: If the collection already exists.
749764
"""
750765
collection = collection.model_dump(mode="json")
751-
base_url = str(kwargs["request"].base_url)
752-
collection = self.database.collection_serializer.stac_to_db(
753-
collection, base_url
754-
)
766+
request = kwargs["request"]
767+
collection = self.database.collection_serializer.stac_to_db(collection, request)
755768
await self.database.create_collection(collection=collection)
756-
return CollectionSerializer.db_to_stac(collection, base_url)
769+
return CollectionSerializer.db_to_stac(
770+
collection,
771+
request,
772+
extensions=[type(ext).__name__ for ext in self.database.extensions],
773+
)
757774

758775
@overrides
759776
async def update_collection(
@@ -780,16 +797,18 @@ async def update_collection(
780797
"""
781798
collection = collection.model_dump(mode="json")
782799

783-
base_url = str(kwargs["request"].base_url)
800+
request = kwargs["request"]
784801

785-
collection = self.database.collection_serializer.stac_to_db(
786-
collection, base_url
787-
)
802+
collection = self.database.collection_serializer.stac_to_db(collection, request)
788803
await self.database.update_collection(
789804
collection_id=collection_id, collection=collection
790805
)
791806

792-
return CollectionSerializer.db_to_stac(collection, base_url)
807+
return CollectionSerializer.db_to_stac(
808+
collection,
809+
request,
810+
extensions=[type(ext).__name__ for ext in self.database.extensions],
811+
)
793812

794813
@overrides
795814
async def delete_collection(

stac_fastapi/core/stac_fastapi/core/models/links.py

+33
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,39 @@ async def get_links(
107107
return links
108108

109109

110+
@attr.s
111+
class CollectionLinks(BaseLinks):
112+
"""Create inferred links specific to collections."""
113+
114+
collection_id: str = attr.ib()
115+
extensions: List[str] = attr.ib(default=attr.Factory(list))
116+
117+
def link_parent(self) -> Dict[str, Any]:
118+
"""Create the `parent` link."""
119+
return dict(rel=Relations.parent, type=MimeTypes.json.value, href=self.base_url)
120+
121+
def link_items(self) -> Dict[str, Any]:
122+
"""Create the `items` link."""
123+
return dict(
124+
rel="items",
125+
type=MimeTypes.geojson.value,
126+
href=urljoin(self.base_url, f"collections/{self.collection_id}/items"),
127+
)
128+
129+
def link_queryables(self) -> Dict[str, Any]:
130+
"""Create the `queryables` link."""
131+
if "FilterExtension" in self.extensions:
132+
return dict(
133+
rel="queryables",
134+
type=MimeTypes.json.value,
135+
href=urljoin(
136+
self.base_url, f"collections/{self.collection_id}/queryables"
137+
),
138+
)
139+
else:
140+
return None
141+
142+
110143
@attr.s
111144
class PagingLinks(BaseLinks):
112145
"""Create links for paging."""

stac_fastapi/core/stac_fastapi/core/serializers.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Serializers."""
22
import abc
33
from copy import deepcopy
4-
from typing import Any
4+
from typing import Any, List, Optional
55

66
import attr
7+
from starlette.requests import Request
78

89
from stac_fastapi.core.datetime_utils import now_to_rfc3339_str
10+
from stac_fastapi.core.models.links import CollectionLinks
911
from stac_fastapi.types import stac as stac_types
10-
from stac_fastapi.types.links import CollectionLinks, ItemLinks, resolve_links
12+
from stac_fastapi.types.links import ItemLinks, resolve_links
1113

1214

1315
@attr.s
@@ -109,29 +111,34 @@ class CollectionSerializer(Serializer):
109111

110112
@classmethod
111113
def stac_to_db(
112-
cls, collection: stac_types.Collection, base_url: str
114+
cls, collection: stac_types.Collection, request: Request
113115
) -> stac_types.Collection:
114116
"""
115117
Transform STAC Collection to database-ready STAC collection.
116118
117119
Args:
118120
stac_data: the STAC Collection object to be transformed
119-
base_url: the base URL for the STAC API
121+
starlette.requests.Request: the API request
120122
121123
Returns:
122124
stac_types.Collection: The database-ready STAC Collection object.
123125
"""
124126
collection = deepcopy(collection)
125-
collection["links"] = resolve_links(collection.get("links", []), base_url)
127+
collection["links"] = resolve_links(
128+
collection.get("links", []), str(request.base_url)
129+
)
126130
return collection
127131

128132
@classmethod
129-
def db_to_stac(cls, collection: dict, base_url: str) -> stac_types.Collection:
133+
def db_to_stac(
134+
cls, collection: dict, request: Request, extensions: Optional[List[str]] = []
135+
) -> stac_types.Collection:
130136
"""Transform database model to STAC collection.
131137
132138
Args:
133139
collection (dict): The collection data in dictionary form, extracted from the database.
134-
base_url (str): The base URL for the collection.
140+
starlette.requests.Request: the API request
141+
extensions: A list of the extension class names (`ext.__name__`) or all enabled STAC API extensions.
135142
136143
Returns:
137144
stac_types.Collection: The STAC collection object.
@@ -157,13 +164,13 @@ def db_to_stac(cls, collection: dict, base_url: str) -> stac_types.Collection:
157164

158165
# Create the collection links using CollectionLinks
159166
collection_links = CollectionLinks(
160-
collection_id=collection_id, base_url=base_url
167+
collection_id=collection_id, request=request, extensions=extensions
161168
).create_links()
162169

163170
# Add any additional links from the collection dictionary
164171
original_links = collection.get("links")
165172
if original_links:
166-
collection_links += resolve_links(original_links, base_url)
173+
collection_links += resolve_links(original_links, str(request.base_url))
167174
collection["links"] = collection_links
168175

169176
# Return the stac_types.Collection object

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
filter_extension,
6060
]
6161

62+
database_logic.extensions = [type(ext).__name__ for ext in extensions]
63+
6264
post_request_model = create_post_request_model(extensions)
6365

6466
api = StacApi(

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import attr
99
from elasticsearch_dsl import Q, Search
10+
from starlette.requests import Request
1011

1112
from elasticsearch import exceptions, helpers # type: ignore
1213
from stac_fastapi.core.extensions import filter
@@ -312,10 +313,12 @@ class DatabaseLogic:
312313
default=CollectionSerializer
313314
)
314315

316+
extensions: List[str] = attr.ib(default=attr.Factory(list))
317+
315318
"""CORE LOGIC"""
316319

317320
async def get_all_collections(
318-
self, token: Optional[str], limit: int, base_url: str
321+
self, token: Optional[str], limit: int, request: Request
319322
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
320323
"""Retrieve a list of all collections from Elasticsearch, supporting pagination.
321324
@@ -342,7 +345,7 @@ async def get_all_collections(
342345
hits = response["hits"]["hits"]
343346
collections = [
344347
self.collection_serializer.db_to_stac(
345-
collection=hit["_source"], base_url=base_url
348+
collection=hit["_source"], request=request, extensions=self.extensions
346349
)
347350
for hit in hits
348351
]

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
filter_extension,
6060
]
6161

62+
database_logic.extensions = [type(ext).__name__ for ext in extensions]
63+
6264
post_request_model = create_post_request_model(extensions)
6365

6466
api = StacApi(

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from opensearchpy.exceptions import TransportError
1111
from opensearchpy.helpers.query import Q
1212
from opensearchpy.helpers.search import Search
13+
from starlette.requests import Request
1314

1415
from stac_fastapi.core import serializers
1516
from stac_fastapi.core.extensions import filter
@@ -333,10 +334,12 @@ class DatabaseLogic:
333334
default=serializers.CollectionSerializer
334335
)
335336

337+
extensions: List[str] = attr.ib(default=attr.Factory(list))
338+
336339
"""CORE LOGIC"""
337340

338341
async def get_all_collections(
339-
self, token: Optional[str], limit: int, base_url: str
342+
self, token: Optional[str], limit: int, request: Request
340343
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
341344
"""
342345
Retrieve a list of all collections from Opensearch, supporting pagination.
@@ -366,7 +369,7 @@ async def get_all_collections(
366369
hits = response["hits"]["hits"]
367370
collections = [
368371
self.collection_serializer.db_to_stac(
369-
collection=hit["_source"], base_url=base_url
372+
collection=hit["_source"], request=request, extensions=self.extensions
370373
)
371374
for hit in hits
372375
]

stac_fastapi/tests/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, item, collection):
5858

5959
class MockRequest:
6060
base_url = "http://test-server"
61+
url = "http://test-server/test"
6162
query_params = {}
6263

6364
def __init__(

stac_fastapi/tests/extensions/test_filter.py

+29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,35 @@
88
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
99

1010

11+
@pytest.mark.asyncio
12+
async def test_filter_extension_landing_page_link(app_client, ctx):
13+
resp = await app_client.get("/")
14+
assert resp.status_code == 200
15+
16+
resp_json = resp.json()
17+
keys = [link["rel"] for link in resp_json["links"]]
18+
19+
assert "queryables" in keys
20+
21+
22+
@pytest.mark.asyncio
23+
async def test_filter_extension_collection_link(app_client, load_test_data):
24+
"""Test creation and deletion of a collection"""
25+
test_collection = load_test_data("test_collection.json")
26+
test_collection["id"] = "test"
27+
28+
resp = await app_client.post("/collections", json=test_collection)
29+
assert resp.status_code == 201
30+
31+
resp = await app_client.get(f"/collections/{test_collection['id']}")
32+
resp_json = resp.json()
33+
keys = [link["rel"] for link in resp_json["links"]]
34+
assert "queryables" in keys
35+
36+
resp = await app_client.delete(f"/collections/{test_collection['id']}")
37+
assert resp.status_code == 204
38+
39+
1140
@pytest.mark.asyncio
1241
async def test_search_filters_post(app_client, ctx):
1342

0 commit comments

Comments
 (0)