diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 317954157e4..7c851f3869a 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -2231,6 +2231,24 @@ bool SemanticsVisitor::canCoerce( return rs; } +bool SemanticsVisitor::canCoerce( + CoercionSite site, + Type* toType, + QualType fromType, + Expr* fromExpr, + ConversionCost* outCost) +{ + // For site-specific coercion, we bypass the cache since the cache + // doesn't account for CoercionSite differences + ConversionCost cost; + bool rs = _coerce(site, toType, nullptr, fromType, fromExpr, getSink(), &cost); + + if (outCost) + *outCost = cost; + + return rs; +} + TypeCastExpr* SemanticsVisitor::createImplicitCastExpr() { return m_astBuilder->create(); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f7dc4d714fb..9478e78f005 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1775,6 +1775,10 @@ struct SemanticsVisitor : public SemanticsContext /// bool canCoerce(Type* toType, QualType fromType, Expr* fromExpr, ConversionCost* outCost = 0); + /// Version of `canCoerce` that accepts a `CoercionSite` parameter. + /// This allows for site-specific conversion rules (e.g., sized array to unsized array conversion for function arguments). + bool canCoerce(CoercionSite site, Type* toType, QualType fromType, Expr* fromExpr, ConversionCost* outCost = 0); + TypeCastExpr* createImplicitCastExpr(); Expr* CreateImplicitCastExpr(Type* toType, Expr* fromExpr); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 24b48091155..61f336632bd 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -744,7 +744,7 @@ bool SemanticsVisitor::TryCheckOverloadCandidateTypes( if (!paramType->equals(argType)) return {nullptr, nullptr}; } - else if (!canCoerce(paramType, argType, arg.argExpr, &cost)) + else if (!canCoerce(CoercionSite::Argument, paramType, argType, arg.argExpr, &cost)) { return {nullptr, nullptr}; } diff --git a/tests/language-feature/array-overload-resolution.slang b/tests/language-feature/array-overload-resolution.slang new file mode 100644 index 00000000000..ff6b0ab2fca --- /dev/null +++ b/tests/language-feature/array-overload-resolution.slang @@ -0,0 +1,40 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cpu + +// Test that sized arrays can be passed to unsized array parameters +// during function overload resolution. + +float4 Function(float4[] f) { return f[0]; } +float4 Function(RWStructuredBuffer f) { return f[0]; } + +// Test with different array sizes +int TestFunc(int[] arr) { return arr[0]; } +int TestFunc(float val) { return (int)val; } + +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + static float4[16] input1 = { + float4(1, 2, 3, 4), + float4(5, 6, 7, 8) + }; + static float4[8] input2 = { + float4(10, 11, 12, 13) + }; + + // Test int arrays too + static int[5] intArray1 = { 42 }; + static int[3] intArray2 = { 100 }; + + float4 result1 = Function(input1); // Should call array version + float4 result2 = Function(input2); // Should call array version + int intResult1 = TestFunc(intArray1); // Should call array version + int intResult2 = TestFunc(intArray2); // Should call array version + + // Store results + outputBuffer[0] = result1 + result2 + float4(intResult1, intResult2, 0, 0); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=16):out +//CHECK: 113.000000 \ No newline at end of file