@@ -9,6 +9,144 @@ import DifferentiationUnittest
99
1010var E2EDifferentiablePropertyTests = TestSuite ( " E2EDifferentiableProperty " )
1111
12+ struct TangentSpace : AdditiveArithmetic {
13+ let x , y : Float
14+ }
15+
16+ extension TangentSpace : Differentiable {
17+ typealias TangentVector = TangentSpace
18+ }
19+
20+ struct Space {
21+ /// `x` is a computed property with a custom vjp.
22+ var x : Float {
23+ @differentiable ( reverse)
24+ get { storedX }
25+ set { storedX = newValue }
26+ }
27+
28+ @derivative ( of: x)
29+ func vjpX( ) -> ( value: Float , pullback: ( Float ) -> TangentSpace ) {
30+ return ( x, { v in TangentSpace ( x: v, y: 0 ) } )
31+ }
32+
33+ private var storedX : Float
34+
35+ @differentiable ( reverse)
36+ var y : Float
37+
38+ init ( x: Float , y: Float ) {
39+ self . storedX = x
40+ self . y = y
41+ }
42+ }
43+
44+ extension Space : Differentiable {
45+ typealias TangentVector = TangentSpace
46+ mutating func move( by offset: TangentSpace ) {
47+ x. move ( by: offset. x)
48+ y. move ( by: offset. y)
49+ }
50+ }
51+
52+ E2EDifferentiablePropertyTests . test ( " computed property " ) {
53+ let actualGrad = gradient ( at: Space ( x: 0 , y: 0 ) ) { ( point: Space ) -> Float in
54+ return 2 * point. x
55+ }
56+ let expectedGrad = TangentSpace ( x: 2 , y: 0 )
57+ expectEqual ( expectedGrad, actualGrad)
58+ }
59+
60+ E2EDifferentiablePropertyTests . test ( " stored property " ) {
61+ let actualGrad = gradient ( at: Space ( x: 0 , y: 0 ) ) { ( point: Space ) -> Float in
62+ return 3 * point. y
63+ }
64+ let expectedGrad = TangentSpace ( x: 0 , y: 3 )
65+ expectEqual ( expectedGrad, actualGrad)
66+ }
67+
68+ struct GenericMemberWrapper < T : Differentiable > : Differentiable {
69+ // Stored property.
70+ @differentiable ( reverse)
71+ var x : T
72+
73+ func vjpX( ) -> ( T , ( T . TangentVector ) -> GenericMemberWrapper . TangentVector ) {
74+ return ( x, { TangentVector ( x: $0) } )
75+ }
76+ }
77+
78+ E2EDifferentiablePropertyTests . test ( " generic stored property " ) {
79+ let actualGrad = gradient ( at: GenericMemberWrapper < Float > ( x: 1 ) ) { point in
80+ return 2 * point. x
81+ }
82+ let expectedGrad = GenericMemberWrapper< Float> . TangentVector( x: 2 )
83+ expectEqual ( expectedGrad, actualGrad)
84+ }
85+
86+ struct ProductSpaceSelfTangent : AdditiveArithmetic {
87+ let x , y : Float
88+ }
89+
90+ extension ProductSpaceSelfTangent : Differentiable {
91+ typealias TangentVector = ProductSpaceSelfTangent
92+ }
93+
94+ E2EDifferentiablePropertyTests . test ( " fieldwise product space, self tangent " ) {
95+ let actualGrad = gradient ( at: ProductSpaceSelfTangent ( x: 0 , y: 0 ) ) { ( point: ProductSpaceSelfTangent ) -> Float in
96+ return 5 * point. y
97+ }
98+ let expectedGrad = ProductSpaceSelfTangent ( x: 0 , y: 5 )
99+ expectEqual ( expectedGrad, actualGrad)
100+ }
101+
102+ struct ProductSpaceOtherTangentTangentSpace : AdditiveArithmetic {
103+ let x , y : Float
104+ }
105+
106+ extension ProductSpaceOtherTangentTangentSpace : Differentiable {
107+ typealias TangentVector = ProductSpaceOtherTangentTangentSpace
108+ }
109+
110+ struct ProductSpaceOtherTangent {
111+ var x , y : Float
112+ }
113+
114+ extension ProductSpaceOtherTangent : Differentiable {
115+ typealias TangentVector = ProductSpaceOtherTangentTangentSpace
116+ mutating func move( by offset: ProductSpaceOtherTangentTangentSpace ) {
117+ x. move ( by: offset. x)
118+ y. move ( by: offset. y)
119+ }
120+ }
121+
122+ E2EDifferentiablePropertyTests . test ( " fieldwise product space, other tangent " ) {
123+ let actualGrad = gradient (
124+ at: ProductSpaceOtherTangent ( x: 0 , y: 0 )
125+ ) { ( point: ProductSpaceOtherTangent ) -> Float in
126+ return 7 * point. y
127+ }
128+ let expectedGrad = ProductSpaceOtherTangentTangentSpace ( x: 0 , y: 7 )
129+ expectEqual ( expectedGrad, actualGrad)
130+ }
131+
132+ E2EDifferentiablePropertyTests . test ( " computed property " ) {
133+ struct TF_544 : Differentiable {
134+ var value : Float
135+ @differentiable ( reverse)
136+ var computed : Float {
137+ get { value }
138+ set { value = newValue }
139+ }
140+ }
141+ let actualGrad = gradient ( at: TF_544 ( value: 2.4 ) ) { x in
142+ return x. computed * x. computed
143+ }
144+ let expectedGrad = TF_544 . TangentVector ( value: 4.8 )
145+ expectEqual ( expectedGrad, actualGrad)
146+ }
147+
148+ /* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
149+ We cannot use `Tracked<T>` :(
12150struct TangentSpace : AdditiveArithmetic {
13151 let x, y: Tracked<Float>
14152}
@@ -144,5 +282,6 @@ E2EDifferentiablePropertyTests.testWithLeakChecking("computed property") {
144282 let expectedGrad = TF_544.TangentVector(value: 4.8)
145283 expectEqual(expectedGrad, actualGrad)
146284}
285+ */
147286
148287runAllTests ( )
0 commit comments