generated from FLAMEGPU/FLAMEGPU2-model-template-cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_publication.py
184 lines (150 loc) · 5.89 KB
/
plot_publication.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#! /usr/bin/env python3
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.image as mpimg
import argparse
import pathlib
# Default DPI
DEFAULT_DPI = 300
# Default directory for visualisation images
DEFAULT_INPUT_DIR="."
# Default directory for visualisation images
DEFAULT_VISUALISATION_DIR = "./sample/figures/visualisation"
# Visualisation images used in the figure (4 required)
VISUALISATION_IMAGE_FILENAMES = ['0.png', '350.png', '650.png', '2500.png']
# Drift csv filename from simulation output
DRIFT_CSV_FILENAME = "drift_perStepPerSimulationCSV.csv"
def cli():
parser = argparse.ArgumentParser(description="Python script to generate figure from csv files")
parser.add_argument(
'-o',
'--output-dir',
type=str,
help='directory to output figures into.',
default='.'
)
parser.add_argument(
'--dpi',
type=int,
help='DPI for output file',
default=DEFAULT_DPI
)
parser.add_argument(
'-i',
'--input-dir',
type=str,
help='Input directory, containing the csv files',
default='.'
)
parser.add_argument(
'-v',
'--vis-dir',
type=str,
help="Input directory, containing the visualisation files",
default=DEFAULT_VISUALISATION_DIR
)
args = parser.parse_args()
return args
def validate_args(args):
valid = True
# If output_dir is passed, create it, error if can't create it.
if args.output_dir is not None:
p = pathlib.Path(args.output_dir)
try:
p.mkdir(exist_ok=True, parents=True)
except Exception as e:
print(f"Error: Could not create output directory {p}: {e}")
valid = False
# DPI must be positive, and add a max.
if args.dpi is not None:
if args.dpi < 1:
print(f"Error: --dpi must be a positive value. {args.dpi}")
valid = False
# Ensure that the input directory exists, and that all required input is present.
if args.input_dir is not None:
input_dir = pathlib.Path(args.input_dir)
if input_dir.is_dir():
csv_path = input_dir / DRIFT_CSV_FILENAME
if not csv_path.is_file():
print(f"Error: {input_dir} does not contain {DRIFT_CSV_FILENAME}:")
else:
print(f"Error: Invalid input_dir provided {args.input_dir}")
valid = False
# Ensure that the visualisation input directory exists, and that all required images are present.
vis_dir = pathlib.Path(args.vis_dir)
if vis_dir.is_dir():
missing_files = []
for vis_filename in VISUALISATION_IMAGE_FILENAMES:
vis_file_path = vis_dir / vis_filename
if not vis_file_path.is_file():
missing_files.append(vis_file_path)
valid = False
if len(missing_files) > 0:
print(f"Error: {vis_dir} does not contain required files:")
for missing_file in missing_files:
print(f" {missing_file}")
else:
print(f"Error: Invalid vis_dir provided {args.vis_dir}")
valid = False
# Additional check on number of visualisation files
if len(VISUALISATION_IMAGE_FILENAMES) != 4:
print(f"Error: VISUALISATION_IMAGE_FILENAMES does not contain 4 files")
valid = False
return valid
def main():
# Validate cli
args = cli()
valid_args = validate_args(args)
if not valid_args:
return False
# Set figure theme
sns.set_theme(style='white')
# setup sub plot using mosaic layout
gs_kw = dict(width_ratios=[2, 1, 1], height_ratios=[1, 1])
f, ax = plt.subplot_mosaic([['drift', 'v1', 'v2'],
['drift', 'v3', 'v4']],
gridspec_kw=gs_kw, figsize=(10, 5),
constrained_layout=True)
# Load per time step data into data frame
input_dir = pathlib.Path(args.input_dir)
step_df = pd.read_csv(input_dir/DRIFT_CSV_FILENAME, sep=',', quotechar='"')
# Strip any white space from column names
step_df.columns = step_df.columns.str.strip()
# rename comm_radius to 'r'
step_df.rename(columns={'comm_radius': 'r'}, inplace=True)
# Plot group by communication radius (r)
plt_drift = sns.lineplot(x='step', y='s_drift', hue='r', data=step_df, ax=ax['drift'])
plt_drift.set(xlabel='Simulation steps', ylabel='Mean drift')
ax['drift'].set_title(label='A', loc='left', fontweight="bold")
# visualisation path
visualisation_dir = pathlib.Path(args.vis_dir)
# Plot vis for time step = 0
v1 = mpimg.imread(visualisation_dir / VISUALISATION_IMAGE_FILENAMES[0])
ax['v1'].imshow(v1)
ax['v1'].set_axis_off()
ax['v1'].set_title(label='B', loc='left', fontweight="bold")
# Plot vis for time step = 350
v1 = mpimg.imread(visualisation_dir / VISUALISATION_IMAGE_FILENAMES[1])
ax['v2'].imshow(v1)
ax['v2'].set_axis_off()
ax['v2'].set_title(label='C', loc='left', fontweight="bold")
# Plot vis for time step = 850
v1 = mpimg.imread(visualisation_dir / VISUALISATION_IMAGE_FILENAMES[2])
ax['v3'].imshow(v1)
ax['v3'].set_axis_off()
ax['v3'].set_title(label='D', loc='left', fontweight="bold")
# Plot vis for time step = 2500
v1 = mpimg.imread(visualisation_dir / VISUALISATION_IMAGE_FILENAMES[3])
ax['v4'].imshow(v1)
ax['v4'].set_axis_off()
ax['v4'].set_title(label='E', loc='left', fontweight="bold")
# Save to image
#f.tight_layout()
output_dir = pathlib.Path(args.output_dir)
f.savefig(output_dir/"figure.png", dpi=args.dpi)
f.savefig(output_dir/"figure.pdf", format='pdf', dpi=args.dpi)
# Run the main method if this was not included as a module
if __name__ == "__main__":
main()