Skip to content

Commit 664fecd

Browse files
tangent-vectorslangbotjhelferty-nv
authored
Cleanups around around parameter-passing modes (#8754)
The primary goal of this change is to try to fix (read: *remove*) the logic that ovewrites the parameter-passing mode of an `out` parameter of non-copyable type to be a `ref` parameter instead. I need that change in order to implement a more reasonable pass to check for use of uninitialized values in IR, because the semantics of an `out` parameter are quite different from a `ref`, and I need to be able to tell them apart. Trying to make a systematic fix for the issue led me to make some more broad refactoring and cleanup changes to operations that deal with parameter-passing modes. Most notably, I have cleaned up several key functions in `slang-lower-to-ir.cpp` to have more clear comments that explain what they are doing, and why. The general idea is that determining the parameter-passing mode of a parameter breaks down into a few phases: * First, theres the "nominal" (aka "declared") parameter-passing mode. This is either the mode that is explicitly indicated by modifiers in the source code (e.g., `in` or `out` on an explicit parameter declaration, or `[mutating]` to control the implicit `this` parameter of a method), or a mode that is inferred from other context related to the declaration (e.g., the implicit `this` parameter of a `set` accessor defaults to `inout`, but an implicit `this` in the context of a `class` is always `in`). * Next there is the "actual" parameter passing mode, which is based on the nominal mode, but may be adjusted based on the type of the parameter. For example, a parameter that is declared as `in` but that has a non-copyable type will be treated as `borrow in`. There's an additional wrinkled related to how we handle entry-point varying input parameters (another case where a parameter declared as `in` may be translated to use `borrow in`). One of my goals was to try to bottleneck the logic for both explicitly-declared parameters and the implicit `this` parameter through the same code path, at least for computing the actual parameter-passing mode from the nominal mode and parameter type. I don't think the logic is as clean as it can/should be yet, but this is an incremental step in a better direction. The most important concern I have about this change is whether it will break any existing code, because I have changed cases where a parameter of non-copyable type would previously have been modified at the AST level to use `ref` over to use `inout` or `out` instead. It is reasonable to be concerned that this change could result in IR code attempting to copy values of non-copyable type, which could break code that uses non-copyable types like `RayQuery` and `HitObject`. I note that the compiler already allows parameters of non-copyable type as `borrow in` parameters (which allow implicit copies to be introduced, just as for `inout` or `out`), so if the handling of `out` and `inout` for non-copyable types is broken, it would strongly imply that `borrow in` parameters of non-copyable types are already broken. If it turns out that there is no reasonable way to stop passes on Slang IR from introducing copies of non-copyable types when they use `out` or `inout` (or even `borrow in`), then it seems like we will need to introduce even more fine-grained parameter passing modes, to reflect the difference between what we might call "value" borrows and "memory" borrows, where the former allow copies to be introduced and the latter don't. I really hope I don't have to go down that road, though, since it is already challenging to explain to folks what all the existing parameter-passing modes mean. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: James Helferty (NVIDIA) <jhelferty@nvidia.com>
1 parent 3838b39 commit 664fecd

20 files changed

+2158
-847
lines changed

source/slang/slang-ast-type.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -610,15 +610,15 @@ ParamPassingMode getParamPassingModeFromPossiblyWrappedParamType(Type* paramType
610610
}
611611
}
612612

613-
ParamPassingMode FuncType::getParamDirection(Index index)
613+
ParamPassingMode FuncType::getParamPassingMode(Index index)
614614
{
615-
auto paramType = getParamTypeWithDirectionWrapper(index);
615+
auto paramType = getParamTypeWithModeWrapper(index);
616616
return getParamPassingModeFromPossiblyWrappedParamType(paramType);
617617
}
618618

619619
Type* FuncType::getParamValueType(Index index)
620620
{
621-
auto paramType = getParamTypeWithDirectionWrapper(index);
621+
auto paramType = getParamTypeWithModeWrapper(index);
622622
if (auto wrappedParamType = as<ParamPassingModeType>(paramType))
623623
return wrappedParamType->getValueType();
624624
return paramType;
@@ -635,7 +635,7 @@ void FuncType::_toTextOverride(StringBuilder& out)
635635
{
636636
out << toSlice(", ");
637637
}
638-
out << getParamTypeWithDirectionWrapper(pp);
638+
out << getParamTypeWithModeWrapper(pp);
639639
}
640640
out << ") -> " << getResultType();
641641

