25
25
# delay load so doc build doesn't require compilation
26
26
# from . import reflmodule
27
27
28
+ try :
29
+ # raise ImportError() # uncomment to force numba off
30
+ from numba import njit , prange
31
+ USE_NUMBA = True
32
+ except ImportError :
33
+ USE_NUMBA = False
34
+ # if no numba then njit does nothing
35
+
36
+ def njit (* args , ** kw ):
37
+ # Check for bare @njit, in which case we just return the function.
38
+ if len (args ) == 1 and callable (args [0 ]) and not kw :
39
+ return args [0 ]
40
+ # Otherwise we have @njit(...), so return the identity decorator.
41
+ return lambda fn : fn
42
+
43
+ try :
44
+ from numba import cuda
45
+ USE_CUDA = bool (cuda .list_devices ()) # True if cuda device is available
46
+ except ImportError :
47
+ USE_CUDA = False
48
+
49
+ if not USE_NUMBA :
50
+ import warnings
51
+ warnings .warn ("Numba missing... magnetism and uniform convolution will run more slowly" )
52
+
28
53
BASE_GUIDE_ANGLE = 270.0
29
54
30
55
@@ -298,30 +323,17 @@ def calculate_u1_u3_py(H, rhoM, thetaM, Aguide):
298
323
return sld_b , u1 , u3
299
324
300
325
301
- try :
302
- # raise ImportError() # uncomment to force numba off
303
- from numba import njit , prange
304
- USE_NUMBA = True
305
- except ImportError :
306
- USE_NUMBA = False
307
- # if no numba then njit does nothing
308
-
309
- def njit (* args , ** kw ):
310
- # Check for bare @njit, in which case we just return the function.
311
- if len (args ) == 1 and callable (args [0 ]) and not kw :
312
- return args [0 ]
313
- # Otherwise we have @njit(...), so return the identity decorator.
314
- return lambda fn : fn
315
-
316
326
MINIMAL_RHO_M = 1e-2 # in units of 1e-6/A^2
317
327
EPS = np .finfo (float ).eps
318
328
B2SLD = 2.31604654 # Scattering factor for B field 1e-6
319
329
330
+ import math
331
+ import cmath
320
332
CR4XA_SIG = 'void(i8, f8[:], f8[:], f8, f8[:], f8[:], f8[:], c16[:], c16[:], f8, c16[:])'
321
- @njit (CR4XA_SIG , parallel = False , cache = True )
333
+ # @njit(CR4XA_SIG, parallel=False, cache=True)
322
334
def Cr4xa (N , D , SIGMA , IP , RHO , IRHO , RHOM , U1 , U3 , KZ , Y ):
323
335
EPS = 1e-10
324
- PI4 = np .pi * 4.0e-6
336
+ PI4 = math .pi * 4.0e-6
325
337
326
338
if (KZ <= - 1.e-10 ):
327
339
L = N - 1
@@ -365,17 +377,14 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
365
377
# minus to plus (lower to higher potential energy) and the observed R-+ will
366
378
# actually be zero at large distances from the interface.
367
379
#
368
- # In the backing, the -S1 and -S3 waves are explicitly set to be zero amplitude
380
+ # In the backing, the -S1 and -S3 waves are math. explicitly set to be zero amplitude
369
381
# by the boundary conditions (neutrons only incident in the fronting medium - no
370
382
# source of neutrons below).
371
383
#
372
- S1L = - sqrt (complex (PI4 * (RHO [L ]+ RHOM [L ]) - E0 , - PI4 * (fabs (IRHO [L ])+ EPS )))
373
- S3L = - sqrt (complex (PI4 * (RHO [L ]- RHOM [L ]) -
374
- E0 , - PI4 * (fabs (IRHO [L ])+ EPS )))
375
- S1LP = - sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ]) -
376
- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
377
- S3LP = - sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ]) -
378
- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
384
+ S1L = - cmath .sqrt (complex (PI4 * (RHO [L ]+ RHOM [L ]) - E0 , - PI4 * (math .fabs (IRHO [L ])+ EPS )))
385
+ S3L = - cmath .sqrt (complex (PI4 * (RHO [L ]- RHOM [L ]) - E0 , - PI4 * (math .fabs (IRHO [L ])+ EPS )))
386
+ S1LP = - cmath .sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ]) - E0 , - PI4 * (math .fabs (IRHO [LP ])+ EPS )))
387
+ S3LP = - cmath .sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ]) - E0 , - PI4 * (math .fabs (IRHO [LP ])+ EPS )))
379
388
SIGMAL = SIGMA [L + SIGMA_OFFSET ]
380
389
381
390
if (abs (U1 [L ]) <= 1.0 ):
@@ -410,23 +419,23 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
410
419
FS3S3 = S3L / S3LP
411
420
412
421
B11 = DELTA * 1.0 * (1.0 + FS1S1 )
413
- B12 = DELTA * 1.0 * (1.0 - FS1S1 ) * exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
422
+ B12 = DELTA * 1.0 * (1.0 - FS1S1 ) * cmath . exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
414
423
B13 = DELTA * - GLP * (1.0 + FS3S1 )
415
- B14 = DELTA * - GLP * (1.0 - FS3S1 ) * exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
424
+ B14 = DELTA * - GLP * (1.0 - FS3S1 ) * cmath . exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
416
425
417
- B21 = DELTA * 1.0 * (1.0 - FS1S1 ) * exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
426
+ B21 = DELTA * 1.0 * (1.0 - FS1S1 ) * cmath . exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
418
427
B22 = DELTA * 1.0 * (1.0 + FS1S1 )
419
- B23 = DELTA * - GLP * (1.0 - FS3S1 ) * exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
428
+ B23 = DELTA * - GLP * (1.0 - FS3S1 ) * cmath . exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
420
429
B24 = DELTA * - GLP * (1.0 + FS3S1 )
421
430
422
431
B31 = DELTA * - BLP * (1.0 + FS1S3 )
423
- B32 = DELTA * - BLP * (1.0 - FS1S3 ) * exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
432
+ B32 = DELTA * - BLP * (1.0 - FS1S3 ) * cmath . exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
424
433
B33 = DELTA * 1.0 * (1.0 + FS3S3 )
425
- B34 = DELTA * 1.0 * (1.0 - FS3S3 ) * exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
434
+ B34 = DELTA * 1.0 * (1.0 - FS3S3 ) * cmath . exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
426
435
427
- B41 = DELTA * - BLP * (1.0 - FS1S3 ) * exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
436
+ B41 = DELTA * - BLP * (1.0 - FS1S3 ) * cmath . exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
428
437
B42 = DELTA * - BLP * (1.0 + FS1S3 )
429
- B43 = DELTA * 1.0 * (1.0 - FS3S3 ) * exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
438
+ B43 = DELTA * 1.0 * (1.0 - FS3S3 ) * cmath . exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
430
439
B44 = DELTA * 1.0 * (1.0 + FS3S3 )
431
440
432
441
Z += D [LP ]
@@ -440,8 +449,8 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
440
449
S3L = S3LP #
441
450
GL = GLP
442
451
BL = BLP
443
- S1LP = - sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ])- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
444
- S3LP = - sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ])- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
452
+ S1LP = - cmath . sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ])- E0 , - PI4 * (math . fabs (IRHO [LP ])+ EPS )))
453
+ S3LP = - cmath . sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ])- E0 , - PI4 * (math . fabs (IRHO [LP ])+ EPS )))
445
454
SIGMAL = SIGMA [L + SIGMA_OFFSET ]
446
455
447
456
if (abs (U1 [LP ]) <= 1.0 ):
@@ -462,13 +471,13 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
462
471
DGB = (1.0 - GL * BLP ) * DELTA
463
472
DGG = (GL - GLP ) * DELTA
464
473
465
- ES1L = exp (S1L * Z )
474
+ ES1L = cmath . exp (S1L * Z )
466
475
ENS1L = 1.0 / ES1L
467
- ES1LP = exp (S1LP * Z )
476
+ ES1LP = cmath . exp (S1LP * Z )
468
477
ENS1LP = 1.0 / ES1LP
469
- ES3L = exp (S3L * Z )
478
+ ES3L = cmath . exp (S3L * Z )
470
479
ENS3L = 1.0 / ES3L
471
- ES3LP = exp (S3LP * Z )
480
+ ES3LP = cmath . exp (S3LP * Z )
472
481
ENS3LP = 1.0 / ES3LP
473
482
474
483
FS1S1 = S1L / S1LP
@@ -479,26 +488,26 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
479
488
A11 = A22 = DBG * (1.0 + FS1S1 )
480
489
A11 *= ES1L * ENS1LP
481
490
A22 *= ENS1L * ES1LP
482
- A12 = A21 = DBG * (1.0 - FS1S1 ) * exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
491
+ A12 = A21 = DBG * (1.0 - FS1S1 ) * cmath . exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
483
492
A12 *= ENS1L * ENS1LP
484
493
A21 *= ES1L * ES1LP
485
494
A13 = A24 = DGG * (1.0 + FS3S1 )
486
495
A13 *= ES3L * ENS1LP
487
496
A24 *= ENS3L * ES1LP
488
- A14 = A23 = DGG * (1.0 - FS3S1 ) * exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
497
+ A14 = A23 = DGG * (1.0 - FS3S1 ) * cmath . exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
489
498
A14 *= ENS3L * ENS1LP
490
499
A23 *= ES3L * ES1LP
491
500
492
501
A31 = A42 = DBB * (1.0 + FS1S3 )
493
502
A31 *= ES1L * ENS3LP
494
503
A42 *= ENS1L * ES3LP
495
- A32 = A41 = DBB * (1.0 - FS1S3 ) * exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
504
+ A32 = A41 = DBB * (1.0 - FS1S3 ) * cmath . exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
496
505
A32 *= ENS1L * ENS3LP
497
506
A41 *= ES1L * ES3LP
498
507
A33 = A44 = DGB * (1.0 + FS3S3 )
499
508
A33 *= ES3L * ENS3LP
500
509
A44 *= ENS3L * ES3LP
501
- A34 = A43 = DGB * (1.0 - FS3S3 ) * exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
510
+ A34 = A43 = DGB * (1.0 - FS3S3 ) * cmath . exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
502
511
A34 *= ENS3L * ENS3LP
503
512
A43 *= ES3L * ES3LP
504
513
@@ -550,14 +559,15 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
550
559
# Calculate reflectivity coefficients specified by POLSTAT
551
560
# IP = +1 fills in ++, +-, -+, --; IP = -1 only fills in -+, --.
552
561
if IP > 0 :
562
+ # Only return -+ and -- if IP is -1, otherwise return all
553
563
Y [0 ] = (B24 * B41 - B21 * B44 )/ DETW # ++
554
564
Y [1 ] = (B21 * B42 - B41 * B22 )/ DETW # +-
555
565
Y [2 ] = (B24 * B43 - B23 * B44 )/ DETW # -+
556
566
Y [3 ] = (B23 * B42 - B43 * B22 )/ DETW # --
557
567
558
568
#@cc.export('mag_amplitude')
559
569
MAGAMP_SIG = 'void(f8[:], f8[:], f8[:], f8[:], f8[:], c16[:], c16[:], f8[:], c16[:,:])'
560
- @njit (MAGAMP_SIG , parallel = True , cache = True )
570
+ # @njit(MAGAMP_SIG, parallel=True, cache=True)
561
571
def magnetic_amplitude_py (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
562
572
"""
563
573
python version of calculation
@@ -581,6 +591,47 @@ def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
581
591
for i in prange (points ):
582
592
Cr4xa (layers , d , sigma , - 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
583
593
594
+ def magnetic_amplitude_cuda (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
595
+ i = cuda .grid (1 )
596
+ if i < KZ .size :
597
+ Cr4xa (d .size , d , sigma , 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
598
+ def magnetic_amplitude_cuda_B (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
599
+ i = cuda .grid (1 )
600
+ if i < KZ .size :
601
+ Cr4xa (d .size , d , sigma , 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
602
+ Cr4xa (d .size , d , sigma , - 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
603
+
604
+ def magnetic_amplitude_driver (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
605
+ if abs (rhoM [0 ]) <= MINIMAL_RHO_M and abs (rhoM [- 1 ]) <= MINIMAL_RHO_M :
606
+ kernel = magnetic_amplitude_cuda
607
+ else :
608
+ kernel = magnetic_amplitude_cuda_B
609
+ gd , gsigma , grho , girho , grhoM , gKZ = [
610
+ cuda .to_device (np .float32 (v ), stream ) for v in (d , sigma , rho , irho , rhoM , KZ )]
611
+ gu1 , gu3 = [
612
+ cuda .to_device (np .complex64 (v ), stream ) for v in (u1 , u3 )]
613
+ gR = cuda .device_array (R .shape , dtype = np .float32 )
614
+ threadsperblock = 32
615
+ blockspergrid = (KZ .size + (threadsperblock - 1 )) // threadsperblock
616
+ kernel [blockspergrid , threadsperblock , stream ](
617
+ gd , gsigma , grho , girho , grhoM , gu1 , gu3 , gKZ , gR )
618
+ stream .synchronize ()
619
+ R [:] = gR .copy_to_host (stream = stream )
620
+
621
+ if USE_CUDA :
622
+ stream = cuda .stream ()
623
+ CR4XA_SIG_F = CR4XA_SIG .replace ('f8' ,'f4' ).replace ('c16' ,'c8' )
624
+ MAGAMP_SIG_F = MAGAMP_SIG .replace ('f8' ,'f4' ).replace ('c16' ,'c8' )
625
+ Cr4xa = cuda .jit (CR4XA_SIG_F , device = True , fastmath = True , inline = True )(Cr4xa )
626
+ magnetic_amplitude_cuda = cuda .jit (MAGAMP_SIG_F , fastmath = True )(magnetic_amplitude_cuda )
627
+ magnetic_amplitude_cuda_B = cuda .jit (MAGAMP_SIG_F , fastmath = True )(magnetic_amplitude_cuda_B )
628
+ magnetic_amplitude_py = magnetic_amplitude_driver
629
+ elif USE_NUMBA :
630
+ Cr4xa = njit (CR4XA_SIG , parallel = False , cache = True )(Cr4xa )
631
+ magnetic_amplitude_py = njit (MAGAMP_SIG , parallel = True , cache = True )(magnetic_amplitude_py )
632
+ else :
633
+ ... # fall back to pure python (non-vectorized!) edition
634
+
584
635
#try:
585
636
# from .magnetic_amplitude import mag_amplitude as magnetic_amplitude_py
586
637
# print("loaded from compiled module")
0 commit comments