@@ -371,7 +371,9 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
371
371
assert (!seai->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>() &&
372
372
" `@noDerivative` struct projections should never be active" );
373
373
auto adjSource = getAdjointBuffer (origBB, seai->getOperand ());
374
- auto *tanField = getTangentStoredProperty (getContext (), seai, getInvoker ());
374
+ auto structType = remapType (seai->getOperand ()->getType ()).getASTType ();
375
+ auto *tanField =
376
+ getTangentStoredProperty (getContext (), seai, structType, getInvoker ());
375
377
assert (tanField && " Invalid projections should have been diagnosed" );
376
378
return builder.createStructElementAddr (seai->getLoc (), adjSource, tanField);
377
379
}
@@ -400,7 +402,10 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
400
402
auto loc = reai->getLoc ();
401
403
// Get the class operand, stripping `begin_borrow`.
402
404
auto classOperand = stripBorrow (reai->getOperand ());
403
- auto *tanField = getTangentStoredProperty (getContext (), reai, getInvoker ());
405
+ auto classType = remapType (reai->getOperand ()->getType ()).getASTType ();
406
+ auto *tanField =
407
+ getTangentStoredProperty (getContext (), reai->getField (), classType,
408
+ reai->getLoc (), getInvoker ());
404
409
assert (tanField && " Invalid projections should have been diagnosed" );
405
410
// Create a local allocation for the element adjoint buffer.
406
411
auto eltTanType = tanField->getValueInterfaceType ()->getCanonicalType ();
@@ -666,8 +671,9 @@ bool PullbackEmitter::runForSemanticMemberGetter() {
666
671
667
672
// Look up the corresponding field in the tangent space.
668
673
auto *origField = cast<VarDecl>(accessor->getStorage ());
669
- auto *tanField =
670
- getTangentStoredProperty (getContext (), origField, pbLoc, getInvoker ());
674
+ auto baseType = remapType (origSelf->getType ()).getASTType ();
675
+ auto *tanField = getTangentStoredProperty (getContext (), origField, baseType,
676
+ pbLoc, getInvoker ());
671
677
if (!tanField) {
672
678
errorOccurred = true ;
673
679
return true ;
@@ -772,8 +778,9 @@ bool PullbackEmitter::runForSemanticMemberSetter() {
772
778
773
779
// Look up the corresponding field in the tangent space.
774
780
auto *origField = cast<VarDecl>(accessor->getStorage ());
775
- auto *tanField =
776
- getTangentStoredProperty (getContext (), origField, pbLoc, getInvoker ());
781
+ auto baseType = remapType (origSelf->getType ()).getASTType ();
782
+ auto *tanField = getTangentStoredProperty (getContext (), origField, baseType,
783
+ pbLoc, getInvoker ());
777
784
if (!tanField) {
778
785
errorOccurred = true ;
779
786
return true ;
@@ -882,7 +889,10 @@ bool PullbackEmitter::run() {
882
889
}
883
890
// Diagnose unsupported stored property projections.
884
891
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
885
- if (!getTangentStoredProperty (getContext (), inst, getInvoker ())) {
892
+ assert (inst->getNumOperands () == 1 );
893
+ auto baseType = remapType (inst->getOperand (0 )->getType ()).getASTType ();
894
+ if (!getTangentStoredProperty (getContext (), inst, baseType,
895
+ getInvoker ())) {
886
896
errorOccurred = true ;
887
897
return true ;
888
898
}
@@ -1699,8 +1709,8 @@ void PullbackEmitter::visitStructInst(StructInst *si) {
1699
1709
if (field->getAttrs ().hasAttribute <NoDerivativeAttr>())
1700
1710
continue ;
1701
1711
// Find the corresponding field in the tangent space.
1702
- auto *tanField =
1703
- getTangentStoredProperty ( getContext (), field, loc, getInvoker ());
1712
+ auto *tanField = getTangentStoredProperty ( getContext (), field, structTy,
1713
+ loc, getInvoker ());
1704
1714
if (!tanField) {
1705
1715
errorOccurred = true ;
1706
1716
return ;
@@ -1732,21 +1742,23 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) {
1732
1742
1733
1743
void PullbackEmitter::visitStructExtractInst (StructExtractInst *sei) {
1734
1744
auto *bb = sei->getParent ();
1745
+ auto loc = getValidLocation (sei);
1735
1746
auto structTy = remapType (sei->getOperand ()->getType ()).getASTType ();
1736
1747
auto tangentVectorTy = getTangentSpace (structTy)->getCanonicalType ();
1737
1748
assert (!getTypeLowering (tangentVectorTy).isAddressOnly ());
1738
1749
auto tangentVectorSILTy = SILType::getPrimitiveObjectType (tangentVectorTy);
1739
1750
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct ();
1740
1751
assert (tangentVectorDecl);
1741
1752
// Find the corresponding field in the tangent space.
1742
- auto *tanField = getTangentStoredProperty (getContext (), sei, getInvoker ());
1753
+ auto *tanField =
1754
+ getTangentStoredProperty (getContext (), sei, structTy, getInvoker ());
1743
1755
assert (tanField && " Invalid projections should have been diagnosed" );
1744
1756
// Accumulate adjoint for the `struct_extract` operand.
1745
1757
auto av = getAdjointValue (bb, sei);
1746
1758
switch (av.getKind ()) {
1747
1759
case AdjointValueKind::Zero:
1748
1760
addAdjointValue (bb, sei->getOperand (),
1749
- makeZeroAdjointValue (tangentVectorSILTy), sei-> getLoc () );
1761
+ makeZeroAdjointValue (tangentVectorSILTy), loc );
1750
1762
break ;
1751
1763
case AdjointValueKind::Concrete:
1752
1764
case AdjointValueKind::Aggregate: {
@@ -1765,7 +1777,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) {
1765
1777
}
1766
1778
addAdjointValue (bb, sei->getOperand (),
1767
1779
makeAggregateAdjointValue (tangentVectorSILTy, eltVals),
1768
- sei-> getLoc () );
1780
+ loc );
1769
1781
}
1770
1782
}
1771
1783
}
@@ -1775,7 +1787,9 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) {
1775
1787
auto loc = reai->getLoc ();
1776
1788
auto adjBuf = getAdjointBuffer (bb, reai);
1777
1789
auto classOperand = reai->getOperand ();
1778
- auto *tanField = getTangentStoredProperty (getContext (), reai, getInvoker ());
1790
+ auto classType = remapType (reai->getOperand ()->getType ()).getASTType ();
1791
+ auto *tanField =
1792
+ getTangentStoredProperty (getContext (), reai, classType, getInvoker ());
1779
1793
assert (tanField && " Invalid projections should have been diagnosed" );
1780
1794
switch (getTangentValueCategory (classOperand)) {
1781
1795
case SILValueCategory::Object: {
0 commit comments