Skip to content

Commit c469590

Browse files
HenrZu“HenrikMaxBetzDLR
authored
744 Create gif from simulation results (#760)
Co-authored-by: “Henrik <“[email protected]”> Co-authored-by: MaxBetzDLR <[email protected]>
1 parent f7b8d38 commit c469590

File tree

7 files changed

+309
-18
lines changed

7 files changed

+309
-18
lines changed

pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import memilio.simulation as mio
55
import memilio.simulation.secir as secir
6+
import memilio.plot.createGIF as mp
67

78
from enum import Enum
89
from memilio.simulation.secir import (Model, Simulation,
@@ -426,7 +427,7 @@ def get_graph(self, end_date):
426427

427428
return graph
428429

429-
def run(self, num_days_sim, num_runs=10, save_graph=True):
430+
def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True):
430431
mio.set_log_level(mio.LogLevel.Warning)
431432
end_date = self.start_date + datetime.timedelta(days=num_days_sim)
432433

@@ -459,6 +460,11 @@ def run(self, num_days_sim, num_runs=10, save_graph=True):
459460
secir.save_results(
460461
ensemble_results, ensemble_params, node_ids, self.results_dir,
461462
save_single_runs, save_percentiles)
463+
if create_gif:
464+
# any compartments in the model (see InfectionStates)
465+
compartments = [c for c in range(1, 8)]
466+
mp.create_gif_map_plot(
467+
self.results_dir + "/p75", self.results_dir, compartments)
462468
return 0
463469

464470

@@ -468,5 +474,5 @@ def run(self, num_days_sim, num_runs=10, save_graph=True):
468474
data_dir=os.path.join(file_path, "../../../data"),
469475
start_date=datetime.date(year=2020, month=12, day=12),
470476
results_dir=os.path.join(file_path, "../../../results_secir"))
471-
num_days_sim = 30
477+
num_days_sim = 50
472478
sim.run(num_days_sim, num_runs=2)

pycode/memilio-plot/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ Required python packages:
5959
- mapclassify
6060
- geopandas
6161
- h5py
62+
- imageio
63+
- datetime
6264

