diff --git a/.gitignore b/.gitignore index cfe2c08a..bcf8c0f8 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ project/plugins/project/ # Node node_modules + +# Spark-ec2 boto +tools/spark-ec2/lib diff --git a/build.sbt b/build.sbt index 095c1228..5de79888 100644 --- a/build.sbt +++ b/build.sbt @@ -4,7 +4,7 @@ version := "1.0" scalaVersion := "2.10.4" -scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature", "-Xfatal-warnings") +scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature", "-Xfatal-warnings", "-Xlint", "-Ywarn-dead-code", "-Xmax-classfile-name", "130") ideaExcludeFolders += ".idea" @@ -13,15 +13,13 @@ ideaExcludeFolders += ".idea_modules" // Because we can't run two spark contexts on same VM parallelExecution in Test := false -libraryDependencies += ("org.apache.spark" %% "spark-core" % "1.3.0" % "provided").exclude("org.apache.hadoop", "hadoop-client") +libraryDependencies += ("org.apache.spark" %% "spark-core" % "1.5.1" % "provided") + .exclude("org.apache.hadoop", "hadoop-client") + .exclude("org.slf4j", "slf4j-log4j12") libraryDependencies += ("org.apache.hadoop" % "hadoop-client" % "2.0.0-cdh4.7.1" % "provided") -libraryDependencies += "com.github.nscala-time" %% "nscala-time" % "0.8.0" - -libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" - -libraryDependencies += "org.scalaj" %% "scalaj-http" % "0.3.16" +libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.4" libraryDependencies += "org.scalaz" %% "scalaz-core" % "7.0.6" @@ -29,6 +27,14 @@ libraryDependencies += "com.github.scopt" %% "scopt" % "3.2.0" libraryDependencies += "net.java.dev.jets3t" % "jets3t" % "0.7.1" +libraryDependencies += "joda-time" % "joda-time" % "2.7" + +libraryDependencies += "org.joda" % "joda-convert" % "1.7" + +libraryDependencies += "com.amazonaws" % "aws-java-sdk" % "1.9.6" + +libraryDependencies += "commons-lang" % "commons-lang" % "2.6" + resolvers += "Akka Repository" at "http://repo.akka.io/releases/" resolvers += "Sonatype OSS Releases" at "http://oss.sonatype.org/content/repositories/releases/" diff --git a/remote_hook.sh b/remote_hook.sh index 305a0ff6..dd76933a 100755 --- a/remote_hook.sh +++ b/remote_hook.sh @@ -11,6 +11,7 @@ CONTROL_DIR="${5?Please give the Control Directory}" SPARK_MEM_PARAM="${6?Please give the Job Memory Size to use}" USE_YARN="${7?Please tell if we should use YARN (yes/no)}" NOTIFY_ON_ERRORS="${8?Please tell if we will notify on errors (yes/no)}" +DRIVER_HEAP_SIZE="${9?Please tell driver heap size to use}" JOB_WITH_TAG=${JOB_NAME}.${JOB_TAG} JOB_CONTROL_DIR="${CONTROL_DIR}/${JOB_WITH_TAG}" @@ -48,6 +49,23 @@ on_trap_exit() { rm -f "${RUNNING_FILE}" } +install_and_run_zeppelin() { + if [[ ! -d "zeppelin" ]]; then + wget "http://www.us.apache.org/dist/incubator/zeppelin/0.5.6-incubating/zeppelin-0.5.6-incubating-bin-all.tgz" -O zeppelin.tar.gz + mkdir zeppelin + tar xvzf zeppelin.tar.gz -C zeppelin --strip-components 1 > /tmp/zeppelin_install.log + fi + if [[ -f "zeppelin/bin/zeppelin.sh" ]]; then + export MASTER="${JOB_MASTER}" + export ZEPPELIN_PORT="8081" + export SPARK_HOME="/root/spark" + export SPARK_SUBMIT_OPTIONS="--jars ${JAR_PATH} --runner-executor-memory ${SPARK_MEM_PARAM}" + sudo -E zeppelin/bin/zeppelin.sh + else + notify_error_and_exit "Zepellin installation not found" + fi +} + trap "on_trap_exit" EXIT @@ -73,14 +91,15 @@ if [[ "${USE_YARN}" == "yes" ]]; then export SPARK_WORKER_MEMORY=${SPARK_MEM_PARAM} fi - if [[ "${JOB_NAME}" == "shell" ]]; then export ADD_JARS=${JAR_PATH} sudo -E ${SPARK_HOME}/bin/spark-shell || notify_error_and_exit "Execution failed for shell" +elif [[ "${JOB_NAME}" == "zeppelin" ]]; then + install_and_run_zeppelin else JOB_OUTPUT="${JOB_CONTROL_DIR}/output.log" tail -F "${JOB_OUTPUT}" & - sudo -E "${SPARK_HOME}/bin/spark-submit" --master "${JOB_MASTER}" --driver-memory 25000M --driver-java-options "-Djava.io.tmpdir=/mnt -verbose:gc -XX:-PrintGCDetails -XX:+PrintGCTimeStamps" --class "${MAIN_CLASS}" ${JAR_PATH} "${JOB_NAME}" --runner-date "${JOB_DATE}" --runner-tag "${JOB_TAG}" --runner-user "${JOB_USER}" --runner-master "${JOB_MASTER}" --runner-executor-memory "${SPARK_MEM_PARAM}" >& "${JOB_OUTPUT}" || notify_error_and_exit "Execution failed for job ${JOB_WITH_TAG}" + sudo -E "${SPARK_HOME}/bin/spark-submit" --master "${JOB_MASTER}" --driver-memory "${DRIVER_HEAP_SIZE}" --driver-java-options "-Djava.io.tmpdir=/mnt -verbose:gc -XX:-PrintGCDetails -XX:+PrintGCTimeStamps" --class "${MAIN_CLASS}" ${JAR_PATH} "${JOB_NAME}" --runner-date "${JOB_DATE}" --runner-tag "${JOB_TAG}" --runner-user "${JOB_USER}" --runner-master "${JOB_MASTER}" --runner-executor-memory "${SPARK_MEM_PARAM}" >& "${JOB_OUTPUT}" || notify_error_and_exit "Execution failed for job ${JOB_WITH_TAG}" fi touch "${JOB_CONTROL_DIR}/SUCCESS" diff --git a/src/main/scala/ignition/core/jobs/CoreJobRunner.scala b/src/main/scala/ignition/core/jobs/CoreJobRunner.scala index aa4dcc76..bbede553 100644 --- a/src/main/scala/ignition/core/jobs/CoreJobRunner.scala +++ b/src/main/scala/ignition/core/jobs/CoreJobRunner.scala @@ -13,9 +13,14 @@ object CoreJobRunner { // Used to provide contextual logging def setLoggingContextValues(config: RunnerConfig): Unit = { - org.slf4j.MDC.put("setupName", config.setupName) - org.slf4j.MDC.put("tag", config.tag) - org.slf4j.MDC.put("user", config.user) + try { // yes, this may fail but we don't want everything to shut down + org.slf4j.MDC.put("setupName", config.setupName) + org.slf4j.MDC.put("tag", config.tag) + org.slf4j.MDC.put("user", config.user) + } catch { + case e: Throwable => + // cry + } } case class RunnerConfig(setupName: String = "nosetup", diff --git a/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala b/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala index fc42ded5..60bddc9a 100644 --- a/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala +++ b/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala @@ -57,6 +57,8 @@ object RDDUtils { def incrementCounterIf(cond: (V) => Boolean, acc: spark.Accumulator[Int]): RDD[V] = { rdd.map(x => { if (cond(x)) acc += 1; x }) } + + def filterNot(p: V => Boolean): RDD[V] = rdd.filter(!p(_)) } implicit class PairRDDImprovements[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) { @@ -80,11 +82,15 @@ object RDDUtils { }, preservesPartitioning = true) } + def collectValues[U: ClassTag](f: PartialFunction[V, U]): RDD[(K, U)] = { + rdd.filter { case (k, v) => f.isDefinedAt(v) }.mapValues(f) + } + def groupByKeyAndTake(n: Int): RDD[(K, List[V])] = rdd.aggregateByKey(List.empty[V])( (lst, v) => if (lst.size >= n) { - logger.warn(s"Ignoring value '$v' due aggregation result of size '${lst.size}' is bigger then n = '$n'") + logger.warn(s"Ignoring value '$v' due aggregation result of size '${lst.size}' is bigger than n=$n") lst } else { v :: lst diff --git a/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala b/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala index 29c32112..4eab7baf 100644 --- a/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala +++ b/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala @@ -1,23 +1,58 @@ package ignition.core.jobs.utils -import java.util.Date - -import ignition.core.utils.ByteUtils -import org.apache.hadoop.io.LongWritable -import org.apache.spark.SparkContext -import org.apache.hadoop.fs.{FileStatus, Path, FileSystem} -import org.apache.spark.rdd.{UnionRDD, RDD} -import org.joda.time.{DateTimeZone, DateTime} +import java.io.InputStream + +import com.amazonaws.auth.EnvironmentVariableCredentialsProvider +import com.amazonaws.services.s3.AmazonS3Client +import com.amazonaws.services.s3.model.{ListObjectsRequest, ObjectListing, S3ObjectSummary} import ignition.core.utils.DateUtils._ +import ignition.core.utils.{AutoCloseableIterator, ByteUtils, HadoopUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.{Partitioner, SparkContext} +import org.joda.time.DateTime +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.io.{Codec, Source} import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal +import ignition.core.utils.ExceptionUtils._ object SparkContextUtils { + private case class BigFileSlice(index: Int) + + private case class HadoopFilePartition(size: Long, paths: Seq[String]) + + private case class IndexedPartitioner(numPartitions: Int, index: Map[Any, Int]) extends Partitioner { + override def getPartition(key: Any): Int = index(key) + } + + private lazy val amazonS3ClientFromEnvironmentVariables = new AmazonS3Client(new EnvironmentVariableCredentialsProvider()) + + private def close(inputStream: InputStream, path: String): Unit = { + try { + inputStream.close() + } catch { + case NonFatal(ex) => + println(s"Fail to close resource from '$path': ${ex.getMessage} -- ${ex.getFullStackTraceString}") + } + } + + case class HadoopFile(path: String, isDir: Boolean, size: Long) + implicit class SparkContextImprovements(sc: SparkContext) { + lazy val _hadoopConf = sc.broadcast(sc.hadoopConfiguration.iterator().map { case entry => entry.getKey -> entry.getValue }.toMap) + private def getFileSystem(path: Path): FileSystem = { path.getFileSystem(sc.hadoopConfiguration) } @@ -28,7 +63,7 @@ object SparkContextUtils { for { path <- paths status <- Option(fs.globStatus(path)).getOrElse(Array.empty).toSeq - if status.isDirectory || !removeEmpty || status.getLen > 0 // remove empty files if necessary + if !removeEmpty || status.getLen > 0 || status.isDirectory // remove empty files if necessary } yield status } @@ -52,7 +87,7 @@ object SparkContextUtils { if (splittedPaths.size < minimumPaths) throw new Exception(s"Not enough paths found for $paths") - val rdds = splittedPaths.grouped(50).map(pathGroup => f(pathGroup.mkString(","))) + val rdds = splittedPaths.grouped(5000).map(pathGroup => f(pathGroup.mkString(","))) new UnionRDD(sc, rdds.toList) } @@ -95,7 +130,6 @@ object SparkContextUtils { } - def getFilteredPaths(paths: Seq[String], requireSuccess: Boolean, inclusiveStartDate: Boolean, @@ -108,7 +142,6 @@ object SparkContextUtils { filterPaths(paths, requireSuccess, inclusiveStartDate, startDate, inclusiveEndDate, endDate, lastN, ignoreMalformedDates) } - lazy val hdfsPathPrefix = sc.master.replaceFirst("spark://(.*):7077", "hdfs://$1:9000/") def synchToHdfs(paths: Seq[String], pathsToRdd: (Seq[String], Int) => RDD[String], forceSynch: Boolean): Seq[String] = { @@ -130,6 +163,15 @@ object SparkContextUtils { } + @deprecated("It may incur heavy S3 costs and/or be slow with small files, use getParallelTextFiles instead", "2015-10-27") + def getTextFiles(paths: Seq[String], synchLocally: Boolean = false, forceSynch: Boolean = false, minimumPaths: Int = 1): RDD[String] = { + if (synchLocally) + processTextFiles(synchToHdfs(paths, processTextFiles, forceSynch), minimumPaths) + else + processTextFiles(paths, minimumPaths) + } + + @deprecated("It may incur heavy S3 costs and/or be slow with small files, use filterAndGetParallelTextFiles instead", "2015-10-27") def filterAndGetTextFiles(path: String, requireSuccess: Boolean = false, inclusiveStartDate: Boolean = true, @@ -144,10 +186,7 @@ object SparkContextUtils { val paths = getFilteredPaths(Seq(path), requireSuccess, inclusiveStartDate, startDate, inclusiveEndDate, endDate, lastN, ignoreMalformedDates) if (paths.size < minimumPaths) throw new Exception(s"Tried with start/end time equals to $startDate/$endDate for path $path but but the resulting number of paths $paths is less than the required") - else if (synchLocally) - processTextFiles(synchToHdfs(paths, processTextFiles, forceSynch), minimumPaths) - else - processTextFiles(paths, minimumPaths) + getTextFiles(paths, synchLocally, forceSynch, minimumPaths) } private def stringHadoopFile(paths: Seq[String], minimumPaths: Int): RDD[Try[String]] = { @@ -190,5 +229,453 @@ object SparkContextUtils { else objectHadoopFile(paths, minimumPaths) } + + case class SizeBasedFileHandling(averageEstimatedCompressionRatio: Int = 8, + compressedExtensions: Set[String] = Set(".gz")) { + + def isBig(f: HadoopFile, uncompressedBigSize: Long): Boolean = estimatedSize(f) >= uncompressedBigSize + + def estimatedSize(f: HadoopFile) = if (isCompressed(f)) + f.size * averageEstimatedCompressionRatio + else + f.size + + def isCompressed(f: HadoopFile): Boolean = compressedExtensions.exists(f.path.endsWith) + } + + + private def readSmallFiles(smallFiles: List[HadoopFile], + maxBytesPerPartition: Long, + minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling): RDD[String] = { + val smallPartitionedFiles = sc.parallelize(smallFiles.map(_.path).map(file => file -> null), 2).partitionBy(createSmallFilesPartitioner(smallFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling)) + val hadoopConf = _hadoopConf + smallPartitionedFiles.mapPartitions { files => + val conf = hadoopConf.value.foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + val codecFactory = new CompressionCodecFactory(conf) + files.map { case (path, _) => path } flatMap { path => + val hadoopPath = new Path(path) + val fileSystem = hadoopPath.getFileSystem(conf) + val inputStream = Option(codecFactory.getCodec(hadoopPath)) match { + case Some(compression) => compression.createInputStream(fileSystem.open(hadoopPath)) + case None => fileSystem.open(hadoopPath) + } + try { + Source.fromInputStream(inputStream)(Codec.UTF8).getLines().foldLeft(ArrayBuffer.empty[String])(_ += _) + } catch { + case NonFatal(ex) => + println(s"Failed to read resource from '$path': ${ex.getMessage} -- ${ex.getFullStackTraceString}") + throw new Exception(s"Failed to read resource from '$path': ${ex.getMessage} -- ${ex.getFullStackTraceString}") + } finally { + close(inputStream, path) + } + } + } + } + + private def readCompressedBigFile(file: HadoopFile, maxBytesPerPartition: Long, minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling, sampleCount: Int = 100): RDD[String] = { + val estimatedSize = sizeBasedFileHandling.estimatedSize(file) + val totalSlices = (estimatedSize / maxBytesPerPartition + 1).toInt + val slices = (0 until totalSlices).map(BigFileSlice.apply) + + val partitioner = { + val indexedPartitions: Map[Any, Int] = slices.map(s => s -> s.index).toMap + IndexedPartitioner(totalSlices, indexedPartitions) + } + val hadoopConf = _hadoopConf + + val partitionedSlices = sc.parallelize(slices.map(s => s -> null), 2).partitionBy(partitioner) + + partitionedSlices.mapPartitions { slices => + val conf = hadoopConf.value.foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + val codecFactory = new CompressionCodecFactory(conf) + val hadoopPath = new Path(file.path) + val fileSystem = hadoopPath.getFileSystem(conf) + slices.flatMap { case (slice, _) => + val inputStream = Option(codecFactory.getCodec(hadoopPath)) match { + case Some(compression) => compression.createInputStream(fileSystem.open(hadoopPath)) + case None => fileSystem.open(hadoopPath) + } + val lines = Source.fromInputStream(inputStream)(Codec.UTF8).getLines() + + val lineSample = lines.take(sampleCount).toList + val linesPerSlice = { + val sampleSize = lineSample.map(_.size).sum + val estimatedAverageLineSize = Math.round(sampleSize / sampleCount.toFloat) + val estimatedTotalLines = Math.round(estimatedSize / estimatedAverageLineSize.toFloat) + estimatedTotalLines / totalSlices + 1 + } + + val linesAfterSeek = (lineSample.toIterator ++ lines).drop(linesPerSlice * slice.index) + + val finalLines = if (slice.index + 1 == totalSlices) // last slice, read until the end + linesAfterSeek + else + linesAfterSeek.take(linesPerSlice) + + AutoCloseableIterator.wrap(finalLines, () => close(inputStream, s"${file.path}, slice $slice")) + } + } + } + + private def readBigFiles(bigFiles: List[HadoopFile], + maxBytesPerPartition: Long, + minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling): RDD[String] = { + def confWith(maxSplitSize: Long): Configuration = (_hadoopConf.value ++ Seq( + "mapreduce.input.fileinputformat.split.maxsize" -> maxSplitSize.toString)) + .foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + + def read(file: HadoopFile, conf: Configuration) = sc.newAPIHadoopFile[LongWritable, Text, TextInputFormat](conf = conf, fClass = classOf[TextInputFormat], + kClass = classOf[LongWritable], vClass = classOf[Text], path = file.path).map(pair => pair._2.toString) + + val confUncompressed = confWith(maxBytesPerPartition) + + val union = new UnionRDD(sc, bigFiles.map { file => + + if (sizeBasedFileHandling.isCompressed(file)) + readCompressedBigFile(file, maxBytesPerPartition, minPartitions, sizeBasedFileHandling) + else + read(file, confUncompressed) + }) + + if (union.partitions.size < minPartitions) + union.coalesce(minPartitions) + else + union + } + + def parallelListAndReadTextFiles(paths: List[String], + maxBytesPerPartition: Long, + minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling = SizeBasedFileHandling()) + (implicit dateExtractor: PathDateExtractor): RDD[String] = { + val foundFiles = paths.flatMap(smartList(_)) + parallelReadTextFiles(foundFiles, maxBytesPerPartition = maxBytesPerPartition, minPartitions = minPartitions, sizeBasedFileHandling = sizeBasedFileHandling) + } + + def parallelReadTextFiles(files: List[HadoopFile], + maxBytesPerPartition: Long = 256 * 1000 * 1000, + minPartitions: Int = 100, + sizeBasedFileHandling: SizeBasedFileHandling = SizeBasedFileHandling(), + synchLocally: Option[String] = None, + forceSynch: Boolean = false): RDD[String] = { + val filteredFiles = files.filter(_.size > 0) + if (synchLocally.isDefined) + doSync(filteredFiles, maxBytesPerPartition = maxBytesPerPartition, minPartitions = minPartitions, synchLocally = synchLocally.get, + sizeBasedFileHandling = sizeBasedFileHandling, forceSynch = forceSynch) + else { + val (bigFiles, smallFiles) = filteredFiles.partition(f => sizeBasedFileHandling.isBig(f, maxBytesPerPartition)) + sc.union( + readSmallFiles(smallFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling), + readBigFiles(bigFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling)) + } + } + + private def createSmallFilesPartitioner(files: List[HadoopFile], maxBytesPerPartition: Long, minPartitions: Long, sizeBasedFileHandling: SizeBasedFileHandling): Partitioner = { + implicit val ordering: Ordering[HadoopFilePartition] = Ordering.by(p => -p.size) // Small partitions come first (highest priority) + + val pq: mutable.PriorityQueue[HadoopFilePartition] = mutable.PriorityQueue.empty + + (0L until minPartitions).foreach(_ => pq += HadoopFilePartition(0, Seq.empty)) + + val partitions = files.foldLeft(pq) { + case (acc, file) => + val fileSize = sizeBasedFileHandling.estimatedSize(file) + + acc.headOption.filter(bucket => bucket.size + fileSize < maxBytesPerPartition) match { + case Some(found) => + val updated = found.copy(size = found.size + fileSize, paths = file.path +: found.paths) + acc.tail += updated + case None => acc += HadoopFilePartition(fileSize, Seq(file.path)) + } + }.filter(_.paths.nonEmpty).toList // Remove empty partitions + + val indexedPartitions: Map[Any, Int] = partitions.zipWithIndex.flatMap { + case (bucket, index) => bucket.paths.map(path => path -> index) + }.toMap + + IndexedPartitioner(partitions.size, indexedPartitions) + } + + private def executeDriverList(paths: Seq[String]): List[HadoopFile] = { + val conf = _hadoopConf.value.foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + paths.flatMap { path => + val hadoopPath = new Path(path) + val fileSystem = hadoopPath.getFileSystem(conf) + val tryFind = try { + val status = fileSystem.getFileStatus(hadoopPath) + if (status.isDirectory) { + val sanitize = Option(fileSystem.listStatus(hadoopPath)).getOrElse(Array.empty) + Option(sanitize.map(status => HadoopFile(status.getPath.toString, status.isDirectory, status.getLen)).toList) + } else if (status.isFile) { + Option(List(HadoopFile(status.getPath.toString, status.isDirectory, status.getLen))) + } else { + None + } + } catch { + case e: java.io.FileNotFoundException => + None + } + + tryFind.getOrElse { + // Maybe is glob or not found + val sanitize = Option(fileSystem.globStatus(hadoopPath)).getOrElse(Array.empty) + sanitize.map(status => HadoopFile(status.getPath.toString, status.isDirectory, status.getLen)).toList + } + }.toList + } + + private def driverListFiles(path: String): List[HadoopFile] = { + def innerListFiles(remainingDirectories: List[HadoopFile]): List[HadoopFile] = { + if (remainingDirectories.isEmpty) { + Nil + } else { + val (dirs, files) = executeDriverList(remainingDirectories.map(_.path)).partition(_.isDir) + files ++ innerListFiles(dirs) + } + } + innerListFiles(List(HadoopFile(path, isDir = true, 0))) + } + + def s3ListCommonPrefixes(bucket: String, prefix: String, delimiter: String = "/") + (implicit s3: AmazonS3Client): Stream[String] = { + def inner(current: ObjectListing): Stream[String] = + if (current.isTruncated) + current.getCommonPrefixes.toStream ++ inner(s3.listNextBatchOfObjects(current)) + else + current.getCommonPrefixes.toStream + + val request = new ListObjectsRequest(bucket, prefix, null, delimiter, 1000) + inner(s3.listObjects(request)) + } + + def s3ListObjects(bucket: String, prefix: String) + (implicit s3: AmazonS3Client): Stream[S3ObjectSummary] = { + def inner(current: ObjectListing): Stream[S3ObjectSummary] = + if (current.isTruncated) + current.getObjectSummaries.toStream ++ inner(s3.listNextBatchOfObjects(current)) + else + current.getObjectSummaries.toStream + + inner(s3.listObjects(bucket, prefix)) + } + + def s3NarrowPaths(bucket: String, + prefix: String, + delimiter: String = "/", + inclusiveStartDate: Boolean = true, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = true, + endDate: Option[DateTime] = None, + ignoreHours: Boolean = true) + (implicit s3: AmazonS3Client, pathDateExtractor: PathDateExtractor): Stream[String] = { + + def isGoodDate(date: DateTime): Boolean = { + val startDateToCompare = startDate.map(date => if (ignoreHours) date.withTimeAtStartOfDay() else date) + val endDateToCompare = endDate.map(date => if (ignoreHours) date.withTime(23, 59, 59, 999) else date) + val goodStartDate = startDateToCompare.isEmpty || (inclusiveStartDate && date.saneEqual(startDateToCompare.get) || date.isAfter(startDateToCompare.get)) + val goodEndDate = endDateToCompare.isEmpty || (inclusiveEndDate && date.saneEqual(endDateToCompare.get) || date.isBefore(endDateToCompare.get)) + goodStartDate && goodEndDate + } + + def classifyPath(path: String): Either[String, (String, DateTime)] = + Try(pathDateExtractor.extractFromPath(s"s3n://$bucket/$path")) match { + case Success(date) => Right(path -> date) + case Failure(_) => Left(path) + } + + val commonPrefixes = s3ListCommonPrefixes(bucket, prefix, delimiter).map(classifyPath) + + if (commonPrefixes.isEmpty) + Stream(s"s3n://$bucket/$prefix") + else + commonPrefixes.toStream.flatMap { + case Left(prefixWithoutDate) => s3NarrowPaths(bucket, prefixWithoutDate, delimiter, inclusiveStartDate, startDate, inclusiveEndDate, endDate, ignoreHours) + case Right((prefixWithDate, date)) if isGoodDate(date) => Stream(s"s3n://$bucket/$prefixWithDate") + case Right(_) => Stream.empty + } + } + + private def s3List(path: String, + inclusiveStartDate: Boolean, + startDate: Option[DateTime], + inclusiveEndDate: Boolean, + endDate: Option[DateTime], + exclusionPattern: Option[String]) + (implicit s3: AmazonS3Client, dateExtractor: PathDateExtractor): Stream[S3ObjectSummary] = { + + val s3Pattern = "s3n?://([^/]+)(.+)".r + + def extractBucketAndPrefix(path: String): Option[(String, String)] = path match { + case s3Pattern(bucket, prefix) => Option(bucket -> prefix.dropWhile(_ == '/')) + case _ => None + } + + extractBucketAndPrefix(path) match { + case Some((pathBucket, pathPrefix)) => + s3NarrowPaths(pathBucket, pathPrefix, inclusiveStartDate = inclusiveStartDate, inclusiveEndDate = inclusiveEndDate, + startDate = startDate, endDate = endDate).flatMap(extractBucketAndPrefix).flatMap { + case (bucket, prefix) => s3ListObjects(bucket, prefix) + } + case _ => Stream.empty + } + } + + def listAndFilterFiles(path: String, + requireSuccess: Boolean = false, + inclusiveStartDate: Boolean = true, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = true, + endDate: Option[DateTime] = None, + lastN: Option[Int] = None, + ignoreMalformedDates: Boolean = false, + endsWith: Option[String] = None, + exclusionPattern: Option[String] = Option(".*_temporary.*|.*_\\$folder.*"), + predicate: HadoopFile => Boolean = _ => true) + (implicit dateExtractor: PathDateExtractor): List[HadoopFile] = { + + def isSuccessFile(file: HadoopFile): Boolean = + file.path.endsWith("_SUCCESS") || file.path.endsWith("_FINISHED") + + def extractDateFromFile(file: HadoopFile): Option[DateTime] = + Try(dateExtractor.extractFromPath(file.path)).toOption + + def excludePatternValidation(file: HadoopFile): Option[HadoopFile] = + exclusionPattern match { + case Some(pattern) if file.path.matches(pattern) => None + case Some(_) | None => Option(file) + } + + def endsWithValidation(file: HadoopFile): Option[HadoopFile] = + endsWith match { + case Some(pattern) if file.path.endsWith(pattern) => Option(file) + case Some(_) if isSuccessFile(file) => Option(file) + case Some(_) => None + case None => Option(file) + } + + def applyPredicate(file: HadoopFile): Option[HadoopFile] = + if (predicate(file)) Option(file) else None + + def dateValidation(file: HadoopFile): Option[HadoopFile] = { + val tryDate = extractDateFromFile(file) + if (tryDate.isEmpty && ignoreMalformedDates) + None + else { + val date = tryDate.get + val goodStartDate = startDate.isEmpty || (inclusiveStartDate && date.saneEqual(startDate.get) || date.isAfter(startDate.get)) + val goodEndDate = endDate.isEmpty || (inclusiveEndDate && date.saneEqual(endDate.get) || date.isBefore(endDate.get)) + if (goodStartDate && goodEndDate) Some(file) else None + } + } + + val preValidations: HadoopFile => Boolean = hadoopFile => { + val validatedFile = for { + _ <- excludePatternValidation(hadoopFile) + _ <- endsWithValidation(hadoopFile) + _ <- dateValidation(hadoopFile) + valid <- applyPredicate(hadoopFile) + } yield valid + validatedFile.isDefined + } + + val preFilteredFiles = smartList(path, inclusiveStartDate = inclusiveStartDate, inclusiveEndDate = inclusiveEndDate, + startDate = startDate, endDate = endDate, exclusionPattern = exclusionPattern).filter(preValidations) + + val filesByDate = preFilteredFiles.groupBy(extractDateFromFile).collect { + case (Some(date), files) => date -> files + } + + val posFilteredFiles = + if (requireSuccess) + filesByDate.filter { case (_, files) => files.exists(isSuccessFile) } + else + filesByDate + + val allFiles = if (lastN.isDefined) + posFilteredFiles.toList.sortBy(_._1).reverse.take(lastN.get).flatMap(_._2) + else + posFilteredFiles.toList.flatMap(_._2) + + allFiles.sortBy(_.path) + } + + def smartList(path: String, + inclusiveStartDate: Boolean = false, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = false, + endDate: Option[DateTime] = None, + exclusionPattern: Option[String] = None)(implicit pathDateExtractor: PathDateExtractor): Stream[HadoopFile] = { + + def toHadoopFile(s3Object: S3ObjectSummary): HadoopFile = + HadoopFile(s"s3n://${s3Object.getBucketName}/${s3Object.getKey}", isDir = false, s3Object.getSize) + + def listPath(path: String): Stream[HadoopFile] = { + if (path.startsWith("s3")) { + s3List(path, inclusiveStartDate = inclusiveStartDate, startDate = startDate, inclusiveEndDate = inclusiveEndDate, + endDate = endDate, exclusionPattern = exclusionPattern)(amazonS3ClientFromEnvironmentVariables, pathDateExtractor ).map(toHadoopFile) + } else { + driverListFiles(path).toStream + } + } + + HadoopUtils.getPathStrings(path).toStream.flatMap(listPath) + } + + def filterAndGetParallelTextFiles(path: String, + requireSuccess: Boolean = false, + inclusiveStartDate: Boolean = true, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = true, + endDate: Option[DateTime] = None, + lastN: Option[Int] = None, + ignoreMalformedDates: Boolean = false, + endsWith: Option[String] = None, + predicate: HadoopFile => Boolean = _ => true, + maxBytesPerPartition: Long = 256 * 1000 * 1000, + minPartitions: Int = 100, + sizeBasedFileHandling: SizeBasedFileHandling = SizeBasedFileHandling(), + minimumFiles: Int = 1, + synchLocally: Option[String] = None, + forceSynch: Boolean = false) + (implicit dateExtractor: PathDateExtractor): RDD[String] = { + + val foundFiles = listAndFilterFiles(path, requireSuccess, inclusiveStartDate, startDate, inclusiveEndDate, + endDate, lastN, ignoreMalformedDates, endsWith, predicate = predicate) + + if (foundFiles.size < minimumFiles) + throw new Exception(s"Tried with start/end time equals to $startDate/$endDate for path $path but but the resulting number of files $foundFiles is less than the required") + + parallelReadTextFiles(foundFiles, maxBytesPerPartition = maxBytesPerPartition, minPartitions = minPartitions, + sizeBasedFileHandling = sizeBasedFileHandling, synchLocally = synchLocally, forceSynch = forceSynch) + } + + private def doSync(hadoopFiles: List[HadoopFile], + synchLocally: String, + forceSynch: Boolean, + maxBytesPerPartition: Long = 256 * 1000 * 1000, + minPartitions: Int = 100, + sizeBasedFileHandling: SizeBasedFileHandling = SizeBasedFileHandling()): RDD[String] = { + require(!synchLocally.contains("*"), "Globs are not supported on the sync key") + + def syncPath(suffix: String) = s"$hdfsPathPrefix/_core_ignition_sync_hdfs_cache/$suffix" + + val hashKey = Integer.toHexString(hadoopFiles.toSet.hashCode()) + + lazy val foundLocalPaths = getStatus(syncPath(s"$synchLocally/$hashKey/{_SUCCESS,_FINISHED}"), removeEmpty = false) + + val cacheKey = syncPath(s"$synchLocally/$hashKey") + + if (forceSynch || foundLocalPaths.isEmpty) { + delete(new Path(syncPath(s"$synchLocally/"))) + val data = parallelReadTextFiles(hadoopFiles, maxBytesPerPartition, minPartitions, synchLocally = None) + data.saveAsTextFile(cacheKey) + } + + sc.textFile(cacheKey) + } + } } diff --git a/src/main/scala/ignition/core/utils/AutoCloseableIterator.scala b/src/main/scala/ignition/core/utils/AutoCloseableIterator.scala new file mode 100644 index 00000000..4e3db808 --- /dev/null +++ b/src/main/scala/ignition/core/utils/AutoCloseableIterator.scala @@ -0,0 +1,67 @@ +package ignition.core.utils + +import scala.util.Try +import scala.util.control.NonFatal + +object AutoCloseableIterator { + case object empty extends AutoCloseableIterator[Nothing] { + override def naiveHasNext() = false + override def naiveNext() = throw new Exception("Empty AutoCloseableIterator") + override def naiveClose() = {} + } + + def wrap[T](iterator: Iterator[T], doClose: () => Unit = () => ()): AutoCloseableIterator[T] = new AutoCloseableIterator[T] { + override def naiveClose(): Unit = doClose() + override def naiveHasNext(): Boolean = iterator.hasNext + override def naiveNext(): T = iterator.next() + } +} + +trait AutoCloseableIterator[T] extends Iterator[T] with AutoCloseable { + // Naive functions should be implemented by the user as in a standard Iterator/AutoCloseable + def naiveHasNext(): Boolean + def naiveNext(): T + def naiveClose(): Unit + + var closed = false + + // hasNext closes the iterator and handles the case where it is already closed + override def hasNext: Boolean = if (closed) + false + else { + val naiveResult = try { + naiveHasNext + } catch { + case NonFatal(e) => + Try { close } + throw e + } + if (naiveResult) + true + else { + close // auto close when exhausted + false + } + } + + // next closes the iterator and handles the case where it is already closed + override def next(): T = if (closed) + throw new RuntimeException("Trying to get next element on a closed iterator") + else if (hasNext) + try { + naiveNext + } catch { + case NonFatal(e) => + Try { close } + throw e + } + else + throw new RuntimeException("Trying to get next element on an exhausted iterator") + + override def close() = if (!closed) { + closed = true + naiveClose + } + + override def finalize() = Try { close } +} diff --git a/src/main/scala/ignition/core/utils/BetterTrace.scala b/src/main/scala/ignition/core/utils/BetterTrace.scala new file mode 100644 index 00000000..387f49f7 --- /dev/null +++ b/src/main/scala/ignition/core/utils/BetterTrace.scala @@ -0,0 +1,14 @@ +package ignition.core.utils + +import ignition.core.utils.ExceptionUtils._ +// Used mainly to augment scalacheck traces in scalatest +trait BetterTrace { + def fail(message: String): Nothing + def withBetterTrace(block: => Unit): Unit = + try { + block + } catch { + case t: Throwable => fail(s"${t.getMessage}: ${t.getFullStackTraceString}") + } + +} diff --git a/src/main/scala/ignition/core/utils/CollectionUtils.scala b/src/main/scala/ignition/core/utils/CollectionUtils.scala index 27977270..f98fb7ec 100644 --- a/src/main/scala/ignition/core/utils/CollectionUtils.scala +++ b/src/main/scala/ignition/core/utils/CollectionUtils.scala @@ -6,7 +6,32 @@ import scalaz.Validation object CollectionUtils { + + + implicit class SeqImprovements[A](xs: Seq[A]) { + def orElseIfEmpty[B >: A](alternative: => Seq[B]): Seq[B] = { + if (xs.nonEmpty) + xs + else + alternative + } + } + implicit class TraversableOnceImprovements[A](xs: TraversableOnce[A]) { + def maxOption(implicit cmp: Ordering[A]): Option[A] = { + if (xs.isEmpty) + None + else + Option(xs.max) + } + + def minOption(implicit cmp: Ordering[A]): Option[A] = { + if (xs.isEmpty) + None + else + Option(xs.min) + } + def maxByOption[B](f: A => B)(implicit cmp: Ordering[B]): Option[A] = { if (xs.isEmpty) None @@ -22,6 +47,12 @@ object CollectionUtils { } } + + + implicit class TraversableOnceLong(xs: TraversableOnce[Long]) { + def toBag(): IntBag = IntBag.from(xs) + } + implicit class TraversableLikeImprovements[A, Repr](xs: TraversableLike[A, Repr]) { def distinctBy[B, That](f: A => B)(implicit cbf: CanBuildFrom[Repr, A, That]) = { val builder = cbf(xs.repr) @@ -59,6 +90,7 @@ object CollectionUtils { builder.result } + } implicit class ValidatedIterableLike[T, R, Repr <: IterableLike[Validation[R, T], Repr]](seq: IterableLike[Validation[R, T], Repr]) { @@ -103,4 +135,10 @@ object CollectionUtils { .toList } } + + + implicit class CollectionMap[K, V <: TraversableOnce[Any]](map: Map[K, V]) { + def removeEmpty(): Map[K, V] = + map.filter { case (k, v) => v.nonEmpty } + } } diff --git a/src/main/scala/ignition/core/utils/DateUtils.scala b/src/main/scala/ignition/core/utils/DateUtils.scala index 231817c7..8ebf3b13 100644 --- a/src/main/scala/ignition/core/utils/DateUtils.scala +++ b/src/main/scala/ignition/core/utils/DateUtils.scala @@ -1,6 +1,6 @@ package ignition.core.utils -import org.joda.time.{Period, DateTimeZone, DateTime} +import org.joda.time.{Seconds, Period, DateTimeZone, DateTime} import org.joda.time.format.ISODateTimeFormat object DateUtils { @@ -20,5 +20,16 @@ object DateUtils { def isEqualOrBefore(other: DateTime) = dateTime.isBefore(other) || dateTime.saneEqual(other) + + def isBetween(start: DateTime, end: DateTime) = + dateTime.isAfter(start) && dateTime.isEqualOrBefore(end) + } + + implicit class SecondsImprovements(val seconds: Seconds) { + + implicit def toScalaDuration: scala.concurrent.duration.FiniteDuration = { + scala.concurrent.duration.Duration(seconds.getSeconds, scala.concurrent.duration.SECONDS) + } + } } diff --git a/src/main/scala/ignition/core/utils/ExceptionUtils.scala b/src/main/scala/ignition/core/utils/ExceptionUtils.scala new file mode 100644 index 00000000..1ae33568 --- /dev/null +++ b/src/main/scala/ignition/core/utils/ExceptionUtils.scala @@ -0,0 +1,9 @@ +package ignition.core.utils + +object ExceptionUtils { + + implicit class ExceptionImprovements(e: Throwable) { + def getFullStackTraceString(): String = org.apache.commons.lang.exception.ExceptionUtils.getFullStackTrace(e) + } + +} diff --git a/src/main/scala/ignition/core/utils/FutureUtils.scala b/src/main/scala/ignition/core/utils/FutureUtils.scala index 068d63bc..4523a94f 100644 --- a/src/main/scala/ignition/core/utils/FutureUtils.scala +++ b/src/main/scala/ignition/core/utils/FutureUtils.scala @@ -1,14 +1,33 @@ package ignition.core.utils -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.util.{Failure, Success} +import scala.concurrent.{ExecutionContext, Future, Promise, blocking, future} +import scala.util.control.NonFatal +import scala.util.{Failure, Success, Try} object FutureUtils { + def blockingFuture[T](body: =>T)(implicit ec: ExecutionContext): Future[T] = Future { blocking { body } } + implicit class FutureImprovements[V](future: Future[V]) { def toOptionOnFailure(errorHandler: (Throwable) => Option[V])(implicit ec: ExecutionContext): Future[Option[V]] = { future.map(Option.apply).recover { case t => errorHandler(t) } } + + /** + * Appear to be redundant. But its the only way to map a future with + * Success and Failure in same algorithm without split it to use map/recover + * or transform. + * + * future.asTry.map { case Success(v) => 1; case Failure(e) => 0 } + * + * instead + * + * future.map(i=>1).recover(case _: Exception => 0) + * + */ + def asTry()(implicit ec: ExecutionContext) : Future[Try[V]] = { + future.map(v => Success(v)).recover { case NonFatal(e) => Failure(e) } + } } implicit class FutureGeneratorImprovements[V](generator: Iterable[() => Future[V]]){ diff --git a/src/main/scala/ignition/core/utils/IntBag.scala b/src/main/scala/ignition/core/utils/IntBag.scala new file mode 100644 index 00000000..a322f6f7 --- /dev/null +++ b/src/main/scala/ignition/core/utils/IntBag.scala @@ -0,0 +1,42 @@ +package ignition.core.utils + +object IntBag { + def from(numbers: TraversableOnce[Long]): IntBag = { + val histogram = scala.collection.mutable.HashMap.empty[Long, Long] + numbers.foreach(n => histogram += (n -> (histogram.getOrElse(n, 0L) + 1))) + IntBag(histogram) + } + + val empty = from(Seq.empty) +} + +case class IntBag(histogram: collection.Map[Long, Long]) { + def ++(other: IntBag): IntBag = { + val newHistogram = scala.collection.mutable.HashMap.empty[Long, Long] + (histogram.keySet ++ other.histogram.keySet).foreach(k => newHistogram += (k -> (histogram.getOrElse(k, 0L) + other.histogram.getOrElse(k, 0L)))) + new IntBag(newHistogram) + } + + + def median: Option[Long] = { + if (histogram.nonEmpty) { + val total = histogram.values.sum + val half = total / 2 + val max = histogram.keys.max + + val accumulatedFrequency = (0L to max).scanLeft(0L) { case (sumFreq, k) => sumFreq + histogram.getOrElse(k, 0L) }.zipWithIndex + accumulatedFrequency.collectFirst { case (sum, k) if sum >= half => k } + } else { + None + } + } + + def avg: Option[Long] = { + if (histogram.nonEmpty) { + val sum = histogram.map { case (k, f) => k * f }.sum + val count = histogram.values.sum + Option(sum / count) + } else + None + } +} diff --git a/src/main/scala/ignition/core/utils/S3Client.scala b/src/main/scala/ignition/core/utils/S3Client.scala index f02d7acd..b806b376 100644 --- a/src/main/scala/ignition/core/utils/S3Client.scala +++ b/src/main/scala/ignition/core/utils/S3Client.scala @@ -26,9 +26,9 @@ class S3Client { null, null, jets3tProperties ) - def writeContent(bucket: String, key: String, content: String): S3Object = { + def writeContent(bucket: String, key: String, content: String, contentType: String = "text/plain"): S3Object = { val obj = new S3Object(key, content) - obj.setContentType("text/plain") + obj.setContentType(contentType) service.putObject(bucket, obj) } @@ -37,7 +37,18 @@ class S3Client { } def list(bucket: String, key: String): Array[S3Object] = { - service.listObjects(bucket, key, null, 99999L) + service.listObjectsChunked(bucket, key, null, 99999L, null, true).getObjects + } + + def copyFile(sourceBucket: String, sourceKey: String, + destBucket: String, destKey: String, + destContentType: Option[String] = None, + destContentEncoding: Option[String] = None): Unit = { + val destFile = new S3Object(destKey) + val replaceMetaData = destContentType.isDefined || destContentEncoding.isDefined + destContentEncoding.foreach(encoding => destFile.setContentEncoding(encoding)) + destContentType.foreach(contentType => destFile.setContentType(contentType)) + service.copyObject(sourceBucket, sourceKey, destBucket, destFile, replaceMetaData) } def fileExists(bucket: String, key: String): Boolean = { diff --git a/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala b/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala index c19579ce..548b2423 100644 --- a/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala +++ b/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala @@ -32,7 +32,18 @@ class CollectionUtilsSpec extends FlatSpec with ShouldMatchers { list.compressBy(_.value) shouldBe List(MyObj("p1", "v1"), MyObj("p1", "v2")) } + it should "provide orElseIfEmpty" in { + Seq.empty[String].orElseIfEmpty(Seq("something")) shouldBe Seq("something") + Seq("not empty").orElseIfEmpty(Seq("something")) shouldBe Seq("not empty") + } + + it should "provide maxOption and minOption" in { + Seq.empty[Int].maxOption shouldBe None + Seq(1, 3, 2).maxOption shouldBe Some(3) + Seq.empty[Int].minOption shouldBe None + Seq(1, 3, 2).minOption shouldBe Some(1) + } } diff --git a/src/test/scala/ignition/core/utils/IntBagSpec.scala b/src/test/scala/ignition/core/utils/IntBagSpec.scala new file mode 100644 index 00000000..b6694b12 --- /dev/null +++ b/src/test/scala/ignition/core/utils/IntBagSpec.scala @@ -0,0 +1,23 @@ +package ignition.core.utils + +import org.scalatest._ + +import scala.util.Random + +class IntBagSpec extends FlatSpec with ShouldMatchers { + + "IntBag" should "be built from sequence" in { + IntBag.from(Seq(1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 4)).histogram shouldBe Map(1 -> 2, 2 -> 3, 3 -> 1, 4 -> 5) + } + + it should "calculate the median and average" in { + val size = 1000 + val numbers = (0 until 1000).map(_ => Random.nextInt(400).toLong).toList + val bag = IntBag.from(numbers) + + bag.avg.get shouldBe numbers.sum / size + + // TODO: the median is only approximate and it could be better, improve it + } + +} diff --git a/tools/cluster.py b/tools/cluster.py index 3cf1828a..e9ad90f3 100755 --- a/tools/cluster.py +++ b/tools/cluster.py @@ -23,6 +23,7 @@ import getpass import json import glob +import webbrowser log = logging.getLogger() @@ -38,7 +39,9 @@ default_instance_type = 'r3.xlarge' default_spot_price = '0.10' default_worker_instances = '1' +default_executor_instances = '1' default_master_instance_type = 'm3.xlarge' +default_driver_heap_size = '12G' default_region = 'us-east-1' default_zone = default_region + 'b' default_key_id = 'ignition_key' @@ -46,16 +49,19 @@ default_ami = None # will be decided based on spark-ec2 list default_master_ami = None default_env = 'dev' -default_spark_version = '1.3.0' +default_spark_version = '1.5.1' +custom_builds = { +# '1.5.1': 'https://s3.amazonaws.com/chaordic-ignition-public/spark-1.5.1-bin-cdh4.7.1.tgz' +} default_spark_repo = 'https://github.com/chaordic/spark' default_remote_user = 'ec2-user' default_remote_control_dir = '/tmp/Ignition' default_collect_results_dir = '/tmp' -default_user_data = os.path.join(script_path, 'scripts', 'S05mount-disks') +default_user_data = os.path.join(script_path, 'scripts', 'noop') default_defaults_filename = 'cluster_defaults.json' default_spark_ec2_git_repo = 'https://github.com/chaordic/spark-ec2' -default_spark_ec2_git_branch = 'v4-yarn' +default_spark_ec2_git_branch = 'branch-1.4-merge' master_post_create_commands = [ @@ -201,14 +207,16 @@ def launch(cluster_name, slaves, tag=[], key_id=default_key_id, region=default_region, zone=default_zone, instance_type=default_instance_type, - ondemand=False, spot_price=default_spot_price, + ondemand=False, spot_price=default_spot_price, master_spot=False, user_data=default_user_data, security_group = None, vpc = None, vpc_subnet = None, master_instance_type=default_master_instance_type, wait_time='180', hadoop_major_version='2', - worker_instances=default_worker_instances, retries_on_same_cluster=5, + worker_instances=default_worker_instances, + executor_instances=default_executor_instances, + retries_on_same_cluster=5, max_clusters_to_create=5, minimum_percentage_healthy_slaves=0.9, remote_user=default_remote_user, @@ -251,9 +259,13 @@ def launch(cluster_name, slaves, ]) spot_params = ['--spot-price', spot_price] if not ondemand else [] + master_spot_params = ['--master-spot'] if not ondemand and master_spot else [] + ami_params = ['--ami', ami] if ami else [] master_ami_params = ['--master-ami', master_ami] if master_ami else [] + spark_version = custom_builds.get(spark_version, spark_version) + for i in range(retries_on_same_cluster): log.info('Running script, try %d of %d', i + 1, retries_on_same_cluster) try: @@ -269,12 +281,14 @@ def launch(cluster_name, slaves, '--spark-ec2-git-repo', spark_ec2_git_repo, '--spark-ec2-git-branch', spark_ec2_git_branch, '--worker-instances', worker_instances, + '--executor-instances', executor_instances, '--master-opts', '-Dspark.worker.timeout={0}'.format(worker_timeout), '--spark-git-repo', spark_repo, '-v', spark_version, '--user-data', user_data, 'launch', cluster_name] + spot_params + + master_spot_params + resume_param + auth_params + ami_params + @@ -372,7 +386,9 @@ def job_run(cluster_name, job_name, job_mem, disable_assembly_build=False, run_tests=False, kill_on_failure=False, - destroy_cluster=False, region=default_region): + destroy_cluster=False, + region=default_region, + driver_heap_size=default_driver_heap_size): utc_job_date_example = '2014-05-04T13:13:10Z' if utc_job_date and len(utc_job_date) != len(utc_job_date_example): @@ -393,10 +409,10 @@ def job_run(cluster_name, job_name, job_mem, job_date = utc_job_date or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') job_tag = job_tag or job_date.replace(':', '_').replace('-', '_').replace('Z', 'UTC') tmux_wait_command = ';(echo Press enter to keep the session open && /bin/bash -c "read -t 5" && sleep 7d)' if not detached else '' - tmux_arg = ". /etc/profile; . ~/.profile;tmux new-session {detached} -s spark.{job_name}.{job_tag} '{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} {tmux_wait_command}' >& /tmp/commandoutput".format( - aws_vars=get_aws_keys_str(), job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, detached='-d' if detached else '', yarn_param=yarn_param, notify_param=notify_param, tmux_wait_command=tmux_wait_command) - non_tmux_arg = ". /etc/profile; . ~/.profile;{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} >& /tmp/commandoutput".format( - aws_vars=get_aws_keys_str(), job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, yarn_param=yarn_param, notify_param=notify_param) + tmux_arg = ". /etc/profile; . ~/.profile;tmux new-session {detached} -s spark.{job_name}.{job_tag} '{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} {driver_heap_size} {tmux_wait_command}' >& /tmp/commandoutput".format( + aws_vars=get_aws_keys_str(), job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, detached='-d' if detached else '', yarn_param=yarn_param, notify_param=notify_param, driver_heap_size=driver_heap_size, tmux_wait_command=tmux_wait_command) + non_tmux_arg = ". /etc/profile; . ~/.profile;{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} {driver_heap_size} >& /tmp/commandoutput".format( + aws_vars=get_aws_keys_str(), job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, yarn_param=yarn_param, notify_param=notify_param, driver_heap_size=driver_heap_size) if not disable_assembly_build: @@ -421,6 +437,9 @@ def job_run(cluster_name, job_name, job_mem, src_local=remote_hook_local, remote_path=with_leading_slash(remote_path)) + if job_name == "zeppelin": + webbrowser.open("http://{master}:8081".format(master=master)) + log.info('Will run job in remote host') if disable_tmux: ssh_call(user=remote_user, host=master, key_file=key_file, args=[non_tmux_arg], allocate_terminal=False) diff --git a/tools/scripts/S05mount-disks b/tools/scripts/S05mount-disks deleted file mode 100644 index 8f129a30..00000000 --- a/tools/scripts/S05mount-disks +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -echo 'Mounting disks' >> /tmp/mount-disks.log -mkdir -p /mnt -mkdir -p /mnt{2,3,4} -chmod -R 777 /mnt* -[ -r /dev/xvdb ] && mkfs.ext4 /dev/xvdb && mount /dev/xvdb /mnt -[ -r /dev/xvdc ] && mkfs.ext4 /dev/xvdc && mount /dev/xvdc /mnt2 -[ -r /dev/xvdd ] && mkfs.ext4 /dev/xvdd && mount /dev/xvdd /mnt3 -[ -r /dev/xvde ] && mkfs.ext4 /dev/xvde && mount /dev/xvde /mnt4 - diff --git a/tools/scripts/noop b/tools/scripts/noop new file mode 100644 index 00000000..cc1f786e --- /dev/null +++ b/tools/scripts/noop @@ -0,0 +1 @@ +#!/bin/bash \ No newline at end of file diff --git a/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh index 3570891b..bd3b656f 100644 --- a/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ b/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh @@ -25,8 +25,11 @@ export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" export MODULES="{{modules}}" export SPARK_VERSION="{{spark_version}}" -export SHARK_VERSION="{{shark_version}}" +export TACHYON_VERSION="{{tachyon_version}}" export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" export SWAP_MB="{{swap}}" export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" +export SPARK_EXECUTOR_INSTANCES="{{spark_executor_instances}}" export SPARK_MASTER_OPTS="{{spark_master_opts}}" +export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" +export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/tools/spark-ec2/spark_ec2.py b/tools/spark-ec2/spark_ec2.py index 5fdf0467..e9442448 100755 --- a/tools/spark-ec2/spark_ec2.py +++ b/tools/spark-ec2/spark_ec2.py @@ -19,9 +19,11 @@ # limitations under the License. # -from __future__ import with_statement +from __future__ import division, print_function, with_statement +import codecs import hashlib +import itertools import logging import os import os.path @@ -36,13 +38,20 @@ import tempfile import textwrap import time -import urllib2 import warnings from datetime import datetime from optparse import OptionParser from sys import stderr -SPARK_EC2_VERSION = "1.3.0" +if sys.version < "3": + from urllib2 import urlopen, Request, HTTPError +else: + from urllib.request import urlopen, Request + from urllib.error import HTTPError + raw_input = input + xrange = range + +SPARK_EC2_VERSION = "1.5.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -60,14 +69,90 @@ "1.2.0", "1.2.1", "1.3.0", + "1.3.1", + "1.4.0", + "1.4.1", + "1.5.0", + "1.5.1", ]) +SPARK_TACHYON_MAP = { + "1.0.0": "0.4.1", + "1.0.1": "0.4.1", + "1.0.2": "0.4.1", + "1.1.0": "0.5.0", + "1.1.1": "0.5.0", + "1.2.0": "0.5.0", + "1.2.1": "0.5.0", + "1.3.0": "0.5.0", + "1.3.1": "0.5.0", + "1.4.0": "0.6.4", + "1.4.1": "0.6.4", + "1.5.0": "0.7.1", + "1.5.1": "0.7.1", +} + DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" # Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" +DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" + + +def setup_external_libs(libs): + """ + Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. + """ + PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" + SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") + + if not os.path.exists(SPARK_EC2_LIB_DIR): + print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( + path=SPARK_EC2_LIB_DIR + )) + print("This should be a one-time operation.") + os.mkdir(SPARK_EC2_LIB_DIR) + + for lib in libs: + versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) + lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) + + if not os.path.isdir(lib_dir): + tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") + print(" - Downloading {lib}...".format(lib=lib["name"])) + download_stream = urlopen( + "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( + prefix=PYPI_URL_PREFIX, + first_letter=lib["name"][:1], + lib_name=lib["name"], + lib_version=lib["version"] + ) + ) + with open(tgz_file_path, "wb") as tgz_file: + tgz_file.write(download_stream.read()) + with open(tgz_file_path, "rb") as tar: + if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: + print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) + sys.exit(1) + tar = tarfile.open(tgz_file_path) + tar.extractall(path=SPARK_EC2_LIB_DIR) + tar.close() + os.remove(tgz_file_path) + print(" - Finished downloading {lib}.".format(lib=lib["name"])) + sys.path.insert(1, lib_dir) + + +# Only PyPI libraries are supported. +external_libs = [ + { + "name": "boto", + "version": "2.34.0", + "md5": "5556223d2d0cc4d06dd4829e671dcecd" + } +] + +setup_external_libs(external_libs) import boto from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType @@ -107,7 +192,7 @@ def parse_args(): help="Master instance type (leave empty for same as instance-type)") parser.add_option( "-r", "--region", default="us-east-1", - help="EC2 region zone to launch instances in") + help="EC2 region used to launch instances in, or to find them in (default: %default)") parser.add_option( "-z", "--zone", default="", help="Availability zone to launch instances in, or 'all' to spread " + @@ -133,9 +218,19 @@ def parse_args(): "--spark-ec2-git-branch", default=DEFAULT_SPARK_EC2_BRANCH, help="Github repo branch of spark-ec2 to use (default: %default)") + parser.add_option( + "--deploy-root-dir", + default=None, + help="A directory to copy into / on the first master. " + + "Must be absolute. Note that a trailing slash is handled as per rsync: " + + "If you omit it, the last directory of the --deploy-root-dir path will be created " + + "in / before copying its contents. If you append the trailing slash, " + + "the directory is not created and its contents are copied directly into /. " + + "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") + help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.7.1), yarn " + + "(Hadoop 2.4.0) (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -155,7 +250,7 @@ def parse_args(): help="Number of EBS volumes to attach to each node as /vol[x]. " + "The volumes will be deleted when the instances terminate. " + "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") parser.add_option( "--placement-group", type="string", default=None, @@ -169,6 +264,10 @@ def parse_args(): "--spot-price", metavar="PRICE", type="float", help="If specified, launch slaves as spot instances with the given " + "maximum price (in dollars)") + parser.add_option( + "--master-spot", action="store_true", default=False, + help="If specified, launch master as spot instance using the same " + + "bid and instance type of the slave ones") parser.add_option( "--ganglia", action="store_true", default=True, help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + @@ -187,14 +286,19 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") + parser.add_option( + "--executor-instances", type="int", default=1, + help="Number of executor instances per worker: variable SPARK_EXECUTOR_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + "(e.g -Dspark.worker.timeout=180)") parser.add_option( "--user-data", type="string", default="", - help="Path to a user-data file (most AMI's interpret this as an initialization script)") + help="Path to a user-data file (most AMIs interpret this as an initialization script)") parser.add_option( "--security-group-prefix", type="string", default=None, help="Use this prefix for the security group rather than the cluster name.") @@ -204,6 +308,10 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") @@ -216,6 +324,17 @@ def parse_args(): parser.add_option( "--spot-timeout", type="int", default=45, help="Maximum amount of time (in minutes) to wait for spot requests to be fulfilled") + parser.add_option( + "--private-ips", action="store_true", default=False, + help="Use private IPs for instances rather than public if VPC/subnet " + + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="terminate", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -228,14 +347,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " + - "must be set") - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " + - "must be set") - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) @@ -246,19 +367,19 @@ def get_or_make_group(conn, name, vpc_id): if len(group) > 0: return group[0] else: - print "Creating security group " + name + print("Creating security group " + name) return conn.create_security_group(name, "Spark EC2 group", vpc_id) def check_if_http_resource_exists(resource): - request = urllib2.Request(resource) + request = Request(resource) request.get_method = lambda: 'HEAD' try: - response = urllib2.urlopen(request) + response = urlopen(request) if response.getcode() == 200: return True else: raise RuntimeError("Resource {resource} not found. Error: {code}".format(resource, response.getcode())) - except urllib2.HTTPError, e: + except HTTPError, e: print >> stderr, "Unable to check if HTTP resource {url} exists. Error: {code}".format( url=resource, code=e.code) @@ -270,12 +391,12 @@ def get_validate_spark_version(version, repo): if check_if_http_resource_exists: return version else: - print >> stderr, "Unable to validate pre-built spark version {version}".format(version=version) + print("Unable to validate pre-built spark version {version}".format(version=version), file=stderr) sys.exit(1) elif "." in version: version = version.replace("v", "") if version not in VALID_SPARK_VERSIONS: - print >> stderr, "Don't know about Spark version: {v}".format(v=version) + print("Don't know about Spark version: {v}".format(v=version), file=stderr) sys.exit(1) return version else: @@ -288,70 +409,77 @@ def get_validate_spark_version(version, repo): return version -# Check whether a given EC2 instance object is in a state we consider active, -# i.e. not terminating or terminated. We count both stopping and stopped as -# active since we can restart stopped clusters. -def is_active(instance): - return (instance.state in ['pending', 'running', 'stopping', 'stopped']) - - # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2014-06-20 +# Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. EC2_INSTANCE_TYPES = { "c1.medium": "pvm", "c1.xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", "c3.2xlarge": "pvm", "c3.4xlarge": "pvm", "c3.8xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", + "c4.large": "hvm", + "c4.xlarge": "hvm", + "c4.2xlarge": "hvm", + "c4.4xlarge": "hvm", + "c4.8xlarge": "hvm", "cc1.4xlarge": "hvm", "cc2.8xlarge": "hvm", "cg1.4xlarge": "hvm", "cr1.8xlarge": "hvm", + "d2.xlarge": "hvm", + "d2.2xlarge": "hvm", + "d2.4xlarge": "hvm", + "d2.8xlarge": "hvm", + "g2.2xlarge": "hvm", + "g2.8xlarge": "hvm", "hi1.4xlarge": "pvm", "hs1.8xlarge": "pvm", + "i2.xlarge": "hvm", "i2.2xlarge": "hvm", "i2.4xlarge": "hvm", "i2.8xlarge": "hvm", - "i2.xlarge": "hvm", - "m1.large": "pvm", - "m1.medium": "pvm", "m1.small": "pvm", + "m1.medium": "pvm", + "m1.large": "pvm", "m1.xlarge": "pvm", + "m2.xlarge": "pvm", "m2.2xlarge": "pvm", "m2.4xlarge": "pvm", - "m2.xlarge": "pvm", - "m3.2xlarge": "hvm", - "m3.large": "hvm", "m3.medium": "hvm", + "m3.large": "hvm", "m3.xlarge": "hvm", + "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", "r3.2xlarge": "hvm", "r3.4xlarge": "hvm", "r3.8xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", "t1.micro": "pvm", - "t2.medium": "hvm", "t2.micro": "hvm", "t2.small": "hvm", - "d2.2xlarge": "hvm", - "d2.4xlarge": "hvm", - "d2.8xlarge": "hvm", - "d2.large": "hvm", - "d2.xlarge": "hvm", + "t2.medium": "hvm", + "t2.large": "hvm", } +def get_tachyon_version(spark_version): + return SPARK_TACHYON_MAP.get(spark_version, "") + # Attempt to resolve an appropriate AMI given the architecture and region of the request. def get_spark_ami(instance_type, region, spark_ec2_git_repo, spark_ec2_git_branch): if instance_type in EC2_INSTANCE_TYPES: instance_type = EC2_INSTANCE_TYPES[instance_type] else: instance_type = "pvm" - print >> stderr,\ - "Don't recognize %s, assuming type is pvm" % instance_type + print("Don't recognize %s, assuming type is pvm" % instance_type, file=stderr) # URL prefix from which to fetch AMI information ami_prefix = "{r}/{b}/ami-list".format( @@ -359,27 +487,27 @@ def get_spark_ami(instance_type, region, spark_ec2_git_repo, spark_ec2_git_branc b=spark_ec2_git_branch) ami_path = "%s/%s/%s" % (ami_prefix, region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urllib2.urlopen(ami_path).read().strip() - print "Spark AMI for %s: %s" % (instance_type, ami) + ami = reader(urlopen(ami_path)).read().strip() except: - print >> stderr, "Could not resolve AMI at: " + ami_path + print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami - # Launch a cluster of the given name, by setting up its security groups, # and then starting new instances in them. # Returns a tuple of EC2 reservation objects for the master and slaves # Fails if there already instances running in the cluster's groups. def launch_cluster(conn, opts, cluster_name): if opts.identity_file is None: - print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." + print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) sys.exit(1) if opts.key_pair is None: - print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." + print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) sys.exit(1) user_data_content = None @@ -387,7 +515,7 @@ def launch_cluster(conn, opts, cluster_name): with open(opts.user_data) as user_data_file: user_data_content = user_data_file.read() - print "Setting up security groups..." + print("Setting up security groups...") if opts.security_group_prefix is None: master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) @@ -421,6 +549,17 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) + # HDFS NFS gateway requires 111,2049,4242 for tcp & udp + master_group.authorize('tcp', 111, 111, authorized_address) + master_group.authorize('udp', 111, 111, authorized_address) + master_group.authorize('tcp', 2049, 2049, authorized_address) + master_group.authorize('udp', 2049, 2049, authorized_address) + master_group.authorize('tcp', 4242, 4242, authorized_address) + master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -451,8 +590,8 @@ def launch_cluster(conn, opts, cluster_name): existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print("ERROR: There are already instances running in group %s or %s" % + (master_group.name, slave_group.name), file=stderr) sys.exit(1) # Figure out Spark AMI @@ -460,7 +599,7 @@ def launch_cluster(conn, opts, cluster_name): opts.ami = get_spark_ami(opts.instance_type, opts.region, opts.spark_ec2_git_repo, opts.spark_ec2_git_branch) if opts.master_ami is None: - opts.master_ami = get_spark_ami(opts.master_instance_type, opts.region, opts.spark_ec2_git_repo, opts.spark_ec2_git_branch) + opts.master_ami = get_spark_ami(opts.master_instance_type, opts.region, opts.spark_ec2_git_repo, opts.spark_ec2_git_branch) # we use group ids to work around https://github.com/boto/boto/issues/350 additional_group_ids = [] @@ -468,12 +607,12 @@ def launch_cluster(conn, opts, cluster_name): additional_group_ids = [sg.id for sg in conn.get_all_security_groups() if opts.additional_security_group in (sg.name, sg.id)] - print "Launching instances..." + print("Launching instances...") try: image = conn.get_all_images(image_ids=[opts.ami])[0] except: - print >> stderr, "Could not find AMI " + opts.ami + print("Could not find AMI " + opts.ami, file=stderr) sys.exit(1) try: @@ -502,8 +641,8 @@ def launch_cluster(conn, opts, cluster_name): # Launch slaves if opts.spot_price is not None: # Launch spot instances with the requested price - print ("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) + print("Requesting %d slaves as spot instances with price $%.3f" % + (opts.slaves, opts.spot_price)) zones = get_zones(conn, opts) num_zones = len(zones) i = 0 @@ -522,12 +661,13 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 start_time = datetime.now() - print "Waiting for spot instances to be granted... Request IDs: %s " % my_req_ids + print("Waiting for spot instances to be granted... Request IDs: %s " % my_req_ids) try: while True: time.sleep(10) @@ -539,28 +679,28 @@ def launch_cluster(conn, opts, cluster_name): raise Exception("Invalid state for spot request: %s - status: %s" % (invalid[0].id, invalid[0].status.message)) if len(active_instance_ids) == opts.slaves: - print "All %d slaves granted" % opts.slaves - reservations = conn.get_all_reservations(active_instance_ids) + print("All %d slaves granted" % opts.slaves) + reservations = conn.get_all_reservations([r.instance_id for r in active_instance_ids]) slave_nodes = [] for r in reservations: slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print("%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves)) if (datetime.now() - start_time).seconds > opts.spot_timeout * 60: raise Exception("Timed out while waiting for spot instances") except: - print "Error: %s" % sys.exc_info()[1] - print "Canceling spot instance requests" + print("Error: %s" % sys.exc_info()[1]) + print("Canceling spot instance requests") conn.cancel_spot_instance_requests(my_req_ids) # Log a warning if any of these requests actually launched instances: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) running = len(master_nodes) + len(slave_nodes) if running: - print >> stderr, ("WARNING: %d instances are still running" % running) + print(("WARNING: %d instances are still running" % running), file=stderr) sys.exit(0) else: # Launch non-spot instances @@ -571,100 +711,184 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances - print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, - zone, slave_res.id) + print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( + s=num_slaves_this_zone, + plural_s=('' if num_slaves_this_zone == 1 else 's'), + z=zone, + r=slave_res.id)) i += 1 # Launch or resume masters if existing_masters: - print "Starting master..." + print("Starting master...") for inst in existing_masters: if inst.state not in ["shutting-down", "terminated"]: inst.start() master_nodes = existing_masters else: master_type = opts.master_instance_type - if master_type == "": + if master_type == "" or opts.master_spot: master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = master_image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) - - master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + if opts.master_spot: + # Launch spot master instance with the requested price + # Note: The spot_price*1.5 is present to ensure a higher bid price to + # the master spot instance, so the master instance will be the + # last one to be terminated in a spot market price increase + print("Requesting master as spot instance with price $%.3f" % + (opts.spot_price)) + master_req = conn.request_spot_instances( + price=(opts.spot_price * 1.5), + image_id=opts.master_ami, + placement=opts.zone, + count=1, + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) + my_master_req_id = [req.id for req in master_req] + + # TODO: refactor duplicated spot waiting code + start_time = datetime.now() + print("Waiting for master spot instance to be granted... Request ID: %s " % my_master_req_id) + try: + while True: + time.sleep(10) + reqs = conn.get_all_spot_instance_requests(my_master_req_id) + active_instance_ids = filter(lambda req: req.state == "active", reqs) + invalid_states = ["capacity-not-available", "capacity-oversubscribed", "price-too-low"] + invalid = filter(lambda req: req.status.code in invalid_states, reqs) + if len(invalid) > 0: + raise Exception("Invalid state for spot request: %s - status: %s" % + (invalid[0].id, invalid[0].status.message)) + if len(active_instance_ids) == 1: + print("Master spot instance granted") + master_res = conn.get_all_reservations([r.instance_id for r in active_instance_ids]) + master_nodes = master_res[0].instances + break + else: + print("Master spot instance not granted yet, waiting longer") + + if (datetime.now() - start_time).seconds > opts.spot_timeout * 60: + raise Exception("Timed out while waiting for master spot instance") + except: + print("Error: %s" % sys.exc_info()[1]) + print("Canceling master spot instance requests") + conn.cancel_spot_instance_requests(my_master_req_id) + # Log a warning if any of these requests actually launched instances: + (master_nodes, slave_nodes) = get_existing_cluster( + conn, opts, cluster_name, die_on_error=False) + running = len(master_nodes) + len(slave_nodes) + if running: + print(("WARNING: %d instances are still running" % running), file=stderr) + sys.exit(0) + else: + # Launch ondemand instance + master_res = master_image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) + + master_nodes = master_res.instances + print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 - print "Waiting for AWS to propagate instance metadata..." + print("Waiting for AWS to propagate instance metadata...") time.sleep(5) - # Give the instances descriptive names + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) + for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) -# Get the EC2 instances in an existing cluster if available. -# Returns a tuple of lists of EC2 instance objects for the masters and slaves def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - print "Searching for existing cluster " + cluster_name + "..." - reservations = conn.get_all_reservations() - master_nodes = [] - slave_nodes = [] - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for inst in active: - group_names = [g.name for g in inst.groups] - if (cluster_name + "-master") in group_names: - master_nodes.append(inst) - elif (cluster_name + "-slaves") in group_names: - slave_nodes.append(inst) - if any((master_nodes, slave_nodes)): - print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) - if master_nodes != [] or not die_on_error: - return (master_nodes, slave_nodes) - else: - if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" - else: - print >> sys.stderr, "ERROR: Could not find any existing cluster" + """ + Get the EC2 instances in an existing cluster if available. + Returns a tuple of lists of EC2 instance objects for the masters and slaves. + """ + print("Searching for existing cluster {c} in region {r}...".format( + c=cluster_name, r=opts.region)) + + def get_instances(group_names): + """ + Get all non-terminated instances that belong to any of the provided security groups. + + EC2 reservation filters and instance states are documented here: + http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options + """ + reservations = conn.get_all_reservations( + filters={"instance.group-name": group_names}) + instances = itertools.chain.from_iterable(r.instances for r in reservations) + return [i for i in instances if i.state not in ["shutting-down", "terminated"]] + + master_instances = get_instances([cluster_name + "-master"]) + slave_instances = get_instances([cluster_name + "-slaves"]) + + if any((master_instances, slave_instances)): + print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( + m=len(master_instances), + plural_m=('' if len(master_instances) == 1 else 's'), + s=len(slave_instances), + plural_s=('' if len(slave_instances) == 1 else 's'))) + + if not master_instances and die_on_error: + print("ERROR: Could not find a master for cluster {c} in region {r}.".format( + c=cluster_name, r=opts.region), file=sys.stderr) sys.exit(1) + return (master_instances, slave_instances) + # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. - - def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name + master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: - print "Generating cluster's SSH key on master..." + print("Generating cluster's SSH key on master...") key_setup = """ [ -f ~/.ssh/id_rsa ] || (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && @@ -672,24 +896,29 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): """ ssh(master, opts, key_setup) dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print "Transferring cluster's SSH key to slaves..." + print("Transferring cluster's SSH key to slaves...") for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) + slave_address = get_dns_name(slave, opts.private_ips) + print(slave_address) + ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') + # Clear SPARK_WORKER_INSTANCES if running on YARN + if opts.hadoop_major_version == "yarn": + opts.worker_instances = "" + # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch) + print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) ssh( host=master, opts=opts, @@ -699,7 +928,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): b=opts.spark_ec2_git_branch) ) - print "Deploying files to master..." + print("Deploying files to master...") deploy_files( conn=conn, root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", @@ -709,18 +938,26 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): modules=modules ) - print "Running setup on master..." + if opts.deploy_root_dir is not None: + print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) + deploy_user_files( + root_dir=opts.deploy_root_dir, + opts=opts, + master_nodes=master_nodes + ) + + print("Running setup on master...") setup_spark_cluster(master, opts) - print "Done!" + print("Done!") def setup_spark_cluster(master, opts): ssh(master, opts, "chmod u+x spark-ec2/setup.sh") ssh(master, opts, "spark-ec2/setup.sh") - print "Spark standalone cluster started at http://%s:8080" % master + print("Spark standalone cluster started at http://%s:8080" % master) if opts.ganglia: - print "Ganglia started at http://%s:5080/ganglia" % master + print("Ganglia started at http://%s:5080/ganglia" % master) def is_ssh_available(host, opts, print_ssh_output=True): @@ -737,7 +974,7 @@ def is_ssh_available(host, opts, print_ssh_output=True): if s.returncode != 0 and print_ssh_output: # extra leading newline is for spacing in wait_for_cluster_state() - print textwrap.dedent("""\n + print(textwrap.dedent("""\n Warning: SSH connection error. (This could be temporary.) Host: {h} SSH return code: {r} @@ -746,7 +983,7 @@ def is_ssh_available(host, opts, print_ssh_output=True): h=host, r=s.returncode, o=cmd_output.strip() - ) + )) return s.returncode == 0 @@ -756,7 +993,8 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + dns_name = get_dns_name(i, opts.private_ips) + if not is_ssh_available(host=dns_name, opts=opts): return False else: return True @@ -786,7 +1024,11 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): for i in cluster_instances: i.update() - statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) + max_batch = 100 + statuses = [] + for j in xrange(0, len(cluster_instances), max_batch): + batch = [i.id for i in cluster_instances[j:j + max_batch]] + statuses.extend(conn.get_all_instance_status(instance_ids=batch)) if cluster_state == 'ssh-ready': if all(i.state == 'running' for i in cluster_instances) and \ @@ -806,63 +1048,78 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): sys.stdout.write("\n") end_time = datetime.now() - print "Cluster is now in '{s}' state. Waited {t} seconds.".format( + print("Cluster is now in '{s}' state. Waited {t} seconds.".format( s=cluster_state, t=(end_time - start_time).seconds - ) + )) # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2014-06-20 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, "c1.xlarge": 4, + "c3.large": 2, + "c3.xlarge": 2, "c3.2xlarge": 2, "c3.4xlarge": 2, "c3.8xlarge": 2, - "c3.large": 2, - "c3.xlarge": 2, + "c4.large": 0, + "c4.xlarge": 0, + "c4.2xlarge": 0, + "c4.4xlarge": 0, + "c4.8xlarge": 0, "cc1.4xlarge": 2, "cc2.8xlarge": 4, "cg1.4xlarge": 2, "cr1.8xlarge": 2, + "d2.xlarge": 3, + "d2.2xlarge": 6, + "d2.4xlarge": 12, + "d2.8xlarge": 24, "g2.2xlarge": 1, + "g2.8xlarge": 2, "hi1.4xlarge": 2, "hs1.8xlarge": 24, + "i2.xlarge": 1, "i2.2xlarge": 2, "i2.4xlarge": 4, "i2.8xlarge": 8, - "i2.xlarge": 1, - "m1.large": 2, - "m1.medium": 1, "m1.small": 1, + "m1.medium": 1, + "m1.large": 2, "m1.xlarge": 4, + "m2.xlarge": 1, "m2.2xlarge": 1, "m2.4xlarge": 2, - "m2.xlarge": 1, - "m3.2xlarge": 2, - "m3.large": 1, "m3.medium": 1, + "m3.large": 1, "m3.xlarge": 2, + "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, + "r3.large": 1, + "r3.xlarge": 1, "r3.2xlarge": 1, "r3.4xlarge": 1, "r3.8xlarge": 2, - "r3.large": 1, - "r3.xlarge": 1, "t1.micro": 0, - 'd2.xlarge': 3, - 'd2.2xlarge': 6, - 'd2.4xlarge': 12, - 'd2.8xlarge': 24, + "t2.micro": 0, + "t2.small": 0, + "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] else: - print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type) + print("WARNING: Don't know number of disks on instance type %s; assuming 1" + % instance_type, file=stderr) return 1 @@ -874,7 +1131,7 @@ def get_num_disks(instance_type): # # root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) num_disks = get_num_disks(opts.instance_type) hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" @@ -891,17 +1148,28 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): if opts.spark_version.startswith("http"): # Custom pre-built spark package spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) + tachyon_v = "" + print("Deploying Spark via custom bunlde; Tachyon won't be set up") + modules = filter(lambda x: x != "tachyon", modules) elif "." in opts.spark_version: # Pre-built Spark deploy spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) + tachyon_v = get_tachyon_version(spark_v) else: # Spark-only custom deploy spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) - + tachyon_v = "" + print("Deploying Spark via git hash; Tachyon won't be set up") + modules = filter(lambda x: x != "tachyon", modules) + + master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] + slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] + worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" + executor_instances_str = "%d" % opts.executor_instances if opts.executor_instances else "" template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), + "master_list": '\n'.join(master_addresses), "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), + "slave_list": '\n'.join(slave_addresses), "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, @@ -909,8 +1177,10 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "swap": str(opts.swap), "modules": '\n'.join(modules), "spark_version": spark_v, + "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, + "spark_worker_instances": worker_instances_str, + "spark_executor_instances": executor_instances_str, "spark_master_opts": opts.master_opts } @@ -953,6 +1223,23 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): shutil.rmtree(tmp_dir) +# Deploy a given local directory to a cluster, WITHOUT parameter substitution. +# Note that unlike deploy_files, this works for binary files. +# Also, it is up to the user to add (or not) the trailing slash in root_dir. +# Files are only deployed to the first master instance in the cluster. +# +# root_dir should be an absolute path. +def deploy_user_files(root_dir, opts, master_nodes): + active_master = get_dns_name(master_nodes[0], opts.private_ips) + command = [ + 'rsync', '-rv', + '-e', stringify_command(ssh_command(opts)), + "%s" % root_dir, + "%s@%s:/" % (opts.user, active_master) + ] + subprocess.check_call(command) + + def stringify_command(parts): if isinstance(parts, str): return parts @@ -986,13 +1273,13 @@ def ssh(host, opts, command): # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + + "Failed to SSH to remote host {0}.\n" + "Please check that you have provided the correct --identity-file and " "--key-pair parameters and try again.".format(host)) else: raise e - print >> stderr, \ - "Error executing remote command, retrying after 30 seconds: {0}".format(e) + print("Error executing remote command, retrying after 30 seconds: {0}".format(e), + file=stderr) time.sleep(30) tries = tries + 1 @@ -1031,8 +1318,8 @@ def ssh_write(host, opts, command, arguments): elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: - print >> stderr, \ - "Error {0} while executing remote command, retrying after 30 seconds".format(status) + print("Error {0} while executing remote command, retrying after 30 seconds". + format(status), file=stderr) time.sleep(30) tries = tries + 1 @@ -1048,12 +1335,26 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone +# Gets the IP address, taking into account the --private-ips flag +def get_ip_address(instance, private_ips=False): + ip = instance.ip_address if not private_ips else \ + instance.private_ip_address + return ip + + +# Gets the DNS name, taking into account the --private-ips flag +def get_dns_name(instance, private_ips=False): + dns = instance.public_dns_name if not private_ips else \ + instance.private_ip_address + return dns + + def real_main(): (opts, action, cluster_name) = parse_args() @@ -1072,28 +1373,28 @@ def real_main(): if opts.identity_file is not None: if not os.path.exists(opts.identity_file): - print >> stderr,\ - "ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file) + print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), + file=stderr) sys.exit(1) file_mode = os.stat(opts.identity_file).st_mode if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print >> stderr, "ERROR: The identity file must be accessible only by you." - print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file) + print("ERROR: The identity file must be accessible only by you.", file=stderr) + print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), + file=stderr) sys.exit(1) if opts.instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type) + print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( + t=opts.instance_type), file=stderr) if opts.master_instance_type != "": if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, \ - "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type) + print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( + t=opts.master_instance_type), file=stderr) if opts.ebs_vol_num > 8: - print >> stderr, "ebs-vol-num cannot be greater than 8" + print("ebs-vol-num cannot be greater than 8", file=stderr) sys.exit(1) # Prevent breaking ami_prefix (/, .git and startswith checks) @@ -1102,15 +1403,22 @@ def real_main(): opts.spark_ec2_git_repo.endswith(".git") or \ not opts.spark_ec2_git_repo.startswith("https://github.com") or \ not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \ - "trailing / or .git. " \ - "Furthermore, we currently only support forks named spark-ec2." + print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " + "Furthermore, we currently only support forks named spark-ec2.", file=stderr) + sys.exit(1) + + if not (opts.deploy_root_dir is None or + (os.path.isabs(opts.deploy_root_dir) and + os.path.isdir(opts.deploy_root_dir) and + os.path.exists(opts.deploy_root_dir))): + print("--deploy-root-dir must be an absolute path to a directory that exists " + "on the local file system", file=stderr) sys.exit(1) try: conn = ec2.connect_to_region(opts.region) except Exception as e: - print >> stderr, (e) + print((e), file=stderr) sys.exit(1) # Select an AZ at random if it was not specified. @@ -1119,7 +1427,7 @@ def real_main(): if action == "launch": if opts.slaves <= 0: - print >> sys.stderr, "ERROR: You have to start at least 1 slave" + print("ERROR: You have to start at least 1 slave", file=sys.stderr) sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -1134,26 +1442,27 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": - print "Are you sure you want to destroy the cluster %s?" % cluster_name - print "The following instances will be terminated:" (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name - msg = "ALL DATA ON ALL NODES WILL BE LOST!!\nDestroy cluster %s (y/N): " % cluster_name + if any(master_nodes + slave_nodes): + print("The following instances will be terminated:") + for inst in master_nodes + slave_nodes: + print("> %s" % get_dns_name(inst, opts.private_ips)) + print("ALL DATA ON ALL NODES WILL BE LOST!!") + + msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) response = raw_input(msg) if response == "y": - print "Terminating master..." + print("Terminating master...") for inst in master_nodes: inst.terminate() - print "Terminating slaves..." + print("Terminating slaves...") for inst in slave_nodes: inst.terminate() # Delete security groups as well if opts.delete_groups: - print "Deleting security groups (this will take some time)..." group_names = [cluster_name + "-master", cluster_name + "-slaves"] wait_for_cluster_state( conn=conn, @@ -1161,15 +1470,16 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='terminated' ) + print("Deleting security groups (this will take some time)...") attempt = 1 while attempt <= 3: - print "Attempt %d" % attempt + print("Attempt %d" % attempt) groups = [g for g in conn.get_all_security_groups() if g.name in group_names] success = True # Delete individual rules in all groups before deleting groups to # remove dependencies between them for group in groups: - print "Deleting rules in security group " + group.name + print("Deleting rules in security group " + group.name) for rule in group.rules: for grant in rule.grants: success &= group.revoke(ip_protocol=rule.ip_protocol, @@ -1182,11 +1492,12 @@ def real_main(): time.sleep(30) # Yes, it does have to be this long :-( for group in groups: try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name + # It is needed to use group_id to make it work with VPC + conn.delete_security_group(group_id=group.id) + print("Deleted security group %s" % group.name) except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group " + group.name + print("Failed to delete security group %s" % group.name) # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails @@ -1196,18 +1507,21 @@ def real_main(): attempt += 1 if not success: - print "Failed to delete all security groups after 3 tries." - print "Try re-running in a few minutes." + print("Failed to delete all security groups after 3 tries.") + print("Try re-running in a few minutes.") elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + if not master_nodes[0].public_dns_name and not opts.private_ips: + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") + else: + master = get_dns_name(master_nodes[0], opts.private_ips) + print("Logging into master " + master + "...") + proxy_opt = [] + if opts.proxy_port is not None: + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "reboot-slaves": response = raw_input( @@ -1217,15 +1531,18 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Rebooting slaves..." + print("Rebooting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - print "Rebooting " + inst.id + print("Rebooting " + inst.id) inst.reboot() elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name + if not master_nodes[0].public_dns_name and not opts.private_ips: + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") + else: + print(get_dns_name(master_nodes[0], opts.private_ips)) elif action == "stop": response = raw_input( @@ -1238,11 +1555,11 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Stopping master..." + print("Stopping master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.stop() - print "Stopping slaves..." + print("Stopping slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: if inst.spot_instance_request_id: @@ -1252,11 +1569,11 @@ def real_main(): elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print "Starting slaves..." + print("Starting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - print "Starting master..." + print("Starting master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -1266,18 +1583,29 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='ssh-ready' ) + + # Determine types of running instances + existing_master_type = master_nodes[0].instance_type + existing_slave_type = slave_nodes[0].instance_type + # Setting opts.master_instance_type to the empty string indicates we + # have the same instance type for the master and the slaves + if existing_master_type == existing_slave_type: + existing_master_type = "" + opts.master_instance_type = existing_master_type + opts.instance_type = existing_slave_type + setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: - print >> stderr, "Invalid action: %s" % action + print("Invalid action: %s" % action, file=stderr) sys.exit(1) def main(): try: real_main() - except UsageError, e: - print >> stderr, "\nError:\n", e + except UsageError as e: + print("\nError:\n", e, file=stderr) sys.exit(1)