Skip to content

Commit d1e93ca

Browse files
committed
interpret: simplify SIMD type handling
1 parent a9fb00b commit d1e93ca

File tree

15 files changed

+187
-197
lines changed

15 files changed

+187
-197
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
384384
sym::simd_insert => {
385385
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
386386
let elem = &args[2];
387-
let (input, input_len) = self.operand_to_simd(&args[0])?;
388-
let (dest, dest_len) = self.mplace_to_simd(dest)?;
387+
let (input, input_len) = self.project_to_simd(&args[0])?;
388+
let (dest, dest_len) = self.project_to_simd(dest)?;
389389
assert_eq!(input_len, dest_len, "Return vector length must match input length");
390390
// Bounds are not checked by typeck so we have to do it ourselves.
391391
if index >= input_len {
@@ -406,7 +406,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
406406
}
407407
sym::simd_extract => {
408408
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
409-
let (input, input_len) = self.operand_to_simd(&args[0])?;
409+
let (input, input_len) = self.project_to_simd(&args[0])?;
410410
// Bounds are not checked by typeck so we have to do it ourselves.
411411
if index >= input_len {
412412
throw_ub_format!(

compiler/rustc_const_eval/src/interpret/operand.rs

-24
Original file line numberDiff line numberDiff line change
@@ -679,30 +679,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
679679
Ok(str)
680680
}
681681

682-
/// Converts a repr(simd) operand into an operand where `place_index` accesses the SIMD elements.
683-
/// Also returns the number of elements.
684-
///
685-
/// Can (but does not always) trigger UB if `op` is uninitialized.
686-
pub fn operand_to_simd(
687-
&self,
688-
op: &OpTy<'tcx, M::Provenance>,
689-
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, u64)> {
690-
// Basically we just transmute this place into an array following simd_size_and_type.
691-
// This only works in memory, but repr(simd) types should never be immediates anyway.
692-
assert!(op.layout.ty.is_simd());
693-
match op.as_mplace_or_imm() {
694-
Left(mplace) => self.mplace_to_simd(&mplace),
695-
Right(imm) => match *imm {
696-
Immediate::Uninit => {
697-
throw_ub!(InvalidUninitBytes(None))
698-
}
699-
Immediate::Scalar(..) | Immediate::ScalarPair(..) => {
700-
bug!("arrays/slices can never have Scalar/ScalarPair layout")
701-
}
702-
},
703-
}
704-
}
705-
706682
/// Read from a local of the current frame.
707683
/// Will not access memory, instead an indirect `Operand` is returned.
708684
///

compiler/rustc_const_eval/src/interpret/place.rs

+22-23
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,15 @@ where
375375
Prov: Provenance,
376376
M: Machine<'tcx, Provenance = Prov>,
377377
{
378-
pub fn ptr_with_meta_to_mplace(
378+
fn ptr_with_meta_to_mplace(
379379
&self,
380380
ptr: Pointer<Option<M::Provenance>>,
381381
meta: MemPlaceMeta<M::Provenance>,
382382
layout: TyAndLayout<'tcx>,
383+
unaligned: bool,
383384
) -> MPlaceTy<'tcx, M::Provenance> {
384-
let misaligned = self.is_ptr_misaligned(ptr, layout.align.abi);
385+
let misaligned =
386+
if unaligned { None } else { self.is_ptr_misaligned(ptr, layout.align.abi) };
385387
MPlaceTy { mplace: MemPlace { ptr, meta, misaligned }, layout }
386388
}
387389

@@ -391,7 +393,16 @@ where
391393
layout: TyAndLayout<'tcx>,
392394
) -> MPlaceTy<'tcx, M::Provenance> {
393395
assert!(layout.is_sized());
394-
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout)
396+
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout, /*unaligned*/ false)
397+
}
398+
399+
pub fn ptr_to_mplace_unaligned(
400+
&self,
401+
ptr: Pointer<Option<M::Provenance>>,
402+
layout: TyAndLayout<'tcx>,
403+
) -> MPlaceTy<'tcx, M::Provenance> {
404+
assert!(layout.is_sized());
405+
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout, /*unaligned*/ true)
395406
}
396407

397408
/// Take a value, which represents a (thin or wide) reference, and make it a place.
@@ -412,7 +423,7 @@ where
412423
// `ref_to_mplace` is called on raw pointers even if they don't actually get dereferenced;
413424
// we hence can't call `size_and_align_of` since that asserts more validity than we want.
414425
let ptr = ptr.to_pointer(self)?;
415-
Ok(self.ptr_with_meta_to_mplace(ptr, meta, layout))
426+
Ok(self.ptr_with_meta_to_mplace(ptr, meta, layout, /*unaligned*/ false))
416427
}
417428

