1
- # Copyright (c) 2018-2019 , Intel Corporation
1
+ # Copyright (c) 2018-2024 , Intel Corporation
2
2
#
3
3
# Redistribution and use in source and binary forms, with or without
4
4
# modification, are permitted provided that the following conditions are met:
28
28
import mkl
29
29
30
30
31
- def test_get_version ():
32
- v = mkl .get_version ()
33
- assert isinstance (v , dict )
34
- assert 'MajorVersion' in v
35
- assert 'MinorVersion' in v
36
- assert 'UpdateVersion' in v
37
-
38
-
39
- def test_get_version_string ():
40
- v = mkl .get_version_string ()
41
- assert isinstance (v , str )
42
- assert 'Math Kernel Library' in v
43
-
44
-
45
31
def test_set_num_threads ():
46
32
saved = mkl .get_max_threads ()
47
- half_nt = int ( (1 + saved ) / 2 )
33
+ half_nt = int ( (1 + saved ) / 2 )
48
34
mkl .set_num_threads (half_nt )
49
35
assert mkl .get_max_threads () == half_nt
50
36
mkl .set_num_threads (saved )
51
37
38
+
52
39
def test_domain_set_num_threads_blas ():
53
40
saved_blas_nt = mkl .domain_get_max_threads (domain = 'blas' )
54
41
saved_fft_nt = mkl .domain_get_max_threads (domain = 'fft' )
@@ -75,22 +62,27 @@ def test_domain_set_num_threads_blas():
75
62
status = mkl .domain_set_num_threads (saved_vml_nt , domain = 'vml' )
76
63
assert status == 'success'
77
64
65
+
78
66
def test_domain_set_num_threads_fft ():
79
67
status = mkl .domain_set_num_threads (4 , domain = 'fft' )
80
68
assert status == 'success'
81
69
70
+
82
71
def test_domain_set_num_threads_vml ():
83
72
status = mkl .domain_set_num_threads (4 , domain = 'vml' )
84
73
assert status == 'success'
85
74
75
+
86
76
def test_domain_set_num_threads_pardiso ():
87
77
status = mkl .domain_set_num_threads (4 , domain = 'pardiso' )
88
78
assert status == 'success'
89
79
80
+
90
81
def test_domain_set_num_threads_all ():
91
82
status = mkl .domain_set_num_threads (4 , domain = 'all' )
92
83
assert status == 'success'
93
84
85
+
94
86
def test_set_num_threads_local ():
95
87
mkl .set_num_threads (1 )
96
88
status = mkl .set_num_threads_local (2 )
@@ -102,27 +94,35 @@ def test_set_num_threads_local():
102
94
status = mkl .set_num_threads_local (8 )
103
95
assert status == 'global_num_threads'
104
96
97
+
105
98
def test_set_dynamic ():
106
99
mkl .set_dynamic (True )
107
100
101
+
108
102
def test_get_max_threads ():
109
103
mkl .get_max_threads ()
110
104
105
+
111
106
def test_domain_get_max_threads_blas ():
112
107
mkl .domain_get_max_threads (domain = 'blas' )
113
108
109
+
114
110
def test_domain_get_max_threads_fft ():
115
111
mkl .domain_get_max_threads (domain = 'fft' )
116
112
113
+
117
114
def test_domain_get_max_threads_vml ():
118
115
mkl .domain_get_max_threads (domain = 'vml' )
119
116
117
+
120
118
def test_domain_get_max_threads_pardiso ():
121
119
mkl .domain_get_max_threads (domain = 'pardiso' )
122
120
121
+
123
122
def test_domain_get_max_threads_all ():
124
123
mkl .domain_get_max_threads (domain = 'all' )
125
124
125
+
126
126
def test_get_dynamic ():
127
127
mkl .get_dynamic ()
128
128
@@ -134,54 +134,80 @@ def test_second():
134
134
delta = s2 - s1
135
135
assert delta >= 0
136
136
137
+
137
138
def test_dsecnd ():
138
139
d1 = mkl .dsecnd ()
139
140
d2 = mkl .dsecnd ()
140
141
delta = d2 - d1
141
142
assert delta >= 0
142
143
144
+
143
145
def test_get_cpu_clocks ():
144
146
c1 = mkl .get_cpu_clocks ()
145
147
c2 = mkl .get_cpu_clocks ()
146
148
delta = c2 - c1
147
149
assert delta >= 0
148
150
151
+
149
152
def test_get_cpu_frequency ():
150
153
assert mkl .get_cpu_frequency () >= 0
151
154
155
+
152
156
def test_get_max_cpu_frequency ():
153
157
assert mkl .get_max_cpu_frequency () >= 0
154
158
159
+
155
160
def test_get_clocks_frequency ():
156
161
assert mkl .get_clocks_frequency () >= 0
157
162
158
163
159
164
def test_free_buffers ():
160
165
mkl .free_buffers ()
161
166
167
+
162
168
def test_thread_free_buffers ():
163
169
mkl .thread_free_buffers ()
164
170
171
+
165
172
def test_disable_fast_mm ():
166
173
mkl .disable_fast_mm ()
167
174
175
+
168
176
def test_mem_stat ():
169
177
mkl .mem_stat ()
170
178
179
+
171
180
def test_peak_mem_usage_enable ():
172
181
mkl .peak_mem_usage ('enable' )
173
182
183
+
174
184
def test_peak_mem_usage_disable ():
175
185
mkl .peak_mem_usage ('disable' )
176
186
187
+
177
188
def test_peak_mem_usage_peak_mem ():
178
189
mkl .peak_mem_usage ('peak_mem' )
179
190
191
+
180
192
def test_peak_mem_usage_peak_mem_reset ():
181
193
mkl .peak_mem_usage ('peak_mem_reset' )
182
194
195
+
183
196
def test_set_memory_limit ():
184
- mkl .set_memory_limit (128 )
197
+ mkl .set_memory_limit (2 ** 16 )
198
+
199
+
200
+ def check_cbwr (branch , cnr_const ):
201
+ status = mkl .cbwr_set (branch = branch )
202
+ if status == 'success' :
203
+ expected_value = 'branch_off' if branch == 'off' else branch
204
+ actual_value = mkl .cbwr_get (cnr_const = cnr_const )
205
+ assert actual_value == expected_value , \
206
+ f"Round-trip failure for CNR branch '{ branch } ', CNR const '{ cnr_const } "
207
+ elif status not in ['err_unsupported_branch' , 'err_mode_change_failure' ]:
208
+ # if MKL has been initialized already,
209
+ # setting CBWR will error with mode_change_failure
210
+ pytest .fail (status )
185
211
186
212
187
213
branches = [
@@ -200,29 +226,25 @@ def test_set_memory_limit():
200
226
'avx512_mic_e1' ,
201
227
'avx512_e1' ,
202
228
]
229
+
230
+
203
231
strict = [
204
232
'avx2,strict' ,
205
233
'avx512_mic,strict' ,
206
234
'avx512,strict' ,
207
235
'avx512_e1,strict' ,
208
236
]
237
+
238
+
209
239
@pytest .mark .parametrize ('branch' , branches )
210
240
def test_cbwr_branch (branch ):
211
241
check_cbwr (branch , 'branch' )
212
242
243
+
213
244
@pytest .mark .parametrize ('branch' , branches + strict )
214
245
def test_cbwr_all (branch ):
215
246
check_cbwr (branch , 'all' )
216
247
217
- def check_cbwr (branch , cnr_const ):
218
- status = mkl .cbwr_set (branch = branch )
219
- if status == 'success' :
220
- expected_value = 'branch_off' if branch == 'off' else branch
221
- actual_value = mkl .cbwr_get (cnr_const = cnr_const )
222
- assert actual_value == expected_value , \
223
- f"Round-trip failure for CNR branch '{ branch } ', CNR const '{ cnr_const } "
224
- elif status != 'err_unsupported_branch' :
225
- pytest .fail (status )
226
248
227
249
def test_cbwr_get_auto_branch ():
228
250
mkl .cbwr_get_auto_branch ()
@@ -231,45 +253,58 @@ def test_cbwr_get_auto_branch():
231
253
def test_enable_instructions_avx512_mic_e1 ():
232
254
mkl .enable_instructions ('avx512_mic_e1' )
233
255
256
+
234
257
def test_enable_instructions_avx512 ():
235
258
mkl .enable_instructions ('avx512' )
236
259
260
+
237
261
def test_enable_instructions_avx512_mic ():
238
262
mkl .enable_instructions ('avx512_mic' )
239
263
264
+
240
265
def test_enable_instructions_avx2 ():
241
266
mkl .enable_instructions ('avx2' )
242
267
268
+
243
269
def test_enable_instructions_avx ():
244
270
mkl .enable_instructions ('avx' )
245
271
272
+
246
273
def test_enable_instructions_sse4_2 ():
247
274
mkl .enable_instructions ('sse4_2' )
248
275
276
+
249
277
def test_set_env_mode ():
250
278
mkl .set_env_mode ()
251
279
280
+
252
281
def test_get_env_mode ():
253
282
mkl .get_env_mode ()
254
283
284
+
255
285
def test_verbose_false ():
256
286
mkl .verbose (False )
257
287
288
+
258
289
def test_verbose_true ():
259
290
mkl .verbose (True )
260
291
292
+
261
293
@pytest .mark .skip (reason = "Skipping MPI-related test" )
262
294
def test_set_mpi_custom ():
263
295
mkl .set_mpi ('custom' , 'custom_library_name' )
264
296
297
+
265
298
@pytest .mark .skip (reason = "Skipping MPI-related test" )
266
299
def test_set_mpi_msmpi ():
267
300
mkl .set_mpi ('msmpi' )
268
301
302
+
269
303
@pytest .mark .skip (reason = "Skipping MPI-related test" )
270
304
def test_set_mpi_intelmpi ():
271
305
mkl .set_mpi ('intelmpi' )
272
306
307
+
273
308
@pytest .mark .skip (reason = "Skipping MPI-related test" )
274
309
def test_set_mpi_mpich2 ():
275
310
mkl .set_mpi ('mpich2' )
@@ -279,53 +314,91 @@ def test_vml_set_get_mode_roundtrip():
279
314
saved = mkl .vml_get_mode ()
280
315
mkl .vml_set_mode (* saved ) # should not raise errors
281
316
317
+
282
318
def test_vml_set_mode_ha_on_ignore ():
283
319
mkl .vml_set_mode ('ha' , 'on' , 'ignore' )
284
320
321
+
285
322
def test_vml_set_mode_ha_on_errno ():
286
323
mkl .vml_set_mode ('ha' , 'on' , 'errno' )
287
324
325
+
288
326
def test_vml_set_mode_la_on_stderr ():
289
327
mkl .vml_set_mode ('la' , 'on' , 'stderr' )
290
328
329
+
291
330
def test_vml_set_mode_la_off_except ():
292
331
mkl .vml_set_mode ('la' , 'off' , 'except' )
293
332
333
+
294
334
def test_vml_set_mode_op_off_callback ():
295
335
mkl .vml_set_mode ('ep' , 'off' , 'callback' )
296
336
337
+
297
338
def test_vml_set_mode_ep_off_default ():
298
339
mkl .vml_set_mode ('ep' , 'off' , 'default' )
299
340
341
+
300
342
def test_vml_get_mode ():
301
343
mkl .vml_get_mode ()
302
344
345
+
303
346
def test_vml_set_err_status_ok ():
304
347
mkl .vml_set_err_status ('ok' )
305
348
349
+
306
350
def test_vml_set_err_status_accuracywarning ():
307
351
mkl .vml_set_err_status ('accuracywarning' )
308
352
353
+
309
354
def test_vml_set_err_status_badsize ():
310
355
mkl .vml_set_err_status ('badsize' )
311
356
357
+
312
358
def test_vml_set_err_status_badmem ():
313
359
mkl .vml_set_err_status ('badmem' )
314
360
361
+
315
362
def test_vml_set_err_status_errdom ():
316
363
mkl .vml_set_err_status ('errdom' )
317
364
365
+
318
366
def test_vml_set_err_status_sing ():
319
367
mkl .vml_set_err_status ('sing' )
320
368
369
+
321
370
def test_vml_set_err_status_overflow ():
322
371
mkl .vml_set_err_status ('overflow' )
323
372
373
+
324
374
def test_vml_set_err_status_underflow ():
325
375
mkl .vml_set_err_status ('underflow' )
326
376
377
+
327
378
def test_vml_get_err_status ():
328
379
mkl .vml_get_err_status ()
329
380
381
+
330
382
def test_vml_clear_err_status ():
331
383
mkl .vml_clear_err_status ()
384
+
385
+
386
+ def test_get_version ():
387
+ """
388
+ Version info sets mode of MKL library, such as
389
+ instruction pathways and conditional numerical
390
+ reproducibility regime. This test is moved to
391
+ the bottom to allow proper testing of functions
392
+ controllign those.
393
+ """
394
+ v = mkl .get_version ()
395
+ assert isinstance (v , dict )
396
+ assert 'MajorVersion' in v
397
+ assert 'MinorVersion' in v
398
+ assert 'UpdateVersion' in v
399
+
400
+
401
+ def test_get_version_string ():
402
+ v = mkl .get_version_string ()
403
+ assert isinstance (v , str )
404
+ assert 'Math Kernel Library' in v
0 commit comments