Skip to content

Commit e7825a0

Browse files
committed
Update node.py
Add ancestry tools
1 parent 0bd5c50 commit e7825a0

File tree

1 file changed

+113
-13
lines changed
  • datacommons_client/endpoints

1 file changed

+113
-13
lines changed

datacommons_client/endpoints/node.py

Lines changed: 113 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
from concurrent.futures import ThreadPoolExecutor
12
from typing import Optional
23

34
from datacommons_client.endpoints.base import API
45
from datacommons_client.endpoints.base import Endpoint
56
from datacommons_client.endpoints.payloads import NodeRequestPayload
67
from datacommons_client.endpoints.payloads import normalize_properties_to_string
78
from datacommons_client.endpoints.response import NodeResponse
9+
from datacommons_client.models.node import Node
10+
from datacommons_client.utils.graph import build_ancestry_map
11+
from datacommons_client.utils.graph import build_ancestry_tree
12+
from datacommons_client.utils.graph import build_parents_dictionary
13+
from datacommons_client.utils.graph import fetch_parents_lru
14+
from datacommons_client.utils.graph import flatten_ancestry
15+
from datacommons_client.utils.graph import Parent
16+
17+
ANCESTRY_MAX_WORKERS = 20
818

919

1020
class NodeEndpoint(Endpoint):
@@ -91,10 +101,12 @@ def fetch_property_labels(
91101
expression = "->" if out else "<-"
92102

93103
# Make the request and return the response.
94-
return self.fetch(node_dcids=node_dcids,
95-
expression=expression,
96-
all_pages=all_pages,
97-
next_token=next_token)
104+
return self.fetch(
105+
node_dcids=node_dcids,
106+
expression=expression,
107+
all_pages=all_pages,
108+
next_token=next_token,
109+
)
98110

99111
def fetch_property_values(
100112
self,
@@ -143,10 +155,12 @@ def fetch_property_values(
143155
if constraints:
144156
expression += f"{{{constraints}}}"
145157

146-
return self.fetch(node_dcids=node_dcids,
147-
expression=expression,
148-
all_pages=all_pages,
149-
next_token=next_token)
158+
return self.fetch(
159+
node_dcids=node_dcids,
160+
expression=expression,
161+
all_pages=all_pages,
162+
next_token=next_token,
163+
)
150164

151165
def fetch_all_classes(
152166
self,
@@ -174,8 +188,94 @@ def fetch_all_classes(
174188
```
175189
"""
176190

177-
return self.fetch_property_values(node_dcids="Class",
178-
properties="typeOf",
179-
out=False,
180-
all_pages=all_pages,
181-
next_token=next_token)
191+
return self.fetch_property_values(
192+
node_dcids="Class",
193+
properties="typeOf",
194+
out=False,
195+
all_pages=all_pages,
196+
next_token=next_token,
197+
)
198+
199+
def fetch_entity_parents(
200+
self, entity_dcids: str | list[str]) -> dict[str, list[Parent]]:
201+
"""Fetches the direct parents of one or more entities using the 'containedInPlace' property.
202+
203+
Args:
204+
entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query.
205+
206+
Returns:
207+
dict[str, list[Parent]]: A dictionary mapping each input DCID to a list of its
208+
immediate parent entities. Each parent is represented as a Parent object, which
209+
contains the DCID, name, and type of the parent entity.
210+
"""
211+
# Fetch property values from the API
212+
data = self.fetch_property_values(
213+
node_dcids=entity_dcids,
214+
properties="containedInPlace",
215+
).get_properties()
216+
217+
return build_parents_dictionary(data=data)
218+
219+
220+
def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]:
221+
"""Returns cached parent nodes for a given entity using an LRU cache.
222+
223+
This private wrapper exists because `@lru_cache` cannot be applied directly
224+
to instance methods. By passing the `NodeEndpoint` instance (`self`) as an
225+
argument caching is preserved while keeping the implementation modular and testable.
226+
227+
Args:
228+
dcid (str): The DCID of the entity whose parents should be fetched.
229+
230+
Returns:
231+
tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents.
232+
"""
233+
return fetch_parents_lru(self, dcid)
234+
235+
236+
def fetch_entity_ancestry(
237+
self,
238+
entity_dcids: str | list[str],
239+
as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]:
240+
"""Fetches the full ancestry (flat or nested) for one or more entities.
241+
For each input DCID, this method builds the complete ancestry graph using a
242+
breadth-first traversal and parallel fetching.
243+
It returns either a flat list of unique parents or a nested tree structure for
244+
each entity, depending on the `as_tree` flag. The flat list matches the structure
245+
of the `/api/place/parent` endpoint of the DC website.
246+
Args:
247+
entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry
248+
will be fetched.
249+
as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list.
250+
Defaults to False.
251+
Returns:
252+
dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either:
253+
- A flat list of parent dictionaries (if `as_tree` is False), or
254+
- A nested ancestry tree (if `as_tree` is True). Each parent is represented by
255+
a dict with 'dcid', 'name', and 'type'.
256+
"""
257+
258+
if isinstance(entity_dcids, str):
259+
entity_dcids = [entity_dcids]
260+
261+
result = {}
262+
263+
# Use a thread pool to fetch ancestry graphs in parallel for each input entity
264+
with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor:
265+
futures = [
266+
executor.submit(build_ancestry_map,
267+
root=dcid,
268+
fetch_fn=self._fetch_parents_cached)
269+
for dcid in entity_dcids
270+
]
271+
272+
# Gather ancestry maps and postprocess into flat or nested form
273+
for future in futures:
274+
dcid, ancestry = future.result()
275+
if as_tree:
276+
ancestry = build_ancestry_tree(dcid, ancestry)
277+
else:
278+
ancestry = flatten_ancestry(ancestry)
279+
result[dcid] = ancestry
280+
281+
return result

0 commit comments

Comments
 (0)