-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodels.py
executable file
·223 lines (168 loc) · 8.07 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""
File: models.py
Author: David Dalton
Description: Implements PrimalGraphEmulator GNN Architecture
"""
import jax
import jax.numpy as jnp
import jax.tree_util as tree
from flax import linen as nn
from typing import Sequence, Callable
DTYPE = jnp.float32
class FlaxMLP(nn.Module):
"""Implements an MLP in Flax"""
features: Sequence[int]
layer_norm: bool
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, dtype=DTYPE)(x)
if i != len(self.features) - 1:
x = nn.celu(x)
if self.layer_norm:
x = nn.LayerNorm()(x)
return x
def make_mlp(features: Sequence[int]):
"""Makes standard MLP
With hidden layers defined by features
"""
def update_fn(inputs):
return FlaxMLP(features, False)(inputs)
return update_fn
def make_layernorm_mlp(features: Sequence[int]):
"""Makes MLP followed by LayerNorm
With hidden layers specified by features
"""
def update_fn(inputs):
return FlaxMLP(features, True)(inputs)
return update_fn
def aggregate_incoming_messages(messages: jnp.ndarray, receivers: Sequence[int], n_nodes: int):
"""Sum aggregates incoming messages to each node
Performs the sum over incoming messages $\sum_{j \in \mathcal{N}_i} m_{ij}^k$
from the processor stage of Algorithm 2 of the manuscript, for all nodes $i$ similtaneously
"""
return jax.ops.segment_sum(messages, receivers, n_nodes)
def MessagePassingStep(node_update_fn: FlaxMLP, edge_update_fn: FlaxMLP, senders: Sequence[int], receivers: Sequence[int], n_real_nodes: int):
"""Returns function to perform one message passing step
Function _ApplyMessagePassingStep performs one step of message passing $k$ as
in the for loop in Algorithm 2 of the manuscript.
"""
def _ApplyMessagePassingStep(Vold: jnp.ndarray, Eold: jnp.ndarray):
# calculate messages along each directed edge with an edge feature vector assigned
messages = edge_update_fn(jnp.hstack((Eold, Vold[receivers], Vold[senders])))
# aggregate incoming messages m_{ij} from nodes i to j where i > j
received_messages_ij = aggregate_incoming_messages(messages, receivers, n_real_nodes)
# aggregate incoming messages m_{ij} from nodes i to j where i < j
# m_{ij} = -m_{ji} where i < j (momentum conservation property of the message passing)
received_messages_ji = aggregate_incoming_messages(-messages, senders, n_real_nodes)
# concatenate node representation with incoming messages and then update node representation
V = node_update_fn(jnp.hstack((Vold, received_messages_ij + received_messages_ji)))
# return updated node and edge representations with residual connection
return Vold + V, Eold + messages
return _ApplyMessagePassingStep
class PrimalGraphEmulator(nn.Module):
"""PrimalGraphEmulator (varying geometry data)"""
mlp_features: Sequence[int]
latent_size: Sequence[int]
K: int
receivers: Sequence[int]
senders: Sequence[int]
n_total_nodes: int
output_dim: int
real_node_indices: Sequence[bool]
boundary_adjust_fn: Sequence[bool] = None
@nn.compact
def __call__(self, V: jnp.ndarray, E: jnp.ndarray, theta: jnp.ndarray, sow_latents=False):
"""Implements Algorithm 2 of the manuscript: forward pass of PrimalGraphEmulator
Inputs:
---------
V: jnp.ndarray
Array giving feature vectors of each node (real and virtual)
E: jnp.ndarray
Array giving feature vectors of each edge
theta: jnp.ndarray
Vector of global graph parameters
Outputs:
---------
U: jnp.ndarray
Array of displacement predictions for each real node in V
"""
## Initialise internal MLPs:
# 3 encoder MLPs
node_encode_mlp=make_layernorm_mlp(self.mlp_features + self.latent_size)
edge_encode_mlp=make_layernorm_mlp(self.mlp_features + self.latent_size)
theta_encode_mlp=make_layernorm_mlp(self.mlp_features + self.latent_size)
# 2K processor MLPs
message_passing_blocks = [MessagePassingStep(make_layernorm_mlp(self.mlp_features + self.latent_size),
make_layernorm_mlp(self.mlp_features + self.latent_size),
self.senders, self.receivers, self.n_total_nodes) for i in range(self.K)]
# D decoder MLPs
node_decode_mlps = [make_mlp(self.mlp_features + (1,)) for i in range(self.output_dim[0])]
## Encoder:
V = node_encode_mlp(V)
E = edge_encode_mlp(E)
## Processor:
# perform K rounds of message passing
for message_pass_block_i in message_passing_blocks:
V, E = message_pass_block_i(V, E)
# aggregate incoming messages to each node
incoming_messages = aggregate_incoming_messages(E, self.receivers, self.n_total_nodes)
# final local learned representation is a concatenation of vector embedding and incoming messages
z_local = nn.LayerNorm()(jnp.hstack((V, incoming_messages))[self.real_node_indices])
# used for rapid evaluations of decoder for fixed geometry
if sow_latents: return z_local
## Decoder:
# encode global parameters theta
z_theta = theta_encode_mlp(theta)
# tile global parameter embeddings (z_theta) to each individual real node
z_theta_array = jnp.tile(z_theta, (z_local.shape[0], 1))
# final learned representations are comprised of (z_theta, z_local)
final_learned_representations = jnp.hstack((z_theta_array, z_local))
# make prediction for forward displacement using different decoder mlp for each dimension
individual_mlp_predictions = [decode_mlp(final_learned_representations) for decode_mlp in node_decode_mlps]
# concatenate the predictions of each individual decoder mlp
Upred = jnp.hstack(individual_mlp_predictions)
# adjust predictions to account for displacement boundary conditions
if self.boundary_adjust_fn is not None:
Upred = self.boundary_adjust_fn(Upred)
# return displacment prediction array
return Upred
class PrimalGraphEmulatorDecoder(nn.Module):
"""PrimalGraphEmulator (just decoder state)"""
mlp_features: Sequence[int]
output_dim: Sequence[int]
n_real_nodes: int
latent_nodal_values: jnp.ndarray
theta_encode_mlp_fn: Callable
boundary_adjust_fn: Sequence[float] = None
@nn.compact
def __call__(self, theta: jnp.ndarray):
"""Implements Decoder stage PrimalGraphEmulator
Input:
---------
theta: jnp.array
Vector giving the global parameters $\theta$ for the fixed geom
being modelled
Output:
---------
U: jnp.ndarray
Array of displacement predictiosn for each real node
"""
# initialise node-decode MLP
node_decode_mlps = [make_mlp(self.mlp_features + (1,)) for i in range(self.output_dim[0])]
# embed theta to higher dim space using pre-trained theta_encode_mlp
z_theta = self.theta_encode_mlp_fn(theta)
# tile global parameter embeddings (z_theta) to each individual real finite-element node
z_theta_array = jnp.tile(z_theta, (self.n_real_nodes,1))
# final learned representations are comprised of (z_theta, z_local)
final_learned_representations = jnp.hstack((z_theta_array, self.latent_nodal_values))
# make prediction for forward displacement using different decoder mlp for each dimension
individual_mlp_predictions = [decode_mlp(final_learned_representations) for decode_mlp in node_decode_mlps]
# concatenate the predictions of each individual decoder mlp
Upred = jnp.hstack(individual_mlp_predictions)
# adjust predictions to account for displacement boundary conditions
if self.boundary_adjust_fn is not None:
Upred = self.boundary_adjust_fn(Upred)
# return displacment prediction array
return Upred