diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 56e2428..c5dce96 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from typing import Optional from datacommons_client.endpoints.base import API @@ -5,6 +6,13 @@ from datacommons_client.endpoints.payloads import NodeRequestPayload from datacommons_client.endpoints.payloads import normalize_properties_to_string from datacommons_client.endpoints.response import NodeResponse +from datacommons_client.models.node import Node +from datacommons_client.utils.graph import build_ancestry_map +from datacommons_client.utils.graph import build_ancestry_tree +from datacommons_client.utils.graph import fetch_parents_lru +from datacommons_client.utils.graph import flatten_ancestry + +ANCESTRY_MAX_WORKERS = 10 class NodeEndpoint(Endpoint): @@ -91,10 +99,12 @@ def fetch_property_labels( expression = "->" if out else "<-" # Make the request and return the response. - return self.fetch(node_dcids=node_dcids, - expression=expression, - all_pages=all_pages, - next_token=next_token) + return self.fetch( + node_dcids=node_dcids, + expression=expression, + all_pages=all_pages, + next_token=next_token, + ) def fetch_property_values( self, @@ -143,10 +153,12 @@ def fetch_property_values( if constraints: expression += f"{{{constraints}}}" - return self.fetch(node_dcids=node_dcids, - expression=expression, - all_pages=all_pages, - next_token=next_token) + return self.fetch( + node_dcids=node_dcids, + expression=expression, + all_pages=all_pages, + next_token=next_token, + ) def fetch_all_classes( self, @@ -174,8 +186,107 @@ def fetch_all_classes( ``` """ - return self.fetch_property_values(node_dcids="Class", - properties="typeOf", - out=False, - all_pages=all_pages, - next_token=next_token) + return self.fetch_property_values( + node_dcids="Class", + properties="typeOf", + out=False, + all_pages=all_pages, + next_token=next_token, + ) + + def fetch_entity_parents( + self, + entity_dcids: str | list[str], + *, + as_dict: bool = True) -> dict[str, list[Node | dict]]: + """Fetches the direct parents of one or more entities using the 'containedInPlace' property. + + Args: + entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query. + as_dict (bool): If True, returns a dictionary mapping each input DCID to its + immediate parent entities. If False, returns a dictionary of Parent objects (which + are dataclasses). + + Returns: + dict[str, list[Parent | dict]]: A dictionary mapping each input DCID to a list of its + immediate parent entities. Each parent is represented as a Parent object (which + contains the DCID, name, and type of the parent entity) or as a dictionary with + the same data. + """ + # Fetch property values from the API + data = self.fetch_property_values( + node_dcids=entity_dcids, + properties="containedInPlace", + ).get_properties() + + if as_dict: + return {k: v.to_dict() for k, v in data.items()} + + return data + + def _fetch_parents_cached(self, dcid: str) -> tuple[Node, ...]: + """Returns cached parent nodes for a given entity using an LRU cache. + + This private wrapper exists because `@lru_cache` cannot be applied directly + to instance methods. By passing the `NodeEndpoint` instance (`self`) as an + argument caching is preserved while keeping the implementation modular and testable. + + Args: + dcid (str): The DCID of the entity whose parents should be fetched. + + Returns: + tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents. + """ + return fetch_parents_lru(self, dcid) + + def fetch_entity_ancestry( + self, + entity_dcids: str | list[str], + as_tree: bool = False, + *, + max_concurrent_requests: Optional[int] = ANCESTRY_MAX_WORKERS + ) -> dict[str, list[dict[str, str]] | dict]: + """Fetches the full ancestry (flat or nested) for one or more entities. + For each input DCID, this method builds the complete ancestry graph using a + breadth-first traversal and parallel fetching. + It returns either a flat list of unique parents or a nested tree structure for + each entity, depending on the `as_tree` flag. The flat list matches the structure + of the `/api/place/parent` endpoint of the DC website. + Args: + entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry + will be fetched. + as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. + Defaults to False. + max_concurrent_requests (Optional[int]): The maximum number of concurrent requests to make. + Defaults to ANCESTRY_MAX_WORKERS. + Returns: + dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: + - A flat list of parent dictionaries (if `as_tree` is False), or + - A nested ancestry tree (if `as_tree` is True). Each parent is represented by + a dict with 'dcid', 'name', and 'type'. + """ + + if isinstance(entity_dcids, str): + entity_dcids = [entity_dcids] + + result = {} + + # Use a thread pool to fetch ancestry graphs in parallel for each input entity + with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: + futures = [ + executor.submit(build_ancestry_map, + root=dcid, + fetch_fn=self._fetch_parents_cached) + for dcid in entity_dcids + ] + + # Gather ancestry maps and postprocess into flat or nested form + for future in futures: + dcid, ancestry = future.result() + if as_tree: + ancestry = build_ancestry_tree(dcid, ancestry) + else: + ancestry = flatten_ancestry(ancestry) + result[dcid] = ancestry + + return result diff --git a/datacommons_client/endpoints/response.py b/datacommons_client/endpoints/response.py index bec7ebb..8b76609 100644 --- a/datacommons_client/endpoints/response.py +++ b/datacommons_client/endpoints/response.py @@ -1,7 +1,5 @@ -from dataclasses import asdict from dataclasses import dataclass from dataclasses import field -import json from typing import Any, Dict, List from datacommons_client.models.node import Arcs @@ -15,42 +13,7 @@ from datacommons_client.models.resolve import Entity from datacommons_client.utils.data_processing import flatten_properties from datacommons_client.utils.data_processing import observations_as_records - - -class SerializableMixin: - """Provides serialization methods for the Response dataclasses.""" - - def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]: - """Converts the instance to a dictionary. - - Args: - exclude_none: If True, only include non-empty values in the response. - - Returns: - Dict[str, Any]: The dictionary representation of the instance. - """ - - def _remove_none(data: Any) -> Any: - """Recursively removes None or empty values from a dictionary or list.""" - if isinstance(data, dict): - return {k: _remove_none(v) for k, v in data.items() if v is not None} - elif isinstance(data, list): - return [_remove_none(item) for item in data] - return data - - result = asdict(self) - return _remove_none(result) if exclude_none else result - - def to_json(self, exclude_none: bool = True) -> str: - """Converts the instance to a JSON string. - - Args: - exclude_none: If True, only include non-empty values in the response. - - Returns: - str: The JSON string representation of the instance. - """ - return json.dumps(self.to_dict(exclude_none=exclude_none), indent=2) +from datacommons_client.utils.data_processing import SerializableMixin @dataclass diff --git a/datacommons_client/models/node.py b/datacommons_client/models/node.py index 8d85f45..861cb84 100644 --- a/datacommons_client/models/node.py +++ b/datacommons_client/models/node.py @@ -2,6 +2,8 @@ from dataclasses import field from typing import Any, Dict, List, Optional, TypeAlias +from datacommons_client.utils.data_processing import SerializableMixin + NextToken: TypeAlias = Optional[str] NodeDCID: TypeAlias = str ArcLabel: TypeAlias = str @@ -10,7 +12,7 @@ @dataclass -class Node: +class Node(SerializableMixin): """Represents an individual node in the Data Commons knowledge graph. Attributes: diff --git a/datacommons_client/tests/endpoints/test_node_endpoint.py b/datacommons_client/tests/endpoints/test_node_endpoint.py index 8720d57..59468c9 100644 --- a/datacommons_client/tests/endpoints/test_node_endpoint.py +++ b/datacommons_client/tests/endpoints/test_node_endpoint.py @@ -1,8 +1,10 @@ from unittest.mock import MagicMock +from unittest.mock import patch from datacommons_client.endpoints.base import API from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.response import NodeResponse +from datacommons_client.models.node import Node def test_node_endpoint_initialization(): @@ -30,13 +32,15 @@ def test_node_endpoint_fetch(): endpoint = NodeEndpoint(api=api_mock) response = endpoint.fetch(node_dcids="test_node", expression="name") - api_mock.post.assert_called_once_with(payload={ - "nodes": ["test_node"], - "property": "name" - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_once_with( + payload={ + "nodes": ["test_node"], + "property": "name" + }, + endpoint="node", + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "test_node" in response.data @@ -80,13 +84,15 @@ def test_node_endpoint_fetch_property_values_out(): out=True) expected_expression = "->name{typeOf:City}" - api_mock.post.assert_called_once_with(payload={ - "nodes": ["node1"], - "property": expected_expression - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_once_with( + payload={ + "nodes": ["node1"], + "property": expected_expression + }, + endpoint="node", + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "node1" in response.data @@ -112,13 +118,15 @@ def test_node_endpoint_fetch_property_values_in(): out=False) expected_expression = "<-name{typeOf:City}" - api_mock.post.assert_called_once_with(payload={ - "nodes": ["node1"], - "property": expected_expression - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_once_with( + payload={ + "nodes": ["node1"], + "property": expected_expression + }, + endpoint="node", + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "node1" in response.data @@ -133,11 +141,13 @@ def test_node_endpoint_fetch_all_classes(): }})) response = endpoint.fetch_all_classes() - endpoint.fetch_property_values.assert_called_once_with(node_dcids="Class", - properties="typeOf", - out=False, - all_pages=True, - next_token=None) + endpoint.fetch_property_values.assert_called_once_with( + node_dcids="Class", + properties="typeOf", + out=False, + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "Class" in response.data @@ -162,23 +172,100 @@ def test_node_endpoint_fetch_property_values_string_vs_list(): properties="name", constraints=None, out=True) - api_mock.post.assert_called_with(payload={ - "nodes": ["node1"], - "property": "->name" - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_with( + payload={ + "nodes": ["node1"], + "property": "->name" + }, + endpoint="node", + all_pages=True, + next_token=None, + ) # List input response = endpoint.fetch_property_values(node_dcids="node1", properties=["name", "typeOf"], constraints=None, out=True) - api_mock.post.assert_called_with(payload={ - "nodes": ["node1"], - "property": "->[name, typeOf]" - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_with( + payload={ + "nodes": ["node1"], + "property": "->[name, typeOf]" + }, + endpoint="node", + all_pages=True, + next_token=None, + ) + + +@patch("datacommons_client.endpoints.node.fetch_parents_lru") +def test_fetch_parents_cached_delegates_to_lru(mock_fetch_lru): + mock_fetch_lru.return_value = (Node("B", "B name", "Region"),) + endpoint = NodeEndpoint(api=MagicMock()) + + result = endpoint._fetch_parents_cached("X") + + assert isinstance(result, tuple) + assert result[0].dcid == "B" + mock_fetch_lru.assert_called_once_with(endpoint, "X") + + +@patch("datacommons_client.endpoints.node.flatten_ancestry") +@patch("datacommons_client.endpoints.node.build_ancestry_map") +def test_fetch_entity_ancestry_flat(mock_build_map, mock_flatten): + """Test fetch_entity_ancestry with flat structure (as_tree=False).""" + mock_build_map.return_value = ( + "X", + { + "X": [Node("A", "A name", "Country")], + "A": [], + }, + ) + mock_flatten.return_value = [{ + "dcid": "A", + "name": "A name", + "type": "Country" + }] + + endpoint = NodeEndpoint(api=MagicMock()) + result = endpoint.fetch_entity_ancestry("X", as_tree=False) + + assert result == {"X": [{"dcid": "A", "name": "A name", "type": "Country"}]} + mock_build_map.assert_called_once() + mock_flatten.assert_called_once() + + +@patch("datacommons_client.endpoints.node.build_ancestry_tree") +@patch("datacommons_client.endpoints.node.build_ancestry_map") +def test_fetch_entity_ancestry_tree(mock_build_map, mock_build_tree): + """Test fetch_entity_ancestry with tree structure (as_tree=True).""" + mock_build_map.return_value = ( + "Y", + { + "Y": [Node("Z", "Z name", "Region")], + "Z": [], + }, + ) + mock_build_tree.return_value = { + "dcid": + "Y", + "name": + None, + "type": + None, + "parents": [{ + "dcid": "Z", + "name": "Z name", + "type": "Region", + "parents": [] + }], + } + + endpoint = NodeEndpoint(api=MagicMock()) + result = endpoint.fetch_entity_ancestry("Y", as_tree=True) + + assert "Y" in result + assert result["Y"]["dcid"] == "Y" + assert result["Y"]["parents"][0]["dcid"] == "Z" + mock_build_map.assert_called_once() + mock_build_tree.assert_called_once_with("Y", mock_build_map.return_value[1]) diff --git a/datacommons_client/tests/utils/test_graph.py b/datacommons_client/tests/utils/test_graph.py new file mode 100644 index 0000000..d6bfd51 --- /dev/null +++ b/datacommons_client/tests/utils/test_graph.py @@ -0,0 +1,238 @@ +from collections import defaultdict +from unittest.mock import MagicMock + +from datacommons_client.models.node import Node +from datacommons_client.utils.graph import _assemble_tree +from datacommons_client.utils.graph import _fetch_parents_uncached +from datacommons_client.utils.graph import _postorder_nodes +from datacommons_client.utils.graph import build_ancestry_map +from datacommons_client.utils.graph import build_ancestry_tree +from datacommons_client.utils.graph import fetch_parents_lru +from datacommons_client.utils.graph import flatten_ancestry + + +def test_fetch_parents_uncached_returns_data(): + """Test _fetch_parents_uncached delegates to endpoint correctly.""" + endpoint = MagicMock() + endpoint.fetch_entity_parents.return_value.get.return_value = [ + Node(dcid="parent1", name="Parent 1", types=["Country"]) + ] + + result = _fetch_parents_uncached(endpoint, "test_dcid") + assert isinstance(result, list) + assert result[0].dcid == "parent1" + + endpoint.fetch_entity_parents.assert_called_once_with("test_dcid", + as_dict=False) + + +def test_fetch_parents_lru_caches_results(): + """Test fetch_parents_lru uses LRU cache and returns tuple.""" + endpoint = MagicMock() + endpoint.fetch_entity_parents.return_value.get.return_value = [ + Node(dcid="parentX", name="Parent X", types=["Region"]) + ] + + result1 = fetch_parents_lru(endpoint, "nodeA") + + # This should hit cache + result2 = fetch_parents_lru(endpoint, "nodeA") + # This should hit cache again + fetch_parents_lru(endpoint, "nodeA") + + assert isinstance(result1, tuple) + assert result1[0].dcid == "parentX" + assert result1 == result2 + assert endpoint.fetch_entity_parents.call_count == 1 # Called only once + + +def test_build_ancestry_map_linear_tree(): + """A -> B -> C""" + + def fetch_mock(dcid): + return { + "C": (Node("B", "Node B", "Type"),), + "B": (Node("A", "Node A", "Type"),), + "A": tuple(), + }.get(dcid, tuple()) + + root, ancestry = build_ancestry_map("C", fetch_mock, max_workers=2) + + assert root == "C" # Since we start from C + assert set(ancestry.keys()) == {"C", "B", "A"} # All nodes should be present + assert ancestry["C"][0].dcid == "B" # First parent of C is B + assert ancestry["B"][0].dcid == "A" # First parent of B is A + assert ancestry["A"] == [] # No parents for A + + +def test_build_ancestry_map_branching_graph(): + r""" + Graph: + F + / \ + D E + / \ / + B C + \/ + A + """ + + def fetch_mock(dcid): + return { + "A": (Node("B", "Node B", "Type"), Node("C", "Node C", "Type")), + "B": (Node("D", "Node D", "Type"),), + "C": (Node("D", "Node D", "Type"), Node("E", "Node E", "Type")), + "D": (Node("F", "Node F", "Type"),), + "E": (Node("F", "Node F", "Type"),), + "F": tuple(), + }.get(dcid, tuple()) + + root, ancestry = build_ancestry_map("A", fetch_mock, max_workers=4) + + assert root == "A" + assert set(ancestry.keys()) == {"A", "B", "C", "D", "E", "F"} + assert [p.dcid for p in ancestry["A"]] == ["B", "C"] # A has two parents + assert [p.dcid for p in ancestry["B"]] == ["D"] # B has one parent + assert [p.dcid for p in ancestry["C"]] == ["D", "E"] # C has two parents + assert [p.dcid for p in ancestry["D"]] == ["F"] # D has one parent + assert [p.dcid for p in ancestry["E"]] == ["F"] # E has one parent + assert ancestry["F"] == [] # F has no parents + + +def test_build_ancestry_map_cycle_detection(): + """ + Graph with a cycle: + A -> B -> C -> A + (Should not loop infinitely) + """ + + call_count = defaultdict(int) + + def fetch_mock(dcid): + call_count[dcid] += 1 + return { + "A": (Node("B", "B", "Type"),), + "B": (Node("C", "C", "Type"),), + "C": (Node("A", "A", "Type"),), # Cycle back to A + }.get(dcid, tuple()) + + root, ancestry = build_ancestry_map("A", fetch_mock, max_workers=2) + + assert root == "A" # Since we start from A + assert set(ancestry.keys()) == {"A", "B", "C"} + assert [p.dcid for p in ancestry["A"]] == ["B"] # A points to B + assert [p.dcid for p in ancestry["B"]] == ["C"] # B points to C + assert [p.dcid for p in ancestry["C"]] == ["A" + ] # C points back to A but it's ok + + # Check that each node was fetched only once (particularly for A to avoid infinite loop) + assert call_count["A"] == 1 + assert call_count["B"] == 1 + assert call_count["C"] == 1 + + +def test_postorder_nodes_simple_graph(): + """Test postorder traversal on a simple graph.""" + ancestry = { + "C": [Node("B", "B", "Type")], + "B": [Node("A", "A", "Type")], + "A": [], + } + + order = _postorder_nodes("C", ancestry) + assert order == ["A", "B", "C"] + + new_order = _postorder_nodes("B", ancestry) + assert new_order == ["A", "B"] + + +def test_assemble_tree_creates_nested_structure(): + """Test _assemble_tree creates a nested structure.""" + ancestry = { + "C": [Node("B", "Node B", "Type")], + "B": [Node("A", "Node A", "Type")], + "A": [], + } + postorder = ["A", "B", "C"] + tree = _assemble_tree(postorder, ancestry) + + assert tree["dcid"] == "C" + assert tree["parents"][0]["dcid"] == "B" + assert tree["parents"][0]["parents"][0]["dcid"] == "A" + + +def test_postorder_nodes_ignores_unreachable_nodes(): + """ + Graph: + A → B → C + Ancestry map also includes D (unconnected) + """ + ancestry = { + "A": [Node("B", "B", "Type")], + "B": [Node("C", "C", "Type")], + "C": [], + "D": [Node("X", "X", "Type")], + } + + postorder = _postorder_nodes("A", ancestry) + + # Only nodes reachable from A should be included + assert postorder == ["C", "B", "A"] + assert "D" not in postorder + + +def test_assemble_tree_shared_parent_not_duplicated(): + """ + Structure: + A → C + B → C + Both A and B have same parent C + """ + + ancestry = { + "A": [Node("C", "C name", "City")], + "B": [Node("C", "C name", "City")], + "C": [], + } + + postorder = ["C", "A", "B"] # C first to allow bottom-up build + tree = _assemble_tree(postorder, ancestry) + + assert tree["dcid"] == "B" + assert len(tree["parents"]) == 1 + assert tree["parents"][0]["dcid"] == "C" + + # Confirm C only appears once + assert tree["parents"][0] is not None + assert tree["parents"][0]["name"] == "C name" + + +def test_build_ancestry_tree_nested_output(): + """Test build_ancestry_tree creates a nested structure.""" + ancestry = { + "C": [Node("B", "B", "Type")], + "B": [Node("A", "A", "Type")], + "A": [], + } + + tree = build_ancestry_tree("C", ancestry) + + assert tree["dcid"] == "C" + assert tree["parents"][0]["dcid"] == "B" + assert tree["parents"][0]["parents"][0]["dcid"] == "A" + + +def test_flatten_ancestry_deduplicates(): + """Test flatten_ancestry deduplicates parents.""" + + ancestry = { + "X": [Node("A", "A", types=["Country"])], + "Y": [Node("A", "A", types=["Country"]), + Node("B", "B", types=["City"])], + } + + flat = flatten_ancestry(ancestry) + + assert {"dcid": "A", "name": "A", "types": ["Country"]} in flat + assert {"dcid": "B", "name": "B", "types": ["City"]} in flat + assert len(flat) == 2 diff --git a/datacommons_client/utils/data_processing.py b/datacommons_client/utils/data_processing.py index f4d4f92..ef6074c 100644 --- a/datacommons_client/utils/data_processing.py +++ b/datacommons_client/utils/data_processing.py @@ -1,4 +1,5 @@ from dataclasses import asdict +import json from typing import Any, Dict @@ -113,3 +114,39 @@ def group_variables_by_entity( for entity in entities: result.setdefault(entity, []).append(variable) return result + + +class SerializableMixin: + """Provides serialization methods for the Response dataclasses.""" + + def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]: + """Converts the instance to a dictionary. + + Args: + exclude_none: If True, only include non-empty values in the response. + + Returns: + Dict[str, Any]: The dictionary representation of the instance. + """ + + def _remove_none(data: Any) -> Any: + """Recursively removes None or empty values from a dictionary or list.""" + if isinstance(data, dict): + return {k: _remove_none(v) for k, v in data.items() if v is not None} + elif isinstance(data, list): + return [_remove_none(item) for item in data] + return data + + result = asdict(self) + return _remove_none(result) if exclude_none else result + + def to_json(self, exclude_none: bool = True) -> str: + """Converts the instance to a JSON string. + + Args: + exclude_none: If True, only include non-empty values in the response. + + Returns: + str: The JSON string representation of the instance. + """ + return json.dumps(self.to_dict(exclude_none=exclude_none), indent=2) diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py new file mode 100644 index 0000000..b7cd710 --- /dev/null +++ b/datacommons_client/utils/graph.py @@ -0,0 +1,229 @@ +from collections import deque +from concurrent.futures import FIRST_COMPLETED +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import wait +from functools import lru_cache +from typing import Callable, Optional, TypeAlias + +from datacommons_client.models.node import Node + +PARENTS_MAX_WORKERS = 10 + +AncestryMap: TypeAlias = dict[str, list[Node]] + +# -- -- Fetch tools -- -- + + +def _fetch_parents_uncached(endpoint, dcid: str) -> list[Node]: + """Fetches the immediate parents of a given DCID from the endpoint, without caching. + + This function performs a direct, uncached call to the API. It exists + primarily to serve as the internal, cache-free fetch used by `fetch_parents_lru`, which + applies LRU caching on top of this raw access function. + + By isolating the pure fetch logic here, we ensure that caching is handled separately + and cleanly via `@lru_cache` on `fetch_parents_lru`, which requires its wrapped + function to be deterministic and side-effect free. + + Args: + endpoint: A client object with a `fetch_entity_parents` method. + dcid (str): The entity ID for which to fetch parents. + Returns: + A list of parent dictionaries, each containing 'dcid', 'name', and 'type'. + """ + parents = endpoint.fetch_entity_parents(dcid, as_dict=False).get(dcid, []) + + return parents if isinstance(parents, list) else [parents] + + +@lru_cache(maxsize=512) +def fetch_parents_lru(endpoint, dcid: str) -> tuple[Node, ...]: + """Fetches parents of a DCID using an LRU cache for improved performance. + Args: + endpoint: A client object with a `fetch_entity_parents` method. + dcid (str): The entity ID to fetch parents for. + Returns: + A tuple of `Parent` objects corresponding to the entity’s parents. + """ + parents = _fetch_parents_uncached(endpoint, dcid) + return tuple(p for p in parents) + + +# -- -- Ancestry tools -- -- + + +def build_ancestry_map( + root: str, + fetch_fn: Callable[[str], tuple[Node, ...]], + max_workers: Optional[int] = PARENTS_MAX_WORKERS, +) -> tuple[str, AncestryMap]: + """Constructs a complete ancestry map for the root node using parallel + Breadth-First Search (BFS). + + Traverses the ancestry graph upward from the root node, discovering all parent + relationships by fetching in parallel. + + Args: + root (str): The DCID of the root entity to start from. + fetch_fn (Callable): A function that takes a DCID and returns a Parent tuple. + max_workers (Optional[int]): Max number of threads to use for parallel fetching. + Optional, defaults to `PARENTS_MAX_WORKERS`. + + Returns: + A tuple containing: + - The original root DCID. + - A dictionary mapping each DCID to a list of its `Parent`s. + """ + ancestry: AncestryMap = {} + visited: set[str] = set() + in_progress: dict[str, Future] = {} + + original_root = root + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + queue = deque([root]) + + # Standard BFS loop, but fetches are executed in parallel threads + while queue or in_progress: + # Submit fetch tasks for all nodes in the queue + while queue: + dcid = queue.popleft() + # Check if the node has already been visited or is in progress + if dcid not in visited and dcid not in in_progress: + # Submit the fetch task + in_progress[dcid] = executor.submit(fetch_fn, dcid) + + # Check if any futures are still in progress + if not in_progress: + continue + + # Wait for at least one future to complete + done_futures, _ = wait(in_progress.values(), return_when=FIRST_COMPLETED) + + # Find which DCIDs have completed + completed_dcids = [ + dcid for dcid, future in in_progress.items() if future in done_futures + ] + + # Process completed fetches and enqueue any unseen parents + for dcid in completed_dcids: + future = in_progress.pop(dcid) + parents = list(future.result()) + ancestry[dcid] = parents + visited.add(dcid) + + for parent in parents: + if parent.dcid not in visited and parent.dcid not in in_progress: + queue.append(parent.dcid) + + return original_root, ancestry + + +def _postorder_nodes(root: str, ancestry: AncestryMap) -> list[str]: + """Generates a postorder list of all nodes reachable from the root. + + Postorder ensures children are processed before their parents. That way the tree + is built bottom-up. + + Args: + root (str): The root DCID to start traversal from. + ancestry (AncestryMap): The ancestry graph. + Returns: + A list of DCIDs in postorder (i.e children before parents). + """ + # Initialize stack and postorder list + stack, postorder, seen = [root], [], set() + + # Traverse the graph using a stack + while stack: + node = stack.pop() + # Skip if already seen + if node in seen: + continue + seen.add(node) + postorder.append(node) + # Push all unvisited parents onto the stack (i.e climb up the graph, child -> parent) + for parent in ancestry.get(node, []): + parent_dcid = parent.dcid + if parent_dcid not in seen: + stack.append(parent_dcid) + + # Reverse the list so that parents come after their children (i.e postorder) + return list(reversed(postorder)) + + +def _assemble_tree(postorder: list[str], ancestry: AncestryMap) -> dict: + """Builds a nested dictionary tree from a postorder node list and ancestry map. + Constructs a nested representation of the ancestry graph, ensuring that parents + are embedded after their children (which is enabled by postorder). + Args: + postorder (list[str]): List of node DCIDs in postorder. + ancestry (AncestryMap): Map from DCID to list of Parent objects. + Returns: + A nested dictionary representing the ancestry tree rooted at the last postorder node. + """ + tree_cache: dict[str, dict] = {} + + for node in postorder: + # Initialize the node dictionary. + node_dict = {"dcid": node, "name": None, "type": None, "parents": []} + + # For each parent of the current node, fetch its details and add it to the node_dict. + for parent in ancestry.get(node, []): + parent_dcid = parent.dcid + name = parent.name + entity_type = parent.types + + # If the parent node is not already in the cache, add it. + if parent_dcid not in tree_cache: + tree_cache[parent_dcid] = { + "dcid": parent_dcid, + "name": name, + "type": entity_type, + "parents": [], + } + + parent_node = tree_cache[parent_dcid] + + # Ensure name/type are up to date (in case of duplicates) + parent_node["name"] = name + parent_node["type"] = entity_type + node_dict["parents"].append(parent_node) + + tree_cache[node] = node_dict + + # The root node is the last one in postorder, that's what gets returned + return tree_cache[postorder[-1]] + + +def build_ancestry_tree(root: str, ancestry: AncestryMap) -> dict: + """Builds a nested ancestry tree from an ancestry map. + Args: + root (str): The DCID of the root node. + ancestry (AncestryMap): A flat ancestry map built from `_build_ancestry_map`. + Returns: + A nested dictionary tree rooted at the specified DCID. + """ + postorder = _postorder_nodes(root, ancestry) + return _assemble_tree(postorder, ancestry) + + +def flatten_ancestry(ancestry: AncestryMap) -> list[dict[str, str]]: + """Flattens the ancestry map into a deduplicated list of parent records. + Args: + ancestry (AncestryMap): Ancestry mapping of DCIDs to lists of Parent objects. + Returns: + A list of dictionaries with keys 'dcid', 'name', and 'type', containing + each unique parent in the graph. + """ + + flat: list = [] + seen: set[str] = set() + for parents in ancestry.values(): + for parent in parents: + if parent.dcid in seen: + continue + seen.add(parent.dcid) + flat.append(parent.to_dict()) + return flat