Skip to content

Commit 3de911e

Browse files
authored
[ENH] Add functional submodule (#75)
* Add functional submodule * update lock file --------- Signed-off-by: Adam Li <[email protected]>
1 parent 8658c5b commit 3de911e

File tree

7 files changed

+154
-4
lines changed

7 files changed

+154
-4
lines changed

docs/api.rst

+11
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,17 @@ a SCM and their data starting from the causal graph.
147147
simulate.simulate_data_from_var
148148
simulate.simulate_var_process_from_summary_graph
149149

150+
Converting graphs to functional models
151+
======================================
152+
An experimental submodule for converting graphs to functional models, such as
153+
linear structural equation Gaussian models (SEMs).
154+
155+
.. currentmodule:: pywhy_graphs.functional
156+
157+
.. autosummary::
158+
:toctree: generated/
159+
160+
make_graph_linear_gaussian
150161

151162
Visualization of causal graphs
152163
==============================

docs/whats_new/v0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Changelog
4343
- |Feature| Implement export/import functions to go to/from pywhy-graphs to pcalg and tetrad, by `Adam Li`_ (:pr:`60`)
4444
- |Feature| Implement export/import functions to go to/from pywhy-graphs to ananke-causal, by `Jaron Lee`_ (:pr:`63`)
4545
- |Feature| Implement pre-commit hooks for development, by `Jaron Lee`_ (:pr:`68`)
46+
- |Feature| Implement a new submodule for converting graphs to a functional model, with :func:`pywhy_graphs.functional.make_graph_linear_gaussian`, by `Adam Li`_ (:pr:`75`)
4647

4748
Code and Documentation Contributors
4849
-----------------------------------

poetry.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pywhy_graphs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from . import classes
2121
from . import networkx
2222
from . import simulate
23+
from . import functional

pywhy_graphs/functional/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .linear import make_graph_linear_gaussian

pywhy_graphs/functional/linear.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from typing import Callable, List, Optional
2+
3+
import networkx as nx
4+
import numpy as np
5+
6+
7+
def make_graph_linear_gaussian(
8+
G: nx.DiGraph,
9+
node_mean_lims: Optional[List[float]] = None,
10+
node_std_lims: Optional[List[float]] = None,
11+
edge_functions: List[Callable[[float], float]] = None,
12+
edge_weight_lims: Optional[List[float]] = None,
13+
random_state=None,
14+
):
15+
r"""Convert an existing DAG to a linear Gaussian graphical model.
16+
17+
All nodes are sampled from a normal distribution with parametrizations
18+
defined uniformly at random between the limits set by the input parameters.
19+
The edges apply then a weight and a function based on the inputs in an additive fashion.
20+
For node :math:`X_i`, we have:
21+
22+
.. math::
23+
24+
X_i = \\sum_{j \in parents} w_j f_j(X_j) + \\epsilon_i
25+
26+
where:
27+
28+
- :math:`\\epsilon_i \sim N(\mu_i, \sigma_i)`, where :math:`\mu_i` is sampled
29+
uniformly at random from `node_mean_lims` and :math:`\sigma_i` is sampled
30+
uniformly at random from `node_std_lims`.
31+
- :math:`w_j \sim U(\\text{edge_weight_lims})`
32+
- :math:`f_j` is a function sampled uniformly at random
33+
from `edge_functions`
34+
35+
Parameters
36+
----------
37+
G : NetworkX DiGraph
38+
The graph to sample data from. The graph will be modified in-place
39+
to get the weights and functions of the edges.
40+
node_mean_lims : Optional[List[float]], optional
41+
The lower and upper bounds of the mean of the Gaussian random variable, by default None,
42+
which defaults to a mean of 0.
43+
node_std_lims : Optional[List[float]], optional
44+
The lower and upper bounds of the std of the Gaussian random variable, by default None,
45+
which defaults to a std of 1.
46+
edge_functions : List[Callable[float]], optional
47+
The set of edge functions that take in an iid sample from the parent and computes
48+
a transformation (possibly nonlinear), such as ``(lambda x: x**2, lambda x: x)``,
49+
by default None, which defaults to the identity function ``lambda x: x``.
50+
edge_weight_lims : Optional[List[float]], optional
51+
The lower and upper bounds of the edge weight, by default None,
52+
which defaults to a weight of 1.
53+
random_state : int, optional
54+
Random seed, by default None.
55+
56+
Returns
57+
-------
58+
G : NetworkX DiGraph
59+
NetworkX graph with the edge weights and functions set with node attributes
60+
set with ``'parent_functions'``, and ``'gaussian_noise_function'``. Moreover
61+
the graph attribute ``'linear_gaussian'`` is set to ``True``.
62+
"""
63+
if not nx.is_directed_acyclic_graph(G):
64+
raise ValueError("The input graph must be a DAG.")
65+
rng = np.random.default_rng(random_state)
66+
67+
if node_mean_lims is None:
68+
node_mean_lims = [0, 0]
69+
elif len(node_mean_lims) != 2:
70+
raise ValueError("node_mean_lims must be a list of length 2.")
71+
if node_std_lims is None:
72+
node_std_lims = [1, 1]
73+
elif len(node_std_lims) != 2:
74+
raise ValueError("node_std_lims must be a list of length 2.")
75+
if edge_functions is None:
76+
edge_functions = [lambda x: x]
77+
if edge_weight_lims is None:
78+
edge_weight_lims = [1, 1]
79+
elif len(edge_weight_lims) != 2:
80+
raise ValueError("edge_weight_lims must be a list of length 2.")
81+
82+
# Create list of topologically sorted nodes
83+
top_sort_idx = list(nx.topological_sort(G))
84+
85+
for node_idx in top_sort_idx:
86+
# get all parents
87+
parents = sorted(list(G.predecessors(node_idx)))
88+
89+
# sample noise
90+
mean = rng.uniform(low=node_mean_lims[0], high=node_mean_lims[1])
91+
std = rng.uniform(low=node_std_lims[0], high=node_std_lims[1])
92+
93+
# sample weight and edge function for each parent
94+
node_function = dict()
95+
for parent in parents:
96+
weight = rng.uniform(low=edge_weight_lims[0], high=edge_weight_lims[1])
97+
func = rng.choice(edge_functions)
98+
node_function.update({parent: {"weight": weight, "func": func}})
99+
100+
# set the node attribute "functions" to hold the weight and function wrt each parent
101+
nx.set_node_attributes(G, {node_idx: node_function}, "parent_functions")
102+
nx.set_node_attributes(G, {node_idx: {"mean": mean, "std": std}}, "gaussian_noise_function")
103+
G.graph["linear_gaussian"] = True
104+
return G
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import networkx as nx
2+
import pytest
3+
4+
from pywhy_graphs.functional import make_graph_linear_gaussian
5+
from pywhy_graphs.simulate import simulate_random_er_dag
6+
7+
8+
def test_make_linear_gaussian_graph():
9+
G = simulate_random_er_dag(n_nodes=5, seed=12345, ensure_acyclic=True)
10+
11+
G = make_graph_linear_gaussian(G, random_state=12345)
12+
13+
assert all(key in nx.get_node_attributes(G, "parent_functions") for key in G.nodes)
14+
assert all(key in nx.get_node_attributes(G, "gaussian_noise_function") for key in G.nodes)
15+
16+
17+
def test_make_linear_gaussian_graph_errors():
18+
G = simulate_random_er_dag(n_nodes=2, seed=12345, ensure_acyclic=True)
19+
20+
with pytest.raises(ValueError, match="must be a list of length 2."):
21+
G = make_graph_linear_gaussian(G, node_mean_lims=[0], random_state=12345)
22+
23+
with pytest.raises(ValueError, match="must be a list of length 2."):
24+
G = make_graph_linear_gaussian(G, node_std_lims=[0], random_state=12345)
25+
26+
with pytest.raises(ValueError, match="must be a list of length 2."):
27+
G = make_graph_linear_gaussian(G, edge_weight_lims=[0], random_state=12345)
28+
29+
with pytest.raises(ValueError, match="The input graph must be a DAG."):
30+
G = make_graph_linear_gaussian(
31+
nx.cycle_graph(4, create_using=nx.DiGraph), random_state=12345
32+
)

0 commit comments

Comments
 (0)