19
19
from invoke .context import Context
20
20
21
21
from . import DEFAULT_REGION , IAM_INSTANCE_PROFILE , AMI_ID , LOGGER , S3_BUCKET_BENCHMARK_ARTIFACTS
22
+ from . import cloudwatch as cloudwatch_utils
22
23
23
24
TMP_DIR = "/home/ubuntu"
24
25
LOCAL_TMP_DIR = "/tmp"
@@ -65,7 +66,7 @@ def run_apache_bench(self, requests, concurrency, input_file, is_workflow=False,
65
66
self .connection .run (f"cp { file_name } { os .path .join (TMP_DIR , 'benchmark/input' )} " )
66
67
else :
67
68
self .connection .run (f"cp { input_file } { os .path .join (TMP_DIR , 'benchmark/input' )} " )
68
-
69
+
69
70
predict_flag = "predictions"
70
71
model_name = "benchmark"
71
72
if is_workflow :
@@ -100,7 +101,7 @@ def extract_metrics(self, connection=None):
100
101
temp_uuid = uuid .uuid4 ()
101
102
102
103
time .sleep (5 )
103
-
104
+
104
105
# Upload to s3 and fetch back to local instance: more reliable than using self.connection.get()
105
106
connection .run (f"aws s3 cp { self .result_file } { S3_BUCKET_BENCHMARK_ARTIFACTS } /{ temp_uuid } /result.txt" )
106
107
time .sleep (2 )
@@ -136,7 +137,7 @@ def extract_entity(self, data, pattern, index, delim=" "):
136
137
if pattern .search (line ):
137
138
return line .split (delim )[index ].strip ()
138
139
139
- def generate_csv_output (self , requests , concurrency , connection = None ):
140
+ def generate_csv_output (self , requests , concurrency , batch_size , mode , connection = None ):
140
141
LOGGER .info ("*Generating CSV output..." )
141
142
142
143
batched_requests = requests / concurrency
@@ -147,6 +148,8 @@ def generate_csv_output(self, requests, concurrency, connection=None):
147
148
with open (f"{ self .local_tmp_dir } /result.txt" ) as f :
148
149
data = f .readlines ()
149
150
artifacts ["Benchmark" ] = "AB"
151
+ artifacts ["Batch Size" ] = batch_size
152
+ artifacts ["Mode" ] = mode # This is the pytorch mode i.e. eager or scripted, as specified in <model>.yaml
150
153
artifacts ["Model" ] = self .model_name
151
154
artifacts ["Concurrency" ] = concurrency
152
155
artifacts ["Requests" ] = requests
@@ -178,6 +181,28 @@ def generate_csv_output(self, requests, concurrency, connection=None):
178
181
179
182
return artifacts
180
183
181
- def generate_report (self , requests , concurrency , connection = None ):
184
+ def push_benchmark_metrics (self , artifacts , connection = None ):
185
+ curr_instance_type = connection .run (
186
+ f"curl http://169.254.169.254/latest/meta-data/instance-type" , warn = True
187
+ ).stdout
188
+
189
+ artifacts ["instance_type" ] = curr_instance_type
190
+
191
+ # 'BENCHMARK_CONTEXT' is set internally at AWS for certain benchmark jobs.
192
+ # When it's not available, 'DevTest' is used.
193
+ dashboard_context = os .getenv ("BENCHMARK_CONTEXT" , "DevTest" )
194
+
195
+ cloudwatchMetricsHandler = cloudwatch_utils .CloudWatchMetricsHandler (
196
+ context = dashboard_context , sub_namespace = f"{ self .model_name } /{ artifacts .get ('Mode' )} "
197
+ )
198
+
199
+ cloudwatchMetricsHandler .push_benchmark_metrics (artifacts )
200
+
201
+ def generate_report (self , requests , concurrency , batch_size , mode , connection = None ):
182
202
self .extract_metrics (connection = connection )
183
- self .generate_csv_output (requests , concurrency , connection = connection )
203
+ artifacts = self .generate_csv_output (
204
+ requests , concurrency , batch_size = batch_size , mode = mode , connection = connection
205
+ )
206
+
207
+ # Push metrics to cloudwatch
208
+ self .push_benchmark_metrics (artifacts , connection = connection )
0 commit comments