Skip to content

Commit ea279a5

Browse files
Detect (non-raw) borrows of null ZST pointers in CheckNull
1 parent d4bdd1e commit ea279a5

File tree

6 files changed

+76
-29
lines changed

6 files changed

+76
-29
lines changed

compiler/rustc_mir_transform/src/check_alignment.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use rustc_index::IndexVec;
22
use rustc_middle::mir::interpret::Scalar;
3+
use rustc_middle::mir::visit::PlaceContext;
34
use rustc_middle::mir::*;
45
use rustc_middle::ty::{Ty, TyCtxt};
56
use rustc_session::Session;
@@ -44,6 +45,7 @@ fn insert_alignment_check<'tcx>(
4445
tcx: TyCtxt<'tcx>,
4546
pointer: Place<'tcx>,
4647
pointee_ty: Ty<'tcx>,
48+
_context: PlaceContext,
4749
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
4850
stmts: &mut Vec<Statement<'tcx>>,
4951
source_info: SourceInfo,

compiler/rustc_mir_transform/src/check_null.rs

+46-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rustc_index::IndexVec;
2-
use rustc_middle::mir::interpret::Scalar;
2+
use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext};
33
use rustc_middle::mir::*;
44
use rustc_middle::ty::{Ty, TyCtxt};
55
use rustc_session::Session;
@@ -26,6 +26,7 @@ fn insert_null_check<'tcx>(
2626
tcx: TyCtxt<'tcx>,
2727
pointer: Place<'tcx>,
2828
pointee_ty: Ty<'tcx>,
29+
context: PlaceContext,
2930
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
3031
stmts: &mut Vec<Statement<'tcx>>,
3132
source_info: SourceInfo,
@@ -42,30 +43,51 @@ fn insert_null_check<'tcx>(
4243
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
4344
stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
4445

45-
// Get size of the pointee (zero-sized reads and writes are allowed).
46-
let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
47-
let sizeof_pointee =
48-
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
49-
stmts.push(Statement {
50-
source_info,
51-
kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
52-
});
53-
54-
// Check that the pointee is not a ZST.
5546
let zero = Operand::Constant(Box::new(ConstOperand {
5647
span: source_info.span,
5748
user_ty: None,
58-
const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
49+
const_: Const::Val(ConstValue::from_target_usize(0, &tcx), tcx.types.usize),
5950
}));
60-
let is_pointee_no_zst =
61-
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
62-
stmts.push(Statement {
63-
source_info,
64-
kind: StatementKind::Assign(Box::new((
65-
is_pointee_no_zst,
66-
Rvalue::BinaryOp(BinOp::Ne, Box::new((Operand::Copy(sizeof_pointee), zero.clone()))),
67-
))),
68-
});
51+
52+
let pointee_should_be_checked = match context {
53+
// Borrows are UB even if the pointee is a ZST.
54+
PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow)
55+
| PlaceContext::MutatingUse(MutatingUseContext::Borrow) => {
56+
// Pointer should be checked unconditionally.
57+
Operand::Constant(Box::new(ConstOperand {
58+
span: source_info.span,
59+
user_ty: None,
60+
const_: Const::Val(ConstValue::from_bool(true), tcx.types.bool),
61+
}))
62+
}
63+
// Other usages only are UB if the pointee is not a ZST.
64+
_ => {
65+
let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
66+
let sizeof_pointee =
67+
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
68+
stmts.push(Statement {
69+
source_info,
70+
kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
71+
});
72+
73+
// Check that the pointee is not a ZST.
74+
let is_pointee_not_zst =
75+
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
76+
stmts.push(Statement {
77+
source_info,
78+
kind: StatementKind::Assign(Box::new((
79+
is_pointee_not_zst,
80+
Rvalue::BinaryOp(
81+
BinOp::Ne,
82+
Box::new((Operand::Copy(sizeof_pointee), zero.clone())),
83+
),
84+
))),
85+
});
86+
87+
// Pointer needs to be checked only if pointee is not a ZST.
88+
Operand::Copy(is_pointee_not_zst)
89+
}
90+
};
6991

7092
// Check whether the pointer is null.
7193
let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
@@ -77,7 +99,8 @@ fn insert_null_check<'tcx>(
7799
))),
78100
});
79101

