Skip to content

Commit ef88434

Browse files
committed
Auto merge of #84274 - nagisa:fix-simd, r=nikic
Don't set fast-math for the SIMD operations we set it for previously Instead of `fast-math`. `fast-math` implies things like functions not being able to accept as an argument or return as a result, say, `inf` which made these functions confusingly named or behaving incorrectly, depending on how you interpret it. It seems that the intended behaviour was to set a `afn` flag instead. In doing so we also renamed the intrinsics to say `_approx` so that it is clear these are not precision oriented and the users can act accordingly. Fixes #84268
2 parents b021bee + 487e273 commit ef88434

19 files changed

+154
-131
lines changed

compiler/rustc_codegen_llvm/src/builder.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -261,39 +261,39 @@ impl BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
261261
fn fadd_fast(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
262262
unsafe {
263263
let instr = llvm::LLVMBuildFAdd(self.llbuilder, lhs, rhs, UNNAMED);
264-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
264+
llvm::LLVMRustSetFastMath(instr);
265265
instr
266266
}
267267
}
268268

269269
fn fsub_fast(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
270270
unsafe {
271271
let instr = llvm::LLVMBuildFSub(self.llbuilder, lhs, rhs, UNNAMED);
272-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
272+
llvm::LLVMRustSetFastMath(instr);
273273
instr
274274
}
275275
}
276276

277277
fn fmul_fast(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
278278
unsafe {
279279
let instr = llvm::LLVMBuildFMul(self.llbuilder, lhs, rhs, UNNAMED);
280-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
280+
llvm::LLVMRustSetFastMath(instr);
281281
instr
282282
}
283283
}
284284

285285
fn fdiv_fast(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
286286
unsafe {
287287
let instr = llvm::LLVMBuildFDiv(self.llbuilder, lhs, rhs, UNNAMED);
288-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
288+
llvm::LLVMRustSetFastMath(instr);
289289
instr
290290
}
291291
}
292292

