diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 4a4774823c..14e887b70f 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -95,6 +95,7 @@ pub mod float; pub mod image; pub mod indirect_command; pub mod integer; +pub mod matrix; pub mod memory; pub mod number; pub mod ray_tracing; diff --git a/crates/spirv-std/src/matrix.rs b/crates/spirv-std/src/matrix.rs new file mode 100644 index 0000000000..1e6e7ca79d --- /dev/null +++ b/crates/spirv-std/src/matrix.rs @@ -0,0 +1,78 @@ +//! a set of common SPIR-V Matrices, used for intrinsics + +use glam::{Affine3A, Mat3, Mat3A, Mat4, Vec3, Vec3A}; + +/// A Matrix with 4 columns of [`Vec3`], very similar to glam's [`Affine3A`]. +/// +/// Primarily used in ray tracing extensions to represent object rotation, scale and translation. +#[derive(Clone, Copy, Default, PartialEq)] +#[repr(C)] +#[spirv(matrix)] +#[allow(missing_docs)] +pub struct Matrix4x3 { + pub x: Vec3A, + pub y: Vec3A, + pub z: Vec3A, + pub w: Vec3A, +} + +/// The `from_*` fn signatures should match [`Affine3A`], to make it easier to switch to [`Affine3A`] later. +/// The `to_*` fn signatures are custom +impl Matrix4x3 { + /// Convert from glam's [`Affine3A`] + pub fn from_affine3a(affine: Affine3A) -> Self { + Self { + x: affine.x_axis, + y: affine.y_axis, + z: affine.z_axis, + w: affine.w_axis, + } + } + + /// Creates an affine transform from a 3x3 matrix (expressing scale, shear and + /// rotation) + pub fn from_mat3(mat3: Mat3) -> Self { + Self::from_affine3a(Affine3A::from_mat3(mat3)) + } + + /// Creates an affine transform from a 3x3 matrix (expressing scale, shear and rotation) + /// and a translation vector. + /// + /// Equivalent to `Affine3A::from_translation(translation) * Affine3A::from_mat3(mat3)` + pub fn from_mat3_translation(mat3: Mat3, translation: Vec3) -> Self { + Self::from_affine3a(Affine3A::from_mat3_translation(mat3, translation)) + } + + /// The given `Mat4` must be an affine transform, + /// i.e. contain no perspective transform. + pub fn from_mat4(m: Mat4) -> Self { + Self::from_affine3a(Affine3A::from_mat4(m)) + } + + /// Convert to glam's [`Affine3A`] + pub fn to_affine3a(self) -> Affine3A { + Affine3A { + matrix3: Mat3A { + x_axis: self.x, + y_axis: self.y, + z_axis: self.z, + }, + translation: self.w, + } + } + + /// Creates a 3x3 matrix representing the rotation and scale, cutting off the translation + pub fn to_mat3a(self) -> Mat3A { + self.to_affine3a().matrix3 + } + + /// Creates a 3x3 matrix representing the rotation and scale, cutting off the translation + pub fn to_mat3(self) -> Mat3 { + Mat3::from(self.to_mat3a()) + } + + /// Creates a 4x4 matrix from this affine transform + pub fn to_mat4(self) -> Mat4 { + Mat4::from(self.to_affine3a()) + } +} diff --git a/crates/spirv-std/src/ray_tracing.rs b/crates/spirv-std/src/ray_tracing.rs index ff18a2e9d2..eb00cb991a 100644 --- a/crates/spirv-std/src/ray_tracing.rs +++ b/crates/spirv-std/src/ray_tracing.rs @@ -3,6 +3,7 @@ // NOTE(eddyb) "&-masking with zero", likely due to `NONE = 0` in `bitflags!`. #![allow(clippy::bad_bit_mask)] +use crate::matrix::Matrix4x3; use crate::vector::Vector; #[cfg(target_arch = "spirv")] use core::arch::asm; @@ -1002,22 +1003,14 @@ impl RayQuery { #[spirv_std_macros::gpu_only] #[doc(alias = "OpRayQueryGetIntersectionObjectToWorldKHR")] #[inline] - pub unsafe fn get_candidate_intersection_object_to_world>(&self) -> [V; 4] { + pub unsafe fn get_candidate_intersection_object_to_world(&self) -> Matrix4x3 { unsafe { let mut result = Default::default(); asm! { "%u32 = OpTypeInt 32 0", - "%f32 = OpTypeFloat 32", - "%f32x3 = OpTypeVector %f32 3", - "%f32x3x4 = OpTypeMatrix %f32x3 4", "%intersection = OpConstant %u32 0", - "%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection", - "%col0 = OpCompositeExtract %f32x3 %matrix 0", - "%col1 = OpCompositeExtract %f32x3 %matrix 1", - "%col2 = OpCompositeExtract %f32x3 %matrix 2", - "%col3 = OpCompositeExtract %f32x3 %matrix 3", - "%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3", + "%result = OpRayQueryGetIntersectionObjectToWorldKHR typeof*{result} {ray_query} %intersection", "OpStore {result} %result", ray_query = in(reg) self, result = in(reg) &mut result, @@ -1037,22 +1030,14 @@ impl RayQuery { #[spirv_std_macros::gpu_only] #[doc(alias = "OpRayQueryGetIntersectionObjectToWorldKHR")] #[inline] - pub unsafe fn get_committed_intersection_object_to_world>(&self) -> [V; 4] { + pub unsafe fn get_committed_intersection_object_to_world(&self) -> Matrix4x3 { unsafe { let mut result = Default::default(); asm! { "%u32 = OpTypeInt 32 0", - "%f32 = OpTypeFloat 32", - "%f32x3 = OpTypeVector %f32 3", - "%f32x3x4 = OpTypeMatrix %f32x3 4", "%intersection = OpConstant %u32 1", - "%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection", - "%col0 = OpCompositeExtract %f32x3 %matrix 0", - "%col1 = OpCompositeExtract %f32x3 %matrix 1", - "%col2 = OpCompositeExtract %f32x3 %matrix 2", - "%col3 = OpCompositeExtract %f32x3 %matrix 3", - "%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3", + "%result = OpRayQueryGetIntersectionObjectToWorldKHR typeof*{result} {ray_query} %intersection", "OpStore {result} %result", ray_query = in(reg) self, result = in(reg) &mut result, diff --git a/tests/compiletests/ui/arch/debug_printf_type_checking.stderr b/tests/compiletests/ui/arch/debug_printf_type_checking.stderr index faec0bc0c7..5c1b487bde 100644 --- a/tests/compiletests/ui/arch/debug_printf_type_checking.stderr +++ b/tests/compiletests/ui/arch/debug_printf_type_checking.stderr @@ -75,9 +75,9 @@ help: the return type of this call is `u32` due to the type of the argument pass | | | this argument influences the return type of `debug_printf_assert_is_type` note: function defined here - --> $SPIRV_STD_SRC/lib.rs:133:8 + --> $SPIRV_STD_SRC/lib.rs:134:8 | -133 | pub fn debug_printf_assert_is_type(ty: T) -> T { +134 | pub fn debug_printf_assert_is_type(ty: T) -> T { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) help: change the type of the numeric literal from `u32` to `f32` @@ -103,9 +103,9 @@ help: the return type of this call is `f32` due to the type of the argument pass | | | this argument influences the return type of `debug_printf_assert_is_type` note: function defined here - --> $SPIRV_STD_SRC/lib.rs:133:8 + --> $SPIRV_STD_SRC/lib.rs:134:8 | -133 | pub fn debug_printf_assert_is_type(ty: T) -> T { +134 | pub fn debug_printf_assert_is_type(ty: T) -> T { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) help: change the type of the numeric literal from `f32` to `u32` @@ -131,12 +131,12 @@ error[E0277]: the trait bound `{float}: Vector` is not satisfied `UVec3` implements `Vector` and 5 others note: required by a bound in `debug_printf_assert_is_vector` - --> $SPIRV_STD_SRC/lib.rs:140:8 + --> $SPIRV_STD_SRC/lib.rs:141:8 | -138 | pub fn debug_printf_assert_is_vector< +139 | pub fn debug_printf_assert_is_vector< | ----------------------------- required by a bound in this function -139 | TY: crate::scalar::Scalar, -140 | V: crate::vector::Vector, +140 | TY: crate::scalar::Scalar, +141 | V: crate::vector::Vector, | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector` = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) @@ -157,9 +157,9 @@ help: the return type of this call is `Vec2` due to the type of the argument pas | | | this argument influences the return type of `debug_printf_assert_is_type` note: function defined here - --> $SPIRV_STD_SRC/lib.rs:133:8 + --> $SPIRV_STD_SRC/lib.rs:134:8 | -133 | pub fn debug_printf_assert_is_type(ty: T) -> T { +134 | pub fn debug_printf_assert_is_type(ty: T) -> T { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compiletests/ui/arch/ray_query_get_intersection_object_to_world_khr.rs b/tests/compiletests/ui/arch/ray_query_get_intersection_object_to_world_khr.rs index ac3f9c3ac7..20b8fc5b2f 100644 --- a/tests/compiletests/ui/arch/ray_query_get_intersection_object_to_world_khr.rs +++ b/tests/compiletests/ui/arch/ray_query_get_intersection_object_to_world_khr.rs @@ -2,6 +2,7 @@ // compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query use glam::Vec3; +use spirv_std::matrix::Matrix4x3; use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery}; use spirv_std::spirv; @@ -10,7 +11,7 @@ pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStruct unsafe { spirv_std::ray_query!(let mut handle); handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0); - let matrix: [glam::Vec3; 4] = handle.get_candidate_intersection_object_to_world(); - let matrix: [glam::Vec3; 4] = handle.get_committed_intersection_object_to_world(); + let matrix: Matrix4x3 = handle.get_candidate_intersection_object_to_world(); + let matrix: Matrix4x3 = handle.get_committed_intersection_object_to_world(); } } diff --git a/tests/compiletests/ui/spirv-attr/all-builtins.rs b/tests/compiletests/ui/spirv-attr/all-builtins.rs index cd83259d39..396fd850f4 100644 --- a/tests/compiletests/ui/spirv-attr/all-builtins.rs +++ b/tests/compiletests/ui/spirv-attr/all-builtins.rs @@ -3,17 +3,9 @@ // compile-flags: -Ctarget-feature=+DeviceGroup,+DrawParameters,+FragmentBarycentricNV,+FragmentBarycentricKHR,+FragmentDensityEXT,+FragmentFullyCoveredEXT,+Geometry,+GroupNonUniform,+GroupNonUniformBallot,+MeshShadingNV,+MultiView,+MultiViewport,+RayTracingKHR,+SampleRateShading,+ShaderSMBuiltinsNV,+ShaderStereoViewNV,+StencilExportEXT,+Tessellation,+ext:SPV_AMD_shader_explicit_vertex_parameter,+ext:SPV_EXT_fragment_fully_covered,+ext:SPV_EXT_fragment_invocation_density,+ext:SPV_EXT_shader_stencil_export,+ext:SPV_KHR_ray_tracing,+ext:SPV_NV_fragment_shader_barycentric,+ext:SPV_NV_mesh_shader,+ext:SPV_NV_shader_sm_builtins,+ext:SPV_NV_stereo_view_rendering use spirv_std::glam::*; +use spirv_std::matrix::Matrix4x3; use spirv_std::spirv; -#[derive(Clone, Copy)] -#[spirv(matrix)] -pub struct Matrix4x3 { - pub x: glam::Vec3, - pub y: glam::Vec3, - pub z: glam::Vec3, - pub w: glam::Vec3, -} - #[spirv(tessellation_control)] pub fn tessellation_control( #[spirv(invocation_id)] invocation_id: u32,