6365
Testing and Coverage
6466
--------------------
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2024 MEmilio
3+
#
4+
# Authors: Henrik Zunker, Maximilian Betz
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
21+
import datetime as dt
22+
import os.path
23+
import imageio
24+
import tempfile
25+
26+
import numpy as np
27+
import pandas as pd
28+
import matplotlib.pyplot as plt
29+
30+
import memilio.epidata.getPopulationData as gpd
31+
import memilio.plot.plotMap as pm
32+
from memilio.epidata import geoModificationGermany as geoger
33+
import memilio.epidata.progress_indicator as progind
34+
import warnings
35+
warnings.simplefilter(action='ignore', category=FutureWarning)
36+
37+
38+
def create_plot_map(day, filename, files_input, output_path, compartments, file_format='h5', relative=False, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}):
39+
"""! Plots region-specific information for a single day of the simulation.
40+
@param[in] day Day of the simulation.
41+
@param[in] filename Name of the file to be created.
42+
@param[in] files_input Dictionary of input files.
43+
@param[in] output_path Output path for the figure.
44+
@param[in] compartments List of compartments to be plotted.
45+
@param[in] file_format Format of the file to be created. Either 'h5' or 'json'.
46+
@param[in] relative Defines if data should be scaled relative to population.
47+
@param[in] age_groups Dictionary of age groups to be considered.
48+
"""
49+
50+
if len(age_groups) == 6:
51+
filter_age = None
52+
else:
53+
if file_format == 'json':
54+
filter_age = [val for val in age_groups.values()]
55+
else:
56+
filter_age = ['Group' + str(key) for key in age_groups.keys()]
57+
58+
# In file_input there can be two different files. When we enter two files,
59+
# both files are plotted side by side in the same figure.
60+
file_index = 0
61+
for file in files_input.values():
62+
63+
df = pm.extract_data(
64+
file, region_spec=None, column=None, date=day,
65+
filters={'Group': filter_age, 'InfectionState': compartments},
66+
file_format=file_format)
67+
68+
if relative:
69+
70+
try:
71+
population = pd.read_json(
72+
'data/pydata/Germany/county_current_population.json')
73+
# pandas>1.5 raise FileNotFoundError instead of ValueError
74+
except (ValueError, FileNotFoundError):
75+
print(
76+
"Population data was not found. Downloading it from the internet.")
77+
population = gpd.get_population_data(
78+
read_data=False, file_format=file_format,
79+
out_folder='data/pydata/Germany/', no_raw=True, merge_eisenach=True)
80+
81+
# For fitting of different age groups we need format ">X".
82+
age_group_values = list(age_groups.values())
83+
age_group_values[-1] = age_group_values[-1].replace(
84+
'80+', '>79')
85+
# scale data
86+
df = pm.scale_dataframe_relative(
87+
df, age_group_values, population)
88+
89+
if file_index == 0:
90+
dfs_all = pd.DataFrame(df.iloc[:, 0])
91+
92+
dfs_all[df.columns[-1] + ' ' + str(file_index)] = df[df.columns[-1]]
93+
file_index += 1
94+
95+
dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce')
96+
97+
dfs_all_sorted = dfs_all.sort_values(by='Region')
98+
dfs_all_sorted = dfs_all_sorted.reset_index(drop=True)
99+
100+
min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min()
101+
max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max()
102+
103+
pm.plot_map(
104+
dfs_all_sorted, scale_colors=np.array([min_val, max_val]),
105+
legend=['', ''],
106+
title='Synthetic data (relative) day ' + f'{day:2d}', plot_colorbar=True,
107+
output_path=output_path,
108+
fig_name=filename, dpi=300,
109+
outercolor=[205 / 255, 238 / 255, 251 / 255])
110+
111+
112+
def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34',
113+
3: '35-59', 4: '60-79', 5: '80+'}):
114+
"""! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then
115+
storing the single plots in a temporary directory. Currently only works for the results created by the parameter study.
116+
117+
@param[in] input_data Path to the input data. The Path should contain a file called 'Results' which contains
118+
the simulation results. This is the default output folder of the parameter study.
119+
@param[in] output_dir Path where the gif should be stored.
120+
@param[in] filename Name of the temporary file.
121+
@param[in] relative Defines if data should be scaled relative to population.
122+
@param[in] age_groups Dictionary of age groups to be considered.
123+
"""
124+
125+
files_input = {'Data set': input_data + '/Results'}
126+
file_format = 'h5'
127+
128+
if len(age_groups) == 6:
129+
filter_age = None
130+
else:
131+
filter_age = ['Group' + str(key) for key in age_groups.keys()]
132+
133+
num_days = pm.extract_time_steps(
134+
files_input[list(files_input.keys())[0]], file_format=file_format)
135+
136+
# create gif
137+
frames = []
138+
with progind.Percentage() as indicator:
139+
with tempfile.TemporaryDirectory() as tmpdirname:
140+
for day in range(0, num_days):
141+
create_plot_map(day, filename, files_input, tmpdirname,
142+
compartments, file_format, relative, age_groups)
143+
144+
image = imageio.v2.imread(
145+
os.path.join(tmpdirname, filename + ".png"))
146+
frames.append(image)
147+
148+
# Close the current figure to free up memory
149+
plt.close('all')
150+
indicator.set_progress((day+1)/num_days)
151+
152+
imageio.mimsave(os.path.join(output_dir, filename + '.gif'),
153+
frames, # array of input frames
154+
duration=0.2, # duration of each frame in seconds
155+
loop=0) # optional: frames per second

pycode/memilio-plot/memilio/plot/plotMap.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import pandas as pd
2828
from matplotlib import pyplot as plt
2929
from matplotlib.gridspec import GridSpec
30+
import matplotlib.colors as mcolors
3031

