Skip to content

Commit b71f7ee

Browse files
aryan26royadam2392
andauthored
[ENH] Add ability to determine whether an inducing path exists between two nodes (#78)
* Add function definition for inducing path algorithm --------- Signed-off-by: Aryan Roy <[email protected]> Co-authored-by: Adam Li <[email protected]>
1 parent 8908d57 commit b71f7ee

File tree

6 files changed

+461
-7
lines changed

6 files changed

+461
-7
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
- run:
6565
name: Install pysal dependencies
6666
command: |
67-
sudo apt install libspatialindex-dev xdg-utils shared-mime-info
67+
sudo apt install libspatialindex-dev xdg-utils shared-mime-info desktop-file-utils
6868
- run:
6969
name: Setup pandoc
7070
command: sudo apt update && sudo apt install -y pandoc optipng

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ causal graph operations.
4343
.. autosummary::
4444
:toctree: generated/
4545

46+
inducing_path
4647
is_valid_mec_graph
4748
possible_ancestors
4849
possible_descendants

docs/glossary.rst

+4-5
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ General Concepts
3535
API
3636
Refers to both the *specific* interfaces for graphs implemented in
3737
pywhy-graphs and the *generalized* conventions across types of
38-
graphs as described in this glossary and :ref:`overviewed in the
39-
contributor documentation <api_overview>`.
38+
graphs as described in this glossary.
4039

4140
The specific interfaces that constitute pywhy-graphs's public API are
4241
largely documented in :ref:`api_ref`. However, we less formally consider
4342
anything as public API if none of the identifiers required to access it
44-
begins with ``_``. We generally try to maintain :term:`backwards
45-
compatibility` for all objects in the public API.
43+
begins with ``_``. We generally try to maintain backwards
44+
compatibility for all objects in the public API.
4645

4746
Private API, including functions, modules and methods beginning ``_``
4847
are not assured to be stable.
@@ -85,7 +84,7 @@ General Concepts
8584
experimental
8685
An experimental tool is already usable but its public API, such as
8786
default parameter values or fitted attributes, is still subject to
88-
change in future versions without the usual :term:`deprecation`
87+
change in future versions without the usual deprecation
8988
warning policy.
9089

9190
F-node

docs/whats_new/v0.1.rst

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Changelog
4545
- |Feature| Implement pre-commit hooks for development, by `Jaron Lee`_ (:pr:`68`)
4646
- |Feature| Implement a new submodule for converting graphs to a functional model, with :func:`pywhy_graphs.functional.make_graph_linear_gaussian`, by `Adam Li`_ (:pr:`75`)
4747
- |Feature| Implement a multidomain linear functional graph, with :func:`pywhy_graphs.functional.make_graph_multidomain`, by `Adam Li`_ (:pr:`77`)
48+
- |Feature| Implement and test functions to find inducing paths between two nodes, `Aryan Roy`_ (:pr:`78`)
49+
4850

4951
Code and Documentation Contributors
5052
-----------------------------------

pywhy_graphs/algorithms/generic.py

+264-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Union
1+
from typing import List, Set, Union
22

33
import networkx as nx
44

@@ -12,6 +12,7 @@
1212
"is_node_common_cause",
1313
"set_nodes_as_latent_confounders",
1414
"is_valid_mec_graph",
15+
"inducing_path",
1516
]
1617

1718

