Skip to content

Commit abed7ea

Browse files
committed
refactor types
1 parent e7825a0 commit abed7ea

File tree

3 files changed

+99
-85
lines changed

3 files changed

+99
-85
lines changed

datacommons_client/endpoints/node.py

Lines changed: 74 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from datacommons_client.endpoints.payloads import NodeRequestPayload
77
from datacommons_client.endpoints.payloads import normalize_properties_to_string
88
from datacommons_client.endpoints.response import NodeResponse
9-
from datacommons_client.models.node import Node
9+
from datacommons_client.models.graph import Parent
1010
from datacommons_client.utils.graph import build_ancestry_map
1111
from datacommons_client.utils.graph import build_ancestry_tree
1212
from datacommons_client.utils.graph import build_parents_dictionary
1313
from datacommons_client.utils.graph import fetch_parents_lru
1414
from datacommons_client.utils.graph import flatten_ancestry
15-
from datacommons_client.utils.graph import Parent
1615

1716
ANCESTRY_MAX_WORKERS = 20
1817

@@ -197,85 +196,95 @@ def fetch_all_classes(
197196
)
198197

199198
def fetch_entity_parents(
200-
self, entity_dcids: str | list[str]) -> dict[str, list[Parent]]:
199+
self,
200+
entity_dcids: str | list[str],
201+
*,
202+
as_dict: bool = True) -> dict[str, list[Parent | dict]]:
201203
"""Fetches the direct parents of one or more entities using the 'containedInPlace' property.
202204
203205
Args:
204206
entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query.
207+
as_dict (bool): If True, returns a dictionary mapping each input DCID to its
208+
immediate parent entities. If False, returns a dictionary of Parent objects (which
209+
are dataclasses).
205210
206211
Returns:
207-
dict[str, list[Parent]]: A dictionary mapping each input DCID to a list of its
208-
immediate parent entities. Each parent is represented as a Parent object, which
209-
contains the DCID, name, and type of the parent entity.
212+
dict[str, list[Parent | dict]]: A dictionary mapping each input DCID to a list of its
213+
immediate parent entities. Each parent is represented as a Parent object (which
214+
contains the DCID, name, and type of the parent entity) or as a dictionary with
215+
the same data.
210216
"""
211217
# Fetch property values from the API
212218
data = self.fetch_property_values(
213219
node_dcids=entity_dcids,
214220
properties="containedInPlace",
215221
).get_properties()
216222

217-
return build_parents_dictionary(data=data)
223+
result = build_parents_dictionary(data=data)
218224

225+
if as_dict:
226+
return {k: [p.to_dict() for p in v] for k, v in result.items()}
219227

220-
def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]:
221-
"""Returns cached parent nodes for a given entity using an LRU cache.
228+
return result
222229

223-
This private wrapper exists because `@lru_cache` cannot be applied directly
224-
to instance methods. By passing the `NodeEndpoint` instance (`self`) as an
225-
argument caching is preserved while keeping the implementation modular and testable.
230+
def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]:
231+
"""Returns cached parent nodes for a given entity using an LRU cache.
226232
227-
Args:
228-
dcid (str): The DCID of the entity whose parents should be fetched.
233+
This private wrapper exists because `@lru_cache` cannot be applied directly
234+
to instance methods. By passing the `NodeEndpoint` instance (`self`) as an
235+
argument caching is preserved while keeping the implementation modular and testable.
229236
230-
Returns:
231-
tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents.
232-
"""
233-
return fetch_parents_lru(self, dcid)
234-
235-
236-
def fetch_entity_ancestry(
237-
self,
238-
entity_dcids: str | list[str],
239-
as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]:
240-
"""Fetches the full ancestry (flat or nested) for one or more entities.
241-
For each input DCID, this method builds the complete ancestry graph using a
242-
breadth-first traversal and parallel fetching.
243-
It returns either a flat list of unique parents or a nested tree structure for
244-
each entity, depending on the `as_tree` flag. The flat list matches the structure
245-
of the `/api/place/parent` endpoint of the DC website.
246-
Args:
247-
entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry
248-
will be fetched.
249-
as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list.
250-
Defaults to False.
251-
Returns:
252-
dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either:
253-
- A flat list of parent dictionaries (if `as_tree` is False), or
254-
- A nested ancestry tree (if `as_tree` is True). Each parent is represented by
255-
a dict with 'dcid', 'name', and 'type'.
256-
"""
237+
Args:
238+
dcid (str): The DCID of the entity whose parents should be fetched.
239+
240+
Returns:
241+
tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents.
242+
"""
243+
return fetch_parents_lru(self, dcid)
244+
245+
def fetch_entity_ancestry(
246+
self,
247+
entity_dcids: str | list[str],
248+
as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]:
249+
"""Fetches the full ancestry (flat or nested) for one or more entities.
250+
For each input DCID, this method builds the complete ancestry graph using a
251+
breadth-first traversal and parallel fetching.
252+
It returns either a flat list of unique parents or a nested tree structure for
253+
each entity, depending on the `as_tree` flag. The flat list matches the structure
254+
of the `/api/place/parent` endpoint of the DC website.
255+
Args:
256+
entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry
257+
will be fetched.
258+
as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list.
259+
Defaults to False.
260+
Returns:
261+
dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either:
262+
- A flat list of parent dictionaries (if `as_tree` is False), or
263+
- A nested ancestry tree (if `as_tree` is True). Each parent is represented by
264+
a dict with 'dcid', 'name', and 'type'.
265+
"""
257266

