Skip to content

Commit

Permalink
cuda implementation of magnetic reflectivity
Browse files Browse the repository at this point in the history
  • Loading branch information
pkienzle committed Jul 30, 2021
1 parent d325c99 commit 261e7ea
Showing 1 changed file with 95 additions and 44 deletions.
139 changes: 95 additions & 44 deletions refl1d/reflectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@
# delay load so doc build doesn't require compilation
# from . import reflmodule

try:
# raise ImportError() # uncomment to force numba off
from numba import njit, prange
USE_NUMBA = True
except ImportError:
USE_NUMBA = False
# if no numba then njit does nothing

def njit(*args, **kw):
# Check for bare @njit, in which case we just return the function.
if len(args) == 1 and callable(args[0]) and not kw:
return args[0]
# Otherwise we have @njit(...), so return the identity decorator.
return lambda fn: fn

try:
from numba import cuda
USE_CUDA = bool(cuda.list_devices()) # True if cuda device is available
except ImportError:
USE_CUDA = False

if not USE_NUMBA:
import warnings
warnings.warn("Numba missing... magnetism and uniform convolution will run more slowly")

BASE_GUIDE_ANGLE = 270.0


Expand Down Expand Up @@ -298,30 +323,17 @@ def calculate_u1_u3_py(H, rhoM, thetaM, Aguide):
return sld_b, u1, u3


try:
# raise ImportError() # uncomment to force numba off
from numba import njit, prange
USE_NUMBA = True
except ImportError:
USE_NUMBA = False
# if no numba then njit does nothing

def njit(*args, **kw):
# Check for bare @njit, in which case we just return the function.
if len(args) == 1 and callable(args[0]) and not kw:
return args[0]
# Otherwise we have @njit(...), so return the identity decorator.
return lambda fn: fn

MINIMAL_RHO_M = 1e-2 # in units of 1e-6/A^2
EPS = np.finfo(float).eps
B2SLD = 2.31604654 # Scattering factor for B field 1e-6

import math
import cmath
CR4XA_SIG = 'void(i8, f8[:], f8[:], f8, f8[:], f8[:], f8[:], c16[:], c16[:], f8, c16[:])'
@njit(CR4XA_SIG, parallel=False, cache=True)
#@njit(CR4XA_SIG, parallel=False, cache=True)
def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
EPS = 1e-10
PI4 = np.pi * 4.0e-6
PI4 = math.pi * 4.0e-6

if (KZ <= -1.e-10):
L = N-1
Expand Down Expand Up @@ -365,17 +377,14 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
# minus to plus (lower to higher potential energy) and the observed R-+ will
# actually be zero at large distances from the interface.
#
# In the backing, the -S1 and -S3 waves are explicitly set to be zero amplitude
# In the backing, the -S1 and -S3 waves are math.explicitly set to be zero amplitude
# by the boundary conditions (neutrons only incident in the fronting medium - no
# source of neutrons below).
#
S1L = -sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(fabs(IRHO[L])+EPS)))
S3L = -sqrt(complex(PI4*(RHO[L]-RHOM[L]) -
E0, -PI4*(fabs(IRHO[L])+EPS)))
S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) -
E0, -PI4*(fabs(IRHO[LP])+EPS)))
S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) -
E0, -PI4*(fabs(IRHO[LP])+EPS)))
S1L = -cmath.sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS)))
S3L = -cmath.sqrt(complex(PI4*(RHO[L]-RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS)))
S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
SIGMAL = SIGMA[L+SIGMA_OFFSET]

if (abs(U1[L]) <= 1.0):
Expand Down Expand Up @@ -410,23 +419,23 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
FS3S3 = S3L/S3LP

B11 = DELTA * 1.0 * (1.0 + FS1S1)
B12 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B12 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B13 = DELTA * -GLP * (1.0 + FS3S1)
B14 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
B14 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)

B21 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B21 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B22 = DELTA * 1.0 * (1.0 + FS1S1)
B23 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
B23 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
B24 = DELTA * -GLP * (1.0 + FS3S1)

B31 = DELTA * -BLP * (1.0 + FS1S3)
B32 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B32 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B33 = DELTA * 1.0 * (1.0 + FS3S3)
B34 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
B34 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)

B41 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B41 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B42 = DELTA * -BLP * (1.0 + FS1S3)
B43 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
B43 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
B44 = DELTA * 1.0 * (1.0 + FS3S3)

Z += D[LP]
Expand All @@ -440,8 +449,8 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
S3L = S3LP #
GL = GLP
BL = BLP
S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS)))
S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS)))
S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
SIGMAL = SIGMA[L+SIGMA_OFFSET]

if (abs(U1[LP]) <= 1.0):
Expand All @@ -462,13 +471,13 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
DGB = (1.0 - GL*BLP) * DELTA
DGG = (GL - GLP) * DELTA

