Skip to content

Commit 3541f58

Browse files
andyleisersonErichDonGubler
authored andcommitted
[naga wgsl-in] vecN() constructors and let type conversions
* Support `vecN()` constructors (fixes gfx-rs#7356) * Apply automatic conversions to the initializer for `let` bindings
1 parent 1ec5bcf commit 3541f58

14 files changed

+422
-98
lines changed

naga/src/front/wgsl/lower/construction.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,25 @@ impl<'source> Lowerer<'source, '_> {
167167
// Empty constructor
168168
(Components::None, dst_ty) => match dst_ty {
169169
Constructor::Type((result_ty, _)) => {
170-
return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span)
170+
expr = crate::Expression::ZeroValue(result_ty);
171171
}
172-
Constructor::PartialVector { .. }
173-
| Constructor::PartialMatrix { .. }
174-
| Constructor::PartialArray => {
172+
Constructor::PartialVector { size } => {
173+
// vec2(), vec3(), vec4() return vectors of abstractInts; the same
174+
// is not true of the similar constructors for matrices or arrays.
175+
// See https://www.w3.org/TR/WGSL/#vec2-builtin et seq.
176+
let result_ty = ctx.module.types.insert(
177+
crate::Type {
178+
name: None,
179+
inner: crate::TypeInner::Vector {
180+
size,
181+
scalar: crate::Scalar::ABSTRACT_INT,
182+
},
183+
},
184+
span,
185+
);
186+
expr = crate::Expression::ZeroValue(result_ty);
187+
}
188+
Constructor::PartialMatrix { .. } | Constructor::PartialArray => {
175189
// We have no arguments from which to infer the result type, so
176190
// partial constructors aren't acceptable here.
177191
return Err(Box::new(Error::TypeNotInferable(ty_span)));

naga/src/front/wgsl/lower/mod.rs

+25-14
Original file line numberDiff line numberDiff line change
@@ -1490,24 +1490,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
14901490
let mut emitter = Emitter::default();
14911491
emitter.start(&ctx.function.expressions);
14921492

1493-
let value =
1494-
self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?;
1495-
1496-
// The WGSL spec says that any expression that refers to a
1497-
// `let`-bound variable is not a const expression. This
1498-
// affects when errors must be reported, so we can't even
1499-
// treat suitable `let` bindings as constant as an
1500-
// optimization.
1501-
ctx.local_expression_kind_tracker.force_non_const(value);
1502-
15031493
let explicit_ty = l
15041494
.ty
15051495
.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_const(block, &mut emitter)))
15061496
.transpose()?;
15071497

1508-
if let Some(ty) = explicit_ty {
1509-
let mut ctx = ctx.as_expression(block, &mut emitter);
1510-
let init_ty = ctx.register_type(value)?;
1498+
let mut ectx = ctx.as_expression(block, &mut emitter);
1499+
1500+
let value = if let Some(ty) = explicit_ty {
1501+
let (init_ty, lowered_init) = self.type_and_init(
1502+
l.name,
1503+
Some(l.init),
1504+
explicit_ty,
1505+
AbstractRule::Concretize,
1506+
&mut ectx,
1507+
)?;
1508+
15111509
if !ctx.module.types[ty]
15121510
.inner
15131511
.equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types)
@@ -1518,7 +1516,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
15181516
got: ctx.type_to_string(init_ty),
15191517
}));
15201518
}
1521-
}
1519+
1520+
// We passed `Some()` to `type_and_init`, so we
1521+
// will get a lowered initializer expression back.
1522+
lowered_init.expect("type_and_init did not return an initializer")
1523+
} else {
1524+
self.expression(l.init, &mut ectx)?
1525+
};
1526+
1527+
// The WGSL spec says that any expression that refers to a
1528+
// `let`-bound variable is not a const expression. This
1529+
// affects when errors must be reported, so we can't even
1530+
// treat suitable `let` bindings as constant as an
1531+
// optimization.
1532+
ctx.local_expression_kind_tracker.force_non_const(value);
15221533

15231534
block.extend(emitter.finish(&ctx.function.expressions));
15241535
ctx.local_table

