Skip to content

Commit c93f02e

Browse files
committed
Fix union field access by representing unions as structs in SPIR-V
When a union has multiple fields, represent it as a struct with all fields at offset 0 instead of just using the largest field. This allows pointer casts between union fields to work correctly by enabling the recover_access_chain_from_offset function to find valid access chains. Fixes #241.
1 parent e4375b1 commit c93f02e

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -718,21 +718,27 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
718718
FieldsShape::Union(_) => {
719719
assert!(!ty.is_unsized(), "{ty:#?}");
720720

721-
// Represent the `union` with its largest case, which should work
722-
// for at least `MaybeUninit<T>` (which is between `T` and `()`),
723-
// but also potentially some other ones as well.
724-
// NOTE(eddyb) even if long-term this may become a byte array, that
725-
// only works for "data types" and not "opaque handles" (images etc.).
726-
let largest_case = (0..ty.fields.count())
727-
.map(|i| ty.field(cx, i))
728-
.max_by_key(|case| case.size);
729-
730-
if let Some(case) = largest_case {
731-
assert_eq!(ty.size, case.size);
732-
case.spirv_type(span, cx)
721+
// If the union has no fields or only one field, represent the `union` with
722+
// its largest case, which should work for at least `MaybeUninit<T>` (which
723+
// is between `T` and `()`), but also potentially some other ones as well.
724+
//
725+
// NOTE(eddyb) even if long-term this may become a byte array, that only
726+
// works for "data types" and not "opaque handles" (images etc.).
727+
if ty.fields.count() <= 1 {
728+
let largest_case = (0..ty.fields.count())
729+
.map(|i| ty.field(cx, i))
730+
.max_by_key(|case| case.size);
731+
732+
if let Some(case) = largest_case {
733+
assert_eq!(ty.size, case.size);
734+
case.spirv_type(span, cx)
735+
} else {
736+
assert_eq!(ty.size, Size::ZERO);
737+
create_zst(cx, span, ty)
738+
}
733739
} else {
734-
assert_eq!(ty.size, Size::ZERO);
735-
create_zst(cx, span, ty)
740+
// For unions with multiple fields, represent as struct with all fields at offset 0
741+
trans_struct(cx, span, ty)
736742
}
737743
}
738744
FieldsShape::Array { stride, count } => {

tests/ui/lang/core/union_cast.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// build-pass
2+
3+
use spirv_std::spirv;
4+
5+
#[repr(C)]
6+
#[derive(Clone, Copy)]
7+
struct Data {
8+
a: f32,
9+
b: [f32; 3],
10+
c: f32,
11+
}
12+
13+
#[repr(C)]
14+
union DataOrArray {
15+
arr: [f32; 5],
16+
str: Data,
17+
}
18+
19+
impl DataOrArray {
20+
fn arr(&self) -> [f32; 5] {
21+
unsafe { self.arr }
22+
}
23+
fn new(arr: [f32; 5]) -> Self {
24+
Self { arr }
25+
}
26+
}
27+
28+
#[spirv(fragment)]
29+
pub fn main() {
30+
let dora = DataOrArray::new([0.0, 0.0, 0.0, 0.0, 0.0]);
31+
let _arr = dora.arr();
32+
}

0 commit comments

Comments
 (0)