Skip to content

interpret: simplify SIMD type handling #130215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
sym::simd_insert => {
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
let elem = &args[2];
let (input, input_len) = self.operand_to_simd(&args[0])?;
let (dest, dest_len) = self.mplace_to_simd(dest)?;
let (input, input_len) = self.project_to_simd(&args[0])?;
let (dest, dest_len) = self.project_to_simd(dest)?;
assert_eq!(input_len, dest_len, "Return vector length must match input length");
// Bounds are not checked by typeck so we have to do it ourselves.
if index >= input_len {
Expand All @@ -406,7 +406,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
sym::simd_extract => {
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
let (input, input_len) = self.operand_to_simd(&args[0])?;
let (input, input_len) = self.project_to_simd(&args[0])?;
// Bounds are not checked by typeck so we have to do it ourselves.
if index >= input_len {
throw_ub_format!(
Expand Down
24 changes: 0 additions & 24 deletions compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,30 +681,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok(str)
}

/// Converts a repr(simd) operand into an operand where `place_index` accesses the SIMD elements.
/// Also returns the number of elements.
///
/// Can (but does not always) trigger UB if `op` is uninitialized.
pub fn operand_to_simd(
&self,
op: &OpTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, u64)> {
// Basically we just transmute this place into an array following simd_size_and_type.
// This only works in memory, but repr(simd) types should never be immediates anyway.
assert!(op.layout.ty.is_simd());
match op.as_mplace_or_imm() {
Left(mplace) => self.mplace_to_simd(&mplace),
Right(imm) => match *imm {
Immediate::Uninit => {
throw_ub!(InvalidUninitBytes(None))
}
Immediate::Scalar(..) | Immediate::ScalarPair(..) => {
bug!("arrays/slices can never have Scalar/ScalarPair layout")
}
},
}
}

/// Read from a local of the current frame.
/// Will not access memory, instead an indirect `Operand` is returned.
///
Expand Down
45 changes: 22 additions & 23 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,15 @@ where
Prov: Provenance,
M: Machine<'tcx, Provenance = Prov>,
{
pub fn ptr_with_meta_to_mplace(
fn ptr_with_meta_to_mplace(
&self,
ptr: Pointer<Option<M::Provenance>>,
meta: MemPlaceMeta<M::Provenance>,
layout: TyAndLayout<'tcx>,
unaligned: bool,
) -> MPlaceTy<'tcx, M::Provenance> {
let misaligned = self.is_ptr_misaligned(ptr, layout.align.abi);
let misaligned =
if unaligned { None } else { self.is_ptr_misaligned(ptr, layout.align.abi) };
MPlaceTy { mplace: MemPlace { ptr, meta, misaligned }, layout }
}

Expand All @@ -393,7 +395,16 @@ where
layout: TyAndLayout<'tcx>,
) -> MPlaceTy<'tcx, M::Provenance> {
assert!(layout.is_sized());
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout)
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout, /*unaligned*/ false)
}

pub fn ptr_to_mplace_unaligned(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiif this method should also come handy in rust-lang/miri#3852 to deal with the alignment problem in eventfd

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, thanks!

&self,
ptr: Pointer<Option<M::Provenance>>,
layout: TyAndLayout<'tcx>,
) -> MPlaceTy<'tcx, M::Provenance> {
assert!(layout.is_sized());
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout, /*unaligned*/ true)
}

/// Take a value, which represents a (thin or wide) reference, and make it a place.
Expand All @@ -414,7 +425,7 @@ where
// `ref_to_mplace` is called on raw pointers even if they don't actually get dereferenced;
// we hence can't call `size_and_align_of` since that asserts more validity than we want.
let ptr = ptr.to_pointer(self)?;
Ok(self.ptr_with_meta_to_mplace(ptr, meta, layout))
Ok(self.ptr_with_meta_to_mplace(ptr, meta, layout, /*unaligned*/ false))
}

/// Turn a mplace into a (thin or wide) mutable raw pointer, pointing to the same space.
Expand Down Expand Up @@ -484,23 +495,6 @@ where
Ok(a)
}

/// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
/// Also returns the number of elements.
pub fn mplace_to_simd(
&self,
mplace: &MPlaceTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, u64)> {
// Basically we want to transmute this place into an array following simd_size_and_type.
let (len, e_ty) = mplace.layout.ty.simd_size_and_type(*self.tcx);
// Some SIMD types have padding, so `len` many `e_ty` does not cover the entire place.
// Therefore we cannot transmute, and instead we project at offset 0, which side-steps
// the size check.
let array_layout = self.layout_of(Ty::new_array(self.tcx.tcx, e_ty, len))?;
assert!(array_layout.size <= mplace.layout.size);
let mplace = mplace.offset(Size::ZERO, array_layout, self)?;
Ok((mplace, len))
}

/// Turn a local in the current frame into a place.
pub fn local_to_place(
&self,
Expand Down Expand Up @@ -986,7 +980,7 @@ where
span_bug!(self.cur_span(), "cannot allocate space for `extern` type, size is not known")
};
let ptr = self.allocate_ptr(size, align, kind)?;
Ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout))
Ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout, /*unaligned*/ false))
}

pub fn allocate(
Expand Down Expand Up @@ -1021,7 +1015,12 @@ where
};
let meta = Scalar::from_target_usize(u64::try_from(str.len()).unwrap(), self);
let layout = self.layout_of(self.tcx.types.str_).unwrap();
Ok(self.ptr_with_meta_to_mplace(ptr.into(), MemPlaceMeta::Meta(meta), layout))
Ok(self.ptr_with_meta_to_mplace(
ptr.into(),
MemPlaceMeta::Meta(meta),
layout,
/*unaligned*/ false,
))
}

pub fn raw_const_to_mplace(
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_const_eval/src/interpret/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,19 @@ where
base.offset(offset, field_layout, self)
}

/// Converts a repr(simd) value into an array of the right size, such that `project_index`
/// accesses the SIMD elements. Also returns the number of elements.
pub fn project_to_simd<P: Projectable<'tcx, M::Provenance>>(
&self,
base: &P,
) -> InterpResult<'tcx, (P, u64)> {
assert!(base.layout().ty.ty_adt_def().unwrap().repr().simd());
// SIMD types must be newtypes around arrays, so all we have to do is project to their only field.
let array = self.project_field(base, 0)?;
let len = array.len(self)?;
Ok((array, len))
}

fn project_constant_index<P: Projectable<'tcx, M::Provenance>>(
&self,
base: &P,
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_middle::ty::{
};
use tracing::debug;

use super::{throw_inval, InterpCx, MPlaceTy, MemPlaceMeta, MemoryKind};
use super::{throw_inval, InterpCx, MPlaceTy, MemoryKind};
use crate::const_eval::{CompileTimeInterpCx, CompileTimeMachine, InterpretationResult};

