diff --git a/slangpy/builtin/array.py b/slangpy/builtin/array.py index c295e220a..f712c8b3e 100644 --- a/slangpy/builtin/array.py +++ b/slangpy/builtin/array.py @@ -7,15 +7,12 @@ from slangpy.builtin.value import ValueMarshall from slangpy.reflection import SlangType, SlangProgramLayout from slangpy.bindings import ( - PYTHON_SIGNATURES, - PYTHON_TYPES, BindContext, BoundVariable, - BoundVariableRuntime, CodeGenBlock, ) from slangpy import ShaderCursor, ShaderObject -from slangpy.core.native import AccessType, CallContext, NativeValueMarshall, unpack_arg +from slangpy.core.native import AccessType, unpack_arg import slangpy.reflection as kfr diff --git a/slangpy/builtin/tensor.py b/slangpy/builtin/tensor.py index 20ee161cd..29e7f470f 100644 --- a/slangpy/builtin/tensor.py +++ b/slangpy/builtin/tensor.py @@ -3,7 +3,7 @@ from typing import Any, Optional, cast -from slangpy.core.native import AccessType, Shape +from slangpy.core.native import AccessType, Shape, CallMode from slangpy.reflection.reflectiontypes import is_matching_array_type, VectorType from slangpy.types.tensor import Tensor @@ -199,18 +199,33 @@ def resolve_type(self, context: BindContext, bound_type: SlangType): f"to tensor with element type {bound_type.dtype.full_name}" ) - # Atomic tensors are special, they must be passed as-is - if bound_type.name == "AtomicTensor": - return bound_type - - return build_tensor_type( - self.layout, - bound_type.dtype, - bound_type.dims, - bound_type.writable, - self.d_in is not None, - self.d_out is not None, - ) + # If binding to an interface, need to decide on tensor type based on call mode + if bound_type.name in ("ITensor", "RWTensor"): + if context.call_mode == CallMode.prim: + # In forwards pass, bind a Tensor or RWTensor depending on writability + return build_tensor_type( + self.layout, + bound_type.dtype, + bound_type.dims, + bound_type.writable, + False, + False, + ) + else: + # If we are in a backward pass, no choice but to bind the full tensor + # type as we don't have context at this point to know if it should be + # GradIn/GradOut/GradInOutTensor + return build_tensor_type( + self.layout, + bound_type.dtype, + bound_type.dims, + bound_type.writable, + self.d_in is not None, + self.d_out is not None, + ) + + # None-interfaces have to be bound to the exact type + return bound_type # if implicit element casts enabled, allow conversion from type to element type if context.options["implicit_element_casts"]: @@ -256,21 +271,25 @@ def resolve_dimensionality( return self.dims + len(self.slang_element_type.shape) - len(vector_target_type.shape) def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): - if isinstance(binding.vector_type, ITensorType): - writable = binding.vector_type.writable - else: - writable = binding.access[0] in (AccessType.write, AccessType.readwrite) - # Atomic tensors are special, they must be passed as-is - if binding.vector_type.name == "AtomicTensor": - type_name = binding.vector_type.full_name + if isinstance(binding.vector_type, ITensorType): + # If binding to a tensor type, we need to use the same basic tensor type. However + # dimensionality may differ, we still need to generate the full tensor type name. + assert not binding.vector_type.name in ("ITensor", "RWTensor") + type_name = ( + f"{binding.vector_type.name}<{self.slang_element_type.full_name}, {self.dims}>" + ) else: + # If binding to another type (eg vectorizing to a scalar), the tensor type to use + # is based on writability and existence of gradients. type_name = build_tensor_name( self.slang_element_type, self.dims, - writable, - self.d_in is not None, - self.d_out is not None, + binding.access[0] in (AccessType.write, AccessType.readwrite), + self.d_in is not None + and binding.access[1] in (AccessType.read, AccessType.readwrite), + self.d_out is not None + and binding.access[1] in (AccessType.write, AccessType.readwrite), ) cgb.type_alias(f"_t_{binding.variable_name}", type_name) diff --git a/src/slangpy_ext/utils/slangpytensor.cpp b/src/slangpy_ext/utils/slangpytensor.cpp index b26296a20..742a561f6 100644 --- a/src/slangpy_ext/utils/slangpytensor.cpp +++ b/src/slangpy_ext/utils/slangpytensor.cpp @@ -141,18 +141,44 @@ void NativeTensorMarshall::write_shader_cursor_pre_dispatch( const ref& grad_in = primal->grad_in(); const ref& grad_out = primal->grad_out(); - if (!has_derivative()) { - write_shader_cursor_fields(context, binding, field, primal, read_back); - } else { - write_shader_cursor_fields(context, binding, field["primal"], primal, read_back); + ShaderCursor primal_field = field.find_field("primal"); + if (primal_field.is_valid()) { + // Record these pointers for debug checks + Buffer* bound_primal_buffer = primal->storage().get(); + Buffer* bound_grad_in_buffer = nullptr; + Buffer* bound_grad_out_buffer = nullptr; + + // Binding to a Tensor object that contains child primal and derivative Tensors. + write_shader_cursor_fields(context, binding, primal_field, primal, read_back); + if (m_d_in) { - SGL_CHECK(grad_in, "Missing required input gradients"); - write_shader_cursor_fields(context, binding, field["d_in"], grad_in.get(), read_back); + ShaderCursor d_in_field = field.find_field("d_in"); + if (d_in_field.is_valid()) { + bound_grad_in_buffer = grad_in->storage().get(); + write_shader_cursor_fields(context, binding, d_in_field, grad_in.get(), read_back); + } } + if (m_d_out) { - SGL_CHECK(grad_out, "Missing required input gradients"); - write_shader_cursor_fields(context, binding, field["d_out"], grad_out.get(), read_back); + ShaderCursor d_out_field = field.find_field("d_out"); + if (d_out_field.is_valid()) { + bound_grad_out_buffer = grad_out->storage().get(); + write_shader_cursor_fields(context, binding, d_out_field, grad_out.get(), read_back); + } } + + if (bound_primal_buffer == bound_grad_in_buffer || bound_primal_buffer == bound_grad_out_buffer) { + log_warn("Binding the same storage for primal and gradient on the same tensor. This will have serious " + "performance impacts."); + } + if (bound_grad_in_buffer != nullptr && bound_grad_in_buffer == bound_grad_out_buffer) { + log_warn("Binding the same storage for grad in and grad out on the same tensor. This will have serious " + "performance impacts."); + } + + } else { + // Binding to a single Tensor object that represents the primal. + write_shader_cursor_fields(context, binding, field, primal, read_back); } if (context->call_mode() != CallMode::prim && grad_in && grad_in == grad_out) {