Skip to content

Commit 05bf9da

Browse files
authored
Merge pull request #1442 from swiftwasm/master
[pull] swiftwasm from master
2 parents 91a93c9 + bf47403 commit 05bf9da

File tree

14 files changed

+355
-23
lines changed

14 files changed

+355
-23
lines changed

include/swift/AST/Attr.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,13 +1715,6 @@ class OriginallyDefinedInAttr: public DeclAttribute {
17151715
}
17161716
};
17171717

1718-
/// A declaration name with location.
1719-
struct DeclNameRefWithLoc {
1720-
DeclNameRef Name;
1721-
DeclNameLoc Loc;
1722-
Optional<AccessorKind> AccessorKind;
1723-
};
1724-
17251718
/// Attribute that marks a function as differentiable.
17261719
///
17271720
/// Examples:
@@ -1847,6 +1840,18 @@ class DifferentiableAttr final
18471840
}
18481841
};
18491842

1843+
/// A declaration name with location.
1844+
struct DeclNameRefWithLoc {
1845+
/// The declaration name.
1846+
DeclNameRef Name;
1847+
/// The declaration name location.
1848+
DeclNameLoc Loc;
1849+
/// An optional accessor kind.
1850+
Optional<AccessorKind> AccessorKind;
1851+
1852+
void print(ASTPrinter &Printer) const;
1853+
};
1854+
18501855
/// The `@derivative(of:)` attribute registers a function as a derivative of
18511856
/// another function-like declaration: a 'func', 'init', 'subscript', or 'var'
18521857
/// computed property declaration.

lib/AST/Attr.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
10521052
Printer.printAttrName("@derivative");
10531053
Printer << "(of: ";
10541054
auto *attr = cast<DerivativeAttr>(this);
1055-
Printer << attr->getOriginalFunctionName().Name;
1055+
if (auto *baseType = attr->getBaseTypeRepr())
1056+
baseType->print(Printer, Options);
1057+
attr->getOriginalFunctionName().print(Printer);
10561058
auto *derivative = cast<AbstractFunctionDecl>(D);
10571059
auto diffParamsString = getDifferentiationParametersClauseString(
10581060
derivative, attr->getParameterIndices(), attr->getParsedParameters(),
@@ -1067,7 +1069,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
10671069
Printer.printAttrName("@transpose");
10681070
Printer << "(of: ";
10691071
auto *attr = cast<TransposeAttr>(this);
1070-
Printer << attr->getOriginalFunctionName().Name;
1072+
if (auto *baseType = attr->getBaseTypeRepr())
1073+
baseType->print(Printer, Options);
1074+
attr->getOriginalFunctionName().print(Printer);
10711075
auto *transpose = cast<AbstractFunctionDecl>(D);
10721076
auto transParamsString = getDifferentiationParametersClauseString(
10731077
transpose, attr->getParameterIndices(), attr->getParsedParameters(),
@@ -1719,6 +1723,12 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
17191723
return original->getGenericEnvironment();
17201724
}
17211725

1726+
void DeclNameRefWithLoc::print(ASTPrinter &Printer) const {
1727+
Printer << Name;
1728+
if (AccessorKind)
1729+
Printer << '.' << getAccessorLabel(*AccessorKind);
1730+
}
1731+
17221732
void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
17231733
bool omitWrtClause) const {
17241734
StreamPrinter P(OS);

lib/AST/AutoDiff.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,14 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
422422
}
423423
}
424424