418429
/// Turn a mplace into a (thin or wide) mutable raw pointer, pointing to the same space.
@@ -482,23 +493,6 @@ where
482493
Ok(a)
483494
}
484495

485-
/// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
486-
/// Also returns the number of elements.
487-
pub fn mplace_to_simd(
488-
&self,
489-
mplace: &MPlaceTy<'tcx, M::Provenance>,
490-
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, u64)> {
491-
// Basically we want to transmute this place into an array following simd_size_and_type.
492-
let (len, e_ty) = mplace.layout.ty.simd_size_and_type(*self.tcx);
493-
// Some SIMD types have padding, so `len` many `e_ty` does not cover the entire place.
494-
// Therefore we cannot transmute, and instead we project at offset 0, which side-steps
495-
// the size check.
496-
let array_layout = self.layout_of(Ty::new_array(self.tcx.tcx, e_ty, len))?;
497-
assert!(array_layout.size <= mplace.layout.size);
498-
let mplace = mplace.offset(Size::ZERO, array_layout, self)?;
499-
Ok((mplace, len))
500-
}
501-
502496
/// Turn a local in the current frame into a place.
503497
pub fn local_to_place(
504498
&self,
@@ -983,7 +977,7 @@ where
983977
span_bug!(self.cur_span(), "cannot allocate space for `extern` type, size is not known")
984978
};
985979
let ptr = self.allocate_ptr(size, align, kind)?;
986-
Ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout))
980+
Ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout, /*unaligned*/ false))
987981
}
988982

989983
pub fn allocate(
@@ -1018,7 +1012,12 @@ where
10181012
};
10191013
let meta = Scalar::from_target_usize(u64::try_from(str.len()).unwrap(), self);
10201014
let layout = self.layout_of(self.tcx.types.str_).unwrap();
1021-
Ok(self.ptr_with_meta_to_mplace(ptr.into(), MemPlaceMeta::Meta(meta), layout))
1015+
Ok(self.ptr_with_meta_to_mplace(
1016+
ptr.into(),
1017+
MemPlaceMeta::Meta(meta),
1018+
layout,
1019+
/*unaligned*/ false,
1020+
))
10221021
}
10231022

10241023
pub fn raw_const_to_mplace(

compiler/rustc_const_eval/src/interpret/projection.rs

+13
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,19 @@ where
244244
base.offset(offset, field_layout, self)
245245
}
246246

247+
/// Converts a repr(simd) value into an array of the right size, such that `project_index`
248+
/// accesses the SIMD elements. Also returns the number of elements.
249+
pub fn project_to_simd<P: Projectable<'tcx, M::Provenance>>(
250+
&self,
251+
base: &P,
252+
) -> InterpResult<'tcx, (P, u64)> {
253+
assert!(base.layout().ty.ty_adt_def().unwrap().repr().simd());
254+
// SIMD types must be newtypes around arrays, so all we have to do is project to their only field.
255+
let array = self.project_field(base, 0)?;
256+
let len = array.len(self)?;
257+
Ok((array, len))
258+
}
259+
247260
fn project_constant_index<P: Projectable<'tcx, M::Provenance>>(
248261
&self,
249262
base: &P,

compiler/rustc_const_eval/src/interpret/util.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rustc_middle::ty::{
99
};
1010
use tracing::debug;
1111

12-
use super::{throw_inval, InterpCx, MPlaceTy, MemPlaceMeta, MemoryKind};
12+
use super::{throw_inval, InterpCx, MPlaceTy, MemoryKind};
1313
use crate::const_eval::{CompileTimeInterpCx, CompileTimeMachine, InterpretationResult};
1414

1515
/// Checks whether a type contains generic parameters which must be instantiated.
@@ -103,5 +103,5 @@ pub(crate) fn create_static_alloc<'tcx>(
103103
assert_eq!(ecx.machine.static_root_ids, None);
104104
ecx.machine.static_root_ids = Some((alloc_id, static_def_id));
105105
assert!(ecx.memory.alloc_map.insert(alloc_id, (MemoryKind::Stack, alloc)).is_none());
106-
Ok(ecx.ptr_with_meta_to_mplace(Pointer::from(alloc_id).into(), MemPlaceMeta::None, layout))
106+
Ok(ecx.ptr_to_mplace(Pointer::from(alloc_id).into(), layout))
107107
}

