Skip to content

Commit d1da613

Browse files
committed
Make the unsafe_sizeof_count_copies lint work with more functions
Specifically: - find std::ptr::write_bytes - find std::ptr::swap_nonoverlapping - find std::ptr::slice_from_raw_parts - find std::ptr::slice_from_raw_parts_mut - pointer_primitive::write_bytes
1 parent 336e41d commit d1da613

File tree

4 files changed

+126
-54
lines changed

4 files changed

+126
-54
lines changed

clippy_lints/src/unsafe_sizeof_count_copies.rs

+30-12
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ declare_clippy_lint! {
4141
declare_lint_pass!(UnsafeSizeofCountCopies => [UNSAFE_SIZEOF_COUNT_COPIES]);
4242

4343
fn get_size_of_ty(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -> Option<Ty<'tcx>> {
44-
match &expr.kind {
45-
ExprKind::Call(ref count_func, _func_args) => {
44+
match expr.kind {
45+
ExprKind::Call(count_func, _func_args) => {
4646
if_chain! {
4747
if let ExprKind::Path(ref count_func_qpath) = count_func.kind;
4848
if let Some(def_id) = cx.qpath_res(count_func_qpath, count_func.hir_id).opt_def_id();
@@ -56,21 +56,24 @@ fn get_size_of_ty(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -> Option<Ty<'tc
5656
}
5757
},
5858
ExprKind::Binary(op, left, right) if BinOpKind::Mul == op.node || BinOpKind::Div == op.node => {
59-
get_size_of_ty(cx, &*left).or_else(|| get_size_of_ty(cx, &*right))
59+
get_size_of_ty(cx, left).or_else(|| get_size_of_ty(cx, right))
6060
},
6161
_ => None,
6262
}
6363
}
6464

6565
fn get_pointee_ty_and_count_expr(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -> Option<(Ty<'tcx>, &'tcx Expr<'tcx>)> {
6666
if_chain! {
67-
// Find calls to ptr::copy and copy_nonoverlapping
68-
if let ExprKind::Call(ref func, ref args) = expr.kind;
69-
if let [_src, _dest, count] = &**args;
67+
// Find calls to ptr::{copy, copy_nonoverlapping}
68+
// and ptr::{swap_nonoverlapping, write_bytes},
69+
if let ExprKind::Call(func, args) = expr.kind;
70+
if let [_, _, count] = args;
7071
if let ExprKind::Path(ref func_qpath) = func.kind;
7172
if let Some(def_id) = cx.qpath_res(func_qpath, func.hir_id).opt_def_id();
7273
if match_def_path(cx, def_id, &paths::COPY_NONOVERLAPPING)
73-
|| match_def_path(cx, def_id, &paths::COPY);
74+
|| match_def_path(cx, def_id, &paths::COPY)
75+
|| match_def_path(cx, def_id, &paths::WRITE_BYTES)
76+
|| match_def_path(cx, def_id, &paths::PTR_SWAP_NONOVERLAPPING);
7477

7578
// Get the pointee type
7679
if let Some(pointee_ty) = cx.typeck_results().node_substs(func.hir_id).types().next();
@@ -79,11 +82,11 @@ fn get_pointee_ty_and_count_expr(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -
7982
}
8083
};
8184
if_chain! {
82-
// Find calls to copy_{from,to}{,_nonoverlapping}
83-
if let ExprKind::MethodCall(ref method_path, _, ref args, _) = expr.kind;
84-
if let [ptr_self, _, count] = &**args;
85+
// Find calls to copy_{from,to}{,_nonoverlapping} and write_bytes methods
86+
if let ExprKind::MethodCall(method_path, _, args, _) = expr.kind;
87+
if let [ptr_self, _, count] = args;
8588
let method_ident = method_path.ident.as_str();
86-
if method_ident== "copy_to" || method_ident == "copy_from"
89+
if method_ident == "write_bytes" || method_ident == "copy_to" || method_ident == "copy_from"
8790
|| method_ident == "copy_to_nonoverlapping" || method_ident == "copy_from_nonoverlapping";
8891

8992
// Get the pointee type
@@ -93,6 +96,21 @@ fn get_pointee_ty_and_count_expr(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -
9396
return Some((pointee_ty, count));
9497
}
9598
};
99+
if_chain! {
100+
// Find calls to ptr::copy and copy_nonoverlapping
101+
if let ExprKind::Call(func, args) = expr.kind;
102+
if let [_data, count] = args;
103+
if let ExprKind::Path(ref func_qpath) = func.kind;
104+
if let Some(def_id) = cx.qpath_res(func_qpath, func.hir_id).opt_def_id();
105+
if match_def_path(cx, def_id, &paths::PTR_SLICE_FROM_RAW_PARTS)
106+
|| match_def_path(cx, def_id, &paths::PTR_SLICE_FROM_RAW_PARTS_MUT);
107+
108+
// Get the pointee type
109+
if let Some(pointee_ty) = cx.typeck_results().node_substs(func.hir_id).types().next();
110+
then {
111+
return Some((pointee_ty, count));
112+
}
113+
};
96114
None
97115
}
98116

@@ -102,7 +120,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeSizeofCountCopies {
102120
for the count parameter, it already gets multiplied by the size of the pointed to type";
103121

104122
const LINT_MSG: &str = "unsafe memory copying using a byte count \
105-
(Multiplied by size_of::<T>) instead of a count of T";
123+
(multiplied by size_of/size_of_val::<T>) instead of a count of T";
106124

107125
if_chain! {
108126
// Find calls to unsafe copy functions and get

clippy_lints/src/utils/paths.rs

+4
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ pub const POLL_READY: [&str; 5] = ["core", "task", "poll", "Poll", "Ready"];
101101
pub const PTR_EQ: [&str; 3] = ["core", "ptr", "eq"];
102102
pub const PTR_NULL: [&str; 3] = ["core", "ptr", "null"];
103103
pub const PTR_NULL_MUT: [&str; 3] = ["core", "ptr", "null_mut"];
104+
pub const PTR_SLICE_FROM_RAW_PARTS: [&str; 3] = ["core", "ptr", "slice_from_raw_parts"];
105+
pub const PTR_SLICE_FROM_RAW_PARTS_MUT: [&str; 3] = ["core", "ptr", "slice_from_raw_parts_mut"];
106+
pub const PTR_SWAP_NONOVERLAPPING: [&str; 3] = ["core", "ptr", "swap_nonoverlapping"];
104107
pub const PUSH_STR: [&str; 4] = ["alloc", "string", "String", "push_str"];
105108
pub const RANGE_ARGUMENT_TRAIT: [&str; 3] = ["core", "ops", "RangeBounds"];
106109
pub const RC: [&str; 3] = ["alloc", "rc", "Rc"];
@@ -154,3 +157,4 @@ pub const VEC_NEW: [&str; 4] = ["alloc", "vec", "Vec", "new"];
154157
pub const VEC_RESIZE: [&str; 4] = ["alloc", "vec", "Vec", "resize"];
155158
pub const WEAK_ARC: [&str; 3] = ["alloc", "sync", "Weak"];
156159
pub const WEAK_RC: [&str; 3] = ["alloc", "rc", "Weak"];
160+
pub const WRITE_BYTES: [&str; 3] = ["core", "intrinsics", "write_bytes"];

tests/ui/unsafe_sizeof_count_copies.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#![warn(clippy::unsafe_sizeof_count_copies)]
22

33
use std::mem::{size_of, size_of_val};
4-
use std::ptr::{copy, copy_nonoverlapping};
4+
use std::ptr::{
5+
copy, copy_nonoverlapping, slice_from_raw_parts, slice_from_raw_parts_mut, swap_nonoverlapping, write_bytes,
6+
};
57

68
fn main() {
79
const SIZE: usize = 128;
@@ -22,6 +24,14 @@ fn main() {
2224
unsafe { copy(x.as_ptr(), y.as_mut_ptr(), size_of::<u8>()) };
2325
unsafe { copy(x.as_ptr(), y.as_mut_ptr(), size_of_val(&x[0])) };
2426

27+
unsafe { y.as_mut_ptr().write_bytes(0u8, size_of::<u8>() * SIZE) };
28+
unsafe { write_bytes(y.as_mut_ptr(), 0u8, size_of::<u8>() * SIZE) };
29+
30+
unsafe { swap_nonoverlapping(y.as_mut_ptr(), x.as_mut_ptr(), size_of::<u8>() * SIZE) };
31+
32+
unsafe { slice_from_raw_parts_mut(y.as_mut_ptr(), size_of::<u8>() * SIZE) };
33+
unsafe { slice_from_raw_parts(y.as_ptr(), size_of::<u8>() * SIZE) };
34+
2535
// Count expression involving multiplication of size_of (Should trigger the lint)
2636
unsafe { copy_nonoverlapping(x.as_ptr(), y.as_mut_ptr(), size_of::<u8>() * SIZE) };
2737
unsafe { copy_nonoverlapping(x.as_ptr(), y.as_mut_ptr(), size_of_val(&x[0]) * SIZE) };

0 commit comments

Comments
 (0)