80-
// We want to throw an exception if the pointer is null and doesn't point to a ZST.
102+
// We want to throw an exception if the pointer is null and the pointee is not unconditionally
103+
// allowed (which for all non-borrow place uses, is when the pointee is ZST).
81104
let should_throw_exception =
82105
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
83106
stmts.push(Statement {
@@ -86,7 +109,7 @@ fn insert_null_check<'tcx>(
86109
should_throw_exception,
87110
Rvalue::BinaryOp(
88111
BinOp::BitAnd,
89-
Box::new((Operand::Copy(is_null), Operand::Copy(is_pointee_no_zst))),
112+
Box::new((Operand::Copy(is_null), pointee_should_be_checked)),
90113
),
91114
))),
92115
});

compiler/rustc_mir_transform/src/check_pointers.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,18 @@ pub(crate) enum BorrowCheckMode {
4040
/// success and fail the check otherwise.
4141
/// This utility will insert a terminator block that asserts on the condition
4242
/// and panics on failure.
43-
pub(crate) fn check_pointers<'a, 'tcx, F>(
43+
pub(crate) fn check_pointers<'tcx, F>(
4444
tcx: TyCtxt<'tcx>,
4545
body: &mut Body<'tcx>,
46-
excluded_pointees: &'a [Ty<'tcx>],
46+
excluded_pointees: &[Ty<'tcx>],
4747
on_finding: F,
4848
borrow_check_mode: BorrowCheckMode,
4949
) where
5050
F: Fn(
5151
/* tcx: */ TyCtxt<'tcx>,
5252
/* pointer: */ Place<'tcx>,
5353
/* pointee_ty: */ Ty<'tcx>,
54+
/* context: */ PlaceContext,
5455
/* local_decls: */ &mut IndexVec<Local, LocalDecl<'tcx>>,
5556
/* stmts: */ &mut Vec<Statement<'tcx>>,
5657
/* source_info: */ SourceInfo,
@@ -86,7 +87,7 @@ pub(crate) fn check_pointers<'a, 'tcx, F>(
8687
);
8788
finder.visit_statement(statement, location);
8889

89-
for (local, ty) in finder.into_found_pointers() {
90+
for (local, ty, context) in finder.into_found_pointers() {
9091
debug!("Inserting check for {:?}", ty);
9192
let new_block = split_block(basic_blocks, location);
9293

@@ -98,6 +99,7 @@ pub(crate) fn check_pointers<'a, 'tcx, F>(
9899
tcx,
99100
local,
100101
ty,
102+
context,
101103
local_decls,
102104
&mut block_data.statements,
103105
source_info,
@@ -125,7 +127,7 @@ struct PointerFinder<'a, 'tcx> {
125127
tcx: TyCtxt<'tcx>,
126128
local_decls: &'a mut LocalDecls<'tcx>,
127129
typing_env: ty::TypingEnv<'tcx>,
128-
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
130+
pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
129131
excluded_pointees: &'a [Ty<'tcx>],
130132
borrow_check_mode: BorrowCheckMode,
131133
}
@@ -148,7 +150,7 @@ impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
148150
}
149151
}
150152

151-
fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
153+
fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> {
152154
self.pointers
153155
}
154156

@@ -211,7 +213,7 @@ impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
211213
return;
212214
}
213215

214-
self.pointers.push((pointer, pointee_ty));
216+
self.pointers.push((pointer, pointee_ty, context));
215217

216218
self.super_place(place, context, location);
217219
}
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: null pointer dereference occured
4+
5+
fn main() {
6+
let ptr: *const () = std::ptr::null();
7+
let _ptr: &() = unsafe { &*ptr };
8+
}

tests/ui/mir/null/place_without_read.rs

+1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ fn main() {
66
let ptr: *const u16 = std::ptr::null();
77
unsafe {
88
let _ = *ptr;
9+
let _ = &raw const *ptr;
910
}
1011
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Make sure that we don't insert a check for places that do not read.
2+
//@ run-pass
3+
//@ compile-flags: -C debug-assertions
4+
5+
fn main() {
6+
let ptr: *const () = std::ptr::null();
7+
unsafe {
8+
let _ = *ptr;
9+
let _ = &raw const *ptr;
10+
}
11+
}

0 commit comments

Comments
 (0)