293293
fn frem_fast(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
294294
unsafe {
295295
let instr = llvm::LLVMBuildFRem(self.llbuilder, lhs, rhs, UNNAMED);
296-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
296+
llvm::LLVMRustSetFastMath(instr);
297297
instr
298298
}
299299
}
@@ -1242,14 +1242,14 @@ impl Builder<'a, 'll, 'tcx> {
12421242
pub fn vector_reduce_fadd_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
12431243
unsafe {
12441244
let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
1245-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
1245+
llvm::LLVMRustSetFastMath(instr);
12461246
instr
12471247
}
12481248
}
12491249
pub fn vector_reduce_fmul_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
12501250
unsafe {
12511251
let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
1252-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
1252+
llvm::LLVMRustSetFastMath(instr);
12531253
instr
12541254
}
12551255
}
@@ -1282,15 +1282,15 @@ impl Builder<'a, 'll, 'tcx> {
12821282
unsafe {
12831283
let instr =
12841284
llvm::LLVMRustBuildVectorReduceFMin(self.llbuilder, src, /*NoNaNs:*/ true);
1285-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
1285+
llvm::LLVMRustSetFastMath(instr);
12861286
instr
12871287
}
12881288
}
12891289
pub fn vector_reduce_fmax_fast(&mut self, src: &'ll Value) -> &'ll Value {
12901290
unsafe {
12911291
let instr =
12921292
llvm::LLVMRustBuildVectorReduceFMax(self.llbuilder, src, /*NoNaNs:*/ true);
1293-
llvm::LLVMRustSetHasUnsafeAlgebra(instr);
1293+
llvm::LLVMRustSetFastMath(instr);
12941294
instr
12951295
}
12961296
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

+19-21
Original file line numberDiff line numberDiff line change
@@ -1053,50 +1053,48 @@ fn generic_simd_intrinsic(
10531053
let vec_ty = bx.type_vector(elem_ty, in_len);
10541054

10551055
let (intr_name, fn_ty) = match name {
1056-
sym::simd_fsqrt => ("sqrt", bx.type_func(&[vec_ty], vec_ty)),
1057-
sym::simd_fsin => ("sin", bx.type_func(&[vec_ty], vec_ty)),
1058-
sym::simd_fcos => ("cos", bx.type_func(&[vec_ty], vec_ty)),
1059-
sym::simd_fabs => ("fabs", bx.type_func(&[vec_ty], vec_ty)),
10601056
sym::simd_ceil => ("ceil", bx.type_func(&[vec_ty], vec_ty)),
1061-
sym::simd_floor => ("floor", bx.type_func(&[vec_ty], vec_ty)),
1062-
sym::simd_round => ("round", bx.type_func(&[vec_ty], vec_ty)),
1063-
sym::simd_trunc => ("trunc", bx.type_func(&[vec_ty], vec_ty)),
1064-
sym::simd_fexp => ("exp", bx.type_func(&[vec_ty], vec_ty)),
1057+
sym::simd_fabs => ("fabs", bx.type_func(&[vec_ty], vec_ty)),
1058+
sym::simd_fcos => ("cos", bx.type_func(&[vec_ty], vec_ty)),
10651059
sym::simd_fexp2 => ("exp2", bx.type_func(&[vec_ty], vec_ty)),
1060+
sym::simd_fexp => ("exp", bx.type_func(&[vec_ty], vec_ty)),
10661061
sym::simd_flog10 => ("log10", bx.type_func(&[vec_ty], vec_ty)),
10671062
sym::simd_flog2 => ("log2", bx.type_func(&[vec_ty], vec_ty)),
10681063
sym::simd_flog => ("log", bx.type_func(&[vec_ty], vec_ty)),
1064+
sym::simd_floor => ("floor", bx.type_func(&[vec_ty], vec_ty)),
1065+
sym::simd_fma => ("fma", bx.type_func(&[vec_ty, vec_ty, vec_ty], vec_ty)),
10691066
sym::simd_fpowi => ("powi", bx.type_func(&[vec_ty, bx.type_i32()], vec_ty)),
10701067
sym::simd_fpow => ("pow", bx.type_func(&[vec_ty, vec_ty], vec_ty)),
1071-
sym::simd_fma => ("fma", bx.type_func(&[vec_ty, vec_ty, vec_ty], vec_ty)),
1068+
sym::simd_fsin => ("sin", bx.type_func(&[vec_ty], vec_ty)),
1069+
sym::simd_fsqrt => ("sqrt", bx.type_func(&[vec_ty], vec_ty)),
1070+
sym::simd_round => ("round", bx.type_func(&[vec_ty], vec_ty)),
1071+
sym::simd_trunc => ("trunc", bx.type_func(&[vec_ty], vec_ty)),
10721072
_ => return_error!("unrecognized intrinsic `{}`", name),
10731073
};
1074-
10751074
let llvm_name = &format!("llvm.{0}.v{1}{2}", intr_name, in_len, elem_ty_str);
10761075
let f = bx.declare_cfn(&llvm_name, llvm::UnnamedAddr::No, fn_ty);
10771076
let c = bx.call(f, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None);
1078-
unsafe { llvm::LLVMRustSetHasUnsafeAlgebra(c) };
10791077
Ok(c)
10801078
}
10811079

10821080
if std::matches!(
10831081
name,
1084-
sym::simd_fsqrt
1085-
| sym::simd_fsin
1086-
| sym::simd_fcos
1082+
sym::simd_ceil
10871083
| sym::simd_fabs
1088-
| sym::simd_ceil
1089-
| sym::simd_floor
1090-
| sym::simd_round
1091-
| sym::simd_trunc
1092-
| sym::simd_fexp
1084+
| sym::simd_fcos
10931085
| sym::simd_fexp2
1086+
| sym::simd_fexp
10941087
| sym::simd_flog10
10951088
| sym::simd_flog2
10961089
| sym::simd_flog
1097-
| sym::simd_fpowi
1098-
| sym::simd_fpow
1090+
| sym::simd_floor
10991091
| sym::simd_fma
1092+
| sym::simd_fpow
1093+
| sym::simd_fpowi
1094+
| sym::simd_fsin
1095+
| sym::simd_fsqrt
1096+
| sym::simd_round
1097+
| sym::simd_trunc
11001098
) {
11011099
return simd_simple_float_intrinsic(name, in_elem, in_ty, in_len, bx, span, args);
11021100
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ extern "C" {
13541354
pub fn LLVMBuildNeg(B: &Builder<'a>, V: &'a Value, Name: *const c_char) -> &'a Value;
13551355
pub fn LLVMBuildFNeg(B: &Builder<'a>, V: &'a Value, Name: *const c_char) -> &'a Value;
13561356
pub fn LLVMBuildNot(B: &Builder<'a>, V: &'a Value, Name: *const c_char) -> &'a Value;
1357-
pub fn LLVMRustSetHasUnsafeAlgebra(Instr: &Value);
1357+
pub fn LLVMRustSetFastMath(Instr: &Value);
13581358

13591359
// Memory
13601360
pub fn LLVMBuildAlloca(B: &Builder<'a>, Ty: &'a Type, Name: *const c_char) -> &'a Value;

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,10 @@ extern "C" void LLVMRustRemoveFunctionAttributes(LLVMValueRef Fn,
349349
F->setAttributes(PALNew);
350350
}
351351

352-
// enable fpmath flag UnsafeAlgebra
353-
extern "C" void LLVMRustSetHasUnsafeAlgebra(LLVMValueRef V) {
352+
// Enable a fast-math flag
353+
//
354+
// https://llvm.org/docs/LangRef.html#fast-math-flags
355+
extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
354356
if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
355357
I->setFast(true);
356358
}

src/test/codegen/issue-84268.rs

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// compile-flags: -O --crate-type=rlib
2+
#![feature(platform_intrinsics, repr_simd)]
3+
4+
extern "platform-intrinsic" {
5+
fn simd_fabs<T>(x: T) -> T;
6+
fn simd_eq<T, U>(x: T, y: T) -> U;
7+
}
8+
9+
#[repr(simd)]
10+
pub struct V([f32; 4]);
11+
12+
#[repr(simd)]
13+
pub struct M([i32; 4]);
14+
15+
#[no_mangle]
16+
// CHECK-LABEL: @is_infinite
17+
pub fn is_infinite(v: V) -> M {
18+
// CHECK: fabs
19+
// CHECK: cmp oeq
20+
unsafe {
21+
simd_eq(simd_fabs(v), V([f32::INFINITY; 4]))
22+
}
23+
}

src/test/codegen/simd-intrinsic/simd-intrinsic-float-abs.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,28 @@ extern "platform-intrinsic" {
3232
// CHECK-LABEL: @fabs_32x2
3333
#[no_mangle]
3434
pub unsafe fn fabs_32x2(a: f32x2) -> f32x2 {
35-
// CHECK: call fast <2 x float> @llvm.fabs.v2f32
35+
// CHECK: call <2 x float> @llvm.fabs.v2f32
3636
simd_fabs(a)
3737
}
3838

3939
// CHECK-LABEL: @fabs_32x4
4040
#[no_mangle]
4141
pub unsafe fn fabs_32x4(a: f32x4) -> f32x4 {
42-
// CHECK: call fast <4 x float> @llvm.fabs.v4f32
42+
// CHECK: call <4 x float> @llvm.fabs.v4f32
4343
simd_fabs(a)
4444
}
4545

4646
// CHECK-LABEL: @fabs_32x8
4747
#[no_mangle]
4848
pub unsafe fn fabs_32x8(a: f32x8) -> f32x8 {
49-
// CHECK: call fast <8 x float> @llvm.fabs.v8f32
49+
// CHECK: call <8 x float> @llvm.fabs.v8f32
5050
simd_fabs(a)
5151
}
5252

5353
// CHECK-LABEL: @fabs_32x16
5454
#[no_mangle]
5555
pub unsafe fn fabs_32x16(a: f32x16) -> f32x16 {
56-
// CHECK: call fast <16 x float> @llvm.fabs.v16f32
56+
// CHECK: call <16 x float> @llvm.fabs.v16f32
5757
simd_fabs(a)
5858
}
5959

@@ -73,20 +73,20 @@ pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
7373
// CHECK-LABEL: @fabs_64x4
7474
#[no_mangle]
7575
pub unsafe fn fabs_64x4(a: f64x4) -> f64x4 {
76-
// CHECK: call fast <4 x double> @llvm.fabs.v4f64
76+
// CHECK: call <4 x double> @llvm.fabs.v4f64
7777
simd_fabs(a)
7878
}
7979

8080
// CHECK-LABEL: @fabs_64x2
8181
#[no_mangle]
8282
pub unsafe fn fabs_64x2(a: f64x2) -> f64x2 {
83-
// CHECK: call fast <2 x double> @llvm.fabs.v2f64
83+
// CHECK: call <2 x double> @llvm.fabs.v2f64
8484
simd_fabs(a)
8585
}
8686

8787
// CHECK-LABEL: @fabs_64x8
8888
#[no_mangle]
8989
pub unsafe fn fabs_64x8(a: f64x8) -> f64x8 {
90-
// CHECK: call fast <8 x double> @llvm.fabs.v8f64
90+
// CHECK: call <8 x double> @llvm.fabs.v8f64
9191
simd_fabs(a)
9292
}

src/test/codegen/simd-intrinsic/simd-intrinsic-float-ceil.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,28 @@ extern "platform-intrinsic" {
3232
// CHECK-LABEL: @ceil_32x2
3333
#[no_mangle]
3434
pub unsafe fn ceil_32x2(a: f32x2) -> f32x2 {
35-
// CHECK: call fast <2 x float> @llvm.ceil.v2f32
35+
// CHECK: call <2 x float> @llvm.ceil.v2f32
3636
simd_ceil(a)
3737
}
3838

3939
// CHECK-LABEL: @ceil_32x4
4040
#[no_mangle]
4141
pub unsafe fn ceil_32x4(a: f32x4) -> f32x4 {
42-
// CHECK: call fast <4 x float> @llvm.ceil.v4f32
42+
// CHECK: call <4 x float> @llvm.ceil.v4f32
4343
simd_ceil(a)
4444
}
4545

4646
// CHECK-LABEL: @ceil_32x8
4747
#[no_mangle]
4848
pub unsafe fn ceil_32x8(a: f32x8) -> f32x8 {
49-
// CHECK: call fast <8 x float> @llvm.ceil.v8f32
49+
// CHECK: call <8 x float> @llvm.ceil.v8f32
5050
simd_ceil(a)
5151
}
5252

5353
// CHECK-LABEL: @ceil_32x16
5454
#[no_mangle]
5555
pub unsafe fn ceil_32x16(a: f32x16) -> f32x16 {
56-
// CHECK: call fast <16 x float> @llvm.ceil.v16f32
56+
// CHECK: call <16 x float> @llvm.ceil.v16f32
5757
simd_ceil(a)
5858
}
5959

@@ -73,20 +73,20 @@ pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
7373
// CHECK-LABEL: @ceil_64x4
7474
#[no_mangle]
7575
pub unsafe fn ceil_64x4(a: f64x4) -> f64x4 {
76-
// CHECK: call fast <4 x double> @llvm.ceil.v4f64
76+
// CHECK: call <4 x double> @llvm.ceil.v4f64
7777
simd_ceil(a)
7878
}
7979

8080
// CHECK-LABEL: @ceil_64x2
8181
#[no_mangle]
8282
pub unsafe fn ceil_64x2(a: f64x2) -> f64x2 {
83-
// CHECK: call fast <2 x double> @llvm.ceil.v2f64
83+
// CHECK: call <2 x double> @llvm.ceil.v2f64
8484
simd_ceil(a)
8585
}
8686

8787
// CHECK-LABEL: @ceil_64x8
8888
#[no_mangle]
8989
pub unsafe fn ceil_64x8(a: f64x8) -> f64x8 {
90-
// CHECK: call fast <8 x double> @llvm.ceil.v8f64
90+
// CHECK: call <8 x double> @llvm.ceil.v8f64
9191
simd_ceil(a)
9292
}

src/test/codegen/simd-intrinsic/simd-intrinsic-float-cos.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,28 @@ extern "platform-intrinsic" {
3232
// CHECK-LABEL: @fcos_32x2
3333
#[no_mangle]
3434
pub unsafe fn fcos_32x2(a: f32x2) -> f32x2 {
35-
// CHECK: call fast <2 x float> @llvm.cos.v2f32
35+
// CHECK: call <2 x float> @llvm.cos.v2f32
3636
simd_fcos(a)
3737
}
3838

3939
// CHECK-LABEL: @fcos_32x4
4040
#[no_mangle]
4141
pub unsafe fn fcos_32x4(a: f32x4) -> f32x4 {
42-
// CHECK: call fast <4 x float> @llvm.cos.v4f32
42+
// CHECK: call <4 x float> @llvm.cos.v4f32
4343
simd_fcos(a)
4444
}
4545

4646
// CHECK-LABEL: @fcos_32x8
4747
#[no_mangle]
4848
pub unsafe fn fcos_32x8(a: f32x8) -> f32x8 {
49-
// CHECK: call fast <8 x float> @llvm.cos.v8f32
49+
// CHECK: call <8 x float> @llvm.cos.v8f32
5050
simd_fcos(a)
5151
}
5252

5353
// CHECK-LABEL: @fcos_32x16
5454
#[no_mangle]
5555
pub unsafe fn fcos_32x16(a: f32x16) -> f32x16 {
56-
// CHECK: call fast <16 x float> @llvm.cos.v16f32
56+
// CHECK: call <16 x float> @llvm.cos.v16f32
5757
simd_fcos(a)
5858
}
5959

@@ -73,20 +73,20 @@ pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
7373
// CHECK-LABEL: @fcos_64x4
7474
#[no_mangle]
7575
pub unsafe fn fcos_64x4(a: f64x4) -> f64x4 {
76-
// CHECK: call fast <4 x double> @llvm.cos.v4f64
76+
// CHECK: call <4 x double> @llvm.cos.v4f64
7777
simd_fcos(a)
7878
}
7979

8080
// CHECK-LABEL: @fcos_64x2
8181
#[no_mangle]
8282
pub unsafe fn fcos_64x2(a: f64x2) -> f64x2 {
83-
// CHECK: call fast <2 x double> @llvm.cos.v2f64
83+
// CHECK: call <2 x double> @llvm.cos.v2f64
8484
simd_fcos(a)
8585
}
8686

8787
// CHECK-LABEL: @fcos_64x8
8888
#[no_mangle]
8989
pub unsafe fn fcos_64x8(a: f64x8) -> f64x8 {
90-
// CHECK: call fast <8 x double> @llvm.cos.v8f64
90+
// CHECK: call <8 x double> @llvm.cos.v8f64
9191
simd_fcos(a)
9292
}

0 commit comments

Comments
 (0)