@@ -32,13 +32,19 @@ class Cxx(Language):
32
32
def get (self ) -> str :
33
33
return "cxx"
34
34
35
+ def end_string (self ) -> str :
36
+ return "#pragma tuner stop"
37
+
35
38
36
39
class Fortran (Language ):
37
40
"""Class to represent Fortran code"""
38
41
39
42
def get (self ) -> str :
40
43
return "fortran"
41
44
45
+ def end_string (self ) -> str :
46
+ return "!$tuner stop"
47
+
42
48
43
49
class Code (object ):
44
50
"""Class to represent the directive and host code of the application"""
@@ -48,6 +54,48 @@ def __init__(self, directive: Directive, lang: Language):
48
54
self .language = lang
49
55
50
56
57
+ class ArraySize (object ):
58
+ """Size of an array"""
59
+
60
+ def __init__ (self ):
61
+ self .size = list ()
62
+
63
+ def __iter__ (self ):
64
+ for i in self .size :
65
+ yield i
66
+
67
+ def __len__ (self ):
68
+ return len (self .size )
69
+
70
+ def clear (self ):
71
+ self .size .clear ()
72
+
73
+ def get (self ) -> int :
74
+ length = len (self .size )
75
+ if length == 0 :
76
+ return 0
77
+ elif length == 1 :
78
+ return self .size [0 ]
79
+ else :
80
+ product = 1
81
+ for i in self .size :
82
+ product *= i
83
+ return product
84
+
85
+ def add (self , dim : int ) -> None :
86
+ # Only allow adding valid dimensions
87
+ if dim >= 1 :
88
+ self .size .append (dim )
89
+
90
+
91
+ def fortran_md_size (size : ArraySize ) -> list :
92
+ """Format a multidimensional size into the correct Fortran string"""
93
+ md_size = list ()
94
+ for dim in size :
95
+ md_size .append (f":{ dim } " )
96
+ return md_size
97
+
98
+
51
99
def is_openacc (directive : Directive ) -> bool :
52
100
"""Check if a directive is OpenACC"""
53
101
return isinstance (directive , OpenACC )
@@ -120,7 +168,7 @@ def openacc_directive_contains_data_clause(line: str) -> bool:
120
168
return openacc_directive_contains_clause (line , data_clauses )
121
169
122
170
123
- def create_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
171
+ def create_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
124
172
"""Create a data directive for a given language"""
125
173
if is_cxx (lang ):
126
174
return create_data_directive_openacc_cxx (name , size )
@@ -129,17 +177,23 @@ def create_data_directive_openacc(name: str, size: int, lang: Language) -> str:
129
177
return ""
130
178
131
179
132
- def create_data_directive_openacc_cxx (name : str , size : int ) -> str :
180
+ def create_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
133
181
"""Create C++ OpenACC code to allocate and copy data"""
134
- return f"#pragma acc enter data create({ name } [:{ size } ])\n #pragma acc update device({ name } [:{ size } ])\n "
182
+ return f"#pragma acc enter data create({ name } [:{ size . get () } ])\n #pragma acc update device({ name } [:{ size . get () } ])\n "
135
183
136
184
137
- def create_data_directive_openacc_fortran (name : str , size : int ) -> str :
185
+ def create_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
138
186
"""Create Fortran OpenACC code to allocate and copy data"""
139
- return f"!$acc enter data create({ name } (:{ size } ))\n !$acc update device({ name } (:{ size } ))\n "
187
+ if len (size ) == 1 :
188
+ return f"!$acc enter data create({ name } (:{ size .get ()} ))\n !$acc update device({ name } (:{ size .get ()} ))\n "
189
+ else :
190
+ md_size = fortran_md_size (size )
191
+ return (
192
+ f"!$acc enter data create({ name } ({ ',' .join (md_size )} ))\n !$acc update device({ name } ({ ',' .join (md_size )} ))\n "
193
+ )
140
194
141
195
142
- def exit_data_directive_openacc (name : str , size : int , lang : Language ) -> str :
196
+ def exit_data_directive_openacc (name : str , size : ArraySize , lang : Language ) -> str :
143
197
"""Create code to copy data back for a given language"""
144
198
if is_cxx (lang ):
145
199
return exit_data_directive_openacc_cxx (name , size )
@@ -148,14 +202,18 @@ def exit_data_directive_openacc(name: str, size: int, lang: Language) -> str:
148
202
return ""
149
203
150
204
151
- def exit_data_directive_openacc_cxx (name : str , size : int ) -> str :
205
+ def exit_data_directive_openacc_cxx (name : str , size : ArraySize ) -> str :
152
206
"""Create C++ OpenACC code to copy back data"""
153
- return f"#pragma acc exit data copyout({ name } [:{ size } ])\n "
207
+ return f"#pragma acc exit data copyout({ name } [:{ size . get () } ])\n "
154
208
155
209
156
- def exit_data_directive_openacc_fortran (name : str , size : int ) -> str :
210
+ def exit_data_directive_openacc_fortran (name : str , size : ArraySize ) -> str :
157
211
"""Create Fortran OpenACC code to copy back data"""
158
- return f"!$acc exit data copyout({ name } (:{ size } ))\n "
212
+ if len (size ) == 1 :
213
+ return f"!$acc exit data copyout({ name } (:{ size .get ()} ))\n "
214
+ else :
215
+ md_size = fortran_md_size (size )
216
+ return f"!$acc exit data copyout({ name } ({ ',' .join (md_size )} ))\n "
159
217
160
218
161
219
def correct_kernel (kernel_name : str , line : str ) -> bool :
@@ -165,7 +223,7 @@ def correct_kernel(kernel_name: str, line: str) -> bool:
165
223
166
224
def find_size_in_preprocessor (dimension : str , preprocessor : list ) -> int :
167
225
"""Find the dimension of a directive defined value in the preprocessor"""
168
- ret_size = None
226
+ ret_size = 0
169
227
for line in preprocessor :
170
228
if f"#define { dimension } " in line :
171
229
try :
@@ -209,45 +267,43 @@ def extract_code(start: str, stop: str, code: str, langs: Code, kernel_name: str
209
267
return sections
210
268
211
269
212
- def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> int :
270
+ def parse_size (size : Any , preprocessor : list = None , dimensions : dict = None ) -> ArraySize :
213
271
"""Converts an arbitrary object into an integer representing memory size"""
214
- ret_size = None
272
+ ret_size = ArraySize ()
215
273
if type (size ) is not int :
216
274
try :
217
275
# Try to convert the size to an integer
218
- ret_size = int (size )
276
+ ret_size . add ( int (size ) )
219
277
except ValueError :
220
278
# If size cannot be natively converted to an int, we try to derive it from the preprocessor
221
- if preprocessor is not None :
222
- try :
279
+ try :
280
+ if preprocessor is not None :
223
281
if "," in size :
224
- ret_size = 1
225
282
for dimension in size .split ("," ):
226
- ret_size *= find_size_in_preprocessor (dimension , preprocessor )
283
+ ret_size . add ( find_size_in_preprocessor (dimension , preprocessor ) )
227
284
else :
228
- ret_size = find_size_in_preprocessor (size , preprocessor )
229
- except TypeError :
230
- # preprocessor is available but does not contain the dimensions
231
- pass
285
+ ret_size . add ( find_size_in_preprocessor (size , preprocessor ) )
286
+ except TypeError :
287
+ # At least one of the dimension cannot be derived from the preprocessor
288
+ pass
232
289
# If size cannot be natively converted, nor retrieved from the preprocessor, we check user provided values
233
290
if dimensions is not None :
234
291
if size in dimensions .keys ():
235
292
try :
236
- ret_size = int (dimensions [size ])
293
+ ret_size . add ( int (dimensions [size ]) )
237
294
except ValueError :
238
295
# User error, no mitigation
239
296
return ret_size
240
297
elif "," in size :
241
- ret_size = 1
242
298
for dimension in size .split ("," ):
243
299
try :
244
- ret_size *= int (dimensions [dimension ])
300
+ ret_size . add ( int (dimensions [dimension ]) )
245
301
except ValueError :
246
302
# User error, no mitigation
247
- return None
303
+ return ret_size
248
304
else :
249
305
# size is already an int. no need for conversion
250
- ret_size = size
306
+ ret_size . add ( size )
251
307
252
308
return ret_size
253
309
@@ -306,24 +362,34 @@ def extract_directive_code(code: str, langs: Code, kernel_name: str = None) -> d
306
362
"""Extract explicitly marked directive sections from code"""
307
363
if is_cxx (langs .language ):
308
364
start_string = "#pragma tuner start"
309
- end_string = "#pragma tuner stop"
310
365
elif is_fortran (langs .language ):
311
366
start_string = "!$tuner start"
312
- end_string = "!$tuner stop"
313
367
314
- return extract_code (start_string , end_string , code , langs , kernel_name )
368
+ return extract_code (start_string , langs . language . end_string () , code , langs , kernel_name )
315
369
316
370
317
371
def extract_initialization_code (code : str , langs : Code ) -> str :
318
372
"""Extract the initialization section from code"""
319
373
if is_cxx (langs .language ):
320
374
start_string = "#pragma tuner initialize"
321
- end_string = "#pragma tuner stop"
322
375
elif is_fortran (langs .language ):
323
376
start_string = "!$tuner initialize"
324
- end_string = "!$tuner stop"
325
377
326
- init_code = extract_code (start_string , end_string , code , langs )
378
+ init_code = extract_code (start_string , langs .language .end_string (), code , langs )
379
+ if len (init_code ) >= 1 :
380
+ return "\n " .join (init_code .values ()) + "\n "
381
+ else :
382
+ return ""
383
+
384
+
385
+ def extract_deinitialization_code (code : str , langs : Code ) -> str :
386
+ """Extract the deinitialization section from code"""
387
+ if is_cxx (langs .language ):
388
+ start_string = "#pragma tuner deinitialize"
389
+ elif is_fortran (langs .language ):
390
+ start_string = "!$tuner deinitialize"
391
+
392
+ init_code = extract_code (start_string , langs .language .end_string (), code , langs )
327
393
if len (init_code ) >= 1 :
328
394
return "\n " .join (init_code .values ()) + "\n "
329
395
else :
@@ -458,6 +524,7 @@ def generate_directive_function(
458
524
langs : Code ,
459
525
data : dict = None ,
460
526
initialization : str = "" ,
527
+ deinitialization : str = "" ,
461
528
user_dimensions : dict = None ,
462
529
) -> str :
463
530
"""Generate tunable function for one directive"""
@@ -485,13 +552,17 @@ def generate_directive_function(
485
552
else :
486
553
code += body
487
554
code = end_timing_cxx (code )
555
+ if len (deinitialization ) > 1 :
556
+ code += deinitialization + "\n "
488
557
code += "\n }"
489
558
elif is_fortran (langs .language ):
490
559
body = wrap_timing (body , langs .language )
491
560
if data is not None :
492
561
code += wrap_data (body + "\n " , langs , data , preprocessor , user_dimensions )
493
562
else :
494
- code += body
563
+ code += body + "\n "
564
+ if len (deinitialization ) > 1 :
565
+ code += deinitialization + "\n "
495
566
name = signature .split (" " )[1 ].split ("(" )[0 ]
496
567
code += f"\n end function { name } \n end module kt\n "
497
568
@@ -537,9 +608,9 @@ def allocate_signature_memory(data: dict, preprocessor: list = None, user_dimens
537
608
p_type = data [parameter ][0 ]
538
609
size = parse_size (data [parameter ][1 ], preprocessor , user_dimensions )
539
610
if "*" in p_type :
540
- args .append (allocate_array (p_type , size ))
611
+ args .append (allocate_array (p_type , size . get () ))
541
612
else :
542
- args .append (allocate_scalar (p_type , size ))
613
+ args .append (allocate_scalar (p_type , size . get () ))
543
614
544
615
return args
545
616
@@ -579,11 +650,15 @@ def add_present_openacc(
579
650
return new_body
580
651
581
652
582
- def add_present_openacc_cxx (name : str , size : int ) -> str :
653
+ def add_present_openacc_cxx (name : str , size : ArraySize ) -> str :
583
654
"""Create present clause for C++ OpenACC directive"""
584
- return f" present({ name } [:{ size } ]) "
655
+ return f" present({ name } [:{ size . get () } ]) "
585
656
586
657
587
- def add_present_openacc_fortran (name : str , size : int ) -> str :
658
+ def add_present_openacc_fortran (name : str , size : ArraySize ) -> str :
588
659
"""Create present clause for Fortran OpenACC directive"""
589
- return f" present({ name } (:{ size } )) "
660
+ if len (size ) == 1 :
661
+ return f" present({ name } (:{ size .get ()} )) "
662
+ else :
663
+ md_size = fortran_md_size (size )
664
+ return f" present({ name } ({ ',' .join (md_size )} )) "
0 commit comments