@@ -659,8 +659,8 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s
659659
List<Type*> substParamTypes;
660660
for (Index pp = 0; pp < getParamCount(); pp++)
661661
{
662-
auto substParamType = as<Type>(
663-
getParamTypeWithDirectionWrapper(pp)->substituteImpl(astBuilder, subst, &diff));
662+
auto substParamType =
663+
as<Type>(getParamTypeWithModeWrapper(pp)->substituteImpl(astBuilder, subst, &diff));
664664
if (auto typePack = as<ConcreteTypePack>(substParamType))
665665
{
666666
// Unwrap the ConcreteTypePack and add each element as a parameter
@@ -695,7 +695,7 @@ Type* FuncType::_createCanonicalTypeOverride()
695695
List<Type*> canParamTypes;
696696
for (Index pp = 0; pp < getParamCount(); pp++)
697697
{
698-
canParamTypes.add(getParamTypeWithDirectionWrapper(pp)->getCanonicalType());
698+
canParamTypes.add(getParamTypeWithModeWrapper(pp)->getCanonicalType());
699699
}
700700

701701
FuncType* canType = getCurrentASTBuilder()->getFuncType(
@@ -1406,14 +1406,9 @@ Val* TextureTypeBase::getFormat()
14061406
return as<Type>(_getGenericTypeArg(this, 8));
14071407
}
14081408

1409-
Type* removeParamDirType(Type* type)
1409+
bool isCopyableType(Type* type)
14101410
{
1411-
for (auto paramDirType = as<ParamPassingModeType>(type); paramDirType;)
1412-
{
1413-
type = paramDirType->getValueType();
1414-
paramDirType = as<ParamPassingModeType>(type);
1415-
}
1416-
return type;
1411+
return !isNonCopyableType(type);
14171412
}
14181413

14191414
bool isNonCopyableType(Type* type)

source/slang/slang-ast-type.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,23 @@ class NamedExpressionType : public Type
814814
NamedExpressionType(DeclRef<TypeDefDecl> inDeclRef) { setOperands(inDeclRef); }
815815
};
816816

817+
/// Adjust a parameter-passing mode to account for the type of a parameter.
818+
///
819+
/// The `originalMode` should be the mode that would be used by default;
820+
/// usually this is a mode returned by `getExplicitlyDeclaredParamPassingMode()`
821+
/// or something similar.
822+
///
823+
/// The `paramType` should be the declared type of the parameter, not including
824+
/// any of the wrapper types that are used to represent parameter-passing modes.
825+
///
826+
/// This function is primarily concerned with adjusting a parameter-passing
827+
/// mode to account for non-copyable types, which may need different defaults
828+
/// than a copyable type.
829+
///
830+
ParamPassingMode adjustParamPassingModeBasedOnParamType(
831+
ParamPassingMode originalMode,
832+
Type* paramType);
833+
817834
// A function type is defined by its parameter types
818835
// and its result type.
819836
FIDDLE()
@@ -851,7 +868,7 @@ class FuncType : public Type
851868
/// the possibility of encountering these wrappers, and handle
852869
/// them accordingly.
853870
///
854-
Type* getParamTypeWithDirectionWrapper(Index index) { return as<Type>(getOperand(index)); }
871+
Type* getParamTypeWithModeWrapper(Index index) { return as<Type>(getOperand(index)); }
855872

856873
/// Get the type of one of the function's parameters, by index.
857874
///
@@ -872,14 +889,14 @@ class FuncType : public Type
872889

873890
/// Get the parameter-passing mode of one of the function's parameters, by index.
874891
///
875-
ParamPassingMode getParamDirection(Index index);
892+
ParamPassingMode getParamPassingMode(Index index);
876893

877894
/// Combined information on the type and parameter-passing mode of a parameter.
878895
///
879896
struct ParamInfo
880897
{
881-
/// The parameter-passing mode used for the parameter.
882-
ParamPassingMode direction = ParamPassingMode::In;
898+
/// The parameter-passing mode for the parameter.
899+
ParamPassingMode mode = ParamPassingMode::In;
883900

884901
/// The user-perceived type of the parameter.
885902
Type* type = nullptr;
@@ -890,7 +907,7 @@ class FuncType : public Type
890907
ParamInfo getParamInfo(Index index)
891908
{
892909
ParamInfo info;
893-
info.direction = getParamDirection(index);
910+
info.mode = getParamPassingMode(index);
894911
info.type = getParamValueType(index);
895912
return info;
896913
}
@@ -1162,7 +1179,7 @@ class ModifiedType : public Type
11621179
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
11631180
};
11641181

1165-
Type* removeParamDirType(Type* type);
1182+
bool isCopyableType(Type* type);
11661183
bool isNonCopyableType(Type* type);
11671184

11681185
} // namespace Slang

