Skip to content

Commit 261e7ea

Browse files
committed
cuda implementation of magnetic reflectivity
1 parent d325c99 commit 261e7ea

File tree

1 file changed

+95
-44
lines changed

1 file changed

+95
-44
lines changed

refl1d/reflectivity.py

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@
2525
# delay load so doc build doesn't require compilation
2626
# from . import reflmodule
2727

28+
try:
29+
# raise ImportError() # uncomment to force numba off
30+
from numba import njit, prange
31+
USE_NUMBA = True
32+
except ImportError:
33+
USE_NUMBA = False
34+
# if no numba then njit does nothing
35+
36+
def njit(*args, **kw):
37+
# Check for bare @njit, in which case we just return the function.
38+
if len(args) == 1 and callable(args[0]) and not kw:
39+
return args[0]
40+
# Otherwise we have @njit(...), so return the identity decorator.
41+
return lambda fn: fn
42+
43+
try:
44+
from numba import cuda
45+
USE_CUDA = bool(cuda.list_devices()) # True if cuda device is available
46+
except ImportError:
47+
USE_CUDA = False
48+
49+
if not USE_NUMBA:
50+
import warnings
51+
warnings.warn("Numba missing... magnetism and uniform convolution will run more slowly")
52+
2853
BASE_GUIDE_ANGLE = 270.0
2954

3055

@@ -298,30 +323,17 @@ def calculate_u1_u3_py(H, rhoM, thetaM, Aguide):
298323
return sld_b, u1, u3
299324

300325

301-
try:
302-
# raise ImportError() # uncomment to force numba off
303-
from numba import njit, prange
304-
USE_NUMBA = True
305-
except ImportError:
306-
USE_NUMBA = False
307-
# if no numba then njit does nothing
308-
309-
def njit(*args, **kw):
310-
# Check for bare @njit, in which case we just return the function.
311-
if len(args) == 1 and callable(args[0]) and not kw:
312-
return args[0]
313-
# Otherwise we have @njit(...), so return the identity decorator.
314-
return lambda fn: fn
315-
316326
MINIMAL_RHO_M = 1e-2 # in units of 1e-6/A^2
317327
EPS = np.finfo(float).eps
318328
B2SLD = 2.31604654 # Scattering factor for B field 1e-6
319329

330+
import math
331+
import cmath
320332
CR4XA_SIG = 'void(i8, f8[:], f8[:], f8, f8[:], f8[:], f8[:], c16[:], c16[:], f8, c16[:])'
321-
@njit(CR4XA_SIG, parallel=False, cache=True)
333+
#@njit(CR4XA_SIG, parallel=False, cache=True)
322334
def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
323335
EPS = 1e-10
324-
PI4 = np.pi * 4.0e-6
336+
PI4 = math.pi * 4.0e-6
325337

326338
if (KZ <= -1.e-10):
327339
L = N-1
@@ -365,17 +377,14 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
365377
# minus to plus (lower to higher potential energy) and the observed R-+ will
366378
# actually be zero at large distances from the interface.
367379
#
368-
# In the backing, the -S1 and -S3 waves are explicitly set to be zero amplitude
380+
# In the backing, the -S1 and -S3 waves are math.explicitly set to be zero amplitude
369381
# by the boundary conditions (neutrons only incident in the fronting medium - no
370382
# source of neutrons below).
371383
#
372-
S1L = -sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(fabs(IRHO[L])+EPS)))
373-
S3L = -sqrt(complex(PI4*(RHO[L]-RHOM[L]) -
374-
E0, -PI4*(fabs(IRHO[L])+EPS)))
375-
S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) -
376-
E0, -PI4*(fabs(IRHO[LP])+EPS)))
377-
S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) -
378-
E0, -PI4*(fabs(IRHO[LP])+EPS)))
384+
S1L = -cmath.sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS)))
385+
S3L = -cmath.sqrt(complex(PI4*(RHO[L]-RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS)))
386+
S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
387+
S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
379388
SIGMAL = SIGMA[L+SIGMA_OFFSET]
380389

