Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c92bdb8
Extrapolating and saving real world data, initialize model in data ge…
mhheger Aug 11, 2025
0ceefd4
reducing extrapolate gnn to its core functionality, adding comments i…
mhheger Aug 11, 2025
fde4b3b
Add description of workflow
mhheger Aug 11, 2025
6b79df8
Merge remote-tracking branch 'origin/main' into 1139-introduce-improv…
HenrZu Aug 11, 2025
ba90b76
init file gnn
HenrZu Aug 11, 2025
3c9bf9b
[ci skip] Updated data generation and extrapolation, code for trainin…
mhheger Aug 18, 2025
306c752
[ci skip] Include different network architectures and introduce grid …
mhheger Sep 3, 2025
bf6c5f5
[ci skip] Improve model creation and implement grid search functiona…
mhheger Sep 8, 2025
5b24fd9
Add validation checks for dataset and model parameters; enhance test …
mhheger Sep 10, 2025
614082b
Enhance data scaling functionality and validation; add tests for scal…
mhheger Sep 22, 2025
6b32716
Ading necessary imports
mhheger Sep 22, 2025
b451ded
Fix imports?
mhheger Sep 22, 2025
b486d10
Add spektral to requirements
mhheger Sep 22, 2025
ea4cd31
undo scale_data-tests
mhheger Sep 22, 2025
20274ab
Refactor imports in GNN_utils.py and update test_surrogatemodel_GNN.p…
mhheger Sep 24, 2025
35fb822
Add model building and training step to evaluate_and_train.py; update…
mhheger Sep 25, 2025
9f37f74
Update requirements
mhheger Sep 25, 2025
b51623f
Merge branch 'main' into 1139-introduce-improved-gnn-surrogate-models
HenrZu Oct 29, 2025
68fa1c0
formating and fix tests
HenrZu Oct 29, 2025
a47f11e
.
HenrZu Oct 29, 2025
b6f40bb
.
HenrZu Oct 30, 2025
5bc242e
support for py 3.8
HenrZu Oct 30, 2025
70a9b11
.
HenrZu Oct 31, 2025
4d42a2a
[ci skip] start rework data gemeration gnn
HenrZu Nov 4, 2025
8c77ef2
[ci skip] .
HenrZu Nov 4, 2025
4d72f23
[ci skip] complete rework data generation
HenrZu Nov 5, 2025
304a702
.
HenrZu Nov 5, 2025
d19c6ab
[ci skip] rework evaluate and trian
HenrZu Nov 5, 2025
d39917d
.
HenrZu Nov 5, 2025
cf59960
[ci skip] rm extrapolate gnn
HenrZu Nov 5, 2025
d177e4c
[ci skip] rework gnn utils + grid_seach
HenrZu Nov 5, 2025
223b3d5
.
HenrZu Nov 5, 2025
d3f6ca4
add rtd for gnn
HenrZu Nov 5, 2025
e706a61
readme
HenrZu Nov 5, 2025
d077cc1
rm unused functions
HenrZu Nov 26, 2025
a6371d0
.
HenrZu Nov 26, 2025
1ab7b2f
Merge remote-tracking branch 'origin/main' into 1139-introduce-improv…
HenrZu Nov 26, 2025
ec9c932
[ci skip] fix data_generation
HenrZu Nov 26, 2025
1b90974
rework some files
HenrZu Nov 26, 2025
b24e203
fix in doc
HenrZu Nov 27, 2025
f09f38a
more
HenrZu Nov 27, 2025
7d3ff08
formatting doc
HenrZu Nov 27, 2025
96ba5e6
Merge remote-tracking branch 'origin/main' into 1139-introduce-improv…
HenrZu Dec 4, 2025
141a2ee
add deps to surrogate toml
HenrZu Dec 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 315 additions & 14 deletions docs/source/python/m-surrogate.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
.. include:: ../literature.rst

MEmilio Surrogate Model
========================

MEmilio Surrogate Model contains machine learning based surrogate models that make predictions based on the MEmilio simulation models.
Currently there are only surrogate models for ODE-type models. The simulations of these models are used for data generation.
The goal is to create a powerful tool that predicts the infection dynamics faster than a simulation of an expert model,
e.g., a metapopulation or agent-based model while still having acceptable errors with respect to the original simulations.

The package can be found in `pycode/memilio-surrogatemodel <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel>`_.

For more details, we refer to: Schmidt A, Zunker H, Heinlein A, Kühn MJ. (2025). *Graph Neural Network Surrogates to leverage Mechanistic Expert Knowledge towards Reliable and Immediate Pandemic Response*. Submitted for publication. `arXiv:2411.06500 <https://arxiv.org/abs/2411.06500>`_
For more details, we refer to:

|Graph_Neural_Network_Surrogates|

Dependencies
------------
Expand All @@ -30,17 +34,19 @@ Usage
The package currently provides the following modules:

- `models`: models for different tasks
Currently we have the following models:
- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection states not stratified by age groups.
- `ode_secir_groups`: A model allowing for asymptomatic as well as symptomatic infection states stratified by age groups and including one damping.

