Skip to content

Commit e046cfd

Browse files
authored
Merge pull request #267 from KernelTuner/directives
ESiWACE3 hackathon
2 parents abd8de0 + 9b88aaa commit e046cfd

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

Diff for: kernel_tuner/utils/directives.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,19 @@ class Cxx(Language):
3232
def get(self) -> str:
3333
return "cxx"
3434

35+
def end_string(self) -> str:
36+
return "#pragma tuner stop"
37+
3538

3639
class Fortran(Language):
3740
"""Class to represent Fortran code"""
3841

3942
def get(self) -> str:
4043
return "fortran"
4144

45+
def end_string(self) -> str:
46+
return "!$tuner stop"
47+
4248

4349
class Code(object):
4450
"""Class to represent the directive and host code of the application"""
@@ -356,24 +362,34 @@ def extract_directive_code(code: str, langs: Code, kernel_name: str = None) -> d
356362
"""Extract explicitly marked directive sections from code"""
357363
if is_cxx(langs.language):
358364
start_string = "#pragma tuner start"
359-
end_string = "#pragma tuner stop"
360365
elif is_fortran(langs.language):
361366
start_string = "!$tuner start"
362-
end_string = "!$tuner stop"
363367

364-
return extract_code(start_string, end_string, code, langs, kernel_name)
368+
return extract_code(start_string, langs.language.end_string(), code, langs, kernel_name)
365369

366370

367371
def extract_initialization_code(code: str, langs: Code) -> str:
368372
"""Extract the initialization section from code"""
369373
if is_cxx(langs.language):
370374
start_string = "#pragma tuner initialize"
371-
end_string = "#pragma tuner stop"
372375
elif is_fortran(langs.language):
373376
start_string = "!$tuner initialize"
374-
end_string = "!$tuner stop"
375377

376-
init_code = extract_code(start_string, end_string, code, langs)
378+
init_code = extract_code(start_string, langs.language.end_string(), code, langs)
379+
if len(init_code) >= 1:
380+
return "\n".join(init_code.values()) + "\n"
381+
else:
382+
return ""
383+
384+
385+
def extract_deinitialization_code(code: str, langs: Code) -> str:
386+
"""Extract the deinitialization section from code"""
387+
if is_cxx(langs.language):
388+
start_string = "#pragma tuner deinitialize"
389+
elif is_fortran(langs.language):
390+
start_string = "!$tuner deinitialize"
391+
392+
init_code = extract_code(start_string, langs.language.end_string(), code, langs)
377393
if len(init_code) >= 1:
378394
return "\n".join(init_code.values()) + "\n"
379395
else:
@@ -508,6 +524,7 @@ def generate_directive_function(
508524
langs: Code,
509525
data: dict = None,
510526
initialization: str = "",
527+
deinitialization: str = "",
511528
user_dimensions: dict = None,
512529
) -> str:
513530
"""Generate tunable function for one directive"""
@@ -535,13 +552,17 @@ def generate_directive_function(
535552
else:
536553
code += body
537554
code = end_timing_cxx(code)
555+
if len(deinitialization) > 1:
556+
code += deinitialization + "\n"
538557
code += "\n}"
539558
elif is_fortran(langs.language):
540559
body = wrap_timing(body, langs.language)
541560
if data is not None:
542561
code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)
543562
else:
544-
code += body
563+
code += body + "\n"
564+
if len(deinitialization) > 1:
565+
code += deinitialization + "\n"
545566
name = signature.split(" ")[1].split("(")[0]
546567
code += f"\nend function {name}\nend module kt\n"
547568

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ generate-setup-file = false
5757
# ATTENTION: if anything is changed here, run `poetry update`
5858
[tool.poetry.dependencies]
5959
python = ">=3.9,<3.13" # NOTE when changing the supported Python versions, also change the test versions in the noxfile
60-
numpy = ">=1.26.0" # Python 3.12 requires numpy at least 1.26
60+
numpy = "^1.26.0" # Python 3.12 requires numpy at least 1.26
6161
scipy = ">=1.11.0"
6262
packaging = "*" # required by file_utils
6363
jsonschema = "*"

Diff for: test/utils/test_directives.py

+7
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,13 @@ def test_extract_initialization_code():
326326
assert extract_initialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n"
327327

328328

329+
def test_extract_deinitialization_code():
330+
code_cpp = "#pragma tuner deinitialize\nconst int value = 42;\n#pragma tuner stop\n"
331+
code_f90 = "!$tuner deinitialize\ninteger :: value\n!$tuner stop\n"
332+
assert extract_deinitialization_code(code_cpp, Code(OpenACC(), Cxx())) == "const int value = 42;\n"
333+
assert extract_deinitialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n"
334+
335+
329336
def test_add_present_openacc():
330337
acc_cxx = Code(OpenACC(), Cxx())
331338
acc_f90 = Code(OpenACC(), Fortran())

0 commit comments

Comments
 (0)