Skip to content

Commit 25b32a3

Browse files
andyleisersonVecvec
authored andcommitted
[msl-out] Fix ReadZeroSkipWrite bounds check mode for pointer arguments
Fixes gfx-rs#4541 -- Co-authored-by: Liam Murphy <[email protected]> Co-Authored-By: Erich Gubler <[email protected]>
1 parent dec9de3 commit 25b32a3

17 files changed

+994
-4
lines changed

Diff for: naga/src/back/msl/mod.rs

+14
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ holding the result.
2929
[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
3030
[all-atom]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
3131
32+
## Pointer-typed bounds-checked expressions and OOB locals
33+
34+
MSL (unlike HLSL and GLSL) has native support for pointer-typed function
35+
arguments. When the [`BoundsCheckPolicy`] is `ReadZeroSkipWrite` and an
36+
out-of-bounds index expression is used for such an argument, our strategy is to
37+
pass a pointer to a dummy variable. These dummy variables are called "OOB
38+
locals". We emit at most one OOB local per function for each type, since all
39+
expressions producing a result of that type can share the same OOB local. (Note
40+
that the OOB local mechanism is not actually implementing "skip write", nor even
41+
"read zero" in some cases of read-after-write, but doing so would require
42+
additional effort and the difference is unlikely to matter.)
43+
44+
[`BoundsCheckPolicy`]: crate::proc::BoundsCheckPolicy
45+
3246
*/
3347

3448
use alloc::{

Diff for: naga/src/back/msl/writer.rs

+67-2
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,17 @@ trait NameKeyExt {
612612
FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
613613
}
614614
}
615+
616+
/// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
617+
/// policy when it needs to produce a pointer-typed result for an OOB access. These
618+
/// are unique per accessed type, so the second argument is a type handle. See docs
619+
/// for [`crate::back::msl`].
620+
fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
621+
match origin {
622+
FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
623+
FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
624+
}
625+
}
615626
}
616627

617628
impl NameKeyExt for NameKey {}
@@ -722,6 +733,11 @@ impl<'a> ExpressionContext<'a> {
722733
index::bounds_check_iter(chain, self.module, self.function, self.info)
723734
}
724735