ES1L = exp(S1L*Z)
ES1L = cmath.exp(S1L*Z)
ENS1L = 1.0 / ES1L
ES1LP = exp(S1LP*Z)
ES1LP = cmath.exp(S1LP*Z)
ENS1LP = 1.0 / ES1LP
ES3L = exp(S3L*Z)
ES3L = cmath.exp(S3L*Z)
ENS3L = 1.0 / ES3L
ES3LP = exp(S3LP*Z)
ES3LP = cmath.exp(S3LP*Z)
ENS3LP = 1.0 / ES3LP

FS1S1 = S1L/S1LP
Expand All @@ -479,26 +488,26 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
A11 = A22 = DBG * (1.0 + FS1S1)
A11 *= ES1L * ENS1LP
A22 *= ENS1L * ES1LP
A12 = A21 = DBG * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
A12 = A21 = DBG * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
A12 *= ENS1L * ENS1LP
A21 *= ES1L * ES1LP
A13 = A24 = DGG * (1.0 + FS3S1)
A13 *= ES3L * ENS1LP
A24 *= ENS3L * ES1LP
A14 = A23 = DGG * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
A14 = A23 = DGG * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
A14 *= ENS3L * ENS1LP
A23 *= ES3L * ES1LP

A31 = A42 = DBB * (1.0 + FS1S3)
A31 *= ES1L * ENS3LP
A42 *= ENS1L * ES3LP
A32 = A41 = DBB * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
A32 = A41 = DBB * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
A32 *= ENS1L * ENS3LP
A41 *= ES1L * ES3LP
A33 = A44 = DGB * (1.0 + FS3S3)
A33 *= ES3L * ENS3LP
A44 *= ENS3L * ES3LP
A34 = A43 = DGB * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
A34 = A43 = DGB * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
A34 *= ENS3L * ENS3LP
A43 *= ES3L * ES3LP

Expand Down Expand Up @@ -550,14 +559,15 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
# Calculate reflectivity coefficients specified by POLSTAT
# IP = +1 fills in ++, +-, -+, --; IP = -1 only fills in -+, --.
if IP > 0:
# Only return -+ and -- if IP is -1, otherwise return all
Y[0] = (B24*B41 - B21*B44)/DETW # ++
Y[1] = (B21*B42 - B41*B22)/DETW # +-
Y[2] = (B24*B43 - B23*B44)/DETW # -+
Y[3] = (B23*B42 - B43*B22)/DETW # --

#@cc.export('mag_amplitude')
MAGAMP_SIG = 'void(f8[:], f8[:], f8[:], f8[:], f8[:], c16[:], c16[:], f8[:], c16[:,:])'
@njit(MAGAMP_SIG, parallel=True, cache=True)
#@njit(MAGAMP_SIG, parallel=True, cache=True)
def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
"""
python version of calculation
Expand All @@ -581,6 +591,47 @@ def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
for i in prange(points):
Cr4xa(layers, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])

def magnetic_amplitude_cuda(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
i = cuda.grid(1)
if i < KZ.size:
Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
def magnetic_amplitude_cuda_B(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
i = cuda.grid(1)
if i < KZ.size:
Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
Cr4xa(d.size, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])

def magnetic_amplitude_driver(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
if abs(rhoM[0]) <= MINIMAL_RHO_M and abs(rhoM[-1]) <= MINIMAL_RHO_M:
kernel = magnetic_amplitude_cuda
else:
kernel = magnetic_amplitude_cuda_B
gd, gsigma, grho, girho, grhoM, gKZ = [
cuda.to_device(np.float32(v), stream) for v in (d, sigma, rho, irho, rhoM, KZ)]
gu1, gu3 = [
cuda.to_device(np.complex64(v), stream) for v in (u1, u3)]
gR = cuda.device_array(R.shape, dtype=np.float32)
threadsperblock = 32
blockspergrid = (KZ.size + (threadsperblock - 1)) // threadsperblock
kernel[blockspergrid, threadsperblock, stream](
gd, gsigma, grho, girho, grhoM, gu1, gu3, gKZ, gR)
stream.synchronize()
R[:] = gR.copy_to_host(stream=stream)

if USE_CUDA:
stream = cuda.stream()
CR4XA_SIG_F = CR4XA_SIG.replace('f8','f4').replace('c16','c8')
MAGAMP_SIG_F = MAGAMP_SIG.replace('f8','f4').replace('c16','c8')
Cr4xa = cuda.jit(CR4XA_SIG_F, device=True, fastmath=True, inline=True)(Cr4xa)
magnetic_amplitude_cuda = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda)
magnetic_amplitude_cuda_B = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda_B)
magnetic_amplitude_py = magnetic_amplitude_driver
elif USE_NUMBA:
Cr4xa = njit(CR4XA_SIG, parallel=False, cache=True)(Cr4xa)
magnetic_amplitude_py = njit(MAGAMP_SIG, parallel=True, cache=True)(magnetic_amplitude_py)
else:
... # fall back to pure python (non-vectorized!) edition

#try:
# from .magnetic_amplitude import mag_amplitude as magnetic_amplitude_py
# print("loaded from compiled module")
Expand Down

1 comment on commit 261e7ea

@bmaranville
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - so much to play with and so little time.

Please sign in to comment.