Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/spirv-std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub mod float;
pub mod image;
pub mod indirect_command;
pub mod integer;
pub mod matrix;
pub mod memory;
pub mod number;
pub mod ray_tracing;
Expand Down
78 changes: 78 additions & 0 deletions crates/spirv-std/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//! a set of common SPIR-V Matrices, used for intrinsics

use glam::{Affine3A, Mat3, Mat3A, Mat4, Vec3, Vec3A};

/// A Matrix with 4 columns of [`Vec3`], very similar to glam's [`Affine3A`].
///
/// Primarily used in ray tracing extensions to represent object rotation, scale and translation.
#[derive(Clone, Copy, Default, PartialEq)]
#[repr(C)]
#[spirv(matrix)]
#[allow(missing_docs)]
pub struct Matrix4x3 {
pub x: Vec3A,
pub y: Vec3A,
pub z: Vec3A,
pub w: Vec3A,
}

/// The `from_*` fn signatures should match [`Affine3A`], to make it easier to switch to [`Affine3A`] later.
/// The `to_*` fn signatures are custom
impl Matrix4x3 {
/// Convert from glam's [`Affine3A`]
pub fn from_affine3a(affine: Affine3A) -> Self {
Self {
x: affine.x_axis,
y: affine.y_axis,
z: affine.z_axis,
w: affine.w_axis,
}
}

/// Creates an affine transform from a 3x3 matrix (expressing scale, shear and
/// rotation)
pub fn from_mat3(mat3: Mat3) -> Self {
Self::from_affine3a(Affine3A::from_mat3(mat3))
}

/// Creates an affine transform from a 3x3 matrix (expressing scale, shear and rotation)
/// and a translation vector.
///
/// Equivalent to `Affine3A::from_translation(translation) * Affine3A::from_mat3(mat3)`
pub fn from_mat3_translation(mat3: Mat3, translation: Vec3) -> Self {
Self::from_affine3a(Affine3A::from_mat3_translation(mat3, translation))
}

/// The given `Mat4` must be an affine transform,
/// i.e. contain no perspective transform.
pub fn from_mat4(m: Mat4) -> Self {
Self::from_affine3a(Affine3A::from_mat4(m))
}

/// Convert to glam's [`Affine3A`]
pub fn to_affine3a(self) -> Affine3A {
Affine3A {
matrix3: Mat3A {
x_axis: self.x,
y_axis: self.y,
z_axis: self.z,
},
translation: self.w,
}
}

/// Creates a 3x3 matrix representing the rotation and scale, cutting off the translation
pub fn to_mat3a(self) -> Mat3A {
self.to_affine3a().matrix3
}

/// Creates a 3x3 matrix representing the rotation and scale, cutting off the translation
pub fn to_mat3(self) -> Mat3 {
Mat3::from(self.to_mat3a())
}

/// Creates a 4x4 matrix from this affine transform
pub fn to_mat4(self) -> Mat4 {
Mat4::from(self.to_affine3a())
}
}
25 changes: 5 additions & 20 deletions crates/spirv-std/src/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// NOTE(eddyb) "&-masking with zero", likely due to `NONE = 0` in `bitflags!`.
#![allow(clippy::bad_bit_mask)]

