|
1 |
| -from typing import List, Union |
| 1 | +from typing import List, Set, Union |
2 | 2 |
|
3 | 3 | import networkx as nx
|
4 | 4 |
|
|
12 | 12 | "is_node_common_cause",
|
13 | 13 | "set_nodes_as_latent_confounders",
|
14 | 14 | "is_valid_mec_graph",
|
| 15 | + "inducing_path", |
15 | 16 | ]
|
16 | 17 |
|
17 | 18 |
|
@@ -333,3 +334,265 @@ def _single_shortest_path_early_stop(G, firstlevel, paths, cutoff, join, valid_p
|
333 | 334 | nextlevel[w] = 1
|
334 | 335 | level += 1
|
335 | 336 | return paths
|
| 337 | + |
| 338 | + |
| 339 | +def _directed_sub_graph_ancestors(G, node: Node): |
| 340 | + """Finds the ancestors of a node in the directed subgraph. |
| 341 | +
|
| 342 | + Parameters |
| 343 | + ---------- |
| 344 | + G : Graph |
| 345 | + The graph. |
| 346 | + node : Node |
| 347 | + The node for which we have to find the ancestors. |
| 348 | +
|
| 349 | + Returns |
| 350 | + ------- |
| 351 | + out : set |
| 352 | + The parents of the provided node. |
| 353 | + """ |
| 354 | + |
| 355 | + return nx.ancestors(G.sub_directed_graph(), node) |
| 356 | + |
| 357 | + |
| 358 | +def _directed_sub_graph_parents(G, node: Node): |
| 359 | + """Finds the parents of a node in the directed subgraph. |
| 360 | +
|
| 361 | + Parameters |
| 362 | + ---------- |
| 363 | + G : Graph |
| 364 | + The graph. |
| 365 | + node : Node |
| 366 | + The node for which we have to find the parents. |
| 367 | +
|
| 368 | + Returns |
| 369 | + ------- |
| 370 | + out : set |
| 371 | + The parents of the provided node. |
| 372 | + """ |
| 373 | + |
| 374 | + return set(G.sub_directed_graph().predecessors(node)) |
| 375 | + |
| 376 | + |
| 377 | +def _bidirected_sub_graph_neighbors(G, node: Node): |
| 378 | + """Finds the neighbors of a node in the bidirected subgraph. |
| 379 | +
|
| 380 | + Parameters |
| 381 | + ---------- |
| 382 | + G : Graph |
| 383 | + The graph. |
| 384 | + node : Node |
| 385 | + The node for which we have to find the neighbors. |
| 386 | +
|
| 387 | + Returns |
| 388 | + ------- |
| 389 | + out : set |
| 390 | + The parents of the provided node. |
| 391 | + """ |
| 392 | + bidirected_parents = set() |
| 393 | + |
| 394 | + if not isinstance(G, CPDAG): |
| 395 | + bidirected_parents = set(G.sub_bidirected_graph().neighbors(node)) |
| 396 | + |
| 397 | + return bidirected_parents |
| 398 | + |
| 399 | + |
| 400 | +def _is_collider(G, prev_node: Node, cur_node: Node, next_node: Node): |
| 401 | + """Checks if the given node is a collider or not. |
| 402 | +
|
| 403 | + Parameters |
| 404 | + ---------- |
| 405 | + G : graph |
| 406 | + The graph. |
| 407 | + prev_node : node |
| 408 | + The previous node in the path. |
| 409 | + cur_node : node |
| 410 | + The node to be checked. |
| 411 | + next_node: Node |
| 412 | + The next node in the path. |
| 413 | +
|
| 414 | + Returns |
| 415 | + ------- |
| 416 | + iscollider : bool |
| 417 | + Bool is set true if the node is a collider, false otherwise. |
| 418 | + """ |
| 419 | + parents = _directed_sub_graph_parents(G, cur_node) |
| 420 | + parents = parents.union(_bidirected_sub_graph_neighbors(G, cur_node)) |
| 421 | + |
| 422 | + if prev_node in parents and next_node in parents: |
| 423 | + return True |
| 424 | + |
| 425 | + return False |
| 426 | + |
| 427 | + |
| 428 | +def _shortest_valid_path( |
| 429 | + G, |
| 430 | + node_x: Node, |
| 431 | + node_y: Node, |
| 432 | + L: Set, |
| 433 | + S: Set, |
| 434 | + visited: Set, |
| 435 | + all_ancestors: Set, |
| 436 | + cur_node: Node, |
| 437 | + prev_node: Node, |
| 438 | +): |
| 439 | + """Recursively explores a graph to find a path. |
| 440 | +
|
| 441 | + Finds path that are compliant with the inducing path requirements. |
| 442 | +
|
| 443 | + Parameters |
| 444 | + ---------- |
| 445 | + G : graph |
| 446 | + The graph. |
| 447 | + node_x : node |
| 448 | + The source node. |
| 449 | + node_y : node |
| 450 | + The destination node |
| 451 | + L : Set |
| 452 | + Set containing all the non-colliders. |
| 453 | + S : Set |
| 454 | + Set containing all the colliders. |
| 455 | + visited : Set |
| 456 | + Set containing all the nodes already visited. |
| 457 | + all_ancestors : Set |
| 458 | + Set containing all the ancestors a collider needs to be checked against. |
| 459 | + cur_node : node |
| 460 | + The current node. |
| 461 | + prev_node : node |
| 462 | + The previous node in the path. |
| 463 | +
|
| 464 | + Returns |
| 465 | + ------- |
| 466 | + path : Tuple[bool, path] |
| 467 | + A tuple containing a bool and a path which is empty if the bool is false. |
| 468 | + """ |
| 469 | + path_exists = False |
| 470 | + path = [] |
| 471 | + visited.add(cur_node) |
| 472 | + neighbors = G.neighbors(cur_node) |
| 473 | + |
| 474 | + if cur_node is node_y: |
| 475 | + return (True, [node_y]) |
| 476 | + |
| 477 | + for elem in neighbors: |
| 478 | + if elem in visited: |
| 479 | + continue |
| 480 | + |
| 481 | + else: |
| 482 | + # If the current node is a collider, check that it is either an |
| 483 | + # ancestor of X, Y or any element of S or that it is |
| 484 | + # the destination node itself. |
| 485 | + if ( |
| 486 | + _is_collider(G, prev_node, cur_node, elem) |
| 487 | + and (cur_node not in all_ancestors) |
| 488 | + and (cur_node not in S) |
| 489 | + and (cur_node is not node_y) |
| 490 | + ): |
| 491 | + continue |
| 492 | + |
| 493 | + # If the current node is not a collider, check that it is |
| 494 | + # either in L or the destination node itself. |
| 495 | + |
| 496 | + elif ( |
| 497 | + not _is_collider(G, prev_node, cur_node, elem) |
| 498 | + and (cur_node not in L) |
| 499 | + and (cur_node is not node_y) |
| 500 | + ): |
| 501 | + continue |
| 502 | + |
| 503 | + # if it is a valid node and not the destination node, |
| 504 | + # check if it has a path to the destination node |
| 505 | + |
| 506 | + path_exists, temp_path = _shortest_valid_path( |
| 507 | + G, node_x, node_y, L, S, visited, all_ancestors, elem, cur_node |
| 508 | + ) |
| 509 | + |
| 510 | + if path_exists: |
| 511 | + path.append(cur_node) |
| 512 | + path.extend(temp_path) |
| 513 | + break |
| 514 | + |
| 515 | + return (path_exists, path) |
| 516 | + |
| 517 | + |
| 518 | +def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): |
| 519 | + """Checks if an inducing path exists between two nodes. |
| 520 | +
|
| 521 | + An inducing path is defined in :footcite:`Zhang2008`. |
| 522 | +
|
| 523 | + Parameters |
| 524 | + ---------- |
| 525 | + G : Graph |
| 526 | + The graph. |
| 527 | + node_x : node |
| 528 | + The source node. |
| 529 | + node_y : node |
| 530 | + The destination node. |
| 531 | + L : Set |
| 532 | + Nodes that are ignored on the path. Defaults to an empty set. See Notes for details. |
| 533 | + S: Set |
| 534 | + Nodes that are always conditioned on. Defaults to an empty set. See Notes for details. |
| 535 | +
|
| 536 | + Returns |
| 537 | + ------- |
| 538 | + path : Tuple[bool, path] |
| 539 | + A tuple containing a bool and a path if the bool is true, an empty list otherwise. |
| 540 | +
|
| 541 | + Notes |
| 542 | + ----- |
| 543 | + An inducing path intuitively is a path between two non-adjacent nodes that |
| 544 | + cannot be d-separated. Therefore, the path is always "active" regardless of |
| 545 | + what variables we condition on. L contains all the non-colliders, these nodes |
| 546 | + are ignored along the path. S contains nodes that are always conditioned on |
| 547 | + (hence if the ancestors of colliders are in S, then those collider |
| 548 | + paths are always "active"). |
| 549 | +
|
| 550 | + References |
| 551 | + ---------- |
| 552 | + .. footbibliography:: |
| 553 | + """ |
| 554 | + if L is None: |
| 555 | + L = set() |
| 556 | + |
| 557 | + if S is None: |
| 558 | + S = set() |
| 559 | + |
| 560 | + nodes = set(G.nodes) |
| 561 | + |
| 562 | + if node_x not in nodes or node_y not in nodes: |
| 563 | + raise ValueError("The provided nodes are not in the graph.") |
| 564 | + |
| 565 | + if node_x == node_y: |
| 566 | + raise ValueError("The source and destination nodes are the same.") |
| 567 | + |
| 568 | + path = [] # this will contain the path. |
| 569 | + |
| 570 | + x_ancestors = _directed_sub_graph_ancestors(G, node_x) |
| 571 | + y_ancestors = _directed_sub_graph_ancestors(G, node_y) |
| 572 | + |
| 573 | + xy_ancestors = x_ancestors.union(y_ancestors) |
| 574 | + |
| 575 | + s_ancestors: set[Node] = set() |
| 576 | + |
| 577 | + for elem in S: |
| 578 | + s_ancestors = s_ancestors.union(_directed_sub_graph_ancestors(G, elem)) |
| 579 | + |
| 580 | + # ancestors of X, Y and all the elements of S |
| 581 | + |
| 582 | + all_ancestors = xy_ancestors.union(s_ancestors) |
| 583 | + x_neighbors = G.neighbors(node_x) |
| 584 | + |
| 585 | + path_exists = False |
| 586 | + for elem in x_neighbors: |
| 587 | + |
| 588 | + visited = {node_x} |
| 589 | + if elem not in visited: |
| 590 | + path_exists, temp_path = _shortest_valid_path( |
| 591 | + G, node_x, node_y, L, S, visited, all_ancestors, elem, node_x |
| 592 | + ) |
| 593 | + if path_exists: |
| 594 | + path.append(node_x) |
| 595 | + path.extend(temp_path) |
| 596 | + break |
| 597 | + |
| 598 | + return (path_exists, path) |
0 commit comments