naga/src/proc/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ impl crate::Literal {
118118
(value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
119119
(1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
120120
(0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
121+
(value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
122+
(value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
121123
_ => None,
122124
}
123125
}

naga/tests/in/wgsl/constructors.wgsl

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ struct Foo {
33
b: i32,
44
}
55

6-
// const const1 = vec3<f32>(0.0); // TODO: this is now a splat and we need to const eval it
6+
const const1 = vec3<f32>(0.0);
77
const const2 = vec3(0.0, 1.0, 2.0);
88
const const3 = mat2x2<f32>(0.0, 1.0, 2.0, 3.0);
99
const const4 = array<mat2x2<f32>, 1>(mat2x2<f32>(0.0, 1.0, 2.0, 3.0));
@@ -19,9 +19,8 @@ const cz6 = array<Foo, 3>();
1919
const cz7 = Foo();
2020

2121
// constructors that infer their type from their parameters
22-
// TODO: these also contain splats
23-
// const cp1 = vec2(0u);
24-
// const cp2 = mat2x2(vec2(0.), vec2(0.));
22+
const cp1 = vec2(0u);
23+
const cp2 = mat2x2(vec2(0.), vec2(0.));
2524
const cp3 = array(0, 1, 2, 3);
2625

2726
@compute @workgroup_size(1)
@@ -49,6 +48,8 @@ fn main() {
4948
let zvc5 = mat2x2<f32>();
5049
let zvc6 = array<Foo, 3>();
5150
let zvc7 = Foo();
51+
let zvc8: vec2<u32> = vec2();
52+
let zvc9: vec2<f32> = vec2();
5253

5354
// constructors that infer their type from their parameters
5455
let cit0 = vec2(0u);

naga/tests/in/wgsl/conversions.wgsl

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Conversion of initializer expressions
2+
const ic00: vec2<u32> = vec2();
3+
const ic01: vec4i = vec4();
4+
const ic02: vec4i = vec4(1);
5+
const ic03: vec4u = vec4();
6+
const ic04: vec4u = vec4(1);
7+
const ic05: vec4f = vec4();
8+
const ic06: vec4f = vec4(1);
9+
const ic07: vec2i = vec2(1, 1);
10+
const ic08: vec3i = vec3(1, 1, 1);
11+
const ic09: vec4i = vec4(1, 1, 1, 1);
12+
const ic10: vec2u = vec2(1, 1);
13+
const ic11: vec3u = vec3(1, 1, 1);
14+
const ic12: vec4u = vec4(1, 1, 1, 1);
15+
const ic13: vec2f = vec2(1, 1);
16+
const ic14: vec3f = vec3(1, 1, 1);
17+
const ic15: vec4f = vec4(1, 1, 1, 1);
18+
const ic16: vec2f = vec2(1.0, 1.0);
19+
const ic17: vec3f = vec3(1.0, 1.0, 1.0);
20+
const ic18: vec4f = vec4(1.0, 1.0, 1.0, 1.0);
21+
const ic19: vec2f = vec2(1, 1) + vec2(1.0, 1.0);
22+
const ic20: mat2x2f = mat2x2(vec2(), vec2());
23+
const ic21: array<u32, 4> = array(1, 2, 3, 4);
24+
25+
// Conversion by value constructors
26+
//let vc0 = i32(1.0); // https://github.com/gfx-rs/wgpu/issues/7312
27+
// etc. (also create the locals versions below)
28+
29+
@compute @workgroup_size(1)
30+
fn main() {
31+
const ic00: vec2<u32> = vec2();
32+
const ic01: vec4i = vec4();
33+
const ic02: vec4i = vec4(1);
34+
const ic03: vec4u = vec4();
35+
const ic04: vec4u = vec4(1);
36+
const ic05: vec4f = vec4();
37+
const ic06: vec4f = vec4(1);
38+
const ic07: vec2i = vec2(1, 1);
39+
const ic08: vec3i = vec3(1, 1, 1);
40+
const ic09: vec4i = vec4(1, 1, 1, 1);
41+
const ic10: vec2u = vec2(1, 1);
42+
const ic11: vec3u = vec3(1, 1, 1);
43+
const ic12: vec4u = vec4(1, 1, 1, 1);
44+
const ic13: vec2f = vec2(1, 1);
45+
const ic14: vec3f = vec3(1, 1, 1);
46+
const ic15: vec4f = vec4(1, 1, 1, 1);
47+
const ic16: vec2f = vec2(1.0, 1.0);
48+
const ic17: vec3f = vec3(1.0, 1.0, 1.0);
49+
const ic18: vec4f = vec4(1.0, 1.0, 1.0, 1.0);
50+
const ic19: vec2f = vec2(1, 1) + vec2(1.0, 1.0);
51+
const ic20: mat2x2f = mat2x2(vec2(), vec2());
52+
const ic21: array<u32, 4> = array(1, 2, 3, 4);
53+
54+
let lc00: vec2<u32> = vec2();
55+
let lc01: vec4i = vec4();
56+
let lc02: vec4i = vec4(1);
57+
let lc03: vec4u = vec4();
58+
let lc04: vec4u = vec4(1);
59+
let lc05: vec4f = vec4();
60+
let lc06: vec4f = vec4(1);
61+
let lc07: vec2i = vec2(1, 1);
62+
let lc08: vec3i = vec3(1, 1, 1);
63+
let lc09: vec4i = vec4(1, 1, 1, 1);
64+
let lc10: vec2u = vec2(1, 1);
65+
let lc11: vec3u = vec3(1, 1, 1);
66+
let lc12: vec4u = vec4(1, 1, 1, 1);
67+
let lc13: vec2f = vec2(1, 1);
68+
let lc14: vec3f = vec3(1, 1, 1);
69+
let lc15: vec4f = vec4(1, 1, 1, 1);
70+
let lc16: vec2f = vec2(1.0, 1.0);
71+
let lc17: vec3f = vec3(1.0, 1.0, 1.0);
72+
let lc18: vec4f = vec4(1.0, 1.0, 1.0, 1.0);
73+
let lc19: vec2f = vec2(1, 1) + vec2(1.0, 1.0);
74+
let lc20: mat2x2f = mat2x2(vec2(), vec2());
75+
let lc21: array<u32, 4> = array(1, 2, 3, 4);
76+
77+
var vc00: vec2<u32> = vec2();
78+
var vc01: vec4i = vec4();
79+
var vc02: vec4i = vec4(1);
80+
var vc03: vec4u = vec4();
81+
var vc04: vec4u = vec4(1);
82+
var vc05: vec4f = vec4();
83+
var vc06: vec4f = vec4(1);
84+
var vc07: vec2i = vec2(1, 1);
85+
var vc08: vec3i = vec3(1, 1, 1);
86+
var vc09: vec4i = vec4(1, 1, 1, 1);
87+
var vc10: vec2u = vec2(1, 1);
88+
var vc11: vec3u = vec3(1, 1, 1);
89+
var vc12: vec4u = vec4(1, 1, 1, 1);
90+
var vc13: vec2f = vec2(1, 1);
91+
var vc14: vec3f = vec3(1, 1, 1);
92+
var vc15: vec4f = vec4(1, 1, 1, 1);
93+
var vc16: vec2f = vec2(1.0, 1.0);
94+
var vc17: vec3f = vec3(1.0, 1.0, 1.0);
95+
var vc18: vec4f = vec4(1.0, 1.0, 1.0, 1.0);
96+
var vc19: vec2f = vec2(1, 1) + vec2(1.0, 1.0);
97+
var vc20: mat2x2f = mat2x2(vec2(), vec2());
98+
var vc21: array<u32, 4> = array(1, 2, 3, 4);
99+
}
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
const g0 = 1;
2+
const g1 = 1u;
3+
const g2 = 1.0;
4+
const g3 = 1.0f;
5+
const g4 = vec4<i32>();
6+
const g5 = vec4(1i);
7+
const g6 = mat2x2<f32>(vec2(), vec2());
8+
const g7 = mat2x2(vec2(1.0, 1), vec2(1, 1));
9+
10+
@compute @workgroup_size(1)
11+
fn main() {
12+
// Expose some constants that wouldn't otherwise be in the output
13+
// because they don't have concrete types.
14+
var g0x = g0;
15+
var g2x = g2;
16+
var g7x = g7;
17+
18+
const c0 = 1;
19+
const c1 = 1u;
20+
const c2 = 1.0;
21+
const c3 = 1.0f;
22+
const c4 = vec4<i32>();
23+
const c5 = vec4(1i);
24+
const c6 = mat2x2<f32>(vec2(), vec2());
25+
const c7 = mat2x2(vec2(1.0, 1), vec2(1, 1));
26+
27+
// Local constants are not emitted in most cases.
28+
// See logic for `Statement::Emit` in `back::wgsl::Writer::write_stmt`.
29+
var c0x = c0;
30+
var c1x = c1;
31+
var c2x = c2;
32+
var c3x = c3;
33+
var c4x = c4;
34+
var c5x = c5;
35+
var c6x = c6;
36+
var c7x = c7;
37+
38+
let l0 = 1;
39+
let l1 = 1u;
40+
let l2 = 1.0;
41+
let l3 = 1.0f;
42+
let l4 = vec4<i32>();
43+
let l5 = vec4(1i);
44+
let l6 = mat2x2<f32>(vec2(), vec2());
45+
let l7 = mat2x2(vec2(1.0, 1), vec2(1, 1));
46+
47+
// Let bindings that evaluate to literals or a `ZeroValue` expression are
48+
// not emitted. See `ConstantEvaluator::append_expr`. `vec4(1i)` is emitted
49+
// because it is translated to a `Splat` expression.
50+
var l0x = l0;
51+
var l1x = l1;
52+
var l2x = l2;
53+
var l3x = l3;
54+
var l4x = l4;
55+
56+
var v0 = 1;
57+
var v1 = 1u;
58+
var v2 = 1.0;
59+
var v3 = 1.0f;
60+
var v4 = vec4<i32>();
61+
var v5 = vec4(1i);
62+
var v6 = mat2x2<f32>(vec2(), vec2());
63+
var v7 = mat2x2(vec2(1.0, 1), vec2(1, 1));
64+
}

naga/tests/out/glsl/constructors.main.Compute.glsl

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ struct Foo {
99
vec4 a;
1010
int b;
1111
};
12+
const vec3 const1_ = vec3(0.0);
1213
const mat2x2 const3_ = mat2x2(vec2(0.0, 1.0), vec2(2.0, 3.0));
1314
const mat2x2 const4_[1] = mat2x2[1](mat2x2(vec2(0.0, 1.0), vec2(2.0, 3.0)));
1415
const bool cz0_ = false;
@@ -19,13 +20,16 @@ const uvec2 cz4_ = uvec2(0u);
1920
const mat2x2 cz5_ = mat2x2(0.0);
2021
const Foo cz6_[3] = Foo[3](Foo(vec4(0.0), 0), Foo(vec4(0.0), 0), Foo(vec4(0.0), 0));
2122
const Foo cz7_ = Foo(vec4(0.0), 0);
23+
const uvec2 cp1_ = uvec2(0u);
2224

2325

2426
void main() {
2527
Foo foo = Foo(vec4(0.0), 0);
2628
foo = Foo(vec4(1.0), 1);
2729
mat2x2 m0_ = mat2x2(vec2(1.0, 0.0), vec2(0.0, 1.0));
2830
mat4x4 m1_ = mat4x4(vec4(1.0, 0.0, 0.0, 0.0), vec4(0.0, 1.0, 0.0, 0.0), vec4(0.0, 0.0, 1.0, 0.0), vec4(0.0, 0.0, 0.0, 1.0));
31+
uvec2 zvc8_ = uvec2(0u, 0u);
32+
vec2 zvc9_ = vec2(0.0, 0.0);
2933
uvec2 cit0_ = uvec2(0u);
3034
mat2x2 cit1_ = mat2x2(vec2(0.0), vec2(0.0));
3135
int cit2_[4] = int[4](0, 1, 2, 3);

naga/tests/out/hlsl/constructors.hlsl

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Foo ZeroValueFoo() {
4545
return (Foo)0;
4646
}
4747

48+
static const float3 const1_ = (0.0).xxx;
4849
static const float2x2 const3_ = float2x2(float2(0.0, 1.0), float2(2.0, 3.0));
4950
static const float2x2 const4_[1] = Constructarray1_float2x2_(float2x2(float2(0.0, 1.0), float2(2.0, 3.0)));
5051
static const bool cz0_ = ZeroValuebool();
@@ -55,6 +56,7 @@ static const uint2 cz4_ = ZeroValueuint2();
5556
static const float2x2 cz5_ = ZeroValuefloat2x2();
5657
static const Foo cz6_[3] = ZeroValuearray3_Foo_();
5758
static const Foo cz7_ = ZeroValueFoo();
59+
static const uint2 cp1_ = (0u).xx;
5860

5961
Foo ConstructFoo(float4 arg0, int arg1) {
6062
Foo ret = (Foo)0;
@@ -81,6 +83,8 @@ void main()
8183
foo = ConstructFoo((1.0).xxxx, int(1));
8284
float2x2 m0_ = float2x2(float2(1.0, 0.0), float2(0.0, 1.0));
8385
float4x4 m1_ = float4x4(float4(1.0, 0.0, 0.0, 0.0), float4(0.0, 1.0, 0.0, 0.0), float4(0.0, 0.0, 1.0, 0.0), float4(0.0, 0.0, 0.0, 1.0));
86+
uint2 zvc8_ = uint2(0u, 0u);
87+
float2 zvc9_ = float2(0.0, 0.0);
8488
uint2 cit0_ = (0u).xx;
8589
float2x2 cit1_ = float2x2((0.0).xx, (0.0).xx);
8690
int cit2_[4] = Constructarray4_int_(int(0), int(1), int(2), int(3));

naga/tests/out/msl/constructors.msl

+10-6
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,39 @@ struct Foo {
88
metal::float4 a;
99
int b;
1010
};
11-
struct type_5 {
11+
struct type_6 {
1212
metal::float2x2 inner[1];
1313
};
14-
struct type_9 {
14+
struct type_10 {
1515
Foo inner[3];
1616
};
17-
struct type_11 {
17+
struct type_12 {
1818
int inner[4];
1919
};
20+
constant metal::float3 const1_ = metal::float3(0.0);
2021
constant metal::float2x2 const3_ = metal::float2x2(metal::float2(0.0, 1.0), metal::float2(2.0, 3.0));
21-
constant type_5 const4_ = type_5 {metal::float2x2(metal::float2(0.0, 1.0), metal::float2(2.0, 3.0))};
22+
constant type_6 const4_ = type_6 {metal::float2x2(metal::float2(0.0, 1.0), metal::float2(2.0, 3.0))};
2223
constant bool cz0_ = bool {};
2324
constant int cz1_ = int {};
2425
constant uint cz2_ = uint {};
2526
constant float cz3_ = float {};
2627
constant metal::uint2 cz4_ = metal::uint2 {};
2728
constant metal::float2x2 cz5_ = metal::float2x2 {};
28-
constant type_9 cz6_ = type_9 {};
29+
constant type_10 cz6_ = type_10 {};
2930
constant Foo cz7_ = Foo {};
31+
constant metal::uint2 cp1_ = metal::uint2(0u);
3032

3133
kernel void main_(
3234
) {
3335
Foo foo = {};
3436
foo = Foo {metal::float4(1.0), 1};
3537
metal::float2x2 m0_ = metal::float2x2(metal::float2(1.0, 0.0), metal::float2(0.0, 1.0));
3638
metal::float4x4 m1_ = metal::float4x4(metal::float4(1.0, 0.0, 0.0, 0.0), metal::float4(0.0, 1.0, 0.0, 0.0), metal::float4(0.0, 0.0, 1.0, 0.0), metal::float4(0.0, 0.0, 0.0, 1.0));
39+
metal::uint2 zvc8_ = metal::uint2(0u, 0u);
40+
metal::float2 zvc9_ = metal::float2(0.0, 0.0);
3741
metal::uint2 cit0_ = metal::uint2(0u);
3842
metal::float2x2 cit1_ = metal::float2x2(metal::float2(0.0), metal::float2(0.0));
39-
type_11 cit2_ = type_11 {0, 1, 2, 3};
43+
type_12 cit2_ = type_12 {0, 1, 2, 3};
4044
metal::uint2 ic4_ = metal::uint2(0u, 0u);
4145
metal::float2x3 ic5_ = metal::float2x3(metal::float3(0.0, 0.0, 0.0), metal::float3(0.0, 0.0, 0.0));
4246
return;

0 commit comments

Comments
 (0)