Each model folder contains the following files:
- `data_generation`: Data generation from expert model simulation outputs.
- `model`: Training and evaluation of the model.
- `network_architectures`: Contains multiple network architectures.
- `grid_search`: Utilities for hyperparameter optimization.
- `hyperparameter_tuning`: Scripts for tuning model hyperparameters.
Currently we have the following models:

- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection states not stratified by age groups.
- `ode_secir_groups`: A model allowing for asymptomatic as well as symptomatic infection states stratified by age groups and including one damping.

Each model folder contains the following files:

- `data_generation`: Data generation from expert model simulation outputs.
- `model`: Training and evaluation of the model.
- `network_architectures`: Contains multiple network architectures.
- `grid_search`: Utilities for hyperparameter optimization.
- `hyperparameter_tuning`: Scripts for tuning model hyperparameters.

- `tests`: This file contains all tests.

Expand Down Expand Up @@ -163,7 +169,302 @@ The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for sy
- Visualization of hyperparameter importance
- Selection of optimal model configurations

SECIR Groups Model
------------------

To be added...

Graph Neural Network (GNN) Surrogate Models
--------------------------------------------

The Graph Neural Network (GNN) module provides advanced surrogate models that leverage spatial connectivity and age-stratified epidemiological dynamics. These models are designed for immediate and reliable pandemic response by combining mechanistic expert knowledge with machine learning efficiency.

Overview and Scientific Foundation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The GNN surrogate models are based on the research presented in:

|Graph_Neural_Network_Surrogates|

The implementation leverages the mechanistic ODE-SECIR model (see :doc:`ODE-SECIR documentation <cpp/models/osecir>`) as the underlying expert model, using Python bindings to the C++ backend for efficient simulation during data generation.

Module Structure
~~~~~~~~~~~~~~~~

The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_ and consists of:

- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model
- **network_architectures.py**: Defines various GNN architectures (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv) with configurable depth and channels
- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models
- **grid_search.py**: Provides hyperparameter optimization through systematic grid search
- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling

Data Generation
~~~~~~~~~~~~~~~

The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations. Use ``generate_data`` to run multiple simulations and persist a pickle with inputs, labels, damping info, and contact matrices:

.. code-block:: python

from memilio.surrogatemodel.GNN import data_generation
import memilio.simulation as mio

data = data_generation.generate_data(
num_runs=5,
data_dir="/path/to/memilio/data",
output_path="/tmp/generated_datasets",
input_width=5,
label_width=30,
start_date=mio.Date(2020, 10, 1),
end_date=mio.Date(2021, 10, 31),
mobility_file="commuter_mobility.txt", # or commuter_mobility_2022.txt
transform=True,
save_data=True
)

**Data Generation Workflow:**

1. **Parameter Sampling**: Randomly sample epidemiological parameters (transmission rates, incubation periods, recovery rates) from predefined distributions to create diverse scenarios.

2. **Compartment Initialization**: Initialize epidemic compartments for each age group in each region based on realistic demographic data. Compartments are initialized using shared base factors.

3. **Mobility Graph Construction**: Build a spatial graph where:

- Nodes represent geographic regions (e.g., German counties)
- Edges represent mobility connections with weights from commuting data
- Node features include age-stratified population sizes

4. **Contact Matrix Configuration**: Load and configure baseline contact matrices for different location types (home, school, work, other) stratified by age groups.

5. **Damping Application**: Apply time-varying dampings to contact matrices to simulate NPIs:

- Multiple damping periods with random start days
- Location-specific damping factors (e.g., stronger school closures, moderate workplace restrictions)
- Realistic parameter ranges based on observed intervention strengths

6. **Simulation Execution**: Run the mechanistic ODE-SECIR model using MEmilio's C++ backend through Python bindings to generate the dataset.

7. **Data Processing**: Transform simulation results into graph-structured format:

- Extract compartment time series for each node (region) and age group
- Apply logarithmic transformation for numerical stability
- Store graph topology, node features, and temporal sequences

Network Architectures
~~~~~~~~~~~~~~~~~~~~~

The ``network_architectures.py`` module provides flexible GNN model construction for supported layer types (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv).

.. code-block:: python

from memilio.surrogatemodel.GNN import network_architectures

model = network_architectures.get_model(
layer_type="GCNConv",
num_layers=3,
num_channels=64,
activation="relu",
num_output=48 # outputs per node
)


Training and Evaluation
~~~~~~~~~~~~~~~~~~~~~~~

The ``evaluate_and_train.py`` module provides the training functionality:

.. code-block:: python

from tensorflow.keras.losses import MeanAbsolutePercentageError
from tensorflow.keras.optimizers import Adam
from memilio.surrogatemodel.GNN import evaluate_and_train, network_architectures

dataset = evaluate_and_train.load_gnn_dataset(
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic5.pickle",
"/path/to/memilio/data/Germany/mobility",
number_of_nodes=400
)

