@@ -32,13 +32,19 @@ class Cxx(Language):
32
32
def get (self ) -> str :
33
33
return "cxx"
34
34
35
+ def end_string (self ) -> str :
36
+ return "#pragma tuner stop"
37
+
35
38
36
39
class Fortran (Language ):
37
40
"""Class to represent Fortran code"""
38
41
39
42
def get (self ) -> str :
40
43
return "fortran"
41
44
45
+ def end_string (self ) -> str :
46
+ return "!$tuner stop"
47
+
42
48
43
49
class Code (object ):
44
50
"""Class to represent the directive and host code of the application"""
@@ -356,24 +362,34 @@ def extract_directive_code(code: str, langs: Code, kernel_name: str = None) -> d
356
362
"""Extract explicitly marked directive sections from code"""
357
363
if is_cxx (langs .language ):
358
364
start_string = "#pragma tuner start"
359
- end_string = "#pragma tuner stop"
360
365
elif is_fortran (langs .language ):
361
366
start_string = "!$tuner start"
362
- end_string = "!$tuner stop"
363
367
364
- return extract_code (start_string , end_string , code , langs , kernel_name )
368
+ return extract_code (start_string , langs . language . end_string () , code , langs , kernel_name )
365
369
366
370
367
371
def extract_initialization_code (code : str , langs : Code ) -> str :
368
372
"""Extract the initialization section from code"""
369
373
if is_cxx (langs .language ):
370
374
start_string = "#pragma tuner initialize"
371
- end_string = "#pragma tuner stop"
372
375
elif is_fortran (langs .language ):
373
376
start_string = "!$tuner initialize"
374
- end_string = "!$tuner stop"
375
377
376
- init_code = extract_code (start_string , end_string , code , langs )
378
+ init_code = extract_code (start_string , langs .language .end_string (), code , langs )
379
+ if len (init_code ) >= 1 :
380
+ return "\n " .join (init_code .values ()) + "\n "
381
+ else :
382
+ return ""
383
+
384
+
385
+ def extract_deinitialization_code (code : str , langs : Code ) -> str :
386
+ """Extract the deinitialization section from code"""
387
+ if is_cxx (langs .language ):
388
+ start_string = "#pragma tuner deinitialize"
389
+ elif is_fortran (langs .language ):
390
+ start_string = "!$tuner deinitialize"
391
+
392
+ init_code = extract_code (start_string , langs .language .end_string (), code , langs )
377
393
if len (init_code ) >= 1 :
378
394
return "\n " .join (init_code .values ()) + "\n "
379
395
else :
@@ -508,6 +524,7 @@ def generate_directive_function(
508
524
langs : Code ,
509
525
data : dict = None ,
510
526
initialization : str = "" ,
527
+ deinitialization : str = "" ,
511
528
user_dimensions : dict = None ,
512
529
) -> str :
513
530
"""Generate tunable function for one directive"""
@@ -535,13 +552,17 @@ def generate_directive_function(
535
552
else :
536
553
code += body
537
554
code = end_timing_cxx (code )
555
+ if len (deinitialization ) > 1 :
556
+ code += deinitialization + "\n "
538
557
code += "\n }"
539
558
elif is_fortran (langs .language ):
540
559
body = wrap_timing (body , langs .language )
541
560
if data is not None :
542
561
code += wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
543
562
else :
544
- code += body
563
+ code += body + "\n "
564
+ if len (deinitialization ) > 1 :
565
+ code += deinitialization + "\n "
545
566
name = signature .split (" " )[1 ].split ("(" )[0 ]
546
567
code += f"\n end function { name } \n end module kt\n "
547
568
0 commit comments