|
| 1 | +from concurrent.futures import ThreadPoolExecutor |
1 | 2 | from typing import Optional
|
2 | 3 |
|
3 | 4 | from datacommons_client.endpoints.base import API
|
4 | 5 | from datacommons_client.endpoints.base import Endpoint
|
5 | 6 | from datacommons_client.endpoints.payloads import NodeRequestPayload
|
6 | 7 | from datacommons_client.endpoints.payloads import normalize_properties_to_string
|
7 | 8 | from datacommons_client.endpoints.response import NodeResponse
|
| 9 | +from datacommons_client.models.node import Node |
| 10 | +from datacommons_client.utils.graph import build_ancestry_map |
| 11 | +from datacommons_client.utils.graph import build_ancestry_tree |
| 12 | +from datacommons_client.utils.graph import build_parents_dictionary |
| 13 | +from datacommons_client.utils.graph import fetch_parents_lru |
| 14 | +from datacommons_client.utils.graph import flatten_ancestry |
| 15 | +from datacommons_client.utils.graph import Parent |
| 16 | + |
| 17 | +ANCESTRY_MAX_WORKERS = 20 |
8 | 18 |
|
9 | 19 |
|
10 | 20 | class NodeEndpoint(Endpoint):
|
@@ -91,10 +101,12 @@ def fetch_property_labels(
|
91 | 101 | expression = "->" if out else "<-"
|
92 | 102 |
|
93 | 103 | # Make the request and return the response.
|
94 |
| - return self.fetch(node_dcids=node_dcids, |
95 |
| - expression=expression, |
96 |
| - all_pages=all_pages, |
97 |
| - next_token=next_token) |
| 104 | + return self.fetch( |
| 105 | + node_dcids=node_dcids, |
| 106 | + expression=expression, |
| 107 | + all_pages=all_pages, |
| 108 | + next_token=next_token, |
| 109 | + ) |
98 | 110 |
|
99 | 111 | def fetch_property_values(
|
100 | 112 | self,
|
@@ -143,10 +155,12 @@ def fetch_property_values(
|
143 | 155 | if constraints:
|
144 | 156 | expression += f"{{{constraints}}}"
|
145 | 157 |
|
146 |
| - return self.fetch(node_dcids=node_dcids, |
147 |
| - expression=expression, |
148 |
| - all_pages=all_pages, |
149 |
| - next_token=next_token) |
| 158 | + return self.fetch( |
| 159 | + node_dcids=node_dcids, |
| 160 | + expression=expression, |
| 161 | + all_pages=all_pages, |
| 162 | + next_token=next_token, |
| 163 | + ) |
150 | 164 |
|
151 | 165 | def fetch_all_classes(
|
152 | 166 | self,
|
@@ -174,8 +188,94 @@ def fetch_all_classes(
|
174 | 188 | ```
|
175 | 189 | """
|
176 | 190 |
|
177 |
| - return self.fetch_property_values(node_dcids="Class", |
178 |
| - properties="typeOf", |
179 |
| - out=False, |
180 |
| - all_pages=all_pages, |
181 |
| - next_token=next_token) |
| 191 | + return self.fetch_property_values( |
| 192 | + node_dcids="Class", |
| 193 | + properties="typeOf", |
| 194 | + out=False, |
| 195 | + all_pages=all_pages, |
| 196 | + next_token=next_token, |
| 197 | + ) |
| 198 | + |
| 199 | + def fetch_entity_parents( |
| 200 | + self, entity_dcids: str | list[str]) -> dict[str, list[Parent]]: |
| 201 | + """Fetches the direct parents of one or more entities using the 'containedInPlace' property. |
| 202 | +
|
| 203 | + Args: |
| 204 | + entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query. |
| 205 | +
|
| 206 | + Returns: |
| 207 | + dict[str, list[Parent]]: A dictionary mapping each input DCID to a list of its |
| 208 | + immediate parent entities. Each parent is represented as a Parent object, which |
| 209 | + contains the DCID, name, and type of the parent entity. |
| 210 | + """ |
| 211 | + # Fetch property values from the API |
| 212 | + data = self.fetch_property_values( |
| 213 | + node_dcids=entity_dcids, |
| 214 | + properties="containedInPlace", |
| 215 | + ).get_properties() |
| 216 | + |
| 217 | + return build_parents_dictionary(data=data) |
| 218 | + |
| 219 | + |
| 220 | +def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]: |
| 221 | + """Returns cached parent nodes for a given entity using an LRU cache. |
| 222 | +
|
| 223 | + This private wrapper exists because `@lru_cache` cannot be applied directly |
| 224 | + to instance methods. By passing the `NodeEndpoint` instance (`self`) as an |
| 225 | + argument caching is preserved while keeping the implementation modular and testable. |
| 226 | +
|
| 227 | + Args: |
| 228 | + dcid (str): The DCID of the entity whose parents should be fetched. |
| 229 | +
|
| 230 | + Returns: |
| 231 | + tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents. |
| 232 | + """ |
| 233 | + return fetch_parents_lru(self, dcid) |
| 234 | + |
| 235 | + |
| 236 | +def fetch_entity_ancestry( |
| 237 | + self, |
| 238 | + entity_dcids: str | list[str], |
| 239 | + as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]: |
| 240 | + """Fetches the full ancestry (flat or nested) for one or more entities. |
| 241 | + For each input DCID, this method builds the complete ancestry graph using a |
| 242 | + breadth-first traversal and parallel fetching. |
| 243 | + It returns either a flat list of unique parents or a nested tree structure for |
| 244 | + each entity, depending on the `as_tree` flag. The flat list matches the structure |
| 245 | + of the `/api/place/parent` endpoint of the DC website. |
| 246 | + Args: |
| 247 | + entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry |
| 248 | + will be fetched. |
| 249 | + as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. |
| 250 | + Defaults to False. |
| 251 | + Returns: |
| 252 | + dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: |
| 253 | + - A flat list of parent dictionaries (if `as_tree` is False), or |
| 254 | + - A nested ancestry tree (if `as_tree` is True). Each parent is represented by |
| 255 | + a dict with 'dcid', 'name', and 'type'. |
| 256 | + """ |
| 257 | + |
| 258 | + if isinstance(entity_dcids, str): |
| 259 | + entity_dcids = [entity_dcids] |
| 260 | + |
| 261 | + result = {} |
| 262 | + |
| 263 | + # Use a thread pool to fetch ancestry graphs in parallel for each input entity |
| 264 | + with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor: |
| 265 | + futures = [ |
| 266 | + executor.submit(build_ancestry_map, |
| 267 | + root=dcid, |
| 268 | + fetch_fn=self._fetch_parents_cached) |
| 269 | + for dcid in entity_dcids |
| 270 | + ] |
| 271 | + |
| 272 | + # Gather ancestry maps and postprocess into flat or nested form |
| 273 | + for future in futures: |
| 274 | + dcid, ancestry = future.result() |
| 275 | + if as_tree: |
| 276 | + ancestry = build_ancestry_tree(dcid, ancestry) |
| 277 | + else: |
| 278 | + ancestry = flatten_ancestry(ancestry) |
| 279 | + result[dcid] = ancestry |
| 280 | + |
| 281 | + return result |
0 commit comments