Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions slangpy/tests/device/test_shader_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class Var:
"u_float2x4": Var(kind="matrix", type="float", value=[[0, 1, 2, 3], [4, 5, 6, 7]]),
"u_float3x4": Var(kind="matrix", type="float", value=[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
"u_float4x4": Var(kind="matrix", type="float", value=[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]),
"pb_float4x4": Var(kind="matrix", type="float", value=[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]),
"pb_float4x3": Var(kind="matrix", type="float", value=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),
"pb_float3x4": Var(kind="matrix", type="float", value=[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# float16_t
"u_float16_t": Var(kind="scalar", type="float16_t", value=1.2345),
"u_float16_t_min": Var(kind="scalar", type="float16_t", value=FLOAT16_MIN),
Expand Down Expand Up @@ -162,6 +165,31 @@ class Var:
"f_float2": Var(kind="vector", type="float", value=[1.23, 1.234]),
"f_float3": Var(kind="vector", type="float", value=[1.23, 1.234, 1.2345]),
"f_float4": Var(kind="vector", type="float", value=[1.23, 1.235, 1.23456, 1.234567]),
"f_float4x3": Var(kind="matrix", type="float", value=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),
"f_bool_array": Var(kind="array", type="bool", value=[False]),
"f_int_array": Var(kind="array", type="int", value=[-10, 10]),
"f_uint_array": Var(kind="array", type="uint", value=[0, 10, 20]),
"f_float_array": Var(kind="array", type="float", value=[0.1, 0.2, 0.3, 0.4]),
},
# pb_struct
"pb_struct": {
"f_bool": Var(kind="scalar", type="bool", value=True),
"f_bool2": Var(kind="vector", type="bool", value=[False, True]),
"f_bool3": Var(kind="vector", type="bool", value=[False, True, False]),
"f_bool4": Var(kind="vector", type="bool", value=[False, True, False, True]),
"f_int": Var(kind="scalar", type="int", value=-123),
"f_int2": Var(kind="vector", type="int", value=[-123, 123]),
"f_int3": Var(kind="vector", type="int", value=[-123, 123, -1234]),
"f_int4": Var(kind="vector", type="int", value=[-123, 123, -1234, 1234]),
"f_uint": Var(kind="scalar", type="uint", value=12),
"f_uint2": Var(kind="vector", type="uint", value=[123, 1234]),
"f_uint3": Var(kind="vector", type="uint", value=[123, 1234, 12345]),
"f_uint4": Var(kind="vector", type="uint", value=[123, 1235, 123456, 1234567]),
"f_float": Var(kind="scalar", type="float", value=1.2),
"f_float2": Var(kind="vector", type="float", value=[1.23, 1.234]),
"f_float3": Var(kind="vector", type="float", value=[1.23, 1.234, 1.2345]),
"f_float4": Var(kind="vector", type="float", value=[1.23, 1.235, 1.23456, 1.234567]),
"f_float4x3": Var(kind="matrix", type="float", value=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),
"f_bool_array": Var(kind="array", type="bool", value=[False]),
"f_int_array": Var(kind="array", type="int", value=[-10, 10]),
"f_uint_array": Var(kind="array", type="uint", value=[0, 10, 20]),
Expand Down Expand Up @@ -225,6 +253,7 @@ def convert_matrix(type: str, rows: int, cols: int, values: Any):
TABLE = {
("float", 2, 2): spy.float2x2,
("float", 3, 3): spy.float3x3,
("float", 4, 3): spy.float4x3,
("float", 2, 4): spy.float2x4,
("float", 3, 4): spy.float3x4,
("float", 4, 4): spy.float4x4,
Expand Down Expand Up @@ -320,6 +349,10 @@ def write_vars(
):
if isinstance(vars, dict):
for key, var in vars.items():
# Disabling ParameterBlock tests on Vulkan, due to issue:
# https://github.com/shader-slang/slang/issues/7431
if device_type == spy.DeviceType.vulkan and key.startswith("pb_"):
continue
if isinstance(var, dict):
write_vars(device_type, cursor[key], var, key + ".")
elif isinstance(var, list):
Expand Down
51 changes: 50 additions & 1 deletion slangpy/tests/device/test_shader_cursor.slang
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Enable (and remove the define) after this issue has been addressed:
// https://github.com/shader-slang/slang/issues/7431
// Slang currently generates somewhat invalid code, and at the same time,
// the tests fail in a weird fashion (u_struct_array[1:] are not set).
#define ENABLE_PARAMBLOCK_TESTS !defined(__TARGET_VULKAN__)

struct Test {
bool f_bool;
bool2 f_bool2;
Expand All @@ -17,6 +23,7 @@ struct Test {
float2 f_float2;
float3 f_float3;
float4 f_float4;
float4x3 f_float4x3;

bool f_bool_array[1];
int f_int_array[2];
Expand Down Expand Up @@ -85,6 +92,11 @@ uniform float3x3 u_float3x3;
uniform float2x4 u_float2x4;
uniform float3x4 u_float3x4;
uniform float4x4 u_float4x4;
#if ENABLE_PARAMBLOCK_TESTS
ParameterBlock<float4x4> pb_float4x4;
ParameterBlock<float4x3> pb_float4x3;
ParameterBlock<float3x4> pb_float3x4;
#endif

uniform float16_t u_float16_t;
uniform float16_t u_float16_t_min;
Expand All @@ -108,12 +120,14 @@ uniform uint u_uint_array[4];
uniform float u_float_array[4];

uniform Test u_struct;
#if ENABLE_PARAMBLOCK_TESTS
ParameterBlock<Test> pb_struct;
#endif
uniform int u_int_array_2[4];
uniform Test u_struct_array[4];

Buffer<uint> u_buffer;


RWStructuredBuffer<uint> results;

struct Writer {
Expand Down Expand Up @@ -260,6 +274,11 @@ void compute_main(uint3 tid: SV_DispatchThreadID)
writer.write(u_float2x4);
writer.write(u_float3x4);
writer.write(u_float4x4);
#if ENABLE_PARAMBLOCK_TESTS
writer.write(pb_float4x4);
writer.write(pb_float4x3);
writer.write(pb_float3x4);
#endif

writer.write(u_float16_t);
writer.write(u_float16_t_min);
Expand Down Expand Up @@ -302,6 +321,7 @@ void compute_main(uint3 tid: SV_DispatchThreadID)
writer.write(u_struct.f_float2);
writer.write(u_struct.f_float3);
writer.write(u_struct.f_float4);
writer.write(u_struct.f_float4x3);

for (uint i = 0; i < 1; ++i)
writer.write(u_struct.f_bool_array[i]);
Expand All @@ -312,6 +332,35 @@ void compute_main(uint3 tid: SV_DispatchThreadID)
for (uint i = 0; i < 4; ++i)
writer.write(u_struct.f_float_array[i]);

#if ENABLE_PARAMBLOCK_TESTS
writer.write(pb_struct.f_bool);
writer.write(pb_struct.f_bool2);
writer.write(pb_struct.f_bool3);
writer.write(pb_struct.f_bool4);
writer.write(pb_struct.f_int);
writer.write(pb_struct.f_int2);
writer.write(pb_struct.f_int3);
writer.write(pb_struct.f_int4);
writer.write(pb_struct.f_uint);
writer.write(pb_struct.f_uint2);
writer.write(pb_struct.f_uint3);
writer.write(pb_struct.f_uint4);
writer.write(pb_struct.f_float);
writer.write(pb_struct.f_float2);
writer.write(pb_struct.f_float3);
writer.write(pb_struct.f_float4);
writer.write(pb_struct.f_float4x3);

for (uint i = 0; i < 1; ++i)
writer.write(pb_struct.f_bool_array[i]);
for (uint i = 0; i < 2; ++i)
writer.write(pb_struct.f_int_array[i]);
for (uint i = 0; i < 3; ++i)
writer.write(pb_struct.f_uint_array[i]);
for (uint i = 0; i < 4; ++i)
writer.write(pb_struct.f_float_array[i]);
#endif

for (uint i = 0; i < 4; ++i)
writer.write(u_int_array_2[i]);

Expand Down
8 changes: 7 additions & 1 deletion src/slangpy_ext/device/cursor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,14 @@ class WriteConverterTable {
}
case TypeReflection::Kind::constant_buffer:
case TypeReflection::Kind::parameter_block:
if constexpr (requires { self.dereference(); }) {
// Unwrap constant buffers or parameter blocks for shader cursors
auto child = self.dereference();
write_internal(child, nbval);
return;
} else
SGL_THROW("constant_buffer and param_block not expected in BufferElementCursor");
case TypeReflection::Kind::struct_: {
// Unwrap constant buffers or parameter blocks
if (kind != TypeReflection::Kind::struct_)
type_layout = type_layout->getElementTypeLayout();

Expand Down
Loading