src/tools/miri/src/intrinsics/simd.rs

+39-39
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
5050
| "bitreverse"
5151
=> {
5252
let [op] = check_arg_count(args)?;
53-
let (op, op_len) = this.operand_to_simd(op)?;
54-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
53+
let (op, op_len) = this.project_to_simd(op)?;
54+
let (dest, dest_len) = this.project_to_simd(dest)?;
5555

5656
assert_eq!(dest_len, op_len);
5757

@@ -200,9 +200,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
200200
use mir::BinOp;
201201

202202
let [left, right] = check_arg_count(args)?;
203-
let (left, left_len) = this.operand_to_simd(left)?;
204-
let (right, right_len) = this.operand_to_simd(right)?;
205-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
203+
let (left, left_len) = this.project_to_simd(left)?;
204+
let (right, right_len) = this.project_to_simd(right)?;
205+
let (dest, dest_len) = this.project_to_simd(dest)?;
206206

207207
assert_eq!(dest_len, left_len);
208208
assert_eq!(dest_len, right_len);
@@ -291,10 +291,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
291291
}
292292
"fma" => {
293293
let [a, b, c] = check_arg_count(args)?;
294-
let (a, a_len) = this.operand_to_simd(a)?;
295-
let (b, b_len) = this.operand_to_simd(b)?;
296-
let (c, c_len) = this.operand_to_simd(c)?;
297-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
294+
let (a, a_len) = this.project_to_simd(a)?;
295+
let (b, b_len) = this.project_to_simd(b)?;
296+
let (c, c_len) = this.project_to_simd(c)?;
297+
let (dest, dest_len) = this.project_to_simd(dest)?;
298298

299299
assert_eq!(dest_len, a_len);
300300
assert_eq!(dest_len, b_len);
@@ -345,7 +345,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
345345
use mir::BinOp;
346346

347347
let [op] = check_arg_count(args)?;
348-
let (op, op_len) = this.operand_to_simd(op)?;
348+
let (op, op_len) = this.project_to_simd(op)?;
349349

350350
let imm_from_bool =
351351
|b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
@@ -408,7 +408,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
408408
use mir::BinOp;
409409

410410
let [op, init] = check_arg_count(args)?;
411-
let (op, op_len) = this.operand_to_simd(op)?;
411+
let (op, op_len) = this.project_to_simd(op)?;
412412
let init = this.read_immediate(init)?;
413413

414414
let mir_op = match intrinsic_name {
@@ -426,10 +426,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
426426
}
427427
"select" => {
428428
let [mask, yes, no] = check_arg_count(args)?;
429-
let (mask, mask_len) = this.operand_to_simd(mask)?;
430-
let (yes, yes_len) = this.operand_to_simd(yes)?;
431-
let (no, no_len) = this.operand_to_simd(no)?;
432-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
429+
let (mask, mask_len) = this.project_to_simd(mask)?;
430+
let (yes, yes_len) = this.project_to_simd(yes)?;
431+
let (no, no_len) = this.project_to_simd(no)?;
432+
let (dest, dest_len) = this.project_to_simd(dest)?;
433433

434434
assert_eq!(dest_len, mask_len);
435435
assert_eq!(dest_len, yes_len);
@@ -448,9 +448,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
448448
// Variant of `select` that takes a bitmask rather than a "vector of bool".
449449
"select_bitmask" => {
450450
let [mask, yes, no] = check_arg_count(args)?;
451-
let (yes, yes_len) = this.operand_to_simd(yes)?;
452-
let (no, no_len) = this.operand_to_simd(no)?;
453-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
451+
let (yes, yes_len) = this.project_to_simd(yes)?;
452+
let (no, no_len) = this.project_to_simd(no)?;
453+
let (dest, dest_len) = this.project_to_simd(dest)?;
454454
let bitmask_len = dest_len.next_multiple_of(8);
455455
if bitmask_len > 64 {
456456
throw_unsup_format!(
@@ -522,7 +522,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
522522
// Converts a "vector of bool" into a bitmask.
523523
"bitmask" => {
524524
let [op] = check_arg_count(args)?;
525-
let (op, op_len) = this.operand_to_simd(op)?;
525+
let (op, op_len) = this.project_to_simd(op)?;
526526
let bitmask_len = op_len.next_multiple_of(8);
527527
if bitmask_len > 64 {
528528
throw_unsup_format!(
@@ -570,8 +570,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
570570
}
571571
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
572572
let [op] = check_arg_count(args)?;
573-
let (op, op_len) = this.operand_to_simd(op)?;
574-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
573+
let (op, op_len) = this.project_to_simd(op)?;
574+
let (dest, dest_len) = this.project_to_simd(dest)?;
575575

576576
assert_eq!(dest_len, op_len);
577577

@@ -627,9 +627,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
627627
}
628628
"shuffle_generic" => {
629629
let [left, right] = check_arg_count(args)?;
630-
let (left, left_len) = this.operand_to_simd(left)?;
631-
let (right, right_len) = this.operand_to_simd(right)?;
632-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
630+
let (left, left_len) = this.project_to_simd(left)?;
631+
let (right, right_len) = this.project_to_simd(right)?;
632+
let (dest, dest_len) = this.project_to_simd(dest)?;
633633

634634
let index = generic_args[2]
635635
.expect_const()
@@ -662,9 +662,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
662662
}
663663
"shuffle" => {
664664
let [left, right, index] = check_arg_count(args)?;
665-
let (left, left_len) = this.operand_to_simd(left)?;
666-
let (right, right_len) = this.operand_to_simd(right)?;
667-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
665+
let (left, left_len) = this.project_to_simd(left)?;
666+
let (right, right_len) = this.project_to_simd(right)?;
667+
let (dest, dest_len) = this.project_to_simd(dest)?;
668668

669669
// `index` is an array, not a SIMD type
670670
let ty::Array(_, index_len) = index.layout.ty.kind() else {
@@ -702,10 +702,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
702702
}
703703
"gather" => {
704704
let [passthru, ptrs, mask] = check_arg_count(args)?;
705-
let (passthru, passthru_len) = this.operand_to_simd(passthru)?;
706-
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
707-
let (mask, mask_len) = this.operand_to_simd(mask)?;
708-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
705+
let (passthru, passthru_len) = this.project_to_simd(passthru)?;
706+
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
707+
let (mask, mask_len) = this.project_to_simd(mask)?;
708+
let (dest, dest_len) = this.project_to_simd(dest)?;
709709

710710
assert_eq!(dest_len, passthru_len);
711711
assert_eq!(dest_len, ptrs_len);
@@ -728,9 +728,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
728728
}
729729
"scatter" => {
730730
let [value, ptrs, mask] = check_arg_count(args)?;
731-
let (value, value_len) = this.operand_to_simd(value)?;
732-
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
733-
let (mask, mask_len) = this.operand_to_simd(mask)?;
731+
let (value, value_len) = this.project_to_simd(value)?;
732+
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
733+
let (mask, mask_len) = this.project_to_simd(mask)?;
734734

735735
assert_eq!(ptrs_len, value_len);
736736
assert_eq!(ptrs_len, mask_len);
@@ -748,10 +748,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
748748
}
749749
"masked_load" => {
750750
let [mask, ptr, default] = check_arg_count(args)?;
751-
let (mask, mask_len) = this.operand_to_simd(mask)?;
751+
let (mask, mask_len) = this.project_to_simd(mask)?;
752752
let ptr = this.read_pointer(ptr)?;
753-
let (default, default_len) = this.operand_to_simd(default)?;
754-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
753+
let (default, default_len) = this.project_to_simd(default)?;
754+
let (dest, dest_len) = this.project_to_simd(dest)?;
755755

756756
assert_eq!(dest_len, mask_len);
757757
assert_eq!(dest_len, default_len);
@@ -775,9 +775,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
775775
}
776776
"masked_store" => {
777777
let [mask, ptr, vals] = check_arg_count(args)?;
778-
let (mask, mask_len) = this.operand_to_simd(mask)?;
778+
let (mask, mask_len) = this.project_to_simd(mask)?;
779779
let ptr = this.read_pointer(ptr)?;
780-
let (vals, vals_len) = this.operand_to_simd(vals)?;
780+
let (vals, vals_len) = this.project_to_simd(vals)?;
781781

782782
assert_eq!(mask_len, vals_len);
783783

src/tools/miri/src/shims/foreign_items.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,8 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
903903
name if name.starts_with("llvm.ctpop.v") => {
904904
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
905905

906-
let (op, op_len) = this.operand_to_simd(op)?;
907-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
906+
let (op, op_len) = this.project_to_simd(op)?;
907+
let (dest, dest_len) = this.project_to_simd(dest)?;
908908

909909
assert_eq!(dest_len, op_len);
910910

0 commit comments

Comments
 (0)