From eb7d5bdb16cb935f2bfa03e7c9774dfd1b5c6f3c Mon Sep 17 00:00:00 2001 From: PatReis Date: Mon, 4 Dec 2023 19:00:26 +0100 Subject: [PATCH] update notebooks docs. Add further models (not tested!!) --- AUTHORS | 4 +- changelog.md | 2 +- docs/source/forces.ipynb | 187 ++--- docs/source/layers.ipynb | 3 +- kgcnn/layers/polynom.py | 242 +++++++ kgcnn/literature/DGIN/__init__.py | 0 kgcnn/literature/DGIN/_layers.py | 121 ++++ kgcnn/literature/DGIN/_make.py | 214 ++++++ kgcnn/literature/DGIN/_model.py | 105 +++ kgcnn/literature/DMPNN/_make.py | 3 +- kgcnn/literature/DimeNetPP/__init__.py | 0 kgcnn/literature/DimeNetPP/_layers.py | 390 +++++++++++ kgcnn/literature/DimeNetPP/_make.py | 397 +++++++++++ kgcnn/literature/DimeNetPP/_model.py | 154 +++++ kgcnn/literature/EGNN/__init__.py | 0 kgcnn/literature/EGNN/_make.py | 214 ++++++ kgcnn/literature/EGNN/_model.py | 106 +++ kgcnn/literature/GCN/_make.py | 3 - kgcnn/literature/GCN/_model.py | 4 + kgcnn/literature/GNNFilm/__init__.py | 0 kgcnn/literature/GNNFilm/_make.py | 157 +++++ kgcnn/literature/GNNFilm/_model.py | 49 ++ kgcnn/literature/Megnet/__init__.py | 0 kgcnn/literature/Megnet/_layers.py | 134 ++++ kgcnn/literature/Megnet/_make.py | 366 ++++++++++ kgcnn/literature/Megnet/_model.py | 193 ++++++ kgcnn/literature/RGCN/__init__.py | 0 kgcnn/literature/RGCN/_make.py | 165 +++++ kgcnn/literature/RGCN/_model.py | 51 ++ kgcnn/literature/Schnet/_make.py | 10 +- kgcnn/molecule/dynamics/base.py | 3 +- notebooks/workflow_qm_regression.ipynb | 907 +++++++++++++++++++++++-- 32 files changed, 4052 insertions(+), 132 deletions(-) create mode 100644 kgcnn/layers/polynom.py create mode 100644 kgcnn/literature/DGIN/__init__.py create mode 100644 kgcnn/literature/DGIN/_layers.py create mode 100644 kgcnn/literature/DGIN/_make.py create mode 100644 kgcnn/literature/DGIN/_model.py create mode 100644 kgcnn/literature/DimeNetPP/__init__.py create mode 100644 kgcnn/literature/DimeNetPP/_layers.py create mode 100644 kgcnn/literature/DimeNetPP/_make.py create mode 100644 kgcnn/literature/DimeNetPP/_model.py create mode 100644 kgcnn/literature/EGNN/__init__.py create mode 100644 kgcnn/literature/EGNN/_make.py create mode 100644 kgcnn/literature/EGNN/_model.py create mode 100644 kgcnn/literature/GNNFilm/__init__.py create mode 100644 kgcnn/literature/GNNFilm/_make.py create mode 100644 kgcnn/literature/GNNFilm/_model.py create mode 100644 kgcnn/literature/Megnet/__init__.py create mode 100644 kgcnn/literature/Megnet/_layers.py create mode 100644 kgcnn/literature/Megnet/_make.py create mode 100644 kgcnn/literature/Megnet/_model.py create mode 100644 kgcnn/literature/RGCN/__init__.py create mode 100644 kgcnn/literature/RGCN/_make.py create mode 100644 kgcnn/literature/RGCN/_model.py diff --git a/AUTHORS b/AUTHORS index 25ff00ad..43b03f5f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,3 +1,5 @@ List of contributors to kgcnn modules. -- GNNExplainer module by Robin Ruff +- GNNExplainer module by robinruff +- DGIN by thegodone + diff --git a/changelog.md b/changelog.md index ca10405b..52f3216a 100644 --- a/changelog.md +++ b/changelog.md @@ -27,7 +27,7 @@ Also be sure to check ``StandardLabelScaler`` if you want to scale regression ta * Input embedding in literature models is now controlled with separate ``input_node_embedding`` or ``input_edge_embedding`` arguments which can be set to `None` for no embedding. Also embedding input tokens must be of dtype int now. No auto-casting from float anymore. * New module ``kgcnn.ops`` with ``kgcnn.backend`` to generalize aggregation functions for graph operations. -* Reduced the models in literature. Will keep bringing all models of kgcnn<4.0.0 back in next versions. +* Reduced the models in literature. Will keep bringing all models of kgcnn<4.0.0 back in next versions and run benchmark training again. diff --git a/docs/source/forces.ipynb b/docs/source/forces.ipynb index d5345619..ece89355 100644 --- a/docs/source/forces.ipynb +++ b/docs/source/forces.ipynb @@ -48,8 +48,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:kgcnn.model.utils:Updated model kwargs:\n", - "INFO:kgcnn.model.utils:{'name': 'SchnetEnergy', 'inputs': [{'shape': [None], 'name': 'z', 'dtype': 'float32', 'ragged': True}, {'shape': [None, 3], 'name': 'R', 'dtype': 'float32', 'ragged': True}, {'shape': [None, 2], 'name': 'range_indices', 'dtype': 'int64', 'ragged': True}], 'input_embedding': {'node': {'input_dim': 95, 'output_dim': 128}}, 'make_distance': True, 'expand_distance': True, 'interaction_args': {'units': 128, 'use_bias': True, 'activation': 'kgcnn>shifted_softplus', 'cfconv_pool': 'sum'}, 'node_pooling_args': {'pooling_method': 'sum'}, 'depth': 6, 'gauss_args': {'bins': 25, 'distance': 5, 'offset': 0.0, 'sigma': 0.4}, 'verbose': 10, 'last_mlp': {'use_bias': [True, True, True], 'units': [128, 64, 1], 'activation': ['kgcnn>shifted_softplus', 'kgcnn>shifted_softplus', 'linear']}, 'output_embedding': 'graph', 'output_to_tensor': True, 'use_output_mlp': False, 'output_mlp': {'use_bias': [True, True], 'units': [64, 1], 'activation': ['kgcnn>shifted_softplus', 'linear']}}\n" + "INFO:kgcnn.models.utils:Updated model kwargs: '{'name': 'SchnetEnergy', 'inputs': [{'shape': [None], 'name': 'z', 'dtype': 'int64'}, {'shape': [None, 3], 'name': 'R', 'dtype': 'float32'}, {'shape': [None, 2], 'name': 'range_indices', 'dtype': 'int64'}, {'shape': (), 'name': 'total_nodes', 'dtype': 'int64'}, {'shape': (), 'name': 'total_ranges', 'dtype': 'int64'}], 'input_tensor_type': 'padded', 'input_embedding': None, 'cast_disjoint_kwargs': {}, 'input_node_embedding': {'input_dim': 95, 'output_dim': 128}, 'make_distance': True, 'expand_distance': True, 'interaction_args': {'units': 128, 'use_bias': True, 'activation': 'kgcnn>shifted_softplus', 'cfconv_pool': 'sum'}, 'node_pooling_args': {'pooling_method': 'sum'}, 'depth': 6, 'gauss_args': {'bins': 25, 'distance': 5, 'offset': 0.0, 'sigma': 0.4}, 'verbose': 10, 'last_mlp': {'use_bias': [True, True, True], 'units': [128, 64, 1], 'activation': ['kgcnn>shifted_softplus', 'kgcnn>shifted_softplus', 'linear']}, 'output_embedding': 'graph', 'output_to_tensor': None, 'use_output_mlp': False, 'output_tensor_type': 'padded', 'output_scaling': None, 'output_mlp': {}}'.\n" ] } ], @@ -58,13 +57,14 @@ "config= {\n", " \"name\": \"SchnetEnergy\",\n", " \"inputs\": [\n", - " {\"shape\": [None], \"name\": \"z\", \"dtype\": \"float32\", \"ragged\": True},\n", - " {\"shape\": [None, 3], \"name\": \"R\", \"dtype\": \"float32\", \"ragged\": True},\n", - " {\"shape\": [None, 2], \"name\": \"range_indices\", \"dtype\": \"int64\", \"ragged\": True}\n", + " {\"shape\": [None], \"name\": \"z\", \"dtype\": \"int64\"},\n", + " {\"shape\": [None, 3], \"name\": \"R\", \"dtype\": \"float32\"},\n", + " {\"shape\": [None, 2], \"name\": \"range_indices\", \"dtype\": \"int64\"},\n", + " {\"shape\": (), \"name\": \"total_nodes\", \"dtype\": \"int64\"},\n", + " {\"shape\": (), \"name\": \"total_ranges\", \"dtype\": \"int64\"}\n", " ],\n", - " \"input_embedding\": {\n", - " \"node\": {\"input_dim\": 95, \"output_dim\": 128}\n", - " },\n", + " \"input_tensor_type\": \"padded\",\n", + " \"input_node_embedding\": {\"input_dim\": 95, \"output_dim\": 128},\n", " \"last_mlp\": {\"use_bias\": [True, True, True], \"units\": [128, 64, 1],\n", " \"activation\": ['kgcnn>shifted_softplus', 'kgcnn>shifted_softplus', 'linear']},\n", " \"interaction_args\": {\n", @@ -95,12 +95,24 @@ "metadata": {}, "outputs": [], "source": [ - "from kgcnn.model.force import EnergyForceModel\n", + "from kgcnn.models.force import EnergyForceModel\n", "\n", "model_energy_force = EnergyForceModel(\n", + " inputs=[\n", + " {\"shape\": [None], \"name\": \"z\", \"dtype\": \"int32\"},\n", + " {\"shape\": [None, 3], \"name\": \"node_coordinates\", \"dtype\": \"float32\"},\n", + " {\"shape\": [None, 2], \"name\": \"range_indices\", \"dtype\": \"int64\"},\n", + " {\"shape\": (), \"name\": \"total_nodes\", \"dtype\": \"int64\"},\n", + " {\"shape\": (), \"name\": \"total_ranges\", \"dtype\": \"int64\"}\n", + " ],\n", + " name=\"SchnetForce\",\n", " model_energy = model_energy,\n", " output_to_tensor = False,\n", - " output_squeeze_states = True\n", + " output_squeeze_states = True,\n", + " outputs={\n", + " \"energy\": {\"name\": \"energy\", \"shape\": (1,)},\n", + " \"force\": {\"name\": \"force\", \"shape\": (None, 3)}\n", + " }\n", ")" ] }, @@ -145,7 +157,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataset (2000): dict_keys(['R', 'E', 'F', 'z', 'name', 'type', 'md5', 'theory', 'train', 'range_indices', 'range_attributes'])\n" + "Dataset (2000): dict_keys(['R', 'E', 'F', 'z', 'name', 'type', 'md5', 'theory', 'train', 'range_indices', 'range_attributes', 'total_nodes', 'total_ranges'])\n" ] } ], @@ -153,6 +165,8 @@ "from kgcnn.data.datasets.MD17Dataset import MD17Dataset\n", "dataset = MD17Dataset(\"ethanol_ccsd_t\")\n", "dataset.map_list(method=\"set_range\", node_coordinates=\"R\", max_distance=4.0)\n", + "dataset.map_list(method= \"count_nodes_and_edges\", total_edges= \"total_ranges\", count_edges= \"range_indices\", \n", + " count_nodes= \"z\", total_nodes= \"total_nodes\")\n", "# Change units to eV/A from kcal/mol\n", "dataset.set(\"E\", [mol[\"E\"]*0.0433634 for mol in dataset])\n", "dataset.set(\"F\", [mol[\"F\"]*0.0433634 for mol in dataset])\n", @@ -197,11 +211,11 @@ "from kgcnn.data.transform.scaler.force import EnergyForceExtensiveLabelScaler\n", "\n", "# Scaling energy and forces.\n", - "scaler = EnergyForceExtensiveLabelScaler()\n", - "scaler_mapping = {\"atomic_number\": \"z\", \"y\": [\"E\", \"F\"]}\n", - "scaler.fit_dataset(dataset_train, **scaler_mapping)\n", - "scaler.transform_dataset(dataset_train, **scaler_mapping)\n", - "scaler.transform_dataset(dataset_test, **scaler_mapping);" + "scaler_mapping = {\"atomic_number\": \"z\", \"energy\": \"E\", \"force\": \"F\"}\n", + "scaler = EnergyForceExtensiveLabelScaler(standardize_scale=False, **scaler_mapping)\n", + "scaler.fit_dataset(dataset_train);\n", + "scaler.transform_dataset(dataset_train)\n", + "scaler.transform_dataset(dataset_test);" ] }, { @@ -221,8 +235,8 @@ "source": [ "# Conversion to tensor input\n", "labels_in_dataset = {\n", - " \"energy\": {\"name\": \"E\", \"ragged\": False},\n", - " \"force\": {\"name\": \"F\", \"shape\": (None, 3), \"ragged\": True}\n", + " \"energy\": {\"name\": \"E\", \"shape\": (1,)},\n", + " \"force\": {\"name\": \"F\", \"shape\": (None, 3)}\n", "}\n", "y_train, y_test = dataset_train.tensor(labels_in_dataset), dataset_test.tensor(labels_in_dataset)\n", "x_train, x_test = dataset_train.tensor(config[\"inputs\"]), dataset_test.tensor(config[\"inputs\"])" @@ -243,14 +257,14 @@ "metadata": {}, "outputs": [], "source": [ - "from kgcnn.metrics.loss import RaggedMeanAbsoluteError\n", - "from tensorflow.keras.optimizers import Adam\n", + "from kgcnn.losses.losses import ForceMeanAbsoluteError\n", + "from keras.optimizers import Adam\n", "\n", "model_energy_force.compile(\n", - " loss={\"energy\": \"mean_absolute_error\", \"force\": RaggedMeanAbsoluteError()},\n", + " loss={\"energy\": \"mean_absolute_error\", \"force\": ForceMeanAbsoluteError()},\n", " optimizer=Adam(learning_rate=1e-03),\n", " metrics=None,\n", - " loss_weights=[1, 49],\n", + " loss_weights={\"energy\": 0.02, \"force\": 0.98},\n", ")" ] }, @@ -284,8 +298,28 @@ " callbacks=[\n", " LinearWarmupExponentialLRScheduler(lr_start=1e-03, gamma=0.995, epo_warmup=1, steps_per_epoch=32, verbose=1)\n", " ]\n", - ");\n", - "plot_train_test_loss([hist])" + ");" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5c65200d-ab11-45a6-8da1-fce3f71ca28b", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_train_test_loss([hist]);" ] }, { @@ -298,13 +332,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "41da557f", "metadata": {}, "outputs": [], "source": [ "# model_energy_force.save(\"model_energy_force\")\n", - "# model_energy_force = tf.keras.models.load_model('model_energy_force')" + "# model_energy_force = keras.models.load_model('model_energy_force')" ] }, { @@ -317,18 +351,18 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "85b8c40d", "metadata": {}, "outputs": [], "source": [ - "scaler.inverse_transform_dataset(dataset, **scaler_mapping)\n", + "scaler.inverse_transform_dataset(dataset)\n", "true_y = dataset_test.get(\"E\"), dataset_test.get(\"F\")" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "851ec0a3", "metadata": {}, "outputs": [], @@ -351,20 +385,18 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "123b82ca", "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -410,20 +442,23 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 14, "id": "682e083f", "metadata": {}, "outputs": [], "source": [ "from kgcnn.molecule.dynamics.base import MolDynamicsModelPredictor\n", "from kgcnn.graph.postprocessor import ExtensiveEnergyForceScalerPostprocessor\n", - "from kgcnn.graph.preprocessor import SetRange\n", + "from kgcnn.graph.preprocessor import SetRange, CountNodesAndEdges\n", "\n", "dyn_model = MolDynamicsModelPredictor(\n", " model=model_energy_force, \n", " model_inputs=config[\"inputs\"], \n", " model_outputs={\"energy\":\"energy\", \"forces\": \"force\"},\n", - " graph_preprocessors=[SetRange(node_coordinates= \"R\", max_distance=4.0)],\n", + " graph_preprocessors=[\n", + " SetRange(node_coordinates=\"R\", max_distance=4.0),\n", + " CountNodesAndEdges(total_edges=\"total_ranges\", count_edges=\"range_indices\", count_nodes=\"z\", total_nodes=\"total_nodes\")\n", + " ],\n", " graph_postprocessors=[\n", " ExtensiveEnergyForceScalerPostprocessor(\n", " scaler, force=\"forces\", atomic_number=\"z\")]\n", @@ -432,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 15, "id": "d118de18", "metadata": {}, "outputs": [], @@ -455,7 +490,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 16, "id": "a17953fb", "metadata": {}, "outputs": [ @@ -465,7 +500,7 @@ "Atoms(symbols='C2OH6', pbc=False)" ] }, - "execution_count": 23, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -480,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 17, "id": "26914421", "metadata": {}, "outputs": [ @@ -498,7 +533,7 @@ " [ 1.84879848e+00, -2.86324036e-02, -5.25690230e-01]])} ...]>" ] }, - "execution_count": 24, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -510,7 +545,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 18, "id": "bdcb03b7", "metadata": {}, "outputs": [], @@ -533,7 +568,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 19, "id": "0fd134a9", "metadata": {}, "outputs": [ @@ -541,15 +576,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'energy': array(-4210.10608608), 'forces': array([[-0.40810014, 1.44508323, -0.08648853],\n", - " [ 2.57021069, -1.84409073, -0.43325531],\n", - " [ 1.60649781, -0.1155006 , -0.78838767],\n", - " [ 0.05678861, -0.22312301, -0.18917786],\n", - " [ 0.07157569, -0.02854657, 0.07422443],\n", - " [-0.47191066, 0.49506172, -0.81499557],\n", - " [-1.67661084, 1.06380911, 0.68716817],\n", - " [-0.22691964, -0.66149807, 0.58455748],\n", - " [-1.52153114, -0.13119507, 0.96635494]])}\n" + "{'energy': array(-4210.1056349), 'forces': array([[-0.41029084, 1.445263 , -0.09325541],\n", + " [ 2.5735703 , -1.8340608 , -0.41830564],\n", + " [ 1.602669 , -0.08252943, -0.7950784 ],\n", + " [ 0.07904099, -0.23876965, -0.19849436],\n", + " [ 0.05621693, -0.04351348, 0.08634967],\n", + " [-0.46894395, 0.5051009 , -0.827881 ],\n", + " [-1.6850219 , 1.0685279 , 0.7068002 ],\n", + " [-0.21787213, -0.68336594, 0.5771504 ],\n", + " [-1.5293683 , -0.13665256, 0.9627145 ]], dtype=float32)}\n" ] } ], @@ -571,7 +606,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 20, "id": "ade3f668", "metadata": {}, "outputs": [ @@ -579,27 +614,27 @@ "name": "stdout", "output_type": "stream", "text": [ - "Energy per atom: Epot = -467.790eV Ekin = 0.048eV (T=372K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.801eV Ekin = 0.060eV (T=463K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.812eV Ekin = 0.071eV (T=546K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.806eV Ekin = 0.063eV (T=491K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.797eV Ekin = 0.057eV (T=438K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.805eV Ekin = 0.062eV (T=481K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.794eV Ekin = 0.053eV (T=408K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.780eV Ekin = 0.040eV (T=311K) Etot = -467.740eV\n", - "Energy per atom: Epot = -467.791eV Ekin = 0.050eV (T=390K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.793eV Ekin = 0.051eV (T=394K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.777eV Ekin = 0.037eV (T=288K) Etot = -467.740eV\n", - "Energy per atom: Epot = -467.793eV Ekin = 0.051eV (T=398K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.806eV Ekin = 0.065eV (T=501K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.798eV Ekin = 0.057eV (T=438K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.797eV Ekin = 0.055eV (T=427K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.809eV Ekin = 0.067eV (T=520K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.787eV Ekin = 0.048eV (T=371K) Etot = -467.739eV\n", - "Energy per atom: Epot = -467.799eV Ekin = 0.058eV (T=450K) Etot = -467.741eV\n", - "Energy per atom: Epot = -467.797eV Ekin = 0.055eV (T=424K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.796eV Ekin = 0.054eV (T=417K) Etot = -467.742eV\n", - "Energy per atom: Epot = -467.793eV Ekin = 0.051eV (T=398K) Etot = -467.742eV\n" + "Energy per atom: Epot = -467.790eV Ekin = 0.047eV (T=364K) Etot = -467.742eV\n", + "Energy per atom: Epot = -467.799eV Ekin = 0.055eV (T=429K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.804eV Ekin = 0.061eV (T=470K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.809eV Ekin = 0.066eV (T=511K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.812eV Ekin = 0.068eV (T=527K) Etot = -467.744eV\n", + "Energy per atom: Epot = -467.797eV Ekin = 0.053eV (T=413K) Etot = -467.744eV\n", + "Energy per atom: Epot = -467.797eV Ekin = 0.054eV (T=416K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.809eV Ekin = 0.065eV (T=507K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.795eV Ekin = 0.052eV (T=402K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.786eV Ekin = 0.043eV (T=335K) Etot = -467.742eV\n", + "Energy per atom: Epot = -467.801eV Ekin = 0.058eV (T=450K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.787eV Ekin = 0.044eV (T=342K) Etot = -467.742eV\n", + "Energy per atom: Epot = -467.798eV Ekin = 0.056eV (T=432K) Etot = -467.742eV\n", + "Energy per atom: Epot = -467.807eV Ekin = 0.064eV (T=495K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.799eV Ekin = 0.056eV (T=436K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.801eV Ekin = 0.058eV (T=446K) Etot = -467.743eV\n", + "Energy per atom: Epot = -467.803eV Ekin = 0.061eV (T=471K) Etot = -467.742eV\n", + "Energy per atom: Epot = -467.778eV Ekin = 0.037eV (T=289K) Etot = -467.741eV\n", + "Energy per atom: Epot = -467.791eV Ekin = 0.050eV (T=387K) Etot = -467.741eV\n", + "Energy per atom: Epot = -467.804eV Ekin = 0.063eV (T=484K) Etot = -467.742eV\n", + "Energy per atom: Epot = -467.800eV Ekin = 0.058eV (T=450K) Etot = -467.742eV\n" ] } ], @@ -644,7 +679,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.5" } }, "nbformat": 4, diff --git a/docs/source/layers.ipynb b/docs/source/layers.ipynb index 11d952b1..1abade87 100644 --- a/docs/source/layers.ipynb +++ b/docs/source/layers.ipynb @@ -28,6 +28,7 @@ " * `mlp` Multi-layer perceptron for graphs.\n", " * `modules` Keras layers and modules to support ragged tensor input.\n", " * `norm` Normalization layers for graph tensors.\n", + " * `polynom` Layers for Polynomials.\n", " * `pooling` General layers for standard aggregation and pooling.\n", " * `relational` Relational message processing.\n", " * `scale` Scaling layer to (constantly) rescale e.g. graph output.\n", @@ -84,7 +85,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.5" } }, "nbformat": 4, diff --git a/kgcnn/layers/polynom.py b/kgcnn/layers/polynom.py new file mode 100644 index 00000000..53aa18ac --- /dev/null +++ b/kgcnn/layers/polynom.py @@ -0,0 +1,242 @@ +import numpy as np +import scipy as sp +import scipy.special +from keras import ops +from scipy.optimize import brentq + + +def spherical_bessel_jn(r, n): + r"""Compute spherical Bessel function :math:`j_n(r)` via scipy. + The spherical bessel functions and there properties can be looked up at + https://en.wikipedia.org/wiki/Bessel_function#Spherical_Bessel_functions . + + Args: + r (np.ndarray): Argument + n (np.ndarray, int): Order. + + Returns: + np.array: Values of the spherical Bessel function + """ + return np.sqrt(np.pi / (2 * r)) * sp.special.jv(n + 0.5, r) + + +def spherical_bessel_jn_zeros(n, k): + r"""Compute the first :math:`k` zeros of the spherical bessel functions :math:`j_n(r)` up to + order :math:`n` (excluded). + Taken from the original implementation of DimeNet at https://github.com/klicperajo/dimenet. + + Args: + n: Order. + k: Number of zero crossings. + + Returns: + np.ndarray: List of zero crossings of shape (n, k) + """ + zerosj = np.zeros((n, k), dtype="float32") + zerosj[0] = np.arange(1, k + 1) * np.pi + points = np.arange(1, k + n) * np.pi + racines = np.zeros(k + n - 1, dtype="float32") + for i in range(1, n): + for j in range(k + n - 1 - i): + foo = brentq(spherical_bessel_jn, points[j], points[j + 1], (i,)) + racines[j] = foo + points = racines + zerosj[i][:k] = racines[:k] + + return zerosj + + +def spherical_bessel_jn_normalization_prefactor(n, k): + r"""Compute the normalization or rescaling pre-factor for the spherical bessel functions :math:`j_n(r)` up to + order :math:`n` (excluded) and maximum frequency :math:`k` (excluded). + Taken from the original implementation of DimeNet at https://github.com/klicperajo/dimenet. + + Args: + n: Order. + k: frequency. + + Returns: + np.ndarray: Normalization of shape (n, k) + """ + zeros = spherical_bessel_jn_zeros(n, k) + normalizer = [] + for order in range(n): + normalizer_tmp = [] + for i in range(k): + normalizer_tmp += [0.5 * spherical_bessel_jn(zeros[order, i], order + 1) ** 2] + normalizer_tmp = 1 / np.array(normalizer_tmp) ** 0.5 + normalizer += [normalizer_tmp] + return np.array(normalizer) + + +def tf_spherical_bessel_jn_explicit(x, n=0): + r"""Compute spherical bessel functions :math:`j_n(x)` for constant positive integer :math:`n` explicitly. + TensorFlow has to cache the function for each :math:`n`. No gradient through :math:`n` or very large number + of :math:`n`'s is possible. + The spherical bessel functions and there properties can be looked up at + https://en.wikipedia.org/wiki/Bessel_function#Spherical_Bessel_functions. + For this implementation the explicit expression from https://dlmf.nist.gov/10.49 has been used. + The definition is: + + :math:`a_{k}(n+\tfrac{1}{2})=\begin{cases}\dfrac{(n+k)!}{2^{k}k!(n-k)!},&k=0,1,\dotsc,n\\ + 0,&k=n+1,n+2,\dotsc\end{cases}` + + :math:`\mathsf{j}_{n}\left(z\right)=\sin\left(z-\tfrac{1}{2}n\pi\right)\sum_{k=0}^{\left\lfloor n/2\right\rfloor} + (-1)^{k}\frac{a_{2k}(n+\tfrac{1}{2})}{z^{2k+1}}+\cos\left(z-\tfrac{1}{2}n\pi\right) + \sum_{k=0}^{\left\lfloor(n-1)/2\right\rfloor}(-1)^{k}\frac{a_{2k+1}(n+\tfrac{1}{2})}{z^{2k+2}}.` + + Args: + x (tf.Tensor): Values to compute :math:`j_n(x)` for. + n (int): Positive integer for the bessel order :math:`n`. + + Returns: + tf.Tensor: Spherical bessel function of order :math:`n` + """ + sin_x = ops.sin(x - n * np.pi / 2) + cos_x = ops.cos(x - n * np.pi / 2) + sum_sin = ops.zeros_like(x) + sum_cos = ops.zeros_like(x) + for k in range(int(np.floor(n / 2)) + 1): + if 2 * k < n + 1: + prefactor = float(sp.special.factorial(n + 2 * k) / np.power(2, 2 * k) / sp.special.factorial( + 2 * k) / sp.special.factorial(n - 2 * k) * np.power(-1, k)) + sum_sin += prefactor * ops.power(x, - (2 * k + 1)) + for k in range(int(np.floor((n - 1) / 2)) + 1): + if 2 * k + 1 < n + 1: + prefactor = float(sp.special.factorial(n + 2 * k + 1) / np.power(2, 2 * k + 1) / sp.special.factorial( + 2 * k + 1) / sp.special.factorial(n - 2 * k - 1) * np.power(-1, k)) + sum_cos += prefactor * ops.power(x, - (2 * k + 2)) + return sum_sin * sin_x + sum_cos * cos_x + + +def tf_spherical_bessel_jn(x, n=0): + r"""Compute spherical bessel functions :math:`j_n(x)` for constant positive integer :math:`n` via recursion. + TensorFlow has to cache the function for each :math:`n`. No gradient through :math:`n` or very large number + of :math:`n` is possible. + The spherical bessel functions and there properties can be looked up at + https://en.wikipedia.org/wiki/Bessel_function#Spherical_Bessel_functions. + The recursive rule is constructed from https://dlmf.nist.gov/10.51. The recursive definition is: + + :math:`j_{n+1}(z)=((2n+1)/z)j_{n}(z)-j_{n-1}(z)` + + :math:`j_{0}(x)=\frac{\sin x}{x}` + + :math:`j_{1}(x)=\frac{1}{x}\frac{\sin x}{x} - \frac{\cos x}{x}` + + :math:`j_{2}(x)=\left(\frac{3}{x^{2}} - 1\right)\frac{\sin x}{x} - \frac{3}{x}\frac{\cos x}{x}` + + Args: + x (tf.Tensor): Values to compute :math:`j_n(x)` for. + n (int): Positive integer for the bessel order :math:`n`. + + Returns: + tf.tensor: Spherical bessel function of order :math:`n` + """ + if n < 0: + raise ValueError("Order parameter must be >= 0 for this implementation of spherical bessel function.") + if n == 0: + return ops.sin(x) / x + elif n == 1: + return ops.sin(x) / ops.square(x) - ops.cos(x) / x + else: + j_n = ops.sin(x) / x + j_nn = ops.sin(x) / ops.square(x) - ops.cos(x) / x + for i in range(1, n): + temp = j_nn + j_nn = (2 * i + 1) / x * j_nn - j_n + j_n = temp + return j_nn + + +def tf_legendre_polynomial_pn(x, n=0): + r"""Compute the (non-associated) Legendre polynomial :math:`P_n(x)` for constant positive integer :math:`n` + via explicit formula. + TensorFlow has to cache the function for each :math:`n`. No gradient through :math:`n` or very large number + of :math:`n` is possible. + Closed form can be viewed at https://en.wikipedia.org/wiki/Legendre_polynomials. + + :math:`P_n(x)=\sum_{k=0}^{\lfloor n/2\rfloor} (-1)^k \frac{(2n - 2k)! \, }{(n-k)! \, (n-2k)! \, k! \, 2^n} x^{n-2k}` + + Args: + x (tf.Tensor): Values to compute :math:`P_n(x)` for. + n (int): Positive integer for :math:`n` in :math:`P_n(x)`. + + Returns: + tf.tensor: Legendre Polynomial of order :math:`n`. + """ + out_sum = ops.zeros_like(x) + prefactors = [ + float((-1) ** k * sp.special.factorial(2 * n - 2 * k) / sp.special.factorial(n - k) / sp.special.factorial( + n - 2 * k) / sp.special.factorial(k) / 2 ** n) for k in range(0, int(np.floor(n / 2)) + 1)] + powers = [float(n - 2 * k) for k in range(0, int(np.floor(n / 2)) + 1)] + for i in range(len(powers)): + out_sum = out_sum + prefactors[i] * ops.power(x, powers[i]) + return out_sum + + +def tf_spherical_harmonics_yl(theta, l=0): + r"""Compute the spherical harmonics :math:`Y_{ml}(\cos\theta)` for :math:`m=0` and constant non-integer :math:`l`. + TensorFlow has to cache the function for each :math:`l`. No gradient through :math:`l` or very large number + of :math:`n` is possible. Uses a simplified formula with :math:`m=0` from + https://en.wikipedia.org/wiki/Spherical_harmonics: + + :math:`Y_{l}^{m}(\theta ,\phi)=\sqrt{\frac{(2l+1)}{4\pi} \frac{(l -m)!}{(l +m)!}} \, P_{l}^{m}(\cos{\theta }) \, + e^{i m \phi}` + + where the associated Legendre polynomial simplifies to :math:`P_l(x)` for :math:`m=0`: + + :math:`P_n(x)=\sum_{k=0}^{\lfloor n/2\rfloor} (-1)^k \frac{(2n - 2k)! \, }{(n-k)! \, (n-2k)! \, k! \, 2^n} x^{n-2k}` + + Args: + theta (tf.Tensor): Values to compute :math:`Y_l(\cos\theta)` for. + l (int): Positive integer for :math:`l` in :math:`Y_l(\cos\theta)`. + + Returns: + tf.tensor: Spherical harmonics for :math:`m=0` and constant non-integer :math:`l`. + """ + x = ops.cos(theta) + out_sum = ops.zeros_like(x) + prefactors = [ + float((-1) ** k * sp.special.factorial(2 * l - 2 * k) / sp.special.factorial(l - k) / sp.special.factorial( + l - 2 * k) / sp.special.factorial(k) / 2 ** l) for k in range(0, int(np.floor(l / 2)) + 1)] + powers = [float(l - 2 * k) for k in range(0, int(np.floor(l / 2)) + 1)] + for i in range(len(powers)): + out_sum = out_sum + prefactors[i] * ops.power(x, powers[i]) + out_sum = out_sum * float(np.sqrt((2 * l + 1) / 4 / np.pi)) + return out_sum + + +def tf_associated_legendre_polynomial(x, l=0, m=0): + r"""Compute the associated Legendre polynomial :math:`P_{l}^{m}(x)` for :math:`m` and constant positive + integer :math:`l` via explicit formula. + Closed Form from taken from https://en.wikipedia.org/wiki/Associated_Legendre_polynomials. + + :math:`P_{l}^{m}(x)=(-1)^{m}\cdot 2^{l}\cdot (1-x^{2})^{m/2}\cdot \sum_{k=m}^{l}\frac{k!}{(k-m)!}\cdot x^{k-m} + \cdot \binom{l}{k}\binom{\frac{l+k-1}{2}}{l}`. + + Args: + x (tf.Tensor): Values to compute :math:`P_{l}^{m}(x)` for. + l (int): Positive integer for :math:`l` in :math:`P_{l}^{m}(x)`. + m (int): Positive/Negative integer for :math:`m` in :math:`P_{l}^{m}(x)`. + + Returns: + tf.tensor: Legendre Polynomial of order n. + """ + if np.abs(m) > l: + raise ValueError("Error: Legendre polynomial must have -l<= m <= l") + if l < 0: + raise ValueError("Error: Legendre polynomial must have l>=0") + if m < 0: + m = -m + neg_m = float(np.power(-1, m) * sp.special.factorial(l - m) / sp.special.factorial(l + m)) + else: + neg_m = 1 + + x_prefactor = ops.power(1 - ops.square(x), m / 2) * float(np.power(-1, m) * np.power(2, l)) + sum_out = ops.zeros_like(x) + for k in range(m, l + 1): + sum_out += ops.power(x, k - m) * float( + sp.special.factorial(k) / sp.special.factorial(k - m) * sp.special.binom(l, k) * + sp.special.binom((l + k - 1) / 2, l)) + + return sum_out * x_prefactor * neg_m \ No newline at end of file diff --git a/kgcnn/literature/DGIN/__init__.py b/kgcnn/literature/DGIN/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/DGIN/_layers.py b/kgcnn/literature/DGIN/_layers.py new file mode 100644 index 00000000..16d4f745 --- /dev/null +++ b/kgcnn/literature/DGIN/_layers.py @@ -0,0 +1,121 @@ +from keras import ops +from kgcnn.layers.gather import GatherNodesOutgoing, GatherEdgesPairs +from kgcnn.layers.aggr import AggregateLocalEdges +from keras.layers import Subtract, Add, Layer + + +class DMPNNPPoolingEdgesDirected(Layer): # noqa + """Pooling of edges for around a target node as defined by + + `DMPNN `__ . This is slightly different as the normal node + aggregation from message passing like networks. Requires edge pair indices for this implementation. + """ + + def __init__(self, **kwargs): + """Initialize layer.""" + super(DMPNNPPoolingEdgesDirected, self).__init__(**kwargs) + self.pool_edge_1 = AggregateLocalEdges(pooling_method="scatter_sum") + self.gather_edges = GatherNodesOutgoing() + self.gather_pairs = GatherEdgesPairs() + self.subtract_layer = Subtract() + + def build(self, input_shape): + super(DMPNNPPoolingEdgesDirected, self).build(input_shape) + # Could call build on sub-layers but is not necessary. + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [nodes, edges, edge_index, edge_reverse_pair] + + - nodes (Tensor): Node embeddings of shape ([N], F) + - edges (Tensor): Edge or message embeddings of shape ([M], F) + - edge_index (Tensor): Edge indices referring to nodes of shape (2, [M]) + - edge_reverse_pair (Tensor): Pair mappings for reverse edges (1, [M]) + + Returns: + Tensor: Edge embeddings of shape ([M], F) + """ + n, ed, edi, edp = inputs + pool_edge_receive = self.pool_edge_1([n, ed, edi], **kwargs) # Sum pooling of all edges + ed_new = self.gather_edges([pool_edge_receive, edi], **kwargs) + ed_not = self.gather_pairs([ed, edp], **kwargs) + out = self.subtract_layer([ed_new, ed_not], **kwargs) + return out + + +class GIN_D(Layer): + r"""Convolutional unit of `Graph Isomorphism Network from: How Powerful are Graph Neural Networks? + `__ . + + Modified to use :math:`h_{w_0}` + + Computes graph convolution at step :math:`k` for node embeddings :math:`h_\nu` as: + + .. math:: + + h_\nu^{(k)} = \phi^{(k)} ((1+\epsilon^{(k)}) h_\nu^{0} + \sum_{u\in N(\nu)}) h_u^{k-1}. + + with optional learnable :math:`\epsilon^{(k)}` + + .. note:: + + The non-linear mapping :math:`\phi^{(k)}`, usually an :obj:`MLP`, is not included in this layer. + """ + + def __init__(self, + pooling_method='sum', + epsilon_learnable=False, + **kwargs): + """Initialize layer. + + Args: + epsilon_learnable (bool): If epsilon is learnable or just constant zero. Default is False. + pooling_method (str): Pooling method for summing edges. Default is 'segment_sum'. + """ + super(GIN_D, self).__init__(**kwargs) + self.pooling_method = pooling_method + self.epsilon_learnable = epsilon_learnable + + # Layers + self.lay_gather = GatherNodesOutgoing() + self.lay_pool = AggregateLocalEdges(pooling_method=self.pooling_method) + self.lay_add = Add() + + # Epsilon with trainable as optional and default zeros initialized. + self.eps_k = self.add_weight(name="epsilon_k", trainable=self.epsilon_learnable, + initializer="zeros", dtype=self.dtype) + + def build(self, input_shape): + """Build layer.""" + super(GIN_D, self).build(input_shape) + + def call(self, inputs, **kwargs): + r"""Forward pass. + Args: + inputs: [nodes, edge_index, nodes_0] + + - nodes (Tensor): Node embeddings of shape `([N], F)` + - edge_index (Tensor): Edge indices referring to nodes of shape `(2, [M])` + - nodes_0 (Tensor): Node embeddings of shape `([N], F)` + + Returns: + Tensor: Node embeddings of shape `([N], F)` + """ + # Need to check if edge_index is full and not half (directed). + node, edge_index, node_0 = inputs + ed = self.lay_gather([node, edge_index], **kwargs) + # Summing for each node connection + nu = self.lay_pool([node, ed, edge_index], **kwargs) + # Modified to use node_0 instead of node see equation 7 in paper. + no = (ops.convert_to_tensor(1, dtype=self.eps_k.dtype) + self.eps_k) * node_0 + out = self.lay_add([no, nu], **kwargs) + return out + + def get_config(self): + """Update config.""" + config = super(GIN_D, self).get_config() + config.update({"pooling_method": self.pooling_method, + "epsilon_learnable": self.epsilon_learnable}) + return config diff --git a/kgcnn/literature/DGIN/_make.py b/kgcnn/literature/DGIN/_make.py new file mode 100644 index 00000000..c4415ee8 --- /dev/null +++ b/kgcnn/literature/DGIN/_make.py @@ -0,0 +1,214 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from kgcnn.models.utils import update_model_kwargs +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from keras.backend import backend as backend_to_use +from kgcnn.layers.modules import Input +from ._model import model_disjoint + +# Keep track of model version from commit date in literature. +# To be updated if model is changed in a significant way. +__model_version__ = "2023-10-23" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] + +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'DGIN' is not supported." % backend_to_use()) + +# Implementation of DGIN in `keras` from paper: +# Analyzing Learned Molecular Representations for Property Prediction +# by Oliver Wieder, Mélaine Kuenemann, Marcus Wieder, Thomas Seidel, +# Christophe Meyer, Sharon D Bryant and Thierry Langer +# https://pubmed.ncbi.nlm.nih.gov/34684766/ + +model_default = { + "name": "DGIN", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None,), "name": "edge_number", "dtype": "int64"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (None, 1), "name": "edge_indices_reverse", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"}, + {"shape": (), "name": "total_reverse", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "cast_disjoint_kwargs": {}, + "input_embedding": None, # deprecated + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "input_graph_embedding": {"input_dim": 100, "output_dim": 64}, + "gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"], + "use_normalization": True, "normalization_technique": "graph_batch"}, + "gin_args": {}, + "last_mlp": {"use_bias": [True, True], "units": [64, 64], + "activation": ["relu", "relu"]}, + "pooling_args": {"pooling_method": "sum"}, + "use_graph_state": False, + "edge_initialize": {"units": 128, "use_bias": True, "activation": "relu"}, + "edge_dense": {"units": 128, "use_bias": True, "activation": "linear"}, + "edge_activation": {"activation": "relu"}, + "verbose": 10, + "depthDMPNN": 4, + "depthGIN": 4, + "dropoutDMPNN": {"rate": 0.15}, + "dropoutGIN": {"rate": 0.15}, + "output_embedding": "graph", + "node_pooling_kwargs": {}, + "output_to_tensor": None, # deprecated + "output_tensor_type": "padded", + "output_mlp": {"use_bias": True, "units": 1, + "activation": "linear"}, + "output_scaling": None, +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(name: str = None, + inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + input_graph_embedding: dict = None, + pooling_args: dict = None, + edge_initialize: dict = None, + edge_dense: dict = None, + edge_activation: dict = None, + dropoutDMPNN: dict = None, # noqa + dropoutGIN: dict = None, # noqa + depthDMPNN: int = None, # noqa + depthGIN: int = None, # noqa + gin_args: dict = None, + gin_mlp: dict = None, + last_mlp: dict = None, + verbose: int = None, + node_pooling_kwargs: dict = None, + use_graph_state: bool = False, + output_embedding: str = None, + output_to_tensor: bool = None, # noqa + output_tensor_type: str = None, + output_mlp: dict = None, + output_scaling: dict = None + ): + r"""Make `DGIN `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.DGIN.model_default` . + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_indices, reverse_indices, (graph_state), ...]` + with '...' indicating mask or id tensors following the template below. + Here, reverse indices are in place of angle indices and refer to edges. The graph state is optional and controlled + by `use_graph_state` parameter. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + name (str): Name of the model. Should be "DGIN". + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. + input_graph_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. + pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`AggregateLocalEdges` layers. + edge_initialize (dict): Dictionary of layer arguments unpacked in :obj:`Dense` layer for first edge embedding. + edge_dense (dict): Dictionary of layer arguments unpacked in :obj:`Dense` layer for edge embedding. + edge_activation (dict): Edge Activation after skip connection. + depthDMPNN (int): Number of graph embedding units or depth of the DMPNN subnetwork. + depthGIN (int): Number of graph embedding units or depth of the GIN subnetwork. + dropoutDMPNN (dict): Dictionary of layer arguments unpacked in :obj:`Dropout`. + dropoutGIN (float): dropout rate. + gin_args (dict): Kwargs unpacked in :obj:`GIN_D` convolutional unit. + gin_mlp (dict): Kwargs unpacked in :obj:`MLP` for GIN layer. + last_mlp (dict): Kwargs unpacked in last :obj:`MLP` . + verbose (int): Level for print information. + use_graph_state (bool): Whether to use graph state information. Default is False. + node_pooling_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): WDeprecated in favour of `output_tensor_type` . + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Kwargs for scaling layer, if scaling layer is to be used. + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + di = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_nodes=True, has_edges=True, + has_graph_state=use_graph_state, + has_angle_indices=True, # Treat reverse indices as edge indices + has_edge_indices=True + ) + + if use_graph_state: + n, ed, edi, e_pairs, gs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di + else: + n, ed, edi, e_pairs, batch_id_node, batch_id_edge, _, node_id, edge_id, _, count_nodes, count_edges, _ = di + gs = None + + # Wrapping disjoint model. + out = model_disjoint( + [n, ed, edi, batch_id_node, e_pairs, count_nodes, gs], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + use_graph_embedding=False if not use_graph_state else ( + "int" in inputs[4]['dtype']) if input_graph_embedding is not None else False, + use_graph_state=use_graph_state, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + input_graph_embedding=input_graph_embedding, + edge_initialize=edge_initialize, + edge_activation=edge_activation, + edge_dense=edge_dense, + depthDMPNN=depthDMPNN, + dropoutDMPNN=dropoutDMPNN, + pooling_args=pooling_args, + gin_mlp=gin_mlp, + depthGIN=depthGIN, + gin_args=gin_args, + output_embedding=output_embedding, + node_pooling_kwargs=node_pooling_kwargs, + last_mlp=last_mlp, + dropoutGIN=dropoutGIN, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/DGIN/_model.py b/kgcnn/literature/DGIN/_model.py new file mode 100644 index 00000000..505331d0 --- /dev/null +++ b/kgcnn/literature/DGIN/_model.py @@ -0,0 +1,105 @@ +from kgcnn.layers.modules import Embedding +from kgcnn.layers.mlp import MLP, GraphMLP +from kgcnn.layers.gather import GatherNodesOutgoing +from keras.layers import Concatenate, Dense, Activation, Add, Dropout +from kgcnn.layers.gather import GatherState +from kgcnn.layers.aggr import AggregateLocalEdges +from kgcnn.layers.pooling import PoolingNodes +from ._layers import DMPNNPPoolingEdgesDirected, GIN_D + + +def model_disjoint( + inputs, + use_node_embedding, + use_edge_embedding, + use_graph_embedding, + use_graph_state=None, + input_node_embedding=None, + input_edge_embedding=None, + input_graph_embedding=None, + edge_initialize=None, + edge_activation=None, + edge_dense=None, + depthDMPNN=None, + dropoutDMPNN=None, + pooling_args=None, + gin_mlp=None, + depthGIN=None, + gin_args=None, + output_embedding=None, + node_pooling_kwargs=None, + last_mlp=None, + dropoutGIN=None, + output_mlp=None +): + n, ed, edi, batch_id_node, ed_pairs, count_nodes, graph_state = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(n) + if use_edge_embedding: + ed = Embedding(**input_edge_embedding)(ed) + if use_graph_state: + if use_graph_embedding: + graph_state = Embedding(**input_graph_embedding)(graph_state) + + # Make first edge hidden h0 step 1 + h_n0 = GatherNodesOutgoing()([n, edi]) + h0 = Concatenate(axis=-1)([h_n0, ed]) + h0 = Dense(**edge_initialize)(h0) + h0 = Activation(**edge_activation)(h0) # relu equation 1 + + # One Dense layer for all message steps this is not the case in DGIN they are independents! + edge_dense_all = Dense(**edge_dense) # see equation 3 comments + + # Model Loop steps 2 & 3 + h = h0 + for i in range(depthDMPNN): + # equation 2 + m_vw = DMPNNPPoolingEdgesDirected()([n, h, edi, ed_pairs]) # ed_pairs for Directed Pooling! + # equation 3 + h = edge_dense_all(m_vw) # do one per layer ... + # h = Dense(**edge_dense)(m_vw) + h = Add()([h, h0]) + # remark : dropout before Activation in DGIN code + h = Activation(**edge_activation)(h) + if dropoutDMPNN is not None: + h = Dropout(**dropoutDMPNN)(h) + + # equation 4 & 5 + m_v = AggregateLocalEdges(**pooling_args)([n, h, edi]) + m_v = Concatenate(axis=-1)([n, m_v]) # + # equation 5b: hv = Dense(**node_dense)(mv) removed based on the paper + + # GIN_D part (this projection is normally not done in DGIN, but we need to get the correct "dim") + n_units = gin_mlp["units"][-1] if isinstance(gin_mlp["units"], list) else int(gin_mlp["units"]) + h_v = Dense(n_units, use_bias=True, activation='linear')(m_v) + h_v_0 = h_v + + list_embeddings = [h_v_0] # empty in the paper + for i in range(depthGIN): + # not sure of the mv, mv ... here but why not ;-) + h_v = GIN_D(**gin_args)( + [h_v, edi, h_v_0]) # equation 6 & 7a mv is new the new nodes values and we do pooling on ed via edi + h_v = GraphMLP(**gin_mlp)([h_v, batch_id_node, count_nodes]) # equation 7b + list_embeddings.append(h_v) + + # Output embedding choice look like it takes only the last h_v in the paper not all ??? + if output_embedding == 'graph': + out = [ + PoolingNodes(**node_pooling_kwargs)([count_nodes, x, batch_id_node]) for x in list_embeddings + ] # will return tensor equation 8 + out = [MLP(**last_mlp)(x) for x in out] # MLP apply per depthGIN + if dropoutGIN is not None: + out = [Dropout(**dropoutGIN)(x) for x in out] + out = Add()(out) + out = MLP(**output_mlp)(out) + elif output_embedding == 'node': + if use_graph_state: + graph_state_node = GatherState()([graph_state, batch_id_node]) + n = Concatenate()([n, graph_state_node]) + out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported graph embedding for mode `DGIN` .") + + return out diff --git a/kgcnn/literature/DMPNN/_make.py b/kgcnn/literature/DMPNN/_make.py index 6484d3db..ebf3e85d 100644 --- a/kgcnn/literature/DMPNN/_make.py +++ b/kgcnn/literature/DMPNN/_make.py @@ -148,7 +148,8 @@ def make_model(name: str = None, [n, ed, edi, batch_id_node, e_pairs, count_nodes, gs], use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, - use_graph_embedding=("int" in inputs[4]['dtype']) if input_graph_embedding is not None else False, + use_graph_embedding=False if not use_graph_state else ( + "int" in inputs[4]['dtype']) if input_graph_embedding is not None else False, input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, input_graph_embedding=input_graph_embedding, diff --git a/kgcnn/literature/DimeNetPP/__init__.py b/kgcnn/literature/DimeNetPP/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/DimeNetPP/_layers.py b/kgcnn/literature/DimeNetPP/_layers.py new file mode 100644 index 00000000..27083e1b --- /dev/null +++ b/kgcnn/literature/DimeNetPP/_layers.py @@ -0,0 +1,390 @@ +import keras as ks +from keras import ops +from keras.layers import Dense, Multiply, Add, Layer +from kgcnn.layers.aggr import AggregateLocalEdges +from kgcnn.layers.gather import GatherNodesOutgoing +from kgcnn.layers.mlp import GraphMLP +from kgcnn.layers.update import ResidualLayer + + +class DimNetInteractionPPBlock(Layer): + """DimNetPP Interaction Block as defined by `DimNetPP `__ . + + Args: + emb_size: Embedding size used for the messages + int_emb_size (int): Embedding size used for interaction triplets + basis_emb_size: Embedding size used inside the basis transformation + num_before_skip: Number of residual layers in interaction block before skip connection + num_after_skip: Number of residual layers in interaction block before skip connection + use_bias (bool, optional): Use bias. Defaults to True. + pooling_method (str): Pooling method information for layer. Default is 'sum'. + activation (str): Activation function. Default is "kgcnn>swish". + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'kgcnn>glorot_orthogonal'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + + def __init__(self, emb_size, + int_emb_size, + basis_emb_size, + num_before_skip, + num_after_skip, + use_bias=True, + pooling_method="sum", + activation='kgcnn>swish', # default is 'kgcnn>swish' + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_initializer="kgcnn>glorot_orthogonal", # default is 'kgcnn>glorot_orthogonal' + bias_initializer='zeros', + **kwargs): + super(DimNetInteractionPPBlock, self).__init__(**kwargs) + self.use_bias = use_bias + self.pooling_method = pooling_method + self.emb_size = emb_size + self.int_emb_size = int_emb_size + self.basis_emb_size = basis_emb_size + self.num_before_skip = num_before_skip + self.num_after_skip = num_after_skip + kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, + "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, + "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, + "bias_initializer": bias_initializer} + + # Transformations of Bessel and spherical basis representations + self.dense_rbf1 = Dense(basis_emb_size, use_bias=False, **kernel_args) + self.dense_rbf2 = Dense(emb_size, use_bias=False, **kernel_args) + self.dense_sbf1 = Dense(basis_emb_size, use_bias=False, **kernel_args) + self.dense_sbf2 = Dense(int_emb_size, use_bias=False, **kernel_args) + + # Dense transformations of input messages + self.dense_ji = Dense(emb_size, activation=activation, use_bias=True, **kernel_args) + self.dense_kj = Dense(emb_size, activation=activation, use_bias=True, **kernel_args) + + # Embedding projections for interaction triplets + self.down_projection = Dense(int_emb_size, activation=activation, use_bias=False, **kernel_args) + self.up_projection = Dense(emb_size, activation=activation, use_bias=False, **kernel_args) + + # Residual layers before skip connection + self.layers_before_skip = [] + for i in range(num_before_skip): + self.layers_before_skip.append( + ResidualLayer(emb_size, activation=activation, use_bias=True, **kernel_args)) + self.final_before_skip = Dense(emb_size, activation=activation, use_bias=True, **kernel_args) + + # Residual layers after skip connection + self.layers_after_skip = [] + for i in range(num_after_skip): + self.layers_after_skip.append( + ResidualLayer(emb_size, activation=activation, use_bias=True, **kernel_args)) + + self.lay_add1 = Add() + self.lay_add2 = Add() + self.lay_mult1 = Multiply() + self.lay_mult2 = Multiply() + + self.lay_gather = GatherNodesOutgoing() # Are edges here + self.lay_pool = AggregateLocalEdges(pooling_method=pooling_method) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [edges, rbf, sbf, angle_index] + + - edges (Tensor): Edge embeddings of shape ([M], F) + - rbf (Tensor): Radial basis features of shape ([M], F) + - sbf (Tensor): Spherical basis features of shape ([K], F) + - angle_index (tf.RaggedTensor): Angle indices referring to two edges of shape (2, [K]) + + Returns: + tf.RaggedTensor: Updated edge embeddings. + """ + x, rbf, sbf, id_expand = inputs + + # Initial transformation + x_ji = self.dense_ji(x, **kwargs) + x_kj = self.dense_kj(x, **kwargs) + + # Transform via Bessel basis + rbf = self.dense_rbf1(rbf, **kwargs) + rbf = self.dense_rbf2(rbf, **kwargs) + x_kj = self.lay_mult1([x_kj, rbf], **kwargs) + + # Down-project embeddings and generate interaction triplet embeddings + x_kj = self.down_projection(x_kj, **kwargs) + x_kj = self.lay_gather([x_kj, id_expand], **kwargs) + + # Transform via 2D spherical basis + sbf = self.dense_sbf1(sbf, **kwargs) + sbf = self.dense_sbf2(sbf, **kwargs) + x_kj = self.lay_mult1([x_kj, sbf], **kwargs) + + # Aggregate interactions and up-project embeddings + x_kj = self.lay_pool([rbf, x_kj, id_expand], **kwargs) + x_kj = self.up_projection(x_kj, **kwargs) + + # Transformations before skip connection + x2 = self.lay_add1([x_ji, x_kj], **kwargs) + for layer in self.layers_before_skip: + x2 = layer(x2, **kwargs) + x2 = self.final_before_skip(x2, **kwargs) + + # Skip connection + x = self.lay_add2([x, x2],**kwargs) + + # Transformations after skip connection + for layer in self.layers_after_skip: + x = layer(x, **kwargs) + + return x + + def get_config(self): + config = super(DimNetInteractionPPBlock, self).get_config() + config.update({"use_bias": self.use_bias, "pooling_method": self.pooling_method, "emb_size": self.emb_size, + "int_emb_size": self.int_emb_size, "basis_emb_size": self.basis_emb_size, + "num_before_skip": self.num_before_skip, "num_after_skip": self.num_after_skip}) + conf_dense = self.dense_ji.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: + config.update({x: conf_dense[x]}) + return config + + +class DimNetOutputBlock(Layer): + """DimNetPP Output Block as defined by `DimNetPP `__ . + + Args: + emb_size (list): List of node embedding dimension. + out_emb_size (list): List of edge embedding dimension. + num_dense (list): Number of dense layer for MLP. + num_targets (int): Number of output target dimension. Defaults to 12. + use_bias (bool, optional): Use bias. Defaults to True. + kernel_initializer: Initializer for kernels. Default is 'glorot_orthogonal' with fallback 'orthogonal'. + output_kernel_initializer: Initializer for last kernel. Default is 'zeros'. + bias_initializer: Initializer for bias. Default is 'zeros'. + activation (str): Activation function. Default is 'kgcnn>swish'. + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + pooling_method (str): Pooling method information for layer. Default is 'mean'. + """ + + def __init__(self, emb_size, + out_emb_size, + num_dense, + num_targets=12, + use_bias=True, + output_kernel_initializer="zeros", kernel_initializer='kgcnn>glorot_orthogonal', + bias_initializer='zeros', + activation='kgcnn>swish', + kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, + kernel_constraint=None, bias_constraint=None, + pooling_method="sum", + **kwargs): + """Initialize layer.""" + super(DimNetOutputBlock, self).__init__(**kwargs) + self.pooling_method = pooling_method + self.emb_size = emb_size + self.out_emb_size = out_emb_size + self.num_dense = num_dense + self.num_targets = num_targets + self.use_bias = use_bias + kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, + "kernel_constraint": kernel_constraint, "bias_initializer": bias_initializer, + "bias_regularizer": bias_regularizer, "bias_constraint": bias_constraint, } + + self.dense_rbf = Dense(emb_size, use_bias=False, kernel_initializer=kernel_initializer, **kernel_args) + self.up_projection = Dense(out_emb_size, use_bias=False, kernel_initializer=kernel_initializer, **kernel_args) + self.dense_mlp = GraphMLP([out_emb_size] * num_dense, activation=activation, + kernel_initializer=kernel_initializer, use_bias=use_bias, **kernel_args) + self.dimnet_mult = Multiply() + self.pool = AggregateLocalEdges(pooling_method=self.pooling_method) + self.dense_final = Dense(num_targets, use_bias=False, kernel_initializer=output_kernel_initializer, + **kernel_args) + + def build(self, input_shape): + """Build layer.""" + super(DimNetOutputBlock, self).build(input_shape) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [nodes, edges, tensor_index, state] + + - nodes (Tensor): Node embeddings of shape ([N], F) + - edges (Tensor): Edge or message embeddings of shape ([M], F) + - rbf (Tensor): Edge distance basis of shape ([M], F) + - tensor_index (Tensor): Edge indices referring to nodes of shape (2, [M]) + + Returns: + Tensor: Updated node embeddings of shape ([N], F_T). + """ + # Calculate edge Update + n_atoms, x, rbf, idnb_i = inputs + g = self.dense_rbf(rbf, **kwargs) + x = self.dimnet_mult([g, x], **kwargs) + x = self.pool([n_atoms, x, idnb_i], **kwargs) + x = self.up_projection(x, **kwargs) + x = self.dense_mlp(x, **kwargs) + x = self.dense_final(x, **kwargs) + return x + + def get_config(self): + config = super(DimNetOutputBlock, self).get_config() + conf_mlp = self.dense_mlp.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: + config.update({x: conf_mlp[x][0]}) + conf_dense_output = self.dense_final.get_config() + config.update({"output_kernel_initializer": conf_dense_output["kernel_initializer"]}) + config.update({"pooling_method": self.pooling_method, "use_bias": self.use_bias}) + config.update({"emb_size": self.emb_size, "out_emb_size": self.out_emb_size, "num_dense": self.num_dense, + "num_targets": self.num_targets}) + return config + + +class EmbeddingDimeBlock(Layer): + """Custom Embedding Block of `DimNetPP `__ . + + Naming of inputs here should match keras Embedding layer. + + Args: + input_dim (int): Integer. Size of the vocabulary, i.e. maximum integer index + 1. + output_dim (int): Integer. Dimension of the dense embedding. + embeddings_initializer: Initializer for the embeddings matrix (see keras.initializers). + embeddings_regularizer: Regularizer function applied to the embeddings matrix (see keras.regularizers). + embeddings_constraint: Constraint function applied to the embeddings matrix (see keras.constraints). + + """ + def __init__(self, + input_dim, # Vocabulary + output_dim, # Embedding size + embeddings_initializer='uniform', + embeddings_regularizer=None, + embeddings_constraint=None, + **kwargs): + super(EmbeddingDimeBlock, self).__init__(**kwargs) + self._supports_ragged_inputs = True + self.output_dim = output_dim + self.input_dim = input_dim + self.embeddings_initializer = ks.initializers.get(embeddings_initializer) + self.embeddings_regularizer = ks.regularizers.get(embeddings_regularizer) + self.embeddings_constraint = ks.constraints.get(embeddings_constraint) + + # Original implementation used initializer: + # embeddings_initializer = {'class_name': 'RandomUniform', 'config': {'minval': -1.7320508075688772, + # 'maxval': 1.7320508075688772, 'seed': None}} + self.embeddings = self.add_weight(name="embeddings", shape=(self.input_dim + 1, self.output_dim), + dtype=self.dtype, initializer=self.embeddings_initializer, + regularizer=self.embeddings_regularizer, + constraint=self.embeddings_constraint, + trainable=True) + + def call(self, inputs, **kwargs): + """Embedding of inputs. Forward pass.""" + out = ops.take(self.embeddings, inputs, axis=0) + return out + + def get_config(self): + config = super(EmbeddingDimeBlock, self).get_config() + config.update({"input_dim": self.input_dim, "output_dim": self.output_dim, + "embeddings_initializer": ks.initializers.serialize(self.embeddings_initializer), + "embeddings_regularizer": ks.regularizers.serialize(self.embeddings_regularizer), + "embeddings_constraint": ks.constraints.serialize(self.embeddings_constraint) + }) + return config + + +class SphericalBasisLayer(Layer): + r"""Expand a distance into a Bessel Basis with :math:`l=m=0`, according to + `Klicpera et al. 2020 `__ . + + Args: + num_spherical (int): Number of spherical basis functions + num_radial (int): Number of radial basis functions + cutoff (float): Cutoff distance c + envelope_exponent (int): Degree of the envelope to smoothen at cutoff. Default is 5. + + """ + + def __init__(self, num_spherical, + num_radial, + cutoff, + envelope_exponent=5, + **kwargs): + super(SphericalBasisLayer, self).__init__(**kwargs) + + assert num_radial <= 64 + self.num_radial = int(num_radial) + self.num_spherical = num_spherical + self.cutoff = cutoff + self.inv_cutoff = ops.convert_to_tensor(1.0 / cutoff, dtype=self.dtype) + self.envelope_exponent = envelope_exponent + + # retrieve formulas + self.bessel_n_zeros = spherical_bessel_jn_zeros(num_spherical, num_radial) + self.bessel_norm = spherical_bessel_jn_normalization_prefactor(num_spherical, num_radial) + + self.layer_gather_out = GatherNodesOutgoing() + + def envelope(self, inputs): + p = self.envelope_exponent + 1 + a = -(p + 1) * (p + 2) / 2 + b = p * (p + 2) + c = -p * (p + 1) / 2 + env_val = 1 / inputs + a * inputs ** (p - 1) + b * inputs ** p + c * inputs ** (p + 1) + return ops.where(inputs < 1, env_val, ops.zeros_like(inputs)) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [distance, angles, angle_index] + + - distance (Tensor): Edge distance of shape ([M], 1) + - angles (Tensor): Angle list of shape ([K], 1) + - angle_index (Tensor): Angle indices referring to edges of shape (2, [K]) + + Returns: + Tensor: Expanded angle/distance basis. Shape is ([K], #Radial * #Spherical) + """ + edge, angles, angle_index = inputs + edge, edge_part = inputs[0].values, inputs[0].row_splits + angles, angle_part = inputs[1].values, inputs[1].row_splits + + d = edge + d_scaled = d[:, 0] * self.inv_cutoff + rbf = [] + for n in range(self.num_spherical): + for k in range(self.num_radial): + rbf += [self.bessel_norm[n, k] * tf_spherical_bessel_jn(d_scaled * self.bessel_n_zeros[n][k], n)] + rbf = ops.stack(rbf, axis=1) + + d_cutoff = self.envelope(d_scaled) + rbf_env = d_cutoff[:, None] * rbf + rbf_env = self.layer_gather_out([rbf_env, angle_index], **kwargs).values + # rbf_env = tf.gather(rbf_env, id_expand_kj[:, 1]) + + cbf = [tf_spherical_harmonics_yl(angles[:, 0], n) for n in range(self.num_spherical)] + cbf = ops.stack(cbf, axis=1) + cbf = ops.repeat(cbf, self.num_radial, axis=1) + out = rbf_env * cbf + + return out + + def get_config(self): + """Update config.""" + config = super(SphericalBasisLayer, self).get_config() + config.update({"num_radial": self.num_radial, "cutoff": self.cutoff, + "envelope_exponent": self.envelope_exponent, "num_spherical": self.num_spherical}) + return config \ No newline at end of file diff --git a/kgcnn/literature/DimeNetPP/_make.py b/kgcnn/literature/DimeNetPP/_make.py new file mode 100644 index 00000000..7245b17c --- /dev/null +++ b/kgcnn/literature/DimeNetPP/_make.py @@ -0,0 +1,397 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from ._model import model_disjoint, model_disjoint_crystal +from kgcnn.layers.modules import Input +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from kgcnn.models.utils import update_model_kwargs +from keras.backend import backend as backend_to_use + +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-04" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'DimeNetPP' is not supported." % backend_to_use()) + +# Implementation of DimeNet++ in `keras` from paper: +# Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules +# Johannes Klicpera, Shankari Giri, Johannes T. Margraf, Stephan Günnemann +# https://arxiv.org/abs/2011.14115 +# Original code: https://github.com/gasteigerjo/dimenet + +model_default = { + "name": "DimeNetPP", + "inputs": [ + {"shape": [None], "name": "node_number", "dtype": "int64"}, + {"shape": [None, 3], "name": "node_coordinates", "dtype": "float32"}, + {"shape": [None, 2], "name": "edge_indices", "dtype": "int64"}, + {"shape": [None, 2], "name": "angle_indices", "dtype": "int64"}, + + ], + "input_tensor_type": "padded", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": { + "input_dim": 95, "output_dim": 128, "embeddings_initializer": { + "class_name": "RandomUniform", + "config": {"minval": -1.7320508075688772, "maxval": 1.7320508075688772}} + }, + "emb_size": 128, "out_emb_size": 256, "int_emb_size": 64, "basis_emb_size": 8, + "num_blocks": 4, "num_spherical": 7, "num_radial": 6, + "cutoff": 5.0, "envelope_exponent": 5, + "num_before_skip": 1, "num_after_skip": 2, "num_dense_output": 3, + "num_targets": 64, "extensive": True, "output_init": "zeros", + "activation": "swish", "verbose": 10, + "output_embedding": "graph", + "use_output_mlp": True, + "output_tensor_type": "padded", + "output_scaling": None, + "output_mlp": {"use_bias": [True, False], + "units": [64, 12], "activation": ["swish", "linear"]} +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + emb_size: int = None, + out_emb_size: int = None, + int_emb_size: int = None, + basis_emb_size: int = None, + num_blocks: int = None, + num_spherical: int = None, + num_radial: int = None, + cutoff: float = None, + envelope_exponent: int = None, + num_before_skip: int = None, + num_after_skip: int = None, + num_dense_output: int = None, + num_targets: int = None, + activation: str = None, + extensive: bool = None, + output_init: str = None, + verbose: int = None, # noqa + name: str = None, + output_embedding: str = None, + output_tensor_type: str = None, + use_output_mlp: bool = None, + output_mlp: dict = None, + output_scaling: dict = None + ): + """Make `DimeNetPP `_ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.DimeNetPP.model_default`. + + Model inputs: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, angle_indices...]` + with '...' indicating mask or ID tensors following the template below. + Note that you must supply angle indices as index pairs that refer to two edges. + + %s + + Model outputs: + The standard output template: + + %s + + + Args: + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. + emb_size (int): Overall embedding size used for the messages. + out_emb_size (int): Embedding size for output of :obj:`DimNetOutputBlock`. + int_emb_size (int): Embedding size used for interaction triplets. + basis_emb_size (int): Embedding size used inside the basis transformation. + num_blocks (int): Number of graph embedding blocks or depth of the network. + num_spherical (int): Number of spherical components in :obj:`SphericalBasisLayer`. + num_radial (int): Number of radial components in basis layer. + cutoff (float): Distance cutoff for basis layer. + envelope_exponent (int): Exponent in envelope function for basis layer. + num_before_skip (int): Number of residual layers in interaction block before skip connection + num_after_skip (int): Number of residual layers in interaction block after skip connection + num_dense_output (int): Number of dense units in output :obj:`DimNetOutputBlock`. + num_targets (int): Number of targets or output embedding dimension of the model. + activation (str, dict): Activation to use. + extensive (bool): Graph output for extensive target to apply sum for pooling or mean otherwise. + output_init (str, dict): Output initializer for kernel. + verbose (int): Level of verbosity. + name (str): Name of the model. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + use_output_mlp (bool): Whether to use the final output MLP. Possibility to skip final :obj:`MLP`. + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. Note that DimeNetPP originally defines the output dimension + via `num_targets`. But this can be set to `out_emb_size` and the `output_mlp` be used for more + specific control. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=False, + has_nodes=2, + has_angles=True, + ) + + n, x, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + + out = model_disjoint( + [n, x, disjoint_indices, batch_id_node, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + input_node_embedding=input_node_embedding, + emb_size=emb_size, + out_emb_size=out_emb_size, + int_emb_size=int_emb_size, + basis_emb_size=basis_emb_size, + num_blocks=num_blocks, + num_spherical=num_spherical, + num_radial=num_radial, + cutoff=cutoff, + envelope_exponent=envelope_exponent, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + num_dense_output=num_dense_output, + num_targets=num_targets, + activation=activation, + extensive=extensive, + output_init=output_init, + use_output_mlp=use_output_mlp, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) + +model_crystal_default = { + "name": "DimeNetPP", + "inputs": [ + {"shape": [None], "name": "node_number", "dtype": "int64", "ragged": True}, + {"shape": [None, 3], "name": "node_coordinates", "dtype": "float32", "ragged": True}, + {"shape": [None, 2], "name": "edge_indices", "dtype": "int64", "ragged": True}, + {"shape": [None, 2], "name": "angle_indices", "dtype": "int64", "ragged": True}, + {'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64', 'ragged': True}, + {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False} + ], + "input_tensor_type": "ragged", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": { + "input_dim": 95, "output_dim": 128, "embeddings_initializer": { + "class_name": "RandomUniform", + "config": {"minval": -1.7320508075688772, "maxval": 1.7320508075688772}} + }, + "emb_size": 128, "out_emb_size": 256, "int_emb_size": 64, "basis_emb_size": 8, + "num_blocks": 4, "num_spherical": 7, "num_radial": 6, + "cutoff": 5.0, "envelope_exponent": 5, + "num_before_skip": 1, "num_after_skip": 2, "num_dense_output": 3, + "num_targets": 64, "extensive": True, "output_init": "zeros", + "activation": "swish", "verbose": 10, + "output_embedding": "graph", + "use_output_mlp": True, + "output_tensor_type": "padded", + "output_scaling": None, + "output_mlp": {"use_bias": [True, False], + "units": [64, 12], "activation": ["swish", "linear"]} +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_crystal_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + emb_size: int = None, + out_emb_size: int = None, + int_emb_size: int = None, + basis_emb_size: int = None, + num_blocks: int = None, + num_spherical: int = None, + num_radial: int = None, + cutoff: float = None, + envelope_exponent: int = None, + num_before_skip: int = None, + num_after_skip: int = None, + num_dense_output: int = None, + num_targets: int = None, + activation: str = None, + extensive: bool = None, + output_init: str = None, + verbose: int = None, # noqa + name: str = None, + output_embedding: str = None, + output_tensor_type: str = None, + use_output_mlp: bool = None, + output_mlp: dict = None, + output_scaling: dict = None + ): + """Make `DimeNetPP `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.DimeNetPP.model_crystal_default`. + + .. note:: + + DimeNetPP does require a large amount of memory for this implementation, which increase quickly with + the number of connections in a batch. Use ragged input or dataloader if possible. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, angle_indices, image_translation, lattice, ...]` + with '...' indicating mask or ID tensors following the template below. + Note that you must supply angle indices as index pairs that refer to two edges. + + %s + + **Model outputs**: + The standard output template: + + %s + + + Args: + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. + emb_size (int): Overall embedding size used for the messages. + out_emb_size (int): Embedding size for output of :obj:`DimNetOutputBlock`. + int_emb_size (int): Embedding size used for interaction triplets. + basis_emb_size (int): Embedding size used inside the basis transformation. + num_blocks (int): Number of graph embedding blocks or depth of the network. + num_spherical (int): Number of spherical components in :obj:`SphericalBasisLayer`. + num_radial (int): Number of radial components in basis layer. + cutoff (float): Distance cutoff for basis layer. + envelope_exponent (int): Exponent in envelope function for basis layer. + num_before_skip (int): Number of residual layers in interaction block before skip connection + num_after_skip (int): Number of residual layers in interaction block after skip connection + num_dense_output (int): Number of dense units in output :obj:`DimNetOutputBlock`. + num_targets (int): Number of targets or output embedding dimension of the model. + activation (str, dict): Activation to use. + extensive (bool): Graph output for extensive target to apply sum for pooling or mean otherwise. + output_init (str, dict): Output initializer for kernel. + verbose (int): Level of verbosity. + name (str): Name of the model. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + use_output_mlp (bool): Whether to use the final output MLP. Possibility to skip final :obj:`MLP`. + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. Note that DimeNetPP originally defines the output dimension + via `num_targets`. But this can be set to `out_emb_size` and the `output_mlp` be used for more + specific control. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + disjoint_inputs = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=False, + has_nodes=2, + has_angles=True, + has_crystal_input=2 + ) + + n, x, djx, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = disjoint_inputs + + # Wrapp disjoint model + out = model_disjoint_crystal( + [n, x, djx, img, lattice, batch_id_node, batch_id_edge, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + input_node_embedding=input_node_embedding, + emb_size=emb_size, + out_emb_size=out_emb_size, + int_emb_size=int_emb_size, + basis_emb_size=basis_emb_size, + num_blocks=num_blocks, + num_spherical=num_spherical, + num_radial=num_radial, + cutoff=cutoff, + envelope_exponent=envelope_exponent, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + num_dense_output=num_dense_output, + num_targets=num_targets, + activation=activation, + extensive=extensive, + output_init=output_init, + use_output_mlp=use_output_mlp, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_crystal_model.__doc__ = make_crystal_model.__doc__ % ( + template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/DimeNetPP/_model.py b/kgcnn/literature/DimeNetPP/_model.py new file mode 100644 index 00000000..d14c73ce --- /dev/null +++ b/kgcnn/literature/DimeNetPP/_model.py @@ -0,0 +1,154 @@ +from keras.layers import Add, Subtract, Concatenate, Dense +from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, BesselBasisLayer, EdgeAngle, ShiftPeriodicLattice +from kgcnn.layers.gather import GatherNodes +from kgcnn.layers.pooling import PoolingNodes +from kgcnn.layers.mlp import MLP +from ._layers import DimNetInteractionPPBlock, EmbeddingDimeBlock, SphericalBasisLayer, DimNetOutputBlock + + +def model_disjoint( + inputs, + use_node_embedding, + input_node_embedding: dict = None, + emb_size: int = None, + out_emb_size: int = None, + int_emb_size: int = None, + basis_emb_size: int = None, + num_blocks: int = None, + num_spherical: int = None, + num_radial: int = None, + cutoff: float = None, + envelope_exponent: int = None, + num_before_skip: int = None, + num_after_skip: int = None, + num_dense_output: int = None, + num_targets: int = None, + activation: str = None, + extensive: bool = None, + output_init: str = None, + use_output_mlp: bool = None, + output_embedding: str = None, + output_mlp: dict = None +): + n, x, edi, adi, batch_id_node, count_nodes = inputs + + # Atom embedding + if use_node_embedding: + n = EmbeddingDimeBlock(**input_node_embedding)(n) + + # Calculate distances + pos1, pos2 = NodePosition()([x, edi]) + d = NodeDistanceEuclidean()([pos1, pos2]) + rbf = BesselBasisLayer(num_radial=num_radial, cutoff=cutoff, envelope_exponent=envelope_exponent)(d) + + # Calculate angles + v12 = Subtract()([pos1, pos2]) + a = EdgeAngle()([v12, adi]) + sbf = SphericalBasisLayer(num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff, + envelope_exponent=envelope_exponent)([d, a, adi]) + + # Embedding block + rbf_emb = Dense(emb_size, use_bias=True, activation=activation, + kernel_initializer="kgcnn>glorot_orthogonal")(rbf) + n_pairs = GatherNodes()([n, edi]) + x = Concatenate(axis=-1)([n_pairs, rbf_emb]) + x = Dense(emb_size, use_bias=True, activation=activation, kernel_initializer="kgcnn>glorot_orthogonal")(x) + ps = DimNetOutputBlock(emb_size, out_emb_size, num_dense_output, num_targets=num_targets, + output_kernel_initializer=output_init)([n, x, rbf, edi]) + + # Interaction blocks + add_xp = Add() + for i in range(num_blocks): + x = DimNetInteractionPPBlock(emb_size, int_emb_size, basis_emb_size, num_before_skip, num_after_skip)( + [x, rbf, sbf, adi]) + p_update = DimNetOutputBlock(emb_size, out_emb_size, num_dense_output, num_targets=num_targets, + output_kernel_initializer=output_init)([n, x, rbf, edi]) + ps = add_xp([ps, p_update]) + + if extensive: + out = PoolingNodes(pooling_method="sum")(ps) + else: + out = PoolingNodes(pooling_method="mean")(ps) + + if use_output_mlp: + out = MLP(**output_mlp)(out) + + if output_embedding != "graph": + raise ValueError("Unsupported output embedding for mode `DimeNetPP`. ") + + return out + + +def model_disjoint_crystal( + inputs, + use_node_embedding, + input_node_embedding: dict = None, + emb_size: int = None, + out_emb_size: int = None, + int_emb_size: int = None, + basis_emb_size: int = None, + num_blocks: int = None, + num_spherical: int = None, + num_radial: int = None, + cutoff: float = None, + envelope_exponent: int = None, + num_before_skip: int = None, + num_after_skip: int = None, + num_dense_output: int = None, + num_targets: int = None, + activation: str = None, + extensive: bool = None, + output_init: str = None, + use_output_mlp: bool = None, + output_embedding: str = None, + output_mlp: dict = None + ): + + n, x, edi, adi, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes = inputs + + # Atom embedding + if use_node_embedding: + n = EmbeddingDimeBlock(**input_node_embedding)(n) + + # Calculate distances + pos1, pos2 = NodePosition()([x, edi]) + pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice, batch_id_edge]) + d = NodeDistanceEuclidean()([pos1, pos2]) + rbf = BesselBasisLayer(num_radial=num_radial, cutoff=cutoff, envelope_exponent=envelope_exponent)(d) + + # Calculate angles + v12 = Subtract()([pos1, pos2]) + a = EdgeAngle()([v12, adi]) + sbf = SphericalBasisLayer(num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff, + envelope_exponent=envelope_exponent)([d, a, adi]) + + # Embedding block + rbf_emb = Dense(emb_size, use_bias=True, activation=activation, + kernel_initializer="kgcnn>glorot_orthogonal")(rbf) + n_pairs = GatherNodes()([n, edi]) + x = Concatenate(axis=-1)([n_pairs, rbf_emb]) + x = Dense(emb_size, use_bias=True, activation=activation, kernel_initializer="kgcnn>glorot_orthogonal")(x) + ps = DimNetOutputBlock(emb_size, out_emb_size, num_dense_output, num_targets=num_targets, + output_kernel_initializer=output_init)([n, x, rbf, edi]) + + # Interaction blocks + add_xp = Add() + for i in range(num_blocks): + x = DimNetInteractionPPBlock(emb_size, int_emb_size, basis_emb_size, num_before_skip, num_after_skip)( + [x, rbf, sbf, adi]) + p_update = DimNetOutputBlock(emb_size, out_emb_size, num_dense_output, num_targets=num_targets, + output_kernel_initializer=output_init)([n, x, rbf, edi]) + ps = add_xp([ps, p_update]) + + if extensive: + out = PoolingNodes(pooling_method="sum")(ps) + else: + out = PoolingNodes(pooling_method="mean")(ps) + + if use_output_mlp: + out = MLP(**output_mlp)(out) + + if output_embedding != "graph": + raise ValueError("Unsupported output embedding for mode `DimeNetPP`. ") + + return out diff --git a/kgcnn/literature/EGNN/__init__.py b/kgcnn/literature/EGNN/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/EGNN/_make.py b/kgcnn/literature/EGNN/_make.py new file mode 100644 index 00000000..295e85af --- /dev/null +++ b/kgcnn/literature/EGNN/_make.py @@ -0,0 +1,214 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from ._model import model_disjoint +from kgcnn.layers.modules import Input +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from kgcnn.models.utils import update_model_kwargs +from keras.backend import backend as backend_to_use + +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-04" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'EGNN' is not supported." % backend_to_use()) + +# Implementation of EGNN in `keras` from paper: +# E(n) Equivariant Graph Neural Networks +# by Victor Garcia Satorras, Emiel Hoogeboom, Max Welling (2021) +# https://arxiv.org/abs/2102.09844 + + +model_default = { + "name": "EGNN", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64", "ragged": True}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True}, + {"shape": (None, 10), "name": "edge_attributes", "dtype": "float32", "ragged": True}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True}, + ], + "input_tensor_type": "padded", + "cast_disjoint_kwargs": {}, + "input_embedding": None, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 95, "output_dim": 64}, + "depth": 4, + "node_mlp_initialize": None, + "euclidean_norm_kwargs": {"keepdims": True, "axis": 2}, + "use_edge_attributes": True, + "edge_mlp_kwargs": {"units": [64, 64], "activation": ["swish", "linear"]}, + "edge_attention_kwargs": None, # {"units: 1", "activation": "sigmoid"} + "use_normalized_difference": False, + "expand_distance_kwargs": None, + "coord_mlp_kwargs": {"units": [64, 1], "activation": ["swish", "linear"]}, # option: "tanh" at the end. + "pooling_coord_kwargs": {"pooling_method": "mean"}, + "pooling_edge_kwargs": {"pooling_method": "sum"}, + "node_normalize_kwargs": None, + "use_node_attributes": False, + "node_mlp_kwargs": {"units": [64, 64], "activation": ["swish", "linear"]}, + "use_skip": True, + "verbose": 10, + "node_decoder_kwargs": None, + "node_pooling_kwargs": {"pooling_method": "sum"}, + "output_embedding": "graph", + "output_to_tensor": None, # deprecated + "output_tensor_type": "padded", + "output_mlp": {"use_bias": [True, True], "units": [64, 1], + "activation": ["swish", "linear"]}, + "output_scaling": None, +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(name: str = None, + inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + depth: int = None, + euclidean_norm_kwargs: dict = None, + node_mlp_initialize: dict = None, + use_edge_attributes: bool = None, + edge_mlp_kwargs: dict = None, + edge_attention_kwargs: dict = None, + use_normalized_difference: bool = None, + expand_distance_kwargs: dict = None, + coord_mlp_kwargs: dict = None, + pooling_coord_kwargs: dict = None, + pooling_edge_kwargs: dict = None, + node_normalize_kwargs: dict = None, + use_node_attributes: bool = None, + node_mlp_kwargs: dict = None, + use_skip: bool = None, + verbose: int = None, # noqa + node_decoder_kwargs: dict = None, + node_pooling_kwargs: dict = None, + output_embedding: str = None, + output_to_tensor: bool = None, # noqa + output_mlp: dict = None, + output_tensor_type: str = None, + output_scaling: dict = None + ): + r"""Make `EGNN `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.EGNN.model_default`. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, node_coordinates, edge_attributes, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + name (str): Name of the model. Default is "EGNN". + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of arguments for nodes unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of arguments for edge unpacked in :obj:`Embedding` layers. + depth (int): Number of graph embedding units or depth of the network. + euclidean_norm_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`EuclideanNorm`. + node_mlp_initialize (dict): Dictionary of layer arguments unpacked in :obj:`GraphMLP` layer for start embedding. + use_edge_attributes (bool): Whether to use edge attributes including for example further edge information. + edge_mlp_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`GraphMLP` layer. + edge_attention_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`GraphMLP` layer. + use_normalized_difference (bool): Whether to use a normalized difference vector for nodes. + expand_distance_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`PositionEncodingBasisLayer`. + coord_mlp_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`GraphMLP` layer. + pooling_coord_kwargs (dict): + pooling_edge_kwargs (dict): + node_normalize_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`GraphLayerNormalization` layer. + use_node_attributes (bool): Whether to add node attributes before node MLP. + node_mlp_kwargs (dict): + use_skip (bool): + verbose (int): Level of verbosity. + node_decoder_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MLP` layer after graph network. + node_pooling_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layers. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Deprecated in favour of `output_tensor_type` . + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=True, + has_nodes=2 + ) + + n, x, ed, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + + out = model_disjoint( + [n, x, ed, disjoint_indices, batch_id_node, batch_id_edge, count_nodes, count_edges], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[2]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + depth=depth, + euclidean_norm_kwargs=euclidean_norm_kwargs, + node_mlp_initialize=node_mlp_initialize, + use_edge_attributes=use_edge_attributes, + edge_mlp_kwargs=edge_mlp_kwargs, + edge_attention_kwargs=edge_attention_kwargs, + use_normalized_difference=use_normalized_difference, + expand_distance_kwargs=expand_distance_kwargs, + coord_mlp_kwargs=coord_mlp_kwargs, + pooling_coord_kwargs=pooling_coord_kwargs, + pooling_edge_kwargs=pooling_edge_kwargs, + node_normalize_kwargs=node_normalize_kwargs, + use_node_attributes=use_node_attributes, + node_mlp_kwargs=node_mlp_kwargs, + use_skip=use_skip, + node_decoder_kwargs=node_decoder_kwargs, + node_pooling_kwargs=node_pooling_kwargs, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/EGNN/_model.py b/kgcnn/literature/EGNN/_model.py new file mode 100644 index 00000000..02acbe70 --- /dev/null +++ b/kgcnn/literature/EGNN/_model.py @@ -0,0 +1,106 @@ +from keras.layers import Add, Concatenate, Multiply, Subtract +from kgcnn.layers.gather import GatherNodes +from kgcnn.layers.aggr import AggregateLocalEdges +from kgcnn.layers.modules import Embedding +from kgcnn.layers.mlp import GraphMLP, MLP +from kgcnn.layers.norm import GraphLayerNormalization +from kgcnn.layers.geom import NodePosition, EuclideanNorm, EdgeDirectionNormalized, PositionEncodingBasisLayer +from kgcnn.layers.pooling import PoolingNodes + + +def model_disjoint( + inputs, + use_node_embedding, + use_edge_embedding, + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + depth: int = None, + euclidean_norm_kwargs: dict = None, + node_mlp_initialize: dict = None, + use_edge_attributes: bool = None, + edge_mlp_kwargs: dict = None, + edge_attention_kwargs: dict = None, + use_normalized_difference: bool = None, + expand_distance_kwargs: dict = None, + coord_mlp_kwargs: dict = None, + pooling_coord_kwargs: dict = None, + pooling_edge_kwargs: dict = None, + node_normalize_kwargs: dict = None, + use_node_attributes: bool = None, + node_mlp_kwargs: dict = None, + use_skip: bool = None, + node_decoder_kwargs: dict = None, + node_pooling_kwargs: dict = None, + output_embedding: str = None, + output_mlp: dict = None +): + h0, x, ed, edi, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs + # Make input + + # Embedding, if no feature dimension + if use_node_embedding: + h0 = Embedding(**input_node_embedding)(h0) + if use_edge_embedding: + ed = Embedding(**input_edge_embedding)(ed) + + # Model + h = GraphMLP(**node_mlp_initialize)([h0, batch_id_node, count_nodes]) if node_mlp_initialize else h0 + for i in range(0, depth): + pos1, pos2 = NodePosition()([x, edi]) + diff_x = Subtract()([pos1, pos2]) + norm_x = EuclideanNorm(**euclidean_norm_kwargs)(diff_x) + # Original code has a normalize option for coord-differences. + if use_normalized_difference: + diff_x = EdgeDirectionNormalized()([pos1, pos2]) + if expand_distance_kwargs: + norm_x = PositionEncodingBasisLayer()(norm_x) + + # Edge model + h_i, h_j = GatherNodes([0, 1], concat_axis=None)([h, edi]) + if use_edge_attributes: + m_ij = Concatenate()([h_i, h_j, norm_x, ed]) + else: + m_ij = Concatenate()([h_i, h_j, norm_x]) + if edge_mlp_kwargs: + m_ij = GraphMLP(**edge_mlp_kwargs)([m_ij, batch_id_edge, count_edges]) + if edge_attention_kwargs: + m_att = GraphMLP(**edge_attention_kwargs)([m_ij, batch_id_edge, count_edges]) + m_ij = Multiply()([m_att, m_ij]) + + # Coord model + if coord_mlp_kwargs: + m_ij_weights = GraphMLP(**coord_mlp_kwargs)([m_ij, batch_id_edge, count_edges]) + x_trans = Multiply()([m_ij_weights, diff_x]) + agg = AggregateLocalEdges(**pooling_coord_kwargs)([h, x_trans, edi]) + x = Add()([x, agg]) + + # Node model + m_i = AggregateLocalEdges(**pooling_edge_kwargs)([h, m_ij, edi]) + if node_mlp_kwargs: + m_i = Concatenate()([h, m_i]) + if use_node_attributes: + m_i = Concatenate()([m_i, h0]) + m_i = GraphMLP(**node_mlp_kwargs)([m_i, batch_id_node, count_nodes]) + if node_normalize_kwargs: + h = GraphLayerNormalization(**node_normalize_kwargs)([h, batch_id_node, count_nodes]) + if use_skip: + h = Add()([h, m_i]) + else: + h = m_i + + # Output embedding choice + if node_decoder_kwargs: + n = GraphMLP(**node_mlp_kwargs)([h, batch_id_node, count_nodes]) + else: + n = h + + # Final step. + if output_embedding == 'graph': + out = PoolingNodes(**node_pooling_kwargs)([count_nodes, n, batch_id_node]) + out = MLP(**output_mlp)(out) + elif output_embedding == 'node': + out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported output embedding for mode `SchNet`") + + return out diff --git a/kgcnn/literature/GCN/_make.py b/kgcnn/literature/GCN/_make.py index ec2153a0..e23b4866 100644 --- a/kgcnn/literature/GCN/_make.py +++ b/kgcnn/literature/GCN/_make.py @@ -6,9 +6,6 @@ from kgcnn.models.casting import template_cast_output, template_cast_list_input from keras.backend import backend as backend_to_use -# from keras_core.layers import Activation -# from kgcnn.layers.aggr import AggregateWeightedLocalEdges -# from kgcnn.layers.gather import GatherNodesOutgoing # Keep track of model version from commit date in literature. __kgcnn_model_version__ = "2023-09-30" diff --git a/kgcnn/literature/GCN/_model.py b/kgcnn/literature/GCN/_model.py index 138084f2..e15aa74b 100644 --- a/kgcnn/literature/GCN/_model.py +++ b/kgcnn/literature/GCN/_model.py @@ -4,6 +4,10 @@ from kgcnn.layers.modules import Embedding from kgcnn.layers.pooling import PoolingNodes, PoolingWeightedNodes +# from keras_core.layers import Activation +# from kgcnn.layers.aggr import AggregateWeightedLocalEdges +# from kgcnn.layers.gather import GatherNodesOutgoing + def model_disjoint(inputs, use_node_embedding: bool = None, diff --git a/kgcnn/literature/GNNFilm/__init__.py b/kgcnn/literature/GNNFilm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/GNNFilm/_make.py b/kgcnn/literature/GNNFilm/_make.py new file mode 100644 index 00000000..25d20ef2 --- /dev/null +++ b/kgcnn/literature/GNNFilm/_make.py @@ -0,0 +1,157 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from ._model import model_disjoint +from kgcnn.layers.modules import Input +from kgcnn.models.utils import update_model_kwargs +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from keras.backend import backend as backend_to_use + + +# Keep track of model version from commit date in literature. +__kgcnn_model_version__ = "2023-12-04" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'GNNFilm' is not supported." % backend_to_use()) + +# Implementation of GNNFilm in `keras` from paper: +# GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation +# Marc Brockschmidt +# https://arxiv.org/abs/1906.12192 + + +model_default = { + "name": "GNNFilm", + "inputs": [ + {"shape": (None,), "name": "node_attributes", "dtype": "int64"}, + {"shape": (None, ), "name": "edge_relations", "dtype": "int64"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "dense_relation_kwargs": {"units": 64, "num_relations": 20}, + "dense_modulation_kwargs": {"units": 64, "num_relations": 20, "activation": "sigmoid"}, + "activation_kwargs": {"activation": "swish"}, + "depth": 3, + "verbose": 10, + "node_pooling_kwargs": {}, + "output_embedding": 'graph', + "output_scaling": None, + "output_tensor_type": "padded", + "output_to_tensor": None, # deprecated + "output_mlp": {"use_bias": True, "units": 1, + "activation": "softmax"} +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + depth: int = None, + dense_relation_kwargs: dict = None, + dense_modulation_kwargs: dict = None, + activation_kwargs: dict = None, + name: str = None, + verbose: int = None, # noqa + node_pooling_kwargs: dict = None, + output_embedding: str = None, + output_to_tensor: bool = None, # noqa + output_scaling: dict = None, + output_tensor_type: str = None, + output_mlp: dict = None + ): + r"""Make `GNNFilm `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.RGCN.model_default`. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edge_relations, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below. + The edge relations do not have a feature dimension and specify the relation of each edge of type 'int'. + Edges are actually edge single weight values which are entries of the pre-scaled adjacency matrix. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. + depth (int): Number of graph embedding units or depth of the network. + dense_relation_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`RelationalDense` layer. + dense_modulation_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`RelationalDense` layer. + activation_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`Activation` layer. + name (str): Name of the model. + verbose (int): Level of print output. + node_pooling_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_to_tensor (bool): Whether to cast model output to :obj:`tf.Tensor`. + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj_inputs = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs + ) + + n, er, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs + + out = model_disjoint( + [n, er, disjoint_indices, batch_id_node, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + input_node_embedding=input_node_embedding, + depth=depth, + dense_modulation_kwargs=dense_modulation_kwargs, + dense_relation_kwargs=dense_relation_kwargs, + activation_kwargs=activation_kwargs, + output_embedding=output_embedding, + output_mlp=output_mlp, + node_pooling_kwargs=node_pooling_kwargs + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __kgcnn_model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/GNNFilm/_model.py b/kgcnn/literature/GNNFilm/_model.py new file mode 100644 index 00000000..ff830eaa --- /dev/null +++ b/kgcnn/literature/GNNFilm/_model.py @@ -0,0 +1,49 @@ +from keras.layers import Dense, Add, Multiply, Activation +from kgcnn.layers.modules import Embedding +from kgcnn.layers.gather import GatherNodes +from kgcnn.layers.relational import RelationalDense +from kgcnn.layers.aggr import AggregateLocalEdges +from kgcnn.layers.pooling import PoolingNodes +from kgcnn.layers.mlp import MLP, GraphMLP + + +def model_disjoint( + inputs, + use_node_embedding, + input_node_embedding=None, + depth=None, + dense_modulation_kwargs=None, + dense_relation_kwargs=None, + activation_kwargs=None, + output_embedding=None, + output_mlp=None, + node_pooling_kwargs=None +): + n, edge_relations, edi, batch_id_node, count_nodes = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(n) + + # Model + for i in range(0, depth): + n_i, n_j = GatherNodes(selection_index=[0, 1], concat_axis=None)([n, edi]) + # Note: This maybe could be done more efficiently. + gamma = RelationalDense(**dense_modulation_kwargs)([n_i, edge_relations]) + beta = RelationalDense(**dense_modulation_kwargs)([n_i, edge_relations]) + h_j = RelationalDense(**dense_relation_kwargs)([n_j, edge_relations]) + m = Multiply()([h_j, gamma]) + m = Add()([m, beta]) + h = AggregateLocalEdges(pooling_method="sum")([n, m, edi]) + n = Activation(**activation_kwargs)(h) + + # Output embedding choice + if output_embedding == "graph": + out = PoolingNodes(**node_pooling_kwargs)([count_nodes, n, batch_id_node]) # will return tensor + out = MLP(**output_mlp)(out) + elif output_embedding == "node": # Node labeling + out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported output embedding for mode `GNNFilm`") + + return out diff --git a/kgcnn/literature/Megnet/__init__.py b/kgcnn/literature/Megnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/Megnet/_layers.py b/kgcnn/literature/Megnet/_layers.py new file mode 100644 index 00000000..c3e71102 --- /dev/null +++ b/kgcnn/literature/Megnet/_layers.py @@ -0,0 +1,134 @@ +from keras.layers import Layer, Dense, Concatenate +from kgcnn.layers.gather import GatherNodes, GatherState +from kgcnn.layers.aggr import AggregateLocalEdges +from kgcnn.layers.pooling import PoolingNodes + + +PoolingGlobalEdges = PoolingNodes + + +class MEGnetBlock(Layer): + r"""Convolutional unit of `MegNet `_ called MegNet Block.""" + + def __init__(self, + node_embed=None, + edge_embed=None, + env_embed=None, + pooling_method="mean", + use_bias=True, + activation='kgcnn>softplus2', + kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, + kernel_constraint=None, bias_constraint=None, + kernel_initializer='glorot_uniform', bias_initializer='zeros', + **kwargs): + """Initialize layer. + + Args: + node_embed (list, optional): List of node embedding dimension. Defaults to [16,16,16]. + edge_embed (list, optional): List of edge embedding dimension. Defaults to [16,16,16]. + env_embed (list, optional): List of environment embedding dimension. Defaults to [16,16,16]. + pooling_method (str): Pooling method information for layer. Default is 'mean'. + use_bias (bool, optional): Use bias. Defaults to True. + activation (str): Activation function. Default is 'kgcnn>softplus2'. + kernel_regularizer: Kernel regularization. Default is None. + bias_regularizer: Bias regularization. Default is None. + activity_regularizer: Activity regularization. Default is None. + kernel_constraint: Kernel constrains. Default is None. + bias_constraint: Bias constrains. Default is None. + kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. + bias_initializer: Initializer for bias. Default is 'zeros'. + """ + super(MEGnetBlock, self).__init__(**kwargs) + self.pooling_method = pooling_method + if node_embed is None: + node_embed = [16, 16, 16] + if env_embed is None: + env_embed = [16, 16, 16] + if edge_embed is None: + edge_embed = [16, 16, 16] + self.node_embed = node_embed + self.edge_embed = edge_embed + self.env_embed = env_embed + self.use_bias = use_bias + kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, + "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, + "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, + "bias_initializer": bias_initializer, "use_bias": use_bias} + + # Node + self.lay_phi_n = Dense(units=self.node_embed[0], activation=activation, **kernel_args) + self.lay_phi_n_1 = Dense(units=self.node_embed[1], activation=activation, **kernel_args) + self.lay_phi_n_2 = Dense(units=self.node_embed[2], activation='linear', **kernel_args) + self.lay_esum = AggregateLocalEdges(pooling_method=self.pooling_method) + self.lay_gather_un = GatherState() + self.lay_conc_nu = Concatenate(axis=-1) + # Edge + self.lay_phi_e = Dense(units=self.edge_embed[0], activation=activation, **kernel_args) + self.lay_phi_e_1 = Dense(units=self.edge_embed[1], activation=activation, **kernel_args) + self.lay_phi_e_2 = Dense(units=self.edge_embed[2], activation='linear', **kernel_args) + self.lay_gather_n = GatherNodes() + self.lay_gather_ue = GatherState() + self.lay_conc_enu = Concatenate(axis=-1) + # Environment + self.lay_usum_e = PoolingGlobalEdges(pooling_method=self.pooling_method) + self.lay_usum_n = PoolingNodes(pooling_method=self.pooling_method) + self.lay_conc_u = Concatenate(axis=-1) + self.lay_phi_u = Dense(units=self.env_embed[0], activation=activation, **kernel_args) + self.lay_phi_u_1 = Dense(units=self.env_embed[1], activation=activation, **kernel_args) + self.lay_phi_u_2 = Dense(units=self.env_embed[2], activation='linear', **kernel_args) + + def build(self, input_shape): + """Build layer.""" + super(MEGnetBlock, self).build(input_shape) + + def call(self, inputs, **kwargs): + """Forward pass. + + Args: + inputs: [nodes, edges, tensor_index, state, batch_id_node, batch_id_edge, count_nodes, count_edges] + + - nodes (Tensor): Node embeddings of shape ([N], F) + - edges (Tensor): Edge or message embeddings of shape ([M], F) + - tensor_index (Tensor): Edge indices referring to nodes of shape (2, [M]) + - state (Tensor): State information for the graph, a single tensor of shape (batch, F) + - graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` . + - graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` . + - nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` . + - edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` . + + Returns: + Tensor: Updated node embeddings of shape ([N], F) + """ + # Calculate edge Update + node_input, edge_input, edge_index_input, env_input, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs + e_n = self.lay_gather_n([node_input, edge_index_input], **kwargs) + e_u = self.lay_gather_ue([env_input, batch_id_edge], **kwargs) + ec = self.lay_conc_enu([e_n, edge_input, e_u], **kwargs) + ep = self.lay_phi_e(ec, **kwargs) # Learning of Update Functions + ep = self.lay_phi_e_1(ep, **kwargs) # Learning of Update Functions + ep = self.lay_phi_e_2(ep, **kwargs) # Learning of Update Functions + # Calculate Node update + vb = self.lay_esum([node_input, ep, edge_index_input], **kwargs) # Summing for each node connections + v_u = self.lay_gather_un([env_input, batch_id_node], **kwargs) + vc = self.lay_conc_nu([vb, node_input, v_u], **kwargs) # LazyConcatenate node features with new edge updates + vp = self.lay_phi_n(vc, **kwargs) # Learning of Update Functions + vp = self.lay_phi_n_1(vp, **kwargs) # Learning of Update Functions + vp = self.lay_phi_n_2(vp, **kwargs) # Learning of Update Functions + # Calculate environment update + es = self.lay_usum_e([count_edges, ep, batch_id_edge], **kwargs) + vs = self.lay_usum_n([count_nodes, vp, batch_id_node], **kwargs) + ub = self.lay_conc_u([es, vs, env_input], **kwargs) + up = self.lay_phi_u(ub, **kwargs) + up = self.lay_phi_u_1(up, **kwargs) + up = self.lay_phi_u_2(up, **kwargs) # Learning of Update Functions + return vp, ep, up + + def get_config(self): + config = super(MEGnetBlock, self).get_config() + config.update({"pooling_method": self.pooling_method, "node_embed": self.node_embed, "use_bias": self.use_bias, + "edge_embed": self.edge_embed, "env_embed": self.env_embed}) + config_dense = self.lay_phi_n.get_config() + for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", + "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: + config.update({x: config_dense[x]}) + return config diff --git a/kgcnn/literature/Megnet/_make.py b/kgcnn/literature/Megnet/_make.py new file mode 100644 index 00000000..fac956eb --- /dev/null +++ b/kgcnn/literature/Megnet/_make.py @@ -0,0 +1,366 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from ._model import model_disjoint, model_disjoint_crystal +from kgcnn.layers.modules import Input +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from kgcnn.models.utils import update_model_kwargs +from keras.backend import backend as backend_to_use + +# To be updated if model is changed in a significant way. +__model_version__ = "2023-12-05" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'Megnet' is not supported." % backend_to_use()) + +# Implementation of Megnet in `tf.keras` from paper: +# Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals +# by Chi Chen, Weike Ye, Yunxing Zuo, Chen Zheng, and Shyue Ping Ong* +# https://github.com/materialsvirtuallab/megnet +# https://pubs.acs.org/doi/10.1021/acs.chemmater.9b01294 + + +model_default = { + "name": "Megnet", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "graph_number", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_graph_embedding": {"input_dim": 100, "output_dim": 64}, + "make_distance": True, "expand_distance": True, + "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, + "meg_block_args": {"node_embed": [64, 32, 32], "edge_embed": [64, 32, 32], + "env_embed": [64, 32, 32], "activation": "kgcnn>softplus2"}, + "set2set_args": {"channels": 16, "T": 3, "pooling_method": "sum", "init_qstar": "0"}, + "node_ff_args": {"units": [64, 32], "activation": "kgcnn>softplus2"}, + "edge_ff_args": {"units": [64, 32], "activation": "kgcnn>softplus2"}, + "state_ff_args": {"units": [64, 32], "activation": "kgcnn>softplus2"}, + "nblocks": 3, "has_ff": True, "dropout": None, "use_set2set": True, + "verbose": 10, + "output_embedding": "graph", + "output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1], + "activation": ["kgcnn>softplus2", "kgcnn>softplus2", "linear"]}, + "output_scaling": None +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_graph_embedding: dict = None, + expand_distance: bool = None, + make_distance: bool = None, + gauss_args: dict = None, + meg_block_args: dict = None, + set2set_args: dict = None, + node_ff_args: dict = None, + edge_ff_args: dict = None, + state_ff_args: dict = None, + use_set2set: bool = None, + nblocks: int = None, + has_ff: bool = None, + dropout: float = None, + name: str = None, + verbose: int = None, # noqa + output_embedding: str = None, + output_mlp: dict = None, + output_tensor_type: str = None, + output_scaling: dict = None + ): + r"""Make `MegNet `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.Megnet.model_default`. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, graph_state, ...]` with `make_distance` and + with '...' indicating mask or ID tensors following the template below. + Note that you could also supply edge features with `make_distance` to False, which would make the input + :obj:`[nodes, edges, edge_indices, graph_state...]` . + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. + input_graph_embedding (dict): Dictionary of embedding arguments for graph unpacked in :obj:`Embedding` layers. + make_distance (bool): Whether input is distance or coordinates at in place of edges. + expand_distance (bool): If the edge input are actual edges or node coordinates instead that are expanded to + form edges with a gauss distance basis given edge indices. Expansion uses `gauss_args`. + gauss_args (dict): Dictionary of layer arguments unpacked in :obj:`GaussBasisLayer` layer. + meg_block_args (dict): Dictionary of layer arguments unpacked in :obj:`MEGnetBlock` layer. + set2set_args (dict): Dictionary of layer arguments unpacked in `:obj:PoolingSet2SetEncoder` layer. + node_ff_args (dict): Dictionary of layer arguments unpacked in :obj:`MLP` feed-forward layer. + edge_ff_args (dict): Dictionary of layer arguments unpacked in :obj:`MLP` feed-forward layer. + state_ff_args (dict): Dictionary of layer arguments unpacked in :obj:`MLP` feed-forward layer. + use_set2set (bool): Whether to use :obj:`PoolingSet2SetEncoder` layer. + nblocks (int): Number of graph embedding blocks or depth of the network. + has_ff (bool): Use feed-forward MLP in each block. + dropout (int): Dropout to use. Default is None. + name (str): Name of the model. + verbose (int): Verbosity level of print. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=(not make_distance), + has_nodes=1 + int(make_distance), + has_graph_state=True + ) + + n, x, disjoint_indices, gs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + + out = model_disjoint( + [n, x, disjoint_indices, gs, batch_id_node, batch_id_edge, count_nodes, count_edges], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_graph_embedding=("int" in inputs[3]['dtype']) if input_graph_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_graph_embedding=input_graph_embedding, + expand_distance=expand_distance, + make_distance=make_distance, + gauss_args=gauss_args, + meg_block_args=meg_block_args, + set2set_args=set2set_args, + node_ff_args=node_ff_args, + edge_ff_args=edge_ff_args, + state_ff_args=state_ff_args, + use_set2set=use_set2set, + nblocks=nblocks, + has_ff=has_ff, + dropout=dropout, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) + +model_crystal_default = { + 'name': "Megnet", + 'inputs': [ + {'shape': (None,), 'name': "node_attributes", 'dtype': 'float32', 'ragged': True}, + {'shape': (None, 3), 'name': "node_coordinates", 'dtype': 'float32', 'ragged': True}, + {'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64', 'ragged': True}, + {'shape': [1], 'name': "charge", 'dtype': 'float32', 'ragged': False}, + {'shape': (None, 3), 'name': "edge_image", 'dtype': 'int64', 'ragged': True}, + {'shape': (3, 3), 'name': "graph_lattice", 'dtype': 'float32', 'ragged': False} + ], + "input_tensor_type": "ragged", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_graph_embedding": {"input_dim": 100, "output_dim": 64}, + "make_distance": True, "expand_distance": True, + 'gauss_args': {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, + 'meg_block_args': {'node_embed': [64, 32, 32], 'edge_embed': [64, 32, 32], + 'env_embed': [64, 32, 32], 'activation': 'kgcnn>softplus2'}, + 'set2set_args': {'channels': 16, 'T': 3, "pooling_method": "sum", "init_qstar": "0"}, + 'node_ff_args': {"units": [64, 32], "activation": "kgcnn>softplus2"}, + 'edge_ff_args': {"units": [64, 32], "activation": "kgcnn>softplus2"}, + 'state_ff_args': {"units": [64, 32], "activation": "kgcnn>softplus2"}, + 'nblocks': 3, 'has_ff': True, 'dropout': None, 'use_set2set': True, + 'verbose': 10, + 'output_embedding': 'graph', + 'output_mlp': {"use_bias": [True, True, True], "units": [32, 16, 1], + "activation": ['kgcnn>softplus2', 'kgcnn>softplus2', 'linear']}, + "output_scaling": None +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_crystal_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_graph_embedding: dict = None, + expand_distance: bool = None, + make_distance: bool = None, + gauss_args: dict = None, + meg_block_args: dict = None, + set2set_args: dict = None, + node_ff_args: dict = None, + edge_ff_args: dict = None, + state_ff_args: dict = None, + use_set2set: bool = None, + nblocks: int = None, + has_ff: bool = None, + dropout: float = None, + name: str = None, + verbose: int = None, # noqa + output_embedding: str = None, + output_mlp: dict = None, + output_tensor_type: str = None, + output_scaling: dict = None + ): + r"""Make `MegNet `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.Megnet.model_crystal_default`. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, coordinates, edge_indices, graph_state, image_translation, lattice, ...]` + with '...' indicating mask or ID tensors following the template below. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. + input_graph_embedding (dict): Dictionary of embedding arguments for graph unpacked in :obj:`Embedding` layers. + make_distance (bool): Whether input is distance or coordinates at in place of edges. + expand_distance (bool): If the edge input are actual edges or node coordinates instead that are expanded to + form edges with a gauss distance basis given edge indices. Expansion uses `gauss_args`. + gauss_args (dict): Dictionary of layer arguments unpacked in :obj:`GaussBasisLayer` layer. + meg_block_args (dict): Dictionary of layer arguments unpacked in :obj:`MEGnetBlock` layer. + set2set_args (dict): Dictionary of layer arguments unpacked in `:obj:PoolingSet2SetEncoder` layer. + node_ff_args (dict): Dictionary of layer arguments unpacked in :obj:`MLP` feed-forward layer. + edge_ff_args (dict): Dictionary of layer arguments unpacked in :obj:`MLP` feed-forward layer. + state_ff_args (dict): Dictionary of layer arguments unpacked in :obj:`MLP` feed-forward layer. + use_set2set (bool): Whether to use :obj:`PoolingSet2SetEncoder` layer. + nblocks (int): Number of graph embedding blocks or depth of the network. + has_ff (bool): Use feed-forward MLP in each block. + dropout (int): Dropout to use. Default is None. + name (str): Name of the model. + verbose (int): Verbosity level of print. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj = template_cast_list_input( + model_inputs, input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_edges=(not make_distance), + has_nodes=1 + int(make_distance), + has_graph_state=True, + has_crystal_input=2 + ) + + n, x, djx, gs, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj + + # Wrapp disjoint model + out = model_disjoint_crystal( + [n, x, djx, gs, img, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_graph_embedding=("int" in inputs[3]['dtype']) if input_graph_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_graph_embedding=input_graph_embedding, + expand_distance=expand_distance, + make_distance=make_distance, + gauss_args=gauss_args, + meg_block_args=meg_block_args, + set2set_args=set2set_args, + node_ff_args=node_ff_args, + edge_ff_args=edge_ff_args, + state_ff_args=state_ff_args, + use_set2set=use_set2set, + nblocks=nblocks, + has_ff=has_ff, + dropout=dropout, + output_embedding=output_embedding, + output_mlp=output_mlp + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + if scaler.extensive: + # Node information must be numbers, or we need an additional input. + out = scaler([out, n, batch_id_node]) + else: + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + + model.__kgcnn_model_version__ = __model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + + return model + + +make_crystal_model.__doc__ = make_crystal_model.__doc__ % ( + template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/Megnet/_model.py b/kgcnn/literature/Megnet/_model.py new file mode 100644 index 00000000..1b0c7938 --- /dev/null +++ b/kgcnn/literature/Megnet/_model.py @@ -0,0 +1,193 @@ +from kgcnn.layers.modules import Embedding +from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, GaussBasisLayer, ShiftPeriodicLattice +from kgcnn.layers.mlp import MLP, GraphMLP +from keras.layers import Dense, Dropout, Concatenate, Flatten, Add +from kgcnn.layers.pooling import PoolingNodes +from kgcnn.layers.set2set import PoolingSet2SetEncoder +from ._layers import MEGnetBlock + + +PoolingGlobalEdges = PoolingNodes + + +def model_disjoint( + inputs, + use_node_embedding, + use_graph_embedding, + input_node_embedding: dict = None, + input_graph_embedding: dict = None, + expand_distance: bool = None, + make_distance: bool = None, + gauss_args: dict = None, + meg_block_args: dict = None, + set2set_args: dict = None, + node_ff_args: dict = None, + edge_ff_args: dict = None, + state_ff_args: dict = None, + use_set2set: bool = None, + nblocks: int = None, + has_ff: bool = None, + dropout: float = None, + output_embedding: str = None, + output_mlp: dict = None, +): + # Make input + vp, x, edi, up, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + vp = Embedding(**input_node_embedding)(vp) + if use_graph_embedding: + up = Embedding(**input_graph_embedding)(up) + + # Edge distance as Gauss-Basis + if make_distance: + pos1, pos2 = NodePosition()([x, edi]) + ep = NodeDistanceEuclidean()([pos1, pos2]) + else: + ep = x + + if expand_distance: + ep = GaussBasisLayer(**gauss_args)(ep) + + # Model + vp = GraphMLP(**node_ff_args)([vp, batch_id_node, count_nodes]) + ep = GraphMLP(**edge_ff_args)([ep, batch_id_edge, count_edges]) + up = MLP(**state_ff_args)(up) + vp2 = vp + ep2 = ep + up2 = up + for i in range(0, nblocks): + if has_ff and i > 0: + vp2 = GraphMLP(**node_ff_args)([vp, batch_id_node, count_nodes]) + ep2 = GraphMLP(**edge_ff_args)([ep, batch_id_edge, count_edges]) + up2 = MLP(**state_ff_args)(up) + + # MEGnetBlock + vp2, ep2, up2 = MEGnetBlock(**meg_block_args)( + [vp2, ep2, edi, up2, batch_id_node, batch_id_edge, count_nodes, count_edges]) + + # skip connection + if dropout is not None: + vp2 = Dropout(dropout, name='dropout_atom_%d' % i)(vp2) + ep2 = Dropout(dropout, name='dropout_bond_%d' % i)(ep2) + up2 = Dropout(dropout, name='dropout_state_%d' % i)(up2) + + vp = Add()([vp2, vp]) + ep = Add()([ep2, ep]) + up = Add()([up2, up]) + + if use_set2set: + vp = Dense(set2set_args["channels"], activation='linear')(vp) # to match units + ep = Dense(set2set_args["channels"], activation='linear')(ep) # to match units + vp = PoolingSet2SetEncoder(**set2set_args)([count_nodes, vp, batch_id_node]) + ep = PoolingSet2SetEncoder(**set2set_args)([count_edges, ep, batch_id_edge]) + else: + vp = PoolingNodes()([count_nodes, vp, batch_id_node]) + ep = PoolingGlobalEdges()([count_edges, ep, batch_id_edge]) + + ep = Flatten()(ep) + vp = Flatten()(vp) + final_vec = Concatenate(axis=-1)([vp, ep, up]) + + if dropout is not None: + final_vec = Dropout(dropout, name='dropout_final')(final_vec) + + # Only graph embedding for Megnet + if output_embedding != "graph": + raise ValueError("Unsupported output embedding for mode `Megnet`.") + + main_output = MLP(**output_mlp)(final_vec) + return main_output + + +def model_disjoint_crystal( + inputs, + use_node_embedding, + use_graph_embedding, + input_node_embedding: dict = None, + input_graph_embedding: dict = None, + expand_distance: bool = None, + make_distance: bool = None, + gauss_args: dict = None, + meg_block_args: dict = None, + set2set_args: dict = None, + node_ff_args: dict = None, + edge_ff_args: dict = None, + state_ff_args: dict = None, + use_set2set: bool = None, + nblocks: int = None, + has_ff: bool = None, + dropout: float = None, + output_embedding: str = None, + output_mlp: dict = None, +): + vp, x, edi, up, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + vp = Embedding(**input_node_embedding)(vp) + if use_graph_embedding: + up = Embedding(**input_graph_embedding)(up) + + # Edge distance as Gauss-Basis + if make_distance: + pos1, pos2 = NodePosition()([x, edi]) + pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice]) + ep = NodeDistanceEuclidean()([pos1, pos2]) + else: + ep = x + + if expand_distance: + ep = GaussBasisLayer(**gauss_args)(ep) + + # Model + vp = GraphMLP(**node_ff_args)([vp, batch_id_edge, count_edges]) + ep = GraphMLP(**edge_ff_args)([ep, batch_id_edge, count_edges]) + up = MLP(**state_ff_args)(up) + vp2 = vp + ep2 = ep + up2 = up + for i in range(0, nblocks): + if has_ff and i > 0: + vp2 = GraphMLP(**node_ff_args)([vp, batch_id_node, count_nodes]) + ep2 = GraphMLP(**edge_ff_args)([ep, batch_id_edge, count_edges]) + up2 = MLP(**state_ff_args)(up) + + # MEGnetBlock + vp2, ep2, up2 = MEGnetBlock(**meg_block_args)( + [vp2, ep2, edi, up2, batch_id_node, batch_id_edge, count_nodes, count_edges]) + + # skip connection + if dropout is not None: + vp2 = Dropout(dropout, name='dropout_atom_%d' % i)(vp2) + ep2 = Dropout(dropout, name='dropout_bond_%d' % i)(ep2) + up2 = Dropout(dropout, name='dropout_state_%d' % i)(up2) + + vp = Add()([vp2, vp]) + ep = Add()([ep2, ep]) + up = Add()([up2, up]) + + if use_set2set: + vp = Dense(set2set_args["channels"], activation='linear')(vp) # to match units + ep = Dense(set2set_args["channels"], activation='linear')(ep) # to match units + vp = PoolingSet2SetEncoder(**set2set_args)([count_nodes, vp, batch_id_node]) + ep = PoolingSet2SetEncoder(**set2set_args)([count_edges, ep, batch_id_edge]) + else: + vp = PoolingNodes()([count_nodes, vp, batch_id_node]) + ep = PoolingGlobalEdges()([count_edges, ep, batch_id_edge]) + + ep = Flatten()(ep) + vp = Flatten()(vp) + final_vec = Concatenate(axis=-1)([vp, ep, up]) + + if dropout is not None: + final_vec = Dropout(dropout, name='dropout_final')(final_vec) + + # Only graph embedding for Megnet + if output_embedding != "graph": + raise ValueError("Unsupported output embedding for mode `Megnet`.") + + main_output = MLP(**output_mlp)(final_vec) + + return main_output diff --git a/kgcnn/literature/RGCN/__init__.py b/kgcnn/literature/RGCN/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kgcnn/literature/RGCN/_make.py b/kgcnn/literature/RGCN/_make.py new file mode 100644 index 00000000..220185f3 --- /dev/null +++ b/kgcnn/literature/RGCN/_make.py @@ -0,0 +1,165 @@ +import keras as ks +from kgcnn.layers.scale import get as get_scaler +from ._model import model_disjoint +from kgcnn.layers.modules import Input +from kgcnn.models.utils import update_model_kwargs +from kgcnn.models.casting import template_cast_output, template_cast_list_input +from keras.backend import backend as backend_to_use + + +# Keep track of model version from commit date in literature. +__kgcnn_model_version__ = "2023-12-04" + +# Supported backends +__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"] +if backend_to_use() not in __kgcnn_model_backend_supported__: + raise NotImplementedError("Backend '%s' for model 'RGCN' is not supported." % backend_to_use()) + +# Implementation of GCN in `keras` from paper: +# Modeling Relational Data with Graph Convolutional Networks +# Michael Schlichtkrull, Thomas N. Kipf, Peter Bloem, Rianne van den Berg, Ivan Titov and Max Welling +# https://arxiv.org/abs/1703.06103 + + +model_default = { + "name": "RGCN", + "inputs": [ + {"shape": (None,), "name": "node_number", "dtype": "int64"}, + {"shape": (None, 1), "name": "edge_weights", "dtype": "float32"}, + {"shape": (None, ), "name": "edge_relations", "dtype": "int64"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "input_tensor_type": "padded", + "input_embedding": None, # deprecated + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, + "dense_relation_kwargs": {"units": 64, "num_relations": 20}, + "dense_kwargs": {"units": 64}, + "activation_kwargs": {"activation": "swish"}, + "depth": 3, + "verbose": 10, + "output_embedding": 'graph', + "node_pooling_kwargs": {}, + "output_to_tensor": None, # deprecated + "output_tensor_type": "padded", + "output_mlp": {"use_bias": True, "units": 1, + "activation": "softmax"}, + "output_scaling": None, +} + + +@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) +def make_model(inputs: list = None, + input_tensor_type: str = None, + cast_disjoint_kwargs: dict = None, + input_embedding: dict = None, # noqa + input_node_embedding: dict = None, + input_edge_embedding: dict = None, + depth: int = None, + dense_relation_kwargs: dict = None, + dense_kwargs: dict = None, + activation_kwargs: dict = None, + name: str = None, + verbose: int = None, + output_embedding: str = None, + output_tensor_type: dict = None, + output_scaling: dict = None, + output_to_tensor: bool = None, # noqa + node_pooling_kwargs: dict = None, + output_mlp: dict = None + ): + r"""Make `RGCN `__ graph network via functional API. + Default parameters can be found in :obj:`kgcnn.literature.RGCN.model_default`. + + **Model inputs**: + Model uses the list template of inputs and standard output template. + The supported inputs are :obj:`[nodes, edges, edge_relations, edge_indices, ...]` + with '...' indicating mask or ID tensors following the template below. + The edge relations do not have a feature dimension and specify the relation of each edge of type 'int'. + Edges are actually edge single weight values which are entries of the pre-scaled adjacency matrix. + + %s + + **Model outputs**: + The standard output template: + + %s + + Args: + inputs (list): List of dictionaries unpacked in :obj:`Input`. Order must match model definition. + input_tensor_type (str): Input type of graph tensor. Default is "padded". + cast_disjoint_kwargs (dict): Dictionary of arguments for casting layers if used. + input_embedding (dict): Deprecated in favour of input_node_embedding etc. + input_node_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. + input_edge_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. + depth (int): Number of graph embedding units or depth of the network. + dense_relation_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`RelationalDense` layer. + dense_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`Dense` layer. + activation_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`Activation` layer. + name (str): Name of the model. + verbose (int): Level of print output. + output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". + output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None. + output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". + output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. + Defines number of model outputs and activation. + node_pooling_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer. + + Returns: + :obj:`keras.models.Model` + """ + # Make input + model_inputs = [Input(**x) for x in inputs] + + dj_inputs = template_cast_list_input( + model_inputs, + input_tensor_type=input_tensor_type, + cast_disjoint_kwargs=cast_disjoint_kwargs, + has_nodes=True, + has_edges=2, + has_edge_indices=True + ) + + n, ed, er, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = dj_inputs + + out = model_disjoint( + [n, ed, er, disjoint_indices, batch_id_node, count_nodes], + use_node_embedding=("int" in inputs[0]['dtype']) if input_node_embedding is not None else False, + use_edge_embedding=("int" in inputs[1]['dtype']) if input_edge_embedding is not None else False, + input_node_embedding=input_node_embedding, + input_edge_embedding=input_edge_embedding, + depth=depth, + dense_kwargs=dense_kwargs, + dense_relation_kwargs=dense_relation_kwargs, + activation_kwargs=activation_kwargs, + node_pooling_kwargs=node_pooling_kwargs, + output_mlp=output_mlp, + output_embedding=output_embedding + ) + + if output_scaling is not None: + scaler = get_scaler(output_scaling["name"])(**output_scaling) + out = scaler(out) + + # Output embedding choice + out = template_cast_output( + [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], + output_embedding=output_embedding, output_tensor_type=output_tensor_type, + input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs + ) + + model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) + model.__kgcnn_model_version__ = __kgcnn_model_version__ + + if output_scaling is not None: + def set_scale(*args, **kwargs): + scaler.set_scale(*args, **kwargs) + + setattr(model, "set_scale", set_scale) + return model + + +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) diff --git a/kgcnn/literature/RGCN/_model.py b/kgcnn/literature/RGCN/_model.py new file mode 100644 index 00000000..f4b2ee4d --- /dev/null +++ b/kgcnn/literature/RGCN/_model.py @@ -0,0 +1,51 @@ +from keras.layers import Dense, Add, Multiply, Activation +from kgcnn.layers.modules import Embedding +from kgcnn.layers.gather import GatherNodesOutgoing +from kgcnn.layers.relational import RelationalDense +from kgcnn.layers.aggr import AggregateLocalEdges +from kgcnn.layers.pooling import PoolingNodes +from kgcnn.layers.mlp import MLP, GraphMLP + + +def model_disjoint( + inputs, + use_node_embedding, + use_edge_embedding, + input_node_embedding=None, + input_edge_embedding=None, + depth=None, + dense_kwargs=None, + dense_relation_kwargs=None, + activation_kwargs=None, + node_pooling_kwargs=None, + output_mlp=None, + output_embedding=None +): + n, edge_weights, edge_relations, edi, batch_id_node, count_nodes = inputs + + # Embedding, if no feature dimension + if use_node_embedding: + n = Embedding(**input_node_embedding)(n) + if use_edge_embedding: + edge_weights = Embedding(**input_edge_embedding)(edge_weights) + + # Model + for i in range(0, depth): + n_j = GatherNodesOutgoing()([n, edi]) + h0 = Dense(**dense_kwargs)(n) + h_j = RelationalDense(**dense_relation_kwargs)([n_j, edge_relations]) + m = Multiply()([h_j, edge_weights]) + h = AggregateLocalEdges(pooling_method="sum")([n, m, edi]) + n = Add()([h, h0]) + n = Activation(**activation_kwargs)(n) + + # Output embedding choice + if output_embedding == "graph": + out = PoolingNodes(**node_pooling_kwargs)([count_nodes, n, batch_id_node]) # will return tensor + out = MLP(**output_mlp)(out) + elif output_embedding == "node": # Node labeling + out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) + else: + raise ValueError("Unsupported output embedding for mode `RGCN`") + + return out diff --git a/kgcnn/literature/Schnet/_make.py b/kgcnn/literature/Schnet/_make.py index 4951bada..69eeda3f 100644 --- a/kgcnn/literature/Schnet/_make.py +++ b/kgcnn/literature/Schnet/_make.py @@ -80,7 +80,7 @@ def make_model(inputs: list = None, r"""Make `SchNet `__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.Schnet.model_default` . - Model inputs: + **Model inputs**: Model uses the list template of inputs and standard output template. The supported inputs are :obj:`[nodes, coordinates, edge_indices, ...]` with `make_distance` and with '...' indicating mask or ID tensors following the template below. @@ -89,7 +89,7 @@ def make_model(inputs: list = None, %s - Model outputs: + **Model outputs**: The standard output template: %s @@ -99,8 +99,7 @@ def make_model(inputs: list = None, input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. input_embedding (dict): Deprecated in favour of input_node_embedding etc. - input_node_embedding (dict): Dictionary of embedding arguments for nodes etc. - unpacked in :obj:`Embedding` layers. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. make_distance (bool): Whether input is distance or coordinates at in place of edges. expand_distance (bool): If the edge input are actual edges or node coordinates instead that are expanded to form edges with a gauss distance basis given edge indices. Expansion uses `gauss_args`. @@ -245,8 +244,7 @@ def make_crystal_model(inputs: list = None, input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. input_embedding (dict): Deprecated in favour of input_node_embedding etc. - input_node_embedding (dict): Dictionary of embedding arguments for nodes etc. - unpacked in :obj:`Embedding` layers. + input_node_embedding (dict): Dictionary of embedding arguments for nodes unpacked in :obj:`Embedding` layers. make_distance (bool): Whether input is distance or coordinates at in place of edges. expand_distance (bool): If the edge input are actual edges or node coordinates instead that are expanded to form edges with a gauss distance basis given edge indices. Expansion uses `gauss_args`. diff --git a/kgcnn/molecule/dynamics/base.py b/kgcnn/molecule/dynamics/base.py index e09fa41d..45b96689 100644 --- a/kgcnn/molecule/dynamics/base.py +++ b/kgcnn/molecule/dynamics/base.py @@ -1,5 +1,6 @@ import time import keras as ks +from keras import ops import numpy as np from typing import Union, List, Callable, Dict from kgcnn.data.base import MemoryGraphList @@ -146,7 +147,7 @@ def __call__(self, graph_list: MemoryGraphList) -> MemoryGraphList: output_list = [] for i in range(num_samples): temp_dict = { - key: np.array(value[i]) for key, value in tensor_dict.items() + key: ops.convert_to_numpy(value[i]) for key, value in tensor_dict.items() } temp_dict = GraphDict(temp_dict) for mp in self.graph_postprocessors: diff --git a/notebooks/workflow_qm_regression.ipynb b/notebooks/workflow_qm_regression.ipynb index e7c87a5b..7f0b2860 100644 --- a/notebooks/workflow_qm_regression.ipynb +++ b/notebooks/workflow_qm_regression.ipynb @@ -291,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "7a24751c", "metadata": {}, "outputs": [ @@ -299,7 +299,607 @@ "name": "stdout", "output_type": "stream", "text": [ - "Running training on fold: 0\n" + "Running training on fold: 0\n" + ] + }, + { + "data": { + "text/html": [ + "
Model: \"Schnet\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"Schnet\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                   Param #  Connected to                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ node_number (InputLayer)      │ (None, None)              │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ range_indices (InputLayer)    │ (None, None, 2)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_nodes (InputLayer)      │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_ranges (InputLayer)     │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_coordinates (InputLayer) │ (None, None, 3)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_indices_to_disj… │ [(None), (2, None),       │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedIndicesToDisjoin… │ (None), (None), (None),   │             │ range_indices[0][0],           │\n",
+       "│                               │ (None), (None), (None)]   │             │ total_nodes[0][0],             │\n",
+       "│                               │                           │             │ total_ranges[0][0]             │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None, 3), (None),       │           0 │ node_coordinates[0][0],        │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None), (None)]           │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None), (None), (None),  │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None)]                   │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_position (NodePosition)  │ [(None, 3), (None, 3)]    │           0 │ cast_batched_attributes_to_di… │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ embedding (Embedding)         │ (None, 64)                │       6,080 │ cast_batched_attributes_to_di… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_distance_euclidean       │ (None, 1)                 │           0 │ node_position[0][0],           │\n",
+       "│ (NodeDistanceEuclidean)       │                           │             │ node_position[0][1]            │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ dense (Dense)                 │ (None, 128)               │       8,320 │ embedding[0][0]                │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ gauss_basis_layer             │ (None, 20)                │           0 │ node_distance_euclidean[0][0]  │\n",
+       "│ (GaussBasisLayer)             │                           │             │                                │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction           │ (None, 128)               │      68,608 │ dense[0][0],                   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer[0][0],       │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_1         │ (None, 128)               │      68,608 │ sch_net_interaction[0][0],     │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer[0][0],       │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_2         │ (None, 128)               │      68,608 │ sch_net_interaction_1[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer[0][0],       │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_3         │ (None, 128)               │      68,608 │ sch_net_interaction_2[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer[0][0],       │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ mlp (MLP)                     │ (None, 1)                 │      24,833 │ sch_net_interaction_3[0][0],   │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ pooling_nodes (PoolingNodes)  │ (None, 1)                 │           0 │ cast_batched_indices_to_disjo… │\n",
+       "│                               │                           │             │ mlp[0][0],                     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_disjoint_to_batched_gra… │ (None, 1)                 │           0 │ pooling_nodes[0][0]            │\n",
+       "│ (CastDisjointToBatchedGraphS… │                           │             │                                │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ node_number (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ range_indices (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_nodes (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_ranges (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_coordinates (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_indices_to_disj… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;34m2\u001b[0m, \u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedIndicesToDisjoin…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ │ range_indices[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ total_ranges[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_coordinates[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_position (\u001b[38;5;33mNodePosition\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m)] │ \u001b[38;5;34m0\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,080\u001b[0m │ cast_batched_attributes_to_di… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_distance_euclidean │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_position[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mNodeDistanceEuclidean\u001b[0m) │ │ │ node_position[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m8,320\u001b[0m │ embedding[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ gauss_basis_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_distance_euclidean[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mGaussBasisLayer\u001b[0m) │ │ │ │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ mlp (\u001b[38;5;33mMLP\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m24,833\u001b[0m │ sch_net_interaction_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ pooling_nodes (\u001b[38;5;33mPoolingNodes\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ cast_batched_indices_to_disjo… │\n", + "│ │ │ │ mlp[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_disjoint_to_batched_gra… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ pooling_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCastDisjointToBatchedGraphS…\u001b[0m │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Compiled with jit: False\n", + "Print Time for training: 1:46:01.859375\n", + "Running training on fold: 1\n" + ] + }, + { + "data": { + "text/html": [ + "
Model: \"Schnet\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"Schnet\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                   Param #  Connected to                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ node_number (InputLayer)      │ (None, None)              │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ range_indices (InputLayer)    │ (None, None, 2)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_nodes (InputLayer)      │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_ranges (InputLayer)     │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_coordinates (InputLayer) │ (None, None, 3)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_indices_to_disj… │ [(None), (2, None),       │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedIndicesToDisjoin… │ (None), (None), (None),   │             │ range_indices[0][0],           │\n",
+       "│                               │ (None), (None), (None)]   │             │ total_nodes[0][0],             │\n",
+       "│                               │                           │             │ total_ranges[0][0]             │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None, 3), (None),       │           0 │ node_coordinates[0][0],        │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None), (None)]           │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None), (None), (None),  │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None)]                   │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_position_1               │ [(None, 3), (None, 3)]    │           0 │ cast_batched_attributes_to_di… │\n",
+       "│ (NodePosition)                │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ embedding_1 (Embedding)       │ (None, 64)                │       6,080 │ cast_batched_attributes_to_di… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_distance_euclidean_1     │ (None, 1)                 │           0 │ node_position_1[0][0],         │\n",
+       "│ (NodeDistanceEuclidean)       │                           │             │ node_position_1[0][1]          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ dense_21 (Dense)              │ (None, 128)               │       8,320 │ embedding_1[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ gauss_basis_layer_1           │ (None, 20)                │           0 │ node_distance_euclidean_1[0][ │\n",
+       "│ (GaussBasisLayer)             │                           │             │                                │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_4         │ (None, 128)               │      68,608 │ dense_21[0][0],                │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_1[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_5         │ (None, 128)               │      68,608 │ sch_net_interaction_4[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_1[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_6         │ (None, 128)               │      68,608 │ sch_net_interaction_5[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_1[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_7         │ (None, 128)               │      68,608 │ sch_net_interaction_6[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_1[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ mlp_1 (MLP)                   │ (None, 1)                 │      24,833 │ sch_net_interaction_7[0][0],   │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ pooling_nodes_1               │ (None, 1)                 │           0 │ cast_batched_indices_to_disjo… │\n",
+       "│ (PoolingNodes)                │                           │             │ mlp_1[0][0],                   │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_disjoint_to_batched_gra… │ (None, 1)                 │           0 │ pooling_nodes_1[0][0]          │\n",
+       "│ (CastDisjointToBatchedGraphS… │                           │             │                                │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ node_number (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ range_indices (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_nodes (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_ranges (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_coordinates (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_indices_to_disj… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;34m2\u001b[0m, \u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedIndicesToDisjoin…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ │ range_indices[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ total_ranges[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_coordinates[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_position_1 │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m)] │ \u001b[38;5;34m0\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ (\u001b[38;5;33mNodePosition\u001b[0m) │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ embedding_1 (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,080\u001b[0m │ cast_batched_attributes_to_di… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_distance_euclidean_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_position_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mNodeDistanceEuclidean\u001b[0m) │ │ │ node_position_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ dense_21 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m8,320\u001b[0m │ embedding_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ gauss_basis_layer_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_distance_euclidean_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "│ (\u001b[38;5;33mGaussBasisLayer\u001b[0m) │ │ │ │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ dense_21[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_6 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_7 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ mlp_1 (\u001b[38;5;33mMLP\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m24,833\u001b[0m │ sch_net_interaction_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ pooling_nodes_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ cast_batched_indices_to_disjo… │\n", + "│ (\u001b[38;5;33mPoolingNodes\u001b[0m) │ │ │ mlp_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_disjoint_to_batched_gra… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ pooling_nodes_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCastDisjointToBatchedGraphS…\u001b[0m │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Compiled with jit: False\n", + "Print Time for training: 1:44:55.859375\n", + "Running training on fold: 2\n" + ] + }, + { + "data": { + "text/html": [ + "
Model: \"Schnet\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"Schnet\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                   Param #  Connected to                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ node_number (InputLayer)      │ (None, None)              │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ range_indices (InputLayer)    │ (None, None, 2)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_nodes (InputLayer)      │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_ranges (InputLayer)     │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_coordinates (InputLayer) │ (None, None, 3)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_indices_to_disj… │ [(None), (2, None),       │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedIndicesToDisjoin… │ (None), (None), (None),   │             │ range_indices[0][0],           │\n",
+       "│                               │ (None), (None), (None)]   │             │ total_nodes[0][0],             │\n",
+       "│                               │                           │             │ total_ranges[0][0]             │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None, 3), (None),       │           0 │ node_coordinates[0][0],        │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None), (None)]           │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None), (None), (None),  │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None)]                   │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_position_2               │ [(None, 3), (None, 3)]    │           0 │ cast_batched_attributes_to_di… │\n",
+       "│ (NodePosition)                │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ embedding_2 (Embedding)       │ (None, 64)                │       6,080 │ cast_batched_attributes_to_di… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_distance_euclidean_2     │ (None, 1)                 │           0 │ node_position_2[0][0],         │\n",
+       "│ (NodeDistanceEuclidean)       │                           │             │ node_position_2[0][1]          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ dense_42 (Dense)              │ (None, 128)               │       8,320 │ embedding_2[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ gauss_basis_layer_2           │ (None, 20)                │           0 │ node_distance_euclidean_2[0][ │\n",
+       "│ (GaussBasisLayer)             │                           │             │                                │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_8         │ (None, 128)               │      68,608 │ dense_42[0][0],                │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_2[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_9         │ (None, 128)               │      68,608 │ sch_net_interaction_8[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_2[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_10        │ (None, 128)               │      68,608 │ sch_net_interaction_9[0][0],   │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_2[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_11        │ (None, 128)               │      68,608 │ sch_net_interaction_10[0][0],  │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_2[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ mlp_2 (MLP)                   │ (None, 1)                 │      24,833 │ sch_net_interaction_11[0][0],  │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ pooling_nodes_2               │ (None, 1)                 │           0 │ cast_batched_indices_to_disjo… │\n",
+       "│ (PoolingNodes)                │                           │             │ mlp_2[0][0],                   │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_disjoint_to_batched_gra… │ (None, 1)                 │           0 │ pooling_nodes_2[0][0]          │\n",
+       "│ (CastDisjointToBatchedGraphS… │                           │             │                                │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ node_number (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ range_indices (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_nodes (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_ranges (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_coordinates (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_indices_to_disj… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;34m2\u001b[0m, \u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedIndicesToDisjoin…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ │ range_indices[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ total_ranges[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_coordinates[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_position_2 │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m)] │ \u001b[38;5;34m0\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ (\u001b[38;5;33mNodePosition\u001b[0m) │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ embedding_2 (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,080\u001b[0m │ cast_batched_attributes_to_di… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_distance_euclidean_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_position_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mNodeDistanceEuclidean\u001b[0m) │ │ │ node_position_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ dense_42 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m8,320\u001b[0m │ embedding_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ gauss_basis_layer_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_distance_euclidean_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "│ (\u001b[38;5;33mGaussBasisLayer\u001b[0m) │ │ │ │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_8 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ dense_42[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_9 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_10 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_9[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_11 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ mlp_2 (\u001b[38;5;33mMLP\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m24,833\u001b[0m │ sch_net_interaction_11[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ pooling_nodes_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ cast_batched_indices_to_disjo… │\n", + "│ (\u001b[38;5;33mPoolingNodes\u001b[0m) │ │ │ mlp_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_disjoint_to_batched_gra… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ pooling_nodes_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCastDisjointToBatchedGraphS…\u001b[0m │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Compiled with jit: False\n", + "Print Time for training: 1:45:27.593750\n", + "Running training on fold: 3\n" ] }, { @@ -342,44 +942,44 @@ "│ cast_batched_attributes_to_d… │ [(None), (None), (None), │ 0 │ node_number[0][0], │\n", "│ (CastBatchedAttributesToDisj… │ (None)] │ │ total_nodes[0][0] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ node_position (NodePosition) │ [(None, 3), (None, 3)] │ 0 │ cast_batched_attributes_to_di… │\n", - "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "│ node_position_3 │ [(None, 3), (None, 3)] │ 0 │ cast_batched_attributes_to_di… │\n", + "│ (NodePosition) │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ embedding (Embedding) │ (None, 64) │ 6,080 │ cast_batched_attributes_to_di… │\n", + "│ embedding_3 (Embedding) │ (None, 64) │ 6,080 │ cast_batched_attributes_to_di… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ node_distance_euclidean │ (None, 1) │ 0 │ node_position[0][0], │\n", - "│ (NodeDistanceEuclidean) │ │ │ node_position[0][1] │\n", + "│ node_distance_euclidean_3 │ (None, 1) │ 0 │ node_position_3[0][0], │\n", + "│ (NodeDistanceEuclidean) │ │ │ node_position_3[0][1] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ dense (Dense) │ (None, 128) │ 8,320 │ embedding[0][0] │\n", + "│ dense_63 (Dense) │ (None, 128) │ 8,320 │ embedding_3[0][0] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ gauss_basis_layer │ (None, 20) │ 0 │ node_distance_euclidean[0][0] │\n", + "│ gauss_basis_layer_3 │ (None, 20) │ 0 │ node_distance_euclidean_3[0][ │\n", "│ (GaussBasisLayer) │ │ │ │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction │ (None, 128) │ 68,608 │ dense[0][0], │\n", - "│ (SchNetInteraction) │ │ │ gauss_basis_layer[0][0], │\n", + "│ sch_net_interaction_12 │ (None, 128) │ 68,608 │ dense_63[0][0], │\n", + "│ (SchNetInteraction) │ │ │ gauss_basis_layer_3[0][0], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction_1 │ (None, 128) │ 68,608 │ sch_net_interaction[0][0], │\n", - "│ (SchNetInteraction) │ │ │ gauss_basis_layer[0][0], │\n", + "│ sch_net_interaction_13 │ (None, 128) │ 68,608 │ sch_net_interaction_12[0][0], │\n", + "│ (SchNetInteraction) │ │ │ gauss_basis_layer_3[0][0], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction_2 │ (None, 128) │ 68,608 │ sch_net_interaction_1[0][0], │\n", - "│ (SchNetInteraction) │ │ │ gauss_basis_layer[0][0], │\n", + "│ sch_net_interaction_14 │ (None, 128) │ 68,608 │ sch_net_interaction_13[0][0], │\n", + "│ (SchNetInteraction) │ │ │ gauss_basis_layer_3[0][0], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction_3 │ (None, 128) │ 68,608 │ sch_net_interaction_2[0][0], │\n", - "│ (SchNetInteraction) │ │ │ gauss_basis_layer[0][0], │\n", + "│ sch_net_interaction_15 │ (None, 128) │ 68,608 │ sch_net_interaction_14[0][0], │\n", + "│ (SchNetInteraction) │ │ │ gauss_basis_layer_3[0][0], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ mlp (MLP) │ (None, 1) │ 24,833 │ sch_net_interaction_3[0][0], │\n", + "│ mlp_3 (MLP) │ (None, 1) │ 24,833 │ sch_net_interaction_15[0][0], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ pooling_nodes (PoolingNodes) │ (None, 1) │ 0 │ cast_batched_indices_to_disjo… │\n", - "│ │ │ │ mlp[0][0], │\n", + "│ pooling_nodes_3 │ (None, 1) │ 0 │ cast_batched_indices_to_disjo… │\n", + "│ (PoolingNodes) │ │ │ mlp_3[0][0], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ cast_disjoint_to_batched_gra… │ (None, 1) │ 0 │ pooling_nodes[0][0] │\n", + "│ cast_disjoint_to_batched_gra… │ (None, 1) │ 0 │ pooling_nodes_3[0][0] │\n", "│ (CastDisjointToBatchedGraphS… │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n", "\n" @@ -409,44 +1009,244 @@ "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ node_position (\u001b[38;5;33mNodePosition\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m)] │ \u001b[38;5;34m0\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ node_position_3 │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m)] │ \u001b[38;5;34m0\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ (\u001b[38;5;33mNodePosition\u001b[0m) │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ embedding_3 (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,080\u001b[0m │ cast_batched_attributes_to_di… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_distance_euclidean_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_position_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mNodeDistanceEuclidean\u001b[0m) │ │ │ node_position_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ dense_63 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m8,320\u001b[0m │ embedding_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ gauss_basis_layer_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_distance_euclidean_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", + "│ (\u001b[38;5;33mGaussBasisLayer\u001b[0m) │ │ │ │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ sch_net_interaction_12 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ dense_63[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,080\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ sch_net_interaction_13 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ node_distance_euclidean │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_position[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mNodeDistanceEuclidean\u001b[0m) │ │ │ node_position[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m] │\n", + "│ sch_net_interaction_14 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_13[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m8,320\u001b[0m │ embedding[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ sch_net_interaction_15 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_14[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ gauss_basis_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_distance_euclidean[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ mlp_3 (\u001b[38;5;33mMLP\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m24,833\u001b[0m │ sch_net_interaction_15[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ pooling_nodes_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ cast_batched_indices_to_disjo… │\n", + "│ (\u001b[38;5;33mPoolingNodes\u001b[0m) │ │ │ mlp_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_disjoint_to_batched_gra… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ pooling_nodes_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCastDisjointToBatchedGraphS…\u001b[0m │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 313,665 (1.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m313,665\u001b[0m (1.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Compiled with jit: False\n", + "Print Time for training: 1:45:28.109375\n", + "Running training on fold: 4\n" + ] + }, + { + "data": { + "text/html": [ + "
Model: \"Schnet\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"Schnet\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                   Param #  Connected to                   ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ node_number (InputLayer)      │ (None, None)              │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ range_indices (InputLayer)    │ (None, None, 2)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_nodes (InputLayer)      │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ total_ranges (InputLayer)     │ (None)                    │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_coordinates (InputLayer) │ (None, None, 3)           │           0 │ -                              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_indices_to_disj… │ [(None), (2, None),       │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedIndicesToDisjoin… │ (None), (None), (None),   │             │ range_indices[0][0],           │\n",
+       "│                               │ (None), (None), (None)]   │             │ total_nodes[0][0],             │\n",
+       "│                               │                           │             │ total_ranges[0][0]             │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None, 3), (None),       │           0 │ node_coordinates[0][0],        │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None), (None)]           │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_batched_attributes_to_d… │ [(None), (None), (None),  │           0 │ node_number[0][0],             │\n",
+       "│ (CastBatchedAttributesToDisj… │ (None)]                   │             │ total_nodes[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_position_4               │ [(None, 3), (None, 3)]    │           0 │ cast_batched_attributes_to_di… │\n",
+       "│ (NodePosition)                │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ embedding_4 (Embedding)       │ (None, 64)                │       6,080 │ cast_batched_attributes_to_di… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ node_distance_euclidean_4     │ (None, 1)                 │           0 │ node_position_4[0][0],         │\n",
+       "│ (NodeDistanceEuclidean)       │                           │             │ node_position_4[0][1]          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ dense_84 (Dense)              │ (None, 128)               │       8,320 │ embedding_4[0][0]              │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ gauss_basis_layer_4           │ (None, 20)                │           0 │ node_distance_euclidean_4[0][ │\n",
+       "│ (GaussBasisLayer)             │                           │             │                                │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_16        │ (None, 128)               │      68,608 │ dense_84[0][0],                │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_4[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_17        │ (None, 128)               │      68,608 │ sch_net_interaction_16[0][0],  │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_4[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_18        │ (None, 128)               │      68,608 │ sch_net_interaction_17[0][0],  │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_4[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ sch_net_interaction_19        │ (None, 128)               │      68,608 │ sch_net_interaction_18[0][0],  │\n",
+       "│ (SchNetInteraction)           │                           │             │ gauss_basis_layer_4[0][0],     │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ mlp_4 (MLP)                   │ (None, 1)                 │      24,833 │ sch_net_interaction_19[0][0],  │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ pooling_nodes_4               │ (None, 1)                 │           0 │ cast_batched_indices_to_disjo… │\n",
+       "│ (PoolingNodes)                │                           │             │ mlp_4[0][0],                   │\n",
+       "│                               │                           │             │ cast_batched_indices_to_disjo… │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n",
+       "│ cast_disjoint_to_batched_gra… │ (None, 1)                 │           0 │ pooling_nodes_4[0][0]          │\n",
+       "│ (CastDisjointToBatchedGraphS… │                           │             │                                │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ node_number (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ range_indices (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_nodes (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ total_ranges (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_coordinates (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_indices_to_disj… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;34m2\u001b[0m, \u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedIndicesToDisjoin…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ │ range_indices[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ total_ranges[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_coordinates[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ cast_batched_attributes_to_d… │ [(\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), (\u001b[38;5;45mNone\u001b[0m), │ \u001b[38;5;34m0\u001b[0m │ node_number[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mCastBatchedAttributesToDisj…\u001b[0m │ (\u001b[38;5;45mNone\u001b[0m)] │ │ total_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_position_4 │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m)] │ \u001b[38;5;34m0\u001b[0m │ cast_batched_attributes_to_di… │\n", + "│ (\u001b[38;5;33mNodePosition\u001b[0m) │ │ │ cast_batched_indices_to_disjo… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ embedding_4 (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m6,080\u001b[0m │ cast_batched_attributes_to_di… │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ node_distance_euclidean_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_position_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mNodeDistanceEuclidean\u001b[0m) │ │ │ node_position_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ dense_84 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m8,320\u001b[0m │ embedding_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", + "│ gauss_basis_layer_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m20\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ node_distance_euclidean_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ (\u001b[38;5;33mGaussBasisLayer\u001b[0m) │ │ │ │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ sch_net_interaction_16 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ dense_84[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ sch_net_interaction_17 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_16[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ sch_net_interaction_18 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_17[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ sch_net_interaction_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", - "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ sch_net_interaction_19 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m68,608\u001b[0m │ sch_net_interaction_18[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mSchNetInteraction\u001b[0m) │ │ │ gauss_basis_layer_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ mlp (\u001b[38;5;33mMLP\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m24,833\u001b[0m │ sch_net_interaction_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ mlp_4 (\u001b[38;5;33mMLP\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m24,833\u001b[0m │ sch_net_interaction_19[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ pooling_nodes (\u001b[38;5;33mPoolingNodes\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ cast_batched_indices_to_disjo… │\n", - "│ │ │ │ mlp[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ pooling_nodes_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ cast_batched_indices_to_disjo… │\n", + "│ (\u001b[38;5;33mPoolingNodes\u001b[0m) │ │ │ mlp_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ │ │ │ cast_batched_indices_to_disjo… │\n", "├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤\n", - "│ cast_disjoint_to_batched_gra… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ pooling_nodes[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ cast_disjoint_to_batched_gra… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ pooling_nodes_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mCastDisjointToBatchedGraphS…\u001b[0m │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘\n" ] @@ -497,7 +1297,8 @@ "name": "stdout", "output_type": "stream", "text": [ - " Compiled with jit: False\n" + " Compiled with jit: False\n", + "Print Time for training: 1:45:01.765625\n" ] } ], @@ -575,10 +1376,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "c1d034b3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from kgcnn.utils.plots import plot_train_test_loss, plot_predict_true\n", "\n",