Skip to content

Commit 130bdc0

Browse files
authored
Merge pull request #214 from Atry/hlists
Support differentiable HList
2 parents 06ea8e8 + cf69c2e commit 130bdc0

File tree

8 files changed

+149
-59
lines changed

8 files changed

+149
-59
lines changed

DeepLearning/src/main/scala/com/thoughtworks/deeplearning/DeepLearning.scala

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,79 @@
11
package com.thoughtworks.deeplearning
2+
import java.io.{PrintStream, PrintWriter}
3+
24
import com.thoughtworks.deeplearning.DeepLearning.Tape
35
import com.thoughtworks.continuation._
46
import com.thoughtworks.future._
5-
67
import scalaz.syntax.all._
78
import com.thoughtworks.raii.asynchronous._
89
import simulacrum.typeclass
910

1011
import scala.language.implicitConversions
1112
import algebra.ring.MultiplicativeMonoid
13+
import scalaz.Semigroup
1214

1315
object DeepLearning {
1416

17+
implicit object multipleExceptionThrowableSemigroup extends Semigroup[Throwable] {
18+
override def append(f1: Throwable, f2: => Throwable): Throwable =
19+
f1 match {
20+
case me1: AbstractMultipleException =>
21+
f2 match {
22+
case me2: AbstractMultipleException => MultipleException(me1.throwableSet ++ me2.throwableSet)
23+
case e: Throwable => MultipleException(me1.throwableSet + e)
24+
}
25+
case _: Throwable =>
26+
f2 match {
27+
case me2: AbstractMultipleException => MultipleException(me2.throwableSet + f1)
28+
case `f1` => f1
29+
case e: Throwable => MultipleException(Set(f1, e))
30+
}
31+
}
32+
}
33+
34+
private final case class MultipleException(throwableSet: Set[Throwable])
35+
extends DeepLearning.AbstractMultipleException
36+
37+
abstract class AbstractMultipleException extends RuntimeException("Multiple exceptions found") {
38+
39+
def throwableSet: Set[Throwable]
40+
41+
override def toString: String = throwableSet.mkString("\n")
42+
43+
override def printStackTrace(): Unit = {
44+
for (throwable <- throwableSet) {
45+
throwable.printStackTrace()
46+
}
47+
}
48+
49+
override def printStackTrace(s: PrintStream): Unit = {
50+
for (throwable <- throwableSet) {
51+
throwable.printStackTrace(s)
52+
}
53+
}
54+
55+
override def printStackTrace(s: PrintWriter): Unit = {
56+
for (throwable <- throwableSet) {
57+
throwable.printStackTrace(s)
58+
}
59+
}
60+
61+
override def getStackTrace: Array[StackTraceElement] = synchronized {
62+
super.getStackTrace match {
63+
case null =>
64+
setStackTrace(throwableSet.flatMap(_.getStackTrace)(collection.breakOut))
65+
super.getStackTrace
66+
case stackTrace =>
67+
stackTrace
68+
}
69+
}
70+
71+
override def fillInStackTrace(): this.type = {
72+
this
73+
}
74+
75+
}
76+
1577
/** The node of wengert list created during [[DeepLearning.forward forward]] pass */
1678
final case class Tape[+Data, -Delta](data: Data, backward: Do[Delta] => UnitContinuation[Unit])
1779

build.sbt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ lazy val `plugins-Logging` = project.dependsOn(`plugins-Layers`, `plugins-Weight
1616

1717
lazy val `plugins-Operators` = project
1818

19+
lazy val `plugins-HLists` = project.dependsOn(DeepLearning)
20+
1921
lazy val `plugins-FloatTraining` = project.dependsOn(`plugins-Training`)
2022

2123
lazy val `plugins-FloatLiterals` = project.dependsOn(`DeepLearning`)
@@ -114,6 +116,7 @@ lazy val `plugins-CumulativeDoubleLayers` =
114116

115117
lazy val `plugins-Builtins` =
116118
project.dependsOn(
119+
`plugins-HLists`,
117120
`plugins-ImplicitsSingleton`,
118121
`plugins-Layers`,
119122
`plugins-Weights`,

plugins-DoubleWeights/build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ libraryDependencies += "com.thoughtworks.feature" %% "implicitapply" % "2.3.0-M8
44

55
libraryDependencies += "com.thoughtworks.feature" %% "factory" % "2.3.0-M8"
66

7-
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.2"
7+
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.3"

plugins-FloatWeights/build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ libraryDependencies += "com.thoughtworks.feature" %% "implicitapply" % "2.3.0-M8
44

55
libraryDependencies += "com.thoughtworks.feature" %% "factory" % "2.3.0-M8"
66

7-
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.2"
7+
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.3"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package com.thoughtworks.deeplearning.plugins
2+
3+
import com.thoughtworks.continuation._
4+
import com.thoughtworks.deeplearning.DeepLearning
5+
import com.thoughtworks.deeplearning.DeepLearning.Tape
6+
import com.thoughtworks.raii.asynchronous._
7+
import scalaz.Applicative
8+
import scalaz.syntax.all._
9+
import scalaz.Tags.Parallel
10+
import shapeless.{::, HList, HNil}
11+
12+
import java.io.{PrintStream, PrintWriter}
13+
14+
import scalaz.Semigroup
15+
16+
private object HLists {
17+
18+
implicit val doParallelApplicative =
19+
asynchronousDoParallelApplicative(DeepLearning.multipleExceptionThrowableSemigroup)
20+
21+
private val noop: Do[HNil] => UnitContinuation[Unit] = {
22+
Function.const(UnitContinuation.now(()))
23+
}
24+
25+
}
26+
27+
/**
28+
* @author 杨博 (Yang Bo)
29+
*/
30+
trait HLists {
31+
import com.thoughtworks.deeplearning.plugins.HLists._
32+
33+
trait ImplicitsApi {
34+
implicit def hnilDeepLearning[L <: HNil]: DeepLearning.Aux[L, HNil, HNil] = new DeepLearning[L] {
35+
type Data = HNil
36+
type Delta = HNil
37+
38+
def forward(differentiable: L): Do[Tape[Data, Delta]] = {
39+
Do.now(Tape(HNil, noop))
40+
}
41+
}
42+
43+
implicit def hconsDeepLearning[Head, Tail <: HList, HeadData, TailData <: HList, HeadDelta, TailDelta <: HList](
44+
implicit headDeepLearning: DeepLearning.Aux[Head, HeadData, HeadDelta],
45+
tailDeepLearning: DeepLearning.Aux[Tail, TailData, TailDelta])
46+
: DeepLearning.Aux[Head :: Tail, HeadData :: TailData, HeadDelta :: TailDelta] = new DeepLearning[Head :: Tail] {
47+
type Data = HeadData :: TailData
48+
type Delta = HeadDelta :: TailDelta
49+
50+
def forward(differentiable: Head :: Tail): Do[Tape[Data, Delta]] = {
51+
val head :: tail = differentiable
52+
val doHead: ParallelDo[Tape[HeadData, HeadDelta]] = Parallel(headDeepLearning.forward(head))
53+
54+
val doTail: ParallelDo[Tape[TailData, TailDelta]] = Parallel(tailDeepLearning.forward(tail))
55+
56+
Parallel.unwrap(Applicative[ParallelDo].tuple2(doHead, doTail)).map {
57+
case (Tape(headData, headBackward), Tape(tailData, tailBackward)) =>
58+
def backward(doDelta: Do[HeadDelta :: TailDelta]) = {
59+
val continuationHead: ParallelContinuation[Unit] = Parallel(headBackward(doDelta.map(_.head)))
60+
val continuationTail: ParallelContinuation[Unit] = Parallel(tailBackward(doDelta.map(_.tail)))
61+
Parallel.unwrap(continuationParallelApplicative.apply2(continuationHead, continuationTail) {
62+
(_: Unit, _: Unit) =>
63+
()
64+
})
65+
}
66+
Tape(headData :: tailData, backward)
67+
}
68+
69+
}
70+
71+
}
72+
}
73+
74+
type Implicits <: ImplicitsApi
75+
76+
}

plugins-INDArrayLayers/src/main/scala-2.11/com/thoughtworks/deeplearning/plugins/INDArrayLayers.scala

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,43 +23,7 @@ import com.thoughtworks.dsl.Dsl
2323

2424
object INDArrayLayers {
2525

26-
final case class MultipleException(throwableSet: Set[Throwable])
27-
extends RuntimeException("Multiple exceptions found") {
28-
override def toString: String = throwableSet.mkString("\n")
29-
30-
override def printStackTrace(): Unit = {
31-
for (throwable <- throwableSet) {
32-
throwable.printStackTrace()
33-
}
34-
}
35-
36-
override def printStackTrace(s: PrintStream): Unit = {
37-
for (throwable <- throwableSet) {
38-
throwable.printStackTrace(s)
39-
}
40-
}
41-
42-
override def printStackTrace(s: PrintWriter): Unit = {
43-
for (throwable <- throwableSet) {
44-
throwable.printStackTrace(s)
45-
}
46-
}
47-
48-
override def getStackTrace: Array[StackTraceElement] = synchronized {
49-
super.getStackTrace match {
50-
case null =>
51-
setStackTrace(throwableSet.flatMap(_.getStackTrace)(collection.breakOut))
52-
super.getStackTrace
53-
case stackTrace =>
54-
stackTrace
55-
}
56-
}
57-
58-
override def fillInStackTrace(): this.type = {
59-
this
60-
}
61-
62-
}
26+
final case class MultipleException(throwableSet: Set[Throwable]) extends DeepLearning.AbstractMultipleException
6327

6428
// Workaround for https://github.com/deeplearning4j/nd4j/issues/1869
6529
private[plugins] implicit final class Nd4jIssues1869Workaround(indArray: INDArray) {
@@ -134,23 +98,8 @@ trait INDArrayLayers extends DoubleLayers with DoubleLiterals with ImplicitsSing
13498
}
13599

136100
@transient
137-
private lazy val doParallelApplicative =
138-
com.thoughtworks.raii.asynchronous.asynchronousDoParallelApplicative(new Semigroup[Throwable] {
139-
override def append(f1: Throwable, f2: => Throwable): Throwable =
140-
f1 match {
141-
case MultipleException(exceptionSet1) =>
142-
f2 match {
143-
case MultipleException(exceptionSet2) => MultipleException(exceptionSet1 ++ exceptionSet2)
144-
case e: Throwable => MultipleException(exceptionSet1 + e)
145-
}
146-
case _: Throwable =>
147-
f2 match {
148-
case MultipleException(exceptionSet2) => MultipleException(exceptionSet2 + f1)
149-
case `f1` => f1
150-
case e: Throwable => MultipleException(Set(f1, e))
151-
}
152-
}
153-
})
101+
implicit private lazy val doParallelApplicative =
102+
asynchronousDoParallelApplicative(DeepLearning.multipleExceptionThrowableSemigroup)
154103

155104
private def parallelApply2[A, B, C](doA: Do[A], doB: Do[B])(f: (A, B) => C): Do[C] = {
156105
Parallel.unwrap(doParallelApplicative.apply2(Parallel(doA), Parallel(doB))(f))

plugins-INDArrayWeights/build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ libraryDependencies += "com.thoughtworks.feature" %% "implicitapply" % "2.3.0-M8
44

55
libraryDependencies += "com.thoughtworks.feature" %% "factory" % "2.3.0-M8"
66

7-
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.2"
7+
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.3"
88

99
libraryDependencies ++= {
1010
import Ordering.Implicits._

plugins-Operators/build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.2"
1+
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.3"

0 commit comments

Comments
 (0)