Skip to content

Add parents/ancestry methods #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 124 additions & 13 deletions datacommons_client/endpoints/node.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

from datacommons_client.endpoints.base import API
from datacommons_client.endpoints.base import Endpoint
from datacommons_client.endpoints.payloads import NodeRequestPayload
from datacommons_client.endpoints.payloads import normalize_properties_to_string
from datacommons_client.endpoints.response import NodeResponse
from datacommons_client.models.node import Node
from datacommons_client.utils.graph import build_ancestry_map
from datacommons_client.utils.graph import build_ancestry_tree
from datacommons_client.utils.graph import fetch_parents_lru
from datacommons_client.utils.graph import flatten_ancestry

ANCESTRY_MAX_WORKERS = 10


class NodeEndpoint(Endpoint):
Expand Down Expand Up @@ -91,10 +99,12 @@ def fetch_property_labels(
expression = "->" if out else "<-"

# Make the request and return the response.
return self.fetch(node_dcids=node_dcids,
expression=expression,
all_pages=all_pages,
next_token=next_token)
return self.fetch(
node_dcids=node_dcids,
expression=expression,
all_pages=all_pages,
next_token=next_token,
)

def fetch_property_values(
self,
Expand Down Expand Up @@ -143,10 +153,12 @@ def fetch_property_values(
if constraints:
expression += f"{{{constraints}}}"

return self.fetch(node_dcids=node_dcids,
expression=expression,
all_pages=all_pages,
next_token=next_token)
return self.fetch(
node_dcids=node_dcids,
expression=expression,
all_pages=all_pages,
next_token=next_token,
)

def fetch_all_classes(
self,
Expand Down Expand Up @@ -174,8 +186,107 @@ def fetch_all_classes(
```
"""

return self.fetch_property_values(node_dcids="Class",
properties="typeOf",
out=False,
all_pages=all_pages,
next_token=next_token)
return self.fetch_property_values(
node_dcids="Class",
properties="typeOf",
out=False,
all_pages=all_pages,
next_token=next_token,
)

def fetch_entity_parents(
self,
entity_dcids: str | list[str],
*,
as_dict: bool = True) -> dict[str, list[Node | dict]]:
"""Fetches the direct parents of one or more entities using the 'containedInPlace' property.

Args:
entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query.
as_dict (bool): If True, returns a dictionary mapping each input DCID to its
immediate parent entities. If False, returns a dictionary of Parent objects (which
are dataclasses).

Returns:
dict[str, list[Parent | dict]]: A dictionary mapping each input DCID to a list of its
immediate parent entities. Each parent is represented as a Parent object (which
contains the DCID, name, and type of the parent entity) or as a dictionary with
the same data.
"""
# Fetch property values from the API
data = self.fetch_property_values(
node_dcids=entity_dcids,
properties="containedInPlace",
).get_properties()

if as_dict:
return {k: v.to_dict() for k, v in data.items()}

return data

def _fetch_parents_cached(self, dcid: str) -> tuple[Node, ...]:
"""Returns cached parent nodes for a given entity using an LRU cache.

This private wrapper exists because `@lru_cache` cannot be applied directly
to instance methods. By passing the `NodeEndpoint` instance (`self`) as an
argument caching is preserved while keeping the implementation modular and testable.

Args:
dcid (str): The DCID of the entity whose parents should be fetched.

Returns:
tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents.
"""
return fetch_parents_lru(self, dcid)

def fetch_entity_ancestry(
self,
entity_dcids: str | list[str],
as_tree: bool = False,
*,
max_concurrent_requests: Optional[int] = ANCESTRY_MAX_WORKERS
) -> dict[str, list[dict[str, str]] | dict]:
"""Fetches the full ancestry (flat or nested) for one or more entities.
For each input DCID, this method builds the complete ancestry graph using a
breadth-first traversal and parallel fetching.
It returns either a flat list of unique parents or a nested tree structure for
each entity, depending on the `as_tree` flag. The flat list matches the structure
of the `/api/place/parent` endpoint of the DC website.
Args:
entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry
will be fetched.
as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list.
Defaults to False.
max_concurrent_requests (Optional[int]): The maximum number of concurrent requests to make.
Defaults to ANCESTRY_MAX_WORKERS.
Returns:
dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either:
- A flat list of parent dictionaries (if `as_tree` is False), or
- A nested ancestry tree (if `as_tree` is True). Each parent is represented by
a dict with 'dcid', 'name', and 'type'.
"""

if isinstance(entity_dcids, str):
entity_dcids = [entity_dcids]

result = {}

# Use a thread pool to fetch ancestry graphs in parallel for each input entity
with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor:
futures = [
executor.submit(build_ancestry_map,
root=dcid,
fetch_fn=self._fetch_parents_cached)
for dcid in entity_dcids
]

# Gather ancestry maps and postprocess into flat or nested form
for future in futures:
dcid, ancestry = future.result()
if as_tree:
ancestry = build_ancestry_tree(dcid, ancestry)
else:
ancestry = flatten_ancestry(ancestry)
result[dcid] = ancestry

return result
39 changes: 1 addition & 38 deletions datacommons_client/endpoints/response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from dataclasses import asdict
from dataclasses import dataclass
from dataclasses import field
import json
from typing import Any, Dict, List

from datacommons_client.models.node import Arcs
Expand All @@ -15,42 +13,7 @@
from datacommons_client.models.resolve import Entity
from datacommons_client.utils.data_processing import flatten_properties
from datacommons_client.utils.data_processing import observations_as_records


class SerializableMixin:
"""Provides serialization methods for the Response dataclasses."""

def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]:
"""Converts the instance to a dictionary.

Args:
exclude_none: If True, only include non-empty values in the response.

Returns:
Dict[str, Any]: The dictionary representation of the instance.
"""

def _remove_none(data: Any) -> Any:
"""Recursively removes None or empty values from a dictionary or list."""
if isinstance(data, dict):
return {k: _remove_none(v) for k, v in data.items() if v is not None}
elif isinstance(data, list):
return [_remove_none(item) for item in data]
return data

result = asdict(self)
return _remove_none(result) if exclude_none else result

def to_json(self, exclude_none: bool = True) -> str:
"""Converts the instance to a JSON string.

Args:
exclude_none: If True, only include non-empty values in the response.

Returns:
str: The JSON string representation of the instance.
"""
return json.dumps(self.to_dict(exclude_none=exclude_none), indent=2)
from datacommons_client.utils.data_processing import SerializableMixin


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion datacommons_client/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import field
from typing import Any, Dict, List, Optional, TypeAlias

from datacommons_client.utils.data_processing import SerializableMixin

NextToken: TypeAlias = Optional[str]
NodeDCID: TypeAlias = str
ArcLabel: TypeAlias = str
Expand All @@ -10,7 +12,7 @@


@dataclass
class Node:
class Node(SerializableMixin):
"""Represents an individual node in the Data Commons knowledge graph.

Attributes:
Expand Down
Loading