736+
/// See docs for [`proc::index::oob_local_types`].
737+
fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
738+
index::oob_local_types(self.module, self.function, self.info, self.policies)
739+
}
740+
725741
fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
726742
match self.function.expressions[expr_handle] {
727743
crate::Expression::AccessIndex { base, index } => {
@@ -929,8 +945,18 @@ impl<W: Write> Writer<W> {
929945
Ok(())
930946
}
931947

932-
/// Writes the local variables of the given function.
948+
/// Writes the local variables of the given function, as well as any extra
949+
/// out-of-bounds locals that are needed.
950+
///
951+
/// The names of the OOB locals are also added to `self.names` at the same
952+
/// time.
933953
fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
954+
let oob_local_types = context.oob_local_types();
955+
for &ty in oob_local_types.iter() {
956+
let name_key = NameKey::oob_local_for_type(context.origin, ty);
957+
self.names.insert(name_key, self.namer.call("oob"));
958+
}
959+
934960
for (name_key, ty, init) in context
935961
.function
936962
.local_variables
@@ -939,6 +965,10 @@ impl<W: Write> Writer<W> {
939965
let name_key = NameKey::local(context.origin, local_handle);
940966
(name_key, local.ty, local.init)
941967
})
968+
.chain(oob_local_types.iter().map(|&ty| {
969+
let name_key = NameKey::oob_local_for_type(context.origin, ty);
970+
(name_key, ty, None)
971+
}))
942972
{
943973
let ty_name = TypeContext {
944974
handle: ty,
@@ -1761,7 +1791,42 @@ impl<W: Write> Writer<W> {
17611791
{
17621792
write!(self.out, " ? ")?;
17631793
self.put_access_chain(expr_handle, policy, context)?;
1764-
write!(self.out, " : DefaultConstructible()")?;
1794+
write!(self.out, " : ")?;
1795+
1796+
if context.resolve_type(base).pointer_space().is_some() {
1797+
// We can't just use `DefaultConstructible` if this is a pointer.
1798+
// Instead, we create a dummy local variable to serve as pointer
1799+
// target if the access is out of bounds.
1800+
let result_ty = context.info[expr_handle]
1801+
.ty
1802+
.inner_with(&context.module.types)
1803+
.pointer_base_type();
1804+
let result_ty_handle = match result_ty {
1805+
Some(TypeResolution::Handle(handle)) => handle,
1806+
Some(TypeResolution::Value(_)) => {
1807+
// As long as the result of a pointer access expression is
1808+
// passed to a function or stored in a let binding, the
1809+
// type will be in the arena. If additional uses of
1810+
// pointers become valid, this assumption might no longer
1811+
// hold. Note that the LHS of a load or store doesn't
1812+
// take this path -- there is dedicated code in `put_load`
1813+
// and `put_store`.
1814+
unreachable!(
1815+
"Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1816+
);
1817+
}
1818+
None => {
1819+
unreachable!(
1820+
"Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1821+
)
1822+
}
1823+
};
1824+
let name_key =
1825+
NameKey::oob_local_for_type(context.origin, result_ty_handle);
1826+
self.out.write_str(&self.names[&name_key])?;
1827+
} else {
1828+
write!(self.out, "DefaultConstructible()")?;
1829+
}
17651830

17661831
if !is_scoped {
17671832
write!(self.out, ")")?;

Diff for: naga/src/proc/index.rs

+57-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
Definitions for index bounds checking.
33
*/
44

5-
use core::iter;
5+
use core::iter::{self, zip};
66

77
use crate::arena::{Handle, HandleSet, UniqueArena};
8-
use crate::valid;
8+
use crate::{valid, FastHashSet};
99

1010
/// How should code generated by Naga do bounds checks?
1111
///
@@ -389,6 +389,61 @@ pub(crate) fn bounds_check_iter<'a>(
389389
})
390390
}
391391

392+
/// Returns all the types which we need out-of-bounds locals for; that is,
393+
/// all of the types which the code might attempt to get an out-of-bounds
394+
/// pointer to, in which case we yield a pointer to the out-of-bounds local
395+
/// of the correct type.
396+
pub fn oob_local_types(
397+
module: &crate::Module,
398+
function: &crate::Function,
399+
info: &valid::FunctionInfo,
400+
policies: BoundsCheckPolicies,
401+
) -> FastHashSet<Handle<crate::Type>> {
402+
let mut result = FastHashSet::default();
403+
404+
if policies.index != BoundsCheckPolicy::ReadZeroSkipWrite {
405+
return result;
406+
}
407+
408+
for statement in &function.body {
409+
// The only situation in which we end up actually needing to create an
410+
// out-of-bounds pointer is when passing one to a function.
411+
//
412+
// This is because pointers are never baked; they're just inlined everywhere
413+
// they're used. That means that loads can just return 0, and stores can just do
414+
// nothing; functions are the only case where you actually *have* to produce a
415+
// pointer.
416+
if let crate::Statement::Call {
417+
function: callee,
418+
ref arguments,
419+
..
420+
} = *statement
421+
{
422+
// Now go through the arguments of the function looking for pointers which need bounds checks.
423+
for (arg_info, &arg) in zip(&module.functions[callee].arguments, arguments) {
424+
match module.types[arg_info.ty].inner {
425+
crate::TypeInner::ValuePointer { .. } => {
426+
// `ValuePointer`s should only ever be used when resolving the types of
427+
// expressions, since the arena can no longer be modified at that point; things
428+
// in the arena should always use proper `Pointer`s.
429+
unreachable!("`ValuePointer` found in arena")
430+
}
431+
crate::TypeInner::Pointer { base, .. } => {
432+
if bounds_check_iter(arg, module, function, info)
433+
.next()
434+
.is_some()
435+
{
436+
result.insert(base);
437+
}
438+
}
439+
_ => continue,
440+
};
441+
}
442+
}
443+
}
444+
result
445+
}
446+
392447
impl GuardedIndex {
393448
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
394449
///

Diff for: naga/src/proc/namer.rs

+10
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,19 @@ pub enum NameKey {
2121
Function(Handle<crate::Function>),
2222
FunctionArgument(Handle<crate::Function>, u32),
2323
FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
24+
25+
/// A local variable used by ReadZeroSkipWrite bounds-check policy
26+
/// when it needs to produce a pointer-typed result for an OOB access.
27+
/// These are unique per accessed type, so the second element is a
28+
/// type handle. See docs for [`crate::back::msl`].
29+
FunctionOobLocal(Handle<crate::Function>, Handle<crate::Type>),
30+
2431
EntryPoint(EntryPointIndex),
2532
EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
2633
EntryPointArgument(EntryPointIndex, u32),
34+
35+
/// Entry point version of `FunctionOobLocal`.
36+
EntryPointOobLocal(EntryPointIndex, Handle<crate::Type>),
2737
}
2838

2939
/// This processor assigns names to all the things in a module
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
targets = "METAL"
2+
3+
[bounds_check_policies]
4+
index = "Restrict"
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
fn takes_ptr(p: ptr<function, i32>) {}
2+
fn takes_array_ptr(p: ptr<function, array<i32, 4>>) {}
3+
fn takes_vec_ptr(p: ptr<function, vec2<i32>>) {}
4+
fn takes_mat_ptr(p: ptr<function, mat2x2<f32>>) {}
5+
6+
fn local_var(i: u32) {
7+
var arr = array(1, 2, 3, 4);
8+
takes_ptr(&arr[i]);
9+
takes_array_ptr(&arr);
10+
11+
}
12+
13+
fn mat_vec_ptrs(
14+
pv: ptr<function, array<vec2<i32>, 4>>,
15+
pm: ptr<function, array<mat2x2<f32>, 4>>,
16+
i: u32,
17+
) {
18+
takes_vec_ptr(&pv[i]);
19+
takes_mat_ptr(&pm[i]);
20+
}
21+
22+
fn argument(v: ptr<function, array<i32, 4>>, i: u32) {
23+
takes_ptr(&v[i]);
24+
}
25+
26+
fn argument_nested_x2(v: ptr<function, array<array<i32, 4>, 4>>, i: u32, j: u32) {
27+
takes_ptr(&v[i][j]);
28+
29+
// Mixing compile and runtime bounds checks
30+
takes_ptr(&v[i][0]);
31+
takes_ptr(&v[0][j]);
32+
33+
takes_array_ptr(&v[i]);
34+
}
35+
36+
fn argument_nested_x3(v: ptr<function, array<array<array<i32, 4>, 4>, 4>>, i: u32, j: u32) {
37+
takes_ptr(&v[i][0][j]);
38+
takes_ptr(&v[i][j][0]);
39+
takes_ptr(&v[0][i][j]);
40+
}
41+
42+
fn index_from_self(v: ptr<function, array<i32, 4>>, i: u32) {
43+
takes_ptr(&v[v[i]]);
44+
}
45+
46+
fn local_var_from_arg(a: array<i32, 4>, i: u32) {
47+
var b = a;
48+
takes_ptr(&b[i]);
49+
}
50+
51+
fn let_binding(a: ptr<function, array<i32, 4>>, i: u32) {
52+
let p0 = &a[i];
53+
takes_ptr(p0);
54+
55+
let p1 = &a[0];
56+
takes_ptr(p1);
57+
}
58+
59+
// Runtime-sized arrays can only appear in storage buffers, while (in the base
60+
// language) pointers can only appear in function or private space, so there
61+
// is no interaction to test.

Diff for: naga/tests/in/wgsl/pointer-function-arg-rzsw.toml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
targets = "METAL"
2+
3+
[bounds_check_policies]
4+
index = "ReadZeroSkipWrite"

Diff for: naga/tests/in/wgsl/pointer-function-arg-rzsw.wgsl

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
fn takes_ptr(p: ptr<function, i32>) {}
2+
fn takes_array_ptr(p: ptr<function, array<i32, 4>>) {}
3+
fn takes_vec_ptr(p: ptr<function, vec2<i32>>) {}
4+
fn takes_mat_ptr(p: ptr<function, mat2x2<f32>>) {}
5+
6+
fn local_var(i: u32) {
7+
var arr = array(1, 2, 3, 4);
8+
takes_ptr(&arr[i]);
9+
takes_array_ptr(&arr);
10+
11+
}
12+
13+
fn mat_vec_ptrs(
14+
pv: ptr<function, array<vec2<i32>, 4>>,
15+
pm: ptr<function, array<mat2x2<f32>, 4>>,
16+
i: u32,
17+
) {
18+
takes_vec_ptr(&pv[i]);
19+
takes_mat_ptr(&pm[i]);
20+
}
21+
22+
fn argument(v: ptr<function, array<i32, 4>>, i: u32) {
23+
takes_ptr(&v[i]);
24+
}
25+
26+
fn argument_nested_x2(v: ptr<function, array<array<i32, 4>, 4>>, i: u32, j: u32) {
27+
takes_ptr(&v[i][j]);
28+
29+
// Mixing compile and runtime bounds checks
30+
takes_ptr(&v[i][0]);
31+
takes_ptr(&v[0][j]);
32+
33+
takes_array_ptr(&v[i]);
34+
}
35+
36+
fn argument_nested_x3(v: ptr<function, array<array<array<i32, 4>, 4>, 4>>, i: u32, j: u32) {
37+
takes_ptr(&v[i][0][j]);
38+
takes_ptr(&v[i][j][0]);
39+
takes_ptr(&v[0][i][j]);
40+
}
41+
42+
fn index_from_self(v: ptr<function, array<i32, 4>>, i: u32) {
43+
takes_ptr(&v[v[i]]);
44+
}
45+
46+
fn local_var_from_arg(a: array<i32, 4>, i: u32) {
47+
var b = a;
48+
takes_ptr(&b[i]);
49+
}
50+
51+
fn let_binding(a: ptr<function, array<i32, 4>>, i: u32) {
52+
let p0 = &a[i];
53+
takes_ptr(p0);
54+
55+
let p1 = &a[0];
56+
takes_ptr(p1);
57+
}
58+
59+
// Runtime-sized arrays can only appear in storage buffers, while (in the base
60+
// language) pointers can only appear in function or private space, so there
61+
// is no interaction to test.

Diff for: naga/tests/in/wgsl/pointer-function-arg.toml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
targets = "METAL | GLSL | HLSL | WGSL"

0 commit comments

Comments
 (0)