|
2 | 2 | from abc import ABC, abstractmethod
|
3 | 3 | import numpy as np
|
4 | 4 |
|
| 5 | +# Function templates |
| 6 | +acc_cpp_template = """ |
| 7 | +<!?PREPROCESSOR?!> |
| 8 | +<!?USER_DEFINES?!> |
| 9 | +#include <chrono> |
| 10 | +
|
| 11 | +extern "C" <!?SIGNATURE?!> { |
| 12 | +<!?INITIALIZATION?!> |
| 13 | +<!?BODY!?> |
| 14 | +<!?DEINITIALIZATION?!> |
| 15 | +} |
| 16 | +""" |
| 17 | + |
| 18 | +acc_f90_template = """ |
| 19 | +<!?PREPROCESSOR?!> |
| 20 | +<!?USER_DEFINES?!> |
| 21 | +
|
| 22 | +module kt |
| 23 | +use iso_c_binding |
| 24 | +contains |
| 25 | +
|
| 26 | +<!?SIGNATURE?!> |
| 27 | +<!?INITIALIZATION?!> |
| 28 | +<!?BODY!?> |
| 29 | +<!?DEINITIALIZATION?!> |
| 30 | +end function <!?NAME?!> |
| 31 | +
|
| 32 | +end module kt |
| 33 | +""" |
| 34 | + |
5 | 35 |
|
6 | 36 | class Directive(ABC):
|
7 | 37 | """Base class for all directives"""
|
@@ -529,42 +559,35 @@ def generate_directive_function(
|
529 | 559 | ) -> str:
|
530 | 560 | """Generate tunable function for one directive"""
|
531 | 561 |
|
532 |
| - code = "\n".join(preprocessor) + "\n" |
533 |
| - if user_dimensions is not None: |
534 |
| - # add user dimensions to preprocessor |
535 |
| - for key, value in user_dimensions.items(): |
536 |
| - code += f"#define {key} {value}\n" |
537 |
| - if is_cxx(langs.language) and "#include <chrono>" not in preprocessor: |
538 |
| - code += "\n#include <chrono>\n" |
539 |
| - if is_cxx(langs.language): |
540 |
| - code += 'extern "C" ' + signature + "{\n" |
541 |
| - elif is_fortran(langs.language): |
542 |
| - code += "\nmodule kt\nuse iso_c_binding\ncontains\n" |
543 |
| - code += "\n" + signature |
544 |
| - if len(initialization) > 1: |
545 |
| - code += initialization + "\n" |
546 |
| - if data is not None: |
547 |
| - body = add_present_openacc(body, langs, data, preprocessor, user_dimensions) |
548 | 562 | if is_cxx(langs.language):
|
| 563 | + code = acc_cpp_template |
549 | 564 | body = start_timing_cxx(body)
|
550 | 565 | if data is not None:
|
551 |
| - code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions) |
552 |
| - else: |
553 |
| - code += body |
554 |
| - code = end_timing_cxx(code) |
555 |
| - if len(deinitialization) > 1: |
556 |
| - code += deinitialization + "\n" |
557 |
| - code += "\n}" |
| 566 | + body = wrap_data(body + "\n", langs, data, preprocessor, user_dimensions) |
| 567 | + body += end_timing_cxx(body) |
558 | 568 | elif is_fortran(langs.language):
|
| 569 | + code = acc_f90_template |
559 | 570 | body = wrap_timing(body, langs.language)
|
560 | 571 | if data is not None:
|
561 |
| - code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions) |
562 |
| - else: |
563 |
| - code += body + "\n" |
564 |
| - if len(deinitialization) > 1: |
565 |
| - code += deinitialization + "\n" |
| 572 | + body = wrap_data(body + "\n", langs, data, preprocessor, user_dimensions) |
566 | 573 | name = signature.split(" ")[1].split("(")[0]
|
567 |
| - code += f"\nend function {name}\nend module kt\n" |
| 574 | + code = code.replace("<!?NAME!?>", name) |
| 575 | + code = code.replace("<!?PREPROCESSOR?!>", preprocessor) |
| 576 | + # if present, add user specific dimensions as defines |
| 577 | + if user_dimensions is not None: |
| 578 | + user_defines = "" |
| 579 | + for key, value in user_dimensions.items(): |
| 580 | + user_defines += f"#define {key} {value}\n" |
| 581 | + code = code.replace("<!?USER_DEFINES?!>", user_defines) |
| 582 | + else: |
| 583 | + code = code.replace("<!?USER_DEFINES?!>", "") |
| 584 | + code = code.replace("<!?SIGNATURE?!>", signature) |
| 585 | + if len(initialization) > 1: |
| 586 | + code = code.replace("<!?INITIALIZATION?!>", initialization) |
| 587 | + if len(deinitialization) > 1: |
| 588 | + code = code.replace("<!?DEINITIALIZATION?!>", deinitialization) |
| 589 | + if data is not None: |
| 590 | + body = add_present_openacc(body, langs, data, preprocessor, user_dimensions) |
568 | 591 |
|
569 | 592 | return code
|
570 | 593 |
|
|
0 commit comments