source/slang/slang-check-constraint.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,8 +1110,8 @@ bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
11101110
if (!TryUnifyTypes(
11111111
constraints,
11121112
unifyCtx,
1113-
fstFunType->getParamTypeWithDirectionWrapper(i),
1114-
sndFunType->getParamTypeWithDirectionWrapper(i)))
1113+
fstFunType->getParamTypeWithModeWrapper(i),
1114+
sndFunType->getParamTypeWithModeWrapper(i)))
11151115
return false;
11161116
}
11171117
return TryUnifyTypes(

source/slang/slang-check-conversion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,8 +2103,8 @@ bool SemanticsVisitor::tryCoerceLambdaToFuncType(
21032103
Index paramId = 0;
21042104
for (auto param : invokeFunc->getParameters())
21052105
{
2106-
auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param);
2107-
auto toParamType = toFuncType->getParamTypeWithDirectionWrapper(paramId);
2106+
auto paramType = getParamTypeWithModeWrapper(m_astBuilder, param);
2107+
auto toParamType = toFuncType->getParamTypeWithModeWrapper(paramId);
21082108
if (!paramType->equals(toParamType))
21092109
{
21102110
return false;

source/slang/slang-check-decl.cpp

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4209,8 +4209,8 @@ bool SemanticsVisitor::doesSignatureMatchRequirement(
42094209
{
42104210
auto requiredParam = requiredParams[paramIndex];
42114211
auto satisfyingParam = satisfyingParams[paramIndex];
4212-
if (getParameterDirection(requiredParam.getDecl()) !=
4213-
getParameterDirection(satisfyingParam.getDecl()))
4212+
if (getParamPassingMode(requiredParam.getDecl()) !=
4213+
getParamPassingMode(satisfyingParam.getDecl()))
42144214
return false;
42154215
auto requiredParamType = getType(m_astBuilder, requiredParam);
42164216
auto satisfyingParamType = getType(m_astBuilder, satisfyingParam);
@@ -5697,16 +5697,16 @@ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
56975697
auto synParam = *synParamIter;
56985698
auto calleeParam = *calleeParamIter;
56995699
if (!matchParamDirection(
5700-
getParameterDirection(calleeParam),
5701-
getParameterDirection(synParam)))
5700+
getParamPassingMode(calleeParam),
5701+
getParamPassingMode(synParam)))
57025702
{
57035703
if (outFailureDetails)
57045704
{
57055705
outFailureDetails->reason =
57065706
WitnessSynthesisFailureReason::ParameterDirMismatch;
57075707
outFailureDetails->candidateMethod = declRefExpr->declRef;
5708-
outFailureDetails->actualDir = getParameterDirection(calleeParam);
5709-
outFailureDetails->expectedDir = getParameterDirection(synParam);
5708+
outFailureDetails->actualDir = getParamPassingMode(calleeParam);
5709+
outFailureDetails->expectedDir = getParamPassingMode(synParam);
57105710
outFailureDetails->paramDecl = calleeParam;
57115711
}
57125712
return false;
@@ -9553,50 +9553,7 @@ void SemanticsDeclHeaderVisitor::visitParamDecl(ParamDecl* paramDecl)
95539553
checkMeshOutputDecl(paramDecl);
95549554
}
95559555

9556-
if (auto declRefType = as<DeclRefType>(paramDecl->type.type))
9557-
{
9558-
if (declRefType->getDeclRef().getDecl()->findModifier<NonCopyableTypeAttribute>())
9559-
{
9560-
// Always pass a non-copyable type by reference.
9561-
// Remove all existing direction modifiers, and replace them with a single Ref modifier.
9562-
List<Modifier*> newModifiers;
9563-
bool hasRefModifier = false;
9564-
bool isMutable = false;
9565-
for (auto modifier : paramDecl->modifiers)
9566-
{
9567-
if (as<InModifier>(modifier))
9568-
{
9569-
continue;
9570-
}
9571-
else if (as<InOutModifier>(modifier) || as<OutModifier>(modifier))
9572-
{
9573-
isMutable = true;
9574-
continue;
9575-
}
9576-
if (as<RefModifier>(modifier) || as<BorrowModifier>(modifier))
9577-
{
9578-
hasRefModifier = true;
9579-
}
9580-
newModifiers.add(modifier);
9581-
}
9582-
if (!hasRefModifier)
9583-
{
9584-
if (isMutable)
9585-
newModifiers.add(this->getASTBuilder()->create<RefModifier>());
9586-
else
9587-
newModifiers.add(this->getASTBuilder()->create<BorrowModifier>());
9588-
}
9589-
paramDecl->modifiers.first = newModifiers.getFirst();
9590-
for (Index i = 0; i < newModifiers.getCount(); i++)
9591-
{
9592-
if (i < newModifiers.getCount() - 1)
9593-
newModifiers[i]->next = newModifiers[i + 1];
9594-
else
9595-
newModifiers[i]->next = nullptr;
9596-
}
9597-
}
9598-
}
9599-
else if (isTypePack(paramDecl->type.type))
9556+
if (isTypePack(paramDecl->type.type))
96009557
{
96019558
// For now, we only allow parameter packs to be `const`.
96029559
bool hasConstModifier = false;
@@ -9605,13 +9562,26 @@ void SemanticsDeclHeaderVisitor::visitParamDecl(ParamDecl* paramDecl)
96059562
if (as<OutModifier>(modifier) || as<InOutModifier>(modifier) ||
96069563
as<RefModifier>(modifier) || as<BorrowModifier>(modifier))
96079564
{
9565+
// TODO(tfoley): The diagnostic in this case should probably not refer
9566+
// to the `const` modifier at all (since that is not actually what is
9567+
// required), and should instead note that a parameter pack may only
9568+
// be declared as a pure input parameter to a function (`in` or
9569+
// `borrow`).
9570+
//
96089571
getSink()->diagnose(modifier, Diagnostics::parameterPackMustBeConst);
96099572
}
96109573
else if (as<ConstModifier>(modifier))
96119574
{
96129575
hasConstModifier = true;
96139576
}
96149577
}
9578+
9579+
// TODO(tfoley): Rather than actually changing the modifiers
9580+
// on the parameter itself, this kind of logic should probably
9581+
// be folded into whatever logic computes the `QualType` for a
9582+
// parameter, and ensure that parameter packs are never treated
9583+
// as l-values.
9584+
//
96159585
if (!hasConstModifier)
96169586
{
96179587
auto constModifier = this->getASTBuilder()->create<ConstModifier>();
@@ -10093,7 +10063,7 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(
1009310063
{
1009410064
auto paramInfo = funcType->getParamInfo(i);
1009510065
auto paramType = paramInfo.type;
10096-
auto paramDir = paramInfo.direction;
10066+
auto paramDir = paramInfo.mode;
1009710067

1009810068
auto param = m_astBuilder->create<ParamDecl>();
1009910069
param->type.type = paramType;
@@ -12720,14 +12690,14 @@ void checkDerivativeAttributeImpl(
1272012690
//
1272112691
if (resolvedInvoke->arguments[ii]->type.type->equals(
1272212692
ctx.getASTBuilder()->getErrorType()) ||
12723-
funcType->getParamDirection(ii) != paramDirections[ii])
12693+
funcType->getParamPassingMode(ii) != paramDirections[ii])
1272412694
{
1272512695
visitor->getSink()->diagnose(
1272612696
attr,
1272712697
Diagnostics::customDerivativeSignatureMismatchAtPosition,
1272812698
ii,
1272912699
qualTypeToString(argList[ii]->type),
12730-
funcType->getParamTypeWithDirectionWrapper(ii)->toString());
12700+
funcType->getParamTypeWithModeWrapper(ii)->toString());
1273112701
}
1273212702
}
1273312703
// The `imaginaryArguments` list does not include the `this` parameter.
@@ -12882,7 +12852,7 @@ ArgsWithDirectionInfo getImaginaryArgsToFunc(
1288212852
arg->type.type = param->getType();
1288312853
arg->loc = loc;
1288412854
imaginaryArguments.add(arg);
12885-
directions.add(getParameterDirection(param));
12855+
directions.add(getParamPassingMode(param));
1288612856
}
1288712857
return {imaginaryArguments, directions, nullptr, ParamPassingMode::In};
1288812858
}
@@ -12938,7 +12908,7 @@ ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative(
1293812908
List<ParamPassingMode> expectedParamDirections;
1293912909
for (auto param : originalFuncDecl->getParameters())
1294012910
{
12941-
expectedParamDirections.add(getParameterDirection(param));
12911+
expectedParamDirections.add(getParamPassingMode(param));
1294212912
}
1294312913

1294412914
return {imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection};
@@ -12995,7 +12965,7 @@ ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative(
1299512965
arg->type.type = param->getType();
1299612966
arg->loc = loc;
1299712967

12998-
ParamPassingMode direction = getParameterDirection(param);
12968+
ParamPassingMode direction = getParamPassingMode(param);
1299912969

1300012970
bool isDiffParam = (!param->findModifier<NoDiffModifier>());
1300112971
if (isDiffParam)

source/slang/slang-check-expr.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2970,7 +2970,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr)
29702970
Index paramCount = funcType->getParamCount();
29712971
for (Index pp = 0; pp < paramCount; ++pp)
29722972
{
2973-
auto paramType = funcType->getParamTypeWithDirectionWrapper(pp);
2973+
auto paramType = funcType->getParamTypeWithModeWrapper(pp);
29742974
Expr* argExpr = nullptr;
29752975
ParamDecl* paramDecl = nullptr;
29762976
if (pp < invoke->arguments.getCount())
@@ -3727,11 +3727,10 @@ Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)
37273727
for (Index i = 0; i < originalType->getParamCount(); i++)
37283728
{
37293729
if (auto jvpParamType =
3730-
_toDifferentialParamType(originalType->getParamTypeWithDirectionWrapper(i)))
3730+
_toDifferentialParamType(originalType->getParamTypeWithModeWrapper(i)))
37313731
paramTypes.add(jvpParamType);
37323732
}
3733-
FuncType* jvpType =
3734-
m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
3733+
FuncType* jvpType = m_astBuilder->getFuncType(paramTypes.getArrayView(), resultType, errorType);
37353734

37363735
return jvpType;
37373736
}
@@ -3753,7 +3752,7 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType)
37533752

37543753
for (Index i = 0; i < originalType->getParamCount(); i++)
37553754
{
3756-
auto originalParamType = originalType->getParamTypeWithDirectionWrapper(i);
3755+
auto originalParamType = originalType->getParamTypeWithModeWrapper(i);
37573756

37583757
if (auto outType = as<OutType>(originalParamType))
37593758
{
@@ -3792,7 +3791,7 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType)
37923791
if (dOutType)
37933792
paramTypes.add(dOutType);
37943793

3795-
return m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
3794+
return m_astBuilder->getFuncType(paramTypes.getArrayView(), resultType, errorType);
37963795
}
37973796

37983797
struct HigherOrderInvokeExprCheckingActions
@@ -4837,7 +4836,7 @@ Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr)
48374836
genApp->arguments.add(returnTypeExp);
48384837
for (auto param : getMembersOfType<ParamDecl>(m_astBuilder, lambdaExpr->paramScopeDecl))
48394838
{
4840-
auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param);
4839+
auto paramType = getParamTypeWithModeWrapper(m_astBuilder, param);
48414840
auto paramTypeExp = synthesizer.emitStaticTypeExpr(paramType);
48424841
genApp->arguments.add(paramTypeExp);
48434842
}

0 commit comments

Comments
 (0)