Skip to content

Commit aaf08a2

Browse files
Fix CI testing numbering of cli.py (#60)
* Remove the SetSuccess stage (and the need for it) Signed-off-by: Jeremy Fowers <[email protected]> * Add a comment about deepcopy Signed-off-by: Jeremy Fowers <[email protected]> * Fix CI testing order * Move large test to bottom --------- Signed-off-by: Jeremy Fowers <[email protected]> Co-authored-by: Jeremy Fowers <[email protected]>
1 parent 6214ab3 commit aaf08a2

File tree

1 file changed

+119
-119
lines changed

1 file changed

+119
-119
lines changed

test/cli.py

+119-119
Original file line numberDiff line numberDiff line change
@@ -261,104 +261,7 @@ def test_003_cli_build_dir(self):
261261

262262
assert_success_of_builds(test_scripts, cache_dir)
263263

264-
def test_021_cli_report(self):
265-
# NOTE: this is not a unit test, it relies on other command
266-
# If this test is failing, make sure the following tests are passing:
267-
# - test_cli_corpus
268-
269-
test_scripts = common.test_scripts_dot_py.keys()
270-
271-
# Build the test corpus so we have builds to report
272-
testargs = [
273-
"turnkey",
274-
"benchmark",
275-
bash(f"{corpus_dir}/*.py"),
276-
"--cache-dir",
277-
cache_dir,
278-
]
279-
with patch.object(sys, "argv", flatten(testargs)):
280-
turnkeycli()
281-
282-
testargs = [
283-
"turnkey",
284-
"cache",
285-
"report",
286-
"--cache-dir",
287-
cache_dir,
288-
]
289-
with patch.object(sys, "argv", testargs):
290-
turnkeycli()
291-
292-
# Read generated CSV file
293-
summary_csv_path = report.get_report_name()
294-
with open(summary_csv_path, "r", encoding="utf8") as summary_csv:
295-
summary = list(csv.DictReader(summary_csv))
296-
297-
# Check if csv file contains all expected rows and columns
298-
expected_cols = [
299-
"model_name",
300-
"author",
301-
"class",
302-
"parameters",
303-
"hash",
304-
"runtime",
305-
"device_type",
306-
"device",
307-
"mean_latency",
308-
"throughput",
309-
"all_build_stages",
310-
"completed_build_stages",
311-
]
312-
linear_summary = summary[1]
313-
assert len(summary) == len(test_scripts)
314-
for elem in expected_cols:
315-
assert (
316-
elem in linear_summary
317-
), f"Couldn't find expected key {elem} in results spreadsheet"
318-
319-
# Check whether all rows we expect to be populated are actually populated
320-
assert (
321-
linear_summary["model_name"] == "linear2"
322-
), f"Wrong model name found {linear_summary['model_name']}"
323-
assert (
324-
linear_summary["author"] == "turnkey"
325-
), f"Wrong author name found {linear_summary['author']}"
326-
assert (
327-
linear_summary["class"] == "TwoLayerModel"
328-
), f"Wrong class found {linear_summary['model_class']}"
329-
assert (
330-
linear_summary["hash"] == "80b93950"
331-
), f"Wrong hash found {linear_summary['hash']}"
332-
assert (
333-
linear_summary["runtime"] == "ort"
334-
), f"Wrong runtime found {linear_summary['runtime']}"
335-
assert (
336-
linear_summary["device_type"] == "x86"
337-
), f"Wrong device type found {linear_summary['device_type']}"
338-
assert (
339-
float(linear_summary["mean_latency"]) > 0
340-
), f"latency must be >0, got {linear_summary['x86_latency']}"
341-
assert (
342-
float(linear_summary["throughput"]) > 100
343-
), f"throughput must be >100, got {linear_summary['throughput']}"
344-
345-
# Make sure the report.get_dict() API works
346-
result_dict = report.get_dict(
347-
summary_csv_path, ["all_build_stages", "completed_build_stages"]
348-
)
349-
for result in result_dict.values():
350-
# All of the models should have exported to ONNX, so the "onnx_exported" value
351-
# should be True for all of them
352-
assert "export_pytorch" in yaml.safe_load(result["all_build_stages"])
353-
assert (
354-
"export_pytorch"
355-
in yaml.safe_load(result["completed_build_stages"]).keys()
356-
)
357-
assert (
358-
yaml.safe_load(result["completed_build_stages"])["export_pytorch"] > 0
359-
)
360-
361-
def test_005_cli_list(self):
264+
def test_004_cli_list(self):
362265
# NOTE: this is not a unit test, it relies on other command
363266
# If this test is failing, make sure the following tests are passing:
364267
# - test_cli_corpus
@@ -391,7 +294,7 @@ def test_005_cli_list(self):
391294
script_name = common.strip_dot_py(test_script)
392295
assert script_name in f.getvalue(), f"{script_name} {f.getvalue()}"
393296

394-
def test_006_cli_delete(self):
297+
def test_005_cli_delete(self):
395298
# NOTE: this is not a unit test, it relies on other command
396299
# If this test is failing, make sure the following tests are passing:
397300
# - test_cli_corpus
@@ -453,7 +356,7 @@ def test_006_cli_delete(self):
453356
script_name = common.strip_dot_py(test_script)
454357
assert script_name not in f.getvalue()
455358

456-
def test_007_cli_stats(self):
359+
def test_006_cli_stats(self):
457360
# NOTE: this is not a unit test, it relies on other command
458361
# If this test is failing, make sure the following tests are passing:
459362
# - test_cli_corpus
@@ -531,7 +434,7 @@ def test_007_cli_stats(self):
531434
]
532435
assert isinstance(stats_dict["task"], str), stats_dict["task"]
533436

