|
| 1 | +############################################################################# |
| 2 | +# Copyright (C) 2020-2024 MEmilio |
| 3 | +# |
| 4 | +# Authors: Henrik Zunker, Maximilian Betz |
| 5 | +# |
| 6 | +# Contact: Martin J. Kuehn <[email protected]> |
| 7 | +# |
| 8 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 9 | +# you may not use this file except in compliance with the License. |
| 10 | +# You may obtain a copy of the License at |
| 11 | +# |
| 12 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 13 | +# |
| 14 | +# Unless required by applicable law or agreed to in writing, software |
| 15 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 16 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 17 | +# See the License for the specific language governing permissions and |
| 18 | +# limitations under the License. |
| 19 | +############################################################################# |
| 20 | + |
| 21 | +import datetime as dt |
| 22 | +import os.path |
| 23 | +import imageio |
| 24 | +import tempfile |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import pandas as pd |
| 28 | +import matplotlib.pyplot as plt |
| 29 | + |
| 30 | +import memilio.epidata.getPopulationData as gpd |
| 31 | +import memilio.plot.plotMap as pm |
| 32 | +from memilio.epidata import geoModificationGermany as geoger |
| 33 | +import memilio.epidata.progress_indicator as progind |
| 34 | +import warnings |
| 35 | +warnings.simplefilter(action='ignore', category=FutureWarning) |
| 36 | + |
| 37 | + |
| 38 | +def create_plot_map(day, filename, files_input, output_path, compartments, file_format='h5', relative=False, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}): |
| 39 | + """! Plots region-specific information for a single day of the simulation. |
| 40 | + @param[in] day Day of the simulation. |
| 41 | + @param[in] filename Name of the file to be created. |
| 42 | + @param[in] files_input Dictionary of input files. |
| 43 | + @param[in] output_path Output path for the figure. |
| 44 | + @param[in] compartments List of compartments to be plotted. |
| 45 | + @param[in] file_format Format of the file to be created. Either 'h5' or 'json'. |
| 46 | + @param[in] relative Defines if data should be scaled relative to population. |
| 47 | + @param[in] age_groups Dictionary of age groups to be considered. |
| 48 | + """ |
| 49 | + |
| 50 | + if len(age_groups) == 6: |
| 51 | + filter_age = None |
| 52 | + else: |
| 53 | + if file_format == 'json': |
| 54 | + filter_age = [val for val in age_groups.values()] |
| 55 | + else: |
| 56 | + filter_age = ['Group' + str(key) for key in age_groups.keys()] |
| 57 | + |
| 58 | + # In file_input there can be two different files. When we enter two files, |
| 59 | + # both files are plotted side by side in the same figure. |
| 60 | + file_index = 0 |
| 61 | + for file in files_input.values(): |
| 62 | + |
| 63 | + df = pm.extract_data( |
| 64 | + file, region_spec=None, column=None, date=day, |
| 65 | + filters={'Group': filter_age, 'InfectionState': compartments}, |
| 66 | + file_format=file_format) |
| 67 | + |
| 68 | + if relative: |
| 69 | + |
| 70 | + try: |
| 71 | + population = pd.read_json( |
| 72 | + 'data/pydata/Germany/county_current_population.json') |
| 73 | + # pandas>1.5 raise FileNotFoundError instead of ValueError |
| 74 | + except (ValueError, FileNotFoundError): |
| 75 | + print( |
| 76 | + "Population data was not found. Downloading it from the internet.") |
| 77 | + population = gpd.get_population_data( |
| 78 | + read_data=False, file_format=file_format, |
| 79 | + out_folder='data/pydata/Germany/', no_raw=True, merge_eisenach=True) |
| 80 | + |
| 81 | + # For fitting of different age groups we need format ">X". |
| 82 | + age_group_values = list(age_groups.values()) |
| 83 | + age_group_values[-1] = age_group_values[-1].replace( |
| 84 | + '80+', '>79') |
| 85 | + # scale data |
| 86 | + df = pm.scale_dataframe_relative( |
| 87 | + df, age_group_values, population) |
| 88 | + |
| 89 | + if file_index == 0: |
| 90 | + dfs_all = pd.DataFrame(df.iloc[:, 0]) |
| 91 | + |
| 92 | + dfs_all[df.columns[-1] + ' ' + str(file_index)] = df[df.columns[-1]] |
| 93 | + file_index += 1 |
| 94 | + |
| 95 | + dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce') |
| 96 | + |
| 97 | + dfs_all_sorted = dfs_all.sort_values(by='Region') |
| 98 | + dfs_all_sorted = dfs_all_sorted.reset_index(drop=True) |
| 99 | + |
| 100 | + min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min() |
| 101 | + max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max() |
| 102 | + |
| 103 | + pm.plot_map( |
| 104 | + dfs_all_sorted, scale_colors=np.array([min_val, max_val]), |
| 105 | + legend=['', ''], |
| 106 | + title='Synthetic data (relative) day ' + f'{day:2d}', plot_colorbar=True, |
| 107 | + output_path=output_path, |
| 108 | + fig_name=filename, dpi=300, |
| 109 | + outercolor=[205 / 255, 238 / 255, 251 / 255]) |
| 110 | + |
| 111 | + |
| 112 | +def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', |
| 113 | + 3: '35-59', 4: '60-79', 5: '80+'}): |
| 114 | + """! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then |
| 115 | + storing the single plots in a temporary directory. Currently only works for the results created by the parameter study. |
| 116 | +
|
| 117 | + @param[in] input_data Path to the input data. The Path should contain a file called 'Results' which contains |
| 118 | + the simulation results. This is the default output folder of the parameter study. |
| 119 | + @param[in] output_dir Path where the gif should be stored. |
| 120 | + @param[in] filename Name of the temporary file. |
| 121 | + @param[in] relative Defines if data should be scaled relative to population. |
| 122 | + @param[in] age_groups Dictionary of age groups to be considered. |
| 123 | + """ |
| 124 | + |
| 125 | + files_input = {'Data set': input_data + '/Results'} |
| 126 | + file_format = 'h5' |
| 127 | + |
| 128 | + if len(age_groups) == 6: |
| 129 | + filter_age = None |
| 130 | + else: |
| 131 | + filter_age = ['Group' + str(key) for key in age_groups.keys()] |
| 132 | + |
| 133 | + num_days = pm.extract_time_steps( |
| 134 | + files_input[list(files_input.keys())[0]], file_format=file_format) |
| 135 | + |
| 136 | + # create gif |
| 137 | + frames = [] |
| 138 | + with progind.Percentage() as indicator: |
| 139 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 140 | + for day in range(0, num_days): |
| 141 | + create_plot_map(day, filename, files_input, tmpdirname, |
| 142 | + compartments, file_format, relative, age_groups) |
| 143 | + |
| 144 | + image = imageio.v2.imread( |
| 145 | + os.path.join(tmpdirname, filename + ".png")) |
| 146 | + frames.append(image) |
| 147 | + |
| 148 | + # Close the current figure to free up memory |
| 149 | + plt.close('all') |
| 150 | + indicator.set_progress((day+1)/num_days) |
| 151 | + |
| 152 | + imageio.mimsave(os.path.join(output_dir, filename + '.gif'), |
| 153 | + frames, # array of input frames |
| 154 | + duration=0.2, # duration of each frame in seconds |
| 155 | + loop=0) # optional: frames per second |
0 commit comments