Skip to content

Commit 82f9dc0

Browse files
committed
Refactoring the function generation using bespoke templates.
1 parent 9b88aaa commit 82f9dc0

File tree

1 file changed

+52
-29
lines changed

1 file changed

+52
-29
lines changed

kernel_tuner/utils/directives.py

+52-29
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,36 @@
22
from abc import ABC, abstractmethod
33
import numpy as np
44

5+
# Function templates
6+
acc_cpp_template = """
7+
<!?PREPROCESSOR?!>
8+
<!?USER_DEFINES?!>
9+
#include <chrono>
10+
11+
extern "C" <!?SIGNATURE?!> {
12+
<!?INITIALIZATION?!>
13+
<!?BODY!?>
14+
<!?DEINITIALIZATION?!>
15+
}
16+
"""
17+
18+
acc_f90_template = """
19+
<!?PREPROCESSOR?!>
20+
<!?USER_DEFINES?!>
21+
22+
module kt
23+
use iso_c_binding
24+
contains
25+
26+
<!?SIGNATURE?!>
27+
<!?INITIALIZATION?!>
28+
<!?BODY!?>
29+
<!?DEINITIALIZATION?!>
30+
end function <!?NAME?!>
31+
32+
end module kt
33+
"""
34+
535

636
class Directive(ABC):
737
"""Base class for all directives"""
@@ -529,42 +559,35 @@ def generate_directive_function(
529559
) -> str:
530560
"""Generate tunable function for one directive"""
531561

532-
code = "\n".join(preprocessor) + "\n"
533-
if user_dimensions is not None:
534-
# add user dimensions to preprocessor
535-
for key, value in user_dimensions.items():
536-
code += f"#define {key} {value}\n"
537-
if is_cxx(langs.language) and "#include <chrono>" not in preprocessor:
538-
code += "\n#include <chrono>\n"
539-
if is_cxx(langs.language):
540-
code += 'extern "C" ' + signature + "{\n"
541-
elif is_fortran(langs.language):
542-
code += "\nmodule kt\nuse iso_c_binding\ncontains\n"
543-
code += "\n" + signature
544-
if len(initialization) > 1:
545-
code += initialization + "\n"
546-
if data is not None:
547-
body = add_present_openacc(body, langs, data, preprocessor, user_dimensions)
548562
if is_cxx(langs.language):
563+
code = acc_cpp_template
549564
body = start_timing_cxx(body)
550565
if data is not None:
551-
code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)
552-
else:
553-
code += body
554-
code = end_timing_cxx(code)
555-
if len(deinitialization) > 1:
556-
code += deinitialization + "\n"
557-
code += "\n}"
566+
body = wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)
567+
body += end_timing_cxx(body)
558568
elif is_fortran(langs.language):
569+
code = acc_f90_template
559570
body = wrap_timing(body, langs.language)
560571
if data is not None:
561-
code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)
562-
else:
563-
code += body + "\n"
564-
if len(deinitialization) > 1:
565-
code += deinitialization + "\n"
572+
body = wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)
566573
name = signature.split(" ")[1].split("(")[0]
567-
code += f"\nend function {name}\nend module kt\n"
574+
code = code.replace("<!?NAME!?>", name)
575+
code = code.replace("<!?PREPROCESSOR?!>", preprocessor)
576+
# if present, add user specific dimensions as defines
577+
if user_dimensions is not None:
578+
user_defines = ""
579+
for key, value in user_dimensions.items():
580+
user_defines += f"#define {key} {value}\n"
581+
code = code.replace("<!?USER_DEFINES?!>", user_defines)
582+
else:
583+
code = code.replace("<!?USER_DEFINES?!>", "")
584+
code = code.replace("<!?SIGNATURE?!>", signature)
585+
if len(initialization) > 1:
586+
code = code.replace("<!?INITIALIZATION?!>", initialization)
587+
if len(deinitialization) > 1:
588+
code = code.replace("<!?DEINITIALIZATION?!>", deinitialization)
589+
if data is not None:
590+
body = add_present_openacc(body, langs, data, preprocessor, user_dimensions)
568591

569592
return code
570593

0 commit comments

Comments
 (0)