Skip to content

Commit 587aea2

Browse files
andyleisersonLiamoluckoErichDonGubler
committed
refactor(msl-out): extract bounds_check_iter helper
-- Co-authored-by: Liam Murphy <[email protected]> Co-Authored-By: Erich Gubler <[email protected]>
1 parent aad187f commit 587aea2

File tree

2 files changed

+86
-53
lines changed

2 files changed

+86
-53
lines changed

naga/src/back/msl/writer.rs

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,20 @@ impl<'a> ExpressionContext<'a> {
695695
)
696696
}
697697

698+
/// See docs for [`proc::index::bounds_check_iter`].
699+
fn bounds_check_iter(
700+
&self,
701+
chain: Handle<crate::Expression>,
702+
) -> impl Iterator<
703+
Item = (
704+
Handle<crate::Expression>,
705+
index::GuardedIndex,
706+
index::IndexableLength,
707+
),
708+
> + '_ {
709+
index::bounds_check_iter(chain, self.module, self.function, self.info)
710+
}
711+
698712
fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
699713
match self.function.expressions[expr_handle] {
700714
crate::Expression::AccessIndex { base, index } => {
@@ -2647,68 +2661,38 @@ impl<W: Write> Writer<W> {
26472661
#[allow(unused_variables)]
26482662
fn put_bounds_checks(
26492663
&mut self,
2650-
mut chain: Handle<crate::Expression>,
2664+
chain: Handle<crate::Expression>,
26512665
context: &ExpressionContext,
26522666
level: back::Level,
26532667
prefix: &'static str,
26542668
) -> Result<bool, Error> {
26552669
let mut check_written = false;
26562670

2657-
// Iterate over the access chain, handling each expression.
2658-
loop {
2659-
// Produce a `GuardedIndex`, so we can shared code between the
2660-
// `Access` and `AccessIndex` cases.
2661-
let (base, guarded_index) = match context.function.expressions[chain] {
2662-
crate::Expression::Access { base, index } => {
2663-
(base, Some(index::GuardedIndex::Expression(index)))
2664-
}
2665-
crate::Expression::AccessIndex { base, index } => {
2666-
// Don't try to check indices into structs. Validation already took
2667-
// care of them, and index::needs_guard doesn't handle that case.
2668-
let mut base_inner = context.resolve_type(base);
2669-
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
2670-
base_inner = &context.module.types[base].inner;
2671-
}
2672-
match *base_inner {
2673-
crate::TypeInner::Struct { .. } => (base, None),
2674-
_ => (base, Some(index::GuardedIndex::Known(index))),
2675-
}
2676-
}
2677-
_ => break,
2678-
};
2679-
2680-
if let Some(index) = guarded_index {
2681-
if let Some(length) = context.access_needs_check(base, index) {
2682-
if check_written {
2683-
write!(self.out, " && ")?;
2684-
} else {
2685-
write!(self.out, "{level}{prefix}")?;
2686-
check_written = true;
2687-
}
2671+
// Iterate over the access chain, handling each required bounds check.
2672+
for (base, index, length) in context.bounds_check_iter(chain) {
2673+
if check_written {
2674+
write!(self.out, " && ")?;
2675+
} else {
2676+
write!(self.out, "{level}{prefix}")?;
2677+
check_written = true;
2678+
}
26882679

2689-
// Check that the index falls within bounds. Do this with a single
2690-
// comparison, by casting the index to `uint` first, so that negative
2691-
// indices become large positive values.
2692-
write!(self.out, "uint(")?;
2693-
self.put_index(index, context, true)?;
2694-
self.out.write_str(") < ")?;
2695-
match length {
2696-
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
2697-
index::IndexableLength::Dynamic => {
2698-
let global =
2699-
context.function.originating_global(base).ok_or_else(|| {
2700-
Error::GenericValidation(
2701-
"Could not find originating global".into(),
2702-
)
2703-
})?;
2704-
write!(self.out, "1 + ")?;
2705-
self.put_dynamic_array_max_index(global, context)?
2706-
}
2707-
}
2680+
// Check that the index falls within bounds. Do this with a single
2681+
// comparison, by casting the index to `uint` first, so that negative
2682+
// indices become large positive values.
2683+
write!(self.out, "uint(")?;
2684+
self.put_index(index, context, true)?;
2685+
self.out.write_str(") < ")?;
2686+
match length {
2687+
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
2688+
index::IndexableLength::Dynamic => {
2689+
let global = context.function.originating_global(base).ok_or_else(|| {
2690+
Error::GenericValidation("Could not find originating global".into())
2691+
})?;
2692+
write!(self.out, "1 + ")?;
2693+
self.put_dynamic_array_max_index(global, context)?
27082694
}
27092695
}
2710-
2711-
chain = base
27122696
}
27132697

27142698
Ok(check_written)

naga/src/proc/index.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Definitions for index bounds checking.
33
*/
44

5+
use core::iter;
6+
57
use crate::arena::{Handle, HandleSet, UniqueArena};
68
use crate::valid;
79

@@ -340,6 +342,53 @@ pub fn access_needs_check(
340342
Some(length)
341343
}
342344

345+
/// Returns an iterator of accesses within the chain of `Access` and
346+
/// `AccessIndex` expressions starting from `chain` that may need to be
347+
/// bounds-checked at runtime.
348+
///
349+
/// They're yielded as `(base, index)` pairs, where `base` is the type that the
350+
/// access expression will produce and `index` is the index being used.
351+
///
352+
/// Accesses through a struct are omitted, since you never need a bounds check
353+
/// for accessing a struct field.
354+
///
355+
/// If `chain` isn't an `Access` or `AccessIndex` expression at all, the
356+
/// iterator is empty.
357+
pub(crate) fn bounds_check_iter<'a>(
358+
mut chain: Handle<crate::Expression>,
359+
module: &'a crate::Module,
360+
function: &'a crate::Function,
361+
info: &'a valid::FunctionInfo,
362+
) -> impl Iterator<Item = (Handle<crate::Expression>, GuardedIndex, IndexableLength)> + 'a {
363+
iter::from_fn(move || {
364+
let (next_expr, result) = match function.expressions[chain] {
365+
crate::Expression::Access { base, index } => {
366+
(base, Some((base, GuardedIndex::Expression(index))))
367+
}
368+
crate::Expression::AccessIndex { base, index } => {
369+
// Don't try to check indices into structs. Validation already took
370+
// care of them, and access_needs_check doesn't handle that case.
371+
let mut base_inner = info[base].ty.inner_with(&module.types);
372+
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
373+
base_inner = &module.types[base].inner;
374+
}
375+
match *base_inner {
376+
crate::TypeInner::Struct { .. } => (base, None),
377+
_ => (base, Some((base, GuardedIndex::Known(index)))),
378+
}
379+
}
380+
_ => return None,
381+
};
382+
chain = next_expr;
383+
Some(result)
384+
})
385+
.flatten()
386+
.filter_map(|(base, index)| {
387+
access_needs_check(base, index, module, &function.expressions, info)
388+
.map(|length| (base, index, length))
389+
})
390+
}
391+
343392
impl GuardedIndex {
344393
/// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
345394
///

0 commit comments

Comments
 (0)