Skip to content

Commit 8d590f2

Browse files
committed
miri: reduce code duplication in SSE2 pavg.b and pavg.w
1 parent 4bb8a8b commit 8d590f2

File tree

2 files changed

+54
-40
lines changed

2 files changed

+54
-40
lines changed

src/tools/miri/src/helpers.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use rustc_middle::mir;
1414
use rustc_middle::ty::{
1515
self,
1616
layout::{IntegerExt as _, LayoutOf, TyAndLayout},
17-
Ty, TyCtxt,
17+
IntTy, Ty, TyCtxt, UintTy,
1818
};
1919
use rustc_span::{def_id::CrateNum, sym, Span, Symbol};
2020
use rustc_target::abi::{Align, FieldIdx, FieldsShape, Integer, Size, Variants};
@@ -1067,6 +1067,24 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
10671067
),
10681068
}
10691069
}
1070+
1071+
/// Returns an integer type that is twice wide as `ty`
1072+
fn get_twice_wide_int_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
1073+
let this = self.eval_context_ref();
1074+
match ty.kind() {
1075+
// Unsigned
1076+
ty::Uint(UintTy::U8) => this.tcx.types.u16,
1077+
ty::Uint(UintTy::U16) => this.tcx.types.u32,
1078+
ty::Uint(UintTy::U32) => this.tcx.types.u64,
1079+
ty::Uint(UintTy::U64) => this.tcx.types.u128,
1080+
// Signed
1081+
ty::Int(IntTy::I8) => this.tcx.types.i16,
1082+
ty::Int(IntTy::I16) => this.tcx.types.i32,
1083+
ty::Int(IntTy::I32) => this.tcx.types.i64,
1084+
ty::Int(IntTy::I64) => this.tcx.types.i128,
1085+
_ => span_bug!(this.cur_span(), "unexpected type: {ty:?}"),
1086+
}
1087+
}
10701088
}
10711089

10721090
impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {

src/tools/miri/src/shims/x86/sse2.rs

+35-39
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_apfloat::{
22
ieee::{Double, Single},
33
Float as _,
44
};
5+
use rustc_middle::mir;
56
use rustc_middle::ty::layout::LayoutOf as _;
67
use rustc_middle::ty::Ty;
78
use rustc_span::Symbol;
@@ -36,9 +37,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
3637
// Intrinsincs sufixed with "epiX" or "epuX" operate with X-bit signed or unsigned
3738
// vectors.
3839
match unprefixed_name {
39-
// Used to implement the _mm_avg_epu8 function.
40-
// Averages packed unsigned 8-bit integers in `left` and `right`.
41-
"pavg.b" => {
40+
// Used to implement the _mm_avg_epu8 and _mm_avg_epu16 functions.
41+
// Averages packed unsigned 8/16-bit integers in `left` and `right`.
42+
"pavg.b" | "pavg.w" => {
4243
let [left, right] =
4344
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
4445

@@ -50,46 +51,41 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
5051
assert_eq!(dest_len, right_len);
5152

5253
for i in 0..dest_len {
53-
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
54-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
54+
let left = this.read_immediate(&this.project_index(&left, i)?)?;
55+
let right = this.read_immediate(&this.project_index(&right, i)?)?;
5556
let dest = this.project_index(&dest, i)?;
5657

57-
// Values are expanded from u8 to u16, so adds cannot overflow.
58-
let res = u16::from(left)
59-
.checked_add(u16::from(right))
60-
.unwrap()
61-
.checked_add(1)
62-
.unwrap()
63-
/ 2;
64-
this.write_scalar(Scalar::from_u8(res.try_into().unwrap()), &dest)?;
65-
}
66-
}
67-
// Used to implement the _mm_avg_epu16 function.
68-
// Averages packed unsigned 16-bit integers in `left` and `right`.
69-
"pavg.w" => {
70-
let [left, right] =
71-
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
72-
73-
let (left, left_len) = this.operand_to_simd(left)?;
74-
let (right, right_len) = this.operand_to_simd(right)?;
75-
let (dest, dest_len) = this.place_to_simd(dest)?;
76-
77-
assert_eq!(dest_len, left_len);
78-
assert_eq!(dest_len, right_len);
58+
// Widen the operands to avoid overflow
59+
let twice_wide_ty = this.get_twice_wide_int_ty(left.layout.ty);
60+
let twice_wide_layout = this.layout_of(twice_wide_ty)?;
61+
let left = this.int_to_int_or_float(&left, twice_wide_ty)?;
62+
let right = this.int_to_int_or_float(&right, twice_wide_ty)?;
63+
64+
// Calculate left + right + 1
65+
let (added, _overflow, _ty) = this.overflowing_binary_op(
66+
mir::BinOp::Add,
67+
&ImmTy::from_immediate(left, twice_wide_layout),
68+
&ImmTy::from_immediate(right, twice_wide_layout),
69+
)?;
70+
let (added, _overflow, _ty) = this.overflowing_binary_op(
71+
mir::BinOp::Add,
72+
&ImmTy::from_scalar(added, twice_wide_layout),
73+
&ImmTy::from_uint(1u32, twice_wide_layout),
74+
)?;
7975

80-
for i in 0..dest_len {
81-
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?;
82-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u16()?;
83-
let dest = this.project_index(&dest, i)?;
76+
// Calculate (left + right + 1) / 2
77+
let (divided, _overflow, _ty) = this.overflowing_binary_op(
78+
mir::BinOp::Div,
79+
&ImmTy::from_scalar(added, twice_wide_layout),
80+
&ImmTy::from_uint(2u32, twice_wide_layout),
81+
)?;
8482

85-
// Values are expanded from u16 to u32, so adds cannot overflow.
86-
let res = u32::from(left)
87-
.checked_add(u32::from(right))
88-
.unwrap()
89-
.checked_add(1)
90-
.unwrap()
91-
/ 2;
92-
this.write_scalar(Scalar::from_u16(res.try_into().unwrap()), &dest)?;
83+
// Narrow back to the original type
84+
let res = this.int_to_int_or_float(
85+
&ImmTy::from_scalar(divided, twice_wide_layout),
86+
dest.layout.ty,
87+
)?;
88+
this.write_immediate(res, &dest)?;
9389
}
9490
}
9591
// Used to implement the _mm_mulhi_epi16 function.

0 commit comments

Comments
 (0)