Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda implementation of magnetic reflectivity #135

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 97 additions & 44 deletions refl1d/reflectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@
# delay load so doc build doesn't require compilation
# from . import reflmodule

try:
# raise ImportError() # uncomment to force numba off
from numba import njit, prange
USE_NUMBA = True
except ImportError:
USE_NUMBA = False
# if no numba then njit does nothing

def njit(*args, **kw):
# Check for bare @njit, in which case we just return the function.
if len(args) == 1 and callable(args[0]) and not kw:
return args[0]
# Otherwise we have @njit(...), so return the identity decorator.
return lambda fn: fn

try:
from numba import cuda
USE_CUDA = bool(cuda.list_devices()) # True if cuda device is available
except ImportError:
USE_CUDA = False

if not USE_NUMBA:
import warnings
warnings.warn("Numba missing... magnetism and uniform convolution will run more slowly")

BASE_GUIDE_ANGLE = 270.0


Expand Down Expand Up @@ -298,30 +323,17 @@ def calculate_u1_u3_py(H, rhoM, thetaM, Aguide):
return sld_b, u1, u3


try:
# raise ImportError() # uncomment to force numba off
from numba import njit, prange
USE_NUMBA = True
except ImportError:
USE_NUMBA = False
# if no numba then njit does nothing

def njit(*args, **kw):
# Check for bare @njit, in which case we just return the function.
if len(args) == 1 and callable(args[0]) and not kw:
return args[0]
# Otherwise we have @njit(...), so return the identity decorator.
return lambda fn: fn

MINIMAL_RHO_M = 1e-2 # in units of 1e-6/A^2
EPS = np.finfo(float).eps
B2SLD = 2.31604654 # Scattering factor for B field 1e-6

import math
import cmath
CR4XA_SIG = 'void(i8, f8[:], f8[:], f8, f8[:], f8[:], f8[:], c16[:], c16[:], f8, c16[:])'
@njit(CR4XA_SIG, parallel=False, cache=True)
#@njit(CR4XA_SIG, parallel=False, cache=True)
def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
EPS = 1e-10
PI4 = np.pi * 4.0e-6
PI4 = math.pi * 4.0e-6

if (KZ <= -1.e-10):
L = N-1
Expand Down Expand Up @@ -365,17 +377,14 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
# minus to plus (lower to higher potential energy) and the observed R-+ will
# actually be zero at large distances from the interface.
#
# In the backing, the -S1 and -S3 waves are explicitly set to be zero amplitude
# In the backing, the -S1 and -S3 waves are math.explicitly set to be zero amplitude
# by the boundary conditions (neutrons only incident in the fronting medium - no
# source of neutrons below).
#
S1L = -sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(fabs(IRHO[L])+EPS)))
S3L = -sqrt(complex(PI4*(RHO[L]-RHOM[L]) -
E0, -PI4*(fabs(IRHO[L])+EPS)))
S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) -
E0, -PI4*(fabs(IRHO[LP])+EPS)))
S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) -
E0, -PI4*(fabs(IRHO[LP])+EPS)))
S1L = -cmath.sqrt(complex(PI4*(RHO[L]+RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS)))
S3L = -cmath.sqrt(complex(PI4*(RHO[L]-RHOM[L]) - E0, -PI4*(math.fabs(IRHO[L])+EPS)))
S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP]) - E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
SIGMAL = SIGMA[L+SIGMA_OFFSET]

if (abs(U1[L]) <= 1.0):
Expand Down Expand Up @@ -410,23 +419,23 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
FS3S3 = S3L/S3LP

B11 = DELTA * 1.0 * (1.0 + FS1S1)
B12 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B12 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B13 = DELTA * -GLP * (1.0 + FS3S1)
B14 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
B14 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)

B21 = DELTA * 1.0 * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B21 = DELTA * 1.0 * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
B22 = DELTA * 1.0 * (1.0 + FS1S1)
B23 = DELTA * -GLP * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
B23 = DELTA * -GLP * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
B24 = DELTA * -GLP * (1.0 + FS3S1)

B31 = DELTA * -BLP * (1.0 + FS1S3)
B32 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B32 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B33 = DELTA * 1.0 * (1.0 + FS3S3)
B34 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
B34 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)

B41 = DELTA * -BLP * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B41 = DELTA * -BLP * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
B42 = DELTA * -BLP * (1.0 + FS1S3)
B43 = DELTA * 1.0 * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
B43 = DELTA * 1.0 * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
B44 = DELTA * 1.0 * (1.0 + FS3S3)

