Skip to content

Commit eecc9f9

Browse files
committed
Matrix4x3: add Matrix4x3, needed for raytracing intrinsics
1 parent d1fce90 commit eecc9f9

File tree

6 files changed

+98
-41
lines changed

6 files changed

+98
-41
lines changed

crates/spirv-std/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ pub mod float;
9595
pub mod image;
9696
pub mod indirect_command;
9797
pub mod integer;
98+
pub mod matrix;
9899
pub mod memory;
99100
pub mod number;
100101
pub mod ray_tracing;

crates/spirv-std/src/matrix.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//! a set of common SPIR-V Matrices, used for intrinsics
2+
3+
use glam::{Affine3A, Mat3, Mat3A, Mat4, Vec3, Vec3A};
4+
5+
/// A Matrix with 4 columns of [`Vec3`], very similar to glam's [`Affine3A`].
6+
///
7+
/// Primarily used in ray tracing extensions to represent object rotation, scale and translation.
8+
#[derive(Clone, Copy, Default, PartialEq)]
9+
#[repr(C)]
10+
#[spirv(matrix)]
11+
#[allow(missing_docs)]
12+
pub struct Matrix4x3 {
13+
pub x: Vec3A,
14+
pub y: Vec3A,
15+
pub z: Vec3A,
16+
pub w: Vec3A,
17+
}
18+
19+
/// The `from_*` fn signatures should match [`Affine3A`], to make it easier to switch to [`Affine3A`] later.
20+
/// The `to_*` fn signatures are custom
21+
impl Matrix4x3 {
22+
/// Convert from glam's [`Affine3A`]
23+
pub fn from_affine3a(affine: Affine3A) -> Self {
24+
Self {
25+
x: affine.x_axis,
26+
y: affine.y_axis,
27+
z: affine.z_axis,
28+
w: affine.w_axis,
29+
}
30+
}
31+
32+
/// Creates an affine transform from a 3x3 matrix (expressing scale, shear and
33+
/// rotation)
34+
pub fn from_mat3(mat3: Mat3) -> Self {
35+
Self::from_affine3a(Affine3A::from_mat3(mat3))
36+
}
37+
38+
/// Creates an affine transform from a 3x3 matrix (expressing scale, shear and rotation)
39+
/// and a translation vector.
40+
///
41+
/// Equivalent to `Affine3A::from_translation(translation) * Affine3A::from_mat3(mat3)`
42+
pub fn from_mat3_translation(mat3: Mat3, translation: Vec3) -> Self {
43+
Self::from_affine3a(Affine3A::from_mat3_translation(mat3, translation))
44+
}
45+
46+
/// The given `Mat4` must be an affine transform,
47+
/// i.e. contain no perspective transform.
48+
pub fn from_mat4(m: Mat4) -> Self {
49+
Self::from_affine3a(Affine3A::from_mat4(m))
50+
}
51+
52+
/// Convert to glam's [`Affine3A`]
53+
pub fn to_affine3a(&self) -> Affine3A {
54+
Affine3A {
55+
matrix3: Mat3A {
56+
x_axis: self.x,
57+
y_axis: self.y,
58+
z_axis: self.z,
59+
},
60+
translation: self.w,
61+
}
62+
}
63+
64+
/// Creates a 3x3 matrix representing the rotation and scale, cutting off the translation
65+
pub fn to_mat3a(&self) -> Mat3A {
66+
self.to_affine3a().matrix3
67+
}
68+
69+
/// Creates a 3x3 matrix representing the rotation and scale, cutting off the translation
70+
pub fn to_mat3(&self) -> Mat3 {
71+
Mat3::from(self.to_mat3a())
72+
}
73+
74+
/// Creates a 4x4 matrix from this affine transform
75+
pub fn to_mat4(&self) -> Mat4 {
76+
Mat4::from(self.to_affine3a())
77+
}
78+
}

crates/spirv-std/src/ray_tracing.rs

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// NOTE(eddyb) "&-masking with zero", likely due to `NONE = 0` in `bitflags!`.
44
#![allow(clippy::bad_bit_mask)]
55

