Skip to content

Commit 78e74d9

Browse files
committed
Auto merge of #115827 - eduardosm:miri-sse-reduce-code-dup, r=RalfJung
miri: reduce code duplication in some SSE/SSE2 intrinsics Reduces code duplication in the Miri implementation of some SSE and SSE2 using generics and rustc_const_eval helper functions. There are also some other minor changes. r? `@RalfJung`
2 parents ed33e40 + 9fbbfd2 commit 78e74d9

File tree

6 files changed

+336
-525
lines changed

6 files changed

+336
-525
lines changed

compiler/rustc_middle/src/mir/interpret/value.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ impl<Prov> Scalar<Prov> {
173173
.unwrap_or_else(|| bug!("Signed value {:#x} does not fit in {} bits", i, size.bits()))
174174
}
175175

176+
#[inline]
177+
pub fn from_i8(i: i8) -> Self {
178+
Self::from_int(i, Size::from_bits(8))
179+
}
180+
181+
#[inline]
182+
pub fn from_i16(i: i16) -> Self {
183+
Self::from_int(i, Size::from_bits(16))
184+
}
185+
176186
#[inline]
177187
pub fn from_i32(i: i32) -> Self {
178188
Self::from_int(i, Size::from_bits(32))
@@ -400,15 +410,19 @@ impl<'tcx, Prov: Provenance> Scalar<Prov> {
400410
Ok(i64::try_from(b).unwrap())
401411
}
402412

413+
#[inline]
414+
pub fn to_float<F: Float>(self) -> InterpResult<'tcx, F> {
415+
// Going through `to_uint` to check size and truncation.
416+
Ok(F::from_bits(self.to_uint(Size::from_bits(F::BITS))?))
417+
}
418+
403419
#[inline]
404420
pub fn to_f32(self) -> InterpResult<'tcx, Single> {
405-
// Going through `u32` to check size and truncation.
406-
Ok(Single::from_bits(self.to_u32()?.into()))
421+
self.to_float()
407422
}
408423

409424
#[inline]
410425
pub fn to_f64(self) -> InterpResult<'tcx, Double> {
411-
// Going through `u64` to check size and truncation.
412-
Ok(Double::from_bits(self.to_u64()?.into()))
426+
self.to_float()
413427
}
414428
}

src/tools/miri/src/helpers.rs

+36-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use rustc_middle::mir;
1414
use rustc_middle::ty::{
1515
self,
1616
layout::{IntegerExt as _, LayoutOf, TyAndLayout},
17-
Ty, TyCtxt,
17+
IntTy, Ty, TyCtxt, UintTy,
1818
};
1919
use rustc_span::{def_id::CrateNum, sym, Span, Symbol};
2020
use rustc_target::abi::{Align, FieldIdx, FieldsShape, Integer, Size, Variants};
@@ -1066,6 +1066,24 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
10661066
),
10671067
}
10681068
}
1069+
1070+
/// Returns an integer type that is twice wide as `ty`
1071+
fn get_twice_wide_int_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
1072+
let this = self.eval_context_ref();
1073+
match ty.kind() {
1074+
// Unsigned
1075+
ty::Uint(UintTy::U8) => this.tcx.types.u16,
1076+
ty::Uint(UintTy::U16) => this.tcx.types.u32,
1077+
ty::Uint(UintTy::U32) => this.tcx.types.u64,
1078+
ty::Uint(UintTy::U64) => this.tcx.types.u128,
1079+
// Signed
1080+
ty::Int(IntTy::I8) => this.tcx.types.i16,
1081+
ty::Int(IntTy::I16) => this.tcx.types.i32,
1082+
ty::Int(IntTy::I32) => this.tcx.types.i64,
1083+
ty::Int(IntTy::I64) => this.tcx.types.i128,
1084+
_ => span_bug!(this.cur_span(), "unexpected type: {ty:?}"),
1085+
}
1086+
}
10691087
}
10701088