534-
def test_008_cli_version(self):
437+
def test_007_cli_version(self):
535438
# Get the version number
536439
with redirect_stdout(io.StringIO()) as f:
537440
testargs = [
@@ -544,7 +447,7 @@ def test_008_cli_version(self):
544447
# Make sure we get back a 3-digit number
545448
assert len(f.getvalue().split(".")) == 3
546449

547-
def test_009_cli_turnkey_args(self):
450+
def test_008_cli_turnkey_args(self):
548451
# NOTE: this is not a unit test, it relies on other command
549452
# If this test is failing, make sure the following tests are passing:
550453
# - test_cli_single
@@ -570,7 +473,7 @@ def test_009_cli_turnkey_args(self):
570473

571474
# TODO: Investigate why this test is failing only on Windows CI failing
572475
@unittest.skipIf(platform.system() == "Windows", "Windows CI only failure")
573-
def test_011_cli_benchmark(self):
476+
def test_009_cli_benchmark(self):
574477
# Test the first model in the corpus
575478
test_script = list(common.test_scripts_dot_py.keys())[0]
576479

@@ -588,7 +491,7 @@ def test_011_cli_benchmark(self):
588491

589492
# TODO: Investigate why this test is non-deterministically failing
590493
@unittest.skip("Flaky test")
591-
def test_013_cli_labels(self):
494+
def test_010_cli_labels(self):
592495
# Only build models labels with test_group::a
593496
testargs = [
594497
"turnkey",
@@ -638,7 +541,7 @@ def test_013_cli_labels(self):
638541
assert state_files == ["linear_d5b1df11_state", "linear2_80b93950_state"]
639542

640543
@unittest.skip("Needs re-implementation")
641-
def test_014_report_on_failed_build(self):
544+
def test_011_report_on_failed_build(self):
642545
testargs = [
643546
"turnkey",
644547
bash(f"{corpus_dir}/linear.py"),
@@ -680,7 +583,7 @@ def test_014_report_on_failed_build(self):
680583
), "Wrong number of parameters found in report"
681584
assert summary[0]["hash"] == "d5b1df11", "Wrong hash found in report"
682585

683-
def test_015_runtimes(self):
586+
def test_012_runtimes(self):
684587
# Attempt to benchmark using an invalid runtime
685588
with self.assertRaises(exceptions.ArgError):
686589
testargs = [
@@ -729,7 +632,7 @@ def test_015_runtimes(self):
729632

730633
# TODO: Investigate why this test is only failing on Windows CI
731634
@unittest.skipIf(platform.system() == "Windows", "Windows CI only failure")
732-
def test_016_cli_onnx_opset(self):
635+
def test_013_cli_onnx_opset(self):
733636
# Test the first model in the corpus
734637
test_script = list(common.test_scripts_dot_py.keys())[0]
735638

@@ -752,7 +655,7 @@ def test_016_cli_onnx_opset(self):
752655
[test_script], cache_dir, None, check_perf=True, check_opset=user_opset
753656
)
754657

755-
def test_016_cli_iteration_count(self):
658+
def test_014_cli_iteration_count(self):
756659
# Test the first model in the corpus
757660
test_script = list(common.test_scripts_dot_py.keys())[0]
758661

@@ -777,7 +680,7 @@ def test_016_cli_iteration_count(self):
777680
check_iteration_count=test_iterations,
778681
)
779682

780-
def test_017_cli_process_isolation(self):
683+
def test_015_cli_process_isolation(self):
781684
# Test the first model in the corpus
782685
test_script = list(common.test_scripts_dot_py.keys())[0]
783686

@@ -799,7 +702,7 @@ def test_017_cli_process_isolation(self):
799702
"Skipping, as torch.compile is not supported on Windows"
800703
"Revisit when torch.compile for Windows is supported",
801704
)
802-
def test_018_skip_compiled(self):
705+
def test_016_skip_compiled(self):
803706
test_script = "compiled.py"
804707
testargs = [
805708
"turnkey",
@@ -817,14 +720,14 @@ def test_018_skip_compiled(self):
817720
# One of those is compiled and should be skipped.
818721
assert builds_found == 1
819722

820-
def test_019_invalid_file_type(self):
723+
def test_017_invalid_file_type(self):
821724
# Ensure that we get an error when running turnkey with invalid input_files
822725
with self.assertRaises(exceptions.ArgError):
823726
testargs = ["turnkey", "gobbledegook"]
824727
with patch.object(sys, "argv", flatten(testargs)):
825728
turnkeycli()
826729

827-
def test_020_cli_export_only(self):
730+
def test_018_cli_export_only(self):
828731
# Test the first model in the corpus
829732
test_script = list(common.test_scripts_dot_py.keys())[0]
830733

@@ -842,7 +745,7 @@ def test_020_cli_export_only(self):
842745

843746
assert_success_of_builds([test_script], cache_dir, check_onnx_file_count=1)
844747

845-
def test_022_cli_onnx_model(self):
748+
def test_019_cli_onnx_model(self):
846749
"""
847750
Manually export an ONNX file, then feed it into the CLI
848751
"""
@@ -871,7 +774,7 @@ def test_022_cli_onnx_model(self):
871774

872775
assert_success_of_builds([build_name], cache_dir)
873776

874-
def test_023_cli_onnx_model_opset(self):
777+
def test_020_cli_onnx_model_opset(self):
875778
"""
876779
Manually export an ONNX file with a non-defualt opset, then feed it into the CLI
877780
"""
@@ -904,7 +807,7 @@ def test_023_cli_onnx_model_opset(self):
904807

905808
assert_success_of_builds([build_name], cache_dir)
906809

907-
def test_024_args_encode_decode(self):
810+
def test_021_args_encode_decode(self):
908811
"""
909812
Test the encoding and decoding of arguments that follow the
910813
["arg1::[value1,value2]","arg2::value1","flag_arg"]' format
@@ -916,7 +819,7 @@ def test_024_args_encode_decode(self):
916819
reencoded_value == encoded_value
917820
), f"input: {encoded_value}, decoded: {decoded_value}, reencoded_value: {reencoded_value}"
918821

919-
def test_025_benchmark_non_existent_file(self):
822+
def test_022_benchmark_non_existent_file(self):
920823
# Ensure we get an error when benchmarking a non existent file
921824
with self.assertRaises(exceptions.ArgError):
922825
filename = "thou_shall_not_exist.py"
@@ -925,7 +828,7 @@ def test_025_benchmark_non_existent_file(self):
925828
with patch.object(sys, "argv", testargs):
926829
turnkeycli()
927830

928-
def test_026_benchmark_non_existent_file_prefix(self):
831+
def test_023_benchmark_non_existent_file_prefix(self):
929832
# Ensure we get an error when benchmarking a non existent file
930833
with self.assertRaises(exceptions.ArgError):
931834
file_prefix = "non_existent_prefix_*.py"
@@ -934,7 +837,7 @@ def test_026_benchmark_non_existent_file_prefix(self):
934837
with patch.object(sys, "argv", testargs):
935838
turnkeycli()
936839

937-
def test_027_input_text_file(self):
840+
def test_024_input_text_file(self):
938841
"""
939842
Ensure that we can intake .txt files
940843
"""
@@ -955,7 +858,7 @@ def test_027_input_text_file(self):
955858
builds_found == 3
956859
), f"Expected 3 builds (1 for linear.py, 2 for linear2.py), but got {builds_found}."
957860

958-
def test_028_cli_timeout(self):
861+
def test_025_cli_timeout(self):
959862
"""
960863
Make sure that the --timeout option and its associated reporting features work.
961864
@@ -1009,6 +912,103 @@ def test_028_cli_timeout(self):
1009912
# the stats.yaml was created, which in turn means the CSV is empty
1010913
pass
1011914

915+
def test_026_cli_report(self):
916+
# NOTE: this is not a unit test, it relies on other command
917+
# If this test is failing, make sure the following tests are passing:
918+
# - test_cli_corpus
919+
920+
test_scripts = common.test_scripts_dot_py.keys()
921+
922+
# Build the test corpus so we have builds to report
923+
testargs = [
924+
"turnkey",
925+
"benchmark",
926+
bash(f"{corpus_dir}/*.py"),
927+
"--cache-dir",
928+
cache_dir,
929+
]
930+
with patch.object(sys, "argv", flatten(testargs)):
931+
turnkeycli()
932+
933+
testargs = [
934+
"turnkey",
935+
"cache",
936+
"report",
937+
"--cache-dir",
938+
cache_dir,
939+
]
940+
with patch.object(sys, "argv", testargs):
941+
turnkeycli()
942+
943+
# Read generated CSV file
944+
summary_csv_path = report.get_report_name()
945+
with open(summary_csv_path, "r", encoding="utf8") as summary_csv:
946+
summary = list(csv.DictReader(summary_csv))
947+
948+
# Check if csv file contains all expected rows and columns
949+
expected_cols = [
950+
"model_name",
951+
"author",
952+
"class",
953+
"parameters",
954+
"hash",
955+
"runtime",
956+
"device_type",
957+
"device",
958+
"mean_latency",
959+
"throughput",
960+
"all_build_stages",
961+
"completed_build_stages",
962+
]
963+
linear_summary = summary[1]
964+
assert len(summary) == len(test_scripts)
965+
for elem in expected_cols:
966+
assert (
967+
elem in linear_summary
968+
), f"Couldn't find expected key {elem} in results spreadsheet"
969+
970+
# Check whether all rows we expect to be populated are actually populated
971+
assert (
972+
linear_summary["model_name"] == "linear2"
973+
), f"Wrong model name found {linear_summary['model_name']}"
974+
assert (
975+
linear_summary["author"] == "turnkey"
976+
), f"Wrong author name found {linear_summary['author']}"
977+
assert (
978+
linear_summary["class"] == "TwoLayerModel"
979+
), f"Wrong class found {linear_summary['model_class']}"
980+
assert (
981+
linear_summary["hash"] == "80b93950"
982+
), f"Wrong hash found {linear_summary['hash']}"
983+
assert (
984+
linear_summary["runtime"] == "ort"
985+
), f"Wrong runtime found {linear_summary['runtime']}"
986+
assert (
987+
linear_summary["device_type"] == "x86"
988+
), f"Wrong device type found {linear_summary['device_type']}"
989+
assert (
990+
float(linear_summary["mean_latency"]) > 0
991+
), f"latency must be >0, got {linear_summary['x86_latency']}"
992+
assert (
993+
float(linear_summary["throughput"]) > 100
994+
), f"throughput must be >100, got {linear_summary['throughput']}"
995+
996+
# Make sure the report.get_dict() API works
997+
result_dict = report.get_dict(
998+
summary_csv_path, ["all_build_stages", "completed_build_stages"]
999+
)
1000+
for result in result_dict.values():
1001+
# All of the models should have exported to ONNX, so the "onnx_exported" value
1002+
# should be True for all of them
1003+
assert "export_pytorch" in yaml.safe_load(result["all_build_stages"])
1004+
assert (
1005+
"export_pytorch"
1006+
in yaml.safe_load(result["completed_build_stages"]).keys()
1007+
)
1008+
assert (
1009+
yaml.safe_load(result["completed_build_stages"])["export_pytorch"] > 0
1010+
)
1011+
10121012

10131013
if __name__ == "__main__":
10141014
unittest.main()

0 commit comments

Comments
 (0)