258-
if isinstance(entity_dcids, str):
259-
entity_dcids = [entity_dcids]
260-
261-
result = {}
262-
263-
# Use a thread pool to fetch ancestry graphs in parallel for each input entity
264-
with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor:
265-
futures = [
266-
executor.submit(build_ancestry_map,
267-
root=dcid,
268-
fetch_fn=self._fetch_parents_cached)
269-
for dcid in entity_dcids
270-
]
271-
272-
# Gather ancestry maps and postprocess into flat or nested form
273-
for future in futures:
274-
dcid, ancestry = future.result()
275-
if as_tree:
276-
ancestry = build_ancestry_tree(dcid, ancestry)
277-
else:
278-
ancestry = flatten_ancestry(ancestry)
279-
result[dcid] = ancestry
280-
281-
return result
267+
if isinstance(entity_dcids, str):
268+
entity_dcids = [entity_dcids]
269+
270+
result = {}
271+
272+
# Use a thread pool to fetch ancestry graphs in parallel for each input entity
273+
with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor:
274+
futures = [
275+
executor.submit(build_ancestry_map,
276+
root=dcid,
277+
fetch_fn=self._fetch_parents_cached)
278+
for dcid in entity_dcids
279+
]
280+
281+
# Gather ancestry maps and postprocess into flat or nested form
282+
for future in futures:
283+
dcid, ancestry = future.result()
284+
if as_tree:
285+
ancestry = build_ancestry_tree(dcid, ancestry)
286+
else:
287+
ancestry = flatten_ancestry(ancestry)
288+
result[dcid] = ancestry
289+
290+
return result

datacommons_client/models/graph.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dataclasses import dataclass
2+
from typing import TypeAlias
3+
4+
from datacommons_client.utils.data_processing import SerializableMixin
5+
6+
7+
@dataclass(frozen=True)
8+
class Parent(SerializableMixin):
9+
"""A class representing a parent node in a graph.
10+
Attributes:
11+
dcid (str): The ID of the parent node.
12+
name (str): The name of the parent node.
13+
type (str | list[str]): The type(s) of the parent node.
14+
"""
15+
16+
dcid: str
17+
name: str
18+
type: str | list[str]
19+
20+
21+
# A dictionary mapping DCIDs to lists of Parent objects.
22+
AncestryMap: TypeAlias = dict[str, list[Parent]]

datacommons_client/utils/graph.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,14 @@
33
from concurrent.futures import Future
44
from concurrent.futures import ThreadPoolExecutor
55
from concurrent.futures import wait
6-
from dataclasses import dataclass
76
from functools import lru_cache
87
from typing import Callable, Optional
98

10-
from datacommons_client.utils.data_processing import SerializableMixin
9+
from datacommons_client.models.graph import AncestryMap
10+
from datacommons_client.models.graph import Parent
1111

1212
PARENTS_MAX_WORKERS = 10
1313

14-
15-
@dataclass(frozen=True)
16-
class Parent(SerializableMixin):
17-
"""A class representing a parent node in a graph.
18-
Attributes:
19-
dcid (str): The ID of the parent node.
20-
name (str): The name of the parent node.
21-
type (str | list[str]): The type(s) of the parent node.
22-
"""
23-
24-
dcid: str
25-
name: str
26-
type: str | list[str]
27-
28-
29-
AncestryMap = dict[str, list[Parent]]
30-
3114
# -- -- Fetch tools -- --
3215

3316

@@ -48,7 +31,7 @@ def _fetch_parents_uncached(endpoint, dcid: str) -> list[Parent]:
4831
Returns:
4932
A list of parent dictionaries, each containing 'dcid', 'name', and 'type'.
5033
"""
51-
return endpoint.fetch_entity_parents(dcid).get(dcid, [])
34+
return endpoint.fetch_entity_parents(dcid, as_dict=False).get(dcid, [])
5235

5336

5437
@lru_cache(maxsize=512)

0 commit comments

Comments
 (0)