Skip to content

Commit dd7e165

Browse files
kjetilkjekaKjetil Kjeka
authored and
Kjetil Kjeka
committed
NVPTX: Add f16 SIMD intrinsics
1 parent a3beb09 commit dd7e165

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

crates/core_arch/src/nvptx/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
1414
use crate::ffi::c_void;
1515

16+
mod packed;
17+
18+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
19+
pub use packed::*;
20+
1621
#[allow(improper_ctypes)]
1722
extern "C" {
1823
#[link_name = "llvm.nvvm.barrier0"]

crates/core_arch/src/nvptx/packed.rs

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//! NVPTX Packed data types (SIMD)
2+
//!
3+
//! Packed Data Types is what PTX calls SIMD types. See [PTX ISA (Packed Data Types)](https://docs.nvidia.com/cuda/parallel-thread-execution/#packed-data-types) for a full reference.
4+
5+
// Note: #[assert_instr] tests are not actually being run on nvptx due to being a `no_std` target incapable of running tests. Something like FileCheck would be appropriate for verifying the correct instruction is used.
6+
7+
use crate::intrinsics::simd::*;
8+
9+
#[allow(improper_ctypes)]
10+
extern "C" {
11+
#[link_name = "llvm.minimum.v2f16"]
12+
fn llvm_f16x2_min(a: f16x2, b: f16x2) -> f16x2;
13+
#[link_name = "llvm.maximum.v2f16"]
14+
fn llvm_f16x2_max(a: f16x2, b: f16x2) -> f16x2;
15+
}
16+
17+
types! {
18+
#![unstable(feature = "stdarch_nvptx", issue = "111199")]
19+
20+
/// PTX-specific 32-bit wide floating point (f16 x 2) vector type
21+
pub struct f16x2(2 x f16);
22+
23+
}
24+
25+
/// Add two values
26+
///
27+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add>
28+
#[inline]
29+
#[cfg_attr(test, assert_instr(add.rn.f16x22))]
30+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
31+
pub unsafe fn f16x2_add(a: f16x2, b: f16x2) -> f16x2 {
32+
simd_add(a, b)
33+
}
34+
35+
/// Subtract two values
36+
///
37+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-sub>
38+
#[inline]
39+
#[cfg_attr(test, assert_instr(sub.rn.f16x2))]
40+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
41+
pub unsafe fn f16x2_sub(a: f16x2, b: f16x2) -> f16x2 {
42+
simd_sub(a, b)
43+
}
44+
45+
/// Multiply two values
46+
///
47+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-mul>
48+
#[inline]
49+
#[cfg_attr(test, assert_instr(mul.rn.f16x2))]
50+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
51+
pub unsafe fn f16x2_mul(a: f16x2, b: f16x2) -> f16x2 {
52+
simd_mul(a, b)
53+
}
54+
55+
/// Fused multiply-add
56+
///
57+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-fma>
58+
#[inline]
59+
#[cfg_attr(test, assert_instr(fma.rn.f16x2))]
60+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
61+
pub unsafe fn f16x2_fma(a: f16x2, b: f16x2, c: f16x2) -> f16x2 {
62+
simd_fma(a, b, c)
63+
}
64+
65+
/// Arithmetic negate
66+
///
67+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-neg>
68+
#[inline]
69+
#[cfg_attr(test, assert_instr(neg.f16x2))]
70+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
71+
pub unsafe fn f16x2_neg(a: f16x2) -> f16x2 {
72+
simd_neg(a)
73+
}
74+
75+
/// Find the minimum of two values
76+
///
77+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-min>
78+
#[inline]
79+
#[cfg_attr(test, assert_instr(min.NaN.f16x2))]
80+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
81+
pub unsafe fn f16x2_min(a: f16x2, b: f16x2) -> f16x2 {
82+
llvm_f16x2_min(a, b)
83+
}
84+
85+
/// Find the maximum of two values
86+
///
87+
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-max>
88+
#[inline]
89+
#[cfg_attr(test, assert_instr(max.NaN.f16x2))]
90+
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
91+
pub unsafe fn f16x2_max(a: f16x2, b: f16x2) -> f16x2 {
92+
llvm_f16x2_max(a, b)
93+
}

0 commit comments

Comments
 (0)