From dc10c52c6e74a97f2806117898d06baa2ef30c3e Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 31 Mar 2025 11:33:29 +0200 Subject: [PATCH 1/8] move serializable mixin --- datacommons_client/endpoints/response.py | 39 +-------------------- datacommons_client/utils/data_processing.py | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/datacommons_client/endpoints/response.py b/datacommons_client/endpoints/response.py index bec7ebb..8b76609 100644 --- a/datacommons_client/endpoints/response.py +++ b/datacommons_client/endpoints/response.py @@ -1,7 +1,5 @@ -from dataclasses import asdict from dataclasses import dataclass from dataclasses import field -import json from typing import Any, Dict, List from datacommons_client.models.node import Arcs @@ -15,42 +13,7 @@ from datacommons_client.models.resolve import Entity from datacommons_client.utils.data_processing import flatten_properties from datacommons_client.utils.data_processing import observations_as_records - - -class SerializableMixin: - """Provides serialization methods for the Response dataclasses.""" - - def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]: - """Converts the instance to a dictionary. - - Args: - exclude_none: If True, only include non-empty values in the response. - - Returns: - Dict[str, Any]: The dictionary representation of the instance. - """ - - def _remove_none(data: Any) -> Any: - """Recursively removes None or empty values from a dictionary or list.""" - if isinstance(data, dict): - return {k: _remove_none(v) for k, v in data.items() if v is not None} - elif isinstance(data, list): - return [_remove_none(item) for item in data] - return data - - result = asdict(self) - return _remove_none(result) if exclude_none else result - - def to_json(self, exclude_none: bool = True) -> str: - """Converts the instance to a JSON string. - - Args: - exclude_none: If True, only include non-empty values in the response. - - Returns: - str: The JSON string representation of the instance. - """ - return json.dumps(self.to_dict(exclude_none=exclude_none), indent=2) +from datacommons_client.utils.data_processing import SerializableMixin @dataclass diff --git a/datacommons_client/utils/data_processing.py b/datacommons_client/utils/data_processing.py index f4d4f92..ef6074c 100644 --- a/datacommons_client/utils/data_processing.py +++ b/datacommons_client/utils/data_processing.py @@ -1,4 +1,5 @@ from dataclasses import asdict +import json from typing import Any, Dict @@ -113,3 +114,39 @@ def group_variables_by_entity( for entity in entities: result.setdefault(entity, []).append(variable) return result + + +class SerializableMixin: + """Provides serialization methods for the Response dataclasses.""" + + def to_dict(self, exclude_none: bool = True) -> Dict[str, Any]: + """Converts the instance to a dictionary. + + Args: + exclude_none: If True, only include non-empty values in the response. + + Returns: + Dict[str, Any]: The dictionary representation of the instance. + """ + + def _remove_none(data: Any) -> Any: + """Recursively removes None or empty values from a dictionary or list.""" + if isinstance(data, dict): + return {k: _remove_none(v) for k, v in data.items() if v is not None} + elif isinstance(data, list): + return [_remove_none(item) for item in data] + return data + + result = asdict(self) + return _remove_none(result) if exclude_none else result + + def to_json(self, exclude_none: bool = True) -> str: + """Converts the instance to a JSON string. + + Args: + exclude_none: If True, only include non-empty values in the response. + + Returns: + str: The JSON string representation of the instance. + """ + return json.dumps(self.to_dict(exclude_none=exclude_none), indent=2) From ae9e5748de7a76742800c7533e2db77d5eaf6225 Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 31 Mar 2025 11:33:47 +0200 Subject: [PATCH 2/8] Add graph utils --- datacommons_client/utils/graph.py | 271 ++++++++++++++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 datacommons_client/utils/graph.py diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py new file mode 100644 index 0000000..3d6b012 --- /dev/null +++ b/datacommons_client/utils/graph.py @@ -0,0 +1,271 @@ +from collections import deque +from concurrent.futures import FIRST_COMPLETED +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import wait +from dataclasses import dataclass +from functools import lru_cache +from typing import Callable, Optional + +from datacommons_client.utils.data_processing import SerializableMixin + +PARENTS_MAX_WORKERS = 10 + + +@dataclass(frozen=True) +class Parent(SerializableMixin): + """A class representing a parent node in a graph. + Attributes: + dcid (str): The ID of the parent node. + name (str): The name of the parent node. + type (str | list[str]): The type(s) of the parent node. + """ + + dcid: str + name: str + type: str | list[str] + + +AncestryMap = dict[str, list[Parent]] + +# -- -- Fetch tools -- -- + + +def _fetch_parents_uncached(endpoint, dcid: str) -> list[Parent]: + """Fetches the immediate parents of a given DCID from the endpoint, without caching. + + This function performs a direct, uncached call to the API. It exists + primarily to serve as the internal, cache-free fetch used by `fetch_parents_lru`, which + applies LRU caching on top of this raw access function. + + By isolating the pure fetch logic here, we ensure that caching is handled separately + and cleanly via `@lru_cache` on `fetch_parents_lru`, which requires its wrapped + function to be deterministic and side-effect free. + + Args: + endpoint: A client object with a `fetch_entity_parents` method. + dcid (str): The entity ID for which to fetch parents. + Returns: + A list of parent dictionaries, each containing 'dcid', 'name', and 'type'. + """ + return endpoint.fetch_entity_parents(dcid).get(dcid, []) + + +@lru_cache(maxsize=512) +def fetch_parents_lru(endpoint, dcid: str) -> tuple[Parent, ...]: + """Fetches parents of a DCID using an LRU cache for improved performance. + Args: + endpoint: A client object with a `fetch_entity_parents` method. + dcid (str): The entity ID to fetch parents for. + Returns: + A tuple of `Parent` objects corresponding to the entity’s parents. + """ + parents = _fetch_parents_uncached(endpoint, dcid) + return tuple(p for p in parents) + + +# -- -- Ancestry tools -- -- +def build_parents_dictionary(data: dict) -> dict[str, list[Parent]]: + """Transforms a dictionary of entities and their parents into a structured + dictionary mapping each entity to its list of Parents. + + Args: + data (dict): The properties dictionary of a Node.fetch_property_values call. + + Returns: + dict[str, list[Parent]]: A dictionary where each key is an entity DCID + and the value is a list of Parent objects representing its parents. + + """ + + result: dict[str, list[Parent]] = {} + + for entity, properties in data.items(): + if not isinstance(properties, list): + properties = [properties] + + for parent in properties: + parent_type = parent.types[0] if len(parent.types) == 1 else parent.types + result.setdefault(entity, []).append( + Parent(dcid=parent.dcid, name=parent.name, type=parent_type)) + return result + + +def build_ancestry_map( + root: str, + fetch_fn: Callable[[str], tuple[Parent, ...]], + max_workers: Optional[int] = PARENTS_MAX_WORKERS, +) -> tuple[str, AncestryMap]: + """Constructs a complete ancestry map for the root node using parallel + Breadth-First Search (BFS). + + Traverses the ancestry graph upward from the root node, discovering all parent + relationships by fetching in parallel. + + Args: + root (str): The DCID of the root entity to start from. + fetch_fn (Callable): A function that takes a DCID and returns a Parent tuple. + max_workers (Optional[int]): Max number of threads to use for parallel fetching. + Optional, defaults to `PARENTS_MAX_WORKERS`. + + Returns: + A tuple containing: + - The original root DCID. + - A dictionary mapping each DCID to a list of its `Parent`s. + """ + ancestry: AncestryMap = {} + visited: set[str] = set() + in_progress: dict[str, Future] = {} + + original_root = root + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + queue = deque([root]) + + # Standard BFS loop, but fetches are executed in parallel threads + while queue or in_progress: + # Submit fetch tasks for all nodes in the queue + while queue: + dcid = queue.popleft() + # Check if the node has already been visited or is in progress + if dcid not in visited and dcid not in in_progress: + # Submit the fetch task + in_progress[dcid] = executor.submit(fetch_fn, dcid) + + # Check if any futures are still in progress + if not in_progress: + continue + + # Wait for at least one future to complete + done_futures, _ = wait(in_progress.values(), return_when=FIRST_COMPLETED) + + # Find which DCIDs have completed + completed_dcids = [ + dcid for dcid, future in in_progress.items() if future in done_futures + ] + + # Process completed fetches and enqueue any unseen parents + for dcid in completed_dcids: + future = in_progress.pop(dcid) + parents = list(future.result()) + ancestry[dcid] = parents + visited.add(dcid) + + for parent in parents: + if parent.dcid not in visited and parent.dcid not in in_progress: + queue.append(parent.dcid) + + return original_root, ancestry + + +def _postorder_nodes(root: str, ancestry: AncestryMap) -> list[str]: + """Generates a postorder list of all nodes reachable from the root. + + Postorder ensures children are processed before their parents. That way the tree + is built bottom-up. + + Args: + root (str): The root DCID to start traversal from. + ancestry (AncestryMap): The ancestry graph. + Returns: + A list of DCIDs in postorder (i.e children before parents). + """ + # Initialize stack and postorder list + stack, postorder, seen = [root], [], set() + + # Traverse the graph using a stack + while stack: + node = stack.pop() + # Skip if already seen + if node in seen: + continue + seen.add(node) + postorder.append(node) + # Push all unvisited parents onto the stack (i.e climb up the graph, child -> parent) + for parent in ancestry.get(node, []): + parent_dcid = parent.dcid + if parent_dcid not in seen: + stack.append(parent_dcid) + + # Reverse the list so that parents come after their children (i.e postorder) + return list(reversed(postorder)) + + +def _assemble_tree(postorder: list[str], ancestry: AncestryMap) -> dict: + """Builds a nested dictionary tree from a postorder node list and ancestry map. + Constructs a nested representation of the ancestry graph, ensuring that parents + are embedded after their children (which is enabled by postorder). + Args: + postorder (list[str]): List of node DCIDs in postorder. + ancestry (AncestryMap): Map from DCID to list of Parent objects. + Returns: + A nested dictionary representing the ancestry tree rooted at the last postorder node. + """ + tree_cache: dict[str, dict] = {} + + for node in postorder: + # Initialize the node dictionary. + node_dict = {"dcid": node, "name": None, "type": None, "parents": []} + + # For each parent of the current node, fetch its details and add it to the node_dict. + for parent in ancestry.get(node, []): + parent_dcid = parent.dcid + name = parent.name + entity_type = parent.type + + # If the parent node is not already in the cache, add it. + if parent_dcid not in tree_cache: + tree_cache[parent_dcid] = { + "dcid": parent_dcid, + "name": name, + "type": entity_type, + "parents": [], + } + + parent_node = tree_cache[parent_dcid] + + # Ensure name/type are up to date (in case of duplicates) + parent_node["name"] = name + parent_node["type"] = entity_type + node_dict["parents"].append(parent_node) + + tree_cache[node] = node_dict + + # The root node is the last one in postorder, that's what gets returned + return tree_cache[postorder[-1]] + + +def build_ancestry_tree(root: str, ancestry: AncestryMap) -> dict: + """Builds a nested ancestry tree from an ancestry map. + Args: + root (str): The DCID of the root node. + ancestry (AncestryMap): A flat ancestry map built from `_build_ancestry_map`. + Returns: + A nested dictionary tree rooted at the specified DCID. + """ + postorder = _postorder_nodes(root, ancestry) + return _assemble_tree(postorder, ancestry) + + +def flatten_ancestry(ancestry: AncestryMap) -> list[dict[str, str]]: + """Flattens the ancestry map into a deduplicated list of parent records. + Args: + ancestry (AncestryMap): Ancestry mapping of DCIDs to lists of Parent objects. + Returns: + A list of dictionaries with keys 'dcid', 'name', and 'type', containing + each unique parent in the graph. + """ + + flat: list = [] + seen: set[str] = set() + for parents in ancestry.values(): + for parent in parents: + if parent.dcid in seen: + continue + seen.add(parent.dcid) + flat.append({ + "dcid": parent.dcid, + "name": parent.name, + "type": parent.type + }) + return flat From 7ed3fff614fff2830d76310483fe8ceab847d59e Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 31 Mar 2025 11:34:04 +0200 Subject: [PATCH 3/8] Update node.py Add ancestry tools --- datacommons_client/endpoints/node.py | 126 ++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 13 deletions(-) diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 56e2428..e10dc97 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from typing import Optional from datacommons_client.endpoints.base import API @@ -5,6 +6,15 @@ from datacommons_client.endpoints.payloads import NodeRequestPayload from datacommons_client.endpoints.payloads import normalize_properties_to_string from datacommons_client.endpoints.response import NodeResponse +from datacommons_client.models.node import Node +from datacommons_client.utils.graph import build_ancestry_map +from datacommons_client.utils.graph import build_ancestry_tree +from datacommons_client.utils.graph import build_parents_dictionary +from datacommons_client.utils.graph import fetch_parents_lru +from datacommons_client.utils.graph import flatten_ancestry +from datacommons_client.utils.graph import Parent + +ANCESTRY_MAX_WORKERS = 20 class NodeEndpoint(Endpoint): @@ -91,10 +101,12 @@ def fetch_property_labels( expression = "->" if out else "<-" # Make the request and return the response. - return self.fetch(node_dcids=node_dcids, - expression=expression, - all_pages=all_pages, - next_token=next_token) + return self.fetch( + node_dcids=node_dcids, + expression=expression, + all_pages=all_pages, + next_token=next_token, + ) def fetch_property_values( self, @@ -143,10 +155,12 @@ def fetch_property_values( if constraints: expression += f"{{{constraints}}}" - return self.fetch(node_dcids=node_dcids, - expression=expression, - all_pages=all_pages, - next_token=next_token) + return self.fetch( + node_dcids=node_dcids, + expression=expression, + all_pages=all_pages, + next_token=next_token, + ) def fetch_all_classes( self, @@ -174,8 +188,94 @@ def fetch_all_classes( ``` """ - return self.fetch_property_values(node_dcids="Class", - properties="typeOf", - out=False, - all_pages=all_pages, - next_token=next_token) + return self.fetch_property_values( + node_dcids="Class", + properties="typeOf", + out=False, + all_pages=all_pages, + next_token=next_token, + ) + + def fetch_entity_parents( + self, entity_dcids: str | list[str]) -> dict[str, list[Parent]]: + """Fetches the direct parents of one or more entities using the 'containedInPlace' property. + + Args: + entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query. + + Returns: + dict[str, list[Parent]]: A dictionary mapping each input DCID to a list of its + immediate parent entities. Each parent is represented as a Parent object, which + contains the DCID, name, and type of the parent entity. + """ + # Fetch property values from the API + data = self.fetch_property_values( + node_dcids=entity_dcids, + properties="containedInPlace", + ).get_properties() + + return build_parents_dictionary(data=data) + + +def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]: + """Returns cached parent nodes for a given entity using an LRU cache. + + This private wrapper exists because `@lru_cache` cannot be applied directly + to instance methods. By passing the `NodeEndpoint` instance (`self`) as an + argument caching is preserved while keeping the implementation modular and testable. + + Args: + dcid (str): The DCID of the entity whose parents should be fetched. + + Returns: + tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents. + """ + return fetch_parents_lru(self, dcid) + + +def fetch_entity_ancestry( + self, + entity_dcids: str | list[str], + as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]: + """Fetches the full ancestry (flat or nested) for one or more entities. + For each input DCID, this method builds the complete ancestry graph using a + breadth-first traversal and parallel fetching. + It returns either a flat list of unique parents or a nested tree structure for + each entity, depending on the `as_tree` flag. The flat list matches the structure + of the `/api/place/parent` endpoint of the DC website. + Args: + entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry + will be fetched. + as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. + Defaults to False. + Returns: + dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: + - A flat list of parent dictionaries (if `as_tree` is False), or + - A nested ancestry tree (if `as_tree` is True). Each parent is represented by + a dict with 'dcid', 'name', and 'type'. + """ + + if isinstance(entity_dcids, str): + entity_dcids = [entity_dcids] + + result = {} + + # Use a thread pool to fetch ancestry graphs in parallel for each input entity + with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor: + futures = [ + executor.submit(build_ancestry_map, + root=dcid, + fetch_fn=self._fetch_parents_cached) + for dcid in entity_dcids + ] + + # Gather ancestry maps and postprocess into flat or nested form + for future in futures: + dcid, ancestry = future.result() + if as_tree: + ancestry = build_ancestry_tree(dcid, ancestry) + else: + ancestry = flatten_ancestry(ancestry) + result[dcid] = ancestry + + return result From 1cf83d4e80fbe4ab0c3704f1b0d73e36f846228d Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 31 Mar 2025 18:51:23 +0200 Subject: [PATCH 4/8] refactor types --- datacommons_client/endpoints/node.py | 139 ++++++++++++++------------- datacommons_client/models/graph.py | 22 +++++ datacommons_client/utils/graph.py | 23 +---- 3 files changed, 99 insertions(+), 85 deletions(-) create mode 100644 datacommons_client/models/graph.py diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index e10dc97..0220a9f 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -6,13 +6,12 @@ from datacommons_client.endpoints.payloads import NodeRequestPayload from datacommons_client.endpoints.payloads import normalize_properties_to_string from datacommons_client.endpoints.response import NodeResponse -from datacommons_client.models.node import Node +from datacommons_client.models.graph import Parent from datacommons_client.utils.graph import build_ancestry_map from datacommons_client.utils.graph import build_ancestry_tree from datacommons_client.utils.graph import build_parents_dictionary from datacommons_client.utils.graph import fetch_parents_lru from datacommons_client.utils.graph import flatten_ancestry -from datacommons_client.utils.graph import Parent ANCESTRY_MAX_WORKERS = 20 @@ -197,16 +196,23 @@ def fetch_all_classes( ) def fetch_entity_parents( - self, entity_dcids: str | list[str]) -> dict[str, list[Parent]]: + self, + entity_dcids: str | list[str], + *, + as_dict: bool = True) -> dict[str, list[Parent | dict]]: """Fetches the direct parents of one or more entities using the 'containedInPlace' property. Args: entity_dcids (str | list[str]): A single DCID or a list of DCIDs to query. + as_dict (bool): If True, returns a dictionary mapping each input DCID to its + immediate parent entities. If False, returns a dictionary of Parent objects (which + are dataclasses). Returns: - dict[str, list[Parent]]: A dictionary mapping each input DCID to a list of its - immediate parent entities. Each parent is represented as a Parent object, which - contains the DCID, name, and type of the parent entity. + dict[str, list[Parent | dict]]: A dictionary mapping each input DCID to a list of its + immediate parent entities. Each parent is represented as a Parent object (which + contains the DCID, name, and type of the parent entity) or as a dictionary with + the same data. """ # Fetch property values from the API data = self.fetch_property_values( @@ -214,68 +220,71 @@ def fetch_entity_parents( properties="containedInPlace", ).get_properties() - return build_parents_dictionary(data=data) + result = build_parents_dictionary(data=data) + if as_dict: + return {k: [p.to_dict() for p in v] for k, v in result.items()} -def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]: - """Returns cached parent nodes for a given entity using an LRU cache. + return result - This private wrapper exists because `@lru_cache` cannot be applied directly - to instance methods. By passing the `NodeEndpoint` instance (`self`) as an - argument caching is preserved while keeping the implementation modular and testable. + def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]: + """Returns cached parent nodes for a given entity using an LRU cache. - Args: - dcid (str): The DCID of the entity whose parents should be fetched. + This private wrapper exists because `@lru_cache` cannot be applied directly + to instance methods. By passing the `NodeEndpoint` instance (`self`) as an + argument caching is preserved while keeping the implementation modular and testable. - Returns: - tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents. - """ - return fetch_parents_lru(self, dcid) - - -def fetch_entity_ancestry( - self, - entity_dcids: str | list[str], - as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]: - """Fetches the full ancestry (flat or nested) for one or more entities. - For each input DCID, this method builds the complete ancestry graph using a - breadth-first traversal and parallel fetching. - It returns either a flat list of unique parents or a nested tree structure for - each entity, depending on the `as_tree` flag. The flat list matches the structure - of the `/api/place/parent` endpoint of the DC website. - Args: - entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry - will be fetched. - as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. - Defaults to False. - Returns: - dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: - - A flat list of parent dictionaries (if `as_tree` is False), or - - A nested ancestry tree (if `as_tree` is True). Each parent is represented by - a dict with 'dcid', 'name', and 'type'. - """ + Args: + dcid (str): The DCID of the entity whose parents should be fetched. + + Returns: + tuple[Parent, ...]: A tuple of Parent objects representing the entity's immediate parents. + """ + return fetch_parents_lru(self, dcid) + + def fetch_entity_ancestry( + self, + entity_dcids: str | list[str], + as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]: + """Fetches the full ancestry (flat or nested) for one or more entities. + For each input DCID, this method builds the complete ancestry graph using a + breadth-first traversal and parallel fetching. + It returns either a flat list of unique parents or a nested tree structure for + each entity, depending on the `as_tree` flag. The flat list matches the structure + of the `/api/place/parent` endpoint of the DC website. + Args: + entity_dcids (str | list[str]): One or more DCIDs of the entities whose ancestry + will be fetched. + as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. + Defaults to False. + Returns: + dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: + - A flat list of parent dictionaries (if `as_tree` is False), or + - A nested ancestry tree (if `as_tree` is True). Each parent is represented by + a dict with 'dcid', 'name', and 'type'. + """ - if isinstance(entity_dcids, str): - entity_dcids = [entity_dcids] - - result = {} - - # Use a thread pool to fetch ancestry graphs in parallel for each input entity - with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor: - futures = [ - executor.submit(build_ancestry_map, - root=dcid, - fetch_fn=self._fetch_parents_cached) - for dcid in entity_dcids - ] - - # Gather ancestry maps and postprocess into flat or nested form - for future in futures: - dcid, ancestry = future.result() - if as_tree: - ancestry = build_ancestry_tree(dcid, ancestry) - else: - ancestry = flatten_ancestry(ancestry) - result[dcid] = ancestry - - return result + if isinstance(entity_dcids, str): + entity_dcids = [entity_dcids] + + result = {} + + # Use a thread pool to fetch ancestry graphs in parallel for each input entity + with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor: + futures = [ + executor.submit(build_ancestry_map, + root=dcid, + fetch_fn=self._fetch_parents_cached) + for dcid in entity_dcids + ] + + # Gather ancestry maps and postprocess into flat or nested form + for future in futures: + dcid, ancestry = future.result() + if as_tree: + ancestry = build_ancestry_tree(dcid, ancestry) + else: + ancestry = flatten_ancestry(ancestry) + result[dcid] = ancestry + + return result diff --git a/datacommons_client/models/graph.py b/datacommons_client/models/graph.py new file mode 100644 index 0000000..e7a0b4c --- /dev/null +++ b/datacommons_client/models/graph.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import TypeAlias + +from datacommons_client.utils.data_processing import SerializableMixin + + +@dataclass(frozen=True) +class Parent(SerializableMixin): + """A class representing a parent node in a graph. + Attributes: + dcid (str): The ID of the parent node. + name (str): The name of the parent node. + type (str | list[str]): The type(s) of the parent node. + """ + + dcid: str + name: str + type: str | list[str] + + +# A dictionary mapping DCIDs to lists of Parent objects. +AncestryMap: TypeAlias = dict[str, list[Parent]] diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py index 3d6b012..b617885 100644 --- a/datacommons_client/utils/graph.py +++ b/datacommons_client/utils/graph.py @@ -3,31 +3,14 @@ from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait -from dataclasses import dataclass from functools import lru_cache from typing import Callable, Optional -from datacommons_client.utils.data_processing import SerializableMixin +from datacommons_client.models.graph import AncestryMap +from datacommons_client.models.graph import Parent PARENTS_MAX_WORKERS = 10 - -@dataclass(frozen=True) -class Parent(SerializableMixin): - """A class representing a parent node in a graph. - Attributes: - dcid (str): The ID of the parent node. - name (str): The name of the parent node. - type (str | list[str]): The type(s) of the parent node. - """ - - dcid: str - name: str - type: str | list[str] - - -AncestryMap = dict[str, list[Parent]] - # -- -- Fetch tools -- -- @@ -48,7 +31,7 @@ def _fetch_parents_uncached(endpoint, dcid: str) -> list[Parent]: Returns: A list of parent dictionaries, each containing 'dcid', 'name', and 'type'. """ - return endpoint.fetch_entity_parents(dcid).get(dcid, []) + return endpoint.fetch_entity_parents(dcid, as_dict=False).get(dcid, []) @lru_cache(maxsize=512) From 8fccc33327eeca96bef62805a3d1a63f3181fc8b Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 31 Mar 2025 19:45:34 +0200 Subject: [PATCH 5/8] Add tests --- .../tests/endpoints/test_node_endpoint.py | 221 ++++++++++++--- datacommons_client/tests/utils/test_graph.py | 264 ++++++++++++++++++ 2 files changed, 445 insertions(+), 40 deletions(-) create mode 100644 datacommons_client/tests/utils/test_graph.py diff --git a/datacommons_client/tests/endpoints/test_node_endpoint.py b/datacommons_client/tests/endpoints/test_node_endpoint.py index 8720d57..e87c10c 100644 --- a/datacommons_client/tests/endpoints/test_node_endpoint.py +++ b/datacommons_client/tests/endpoints/test_node_endpoint.py @@ -1,8 +1,11 @@ from unittest.mock import MagicMock +from unittest.mock import patch from datacommons_client.endpoints.base import API from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.response import NodeResponse +from datacommons_client.models.graph import Parent +from datacommons_client.models.node import Node def test_node_endpoint_initialization(): @@ -30,13 +33,15 @@ def test_node_endpoint_fetch(): endpoint = NodeEndpoint(api=api_mock) response = endpoint.fetch(node_dcids="test_node", expression="name") - api_mock.post.assert_called_once_with(payload={ - "nodes": ["test_node"], - "property": "name" - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_once_with( + payload={ + "nodes": ["test_node"], + "property": "name" + }, + endpoint="node", + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "test_node" in response.data @@ -80,13 +85,15 @@ def test_node_endpoint_fetch_property_values_out(): out=True) expected_expression = "->name{typeOf:City}" - api_mock.post.assert_called_once_with(payload={ - "nodes": ["node1"], - "property": expected_expression - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_once_with( + payload={ + "nodes": ["node1"], + "property": expected_expression + }, + endpoint="node", + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "node1" in response.data @@ -112,13 +119,15 @@ def test_node_endpoint_fetch_property_values_in(): out=False) expected_expression = "<-name{typeOf:City}" - api_mock.post.assert_called_once_with(payload={ - "nodes": ["node1"], - "property": expected_expression - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_once_with( + payload={ + "nodes": ["node1"], + "property": expected_expression + }, + endpoint="node", + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "node1" in response.data @@ -133,11 +142,13 @@ def test_node_endpoint_fetch_all_classes(): }})) response = endpoint.fetch_all_classes() - endpoint.fetch_property_values.assert_called_once_with(node_dcids="Class", - properties="typeOf", - out=False, - all_pages=True, - next_token=None) + endpoint.fetch_property_values.assert_called_once_with( + node_dcids="Class", + properties="typeOf", + out=False, + all_pages=True, + next_token=None, + ) assert isinstance(response, NodeResponse) assert "Class" in response.data @@ -162,23 +173,153 @@ def test_node_endpoint_fetch_property_values_string_vs_list(): properties="name", constraints=None, out=True) - api_mock.post.assert_called_with(payload={ - "nodes": ["node1"], - "property": "->name" - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_with( + payload={ + "nodes": ["node1"], + "property": "->name" + }, + endpoint="node", + all_pages=True, + next_token=None, + ) # List input response = endpoint.fetch_property_values(node_dcids="node1", properties=["name", "typeOf"], constraints=None, out=True) - api_mock.post.assert_called_with(payload={ - "nodes": ["node1"], - "property": "->[name, typeOf]" - }, - endpoint="node", - all_pages=True, - next_token=None) + api_mock.post.assert_called_with( + payload={ + "nodes": ["node1"], + "property": "->[name, typeOf]" + }, + endpoint="node", + all_pages=True, + next_token=None, + ) + + +@patch("datacommons_client.utils.graph.build_parents_dictionary") +def test_fetch_entity_parents_as_dict(mock_build_parents_dict): + """Test fetch_entity_parents with dictionary output.""" + api_mock = MagicMock() + endpoint = NodeEndpoint(api=api_mock) + + api_mock.post.return_value = { + "data": { + "X": { + "properties": { + "containedInPlace": [] + } + } + } + } + endpoint.fetch_property_values = MagicMock() + endpoint.fetch_property_values.return_value.get_properties.return_value = { + "X": Node("X", "X name", types=["Country"]) + } + + result = endpoint.fetch_entity_parents("X", as_dict=True) + assert result == {"X": [{"dcid": "X", "name": "X name", "type": "Country"}]} + + endpoint.fetch_property_values.assert_called_once_with( + node_dcids="X", properties="containedInPlace") + + +@patch("datacommons_client.utils.graph.build_parents_dictionary") +def test_fetch_entity_parents_as_objects(mock_build_parents_dict): + """Test fetch_entity_parents with raw Parent object output.""" + api_mock = MagicMock() + endpoint = NodeEndpoint(api=api_mock) + + # Simulate what fetch_property_values().get_properties() would return + endpoint.fetch_property_values = MagicMock() + endpoint.fetch_property_values.return_value.get_properties.return_value = { + "X": Node("X", "X name", types=["Country"]) + } + + # Mock output of build_parents_dictionary + parent_obj = Node("X", "X name", types=["Country"]) + mock_build_parents_dict.return_value = {"X": [parent_obj]} + + result = endpoint.fetch_entity_parents("X", as_dict=False) + + assert isinstance(result, dict) + assert "X" in result + assert isinstance(result["X"][0], Parent) + + endpoint.fetch_property_values.assert_called_once_with( + node_dcids="X", properties="containedInPlace") + + +@patch("datacommons_client.endpoints.node.fetch_parents_lru") +def test_fetch_parents_cached_delegates_to_lru(mock_fetch_lru): + mock_fetch_lru.return_value = (Parent("B", "B name", "Region"),) + endpoint = NodeEndpoint(api=MagicMock()) + + result = endpoint._fetch_parents_cached("X") + + assert isinstance(result, tuple) + assert result[0].dcid == "B" + mock_fetch_lru.assert_called_once_with(endpoint, "X") + + +@patch("datacommons_client.endpoints.node.flatten_ancestry") +@patch("datacommons_client.endpoints.node.build_ancestry_map") +def test_fetch_entity_ancestry_flat(mock_build_map, mock_flatten): + """Test fetch_entity_ancestry with flat structure (as_tree=False).""" + mock_build_map.return_value = ( + "X", + { + "X": [Parent("A", "A name", "Country")], + "A": [], + }, + ) + mock_flatten.return_value = [{ + "dcid": "A", + "name": "A name", + "type": "Country" + }] + + endpoint = NodeEndpoint(api=MagicMock()) + result = endpoint.fetch_entity_ancestry("X", as_tree=False) + + assert result == {"X": [{"dcid": "A", "name": "A name", "type": "Country"}]} + mock_build_map.assert_called_once() + mock_flatten.assert_called_once() + + +@patch("datacommons_client.endpoints.node.build_ancestry_tree") +@patch("datacommons_client.endpoints.node.build_ancestry_map") +def test_fetch_entity_ancestry_tree(mock_build_map, mock_build_tree): + """Test fetch_entity_ancestry with tree structure (as_tree=True).""" + mock_build_map.return_value = ( + "Y", + { + "Y": [Parent("Z", "Z name", "Region")], + "Z": [], + }, + ) + mock_build_tree.return_value = { + "dcid": + "Y", + "name": + None, + "type": + None, + "parents": [{ + "dcid": "Z", + "name": "Z name", + "type": "Region", + "parents": [] + }], + } + + endpoint = NodeEndpoint(api=MagicMock()) + result = endpoint.fetch_entity_ancestry("Y", as_tree=True) + + assert "Y" in result + assert result["Y"]["dcid"] == "Y" + assert result["Y"]["parents"][0]["dcid"] == "Z" + mock_build_map.assert_called_once() + mock_build_tree.assert_called_once_with("Y", mock_build_map.return_value[1]) diff --git a/datacommons_client/tests/utils/test_graph.py b/datacommons_client/tests/utils/test_graph.py new file mode 100644 index 0000000..21b5e09 --- /dev/null +++ b/datacommons_client/tests/utils/test_graph.py @@ -0,0 +1,264 @@ +from collections import defaultdict +from unittest.mock import MagicMock + +from datacommons_client.models.node import Node +from datacommons_client.utils.graph import _assemble_tree +from datacommons_client.utils.graph import _fetch_parents_uncached +from datacommons_client.utils.graph import _postorder_nodes +from datacommons_client.utils.graph import build_ancestry_tree +from datacommons_client.utils.graph import build_parents_dictionary +from datacommons_client.utils.graph import fetch_parents_lru +from datacommons_client.utils.graph import flatten_ancestry + + +def test_fetch_parents_uncached_returns_data(): + """Test _fetch_parents_uncached delegates to endpoint correctly.""" + endpoint = MagicMock() + endpoint.fetch_entity_parents.return_value.get.return_value = [ + Node(dcid="parent1", name="Parent 1", types=["Country"]) + ] + + result = _fetch_parents_uncached(endpoint, "test_dcid") + assert isinstance(result, list) + assert result[0].dcid == "parent1" + + endpoint.fetch_entity_parents.assert_called_once_with("test_dcid", + as_dict=False) + + +def test_fetch_parents_lru_caches_results(): + """Test fetch_parents_lru uses LRU cache and returns tuple.""" + endpoint = MagicMock() + endpoint.fetch_entity_parents.return_value.get.return_value = [ + Parent(dcid="parentX", name="Parent X", type="Region") + ] + + result1 = fetch_parents_lru(endpoint, "nodeA") + + # This should hit cache + result2 = fetch_parents_lru(endpoint, "nodeA") + # This should hit cache again + fetch_parents_lru(endpoint, "nodeA") + + assert isinstance(result1, tuple) + assert result1[0].dcid == "parentX" + assert result1 == result2 + assert endpoint.fetch_entity_parents.call_count == 1 # Called only once + + +def test_build_parents_dictionary_single_type(): + """Test build_parents_dictionary with single type.""" + data = {"child1": [Node(dcid="p1", name="Parent 1", types=["Country"])]} + + result = build_parents_dictionary(data) + + assert result["child1"][0].dcid == "p1" + assert result["child1"][0].type == "Country" + + +def test_build_parents_dictionary_multi_type(): + data = { + "child1": [Node(dcid="p1", name="Parent 1", types=["Country", "Region"])], + "child2": [Node(dcid="p2", name="Parent 12", types=["Continent"])], + } + + result = build_parents_dictionary(data) + + assert result["child1"][0].type == ["Country", "Region"] + assert result["child2"][0].type == "Continent" + + +def test_build_ancestry_map_linear_tree(): + """A -> B -> C""" + + def fetch_mock(dcid): + return { + "C": (Parent("B", "Node B", "Type"),), + "B": (Parent("A", "Node A", "Type"),), + "A": tuple(), + }.get(dcid, tuple()) + + root, ancestry = build_ancestry_map("C", fetch_mock, max_workers=2) + + assert root == "C" # Since we start from C + assert set(ancestry.keys()) == {"C", "B", "A"} # All nodes should be present + assert ancestry["C"][0].dcid == "B" # First parent of C is B + assert ancestry["B"][0].dcid == "A" # First parent of B is A + assert ancestry["A"] == [] # No parents for A + + +from datacommons_client.models.graph import Parent +from datacommons_client.utils.graph import build_ancestry_map + + +def test_build_ancestry_map_branching_graph(): + r""" + Graph: + F + / \ + D E + / \ / + B C + \/ + A + """ + + def fetch_mock(dcid): + return { + "A": (Parent("B", "Node B", "Type"), Parent("C", "Node C", "Type")), + "B": (Parent("D", "Node D", "Type"),), + "C": (Parent("D", "Node D", "Type"), Parent("E", "Node E", "Type")), + "D": (Parent("F", "Node F", "Type"),), + "E": (Parent("F", "Node F", "Type"),), + "F": tuple(), + }.get(dcid, tuple()) + + root, ancestry = build_ancestry_map("A", fetch_mock, max_workers=4) + + assert root == "A" + assert set(ancestry.keys()) == {"A", "B", "C", "D", "E", "F"} + assert [p.dcid for p in ancestry["A"]] == ["B", "C"] # A has two parents + assert [p.dcid for p in ancestry["B"]] == ["D"] # B has one parent + assert [p.dcid for p in ancestry["C"]] == ["D", "E"] # C has two parents + assert [p.dcid for p in ancestry["D"]] == ["F"] # D has one parent + assert [p.dcid for p in ancestry["E"]] == ["F"] # E has one parent + assert ancestry["F"] == [] # F has no parents + + +def test_build_ancestry_map_cycle_detection(): + """ + Graph with a cycle: + A -> B -> C -> A + (Should not loop infinitely) + """ + + call_count = defaultdict(int) + + def fetch_mock(dcid): + call_count[dcid] += 1 + return { + "A": (Parent("B", "B", "Type"),), + "B": (Parent("C", "C", "Type"),), + "C": (Parent("A", "A", "Type"),), # Cycle back to A + }.get(dcid, tuple()) + + root, ancestry = build_ancestry_map("A", fetch_mock, max_workers=2) + + assert root == "A" # Since we start from A + assert set(ancestry.keys()) == {"A", "B", "C"} + assert [p.dcid for p in ancestry["A"]] == ["B"] # A points to B + assert [p.dcid for p in ancestry["B"]] == ["C"] # B points to C + assert [p.dcid for p in ancestry["C"]] == ["A" + ] # C points back to A but it's ok + + # Check that each node was fetched only once (particularly for A to avoid infinite loop) + assert call_count["A"] == 1 + assert call_count["B"] == 1 + assert call_count["C"] == 1 + + +def test_postorder_nodes_simple_graph(): + """Test postorder traversal on a simple graph.""" + ancestry = { + "C": [Parent("B", "B", "Type")], + "B": [Parent("A", "A", "Type")], + "A": [], + } + + order = _postorder_nodes("C", ancestry) + assert order == ["A", "B", "C"] + + new_order = _postorder_nodes("B", ancestry) + assert new_order == ["A", "B"] + + +def test_assemble_tree_creates_nested_structure(): + """Test _assemble_tree creates a nested structure.""" + ancestry = { + "C": [Parent("B", "Node B", "Type")], + "B": [Parent("A", "Node A", "Type")], + "A": [], + } + postorder = ["A", "B", "C"] + tree = _assemble_tree(postorder, ancestry) + + assert tree["dcid"] == "C" + assert tree["parents"][0]["dcid"] == "B" + assert tree["parents"][0]["parents"][0]["dcid"] == "A" + + +def test_postorder_nodes_ignores_unreachable_nodes(): + """ + Graph: + A → B → C + Ancestry map also includes D (unconnected) + """ + ancestry = { + "A": [Parent("B", "B", "Type")], + "B": [Parent("C", "C", "Type")], + "C": [], + "D": [Parent("X", "X", "Type")], + } + + postorder = _postorder_nodes("A", ancestry) + + # Only nodes reachable from A should be included + assert postorder == ["C", "B", "A"] + assert "D" not in postorder + + +def test_assemble_tree_shared_parent_not_duplicated(): + """ + Structure: + A → C + B → C + Both A and B have same parent C + """ + + ancestry = { + "A": [Parent("C", "C name", "City")], + "B": [Parent("C", "C name", "City")], + "C": [], + } + + postorder = ["C", "A", "B"] # C first to allow bottom-up build + tree = _assemble_tree(postorder, ancestry) + + assert tree["dcid"] == "B" + assert len(tree["parents"]) == 1 + assert tree["parents"][0]["dcid"] == "C" + + # Confirm C only appears once + assert tree["parents"][0] is not None + assert tree["parents"][0]["name"] == "C name" + + +def test_build_ancestry_tree_nested_output(): + """Test build_ancestry_tree creates a nested structure.""" + ancestry = { + "C": [Parent("B", "B", "Type")], + "B": [Parent("A", "A", "Type")], + "A": [], + } + + tree = build_ancestry_tree("C", ancestry) + + assert tree["dcid"] == "C" + assert tree["parents"][0]["dcid"] == "B" + assert tree["parents"][0]["parents"][0]["dcid"] == "A" + + +def test_flatten_ancestry_deduplicates(): + """Test flatten_ancestry deduplicates parents.""" + + ancestry = { + "X": [Parent("A", "A", "Country")], + "Y": [Parent("A", "A", "Country"), + Parent("B", "B", "City")], + } + + flat = flatten_ancestry(ancestry) + + assert {"dcid": "A", "name": "A", "type": "Country"} in flat + assert {"dcid": "B", "name": "B", "type": "City"} in flat + assert len(flat) == 2 From a9a7ee6638c9c7386fa596b9740b07abc2e90388 Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Tue, 1 Apr 2025 16:44:51 +0200 Subject: [PATCH 6/8] Update node.py --- datacommons_client/endpoints/node.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 0220a9f..042fa4c 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -245,7 +245,10 @@ def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]: def fetch_entity_ancestry( self, entity_dcids: str | list[str], - as_tree: bool = False) -> dict[str, list[dict[str, str]] | dict]: + as_tree: bool = False, + *, + max_concurrent_requests: Optional[int] = ANCESTRY_MAX_WORKERS + ) -> dict[str, list[dict[str, str]] | dict]: """Fetches the full ancestry (flat or nested) for one or more entities. For each input DCID, this method builds the complete ancestry graph using a breadth-first traversal and parallel fetching. @@ -257,6 +260,8 @@ def fetch_entity_ancestry( will be fetched. as_tree (bool): If True, returns a nested tree structure; otherwise, returns a flat list. Defaults to False. + max_concurrent_requests (Optional[int]): The maximum number of concurrent requests to make. + Defaults to ANCESTRY_MAX_WORKERS. Returns: dict[str, list[dict[str, str]] | dict]: A dictionary mapping each input DCID to either: - A flat list of parent dictionaries (if `as_tree` is False), or @@ -270,7 +275,7 @@ def fetch_entity_ancestry( result = {} # Use a thread pool to fetch ancestry graphs in parallel for each input entity - with ThreadPoolExecutor(max_workers=ANCESTRY_MAX_WORKERS) as executor: + with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: futures = [ executor.submit(build_ancestry_map, root=dcid, From 306d28d965e63ba1301aa6fa1322145b2ca7362a Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 7 Apr 2025 20:13:43 +0200 Subject: [PATCH 7/8] Simplify as `Node` --- datacommons_client/endpoints/node.py | 13 +-- datacommons_client/models/graph.py | 22 ---- datacommons_client/models/node.py | 4 +- .../tests/endpoints/test_node_endpoint.py | 60 +---------- datacommons_client/tests/utils/test_graph.py | 100 +++++++----------- datacommons_client/utils/graph.py | 49 +++------ 6 files changed, 60 insertions(+), 188 deletions(-) delete mode 100644 datacommons_client/models/graph.py diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 042fa4c..c700edf 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -6,10 +6,9 @@ from datacommons_client.endpoints.payloads import NodeRequestPayload from datacommons_client.endpoints.payloads import normalize_properties_to_string from datacommons_client.endpoints.response import NodeResponse -from datacommons_client.models.graph import Parent +from datacommons_client.models.node import Node from datacommons_client.utils.graph import build_ancestry_map from datacommons_client.utils.graph import build_ancestry_tree -from datacommons_client.utils.graph import build_parents_dictionary from datacommons_client.utils.graph import fetch_parents_lru from datacommons_client.utils.graph import flatten_ancestry @@ -199,7 +198,7 @@ def fetch_entity_parents( self, entity_dcids: str | list[str], *, - as_dict: bool = True) -> dict[str, list[Parent | dict]]: + as_dict: bool = True) -> dict[str, list[Node | dict]]: """Fetches the direct parents of one or more entities using the 'containedInPlace' property. Args: @@ -220,14 +219,12 @@ def fetch_entity_parents( properties="containedInPlace", ).get_properties() - result = build_parents_dictionary(data=data) - if as_dict: - return {k: [p.to_dict() for p in v] for k, v in result.items()} + return {k: v.to_dict() for k, v in data.items()} - return result + return data - def _fetch_parents_cached(self, dcid: str) -> tuple[Parent, ...]: + def _fetch_parents_cached(self, dcid: str) -> tuple[Node, ...]: """Returns cached parent nodes for a given entity using an LRU cache. This private wrapper exists because `@lru_cache` cannot be applied directly diff --git a/datacommons_client/models/graph.py b/datacommons_client/models/graph.py deleted file mode 100644 index e7a0b4c..0000000 --- a/datacommons_client/models/graph.py +++ /dev/null @@ -1,22 +0,0 @@ -from dataclasses import dataclass -from typing import TypeAlias - -from datacommons_client.utils.data_processing import SerializableMixin - - -@dataclass(frozen=True) -class Parent(SerializableMixin): - """A class representing a parent node in a graph. - Attributes: - dcid (str): The ID of the parent node. - name (str): The name of the parent node. - type (str | list[str]): The type(s) of the parent node. - """ - - dcid: str - name: str - type: str | list[str] - - -# A dictionary mapping DCIDs to lists of Parent objects. -AncestryMap: TypeAlias = dict[str, list[Parent]] diff --git a/datacommons_client/models/node.py b/datacommons_client/models/node.py index 8d85f45..861cb84 100644 --- a/datacommons_client/models/node.py +++ b/datacommons_client/models/node.py @@ -2,6 +2,8 @@ from dataclasses import field from typing import Any, Dict, List, Optional, TypeAlias +from datacommons_client.utils.data_processing import SerializableMixin + NextToken: TypeAlias = Optional[str] NodeDCID: TypeAlias = str ArcLabel: TypeAlias = str @@ -10,7 +12,7 @@ @dataclass -class Node: +class Node(SerializableMixin): """Represents an individual node in the Data Commons knowledge graph. Attributes: diff --git a/datacommons_client/tests/endpoints/test_node_endpoint.py b/datacommons_client/tests/endpoints/test_node_endpoint.py index e87c10c..59468c9 100644 --- a/datacommons_client/tests/endpoints/test_node_endpoint.py +++ b/datacommons_client/tests/endpoints/test_node_endpoint.py @@ -4,7 +4,6 @@ from datacommons_client.endpoints.base import API from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.response import NodeResponse -from datacommons_client.models.graph import Parent from datacommons_client.models.node import Node @@ -199,62 +198,9 @@ def test_node_endpoint_fetch_property_values_string_vs_list(): ) -@patch("datacommons_client.utils.graph.build_parents_dictionary") -def test_fetch_entity_parents_as_dict(mock_build_parents_dict): - """Test fetch_entity_parents with dictionary output.""" - api_mock = MagicMock() - endpoint = NodeEndpoint(api=api_mock) - - api_mock.post.return_value = { - "data": { - "X": { - "properties": { - "containedInPlace": [] - } - } - } - } - endpoint.fetch_property_values = MagicMock() - endpoint.fetch_property_values.return_value.get_properties.return_value = { - "X": Node("X", "X name", types=["Country"]) - } - - result = endpoint.fetch_entity_parents("X", as_dict=True) - assert result == {"X": [{"dcid": "X", "name": "X name", "type": "Country"}]} - - endpoint.fetch_property_values.assert_called_once_with( - node_dcids="X", properties="containedInPlace") - - -@patch("datacommons_client.utils.graph.build_parents_dictionary") -def test_fetch_entity_parents_as_objects(mock_build_parents_dict): - """Test fetch_entity_parents with raw Parent object output.""" - api_mock = MagicMock() - endpoint = NodeEndpoint(api=api_mock) - - # Simulate what fetch_property_values().get_properties() would return - endpoint.fetch_property_values = MagicMock() - endpoint.fetch_property_values.return_value.get_properties.return_value = { - "X": Node("X", "X name", types=["Country"]) - } - - # Mock output of build_parents_dictionary - parent_obj = Node("X", "X name", types=["Country"]) - mock_build_parents_dict.return_value = {"X": [parent_obj]} - - result = endpoint.fetch_entity_parents("X", as_dict=False) - - assert isinstance(result, dict) - assert "X" in result - assert isinstance(result["X"][0], Parent) - - endpoint.fetch_property_values.assert_called_once_with( - node_dcids="X", properties="containedInPlace") - - @patch("datacommons_client.endpoints.node.fetch_parents_lru") def test_fetch_parents_cached_delegates_to_lru(mock_fetch_lru): - mock_fetch_lru.return_value = (Parent("B", "B name", "Region"),) + mock_fetch_lru.return_value = (Node("B", "B name", "Region"),) endpoint = NodeEndpoint(api=MagicMock()) result = endpoint._fetch_parents_cached("X") @@ -271,7 +217,7 @@ def test_fetch_entity_ancestry_flat(mock_build_map, mock_flatten): mock_build_map.return_value = ( "X", { - "X": [Parent("A", "A name", "Country")], + "X": [Node("A", "A name", "Country")], "A": [], }, ) @@ -296,7 +242,7 @@ def test_fetch_entity_ancestry_tree(mock_build_map, mock_build_tree): mock_build_map.return_value = ( "Y", { - "Y": [Parent("Z", "Z name", "Region")], + "Y": [Node("Z", "Z name", "Region")], "Z": [], }, ) diff --git a/datacommons_client/tests/utils/test_graph.py b/datacommons_client/tests/utils/test_graph.py index 21b5e09..d6bfd51 100644 --- a/datacommons_client/tests/utils/test_graph.py +++ b/datacommons_client/tests/utils/test_graph.py @@ -5,8 +5,8 @@ from datacommons_client.utils.graph import _assemble_tree from datacommons_client.utils.graph import _fetch_parents_uncached from datacommons_client.utils.graph import _postorder_nodes +from datacommons_client.utils.graph import build_ancestry_map from datacommons_client.utils.graph import build_ancestry_tree -from datacommons_client.utils.graph import build_parents_dictionary from datacommons_client.utils.graph import fetch_parents_lru from datacommons_client.utils.graph import flatten_ancestry @@ -30,7 +30,7 @@ def test_fetch_parents_lru_caches_results(): """Test fetch_parents_lru uses LRU cache and returns tuple.""" endpoint = MagicMock() endpoint.fetch_entity_parents.return_value.get.return_value = [ - Parent(dcid="parentX", name="Parent X", type="Region") + Node(dcid="parentX", name="Parent X", types=["Region"]) ] result1 = fetch_parents_lru(endpoint, "nodeA") @@ -46,35 +46,13 @@ def test_fetch_parents_lru_caches_results(): assert endpoint.fetch_entity_parents.call_count == 1 # Called only once -def test_build_parents_dictionary_single_type(): - """Test build_parents_dictionary with single type.""" - data = {"child1": [Node(dcid="p1", name="Parent 1", types=["Country"])]} - - result = build_parents_dictionary(data) - - assert result["child1"][0].dcid == "p1" - assert result["child1"][0].type == "Country" - - -def test_build_parents_dictionary_multi_type(): - data = { - "child1": [Node(dcid="p1", name="Parent 1", types=["Country", "Region"])], - "child2": [Node(dcid="p2", name="Parent 12", types=["Continent"])], - } - - result = build_parents_dictionary(data) - - assert result["child1"][0].type == ["Country", "Region"] - assert result["child2"][0].type == "Continent" - - def test_build_ancestry_map_linear_tree(): """A -> B -> C""" def fetch_mock(dcid): return { - "C": (Parent("B", "Node B", "Type"),), - "B": (Parent("A", "Node A", "Type"),), + "C": (Node("B", "Node B", "Type"),), + "B": (Node("A", "Node A", "Type"),), "A": tuple(), }.get(dcid, tuple()) @@ -87,29 +65,25 @@ def fetch_mock(dcid): assert ancestry["A"] == [] # No parents for A -from datacommons_client.models.graph import Parent -from datacommons_client.utils.graph import build_ancestry_map - - def test_build_ancestry_map_branching_graph(): r""" - Graph: - F - / \ - D E - / \ / - B C - \/ - A - """ + Graph: + F + / \ + D E + / \ / + B C + \/ + A + """ def fetch_mock(dcid): return { - "A": (Parent("B", "Node B", "Type"), Parent("C", "Node C", "Type")), - "B": (Parent("D", "Node D", "Type"),), - "C": (Parent("D", "Node D", "Type"), Parent("E", "Node E", "Type")), - "D": (Parent("F", "Node F", "Type"),), - "E": (Parent("F", "Node F", "Type"),), + "A": (Node("B", "Node B", "Type"), Node("C", "Node C", "Type")), + "B": (Node("D", "Node D", "Type"),), + "C": (Node("D", "Node D", "Type"), Node("E", "Node E", "Type")), + "D": (Node("F", "Node F", "Type"),), + "E": (Node("F", "Node F", "Type"),), "F": tuple(), }.get(dcid, tuple()) @@ -137,9 +111,9 @@ def test_build_ancestry_map_cycle_detection(): def fetch_mock(dcid): call_count[dcid] += 1 return { - "A": (Parent("B", "B", "Type"),), - "B": (Parent("C", "C", "Type"),), - "C": (Parent("A", "A", "Type"),), # Cycle back to A + "A": (Node("B", "B", "Type"),), + "B": (Node("C", "C", "Type"),), + "C": (Node("A", "A", "Type"),), # Cycle back to A }.get(dcid, tuple()) root, ancestry = build_ancestry_map("A", fetch_mock, max_workers=2) @@ -160,8 +134,8 @@ def fetch_mock(dcid): def test_postorder_nodes_simple_graph(): """Test postorder traversal on a simple graph.""" ancestry = { - "C": [Parent("B", "B", "Type")], - "B": [Parent("A", "A", "Type")], + "C": [Node("B", "B", "Type")], + "B": [Node("A", "A", "Type")], "A": [], } @@ -175,8 +149,8 @@ def test_postorder_nodes_simple_graph(): def test_assemble_tree_creates_nested_structure(): """Test _assemble_tree creates a nested structure.""" ancestry = { - "C": [Parent("B", "Node B", "Type")], - "B": [Parent("A", "Node A", "Type")], + "C": [Node("B", "Node B", "Type")], + "B": [Node("A", "Node A", "Type")], "A": [], } postorder = ["A", "B", "C"] @@ -194,10 +168,10 @@ def test_postorder_nodes_ignores_unreachable_nodes(): Ancestry map also includes D (unconnected) """ ancestry = { - "A": [Parent("B", "B", "Type")], - "B": [Parent("C", "C", "Type")], + "A": [Node("B", "B", "Type")], + "B": [Node("C", "C", "Type")], "C": [], - "D": [Parent("X", "X", "Type")], + "D": [Node("X", "X", "Type")], } postorder = _postorder_nodes("A", ancestry) @@ -216,8 +190,8 @@ def test_assemble_tree_shared_parent_not_duplicated(): """ ancestry = { - "A": [Parent("C", "C name", "City")], - "B": [Parent("C", "C name", "City")], + "A": [Node("C", "C name", "City")], + "B": [Node("C", "C name", "City")], "C": [], } @@ -236,8 +210,8 @@ def test_assemble_tree_shared_parent_not_duplicated(): def test_build_ancestry_tree_nested_output(): """Test build_ancestry_tree creates a nested structure.""" ancestry = { - "C": [Parent("B", "B", "Type")], - "B": [Parent("A", "A", "Type")], + "C": [Node("B", "B", "Type")], + "B": [Node("A", "A", "Type")], "A": [], } @@ -252,13 +226,13 @@ def test_flatten_ancestry_deduplicates(): """Test flatten_ancestry deduplicates parents.""" ancestry = { - "X": [Parent("A", "A", "Country")], - "Y": [Parent("A", "A", "Country"), - Parent("B", "B", "City")], + "X": [Node("A", "A", types=["Country"])], + "Y": [Node("A", "A", types=["Country"]), + Node("B", "B", types=["City"])], } flat = flatten_ancestry(ancestry) - assert {"dcid": "A", "name": "A", "type": "Country"} in flat - assert {"dcid": "B", "name": "B", "type": "City"} in flat + assert {"dcid": "A", "name": "A", "types": ["Country"]} in flat + assert {"dcid": "B", "name": "B", "types": ["City"]} in flat assert len(flat) == 2 diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py index b617885..b7cd710 100644 --- a/datacommons_client/utils/graph.py +++ b/datacommons_client/utils/graph.py @@ -4,17 +4,18 @@ from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait from functools import lru_cache -from typing import Callable, Optional +from typing import Callable, Optional, TypeAlias -from datacommons_client.models.graph import AncestryMap -from datacommons_client.models.graph import Parent +from datacommons_client.models.node import Node PARENTS_MAX_WORKERS = 10 +AncestryMap: TypeAlias = dict[str, list[Node]] + # -- -- Fetch tools -- -- -def _fetch_parents_uncached(endpoint, dcid: str) -> list[Parent]: +def _fetch_parents_uncached(endpoint, dcid: str) -> list[Node]: """Fetches the immediate parents of a given DCID from the endpoint, without caching. This function performs a direct, uncached call to the API. It exists @@ -31,11 +32,13 @@ def _fetch_parents_uncached(endpoint, dcid: str) -> list[Parent]: Returns: A list of parent dictionaries, each containing 'dcid', 'name', and 'type'. """ - return endpoint.fetch_entity_parents(dcid, as_dict=False).get(dcid, []) + parents = endpoint.fetch_entity_parents(dcid, as_dict=False).get(dcid, []) + + return parents if isinstance(parents, list) else [parents] @lru_cache(maxsize=512) -def fetch_parents_lru(endpoint, dcid: str) -> tuple[Parent, ...]: +def fetch_parents_lru(endpoint, dcid: str) -> tuple[Node, ...]: """Fetches parents of a DCID using an LRU cache for improved performance. Args: endpoint: A client object with a `fetch_entity_parents` method. @@ -48,35 +51,11 @@ def fetch_parents_lru(endpoint, dcid: str) -> tuple[Parent, ...]: # -- -- Ancestry tools -- -- -def build_parents_dictionary(data: dict) -> dict[str, list[Parent]]: - """Transforms a dictionary of entities and their parents into a structured - dictionary mapping each entity to its list of Parents. - - Args: - data (dict): The properties dictionary of a Node.fetch_property_values call. - - Returns: - dict[str, list[Parent]]: A dictionary where each key is an entity DCID - and the value is a list of Parent objects representing its parents. - - """ - - result: dict[str, list[Parent]] = {} - - for entity, properties in data.items(): - if not isinstance(properties, list): - properties = [properties] - - for parent in properties: - parent_type = parent.types[0] if len(parent.types) == 1 else parent.types - result.setdefault(entity, []).append( - Parent(dcid=parent.dcid, name=parent.name, type=parent_type)) - return result def build_ancestry_map( root: str, - fetch_fn: Callable[[str], tuple[Parent, ...]], + fetch_fn: Callable[[str], tuple[Node, ...]], max_workers: Optional[int] = PARENTS_MAX_WORKERS, ) -> tuple[str, AncestryMap]: """Constructs a complete ancestry map for the root node using parallel @@ -194,7 +173,7 @@ def _assemble_tree(postorder: list[str], ancestry: AncestryMap) -> dict: for parent in ancestry.get(node, []): parent_dcid = parent.dcid name = parent.name - entity_type = parent.type + entity_type = parent.types # If the parent node is not already in the cache, add it. if parent_dcid not in tree_cache: @@ -246,9 +225,5 @@ def flatten_ancestry(ancestry: AncestryMap) -> list[dict[str, str]]: if parent.dcid in seen: continue seen.add(parent.dcid) - flat.append({ - "dcid": parent.dcid, - "name": parent.name, - "type": parent.type - }) + flat.append(parent.to_dict()) return flat From 045e3098ba6ad0a39e714f50e4618ed413507f40 Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Mon, 7 Apr 2025 20:25:39 +0200 Subject: [PATCH 8/8] Update node.py --- datacommons_client/endpoints/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index c700edf..c5dce96 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -12,7 +12,7 @@ from datacommons_client.utils.graph import fetch_parents_lru from datacommons_client.utils.graph import flatten_ancestry -ANCESTRY_MAX_WORKERS = 20 +ANCESTRY_MAX_WORKERS = 10 class NodeEndpoint(Endpoint):