use crate::matrix::Matrix4x3;
use crate::vector::Vector;
#[cfg(target_arch = "spirv")]
use core::arch::asm;
Expand Down Expand Up @@ -1002,22 +1003,14 @@ impl RayQuery {
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpRayQueryGetIntersectionObjectToWorldKHR")]
#[inline]
pub unsafe fn get_candidate_intersection_object_to_world<V: Vector<f32, 3>>(&self) -> [V; 4] {
pub unsafe fn get_candidate_intersection_object_to_world(&self) -> Matrix4x3 {
unsafe {
let mut result = Default::default();

asm! {
"%u32 = OpTypeInt 32 0",
"%f32 = OpTypeFloat 32",
"%f32x3 = OpTypeVector %f32 3",
"%f32x3x4 = OpTypeMatrix %f32x3 4",
"%intersection = OpConstant %u32 0",
"%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection",
"%col0 = OpCompositeExtract %f32x3 %matrix 0",
"%col1 = OpCompositeExtract %f32x3 %matrix 1",
"%col2 = OpCompositeExtract %f32x3 %matrix 2",
"%col3 = OpCompositeExtract %f32x3 %matrix 3",
"%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3",
"%result = OpRayQueryGetIntersectionObjectToWorldKHR typeof*{result} {ray_query} %intersection",
"OpStore {result} %result",
ray_query = in(reg) self,
result = in(reg) &mut result,
Expand All @@ -1037,22 +1030,14 @@ impl RayQuery {
#[spirv_std_macros::gpu_only]
#[doc(alias = "OpRayQueryGetIntersectionObjectToWorldKHR")]
#[inline]
pub unsafe fn get_committed_intersection_object_to_world<V: Vector<f32, 3>>(&self) -> [V; 4] {
pub unsafe fn get_committed_intersection_object_to_world(&self) -> Matrix4x3 {
unsafe {
let mut result = Default::default();

asm! {
"%u32 = OpTypeInt 32 0",
"%f32 = OpTypeFloat 32",
"%f32x3 = OpTypeVector %f32 3",
"%f32x3x4 = OpTypeMatrix %f32x3 4",
"%intersection = OpConstant %u32 1",
"%matrix = OpRayQueryGetIntersectionObjectToWorldKHR %f32x3x4 {ray_query} %intersection",
"%col0 = OpCompositeExtract %f32x3 %matrix 0",
"%col1 = OpCompositeExtract %f32x3 %matrix 1",
"%col2 = OpCompositeExtract %f32x3 %matrix 2",
"%col3 = OpCompositeExtract %f32x3 %matrix 3",
"%result = OpCompositeConstruct typeof*{result} %col0 %col1 %col2 %col3",
"%result = OpRayQueryGetIntersectionObjectToWorldKHR typeof*{result} {ray_query} %intersection",
"OpStore {result} %result",
ray_query = in(reg) self,
result = in(reg) &mut result,
Expand Down
20 changes: 10 additions & 10 deletions tests/compiletests/ui/arch/debug_printf_type_checking.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ help: the return type of this call is `u32` due to the type of the argument pass
| |
| this argument influences the return type of `debug_printf_assert_is_type`
note: function defined here
--> $SPIRV_STD_SRC/lib.rs:133:8
--> $SPIRV_STD_SRC/lib.rs:134:8
|
133 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
134 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
help: change the type of the numeric literal from `u32` to `f32`
Expand All @@ -103,9 +103,9 @@ help: the return type of this call is `f32` due to the type of the argument pass
| |
| this argument influences the return type of `debug_printf_assert_is_type`
note: function defined here
--> $SPIRV_STD_SRC/lib.rs:133:8
--> $SPIRV_STD_SRC/lib.rs:134:8
|
133 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
134 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)
help: change the type of the numeric literal from `f32` to `u32`
Expand All @@ -131,12 +131,12 @@ error[E0277]: the trait bound `{float}: Vector<f32, 2>` is not satisfied
`UVec3` implements `Vector<u32, 3>`
and 5 others
note: required by a bound in `debug_printf_assert_is_vector`
--> $SPIRV_STD_SRC/lib.rs:140:8
--> $SPIRV_STD_SRC/lib.rs:141:8
|
138 | pub fn debug_printf_assert_is_vector<
139 | pub fn debug_printf_assert_is_vector<
| ----------------------------- required by a bound in this function
139 | TY: crate::scalar::Scalar,
140 | V: crate::vector::Vector<TY, SIZE>,
140 | TY: crate::scalar::Scalar,
141 | V: crate::vector::Vector<TY, SIZE>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)

Expand All @@ -157,9 +157,9 @@ help: the return type of this call is `Vec2` due to the type of the argument pas
| |
| this argument influences the return type of `debug_printf_assert_is_type`
note: function defined here
--> $SPIRV_STD_SRC/lib.rs:133:8
--> $SPIRV_STD_SRC/lib.rs:134:8
|
133 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
134 | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
= note: this error originates in the macro `debug_printf` (in Nightly builds, run with -Z macro-backtrace for more info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query

use glam::Vec3;
use spirv_std::matrix::Matrix4x3;
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
use spirv_std::spirv;

Expand All @@ -10,7 +11,7 @@ pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStruct
unsafe {
spirv_std::ray_query!(let mut handle);
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
let matrix: [glam::Vec3; 4] = handle.get_candidate_intersection_object_to_world();
let matrix: [glam::Vec3; 4] = handle.get_committed_intersection_object_to_world();
let matrix: Matrix4x3 = handle.get_candidate_intersection_object_to_world();
let matrix: Matrix4x3 = handle.get_committed_intersection_object_to_world();
}
}
10 changes: 1 addition & 9 deletions tests/compiletests/ui/spirv-attr/all-builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,9 @@
// compile-flags: -Ctarget-feature=+DeviceGroup,+DrawParameters,+FragmentBarycentricNV,+FragmentBarycentricKHR,+FragmentDensityEXT,+FragmentFullyCoveredEXT,+Geometry,+GroupNonUniform,+GroupNonUniformBallot,+MeshShadingNV,+MultiView,+MultiViewport,+RayTracingKHR,+SampleRateShading,+ShaderSMBuiltinsNV,+ShaderStereoViewNV,+StencilExportEXT,+Tessellation,+ext:SPV_AMD_shader_explicit_vertex_parameter,+ext:SPV_EXT_fragment_fully_covered,+ext:SPV_EXT_fragment_invocation_density,+ext:SPV_EXT_shader_stencil_export,+ext:SPV_KHR_ray_tracing,+ext:SPV_NV_fragment_shader_barycentric,+ext:SPV_NV_mesh_shader,+ext:SPV_NV_shader_sm_builtins,+ext:SPV_NV_stereo_view_rendering

use spirv_std::glam::*;
use spirv_std::matrix::Matrix4x3;
use spirv_std::spirv;

#[derive(Clone, Copy)]
#[spirv(matrix)]
pub struct Matrix4x3 {
pub x: glam::Vec3,
pub y: glam::Vec3,
pub z: glam::Vec3,
pub w: glam::Vec3,
}

#[spirv(tessellation_control)]
pub fn tessellation_control(
#[spirv(invocation_id)] invocation_id: u32,
Expand Down