Skip to content

Commit 0bd5c50

Browse files
committed
Add graph utils
1 parent 8eeeff6 commit 0bd5c50

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed

datacommons_client/utils/graph.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
from collections import deque
2+
from concurrent.futures import FIRST_COMPLETED
3+
from concurrent.futures import Future
4+
from concurrent.futures import ThreadPoolExecutor
5+
from concurrent.futures import wait
6+
from dataclasses import dataclass
7+
from functools import lru_cache
8+
from typing import Callable, Optional
9+
10+
from datacommons_client.utils.data_processing import SerializableMixin
11+
12+
PARENTS_MAX_WORKERS = 10
13+
14+
15+
@dataclass(frozen=True)
16+
class Parent(SerializableMixin):
17+
"""A class representing a parent node in a graph.
18+
Attributes:
19+
dcid (str): The ID of the parent node.
20+
name (str): The name of the parent node.
21+
type (str | list[str]): The type(s) of the parent node.
22+
"""
23+
24+
dcid: str
25+
name: str
26+
type: str | list[str]
27+
28+
29+
AncestryMap = dict[str, list[Parent]]
30+
31+
# -- -- Fetch tools -- --
32+
33+
34+
def _fetch_parents_uncached(endpoint, dcid: str) -> list[Parent]:
35+
"""Fetches the immediate parents of a given DCID from the endpoint, without caching.
36+
37+
This function performs a direct, uncached call to the API. It exists
38+
primarily to serve as the internal, cache-free fetch used by `fetch_parents_lru`, which
39+
applies LRU caching on top of this raw access function.
40+
41+
By isolating the pure fetch logic here, we ensure that caching is handled separately
42+
and cleanly via `@lru_cache` on `fetch_parents_lru`, which requires its wrapped
43+
function to be deterministic and side-effect free.
44+
45+
Args:
46+
endpoint: A client object with a `fetch_entity_parents` method.
47+
dcid (str): The entity ID for which to fetch parents.
48+
Returns:
49+
A list of parent dictionaries, each containing 'dcid', 'name', and 'type'.
50+
"""
51+
return endpoint.fetch_entity_parents(dcid).get(dcid, [])
52+
53+
54+
@lru_cache(maxsize=512)
55+
def fetch_parents_lru(endpoint, dcid: str) -> tuple[Parent, ...]:
56+
"""Fetches parents of a DCID using an LRU cache for improved performance.
57+
Args:
58+
endpoint: A client object with a `fetch_entity_parents` method.
59+
dcid (str): The entity ID to fetch parents for.
60+
Returns:
61+
A tuple of `Parent` objects corresponding to the entity’s parents.
62+
"""
63+
parents = _fetch_parents_uncached(endpoint, dcid)
64+
return tuple(p for p in parents)
65+
66+
67+
# -- -- Ancestry tools -- --
68+
def build_parents_dictionary(data: dict) -> dict[str, list[Parent]]:
69+
"""Transforms a dictionary of entities and their parents into a structured
70+
dictionary mapping each entity to its list of Parents.
71+
72+
Args:
73+
data (dict): The properties dictionary of a Node.fetch_property_values call.
74+
75+
Returns:
76+
dict[str, list[Parent]]: A dictionary where each key is an entity DCID
77+
and the value is a list of Parent objects representing its parents.
78+
79+
"""
80+
81+
result: dict[str, list[Parent]] = {}
82+
83+
for entity, properties in data.items():
84+
if not isinstance(properties, list):
85+
properties = [properties]
86+
87+
for parent in properties:
88+
parent_type = parent.types[0] if len(parent.types) == 1 else parent.types
89+
result.setdefault(entity, []).append(
90+
Parent(dcid=parent.dcid, name=parent.name, type=parent_type))
91+
return result
92+
93+
94+
def build_ancestry_map(
95+
root: str,
96+
fetch_fn: Callable[[str], tuple[Parent, ...]],
97+
max_workers: Optional[int] = PARENTS_MAX_WORKERS,
98+
) -> tuple[str, AncestryMap]:
99+
"""Constructs a complete ancestry map for the root node using parallel
100+
Breadth-First Search (BFS).
101+
102+
Traverses the ancestry graph upward from the root node, discovering all parent
103+
relationships by fetching in parallel.
104+
105+
Args:
106+
root (str): The DCID of the root entity to start from.
107+
fetch_fn (Callable): A function that takes a DCID and returns a Parent tuple.
108+
max_workers (Optional[int]): Max number of threads to use for parallel fetching.
109+
Optional, defaults to `PARENTS_MAX_WORKERS`.
110+
111+
Returns:
112+
A tuple containing:
113+
- The original root DCID.
114+
- A dictionary mapping each DCID to a list of its `Parent`s.
115+
"""
116+
ancestry: AncestryMap = {}
117+
visited: set[str] = set()
118+
in_progress: dict[str, Future] = {}
119+
120+
original_root = root
121+
122+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
123+
queue = deque([root])
124+
125+
# Standard BFS loop, but fetches are executed in parallel threads
126+
while queue or in_progress:
127+
# Submit fetch tasks for all nodes in the queue
128+
while queue:
129+
dcid = queue.popleft()
130+
# Check if the node has already been visited or is in progress
131+
if dcid not in visited and dcid not in in_progress:
132+
# Submit the fetch task
133+
in_progress[dcid] = executor.submit(fetch_fn, dcid)
134+
135+
# Check if any futures are still in progress
136+
if not in_progress:
137+
continue
138+
139+
# Wait for at least one future to complete
140+
done_futures, _ = wait(in_progress.values(), return_when=FIRST_COMPLETED)
141+
142+
# Find which DCIDs have completed
143+
completed_dcids = [
144+
dcid for dcid, future in in_progress.items() if future in done_futures
145+
]
146+
147+
# Process completed fetches and enqueue any unseen parents
148+
for dcid in completed_dcids:
149+
future = in_progress.pop(dcid)
150+
parents = list(future.result())
151+
ancestry[dcid] = parents
152+
visited.add(dcid)
153+
154+
for parent in parents:
155+
if parent.dcid not in visited and parent.dcid not in in_progress:
156+
queue.append(parent.dcid)
157+
158+
return original_root, ancestry
159+
160+
161+
def _postorder_nodes(root: str, ancestry: AncestryMap) -> list[str]:
162+
"""Generates a postorder list of all nodes reachable from the root.
163+
164+
Postorder ensures children are processed before their parents. That way the tree
165+
is built bottom-up.
166+
167+
Args:
168+
root (str): The root DCID to start traversal from.
169+
ancestry (AncestryMap): The ancestry graph.
170+
Returns:
171+
A list of DCIDs in postorder (i.e children before parents).
172+
"""
173+
# Initialize stack and postorder list
174+
stack, postorder, seen = [root], [], set()
175+
176+
# Traverse the graph using a stack
177+
while stack:
178+
node = stack.pop()
179+
# Skip if already seen
180+
if node in seen:
181+
continue
182+
seen.add(node)
183+
postorder.append(node)
184+
# Push all unvisited parents onto the stack (i.e climb up the graph, child -> parent)
185+
for parent in ancestry.get(node, []):
186+
parent_dcid = parent.dcid
187+
if parent_dcid not in seen:
188+
stack.append(parent_dcid)
189+
190+
# Reverse the list so that parents come after their children (i.e postorder)
191+
return list(reversed(postorder))
192+
193+
194+
def _assemble_tree(postorder: list[str], ancestry: AncestryMap) -> dict:
195+
"""Builds a nested dictionary tree from a postorder node list and ancestry map.
196+
Constructs a nested representation of the ancestry graph, ensuring that parents
197+
are embedded after their children (which is enabled by postorder).
198+
Args:
199+
postorder (list[str]): List of node DCIDs in postorder.
200+
ancestry (AncestryMap): Map from DCID to list of Parent objects.
201+
Returns:
202+
A nested dictionary representing the ancestry tree rooted at the last postorder node.
203+
"""
204+
tree_cache: dict[str, dict] = {}
205+
206+
for node in postorder:
207+
# Initialize the node dictionary.
208+
node_dict = {"dcid": node, "name": None, "type": None, "parents": []}
209+
210+
# For each parent of the current node, fetch its details and add it to the node_dict.
211+
for parent in ancestry.get(node, []):
212+
parent_dcid = parent.dcid
213+
name = parent.name
214+
entity_type = parent.type
215+
216+
# If the parent node is not already in the cache, add it.
217+
if parent_dcid not in tree_cache:
218+
tree_cache[parent_dcid] = {
219+
"dcid": parent_dcid,
220+
"name": name,
221+
"type": entity_type,
222+
"parents": [],
223+
}
224+
225+
parent_node = tree_cache[parent_dcid]
226+
227+
# Ensure name/type are up to date (in case of duplicates)
228+
parent_node["name"] = name
229+
parent_node["type"] = entity_type
230+
node_dict["parents"].append(parent_node)
231+
232+
tree_cache[node] = node_dict
233+
234+
# The root node is the last one in postorder, that's what gets returned
235+
return tree_cache[postorder[-1]]
236+
237+
238+
def build_ancestry_tree(root: str, ancestry: AncestryMap) -> dict:
239+
"""Builds a nested ancestry tree from an ancestry map.
240+
Args:
241+
root (str): The DCID of the root node.
242+
ancestry (AncestryMap): A flat ancestry map built from `_build_ancestry_map`.
243+
Returns:
244+
A nested dictionary tree rooted at the specified DCID.
245+
"""
246+
postorder = _postorder_nodes(root, ancestry)
247+
return _assemble_tree(postorder, ancestry)
248+
249+
250+
def flatten_ancestry(ancestry: AncestryMap) -> list[dict[str, str]]:
251+
"""Flattens the ancestry map into a deduplicated list of parent records.
252+
Args:
253+
ancestry (AncestryMap): Ancestry mapping of DCIDs to lists of Parent objects.
254+
Returns:
255+
A list of dictionaries with keys 'dcid', 'name', and 'type', containing
256+
each unique parent in the graph.
257+
"""
258+
259+
flat: list = []
260+
seen: set[str] = set()
261+
for parents in ancestry.values():
262+
for parent in parents:
263+
if parent.dcid in seen:
264+
continue
265+
seen.add(parent.dcid)
266+
flat.append({
267+
"dcid": parent.dcid,
268+
"name": parent.name,
269+
"type": parent.type
270+
})
271+
return flat

0 commit comments

Comments
 (0)