6+
use crate::matrix::Matrix4x3;
67
use crate::vector::Vector;
78
#[cfg(target_arch = "spirv")]
89
use core::arch::asm;
@@ -1002,22 +1003,14 @@ impl RayQuery {
10021003
#[spirv_std_macros::gpu_only]
10031004
#[doc(alias = "OpRayQueryGetIntersectionObjectToWorldKHR")]
10041005
#[inline]
1005-
pub unsafe fn get_candidate_intersection_object_to_world<V: Vector<f32, 3>>(&self) -> [V; 4] {
1006+
pub unsafe fn get_candidate_intersection_object_to_world(&self) -> Matrix4x3 {
10061007
unsafe {
10071008
let mut result = Default::default();
10081009

10091010
asm! {
10101011
"%u32 = OpTypeInt 32 0",
1011-
"%f32 = OpTypeFloat 32",
1012-
"%f32x3 = OpTypeVector %f32 3",
1013-
"%f32x3x4 = OpTypeMatrix %f32x3 4",
10141012
"%intersection = OpConstant %u32 0",
1015-
"%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection",
1016-
"%col0 = OpCompositeExtract %f32x3 %matrix 0",
1017-
"%col1 = OpCompositeExtract %f32x3 %matrix 1",
1018-
"%col2 = OpCompositeExtract %f32x3 %matrix 2",
1019-
"%col3 = OpCompositeExtract %f32x3 %matrix 3",
1020-
"%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3",
1013+
"%result = OpRayQueryGetIntersectionObjectToWorldKHR typeof*{result} {ray_query} %intersection",
10211014
"OpStore {result} %result",
10221015
ray_query = in(reg) self,
10231016
result = in(reg) &mut result,
@@ -1037,22 +1030,14 @@ impl RayQuery {
10371030
#[spirv_std_macros::gpu_only]
10381031
#[doc(alias = "OpRayQueryGetIntersectionObjectToWorldKHR")]
10391032
#[inline]
1040-
pub unsafe fn get_committed_intersection_object_to_world<V: Vector<f32, 3>>(&self) -> [V; 4] {
1033+
pub unsafe fn get_committed_intersection_object_to_world(&self) -> Matrix4x3 {
10411034
unsafe {
10421035
let mut result = Default::default();
10431036

10441037
asm! {
10451038
"%u32 = OpTypeInt 32 0",
1046-
"%f32 = OpTypeFloat 32",
1047-
"%f32x3 = OpTypeVector %f32 3",
1048-
"%f32x3x4 = OpTypeMatrix %f32x3 4",
10491039
"%intersection = OpConstant %u32 1",
1050-
"%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection",
1051-
"%col0 = OpCompositeExtract %f32x3 %matrix 0",
1052-
"%col1 = OpCompositeExtract %f32x3 %matrix 1",
1053-
"%col2 = OpCompositeExtract %f32x3 %matrix 2",
1054-
"%col3 = OpCompositeExtract %f32x3 %matrix 3",
1055-
"%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3",
1040+
"%result = OpRayQueryGetIntersectionObjectToWorldKHR typeof*{result} {ray_query} %intersection",
10561041
"OpStore {result} %result",
10571042
ray_query = in(reg) self,
10581043
result = in(reg) &mut result,

tests/compiletests/ui/arch/debug_printf_type_checking.stderr

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ help: the return type of this call is `u32` due to the type of the argument pass
7575
| |
7676
| this argument influences the return type of `debug_printf_assert_is_type`
7777
note: function defined here
78-
--> $SPIRV_STD_SRC/lib.rs:133:8
78+
--> $SPIRV_STD_SRC/lib.rs:134:8
7979
|
80-
133 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
80+
134 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
8181
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
8282
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
8383
help: change the type of the numeric literal from `u32` to `f32`
@@ -103,9 +103,9 @@ help: the return type of this call is `f32` due to the type of the argument pass
103103
| |
104104
| this argument influences the return type of `debug_printf_assert_is_type`
105105
note: function defined here
106-
--> $SPIRV_STD_SRC/lib.rs:133:8
106+
--> $SPIRV_STD_SRC/lib.rs:134:8
107107
|
108-
133 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
108+
134 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
109109
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
110110
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
111111
help: change the type of the numeric literal from `f32` to `u32`
@@ -131,12 +131,12 @@ error[E0277]: the trait bound `{float}: Vector<f32, 2>` is not satisfied
131131
`UVec3` implements `Vector<u32, 3>`
132132
and 5 others
133133
note: required by a bound in `debug_printf_assert_is_vector`
134-
--> $SPIRV_STD_SRC/lib.rs:140:8
134+
--> $SPIRV_STD_SRC/lib.rs:141:8
135135
|
136-
138 | pub fn debug_printf_assert_is_vector<
136+
139 | pub fn debug_printf_assert_is_vector<
137137
| ----------------------------- required by a bound in this function
138-
139 | TY: crate::scalar::Scalar,
139-
140 | V: crate::vector::Vector<TY, SIZE>,
138+
140 | TY: crate::scalar::Scalar,
139+
141 | V: crate::vector::Vector<TY, SIZE>,
140140
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`
141141
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
142142

@@ -157,9 +157,9 @@ help: the return type of this call is `Vec2` due to the type of the argument pas
157157
| |
158158
| this argument influences the return type of `debug_printf_assert_is_type`
159159
note: function defined here
160-
--> $SPIRV_STD_SRC/lib.rs:133:8
160+
--> $SPIRV_STD_SRC/lib.rs:134:8
161161
|
162-
133 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
162+
134 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
163163
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
164164
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
165165

tests/compiletests/ui/arch/ray_query_get_intersection_object_to_world_khr.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
33

44
use glam::Vec3;
5+
use spirv_std::matrix::Matrix4x3;
56
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
67
use spirv_std::spirv;
78

@@ -10,7 +11,7 @@ pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStruct
1011
unsafe {
1112
spirv_std::ray_query!(let mut handle);
1213
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
13-
let matrix: [glam::Vec3; 4] = handle.get_candidate_intersection_object_to_world();
14-
let matrix: [glam::Vec3; 4] = handle.get_committed_intersection_object_to_world();
14+
let matrix: Matrix4x3 = handle.get_candidate_intersection_object_to_world();
15+
let matrix: Matrix4x3 = handle.get_committed_intersection_object_to_world();
1516
}
1617
}

tests/compiletests/ui/spirv-attr/all-builtins.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,9 @@
33
// compile-flags: -Ctarget-feature=+DeviceGroup,+DrawParameters,+FragmentBarycentricNV,+FragmentBarycentricKHR,+FragmentDensityEXT,+FragmentFullyCoveredEXT,+Geometry,+GroupNonUniform,+GroupNonUniformBallot,+MeshShadingNV,+MultiView,+MultiViewport,+RayTracingKHR,+SampleRateShading,+ShaderSMBuiltinsNV,+ShaderStereoViewNV,+StencilExportEXT,+Tessellation,+ext:SPV_AMD_shader_explicit_vertex_parameter,+ext:SPV_EXT_fragment_fully_covered,+ext:SPV_EXT_fragment_invocation_density,+ext:SPV_EXT_shader_stencil_export,+ext:SPV_KHR_ray_tracing,+ext:SPV_NV_fragment_shader_barycentric,+ext:SPV_NV_mesh_shader,+ext:SPV_NV_shader_sm_builtins,+ext:SPV_NV_stereo_view_rendering
44

55
use spirv_std::glam::*;
6+
use spirv_std::matrix::Matrix4x3;
67
use spirv_std::spirv;
78

8-
#[derive(Clone, Copy)]
9-
#[spirv(matrix)]
10-
pub struct Matrix4x3 {
11-
pub x: glam::Vec3,
12-
pub y: glam::Vec3,
13-
pub z: glam::Vec3,
14-
pub w: glam::Vec3,
15-
}
16-
179
#[spirv(tessellation_control)]
1810
pub fn tessellation_control(
1911
#[spirv(invocation_id)] invocation_id: u32,

0 commit comments

Comments
 (0)