3132
from memilio.epidata import geoModificationGermany as geoger
3233
from memilio.epidata import getDataIntoPandasDataFrame as gd
@@ -143,11 +144,14 @@ def extract_data(
143144

144145
# Set no filtering if filters were set to None.
145146
if filters == None:
146-
filters['Group'] = list(h5file[regions[i]].keys())[
147+
filters['Group'] = list(h5file[regions[0]].keys())[
147148
:-2] # Remove 'Time' and 'Total'.
148149
filters['InfectionState'] = list(
149150
range(h5file[regions[i]]['Group1'].shape[1]))
150151

152+
if filters['Group'] == None:
153+
filters['Group'] = list(h5file[regions[0]].keys())[:-2]
154+
151155
InfectionStateList = [j for j in filters['InfectionState']]
152156

153157
# Create data frame to store results to plot.
@@ -192,8 +196,8 @@ def extract_data(
192196

193197
k += 1
194198
else:
195-
raise gd.ValueError(
196-
"Time point not found for region " + str(regions[i]) + ".")
199+
raise gd.DataError(
200+
"Time point " + str(date) + " not found for region " + str(regions[i]) + ".")
197201

198202
# Aggregated or matrix output.
199203
if output == 'sum':
@@ -207,6 +211,29 @@ def extract_data(
207211
raise gd.DataError("Data could not be read in.")
208212

209213

214+
def extract_time_steps(file, file_format='json'):
215+
""" Reads data from a general json or specific hdf5 file as output by the
216+
MEmilio simulation framework and extracts the number of days used.
217+
218+
@param[in] file Path and filename of file to be read in, relative from current
219+
directory.
220+
@param[in] file_format File format; either json or h5.
221+
@return Number of time steps.
222+
"""
223+
input_file = os.path.join(os.getcwd(), str(file))
224+
if file_format == 'json':
225+
df = pd.read_json(input_file + '.' + file_format)
226+
if 'Date' in df.columns:
227+
time_steps = df['Date'].nunique()
228+
else:
229+
time_steps = 1
230+
elif file_format == 'h5':
231+
h5file = h5py.File(input_file + '.' + file_format, 'r')
232+
regions = list(h5file.keys())
233+
time_steps = len(h5file[regions[0]]['Time'])
234+
return time_steps
235+
236+
210237
def scale_dataframe_relative(df, age_groups, df_population):
211238
"""! Scales a population-related data frame relative to the size of the
212239
local populations or subpopulations (e.g., if not all age groups are
@@ -225,7 +252,7 @@ def scale_dataframe_relative(df, age_groups, df_population):
225252
"""
226253

227254
# Merge population data of Eisenach (if counted separately) with Wartburgkreis.
228-
if 16056 in df_population[df.columns[0]].values:
255+
if 16056 in df_population['ID_County'].values:
229256
for i in range(1, len(df_population.columns)):
230257
df_population.loc[df_population[df.columns[0]] == 16063, df_population.columns[i]
231258
] += df_population.loc[df_population.ID_County == 16056, df_population.columns[i]]
@@ -235,12 +262,12 @@ def scale_dataframe_relative(df, age_groups, df_population):
235262
columns=[df_population.columns[0]] + age_groups)
236263
# Extrapolate on oldest age group with maximumg age 100.
237264
for region_id in df.iloc[:, 0]:
238-
df_population_agegroups.loc[len(df_population_agegroups.index), :] = [region_id] + list(
239-
mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == region_id].iloc[:, 2:], age_groups))
265+
df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list(
266+
mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups))
240267

241268
def scale_row(elem):
242269
population_local_sum = df_population_agegroups[
243-
df_population_agegroups[df.columns[0]] == elem[0]].iloc[
270+
df_population_agegroups['ID_County'] == int(elem[0])].iloc[
244271
:, 1:].sum(axis=1)
245272
return elem['Count'] / population_local_sum.values[0]
246273

@@ -272,7 +299,8 @@ def plot_map(data: pd.DataFrame,
272299
output_path: str = '',
273300
fig_name: str = 'customPlot',
274301
dpi: int = 300,
275-
outercolor='white'):
302+
outercolor='white',
303+
log_scale=False):
276304
"""! Plots region-specific information onto a interactive html map and
277305
returning svg and png image. Allows the comparisons of a variable list of
278306
data sets.
@@ -288,12 +316,14 @@ def plot_map(data: pd.DataFrame,
288316
@param[in] fig_name Name of the figure created.
289317
@param[in] dpi Dots-per-inch value for the exported figure.
290318
@param[in] outercolor Background color of the plot image.
319+
@param[in] log_scale Defines if the colorbar is plotted in log scale.
291320
"""
292321
region_classifier = data.columns[0]
322+
region_data = data[region_classifier].to_numpy().astype(int)
293323

294324
data_columns = data.columns[1:]
295325
# Read and filter map data.
296-
if data[region_classifier].isin(geoger.get_county_ids()).all():
326+
if np.isin(region_data, geoger.get_county_ids()).all():
297327
try:
298328
map_data = geopandas.read_file(
299329
os.path.join(
@@ -342,17 +372,23 @@ def plot_map(data: pd.DataFrame,
342372
# Use top row for title.
343373
tax = fig.add_subplot(gs[0, :])
344374
tax.set_axis_off()
345-
tax.set_title(title, fontsize=18)
375+
tax.set_title(title, fontsize=16)
346376
if plot_colorbar:
347377
# Prepare colorbar.
348378
cax = fig.add_subplot(gs[1, 0])
349379
else:
350380
cax = None
351381

382+
if log_scale:
383+
norm = mcolors.LogNorm(vmin=scale_colors[0], vmax=scale_colors[1])
384+
352385
for i in range(len(data_columns)):
353386

354387
ax = fig.add_subplot(gs[:, i+2])
355-
if cax is not None:
388+
if log_scale:
389+
map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True,
390+
norm=norm)
391+
elif cax is not None:
356392
map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True,
357393
vmin=scale_colors[0], vmax=scale_colors[1])
358394
else:
@@ -364,8 +400,4 @@ def plot_map(data: pd.DataFrame,
364400
ax.set_axis_off()
365401

366402
plt.subplots_adjust(bottom=0.1)
367-
368403
plt.savefig(os.path.join(output_path, fig_name + '.png'), dpi=dpi)
369-
plt.savefig(os.path.join(output_path, fig_name + '.svg'), dpi=dpi)
370-
371-
plt.show()

0 commit comments

Comments
 (0)