|
| 1 | +from concurrent.futures import ThreadPoolExecutor |
1 | 2 | from typing import Optional
|
2 | 3 |
|
3 | 4 | from datacommons_client.endpoints.base import API
|
4 | 5 | from datacommons_client.endpoints.base import Endpoint
|
5 | 6 | from datacommons_client.endpoints.payloads import NodeRequestPayload
|
6 | 7 | from datacommons_client.endpoints.payloads import normalize_properties_to_string
|
7 | 8 | from datacommons_client.endpoints.response import NodeResponse
|
| 9 | +from datacommons_client.models.node import Node |
| 10 | +from datacommons_client.utils.graph import build_ancestry_map |
| 11 | +from datacommons_client.utils.graph import build_ancestry_tree |
| 12 | +from datacommons_client.utils.graph import fetch_parents_lru |
| 13 | +from datacommons_client.utils.graph import flatten_ancestry |
| 14 | + |
| 15 | +ANCESTRY_MAX_WORKERS = 10 |
8 | 16 |
|
9 | 17 |
|
10 | 18 | class NodeEndpoint(Endpoint):
|
@@ -91,10 +99,12 @@ def fetch_property_labels(
|
91 | 99 | expression = "->" if out else "<-"
|
92 | 100 |
|
93 | 101 | # Make the request and return the response.
|
94 |
| - return self.fetch(node_dcids=node_dcids, |
95 |
| - expression=expression, |
96 |
| - all_pages=all_pages, |
97 |
| - next_token=next_token) |
| 102 | + return self.fetch( |
| 103 | + node_dcids=node_dcids, |
| 104 | + expression=expression, |
| 105 | + all_pages=all_pages, |
| 106 | + next_token=next_token, |
| 107 | + ) |
98 | 108 |
|
99 | 109 | def fetch_property_values(
|
100 | 110 | self,
|
@@ -143,10 +153,12 @@ def fetch_property_values(
|
143 | 153 | if constraints:
|
144 | 154 | expression += f"{{{constraints}}}"
|
145 | 155 |
|
146 |
| - return self.fetch(node_dcids=node_dcids, |
147 |
| - expression=expression, |
148 |
| - all_pages=all_pages, |
149 |
| - next_token=next_token) |
| 156 | + return self.fetch( |
| 157 | + node_dcids=node_dcids, |
| 158 | + expression=expression, |
| 159 | + all_pages=all_pages, |
| 160 | + next_token=next_token, |
| 161 | + ) |
150 | 162 |
|
151 | 163 | def fetch_all_classes(
|
152 | 164 | self,
|
@@ -174,8 +186,107 @@ def fetch_all_classes(
|
174 | 186 | ```
|
175 | 187 | """
|
176 | 188 |
|
177 |
| - return self.fetch_property_values(node_dcids="Class", |
178 |
| - properties="typeOf", |
179 |
| - out=False, |
180 |
| - all_pages=all_pages, |
181 |
| - next_token=next_token) |
| 189 | + return self.fetch_property_values( |
| 190 | + node_dcids="Class", |
| 191 | + properties="typeOf", |
| 192 | + out=False, |
| 193 | + all_pages=all_pages, |
| 194 | + next_token=next_token, |
| 195 | + ) |
| 196 | + |
| 197 | + def fetch_entity_parents( |
| 198 | + self, |
| 199 | + entity_dcids: str | list[str], |
| 200 | + *, |
| 201 | + as_dict: bool = True) -> dict[str, list[Node | dict]]: |
| 202 | + """Fetches the direct parents of one or more entities using the 'containedInPlace' property. |
| 203 | +
|
| 204 | + Args: |
| 205 | + entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query. |
| 206 | + as_dict (bool): If True, returns a dictionary mapping each input DCID to its |
| 207 | + immediate parent entities. If False, returns a dictionary of Parent objects (which |
| 208 | + are dataclasses). |
| 209 | +
|
| 210 | + Returns: |
| 211 | + dict[str, list[Parent | dict]]: A dictionary mapping each input DCID to a list of its |
| 212 | + immediate parent entities. Each parent is represented as a Parent object (which |
| 213 | + contains the DCID, name, and type of the parent entity) or as a dictionary with |
| 214 | + the same data. |
| 215 | + """ |
| 216 | + # Fetch property values from the API |
| 217 | + data = self.fetch_property_values( |
| 218 | + node_dcids=entity_dcids, |
| 219 | + properties="containedInPlace", |
| 220 | + ).get_properties() |
| 221 | + |
| 222 | + if as_dict: |
| 223 | + return {k: v.to_dict() for k, v in data.items()} |
| 224 | + |
| 225 | + return data |
| 226 | + |
| 227 | + def _fetch_parents_cached(self, dcid: str) -> tuple[Node, ...]: |
| 228 | + """Returns cached parent nodes for a given entity using an LRU cache. |
| 229 | +
|
| 230 | + This private wrapper exists because `@lru_cache` cannot be applied directly |
| 231 | + to instance methods. By passing the `NodeEndpoint` instance (`self`) as an |
| 232 | + argument caching is preserved while keeping the implementation modular and testable. |
| 233 | +
|
| 234 | + Args: |
| 235 | + dcid (str): The DCID of the entity whose parents should be fetched. |
| 236 | +
|
| 237 | + Returns: |
| 238 | + tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents. |
| 239 | + """ |
| 240 | + return fetch_parents_lru(self, dcid) |
| 241 | + |
| 242 | + def fetch_entity_ancestry( |
| 243 | + self, |
| 244 | + entity_dcids: str | list[str], |
| 245 | + as_tree: bool = False, |
| 246 | + *, |
| 247 | + max_concurrent_requests: Optional[int] = ANCESTRY_MAX_WORKERS |
| 248 | + ) -> dict[str, list[dict[str, str]] | dict]: |
| 249 | + """Fetches the full ancestry (flat or nested) for one or more entities. |
| 250 | + For each input DCID, this method builds the complete ancestry graph using a |
| 251 | + breadth-first traversal and parallel fetching. |
| 252 | + It returns either a flat list of unique parents or a nested tree structure for |
| 253 | + each entity, depending on the `as_tree` flag. The flat list matches the structure |
| 254 | + of the `/api/place/parent` endpoint of the DC website. |
| 255 | + Args: |
| 256 | + entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry |
| 257 | + will be fetched. |
| 258 | + as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. |
| 259 | + Defaults to False. |
| 260 | + max_concurrent_requests (Optional[int]): The maximum number of concurrent requests to make. |
| 261 | + Defaults to ANCESTRY_MAX_WORKERS. |
| 262 | + Returns: |
| 263 | + dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: |
| 264 | + - A flat list of parent dictionaries (if `as_tree` is False), or |
| 265 | + - A nested ancestry tree (if `as_tree` is True). Each parent is represented by |
| 266 | + a dict with 'dcid', 'name', and 'type'. |
| 267 | + """ |
| 268 | + |
| 269 | + if isinstance(entity_dcids, str): |
| 270 | + entity_dcids = [entity_dcids] |
| 271 | + |
| 272 | + result = {} |
| 273 | + |
| 274 | + # Use a thread pool to fetch ancestry graphs in parallel for each input entity |
| 275 | + with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: |
| 276 | + futures = [ |
| 277 | + executor.submit(build_ancestry_map, |
| 278 | + root=dcid, |
| 279 | + fetch_fn=self._fetch_parents_cached) |
| 280 | + for dcid in entity_dcids |
| 281 | + ] |
| 282 | + |
| 283 | + # Gather ancestry maps and postprocess into flat or nested form |
| 284 | + for future in futures: |
| 285 | + dcid, ancestry = future.result() |
| 286 | + if as_tree: |
| 287 | + ancestry = build_ancestry_tree(dcid, ancestry) |
| 288 | + else: |
| 289 | + ancestry = flatten_ancestry(ancestry) |
| 290 | + result[dcid] = ancestry |
| 291 | + |
| 292 | + return result |
0 commit comments