| 
2 | 2 | from abc import ABC, abstractmethod  | 
3 | 3 | import numpy as np  | 
4 | 4 | 
 
  | 
 | 5 | +# Function templates  | 
 | 6 | +acc_cpp_template = """  | 
 | 7 | +<!?PREPROCESSOR?!>  | 
 | 8 | +<!?USER_DEFINES?!>  | 
 | 9 | +#include <chrono>  | 
 | 10 | +
  | 
 | 11 | +extern "C" <!?SIGNATURE?!> {  | 
 | 12 | +<!?INITIALIZATION?!>  | 
 | 13 | +<!?BODY!?>  | 
 | 14 | +<!?DEINITIALIZATION?!>  | 
 | 15 | +}  | 
 | 16 | +"""  | 
 | 17 | + | 
 | 18 | +acc_f90_template = """  | 
 | 19 | +<!?PREPROCESSOR?!>  | 
 | 20 | +<!?USER_DEFINES?!>  | 
 | 21 | +
  | 
 | 22 | +module kt  | 
 | 23 | +use iso_c_binding  | 
 | 24 | +contains  | 
 | 25 | +
  | 
 | 26 | +<!?SIGNATURE?!>  | 
 | 27 | +<!?INITIALIZATION?!>  | 
 | 28 | +<!?BODY!?>  | 
 | 29 | +<!?DEINITIALIZATION?!>  | 
 | 30 | +end function <!?NAME?!>  | 
 | 31 | +
  | 
 | 32 | +end module kt  | 
 | 33 | +"""  | 
 | 34 | + | 
5 | 35 | 
 
  | 
6 | 36 | class Directive(ABC):  | 
7 | 37 |     """Base class for all directives"""  | 
@@ -529,42 +559,35 @@ def generate_directive_function(  | 
529 | 559 | ) -> str:  | 
530 | 560 |     """Generate tunable function for one directive"""  | 
531 | 561 | 
 
  | 
532 |  | -    code = "\n".join(preprocessor) + "\n"  | 
533 |  | -    if user_dimensions is not None:  | 
534 |  | -        # add user dimensions to preprocessor  | 
535 |  | -        for key, value in user_dimensions.items():  | 
536 |  | -            code += f"#define {key} {value}\n"  | 
537 |  | -    if is_cxx(langs.language) and "#include <chrono>" not in preprocessor:  | 
538 |  | -        code += "\n#include <chrono>\n"  | 
539 |  | -    if is_cxx(langs.language):  | 
540 |  | -        code += 'extern "C" ' + signature + "{\n"  | 
541 |  | -    elif is_fortran(langs.language):  | 
542 |  | -        code += "\nmodule kt\nuse iso_c_binding\ncontains\n"  | 
543 |  | -        code += "\n" + signature  | 
544 |  | -    if len(initialization) > 1:  | 
545 |  | -        code += initialization + "\n"  | 
546 |  | -    if data is not None:  | 
547 |  | -        body = add_present_openacc(body, langs, data, preprocessor, user_dimensions)  | 
548 | 562 |     if is_cxx(langs.language):  | 
 | 563 | +        code = acc_cpp_template  | 
549 | 564 |         body = start_timing_cxx(body)  | 
550 | 565 |         if data is not None:  | 
551 |  | -            code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)  | 
552 |  | -        else:  | 
553 |  | -            code += body  | 
554 |  | -        code = end_timing_cxx(code)  | 
555 |  | -        if len(deinitialization) > 1:  | 
556 |  | -            code += deinitialization + "\n"  | 
557 |  | -        code += "\n}"  | 
 | 566 | +            body = wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)  | 
 | 567 | +        body += end_timing_cxx(body)  | 
558 | 568 |     elif is_fortran(langs.language):  | 
 | 569 | +        code = acc_f90_template  | 
559 | 570 |         body = wrap_timing(body, langs.language)  | 
560 | 571 |         if data is not None:  | 
561 |  | -            code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)  | 
562 |  | -        else:  | 
563 |  | -            code += body + "\n"  | 
564 |  | -        if len(deinitialization) > 1:  | 
565 |  | -            code += deinitialization + "\n"  | 
 | 572 | +            body = wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)  | 
566 | 573 |         name = signature.split(" ")[1].split("(")[0]  | 
567 |  | -        code += f"\nend function {name}\nend module kt\n"  | 
 | 574 | +        code = code.replace("<!?NAME!?>", name)  | 
 | 575 | +    code = code.replace("<!?PREPROCESSOR?!>", preprocessor)  | 
 | 576 | +    # if present, add user specific dimensions as defines  | 
 | 577 | +    if user_dimensions is not None:  | 
 | 578 | +        user_defines = ""  | 
 | 579 | +        for key, value in user_dimensions.items():  | 
 | 580 | +            user_defines += f"#define {key} {value}\n"  | 
 | 581 | +        code = code.replace("<!?USER_DEFINES?!>", user_defines)  | 
 | 582 | +    else:  | 
 | 583 | +        code = code.replace("<!?USER_DEFINES?!>", "")  | 
 | 584 | +    code = code.replace("<!?SIGNATURE?!>", signature)  | 
 | 585 | +    if len(initialization) > 1:  | 
 | 586 | +        code = code.replace("<!?INITIALIZATION?!>", initialization)  | 
 | 587 | +    if len(deinitialization) > 1:  | 
 | 588 | +        code = code.replace("<!?DEINITIALIZATION?!>", deinitialization)  | 
 | 589 | +    if data is not None:  | 
 | 590 | +        body = add_present_openacc(body, langs, data, preprocessor, user_dimensions)  | 
568 | 591 | 
 
  | 
569 | 592 |     return code  | 
570 | 593 | 
 
  | 
 | 
0 commit comments