Z += D[LP]
Expand All @@ -440,8 +449,8 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
S3L = S3LP #
GL = GLP
BL = BLP
S1LP = -sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS)))
S3LP = -sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(fabs(IRHO[LP])+EPS)))
S1LP = -cmath.sqrt(complex(PI4*(RHO[LP]+RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
S3LP = -cmath.sqrt(complex(PI4*(RHO[LP]-RHOM[LP])-E0, -PI4*(math.fabs(IRHO[LP])+EPS)))
SIGMAL = SIGMA[L+SIGMA_OFFSET]

if (abs(U1[LP]) <= 1.0):
Expand All @@ -462,13 +471,13 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
DGB = (1.0 - GL*BLP) * DELTA
DGG = (GL - GLP) * DELTA

ES1L = exp(S1L*Z)
ES1L = cmath.exp(S1L*Z)
ENS1L = 1.0 / ES1L
ES1LP = exp(S1LP*Z)
ES1LP = cmath.exp(S1LP*Z)
ENS1LP = 1.0 / ES1LP
ES3L = exp(S3L*Z)
ES3L = cmath.exp(S3L*Z)
ENS3L = 1.0 / ES3L
ES3LP = exp(S3LP*Z)
ES3LP = cmath.exp(S3LP*Z)
ENS3LP = 1.0 / ES3LP

FS1S1 = S1L/S1LP
Expand All @@ -479,26 +488,26 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
A11 = A22 = DBG * (1.0 + FS1S1)
A11 *= ES1L * ENS1LP
A22 *= ENS1L * ES1LP
A12 = A21 = DBG * (1.0 - FS1S1) * exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
A12 = A21 = DBG * (1.0 - FS1S1) * cmath.exp(2.*S1L*S1LP*SIGMAL*SIGMAL)
A12 *= ENS1L * ENS1LP
A21 *= ES1L * ES1LP
A13 = A24 = DGG * (1.0 + FS3S1)
A13 *= ES3L * ENS1LP
A24 *= ENS3L * ES1LP
A14 = A23 = DGG * (1.0 - FS3S1) * exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
A14 = A23 = DGG * (1.0 - FS3S1) * cmath.exp(2.*S3L*S1LP*SIGMAL*SIGMAL)
A14 *= ENS3L * ENS1LP
A23 *= ES3L * ES1LP

A31 = A42 = DBB * (1.0 + FS1S3)
A31 *= ES1L * ENS3LP
A42 *= ENS1L * ES3LP
A32 = A41 = DBB * (1.0 - FS1S3) * exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
A32 = A41 = DBB * (1.0 - FS1S3) * cmath.exp(2.*S1L*S3LP*SIGMAL*SIGMAL)
A32 *= ENS1L * ENS3LP
A41 *= ES1L * ES3LP
A33 = A44 = DGB * (1.0 + FS3S3)
A33 *= ES3L * ENS3LP
A44 *= ENS3L * ES3LP
A34 = A43 = DGB * (1.0 - FS3S3) * exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
A34 = A43 = DGB * (1.0 - FS3S3) * cmath.exp(2.*S3L*S3LP*SIGMAL*SIGMAL)
A34 *= ENS3L * ENS3LP
A43 *= ES3L * ES3LP

Expand Down Expand Up @@ -550,14 +559,15 @@ def Cr4xa(N, D, SIGMA, IP, RHO, IRHO, RHOM, U1, U3, KZ, Y):
# Calculate reflectivity coefficients specified by POLSTAT
# IP = +1 fills in ++, +-, -+, --; IP = -1 only fills in -+, --.
if IP > 0:
# Only return -+ and -- if IP is -1, otherwise return all
Y[0] = (B24*B41 - B21*B44)/DETW # ++
Y[1] = (B21*B42 - B41*B22)/DETW # +-
Y[2] = (B24*B43 - B23*B44)/DETW # -+
Y[3] = (B23*B42 - B43*B22)/DETW # --

#@cc.export('mag_amplitude')
MAGAMP_SIG = 'void(f8[:], f8[:], f8[:], f8[:], f8[:], c16[:], c16[:], f8[:], c16[:,:])'
@njit(MAGAMP_SIG, parallel=True, cache=True)
#@njit(MAGAMP_SIG, parallel=True, cache=True)
def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
"""
python version of calculation
Expand All @@ -581,6 +591,49 @@ def magnetic_amplitude_py(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
for i in prange(points):
Cr4xa(layers, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])

def magnetic_amplitude_cuda(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
i = cuda.grid(1)
if i < KZ.size:
Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
def magnetic_amplitude_cuda_B(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
i = cuda.grid(1)
if i < KZ.size:
Cr4xa(d.size, d, sigma, 1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])
Cr4xa(d.size, d, sigma, -1.0, rho, irho, rhoM, u1, u3, KZ[i], R[i])

def magnetic_amplitude_driver(d, sigma, rho, irho, rhoM, u1, u3, KZ, R):
if abs(rhoM[0]) <= MINIMAL_RHO_M and abs(rhoM[-1]) <= MINIMAL_RHO_M:
kernel = magnetic_amplitude_cuda
else:
kernel = magnetic_amplitude_cuda_B
gd, gsigma, grho, girho, grhoM, gKZ = [
cuda.to_device(np.float32(v), stream) for v in (d, sigma, rho, irho, rhoM, KZ)]
gu1, gu3 = [
cuda.to_device(np.complex64(v), stream) for v in (u1, u3)]
gR = cuda.device_array(R.shape, dtype=np.float32)
threadsperblock = 32
blockspergrid = (KZ.size + (threadsperblock - 1)) // threadsperblock
kernel[blockspergrid, threadsperblock, stream](
gd, gsigma, grho, girho, grhoM, gu1, gu3, gKZ, gR)
stream.synchronize()
R[:] = gR.copy_to_host(stream=stream)

if USE_CUDA:
dev = cuda.select_device(0)
print(dev)
stream = cuda.stream()
CR4XA_SIG_F = CR4XA_SIG.replace('f8','f4').replace('c16','c8')
MAGAMP_SIG_F = MAGAMP_SIG.replace('f8','f4').replace('c16','c8')
Cr4xa = cuda.jit(CR4XA_SIG_F, device=True, fastmath=True, inline=True)(Cr4xa)
magnetic_amplitude_cuda = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda)
magnetic_amplitude_cuda_B = cuda.jit(MAGAMP_SIG_F, fastmath=True)(magnetic_amplitude_cuda_B)
magnetic_amplitude_py = magnetic_amplitude_driver
elif USE_NUMBA:
Cr4xa = njit(CR4XA_SIG, parallel=False, cache=True)(Cr4xa)
magnetic_amplitude_py = njit(MAGAMP_SIG, parallel=True, cache=True)(magnetic_amplitude_py)
else:
... # fall back to pure python (non-vectorized!) edition

#try:
# from .magnetic_amplitude import mag_amplitude as magnetic_amplitude_py
# print("loaded from compiled module")
Expand Down