/// Checks whether a type contains generic parameters which must be instantiated.
Expand Down Expand Up @@ -103,5 +103,5 @@ pub(crate) fn create_static_alloc<'tcx>(
assert_eq!(ecx.machine.static_root_ids, None);
ecx.machine.static_root_ids = Some((alloc_id, static_def_id));
assert!(ecx.memory.alloc_map.insert(alloc_id, (MemoryKind::Stack, alloc)).is_none());
Ok(ecx.ptr_with_meta_to_mplace(Pointer::from(alloc_id).into(), MemPlaceMeta::None, layout))
Ok(ecx.ptr_to_mplace(Pointer::from(alloc_id).into(), layout))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this change have anything to do w the pr? it seems unrelated lol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i see you changed the privacy of that fn.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had to add ptr_to_mplace_unaligned for this PR and so I had to make the lower-level ptr_with_meta_to_mplace more flexible and didn't want to expose this somewhat "dangerous" operation too far so I checked if I could make it private, and turns out yes I could.

}
82 changes: 41 additions & 41 deletions src/tools/miri/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
| "bitreverse"
=> {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down Expand Up @@ -200,9 +200,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
Expand Down Expand Up @@ -291,10 +291,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"fma" => {
let [a, b, c] = check_arg_count(args)?;
let (a, a_len) = this.operand_to_simd(a)?;
let (b, b_len) = this.operand_to_simd(b)?;
let (c, c_len) = this.operand_to_simd(c)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (a, a_len) = this.project_to_simd(a)?;
let (b, b_len) = this.project_to_simd(b)?;
let (c, c_len) = this.project_to_simd(c)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, a_len);
assert_eq!(dest_len, b_len);
Expand Down Expand Up @@ -345,7 +345,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;

let imm_from_bool =
|b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
Expand Down Expand Up @@ -408,7 +408,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [op, init] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;
let init = this.read_immediate(init)?;

let mir_op = match intrinsic_name {
Expand All @@ -426,10 +426,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"select" => {
let [mask, yes, no] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (yes, yes_len) = this.project_to_simd(yes)?;
let (no, no_len) = this.project_to_simd(no)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, yes_len);
Expand All @@ -448,9 +448,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Variant of `select` that takes a bitmask rather than a "vector of bool".
"select_bitmask" => {
let [mask, yes, no] = check_arg_count(args)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (yes, yes_len) = this.project_to_simd(yes)?;
let (no, no_len) = this.project_to_simd(no)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
let bitmask_len = dest_len.next_multiple_of(8);
if bitmask_len > 64 {
throw_unsup_format!(
Expand Down Expand Up @@ -522,7 +522,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Converts a "vector of bool" into a bitmask.
"bitmask" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;
let bitmask_len = op_len.next_multiple_of(8);
if bitmask_len > 64 {
throw_unsup_format!(
Expand Down Expand Up @@ -570,8 +570,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down Expand Up @@ -627,9 +627,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"shuffle_generic" => {
let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

let index = generic_args[2]
.expect_const()
Expand Down Expand Up @@ -662,15 +662,15 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"shuffle" => {
let [left, right, index] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

// `index` is an array or a SIMD type
let (index, index_len) = match index.layout.ty.kind() {
// FIXME: remove this once `index` must always be a SIMD vector.
ty::Array(..) => (index.assert_mem_place(), index.len(this)?),
_ => this.operand_to_simd(index)?,
ty::Array(..) => (index.clone(), index.len(this)?),
_ => this.project_to_simd(index)?,
};

assert_eq!(left_len, right_len);
Expand Down Expand Up @@ -699,10 +699,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"gather" => {
let [passthru, ptrs, mask] = check_arg_count(args)?;
let (passthru, passthru_len) = this.operand_to_simd(passthru)?;
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (passthru, passthru_len) = this.project_to_simd(passthru)?;
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, passthru_len);
assert_eq!(dest_len, ptrs_len);
Expand All @@ -725,9 +725,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"scatter" => {
let [value, ptrs, mask] = check_arg_count(args)?;
let (value, value_len) = this.operand_to_simd(value)?;
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (value, value_len) = this.project_to_simd(value)?;
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
let (mask, mask_len) = this.project_to_simd(mask)?;

assert_eq!(ptrs_len, value_len);
assert_eq!(ptrs_len, mask_len);
Expand All @@ -745,10 +745,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"masked_load" => {
let [mask, ptr, default] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let ptr = this.read_pointer(ptr)?;
let (default, default_len) = this.operand_to_simd(default)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (default, default_len) = this.project_to_simd(default)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, default_len);
Expand All @@ -772,9 +772,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"masked_store" => {
let [mask, ptr, vals] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let ptr = this.read_pointer(ptr)?;
let (vals, vals_len) = this.operand_to_simd(vals)?;
let (vals, vals_len) = this.project_to_simd(vals)?;

assert_eq!(mask_len, vals_len);

Expand Down
4 changes: 2 additions & 2 deletions src/tools/miri/src/shims/foreign_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,8 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
name if name.starts_with("llvm.ctpop.v") => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down
Loading
Loading