Skip to content

Commit f6d3162

Browse files
author
Rahul Batra
committed
[ROCm]: Add support to continue on fail
1 parent 8a11b40 commit f6d3162

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

build/rocm/run_multi_gpu.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ run_tests() {
2020
local base_dir=./logs
2121
local gpu_devices="$1"
2222
export HIP_VISIBLE_DEVICES=$gpu_devices
23-
python3 -m pytest --html=$base_dir/multi_gpu_pmap_test_log.html --reruns 3 -x tests/pmap_test.py
24-
python3 -m pytest --html=$base_dir/multi_gpu_multi_device_test_log.html --reruns 3 -x tests/multi_device_test.py
23+
python3 -m pytest --html=$base_dir/multi_gpu_pmap_test_log.html --reruns 3 tests/pmap_test.py
24+
python3 -m pytest --html=$base_dir/multi_gpu_multi_device_test_log.html --reruns 3 tests/multi_device_test.py
2525
python3 -m pytest_html_merger -i $base_dir/ -o $base_dir/final_compiled_report.html
2626
}
2727

build/rocm/run_single_gpu.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ def collect_testmodules():
6868
print("Test module discovery failed.")
6969
exit(return_code)
7070
for line in stdout.split("\n"):
71-
match = re.match("<Module (.*)>", line)
71+
match = re.match("<Module (.*)>", line.strip())
7272
if match:
7373
test_file = match.group(1)
74+
if "/" not in test_file:
75+
test_file = os.path.join("tests",test_file)
7476
all_test_files.append(test_file)
7577
print("---------- collected test modules ----------")
7678
print("Found %d test modules." % (len(all_test_files)))
@@ -79,7 +81,7 @@ def collect_testmodules():
7981
return all_test_files
8082

8183

82-
def run_test(testmodule, gpu_tokens):
84+
def run_test(testmodule, gpu_tokens, continue_on_fail):
8385
global LAST_CODE
8486
with GPU_LOCK:
8587
if LAST_CODE != 0:
@@ -90,39 +92,43 @@ def run_test(testmodule, gpu_tokens):
9092
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
9193
}
9294
testfile = extract_filename(testmodule)
93-
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule]
95+
if continue_on_fail:
96+
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-v", testmodule]
97+
else:
98+
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", "-v", testmodule]
9499
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
95100
with GPU_LOCK:
96101
gpu_tokens.append(target_gpu)
97102
if LAST_CODE == 0:
98103
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
99104
print(stdout)
100105
print(stderr)
101-
LAST_CODE = return_code
106+
if continue_on_fail == False:
107+
LAST_CODE = return_code
102108
return
103109

104110

105-
def run_parallel(all_testmodules, p):
106-
print("Running tests with parallelism=", p)
111+
def run_parallel(all_testmodules, p, c):
112+
print(f"Running tests with parallelism=", p)
107113
available_gpu_tokens = list(range(p))
108114
executor = ThreadPoolExecutor(max_workers=p)
109115
# walking through test modules
110116
for testmodule in all_testmodules:
111-
executor.submit(run_test, testmodule, available_gpu_tokens)
117+
executor.submit(run_test, testmodule, available_gpu_tokens, c)
112118
# waiting for all modules to finish
113119
executor.shutdown(wait=True) # wait for all jobs to finish
114120
return
115121

116122

117123
def find_num_gpus():
118-
cmd = ["lspci|grep 'controller'|grep 'AMD/ATI'|wc -l"]
124+
cmd = ["lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
119125
_, _, stdout = run_shell_command(cmd, shell=True)
120126
return int(stdout)
121127

122128

123129
def main(args):
124130
all_testmodules = collect_testmodules()
125-
run_parallel(all_testmodules, args.parallel)
131+
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
126132
generate_final_report()
127133
exit(LAST_CODE)
128134

@@ -134,7 +140,13 @@ def main(args):
134140
"--parallel",
135141
type=int,
136142
help="number of tests to run in parallel")
143+
parser.add_argument("-c",
144+
"--continue_on_fail",
145+
action='store_true',
146+
help="continue on failure")
137147
args = parser.parse_args()
148+
if args.continue_on_fail:
149+
print("continue on fail is set")
138150
if args.parallel is None:
139151
sys_gpu_count = find_num_gpus()
140152
args.parallel = sys_gpu_count

0 commit comments

Comments
 (0)