diff --git a/refl1d/reflectivity.py b/refl1d/reflectivity.py index 27facf14..4dc62a85 100644 --- a/refl1d/reflectivity.py +++ b/refl1d/reflectivity.py @@ -25,6 +25,31 @@ # delay load so doc build doesn't require compilation # from . import reflmodule +try: + # raise ImportError() # uncomment to force numba off + from numba import njit, prange + USE_NUMBA = True +except ImportError: + USE_NUMBA = False + # if no numba then njit does nothing + + def njit(*args, **kw): + # Check for bare @njit, in which case we just return the function. + if len(args) == 1 and callable(args[0]) and not kw: + return args[0] + # Otherwise we have @njit(...), so return the identity decorator. + return lambda fn: fn + +try: + from numba import cuda + USE_CUDA = bool(cuda.list_devices()) # True if cuda device is available +except ImportError: + USE_CUDA = False + +if not USE_NUMBA: + import warnings + warnings.warn("Numba missing... magnetism and uniform convolution will run more slowly") + BASE_GUIDE_ANGLE = 270.0 @@ -298,30 +323,17 @@ def calculate_u1_u3_py(H, rhoM, thetaM, Aguide): return sld_b, u1, u3 -try: - # raise ImportError() # uncomment to force numba off - from numba import njit, prange - USE_NUMBA = True -except ImportError: - USE_NUMBA = False - # if no numba then njit does nothing - - def njit(*args, **kw): - # Check for bare @njit, in which case we just return the function. - if len(args) == 1 and callable(args[0]) and not kw: - return args[0] - # Otherwise we have @njit(...), so return the identity decorator. - return lambda fn: fn - MINIMAL_RHO_M = 1e-2 # in units of 1e-6/A^2 EPS = np.finfo(float).eps B2SLD = 2.31604654 # Scattering factor for B field 1e-6 +import math +import cmath CR4XA_SIG = 'void(i8, f8[:], f8[:], f8, f8[:], f8[:], f8[:], c16[:], c16[:], f8, c16[:])' -@njit(CR4XA_SIG, parallel=False, cache=True) +#@njit(CR4XA_SIG, parallel=False, cache=True) def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): EPS = 1e-10 - PI4 = np.pi * 4.0e-6 + PI4 = math.pi * 4.0e-6 if (KZ <= -1.e-10): L = N-1 @@ -365,17 +377,14 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): # minus to plus (lower to higher potential energy) and the observed R-+ will # actually be zero at large distances from the interface. # - # In the backing, the -S1 and -S3 waves are explicitly set to be zero amplitude + # In the backing, the -S1 and -S3 waves are math.explicitly set to be zero amplitude # by the boundary conditions (neutrons only incident in the fronting medium - no # source of neutrons below). # - S1L = -sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(fabs(IRHO[L])+EPS))) - S3L = -sqrt(complex(PI4*(RHO[L]-RHOM[L]) - - E0, -PI4*(fabs(IRHO[L])+EPS))) - S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) - - E0, -PI4*(fabs(IRHO[LP])+EPS))) - S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) - - E0, -PI4*(fabs(IRHO[LP])+EPS))) + S1L = -cmath.sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS))) + S3L = -cmath.sqrt(complex(PI4*(RHO[L]-RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS))) + S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS))) + S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS))) SIGMAL = SIGMA[L+SIGMA_OFFSET] if (abs(U1[L]) <= 1.0): @@ -410,23 +419,23 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): FS3S3 = S3L/S3LP B11 = DELTA * 1.0 * (1.0 + FS1S1) - B12 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL) + B12 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL) B13 = DELTA * -GLP * (1.0 + FS3S1) - B14 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL) + B14 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL) - B21 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL) + B21 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL) B22 = DELTA * 1.0 * (1.0 + FS1S1) - B23 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL) + B23 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL) B24 = DELTA * -GLP * (1.0 + FS3S1) B31 = DELTA * -BLP * (1.0 + FS1S3) - B32 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL) + B32 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL) B33 = DELTA * 1.0 * (1.0 + FS3S3) - B34 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL) + B34 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL) - B41 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL) + B41 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL) B42 = DELTA * -BLP * (1.0 + FS1S3) - B43 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL) + B43 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL) B44 = DELTA * 1.0 * (1.0 + FS3S3) Z += D[LP] @@ -440,8 +449,8 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): S3L = S3LP # GL = GLP BL = BLP - S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS))) - S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS))) + S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS))) + S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS))) SIGMAL = SIGMA[L+SIGMA_OFFSET] if (abs(U1[LP]) <= 1.0): @@ -462,13 +471,13 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): DGB = (1.0 - GL*BLP) * DELTA DGG = (GL - GLP) * DELTA - ES1L = exp(S1L*Z) + ES1L = cmath.exp(S1L*Z) ENS1L = 1.0 / ES1L - ES1LP = exp(S1LP*Z) + ES1LP = cmath.exp(S1LP*Z) ENS1LP = 1.0 / ES1LP - ES3L = exp(S3L*Z) + ES3L = cmath.exp(S3L*Z) ENS3L = 1.0 / ES3L - ES3LP = exp(S3LP*Z) + ES3LP = cmath.exp(S3LP*Z) ENS3LP = 1.0 / ES3LP FS1S1 = S1L/S1LP @@ -479,26 +488,26 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): A11 = A22 = DBG * (1.0 + FS1S1) A11 *= ES1L * ENS1LP A22 *= ENS1L * ES1LP - A12 = A21 = DBG * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL) + A12 = A21 = DBG * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL) A12 *= ENS1L * ENS1LP A21 *= ES1L * ES1LP A13 = A24 = DGG * (1.0 + FS3S1) A13 *= ES3L * ENS1LP A24 *= ENS3L * ES1LP - A14 = A23 = DGG * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL) + A14 = A23 = DGG * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL) A14 *= ENS3L * ENS1LP A23 *= ES3L * ES1LP A31 = A42 = DBB * (1.0 + FS1S3) A31 *= ES1L * ENS3LP A42 *= ENS1L * ES3LP - A32 = A41 = DBB * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL) + A32 = A41 = DBB * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL) A32 *= ENS1L * ENS3LP A41 *= ES1L * ES3LP A33 = A44 = DGB * (1.0 + FS3S3) A33 *= ES3L * ENS3LP A44 *= ENS3L * ES3LP - A34 = A43 = DGB * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL) + A34 = A43 = DGB * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL) A34 *= ENS3L * ENS3LP A43 *= ES3L * ES3LP @@ -550,6 +559,7 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): # Calculate reflectivity coefficients specified by POLSTAT # IP = +1 fills in ++, +-, -+, --; IP = -1 only fills in -+, --. if IP > 0: + # Only return -+ and -- if IP is -1, otherwise return all Y[0] = (B24*B41 - B21*B44)/DETW # ++ Y[1] = (B21*B42 - B41*B22)/DETW # +- Y[2] = (B24*B43 - B23*B44)/DETW # -+ @@ -557,7 +567,7 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y): #@cc.export('mag_amplitude') MAGAMP_SIG = 'void(f8[:], f8[:], f8[:], f8[:], f8[:], c16[:], c16[:], f8[:], c16[:,:])' -@njit(MAGAMP_SIG, parallel=True, cache=True) +#@njit(MAGAMP_SIG, parallel=True, cache=True) def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R): """ python version of calculation @@ -581,6 +591,47 @@ def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R): for i in prange(points): Cr4xa(layers, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i]) +def magnetic_amplitude_cuda(d, sigma, rho, irho, rhoM, u1, u3, KZ, R): + i = cuda.grid(1) + if i < KZ.size: + Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i]) +def magnetic_amplitude_cuda_B(d, sigma, rho, irho, rhoM, u1, u3, KZ, R): + i = cuda.grid(1) + if i < KZ.size: + Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i]) + Cr4xa(d.size, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i]) + +def magnetic_amplitude_driver(d, sigma, rho, irho, rhoM, u1, u3, KZ, R): + if abs(rhoM[0]) <= MINIMAL_RHO_M and abs(rhoM[-1]) <= MINIMAL_RHO_M: + kernel = magnetic_amplitude_cuda + else: + kernel = magnetic_amplitude_cuda_B + gd, gsigma, grho, girho, grhoM, gKZ = [ + cuda.to_device(np.float32(v), stream) for v in (d, sigma, rho, irho, rhoM, KZ)] + gu1, gu3 = [ + cuda.to_device(np.complex64(v), stream) for v in (u1, u3)] + gR = cuda.device_array(R.shape, dtype=np.float32) + threadsperblock = 32 + blockspergrid = (KZ.size + (threadsperblock - 1)) // threadsperblock + kernel[blockspergrid, threadsperblock, stream]( + gd, gsigma, grho, girho, grhoM, gu1, gu3, gKZ, gR) + stream.synchronize() + R[:] = gR.copy_to_host(stream=stream) + +if USE_CUDA: + stream = cuda.stream() + CR4XA_SIG_F = CR4XA_SIG.replace('f8','f4').replace('c16','c8') + MAGAMP_SIG_F = MAGAMP_SIG.replace('f8','f4').replace('c16','c8') + Cr4xa = cuda.jit(CR4XA_SIG_F, device=True, fastmath=True, inline=True)(Cr4xa) + magnetic_amplitude_cuda = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda) + magnetic_amplitude_cuda_B = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda_B) + magnetic_amplitude_py = magnetic_amplitude_driver +elif USE_NUMBA: + Cr4xa = njit(CR4XA_SIG, parallel=False, cache=True)(Cr4xa) + magnetic_amplitude_py = njit(MAGAMP_SIG, parallel=True, cache=True)(magnetic_amplitude_py) +else: + ... # fall back to pure python (non-vectorized!) edition + #try: # from .magnetic_amplitude import mag_amplitude as magnetic_amplitude_py # print("loaded from compiled module")