model = network_architectures.get_model(
layer_type="GCNConv",
num_layers=3,
num_channels=32,
activation="relu",
num_output=48
)

results = evaluate_and_train.train_and_evaluate(
data=dataset,
batch_size=32,
epochs=50,
model=model,
loss_fn=MeanAbsolutePercentageError(),
optimizer=Adam(learning_rate=0.001),
es_patience=10,
save_dir="/tmp/model_results",
save_name="gnn_model"
)

**Training Features:**

1. **Mini-batch Training**: Graph batching for efficient training on large datasets
2. **Custom Loss Functions**: MSE, MAE, MAPE, or custom compartment-weighted losses
3. **Early Stopping**: Monitors validation loss to prevent overfitting
4. **Save Best Weights**: Saves best model weights based on validation performance

**Evaluation Metrics:**

- **Mean Absolute Error (MAE)**: Average absolute prediction error per compartment
- **Mean Absolute Percentage Error (MAPE)**: Mean absolute error as percentage
- **R² Score**: Coefficient of determination for prediction quality

**Data Splitting:**

- **Training Set (70%)**: For model parameter optimization
- **Validation Set (15%)**: For hyperparameter tuning and early stopping
- **Test Set (15%)**: For final performance evaluation

Hyperparameter Optimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``grid_search.py`` module enables systematic exploration of hyperparameter space:

.. code-block:: python

from pathlib import Path
from memilio.surrogatemodel.GNN import grid_search, evaluate_and_train

data = evaluate_and_train.create_dataset(
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic5.pickle",
"/path/to/memilio/data/Germany/mobility",
number_of_nodes=400
)

parameter_grid = grid_search.generate_parameter_grid(
layer_types=["GCNConv", "GATConv"],
num_layers_options=[2, 3],
num_channels_options=[16, 32],
activation_functions=["relu", "elu"]
)

grid_search.perform_grid_search(
data=data,
parameter_grid=parameter_grid,
save_dir=str(Path("/tmp/grid_results")),
batch_size=32,
max_epochs=50,
es_patience=10,
learning_rate=0.001
)

Utility Functions
~~~~~~~~~~~~~~~~~

The ``GNN_utils.py`` module provides essential helper functions used throughout the GNN workflow:

**Data Preprocessing:**

.. code-block:: python

from memilio.surrogatemodel.GNN import GNN_utils

# Remove confirmed compartments (simplify model)
simplified_data = GNN_utils.remove_confirmed_compartments(
dataset_entries=dataset,
num_groups=6
)

# Apply logarithmic scaling
scaled_inputs, scaled_labels = GNN_utils.scale_data(
data=dataset,
transform=True
)

**Graph Construction:**

.. code-block:: python

# Create mobility graph from commuting data
graph = GNN_utils.create_mobility_graph(
mobility_dir='path/to/mobility',
num_regions=401, # German counties
county_ids=county_list,
models=models_per_region # SECIR models for each region
)

# Get baseline contact matrix
contact_matrix = GNN_utils.get_baseline_contact_matrix(
data_dir='path/to/contact_matrices'
)

Practical Usage Example
~~~~~~~~~~~~~~~~~~~~~~~

Here is a complete example workflow from data generation to model evaluation:

.. code-block:: python

import memilio.simulation as mio
from tensorflow.keras.losses import MeanAbsolutePercentageError
from tensorflow.keras.optimizers import Adam
from memilio.surrogatemodel.GNN import (
data_generation,
network_architectures,
evaluate_and_train
)

# Step 1: Generate and save training data
data_generation.generate_data(
num_runs=100,
data_dir="/path/to/memilio/data",
output_path="/tmp/generated_datasets",
input_width=5,
label_width=30,
start_date=mio.Date(2020, 10, 1),
end_date=mio.Date(2021, 10, 31),
save_data=True,
mobility_file="commuter_mobility.txt"
)

# Step 2: Load dataset and build model
dataset = evaluate_and_train.load_gnn_dataset(
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic100.pickle",
"/path/to/memilio/data/Germany/mobility",
number_of_nodes=400
)

model = network_architectures.get_model(
layer_type="GCNConv",
num_layers=4,
num_channels=128,
activation="relu",
num_output=48
)

# Step 3: Train and evaluate
results = evaluate_and_train.train_and_evaluate(
data=dataset,
batch_size=32,
epochs=100,
model=model,
loss_fn=MeanAbsolutePercentageError(),
optimizer=Adam(learning_rate=0.001),
es_patience=20,
save_dir="/tmp/model_results",
save_name="gnn_weights_best"
)

**GPU Acceleration:**

- TensorFlow automatically uses GPU when available
- Spektral layers are optimized for GPU execution
- Training time can be heavily reduced with appropriate GPU hardware

Additional Resources
~~~~~~~~~~~~~~~~~~~~

**Code and Examples:**

- `GNN Module <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_
- `GNN README <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md>`_

**Related Documentation:**

- :doc:`MEmilio Simulation Package <m-simulation>`
Loading
Loading