diff --git a/CHANGELOG.md b/CHANGELOG.md index 9643107bd8..5c0b31d87c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -150,6 +150,7 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206 - The `STORAGE_READ_ONLY` texture usage is now permitted to coexist with other read-only usages. By @andyleiserson in [#8490](https://github.com/gfx-rs/wgpu/pull/8490). - Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454). - Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370). +- Add WGSL writing for mesh shaders. By @Slightlyclueless [#8481](https://github.com/gfx-rs/wgpu/pull/8481). #### DX12 diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index daf32a7116..2b764d8a6d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -33,6 +33,9 @@ enum Attribute { BlendSrc(u32), Stage(ShaderStage), WorkGroupSize([u32; 3]), + MeshStage(String), + TaskPayload(String), + PerPrimitive, } /// The WGSL form that `write_expr_with_indirection` should use to render a Naga @@ -207,9 +210,37 @@ impl Writer { Attribute::Stage(ShaderStage::Compute), Attribute::WorkGroupSize(ep.workgroup_size), ], - ShaderStage::Mesh | ShaderStage::Task => unreachable!(), + ShaderStage::Mesh => { + let mesh_output_name = module.global_variables + [ep.mesh_info.as_ref().unwrap().output_variable] + .name + .clone() + .unwrap(); + let mut mesh_attrs = vec![ + Attribute::MeshStage(mesh_output_name), + Attribute::WorkGroupSize(ep.workgroup_size), + ]; + if ep.task_payload.is_some() { + let payload_name = module.global_variables[ep.task_payload.unwrap()] + .name + .clone() + .unwrap(); + mesh_attrs.push(Attribute::TaskPayload(payload_name)); + } + mesh_attrs + } + ShaderStage::Task => { + let payload_name = module.global_variables[ep.task_payload.unwrap()] + .name + .clone() + .unwrap(); + vec![ + Attribute::Stage(ShaderStage::Task), + Attribute::TaskPayload(payload_name), + Attribute::WorkGroupSize(ep.workgroup_size), + ] + } }; - self.write_attributes(&attributes)?; // Add a newline after attribute writeln!(self.out)?; @@ -243,6 +274,7 @@ impl Writer { let mut needs_f16 = false; let mut needs_dual_source_blending = false; let mut needs_clip_distances = false; + let mut needs_mesh_shaders = false; // Determine which `enable` declarations are needed for (_, ty) in module.types.iter() { @@ -263,6 +295,25 @@ impl Writer { crate::Binding::BuiltIn(crate::BuiltIn::ClipDistance) => { needs_clip_distances = true; } + crate::Binding::Location { + per_primitive: true, + .. + } => { + needs_mesh_shaders = true; + } + crate::Binding::BuiltIn( + crate::BuiltIn::MeshTaskSize + | crate::BuiltIn::CullPrimitive + | crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices + | crate::BuiltIn::VertexCount + | crate::BuiltIn::Vertices + | crate::BuiltIn::PrimitiveCount + | crate::BuiltIn::Primitives, + ) => { + needs_mesh_shaders = true; + } _ => {} } } @@ -271,6 +322,22 @@ impl Writer { } } + if module + .entry_points + .iter() + .any(|ep| matches!(ep.stage, ShaderStage::Mesh | ShaderStage::Task)) + { + needs_mesh_shaders = true; + } + + if module + .global_variables + .iter() + .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload) + { + needs_mesh_shaders = true; + } + // Write required declarations let mut any_written = false; if needs_f16 { @@ -285,6 +352,10 @@ impl Writer { writeln!(self.out, "enable clip_distances;")?; any_written = true; } + if needs_mesh_shaders { + writeln!(self.out, "enable wgpu_mesh_shader;")?; + any_written = true; + } if any_written { // Empty line for readability writeln!(self.out)?; @@ -403,8 +474,10 @@ impl Writer { ShaderStage::Vertex => "vertex", ShaderStage::Fragment => "fragment", ShaderStage::Compute => "compute", - ShaderStage::Task | ShaderStage::Mesh => unreachable!(), + ShaderStage::Task => "task", + ShaderStage::Mesh => unreachable!(), }; + write!(self.out, "@{stage_str} ")?; } Attribute::WorkGroupSize(size) => { @@ -433,6 +506,13 @@ impl Writer { write!(self.out, "@interpolate({interpolation}) ")?; } } + Attribute::MeshStage(ref name) => { + write!(self.out, "@mesh({name}) ")?; + } + Attribute::TaskPayload(ref payload_name) => { + write!(self.out, "@payload({payload_name}) ")?; + } + Attribute::PerPrimitive => write!(self.out, "@per_primitive ")?, }; } Ok(()) @@ -1822,21 +1902,33 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: None, - per_primitive: _, - } => vec![ - Attribute::Location(location), - Attribute::Interpolate(interpolation, sampling), - ], + per_primitive, + } => { + let mut attrs = vec![ + Attribute::Location(location), + Attribute::Interpolate(interpolation, sampling), + ]; + if per_primitive { + attrs.push(Attribute::PerPrimitive); + } + attrs + } crate::Binding::Location { location, interpolation, sampling, blend_src: Some(blend_src), - per_primitive: _, - } => vec![ - Attribute::Location(location), - Attribute::BlendSrc(blend_src), - Attribute::Interpolate(interpolation, sampling), - ], + per_primitive, + } => { + let mut attrs = vec![ + Attribute::Location(location), + Attribute::BlendSrc(blend_src), + Attribute::Interpolate(interpolation, sampling), + ]; + if per_primitive { + attrs.push(Attribute::PerPrimitive); + } + attrs + } } } diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 5e6178c049..7140b4883e 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -183,22 +183,23 @@ impl TryToWgsl for crate::BuiltIn { Bi::SubgroupInvocationId => "subgroup_invocation_id", // Non-standard built-ins. + Bi::MeshTaskSize => "mesh_task_size", + Bi::TriangleIndices => "triangle_indices", + Bi::LineIndices => "line_indices", + Bi::PointIndex => "point_index", + Bi::Vertices => "vertices", + Bi::Primitives => "primitives", + Bi::VertexCount => "vertex_count", + Bi::PrimitiveCount => "primitive_count", + Bi::CullPrimitive => "cull_primitive", + Bi::BaseInstance | Bi::BaseVertex | Bi::CullDistance | Bi::PointSize | Bi::DrawID | Bi::PointCoord - | Bi::WorkGroupSize - | Bi::CullPrimitive - | Bi::TriangleIndices - | Bi::LineIndices - | Bi::MeshTaskSize - | Bi::PointIndex - | Bi::VertexCount - | Bi::PrimitiveCount - | Bi::Vertices - | Bi::Primitives => return None, + | Bi::WorkGroupSize => return None, }) } } @@ -362,7 +363,7 @@ pub const fn address_space_str( As::WorkGroup => "workgroup", As::Handle => return (None, None), As::Function => "function", - As::TaskPayload => return (None, None), + As::TaskPayload => "task_payload", }), None, ) diff --git a/naga/tests/in/wgsl/mesh-shader-empty.toml b/naga/tests/in/wgsl/mesh-shader-empty.toml index 8500399f93..ecfa36ccd3 100644 --- a/naga/tests/in/wgsl/mesh-shader-empty.toml +++ b/naga/tests/in/wgsl/mesh-shader-empty.toml @@ -1,2 +1,2 @@ god_mode = true -targets = "IR | ANALYSIS" +targets = "IR | ANALYSIS | WGSL" diff --git a/naga/tests/in/wgsl/mesh-shader-lines.toml b/naga/tests/in/wgsl/mesh-shader-lines.toml index 8500399f93..ecfa36ccd3 100644 --- a/naga/tests/in/wgsl/mesh-shader-lines.toml +++ b/naga/tests/in/wgsl/mesh-shader-lines.toml @@ -1,2 +1,2 @@ god_mode = true -targets = "IR | ANALYSIS" +targets = "IR | ANALYSIS | WGSL" diff --git a/naga/tests/in/wgsl/mesh-shader-points.toml b/naga/tests/in/wgsl/mesh-shader-points.toml index 8500399f93..ecfa36ccd3 100644 --- a/naga/tests/in/wgsl/mesh-shader-points.toml +++ b/naga/tests/in/wgsl/mesh-shader-points.toml @@ -1,2 +1,2 @@ god_mode = true -targets = "IR | ANALYSIS" +targets = "IR | ANALYSIS | WGSL" diff --git a/naga/tests/in/wgsl/mesh-shader.toml b/naga/tests/in/wgsl/mesh-shader.toml index 8500399f93..ecfa36ccd3 100644 --- a/naga/tests/in/wgsl/mesh-shader.toml +++ b/naga/tests/in/wgsl/mesh-shader.toml @@ -1,2 +1,2 @@ god_mode = true -targets = "IR | ANALYSIS" +targets = "IR | ANALYSIS | WGSL" diff --git a/naga/tests/out/wgsl/wgsl-mesh-shader-empty.wgsl b/naga/tests/out/wgsl/wgsl-mesh-shader-empty.wgsl new file mode 100644 index 0000000000..c5e853af26 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-mesh-shader-empty.wgsl @@ -0,0 +1,33 @@ +enable wgpu_mesh_shader; + +struct TaskPayload { + dummy: u32, +} + +struct VertexOutput { + @builtin(position) position: vec4, +} + +struct PrimitiveOutput { + @builtin(triangle_indices) indices: vec3, +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var taskPayload: TaskPayload; +var mesh_output: MeshOutput; + +@task @payload(taskPayload) @workgroup_size(1, 1, 1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + return vec3(1u, 1u, 1u); +} + +@mesh(mesh_output) @workgroup_size(1, 1, 1) @payload(taskPayload) +fn ms_main() { + return; +} diff --git a/naga/tests/out/wgsl/wgsl-mesh-shader-lines.wgsl b/naga/tests/out/wgsl/wgsl-mesh-shader-lines.wgsl new file mode 100644 index 0000000000..fe7c341f30 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-mesh-shader-lines.wgsl @@ -0,0 +1,33 @@ +enable wgpu_mesh_shader; + +struct TaskPayload { + dummy: u32, +} + +struct VertexOutput { + @builtin(position) position: vec4, +} + +struct PrimitiveOutput { + @builtin(line_indices) indices: vec2, +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var taskPayload: TaskPayload; +var mesh_output: MeshOutput; + +@task @payload(taskPayload) @workgroup_size(1, 1, 1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + return vec3(1u, 1u, 1u); +} + +@mesh(mesh_output) @workgroup_size(1, 1, 1) @payload(taskPayload) +fn ms_main() { + return; +} diff --git a/naga/tests/out/wgsl/wgsl-mesh-shader-points.wgsl b/naga/tests/out/wgsl/wgsl-mesh-shader-points.wgsl new file mode 100644 index 0000000000..b6eea73d08 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-mesh-shader-points.wgsl @@ -0,0 +1,33 @@ +enable wgpu_mesh_shader; + +struct TaskPayload { + dummy: u32, +} + +struct VertexOutput { + @builtin(position) position: vec4, +} + +struct PrimitiveOutput { + @builtin(point_index) indices: u32, +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var taskPayload: TaskPayload; +var mesh_output: MeshOutput; + +@task @payload(taskPayload) @workgroup_size(1, 1, 1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + return vec3(1u, 1u, 1u); +} + +@mesh(mesh_output) @workgroup_size(1, 1, 1) @payload(taskPayload) +fn ms_main() { + return; +} diff --git a/naga/tests/out/wgsl/wgsl-mesh-shader.wgsl b/naga/tests/out/wgsl/wgsl-mesh-shader.wgsl new file mode 100644 index 0000000000..5a4a91dce3 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-mesh-shader.wgsl @@ -0,0 +1,66 @@ +enable wgpu_mesh_shader; + +struct TaskPayload { + colorMask: vec4, + visible: bool, +} + +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} + +struct PrimitiveOutput { + @builtin(triangle_indices) indices: vec3, + @builtin(cull_primitive) cull: bool, + @location(1) @per_primitive colorMask: vec4, +} + +struct PrimitiveInput { + @location(1) @per_primitive colorMask: vec4, +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var taskPayload: TaskPayload; +var workgroupData: f32; +var mesh_output: MeshOutput; + +@task @payload(taskPayload) @workgroup_size(1, 1, 1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1f; + taskPayload.colorMask = vec4(1f, 1f, 0f, 1f); + taskPayload.visible = true; + return vec3(1u, 1u, 1u); +} + +@mesh(mesh_output) @workgroup_size(1, 1, 1) @payload(taskPayload) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + mesh_output.vertex_count = 3u; + mesh_output.primitive_count = 1u; + workgroupData = 2f; + mesh_output.vertices[0].position = vec4(0f, 1f, 0f, 1f); + let _e25 = taskPayload.colorMask; + mesh_output.vertices[0].color = (vec4(0f, 1f, 0f, 1f) * _e25); + mesh_output.vertices[1].position = vec4(-1f, -1f, 0f, 1f); + let _e47 = taskPayload.colorMask; + mesh_output.vertices[1].color = (vec4(0f, 0f, 1f, 1f) * _e47); + mesh_output.vertices[2].position = vec4(1f, -1f, 0f, 1f); + let _e69 = taskPayload.colorMask; + mesh_output.vertices[2].color = (vec4(1f, 0f, 0f, 1f) * _e69); + mesh_output.primitives[0].indices = vec3(0u, 1u, 2u); + let _e90 = taskPayload.visible; + mesh_output.primitives[0].cull = !(_e90); + mesh_output.primitives[0].colorMask = vec4(1f, 0f, 1f, 1f); + return; +} + +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return (vertex.color * primitive.colorMask); +}