381390
if (abs(U1[L]) <= 1.0):
@@ -410,23 +419,23 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
410419
FS3S3 = S3L/S3LP
411420

412421
B11 = DELTA * 1.0 * (1.0 + FS1S1)
413-
B12 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
422+
B12 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
414423
B13 = DELTA * -GLP * (1.0 + FS3S1)
415-
B14 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
424+
B14 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
416425

417-
B21 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
426+
B21 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
418427
B22 = DELTA * 1.0 * (1.0 + FS1S1)
419-
B23 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
428+
B23 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
420429
B24 = DELTA * -GLP * (1.0 + FS3S1)
421430

422431
B31 = DELTA * -BLP * (1.0 + FS1S3)
423-
B32 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
432+
B32 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
424433
B33 = DELTA * 1.0 * (1.0 + FS3S3)
425-
B34 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
434+
B34 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
426435

427-
B41 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
436+
B41 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
428437
B42 = DELTA * -BLP * (1.0 + FS1S3)
429-
B43 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
438+
B43 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
430439
B44 = DELTA * 1.0 * (1.0 + FS3S3)
431440

432441
Z += D[LP]
@@ -440,8 +449,8 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
440449
S3L = S3LP #
441450
GL = GLP
442451
BL = BLP
443-
S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS)))
444-
S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS)))
452+
S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
453+
S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
445454
SIGMAL = SIGMA[L+SIGMA_OFFSET]
446455

447456
if (abs(U1[LP]) <= 1.0):
@@ -462,13 +471,13 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
462471
DGB = (1.0 - GL*BLP) * DELTA
463472
DGG = (GL - GLP) * DELTA
464473

465-
ES1L = exp(S1L*Z)
474+
ES1L = cmath.exp(S1L*Z)
466475
ENS1L = 1.0 / ES1L
467-
ES1LP = exp(S1LP*Z)
476+
ES1LP = cmath.exp(S1LP*Z)
468477
ENS1LP = 1.0 / ES1LP
469-
ES3L = exp(S3L*Z)
478+
ES3L = cmath.exp(S3L*Z)
470479
ENS3L = 1.0 / ES3L
471-
ES3LP = exp(S3LP*Z)
480+
ES3LP = cmath.exp(S3LP*Z)
472481
ENS3LP = 1.0 / ES3LP
473482

474483
FS1S1 = S1L/S1LP
@@ -479,26 +488,26 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
479488
A11 = A22 = DBG * (1.0 + FS1S1)
480489
A11 *= ES1L * ENS1LP
481490
A22 *= ENS1L * ES1LP
482-
A12 = A21 = DBG * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
491+
A12 = A21 = DBG * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
483492
A12 *= ENS1L * ENS1LP
484493
A21 *= ES1L * ES1LP
485494
A13 = A24 = DGG * (1.0 + FS3S1)
486495
A13 *= ES3L * ENS1LP
487496
A24 *= ENS3L * ES1LP
488-
A14 = A23 = DGG * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
497+
A14 = A23 = DGG * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
489498
A14 *= ENS3L * ENS1LP
490499
A23 *= ES3L * ES1LP
491500

492501
A31 = A42 = DBB * (1.0 + FS1S3)
493502
A31 *= ES1L * ENS3LP
494503
A42 *= ENS1L * ES3LP
495-
A32 = A41 = DBB * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
504+
A32 = A41 = DBB * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
496505
A32 *= ENS1L * ENS3LP
497506
A41 *= ES1L * ES3LP
498507
A33 = A44 = DGB * (1.0 + FS3S3)
499508
A33 *= ES3L * ENS3LP
500509
A44 *= ENS3L * ES3LP
501-
A34 = A43 = DGB * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
510+
A34 = A43 = DGB * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
502511
A34 *= ENS3L * ENS3LP
503512
A43 *= ES3L * ES3LP
504513