425+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
426+
const DeclNameRefWithLoc &name) {
427+
os << name.Name;
428+
if (auto accessorKind = name.AccessorKind)
429+
os << '.' << getAccessorLabel(*accessorKind);
430+
return os;
431+
}
432+
425433
bool swift::operator==(const TangentPropertyInfo::Error &lhs,
426434
const TangentPropertyInfo::Error &rhs) {
427435
if (lhs.kind != rhs.kind)

lib/SILOptimizer/Analysis/EscapeAnalysis.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,6 +2079,10 @@ void EscapeAnalysis::analyzeInstruction(SILInstruction *I,
20792079
return;
20802080
}
20812081

2082+
// Incidental uses produce no values and have no effect on their operands.
2083+
if (isIncidentalUse(I))
2084+
return;
2085+
20822086
// Instructions which return the address of non-writable memory cannot have
20832087
// an effect on escaping.
20842088
if (isNonWritableMemoryAddress(I))

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,6 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) {
457457
// optimization.
458458
P.addGlobalOpt();
459459

460-
// We earlier eliminated ownership if we are not compiling the stdlib. Now
461-
// handle the stdlib functions.
462-
P.addNonTransparentFunctionOwnershipModelEliminator();
463-
464460
// Add the outliner pass (Osize).
465461
P.addOutliner();
466462

@@ -485,6 +481,11 @@ static void addHighLevelFunctionPipeline(SILPassPipelinePlan &P) {
485481
P.startPipeline("HighLevel,Function+EarlyLoopOpt");
486482
// FIXME: update EagerSpecializer to be a function pass!
487483
P.addEagerSpecializer();
484+
485+
// We earlier eliminated ownership if we are not compiling the stdlib. Now
486+
// handle the stdlib functions.
487+
P.addNonTransparentFunctionOwnershipModelEliminator();
488+
488489
addFunctionPasses(P, OptimizationLevelKind::HighLevel);
489490

490491
addHighLevelLoopOptPasses(P);
@@ -714,6 +715,8 @@ SILPassPipelinePlan::getPerformancePassPipeline(const SILOptions &Options) {
714715
//
715716
// FIXME: When *not* emitting a .swiftmodule, skip the high-level function
716717
// pipeline to save compile time.
718+
//
719+
// NOTE: Ownership is now stripped within this function!
717720
addHighLevelFunctionPipeline(P);
718721

719722
addHighLevelModulePipeline(P);

lib/Serialization/Deserialization.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4371,16 +4371,26 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43714371
case decls_block::Derivative_DECL_ATTR: {
43724372
bool isImplicit;
43734373
uint64_t origNameId;
4374+
bool hasAccessorKind;
4375+
uint64_t rawAccessorKind;
43744376
DeclID origDeclId;
43754377
uint64_t rawDerivativeKind;
43764378
ArrayRef<uint64_t> parameters;
43774379

43784380
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
4379-
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
4380-
parameters);
4381+
scratch, isImplicit, origNameId, hasAccessorKind, rawAccessorKind,
4382+
origDeclId, rawDerivativeKind, parameters);
4383+
4384+
Optional<AccessorKind> accessorKind = None;
4385+
if (hasAccessorKind) {
4386+
auto maybeAccessorKind = getActualAccessorKind(rawAccessorKind);
4387+
if (!maybeAccessorKind)
4388+
MF.fatal();
4389+
accessorKind = *maybeAccessorKind;
4390+
}
43814391

43824392
DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)),
4383-
DeclNameLoc(), None};
4393+
DeclNameLoc(), accessorKind};
43844394
auto derivativeKind =
43854395
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
43864396
if (!derivativeKind)

lib/Serialization/ModuleFormat.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t SWIFTMODULE_VERSION_MINOR = 563; // unchecked_value_cast
58+
const uint16_t SWIFTMODULE_VERSION_MINOR = 564; // `@derivative` attribute accessor kind
5959

6060
/// A standard hash seed used for all string hashes in a serialized module.
6161
///
@@ -1848,6 +1848,8 @@ namespace decls_block {
18481848
Derivative_DECL_ATTR,
18491849
BCFixed<1>, // Implicit flag.
18501850
IdentifierIDField, // Original name.
1851+
BCFixed<1>, // Has original accessor kind?
1852+
AccessorKindField, // Original accessor kind.
18511853
DeclIDField, // Original function declaration.
18521854
AutoDiffDerivativeFunctionKindField, // Derivative function kind.
18531855
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.

lib/Serialization/Serialization.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,19 +2431,25 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
24312431
assert(attr->getOriginalFunction(ctx) &&
24322432
"`@derivative` attribute should have original declaration set "
24332433
"during construction or parsing");
2434-
auto origName = attr->getOriginalFunctionName().Name.getBaseName();
2434+
auto origDeclNameRef = attr->getOriginalFunctionName();
2435+
auto origName = origDeclNameRef.Name.getBaseName();
24352436
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
24362437
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction(ctx));
24372438
auto derivativeKind =
24382439
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
2440+
uint8_t rawAccessorKind = 0;
2441+
auto origAccessorKind = origDeclNameRef.AccessorKind;
2442+
if (origAccessorKind)
2443+
rawAccessorKind = uint8_t(getStableAccessorKind(*origAccessorKind));
24392444
auto *parameterIndices = attr->getParameterIndices();
24402445
assert(parameterIndices && "Parameter indices must be resolved");
24412446
SmallVector<bool, 4> paramIndicesVector;
24422447
for (unsigned i : range(parameterIndices->getCapacity()))
24432448
paramIndicesVector.push_back(parameterIndices->contains(i));
24442449
DerivativeDeclAttrLayout::emitRecord(
24452450
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
2446-
origDeclID, derivativeKind, paramIndicesVector);
2451+
origAccessorKind.hasValue(), rawAccessorKind, origDeclID,
2452+
derivativeKind, paramIndicesVector);
24472453
return;
24482454
}
24492455

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ where Element: Differentiable {
4242
return (base, { $0 })
4343
}
4444

45+
@usableFromInline
46+
@derivative(of: base)
47+
func _jvpBase() -> (
48+
value: [Element], differential: (Array<Element>.TangentVector) -> TangentVector
49+
) {
50+
return (base, { $0 })
51+
}
52+
4553
/// Creates a differentiable view of the given array.
4654
public init(_ base: [Element]) { self._base = base }
4755

