Skip to content

Commit 4d6b4da

Browse files
bors[bot]grovesNL
andcommitted
Merge #2399
2399: [mtl] Conditionally enable rasterization based on vertex shader requirements r=kvark a=grovesNL - Updates SPIRV-Cross - Sets `MTLRenderPipelineDescriptor.rasterizationEnabled` flag by querying the vertex shader to check whether rasterization should be enabled - Currently we query all shaders and not just the vertex shader, we could optimize this further if it's preferable Overall fixes a few of the CTS tests mentioned in #2394. There is another remaining issue to be fixed upstream in SPIRV-Cross, but we can get that later. PR checklist: - [ ] `make` succeeds (on *nix) - [ ] `make reftests` succeeds - [X] tested examples with the following backends: Metal - [ ] `rustfmt` run on changed code Co-authored-by: Joshua Groves <[email protected]>
2 parents a381a26 + f8f9b10 commit 4d6b4da

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

src/backend/metal/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ block = "0.1"
3131
cocoa = "0.18"
3232
core-graphics = "0.17"
3333
smallvec = "0.6"
34-
spirv_cross = "0.10"
34+
spirv_cross = "0.11.1"
3535
parking_lot = "0.6.3"
3636
storage-map = "0.1.2"

src/backend/metal/src/device.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,10 @@ impl Device {
656656
}
657657

658658
pub fn create_shader_library_from_source<S>(
659-
&self, source: S, version: LanguageVersion,
659+
&self,
660+
source: S,
661+
version: LanguageVersion,
662+
rasterization_enabled: bool,
660663
) -> Result<n::ShaderModule, ShaderError> where S: AsRef<str> {
661664
let options = metal::CompileOptions::new();
662665
let msl_version = match version {
@@ -677,6 +680,7 @@ impl Device {
677680
.map(|library| n::ShaderModule::Compiled(n::ModuleInfo {
678681
library,
679682
entry_point_map: n::EntryPointMap::default(),
683+
rasterization_enabled,
680684
}))
681685
.map_err(|e| ShaderError::CompilationFailed(e.into()))
682686
}
@@ -741,6 +745,9 @@ impl Device {
741745
});
742746
}
743747

748+
let rasterization_enabled = ast.is_rasterization_enabled()
749+
.map_err(|_| ShaderError::CompilationFailed("Unknown compile error".into()))?;
750+
744751
// done
745752
debug!("SPIRV-Cross generated shader:\n{}", shader_code);
746753

@@ -755,6 +762,7 @@ impl Device {
755762
Ok(n::ModuleInfo {
756763
library,
757764
entry_point_map,
765+
rasterization_enabled,
758766
})
759767
}
760768

@@ -764,7 +772,7 @@ impl Device {
764772
layout: &n::PipelineLayout,
765773
primitive_class: MTLPrimitiveTopologyClass,
766774
pipeline_cache: Option<&n::PipelineCache>,
767-
) -> Result<(metal::Library, metal::Function, metal::MTLSize), pso::CreationError> {
775+
) -> Result<(metal::Library, metal::Function, metal::MTLSize, bool), pso::CreationError> {
768776
let device = &self.shared.device;
769777
let msl_version = self.private_caps.msl_version;
770778
let module_map;
@@ -816,7 +824,7 @@ impl Device {
816824
pso::CreationError::Other
817825
})?;
818826

819-
Ok((lib, mtl_function, wg_size))
827+
Ok((lib, mtl_function, wg_size, info.rasterization_enabled))
820828
}
821829

822830
fn describe_argument(
@@ -1213,7 +1221,7 @@ impl hal::Device<Backend> for Device {
12131221
pipeline.set_input_primitive_topology(primitive_class);
12141222

12151223
// Vertex shader
1216-
let (vs_lib, vs_function, _) = self.load_shader(
1224+
let (vs_lib, vs_function, _, enable_rasterization) = self.load_shader(
12171225
&pipeline_desc.shaders.vertex,
12181226
pipeline_layout,
12191227
primitive_class,
@@ -1225,7 +1233,7 @@ impl hal::Device<Backend> for Device {
12251233
let fs_function;
12261234
let fs_lib = match pipeline_desc.shaders.fragment {
12271235
Some(ref ep) => {
1228-
let (lib, fun, _) = self.load_shader(ep, pipeline_layout, primitive_class, cache)?;
1236+
let (lib, fun, _, _) = self.load_shader(ep, pipeline_layout, primitive_class, cache)?;
12291237
fs_function = fun;
12301238
pipeline.set_fragment_function(Some(&fs_function));
12311239
Some(lib)
@@ -1251,6 +1259,8 @@ impl hal::Device<Backend> for Device {
12511259
return Err(pso::CreationError::Shader(ShaderError::UnsupportedStage(pso::Stage::Geometry)));
12521260
}
12531261

1262+
pipeline.set_rasterization_enabled(enable_rasterization);
1263+
12541264
// Assign target formats
12551265
let blend_targets = pipeline_desc.blender.targets
12561266
.iter()
@@ -1432,7 +1442,7 @@ impl hal::Device<Backend> for Device {
14321442
debug!("create_compute_pipeline {:?}", pipeline_desc);
14331443
let pipeline = metal::ComputePipelineDescriptor::new();
14341444

1435-
let (cs_lib, cs_function, work_group_size) = self.load_shader(
1445+
let (cs_lib, cs_function, work_group_size, _) = self.load_shader(
14361446
&pipeline_desc.shader,
14371447
&pipeline_desc.layout,
14381448
MTLPrimitiveTopologyClass::Unspecified,

src/backend/metal/src/native.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl PipelineLayout {
210210
pub struct ModuleInfo {
211211
pub library: metal::Library,
212212
pub entry_point_map: EntryPointMap,
213+
pub rasterization_enabled: bool,
213214
}
214215

215216
pub struct PipelineCache {

0 commit comments

Comments
 (0)