diff --git a/contrib/swig/CMakeLists.txt b/contrib/swig/CMakeLists.txt index b564318c9..7c345fefc 100644 --- a/contrib/swig/CMakeLists.txt +++ b/contrib/swig/CMakeLists.txt @@ -21,30 +21,30 @@ set(SCALA_VERSION "2.11.11" CACHE STRING "Scala version to compile the swig wrap string(REGEX MATCH "[0-9]+\\.[0-9]+" SCALA_BIN_VERSION "${SCALA_VERSION}") # Set java package (+cuda flag, if appropriate) -if(WITH_CUDA_BACKEND) - set(CMAKE_SWIG_FLAGS -package edu.cmu.dynet.internal -DSWIG_USE_CUDA) -else(WITH_CUDA_BACKEND) - set(CMAKE_SWIG_FLAGS -package edu.cmu.dynet.internal) -endif(WITH_CUDA_BACKEND) +set(CMAKE_SWIG_FLAGS -package edu.cmu.dynet.internal -Dfinal) # Run swig set_source_files_properties(dynet_swig.i PROPERTIES CPLUSPLUS ON) -swig_add_module(dynet_swig java dynet_swig.i) +if(${CMAKE_VERSION} VERSION_LESS "3.8.0") + swig_add_module(dynet_swig java dynet_swig.i) +else() + swig_add_library(dynet_swig LANGUAGE java SOURCES dynet_swig.i) +endif() # add C++ compiler flags if(WITH_CUDA_BACKEND) set_target_properties(dynet_swig PROPERTIES COMPILE_DEFINITIONS HAVE_CUDA) -endif(WITH_CUDA_BACKEND) +endif() # Link with dynet library if(WITH_CUDA_BACKEND) MESSAGE("-- swig link with GPU library") swig_link_libraries(dynet_swig dynet) -else(WITH_CUDA_BACKEND) +else() MESSAGE("-- swig link with CPU library") swig_link_libraries(dynet_swig dynet) -endif(WITH_CUDA_BACKEND) +endif() # Create jar file add_jar( @@ -60,6 +60,7 @@ add_jar( "${CMAKE_SWIG_OUTDIR}/CompactVanillaLSTMBuilder.java" "${CMAKE_SWIG_OUTDIR}/CyclicalSGDTrainer.java" "${CMAKE_SWIG_OUTDIR}/Device.java" + "${CMAKE_SWIG_OUTDIR}/DeviceManager.java" "${CMAKE_SWIG_OUTDIR}/DeviceMempool.java" "${CMAKE_SWIG_OUTDIR}/DeviceMempoolSizes.java" "${CMAKE_SWIG_OUTDIR}/DeviceType.java" @@ -112,6 +113,7 @@ add_jar( "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_p_p_char.java" "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_size_t.java" "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__AlignedMemoryPool_p_t.java" + "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__Device_p_t.java" "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__Node_p_t.java" "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__Tensor_t.java" "${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__VariableIndex_t.java" diff --git a/contrib/swig/build.sbt b/contrib/swig/build.sbt index db6185b02..27152bcf0 100644 --- a/contrib/swig/build.sbt +++ b/contrib/swig/build.sbt @@ -1,26 +1,27 @@ lazy val root = (project in file(".")) - .settings( - name := "dynet_scala_helpers", - organization := "edu.cmu.dynet", - version := "0.0.1-SNAPSHOT" - ) + .settings( + name := "dynet_scala_helpers", + organization := "edu.cmu.dynet", + version := "0.0.1-SNAPSHOT" + ) val DEFAULT_BUILD_PATH = "../../build/contrib/swig" // The default scala version to use if none was specified from // outside. When building with cmake, the scalaversion property // should always be set; this is only a fallback for other cases. -val DEFAULT_SCALA_VERSION = "2.11.11" +val DEFAULT_SCALA_VERSION = "2.12.8" -scalaVersion := { sys.props.get("scalaversion") match { +scalaVersion := { + sys.props.get("scalaversion") match { case Some(p) => p - case None => { + case None => println(s"using default scala version ${DEFAULT_SCALA_VERSION}") DEFAULT_SCALA_VERSION - } -}} - + } +} +javaOptions in Test ++= Seq("-Xms1G","-XX:+CMSClassUnloadingEnabled","-XX:+UseConcMarkSweepGC") // This is where `make` does all its work, and it's where we'll do all our work as well. @@ -29,10 +30,9 @@ lazy val buildPath = settingKey[String]("Build Path") buildPath := { val bp = sys.props.get("buildpath") match { case Some(p) => p - case None => { + case None => println(s"using default buildpath ${DEFAULT_BUILD_PATH}") DEFAULT_BUILD_PATH - } } if (new File(bp).exists) { bp @@ -93,6 +93,6 @@ assemblyMergeStrategy in assembly := { // Don't include Scala libraries in the jar // see https://github.com/sbt/sbt-assembly/issues/3 // and http://stackoverflow.com/questions/15856739/assembling-a-jar-containing-only-the-provided-dependencies -assembleArtifact in packageScala := false +assembleArtifact in assemblyPackageScala := false -libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.0" % "test" +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.8" % "test" diff --git a/contrib/swig/dynet_swig.i b/contrib/swig/dynet_swig.i index d2bd96e02..00feaeaf7 100644 --- a/contrib/swig/dynet_swig.i +++ b/contrib/swig/dynet_swig.i @@ -97,7 +97,6 @@ VECTORCONSTRUCTOR(std::vector, UnsignedVector, UnsignedVectorVector) VECTORCONSTRUCTOR(std::vector, ExpressionVector, ExpressionVectorVector) VECTORCONSTRUCTOR(std::vector, ParameterVector, ParameterVectorVector) - // Useful SWIG libraries %include "std_vector.i" %include "std_string.i" @@ -107,11 +106,17 @@ VECTORCONSTRUCTOR(std::vector, ParameterVector, ParameterVecto %shared_ptr(dynet::ParameterStorage) %shared_ptr(dynet::LookupParameterStorage) +%shared_ptr(dynet::ParameterStorageBase) // Convert C++ exceptions into Java exceptions. This provides // nice error messages for each listed exception, and a default // "unknown error" message for all others. -%catches(std::invalid_argument, ...); +%catches(std::invalid_argument, + std::runtime_error, + std::domain_error, + dynet::out_of_memory, + dynet::cuda_exception, + ...); %pointer_functions(unsigned, uintp); %pointer_functions(int, intp); @@ -155,6 +160,8 @@ struct Node; struct ParameterStorage; struct LookupParameterStorage; +struct Device; + /////////////////////////////////// // declarations from dynet/dim.h // /////////////////////////////////// @@ -322,9 +329,12 @@ private: struct ParameterStorageBase { virtual void scale_parameters(float a) = 0; + virtual void scale_gradient(float a) = 0; virtual void zero() = 0; virtual void squared_l2norm(float* sqnorm) const = 0; virtual void g_squared_l2norm(float* sqnorm) const = 0; + virtual bool is_updated() const = 0; + virtual bool has_grad() const = 0; virtual size_t size() const = 0; virtual ~ParameterStorageBase(); }; @@ -385,11 +395,16 @@ class ParameterCollection { float gradient_l2_norm() const; void reset_gradient(); - Parameter add_parameters(const Dim& d, float scale = 0.0f); - Parameter add_parameters(const Dim& d, const ParameterInit & init); - LookupParameter add_lookup_parameters(unsigned n, const Dim& d); - LookupParameter add_lookup_parameters(unsigned n, const Dim& d, const ParameterInit & init); - + Parameter add_parameters(const Dim& d, float scale = 0.0f, + const std::string & name = "", Device *device = dynet::default_device); + Parameter add_parameters(const Dim& d, Device *device); + Parameter add_parameters(const Dim& d, const std::string & name, Device *device = dynet::default_device); + Parameter add_parameters(const Dim& d, const ParameterInit & init, + const std::string & name = "", Device *device = dynet::default_device); + LookupParameter add_lookup_parameters(unsigned n, const Dim& d, + const std::string & name = "", Device *device = dynet::default_device); + LookupParameter add_lookup_parameters(unsigned n, const Dim& d, const ParameterInit & init, + const std::string & name = "", Device *device = dynet::default_device); void project_weights(float radius = 1.0f); void set_weight_decay_lambda(float lambda); @@ -434,6 +449,7 @@ struct Expression { ComputationGraph *pg; VariableIndex i; Expression(ComputationGraph *pg, VariableIndex i) : pg(pg), i(i) { }; + std::string get_device_name(); const Tensor& value(); const Dim& dim() const { return pg->get_dimension(i); } }; @@ -448,10 +464,13 @@ Expression f(const T& xs, const T1& arg1); /* INPUT OPERATIONS */ -Expression input(ComputationGraph& g, real s); -Expression input(ComputationGraph& g, const real *ps); -Expression input(ComputationGraph& g, const Dim& d, const std::vector* pdata); -Expression input(ComputationGraph& g, const Dim& d, const std::vector& ids, const std::vector& data, float defdata = 0.f); +Expression input(ComputationGraph& g, real s, Device *device = dynet::default_device); +Expression input(ComputationGraph& g, const real *ps, Device *device = dynet::default_device); +Expression input(ComputationGraph& g, const Dim& d, const std::vector& data, Device *device = dynet::default_device); +// Expression input(ComputationGraph& g, const Dim& d, const std::vector* pdata, Device *device = dynet::default_device); +Expression input(ComputationGraph& g, const Dim& d, const std::vector& ids, const std::vector& data, float defdata = 0.f, Device *device = dynet::default_device); +Expression one_hot(ComputationGraph& g, unsigned int d, unsigned int idx, Device *device = dynet::default_device); +Expression one_hot(ComputationGraph& g, unsigned int d, const std::vector& ids, Device *device = dynet::default_device); Expression parameter(ComputationGraph& g, Parameter p); Expression parameter(ComputationGraph& g, LookupParameter lp); Expression const_parameter(ComputationGraph& g, Parameter p); @@ -465,14 +484,14 @@ Expression lookup(ComputationGraph& g, LookupParameter p, const std::vector& indices); //Expression const_lookup(ComputationGraph& g, LookupParameter p, const std::vector* pindices); -Expression zeros(ComputationGraph& g, const Dim& d); -Expression zeroes(ComputationGraph& g, const Dim& d); -Expression ones(ComputationGraph& g, const Dim& d); -Expression constant(ComputationGraph& g, const Dim& d, float val); -Expression random_normal(ComputationGraph& g, const Dim& d); -Expression random_bernoulli(ComputationGraph& g, const Dim& d, real p, real scale = 1.0f); -Expression random_uniform(ComputationGraph& g, const Dim& d, real left, real right); -Expression random_gumbel(ComputationGraph& g, const Dim& d, real mu = 0.0, real beta = 1.0); +Expression zeros(ComputationGraph& g, const Dim& d, Device *device = dynet::default_device); +Expression zeroes(ComputationGraph& g, const Dim& d, Device *device = dynet::default_device); +Expression ones(ComputationGraph& g, const Dim& d, Device *device = dynet::default_device); +Expression constant(ComputationGraph& g, const Dim& d, float val, Device *device = dynet::default_device); +Expression random_normal(ComputationGraph& g, const Dim& d, float mean=0.f, float stddev=1.0, Device *device = dynet::default_device); +Expression random_bernoulli(ComputationGraph& g, const Dim& d, real p, real scale = 1.0f, Device *device = dynet::default_device); +Expression random_uniform(ComputationGraph& g, const Dim& d, real left, real right, Device *device = dynet::default_device); +Expression random_gumbel(ComputationGraph& g, const Dim& d, real mu = 0.0, real beta = 1.0, Device *device = dynet::default_device); /* ARITHMETIC OPERATIONS */ @@ -677,6 +696,8 @@ Expression trace_of_product(const Expression& x, const Expression& y); Expression layer_norm(const Expression& x, const Expression& g, const Expression& b); Expression weight_norm(const Expression& w, const Expression& g); +Expression to_device(const Expression & x, Device *device); + ///////////////////////////////////// // declarations from dynet/dynet.h // ///////////////////////////////////// @@ -770,6 +791,7 @@ class Device { Device& operator=(const Device&) = delete; virtual ~Device(); public: + void reset_rng(unsigned seed) {}; int device_id; DeviceType type; MemAllocator* mem; @@ -785,6 +807,36 @@ class Device { extern Device* default_device; // where parameters go by default +class DeviceManager final { + public: + DeviceManager(); + ~DeviceManager(); + + void clear(); + + void add(Device* d); + + Device* get(size_t i) { return devices[i]; } + + size_t num_devices() const { return devices.size(); } + + const std::vector& get_devices() const { return devices; } + + Device* get_global_device(const std::string & name); + + // no copying allowed + DeviceManager(const DeviceManager &) = delete; + void operator=(const DeviceManager &) = delete; + + private: + std::vector devices; + std::unordered_map devices_map; +}; + +DeviceManager* get_device_manager(); + +inline void show_pool_mem_info(); + //////////////////////////////////////// // declarations from dynet/training.h // //////////////////////////////////////// @@ -1233,12 +1285,10 @@ struct DynetParams { int profiling = 0; /**< Whether to show profiling info or not */ bool shared_parameters = false; /**< TO DOCUMENT */ -#ifdef SWIG_USE_CUDA bool ngpus_requested = false; /**< GPUs requested by number */ bool ids_requested = false; /**< GPUs requested by ids */ int requested_gpus = -1; /**< Number of requested GPUs */ std::vector gpu_mask; /**< List of required GPUs by ids */ -#endif }; diff --git a/contrib/swig/project/assembly.sbt b/contrib/swig/project/assembly.sbt index 15a88b093..9c014713d 100644 --- a/contrib/swig/project/assembly.sbt +++ b/contrib/swig/project/assembly.sbt @@ -1 +1 @@ -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.5") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9") diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/ComputationGraph.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/ComputationGraph.scala index de2b1e7a3..3b945b344 100644 --- a/contrib/swig/src/main/scala/edu/cmu/dynet/ComputationGraph.scala +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/ComputationGraph.scala @@ -6,7 +6,7 @@ package edu.cmu.dynet object ComputationGraph { private[dynet] var cg: internal.ComputationGraph = internal.ComputationGraph.getNew var version: Long = 0L - private var defaultDevice: internal.Device = internal.dynet_swig.getDefault_device() + private val defaultDevice: internal.Device = internal.dynet_swig.getDefault_device() /** Gets rid of the singleton Computation Graph and replaces it with a fresh one. Increments * `version` to make sure we don't use any stale expressions. diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/Device.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/Device.scala new file mode 100644 index 000000000..3584fc1ab --- /dev/null +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/Device.scala @@ -0,0 +1,15 @@ +package edu.cmu.dynet + +object Device { + def apply(str:String): internal.Device = { + if(str == "" || str == "default") internal.dynet_swig.getDefault_device + else DeviceManager.getGlobalDevice(str) + } + + lazy val default: internal.Device = internal.dynet_swig.getDefault_device + + lazy val available: Vector[internal.Device] = { + val tmp = for(l <- 0L until DeviceManager.numDevices()) yield DeviceManager.get(l) + tmp.toVector + } +} diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/DeviceManager.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/DeviceManager.scala new file mode 100644 index 000000000..266bd7012 --- /dev/null +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/DeviceManager.scala @@ -0,0 +1,17 @@ +package edu.cmu.dynet + +object DeviceManager { + private[dynet] val dm: internal.DeviceManager = internal.dynet_swig.get_device_manager() + + def add(d: internal.Device): Unit = dm.add(d) + + def get(l: Long): internal.Device = dm.get(l) + + def numDevices(): Long = dm.num_devices() + + def getGlobalDevice(name: String): internal.Device = dm.get_global_device(name) + + def getDefaultDevice: internal.Device = internal.dynet_swig.getDefault_device + + def showMemPoolInfo(): Unit = internal.dynet_swig.show_pool_mem_info() +} diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/Dim.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/Dim.scala index c2711126e..b2fac8c13 100644 --- a/contrib/swig/src/main/scala/edu/cmu/dynet/Dim.scala +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/Dim.scala @@ -14,7 +14,7 @@ class Dim private[dynet] (private[dynet] val dim: internal.Dim) { def truncate(): Dim = new Dim(dim.truncate()) def singleBatch(): Dim = new Dim(dim.single_batch()) - def resize(i: Long) = dim.resize(i) + def resize(i: Long): Unit = dim.resize(i) def nDims(): Long = dim.ndims() def rows(): Long = dim.rows() def cols(): Long = dim.cols() @@ -31,7 +31,7 @@ class Dim private[dynet] (private[dynet] val dim: internal.Dim) { /** We override `equals` so that `Dim` objects should be equal whenever all of their dimension * sizes match. */ - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case that: Dim => dim == that.dim case _ => false } @@ -39,7 +39,7 @@ class Dim private[dynet] (private[dynet] val dim: internal.Dim) { override def toString: String = "Dim(" + (0 until nDims.toInt).map(get(_)).mkString(", ") + ")" - def debugString(): String = s"(Dim: ${size} ${nDims} ${(0 until nDims.toInt).map(get(_))} )" + def debugString(): String = s"(Dim: $size $nDims ${(0 until nDims.toInt).map(get(_))} )" } /** Factory for [[edu.cmu.dynet.Dim]] instances. */ diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/Expression.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/Expression.scala index 4516124c8..147bc0044 100644 --- a/contrib/swig/src/main/scala/edu/cmu/dynet/Expression.scala +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/Expression.scala @@ -59,13 +59,27 @@ object Expression { new Expression(expr, references) } - def input(s: Float): Expression = makeExpr(cg => dn.input(ComputationGraph.cg, s)) + def input(s: Float): Expression = makeExpr(cg => dn.input(ComputationGraph.cg, s, Device.default)) + def input(s: Float, device: internal.Device): Expression = makeExpr(cg => dn.input(ComputationGraph.cg, s, device)) def input(fp: FloatPointer): Expression = - makeExpr(cg => dn.input(ComputationGraph.cg, fp.floatp), Seq(fp)) + makeExpr(cg => dn.input(ComputationGraph.cg, fp.floatp, Device.default), Seq(fp)) + def input(fp: FloatPointer, device: internal.Device): Expression = + makeExpr(cg => dn.input(ComputationGraph.cg, fp.floatp, device), Seq(fp)) def input(d: Dim, pdata: FloatVector): Expression = - makeExpr(cg => dn.input(cg, d.dim, pdata.vector), Seq(d, pdata)) - def input(d: Dim, ids: UnsignedVector, data: FloatVector, defdata: Float = 0f) = - makeExpr(cg => dn.input(cg, d.dim, ids.vector, data.vector, defdata), Seq(d, ids, data)) + makeExpr(cg => dn.input(cg, d.dim, pdata.vector, Device.default), Seq(d, pdata)) + def input(d: Dim, pdata: FloatVector, device: internal.Device): Expression = + makeExpr(cg => dn.input(cg, d.dim, pdata.vector, device), Seq(d, pdata)) + def input(d: Dim, ids: UnsignedVector, data: FloatVector, defdata: Float = 0f, + device: internal.Device = Device.default): Expression = + makeExpr(cg => dn.input(cg, d.dim, ids.vector, data.vector, defdata, device), Seq(d, ids, data)) + def oneHot(d:Long, idx:Long): Expression = + makeExpr(cg => dn.one_hot(cg, d, idx, Device.default)) + def oneHot(d:Long, idx:Long, device: internal.Device): Expression = + makeExpr(cg => dn.one_hot(cg, d, idx, device)) + def oneHot(d:Long, ids:UnsignedVector): Expression = + makeExpr(cg => dn.one_hot(cg, d, ids.vector, Device.default)) + def oneHot(d:Long, ids:UnsignedVector, device: internal.Device): Expression = + makeExpr(cg => dn.one_hot(cg, d, ids.vector, device)) def parameter(p: Parameter): Expression = makeExpr(cg => dn.parameter(cg, p.parameter), Seq(p)) def parameter(lp: LookupParameter): Expression = makeExpr(cg => dn.parameter(cg, lp.lookupParameter), Seq(lp)) @@ -74,30 +88,31 @@ object Expression { def constParameter(lp: LookupParameter): Expression = makeExpr(cg => dn.const_parameter(cg, lp.lookupParameter), Seq(lp)) - def lookup(p: LookupParameter, index: Long) = + def lookup(p: LookupParameter, index: Long): Expression = makeExpr(cg => dn.lookup(cg, p.lookupParameter, index), Seq(p)) - def lookup(p: LookupParameter, pindex: UnsignedPointer) = + def lookup(p: LookupParameter, pindex: UnsignedPointer): Expression = makeExpr(cg => dn.lookup(cg, p.lookupParameter, pindex.uintp), Seq(p, pindex)) - def constLookup(p: LookupParameter, index: Long) = + def constLookup(p: LookupParameter, index: Long): Expression = makeExpr(cg => dn.const_lookup(cg, p.lookupParameter, index), Seq(p)) - def constLookup(p: LookupParameter, pindex: UnsignedPointer) = + def constLookup(p: LookupParameter, pindex: UnsignedPointer): Expression = makeExpr(cg => dn.const_lookup(cg, p.lookupParameter, pindex.uintp), Seq(p, pindex)) - def lookup(p: LookupParameter, indices: UnsignedVector) = + def lookup(p: LookupParameter, indices: UnsignedVector): Expression = makeExpr(cg => dn.lookup(cg, p.lookupParameter, indices.vector), Seq(p, indices)) - def constLookup(p: LookupParameter, indices: UnsignedVector) = + def constLookup(p: LookupParameter, indices: UnsignedVector): Expression = makeExpr(cg => dn.const_lookup(cg, p.lookupParameter, indices.vector), Seq(p, indices)) - def zeros(d: Dim) = makeExpr(cg => dn.zeros(cg, d.dim), Seq(d)) - def zeroes(d: Dim) = makeExpr(cg => dn.zeros(cg, d.dim), Seq(d)) - def ones(d: Dim) = makeExpr(cg => dn.ones(cg, d.dim), Seq(d)) - def constant(d: Dim, v: Float) = makeExpr(cg => dn.constant(cg, d.dim, v), Seq(d)) - def randomNormal(d: Dim) = makeExpr(cg => dn.random_normal(cg, d.dim), Seq(d)) - def randomBernoulli(d: Dim, p: Float, scale: Float = 1.0f) = makeExpr( - cg => dn.random_bernoulli(cg, d.dim, p, scale), Seq(d)) - def randomUniform(d: Dim, left: Float, right: Float) = makeExpr( - cg => dn.random_uniform(cg, d.dim, left, right), Seq(d)) - def randomGumbel(d: Dim, mu: Float, beta: Float) = makeExpr( - cg => dn.random_gumbel(cg, d.dim, mu, beta), Seq(d)) + def zeros(d: Dim, device: internal.Device = Device.default): Expression = makeExpr(cg => dn.zeros(cg, d.dim, device), Seq(d)) + def zeroes(d: Dim, device: internal.Device = Device.default): Expression = makeExpr(cg => dn.zeros(cg, d.dim, device), Seq(d)) + def ones(d: Dim, device: internal.Device = Device.default): Expression = makeExpr(cg => dn.ones(cg, d.dim, device), Seq(d)) + def constant(d: Dim, v: Float, device: internal.Device = Device.default): Expression = makeExpr(cg => dn.constant(cg, d.dim, v, device), Seq(d)) + def randomNormal(d: Dim, mean: Float = 0f, stdDev: Float = 1f, device: internal.Device = Device.default): Expression = + makeExpr(cg => dn.random_normal(cg, d.dim, mean, stdDev, device), Seq(d)) + def randomBernoulli(d: Dim, p: Float, scale: Float = 1.0f, device: internal.Device = Device.default): Expression = + makeExpr(cg => dn.random_bernoulli(cg, d.dim, p, scale, device), Seq(d)) + def randomUniform(d: Dim, left: Float, right: Float, device: internal.Device = Device.default): Expression = + makeExpr(cg => dn.random_uniform(cg, d.dim, left, right, device), Seq(d)) + def randomGumbel(d: Dim, mu: Float, beta: Float, device: internal.Device = Device.default): Expression = + makeExpr(cg => dn.random_gumbel(cg, d.dim, mu, beta, device), Seq(d)) /* ARITHMETIC OPERATIONS */ @@ -111,7 +126,7 @@ object Expression { } private type UnaryTransform = internal.Expression => internal.Expression - private def unary(e: Expression, transformer: UnaryTransform) = { + private def unary(e: Expression, transformer: UnaryTransform): Expression = { e.ensureFresh() // Specify e as reference so it can't get prematurely garbage collected. new Expression(transformer(e.expr), Seq(e)) @@ -146,7 +161,7 @@ object Expression { def sum(exprs: Expression*): Expression = sum(new ExpressionVector(exprs)) def sumElems(e: Expression): Expression = unary(e, dn.sum_elems) - def momentElems(e: Expression, r: Long) = unary(e, e => dn.moment_elems(e, r)) + def momentElems(e: Expression, r: Long): Expression = unary(e, e => dn.moment_elems(e, r)) def meanElems(e: Expression): Expression = unary(e, dn.mean_elems) def stdElems(e: Expression): Expression = unary(e, dn.std_elems) @@ -270,7 +285,12 @@ object Expression { def concatenateCols(v: ExpressionVector): Expression = vectory(v, dn.concatenate_cols) def concatenateCols(exprs: Expression*): Expression = concatenateCols(new ExpressionVector(exprs)) - def concatenate(v: ExpressionVector): Expression = vectory(v, dn.concatenate) + def concatenate(v: ExpressionVector, d: Int = 0): Expression = { + assert(v.nonEmpty, "Operation requires > 0 expression arguments") + v.ensureFresh() + new Expression(dn.concatenate(v.vector, d), Seq(v)) + } + def concatenate(exprs: Expression*): Expression = concatenate(new ExpressionVector(exprs)) /* NOISE OPERATIONS */ @@ -337,6 +357,7 @@ object Expression { new Expression(dn.layer_norm(x.expr, g.expr, b.expr), Seq(x, g, b)) } def weightNorm(w: Expression, g: Expression): Expression = binary(w, g, dn.weight_norm) + def toDevice(x: Expression, device: internal.Device): Expression = unary(x, x => dn.to_device(x, device)) /** Augment numbers so that they can do arithmetic with expressions. */ implicit class ImplicitNumerics[T](x: T)(implicit n: Numeric[T]) { diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/Initialize.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/Initialize.scala index ec9cb3f98..4380d7d3a 100644 --- a/contrib/swig/src/main/scala/edu/cmu/dynet/Initialize.scala +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/Initialize.scala @@ -29,7 +29,27 @@ object Initialize { .foreach(arg => params.setAutobatch(arg.asInstanceOf[Int])) args.get("profiling") - .foreach(arg => params.setProfiling(arg.asInstanceOf[Int])) + .foreach(arg => params.setProfiling(arg.asInstanceOf[Int])) + + args.get("gpus") + .foreach(arg => params.setRequested_gpus(arg.asInstanceOf[Int])) + + if(args.contains("devices")){ + require(!params.getIds_requested) + params.setIds_requested(true) + args.get("devices") + .foreach(arg => arg.asInstanceOf[String].split(',').foreach( + s => + if(s.startsWith("CPU:")){ + Console.err.println("DyNet doesn't support specifying CPU id") + }else if(s.startsWith("GPU:")){ + val gpuID = s.split(":")(1).toInt + params.getGpu_mask.set(gpuID, params.getGpu_mask.get(gpuID) + 1) + params.setRequested_gpus(params.getRequested_gpus + 1) + require(params.getGpu_mask.get(gpuID) == 1) + } + )) + } initialize(params) } diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/ParameterCollection.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/ParameterCollection.scala index 284778bb5..577c70639 100644 --- a/contrib/swig/src/main/scala/edu/cmu/dynet/ParameterCollection.scala +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/ParameterCollection.scala @@ -15,15 +15,36 @@ class ParameterCollection private[dynet] (private[dynet] val model: internal.Par } */ - def addParameters(d: Dim, scale: Float = 0.0f): Parameter = - new Parameter(model.add_parameters(d.dim, scale)) - def addParameters(d: Dim, init: ParameterInit): Parameter = - new Parameter(model.add_parameters(d.dim, init.parameterInit)) - - def addLookupParameters(n: Long, d: Dim): LookupParameter = - new LookupParameter(model.add_lookup_parameters(n, d.dim)) - def addLookupParameters(n: Long, d: Dim, init: ParameterInit) = - new LookupParameter(model.add_lookup_parameters(n, d.dim, init.parameterInit)) + def addParameters(d: Dim, + init: ParameterInit = ParameterInit.glorot(), + name: String = "", + device: internal.Device = Device.default): Parameter = + new Parameter(model.add_parameters(d.dim, init.parameterInit, name, device)) + + // Scala compiler does not allow multiple overloaded method with default arguments + def addParameters(d: Dim, scale: Float, name:String, device: internal.Device): Parameter = + new Parameter(model.add_parameters(d.dim, scale, name, device)) + def addParameters(d: Dim, scale: Float, name:String): Parameter = + new Parameter(model.add_parameters(d.dim, scale, name, Device.default)) + def addParameters(d: Dim, scale: Float, device: internal.Device): Parameter = + new Parameter(model.add_parameters(d.dim, scale, "", device)) + def addParameters(d: Dim, scale: Float): Parameter = + new Parameter(model.add_parameters(d.dim, scale, "", Device.default)) +// def addParameters(d: Dim, device: internal.Device): Parameter = +// new Parameter(model.add_parameters(d.dim, device)) +// def addParameters(d: Dim, name:String, device: internal.Device): Parameter = +// new Parameter(model.add_parameters(d.dim, name, device)) +// def addParameters(d: Dim, init: ParameterInit): Parameter = +// new Parameter(model.add_parameters(d.dim, init.parameterInit)) +// def addParameters(d: Dim, init:ParameterInit, name: String): Parameter = +// new Parameter(model.add_parameters(d.dim, init.parameterInit, name)) + + def addLookupParameters(n: Long, + d: Dim, + init: ParameterInit = ParameterInit.glorot(), + name: String = "", + device: internal.Device = Device.default): LookupParameter = + new LookupParameter(model.add_lookup_parameters(n, d.dim, init.parameterInit, name, device)) def projectWeights(radius: Float = 0.0f) = model.project_weights(radius) def setWeightDecayLambda(lambda: Float) = model.set_weight_decay_lambda(lambda) diff --git a/contrib/swig/src/main/scala/edu/cmu/dynet/package.scala b/contrib/swig/src/main/scala/edu/cmu/dynet/package.scala new file mode 100644 index 000000000..c339e565f --- /dev/null +++ b/contrib/swig/src/main/scala/edu/cmu/dynet/package.scala @@ -0,0 +1,9 @@ +package edu.cmu + +package object dynet { + implicit class RichDevice(self: internal.Device) { + def name(): String = self.getName + def deviceID(): Int = self.getDevice_id + def resetRNG(seed:Long): Unit = self.reset_rng(seed) + } +}