@@ -550,14 +559,15 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
550559
# Calculate reflectivity coefficients specified by POLSTAT
551560
# IP = +1 fills in ++, +-, -+, --; IP = -1 only fills in -+, --.
552561
if IP > 0:
562+
# Only return -+ and -- if IP is -1, otherwise return all
553563
Y[0] = (B24*B41 - B21*B44)/DETW # ++
554564
Y[1] = (B21*B42 - B41*B22)/DETW # +-
555565
Y[2] = (B24*B43 - B23*B44)/DETW # -+
556566
Y[3] = (B23*B42 - B43*B22)/DETW # --
557567

558568
#@cc.export('mag_amplitude')
559569
MAGAMP_SIG = 'void(f8[:], f8[:], f8[:], f8[:], f8[:], c16[:], c16[:], f8[:], c16[:,:])'
560-
@njit(MAGAMP_SIG, parallel=True, cache=True)
570+
#@njit(MAGAMP_SIG, parallel=True, cache=True)
561571
def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
562572
"""
563573
python version of calculation
@@ -581,6 +591,47 @@ def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
581591
for i in prange(points):
582592
Cr4xa(layers, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
583593

594+
def magnetic_amplitude_cuda(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
595+
i = cuda.grid(1)
596+
if i < KZ.size:
597+
Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
598+
def magnetic_amplitude_cuda_B(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
599+
i = cuda.grid(1)
600+
if i < KZ.size:
601+
Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
602+
Cr4xa(d.size, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
603+
604+
def magnetic_amplitude_driver(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
605+
if abs(rhoM[0]) <= MINIMAL_RHO_M and abs(rhoM[-1]) <= MINIMAL_RHO_M:
606+
kernel = magnetic_amplitude_cuda
607+
else:
608+
kernel = magnetic_amplitude_cuda_B
609+
gd, gsigma, grho, girho, grhoM, gKZ = [
610+
cuda.to_device(np.float32(v), stream) for v in (d, sigma, rho, irho, rhoM, KZ)]
611+
gu1, gu3 = [
612+
cuda.to_device(np.complex64(v), stream) for v in (u1, u3)]
613+
gR = cuda.device_array(R.shape, dtype=np.float32)
614+
threadsperblock = 32
615+
blockspergrid = (KZ.size + (threadsperblock - 1)) // threadsperblock
616+
kernel[blockspergrid, threadsperblock, stream](
617+
gd, gsigma, grho, girho, grhoM, gu1, gu3, gKZ, gR)
618+
stream.synchronize()
619+
R[:] = gR.copy_to_host(stream=stream)
620+
621+
if USE_CUDA:
622+
stream = cuda.stream()
623+
CR4XA_SIG_F = CR4XA_SIG.replace('f8','f4').replace('c16','c8')
624+
MAGAMP_SIG_F = MAGAMP_SIG.replace('f8','f4').replace('c16','c8')
625+
Cr4xa = cuda.jit(CR4XA_SIG_F, device=True, fastmath=True, inline=True)(Cr4xa)
626+
magnetic_amplitude_cuda = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda)
627+
magnetic_amplitude_cuda_B = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda_B)
628+
magnetic_amplitude_py = magnetic_amplitude_driver
629+
elif USE_NUMBA:
630+
Cr4xa = njit(CR4XA_SIG, parallel=False, cache=True)(Cr4xa)
631+
magnetic_amplitude_py = njit(MAGAMP_SIG, parallel=True, cache=True)(magnetic_amplitude_py)
632+
else:
633+
... # fall back to pure python (non-vectorized!) edition
634+
584635
#try:
585636
# from .magnetic_amplitude import mag_amplitude as magnetic_amplitude_py
586637
# print("loaded from compiled module")

0 commit comments

Comments
 (0)