Skip to content

Commit

Permalink
update CMPNN layers and update notebooks.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Dec 2, 2023
1 parent 48a176b commit 45a657c
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 1,266 deletions.
2 changes: 1 addition & 1 deletion docs/source/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.5"
}
},
"nbformat": 4,
Expand Down
46 changes: 42 additions & 4 deletions kgcnn/literature/CMPNN/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,38 @@

class PoolingNodesGRU(Layer):

def __init__(self, units, **kwargs):
def __init__(self, units, static_output_shape=None,
activation='tanh', recurrent_activation='sigmoid',
use_bias=True, kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros', kernel_regularizer=None,
recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None,
recurrent_constraint=None, bias_constraint=None, dropout=0.0,
recurrent_dropout=0.0, reset_after=True, seed=None,
**kwargs):
super(PoolingNodesGRU, self).__init__(**kwargs)
self.units = units
self.cast_layer = CastDisjointToBatchedAttributes(return_mask=True)
self.gru = GRU(units=units)
self.cast_layer = CastDisjointToBatchedAttributes(
static_output_shape=static_output_shape, return_mask=True)
self.gru = GRU(
units=units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
reset_after=reset_after,
seed=seed
)

def call(self, inputs, **kwargs):
n, mask = self.cast_layer(inputs)
Expand All @@ -17,5 +44,16 @@ def call(self, inputs, **kwargs):

def get_config(self):
config = super(PoolingNodesGRU, self).get_config()
config.update({"units": self.units})
config.update({"units": self.units, "static_output_shape": self.static_output_shape})
conf_gru = self.gru_cell.get_config()
param_list = ["units", "activation", "recurrent_activation",
"use_bias", "kernel_initializer",
"recurrent_initializer",
"bias_initializer", "kernel_regularizer",
"recurrent_regularizer", "bias_regularizer", "kernel_constraint",
"recurrent_constraint", "bias_constraint", "dropout",
"recurrent_dropout", "reset_after"]
for x in param_list:
if x in conf_gru.keys():
config.update({x: conf_gru[x]})
return config
Loading

0 comments on commit 45a657c

Please sign in to comment.