diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..79ea469e --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,15 @@ +version: 2.1 + +# Define the jobs we want to run for this project +jobs: + build: + docker: + - image: openjdk:8-jdk-oraclelinux7 + steps: + - run: echo "build job is not implemented" + +# Orchestrate our job run sequence +workflows: + build: + jobs: + - build \ No newline at end of file diff --git a/.gitignore b/.gitignore index cfe2c08a..bcf8c0f8 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ project/plugins/project/ # Node node_modules + +# Spark-ec2 boto +tools/spark-ec2/lib diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..35ab3b28 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "tools/flintrock"] + path = tools/flintrock + url = git@github.com:chaordic/flintrock.git + branch = ignition_v1 diff --git a/README.md b/README.md index 8f30027e..8b395319 100644 --- a/README.md +++ b/README.md @@ -8,4 +8,4 @@ It also provides many utilities for Spark jobs and Scala programs in general. It should be used inside a project as a submodule. See https://github.com/chaordic/ignition-template for an example. # Getting started -See http://monkeys.chaordic.com.br/start-using-spark-with-ignition/ for quick-start tutorial +See [Start using Spark with Ignition!](http://monkeys.chaordic.com.br/2015/03/22/start-using-spark-with-ignition.html) for quick-start tutorial diff --git a/build.sbt b/build.sbt index 095c1228..c6d7dbdb 100644 --- a/build.sbt +++ b/build.sbt @@ -2,37 +2,31 @@ name := "Ignition-Core" version := "1.0" -scalaVersion := "2.10.4" +scalaVersion := "2.11.12" -scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature", "-Xfatal-warnings") - -ideaExcludeFolders += ".idea" - -ideaExcludeFolders += ".idea_modules" +scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature", "-Xfatal-warnings", "-Xlint", "-Ywarn-dead-code", "-Xmax-classfile-name", "130") // Because we can't run two spark contexts on same VM parallelExecution in Test := false -libraryDependencies += ("org.apache.spark" %% "spark-core" % "1.3.0" % "provided").exclude("org.apache.hadoop", "hadoop-client") - -libraryDependencies += ("org.apache.hadoop" % "hadoop-client" % "2.0.0-cdh4.7.1" % "provided") +test in assembly := {} -libraryDependencies += "com.github.nscala-time" %% "nscala-time" % "0.8.0" +libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.3" % "provided" -libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" +libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "2.7.6" % "provided" -libraryDependencies += "org.scalaj" %% "scalaj-http" % "0.3.16" +libraryDependencies += "org.apache.hadoop" % "hadoop-aws" % "2.7.6" % "provided" -libraryDependencies += "org.scalaz" %% "scalaz-core" % "7.0.6" +libraryDependencies += "com.amazonaws" % "aws-java-sdk" % "1.7.4" % "provided" -libraryDependencies += "com.github.scopt" %% "scopt" % "3.2.0" +libraryDependencies += "org.scalaz" %% "scalaz-core" % "7.2.27" -libraryDependencies += "net.java.dev.jets3t" % "jets3t" % "0.7.1" +libraryDependencies += "com.github.scopt" %% "scopt" % "3.6.0" -resolvers += "Akka Repository" at "http://repo.akka.io/releases/" +libraryDependencies += "joda-time" % "joda-time" % "2.9.9" -resolvers += "Sonatype OSS Releases" at "http://oss.sonatype.org/content/repositories/releases/" +libraryDependencies += "org.joda" % "joda-convert" % "1.8.2" -resolvers += "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" +libraryDependencies += "org.slf4j" % "slf4j-api" % "1.7.25" -resolvers += Resolver.sonatypeRepo("public") +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.3" diff --git a/project/build.properties b/project/build.properties index be6c454f..7c58a83a 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=0.13.5 +sbt.version=1.2.6 diff --git a/project/plugins.sbt b/project/plugins.sbt deleted file mode 100644 index d5f371ab..00000000 --- a/project/plugins.sbt +++ /dev/null @@ -1,5 +0,0 @@ -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.4.0") - -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.10.2") - -addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") diff --git a/remote_hook.sh b/remote_hook.sh index 305a0ff6..903d618e 100755 --- a/remote_hook.sh +++ b/remote_hook.sh @@ -1,5 +1,6 @@ #!/bin/bash + # We suppose we are in a subdirectory of the root project DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" @@ -11,6 +12,8 @@ CONTROL_DIR="${5?Please give the Control Directory}" SPARK_MEM_PARAM="${6?Please give the Job Memory Size to use}" USE_YARN="${7?Please tell if we should use YARN (yes/no)}" NOTIFY_ON_ERRORS="${8?Please tell if we will notify on errors (yes/no)}" +DRIVER_HEAP_SIZE="${9?Please tell driver heap size to use}" +shift 9 JOB_WITH_TAG=${JOB_NAME}.${JOB_TAG} JOB_CONTROL_DIR="${CONTROL_DIR}/${JOB_WITH_TAG}" @@ -20,10 +23,18 @@ MY_USER=$(whoami) sudo mkdir -p "${JOB_CONTROL_DIR}" sudo chown $MY_USER "${JOB_CONTROL_DIR}" - RUNNING_FILE="${JOB_CONTROL_DIR}/RUNNING" +# This should be the first thing in the script to avoid the wait remote job thinking we died echo $$ > "${RUNNING_FILE}" + + +# Let us read the spark home even when the image doesn't give us the permission +sudo chmod o+rx /home/ec2-user +sudo chmod -R o+rx /home/ec2-user/spark + +mkdir -p /media/tmp/spark-events + notify_error_and_exit() { description="${1}" echo "Exiting because: ${description}" @@ -48,6 +59,37 @@ on_trap_exit() { rm -f "${RUNNING_FILE}" } +install_and_run_zeppelin() { + if [[ ! -d "zeppelin" ]]; then + wget "http://www-us.apache.org/dist/zeppelin/zeppelin-0.8.0/zeppelin-0.8.0-bin-all.tgz" -O zeppelin.tar.gz + mkdir zeppelin + tar xvzf zeppelin.tar.gz -C zeppelin --strip-components 1 > /tmp/zeppelin_install.log + fi + if [[ -f "zeppelin/bin/zeppelin.sh" ]]; then + export MASTER="${JOB_MASTER}" + export ZEPPELIN_PORT="8081" + export SPARK_HOME=$(get_first_present /root/spark /opt/spark ~/spark*/) + export SPARK_SUBMIT_OPTIONS="--jars ${JAR_PATH} --executor-memory ${SPARK_MEM_PARAM}" + zeppelin/bin/zeppelin.sh + else + notify_error_and_exit "Zeppelin installation not found" + fi +} + +install_and_run_jupyter() { + sudo yum -y install python3 python3-pip + sudo pip3 install jupyter pandas boto3 matplotlib numpy sklearn scipy toree + export SPARK_HOME=$(get_first_present /root/spark /opt/spark ~/spark*/) + export HADOOP_HOME=$(get_first_present /root/hadoop /opt/hadoop ~/hadoop*/) + export SPARK_CONF_DIR="${SPARK_HOME}/conf" + export HADOOP_CONF_DIR="${HADOOP_HOME}/conf" + export JOB_MASTER=${MASTER:-spark://${SPARK_MASTER_HOST}:7077} + export PYSPARK_PYTHON=$(which python3) + export PYSPARK_DRIVER_PYTHON=$(which jupyter) + export PYSPARK_DRIVER_PYTHON_OPTS="notebook --allow-root --ip=${SPARK_MASTER_HOST} --no-browser --port=8888" + sudo $(which jupyter) toree install --spark_home="${SPARK_HOME}" --spark_opts="--master ${JOB_MASTER} --executor-memory ${SPARK_MEM_PARAM} --driver-memory ${DRIVER_HEAP_SIZE}" + ${SPARK_HOME}/bin/pyspark --master "${JOB_MASTER}" --executor-memory "${SPARK_MEM_PARAM}" --driver-memory "${DRIVER_HEAP_SIZE}" +} trap "on_trap_exit" EXIT @@ -58,12 +100,14 @@ MAIN_CLASS="ignition.jobs.Runner" cd "${DIR}" || notify_error_and_exit "Internal script error for job ${JOB_WITH_TAG}" -JAR_PATH_SRC=$(echo "${DIR}"/*assembly*.jar) +JAR_PATH_SRC=$(ls -t "${DIR}"/*assembly*.jar | head -1) # most recent jar JAR_PATH="${JOB_CONTROL_DIR}/Ignition.jar" cp ${JAR_PATH_SRC} ${JAR_PATH} -export JOB_MASTER=${MASTER} +# If no $MASTER, then build a url using $SPARK_MASTER_HOST +export JOB_MASTER=${MASTER:-spark://${SPARK_MASTER_HOST}:7077} + if [[ "${USE_YARN}" == "yes" ]]; then export YARN_MODE=true @@ -73,14 +117,16 @@ if [[ "${USE_YARN}" == "yes" ]]; then export SPARK_WORKER_MEMORY=${SPARK_MEM_PARAM} fi - if [[ "${JOB_NAME}" == "shell" ]]; then - export ADD_JARS=${JAR_PATH} - sudo -E ${SPARK_HOME}/bin/spark-shell || notify_error_and_exit "Execution failed for shell" + ${SPARK_HOME}/bin/spark-shell --master "${JOB_MASTER}" --jars ${JAR_PATH} --driver-memory "${DRIVER_HEAP_SIZE}" --driver-java-options "-Djava.io.tmpdir=/media/tmp -verbose:gc -XX:-PrintGCDetails -XX:+PrintGCTimeStamps" --executor-memory "${SPARK_MEM_PARAM}" || notify_error_and_exit "Execution failed for shell" +elif [[ "${JOB_NAME}" == "zeppelin" ]]; then + install_and_run_zeppelin +elif [[ "${JOB_NAME}" == "jupyter" ]]; then + install_and_run_jupyter else JOB_OUTPUT="${JOB_CONTROL_DIR}/output.log" tail -F "${JOB_OUTPUT}" & - sudo -E "${SPARK_HOME}/bin/spark-submit" --master "${JOB_MASTER}" --driver-memory 25000M --driver-java-options "-Djava.io.tmpdir=/mnt -verbose:gc -XX:-PrintGCDetails -XX:+PrintGCTimeStamps" --class "${MAIN_CLASS}" ${JAR_PATH} "${JOB_NAME}" --runner-date "${JOB_DATE}" --runner-tag "${JOB_TAG}" --runner-user "${JOB_USER}" --runner-master "${JOB_MASTER}" --runner-executor-memory "${SPARK_MEM_PARAM}" >& "${JOB_OUTPUT}" || notify_error_and_exit "Execution failed for job ${JOB_WITH_TAG}" + ${SPARK_HOME}/bin/spark-submit --master "${JOB_MASTER}" --driver-memory "${DRIVER_HEAP_SIZE}" --driver-java-options "-Djava.io.tmpdir=/media/tmp -verbose:gc -XX:-PrintGCDetails -XX:+PrintGCTimeStamps " --class "${MAIN_CLASS}" ${JAR_PATH} "${JOB_NAME}" --runner-date "${JOB_DATE}" --runner-tag "${JOB_TAG}" --runner-user "${JOB_USER}" --runner-master "${JOB_MASTER}" --runner-executor-memory "${SPARK_MEM_PARAM}" "$@" >& "${JOB_OUTPUT}" || notify_error_and_exit "Execution failed for job ${JOB_WITH_TAG}" fi touch "${JOB_CONTROL_DIR}/SUCCESS" diff --git a/src/main/scala/ignition/core/jobs/CoreJobRunner.scala b/src/main/scala/ignition/core/jobs/CoreJobRunner.scala index aa4dcc76..eb1c7014 100644 --- a/src/main/scala/ignition/core/jobs/CoreJobRunner.scala +++ b/src/main/scala/ignition/core/jobs/CoreJobRunner.scala @@ -1,21 +1,31 @@ package ignition.core.jobs -import org.apache.spark.{SparkConf, SparkContext} -import org.joda.time.{DateTimeZone, DateTime} +import org.apache.spark.SparkContext +import org.apache.spark.sql.SparkSession +import org.joda.time.{DateTime, DateTimeZone} +import org.slf4j.{Logger, LoggerFactory} -import scala.util.Try +import scala.concurrent.Future object CoreJobRunner { + val logger: Logger = LoggerFactory.getLogger(getClass) + case class RunnerContext(sparkContext: SparkContext, + sparkSession: SparkSession, config: RunnerConfig) // Used to provide contextual logging def setLoggingContextValues(config: RunnerConfig): Unit = { - org.slf4j.MDC.put("setupName", config.setupName) - org.slf4j.MDC.put("tag", config.tag) - org.slf4j.MDC.put("user", config.user) + try { // yes, this may fail but we don't want everything to shut down + org.slf4j.MDC.put("setupName", config.setupName) + org.slf4j.MDC.put("tag", config.tag) + org.slf4j.MDC.put("user", config.user) + } catch { + case e: Throwable => + // cry + } } case class RunnerConfig(setupName: String = "nosetup", @@ -24,7 +34,7 @@ object CoreJobRunner { user: String = "nouser", master: String = "local[*]", executorMemory: String = "2G", - additionalArgs: Map[String, String] = Map.empty) + extraArgs: Map[String, String] = Map.empty) def runJobSetup(args: Array[String], jobsSetups: Map[String, (CoreJobRunner.RunnerContext => Unit, Map[String, String])], defaultSparkConfMap: Map[String, String]) { val parser = new scopt.OptionParser[RunnerConfig]("Runner") { @@ -49,8 +59,8 @@ object CoreJobRunner { c.copy(executorMemory = x) } - opt[(String, String)]('w', "runner-with-arg") unbounded() action { (x, c) => - c.copy(additionalArgs = c.additionalArgs ++ Map(x)) + opt[(String, String)]('w', "runner-extra") unbounded() action { (x, c) => + c.copy(extraArgs = c.extraArgs ++ Map(x)) } } @@ -65,27 +75,39 @@ object CoreJobRunner { val appName = s"${config.setupName}.${config.tag}" - val sparkConf = new SparkConf() - sparkConf.set("spark.executor.memory", config.executorMemory) + val builder = SparkSession.builder + builder.config("spark.executor.memory", config.executorMemory) + + builder.config("spark.eventLog.dir", "file:///media/tmp/spark-events") + + builder.master(config.master) + builder.appName(appName) - sparkConf.setMaster(config.master) - sparkConf.setAppName(appName) - - defaultSparkConfMap.foreach { case (k, v) => sparkConf.set(k, v) } + builder.config("spark.hadoop.mapred.output.committer.class", classOf[DirectOutputCommitter].getName()) - jobConf.foreach { case (k, v) => sparkConf.set(k, v) } + defaultSparkConfMap.foreach { case (k, v) => builder.config(k, v) } + + jobConf.foreach { case (k, v) => builder.config(k, v) } // Add logging context to driver setLoggingContextValues(config) - - val sc = new SparkContext(sparkConf) + try { + builder.enableHiveSupport() + } catch { + case t: Throwable => logger.warn("Failed to enable HIVE support", t) + } + + val session = builder.getOrCreate() + + val sc = session.sparkContext // Also try to propagate logging context to workers // TODO: find a more efficient and bullet-proof way val configBroadCast = sc.broadcast(config) + sc.parallelize(Range(1, 2000), numSlices = 2000).foreachPartition(_ => setLoggingContextValues(configBroadCast.value)) - val context = RunnerContext(sc, config) + val context = RunnerContext(sc, session, config) try { jobSetup.apply(context) @@ -94,8 +116,14 @@ object CoreJobRunner { t.printStackTrace() System.exit(1) // force exit of all threads } - Try { sc.stop() } - System.exit(0) // force exit of all threads + + import scala.concurrent.ExecutionContext.Implicits.global + Future { + // If everything is fine, the system will shut down without the help of this thread and YARN will report success + // But sometimes it gets stuck, then it's necessary to use the force, but this may finish the job as failed on YARN + Thread.sleep(30 * 1000) + System.exit(0) // force exit of all threads + } } } } diff --git a/src/main/scala/ignition/core/jobs/DirectOutputCommitter.scala b/src/main/scala/ignition/core/jobs/DirectOutputCommitter.scala new file mode 100644 index 00000000..63611cf4 --- /dev/null +++ b/src/main/scala/ignition/core/jobs/DirectOutputCommitter.scala @@ -0,0 +1,75 @@ +package ignition.core.jobs + +// Code from: https://gist.github.com/aarondav/c513916e72101bbe14ec +// Suggested by: http://tech.grammarly.com/blog/posts/Petabyte-Scale-Text-Processing-with-Spark.html + +/* + * Copyright 2015 Databricks, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain + * a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred._ + +/** + * OutputCommitter suitable for S3 workloads. Unlike the usual FileOutputCommitter, which + * writes files to a _temporary/ directory before renaming them to their final location, this + * simply writes directly to the final location. + * + * The FileOutputCommitter is required for HDFS + speculation, which allows only one writer at + * a time for a file (so two people racing to write the same file would not work). However, S3 + * supports multiple writers outputting to the same file, where visibility is guaranteed to be + * atomic. This is a monotonic operation: all writers should be writing the same data, so which + * one wins is immaterial. + * + * Code adapted from Ian Hummel's code from this PR: + * https://github.com/themodernlife/spark/commit/4359664b1d557d55b0579023df809542386d5b8c + */ +class DirectOutputCommitter extends OutputCommitter { + override def setupJob(jobContext: JobContext): Unit = { } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { } + + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = { + // We return true here to guard against implementations that do not handle false correctly. + // The meaning of returning false is not entirely clear, so it's possible to be interpreted + // as an error. Returning true just means that commitTask() will be called, which is a no-op. + true + } + + override def commitTask(taskContext: TaskAttemptContext): Unit = { } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { } + + /** + * Creates a _SUCCESS file to indicate the entire job was successful. + * This mimics the behavior of FileOutputCommitter, reusing the same file name and conf option. + */ + override def commitJob(context: JobContext): Unit = { + val conf = context.getJobConf + if (shouldCreateSuccessFile(conf)) { + val outputPath = FileOutputFormat.getOutputPath(conf) + if (outputPath != null) { + val fileSys = outputPath.getFileSystem(conf) + val filePath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSys.create(filePath).close() + } + } + } + + /** By default, we do create the _SUCCESS file, but we allow it to be turned off. */ + private def shouldCreateSuccessFile(conf: JobConf): Boolean = { + conf.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true) + } +} \ No newline at end of file diff --git a/src/main/scala/ignition/core/jobs/ExecutionRetry.scala b/src/main/scala/ignition/core/jobs/ExecutionRetry.scala index 61daa523..7e5a3953 100644 --- a/src/main/scala/ignition/core/jobs/ExecutionRetry.scala +++ b/src/main/scala/ignition/core/jobs/ExecutionRetry.scala @@ -2,6 +2,8 @@ package ignition.core.jobs import scala.util.Try +object ExecutionRetry extends ExecutionRetry + trait ExecutionRetry { def executeRetrying[T](code: => T, maxExecutions: Int = 3): T = { diff --git a/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala b/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala index fc42ded5..ab08c3c7 100644 --- a/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala +++ b/src/main/scala/ignition/core/jobs/utils/RDDUtils.scala @@ -1,16 +1,12 @@ package ignition.core.jobs.utils +import org.apache.spark.rdd.RDD import org.slf4j.LoggerFactory +import scalaz.{Success, Validation} +import scala.collection.mutable import scala.reflect._ -import org.apache.spark.rdd.{PairRDDFunctions, CoGroupedRDD, RDD} -import org.apache.spark.SparkContext._ -import org.apache.spark.Partitioner -import org.apache.spark -import org.joda.time.DateTime -import org.joda.time.format.DateTimeFormat - -import scalaz.{Success, Validation} +import scala.util.Random object RDDUtils { @@ -29,6 +25,12 @@ object RDDUtils { } } + implicit class SetRDDImprovements[V: ClassTag](rdd: RDD[Set[V]]) { + def flatten: RDD[V] = { + rdd.flatMap(x => x) + } + } + implicit class ValidatedRDDImprovements[A: ClassTag, B: ClassTag](rdd: RDD[Validation[A, B]]) { def mapSuccess(f: B => Validation[A, B]): RDD[Validation[A, B]] = { @@ -50,23 +52,10 @@ object RDDUtils { } implicit class RDDImprovements[V: ClassTag](rdd: RDD[V]) { - def incrementCounter(acc: spark.Accumulator[Int]): RDD[V] = { - rdd.map(x => { acc += 1; x }) - } - - def incrementCounterIf(cond: (V) => Boolean, acc: spark.Accumulator[Int]): RDD[V] = { - rdd.map(x => { if (cond(x)) acc += 1; x }) - } + def filterNot(p: V => Boolean): RDD[V] = rdd.filter(!p(_)) } implicit class PairRDDImprovements[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) { - def incrementCounter(acc: spark.Accumulator[Int]): RDD[(K, V)] = { - rdd.mapValues(x => { acc += 1; x }) - } - - def incrementCounterIf(cond: (K, V) => Boolean, acc: spark.Accumulator[Int]): RDD[(K, V)] = { - rdd.mapPreservingPartitions(x => { if(cond(x._1, x._2)) acc += 1; x._2 }) - } def flatMapPreservingPartitions[U: ClassTag](f: ((K, V)) => Seq[U]): RDD[(K, U)] = { rdd.mapPartitions[(K, U)](kvs => { @@ -80,14 +69,22 @@ object RDDUtils { }, preservesPartitioning = true) } - def groupByKeyAndTake(n: Int): RDD[(K, List[V])] = - rdd.aggregateByKey(List.empty[V])( + def collectValues[U: ClassTag](f: PartialFunction[V, U]): RDD[(K, U)] = { + rdd.filter { case (k, v) => f.isDefinedAt(v) }.mapValues(f) + } + + // loggingFactor: percentage of the potential logging that will be really printed + // Big jobs will have too much logging and my eat up cluster disk space + def groupByKeyAndTake(n: Int, loggingFactor: Double = 0.5): RDD[(K, List[V])] = + rdd.aggregateByKey(mutable.ListBuffer.empty[V])( (lst, v) => if (lst.size >= n) { - logger.warn(s"Ignoring value '$v' due aggregation result of size '${lst.size}' is bigger then n = '$n'") + if (Random.nextDouble() < loggingFactor) + logger.warn(s"Ignoring value '$v' due aggregation result of size '${lst.size}' is bigger than n=$n") lst } else { - v :: lst + lst += v + lst }, (lstA, lstB) => if (lstA.size >= n) @@ -96,12 +93,16 @@ object RDDUtils { lstB else { if (lstA.size + lstB.size > n) { - logger.warn(s"Merging partition1=${lstA.size} with partition2=${lstB.size} and taking the first n=$n, sample1='${lstA.take(5)}', sample2='${lstB.take(5)}'") - (lstA ++ lstB).take(n) - } else - lstA ++ lstB + if (Random.nextDouble() < loggingFactor) + logger.warn(s"Merging partition1=${lstA.size} with partition2=${lstB.size} and taking the first n=$n, sample1='${lstA.take(5)}', sample2='${lstB.take(5)}'") + lstA ++= lstB + lstA.take(n) + } else { + lstA ++= lstB + lstA + } } - ) + ).mapValues(_.toList) // Note: completely unoptimized. We could use instead for better performance: // 1) sortByKey @@ -113,4 +114,4 @@ object RDDUtils { (lstA, lstB) => (lstA ++ lstB).sorted(ord).take(n)) } } -} +} \ No newline at end of file diff --git a/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala b/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala index 29c32112..e5155340 100644 --- a/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala +++ b/src/main/scala/ignition/core/jobs/utils/SparkContextUtils.scala @@ -1,23 +1,89 @@ package ignition.core.jobs.utils -import java.util.Date - -import ignition.core.utils.ByteUtils -import org.apache.hadoop.io.LongWritable -import org.apache.spark.SparkContext -import org.apache.hadoop.fs.{FileStatus, Path, FileSystem} -import org.apache.spark.rdd.{UnionRDD, RDD} -import org.joda.time.{DateTimeZone, DateTime} +import java.io.InputStream + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.s3.model.{ListObjectsRequest, ObjectListing, S3ObjectSummary} +import com.amazonaws.services.s3.{AmazonS3, AmazonS3Client} +import ignition.core.utils.CollectionUtils._ import ignition.core.utils.DateUtils._ +import ignition.core.utils.ExceptionUtils._ +import ignition.core.utils.{AutoCloseableIterator, ByteUtils, HadoopUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.{Partitioner, SparkContext} +import org.joda.time.DateTime +import org.slf4j.LoggerFactory +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.io.{Codec, Source} import scala.reflect.ClassTag -import scala.util.Try +import scala.util.control.NonFatal +import scala.util.{Failure, Success, Try} object SparkContextUtils { + private case class BigFileSlice(index: Int) + + private case class HadoopFilePartition(size: Long, paths: Seq[String]) + + private case class IndexedPartitioner(numPartitions: Int, index: Map[Any, Int]) extends Partitioner { + override def getPartition(key: Any): Int = index(key) + } + + case class SizeBasedFileHandling(averageEstimatedCompressionRatio: Int = 8, compressedExtensions: Set[String] = Set(".gz")) { + + def isBig(f: HadoopFile, uncompressedBigSize: Long): Boolean = estimatedSize(f) >= uncompressedBigSize + + def estimatedSize(f: HadoopFile): Long = if (isCompressed(f)) + f.size * averageEstimatedCompressionRatio + else + f.size + + def isCompressed(f: HadoopFile): Boolean = compressedExtensions.exists(f.path.endsWith) + } + + private lazy val amazonS3ClientFromEnvironmentVariables: AmazonS3 = new AmazonS3Client(new DefaultAWSCredentialsProviderChain()) + + private def close(inputStream: InputStream, path: String): Unit = { + try { + inputStream.close() + } catch { + case NonFatal(ex) => + println(s"Fail to close resource from '$path': ${ex.getMessage} -- ${ex.getFullStackTraceString}") + } + } + + object S3SplittedPath { + val s3Pattern = "s3[an]?://([^/]+)(.+)".r + + def from(fullPath: String): Option[S3SplittedPath] = + fullPath match { + case s3Pattern(bucket, prefix) => Option(S3SplittedPath(bucket, prefix.dropWhile(_ == '/'))) + case _ => None + } + } + + case class S3SplittedPath(bucket: String, key: String) { + def join: String = s"s3a://$bucket/$key" + } + + case class HadoopFile(path: String, isDir: Boolean, size: Long) + + case class WithOptDate[E](date: Option[DateTime], value: E) implicit class SparkContextImprovements(sc: SparkContext) { + private lazy val logger = LoggerFactory.getLogger(getClass) + + lazy val _hadoopConf = sc.broadcast(sc.hadoopConfiguration.iterator().asScala.map { case entry => entry.getKey -> entry.getValue }.toMap) + private def getFileSystem(path: Path): FileSystem = { path.getFileSystem(sc.hadoopConfiguration) } @@ -28,7 +94,7 @@ object SparkContextUtils { for { path <- paths status <- Option(fs.globStatus(path)).getOrElse(Array.empty).toSeq - if status.isDirectory || !removeEmpty || status.getLen > 0 // remove empty files if necessary + if !removeEmpty || status.getLen > 0 || status.isDirectory // remove empty files if necessary } yield status } @@ -38,7 +104,7 @@ object SparkContextUtils { } // This call is equivalent to a ls -d in shell, but won't fail if part of a path matches nothing, - // For instance, given path = s3n://bucket/{a,b}, it will work fine if a exists but b is missing + // For instance, given path = s3a://bucket/{a,b}, it will work fine if a exists but b is missing def sortedGlobPath(_paths: Seq[String], removeEmpty: Boolean = true): Seq[String] = { val paths = _paths.flatMap(path => ignition.core.utils.HadoopUtils.getPathStrings(path)) paths.flatMap(p => getStatus(p, removeEmpty)).map(_.getPath.toString).distinct.sorted @@ -52,7 +118,7 @@ object SparkContextUtils { if (splittedPaths.size < minimumPaths) throw new Exception(s"Not enough paths found for $paths") - val rdds = splittedPaths.grouped(50).map(pathGroup => f(pathGroup.mkString(","))) + val rdds = splittedPaths.grouped(5000).map(pathGroup => f(pathGroup.mkString(","))) new UnionRDD(sc, rdds.toList) } @@ -95,7 +161,6 @@ object SparkContextUtils { } - def getFilteredPaths(paths: Seq[String], requireSuccess: Boolean, inclusiveStartDate: Boolean, @@ -108,16 +173,14 @@ object SparkContextUtils { filterPaths(paths, requireSuccess, inclusiveStartDate, startDate, inclusiveEndDate, endDate, lastN, ignoreMalformedDates) } - lazy val hdfsPathPrefix = sc.master.replaceFirst("spark://(.*):7077", "hdfs://$1:9000/") def synchToHdfs(paths: Seq[String], pathsToRdd: (Seq[String], Int) => RDD[String], forceSynch: Boolean): Seq[String] = { val filesToOutput = 1500 def mapPaths(actionWhenNeedsSynching: (String, String) => Unit): Seq[String] = { paths.map(p => { - val hdfsPath = p.replace("s3n://", hdfsPathPrefix) + val hdfsPath = p.replaceFirst("s3[an]://", hdfsPathPrefix) if (forceSynch || getStatus(hdfsPath, false).isEmpty || getStatus(s"$hdfsPath/*", true).filterNot(_.isDirectory).size != filesToOutput) { - val _hdfsPath = new Path(hdfsPath) actionWhenNeedsSynching(p, hdfsPath) } hdfsPath @@ -130,6 +193,15 @@ object SparkContextUtils { } + @deprecated("It may incur heavy S3 costs and/or be slow with small files, use getParallelTextFiles instead", "2015-10-27") + def getTextFiles(paths: Seq[String], synchLocally: Boolean = false, forceSynch: Boolean = false, minimumPaths: Int = 1): RDD[String] = { + if (synchLocally) + processTextFiles(synchToHdfs(paths, processTextFiles, forceSynch), minimumPaths) + else + processTextFiles(paths, minimumPaths) + } + + @deprecated("It may incur heavy S3 costs and/or be slow with small files, use filterAndGetParallelTextFiles instead", "2015-10-27") def filterAndGetTextFiles(path: String, requireSuccess: Boolean = false, inclusiveStartDate: Boolean = true, @@ -144,10 +216,7 @@ object SparkContextUtils { val paths = getFilteredPaths(Seq(path), requireSuccess, inclusiveStartDate, startDate, inclusiveEndDate, endDate, lastN, ignoreMalformedDates) if (paths.size < minimumPaths) throw new Exception(s"Tried with start/end time equals to $startDate/$endDate for path $path but but the resulting number of paths $paths is less than the required") - else if (synchLocally) - processTextFiles(synchToHdfs(paths, processTextFiles, forceSynch), minimumPaths) - else - processTextFiles(paths, minimumPaths) + getTextFiles(paths, synchLocally, forceSynch, minimumPaths) } private def stringHadoopFile(paths: Seq[String], minimumPaths: Int): RDD[Try[String]] = { @@ -190,5 +259,441 @@ object SparkContextUtils { else objectHadoopFile(paths, minimumPaths) } + + private def readSmallFiles(smallFiles: List[HadoopFile], + maxBytesPerPartition: Long, + minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling): RDD[String] = { + val smallPartitionedFiles = sc.parallelize(smallFiles.map(_.path).map(file => file -> null), 2).partitionBy(createSmallFilesPartitioner(smallFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling)) + val hadoopConf = _hadoopConf + smallPartitionedFiles.mapPartitions { files => + val conf = hadoopConf.value.foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + val codecFactory = new CompressionCodecFactory(conf) + files.map { case (path, _) => path } flatMap { path => + val hadoopPath = new Path(path) + val fileSystem = hadoopPath.getFileSystem(conf) + val inputStream = Option(codecFactory.getCodec(hadoopPath)) match { + case Some(compression) => compression.createInputStream(fileSystem.open(hadoopPath)) + case None => fileSystem.open(hadoopPath) + } + try { + Source.fromInputStream(inputStream)(Codec.UTF8).getLines().foldLeft(ArrayBuffer.empty[String])(_ += _) + } catch { + case NonFatal(ex) => + println(s"Failed to read resource from '$path': ${ex.getMessage} -- ${ex.getFullStackTraceString}") + throw new Exception(s"Failed to read resource from '$path': ${ex.getMessage} -- ${ex.getFullStackTraceString}") + } finally { + close(inputStream, path) + } + } + } + } + + private def readCompressedBigFile(file: HadoopFile, maxBytesPerPartition: Long, minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling, sampleCount: Int = 100): RDD[String] = { + val estimatedSize = sizeBasedFileHandling.estimatedSize(file) + val totalSlices = (estimatedSize / maxBytesPerPartition + 1).toInt + val slices = (0 until totalSlices).map(BigFileSlice.apply) + + val partitioner = { + val indexedPartitions: Map[Any, Int] = slices.map(s => s -> s.index).toMap + IndexedPartitioner(totalSlices, indexedPartitions) + } + val hadoopConf = _hadoopConf + + val partitionedSlices = sc.parallelize(slices.map(s => s -> null), 2).partitionBy(partitioner) + partitionedSlices.mapPartitions { slices => + val conf = hadoopConf.value.foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + val codecFactory = new CompressionCodecFactory(conf) + val hadoopPath = new Path(file.path) + val fileSystem = hadoopPath.getFileSystem(conf) + slices.flatMap { case (slice, _) => + try { + val inputStream = Option(codecFactory.getCodec(hadoopPath)) match { + case Some(compression) => compression.createInputStream(fileSystem.open(hadoopPath)) + case None => fileSystem.open(hadoopPath) + } + val lines = Source.fromInputStream(inputStream)(Codec.UTF8).getLines() + + val lineSample = lines.take(sampleCount).toList + val linesPerSlice = { + val sampleSize = lineSample.map(_.size).sum + val estimatedAverageLineSize = Math.round(sampleSize / sampleCount.toFloat) + val estimatedTotalLines = Math.round(estimatedSize / estimatedAverageLineSize.toFloat) + estimatedTotalLines / totalSlices + 1 + } + + val linesAfterSeek = (lineSample.toIterator ++ lines).drop(linesPerSlice * slice.index) + + val finalLines = if (slice.index + 1 == totalSlices) // last slice, read until the end + linesAfterSeek + else + linesAfterSeek.take(linesPerSlice) + + AutoCloseableIterator.wrap(finalLines, () => close(inputStream, s"${file.path}, slice $slice")) + } catch { + case NonFatal(e) => + throw new Exception(s"Error on read compressed big file, slice=$slice, file=$file", e) + } + } + } + } + + private def readBigFiles(bigFiles: List[HadoopFile], + maxBytesPerPartition: Long, + minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling): RDD[String] = { + def confWith(maxSplitSize: Long): Configuration = (_hadoopConf.value ++ Seq( + "mapreduce.input.fileinputformat.split.minsize" -> maxSplitSize.toString, + "mapreduce.input.fileinputformat.split.maxsize" -> maxSplitSize.toString)) + .foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + + def read(file: HadoopFile, conf: Configuration) = sc.newAPIHadoopFile[LongWritable, Text, TextInputFormat](conf = conf, fClass = classOf[TextInputFormat], + kClass = classOf[LongWritable], vClass = classOf[Text], path = file.path).map(pair => pair._2.toString) + + val confUncompressed = confWith(maxBytesPerPartition) + + val union = new UnionRDD(sc, bigFiles.map { file => + + if (sizeBasedFileHandling.isCompressed(file)) + readCompressedBigFile(file, maxBytesPerPartition, minPartitions, sizeBasedFileHandling) + else + read(file, confUncompressed) + }) + + if (union.partitions.size < minPartitions) + union.coalesce(minPartitions) + else + union + } + + def parallelReadTextFiles(files: List[HadoopFile], + maxBytesPerPartition: Long = 128 * 1000 * 1000, + minPartitions: Int = 100, + sizeBasedFileHandling: SizeBasedFileHandling = SizeBasedFileHandling(), + synchLocally: Option[String] = None, + forceSynch: Boolean = false): RDD[String] = { + val filteredFiles = files.filter(_.size > 0) + if (synchLocally.isDefined) + doSync(filteredFiles, maxBytesPerPartition = maxBytesPerPartition, minPartitions = minPartitions, synchLocally = synchLocally.get, + sizeBasedFileHandling = sizeBasedFileHandling, forceSynch = forceSynch) + else { + val (bigFiles, smallFiles) = filteredFiles.partition(f => sizeBasedFileHandling.isBig(f, maxBytesPerPartition)) + sc.union( + readSmallFiles(smallFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling), + readBigFiles(bigFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling)) + } + } + + private def createSmallFilesPartitioner(files: List[HadoopFile], maxBytesPerPartition: Long, minPartitions: Long, sizeBasedFileHandling: SizeBasedFileHandling): Partitioner = { + implicit val ordering: Ordering[HadoopFilePartition] = Ordering.by(p => -p.size) // Small partitions come first (highest priority) + + val pq: mutable.PriorityQueue[HadoopFilePartition] = mutable.PriorityQueue.empty + + (0L until minPartitions).foreach(_ => pq += HadoopFilePartition(0, Seq.empty)) + + val partitions = files.foldLeft(pq) { + case (acc, file) => + val fileSize = sizeBasedFileHandling.estimatedSize(file) + + acc.headOption.filter(bucket => bucket.size + fileSize < maxBytesPerPartition) match { + case Some(found) => + val updated = found.copy(size = found.size + fileSize, paths = file.path +: found.paths) + acc.tail += updated + case None => acc += HadoopFilePartition(fileSize, Seq(file.path)) + } + }.filter(_.paths.nonEmpty).toList // Remove empty partitions + + val indexedPartitions: Map[Any, Int] = partitions.zipWithIndex.flatMap { + case (bucket, index) => bucket.paths.map(path => path -> index) + }.toMap + + IndexedPartitioner(partitions.size, indexedPartitions) + } + + private def executeDriverList(paths: Seq[String]): List[HadoopFile] = { + val conf = _hadoopConf.value.foldLeft(new Configuration()) { case (acc, (k, v)) => acc.set(k, v); acc } + paths.flatMap { path => + val hadoopPath = new Path(path) + val fileSystem = hadoopPath.getFileSystem(conf) + val tryFind = try { + val status = fileSystem.getFileStatus(hadoopPath) + if (status.isDirectory) { + val sanitize = Option(fileSystem.listStatus(hadoopPath)).getOrElse(Array.empty) + Option(sanitize.map(status => HadoopFile(status.getPath.toString, status.isDirectory, status.getLen)).toList) + } else if (status.isFile) { + Option(List(HadoopFile(status.getPath.toString, status.isDirectory, status.getLen))) + } else { + None + } + } catch { + case e: java.io.FileNotFoundException => + None + } + + tryFind.getOrElse { + // Maybe is glob or not found + val sanitize = Option(fileSystem.globStatus(hadoopPath)).getOrElse(Array.empty) + sanitize.map(status => HadoopFile(status.getPath.toString, status.isDirectory, status.getLen)).toList + } + }.toList + } + + private def driverListFiles(path: String): List[HadoopFile] = { + def innerListFiles(remainingDirectories: List[HadoopFile]): List[HadoopFile] = { + if (remainingDirectories.isEmpty) { + Nil + } else { + val (dirs, files) = executeDriverList(remainingDirectories.map(_.path)).partition(_.isDir) + files ++ innerListFiles(dirs) + } + } + innerListFiles(List(HadoopFile(path, isDir = true, 0))) + } + + def s3ListCommonPrefixes(path: S3SplittedPath, delimiter: String = "/") + (implicit s3: AmazonS3): Stream[S3SplittedPath] = { + def inner(current: ObjectListing): Stream[String] = + if (current.isTruncated) { + logger.trace(s"list common prefixed truncated for ${path.bucket} ${path.key}: ${current.getCommonPrefixes}") + current.getCommonPrefixes.asScala.toStream ++ inner(s3.listNextBatchOfObjects(current)) + } else { + logger.trace(s"list common prefixed finished for ${path.bucket} ${path.key}: ${current.getCommonPrefixes}") + current.getCommonPrefixes.asScala.toStream + } + + val request = new ListObjectsRequest(path.bucket, path.key, null, delimiter, 1000) + inner(s3.listObjects(request)).map(prefix => path.copy(key = prefix)) + } + + def s3ListObjects(path: S3SplittedPath) + (implicit s3: AmazonS3): Stream[S3ObjectSummary] = { + def inner(current: ObjectListing): Stream[S3ObjectSummary] = + if (current.isTruncated) { + logger.trace(s"list objects truncated for ${path.bucket} ${path.key}: $current") + current.getObjectSummaries.asScala.toStream ++ inner(s3.listNextBatchOfObjects(current)) + } else { + logger.trace(s"list objects finished for ${path.bucket} ${path.key}") + current.getObjectSummaries.asScala.toStream + } + + inner(s3.listObjects(path.bucket, path.key)) + } + + def s3NarrowPaths(splittedPath: S3SplittedPath, + inclusiveStartDate: Boolean = true, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = true, + endDate: Option[DateTime] = None, + ignoreHours: Boolean = true) + (implicit s3: AmazonS3, pathDateExtractor: PathDateExtractor): Stream[WithOptDate[S3SplittedPath]] = { + + def isGoodDate(date: DateTime): Boolean = { + val startDateToCompare = startDate.map(date => if (ignoreHours) date.withTimeAtStartOfDay() else date) + val endDateToCompare = endDate.map(date => if (ignoreHours) date.withTime(23, 59, 59, 999) else date) + val goodStartDate = startDateToCompare.isEmpty || (inclusiveStartDate && date.saneEqual(startDateToCompare.get) || date.isAfter(startDateToCompare.get)) + val goodEndDate = endDateToCompare.isEmpty || (inclusiveEndDate && date.saneEqual(endDateToCompare.get) || date.isBefore(endDateToCompare.get)) + goodStartDate && goodEndDate + } + + def classifyPath(path: S3SplittedPath): Either[S3SplittedPath, (S3SplittedPath, DateTime)] = + Try(pathDateExtractor.extractFromPath(path.join)) match { + case Success(date) => Right(path -> date) + case Failure(_) => Left(path) + } + + val commonPrefixes = s3ListCommonPrefixes(splittedPath).map(classifyPath) + + logger.trace(s"s3NarrowPaths for $splittedPath, common prefixes: $commonPrefixes") + if (commonPrefixes.isEmpty) + Stream(WithOptDate(Try(pathDateExtractor.extractFromPath(splittedPath.join)).toOption, splittedPath)) + else + commonPrefixes.toStream.flatMap { + case Left(prefixWithoutDate) => + logger.trace(s"s3NarrowPaths prefixWithoutDate: $prefixWithoutDate") + s3NarrowPaths(prefixWithoutDate, inclusiveStartDate, startDate, inclusiveEndDate, endDate, ignoreHours) + case Right((prefixWithDate, date)) if isGoodDate(date) => Stream(WithOptDate(Option(date), prefixWithDate)) + case Right(_) => Stream.empty + } + } + + // Sorted from most recent to least recent path + private def sortPaths[P](paths: Stream[WithOptDate[P]]): Stream[WithOptDate[P]] = { + paths.sortBy { p => p.date.getOrElse(new DateTime(1970, 1, 1, 1, 1)) }(Ordering[DateTime].reverse) + } + + private def sortedS3List(path: String, + inclusiveStartDate: Boolean, + startDate: Option[DateTime], + inclusiveEndDate: Boolean, + endDate: Option[DateTime], + exclusionPattern: Option[String]) + (implicit s3: AmazonS3, dateExtractor: PathDateExtractor): Stream[WithOptDate[Array[S3ObjectSummary]]] = { + + + S3SplittedPath.from(path) match { + case Some(splittedPath) => + val prefixes: Stream[WithOptDate[S3SplittedPath]] = + s3NarrowPaths(splittedPath, inclusiveStartDate = inclusiveStartDate, inclusiveEndDate = inclusiveEndDate, + startDate = startDate, endDate = endDate) + + sortPaths(prefixes) + .map { case WithOptDate(date, path) => WithOptDate(date, s3ListObjects(path).toArray) } // Will list the most recent path first and only if needed the others + case _ => Stream.empty + } + } + + + def listAndFilterFiles(path: String, + requireSuccess: Boolean = false, + inclusiveStartDate: Boolean = true, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = true, + endDate: Option[DateTime] = None, + lastN: Option[Int] = None, + ignoreMalformedDates: Boolean = false, + endsWith: Option[String] = None, + exclusionPattern: Option[String] = Option(".*_temporary.*|.*_\\$folder.*"), + predicate: HadoopFile => Boolean = _ => true) + (implicit dateExtractor: PathDateExtractor): List[HadoopFile] = { + + def isSuccessFile(file: HadoopFile): Boolean = + file.path.endsWith("_SUCCESS") || file.path.endsWith("_FINISHED") + + def excludePatternValidation(file: HadoopFile): Boolean = + exclusionPattern.map(pattern => !file.path.matches(pattern)).getOrElse(true) + + def endsWithValidation(file: HadoopFile): Boolean = + endsWith.map { pattern => + file.path.endsWith(pattern) || isSuccessFile(file) + }.getOrElse(true) + + def dateValidation(files: WithOptDate[Array[HadoopFile]]): Boolean = { + val tryDate = files.date + if (tryDate.isEmpty && ignoreMalformedDates) + true + else if (tryDate.isEmpty) + throw new Exception(s"Not date found for path $path, expanded files: ${files.value.toList}, consider using ignoreMalformedDates=true if not date is expected on this path") + else { + val date = tryDate.get + val goodStartDate = startDate.isEmpty || (inclusiveStartDate && date.saneEqual(startDate.get) || date.isAfter(startDate.get)) + def goodEndDate = endDate.isEmpty || (inclusiveEndDate && date.saneEqual(endDate.get) || date.isBefore(endDate.get)) + goodStartDate && goodEndDate + } + } + + def successFileValidation(files: WithOptDate[Array[HadoopFile]]): Boolean = { + if (requireSuccess) + files.value.exists(isSuccessFile) + else + true + } + + def preValidations(files: WithOptDate[Array[HadoopFile]]): Option[WithOptDate[Array[HadoopFile]]] = { + if (!successFileValidation(files)) + None + else { + val filtered = files.copy(value = files.value + .filter(excludePatternValidation).filter(endsWithValidation).filter(predicate)) + if (filtered.value.isEmpty || !dateValidation(filtered)) + None + else + Option(filtered) + } + } + + val groupedAndSortedByDateFiles = sortedSmartList(path, inclusiveStartDate = inclusiveStartDate, inclusiveEndDate = inclusiveEndDate, + startDate = startDate, endDate = endDate, exclusionPattern = exclusionPattern).flatMap(preValidations) + + val allFiles = if (lastN.isDefined) + groupedAndSortedByDateFiles.take(lastN.get).flatMap(_.value) + else + groupedAndSortedByDateFiles.flatMap(_.value) + + allFiles.sortBy(_.path).toList + } + + def sortedSmartList(path: String, + inclusiveStartDate: Boolean = false, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = false, + endDate: Option[DateTime] = None, + exclusionPattern: Option[String] = None)(implicit pathDateExtractor: PathDateExtractor): Stream[WithOptDate[Array[HadoopFile]]] = { + + def toHadoopFile(s3Object: S3ObjectSummary): HadoopFile = + HadoopFile(s"s3a://${s3Object.getBucketName}/${s3Object.getKey}", isDir = false, s3Object.getSize) + + def listPath(path: String): Stream[WithOptDate[Array[HadoopFile]]] = { + if (path.startsWith("s3")) { + sortedS3List(path, inclusiveStartDate = inclusiveStartDate, startDate = startDate, inclusiveEndDate = inclusiveEndDate, + endDate = endDate, exclusionPattern = exclusionPattern)(amazonS3ClientFromEnvironmentVariables, pathDateExtractor).map { + case WithOptDate(date, paths) => WithOptDate(date, paths.map(toHadoopFile).toArray) + } + } else { + val pathsWithDate: Stream[WithOptDate[Iterable[HadoopFile]]] = driverListFiles(path) + .map(p => (Try { pathDateExtractor.extractFromPath(p.path) }.toOption, p)) + .groupByKey() + .map { case (date, path) => WithOptDate(date, path) } + .toStream + sortPaths(pathsWithDate).map { case WithOptDate(date, paths) => WithOptDate(date, paths.toArray) } + } + } + + HadoopUtils.getPathStrings(path).toStream.flatMap(listPath) + } + + def filterAndGetParallelTextFiles(path: String, + requireSuccess: Boolean = false, + inclusiveStartDate: Boolean = true, + startDate: Option[DateTime] = None, + inclusiveEndDate: Boolean = true, + endDate: Option[DateTime] = None, + lastN: Option[Int] = None, + ignoreMalformedDates: Boolean = false, + endsWith: Option[String] = None, + predicate: HadoopFile => Boolean = _ => true, + maxBytesPerPartition: Long = 128 * 1000 * 1000, + minPartitions: Int = 100, + sizeBasedFileHandling: SizeBasedFileHandling = SizeBasedFileHandling(), + minimumFiles: Int = 1, + synchLocally: Option[String] = None, + forceSynch: Boolean = false) + (implicit dateExtractor: PathDateExtractor): RDD[String] = { + + val foundFiles = listAndFilterFiles(path, requireSuccess, inclusiveStartDate, startDate, inclusiveEndDate, + endDate, lastN, ignoreMalformedDates, endsWith, predicate = predicate) + + if (foundFiles.size < minimumFiles) + throw new Exception(s"Tried with start/end time equals to $startDate/$endDate for path $path but but the resulting number of files $foundFiles is less than the required") + + parallelReadTextFiles(foundFiles, maxBytesPerPartition = maxBytesPerPartition, minPartitions = minPartitions, + sizeBasedFileHandling = sizeBasedFileHandling, synchLocally = synchLocally, forceSynch = forceSynch) + } + + private def doSync(hadoopFiles: List[HadoopFile], + synchLocally: String, + forceSynch: Boolean, + maxBytesPerPartition: Long, + minPartitions: Int, + sizeBasedFileHandling: SizeBasedFileHandling): RDD[String] = { + require(!synchLocally.contains("*"), "Globs are not supported on the sync key") + + def syncPath(suffix: String) = s"$hdfsPathPrefix/_core_ignition_sync_hdfs_cache/$suffix" + + val hashKey = Integer.toHexString(hadoopFiles.toSet.hashCode()) + + lazy val foundLocalPaths = getStatus(syncPath(s"$synchLocally/$hashKey/{_SUCCESS,_FINISHED}"), removeEmpty = false) + + val cacheKey = syncPath(s"$synchLocally/$hashKey") + + if (forceSynch || foundLocalPaths.isEmpty) { + delete(new Path(syncPath(s"$synchLocally/"))) + val data = parallelReadTextFiles(hadoopFiles, maxBytesPerPartition, minPartitions, sizeBasedFileHandling = sizeBasedFileHandling, synchLocally = None) + data.saveAsTextFile(cacheKey) + } + + sc.textFile(cacheKey) + } + } } diff --git a/src/main/scala/ignition/core/testsupport/spark/LocalSparkContext.scala b/src/main/scala/ignition/core/testsupport/spark/LocalSparkContext.scala index 2edb28e7..814f565d 100644 --- a/src/main/scala/ignition/core/testsupport/spark/LocalSparkContext.scala +++ b/src/main/scala/ignition/core/testsupport/spark/LocalSparkContext.scala @@ -21,13 +21,12 @@ import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLogger import org.apache.spark.SparkContext import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite} -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) super.beforeAll() } diff --git a/src/main/scala/ignition/core/testsupport/spark/SharedSparkContext.scala b/src/main/scala/ignition/core/testsupport/spark/SharedSparkContext.scala index 314d5442..4fa5756b 100644 --- a/src/main/scala/ignition/core/testsupport/spark/SharedSparkContext.scala +++ b/src/main/scala/ignition/core/testsupport/spark/SharedSparkContext.scala @@ -33,6 +33,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => //Logger.getRootLogger().removeAllAppenders(); //Logger.getRootLogger().addAppender(new NullAppender()); _sc = new SparkContext("local", "test", conf) + _sc.setLogLevel("OFF") super.beforeAll() } diff --git a/src/main/scala/ignition/core/utils/AutoCloseableIterator.scala b/src/main/scala/ignition/core/utils/AutoCloseableIterator.scala new file mode 100644 index 00000000..4e3db808 --- /dev/null +++ b/src/main/scala/ignition/core/utils/AutoCloseableIterator.scala @@ -0,0 +1,67 @@ +package ignition.core.utils + +import scala.util.Try +import scala.util.control.NonFatal + +object AutoCloseableIterator { + case object empty extends AutoCloseableIterator[Nothing] { + override def naiveHasNext() = false + override def naiveNext() = throw new Exception("Empty AutoCloseableIterator") + override def naiveClose() = {} + } + + def wrap[T](iterator: Iterator[T], doClose: () => Unit = () => ()): AutoCloseableIterator[T] = new AutoCloseableIterator[T] { + override def naiveClose(): Unit = doClose() + override def naiveHasNext(): Boolean = iterator.hasNext + override def naiveNext(): T = iterator.next() + } +} + +trait AutoCloseableIterator[T] extends Iterator[T] with AutoCloseable { + // Naive functions should be implemented by the user as in a standard Iterator/AutoCloseable + def naiveHasNext(): Boolean + def naiveNext(): T + def naiveClose(): Unit + + var closed = false + + // hasNext closes the iterator and handles the case where it is already closed + override def hasNext: Boolean = if (closed) + false + else { + val naiveResult = try { + naiveHasNext + } catch { + case NonFatal(e) => + Try { close } + throw e + } + if (naiveResult) + true + else { + close // auto close when exhausted + false + } + } + + // next closes the iterator and handles the case where it is already closed + override def next(): T = if (closed) + throw new RuntimeException("Trying to get next element on a closed iterator") + else if (hasNext) + try { + naiveNext + } catch { + case NonFatal(e) => + Try { close } + throw e + } + else + throw new RuntimeException("Trying to get next element on an exhausted iterator") + + override def close() = if (!closed) { + closed = true + naiveClose + } + + override def finalize() = Try { close } +} diff --git a/src/main/scala/ignition/core/utils/BetterTrace.scala b/src/main/scala/ignition/core/utils/BetterTrace.scala new file mode 100644 index 00000000..9c91ca05 --- /dev/null +++ b/src/main/scala/ignition/core/utils/BetterTrace.scala @@ -0,0 +1,15 @@ +package ignition.core.utils + +import ignition.core.utils.ExceptionUtils._ +// Used mainly to augment scalacheck traces in scalatest +trait BetterTrace { + def fail(message: String): Nothing = throw new NotImplementedError(message) + + def withBetterTrace(block: => Unit): Unit = + try { + block + } catch { + case t: Throwable => fail(s"${t.getMessage}: ${t.getFullStackTraceString}") + } + +} diff --git a/src/main/scala/ignition/core/utils/CollectionUtils.scala b/src/main/scala/ignition/core/utils/CollectionUtils.scala index 27977270..2405c7ef 100644 --- a/src/main/scala/ignition/core/utils/CollectionUtils.scala +++ b/src/main/scala/ignition/core/utils/CollectionUtils.scala @@ -1,12 +1,39 @@ package ignition.core.utils -import scala.collection.{TraversableLike, IterableLike} -import scala.collection.generic.CanBuildFrom -import scala.language.implicitConversions import scalaz.Validation +import scala.collection.generic.CanBuildFrom +import scala.collection.{IterableLike, TraversableLike} + object CollectionUtils { + implicit class SeqImprovements[A](xs: Seq[A]) { + def orElseIfEmpty[B >: A](alternative: => Seq[B]): Seq[B] = { + if (xs.nonEmpty) + xs + else + alternative + } + + def mostFrequentOption: Option[A] = { + xs.groupBy(identity).maxByOption(_._2.size).map(_._1) + } + } + implicit class TraversableOnceImprovements[A](xs: TraversableOnce[A]) { + def maxOption(implicit cmp: Ordering[A]): Option[A] = { + if (xs.isEmpty) + None + else + Option(xs.max) + } + + def minOption(implicit cmp: Ordering[A]): Option[A] = { + if (xs.isEmpty) + None + else + Option(xs.min) + } + def maxByOption[B](f: A => B)(implicit cmp: Ordering[B]): Option[A] = { if (xs.isEmpty) None @@ -20,6 +47,13 @@ object CollectionUtils { else Option(xs.minBy(f)) } + + } + + + + implicit class TraversableOnceLong(xs: TraversableOnce[Long]) { + def toBag(): IntBag = IntBag.from(xs) } implicit class TraversableLikeImprovements[A, Repr](xs: TraversableLike[A, Repr]) { @@ -59,6 +93,7 @@ object CollectionUtils { builder.result } + } implicit class ValidatedIterableLike[T, R, Repr <: IterableLike[Validation[R, T], Repr]](seq: IterableLike[Validation[R, T], Repr]) { @@ -102,5 +137,13 @@ object CollectionUtils { .mapValues(_.map { case (k, v) => v }.reduce(fn)) .toList } + def values: List[V] = + iterable.map { case (k, v) => v }.toList + } + + + implicit class CollectionMap[K, V <: TraversableOnce[Any]](map: Map[K, V]) { + def removeEmpty(): Map[K, V] = + map.filter { case (k, v) => v.nonEmpty } } } diff --git a/src/main/scala/ignition/core/utils/DateUtils.scala b/src/main/scala/ignition/core/utils/DateUtils.scala index 231817c7..71ec771f 100644 --- a/src/main/scala/ignition/core/utils/DateUtils.scala +++ b/src/main/scala/ignition/core/utils/DateUtils.scala @@ -1,6 +1,8 @@ package ignition.core.utils -import org.joda.time.{Period, DateTimeZone, DateTime} +import java.sql.Timestamp + +import org.joda.time.{DateTime, DateTimeZone, Period, Seconds} import org.joda.time.format.ISODateTimeFormat object DateUtils { @@ -9,6 +11,10 @@ object DateUtils { implicit def dateTimeOrdering: Ordering[DateTime] = Ordering.fromLessThan(_ isBefore _) implicit def periodOrdering: Ordering[Period] = Ordering.fromLessThan(_.toStandardSeconds.getSeconds < _.toStandardSeconds.getSeconds) + implicit def timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x compareTo y + } + implicit class DateTimeImprovements(val dateTime: DateTime) { def toIsoString = isoDateTimeFormatter.print(dateTime) @@ -20,5 +26,16 @@ object DateUtils { def isEqualOrBefore(other: DateTime) = dateTime.isBefore(other) || dateTime.saneEqual(other) + + def isBetween(start: DateTime, end: DateTime) = + dateTime.isAfter(start) && dateTime.isEqualOrBefore(end) + } + + implicit class SecondsImprovements(val seconds: Seconds) { + + implicit def toScalaDuration: scala.concurrent.duration.FiniteDuration = { + scala.concurrent.duration.Duration(seconds.getSeconds, scala.concurrent.duration.SECONDS) + } + } } diff --git a/src/main/scala/ignition/core/utils/ExceptionUtils.scala b/src/main/scala/ignition/core/utils/ExceptionUtils.scala new file mode 100644 index 00000000..1ae33568 --- /dev/null +++ b/src/main/scala/ignition/core/utils/ExceptionUtils.scala @@ -0,0 +1,9 @@ +package ignition.core.utils + +object ExceptionUtils { + + implicit class ExceptionImprovements(e: Throwable) { + def getFullStackTraceString(): String = org.apache.commons.lang.exception.ExceptionUtils.getFullStackTrace(e) + } + +} diff --git a/src/main/scala/ignition/core/utils/FutureUtils.scala b/src/main/scala/ignition/core/utils/FutureUtils.scala index 068d63bc..4054f750 100644 --- a/src/main/scala/ignition/core/utils/FutureUtils.scala +++ b/src/main/scala/ignition/core/utils/FutureUtils.scala @@ -1,18 +1,51 @@ package ignition.core.utils -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.util.{Failure, Success} +import scala.concurrent.{ExecutionContext, Future, Promise, blocking} +import scala.util.control.NonFatal +import scala.util.{Failure, Success, Try} object FutureUtils { + def blockingFuture[T](body: =>T)(implicit ec: ExecutionContext): Future[T] = Future { blocking { body } } + + implicit class FutureImprovements[V](future: Future[V]) { def toOptionOnFailure(errorHandler: (Throwable) => Option[V])(implicit ec: ExecutionContext): Future[Option[V]] = { future.map(Option.apply).recover { case t => errorHandler(t) } } + + /** + * Appear to be redundant. But its the only way to map a future with + * Success and Failure in same algorithm without split it to use map/recover + * or transform. + * + * future.asTry.map { case Success(v) => 1; case Failure(e) => 0 } + * + * instead + * + * future.map(i=>1).recover(case _: Exception => 0) + * + */ + def asTry()(implicit ec: ExecutionContext) : Future[Try[V]] = { + future.map(v => Success(v)).recover { case NonFatal(e) => Failure(e) } + } + + } + + implicit class TryFutureImprovements[V](future: Try[Future[V]]) { + // Works like asTry(), but will also wrap the outer Try inside the Future + def asFutureTry()(implicit ec: ExecutionContext): Future[Try[V]] = { + future match { + case Success(f) => + f.asTry() + case Failure(e) => + Future.successful(Failure(e)) + } + } } implicit class FutureGeneratorImprovements[V](generator: Iterable[() => Future[V]]){ - def toLazyIterable(batchSize: Int = 1)(implicit ec: ExecutionContext): Iterable[Future[V]] = new Iterable[Future[V]] { + def toLazyIterable(batchSize: Int = 1): Iterable[Future[V]] = new Iterable[Future[V]] { override def iterator = new Iterator[Future[V]] { val generatorIterator = generator.toIterator var currentBatch: List[Future[V]] = List.empty diff --git a/src/main/scala/ignition/core/utils/IntBag.scala b/src/main/scala/ignition/core/utils/IntBag.scala new file mode 100644 index 00000000..1dfce82a --- /dev/null +++ b/src/main/scala/ignition/core/utils/IntBag.scala @@ -0,0 +1,63 @@ +package ignition.core.utils + +import ignition.core.utils.CollectionUtils._ + +object IntBag { + def from(numbers: TraversableOnce[Long]): IntBag = { + val histogram = scala.collection.mutable.HashMap.empty[Long, Long] + numbers.foreach(n => histogram += (n -> (histogram.getOrElse(n, 0L) + 1))) + IntBag(histogram) + } + + val empty = from(Seq.empty) +} + +case class IntBag(histogram: collection.Map[Long, Long]) { + + def +(n: Long) = + this ++ IntBag.from(n :: Nil) + + def ++(other: IntBag): IntBag = { + val newHistogram = scala.collection.mutable.HashMap.empty[Long, Long] + (histogram.keySet ++ other.histogram.keySet).foreach(k => newHistogram += (k -> (histogram.getOrElse(k, 0L) + other.histogram.getOrElse(k, 0L)))) + new IntBag(newHistogram) + } + + + def median: Option[Long] = { + percentile(50) + } + + def percentile(n: Double): Option[Long] = { + require(n > 0 && n <= 100) + histogram.keys.maxOption.flatMap { max => + val total = histogram.values.sum + val position = total * (n / 100) + + val accumulatedFrequency = (0L to max).scanLeft(0L) { case (sumFreq, k) => sumFreq + histogram.getOrElse(k, 0L) }.zipWithIndex + accumulatedFrequency.collectFirst { case (sum, k) if sum >= position => k - 1 } + } + } + + def count: Long = histogram.values.sum + + def sum: Long = histogram.map { case (k, f) => k * f }.sum + + def avg: Option[Long] = { + if (histogram.nonEmpty) + Option(sum / count) + else + None + } + + def min: Option[Long] = { + histogram.keys.minOption + } + + def max: Option[Long] = { + histogram.keys.maxOption + } + + override def toString: String = s"IntBag(median=$median, count=$count, sum=$sum, avg=$avg, min=$min, max=$max)" + +} diff --git a/src/main/scala/ignition/core/utils/S3Client.scala b/src/main/scala/ignition/core/utils/S3Client.scala deleted file mode 100644 index f02d7acd..00000000 --- a/src/main/scala/ignition/core/utils/S3Client.scala +++ /dev/null @@ -1,51 +0,0 @@ -package ignition.core.utils - -import java.util.Properties - -import org.jets3t.service.impl.rest.httpclient.RestS3Service -import org.jets3t.service.model.S3Object -import org.jets3t.service.security.AWSCredentials -import org.jets3t.service.{Constants, Jets3tProperties} - - -class S3Client { - - val jets3tProperties = { - val jets3tProperties = Jets3tProperties.getInstance(Constants.JETS3T_PROPERTIES_FILENAME) - val properties = new Properties() -// properties.put("httpclient.max-connections", "2") // The maximum number of simultaneous connections to allow globally -// properties.put("httpclient.retry-max", "10") // How many times to retry connections when they fail with IO errors -// properties.put("httpclient.socket-timeout-ms", "30000") // How many milliseconds to wait before a connection times out. 0 means infinity. - - jets3tProperties.loadAndReplaceProperties(properties, "ignition'") - jets3tProperties - } - - val service = new RestS3Service( - new AWSCredentials(System.getenv("AWS_ACCESS_KEY_ID"), System.getenv("AWS_SECRET_ACCESS_KEY")), - null, null, jets3tProperties - ) - - def writeContent(bucket: String, key: String, content: String): S3Object = { - val obj = new S3Object(key, content) - obj.setContentType("text/plain") - service.putObject(bucket, obj) - } - - def readContent(bucket: String, key: String): S3Object = { - service.getObject(bucket, key, null, null, null, null, null, null) - } - - def list(bucket: String, key: String): Array[S3Object] = { - service.listObjects(bucket, key, null, 99999L) - } - - def fileExists(bucket: String, key: String): Boolean = { - try { - service.getObjectDetails(bucket, key, null, null, null, null) - true - } catch { - case e: org.jets3t.service.S3ServiceException if e.getResponseCode == 404 => false - } - } -} diff --git a/src/main/scala/ignition/core/utils/URLUtils.scala b/src/main/scala/ignition/core/utils/URLUtils.scala new file mode 100644 index 00000000..4a0ae28c --- /dev/null +++ b/src/main/scala/ignition/core/utils/URLUtils.scala @@ -0,0 +1,24 @@ +package ignition.core.utils + +import java.net.{URLDecoder, URLEncoder} + +import org.apache.http.client.utils.URIBuilder + +import scala.util.Try + +object URLUtils { + + // Due to ancient standards, Java will encode space as + instead of using percent. + // + // See: + // http://stackoverflow.com/questions/1634271/url-encoding-the-space-character-or-20 + // https://docs.oracle.com/javase/7/docs/api/java/net/URLEncoder.html#encode(java.lang.String,%20java.lang.String) + def sanitizePathSegment(segment: String): Try[String] = + Try { URLEncoder.encode(URLDecoder.decode(segment, "UTF-8"), "UTF-8").replace("+", "%20") } + + def addParametersToUrl(url: String, partnerParams: Map[String, String]): String = { + val builder = new URIBuilder(url.trim) + partnerParams.foreach { case (k, v) => builder.addParameter(k, v) } + builder.build().toString + } +} diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties new file mode 100644 index 00000000..8455c4cf --- /dev/null +++ b/src/test/resources/log4j.properties @@ -0,0 +1,21 @@ +log4j.rootCategory=ERROR, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Ignition! +log4j.logger.ignition=ERROR + +# Disable annoying logger that is always logging an error message on ExpiringMultipleLevelCacheSpec test +log4j.logger.ignition.core.cache.ExpiringMultiLevelCache=OFF + +# Spark, Hadoop, etc +log4j.logger.org.apache=ERROR + +# Akka +log4j.logger.Remoting=ERROR + +# Jetty +log4j.logger.org.eclipse.jetty=ERROR +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +org.eclipse.jetty.LEVEL=ERROR diff --git a/src/test/scala/ignition/core/jobs/utils/RDDUtilsSpec.scala b/src/test/scala/ignition/core/jobs/utils/RDDUtilsSpec.scala index a00e5de8..eed298b6 100644 --- a/src/test/scala/ignition/core/jobs/utils/RDDUtilsSpec.scala +++ b/src/test/scala/ignition/core/jobs/utils/RDDUtilsSpec.scala @@ -6,20 +6,21 @@ import org.scalatest._ import scala.util.Random -class RDDUtilsSpec extends FlatSpec with ShouldMatchers with SharedSparkContext { +class RDDUtilsSpec extends FlatSpec with Matchers with SharedSparkContext { "RDDUtils" should "provide groupByKeyAndTake" in { - val take = 5 - val rdd = sc.parallelize((1 to Random.nextInt(40) + 10).map(x => "a" -> Random.nextInt()) ++ (1 to Random.nextInt(40) + 10).map(x => "b" -> Random.nextInt())) - val result = rdd.groupByKeyAndTake(take).collect().toMap - result("a").length shouldBe take - result("b").length shouldBe take + (10 to 60 by 10).foreach { take => + val rdd = sc.parallelize((1 to 400).map(x => "a" -> Random.nextInt()) ++ (1 to 400).map(x => "b" -> Random.nextInt()), 60) + val result = rdd.groupByKeyAndTake(take).collect().toMap + result("a").length shouldBe take + result("b").length shouldBe take + } } it should "provide groupByKeyAndTakeOrdered" in { - val take = 5 - val aList = (1 to Random.nextInt(40) + 10).map(x => "a" -> Random.nextInt()).toList - val bList = (1 to Random.nextInt(40) + 10).map(x => "b" -> Random.nextInt()).toList + val take = 50 + val aList = (1 to Random.nextInt(400) + 100).map(x => "a" -> Random.nextInt()).toList + val bList = (1 to Random.nextInt(400) + 100).map(x => "b" -> Random.nextInt()).toList val rdd = sc.parallelize(aList ++ bList) val result = rdd.groupByKeyAndTakeOrdered(take).collect().toMap result("a") shouldBe aList.map(_._2).sorted.take(take) diff --git a/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala b/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala index c19579ce..26757c26 100644 --- a/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala +++ b/src/test/scala/ignition/core/utils/CollectionUtilsSpec.scala @@ -3,7 +3,7 @@ package ignition.core.utils import org.scalatest._ import CollectionUtils._ -class CollectionUtilsSpec extends FlatSpec with ShouldMatchers { +class CollectionUtilsSpec extends FlatSpec with Matchers { case class MyObj(property: String, value: String) "CollectionUtils" should "provide distinctBy" in { @@ -32,7 +32,23 @@ class CollectionUtilsSpec extends FlatSpec with ShouldMatchers { list.compressBy(_.value) shouldBe List(MyObj("p1", "v1"), MyObj("p1", "v2")) } + it should "provide orElseIfEmpty" in { + Seq.empty[String].orElseIfEmpty(Seq("something")) shouldBe Seq("something") + Seq("not empty").orElseIfEmpty(Seq("something")) shouldBe Seq("not empty") + } + + it should "provide maxOption and minOption" in { + Seq.empty[Int].maxOption shouldBe None + Seq(1, 3, 2).maxOption shouldBe Some(3) + Seq.empty[Int].minOption shouldBe None + Seq(1, 3, 2).minOption shouldBe Some(1) + } + + it should "provide mostFrequentOption" in { + Seq.empty[String].mostFrequentOption shouldBe None + Seq("a", "b", "b", "c", "a", "b").mostFrequentOption shouldBe Option("b") + } } diff --git a/src/test/scala/ignition/core/utils/FutureUtilsSpec.scala b/src/test/scala/ignition/core/utils/FutureUtilsSpec.scala index 8c2b3270..c10b50d5 100644 --- a/src/test/scala/ignition/core/utils/FutureUtilsSpec.scala +++ b/src/test/scala/ignition/core/utils/FutureUtilsSpec.scala @@ -1,43 +1,56 @@ package ignition.core.utils -import FutureUtils._ +import ignition.core.utils.FutureUtils._ import org.scalatest._ +import org.scalatest.concurrent.ScalaFutures -import scala.concurrent.{Await, Future} -import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future +import scala.concurrent.duration._ -class FutureUtilsSpec extends FlatSpec with ShouldMatchers { +class FutureUtilsSpec extends FlatSpec with Matchers with ScalaFutures { "FutureUtils" should "provide toLazyIterable" in { val timesCalled = collection.mutable.Map.empty[Int, Int].withDefaultValue(0) - val generators = (0 until 20).map { i => () => Future { timesCalled(i) += 1 ; i } } + val generators = (0 until 20).map { i => () => Future { synchronized { timesCalled(i) += 1 } ; i } } val iterable = generators.toLazyIterable() val iterator = iterable.toIterator - timesCalled.forall { case (key, count) => count == 0 } shouldBe true + timesCalled.forall { case (_, count) => count == 0 } shouldBe true - Await.result(iterator.next(), 2.seconds) + whenReady(iterator.next(), timeout(2.seconds)) { _ => () } timesCalled(0) shouldBe 1 (1 until 20).foreach { i => timesCalled(i) shouldBe 0 } - Await.result(Future.sequence(iterator), 5.seconds).toList shouldBe (1 until 20).toList + whenReady(Future.sequence(iterator), timeout(5.seconds)) { result => + result.toList shouldBe (1 until 20).toList + } (0 until 20).foreach { i => timesCalled(i) shouldBe 1 } } it should "provide collectAndTake" in { val timesCalled = collection.mutable.Map.empty[Int, Int].withDefaultValue(0) - val iterable = (0 until 30).map { i => () => Future { timesCalled(i) += 1 ; i } }.toLazyIterable() + val iterable = (0 until 30).map { i => + () => + Future { + synchronized { + timesCalled(i) += 1 + } + i + } + }.toLazyIterable() val expectedRange = Range(5, 15) - val result = Await.result(iterable.collectAndTake({ case i if expectedRange.contains(i) => i }, n = expectedRange.size), 5.seconds) - result shouldBe expectedRange.toList + val f: Future[List[Int]] = iterable.collectAndTake({ case i if expectedRange.contains(i) => i }, n = expectedRange.size) + + whenReady(f, timeout(5.seconds)) { result => + result shouldBe expectedRange.toList + } (0 until 20).foreach { i => timesCalled(i) shouldBe 1 } // 2 batches of size 10 (20 until 30).foreach { i => timesCalled(i) shouldBe 0 } // last batch won't be ran - } } diff --git a/src/test/scala/ignition/core/utils/IntBagSpec.scala b/src/test/scala/ignition/core/utils/IntBagSpec.scala new file mode 100644 index 00000000..f577237e --- /dev/null +++ b/src/test/scala/ignition/core/utils/IntBagSpec.scala @@ -0,0 +1,33 @@ +package ignition.core.utils + +import org.scalatest._ + +import scala.util.Random + +class IntBagSpec extends FlatSpec with Matchers { + + "IntBag" should "be built from sequence" in { + IntBag.from(Seq(1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 4)).histogram shouldBe Map(1 -> 2, 2 -> 3, 3 -> 1, 4 -> 5) + } + + it should "calculate the average" in { + val size = 1000 + val numbers = (0 until size).map(_ => Random.nextInt(400).toLong).toList + val bag = IntBag.from(numbers) + + bag.avg.get shouldBe numbers.sum / size + } + + it should "calculate the percentile, min and max" in { + val size = 3 // anything different is hard to guess because of the approximation + val numbers = (0 until size).map(_ => Random.nextInt(400).toLong).toList + val bag = IntBag.from(numbers) + + bag.min.get shouldBe numbers.min + bag.percentile(0.1).get shouldBe numbers.min + bag.median.get shouldBe numbers.sorted.apply(1) + bag.percentile(99.9).get shouldBe numbers.max + bag.max.get shouldBe numbers.max + } + +} diff --git a/src/test/scala/ignition/core/utils/URLUtilsSpec.scala b/src/test/scala/ignition/core/utils/URLUtilsSpec.scala new file mode 100644 index 00000000..61781903 --- /dev/null +++ b/src/test/scala/ignition/core/utils/URLUtilsSpec.scala @@ -0,0 +1,46 @@ +package ignition.core.utils + +import org.scalatest.{FlatSpec, Matchers} + +class URLUtilsSpec extends FlatSpec with Matchers { + + "URLUtils" should "add parameters to url with encoded params in base url and not be double encoded" in { + val baseUrl: String = "https://tracker.client.com/product=1?email=user%40mail.com" + val params = Map("cc" -> "second@mail.com") + + val result: String = URLUtils.addParametersToUrl(baseUrl, params) + result shouldEqual "https://tracker.client.com/product=1?email=user%40mail.com&cc=second%40mail.com" + } + + it should "add multiples params with the same name" in { + val baseUrl: String = "https://tracker.client.com/product=1?email=user%40mail.com&cc=second%40mail.com" + val params = Map("cc" -> "third@mail.com") + + val result: String = URLUtils.addParametersToUrl(baseUrl, params) + result shouldEqual "https://tracker.client.com/product=1?email=user%40mail.com&cc=second%40mail.com&cc=third%40mail.com" + } + + it should "works with Fragment in original URL" in { + + val baseUrl = "https://www.petlove.com.br/carrinho?utm_campanha=internalmkt#/add/variant_sku/310178,31012214/quantity/1?t=1" + val params: Map[String, String] = Map( + "utm_campaign" -> "abandonodecarrinho", + "utm_source" -> "chaordic-mail", + "utm_medium" -> "emailmkt", + "cc" -> "second@mail.com" + ) + + val result = URLUtils.addParametersToUrl(baseUrl, params) + + val expected = "https://www.petlove.com.br/carrinho?utm_campanha=internalmkt&utm_campaign=abandonodecarrinho&utm_source=chaordic-mail&utm_medium=emailmkt&cc=second%40mail.com#/add/variant_sku/310178,31012214/quantity/1?t=1" + + result shouldEqual expected + } + + it should "handle urls with new line character at the edges" in { + val url = "\n\t\n\thttps://www.petlove.com.br/carrinho#/add/variant_sku/3105748-1,3107615/quantity/1?t=1\n\t" + val finalUrl = URLUtils.addParametersToUrl(url, Map("test" -> "true")) + finalUrl shouldEqual "https://www.petlove.com.br/carrinho?test=true#/add/variant_sku/3105748-1,3107615/quantity/1?t=1" + } + +} diff --git a/tools/cluster.py b/tools/cluster.py index 3cf1828a..ca33835e 100755 --- a/tools/cluster.py +++ b/tools/cluster.py @@ -7,14 +7,14 @@ """ +from cgitb import reset import argh from argh import ArghParser, CommandError from argh.decorators import named, arg import subprocess from subprocess import check_output, check_call -from itertools import chain -from utils import tag_instances, get_masters, get_active_nodes -from utils import check_call_with_timeout, ProcessTimeoutException +from utils import tag_instances, get_masters, get_active_nodes, get_active_nodes_by_tag +from utils import check_call_with_timeout, check_call_with_timeout_describe, destroy_by_fleet_id import os import sys from datetime import datetime @@ -23,7 +23,8 @@ import getpass import json import glob - +import webbrowser +import ssl log = logging.getLogger() log.setLevel(logging.INFO) @@ -38,28 +39,32 @@ default_instance_type = 'r3.xlarge' default_spot_price = '0.10' default_worker_instances = '1' -default_master_instance_type = 'm3.xlarge' +default_executor_instances = '1' +default_master_instance_type = '' +default_driver_heap_size = '12G' +default_min_root_ebs_size_gb = '30' default_region = 'us-east-1' default_zone = default_region + 'b' default_key_id = 'ignition_key' default_key_file = os.path.expanduser('~/.ssh/ignition_key.pem') -default_ami = None # will be decided based on spark-ec2 list -default_master_ami = None +default_ami = 'ami-611e7976' +default_master_ami = '' default_env = 'dev' -default_spark_version = '1.3.0' -default_spark_repo = 'https://github.com/chaordic/spark' +default_spark_version = '2.4.3' +default_hdfs_version = '2.7.6' +default_spark_download_source = 'https://s3.amazonaws.com/chaordic-ignition-public/spark-{v}-bin-hadoop2.7.tgz' +default_hdfs_download_source = 'https://s3.amazonaws.com/chaordic-ignition-public/hadoop-{v}.tar.gz' default_remote_user = 'ec2-user' +default_installation_user = 'ec2-user' default_remote_control_dir = '/tmp/Ignition' default_collect_results_dir = '/tmp' -default_user_data = os.path.join(script_path, 'scripts', 'S05mount-disks') +default_user_data = os.path.join(script_path, 'scripts', 'noop') default_defaults_filename = 'cluster_defaults.json' - -default_spark_ec2_git_repo = 'https://github.com/chaordic/spark-ec2' -default_spark_ec2_git_branch = 'v4-yarn' +default_vpc='vpc-94215df1' master_post_create_commands = [ - 'sudo', 'yum', '-y', 'install', 'tmux' + ['sudo', 'yum', '-y', 'install', 'tmux'], ] @@ -111,8 +116,10 @@ def logged_call(args, tries=1): return logged_call_base(check_call, args, tries) -def ssh_call(user, host, key_file, args=(), allocate_terminal=True, get_output=False): - base = ['ssh', '-q'] +def ssh_call(user, host, key_file, args=(), allocate_terminal=True, get_output=False, quiet=False): + base = ['ssh'] + if quiet: + base += ['-q'] if allocate_terminal: base += ['-tt'] base += ['-i', key_file, @@ -120,24 +127,34 @@ def ssh_call(user, host, key_file, args=(), allocate_terminal=True, get_output=F '{0}@{1}'.format(user, host)] base += args if get_output: - return logged_call_output(base) + return logged_call_output(base).decode("utf-8") else: return logged_call(base) +def ec2_script_base_path(): + return os.path.join(script_path, 'flintrock') def chdir_to_ec2_script_and_get_path(): - ec2_script_base = os.path.join(script_path, 'spark-ec2') + ec2_script_base = ec2_script_base_path() os.chdir(ec2_script_base) - ec2_script_path = os.path.join(ec2_script_base, 'spark_ec2.py') + ec2_script_path = os.path.join(ec2_script_base, 'standalone.py') return ec2_script_path -def call_ec2_script(args, timeout_total_minutes, timeout_inactivity_minutes): +def call_ec2_script(args, timeout_total_minutes, timeout_inactivity_minutes, stdout=None): + ec2_script_path = chdir_to_ec2_script_and_get_path() + return check_call_with_timeout(['/usr/bin/env', 'python3', '-u', + ec2_script_path] + args, + stdout=stdout, + timeout_total_minutes=timeout_total_minutes, + timeout_inactivity_minutes=timeout_inactivity_minutes) +def call_ec2_script_describe(args, timeout_total_minutes, timeout_inactivity_minutes, stdout=None): ec2_script_path = chdir_to_ec2_script_and_get_path() - return check_call_with_timeout(['/usr/bin/env', 'python', '-u', + return check_call_with_timeout_describe(['/usr/bin/env', 'python3', '-u', ec2_script_path] + args, - timeout_total_minutes=timeout_total_minutes, - timeout_inactivity_minutes=timeout_inactivity_minutes) + stdout=stdout, + timeout_total_minutes=timeout_total_minutes, + timeout_inactivity_minutes=timeout_inactivity_minutes) def cluster_exists(cluster_name, region): @@ -164,18 +181,22 @@ def save_cluster_args(master, key_file, remote_user, all_args): args=["echo '{}' > /tmp/cluster_args.json".format(json.dumps(all_args))]) def load_cluster_args(master, key_file, remote_user): - return json.loads(ssh_call(user=remote_user, host=master, key_file=key_file, + return json.loads(ssh_call(user=remote_user, host=master, key_file=key_file, allocate_terminal=False, args=["cat", "/tmp/cluster_args.json"], get_output=True)) # Util to be used by external scripts def save_extra_data(data_str, cluster_name, region=default_region, key_file=default_key_file, remote_user=default_remote_user, master=None): master = master or get_master(cluster_name, region=region) - ssh_call(user=remote_user, host=master, key_file=key_file, - args=["echo '{}' > /tmp/cluster_extra_data.txt".format(data_str)]) + cmd = ['ssh', '-o', 'StrictHostKeyChecking=no', remote_user + '@' + master , '-i', key_file, '/bin/bash', '-c', 'cat > /tmp/cluster_extra_data.txt'] + p = subprocess.Popen(cmd, stdin=subprocess.PIPE) + p.communicate(data_str) + if p.wait() != 0: + raise Exception('Error saving extra data on master') + def load_extra_data(cluster_name, region=default_region, key_file=default_key_file, remote_user=default_remote_user, master=None): master = master or get_master(cluster_name, region=region) - return ssh_call(user=remote_user, host=master, key_file=key_file, + return ssh_call(user=remote_user, host=master, key_file=key_file, allocate_terminal=False, args=["cat", "/tmp/cluster_extra_data.txt"], get_output=True) @@ -201,90 +222,103 @@ def launch(cluster_name, slaves, tag=[], key_id=default_key_id, region=default_region, zone=default_zone, instance_type=default_instance_type, - ondemand=False, spot_price=default_spot_price, + # TODO: implement it in flintrock + ondemand=False, + spot_price=default_spot_price, + # TODO: implement it in flintrock + master_spot=False, user_data=default_user_data, - security_group = None, - vpc = None, - vpc_subnet = None, + security_group=None, + vpc=None, + vpc_subnet=None, + # TODO: consider implementing in flintrock master_instance_type=default_master_instance_type, - wait_time='180', hadoop_major_version='2', - worker_instances=default_worker_instances, retries_on_same_cluster=5, + executor_instances=default_executor_instances, + min_root_ebs_size_gb=default_min_root_ebs_size_gb, + retries_on_same_cluster=5, max_clusters_to_create=5, minimum_percentage_healthy_slaves=0.9, remote_user=default_remote_user, + installation_user=default_installation_user, script_timeout_total_minutes=55, script_timeout_inactivity_minutes=10, - resume=False, just_ignore_existing=False, worker_timeout=240, - spark_repo=default_spark_repo, + just_ignore_existing=False, + spark_download_source=default_spark_download_source, spark_version=default_spark_version, - spark_ec2_git_repo=default_spark_ec2_git_repo, - spark_ec2_git_branch=default_spark_ec2_git_branch, - ami=default_ami, master_ami=default_master_ami): + hdfs_download_source=default_hdfs_download_source, + hdfs_version=default_hdfs_version, + ami=default_ami, + # TODO: consider implementing in flintrock + master_ami=default_master_ami, + instance_profile_name=None): + + assert not master_instance_type or master_instance_type == instance_type, 'Different master instance type is currently unsupported' + assert not master_ami or master_ami == ami, 'Different master ami is currently unsupported' + assert not ondemand, 'On demand is unsupported' + assert master_spot, 'On demand master is currently unsupported' all_args = locals() - if cluster_exists(cluster_name, region=region) and not resume: + if cluster_exists(cluster_name, region=region): if just_ignore_existing: log.info('Cluster exists but that is ok') return '' else: - raise CommandError('Cluster already exists, pick another name or resume the setup using --resume') + raise CommandError('Cluster already exists, pick another name') for j in range(max_clusters_to_create): log.info('Creating new cluster {0}, try {1}'.format(cluster_name, j+1)) success = False - resume_param = ['--resume'] if resume else [] auth_params = [] - if security_group: - auth_params.extend([ - '--authorized-address', '127.0.0.1/32', - '--additional-security-group', security_group - ]) # '--vpc-id', default_vpc, # '--subnet-id', default_vpc_subnet, if vpc and vpc_subnet: auth_params.extend([ - '--vpc-id', vpc, - '--subnet-id', vpc_subnet, + '--ec2-vpc-id', vpc, + '--ec2-subnet-id', vpc_subnet, ]) - spot_params = ['--spot-price', spot_price] if not ondemand else [] - ami_params = ['--ami', ami] if ami else [] - master_ami_params = ['--master-ami', master_ami] if master_ami else [] + spot_params = ['--ec2-spot-price', spot_price] if not ondemand else [] + #master_spot_params = ['--master-spot'] if not ondemand and master_spot else [] + + ami_params = ['--ec2-ami', ami] if ami else [] + #master_ami_params = ['--master-ami', master_ami] if master_ami else [] + + iam_params = ['--ec2-instance-profile-name', instance_profile_name] if instance_profile_name else [] for i in range(retries_on_same_cluster): log.info('Running script, try %d of %d', i + 1, retries_on_same_cluster) try: - call_ec2_script(['--identity-file', key_file, - '--key-pair', key_id, - '--slaves', slaves, - '--region', region, - '--zone', zone, - '--instance-type', instance_type, - '--master-instance-type', master_instance_type, - '--wait', wait_time, - '--hadoop-major-version', hadoop_major_version, - '--spark-ec2-git-repo', spark_ec2_git_repo, - '--spark-ec2-git-branch', spark_ec2_git_branch, - '--worker-instances', worker_instances, - '--master-opts', '-Dspark.worker.timeout={0}'.format(worker_timeout), - '--spark-git-repo', spark_repo, - '-v', spark_version, - '--user-data', user_data, - 'launch', cluster_name] + + call_ec2_script(['--debug', + 'launch', + '--ec2-identity-file', key_file, + '--ec2-key-name', key_id, + '--num-slaves', slaves, + '--ec2-region', region, + '--ec2-instance-type', instance_type, + '--ec2-min-root-ebs-size-gb', min_root_ebs_size_gb, + '--assume-yes', + '--install-spark', + '--install-hdfs', + '--spark-version', spark_version, + '--hdfs-version', hdfs_version, + '--spark-download-source', spark_download_source, + '--hdfs-download-source', hdfs_download_source, + '--spark-executor-instances', executor_instances, + '--ec2-security-group', security_group, + '--ec2-user', installation_user, + '--ec2-user-data', user_data, + '--launch-template-name', cluster_name, + cluster_name] + spot_params + - resume_param + auth_params + ami_params + - master_ami_params, + iam_params, timeout_total_minutes=script_timeout_total_minutes, timeout_inactivity_minutes=script_timeout_inactivity_minutes) success = True - except subprocess.CalledProcessError as e: - resume_param = ['--resume'] - log.warn('Failed with: %s', e) except Exception as e: # Probably a timeout log.exception('Fatal error calling EC2 script') @@ -300,33 +334,102 @@ def launch(cluster_name, slaves, master = get_master(cluster_name, region=region) save_cluster_args(master, key_file, remote_user, all_args) health_check(cluster_name=cluster_name, key_file=key_file, master=master, remote_user=remote_user, region=region) - ssh_call(user=remote_user, host=master, key_file=key_file, args=master_post_create_commands) + for command in master_post_create_commands: + ssh_call(user=remote_user, host=master, key_file=key_file, args=command) return master except Exception as e: log.exception('Got exception on last steps of cluster configuration') log.warn('Destroying unsuccessful cluster') - destroy(cluster_name=cluster_name, region=region) - raise CommandError('Failed to created cluster {} after failures'.format(cluster_name)) + destroy(cluster_name=cluster_name, region=region, wait_termination=True) + raise CommandError('Failed to created cluster {0} after failures'.format(cluster_name)) + + +def destroy_by_flyntrock(region, cluster_name, vpc=default_vpc, script_timeout_total_minutes=55, script_timeout_inactivity_minutes=10, wait_termination=False, wait_timeout_minutes=10): + # create a variable to store the result + result = False + + try: # create a try catch to manage the possible erros + cluster = call_ec2_script_describe(['describe', cluster_name,'--ec2-vpc-id',vpc],timeout_total_minutes=script_timeout_total_minutes, timeout_inactivity_minutes=script_timeout_inactivity_minutes) + if cluster == cluster_name: + call_ec2_script(['destroy','--assume-yes', cluster_name,'--ec2-vpc-id',vpc],timeout_total_minutes=script_timeout_total_minutes, timeout_inactivity_minutes=script_timeout_inactivity_minutes) + result = True + except Exception as e: + #log.info('Error to destroy cluster {0} by flintrock'.format(cluster_name)) + destroy_by_cluster_name_tag(region, 'spark_cluster_name', cluster_name, wait_termination, wait_timeout_minutes) + pass + return result -def destroy(cluster_name, delete_groups=False, region=default_region): - delete_sg_param = ['--delete-groups'] if delete_groups else [] +def destroy_by_cluster_name_tag(region, tag_name, cluster_name, wait_termination, wait_timeout_minutes): + instances = get_active_nodes_by_tag(region, tag_name, cluster_name) + + if instances: + #log.info('Trying to terminate remain instances by id.') - ec2_script_path = chdir_to_ec2_script_and_get_path() - p = subprocess.Popen(['/usr/bin/env', 'python', '-u', - ec2_script_path, - 'destroy', cluster_name, - '--region', region] + delete_sg_param, - stdin=subprocess.PIPE, - stdout=sys.stdout, universal_newlines=True) - p.communicate('y') + for instance in instances: + #log.info('Terminate instance {0}'.format(instance.id)) + instance.terminate() + log.info('Instance {0} is terminating.'.format(instance.id)) + + # call this function to wait instances to terminate + wait_for_intances_to_terminate(cluster_name, wait_termination, wait_timeout_minutes, instances) + + return instances + + +def destroy(cluster_name, wait_termination=False, vpc=default_vpc, wait_timeout_minutes=10, delete_groups=False, region=default_region,script_timeout_total_minutes=55,script_timeout_inactivity_minutes=10): + assert not delete_groups, 'Delete groups is deprecated and unsupported' + masters, slaves = get_active_nodes(cluster_name, region=region) + + try: # First we test if exist the cluster with the function cluster_exists + # get instances ids by json return and cancel the requests + + # if in dev environment, will delete the flintrock SG rules of the machine running this script + if os.getenv('ENVIRONMENT') == 'development': + revoke_sg_script = os.path.join(script_path, 'revoke_sg_rules.py') + process = subprocess.Popen(["python3", revoke_sg_script, region, vpc], stdout=subprocess.PIPE) + stdout_str = process.communicate()[0] + log.info(stdout_str) + + wait_for_intances_to_terminate(cluster_name, wait_termination, wait_timeout_minutes, destroy_by_fleet_id(region, cluster_name)) + + # test if the cluster exists and call destroy by fintorock to destroy it + if destroy_by_flyntrock(region, cluster_name, vpc, script_timeout_total_minutes, script_timeout_inactivity_minutes, wait_termination, wait_timeout_minutes): + # Here we use the script to destroy the cluster using the name of it + all_instances = masters + slaves + # To better view about what the script is doing i choose to let the same code of the destroy i have updated + wait_for_intances_to_terminate(cluster_name, wait_termination, wait_timeout_minutes, all_instances) + # Here is the exception of the try if we don't find the cluster + except Exception as e: + log.info('Does not exist %s', cluster_name) + pass + +def wait_for_intances_to_terminate(cluster_name, wait_termination=False, wait_timeout_minutes=10, all_instances=[]): + # To better view about what the script is doing i choose to let the same code of the destroy i have updated + if all_instances: + log.info('The %s will be terminated:', cluster_name) + for i in all_instances: + log.info('-> %s' % (i.public_dns_name or i.private_dns_name or i.id)) + + if wait_termination: + log.info('Waiting for instances termination...') + termination_timeout = wait_timeout_minutes*60 + termination_start = time.time() + + while wait_termination and all_instances and time.time() < termination_start+termination_timeout: + all_instances = [i for i in all_instances if i.state != 'terminated'] + time.sleep(5) + for i in all_instances: + i.update() + # The log says the destruction is Done but is still running, just chill and enjoy the ride + log.info('Done.') def get_master(cluster_name, region=default_region): masters = get_masters(cluster_name, region=region) if not masters: raise CommandError("No master on {}".format(cluster_name)) - return masters[0].public_dns_name + return masters[0].public_dns_name or masters[0].private_dns_name def ssh_master(cluster_name, key_file=default_key_file, user=default_remote_user, region=default_region, *args): @@ -334,6 +437,24 @@ def ssh_master(cluster_name, key_file=default_key_file, user=default_remote_user ssh_call(user=user, host=master, key_file=key_file, args=args) +def exec_shell(cluster_name, command, key_file=default_key_file, user=default_remote_user, region=default_region, sudo=False): + import subprocess + masters, slaves = get_active_nodes(cluster_name, region=region) + if not masters: + log.warn('No master found') + for node in masters + slaves: + host = node.public_dns_name or node.private_dns_name + log.info("exec output of host %s\n", host) + cmd = ['ssh', '-t', '-o', 'StrictHostKeyChecking=no', user + '@' + host ,'-i', key_file] + if sudo: + cmd += ['sudo'] + cmd += ['bash'] + p = subprocess.Popen(cmd, stdin=subprocess.PIPE) + p.communicate(command.encode('utf-8')) + if p.wait() != 0: + log.warn('\nError executing command on host: %s', host) + + def rsync_call(user, host, key_file, args=[], src_local='', dest_local='', remote_path='', tries=3): rsync_args = ['rsync', '--timeout', '60', '-azvP'] rsync_args += ['-e', 'ssh -i {} -o StrictHostKeyChecking=no'.format(key_file)] @@ -349,7 +470,8 @@ def build_assembly(): def get_assembly_path(): paths = glob.glob(get_project_path() + '/target/scala-*/*assembly*.jar') if paths: - return paths[0] + paths.sort(key=os.path.getmtime) + return paths[-1] else: return None @@ -359,6 +481,8 @@ def get_assembly_path(): @arg('--disable-tmux', help='Do not use tmux. Warning: many features will not work without tmux. Use only if the tmux is missing on the master.') @arg('--detached', help='Run job in background, requires tmux') @arg('--destroy-cluster', help='Will destroy cluster after finishing the job') +@arg('--extra', action='append', type=str, help='Additional arguments for the job in the format k=v') +@arg('--disable-propagate-aws-credentials', help='Setting this to true will not propagate your AWS credentials from your environment to the master') @named('run') def job_run(cluster_name, job_name, job_mem, key_file=default_key_file, disable_tmux=False, @@ -370,9 +494,13 @@ def job_run(cluster_name, job_name, job_mem, remote_control_dir = default_remote_control_dir, remote_path=None, master=None, disable_assembly_build=False, - run_tests=False, kill_on_failure=False, - destroy_cluster=False, region=default_region): + destroy_cluster=False, + region=default_region, + driver_heap_size=default_driver_heap_size, + remove_files=True, + disable_propagate_aws_credentials=False, + extra=[]): utc_job_date_example = '2014-05-04T13:13:10Z' if utc_job_date and len(utc_job_date) != len(utc_job_date_example): @@ -383,20 +511,21 @@ def job_run(cluster_name, job_name, job_mem, project_path = get_project_path() project_name = os.path.basename(project_path) - module_name = os.path.basename(get_module_path()) # Use job user on remote path to avoid too many conflicts for different local users remote_path = remote_path or '/home/%s/%s.%s' % (default_remote_user, job_user, project_name) remote_hook_local = '{module_path}/remote_hook.sh'.format(module_path=get_module_path()) remote_hook = '{remote_path}/remote_hook.sh'.format(remote_path=remote_path) notify_param = 'yes' if notify_on_errors else 'no' yarn_param = 'yes' if yarn else 'no' + aws_vars = get_aws_keys_str() if not disable_propagate_aws_credentials else '' job_date = utc_job_date or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') job_tag = job_tag or job_date.replace(':', '_').replace('-', '_').replace('Z', 'UTC') + runner_extra_args = ' '.join('--runner-extra "%s"' % arg for arg in extra) tmux_wait_command = ';(echo Press enter to keep the session open && /bin/bash -c "read -t 5" && sleep 7d)' if not detached else '' - tmux_arg = ". /etc/profile; . ~/.profile;tmux new-session {detached} -s spark.{job_name}.{job_tag} '{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} {tmux_wait_command}' >& /tmp/commandoutput".format( - aws_vars=get_aws_keys_str(), job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, detached='-d' if detached else '', yarn_param=yarn_param, notify_param=notify_param, tmux_wait_command=tmux_wait_command) - non_tmux_arg = ". /etc/profile; . ~/.profile;{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} >& /tmp/commandoutput".format( - aws_vars=get_aws_keys_str(), job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, yarn_param=yarn_param, notify_param=notify_param) + tmux_arg = ". /etc/profile; . ~/.profile;tmux new-session {detached} -s spark.{job_name}.{job_tag} '{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} {driver_heap_size} {runner_extra_args} {tmux_wait_command}' >& /tmp/commandoutput".format( + aws_vars=aws_vars, job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, detached='-d' if detached else '', yarn_param=yarn_param, notify_param=notify_param, driver_heap_size=driver_heap_size, runner_extra_args=runner_extra_args, tmux_wait_command=tmux_wait_command) + non_tmux_arg = ". /etc/profile; . ~/.profile;{aws_vars} {remote_hook} {job_name} {job_date} {job_tag} {job_user} {remote_control_dir} {spark_mem} {yarn_param} {notify_param} {driver_heap_size} {runner_extra_args} >& /tmp/commandoutput".format( + aws_vars=aws_vars, job_name=job_name, job_date=job_date, job_tag=job_tag, job_user=job_user, remote_control_dir=remote_control_dir, remote_hook=remote_hook, spark_mem=job_mem, yarn_param=yarn_param, notify_param=notify_param, driver_heap_size=driver_heap_size, runner_extra_args=runner_extra_args) if not disable_assembly_build: @@ -421,6 +550,9 @@ def job_run(cluster_name, job_name, job_mem, src_local=remote_hook_local, remote_path=with_leading_slash(remote_path)) + if job_name == "zeppelin": + webbrowser.open("http://{master}:8081".format(master=master)) + log.info('Will run job in remote host') if disable_tmux: ssh_call(user=remote_user, host=master, key_file=key_file, args=[non_tmux_arg], allocate_terminal=False) @@ -428,6 +560,7 @@ def job_run(cluster_name, job_name, job_mem, ssh_call(user=remote_user, host=master, key_file=key_file, args=[tmux_arg], allocate_terminal=True) if wait_completion: + time.sleep(5) # wait job to set up before checking it failed = False failed_exception = None try: @@ -436,7 +569,7 @@ def job_run(cluster_name, job_name, job_mem, region=region, job_timeout_minutes=job_timeout_minutes, remote_user=remote_user, remote_control_dir=remote_control_dir, - collect_results_dir=collect_results_dir) + collect_results_dir=collect_results_dir, remove_files=remove_files) except JobFailure as e: failed = True failed_exception = e @@ -467,6 +600,82 @@ def job_run(cluster_name, job_name, job_mem, raise failed_exception or Exception('Failed!?') return (job_name, job_tag) +@argh.arg('-c', '--conf', action='append', type=str) +@arg('job-mem', help='The amount of memory to use for this job (like: 80G)') +@named('local-yarn-run') +def job_local_yarn_run(job_name, job_mem, queue, + job_user=getpass.getuser(), + utc_job_date=None, job_tag=None, + disable_assembly_build=False, + executor_cores=5, + spark_submit='spark-submit', + deploy_mode='cluster', + yarn_memory_overhead=0.2, + driver_heap_size=default_driver_heap_size, + driver_java_options='-verbose:gc -XX:-PrintGCDetails -XX:+PrintGCTimeStamps', + conf=[]): + + def parse_memory(s): + import re + match = re.match(r'([0-9]+)([a-zA-Z]+)', s) + if match is None or len(match.groups()) != 2: + raise Exception('Invalid memory size: ' + s) + return match.groups() + + def calculate_overhead(s): + from math import ceil + (n, unit) = parse_memory(s) + return str(int(ceil(float(n) * yarn_memory_overhead))) + unit + + driver_overhead = calculate_overhead(driver_heap_size) + executor_overhead = calculate_overhead(job_mem) + + utc_job_date_example = '2014-05-04T13:13:10Z' + if utc_job_date and len(utc_job_date) != len(utc_job_date_example): + raise CommandError('UTC Job Date should be given as in the following example: {}'.format(utc_job_date_example)) + + job_date = utc_job_date or datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') + job_tag = job_tag or job_date.replace(':', '_').replace('-', '_').replace('Z', 'UTC') + + if not disable_assembly_build: + build_assembly() + + assembly_path = get_assembly_path() + if assembly_path is None: + raise Exception('Something is wrong: no assembly found') + + + log.info('Will run job using local installation of yarn') + confs = [ + spark_submit, + '--class', 'ignition.jobs.Runner', + '--master', 'yarn', + '--driver-java-options', driver_java_options, + '--deploy-mode', deploy_mode, + '--queue', queue, + '--conf', 'spark.executor.cores=' + str(executor_cores), + '--driver-memory', driver_heap_size, + '--conf', 'spark.yarn.am.memory=' + driver_heap_size, + '--executor-memory', job_mem, + '--conf', 'spark.yarn.am.memoryOverhead=' + driver_overhead, + '--conf', 'spark.driver.memoryOverhead=' + driver_overhead, + '--conf', 'spark.executor.memoryOverhead=' + executor_overhead + ] + + for c in conf: + confs.extend(['--conf', c]) + + check_call( + confs + [ + assembly_path, + job_name, + '--runner-master', 'yarn', + '--runner-executor-memory', job_mem, + '--runner-user', job_user, + '--runner-tag', job_tag, + '--runner-date', job_date + ]) + @named('attach') def job_attach(cluster_name, key_file=default_key_file, job_name=None, job_tag=None, @@ -483,15 +692,25 @@ def job_attach(cluster_name, key_file=default_key_file, job_name=None, job_tag=N class NotHealthyCluster(Exception): pass @named('health-check') -def health_check(cluster_name, key_file=default_key_file, master=None, remote_user=default_remote_user, region=default_region): - master = master or get_master(cluster_name, region=region) - all_args = load_cluster_args(master, key_file, remote_user) - nslaves = int(all_args['slaves']) - minimum_percentage_healthy_slaves = all_args['minimum_percentage_healthy_slaves'] - masters, slaves = get_active_nodes(cluster_name, region=region) - if nslaves == 0 or float(len(slaves)) / nslaves < minimum_percentage_healthy_slaves: - raise NotHealthyCluster('Not enough healthy slaves: {0}/{1}'.format(len(slaves), nslaves)) - +def health_check(cluster_name, key_file=default_key_file, master=None, remote_user=default_remote_user, region=default_region, retries=3): + for i in range(retries): + try: + master = master or get_master(cluster_name, region=region) + all_args = load_cluster_args(master, key_file, remote_user) + nslaves = int(all_args['slaves']) + minimum_percentage_healthy_slaves = all_args['minimum_percentage_healthy_slaves'] + masters, slaves = get_active_nodes(cluster_name, region=region) + if nslaves == 0 or float(len(slaves)) / nslaves < minimum_percentage_healthy_slaves: + raise NotHealthyCluster('Not enough healthy slaves: {0}/{1}'.format(len(slaves), nslaves)) + if not masters: + raise NotHealthyCluster('No master found') + except NotHealthyCluster as e: + raise e + except Exception as e: + log.warning("Failed to check cluster health, cluster: %s, retries %s" % (cluster_name, i), exc_info=True) + if i >= retries - 1: + log.critical("Failed to check cluster health, cluster: %s, giveup!" % (cluster_name)) + raise e class JobFailure(Exception): pass @@ -513,16 +732,18 @@ def collect_job_results(cluster_name, job_name, job_tag, region=default_region, master=None, remote_user=default_remote_user, remote_control_dir=default_remote_control_dir, - collect_results_dir=default_collect_results_dir): + collect_results_dir=default_collect_results_dir, + remove_files=False): master = master or get_master(cluster_name, region=region) job_with_tag = get_job_with_tag(job_name, job_tag) job_control_dir = get_job_control_dir(remote_control_dir, job_with_tag) + # Keep the RUNNING file so we can kill the job if needed + args = ['--remove-source-files', '--exclude', 'RUNNING'] if remove_files else [] rsync_call(user=remote_user, host=master, - # Keep the RUNNING file so we can kill the job if needed - args=['--remove-source-files', '--exclude', 'RUNNING'], + args=args, key_file=key_file, dest_local=with_leading_slash(collect_results_dir), remote_path=job_control_dir) @@ -530,13 +751,35 @@ def collect_job_results(cluster_name, job_name, job_tag, return os.path.join(collect_results_dir, os.path.basename(job_control_dir)) +@named('collect-all-results') +def collect_all_job_results(cluster_name, + key_file=default_key_file, + region=default_region, + master=None, remote_user=default_remote_user, + remote_control_dir=default_remote_control_dir, + collect_results_dir=default_collect_results_dir, + remove_files=False): + master = master or get_master(cluster_name, region=region) + + # Keep the RUNNING file so we can kill the job if needed + args = ['--remove-source-files', '--exclude', 'RUNNING'] if remove_files else [] + rsync_call(user=remote_user, + host=master, + args=args, + key_file=key_file, + dest_local=with_leading_slash(collect_results_dir), + remote_path=with_leading_slash(remote_control_dir)) + + return collect_results_dir + + @named('wait-for') def wait_for_job(cluster_name, job_name, job_tag, key_file=default_key_file, master=None, remote_user=default_remote_user, region=default_region, remote_control_dir=default_remote_control_dir, collect_results_dir=default_collect_results_dir, - job_timeout_minutes=0, max_failures=5, seconds_to_sleep=60): + job_timeout_minutes=0, max_failures=5, seconds_to_sleep=60, remove_files=True): master = master or get_master(cluster_name, region=region) @@ -561,7 +804,7 @@ def collect(show_tail): key_file=key_file, region=region, master=master, remote_user=remote_user, remote_control_dir=remote_control_dir, - collect_results_dir=collect_results_dir) + collect_results_dir=collect_results_dir, remove_files=remove_files) log.info('Jobs results saved on: {}'.format(dest_log_dir)) if show_tail: output_log = os.path.join(dest_log_dir, 'output.log') @@ -622,7 +865,7 @@ def collect(show_tail): failures += 1 last_failure = 'Unexpected response: {}'.format(output) health_check(cluster_name=cluster_name, key_file=key_file, master=master, remote_user=remote_user, region=region) - except subprocess.CalledProcessError as e: + except (subprocess.CalledProcessError, ssl.SSLError) as e: failures += 1 log.exception('Got exception') last_failure = 'Exception: {}'.format(e) @@ -671,13 +914,36 @@ def killall_jobs(cluster_name, key_file=default_key_file, done >& /dev/null || true'''.format(remote_control_dir=remote_control_dir) ]) - +def check_flintrock_installation(): + try: + with open('/dev/null', 'w') as devnull: + call_ec2_script(['--help'], 1 , 1, stdout=devnull) + except: + setup = os.path.join(ec2_script_base_path(), 'setup.py') + if not os.path.exists(setup): + log.error(''' +Flintrock is missing (or the wrong version is being used). +Check if you have checked out the submodule. Try: + git submode update --init --recursive +Or checkout ignition with: + git clone --recursive .... +''') + else: + log.error(''' +Some dependencies are missing. For an Ubuntu system, try the following: +sudo apt-get install python3-yaml libyaml-dev python3-pip +sudo python3 -m pip install -U pip packaging setuptools +cd {flintrock} +sudo pip3 install -r requirements/user.pip + '''.format(flintrock=ec2_script_base_path())) + sys.exit(1) parser = ArghParser() -parser.add_commands([launch, destroy, get_master, ssh_master, tag_cluster_instances, health_check]) -parser.add_commands([job_run, job_attach, wait_for_job, - kill_job, killall_jobs, collect_job_results], namespace="jobs") +parser.add_commands([launch, destroy, get_master, ssh_master, tag_cluster_instances, health_check, exec_shell]) +parser.add_commands([job_run, job_local_yarn_run, job_attach, wait_for_job, + kill_job, killall_jobs, collect_job_results, collect_all_job_results], namespace="jobs") if __name__ == '__main__': + check_flintrock_installation() parser.dispatch() diff --git a/tools/create_image.sh b/tools/create_image.sh new file mode 100644 index 00000000..852b861c --- /dev/null +++ b/tools/create_image.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Creates an AMI for the Spark EC2 scripts starting with a stock Amazon Linux AMI. + +# This script was adapted from: +# https://github.com/amplab/spark-ec2/blob/branch-1.6/create_image.sh + +set -e + +if [ "$(id -u)" != "0" ]; then + echo "This script must be run as root" 1>&2 + exit 1 +fi + +# Dev tools +sudo yum install -y java-1.8.0-openjdk-devel +# Perf tools +sudo yum install -y dstat iotop strace sysstat htop perf +sudo debuginfo-install -q -y glibc +sudo debuginfo-install -q -y kernel +sudo yum --enablerepo='*-debug*' install -q -y java-1.8.0-openjdk-debuginfo.x86_64 + +# Root ssh config +sudo sed -i 's/PermitRootLogin.*/PermitRootLogin without-password/g' \ + /etc/ssh/sshd_config +sudo sed -i 's/disable_root.*/disable_root: 0/g' /etc/cloud/cloud.cfg + +# Edit bash profile +echo "export PS1=\"\\u@\\h \\W]\\$ \"" >> ~/.bash_profile +echo "export JAVA_HOME=/usr/lib/jvm/java-1.8.0" >> ~/.bash_profile + +source ~/.bash_profile + +# Global JAVA_HOME env +echo "export JAVA_HOME=/usr/lib/jvm/java-1.8.0" >> /etc/environment + +# Install Snappy lib (for Hadoop) +yum install -y snappy + +# Install netlib-java native dependencies +yum install -y blas atlas lapack + +# Install python3 and pip3 +yum install -y python3 python3-pip + +# Install tmux +yum install -y tmux + +# Create /usr/bin/realpath which is used by R to find Java installations +# NOTE: /usr/bin/realpath is missing in CentOS AMIs. See +# http://superuser.com/questions/771104/usr-bin-realpath-not-found-in-centos-6-5 +echo '#!/bin/bash' > /usr/bin/realpath +echo 'readlink -e "$@"' >> /usr/bin/realpath +chmod a+x /usr/bin/realpath \ No newline at end of file diff --git a/tools/delete_fleet.py b/tools/delete_fleet.py new file mode 100644 index 00000000..70f1ce0f --- /dev/null +++ b/tools/delete_fleet.py @@ -0,0 +1,49 @@ +import sys +from time import sleep + +import boto3 +from botocore.exceptions import ClientError + +def describe_fleets(region, fleet_id): + ec2 = boto3.client('ec2', region_name=region) + response = ec2.describe_fleets( + FleetIds=[ + fleet_id + ], + ) + errors = response['Fleets'][0]['Errors'] + instances = response['Fleets'][0]['Instances'] + # to ensure we are returning an array anyway + if len(errors) > 0 and len(instances) == 0: + return [''] + return instances[0]['InstanceIds'] + +def delete_fleet(region, fleet_id): + ec2 = boto3.client('ec2', region_name=region) + response = ec2.delete_fleets( + FleetIds=[ + fleet_id, + ], + TerminateInstances=True + ) + + return response['SuccessfulFleetDeletions'][0]['CurrentFleetState'] + + +if __name__ == '__main__': + region = sys.argv[1] + fleet_id = sys.argv[2] + try: + # Delete the fleet + fleet_deleted_states = ["deleted", "deleted_running", "deleted_terminating"] + fleet_state = None + while fleet_state not in fleet_deleted_states: + sleep(5) + fleet_state = delete_fleet(region=region, fleet_id=fleet_id) + print(f"Fleet deleted. Fleet state: {fleet_state}") + + # get the instance ids from the fleet + print(describe_fleets(region=region, fleet_id=fleet_id)) + except (ClientError, Exception) as e: + print(e) + \ No newline at end of file diff --git a/tools/flintrock b/tools/flintrock new file mode 160000 index 00000000..e5b3b9b2 --- /dev/null +++ b/tools/flintrock @@ -0,0 +1 @@ +Subproject commit e5b3b9b2a6ac66536ba6e105cd42f988f9d8bb7e diff --git a/tools/revoke_sg_rules.py b/tools/revoke_sg_rules.py new file mode 100644 index 00000000..d4ab31a0 --- /dev/null +++ b/tools/revoke_sg_rules.py @@ -0,0 +1,122 @@ +import urllib.request +import sys + +from botocore.exceptions import ClientError +import boto3 + + +def _get_security_group(region, vpc_id, sg_name): + ec2 = boto3.client('ec2', region_name=region) + response = ec2.describe_security_groups( + Filters=[ + { + 'Name': 'vpc-id', + 'Values': [ + vpc_id, + ] + }, + ], + ) + desired_sg = None + security_groups = response['SecurityGroups'] + for security_group in security_groups: + if security_group['GroupName'] == sg_name: + desired_sg = security_group + + return desired_sg + + +def _client_cidr(): + flintrock_client_ip = ( + urllib.request.urlopen('http://checkip.amazonaws.com/') + .read().decode('utf-8').strip()) + flintrock__client_cidr = '{ip}/32'.format(ip=flintrock_client_ip) + return flintrock__client_cidr + + +def _exists_cidr_in_sg(region, cidr, sg_id): + """Boolean function to return `true` if a given cidr + exists in a given security group id. Otherwise returns + `false`. + """ + ec2 = boto3.client('ec2', region_name=region) + response = ec2.describe_security_group_rules( + Filters=[ + { + 'Name': 'group-id', + 'Values': [ + sg_id, + ] + }, + ] + ) + rules = response['SecurityGroupRules'] + for rule in rules: + if rule['CidrIpv4'] == cidr: + return True + return False + + +def _delete_rule(cidr_ip, ip_protocol, from_port, to_port, group_id, region): + ec2 = boto3.client('ec2', region_name=region) + ec2.revoke_security_group_ingress( + CidrIp=cidr_ip, + GroupId=group_id, + IpProtocol=ip_protocol, + FromPort=from_port, + ToPort=to_port + ) + + +def revoke_flintrock_sg_ingress(region, vpc_id): + """Revoke Flintrock Security Group's Rules matched with the IP from + the current machine given the Region and VPC ID + :param `region`: The AWS region where the VPC is located + :type `region`: str + :param `vpc_id`: The VPC ID where flintrock Security Group was created + :type `vpc_id`: str + """ + + flintrock_security_group = _get_security_group(region=region, vpc_id=vpc_id, sg_name='flintrock') + cidr_to_revoke_rules = _client_cidr() + flintrock_group_id = flintrock_security_group['GroupId'] + + if flintrock_security_group['GroupName'] != 'flintrock': + print('Flintrock security groups doesn\'t exist in this vpc {} at region {}'.format(vpc_id, region)) + return # we don't want the script to ``raise`` an error, to not mess with the job_runner.py logs + + # check if the local IP is in some rule or not + if not _exists_cidr_in_sg(region=region, cidr=cidr_to_revoke_rules, sg_id=flintrock_group_id): + print('There is no rules with the IP of this client in Flintrock security group.') + return + + for ip_permission in flintrock_security_group['IpPermissions']: + for ip_range in ip_permission['IpRanges']: + group_id = flintrock_group_id + from_port = ip_permission['FromPort'] + ip_protocol = ip_permission['IpProtocol'] + to_port = ip_permission['ToPort'] + + if 'FromPort' in ip_permission and ip_range['CidrIp'] == cidr_to_revoke_rules: + try: + _delete_rule( + cidr_ip=ip_range['CidrIp'], + ip_protocol=ip_protocol, + from_port=from_port, + to_port=to_port, + group_id=group_id, + region=region + ) + except ClientError as error: + print(error) + + # check again to confirm if the rules were revoked + if not _exists_cidr_in_sg(region=region, cidr=cidr_to_revoke_rules, sg_id=flintrock_group_id): + print('Successfully deleted rules of this client from flintrock security group at vpc {}'.format(vpc_id)) + + +if __name__ == '__main__': + region = sys.argv[1] + vpc_id = sys.argv[2] + revoke_flintrock_sg_ingress(region=region, vpc_id=vpc_id) + \ No newline at end of file diff --git a/tools/scripts/S05mount-disks b/tools/scripts/S05mount-disks deleted file mode 100644 index 8f129a30..00000000 --- a/tools/scripts/S05mount-disks +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -echo 'Mounting disks' >> /tmp/mount-disks.log -mkdir -p /mnt -mkdir -p /mnt{2,3,4} -chmod -R 777 /mnt* -[ -r /dev/xvdb ] && mkfs.ext4 /dev/xvdb && mount /dev/xvdb /mnt -[ -r /dev/xvdc ] && mkfs.ext4 /dev/xvdc && mount /dev/xvdc /mnt2 -[ -r /dev/xvdd ] && mkfs.ext4 /dev/xvdd && mount /dev/xvdd /mnt3 -[ -r /dev/xvde ] && mkfs.ext4 /dev/xvde && mount /dev/xvde /mnt4 - diff --git a/tools/scripts/noop b/tools/scripts/noop new file mode 100644 index 00000000..0e872836 --- /dev/null +++ b/tools/scripts/noop @@ -0,0 +1,3 @@ +#!/bin/bash + +echo '* - nofile 1000000' >> /etc/security/limits.conf \ No newline at end of file diff --git a/tools/spark-ec2/README b/tools/spark-ec2/README deleted file mode 100644 index 72434f24..00000000 --- a/tools/spark-ec2/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder contains a script, spark-ec2, for launching Spark clusters on -Amazon EC2. Usage instructions are available online at: - -http://spark.apache.org/docs/latest/ec2-scripts.html diff --git a/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh deleted file mode 100644 index 3570891b..00000000 --- a/tools/spark-ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# These variables are automatically filled in by the spark-ec2 script. -export MASTERS="{{master_list}}" -export SLAVES="{{slave_list}}" -export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" -export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" -export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" -export MODULES="{{modules}}" -export SPARK_VERSION="{{spark_version}}" -export SHARK_VERSION="{{shark_version}}" -export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" -export SWAP_MB="{{swap}}" -export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" -export SPARK_MASTER_OPTS="{{spark_master_opts}}" diff --git a/tools/spark-ec2/spark-ec2 b/tools/spark-ec2/spark-ec2 deleted file mode 100755 index 31f97712..00000000 --- a/tools/spark-ec2/spark-ec2 +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/sh - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" diff --git a/tools/spark-ec2/spark_ec2.py b/tools/spark-ec2/spark_ec2.py deleted file mode 100755 index 5fdf0467..00000000 --- a/tools/spark-ec2/spark_ec2.py +++ /dev/null @@ -1,1286 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import with_statement - -import hashlib -import logging -import os -import os.path -import pipes -import random -import shutil -import string -from stat import S_IRUSR -import subprocess -import sys -import tarfile -import tempfile -import textwrap -import time -import urllib2 -import warnings -from datetime import datetime -from optparse import OptionParser -from sys import stderr - -SPARK_EC2_VERSION = "1.3.0" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) - -VALID_SPARK_VERSIONS = set([ - "0.7.3", - "0.8.0", - "0.8.1", - "0.9.0", - "0.9.1", - "0.9.2", - "1.0.0", - "1.0.1", - "1.0.2", - "1.1.0", - "1.1.1", - "1.2.0", - "1.2.1", - "1.3.0", -]) - -DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION -DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" - -# Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" - -import boto -from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType -from boto import ec2 - - -class UsageError(Exception): - pass - - -# Configure and parse our command-line arguments -def parse_args(): - parser = OptionParser( - prog="spark-ec2", - version="%prog {v}".format(v=SPARK_EC2_VERSION), - usage="%prog [options] \n\n" - + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") - - parser.add_option( - "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: %default)") - parser.add_option( - "-w", "--wait", type="int", - help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") - parser.add_option( - "-k", "--key-pair", - help="Key pair to use on instances") - parser.add_option( - "-i", "--identity-file", - help="SSH private key file to use for logging into instances") - parser.add_option( - "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: %default). " + - "WARNING: must be 64-bit; small instances won't work") - parser.add_option( - "-m", "--master-instance-type", default="", - help="Master instance type (leave empty for same as instance-type)") - parser.add_option( - "-r", "--region", default="us-east-1", - help="EC2 region zone to launch instances in") - parser.add_option( - "-z", "--zone", default="", - help="Availability zone to launch instances in, or 'all' to spread " + - "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies) (default: a single zone chosen at random)") - parser.add_option( - "-a", "--ami", - help="Amazon Machine Image ID to use") - parser.add_option("--master-ami", - help="Amazon Machine Image ID to use for the Master") - parser.add_option( - "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, - help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") - parser.add_option( - "--spark-git-repo", - default=DEFAULT_SPARK_GITHUB_REPO, - help="Github repo from which to checkout supplied commit hash (default: %default)") - parser.add_option( - "--spark-ec2-git-repo", - default=DEFAULT_SPARK_EC2_GITHUB_REPO, - help="Github repo from which to checkout spark-ec2 (default: %default)") - parser.add_option( - "--spark-ec2-git-branch", - default=DEFAULT_SPARK_EC2_BRANCH, - help="Github repo branch of spark-ec2 to use (default: %default)") - parser.add_option( - "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") - parser.add_option( - "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", - help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + - "the given local address (for use with login)") - parser.add_option( - "--resume", action="store_true", default=False, - help="Resume installation on a previously launched cluster " + - "(for debugging)") - parser.add_option( - "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Size (in GB) of each EBS volume.") - parser.add_option( - "--ebs-vol-type", default="standard", - help="EBS volume type (e.g. 'gp2', 'standard').") - parser.add_option( - "--ebs-vol-num", type="int", default=1, - help="Number of EBS volumes to attach to each node as /vol[x]. " + - "The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + - "Only support up to 8 EBS volumes.") - parser.add_option( - "--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") - parser.add_option( - "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: %default)") - parser.add_option( - "--spot-price", metavar="PRICE", type="float", - help="If specified, launch slaves as spot instances with the given " + - "maximum price (in dollars)") - parser.add_option( - "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + - "the Ganglia page will be publicly accessible") - parser.add_option( - "--no-ganglia", action="store_false", dest="ganglia", - help="Disable Ganglia monitoring for the cluster") - parser.add_option( - "-u", "--user", default="root", - help="The SSH user you want to connect as (default: %default)") - parser.add_option( - "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") - parser.add_option( - "--use-existing-master", action="store_true", default=False, - help="Launch fresh slaves, but use an existing stopped master if possible") - parser.add_option( - "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") - parser.add_option( - "--master-opts", type="string", default="", - help="Extra options to give to master through SPARK_MASTER_OPTS variable " + - "(e.g -Dspark.worker.timeout=180)") - parser.add_option( - "--user-data", type="string", default="", - help="Path to a user-data file (most AMI's interpret this as an initialization script)") - parser.add_option( - "--security-group-prefix", type="string", default=None, - help="Use this prefix for the security group rather than the cluster name.") - parser.add_option( - "--authorized-address", type="string", default="0.0.0.0/0", - help="Address to authorize on created security groups (default: %default)") - parser.add_option( - "--additional-security-group", type="string", default="", - help="Additional security group to place the machines in") - parser.add_option( - "--copy-aws-credentials", action="store_true", default=False, - help="Add AWS credentials to hadoop configuration to allow Spark to access S3") - parser.add_option( - "--subnet-id", default=None, - help="VPC subnet to launch instances in") - parser.add_option( - "--vpc-id", default=None, - help="VPC to launch instances in") - parser.add_option( - "--spot-timeout", type="int", default=45, - help="Maximum amount of time (in minutes) to wait for spot requests to be fulfilled") - - (opts, args) = parser.parse_args() - if len(args) != 2: - parser.print_help() - sys.exit(1) - (action, cluster_name) = args - - # Boto config check - # http://boto.cloudhackers.com/en/latest/boto_config_tut.html - home_dir = os.getenv('HOME') - if home_dir is None or not os.path.isfile(home_dir + '/.boto'): - if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " + - "must be set") - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " + - "must be set") - sys.exit(1) - return (opts, action, cluster_name) - - -# Get the EC2 security group of the given name, creating it if it doesn't exist -def get_or_make_group(conn, name, vpc_id): - groups = conn.get_all_security_groups() - group = [g for g in groups if g.name == name] - if len(group) > 0: - return group[0] - else: - print "Creating security group " + name - return conn.create_security_group(name, "Spark EC2 group", vpc_id) - -def check_if_http_resource_exists(resource): - request = urllib2.Request(resource) - request.get_method = lambda: 'HEAD' - try: - response = urllib2.urlopen(request) - if response.getcode() == 200: - return True - else: - raise RuntimeError("Resource {resource} not found. Error: {code}".format(resource, response.getcode())) - except urllib2.HTTPError, e: - print >> stderr, "Unable to check if HTTP resource {url} exists. Error: {code}".format( - url=resource, - code=e.code) - return False - -def get_validate_spark_version(version, repo): - if version.startswith("http"): - #check if custom package URL exists - if check_if_http_resource_exists: - return version - else: - print >> stderr, "Unable to validate pre-built spark version {version}".format(version=version) - sys.exit(1) - elif "." in version: - version = version.replace("v", "") - if version not in VALID_SPARK_VERSIONS: - print >> stderr, "Don't know about Spark version: {v}".format(v=version) - sys.exit(1) - return version - else: - github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - if not check_if_http_resource_exists(github_commit_url): - print >> stderr, "Couldn't validate Spark commit: {repo} / {commit}".format( - repo=repo, commit=version) - sys.exit(1) - else: - return version - - -# Check whether a given EC2 instance object is in a state we consider active, -# i.e. not terminating or terminated. We count both stopping and stopped as -# active since we can restart stopped clusters. -def is_active(instance): - return (instance.state in ['pending', 'running', 'stopping', 'stopped']) - - -# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2014-06-20 -# For easy maintainability, please keep this manually-inputted dictionary sorted by key. -EC2_INSTANCE_TYPES = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "i2.xlarge": "hvm", - "m1.large": "pvm", - "m1.medium": "pvm", - "m1.small": "pvm", - "m1.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m2.xlarge": "pvm", - "m3.2xlarge": "hvm", - "m3.large": "hvm", - "m3.medium": "hvm", - "m3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "t1.micro": "pvm", - "t2.medium": "hvm", - "t2.micro": "hvm", - "t2.small": "hvm", - "d2.2xlarge": "hvm", - "d2.4xlarge": "hvm", - "d2.8xlarge": "hvm", - "d2.large": "hvm", - "d2.xlarge": "hvm", -} - - -# Attempt to resolve an appropriate AMI given the architecture and region of the request. -def get_spark_ami(instance_type, region, spark_ec2_git_repo, spark_ec2_git_branch): - if instance_type in EC2_INSTANCE_TYPES: - instance_type = EC2_INSTANCE_TYPES[instance_type] - else: - instance_type = "pvm" - print >> stderr,\ - "Don't recognize %s, assuming type is pvm" % instance_type - - # URL prefix from which to fetch AMI information - ami_prefix = "{r}/{b}/ami-list".format( - r=spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), - b=spark_ec2_git_branch) - - ami_path = "%s/%s/%s" % (ami_prefix, region, instance_type) - try: - ami = urllib2.urlopen(ami_path).read().strip() - print "Spark AMI for %s: %s" % (instance_type, ami) - except: - print >> stderr, "Could not resolve AMI at: " + ami_path - sys.exit(1) - - return ami - - -# Launch a cluster of the given name, by setting up its security groups, -# and then starting new instances in them. -# Returns a tuple of EC2 reservation objects for the master and slaves -# Fails if there already instances running in the cluster's groups. -def launch_cluster(conn, opts, cluster_name): - if opts.identity_file is None: - print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." - sys.exit(1) - - if opts.key_pair is None: - print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." - sys.exit(1) - - user_data_content = None - if opts.user_data: - with open(opts.user_data) as user_data_file: - user_data_content = user_data_file.read() - - print "Setting up security groups..." - if opts.security_group_prefix is None: - master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) - slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) - else: - master_group = get_or_make_group(conn, opts.security_group_prefix + "-master", opts.vpc_id) - slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves", opts.vpc_id) - - authorized_address = opts.authorized_address - if master_group.rules == []: # Group was just now created - if opts.vpc_id is None: - master_group.authorize(src_group=master_group) - master_group.authorize(src_group=slave_group) - else: - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize('tcp', 22, 22, authorized_address) - master_group.authorize('tcp', 8080, 8081, authorized_address) - master_group.authorize('tcp', 18080, 18080, authorized_address) - master_group.authorize('tcp', 19999, 19999, authorized_address) - master_group.authorize('tcp', 50030, 50030, authorized_address) - master_group.authorize('tcp', 50070, 50070, authorized_address) - master_group.authorize('tcp', 60070, 60070, authorized_address) - master_group.authorize('tcp', 4040, 4045, authorized_address) - if opts.ganglia: - master_group.authorize('tcp', 5080, 5080, authorized_address) - if slave_group.rules == []: # Group was just now created - if opts.vpc_id is None: - slave_group.authorize(src_group=master_group) - slave_group.authorize(src_group=slave_group) - else: - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize('tcp', 22, 22, authorized_address) - slave_group.authorize('tcp', 8080, 8081, authorized_address) - slave_group.authorize('tcp', 50060, 50060, authorized_address) - slave_group.authorize('tcp', 50075, 50075, authorized_address) - slave_group.authorize('tcp', 60060, 60060, authorized_address) - slave_group.authorize('tcp', 60075, 60075, authorized_address) - - # Check if instances are already running in our groups - existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) - sys.exit(1) - - # Figure out Spark AMI - if opts.ami is None: - opts.ami = get_spark_ami(opts.instance_type, opts.region, opts.spark_ec2_git_repo, opts.spark_ec2_git_branch) - - if opts.master_ami is None: - opts.master_ami = get_spark_ami(opts.master_instance_type, opts.region, opts.spark_ec2_git_repo, opts.spark_ec2_git_branch) - - # we use group ids to work around https://github.com/boto/boto/issues/350 - additional_group_ids = [] - if opts.additional_security_group: - additional_group_ids = [sg.id - for sg in conn.get_all_security_groups() - if opts.additional_security_group in (sg.name, sg.id)] - print "Launching instances..." - - try: - image = conn.get_all_images(image_ids=[opts.ami])[0] - except: - print >> stderr, "Could not find AMI " + opts.ami - sys.exit(1) - - try: - master_image = conn.get_all_images(image_ids=[opts.master_ami])[0] - except: - print >> stderr, "Could not find AMI " + opts.master_ami - sys.exit(1) - - # Create block device mapping so that we can add EBS volumes if asked to. - # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz - block_map = BlockDeviceMapping() - if opts.ebs_vol_size > 0: - for i in range(opts.ebs_vol_num): - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.volume_type = opts.ebs_vol_type - device.delete_on_termination = True - block_map["/dev/sd" + chr(ord('s') + i)] = device - - for i in range(get_num_disks(opts.instance_type)): - dev = BlockDeviceType() - dev.ephemeral_name = 'ephemeral%d' % i - name = '/dev/xvd' + string.letters[i + 1] - block_map[name] = dev - - # Launch slaves - if opts.spot_price is not None: - # Launch spot instances with the requested price - print ("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - my_req_ids = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - slave_reqs = conn.request_spot_instances( - price=opts.spot_price, - image_id=opts.ami, - launch_group="launch-group-%s" % cluster_name, - placement=zone, - count=num_slaves_this_zone, - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) - my_req_ids += [req.id for req in slave_reqs] - i += 1 - - start_time = datetime.now() - print "Waiting for spot instances to be granted... Request IDs: %s " % my_req_ids - try: - while True: - time.sleep(10) - reqs = conn.get_all_spot_instance_requests(my_req_ids) - active_instance_ids = filter(lambda req: req.state == "active", reqs) - invalid_states = ["capacity-not-available", "capacity-oversubscribed", "price-too-low"] - invalid = filter(lambda req: req.status.code in invalid_states, reqs) - if len(invalid) > 0: - raise Exception("Invalid state for spot request: %s - status: %s" % - (invalid[0].id, invalid[0].status.message)) - if len(active_instance_ids) == opts.slaves: - print "All %d slaves granted" % opts.slaves - reservations = conn.get_all_reservations(active_instance_ids) - slave_nodes = [] - for r in reservations: - slave_nodes += r.instances - break - else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) - - if (datetime.now() - start_time).seconds > opts.spot_timeout * 60: - raise Exception("Timed out while waiting for spot instances") - except: - print "Error: %s" % sys.exc_info()[1] - print "Canceling spot instance requests" - conn.cancel_spot_instance_requests(my_req_ids) - # Log a warning if any of these requests actually launched instances: - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - running = len(master_nodes) + len(slave_nodes) - if running: - print >> stderr, ("WARNING: %d instances are still running" % running) - sys.exit(0) - else: - # Launch non-spot instances - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - slave_nodes = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) - slave_nodes += slave_res.instances - print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, - zone, slave_res.id) - i += 1 - - # Launch or resume masters - if existing_masters: - print "Starting master..." - for inst in existing_masters: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - master_nodes = existing_masters - else: - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = master_image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) - - master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) - - # This wait time corresponds to SPARK-4983 - print "Waiting for AWS to propagate instance metadata..." - time.sleep(5) - # Give the instances descriptive names - for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) - for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) - - # Return all the instances - return (master_nodes, slave_nodes) - - -# Get the EC2 instances in an existing cluster if available. -# Returns a tuple of lists of EC2 instance objects for the masters and slaves -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - print "Searching for existing cluster " + cluster_name + "..." - reservations = conn.get_all_reservations() - master_nodes = [] - slave_nodes = [] - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for inst in active: - group_names = [g.name for g in inst.groups] - if (cluster_name + "-master") in group_names: - master_nodes.append(inst) - elif (cluster_name + "-slaves") in group_names: - slave_nodes.append(inst) - if any((master_nodes, slave_nodes)): - print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) - if master_nodes != [] or not die_on_error: - return (master_nodes, slave_nodes) - else: - if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" - else: - print >> sys.stderr, "ERROR: Could not find any existing cluster" - sys.exit(1) - - -# Deploy configuration files and run setup scripts on a newly launched -# or started EC2 cluster. - - -def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name - if deploy_ssh_key: - print "Generating cluster's SSH key on master..." - key_setup = """ - [ -f ~/.ssh/id_rsa ] || - (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) - """ - ssh(master, opts, key_setup) - dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print "Transferring cluster's SSH key to slaves..." - for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) - - modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] - - if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) - - if opts.ganglia: - modules.append('ganglia') - - # NOTE: We should clone the repository before running deploy_files to - # prevent ec2-variables.sh from being overwritten - print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch) - ssh( - host=master, - opts=opts, - command="rm -rf spark-ec2" - + " && " - + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, - b=opts.spark_ec2_git_branch) - ) - - print "Deploying files to master..." - deploy_files( - conn=conn, - root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", - opts=opts, - master_nodes=master_nodes, - slave_nodes=slave_nodes, - modules=modules - ) - - print "Running setup on master..." - setup_spark_cluster(master, opts) - print "Done!" - - -def setup_spark_cluster(master, opts): - ssh(master, opts, "chmod u+x spark-ec2/setup.sh") - ssh(master, opts, "spark-ec2/setup.sh") - print "Spark standalone cluster started at http://%s:8080" % master - - if opts.ganglia: - print "Ganglia started at http://%s:5080/ganglia" % master - - -def is_ssh_available(host, opts, print_ssh_output=True): - """ - Check if SSH is available on a host. - """ - s = subprocess.Popen( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order - ) - cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout - - if s.returncode != 0 and print_ssh_output: - # extra leading newline is for spacing in wait_for_cluster_state() - print textwrap.dedent("""\n - Warning: SSH connection error. (This could be temporary.) - Host: {h} - SSH return code: {r} - SSH output: {o} - """).format( - h=host, - r=s.returncode, - o=cmd_output.strip() - ) - - return s.returncode == 0 - - -def is_cluster_ssh_available(cluster_instances, opts): - """ - Check if SSH is available on all the instances in a cluster. - """ - for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): - return False - else: - return True - - -def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): - """ - Wait for all the instances in the cluster to reach a designated state. - - cluster_instances: a list of boto.ec2.instance.Instance - cluster_state: a string representing the desired state of all the instances in the cluster - value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as - 'running', 'terminated', etc. - (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) - """ - sys.stdout.write( - "Waiting for cluster to enter '{s}' state.".format(s=cluster_state) - ) - sys.stdout.flush() - - start_time = datetime.now() - num_attempts = 0 - - while True: - time.sleep(5 * num_attempts) # seconds - - for i in cluster_instances: - i.update() - - statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) - - if cluster_state == 'ssh-ready': - if all(i.state == 'running' for i in cluster_instances) and \ - all(s.system_status.status == 'ok' for s in statuses) and \ - all(s.instance_status.status == 'ok' for s in statuses) and \ - is_cluster_ssh_available(cluster_instances, opts): - break - else: - if all(i.state == cluster_state for i in cluster_instances): - break - - num_attempts += 1 - - sys.stdout.write(".") - sys.stdout.flush() - - sys.stdout.write("\n") - - end_time = datetime.now() - print "Cluster is now in '{s}' state. Waited {t} seconds.".format( - s=cluster_state, - t=(end_time - start_time).seconds - ) - - -# Get number of local disks available for a given EC2 instance type. -def get_num_disks(instance_type): - # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2014-06-20 - # For easy maintainability, please keep this manually-inputted dictionary sorted by key. - disks_by_instance = { - "c1.medium": 1, - "c1.xlarge": 4, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "c3.large": 2, - "c3.xlarge": 2, - "cc1.4xlarge": 2, - "cc2.8xlarge": 4, - "cg1.4xlarge": 2, - "cr1.8xlarge": 2, - "g2.2xlarge": 1, - "hi1.4xlarge": 2, - "hs1.8xlarge": 24, - "i2.2xlarge": 2, - "i2.4xlarge": 4, - "i2.8xlarge": 8, - "i2.xlarge": 1, - "m1.large": 2, - "m1.medium": 1, - "m1.small": 1, - "m1.xlarge": 4, - "m2.2xlarge": 1, - "m2.4xlarge": 2, - "m2.xlarge": 1, - "m3.2xlarge": 2, - "m3.large": 1, - "m3.medium": 1, - "m3.xlarge": 2, - "r3.2xlarge": 1, - "r3.4xlarge": 1, - "r3.8xlarge": 2, - "r3.large": 1, - "r3.xlarge": 1, - "t1.micro": 0, - 'd2.xlarge': 3, - 'd2.2xlarge': 6, - 'd2.4xlarge': 12, - 'd2.8xlarge': 24, - } - if instance_type in disks_by_instance: - return disks_by_instance[instance_type] - else: - print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type) - return 1 - - -# Deploy the configuration file templates in a given local directory to -# a cluster, filling in any template parameters with information about the -# cluster (e.g. lists of masters and slaves). Files are only deployed to -# the first master instance in the cluster, and we expect the setup -# script to be run on that instance to copy them to other nodes. -# -# root_dir should be an absolute path to the directory with the files we want to deploy. -def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name - - num_disks = get_num_disks(opts.instance_type) - hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" - mapred_local_dirs = "/mnt/hadoop/mrlocal" - spark_local_dirs = "/mnt/spark" - if num_disks > 1: - for i in range(2, num_disks + 1): - hdfs_data_dirs += ",/mnt%d/ephemeral-hdfs/data" % i - mapred_local_dirs += ",/mnt%d/hadoop/mrlocal" % i - spark_local_dirs += ",/mnt%d/spark" % i - - cluster_url = "%s:7077" % active_master - - if opts.spark_version.startswith("http"): - # Custom pre-built spark package - spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - elif "." in opts.spark_version: - # Pre-built Spark deploy - spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - else: - # Spark-only custom deploy - spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) - - template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), - "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), - "cluster_url": cluster_url, - "hdfs_data_dirs": hdfs_data_dirs, - "mapred_local_dirs": mapred_local_dirs, - "spark_local_dirs": spark_local_dirs, - "swap": str(opts.swap), - "modules": '\n'.join(modules), - "spark_version": spark_v, - "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, - "spark_master_opts": opts.master_opts - } - - if opts.copy_aws_credentials: - template_vars["aws_access_key_id"] = conn.aws_access_key_id - template_vars["aws_secret_access_key"] = conn.aws_secret_access_key - else: - template_vars["aws_access_key_id"] = "" - template_vars["aws_secret_access_key"] = "" - - # Create a temp directory in which we will place all the files to be - # deployed after we substitue template parameters in them - tmp_dir = tempfile.mkdtemp() - for path, dirs, files in os.walk(root_dir): - if path.find(".svn") == -1: - dest_dir = os.path.join('/', path[len(root_dir):]) - local_dir = tmp_dir + dest_dir - if not os.path.exists(local_dir): - os.makedirs(local_dir) - for filename in files: - if filename[0] not in '#.~' and filename[-1] != '~': - dest_file = os.path.join(dest_dir, filename) - local_file = tmp_dir + dest_file - with open(os.path.join(path, filename)) as src: - with open(local_file, "w") as dest: - text = src.read() - for key in template_vars: - text = text.replace("{{" + key + "}}", template_vars[key]) - dest.write(text) - dest.close() - # rsync the whole directory over to the master machine - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s/" % tmp_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - # Remove the temp directory we created above - shutil.rmtree(tmp_dir) - - -def stringify_command(parts): - if isinstance(parts, str): - return parts - else: - return ' '.join(map(pipes.quote, parts)) - - -def ssh_args(opts): - parts = ['-o', 'StrictHostKeyChecking=no'] - parts += ['-o', 'UserKnownHostsFile=/dev/null'] - if opts.identity_file is not None: - parts += ['-i', opts.identity_file] - return parts - - -def ssh_command(opts): - return ['ssh'] + ssh_args(opts) - - -# Run a command on a host through ssh, retrying up to five times -# and then throwing an exception if ssh continues to fail. -def ssh(host, opts, command): - tries = 0 - while True: - try: - return subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), - stringify_command(command)]) - except subprocess.CalledProcessError as e: - if tries > 5: - # If this was an ssh failure, provide the user with hints. - if e.returncode == 255: - raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + - "--key-pair parameters and try again.".format(host)) - else: - raise e - print >> stderr, \ - "Error executing remote command, retrying after 30 seconds: {0}".format(e) - time.sleep(30) - tries = tries + 1 - - -# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) -def _check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - -def ssh_read(host, opts, command): - return _check_output( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) - - -def ssh_write(host, opts, command, arguments): - tries = 0 - while True: - proc = subprocess.Popen( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], - stdin=subprocess.PIPE) - proc.stdin.write(arguments) - proc.stdin.close() - status = proc.wait() - if status == 0: - break - elif tries > 5: - raise RuntimeError("ssh_write failed with error %s" % proc.returncode) - else: - print >> stderr, \ - "Error {0} while executing remote command, retrying after 30 seconds".format(status) - time.sleep(30) - tries = tries + 1 - - -# Gets a list of zones to launch instances in -def get_zones(conn, opts): - if opts.zone == 'all': - zones = [z.name for z in conn.get_all_zones()] - else: - zones = [opts.zone] - return zones - - -# Gets the number of items in a partition -def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions - if (total % num_partitions) - current_partitions > 0: - num_slaves_this_zone += 1 - return num_slaves_this_zone - - -def real_main(): - (opts, action, cluster_name) = parse_args() - - # Input parameter validation - get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - - if opts.wait is not None: - # NOTE: DeprecationWarnings are silent in 2.7+ by default. - # To show them, run Python with the -Wdefault switch. - # See: https://docs.python.org/3.5/whatsnew/2.7.html - warnings.warn( - "This option is deprecated and has no effect. " - "spark-ec2 automatically waits as long as necessary for clusters to start up.", - DeprecationWarning - ) - - if opts.identity_file is not None: - if not os.path.exists(opts.identity_file): - print >> stderr,\ - "ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file) - sys.exit(1) - - file_mode = os.stat(opts.identity_file).st_mode - if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print >> stderr, "ERROR: The identity file must be accessible only by you." - print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file) - sys.exit(1) - - if opts.instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type) - - if opts.master_instance_type != "": - if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, \ - "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type) - - if opts.ebs_vol_num > 8: - print >> stderr, "ebs-vol-num cannot be greater than 8" - sys.exit(1) - - # Prevent breaking ami_prefix (/, .git and startswith checks) - # Prevent forks with non spark-ec2 names for now. - if opts.spark_ec2_git_repo.endswith("/") or \ - opts.spark_ec2_git_repo.endswith(".git") or \ - not opts.spark_ec2_git_repo.startswith("https://github.com") or \ - not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \ - "trailing / or .git. " \ - "Furthermore, we currently only support forks named spark-ec2." - sys.exit(1) - - try: - conn = ec2.connect_to_region(opts.region) - except Exception as e: - print >> stderr, (e) - sys.exit(1) - - # Select an AZ at random if it was not specified. - if opts.zone == "": - opts.zone = random.choice(conn.get_all_zones()).name - - if action == "launch": - if opts.slaves <= 0: - print >> sys.stderr, "ERROR: You have to start at least 1 slave" - sys.exit(1) - if opts.resume: - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - else: - (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - setup_cluster(conn, master_nodes, slave_nodes, opts, True) - - elif action == "destroy": - print "Are you sure you want to destroy the cluster %s?" % cluster_name - print "The following instances will be terminated:" - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name - - msg = "ALL DATA ON ALL NODES WILL BE LOST!!\nDestroy cluster %s (y/N): " % cluster_name - response = raw_input(msg) - if response == "y": - print "Terminating master..." - for inst in master_nodes: - inst.terminate() - print "Terminating slaves..." - for inst in slave_nodes: - inst.terminate() - - # Delete security groups as well - if opts.delete_groups: - print "Deleting security groups (this will take some time)..." - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='terminated' - ) - attempt = 1 - while attempt <= 3: - print "Attempt %d" % attempt - groups = [g for g in conn.get_all_security_groups() if g.name in group_names] - success = True - # Delete individual rules in all groups before deleting groups to - # remove dependencies between them - for group in groups: - print "Deleting rules in security group " + group.name - for rule in group.rules: - for grant in rule.grants: - success &= group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - - # Sleep for AWS eventual-consistency to catch up, and for instances - # to terminate - time.sleep(30) # Yes, it does have to be this long :-( - for group in groups: - try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name - except boto.exception.EC2ResponseError: - success = False - print "Failed to delete security group " + group.name - - # Unfortunately, group.revoke() returns True even if a rule was not - # deleted, so this needs to be rerun if something fails - if success: - break - - attempt += 1 - - if not success: - print "Failed to delete all security groups after 3 tries." - print "Try re-running in a few minutes." - - elif action == "login": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) - - elif action == "reboot-slaves": - response = raw_input( - "Are you sure you want to reboot the cluster " + - cluster_name + " slaves?\n" + - "Reboot cluster slaves " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print "Rebooting slaves..." - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - print "Rebooting " + inst.id - inst.reboot() - - elif action == "get-master": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name - - elif action == "stop": - response = raw_input( - "Are you sure you want to stop the cluster " + - cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + - "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + - "AMAZON EBS IF IT IS EBS-BACKED!!\n" + - "All data on spot-instance slaves will be lost.\n" + - "Stop cluster " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print "Stopping master..." - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.stop() - print "Stopping slaves..." - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - if inst.spot_instance_request_id: - inst.terminate() - else: - inst.stop() - - elif action == "start": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print "Starting slaves..." - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - print "Starting master..." - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - setup_cluster(conn, master_nodes, slave_nodes, opts, False) - - else: - print >> stderr, "Invalid action: %s" % action - sys.exit(1) - - -def main(): - try: - real_main() - except UsageError, e: - print >> stderr, "\nError:\n", e - sys.exit(1) - - -if __name__ == "__main__": - logging.basicConfig() - main() diff --git a/tools/utils.py b/tools/utils.py index bac56029..8c97d23f 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -1,10 +1,19 @@ #!/usr/bin/env python +import ast import logging +from pprint import pprint import boto.ec2 import sys import subprocess import select import time +import json +from os.path import exists +from os import makedirs +import os + +# get a folder_log_path from env variable +folder_log_path = os.getenv('LOG_FOLDER') logging.basicConfig(level=logging.INFO) @@ -12,7 +21,7 @@ def get_active_instances(conn): active = [instance for res in conn.get_all_instances() for instance in res.instances if instance.state in set(['pending', 'running', - 'stopping', 'stopped'])] + 'stopping', 'stopped', 'shutting-down'])] return active def parse_nodes(active_instances, cluster_name): @@ -20,9 +29,10 @@ def parse_nodes(active_instances, cluster_name): slave_nodes = [] for instance in active_instances: group_names = [g.name for g in instance.groups] - if (cluster_name + "-master") in group_names: + # This can handle both spark-ec2 and flintrock clusters + if (cluster_name + "-master") in group_names or (("flintrock-" + cluster_name) in group_names and instance.tags.get('flintrock-role') == 'master'): master_nodes.append(instance) - elif (cluster_name + "-slaves") in group_names: + elif (cluster_name + "-slaves") in group_names or (("flintrock-" + cluster_name) in group_names and instance.tags.get('flintrock-role') in ('slave', None)): slave_nodes.append(instance) return (master_nodes, slave_nodes) @@ -39,6 +49,93 @@ def get_active_nodes(cluster_name, region): return parse_nodes(active, cluster_name) +def get_active_nodes_by_tag(region, tag_name, tag_value): + conn = boto.ec2.connect_to_region(region) + filter = {"tag:{0}".format(tag_name):["{0}".format(tag_value)], "instance-state-name":["running"]} + return conn.get_only_instances(filters=filter) + +def get_fleet_id_by_cluster_name(cluster_name): + # create a array with the requests ids + fleet_id = '' + file_name = '{0}.json'.format(cluster_name) + + if folder_log_path: + # check if the folder exists and if not create it + if not exists(folder_log_path): + makedirs(folder_log_path) + + file_name = '{0}/{1}.json'.format(folder_log_path, cluster_name) + + # verify if the file exists + if exists(file_name): + # open a json log file if exists + with open(file_name) as json_file:# deserialize the json file to object + json_content = json.load(json_file) + + # create a array with the requests ids + for request in json_content: + fleet_id = str(request['FleetId']) + + return fleet_id + + +def destroy_by_fleet_id(region, cluster_name): + conn = boto.ec2.connect_to_region(region) + fleet_instances_ids = [] + instances = [] + + try: + # get fleet id from json log file + fleet_id = get_fleet_id_by_cluster_name(cluster_name) + + if fleet_id in [None, '']: + raise Exception('There is no fleet id to delete. Keep going.') + + logging.info('The fleet id found in json log file: {0}'.format(fleet_id)) + + # call an external script to delete the fleet and retrieve the list of instances + delete_fleet_script = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'delete_fleet.py') + process = subprocess.Popen(["python3", delete_fleet_script, region, fleet_id], stdout=subprocess.PIPE) + stdout_str = process.communicate()[0] + + # the subprocess return a string with the character '\n' separating the delete message and the list of instances + stdout_str_split = stdout_str.split('\n') + + # message of fleet deletion + deleted_fleet = stdout_str_split[0] + logging.info(deleted_fleet) + + # getting the list of the string containing the list of istances + # e.g."['i-0e90a67a64693dc39', 'i-00889275ebe58bb7b', 'i-0982e3e6728044bef']" + fleet_instances = ast.literal_eval(stdout_str_split[1]) + fleet_instances_ids.extend(fleet_instances) + + # test if the instance id is not empty and contains an instance id for sure + if len(fleet_instances_ids) > 0 and fleet_instances_ids[0].startswith('i-'): + instances_requested = conn.get_only_instances(fleet_instances_ids) + + # terminate instances from request spot + for instance in instances_requested: + # checking again if the object is in the list to not terminate wrong machines + if fleet_instances_ids.index(instance.id) > -1: + if instance.state == 'running': + logging.info('Terminating instance: {0}'.format(instance.id)) + # add only instances that are running to return list + instances.append(instance) + # terminate the instance + instance.terminate() + elif instance.state == 'shutting-down': + # add the instance to the wait list + instances.append(instance) + + except Exception as e: + logging.error(e) + logging.error('Error to destroy cluster {0} by fleet id.'.format(cluster_name)) + pass + + return instances + + def tag_instances(cluster_name, tags, region): conn = boto.ec2.connect_to_region(region) @@ -78,7 +175,7 @@ def read_non_blocking(f): while select.select([f], [], [], 0)[0]: c = f.read(1) if c: - result.append(c) + result.append(c.decode('utf-8')) else: break return ''.join(result) if result else None @@ -119,6 +216,47 @@ def check_call_with_timeout(args, stdin=None, stdout=None, read_from_to(p.stdout, stdout) read_from_to(p.stderr, stderr) if p.returncode != 0: - raise subprocess.CalledProcessError(p.returncode, args) + stdall = 'STDOUT:\n{}\nSTDERR:\n{}'.format(stdout, stderr) + raise subprocess.CalledProcessError(p.returncode, args, output=stdall) return p.returncode +def check_call_with_timeout_describe(args, stdin=None, stdout=None, + stderr=None, shell=False, + timeout_total_minutes=0, + timeout_inactivity_minutes=0): + stdout = stdout or sys.stdout + stderr = stderr or sys.stderr + begin_time_total = time.time() + begin_time_inactivity = time.time() + p = subprocess.Popen(args, + stdin=stdin, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + universal_newlines=False) + while True: + if read_from_to(p.stdout, stdout): + begin_time_inactivity = time.time() + if read_from_to(p.stderr, stderr): + begin_time_inactivity = time.time() + if p.poll() is not None: + break + terminate_by_total_timeout = timeout_total_minutes > 0 and time.time() - begin_time_total > (timeout_total_minutes * 60) + terminate_by_inactivity_timeout = timeout_inactivity_minutes > 0 and time.time() - begin_time_inactivity > (timeout_inactivity_minutes * 60) + if terminate_by_inactivity_timeout or terminate_by_total_timeout: + p.terminate() + for i in range(100): + if p.poll is not None: + break + time.sleep(0.1) + p.kill() + message = 'Terminated by inactivity' if terminate_by_inactivity_timeout else 'Terminated by total timeout' + raise ProcessTimeoutException(message) + time.sleep(0.5) + read_from_to(p.stdout, stdout) + read_from_to(p.stderr, stderr) + if p.returncode != 0: + stdall = 'STDOUT:\n{}\nSTDERR:\n{}'.format(stdout, stderr) + raise subprocess.CalledProcessError(p.returncode, args, output=stdall) + if len(args) > 5: + return args[5]