Skip to content

Commit 4d11055

Browse files
andyleisersonLiamoluckoErichDonGubler
committed
[msl-out] Fix ReadZeroSkipWrite bounds check mode for pointer arguments
Fixes gfx-rs#4541 -- Co-authored-by: Liam Murphy <[email protected]> Co-Authored-By: Erich Gubler <[email protected]>
1 parent a954f13 commit 4d11055

16 files changed

+971
-3
lines changed

naga/src/back/msl/writer.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,13 @@ trait NameKeyExt {
611611
FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
612612
}
613613
}
614+
615+
fn oob_local(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
616+
match origin {
617+
FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
618+
FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
619+
}
620+
}
614621
}
615622

616623
impl NameKeyExt for NameKey {}
@@ -721,6 +728,11 @@ impl<'a> ExpressionContext<'a> {
721728
index::bounds_check_iter(chain, self.module, self.function, self.info)
722729
}
723730

731+
/// See docs for `proc::index::oob_locals`.
732+
fn oob_locals(&self) -> FastHashSet<Handle<crate::Type>> {
733+
index::oob_locals(self.module, self.function, self.info, self.policies)
734+
}
735+
724736
fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
725737
match self.function.expressions[expr_handle] {
726738
crate::Expression::AccessIndex { base, index } => {
@@ -928,7 +940,18 @@ impl<W: Write> Writer<W> {
928940
Ok(())
929941
}
930942

943+
/// Writes the local variables of the given function, as well as any extra
944+
/// out-of-bounds locals that are needed.
945+
///
946+
/// The names of the OOB locals are also added to `self.names` at the same
947+
/// time.
931948
fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
949+
let oob_locals = context.oob_locals();
950+
for &ty in oob_locals.iter() {
951+
let name_key = NameKey::oob_local(context.origin, ty);
952+
self.names.insert(name_key, self.namer.call("oob"));
953+
}
954+
932955
for (name_key, ty, init) in context
933956
.function
934957
.local_variables
@@ -937,6 +960,10 @@ impl<W: Write> Writer<W> {
937960
let name_key = NameKey::local(context.origin, local_handle);
938961
(name_key, local.ty, local.init)
939962
})
963+
.chain(oob_locals.iter().map(|&ty| {
964+
let name_key = NameKey::oob_local(context.origin, ty);
965+
(name_key, ty, None)
966+
}))
940967
{
941968
let ty_name = TypeContext {
942969
handle: ty,
@@ -1741,7 +1768,39 @@ impl<W: Write> Writer<W> {
17411768
{
17421769
write!(self.out, " ? ")?;
17431770
self.put_access_chain(expr_handle, policy, context)?;
1744-
write!(self.out, " : DefaultConstructible()")?;
1771+
write!(self.out, " : ")?;
1772+
1773+
if context.resolve_type(base).pointer_space().is_some() {
1774+
// We can't just use `DefaultConstructible` if this is a pointer.
1775+
// Instead, we create a dummy local variable to serve as pointer
1776+
// target if the access is out of bounds.
1777+
let result_ty = context.info[expr_handle]
1778+
.ty
1779+
.inner_with(&context.module.types)
1780+
.pointer_base_type();
1781+
let result_ty_handle = match result_ty {
1782+
Some(TypeResolution::Handle(handle)) => handle,
1783+
Some(TypeResolution::Value(_)) => {
1784+
// I don't have a succinct argument why this is the case.
1785+
// It's really up to what the source language lets you do
1786+
// with pointers. Note that regular loads and stores don't
1787+
// take this path -- they have dedicated code in `put_load`
1788+
// and `put_store`.
1789+
unreachable!(
1790+
"Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1791+
);
1792+
}
1793+
None => {
1794+
unreachable!(
1795+
"Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1796+
)
1797+
}
1798+
};
1799+
let name_key = NameKey::oob_local(context.origin, result_ty_handle);
1800+
self.out.write_str(&self.names[&name_key])?;
1801+
} else {
1802+
write!(self.out, "DefaultConstructible()")?;
1803+
}
17451804

17461805
if !is_scoped {
17471806
write!(self.out, ")")?;

naga/src/proc/index.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
Definitions for index bounds checking.
33
*/
44

5-
use core::iter;
5+
use core::iter::{self, zip};
66

77
use crate::arena::{Handle, HandleSet, UniqueArena};
8-
use crate::valid;
8+
use crate::{valid, FastHashSet};
99

1010
/// How should code generated by Naga do bounds checks?
1111
///
@@ -389,6 +389,61 @@ pub fn bounds_check_iter<'a>(
389389
})
390390
}
391391

392+
/// Returns all the types which we need out-of-bounds locals for; that is,
393+
/// all of the types which the code might attempt to get an out-of-bounds
394+
/// pointer to, in which case we yield a pointer to the out-of-bounds local
395+
/// of the correct type.
396+
pub fn oob_locals(
397+
module: &crate::Module,
398+
function: &crate::Function,
399+
info: &valid::FunctionInfo,
400+
policies: BoundsCheckPolicies,
401+
) -> FastHashSet<Handle<crate::Type>> {
402+
let mut result = FastHashSet::default();
403+
404+
if policies.index != BoundsCheckPolicy::ReadZeroSkipWrite {
405+
return result;
406+
}
407+
408+
for statement in &function.body {
409+
// The only situation in which we end up actually needing to create an
410+
// out-of-bounds pointer is when passing one to a function.
411+
//
412+
// This is because pointers are never baked; they're just inlined everywhere
413+
// they're used. That means that loads can just return 0, and stores can just do
414+
// nothing; functions are the only case where you actually *have* to produce a
415+
// pointer.
416+
if let crate::Statement::Call {
417+
function: callee,
418+
ref arguments,
419+
..
420+
} = *statement
421+
{
422+
// Now go through the arguments of the function looking for pointers which need bounds checks.
423+
for (arg_info, &arg) in zip(&module.functions[callee].arguments, arguments) {
424+
match module.types[arg_info.ty].inner {
425+
crate::TypeInner::ValuePointer { .. } => {
426+
// `ValuePointer`s should only ever be used when resolving the types of
427+
// expressions, since the arena can no longer be modified at that point; things
428+
// in the arena should always use proper `Pointer`s.
429+
unreachable!("`ValuePointer` found in arena")
430+
}
431+
crate::TypeInner::Pointer { base, .. } => {
432+
if bounds_check_iter(arg, module, function, info)
433+
.next()
434+
.is_some()
435+
{
436+
result.insert(base);
437+
}
438+
}
439+
_ => continue,
440+
};
441+
}
442+
}
443+
}
444+
result
445+
}
446+
392447
impl GuardedIndex {
393448
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
394449
///

naga/src/proc/namer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@ pub enum NameKey {
2121
Function(Handle<crate::Function>),
2222
FunctionArgument(Handle<crate::Function>, u32),
2323
FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
24+
25+
/// A local variable used by ReadZeroSkipWrite bounds-check policy
26+
/// when it needs to produce a pointer-typed result for an OOB access.
27+
FunctionOobLocal(Handle<crate::Function>, Handle<crate::Type>),
28+
2429
EntryPoint(EntryPointIndex),
2530
EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
2631
EntryPointArgument(EntryPointIndex, u32),
32+
33+
/// Entry point version of `FunctionOobLocal`.
34+
EntryPointOobLocal(EntryPointIndex, Handle<crate::Type>),
2735
}
2836

2937
/// This processor assigns names to all the things in a module
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
targets = "METAL"
2+
3+
[bounds_check_policies]
4+
index = "Restrict"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
fn takes_ptr(p: ptr<function, i32>) {}
2+
fn takes_array_ptr(p: ptr<function, array<i32, 4>>) {}
3+
fn takes_vec_ptr(p: ptr<function, vec2<i32>>) {}
4+
fn takes_mat_ptr(p: ptr<function, mat2x2<f32>>) {}
5+
6+
fn local_var(i: u32) {
7+
var arr = array(1, 2, 3, 4);
8+
takes_ptr(&arr[i]);
9+
takes_array_ptr(&arr);
10+
11+
}
12+
13+
fn mat_vec_ptrs(
14+
pv: ptr<function, array<vec2<i32>, 4>>,
15+
pm: ptr<function, array<mat2x2<f32>, 4>>,
16+
i: u32,
17+
) {
18+
takes_vec_ptr(&pv[i]);
19+
takes_mat_ptr(&pm[i]);
20+
}
21+
22+
fn argument(v: ptr<function, array<i32, 4>>, i: u32) {
23+
takes_ptr(&v[i]);
24+
}
25+
26+
fn argument_nested_x2(v: ptr<function, array<array<i32, 4>, 4>>, i: u32, j: u32) {
27+
takes_ptr(&v[i][j]);
28+
29+
// Mixing compile and runtime bounds checks
30+
takes_ptr(&v[i][0]);
31+
takes_ptr(&v[0][j]);
32+
33+
takes_array_ptr(&v[i]);
34+
}
35+
36+
fn argument_nested_x3(v: ptr<function, array<array<array<i32, 4>, 4>, 4>>, i: u32, j: u32) {
37+
takes_ptr(&v[i][0][j]);
38+
takes_ptr(&v[i][j][0]);
39+
takes_ptr(&v[0][i][j]);
40+
}
41+
42+
fn index_from_self(v: ptr<function, array<i32, 4>>, i: u32) {
43+
takes_ptr(&v[v[i]]);
44+
}
45+
46+
fn local_var_from_arg(a: array<i32, 4>, i: u32) {
47+
var b = a;
48+
takes_ptr(&b[i]);
49+
}
50+
51+
fn let_binding(a: ptr<function, array<i32, 4>>, i: u32) {
52+
let p0 = &a[i];
53+
takes_ptr(p0);
54+
55+
let p1 = &a[0];
56+
takes_ptr(p1);
57+
}
58+
59+
// Runtime-sized arrays can only appear in storage buffers, while (in the base
60+
// language) pointers can only appear in function or private space, so there
61+
// is no interaction to test.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
targets = "METAL"
2+
3+
[bounds_check_policies]
4+
index = "ReadZeroSkipWrite"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
fn takes_ptr(p: ptr<function, i32>) {}
2+
fn takes_array_ptr(p: ptr<function, array<i32, 4>>) {}
3+
fn takes_vec_ptr(p: ptr<function, vec2<i32>>) {}
4+
fn takes_mat_ptr(p: ptr<function, mat2x2<f32>>) {}
5+
6+
fn local_var(i: u32) {
7+
var arr = array(1, 2, 3, 4);
8+
takes_ptr(&arr[i]);
9+
takes_array_ptr(&arr);
10+
11+
}
12+
13+
fn mat_vec_ptrs(
14+
pv: ptr<function, array<vec2<i32>, 4>>,
15+
pm: ptr<function, array<mat2x2<f32>, 4>>,
16+
i: u32,
17+
) {
18+
takes_vec_ptr(&pv[i]);
19+
takes_mat_ptr(&pm[i]);
20+
}
21+
22+
fn argument(v: ptr<function, array<i32, 4>>, i: u32) {
23+
takes_ptr(&v[i]);
24+
}
25+
26+
fn argument_nested_x2(v: ptr<function, array<array<i32, 4>, 4>>, i: u32, j: u32) {
27+
takes_ptr(&v[i][j]);
28+
29+
// Mixing compile and runtime bounds checks
30+
takes_ptr(&v[i][0]);
31+
takes_ptr(&v[0][j]);
32+
33+
takes_array_ptr(&v[i]);
34+
}
35+
36+
fn argument_nested_x3(v: ptr<function, array<array<array<i32, 4>, 4>, 4>>, i: u32, j: u32) {
37+
takes_ptr(&v[i][0][j]);
38+
takes_ptr(&v[i][j][0]);
39+
takes_ptr(&v[0][i][j]);
40+
}
41+
42+
fn index_from_self(v: ptr<function, array<i32, 4>>, i: u32) {
43+
takes_ptr(&v[v[i]]);
44+
}
45+
46+
fn local_var_from_arg(a: array<i32, 4>, i: u32) {
47+
var b = a;
48+
takes_ptr(&b[i]);
49+
}
50+
51+
fn let_binding(a: ptr<function, array<i32, 4>>, i: u32) {
52+
let p0 = &a[i];
53+
takes_ptr(p0);
54+
55+
let p1 = &a[0];
56+
takes_ptr(p1);
57+
}
58+
59+
// Runtime-sized arrays can only appear in storage buffers, while (in the base
60+
// language) pointers can only appear in function or private space, so there
61+
// is no interaction to test.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
targets = "METAL | GLSL | HLSL | WGSL"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
@compute @workgroup_size(1)
2+
fn main() {}
3+
4+
fn takes_ptr(p: ptr<function, i32>) {}
5+
fn takes_array_ptr(p: ptr<function, array<i32, 4>>) {}
6+
fn takes_vec_ptr(p: ptr<function, vec2<i32>>) {}
7+
fn takes_mat_ptr(p: ptr<function, mat2x2<f32>>) {}
8+
9+
fn local_var(i: u32) {
10+
var arr = array(1, 2, 3, 4);
11+
takes_ptr(&arr[i]);
12+
takes_array_ptr(&arr);
13+
14+
}
15+
16+
fn mat_vec_ptrs(
17+
pv: ptr<function, array<vec2<i32>, 4>>,
18+
pm: ptr<function, array<mat2x2<f32>, 4>>,
19+
i: u32,
20+
) {
21+
takes_vec_ptr(&pv[i]);
22+
takes_mat_ptr(&pm[i]);
23+
}
24+
25+
fn argument(v: ptr<function, array<i32, 4>>, i: u32) {
26+
takes_ptr(&v[i]);
27+
}
28+
29+
fn argument_nested_x2(v: ptr<function, array<array<i32, 4>, 4>>, i: u32, j: u32) {
30+
takes_ptr(&v[i][j]);
31+
32+
// Mixing compile and runtime bounds checks
33+
takes_ptr(&v[i][0]);
34+
takes_ptr(&v[0][j]);
35+
36+
takes_array_ptr(&v[i]);
37+
}
38+
39+
fn argument_nested_x3(v: ptr<function, array<array<array<i32, 4>, 4>, 4>>, i: u32, j: u32) {
40+
takes_ptr(&v[i][0][j]);
41+
takes_ptr(&v[i][j][0]);
42+
takes_ptr(&v[0][i][j]);
43+
}
44+
45+
fn index_from_self(v: ptr<function, array<i32, 4>>, i: u32) {
46+
takes_ptr(&v[v[i]]);
47+
}
48+
49+
fn local_var_from_arg(a: array<i32, 4>, i: u32) {
50+
var b = a;
51+
takes_ptr(&b[i]);
52+
}
53+
54+
fn let_binding(a: ptr<function, array<i32, 4>>, i: u32) {
55+
let p0 = &a[i];
56+
takes_ptr(p0);
57+
58+
let p1 = &a[0];
59+
takes_ptr(p1);
60+
}
61+
62+
// Runtime-sized arrays can only appear in storage buffers, while (in the base
63+
// language) pointers can only appear in function or private space, so there
64+
// is no interaction to test.

0 commit comments

Comments
 (0)