10711089
impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
@@ -1151,3 +1169,20 @@ pub fn get_local_crates(tcx: TyCtxt<'_>) -> Vec<CrateNum> {
11511169
pub fn target_os_is_unix(target_os: &str) -> bool {
11521170
matches!(target_os, "linux" | "macos" | "freebsd" | "android")
11531171
}
1172+
1173+
pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
1174+
// SIMD uses all-1 as pattern for "true". In two's complement,
1175+
// -1 has all its bits set to one and `from_int` will truncate or
1176+
// sign-extend it to `size` as required.
1177+
let val = if b { -1 } else { 0 };
1178+
Scalar::from_int(val, size)
1179+
}
1180+
1181+
pub(crate) fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
1182+
let val = elem.to_scalar().to_int(elem.layout.size)?;
1183+
Ok(match val {
1184+
0 => false,
1185+
-1 => true,
1186+
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
1187+
})
1188+
}

src/tools/miri/src/shims/intrinsics/simd.rs

+2-17
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use rustc_apfloat::{Float, Round};
22
use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
33
use rustc_middle::{mir, ty, ty::FloatTy};
4-
use rustc_target::abi::{Endian, HasDataLayout, Size};
4+
use rustc_target::abi::{Endian, HasDataLayout};
55

66
use crate::*;
7-
use helpers::check_arg_count;
7+
use helpers::{bool_to_simd_element, check_arg_count, simd_element_to_bool};
88

99
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
1010
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
@@ -612,21 +612,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
612612
}
613613
}
614614

