2525# delay load so doc build doesn't require compilation
2626# from . import reflmodule
2727
28+ try :
29+ # raise ImportError() # uncomment to force numba off
30+ from numba import njit , prange
31+ USE_NUMBA = True
32+ except ImportError :
33+ USE_NUMBA = False
34+ # if no numba then njit does nothing
35+
36+ def njit (* args , ** kw ):
37+ # Check for bare @njit, in which case we just return the function.
38+ if len (args ) == 1 and callable (args [0 ]) and not kw :
39+ return args [0 ]
40+ # Otherwise we have @njit(...), so return the identity decorator.
41+ return lambda fn : fn
42+
43+ try :
44+ from numba import cuda
45+ USE_CUDA = bool (cuda .list_devices ()) # True if cuda device is available
46+ except ImportError :
47+ USE_CUDA = False
48+
49+ if not USE_NUMBA :
50+ import warnings
51+ warnings .warn ("Numba missing... magnetism and uniform convolution will run more slowly" )
52+
2853BASE_GUIDE_ANGLE = 270.0
2954
3055
@@ -298,30 +323,17 @@ def calculate_u1_u3_py(H, rhoM, thetaM, Aguide):
298323 return sld_b , u1 , u3
299324
300325
301- try :
302- # raise ImportError() # uncomment to force numba off
303- from numba import njit , prange
304- USE_NUMBA = True
305- except ImportError :
306- USE_NUMBA = False
307- # if no numba then njit does nothing
308-
309- def njit (* args , ** kw ):
310- # Check for bare @njit, in which case we just return the function.
311- if len (args ) == 1 and callable (args [0 ]) and not kw :
312- return args [0 ]
313- # Otherwise we have @njit(...), so return the identity decorator.
314- return lambda fn : fn
315-
316326MINIMAL_RHO_M = 1e-2 # in units of 1e-6/A^2
317327EPS = np .finfo (float ).eps
318328B2SLD = 2.31604654 # Scattering factor for B field 1e-6
319329
330+ import math
331+ import cmath
320332CR4XA_SIG = 'void(i8, f8[:], f8[:], f8, f8[:], f8[:], f8[:], c16[:], c16[:], f8, c16[:])'
321- @njit (CR4XA_SIG , parallel = False , cache = True )
333+ # @njit(CR4XA_SIG, parallel=False, cache=True)
322334def Cr4xa (N , D , SIGMA , IP , RHO , IRHO , RHOM , U1 , U3 , KZ , Y ):
323335 EPS = 1e-10
324- PI4 = np .pi * 4.0e-6
336+ PI4 = math .pi * 4.0e-6
325337
326338 if (KZ <= - 1.e-10 ):
327339 L = N - 1
@@ -365,17 +377,14 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
365377 # minus to plus (lower to higher potential energy) and the observed R-+ will
366378 # actually be zero at large distances from the interface.
367379 #
368- # In the backing, the -S1 and -S3 waves are explicitly set to be zero amplitude
380+ # In the backing, the -S1 and -S3 waves are math. explicitly set to be zero amplitude
369381 # by the boundary conditions (neutrons only incident in the fronting medium - no
370382 # source of neutrons below).
371383 #
372- S1L = - sqrt (complex (PI4 * (RHO [L ]+ RHOM [L ]) - E0 , - PI4 * (fabs (IRHO [L ])+ EPS )))
373- S3L = - sqrt (complex (PI4 * (RHO [L ]- RHOM [L ]) -
374- E0 , - PI4 * (fabs (IRHO [L ])+ EPS )))
375- S1LP = - sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ]) -
376- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
377- S3LP = - sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ]) -
378- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
384+ S1L = - cmath .sqrt (complex (PI4 * (RHO [L ]+ RHOM [L ]) - E0 , - PI4 * (math .fabs (IRHO [L ])+ EPS )))
385+ S3L = - cmath .sqrt (complex (PI4 * (RHO [L ]- RHOM [L ]) - E0 , - PI4 * (math .fabs (IRHO [L ])+ EPS )))
386+ S1LP = - cmath .sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ]) - E0 , - PI4 * (math .fabs (IRHO [LP ])+ EPS )))
387+ S3LP = - cmath .sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ]) - E0 , - PI4 * (math .fabs (IRHO [LP ])+ EPS )))
379388 SIGMAL = SIGMA [L + SIGMA_OFFSET ]
380389
381390 if (abs (U1 [L ]) <= 1.0 ):
@@ -410,23 +419,23 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
410419 FS3S3 = S3L / S3LP
411420
412421 B11 = DELTA * 1.0 * (1.0 + FS1S1 )
413- B12 = DELTA * 1.0 * (1.0 - FS1S1 ) * exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
422+ B12 = DELTA * 1.0 * (1.0 - FS1S1 ) * cmath . exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
414423 B13 = DELTA * - GLP * (1.0 + FS3S1 )
415- B14 = DELTA * - GLP * (1.0 - FS3S1 ) * exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
424+ B14 = DELTA * - GLP * (1.0 - FS3S1 ) * cmath . exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
416425
417- B21 = DELTA * 1.0 * (1.0 - FS1S1 ) * exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
426+ B21 = DELTA * 1.0 * (1.0 - FS1S1 ) * cmath . exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
418427 B22 = DELTA * 1.0 * (1.0 + FS1S1 )
419- B23 = DELTA * - GLP * (1.0 - FS3S1 ) * exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
428+ B23 = DELTA * - GLP * (1.0 - FS3S1 ) * cmath . exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
420429 B24 = DELTA * - GLP * (1.0 + FS3S1 )
421430
422431 B31 = DELTA * - BLP * (1.0 + FS1S3 )
423- B32 = DELTA * - BLP * (1.0 - FS1S3 ) * exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
432+ B32 = DELTA * - BLP * (1.0 - FS1S3 ) * cmath . exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
424433 B33 = DELTA * 1.0 * (1.0 + FS3S3 )
425- B34 = DELTA * 1.0 * (1.0 - FS3S3 ) * exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
434+ B34 = DELTA * 1.0 * (1.0 - FS3S3 ) * cmath . exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
426435
427- B41 = DELTA * - BLP * (1.0 - FS1S3 ) * exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
436+ B41 = DELTA * - BLP * (1.0 - FS1S3 ) * cmath . exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
428437 B42 = DELTA * - BLP * (1.0 + FS1S3 )
429- B43 = DELTA * 1.0 * (1.0 - FS3S3 ) * exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
438+ B43 = DELTA * 1.0 * (1.0 - FS3S3 ) * cmath . exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
430439 B44 = DELTA * 1.0 * (1.0 + FS3S3 )
431440
432441 Z += D [LP ]
@@ -440,8 +449,8 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
440449 S3L = S3LP #
441450 GL = GLP
442451 BL = BLP
443- S1LP = - sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ])- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
444- S3LP = - sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ])- E0 , - PI4 * (fabs (IRHO [LP ])+ EPS )))
452+ S1LP = - cmath . sqrt (complex (PI4 * (RHO [LP ]+ RHOM [LP ])- E0 , - PI4 * (math . fabs (IRHO [LP ])+ EPS )))
453+ S3LP = - cmath . sqrt (complex (PI4 * (RHO [LP ]- RHOM [LP ])- E0 , - PI4 * (math . fabs (IRHO [LP ])+ EPS )))
445454 SIGMAL = SIGMA [L + SIGMA_OFFSET ]
446455
447456 if (abs (U1 [LP ]) <= 1.0 ):
@@ -462,13 +471,13 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
462471 DGB = (1.0 - GL * BLP ) * DELTA
463472 DGG = (GL - GLP ) * DELTA
464473
465- ES1L = exp (S1L * Z )
474+ ES1L = cmath . exp (S1L * Z )
466475 ENS1L = 1.0 / ES1L
467- ES1LP = exp (S1LP * Z )
476+ ES1LP = cmath . exp (S1LP * Z )
468477 ENS1LP = 1.0 / ES1LP
469- ES3L = exp (S3L * Z )
478+ ES3L = cmath . exp (S3L * Z )
470479 ENS3L = 1.0 / ES3L
471- ES3LP = exp (S3LP * Z )
480+ ES3LP = cmath . exp (S3LP * Z )
472481 ENS3LP = 1.0 / ES3LP
473482
474483 FS1S1 = S1L / S1LP
@@ -479,26 +488,26 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
479488 A11 = A22 = DBG * (1.0 + FS1S1 )
480489 A11 *= ES1L * ENS1LP
481490 A22 *= ENS1L * ES1LP
482- A12 = A21 = DBG * (1.0 - FS1S1 ) * exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
491+ A12 = A21 = DBG * (1.0 - FS1S1 ) * cmath . exp (2. * S1L * S1LP * SIGMAL * SIGMAL )
483492 A12 *= ENS1L * ENS1LP
484493 A21 *= ES1L * ES1LP
485494 A13 = A24 = DGG * (1.0 + FS3S1 )
486495 A13 *= ES3L * ENS1LP
487496 A24 *= ENS3L * ES1LP
488- A14 = A23 = DGG * (1.0 - FS3S1 ) * exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
497+ A14 = A23 = DGG * (1.0 - FS3S1 ) * cmath . exp (2. * S3L * S1LP * SIGMAL * SIGMAL )
489498 A14 *= ENS3L * ENS1LP
490499 A23 *= ES3L * ES1LP
491500
492501 A31 = A42 = DBB * (1.0 + FS1S3 )
493502 A31 *= ES1L * ENS3LP
494503 A42 *= ENS1L * ES3LP
495- A32 = A41 = DBB * (1.0 - FS1S3 ) * exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
504+ A32 = A41 = DBB * (1.0 - FS1S3 ) * cmath . exp (2. * S1L * S3LP * SIGMAL * SIGMAL )
496505 A32 *= ENS1L * ENS3LP
497506 A41 *= ES1L * ES3LP
498507 A33 = A44 = DGB * (1.0 + FS3S3 )
499508 A33 *= ES3L * ENS3LP
500509 A44 *= ENS3L * ES3LP
501- A34 = A43 = DGB * (1.0 - FS3S3 ) * exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
510+ A34 = A43 = DGB * (1.0 - FS3S3 ) * cmath . exp (2. * S3L * S3LP * SIGMAL * SIGMAL )
502511 A34 *= ENS3L * ENS3LP
503512 A43 *= ES3L * ES3LP
504513
@@ -550,14 +559,15 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
550559 # Calculate reflectivity coefficients specified by POLSTAT
551560 # IP = +1 fills in ++, +-, -+, --; IP = -1 only fills in -+, --.
552561 if IP > 0 :
562+ # Only return -+ and -- if IP is -1, otherwise return all
553563 Y [0 ] = (B24 * B41 - B21 * B44 )/ DETW # ++
554564 Y [1 ] = (B21 * B42 - B41 * B22 )/ DETW # +-
555565 Y [2 ] = (B24 * B43 - B23 * B44 )/ DETW # -+
556566 Y [3 ] = (B23 * B42 - B43 * B22 )/ DETW # --
557567
558568#@cc.export('mag_amplitude')
559569MAGAMP_SIG = 'void(f8[:], f8[:], f8[:], f8[:], f8[:], c16[:], c16[:], f8[:], c16[:,:])'
560- @njit (MAGAMP_SIG , parallel = True , cache = True )
570+ # @njit(MAGAMP_SIG, parallel=True, cache=True)
561571def magnetic_amplitude_py (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
562572 """
563573 python version of calculation
@@ -581,6 +591,47 @@ def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
581591 for i in prange (points ):
582592 Cr4xa (layers , d , sigma , - 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
583593
594+ def magnetic_amplitude_cuda (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
595+ i = cuda .grid (1 )
596+ if i < KZ .size :
597+ Cr4xa (d .size , d , sigma , 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
598+ def magnetic_amplitude_cuda_B (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
599+ i = cuda .grid (1 )
600+ if i < KZ .size :
601+ Cr4xa (d .size , d , sigma , 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
602+ Cr4xa (d .size , d , sigma , - 1.0 , rho , irho , rhoM , u1 , u3 , KZ [i ], R [i ])
603+
604+ def magnetic_amplitude_driver (d , sigma , rho , irho , rhoM , u1 , u3 , KZ , R ):
605+ if abs (rhoM [0 ]) <= MINIMAL_RHO_M and abs (rhoM [- 1 ]) <= MINIMAL_RHO_M :
606+ kernel = magnetic_amplitude_cuda
607+ else :
608+ kernel = magnetic_amplitude_cuda_B
609+ gd , gsigma , grho , girho , grhoM , gKZ = [
610+ cuda .to_device (np .float32 (v ), stream ) for v in (d , sigma , rho , irho , rhoM , KZ )]
611+ gu1 , gu3 = [
612+ cuda .to_device (np .complex64 (v ), stream ) for v in (u1 , u3 )]
613+ gR = cuda .device_array (R .shape , dtype = np .float32 )
614+ threadsperblock = 32
615+ blockspergrid = (KZ .size + (threadsperblock - 1 )) // threadsperblock
616+ kernel [blockspergrid , threadsperblock , stream ](
617+ gd , gsigma , grho , girho , grhoM , gu1 , gu3 , gKZ , gR )
618+ stream .synchronize ()
619+ R [:] = gR .copy_to_host (stream = stream )
620+
621+ if USE_CUDA :
622+ stream = cuda .stream ()
623+ CR4XA_SIG_F = CR4XA_SIG .replace ('f8' ,'f4' ).replace ('c16' ,'c8' )
624+ MAGAMP_SIG_F = MAGAMP_SIG .replace ('f8' ,'f4' ).replace ('c16' ,'c8' )
625+ Cr4xa = cuda .jit (CR4XA_SIG_F , device = True , fastmath = True , inline = True )(Cr4xa )
626+ magnetic_amplitude_cuda = cuda .jit (MAGAMP_SIG_F , fastmath = True )(magnetic_amplitude_cuda )
627+ magnetic_amplitude_cuda_B = cuda .jit (MAGAMP_SIG_F , fastmath = True )(magnetic_amplitude_cuda_B )
628+ magnetic_amplitude_py = magnetic_amplitude_driver
629+ elif USE_NUMBA :
630+ Cr4xa = njit (CR4XA_SIG , parallel = False , cache = True )(Cr4xa )
631+ magnetic_amplitude_py = njit (MAGAMP_SIG , parallel = True , cache = True )(magnetic_amplitude_py )
632+ else :
633+ ... # fall back to pure python (non-vectorized!) edition
634+
584635#try:
585636# from .magnetic_amplitude import mag_amplitude as magnetic_amplitude_py
586637# print("loaded from compiled module")
0 commit comments