@@ -333,3 +334,265 @@ def _single_shortest_path_early_stop(G, firstlevel, paths, cutoff, join, valid_p
333334
nextlevel[w] = 1
334335
level += 1
335336
return paths
337+
338+
339+
def _directed_sub_graph_ancestors(G, node: Node):
340+
"""Finds the ancestors of a node in the directed subgraph.
341+
342+
Parameters
343+
----------
344+
G : Graph
345+
The graph.
346+
node : Node
347+
The node for which we have to find the ancestors.
348+
349+
Returns
350+
-------
351+
out : set
352+
The parents of the provided node.
353+
"""
354+
355+
return nx.ancestors(G.sub_directed_graph(), node)
356+
357+
358+
def _directed_sub_graph_parents(G, node: Node):
359+
"""Finds the parents of a node in the directed subgraph.
360+
361+
Parameters
362+
----------
363+
G : Graph
364+
The graph.
365+
node : Node
366+
The node for which we have to find the parents.
367+
368+
Returns
369+
-------
370+
out : set
371+
The parents of the provided node.
372+
"""
373+
374+
return set(G.sub_directed_graph().predecessors(node))
375+
376+
377+
def _bidirected_sub_graph_neighbors(G, node: Node):
378+
"""Finds the neighbors of a node in the bidirected subgraph.
379+
380+
Parameters
381+
----------
382+
G : Graph
383+
The graph.
384+
node : Node
385+
The node for which we have to find the neighbors.
386+
387+
Returns
388+
-------
389+
out : set
390+
The parents of the provided node.
391+
"""
392+
bidirected_parents = set()
393+
394+
if not isinstance(G, CPDAG):
395+
bidirected_parents = set(G.sub_bidirected_graph().neighbors(node))
396+
397+
return bidirected_parents
398+
399+
400+
def _is_collider(G, prev_node: Node, cur_node: Node, next_node: Node):
401+
"""Checks if the given node is a collider or not.
402+
403+
Parameters
404+
----------
405+
G : graph
406+
The graph.
407+
prev_node : node
408+
The previous node in the path.
409+
cur_node : node
410+
The node to be checked.
411+
next_node: Node
412+
The next node in the path.
413+
414+
Returns
415+
-------
416+
iscollider : bool
417+
Bool is set true if the node is a collider, false otherwise.
418+
"""
419+
parents = _directed_sub_graph_parents(G, cur_node)
420+
parents = parents.union(_bidirected_sub_graph_neighbors(G, cur_node))
421+
422+
if prev_node in parents and next_node in parents:
423+
return True
424+
425+
return False
426+
427+
428+
def _shortest_valid_path(
429+
G,
430+
node_x: Node,
431+
node_y: Node,
432+
L: Set,
433+
S: Set,
434+
visited: Set,
435+
all_ancestors: Set,
436+
cur_node: Node,
437+
prev_node: Node,
438+
):
439+
"""Recursively explores a graph to find a path.
440+
441+
Finds path that are compliant with the inducing path requirements.
442+
443+
Parameters
444+
----------
445+
G : graph
446+
The graph.
447+
node_x : node
448+
The source node.
449+
node_y : node
450+
The destination node
451+
L : Set
452+
Set containing all the non-colliders.
453+
S : Set
454+
Set containing all the colliders.
455+
visited : Set
456+
Set containing all the nodes already visited.
457+
all_ancestors : Set
458+
Set containing all the ancestors a collider needs to be checked against.
459+
cur_node : node
460+
The current node.
461+
prev_node : node
462+
The previous node in the path.
463+
464+
Returns
465+
-------
466+
path : Tuple[bool, path]
467+
A tuple containing a bool and a path which is empty if the bool is false.
468+
"""
469+
path_exists = False
470+
path = []
471+
visited.add(cur_node)
472+
neighbors = G.neighbors(cur_node)
473+
474+
if cur_node is node_y:
475+
return (True, [node_y])
476+
477+
for elem in neighbors:
478+
if elem in visited:
479+
continue
480+
481+
else:
482+
# If the current node is a collider, check that it is either an
483+
# ancestor of X, Y or any element of S or that it is
484+
# the destination node itself.
485+
if (
486+
_is_collider(G, prev_node, cur_node, elem)
487+
and (cur_node not in all_ancestors)
488+
and (cur_node not in S)
489+
and (cur_node is not node_y)
490+
):
491+
continue
492+
493+
# If the current node is not a collider, check that it is
494+
# either in L or the destination node itself.
495+
496+
elif (
497+
not _is_collider(G, prev_node, cur_node, elem)
498+
and (cur_node not in L)
499+
and (cur_node is not node_y)
500+
):
501+
continue
502+
503+
# if it is a valid node and not the destination node,
504+
# check if it has a path to the destination node
505+
506+
path_exists, temp_path = _shortest_valid_path(
507+
G, node_x, node_y, L, S, visited, all_ancestors, elem, cur_node
508+
)
509+
510+
if path_exists:
511+
path.append(cur_node)
512+
path.extend(temp_path)
513+
break
514+
515+
return (path_exists, path)
516+
517+
518+
def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
519+
"""Checks if an inducing path exists between two nodes.
520+
521+
An inducing path is defined in :footcite:`Zhang2008`.
522+
523+
Parameters
524+
----------
525+
G : Graph
526+
The graph.
527+
node_x : node
528+
The source node.
529+
node_y : node
530+
The destination node.
531+
L : Set
532+
Nodes that are ignored on the path. Defaults to an empty set. See Notes for details.
533+
S: Set
534+
Nodes that are always conditioned on. Defaults to an empty set. See Notes for details.
535+
536+
Returns
537+
-------
538+
path : Tuple[bool, path]
539+
A tuple containing a bool and a path if the bool is true, an empty list otherwise.
540+
541+
Notes
542+
-----
543+
An inducing path intuitively is a path between two non-adjacent nodes that
544+
cannot be d-separated. Therefore, the path is always "active" regardless of
545+
what variables we condition on. L contains all the non-colliders, these nodes
546+
are ignored along the path. S contains nodes that are always conditioned on
547+
(hence if the ancestors of colliders are in S, then those collider
548+
paths are always "active").
549+
550+
References
551+
----------
552+
.. footbibliography::
553+
"""
554+
if L is None:
555+
L = set()
556+
557+
if S is None:
558+
S = set()
559+
560+
nodes = set(G.nodes)
561+
562+
if node_x not in nodes or node_y not in nodes:
563+
raise ValueError("The provided nodes are not in the graph.")
564+
565+
if node_x == node_y:
566+
raise ValueError("The source and destination nodes are the same.")
567+
568+
path = [] # this will contain the path.
569+
570+
x_ancestors = _directed_sub_graph_ancestors(G, node_x)
571+
y_ancestors = _directed_sub_graph_ancestors(G, node_y)
572+
573+
xy_ancestors = x_ancestors.union(y_ancestors)
574+
575+
s_ancestors: set[Node] = set()
576+
577+
for elem in S:
578+
s_ancestors = s_ancestors.union(_directed_sub_graph_ancestors(G, elem))
579+
580+
# ancestors of X, Y and all the elements of S
581+
582+
all_ancestors = xy_ancestors.union(s_ancestors)
583+
x_neighbors = G.neighbors(node_x)
584+
585+
path_exists = False
586+
for elem in x_neighbors:
587+
588+
visited = {node_x}
589+
if elem not in visited:
590+
path_exists, temp_path = _shortest_valid_path(
591+
G, node_x, node_y, L, S, visited, all_ancestors, elem, node_x
592+
)
593+
if path_exists:
594+
path.append(node_x)
595+
path.extend(temp_path)
596+
break
597+
598+
return (path_exists, path)

0 commit comments

Comments
 (0)