615-
fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
616-
// SIMD uses all-1 as pattern for "true"
617-
let val = if b { -1 } else { 0 };
618-
Scalar::from_int(val, size)
619-
}
620-
621-
fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
622-
let val = elem.to_scalar().to_int(elem.layout.size)?;
623-
Ok(match val {
624-
0 => false,
625-
-1 => true,
626-
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
627-
})
628-
}
629-
630615
fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
631616
assert!(idx < vec_len);
632617
match endianness {

src/tools/miri/src/shims/x86/mod.rs

+157-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use crate::InterpResult;
1+
use rustc_middle::mir;
2+
use rustc_target::abi::Size;
3+
4+
use crate::*;
5+
use helpers::bool_to_simd_element;
26

37
pub(super) mod sse;
48
pub(super) mod sse2;
@@ -43,3 +47,155 @@ impl FloatCmpOp {
4347
}
4448
}
4549
}
50+
51+
#[derive(Copy, Clone)]
52+
enum FloatBinOp {
53+
/// Arithmetic operation
54+
Arith(mir::BinOp),
55+
/// Comparison
56+
Cmp(FloatCmpOp),
57+
/// Minimum value (with SSE semantics)
58+
///
59+
/// <https://www.felixcloutier.com/x86/minss>
60+
/// <https://www.felixcloutier.com/x86/minps>
61+
/// <https://www.felixcloutier.com/x86/minsd>
62+
/// <https://www.felixcloutier.com/x86/minpd>
63+
Min,
64+
/// Maximum value (with SSE semantics)
65+
///
66+
/// <https://www.felixcloutier.com/x86/maxss>
67+
/// <https://www.felixcloutier.com/x86/maxps>
68+
/// <https://www.felixcloutier.com/x86/maxsd>
69+
/// <https://www.felixcloutier.com/x86/maxpd>
70+
Max,
71+
}
72+
73+
/// Performs `which` scalar operation on `left` and `right` and returns
74+
/// the result.
75+
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
76+
this: &crate::MiriInterpCx<'_, 'tcx>,
77+
which: FloatBinOp,
78+
left: &ImmTy<'tcx, Provenance>,
79+
right: &ImmTy<'tcx, Provenance>,
80+
) -> InterpResult<'tcx, Scalar<Provenance>> {
81+
match which {
82+
FloatBinOp::Arith(which) => {
83+
let (res, _overflow, _ty) = this.overflowing_binary_op(which, left, right)?;
84+
Ok(res)
85+
}
86+
FloatBinOp::Cmp(which) => {
87+
let left = left.to_scalar().to_float::<F>()?;
88+
let right = right.to_scalar().to_float::<F>()?;
89+
// FIXME: Make sure that these operations match the semantics
90+
// of cmpps/cmpss/cmppd/cmpsd
91+
let res = match which {
92+
FloatCmpOp::Eq => left == right,
93+
FloatCmpOp::Lt => left < right,
94+
FloatCmpOp::Le => left <= right,
95+
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
96+
FloatCmpOp::Neq => left != right,
97+
FloatCmpOp::Nlt => !(left < right),
98+
FloatCmpOp::Nle => !(left <= right),
99+
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
100+
};
101+
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
102+
}
103+
FloatBinOp::Min => {
104+
let left_scalar = left.to_scalar();
105+
let left = left_scalar.to_float::<F>()?;
106+
let right_scalar = right.to_scalar();
107+
let right = right_scalar.to_float::<F>()?;
108+
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
109+
// is true when `x` is either +0 or -0.
110+
if (left == F::ZERO && right == F::ZERO)
111+
|| left.is_nan()
112+
|| right.is_nan()
113+
|| left >= right
114+
{
115+
Ok(right_scalar)
116+
} else {
117+
Ok(left_scalar)
118+
}
119+
}
120+
FloatBinOp::Max => {
121+
let left_scalar = left.to_scalar();
122+
let left = left_scalar.to_float::<F>()?;
123+
let right_scalar = right.to_scalar();
124+
let right = right_scalar.to_float::<F>()?;
125+
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
126+
// is true when `x` is either +0 or -0.
127+
if (left == F::ZERO && right == F::ZERO)
128+
|| left.is_nan()
129+
|| right.is_nan()
130+
|| left <= right
131+
{
132+
Ok(right_scalar)
133+
} else {
134+
Ok(left_scalar)
135+
}
136+
}
137+
}
138+
}
139+
140+
/// Performs `which` operation on the first component of `left` and `right`
141+
/// and copies the other components from `left`. The result is stored in `dest`.
142+
fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
143+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
144+
which: FloatBinOp,
145+
left: &OpTy<'tcx, Provenance>,
146+
right: &OpTy<'tcx, Provenance>,
147+
dest: &PlaceTy<'tcx, Provenance>,
148+
) -> InterpResult<'tcx, ()> {
149+
let (left, left_len) = this.operand_to_simd(left)?;
150+
let (right, right_len) = this.operand_to_simd(right)?;
151+
let (dest, dest_len) = this.place_to_simd(dest)?;
152+
153+
assert_eq!(dest_len, left_len);
154+
assert_eq!(dest_len, right_len);
155+
156+
let res0 = bin_op_float::<F>(
157+
this,
158+
which,
159+
&this.read_immediate(&this.project_index(&left, 0)?)?,
160+
&this.read_immediate(&this.project_index(&right, 0)?)?,
161+
)?;
162+
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
163+
164+
for i in 1..dest_len {
165+
this.copy_op(
166+
&this.project_index(&left, i)?,
167+
&this.project_index(&dest, i)?,
168+
/*allow_transmute*/ false,
169+
)?;
170+
}
171+
172+
Ok(())
173+
}
174+
175+
/// Performs `which` operation on each component of `left` and
176+
/// `right`, storing the result is stored in `dest`.
177+
fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
178+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
179+
which: FloatBinOp,
180+
left: &OpTy<'tcx, Provenance>,
181+
right: &OpTy<'tcx, Provenance>,
182+
dest: &PlaceTy<'tcx, Provenance>,
183+
) -> InterpResult<'tcx, ()> {
184+
let (left, left_len) = this.operand_to_simd(left)?;
185+
let (right, right_len) = this.operand_to_simd(right)?;
186+
let (dest, dest_len) = this.place_to_simd(dest)?;
187+
188+
assert_eq!(dest_len, left_len);
189+
assert_eq!(dest_len, right_len);
190+
191+
for i in 0..dest_len {
192+
let left = this.read_immediate(&this.project_index(&left, i)?)?;
193+
let right = this.read_immediate(&this.project_index(&right, i)?)?;
194+
let dest = this.project_index(&dest, i)?;
195+
196+
let res = bin_op_float::<F>(this, which, &left, &right)?;
197+
this.write_scalar(res, &dest)?;
198+
}
199+
200+
Ok(())
201+
}

0 commit comments

Comments
 (0)