Skip to content

Commit d5eaa16

Browse files
committed
separate bounds-check from alignment check
1 parent 0682894 commit d5eaa16

File tree

13 files changed

+89
-136
lines changed

13 files changed

+89
-136
lines changed

compiler/rustc_const_eval/src/interpret/intern.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,15 @@ impl<'rt, 'mir, 'tcx: 'mir, M: CompileTimeMachine<'mir, 'tcx, const_eval::Memory
259259
// to avoid could be expensive: on the potentially larger types, arrays and slices,
260260
// rather than on all aggregates unconditionally.
261261
if matches!(mplace.layout.ty.kind(), ty::Array(..) | ty::Slice(..)) {
262-
let Some((size, align)) = self.ecx.size_and_align_of_mplace(&mplace)? else {
262+
let Some((size, _align)) = self.ecx.size_and_align_of_mplace(&mplace)? else {
263263
// We do the walk if we can't determine the size of the mplace: we may be
264264
// dealing with extern types here in the future.
265265
return Ok(true);
266266
};
267267

268268
// If there is no provenance in this allocation, it does not contain references
269269
// that point to another allocation, and we can avoid the interning walk.
270-
if let Some(alloc) = self.ecx.get_ptr_alloc(mplace.ptr(), size, align)? {
270+
if let Some(alloc) = self.ecx.get_ptr_alloc(mplace.ptr(), size)? {
271271
if !alloc.has_provenance() {
272272
return Ok(false);
273273
}

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+8-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use rustc_middle::ty::layout::{LayoutOf as _, ValidityRequirement};
1313
use rustc_middle::ty::GenericArgsRef;
1414
use rustc_middle::ty::{Ty, TyCtxt};
1515
use rustc_span::symbol::{sym, Symbol};
16-
use rustc_target::abi::{Abi, Align, Primitive, Size};
16+
use rustc_target::abi::{Abi, Primitive, Size};
1717

1818
use super::{
1919
util::ensure_monomorphic_enough, CheckInAllocMsg, ImmTy, InterpCx, Machine, OpTy, PlaceTy,
@@ -349,10 +349,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
349349
// Check that the range between them is dereferenceable ("in-bounds or one past the
350350
// end of the same allocation"). This is like the check in ptr_offset_inbounds.
351351
let min_ptr = if dist >= 0 { b } else { a };
352-
self.check_ptr_access_align(
352+
self.check_ptr_access(
353353
min_ptr,
354354
Size::from_bytes(dist.unsigned_abs()),
355-
Align::ONE,
356355
CheckInAllocMsg::OffsetFromTest,
357356
)?;
358357

@@ -581,10 +580,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
581580
// pointers to be properly aligned (unlike a read/write operation).
582581
let min_ptr = if offset_bytes >= 0 { ptr } else { offset_ptr };
583582
// This call handles checking for integer/null pointers.
584-
self.check_ptr_access_align(
583+
self.check_ptr_access(
585584
min_ptr,
586585
Size::from_bytes(offset_bytes.unsigned_abs()),
587-
Align::ONE,
588586
CheckInAllocMsg::PointerArithmeticTest,
589587
)?;
590588
Ok(offset_ptr)
@@ -613,7 +611,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
613611
let src = self.read_pointer(src)?;
614612
let dst = self.read_pointer(dst)?;
615613

616-
self.mem_copy(src, align, dst, align, size, nonoverlapping)
614+
self.check_ptr_align(src, align)?;
615+
self.check_ptr_align(dst, align)?;
616+
617+
self.mem_copy(src, dst, size, nonoverlapping)
617618
}
618619

619620
pub(crate) fn write_bytes_intrinsic(
@@ -669,7 +670,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
669670
size|
670671
-> InterpResult<'tcx, &[u8]> {
671672
let ptr = this.read_pointer(op)?;
672-
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size, Align::ONE)? else {
673+
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size)? else {
673674
// zero-sized access
674675
return Ok(&[]);
675676
};

compiler/rustc_const_eval/src/interpret/memory.rs

+48-81
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
258258
None => self.get_alloc_raw(alloc_id)?.size(),
259259
};
260260
// This will also call the access hooks.
261-
self.mem_copy(
262-
ptr,
263-
Align::ONE,
264-
new_ptr.into(),
265-
Align::ONE,
266-
old_size.min(new_size),
267-
/*nonoverlapping*/ true,
268-
)?;
261+
self.mem_copy(ptr, new_ptr.into(), old_size.min(new_size), /*nonoverlapping*/ true)?;
269262
self.deallocate_ptr(ptr, old_size_and_align, kind)?;
270263

271264
Ok(new_ptr)
@@ -367,12 +360,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
367360
&self,
368361
ptr: Pointer<Option<M::Provenance>>,
369362
size: Size,
370-
align: Align,
371363
) -> InterpResult<'tcx, Option<(AllocId, Size, M::ProvenanceExtra)>> {
372364
self.check_and_deref_ptr(
373365
ptr,
374366
size,
375-
align,
376367
CheckInAllocMsg::MemoryAccessTest,
377368
|alloc_id, offset, prov| {
378369
let (size, align) = self
@@ -382,17 +373,16 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
382373
)
383374
}
384375

385-
/// Check if the given pointer points to live memory of given `size` and `align`.
376+
/// Check if the given pointer points to live memory of the given `size`.
386377
/// The caller can control the error message for the out-of-bounds case.
387378
#[inline(always)]
388-
pub fn check_ptr_access_align(
379+
pub fn check_ptr_access(
389380
&self,
390381
ptr: Pointer<Option<M::Provenance>>,
391382
size: Size,
392-
align: Align,
393383
msg: CheckInAllocMsg,
394384
) -> InterpResult<'tcx> {
395-
self.check_and_deref_ptr(ptr, size, align, msg, |alloc_id, _, _| {
385+
self.check_and_deref_ptr(ptr, size, msg, |alloc_id, _, _| {
396386
let (size, align) = self.get_live_alloc_size_and_align(alloc_id, msg)?;
397387
Ok((size, align, ()))
398388
})?;
@@ -408,7 +398,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
408398
&self,
409399
ptr: Pointer<Option<M::Provenance>>,
410400
size: Size,
411-
align: Align,
412401
msg: CheckInAllocMsg,
413402
alloc_size: impl FnOnce(
414403
AllocId,
@@ -423,17 +412,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
423412
if size.bytes() > 0 || addr == 0 {
424413
throw_ub!(DanglingIntPointer(addr, msg));
425414
}
426-
// Must be aligned.
427-
if M::enforce_alignment(self) && align.bytes() > 1 {
428-
self.check_misalign(
429-
Self::offset_misalignment(addr, align),
430-
CheckAlignMsg::AccessedPtr,
431-
)?;
432-
}
433415
None
434416
}
435417
Ok((alloc_id, offset, prov)) => {
436-
let (alloc_size, alloc_align, ret_val) = alloc_size(alloc_id, offset, prov)?;
418+
let (alloc_size, _alloc_align, ret_val) = alloc_size(alloc_id, offset, prov)?;
437419
// Test bounds. This also ensures non-null.
438420
// It is sufficient to check this for the end pointer. Also check for overflow!
439421
if offset.checked_add(size, &self.tcx).map_or(true, |end| end > alloc_size) {
@@ -449,14 +431,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
449431
if M::Provenance::OFFSET_IS_ADDR {
450432
assert_ne!(ptr.addr(), Size::ZERO);
451433
}
452-
// Test align. Check this last; if both bounds and alignment are violated
453-
// we want the error to be about the bounds.
454-
if M::enforce_alignment(self) && align.bytes() > 1 {
455-
self.check_misalign(
456-
self.alloc_misalignment(ptr, offset, align, alloc_align),
457-
CheckAlignMsg::AccessedPtr,
458-
)?;
459-
}
460434

461435
// We can still be zero-sized in this branch, in which case we have to
462436
// return `None`.
@@ -465,7 +439,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
465439
})
466440
}
467441

468-
#[inline(always)]
469442
pub(super) fn check_misalign(
470443
&self,
471444
misaligned: Option<Misalignment>,
@@ -477,54 +450,55 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
477450
Ok(())
478451
}
479452

480-
#[must_use]
481-
fn offset_misalignment(offset: u64, align: Align) -> Option<Misalignment> {
482-
if offset % align.bytes() == 0 {
483-
None
484-
} else {
485-
// The biggest power of two through which `offset` is divisible.
486-
let offset_pow2 = 1 << offset.trailing_zeros();
487-
Some(Misalignment { has: Align::from_bytes(offset_pow2).unwrap(), required: align })
488-
}
489-
}
490-
491-
#[must_use]
492-
fn alloc_misalignment(
453+
pub(super) fn is_ptr_misaligned(
493454
&self,
494455
ptr: Pointer<Option<M::Provenance>>,
495-
offset: Size,
496456
align: Align,
497-
alloc_align: Align,
498457
) -> Option<Misalignment> {
499-
if M::use_addr_for_alignment_check(self) {
500-
// `use_addr_for_alignment_check` can only be true if `OFFSET_IS_ADDR` is true.
501-
Self::offset_misalignment(ptr.addr().bytes(), align)
502-
} else {
503-
// Check allocation alignment and offset alignment.
504-
if alloc_align.bytes() < align.bytes() {
505-
Some(Misalignment { has: alloc_align, required: align })
458+
if !M::enforce_alignment(self) || align.bytes() == 1 {
459+
return None;
460+
}
461+
462+
#[inline]
463+
fn offset_misalignment(offset: u64, align: Align) -> Option<Misalignment> {
464+
if offset % align.bytes() == 0 {
465+
None
506466
} else {
507-
Self::offset_misalignment(offset.bytes(), align)
467+
// The biggest power of two through which `offset` is divisible.
468+
let offset_pow2 = 1 << offset.trailing_zeros();
469+
Some(Misalignment { has: Align::from_bytes(offset_pow2).unwrap(), required: align })
508470
}
509471
}
510-
}
511472

512-
pub(super) fn is_ptr_misaligned(
513-
&self,
514-
ptr: Pointer<Option<M::Provenance>>,
515-
align: Align,
516-
) -> Option<Misalignment> {
517-
if !M::enforce_alignment(self) {
518-
return None;
519-
}
520473
match self.ptr_try_get_alloc_id(ptr) {
521-
Err(addr) => Self::offset_misalignment(addr, align),
474+
Err(addr) => offset_misalignment(addr, align),
522475
Ok((alloc_id, offset, _prov)) => {
523476
let (_size, alloc_align, _kind) = self.get_alloc_info(alloc_id);
524-
self.alloc_misalignment(ptr, offset, align, alloc_align)
477+
if M::use_addr_for_alignment_check(self) {
478+
// `use_addr_for_alignment_check` can only be true if `OFFSET_IS_ADDR` is true.
479+
offset_misalignment(ptr.addr().bytes(), align)
480+
} else {
481+
// Check allocation alignment and offset alignment.
482+
if alloc_align.bytes() < align.bytes() {
483+
Some(Misalignment { has: alloc_align, required: align })
484+
} else {
485+
offset_misalignment(offset.bytes(), align)
486+
}
487+
}
525488
}
526489
}
527490
}
491+
492+
/// Checks a pointer for misalignment.
493+
///
494+
/// The error assumes this is checking the pointer used directly for an access.
495+
pub fn check_ptr_align(
496+
&self,
497+
ptr: Pointer<Option<M::Provenance>>,
498+
align: Align,
499+
) -> InterpResult<'tcx> {
500+
self.check_misalign(self.is_ptr_misaligned(ptr, align), CheckAlignMsg::AccessedPtr)
501+
}
528502
}
529503

530504
/// Allocation accessors
@@ -629,18 +603,16 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
629603
}
630604
}
631605

632-
/// "Safe" (bounds and align-checked) allocation access.
606+
/// Bounds-checked *but not align-checked* allocation access.
633607
pub fn get_ptr_alloc<'a>(
634608
&'a self,
635609
ptr: Pointer<Option<M::Provenance>>,
636610
size: Size,
637-
align: Align,
638611
) -> InterpResult<'tcx, Option<AllocRef<'a, 'tcx, M::Provenance, M::AllocExtra, M::Bytes>>>
639612
{
640613
let ptr_and_alloc = self.check_and_deref_ptr(
641614
ptr,
642615
size,
643-
align,
644616
CheckInAllocMsg::MemoryAccessTest,
645617
|alloc_id, offset, prov| {
646618
let alloc = self.get_alloc_raw(alloc_id)?;
@@ -701,15 +673,14 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
701673
Ok((alloc, &mut self.machine))
702674
}
703675

704-
/// "Safe" (bounds and align-checked) allocation access.
676+
/// Bounds-checked *but not align-checked* allocation access.
705677
pub fn get_ptr_alloc_mut<'a>(
706678
&'a mut self,
707679
ptr: Pointer<Option<M::Provenance>>,
708680
size: Size,
709-
align: Align,
710681
) -> InterpResult<'tcx, Option<AllocRefMut<'a, 'tcx, M::Provenance, M::AllocExtra, M::Bytes>>>
711682
{
712-
let parts = self.get_ptr_access(ptr, size, align)?;
683+
let parts = self.get_ptr_access(ptr, size)?;
713684
if let Some((alloc_id, offset, prov)) = parts {
714685
let tcx = *self.tcx;
715686
// FIXME: can we somehow avoid looking up the allocation twice here?
@@ -1066,7 +1037,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
10661037
ptr: Pointer<Option<M::Provenance>>,
10671038
size: Size,
10681039
) -> InterpResult<'tcx, &[u8]> {
1069-
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size, Align::ONE)? else {
1040+
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size)? else {
10701041
// zero-sized access
10711042
return Ok(&[]);
10721043
};
@@ -1092,7 +1063,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
10921063
assert_eq!(lower, len, "can only write iterators with a precise length");
10931064

10941065
let size = Size::from_bytes(len);
1095-
let Some(alloc_ref) = self.get_ptr_alloc_mut(ptr, size, Align::ONE)? else {
1066+
let Some(alloc_ref) = self.get_ptr_alloc_mut(ptr, size)? else {
10961067
// zero-sized access
10971068
assert_matches!(src.next(), None, "iterator said it was empty but returned an element");
10981069
return Ok(());
@@ -1117,29 +1088,25 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
11171088
pub fn mem_copy(
11181089
&mut self,
11191090
src: Pointer<Option<M::Provenance>>,
1120-
src_align: Align,
11211091
dest: Pointer<Option<M::Provenance>>,
1122-
dest_align: Align,
11231092
size: Size,
11241093
nonoverlapping: bool,
11251094
) -> InterpResult<'tcx> {
1126-
self.mem_copy_repeatedly(src, src_align, dest, dest_align, size, 1, nonoverlapping)
1095+
self.mem_copy_repeatedly(src, dest, size, 1, nonoverlapping)
11271096
}
11281097

11291098
pub fn mem_copy_repeatedly(
11301099
&mut self,
11311100
src: Pointer<Option<M::Provenance>>,
1132-
src_align: Align,
11331101
dest: Pointer<Option<M::Provenance>>,
1134-
dest_align: Align,
11351102
size: Size,
11361103
num_copies: u64,
11371104
nonoverlapping: bool,
11381105
) -> InterpResult<'tcx> {
11391106
let tcx = self.tcx;
11401107
// We need to do our own bounds-checks.
1141-
let src_parts = self.get_ptr_access(src, size, src_align)?;
1142-
let dest_parts = self.get_ptr_access(dest, size * num_copies, dest_align)?; // `Size` multiplication
1108+
let src_parts = self.get_ptr_access(src, size)?;
1109+
let dest_parts = self.get_ptr_access(dest, size * num_copies)?; // `Size` multiplication
11431110

11441111
// FIXME: we look up both allocations twice here, once before for the `check_ptr_access`
11451112
// and once below to get the underlying `&[mut] Allocation`.

compiler/rustc_const_eval/src/interpret/place.rs

+3-10
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ where
460460
.unwrap_or((mplace.layout.size, mplace.layout.align.abi));
461461
// We check alignment separately, and *after* checking everything else.
462462
// If an access is both OOB and misaligned, we want to see the bounds error.
463-
let a = self.get_ptr_alloc(mplace.ptr(), size, Align::ONE)?;
463+
let a = self.get_ptr_alloc(mplace.ptr(), size)?;
464464
self.check_misalign(mplace.mplace.misaligned, CheckAlignMsg::BasedOn)?;
465465
Ok(a)
466466
}
@@ -478,7 +478,7 @@ where
478478
// If an access is both OOB and misaligned, we want to see the bounds error.
479479
// However we have to call `check_misalign` first to make the borrow checker happy.
480480
let misalign_err = self.check_misalign(mplace.mplace.misaligned, CheckAlignMsg::BasedOn);
481-
let a = self.get_ptr_alloc_mut(mplace.ptr(), size, Align::ONE)?;
481+
let a = self.get_ptr_alloc_mut(mplace.ptr(), size)?;
482482
misalign_err?;
483483
Ok(a)
484484
}
@@ -873,14 +873,7 @@ where
873873
// non-overlapping.)
874874
// We check alignment separately, and *after* checking everything else.
875875
// If an access is both OOB and misaligned, we want to see the bounds error.
876-
self.mem_copy(
877-
src.ptr(),
878-
Align::ONE,
879-
dest.ptr(),
880-
Align::ONE,
881-
dest_size,
882-
/*nonoverlapping*/ true,
883-
)?;
876+
self.mem_copy(src.ptr(), dest.ptr(), dest_size, /*nonoverlapping*/ true)?;
884877
self.check_misalign(src.mplace.misaligned, CheckAlignMsg::BasedOn)?;
885878
self.check_misalign(dest.mplace.misaligned, CheckAlignMsg::BasedOn)?;
886879
Ok(())

0 commit comments

Comments
 (0)