@@ -53,6 +61,14 @@ where Element: Differentiable {
5361
return (Array.DifferentiableView(base), { $0 })
5462
}
5563

64+
@usableFromInline
65+
@derivative(of: init(_:))
66+
static func _jvpInit(_ base: [Element]) -> (
67+
value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector
68+
) {
69+
return (Array.DifferentiableView(base), { $0 })
70+
}
71+
5672
public typealias TangentVector =
5773
Array<Element.TangentVector>.DifferentiableView
5874

@@ -191,6 +207,17 @@ extension Array where Element: Differentiable {
191207
return (self[index], pullback)
192208
}
193209

210+
@usableFromInline
211+
@derivative(of: subscript)
212+
func _jvpSubscript(index: Int) -> (
213+
value: Element, differential: (TangentVector) -> Element.TangentVector
214+
) {
215+
func differential(_ v: TangentVector) -> Element.TangentVector {
216+
return v[index]
217+
}
218+
return (self[index], differential)
219+
}
220+
194221
@usableFromInline
195222
@derivative(of: +)
196223
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
@@ -210,8 +237,26 @@ extension Array where Element: Differentiable {
210237
}
211238
return (lhs + rhs, pullback)
212239
}
240+
241+
@usableFromInline
242+
@derivative(of: +)
243+
static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> (
244+
value: Self,
245+
differential: (TangentVector, TangentVector) -> TangentVector
246+
) {
247+
func differential(_ l: TangentVector, _ r: TangentVector) -> TangentVector {
248+
precondition(
249+
l.base.count == lhs.count && r.base.count == rhs.count, """
250+
Tangent vectors with invalid count; expected to equal the \
251+
operand counts \(lhs.count) and \(rhs.count)
252+
""")
253+
return .init(l.base + r.base)
254+
}
255+
return (lhs + rhs, differential)
256+
}
213257
}
214258

259+
215260
extension Array where Element: Differentiable {
216261
@usableFromInline
217262
@derivative(of: append)
@@ -277,6 +322,17 @@ extension Array where Element: Differentiable {
277322
}
278323
)
279324
}
325+
326+
@usableFromInline
327+
@derivative(of: init(repeating:count:))
328+
static func _jvpInit(repeating repeatedValue: Element, count: Int) -> (
329+
value: Self, differential: (Element.TangentVector) -> TangentVector
330+
) {
331+
(
332+
value: Self(repeating: repeatedValue, count: count),
333+
differential: { v in TangentVector(.init(repeating: v, count: count)) }
334+
)
335+
}
280336
}
281337

282338
//===----------------------------------------------------------------------===//
@@ -312,6 +368,27 @@ extension Array where Element: Differentiable {
312368
}
313369
return (value: values, pullback: pullback)
314370
}
371+
372+
@inlinable
373+
@derivative(of: differentiableMap)
374+
internal func _jvpDifferentiableMap<Result: Differentiable>(
375+
_ body: @differentiable (Element) -> Result
376+
) -> (
377+
value: [Result],
378+
differential: (Array.TangentVector) -> Array<Result>.TangentVector
379+
) {
380+
var values: [Result] = []
381+
var differentials: [(Element.TangentVector) -> Result.TangentVector] = []
382+
for x in self {
383+
let (y, df) = valueWithDifferential(at: x, in: body)
384+
values.append(y)
385+
differentials.append(df)
386+
}
387+
func differential(_ tans: Array.TangentVector) -> Array<Result>.TangentVector {
388+
.init(zip(tans.base, differentials).map { tan, df in df(tan) })
389+
}
390+
return (value: values, differential: differential)
391+
}
315392
}
316393

317394
extension Array where Element: Differentiable {
@@ -361,4 +438,33 @@ extension Array where Element: Differentiable {
361438
}
362439
)
363440
}
441+
442+
@inlinable
443+
@derivative(of: differentiableReduce, wrt: (self, initialResult))
444+
func _jvpDifferentiableReduce<Result: Differentiable>(
445+
_ initialResult: Result,
446+
_ nextPartialResult: @differentiable (Result, Element) -> Result
447+
) -> (value: Result,
448+
differential: (Array.TangentVector, Result.TangentVector)
449+
-> Result.TangentVector) {
450+
var differentials:
451+
[(Result.TangentVector, Element.TangentVector) -> Result.TangentVector]
452+
= []
453+
let count = self.count
454+
differentials.reserveCapacity(count)
455+
var result = initialResult
456+
for element in self {
457+
let (y, df) =
458+
valueWithDifferential(at: result, element, in: nextPartialResult)
459+
result = y
460+
differentials.append(df)
461+
}
462+
return (value: result, differential: { dSelf, dInitial in
463+
var dResult = dInitial
464+
for (dElement, df) in zip(dSelf.base, differentials) {
465+
dResult = df(dResult, dElement)
466+
}
467+
return dResult
468+
})
469+
}
364470
}

0 commit comments

Comments
 (0)