From 67aa030e06346ffef6fb9cf5352ee36d928c4de2 Mon Sep 17 00:00:00 2001 From: lfrancioli Date: Thu, 17 Nov 2016 14:24:03 -0500 Subject: [PATCH 01/51] Fixed bug causing type mismatch when using --csq on a VDS with existing csq annotation (#1101) --- src/main/scala/org/broadinstitute/hail/driver/VEP.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/VEP.scala b/src/main/scala/org/broadinstitute/hail/driver/VEP.scala index 74e6c9a664a..c7598b61ed6 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/VEP.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/VEP.scala @@ -211,7 +211,7 @@ object VEP extends Command { val rootType = vds.vaSignature.getOption(root) .filter { t => - val r = t == vepSignature + val r = t == (if(csq) TString else vepSignature) if (!r) { if (options.force) warn(s"type for ${ options.root } does not match vep signature, overwriting.") From 550c4159494940af9cf0c2ea3205db8557c66d46 Mon Sep 17 00:00:00 2001 From: cseed Date: Thu, 17 Nov 2016 15:13:30 -0500 Subject: [PATCH 02/51] Added gtj, gtk, gtIndex to expr. (#1102) --- docs/reference/HailExpressionLanguage.md | 4 ++++ .../scala/org/broadinstitute/hail/expr/FunctionRegistry.scala | 4 ++++ .../scala/org/broadinstitute/hail/methods/ExprSuite.scala | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/docs/reference/HailExpressionLanguage.md b/docs/reference/HailExpressionLanguage.md index ba3d6876a47..fea1043b55c 100644 --- a/docs/reference/HailExpressionLanguage.md +++ b/docs/reference/HailExpressionLanguage.md @@ -139,6 +139,10 @@ Several Hail commands provide the ability to perform a broad array of computatio - range: `range(end)` or `range(start, end)`. This function will produce an `Array[Int]`. `range(3)` produces `[0, 1, 2]`. `range(-2, 2)` produces `[-2, -1, 0, 1]`. + - `gtj(i)` and `gtk(i)`. Convert from genotype index (triangular numbers) to `j/k` pairs. + + - `gtIndex(j, k)`. Convert from `j/k` pair to genotype index (triangular numbers). + **Note:** - All variables and values are case sensitive diff --git a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala index cbb6675d631..bef2a399d04 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala @@ -360,4 +360,8 @@ object FunctionRegistry { registerConversion { (x: Long) => x.toDouble } registerConversion { (x: Int) => x.toLong } registerConversion { (x: Float) => x.toDouble } + + register("gtj", (i: Int) => Genotype.gtPair(i).j) + register("gtk", (i: Int) => Genotype.gtPair(i).k) + register("gtIndex", (j: Int, k: Int) => Genotype.gtIndex(j, k)) } diff --git a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala index c81a992d21d..75e6ce76471 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala @@ -508,6 +508,10 @@ class ExprSuite extends SparkSuite { assert(eval[Boolean]("rnorm(2.0, 4.0).abs > -1.0").contains(true)) assert(eval[Any]("if (true) NA: Double else 0.0").isEmpty) + + assert(eval[Int]("gtIndex(3, 5)").contains(18)) + assert(eval[Int]("gtj(18)").contains(3)) + assert(eval[Int]("gtk(18)").contains(5)) } @Test def testParseTypes() { From 01803da488db33d5a5291b63bcdc2ab7f4671505 Mon Sep 17 00:00:00 2001 From: cseed Date: Thu, 17 Nov 2016 16:19:09 -0500 Subject: [PATCH 03/51] Added keep=True to filter methods. (#1104) --- python/pyhail/dataset.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index e921bb6dcc8..a7bbfac1f42 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -513,14 +513,16 @@ def write(self, output, overwrite=False): pargs.append('--overwrite') return self.hc.run_command(self, pargs) - def filter_genotypes(self, condition): + def filter_genotypes(self, condition, keep=True): """Filter variants based on expression. :param str condition: Expression for filter condition. """ - pargs = ['filtergenotypes', '--keep', '-c', condition] + pargs = ['filtergenotypes', + '--keep' if keep else '--remove', + '-c', condition] return self.hc.run_command(self, pargs) def filter_multi(self): @@ -539,24 +541,28 @@ def filter_samples_all(self): pargs = ['filtersamples', 'all'] return self.hc.run_command(self, pargs) - def filter_samples_expr(self, condition): + def filter_samples_expr(self, condition, keep=True): """Filter samples based on expression. :param str condition: Expression for filter condition. """ - pargs = ['filtersamples', 'expr', '--keep', '-c', condition] + pargs = ['filtersamples', 'expr', + '--keep' if keep else '--remove', + '-c', condition] return self.hc.run_command(self, pargs) - def filter_samples_list(self, input): + def filter_samples_list(self, input, keep=True): """Filter samples with a sample list file. :param str input: Path to sample list file. """ - pargs = ['filtersamples', 'list', '--keep', '-i', input] + pargs = ['filtersamples', 'list', + '--keep' if keep else '--remove', + '-i', input] return self.hc.run_command(self, pargs) def filter_variants_all(self): @@ -565,34 +571,40 @@ def filter_variants_all(self): pargs = ['filtervariants', 'all'] return self.hc.run_command(self, pargs) - def filter_variants_expr(self, condition): + def filter_variants_expr(self, condition, keep=True): """Filter variants based on expression. :param str condition: Expression for filter condition. """ - pargs = ['filtervariants', 'expr', '--keep', '-c', condition] + pargs = ['filtervariants', 'expr', + '--keep' if keep else '--remove', + '-c', condition] return self.hc.run_command(self, pargs) - def filter_variants_intervals(self, input): + def filter_variants_intervals(self, input, keep=True): """Filter variants with an .interval_list file. :param str input: Path to .interval_list file. """ - pargs = ['filtervariants', 'intervals', '--keep', '-i', input] + pargs = ['filtervariants', 'intervals', + '--keep' if keep else '--remove', + '-i', input] return self.hc.run_command(self, pargs) - def filter_variants_list(self, input): + def filter_variants_list(self, input, keep=True): """Filter variants with a list of variants. :param str input: Path to variant list file. """ - pargs = ['filtervariants', 'list', '--keep', '-i', input] + pargs = ['filtervariants', 'list', + '--keep' if keep else '--remove', + '-i', input] return self.hc.run_command(self, pargs) def grm(self, format, output, id_file=None, n_file=None): From 3b45e751c56f0ddb1b30fadd9375733855873099 Mon Sep 17 00:00:00 2001 From: cseed Date: Fri, 18 Nov 2016 17:08:27 -0500 Subject: [PATCH 04/51] Updated python export_plink. (#1109) --- python/pyhail/dataset.py | 4 +- .../hail/driver/ExportPlink.scala | 68 +++++++++++++++++-- .../org/broadinstitute/hail/expr/Parser.scala | 13 ++++ .../hail/io/plink/ExportBedBimFam.scala | 4 -- .../hail/io/ExportPlinkSuite.scala | 20 ++++++ 5 files changed, 99 insertions(+), 10 deletions(-) diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index a7bbfac1f42..984a23f6e6a 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -381,14 +381,14 @@ def export_genotypes(self, output, condition, types=None, export_ref=False, expo pargs.append('--print-missing') return self.hc.run_command(self, pargs) - def export_plink(self, output): + def export_plink(self, output, fam_expr = 'id = s.id'): """Export as PLINK .bed/.bim/.fam :param str output: Output file base. Will write .bed, .bim and .fam files. """ - pargs = ['exportplink', '--output', output] + pargs = ['exportplink', '--output', output, '--fam-expr', fam_expr] return self.hc.run_command(self, pargs) def export_samples(self, output, condition, types=None): diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala index 104f3db7724..f7b898f77a6 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala @@ -1,6 +1,7 @@ package org.broadinstitute.hail.driver import org.apache.spark.storage.StorageLevel +import org.broadinstitute.hail.expr.{BaseAggregable, EvalContext, Parser, TBoolean, TDouble, TGenotype, TSample, TString, Type} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.io.plink.ExportBedBimFam import org.kohsuke.args4j.{Option => Args4jOption} @@ -11,6 +12,10 @@ object ExportPlink extends Command { @Args4jOption(required = true, name = "-o", aliases = Array("--output"), usage = "Output file base (will generate .bed, .bim, .fam)") var output: String = _ + + @Args4jOption(name = "-f", aliases = Array("--fam-expr"), + usage = "Expression for .fam file values, in sample context only (global, s, sa in scope), assignable fields: famID, id, matID, patID (String), isFemale (Boolean), isCase (Boolean) or qPheno (Double)") + var famExpr: String = "id = s.id" } def newOptions = new Options @@ -26,12 +31,64 @@ object ExportPlink extends Command { def run(state: State, options: Options): State = { val vds = state.vds + val symTab = Map( + "s" -> (0, TSample), + "sa" -> (1, vds.saSignature), + "global" -> (2, vds.globalSignature)) + + val ec = EvalContext(symTab) + ec.set(2, vds.globalAnnotation) + + type Formatter = (() => Option[Any]) => () => String + + val formatID: Formatter = f => () => f().map(_.asInstanceOf[String]).getOrElse("0") + val formatIsFemale: Formatter = f => () => f().map { + _.asInstanceOf[Boolean] match { + case true => "2" + case false => "1" + } + }.getOrElse("0") + val formatIsCase: Formatter = f => () => f().map { + _.asInstanceOf[Boolean] match { + case true => "2" + case false => "1" + } + }.getOrElse("-9") + val formatQPheno: Formatter = f => () => f().map(_.toString).getOrElse("-9") + + val famColumns: Map[String, (Type, Int, Formatter)] = Map( + "famID" -> (TString, 0, formatID), + "id" -> (TString, 1, formatID), + "patID" -> (TString, 2, formatID), + "matID" -> (TString, 3, formatID), + "isFemale" -> (TBoolean, 4, formatIsFemale), + "qPheno" -> (TDouble, 5, formatQPheno), + "isCase" -> (TBoolean, 5, formatIsCase)) + + val exprs = Parser.parseNamedExprs(options.famExpr, ec) + + val famFns: Array[() => String] = Array( + () => "0", () => "0", () => "0", () => "0", () => "-9", () => "-9") + + exprs.foreach { case (name, t, f) => + famColumns.get(name) match { + case Some((colt, i, formatter)) => + if (colt != t) + fatal("invalid type for .fam file column $h: expected $colt, got $t") + famFns(i) = formatter(f) + + case None => + fatal(s"no .fam file column $name") + } + } + val spaceRegex = """\s+""".r val badSampleIds = vds.sampleIds.filter(id => spaceRegex.findFirstIn(id).isDefined) if (badSampleIds.nonEmpty) { - fatal(s"""Found ${ badSampleIds.length } sample IDs with whitespace + fatal( + s"""Found ${ badSampleIds.length } sample IDs with whitespace | Please run `renamesamples' to fix this problem before exporting to plink format - | Bad sample IDs: @1 """.stripMargin, badSampleIds) + | Bad sample IDs: @1 """.stripMargin, badSampleIds) } val bedHeader = Array[Byte](108, 27, 1) @@ -49,8 +106,11 @@ object ExportPlink extends Command { plinkRDD.unpersist() val famRows = vds - .sampleIds - .map(ExportBedBimFam.makeFamRow) + .sampleIdsAndAnnotations + .map { case (s, sa) => + ec.setAll(s, sa) + famFns.map(_()).mkString("\t") + } state.hadoopConf.writeTextFile(options.output + ".fam")(out => famRows.foreach(line => { diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index 3b3713fe8dc..6c368174695 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -186,6 +186,19 @@ object Parser extends JavaTokenParsers { path.tail } + def parseNamedExprs(code: String, ec: EvalContext): Array[(String, BaseType, () => Option[Any])] = { + val parsed = parseAll(named_args, code) match { + case Success(result, _) => result.asInstanceOf[Array[(String, AST)]] + case NoSuccess(msg, _) => fatal(msg) + } + + parsed.map { case (name, ast) => + ast.typecheck(ec) + val f = ast.eval(ec) + (name, ast.`type`, () => Option(f())) + } + } + def parseExprs(code: String, ec: EvalContext): (Array[(BaseType, () => Option[Any])]) = { if (code.matches("""\s*""")) diff --git a/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala b/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala index 005f7cea76f..3dd37176696 100644 --- a/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala +++ b/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala @@ -35,8 +35,4 @@ object ExportBedBimFam { val id = s"${v.contig}:${v.start}:${v.ref}:${v.alt}" s"""${v.contig}\t$id\t0\t${v.start}\t${v.alt}\t${v.ref}""" } - - def makeFamRow(s: String): String = { - s"0\t$s\t0\t0\t0\t-9" - } } diff --git a/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala b/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala index 815972b617e..bb4046583f3 100644 --- a/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala @@ -66,4 +66,24 @@ class ExportPlinkSuite extends SparkSuite { } ) } + + @Test def testFamExport() { + val plink = tmpDir.createTempFile("mendel") + + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array("src/test/resources/mendel.vcf")) + s = SplitMulti.run(s) + s = HardCalls.run(s) + s = AnnotateSamplesFam.run(s, Array("-i", "src/test/resources/mendel.fam", "-d", "\\\\s+")) + s = AnnotateSamplesExpr.run(s, Array("-c", "sa = sa.fam")) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.rsid = str(v)")) + s = AnnotateVariantsExpr.run(s, Array("-c", "va = select(va, rsid)")) + + s = ExportPlink.run(s, Array("-o", plink, "-f", + "famID = sa.famID, id = s.id, matID = sa.matID, patID = sa.patID, isFemale = sa.isFemale, isCase = sa.isCase")) + + var s2 = ImportPlink.run(s, Array("--bfile", plink)) + + assert(s.vds.same(s2.vds)) + } } From d623bb3bde52532bc1b332acf2e6f804e3ae8414 Mon Sep 17 00:00:00 2001 From: jbloom22 Date: Mon, 21 Nov 2016 08:06:31 -0500 Subject: [PATCH 05/51] Added statistical functions useful for meta-analysis, genomic control lambda (#1111) * Added statistical functions useful for meta-analysis and genomic control lambda - added pnorm, qnorm, inverseChiSquaredTailOneDF to stats package - added tests to StatsSuite - added pnorm, qnorm, pchisq1tail, qchisq1tail to function registry - added tests to ExprSuite - added docs to HailExpressionLanguage.md * changed C to Z^2 --- docs/faq/ExpressionLanguage.md | 3 +++ docs/reference/HailExpressionLanguage.md | 6 ++++++ .../hail/expr/FunctionRegistry.scala | 6 ++++++ .../broadinstitute/hail/stats/package.scala | 18 +++++++++++++++-- .../hail/methods/ExprSuite.scala | 6 ++++++ .../hail/stats/StatsSuite.scala | 20 ++++++++++++++++++- 6 files changed, 56 insertions(+), 3 deletions(-) diff --git a/docs/faq/ExpressionLanguage.md b/docs/faq/ExpressionLanguage.md index 3a009907f25..bada94b4cd4 100644 --- a/docs/faq/ExpressionLanguage.md +++ b/docs/faq/ExpressionLanguage.md @@ -1 +1,4 @@ ## Expression Language + + + diff --git a/docs/reference/HailExpressionLanguage.md b/docs/reference/HailExpressionLanguage.md index fea1043b55c..249ed0574ea 100644 --- a/docs/reference/HailExpressionLanguage.md +++ b/docs/reference/HailExpressionLanguage.md @@ -58,6 +58,12 @@ Several Hail commands provide the ability to perform a broad array of computatio - pcoin(p) -- returns `true` with probability `p`. `p` should be between 0.0 and 1.0 - runif(min, max) -- returns a random draw from a uniform distribution on \[`min`, `max`). `min` should be less than or equal to `max` - rnorm(mean, sd) -- returns a random draw from a normal distribution with mean `mean` and standard deviation `sd`. `sd` should be non-negative + + - Statistics + - pnorm(x) -- Returns left-tail probability p for which p = Prob($Z$ < x) with $Z$ a standard normal random variable + - qnorm(p) -- Returns left-quantile x for which p = Prob($Z$ < x) with $Z$ a standard normal random variable. `p` must satisfy `0 < p < 1`. Inverse of `pnorm` + - pchisq1tail(x) -- Returns right-tail probability p for which p = Prob($Z^2$ > x) with $Z^2$ a chi-squared random variable with one degree of freedom. `x` must be positive + - qchisq1tail(p) -- Returns right-quantile x for which p = Prob($Z^2$ > x) with $Z^2$ a chi-squared RV with one degree of freedom. `p` must satisfy `0 < p <= 1`. Inverse of `pchisq1tail` - Array Operations: - constructor: `[element1, element2, ...]` -- Create a new array from elements of the same type. diff --git a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala index bef2a399d04..c7bc99ea53f 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala @@ -356,6 +356,12 @@ object FunctionRegistry { register("runif", { (min: Double, max: Double) => min + (max - min) * math.random }) register("rnorm", { (mean: Double, sd: Double) => mean + sd * scala.util.Random.nextGaussian() }) + register("pnorm", { (x: Double) => pnorm(x) }) + register("qnorm", { (p: Double) => qnorm(p) }) + + register("pchisq1tail", { (x: Double) => chiSquaredTail(1.0, x) }) + register("qchisq1tail", { (p: Double) => inverseChiSquaredTailOneDF(p) }) + registerConversion((x: Int) => x.toDouble, priority = 2) registerConversion { (x: Long) => x.toDouble } registerConversion { (x: Int) => x.toLong } diff --git a/src/main/scala/org/broadinstitute/hail/stats/package.scala b/src/main/scala/org/broadinstitute/hail/stats/package.scala index 6abffa56048..2e43d01fb4b 100644 --- a/src/main/scala/org/broadinstitute/hail/stats/package.scala +++ b/src/main/scala/org/broadinstitute/hail/stats/package.scala @@ -2,10 +2,9 @@ package org.broadinstitute.hail import breeze.linalg.Matrix import org.apache.commons.math3.distribution.HypergeometricDistribution -import org.apache.commons.math3.special.Gamma +import org.apache.commons.math3.special.{Erf, Gamma} import org.apache.spark.SparkContext import org.broadinstitute.hail.annotations.Annotation -import org.broadinstitute.hail.expr.{TDouble, TInt, TStruct} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.{Genotype, Variant, VariantDataset, VariantMetadata, VariantSampleMatrix} @@ -271,11 +270,26 @@ package object stats { Array(Option(pvalue), oddsRatioEstimate, confInterval._1, confInterval._2) } + val sqrt2 = math.sqrt(2) + + // Returns the p for which p = Prob(Z < x) with Z a standard normal RV + def pnorm(x: Double) = 0.5 * (1 + Erf.erf(x / sqrt2)) + + // Returns the x for which p = Prob(Z < x) with Z a standard normal RV + def qnorm(p: Double) = sqrt2 * Erf.erfInv(2 * p - 1) + + // Returns the p for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with one degree of freedom // This implementation avoids the round-off error truncation issue in // org.apache.commons.math3.distribution.ChiSquaredDistribution, // which computes the CDF with regularizedGammaP and p = 1 - CDF. def chiSquaredTail(df: Double, x: Double) = Gamma.regularizedGammaQ(df / 2, x / 2) + // Returns the x for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with one degree of freedom + def inverseChiSquaredTailOneDF(p: Double) = { + val q = qnorm(0.5 * p) + q * q + } + def uninitialized[T]: T = { class A { var x: T = _ diff --git a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala index 75e6ce76471..2140bc830aa 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala @@ -507,6 +507,12 @@ class ExprSuite extends SparkSuite { assert(eval[Boolean]("rnorm(2.0, 4.0).abs > -1.0").contains(true)) + assert(D_==(eval[Double]("pnorm(qnorm(0.5))").get, 0.5)) + assert(D_==(eval[Double]("qnorm(pnorm(0.5))").get, 0.5)) + + assert(D_==(eval[Double]("qchisq1tail(pchisq1tail(0.5))").get, 0.5)) + assert(D_==(eval[Double]("pchisq1tail(qchisq1tail(0.5))").get, 0.5)) + assert(eval[Any]("if (true) NA: Double else 0.0").isEmpty) assert(eval[Int]("gtIndex(3, 5)").contains(18)) diff --git a/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala b/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala index 59a3fa40ac9..eeca50781e9 100644 --- a/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala @@ -1,7 +1,7 @@ package org.broadinstitute.hail.stats import breeze.linalg.DenseMatrix -import org.apache.commons.math3.distribution.ChiSquaredDistribution +import org.apache.commons.math3.distribution.{ChiSquaredDistribution, NormalDistribution} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.Variant import org.broadinstitute.hail.SparkSuite @@ -21,6 +21,24 @@ class StatsSuite extends SparkSuite { val chiSq5 = new ChiSquaredDistribution(5.2) assert(D_==(chiSquaredTail(5.2, 1), 1 - chiSq5.cumulativeProbability(1))) assert(D_==(chiSquaredTail(5.2, 5.52341), 1 - chiSq5.cumulativeProbability(5.52341))) + + assert(D_==(inverseChiSquaredTailOneDF(.1), chiSq1.inverseCumulativeProbability(1 - .1))) + assert(D_==(inverseChiSquaredTailOneDF(.0001), chiSq1.inverseCumulativeProbability(1 - .0001))) + + val a = List(.0000000001, .5, .9999999999, 1.0) + a.foreach(p => println(p, inverseChiSquaredTailOneDF(p))) + a.foreach(p => assert(D_==(chiSquaredTail(1.0, inverseChiSquaredTailOneDF(p)), p))) + } + + @Test def normTest() = { + val normalDist = new NormalDistribution() + assert(D_==(pnorm(1), normalDist.cumulativeProbability(1))) + assert(D_==(pnorm(-10), normalDist.cumulativeProbability(-10))) + assert(D_==(qnorm(.6), normalDist.inverseCumulativeProbability(.6))) + assert(D_==(qnorm(.0001), normalDist.inverseCumulativeProbability(.0001))) + + val a = List(0.0, .0000000001, .5, .9999999999, 1.0) + assert(a.forall(p => D_==(qnorm(pnorm(qnorm(p))), qnorm(p)))) } @Test def vdsFromMatrixTest() { From 60a197e7f286e85f7bc2de728c790fff0f64945f Mon Sep 17 00:00:00 2001 From: jbloom22 Date: Mon, 21 Nov 2016 10:38:43 -0500 Subject: [PATCH 06/51] changed nSmaller to nLess for histograms (#1112) * changed nSmaller to nLess for histograms * fixed formatting and docs * changed gqDensities to gqHist --- docs/reference/HailExpressionLanguage.md | 8 ++--- .../hail/stats/HistogramCombiner.scala | 31 ++++++++++--------- .../hail/methods/AggregatorSuite.scala | 12 +++---- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/docs/reference/HailExpressionLanguage.md b/docs/reference/HailExpressionLanguage.md index 249ed0574ea..b20a49982e9 100644 --- a/docs/reference/HailExpressionLanguage.md +++ b/docs/reference/HailExpressionLanguage.md @@ -365,7 +365,7 @@ The resulting array is sorted by count in descending order (the most common elem .hist( start, end, bins ) ``` -This aggregator is used to compute density distributions of numeric parameters. The start, end, and bins params are no-scope parameters, which means that while computations like `100 / 4` are acceptable, variable references like `global.nBins` are not. +This aggregator is used to compute frequency distributions of numeric parameters. The start, end, and bins params are no-scope parameters, which means that while computations like `100 / 4` are acceptable, variable references like `global.nBins` are not. The result of a `hist` invocation is a struct: @@ -373,7 +373,7 @@ The result of a `hist` invocation is a struct: Struct { binEdges: Array[Double], binFrequencies: Array[Long], - nSmaller: Long, + nLess: Long, nGreater: Long } ``` @@ -384,7 +384,7 @@ Important properties: - (bins + 1) breakpoints are generated from the range `(start to end by binsize)` - `binEdges` stores an array of bin cutoffs. Each bin is left-inclusive, right-exclusive except the last bin, which includes the maximum value. This means that if there are N total bins, there will be N + 1 elements in binEdges. For the invocation `hist(0, 3, 3)`, `binEdges` would be `[0, 1, 2, 3]` where the bins are `[0, 1)`, `[1, 2)`, `[2, 3]`. - `binFrequencies` stores the number of elements in the aggregable that fall in each bin. It contains one element for each bin. - - Elements greater than the max bin or smaller than the min bin will be tracked separately by `nSmaller` and `nGreater` + - Elements greater than the max bin or less than the min bin will be tracked separately by `nLess` and `nGreater` **Examples:** @@ -398,7 +398,7 @@ Or, extend the above to compute a global gq histogram: ``` annotatevariants expr -c 'va.gqHist = gs.map(g => g.gq).hist(0, 100, 20)' -annotateglobal expr -c 'global.gqDensity = variants.map(v => va.gqHist.densities).sum()' +annotateglobal expr -c 'global.gqHist = variants.map(v => va.gqHist.binFrequencies).sum()' ``` ### Collect diff --git a/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala b/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala index 55d926153a0..5d52748adef 100644 --- a/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala +++ b/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala @@ -9,7 +9,7 @@ object HistogramCombiner { def schema: Type = TStruct( "binEdges" -> TArray(TDouble), "binFrequencies" -> TArray(TLong), - "nSmaller" -> TLong, + "nLess" -> TLong, "nGreater" -> TLong) } @@ -18,13 +18,13 @@ class HistogramCombiner(indices: Array[Double]) extends Serializable { val min = indices.head val max = indices(indices.length - 1) - var nSmaller = 0L + var nLess = 0L var nGreater = 0L - val density = Array.fill(indices.length - 1)(0L) + val frequency = Array.fill(indices.length - 1)(0L) def merge(d: Double): HistogramCombiner = { if (d < min) - nSmaller += 1 + nLess += 1 else if (d > max) nGreater += 1 else if (!d.isNaN) { @@ -32,27 +32,28 @@ class HistogramCombiner(indices: Array[Double]) extends Serializable { val ind = if (bs < 0) -bs - 2 else - math.min(bs, density.length - 1) - assert(ind < density.length && ind >= 0, s"""found out of bounds index $ind - | Resulted from trying to merge $d - | Indices are [${indices.mkString(", ")}] - | Binary search index was $bs""".stripMargin) - density(ind) += 1 + math.min(bs, frequency.length - 1) + assert(ind < frequency.length && ind >= 0, + s"""found out of bounds index $ind + | Resulted from trying to merge $d + | Indices are [${ indices.mkString(", ") }] + | Binary search index was $bs""".stripMargin) + frequency(ind) += 1 } this } def merge(that: HistogramCombiner): HistogramCombiner = { - require(density.length == that.density.length) + require(frequency.length == that.frequency.length) - nSmaller += that.nSmaller + nLess += that.nLess nGreater += that.nGreater - for (i <- density.indices) - density(i) += that.density(i) + for (i <- frequency.indices) + frequency(i) += that.frequency(i) this } - def toAnnotation: Annotation = Annotation(indices: IndexedSeq[Double], density: IndexedSeq[Long], nSmaller, nGreater) + def toAnnotation: Annotation = Annotation(indices: IndexedSeq[Double], frequency: IndexedSeq[Long], nLess, nGreater) } diff --git a/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala index 674357ce3ca..7f18f5205d5 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala @@ -140,24 +140,24 @@ class AggregatorSuite extends SparkSuite { s2.vds.rdd.collect.foreach { case (v, (va, gs)) => val r = va.asInstanceOf[Row] - val densities = r.getAs[IndexedSeq[Long]](1) + val frequencies = r.getAs[IndexedSeq[Long]](1) val definedGq = gs.flatMap(_.gq) - assert(densities(0) == definedGq.count(gq => gq < 5)) - assert(densities(1) == definedGq.count(gq => gq >= 5 && gq < 10)) - assert(densities.last == definedGq.count(gq => gq >= 95)) + assert(frequencies(0) == definedGq.count(gq => gq < 5)) + assert(frequencies(1) == definedGq.count(gq => gq >= 5 && gq < 10)) + assert(frequencies.last == definedGq.count(gq => gq >= 95)) } val s3 = AnnotateVariantsExpr.run(s, Array("-c", "va = gs.map(g => g.gq).hist(22, 80, 5)")) s3.vds.rdd.collect.foreach { case (v, (va, gs)) => val r = va.asInstanceOf[Row] - val nSmaller = r.getAs[Long](2) + val nLess = r.getAs[Long](2) val nGreater = r.getAs[Long](3) val definedGq = gs.flatMap(_.gq) - assert(nSmaller == definedGq.count(_ < 22)) + assert(nLess == definedGq.count(_ < 22)) assert(nGreater == definedGq.count(_ > 80)) } From 969032a22f1494b64cb7afd93459211a682c07bd Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Mon, 21 Nov 2016 14:51:22 -0500 Subject: [PATCH 07/51] Fix python methods, expose samples_to_pandas, integrate logging (#1085) * Fix python methods, expose samples_to_pandas * Better logging support * More better * made exception handling messages private * fix tests * call configure from main * Address comments * Added configure to tests * Fix bgzip test --- python/pyhail/context.py | 67 ++++++++++++++---- python/pyhail/dataset.py | 48 +++++++++---- .../broadinstitute/hail/driver/Command.scala | 2 +- .../org/broadinstitute/hail/driver/Main.scala | 49 +------------ .../broadinstitute/hail/driver/package.scala | 68 +++++++++++++++++++ .../org/broadinstitute/hail/expr/AST.scala | 5 +- .../org/broadinstitute/hail/expr/Parser.scala | 11 +++ .../broadinstitute/hail/utils/package.scala | 39 ++++++++++- .../hail/variant/VariantSampleMatrix.scala | 12 ++++ .../org/broadinstitute/hail/SparkSuite.scala | 3 + .../hail/io/compress/BGzipCodecSuite.scala | 2 +- 11 files changed, 229 insertions(+), 77 deletions(-) diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 2cb6b6a307a..b1d3c03ff77 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -1,16 +1,47 @@ import pyspark from pyhail.dataset import VariantDataset -from pyhail.java import jarray, scala_object +from pyhail.java import jarray, scala_object, scala_package_object +from py4j.protocol import Py4JJavaError + + +class FatalError(Exception): + """:class:`.FatalError` is an error thrown by Hail method failures""" + + def __init__(self, message, java_exception): + self.msg = message + self.java_exception = java_exception + super(FatalError) + + def __str__(self): + return self.msg + class HailContext(object): """:class:`.HailContext` is the main entrypoint for PyHail functionality. :param SparkContext sc: The pyspark context. + + :param str log: Log file. + + :param bool quiet: Don't write log file. + + :param bool append: Append to existing log file. + + :param long block_size: Minimum size of file splits in MB. + + :param str parquet_compression: Parquet compression codec. + + :param int branching_factor: Branching factor to use in tree aggregate. + + :param str tmp_dir: Temporary directory for file merging. """ - def __init__(self, sc): + def __init__(self, sc=None, log='hail.log', quiet=False, append=False, + block_size=1, parquet_compression='uncompressed', + branching_factor=50, tmp_dir='/tmp'): + self.sc = sc self.gateway = sc._gateway @@ -23,26 +54,37 @@ def __init__(self, sc): self.sql_context = pyspark.sql.SQLContext(sc, self.jsql_context) - self.jsc.hadoopConfiguration().set( - 'io.compression.codecs', - 'org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec') + scala_package_object(self.jvm.org.broadinstitute.hail.driver).configure( + self.jsc, + log, + quiet, + append, + parquet_compression, + block_size, + branching_factor, + tmp_dir) - logger = sc._jvm.org.apache.log4j - logger.LogManager.getLogger("org"). setLevel(logger.Level.ERROR) - logger.LogManager.getLogger("akka").setLevel(logger.Level.ERROR) def _jstate(self, jvds): return self.jvm.org.broadinstitute.hail.driver.State( self.jsc, self.jsql_context, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty()) + def _raise_py4j_exception(self, e): + msg = scala_package_object(self.jvm.org.broadinstitute.hail.utils).getMinimalMessage(e.java_exception) + raise FatalError(msg, e.java_exception) + def run_command(self, vds, pargs): jargs = jarray(self.gateway, self.jvm.java.lang.String, pargs) t = self.jvm.org.broadinstitute.hail.driver.ToplevelCommands.lookup(jargs) cmd = t._1() cmd_args = t._2() jstate = self._jstate(vds.jvds if vds != None else None) - result = cmd.run(jstate, - cmd_args) + + try: + result = cmd.run(jstate, cmd_args) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + return VariantDataset(self, result.vds()) def grep(self, regex, path, max_count=100): @@ -74,7 +116,7 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No # text table options types=None, missing="NA", delimiter="\\t", comment=None, header=True, impute=False): - """Import variants and variant annotaitons from a delimited text file + """Import variants and variant annotations from a delimited text file (text table) as a sites-only VariantDataset. :param path: The files to import. @@ -427,7 +469,8 @@ def balding_nichols_model(self, populations, samples, variants, npartitions, :rtype: :class:`.VariantDataset` """ - pargs = ['baldingnichols', '-k', str(populations), '-n', str(samples), '-m', str(variants), '--npartitions', str(npartitions), + pargs = ['baldingnichols', '-k', str(populations), '-n', str(samples), '-m', str(variants), '--npartitions', + str(npartitions), '--root', root] if population_dist: pargs.append('-d') diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index 984a23f6e6a..8e88e2ac6f5 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -1,12 +1,17 @@ from pyhail.java import scala_package_object import pyspark +from py4j.protocol import Py4JJavaError + class VariantDataset(object): def __init__(self, hc, jvds): self.hc = hc self.jvds = jvds + def _raise_py4j_exception(self, e): + self.hc._raise_py4j_exception(e) + def aggregate_intervals(self, input, condition, output): """Aggregate over intervals and export. @@ -324,9 +329,12 @@ def count(self, genotypes=False): """ - return (scala_package_object(self.hc.jvm.org.broadinstitute.hail.driver) - .count(self.jvds, genotypes) - .toJavaMap()) + try: + return (scala_package_object(self.hc.jvm.org.broadinstitute.hail.driver) + .count(self.jvds, genotypes) + .toJavaMap()) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def deduplicate(self): """Remove duplicate variants.""" @@ -505,7 +513,7 @@ def write(self, output, overwrite=False): :param str output: Path of .vds file to write. :param bool overwrite: If True, overwrite any existing .vds file. - + """ pargs = ['write', '-o', output] @@ -710,11 +718,10 @@ def join(self, right): and global annotations from self. """ - - return VariantDataset( - self.hc, - self.hc.jvm.org.broadinstitute.hail.driver.Join.join(self.jvds, - right.jvds)) + try: + return VariantDataset(self.hc, self.hc.jvm.org.broadinstitute.hail.driver.Join.join(self.jvds, right.jvds)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def linreg(self, y, covariates='', root='va.linreg', minac=1, minaf=None): """Test each variant for association using the linear regression @@ -877,8 +884,10 @@ def same(self, other): :rtype: bool """ - - return self.jvds.same(other.jvds, 1e-6) + try: + return self.jvds.same(other.jvds, 1e-6) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def sample_qc(self, branching_factor=None): """Compute per-sample QC metrics. @@ -926,7 +935,7 @@ def split_multi(self, propagate_gq=False): pargs.append('--propagate-gq') return self.hc.run_command(self, pargs) - def tdt(self, fam, root = 'va.tdt'): + def tdt(self, fam, root='va.tdt'): """Find transmitted and untransmitted variants; count per variant and nuclear family. @@ -987,5 +996,16 @@ def vep(self, config, block_size=None, root=None, force=False, csq=False): def variants_to_pandas(self): """Convert variants and variant annotations to Pandas dataframe.""" - return pyspark.sql.DataFrame(self.jvds.variantsDF(self.hc.jsql_context), - self.hc.sql_context).toPandas() + try: + return pyspark.sql.DataFrame(self.jvds.variantsDF(self.hc.jsql_context), + self.hc.sql_context).toPandas() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def samples_to_pandas(self): + """Convert samples and sample annotations to Pandas dataframe.""" + try: + return pyspark.sql.DataFrame(self.jvds.samplesDF(self.hc.jsql_context), + self.hc.sql_context).toPandas() + except Py4JJavaError as e: + self._raise_py4j_exception(e) diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index 436f6420500..e40ef4554d5 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -254,7 +254,7 @@ abstract class Command { fatal("this module does not support multiallelic variants.\n Please run `splitmulti' first.") else { if (requiresVDS) - log.info(s"sparkinfo: $name, ${state.vds.nPartitions} partitions, ${state.vds.rdd.getStorageLevel.toReadableString()}") + log.info(s"sparkinfo: $name, ${ state.vds.nPartitions } partitions, ${ state.vds.rdd.getStorageLevel.toReadableString() }") run(state, options) } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/Main.scala b/src/main/scala/org/broadinstitute/hail/driver/Main.scala index a9aa78c025b..6d5d2d9a317 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Main.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Main.scala @@ -34,10 +34,7 @@ object SparkManager { conf.setMaster(local) } - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") _sc = new SparkContext(conf) - _sc.hadoopConfiguration.set("io.compression.codecs", - "org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec") } _sc @@ -241,29 +238,6 @@ object Main { sys.exit(1) } - val logProps = new Properties() - if (options.logQuiet) { - logProps.put("log4j.rootLogger", "OFF, stderr") - - logProps.put("log4j.appender.stderr", "org.apache.log4j.ConsoleAppender") - logProps.put("log4j.appender.stderr.Target", "System.err") - logProps.put("log4j.appender.stderr.threshold", "OFF") - logProps.put("log4j.appender.stderr.layout", "org.apache.log4j.PatternLayout") - logProps.put("log4j.appender.stderr.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") - } else { - logProps.put("log4j.rootLogger", "INFO, logfile") - - logProps.put("log4j.appender.logfile", "org.apache.log4j.FileAppender") - logProps.put("log4j.appender.logfile.append", options.logAppend.toString) - logProps.put("log4j.appender.logfile.file", options.logFile) - logProps.put("log4j.appender.logfile.threshold", "INFO") - logProps.put("log4j.appender.logfile.layout", "org.apache.log4j.PatternLayout") - logProps.put("log4j.appender.logfile.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") - } - - LogManager.resetConfiguration() - PropertyConfigurator.configure(logProps) - if (splitArgs.length == 1) fail(s"hail: fatal: no commands given") @@ -288,23 +262,9 @@ object Main { val sc = SparkManager.createSparkContext("Hail", Option(options.master), "local[*]") - val conf = sc.getConf - conf.set("spark.ui.showConsoleProgress", "false") - val progressBar = ProgressBarBuilder.build(sc) - - conf.set("spark.sql.parquet.compression.codec", options.parquetCompression) - - sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", options.blockSize * 1024L * 1024L) - - /* `DataFrame.write` writes one file per partition. Without this, read will split files larger than the default - * parquet block size into multiple partitions. This causes `OrderedRDD` to fail since the per-partition range - * no longer line up with the RDD partitions. - * - * For reasons we don't understand, the DataFrame code uses `SparkHadoopUtil.get.conf` instead of the Hadoop - * configuration in the SparkContext. Set both for consistency. - */ - SparkHadoopUtil.get.conf.setLong("parquet.block.size", 1099511627776L) - sc.hadoopConfiguration.setLong("parquet.block.size", 1099511627776L) + configure(sc, logFile = options.logFile, quiet = options.logQuiet, append = options.logAppend, + parquetCompression = options.parquetCompression, blockSize = options.blockSize, + branchingFactor = options.branchingFactor, tmpDir = options.tmpDir) val sqlContext = SparkManager.createSQLContext() @@ -313,12 +273,9 @@ object Main { sc.addJar(jar) HailConfiguration.installDir = new File(jar).getParent + "/.." - HailConfiguration.tmpDir = options.tmpDir - HailConfiguration.branchingFactor = options.branchingFactor runCommands(sc, sqlContext, invocations) sc.stop() - progressBar.stop() } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/package.scala b/src/main/scala/org/broadinstitute/hail/driver/package.scala index d564e84199b..a2dced1a88c 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/package.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/package.scala @@ -1,8 +1,15 @@ package org.broadinstitute.hail +import java.io.File import java.util +import java.util.Properties + +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{ProgressBarBuilder, SparkContext} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.VariantDataset + import scala.collection.JavaConverters._ package object driver { @@ -39,4 +46,65 @@ package object driver { CountResult(vds.nSamples, nVariants, nCalled) } + + def configure(sc: SparkContext, logFile: String, quiet: Boolean, append: Boolean, + parquetCompression: String, blockSize: Long, branchingFactor: Int, tmpDir: String) { + require(blockSize > 0) + require(branchingFactor > 0) + + val logProps = new Properties() + if (quiet) { + logProps.put("log4j.rootLogger", "OFF, stderr") + logProps.put("log4j.appender.stderr", "org.apache.log4j.ConsoleAppender") + logProps.put("log4j.appender.stderr.Target", "System.err") + logProps.put("log4j.appender.stderr.threshold", "OFF") + logProps.put("log4j.appender.stderr.layout", "org.apache.log4j.PatternLayout") + logProps.put("log4j.appender.stderr.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") + } else { + logProps.put("log4j.rootLogger", "INFO, logfile") + logProps.put("log4j.appender.logfile", "org.apache.log4j.FileAppender") + logProps.put("log4j.appender.logfile.append", append.toString) + logProps.put("log4j.appender.logfile.file", logFile) + logProps.put("log4j.appender.logfile.threshold", "INFO") + logProps.put("log4j.appender.logfile.layout", "org.apache.log4j.PatternLayout") + logProps.put("log4j.appender.logfile.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") + } + + LogManager.resetConfiguration() + PropertyConfigurator.configure(logProps) + + val conf = sc.getConf + + conf.set("spark.ui.showConsoleProgress", "false") + val progressBar = ProgressBarBuilder.build(sc) + + sc.hadoopConfiguration.set( + "io.compression.codecs", + "org.apache.hadoop.io.compress.DefaultCodec," + + "org.broadinstitute.hail.io.compress.BGzipCodec," + + "org.apache.hadoop.io.compress.GzipCodec") + + conf.set("spark.sql.parquet.compression.codec", parquetCompression) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", blockSize * 1024L * 1024L) + + /* `DataFrame.write` writes one file per partition. Without this, read will split files larger than the default + * parquet block size into multiple partitions. This causes `OrderedRDD` to fail since the per-partition range + * no longer line up with the RDD partitions. + * + * For reasons we don't understand, the DataFrame code uses `SparkHadoopUtil.get.conf` instead of the Hadoop + * configuration in the SparkContext. Set both for consistency. + */ + SparkHadoopUtil.get.conf.setLong("parquet.block.size", 1099511627776L) + sc.hadoopConfiguration.setLong("parquet.block.size", 1099511627776L) + + + val jar = getClass.getProtectionDomain.getCodeSource.getLocation.toURI.getPath + sc.addJar(jar) + + HailConfiguration.installDir = new File(jar).getParent + "/.." + HailConfiguration.tmpDir = tmpDir + HailConfiguration.branchingFactor = branchingFactor + } } diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 9967d896468..08b199ff3bc 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -1649,9 +1649,10 @@ case class IndexOp(posn: Position, f: AST, idx: AST) extends AST(posn, Array(f, } catch { case e: java.lang.IndexOutOfBoundsException => ParserUtils.error(localPos, - s"""Tried to access index [$i] on array ${ JsonMethods.compact(localT.toJSON(a)) } of length ${ a.length } + s"""Invalid array index: tried to access index [$i] on array `@1' of length ${ a.length } | Hint: All arrays in Hail are zero-indexed (`array[0]' is the first element) - | Hint: For accessing `A'-numbered info fields in split variants, `va.info.field[va.aIndex - 1]' is correct""".stripMargin) + | Hint: For accessing `A'-numbered info fields in split variants, `va.info.field[va.aIndex - 1]' is correct""".stripMargin, + JsonMethods.compact(localT.toJSON(a))) case e: Throwable => throw e }) diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index 6c368174695..bd1d3186fb9 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -20,6 +20,17 @@ object ParserUtils { lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' } }^""".stripMargin) } + + def error(pos: Position, msg: String, tr: Truncatable): Nothing = { + val lineContents = pos.longString.split("\n").head + val prefix = s":${ pos.line }:" + fatal( + s"""$msg + |$prefix$lineContents + |${ " " * prefix.length }${ + lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' } + }^""".stripMargin, tr) + } } object Parser extends JavaTokenParsers { diff --git a/src/main/scala/org/broadinstitute/hail/utils/package.scala b/src/main/scala/org/broadinstitute/hail/utils/package.scala index 3c476e45829..4c41172fe5c 100644 --- a/src/main/scala/org/broadinstitute/hail/utils/package.scala +++ b/src/main/scala/org/broadinstitute/hail/utils/package.scala @@ -15,7 +15,44 @@ package object utils extends Logging with richUtils.Implicits with utils.NumericImplicits { - class FatalException(msg: String, logMsg: Option[String] = None) extends RuntimeException(msg) + class FatalException(val msg: String, val logMsg: Option[String] = None) extends RuntimeException(msg) + + def digForFatal(e: Throwable): Option[String] = { + val r = e match { + case f: FatalException => + println(s"found fatal $f") + Some(s"${ e.getMessage }") + case _ => + Option(e.getCause).flatMap(c => digForFatal(c)) + } + r + } + + def deepestMessage(e: Throwable): String = { + var iterE = e + while (iterE.getCause != null) + iterE = iterE.getCause + + s"${ e.getClass.getSimpleName }: ${ e.getLocalizedMessage }" + } + + def expandException(e: Throwable): String = { + val msg = e match { + case f: FatalException => f.logMsg.getOrElse(f.msg) + case _ => e.getLocalizedMessage + } + s"${ e.getClass.getName }: $msg\n\tat ${ e.getStackTrace.mkString("\n\tat ") }${ + Option(e.getCause).map(exception => expandException(exception)).getOrElse("") + }" + } + + def getMinimalMessage(e: Exception): String = { + val fatalOption = digForFatal(e) + val prefix = if (fatalOption.isDefined) "fatal" else "caught exception" + val msg = fatalOption.getOrElse(deepestMessage(e)) + log.error(s"hail: $prefix: $msg\nFrom ${ expandException(e) }") + msg + } trait Truncatable { def truncate: String diff --git a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala index 97d1acc8591..3e7ada75e2b 100644 --- a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala +++ b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala @@ -923,6 +923,18 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, sqlContext.createDataFrame(rowRDD, schema) } + def samplesDF(sqlContext: SQLContext): DataFrame = { + val rowRDD = sparkContext.parallelize( + sampleIdsAndAnnotations.map { case (s, sa) => + Row(s, SparkAnnotationImpex.exportAnnotation(sa, saSignature)) + }) + val schema = StructType(Array( + StructField("sample", StringType, nullable = false), + StructField("sa", saSignature.schema, nullable = true) + )) + + sqlContext.createDataFrame(rowRDD, schema) + } } // FIXME AnyVal Scala 2.11 diff --git a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala index 94bfd70cfdf..bbe2cddac4a 100644 --- a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala @@ -34,6 +34,9 @@ class SparkSuite extends TestNGSuite { val jar = getClass.getProtectionDomain.getCodeSource.getLocation.toURI.getPath HailConfiguration.installDir = new File(jar).getParent + "/.." HailConfiguration.tmpDir = "/tmp" + + driver.configure(sc, logFile = "hail.log", quiet = true, append = false, + parquetCompression = "uncompressed", blockSize = 1L, branchingFactor = 50, tmpDir = "/tmp") } @AfterClass(alwaysRun = true) diff --git a/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala b/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala index 5e797a4ffc7..86a848508c9 100644 --- a/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala @@ -44,7 +44,7 @@ class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat { class BGzipCodecSuite extends SparkSuite { @Test def test() { - sc.hadoopConfiguration.set("io.compression.codecs", "org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec") + sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", 1L) val uncompPath = "src/test/resources/sample.vcf" From 36de318e6a1939d52ed7e49f2a4f5a7abaf6de43 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 18 Oct 2016 23:13:02 -0400 Subject: [PATCH 08/51] experiment with parser --- .../hail/driver/AddKeyTable.scala | 117 ++++++++++++++++++ .../hail/driver/AggregateIntervals.scala | 2 - .../hail/methods/AddKeyTableSuite.scala | 17 +++ 3 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala create mode 100644 src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala new file mode 100644 index 00000000000..43d91df0554 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -0,0 +1,117 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.annotations.Annotation +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.methods.Aggregators +import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.variant.Variant +import org.kohsuke.args4j.{Option => Args4jOption} + +object AddKeyTable extends Command { + class Options extends BaseOptions with TextTableOptions { + @Args4jOption(required = true, name = "-k", aliases = Array("--key-cond"), + usage = "Struct with expr defining keys") + var keyCond: String = _ + + @Args4jOption(required = false, name = "-c", aliases = Array("--cond"), + usage = "Aggregation condition") + var cond: String = _ + + @Args4jOption(required = false, name = "-o", aliases = Array("--output"), + usage = "output file") + var outFile: String = _ + } + + def newOptions = new Options + + def name = "addkeytable" + + def description = "Creates new key table with key determined by an expression" + + def supportsMultiallelic = true + + def requiresVDS = true + + override def hidden = true + + def run(state: State, options: Options): State = { + + val vds = state.vds + val splat = false + +// val cond = options.cond + val keyCond = options.keyCond + + val aggregationEC = EvalContext(Map( + "v" -> (0, TVariant), + "va" -> (1, vds.vaSignature), + "s" -> (2, TSample), + "sa" -> (3, vds.saSignature), + "global" -> (4, vds.globalSignature))) + + val symTab = Map( + "v" -> (0, TVariant), + "va" -> (1, vds.vaSignature), + "s" -> (2, TSample), + "sa" -> (3, vds.saSignature), + "global" -> (4, vds.globalSignature), + "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype))) + + val ec = EvalContext(symTab) + val a = ec.a + + ec.set(4, vds.globalAnnotation) + aggregationEC.set(4, vds.globalAnnotation) + + val (header, parseTypes, f) = Parser.parseNamedArgs(keyCond, ec) + + if (header.isEmpty) + fatal("this module requires one or more named expr arguments") + + + + println(header.mkString("\n")) + println(parseTypes.mkString("\n")) + println(f().mkString("\n")) + + val foo = vds.rdd.map{case (v, (va, gs)) => + ec.set(0, v) + ec.set(1, va) + val (header, parseTypes, f) = Parser.parseNamedArgs(keyCond, ec) + f() + } + +// println(foo.collect().map(_.mkString(",")).mkString("\n")) + + + +// val (zVals, seqOp, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) +// +// val zvf = () => zVals.indices.map(zVals).toArray +// +// val results = vds.variantsAndAnnotations.flatMap { case (v, va) => i => (i, (v, va)) } +// } +// .aggregateByKey(zvf())(seqOp, combOp) +// .collectAsMap() + +// println(parseTypes.mkString("\n")) + + +// val groups = vds.rdd.flatMap { case (v, (va, gs)) => +// val key = qGroupKey(va) +// val genotypes = gs.map { g => g.nNonRefAlleles.getOrElse(9) } //SKAT-O null value is +9 +// key match { +// case Some(x) => +// if (splat) +// for (k <- x.asInstanceOf[Iterable[_]]) yield (k, genotypes) +// else +// Some((x, genotypes)) +// case None => None +// } +// }.groupByKey() + + + + state + } +} diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala index af89ccb2842..dee21c588f1 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala @@ -67,8 +67,6 @@ object AggregateIntervals extends Command { val zvf = () => zVals.indices.map(zVals).toArray - val variantAggregations = Aggregators.buildVariantAggregations(vds, aggregationEC) - val iList = IntervalListAnnotator.read(options.input, sc.hadoopConfiguration) val iListBc = sc.broadcast(iList) diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala new file mode 100644 index 00000000000..61718607dea --- /dev/null +++ b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala @@ -0,0 +1,17 @@ +package org.broadinstitute.hail.methods + +import org.broadinstitute.hail.SparkSuite +import org.testng.annotations.Test +import org.broadinstitute.hail.driver._ + +class AddKeyTableSuite extends SparkSuite { + @Test def test1() { + var s = State(sc, sqlContext, null) + s = ImportVCF.run(s, Array("-i", "src/test/resources/sample.vcf")) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.foo = gs.filter(g => g.isHet).count()")) + s = AnnotateSamplesExpr.run(s, Array("-c", "sa.foo = gs.filter(g => g.isHet).count()")) + s = AnnotateGlobalExpr.run(s, Array("-c", "global.foo = variants.count()")) + s = PrintSchema.run(s, Array.empty[String]) + s = AddKeyTable.run(s, Array("-k", "foo = va.foo, foo1 = global.foo, foo2 = sa.foo")) + } +} From 09ae742568dbb283504d88f2d49866690c9eadea Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Wed, 19 Oct 2016 16:34:18 -0400 Subject: [PATCH 09/51] try to refactor ExportAggregate --- .../hail/driver/AddKeyTable.scala | 22 ++- .../hail/driver/ExportAggregate.scala | 166 ++++++++++++++++++ .../hail/methods/AddKeyTableSuite.scala | 4 +- 3 files changed, 190 insertions(+), 2 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala index 43d91df0554..53293ef6f6f 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -4,7 +4,7 @@ import org.broadinstitute.hail.annotations.Annotation import org.broadinstitute.hail.expr._ import org.broadinstitute.hail.methods.Aggregators import org.broadinstitute.hail.utils._ -import org.broadinstitute.hail.variant.Variant +import org.broadinstitute.hail.variant._ import org.kohsuke.args4j.{Option => Args4jOption} object AddKeyTable extends Command { @@ -68,6 +68,26 @@ object AddKeyTable extends Command { if (header.isEmpty) fatal("this module requires one or more named expr arguments") +// def buildKeyAggregations(vds: VariantDataset, ec: EvalContext) = { +// val aggregators = ec.aggregationFunctions.toArray +// val aggregatorA = ec.a +// +// if (aggregators.isEmpty) +// None +// else { +// +// val localSamplesBc = vds.sampleIdsBc +// val localAnnotationsBc = vds.sampleAnnotationsBc +// +// val nAggregations = aggregators.length +// val nSamples = vds.nSamples +// val depth = HailConfiguration.treeAggDepth(vds.nPartitions) +// +// val baseArray = MultiArray2.fill[Aggregator](nSamples, nAggregations)(null) +// for (i <- 0 until nSamples; j <- 0 until nAggregations) { +// baseArray.update(i, j, aggregators(j).copy()) +// } +// } println(header.mkString("\n")) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala new file mode 100644 index 00000000000..d0a30d33b32 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala @@ -0,0 +1,166 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.annotations._ +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.utils.{MultiArray2} +import org.broadinstitute.hail.variant._ +import org.kohsuke.args4j.{Option => Args4jOption} + +object ExportAggregate extends Command { + + class Options extends BaseOptions { + + @Args4jOption(required = false, name = "-o", aliases = Array("--output"), + usage = "path of output file") + var output: String = _ + + @Args4jOption(required = true, name = "-k", aliases = Array("--key-condition"), + usage = "named expression for which keys to aggregate on (variant and sample)") + var keyCondition: String = _ + + @Args4jOption(required = true, name = "-a", usage = "named expression for item to compute") + var aggCondition: String = _ + } + + def newOptions = new Options + + def name = "exportaggregate" + + def description = "Aggregate and export samples information grouped by a given variant annnotation" + + def supportsMultiallelic = true + + def requiresVDS = true + + def run(state: State, options: Options): State = { + val vds = state.vds + val sc = vds.sparkContext + val keyCond = options.keyCondition + val aggCond = options.aggCondition + val output = options.output + val vas = vds.vaSignature + val sas = vds.saSignature + val localSamplesBc = vds.sampleIdsBc + val localAnnotationsBc = vds.sampleAnnotationsBc + + val aggregationEC = EvalContext(Map( + "v" -> (0, TVariant), + "va" -> (1, vds.vaSignature), + "s" -> (2, TSample), + "sa" -> (3, vds.saSignature), + "global" -> (4, vds.globalSignature))) + + val ec = EvalContext(Map( + "v" -> (0, TVariant), + "va" -> (1, vds.vaSignature), + "s" -> (2, TSample), + "sa" -> (3, vds.saSignature), + "global" -> (4, vds.globalSignature), + "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype)))) + + aggregationEC.set(4, vds.globalAnnotation) + ec.set(4, vds.globalAnnotation) + + val (aggNames, aggTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) + + if (aggNames.isEmpty) + fatal("need at least 1 aggregation argument") + + val aggregators = aggregationEC.aggregationFunctions.toArray + val aggregatorA = aggregationEC.a + val nAggregations = aggregators.length + + val keyParseResult = Parser.parseNamedArgs(keyCond, ec) + + val sampleGroups = vds.sampleIdsAndAnnotations.map { case (s, sa) => + ec.set(2, s) + ec.set(3, sa) + + keyParseResult._3.apply().toIndexedSeq + } + + val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap + val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) + val nSampleGroups = distinctSampleGroupMap.size + + def zero() = { + val baseArray = MultiArray2.fill[Aggregator](nSampleGroups, nAggregations)(null) + for (i <- 0 until nSampleGroups; j <- 0 until nAggregations) { + baseArray.update(i, j, aggregators(j).copy()) + } + baseArray + } + + val mapOp : (Variant, Annotation) => IndexedSeq[Any] = {case (v, va) => + ec.set(0, v) + ec.set(1, va) + keyParseResult._3.apply().toIndexedSeq + } + + val seqOp : (MultiArray2[Aggregator], (Variant, (Annotation, Iterable[Genotype]))) => MultiArray2[Aggregator] = { + case (arr, (v, (va, gs))) => + aggregatorA(0) = v + aggregatorA(1) = va + for ((g, i) <- gs.zipWithIndex) + for (j <- 0 until nAggregations) { + aggregatorA(2) = localSamplesBc.value(i) + aggregatorA(3) = localAnnotationsBc.value(i) + val sampleGroup = siToGroupIndex(i) + arr(sampleGroup, j).seqOp(g) + } + + arr + } + + val combOp : (MultiArray2[Aggregator], MultiArray2[Aggregator]) => MultiArray2[Aggregator] = { + case (arr1, arr2) => + for ((i, j) <- arr1.indices) { + val a1 = arr1(i, j) + a1.combOp(arr2(i, j).asInstanceOf[a1.type]) + } + arr1 + } + + val res = vds.rdd.map { case (v, (va, gs)) => (mapOp(v, va), (v, (va, gs))) } + .aggregateByKey(zero())(seqOp, combOp) + +// +// def getLine(sampleGroupIndex: Integer, values: MultiArray2[Any], sb:StringBuilder) : String = { +// for (j <- 0 until nAggregations) { +// aggregatorA(aggregators(j).idx) = values(sampleGroupIndex, j) +// } +// +// aggregationParseResult.foreachBetween { case (t, f) => +// sb.append(f().map(TableAnnotationImpex.exportAnnotation(_, t)).getOrElse("NA")) +// } { sb += '\t' } +// sb.result() +// } +// +// res.map({ +// case (variantGroup, values) => +// +// val sb = new StringBuilder() +// val lines = for ((sampleGroup, i) <- distinctSampleGroupMap.keys.zipWithIndex) yield { +// sb.clear() +// sb.append(sampleGroup.map(_.getOrElse("NA").toString).mkString("\t") + "\t") +// getLine(i,values,sb) +// } +// lines.map(variantGroup.map(_.getOrElse("NA").toString).mkString("\t") + "\t" + _).mkString("\n") +// }) +// .writeTable(options.output, +// header = Some(variantGroupParseResult.map(_._1).mkString("\t") + "\t" + +// sampleGroupsParseResult.map(_._1).mkString("\t") + "\t" + +// aggregationHeader.mkString("\t"))) +// +// val variantGroupEC = EvalContext( Map( +// "v" -> (0, TVariant), +// "va" -> (1, vds.vaSignature), +// "global" -> (2, vds.globalSignature))) +// variantGroupEC.set(2,vds.globalSignature) +// +// val variantGroupParseResult = Parser.parseNamedArgs(options.byV ,variantGroupEC) +// + state + } +} diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala index 61718607dea..3faf044e667 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala @@ -12,6 +12,8 @@ class AddKeyTableSuite extends SparkSuite { s = AnnotateSamplesExpr.run(s, Array("-c", "sa.foo = gs.filter(g => g.isHet).count()")) s = AnnotateGlobalExpr.run(s, Array("-c", "global.foo = variants.count()")) s = PrintSchema.run(s, Array.empty[String]) - s = AddKeyTable.run(s, Array("-k", "foo = va.foo, foo1 = global.foo, foo2 = sa.foo")) + s = ExportAggregate.run(s, Array("-k", "foo = va.foo, foo1 = global.foo, foo2 = sa.foo, foo3 = 5", "-a", "nHet = gs.filter(g => g.isHet).count()")) + s = Count.run(s, Array.empty[String]) +// s = AddKeyTable.run(s, Array("-k", "foo = va.foo, foo1 = global.foo, foo2 = sa.foo")) } } From 99d2987e1b46229846f8b336cfb219ef0a6167b7 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 20 Oct 2016 14:52:45 -0400 Subject: [PATCH 10/51] test --- .../hail/driver/AddKeyTable.scala | 5 + .../hail/driver/ExportAggregate.scala | 2 + .../hail/methods/Aggregators.scala | 146 ++++++++++++++++++ 3 files changed, 153 insertions(+) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala index 53293ef6f6f..94a5d56e8f3 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -68,6 +68,11 @@ object AddKeyTable extends Command { if (header.isEmpty) fatal("this module requires one or more named expr arguments") + val aggregateOption = Aggregators.buildVariantAggregationsByGroup(vds, aggregationEC, f) + + vds.rdd.map{ case (v, (va, gs)) => + aggregateOption.foreach(f => f(v, va, gs)) + } // def buildKeyAggregations(vds: VariantDataset, ec: EvalContext) = { // val aggregators = ec.aggregationFunctions.toArray // val aggregatorA = ec.a diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala index d0a30d33b32..3449a8f7758 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala @@ -125,6 +125,8 @@ object ExportAggregate extends Command { val res = vds.rdd.map { case (v, (va, gs)) => (mapOp(v, va), (v, (va, gs))) } .aggregateByKey(zero())(seqOp, combOp) + res.map{case (key, agg) => key.mkString(",")}.collect().foreach(println(_)) + // // def getLine(sampleGroupIndex: Integer, values: MultiArray2[Any], sb:StringBuilder) : String = { // for (j <- 0 until nAggregations) { diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index 00be90c835f..fd98e1439be 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -43,6 +43,49 @@ object Aggregators { } else None } + def buildVariantAggregationsByGroup(vds: VariantDataset, ec: EvalContext, keyFn: () => Array[_]): Option[(Variant, Annotation, Iterable[Genotype]) => Unit] = { + val aggregators = ec.aggregationFunctions.toArray + val aggregatorA = ec.a + val nAggregators = aggregators.length + + if (aggregators.nonEmpty) { + + val localSamplesBc = vds.sampleIdsBc + val localAnnotationsBc = vds.sampleAnnotationsBc + + val sampleGroups = vds.sampleIdsAndAnnotations.map { case (s, sa) => + ec.set(2, s) + ec.set(3, sa) + + keyFn() + } + + val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap + val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) + val nGroups = distinctSampleGroupMap.size + + val f = (v: Variant, va: Annotation, gs: Iterable[Genotype]) => { + val baseArray = MultiArray2.fill[Aggregator](nGroups, nAggregators)(null) + + aggregatorA(0) = v + aggregatorA(1) = va + + gs.zip(localSamplesBc.value).zip(localAnnotationsBc.value).zip(siToGroupIndex) +// (gs, localSamplesBc.value, localAnnotationsBc.value).zipped + .foreach { + case ((((g, s), sa), gi)) => + aggregatorA(2) = s + aggregatorA(3) = sa + for (j <- 0 until baseArray.n2) + baseArray(gi, j).seqOp(g) + } + + baseArray.foreach { agg => aggregatorA(agg.idx) = agg.result } + } + Some(f) + } else None + } + def buildSampleAggregations(vds: VariantDataset, ec: EvalContext): Option[(String) => Unit] = { val aggregators = ec.aggregationFunctions.toArray val aggregatorA = ec.a @@ -97,6 +140,109 @@ object Aggregators { } } + def buildGroupedSampleAggregations(vds: VariantDataset, ec: EvalContext, keyFn: () => Array[_]): Option[(Array[_]) => Unit] = { + val aggregators = ec.aggregationFunctions.toArray + val aggregatorA = ec.a + + if (aggregators.isEmpty) + None + else { + + val localSamplesBc = vds.sampleIdsBc + val localAnnotationsBc = vds.sampleAnnotationsBc + + val sampleGroups = vds.sampleIdsAndAnnotations.map { case (s, sa) => + ec.set(2, s) + ec.set(3, sa) + + keyFn() + } + + val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap + val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) + val nGroups = distinctSampleGroupMap.size + + val nAggregations = aggregators.length + val depth = HailConfiguration.treeAggDepth(vds.nPartitions) + + val baseArray = MultiArray2.fill[Aggregator](nGroups, nAggregations)(null) + for (i <- 0 until nGroups; j <- 0 until nAggregations) { + baseArray.update(i, j, aggregators(j).copy()) + } + + val result = vds.rdd.treeAggregate(baseArray)({ case (arr, (v, (va, gs))) => + aggregatorA(0) = v + aggregatorA(1) = va + var i = 0 + gs.foreach { g => + aggregatorA(2) = localSamplesBc.value(i) + aggregatorA(3) = localAnnotationsBc.value(i) + + val gi = siToGroupIndex(i) + var ai = 0 + while (ai < nAggregations) { + arr(gi, ai).seqOp(g) + ai += 1 + } + i += 1 + } + arr + }, { case (arr1, arr2) => + for (i <- 0 until nGroups; j <- 0 until nAggregations) { + val a1 = arr1(i, j) + a1.combOp(arr2(i, j).asInstanceOf[a1.type]) + } + arr1 + }, depth = depth) + + Some((s: Array[_]) => { + val i = distinctSampleGroupMap(s) + for (j <- 0 until nAggregations) { + aggregatorA(aggregators(j).idx) = result(i, j).result + } + }) + } + } +// def makeGroupedFunctions(ec: EvalContext, keyFn: () => Array[String]): (MultiArray2[Aggregator], (MultiArray2[Aggregator], (Any, Any)) => MultiArray2[Aggregator], +// (MultiArray2[Aggregator], MultiArray2[Aggregator]) => MultiArray2[Aggregator], (MultiArray2[Aggregator]) => Unit) = { +// +// val aggregators = ec.aggregationFunctions.toArray +// val nAggregators = aggregators.length +// +// val nGroups = ??? +// +// val arr = ec.a +// +// val baseArray = MultiArray2.fill[Aggregator](nGroups, nAggregators)(null) +// +// val zero = { +// for ((i, j) <- baseArray.indices) +// baseArray(i, j) = aggregators(j).copy() +// baseArray +// } +// +// val seqOp = (array: MultiArray2[Aggregator], b: (Any, Any)) => { +// val (aggT, annotation) = b +// ec.set(0, annotation) +// for ((i, j) <- array.indices) { +// array(i, j).seqOp(aggT) +// } +// array +// } +// +// val combOp = (arr1: MultiArray2[Aggregator], arr2: MultiArray2[Aggregator]) => { +// for ((i, j) <- arr1.indices) { +// val a1 = arr1(i, j) +// a1.combOp(arr2(i, j).asInstanceOf[a1.type]) +// } +// arr1 +// } +// +// val resultOp = (array: MultiArray2[Aggregator]) => array.foreach { res => arr(res.idx) = res.result } +// +// (zero, seqOp, combOp, resultOp) +// } + def makeFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any)) => Array[Aggregator], (Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = { From e567c499a6090430a33b6fdf5a3c69aac11ef772 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Fri, 21 Oct 2016 12:32:04 -0400 Subject: [PATCH 11/51] trying out things --- .../hail/driver/ExportAggregate.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala index 3449a8f7758..cb49c6944c3 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala @@ -80,6 +80,14 @@ object ExportAggregate extends Command { keyParseResult._3.apply().toIndexedSeq } + // val variantGroupEC = EvalContext( Map( + // "v" -> (0, TVariant), + // "va" -> (1, vds.vaSignature), + // "global" -> (2, vds.globalSignature))) + // variantGroupEC.set(2,vds.globalSignature) + // + // val variantGroupParseResult = Parser.parseNamedArgs(options.byV ,variantGroupEC) + val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) val nSampleGroups = distinctSampleGroupMap.size @@ -125,7 +133,7 @@ object ExportAggregate extends Command { val res = vds.rdd.map { case (v, (va, gs)) => (mapOp(v, va), (v, (va, gs))) } .aggregateByKey(zero())(seqOp, combOp) - res.map{case (key, agg) => key.mkString(",")}.collect().foreach(println(_)) +// res.map{case (key, agg) => key.mkString(",")}.collect().foreach(println(_)) // // def getLine(sampleGroupIndex: Integer, values: MultiArray2[Any], sb:StringBuilder) : String = { @@ -155,13 +163,7 @@ object ExportAggregate extends Command { // sampleGroupsParseResult.map(_._1).mkString("\t") + "\t" + // aggregationHeader.mkString("\t"))) // -// val variantGroupEC = EvalContext( Map( -// "v" -> (0, TVariant), -// "va" -> (1, vds.vaSignature), -// "global" -> (2, vds.globalSignature))) -// variantGroupEC.set(2,vds.globalSignature) -// -// val variantGroupParseResult = Parser.parseNamedArgs(options.byV ,variantGroupEC) + // state } From 78b67cbcbb9561223e100d2fd3a039dd70d93a4f Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Fri, 21 Oct 2016 14:54:29 -0400 Subject: [PATCH 12/51] working wout tests; needs improvements --- .../hail/driver/AddKeyTable.scala | 109 ++++------- .../hail/methods/Aggregators.scala | 170 ++++-------------- .../hail/methods/AddKeyTableSuite.scala | 5 +- 3 files changed, 69 insertions(+), 215 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala index 94a5d56e8f3..520ed1e8982 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -1,10 +1,8 @@ package org.broadinstitute.hail.driver -import org.broadinstitute.hail.annotations.Annotation import org.broadinstitute.hail.expr._ import org.broadinstitute.hail.methods.Aggregators import org.broadinstitute.hail.utils._ -import org.broadinstitute.hail.variant._ import org.kohsuke.args4j.{Option => Args4jOption} object AddKeyTable extends Command { @@ -13,13 +11,13 @@ object AddKeyTable extends Command { usage = "Struct with expr defining keys") var keyCond: String = _ - @Args4jOption(required = false, name = "-c", aliases = Array("--cond"), + @Args4jOption(required = true, name = "-a", aliases = Array("--agg-cond"), usage = "Aggregation condition") - var cond: String = _ + var aggCond: String = _ - @Args4jOption(required = false, name = "-o", aliases = Array("--output"), + @Args4jOption(required = true, name = "-o", aliases = Array("--output"), usage = "output file") - var outFile: String = _ + var output: String = _ } def newOptions = new Options @@ -37,9 +35,9 @@ object AddKeyTable extends Command { def run(state: State, options: Options): State = { val vds = state.vds - val splat = false + val sc = state.sc -// val cond = options.cond + val aggCond = options.aggCond val keyCond = options.keyCond val aggregationEC = EvalContext(Map( @@ -63,79 +61,44 @@ object AddKeyTable extends Command { ec.set(4, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (header, parseTypes, f) = Parser.parseNamedArgs(keyCond, ec) + val (keyNames, keyParseTypes, keyF) = Parser.parseNamedArgs(keyCond, ec) + val (aggNames, aggParseTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) - if (header.isEmpty) - fatal("this module requires one or more named expr arguments") + if (keyNames.isEmpty) + fatal("this module requires one or more named expr arguments as keys") + if (aggNames.isEmpty) + fatal("this module requires one or more named expr arguments to aggregate by key") - val aggregateOption = Aggregators.buildVariantAggregationsByGroup(vds, aggregationEC, f) - - vds.rdd.map{ case (v, (va, gs)) => - aggregateOption.foreach(f => f(v, va, gs)) - } -// def buildKeyAggregations(vds: VariantDataset, ec: EvalContext) = { -// val aggregators = ec.aggregationFunctions.toArray -// val aggregatorA = ec.a -// -// if (aggregators.isEmpty) -// None -// else { -// -// val localSamplesBc = vds.sampleIdsBc -// val localAnnotationsBc = vds.sampleAnnotationsBc -// -// val nAggregations = aggregators.length -// val nSamples = vds.nSamples -// val depth = HailConfiguration.treeAggDepth(vds.nPartitions) -// -// val baseArray = MultiArray2.fill[Aggregator](nSamples, nAggregations)(null) -// for (i <- 0 until nSamples; j <- 0 until nAggregations) { -// baseArray.update(i, j, aggregators(j).copy()) -// } -// } - - - println(header.mkString("\n")) - println(parseTypes.mkString("\n")) - println(f().mkString("\n")) - - val foo = vds.rdd.map{case (v, (va, gs)) => - ec.set(0, v) - ec.set(1, va) - val (header, parseTypes, f) = Parser.parseNamedArgs(keyCond, ec) - f() - } - -// println(foo.collect().map(_.mkString(",")).mkString("\n")) + val (zVals, seqOp, combOp, resultOp) = Aggregators.makeKeyFunctions(aggregationEC) + val zvf = () => zVals.indices.map(zVals).toArray + val results = vds.mapPartitionsWithAll{ it => + it.map { case (v, va, s, sa, g) => + ec.setAll(v, va, s, sa, g) + val key = keyF().toIndexedSeq + (key, (v, va, s, sa, g)) + } + }.aggregateByKey(zvf())(seqOp, combOp).collectAsMap() + sc.hadoopConfiguration.writeTextFile(options.output) { out => + val sb = new StringBuilder + val headerNames = keyNames ++ aggNames + headerNames.foreachBetween(k => sb.append(k))(sb += '\t') + sb += '\n' -// val (zVals, seqOp, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) -// -// val zvf = () => zVals.indices.map(zVals).toArray -// -// val results = vds.variantsAndAnnotations.flatMap { case (v, va) => i => (i, (v, va)) } -// } -// .aggregateByKey(zvf())(seqOp, combOp) -// .collectAsMap() + results.foreachBetween { case (key, agg) => + key.foreachBetween(k => sb.append(k))(sb += '\t') -// println(parseTypes.mkString("\n")) - - -// val groups = vds.rdd.flatMap { case (v, (va, gs)) => -// val key = qGroupKey(va) -// val genotypes = gs.map { g => g.nNonRefAlleles.getOrElse(9) } //SKAT-O null value is +9 -// key match { -// case Some(x) => -// if (splat) -// for (k <- x.asInstanceOf[Iterable[_]]) yield (k, genotypes) -// else -// Some((x, genotypes)) -// case None => None -// } -// }.groupByKey() + resultOp(agg) + aggF().foreach { field => + sb += '\t' + sb.append(field) + } + }(sb += '\n') + out.write(sb.result()) + } state } diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index fd98e1439be..76806bfabdf 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -43,49 +43,6 @@ object Aggregators { } else None } - def buildVariantAggregationsByGroup(vds: VariantDataset, ec: EvalContext, keyFn: () => Array[_]): Option[(Variant, Annotation, Iterable[Genotype]) => Unit] = { - val aggregators = ec.aggregationFunctions.toArray - val aggregatorA = ec.a - val nAggregators = aggregators.length - - if (aggregators.nonEmpty) { - - val localSamplesBc = vds.sampleIdsBc - val localAnnotationsBc = vds.sampleAnnotationsBc - - val sampleGroups = vds.sampleIdsAndAnnotations.map { case (s, sa) => - ec.set(2, s) - ec.set(3, sa) - - keyFn() - } - - val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap - val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) - val nGroups = distinctSampleGroupMap.size - - val f = (v: Variant, va: Annotation, gs: Iterable[Genotype]) => { - val baseArray = MultiArray2.fill[Aggregator](nGroups, nAggregators)(null) - - aggregatorA(0) = v - aggregatorA(1) = va - - gs.zip(localSamplesBc.value).zip(localAnnotationsBc.value).zip(siToGroupIndex) -// (gs, localSamplesBc.value, localAnnotationsBc.value).zipped - .foreach { - case ((((g, s), sa), gi)) => - aggregatorA(2) = s - aggregatorA(3) = sa - for (j <- 0 until baseArray.n2) - baseArray(gi, j).seqOp(g) - } - - baseArray.foreach { agg => aggregatorA(agg.idx) = agg.result } - } - Some(f) - } else None - } - def buildSampleAggregations(vds: VariantDataset, ec: EvalContext): Option[(String) => Unit] = { val aggregators = ec.aggregationFunctions.toArray val aggregatorA = ec.a @@ -140,110 +97,44 @@ object Aggregators { } } - def buildGroupedSampleAggregations(vds: VariantDataset, ec: EvalContext, keyFn: () => Array[_]): Option[(Array[_]) => Unit] = { + def makeFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any)) => Array[Aggregator], + (Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = { + val aggregators = ec.aggregationFunctions.toArray - val aggregatorA = ec.a - if (aggregators.isEmpty) - None - else { + val arr = ec.a - val localSamplesBc = vds.sampleIdsBc - val localAnnotationsBc = vds.sampleAnnotationsBc + val baseArray = Array.fill[Aggregator](aggregators.length)(null) - val sampleGroups = vds.sampleIdsAndAnnotations.map { case (s, sa) => - ec.set(2, s) - ec.set(3, sa) + val zero = { + for (i <- baseArray.indices) + baseArray(i) = aggregators(i).copy() + baseArray + } - keyFn() + val seqOp = (array: Array[Aggregator], b: (Any, Any)) => { + val (aggT, annotation) = b + ec.set(0, annotation) + for (i <- array.indices) { + array(i).seqOp(aggT) } + array + } - val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap - val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) - val nGroups = distinctSampleGroupMap.size - - val nAggregations = aggregators.length - val depth = HailConfiguration.treeAggDepth(vds.nPartitions) - - val baseArray = MultiArray2.fill[Aggregator](nGroups, nAggregations)(null) - for (i <- 0 until nGroups; j <- 0 until nAggregations) { - baseArray.update(i, j, aggregators(j).copy()) + val combOp = (arr1: Array[Aggregator], arr2: Array[Aggregator]) => { + for (i <- arr1.indices) { + val a1 = arr1(i) + a1.combOp(arr2(i).asInstanceOf[a1.type]) } + arr1 + } - val result = vds.rdd.treeAggregate(baseArray)({ case (arr, (v, (va, gs))) => - aggregatorA(0) = v - aggregatorA(1) = va - var i = 0 - gs.foreach { g => - aggregatorA(2) = localSamplesBc.value(i) - aggregatorA(3) = localAnnotationsBc.value(i) - - val gi = siToGroupIndex(i) - var ai = 0 - while (ai < nAggregations) { - arr(gi, ai).seqOp(g) - ai += 1 - } - i += 1 - } - arr - }, { case (arr1, arr2) => - for (i <- 0 until nGroups; j <- 0 until nAggregations) { - val a1 = arr1(i, j) - a1.combOp(arr2(i, j).asInstanceOf[a1.type]) - } - arr1 - }, depth = depth) + val resultOp = (array: Array[Aggregator]) => array.foreach { res => arr(res.idx) = res.result } - Some((s: Array[_]) => { - val i = distinctSampleGroupMap(s) - for (j <- 0 until nAggregations) { - aggregatorA(aggregators(j).idx) = result(i, j).result - } - }) - } + (zero, seqOp, combOp, resultOp) } -// def makeGroupedFunctions(ec: EvalContext, keyFn: () => Array[String]): (MultiArray2[Aggregator], (MultiArray2[Aggregator], (Any, Any)) => MultiArray2[Aggregator], -// (MultiArray2[Aggregator], MultiArray2[Aggregator]) => MultiArray2[Aggregator], (MultiArray2[Aggregator]) => Unit) = { -// -// val aggregators = ec.aggregationFunctions.toArray -// val nAggregators = aggregators.length -// -// val nGroups = ??? -// -// val arr = ec.a -// -// val baseArray = MultiArray2.fill[Aggregator](nGroups, nAggregators)(null) -// -// val zero = { -// for ((i, j) <- baseArray.indices) -// baseArray(i, j) = aggregators(j).copy() -// baseArray -// } -// -// val seqOp = (array: MultiArray2[Aggregator], b: (Any, Any)) => { -// val (aggT, annotation) = b -// ec.set(0, annotation) -// for ((i, j) <- array.indices) { -// array(i, j).seqOp(aggT) -// } -// array -// } -// -// val combOp = (arr1: MultiArray2[Aggregator], arr2: MultiArray2[Aggregator]) => { -// for ((i, j) <- arr1.indices) { -// val a1 = arr1(i, j) -// a1.combOp(arr2(i, j).asInstanceOf[a1.type]) -// } -// arr1 -// } -// -// val resultOp = (array: MultiArray2[Aggregator]) => array.foreach { res => arr(res.idx) = res.result } -// -// (zero, seqOp, combOp, resultOp) -// } - def makeFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any)) => Array[Aggregator], + def makeKeyFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any, Any, Any, Any)) => Array[Aggregator], (Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = { val aggregators = ec.aggregationFunctions.toArray @@ -258,9 +149,12 @@ object Aggregators { baseArray } - val seqOp = (array: Array[Aggregator], b: (Any, Any)) => { - val (aggT, annotation) = b - ec.set(0, annotation) + val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { + val (v, va, s, sa, aggT) = b + ec.set(0, v) + ec.set(1, va) + ec.set(2, s) + ec.set(3, sa) for (i <- array.indices) { array(i).seqOp(aggT) } diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala index 3faf044e667..50b2ce4e61c 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala @@ -11,9 +11,6 @@ class AddKeyTableSuite extends SparkSuite { s = AnnotateVariantsExpr.run(s, Array("-c", "va.foo = gs.filter(g => g.isHet).count()")) s = AnnotateSamplesExpr.run(s, Array("-c", "sa.foo = gs.filter(g => g.isHet).count()")) s = AnnotateGlobalExpr.run(s, Array("-c", "global.foo = variants.count()")) - s = PrintSchema.run(s, Array.empty[String]) - s = ExportAggregate.run(s, Array("-k", "foo = va.foo, foo1 = global.foo, foo2 = sa.foo, foo3 = 5", "-a", "nHet = gs.filter(g => g.isHet).count()")) - s = Count.run(s, Array.empty[String]) -// s = AddKeyTable.run(s, Array("-k", "foo = va.foo, foo1 = global.foo, foo2 = sa.foo")) + s = AddKeyTable.run(s, Array("-k", "foo1 = va.foo, foo2 = sa.foo", "-a", "hetCount = gs.filter(g => g.isHet).count(), totalCount = gs.count()", "-o", "testKeyTable.tsv")) } } From 02d2d4b7dae5b5f2e85a1f7ea01dcbebac363fb0 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Mon, 24 Oct 2016 23:37:54 -0400 Subject: [PATCH 13/51] working tests --- .../hail/methods/AddKeyTableSuite.scala | 239 +++++++++++++++++- 1 file changed, 232 insertions(+), 7 deletions(-) diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala index 50b2ce4e61c..b3e2da71e19 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala @@ -3,14 +3,239 @@ package org.broadinstitute.hail.methods import org.broadinstitute.hail.SparkSuite import org.testng.annotations.Test import org.broadinstitute.hail.driver._ +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.annotations._ +import org.broadinstitute.hail.check.Arbitrary._ +import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantDataset, VariantSampleMatrix} +import org.broadinstitute.hail.check.Gen +import org.broadinstitute.hail.check.Prop._ +import org.broadinstitute.hail.check.Properties +import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.utils.TextTableReader class AddKeyTableSuite extends SparkSuite { - @Test def test1() { - var s = State(sc, sqlContext, null) - s = ImportVCF.run(s, Array("-i", "src/test/resources/sample.vcf")) - s = AnnotateVariantsExpr.run(s, Array("-c", "va.foo = gs.filter(g => g.isHet).count()")) - s = AnnotateSamplesExpr.run(s, Array("-c", "sa.foo = gs.filter(g => g.isHet).count()")) - s = AnnotateGlobalExpr.run(s, Array("-c", "global.foo = variants.count()")) - s = AddKeyTable.run(s, Array("-k", "foo1 = va.foo, foo2 = sa.foo", "-a", "hetCount = gs.filter(g => g.isHet).count(), totalCount = gs.count()", "-o", "testKeyTable.tsv")) + + def createKey(nItems: Int, nCategories: Int) = + Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) + + def createKeys(nKeys: Int, nItems: Int) = + Gen.buildableOfN[Array, Array[Option[String]]](nKeys, createKey(nItems, Gen.choose(1, 10).sample())) + +// @Test def test1() { +// var s = State(sc, sqlContext, null) +// s = ImportVCF.run(s, Array("-i", "src/test/resources/sample.vcf")) +// s = AnnotateVariantsExpr.run(s, Array("-c", "va.foo = gs.filter(g => g.isHet).count()")) +// s = AnnotateSamplesExpr.run(s, Array("-c", "sa.foo = gs.filter(g => g.isHet).count()")) +// s = AnnotateGlobalExpr.run(s, Array("-c", "global.foo = variants.count()")) +// s = AddKeyTable.run(s, Array("-k", "foo1 = va.foo, foo2 = sa.foo", "-a", "hetCount = gs.filter(g => g.isHet).count(), totalCount = gs.count()", "-o", "testKeyTable.tsv")) +// } + + object Spec extends Properties("CreateKeyTable") { + val compGen = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); + nKeys <- Gen.choose(1, 5); + nSampleKeys <- Gen.choose(0, nKeys); + nVariantKeys <- Gen.const(nKeys - nSampleKeys); + sampleGroups <- createKeys(nSampleKeys, vds.nSamples); + variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) + ) yield (vds, sampleGroups, variantGroups) + + val compGenSample = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); + nSampleKeys <- Gen.choose(2, 5); + sampleGroups <- createKeys(nSampleKeys, vds.nSamples) + ) yield (vds, sampleGroups) + + val compGenVariant = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); + nVariantKeys <- Gen.choose(2, 5); + variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) + ) yield (vds, variantGroups) + + def getKeyTableResults(fileName: String, keyNames: IndexedSeq[String]) = { + val ktr = hadoopConf.readLines(fileName)(_.map(_.map { line => + line.trim.split("\\s+") + }.value).toIndexedSeq) + + val header = ktr.take(1) + ktr.drop(1).map(r => (header(0), r).zipped.toMap) + .map { x => (keyNames.map { k => x(k) }, (x("nHet"), x("nCalled"), x("nTotal"))) }.toMap + } + + def keyTableEqualAnnExpr(annExprResult: scala.collection.Map[IndexedSeq[String], (Long, Long, Long)], keyTableResults: scala.collection.Map[IndexedSeq[String], (String, String, String)]) = + annExprResult.forall{ case (keys, (nHet, nCalled, nTotal)) => + val (ktHet, ktCalled, ktTotal) = keyTableResults(keys.map(k => if (k != null) k else "NA").toIndexedSeq) + ktHet.toLong == nHet && + ktCalled.toLong == nCalled && + ktTotal.toLong == nTotal + } + + property("group by variant id same as variant aggregations") = forAll(VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0)) {case (vds: VariantDataset) => + val outputFile = tmpDir.createTempFile("aggByVariant", ".tsv") + + var s = State(sc, sqlContext, vds) + + s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count(), va.nCalled = gs.filter(g => g.isCalled).count(), va.nTotal = gs.count()")) + + val (_, nHetQuery) = s.vds.queryVA("va.nHet") + val (_, nCalledQuery) = s.vds.queryVA("va.nCalled") + val (_, nTotalQuery) = s.vds.queryVA("va.nTotal") + + val truthResult = s.vds.variantsAndAnnotations.map{ case (v, va) => + (IndexedSeq(v.toString), (nHetQuery(va).get.asInstanceOf[Long], nCalledQuery(va).get.asInstanceOf[Long], nTotalQuery(va).get.asInstanceOf[Long])) + }.collectAsMap() + + s = AddKeyTable.run(s, Array("-k", "Variant = v", + "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", + "-o", outputFile)) + + val keyTableResults = getKeyTableResults(outputFile, IndexedSeq("Variant")) + + keyTableEqualAnnExpr(truthResult, keyTableResults) + } + + property("group by sample id same as sample aggregations") = forAll(VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0)) {case (vds: VariantDataset) => + val outputFile = tmpDir.createTempFile("aggBySample", ".tsv") + + var s = State(sc, sqlContext, vds) + + s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count(), sa.nCalled = gs.filter(g => g.isCalled).count(), sa.nTotal = gs.count()")) + + val (_, nHetQuery) = s.vds.querySA("sa.nHet") + val (_, nCalledQuery) = s.vds.querySA("sa.nCalled") + val (_, nTotalQuery) = s.vds.querySA("sa.nTotal") + + val truthResult = s.vds.sampleIdsAndAnnotations.map{ case (sid, sa) => + (IndexedSeq(sid), (nHetQuery(sa).get.asInstanceOf[Long], nCalledQuery(sa).get.asInstanceOf[Long], nTotalQuery(sa).get.asInstanceOf[Long])) + }.toMap + + s = AddKeyTable.run(s, Array("-k", "Sample = s", + "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", + "-o", outputFile)) + + val keyTableResults = getKeyTableResults(outputFile, IndexedSeq("Sample")) + + keyTableEqualAnnExpr(truthResult, keyTableResults) + } + + property("aggregate by variant groups same") = forAll(compGenVariant) { case (vds, varGroups) => + val outputFile = tmpDir.createTempFile("aggByVariantGroup", ".tsv") + + val nKeys = varGroups.length + val keyNames = (1 to nKeys).map("key" + _) + + var signature = TStruct() + keyNames.foreach(k => signature = signature.appendKey(k, TString)) + + val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => + (v, Annotation(varGroups.map(_ (i).orNull).toSeq: _*)) + }).toOrderedRDD + + var s = State(sc, sqlContext, vds.annotateVariants(variantAnnotations, signature, "va.keys")) + + s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count(), va.nCalled = gs.filter(g => g.isCalled).count(), va.nTotal = gs.count()")) + + val (_, nHetQuery) = s.vds.queryVA("va.nHet") + val (_, nCalledQuery) = s.vds.queryVA("va.nCalled") + val (_, nTotalQuery) = s.vds.queryVA("va.nTotal") + val (_, keyQuery) = s.vds.queryVA("va.keys.*") + + val truthResult = s.vds.variantsAndAnnotations.map{ case (v, va) => + (keyQuery(va).get.asInstanceOf[IndexedSeq[String]], (nHetQuery(va).get.asInstanceOf[Long], nCalledQuery(va).get.asInstanceOf[Long], nTotalQuery(va).get.asInstanceOf[Long])) + }.aggregateByKey((0L, 0L, 0L))((comb, counts) => (comb._1 + counts._1, comb._2 + counts._2, comb._3 + counts._3), + (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() + + s = AddKeyTable.run(s, Array("-k", keyNames.map( k => k + " = " + "va.keys." + k).mkString(","), + "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", + "-o", outputFile)) + + val keyTableResults = getKeyTableResults(outputFile, keyNames) + + keyTableEqualAnnExpr(truthResult, keyTableResults) + } + + property("aggregate by sample groups same") = forAll(compGenSample) { case (vds, phenotypes) => + val outputFile = tmpDir.createTempFile("aggBySampleGroup", ".tsv") + + val nPhenotypes = phenotypes.length + val keyNames = (1 to nPhenotypes).map("key" + _) + + var signature = TStruct() + keyNames.foreach(k => signature = signature.appendKey(k, TString)) + + val phenoMap = vds.sampleIds.zipWithIndex.map{ case (sid, i) => + (sid, Annotation(phenotypes.map(_(i).orNull).toSeq : _*)) + }.toMap + + var s = State(sc, sqlContext, vds.annotateSamples(phenoMap, signature, "sa.pheno")) + + s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count(), sa.nCalled = gs.filter(g => g.isCalled).count(), sa.nTotal = gs.count()")) + + val (_, nHetQuery) = s.vds.querySA("sa.nHet") + val (_, nCalledQuery) = s.vds.querySA("sa.nCalled") + val (_, nTotalQuery) = s.vds.querySA("sa.nTotal") + val (_, phenoQuery) = s.vds.querySA("sa.pheno.*") + + val truthResult = sc.parallelize(s.vds.sampleIdsAndAnnotations).map{ case (sid, sa) => + (phenoQuery(sa).get.asInstanceOf[IndexedSeq[String]], (nHetQuery(sa).get.asInstanceOf[Long], nCalledQuery(sa).get.asInstanceOf[Long], nTotalQuery(sa).get.asInstanceOf[Long])) + }.aggregateByKey((0L, 0L, 0L))((comb, counts) => (comb._1 + counts._1, comb._2 + counts._2, comb._3 + counts._3), + (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() + + s = AddKeyTable.run(s, Array("-k", keyNames.map( k => k + " = " + "sa.pheno." + k).mkString(","), + "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", + "-o", outputFile)) + + val keyTableResults = getKeyTableResults(outputFile, keyNames) + + keyTableEqualAnnExpr(truthResult, keyTableResults) + } + + property("aggregate by sample and variants same") = forAll(compGen) { case (vds, sampleGroups, variantGroups) => + val outputFile = tmpDir.createTempFile("aggBySampleVariantGroup", ".tsv") + + val nKeys = sampleGroups.length + variantGroups.length + val keyNames = (1 to nKeys).map("key" + _) + val sampleKeyNames = (1 to sampleGroups.length).map("key" + _) + val variantKeyNames = (sampleGroups.length + 1 to nKeys).map("key" + _) + + var sampleSignature = TStruct() + sampleKeyNames.foreach(k => sampleSignature = sampleSignature.appendKey(k, TString)) + + var variantSignature = TStruct() + variantKeyNames.foreach(k => variantSignature = variantSignature.appendKey(k, TString)) + + val sampleMap = vds.sampleIds.zipWithIndex.map{ case (sid, i) => + (sid, Annotation(sampleGroups.map(_(i).orNull).toSeq : _*)) + }.toMap + + val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => + (v, Annotation(variantGroups.map(_ (i).orNull).toSeq: _*)) + }).toOrderedRDD + + var s = State(sc, sqlContext, vds.annotateSamples(sampleMap, sampleSignature, "sa.keys") + .annotateVariants(variantAnnotations, variantSignature, "va.keys")) + + val (_, sampleKeyQuery) = s.vds.querySA("sa.keys.*") + val (_, variantKeyQuery) = s.vds.queryVA("va.keys.*") + + val keyGenotypeRDD = s.vds.mapWithAll{case (v, va, sid, sa, g) => + val key = sampleKeyQuery(sa).get.asInstanceOf[IndexedSeq[String]] ++ variantKeyQuery(va).get.asInstanceOf[IndexedSeq[String]] + (key, g) + } + + val result = keyGenotypeRDD.aggregateByKey((0L, 0L, 0L))( + (comb, gt) => (comb._1 + gt.isHet.toInt.toInt, comb._2 + gt.isCalled.toInt.toInt, comb._3 + 1), + (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() + + s = AddKeyTable.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), + "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", + "-o", outputFile)) + + val keyTableResults = getKeyTableResults(outputFile, keyNames) + + keyTableEqualAnnExpr(result, keyTableResults) + } } + + @Test def testAddKeyTable() { + Spec.check() + } + } From bf37d532fca1b3a77f926d24eb81d7809d097892 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 25 Oct 2016 00:20:27 -0400 Subject: [PATCH 14/51] Added hidden command for creating a key-table - takes in a named expression specifying the keys of the new table - takes named aggregator expressions specifying how the columns should be computed --- .../hail/driver/AddKeyTable.scala | 50 +++-- .../broadinstitute/hail/driver/Command.scala | 1 + .../hail/methods/Aggregators.scala | 40 ---- .../hail/methods/AddKeyTableSuite.scala | 194 +++--------------- 4 files changed, 48 insertions(+), 237 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala index 520ed1e8982..327be889f2b 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -6,13 +6,14 @@ import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} object AddKeyTable extends Command { + class Options extends BaseOptions with TextTableOptions { @Args4jOption(required = true, name = "-k", aliases = Array("--key-cond"), - usage = "Struct with expr defining keys") + usage = "Named key condition", metaVar = "EXPR") var keyCond: String = _ @Args4jOption(required = true, name = "-a", aliases = Array("--agg-cond"), - usage = "Aggregation condition") + usage = "Named aggregation condition", metaVar = "EXPR") var aggCond: String = _ @Args4jOption(required = true, name = "-o", aliases = Array("--output"), @@ -24,7 +25,7 @@ object AddKeyTable extends Command { def name = "addkeytable" - def description = "Creates new key table with key determined by an expression" + def description = "Creates a new key table with key(s) determined by named expressions and additional columns determined by named aggregator expressions" def supportsMultiallelic = true @@ -69,36 +70,33 @@ object AddKeyTable extends Command { if (aggNames.isEmpty) fatal("this module requires one or more named expr arguments to aggregate by key") - val (zVals, seqOp, combOp, resultOp) = Aggregators.makeKeyFunctions(aggregationEC) + val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) val zvf = () => zVals.indices.map(zVals).toArray - val results = vds.mapPartitionsWithAll{ it => + val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { + val (v, va, s, sa, aggT) = b + ec.set(0, v) + ec.set(1, va) + ec.set(2, s) + ec.set(3, sa) + for (i <- array.indices) { + array(i).seqOp(aggT) + } + array + } + + vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) val key = keyF().toIndexedSeq (key, (v, va, s, sa, g)) - } - }.aggregateByKey(zvf())(seqOp, combOp).collectAsMap() - - sc.hadoopConfiguration.writeTextFile(options.output) { out => - val sb = new StringBuilder - val headerNames = keyNames ++ aggNames - headerNames.foreachBetween(k => sb.append(k))(sb += '\t') - sb += '\n' - - results.foreachBetween { case (key, agg) => - key.foreachBetween(k => sb.append(k))(sb += '\t') - + } + }.aggregateByKey(zvf())(seqOp, combOp) + .map { case (k, agg) => resultOp(agg) - - aggF().foreach { field => - sb += '\t' - sb.append(field) - } - }(sb += '\n') - - out.write(sb.result()) - } + (k ++ aggF()).mkString("\t") + } + .writeTable(options.output, Option((keyNames ++ aggNames).mkString("\t"))) state } diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index e40ef4554d5..fdc6caa51a3 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -45,6 +45,7 @@ object ToplevelCommands { + cmd.description)) } + register(AddKeyTable) register(AggregateIntervals) register(AnnotateSamples) register(AnnotateVariants) diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index 76806bfabdf..00be90c835f 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -133,46 +133,6 @@ object Aggregators { (zero, seqOp, combOp, resultOp) } - - def makeKeyFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any, Any, Any, Any)) => Array[Aggregator], - (Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = { - - val aggregators = ec.aggregationFunctions.toArray - - val arr = ec.a - - val baseArray = Array.fill[Aggregator](aggregators.length)(null) - - val zero = { - for (i <- baseArray.indices) - baseArray(i) = aggregators(i).copy() - baseArray - } - - val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { - val (v, va, s, sa, aggT) = b - ec.set(0, v) - ec.set(1, va) - ec.set(2, s) - ec.set(3, sa) - for (i <- array.indices) { - array(i).seqOp(aggT) - } - array - } - - val combOp = (arr1: Array[Aggregator], arr2: Array[Aggregator]) => { - for (i <- arr1.indices) { - val a1 = arr1(i) - a1.combOp(arr2(i).asInstanceOf[a1.type]) - } - arr1 - } - - val resultOp = (array: Array[Aggregator]) => array.foreach { res => arr(res.idx) = res.result } - - (zero, seqOp, combOp, resultOp) - } } class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Long] { diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala index b3e2da71e19..a9ed5dfb2a1 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala @@ -1,17 +1,14 @@ package org.broadinstitute.hail.methods import org.broadinstitute.hail.SparkSuite -import org.testng.annotations.Test -import org.broadinstitute.hail.driver._ -import org.broadinstitute.hail.expr._ import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.check.Arbitrary._ -import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantDataset, VariantSampleMatrix} -import org.broadinstitute.hail.check.Gen import org.broadinstitute.hail.check.Prop._ -import org.broadinstitute.hail.check.Properties +import org.broadinstitute.hail.check.{Gen, Properties} +import org.broadinstitute.hail.driver._ +import org.broadinstitute.hail.expr._ import org.broadinstitute.hail.utils._ -import org.broadinstitute.hail.utils.TextTableReader +import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantDataset, VariantSampleMatrix} +import org.testng.annotations.Test class AddKeyTableSuite extends SparkSuite { @@ -21,15 +18,6 @@ class AddKeyTableSuite extends SparkSuite { def createKeys(nKeys: Int, nItems: Int) = Gen.buildableOfN[Array, Array[Option[String]]](nKeys, createKey(nItems, Gen.choose(1, 10).sample())) -// @Test def test1() { -// var s = State(sc, sqlContext, null) -// s = ImportVCF.run(s, Array("-i", "src/test/resources/sample.vcf")) -// s = AnnotateVariantsExpr.run(s, Array("-c", "va.foo = gs.filter(g => g.isHet).count()")) -// s = AnnotateSamplesExpr.run(s, Array("-c", "sa.foo = gs.filter(g => g.isHet).count()")) -// s = AnnotateGlobalExpr.run(s, Array("-c", "global.foo = variants.count()")) -// s = AddKeyTable.run(s, Array("-k", "foo1 = va.foo, foo2 = sa.foo", "-a", "hetCount = gs.filter(g => g.isHet).count(), totalCount = gs.count()", "-o", "testKeyTable.tsv")) -// } - object Spec extends Properties("CreateKeyTable") { val compGen = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); nKeys <- Gen.choose(1, 5); @@ -39,156 +27,8 @@ class AddKeyTableSuite extends SparkSuite { variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) ) yield (vds, sampleGroups, variantGroups) - val compGenSample = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); - nSampleKeys <- Gen.choose(2, 5); - sampleGroups <- createKeys(nSampleKeys, vds.nSamples) - ) yield (vds, sampleGroups) - - val compGenVariant = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); - nVariantKeys <- Gen.choose(2, 5); - variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) - ) yield (vds, variantGroups) - - def getKeyTableResults(fileName: String, keyNames: IndexedSeq[String]) = { - val ktr = hadoopConf.readLines(fileName)(_.map(_.map { line => - line.trim.split("\\s+") - }.value).toIndexedSeq) - - val header = ktr.take(1) - ktr.drop(1).map(r => (header(0), r).zipped.toMap) - .map { x => (keyNames.map { k => x(k) }, (x("nHet"), x("nCalled"), x("nTotal"))) }.toMap - } - - def keyTableEqualAnnExpr(annExprResult: scala.collection.Map[IndexedSeq[String], (Long, Long, Long)], keyTableResults: scala.collection.Map[IndexedSeq[String], (String, String, String)]) = - annExprResult.forall{ case (keys, (nHet, nCalled, nTotal)) => - val (ktHet, ktCalled, ktTotal) = keyTableResults(keys.map(k => if (k != null) k else "NA").toIndexedSeq) - ktHet.toLong == nHet && - ktCalled.toLong == nCalled && - ktTotal.toLong == nTotal - } - - property("group by variant id same as variant aggregations") = forAll(VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0)) {case (vds: VariantDataset) => - val outputFile = tmpDir.createTempFile("aggByVariant", ".tsv") - - var s = State(sc, sqlContext, vds) - - s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count(), va.nCalled = gs.filter(g => g.isCalled).count(), va.nTotal = gs.count()")) - - val (_, nHetQuery) = s.vds.queryVA("va.nHet") - val (_, nCalledQuery) = s.vds.queryVA("va.nCalled") - val (_, nTotalQuery) = s.vds.queryVA("va.nTotal") - - val truthResult = s.vds.variantsAndAnnotations.map{ case (v, va) => - (IndexedSeq(v.toString), (nHetQuery(va).get.asInstanceOf[Long], nCalledQuery(va).get.asInstanceOf[Long], nTotalQuery(va).get.asInstanceOf[Long])) - }.collectAsMap() - - s = AddKeyTable.run(s, Array("-k", "Variant = v", - "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", - "-o", outputFile)) - - val keyTableResults = getKeyTableResults(outputFile, IndexedSeq("Variant")) - - keyTableEqualAnnExpr(truthResult, keyTableResults) - } - - property("group by sample id same as sample aggregations") = forAll(VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0)) {case (vds: VariantDataset) => - val outputFile = tmpDir.createTempFile("aggBySample", ".tsv") - - var s = State(sc, sqlContext, vds) - - s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count(), sa.nCalled = gs.filter(g => g.isCalled).count(), sa.nTotal = gs.count()")) - - val (_, nHetQuery) = s.vds.querySA("sa.nHet") - val (_, nCalledQuery) = s.vds.querySA("sa.nCalled") - val (_, nTotalQuery) = s.vds.querySA("sa.nTotal") - - val truthResult = s.vds.sampleIdsAndAnnotations.map{ case (sid, sa) => - (IndexedSeq(sid), (nHetQuery(sa).get.asInstanceOf[Long], nCalledQuery(sa).get.asInstanceOf[Long], nTotalQuery(sa).get.asInstanceOf[Long])) - }.toMap - - s = AddKeyTable.run(s, Array("-k", "Sample = s", - "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", - "-o", outputFile)) - - val keyTableResults = getKeyTableResults(outputFile, IndexedSeq("Sample")) - - keyTableEqualAnnExpr(truthResult, keyTableResults) - } - - property("aggregate by variant groups same") = forAll(compGenVariant) { case (vds, varGroups) => - val outputFile = tmpDir.createTempFile("aggByVariantGroup", ".tsv") - - val nKeys = varGroups.length - val keyNames = (1 to nKeys).map("key" + _) - - var signature = TStruct() - keyNames.foreach(k => signature = signature.appendKey(k, TString)) - - val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => - (v, Annotation(varGroups.map(_ (i).orNull).toSeq: _*)) - }).toOrderedRDD - - var s = State(sc, sqlContext, vds.annotateVariants(variantAnnotations, signature, "va.keys")) - - s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count(), va.nCalled = gs.filter(g => g.isCalled).count(), va.nTotal = gs.count()")) - - val (_, nHetQuery) = s.vds.queryVA("va.nHet") - val (_, nCalledQuery) = s.vds.queryVA("va.nCalled") - val (_, nTotalQuery) = s.vds.queryVA("va.nTotal") - val (_, keyQuery) = s.vds.queryVA("va.keys.*") - - val truthResult = s.vds.variantsAndAnnotations.map{ case (v, va) => - (keyQuery(va).get.asInstanceOf[IndexedSeq[String]], (nHetQuery(va).get.asInstanceOf[Long], nCalledQuery(va).get.asInstanceOf[Long], nTotalQuery(va).get.asInstanceOf[Long])) - }.aggregateByKey((0L, 0L, 0L))((comb, counts) => (comb._1 + counts._1, comb._2 + counts._2, comb._3 + counts._3), - (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() - - s = AddKeyTable.run(s, Array("-k", keyNames.map( k => k + " = " + "va.keys." + k).mkString(","), - "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", - "-o", outputFile)) - - val keyTableResults = getKeyTableResults(outputFile, keyNames) - - keyTableEqualAnnExpr(truthResult, keyTableResults) - } - - property("aggregate by sample groups same") = forAll(compGenSample) { case (vds, phenotypes) => - val outputFile = tmpDir.createTempFile("aggBySampleGroup", ".tsv") - - val nPhenotypes = phenotypes.length - val keyNames = (1 to nPhenotypes).map("key" + _) - - var signature = TStruct() - keyNames.foreach(k => signature = signature.appendKey(k, TString)) - - val phenoMap = vds.sampleIds.zipWithIndex.map{ case (sid, i) => - (sid, Annotation(phenotypes.map(_(i).orNull).toSeq : _*)) - }.toMap - - var s = State(sc, sqlContext, vds.annotateSamples(phenoMap, signature, "sa.pheno")) - - s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count(), sa.nCalled = gs.filter(g => g.isCalled).count(), sa.nTotal = gs.count()")) - - val (_, nHetQuery) = s.vds.querySA("sa.nHet") - val (_, nCalledQuery) = s.vds.querySA("sa.nCalled") - val (_, nTotalQuery) = s.vds.querySA("sa.nTotal") - val (_, phenoQuery) = s.vds.querySA("sa.pheno.*") - - val truthResult = sc.parallelize(s.vds.sampleIdsAndAnnotations).map{ case (sid, sa) => - (phenoQuery(sa).get.asInstanceOf[IndexedSeq[String]], (nHetQuery(sa).get.asInstanceOf[Long], nCalledQuery(sa).get.asInstanceOf[Long], nTotalQuery(sa).get.asInstanceOf[Long])) - }.aggregateByKey((0L, 0L, 0L))((comb, counts) => (comb._1 + counts._1, comb._2 + counts._2, comb._3 + counts._3), - (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() - - s = AddKeyTable.run(s, Array("-k", keyNames.map( k => k + " = " + "sa.pheno." + k).mkString(","), - "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", - "-o", outputFile)) - - val keyTableResults = getKeyTableResults(outputFile, keyNames) - - keyTableEqualAnnExpr(truthResult, keyTableResults) - } - property("aggregate by sample and variants same") = forAll(compGen) { case (vds, sampleGroups, variantGroups) => - val outputFile = tmpDir.createTempFile("aggBySampleVariantGroup", ".tsv") + val outputFile = tmpDir.createTempFile("keyTableTest", "tsv") val nKeys = sampleGroups.length + variantGroups.length val keyNames = (1 to nKeys).map("key" + _) @@ -201,8 +41,8 @@ class AddKeyTableSuite extends SparkSuite { var variantSignature = TStruct() variantKeyNames.foreach(k => variantSignature = variantSignature.appendKey(k, TString)) - val sampleMap = vds.sampleIds.zipWithIndex.map{ case (sid, i) => - (sid, Annotation(sampleGroups.map(_(i).orNull).toSeq : _*)) + val sampleMap = vds.sampleIds.zipWithIndex.map { case (sid, i) => + (sid, Annotation(sampleGroups.map(_ (i).orNull).toSeq: _*)) }.toMap val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => @@ -215,7 +55,7 @@ class AddKeyTableSuite extends SparkSuite { val (_, sampleKeyQuery) = s.vds.querySA("sa.keys.*") val (_, variantKeyQuery) = s.vds.queryVA("va.keys.*") - val keyGenotypeRDD = s.vds.mapWithAll{case (v, va, sid, sa, g) => + val keyGenotypeRDD = s.vds.mapWithAll { case (v, va, sid, sa, g) => val key = sampleKeyQuery(sa).get.asInstanceOf[IndexedSeq[String]] ++ variantKeyQuery(va).get.asInstanceOf[IndexedSeq[String]] (key, g) } @@ -228,9 +68,21 @@ class AddKeyTableSuite extends SparkSuite { "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-o", outputFile)) - val keyTableResults = getKeyTableResults(outputFile, keyNames) + val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => + line.trim.split("\\s+") + }.value).toIndexedSeq) + + val header = ktr.take(1) + + val keyTableResults = ktr.drop(1).map(r => (header(0), r).zipped.toMap) + .map { x => (keyNames.map { k => x(k) }, (x("nHet"), x("nCalled"), x("nTotal"))) }.toMap - keyTableEqualAnnExpr(result, keyTableResults) + result.forall { case (keys, (nHet, nCalled, nTotal)) => + val (ktHet, ktCalled, ktTotal) = keyTableResults(keys.map(k => if (k != null) k else "NA").toIndexedSeq) + ktHet.toLong == nHet && + ktCalled.toLong == nCalled && + ktTotal.toLong == nTotal + } } } From 1d792dd505eb53ec45bd64b19cca77d95954e8b3 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 25 Oct 2016 00:25:52 -0400 Subject: [PATCH 15/51] removed exportaggregator --- .../hail/driver/ExportAggregate.scala | 170 ------------------ 1 file changed, 170 deletions(-) delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala deleted file mode 100644 index cb49c6944c3..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportAggregate.scala +++ /dev/null @@ -1,170 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.utils._ -import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr._ -import org.broadinstitute.hail.utils.{MultiArray2} -import org.broadinstitute.hail.variant._ -import org.kohsuke.args4j.{Option => Args4jOption} - -object ExportAggregate extends Command { - - class Options extends BaseOptions { - - @Args4jOption(required = false, name = "-o", aliases = Array("--output"), - usage = "path of output file") - var output: String = _ - - @Args4jOption(required = true, name = "-k", aliases = Array("--key-condition"), - usage = "named expression for which keys to aggregate on (variant and sample)") - var keyCondition: String = _ - - @Args4jOption(required = true, name = "-a", usage = "named expression for item to compute") - var aggCondition: String = _ - } - - def newOptions = new Options - - def name = "exportaggregate" - - def description = "Aggregate and export samples information grouped by a given variant annnotation" - - def supportsMultiallelic = true - - def requiresVDS = true - - def run(state: State, options: Options): State = { - val vds = state.vds - val sc = vds.sparkContext - val keyCond = options.keyCondition - val aggCond = options.aggCondition - val output = options.output - val vas = vds.vaSignature - val sas = vds.saSignature - val localSamplesBc = vds.sampleIdsBc - val localAnnotationsBc = vds.sampleAnnotationsBc - - val aggregationEC = EvalContext(Map( - "v" -> (0, TVariant), - "va" -> (1, vds.vaSignature), - "s" -> (2, TSample), - "sa" -> (3, vds.saSignature), - "global" -> (4, vds.globalSignature))) - - val ec = EvalContext(Map( - "v" -> (0, TVariant), - "va" -> (1, vds.vaSignature), - "s" -> (2, TSample), - "sa" -> (3, vds.saSignature), - "global" -> (4, vds.globalSignature), - "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype)))) - - aggregationEC.set(4, vds.globalAnnotation) - ec.set(4, vds.globalAnnotation) - - val (aggNames, aggTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) - - if (aggNames.isEmpty) - fatal("need at least 1 aggregation argument") - - val aggregators = aggregationEC.aggregationFunctions.toArray - val aggregatorA = aggregationEC.a - val nAggregations = aggregators.length - - val keyParseResult = Parser.parseNamedArgs(keyCond, ec) - - val sampleGroups = vds.sampleIdsAndAnnotations.map { case (s, sa) => - ec.set(2, s) - ec.set(3, sa) - - keyParseResult._3.apply().toIndexedSeq - } - - // val variantGroupEC = EvalContext( Map( - // "v" -> (0, TVariant), - // "va" -> (1, vds.vaSignature), - // "global" -> (2, vds.globalSignature))) - // variantGroupEC.set(2,vds.globalSignature) - // - // val variantGroupParseResult = Parser.parseNamedArgs(options.byV ,variantGroupEC) - - val distinctSampleGroupMap = sampleGroups.distinct.zipWithIndex.toMap - val siToGroupIndex = sampleGroups.map(distinctSampleGroupMap) - val nSampleGroups = distinctSampleGroupMap.size - - def zero() = { - val baseArray = MultiArray2.fill[Aggregator](nSampleGroups, nAggregations)(null) - for (i <- 0 until nSampleGroups; j <- 0 until nAggregations) { - baseArray.update(i, j, aggregators(j).copy()) - } - baseArray - } - - val mapOp : (Variant, Annotation) => IndexedSeq[Any] = {case (v, va) => - ec.set(0, v) - ec.set(1, va) - keyParseResult._3.apply().toIndexedSeq - } - - val seqOp : (MultiArray2[Aggregator], (Variant, (Annotation, Iterable[Genotype]))) => MultiArray2[Aggregator] = { - case (arr, (v, (va, gs))) => - aggregatorA(0) = v - aggregatorA(1) = va - for ((g, i) <- gs.zipWithIndex) - for (j <- 0 until nAggregations) { - aggregatorA(2) = localSamplesBc.value(i) - aggregatorA(3) = localAnnotationsBc.value(i) - val sampleGroup = siToGroupIndex(i) - arr(sampleGroup, j).seqOp(g) - } - - arr - } - - val combOp : (MultiArray2[Aggregator], MultiArray2[Aggregator]) => MultiArray2[Aggregator] = { - case (arr1, arr2) => - for ((i, j) <- arr1.indices) { - val a1 = arr1(i, j) - a1.combOp(arr2(i, j).asInstanceOf[a1.type]) - } - arr1 - } - - val res = vds.rdd.map { case (v, (va, gs)) => (mapOp(v, va), (v, (va, gs))) } - .aggregateByKey(zero())(seqOp, combOp) - -// res.map{case (key, agg) => key.mkString(",")}.collect().foreach(println(_)) - -// -// def getLine(sampleGroupIndex: Integer, values: MultiArray2[Any], sb:StringBuilder) : String = { -// for (j <- 0 until nAggregations) { -// aggregatorA(aggregators(j).idx) = values(sampleGroupIndex, j) -// } -// -// aggregationParseResult.foreachBetween { case (t, f) => -// sb.append(f().map(TableAnnotationImpex.exportAnnotation(_, t)).getOrElse("NA")) -// } { sb += '\t' } -// sb.result() -// } -// -// res.map({ -// case (variantGroup, values) => -// -// val sb = new StringBuilder() -// val lines = for ((sampleGroup, i) <- distinctSampleGroupMap.keys.zipWithIndex) yield { -// sb.clear() -// sb.append(sampleGroup.map(_.getOrElse("NA").toString).mkString("\t") + "\t") -// getLine(i,values,sb) -// } -// lines.map(variantGroup.map(_.getOrElse("NA").toString).mkString("\t") + "\t" + _).mkString("\n") -// }) -// .writeTable(options.output, -// header = Some(variantGroupParseResult.map(_._1).mkString("\t") + "\t" + -// sampleGroupsParseResult.map(_._1).mkString("\t") + "\t" + -// aggregationHeader.mkString("\t"))) -// - -// - state - } -} From 8fe6718f179984df7aab450fa84abbd97b5a09ab Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 25 Oct 2016 12:15:48 -0400 Subject: [PATCH 16/51] removed explicit casting to IndexedSeq --- src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala index 327be889f2b..365aba3c981 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -88,7 +88,7 @@ object AddKeyTable extends Command { vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) - val key = keyF().toIndexedSeq + val key: IndexedSeq[String] = keyF() (key, (v, va, s, sa, g)) } }.aggregateByKey(zvf())(seqOp, combOp) From 34443f3f0510b2e09aeba9e45359743686456012 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 25 Oct 2016 17:20:06 -0400 Subject: [PATCH 17/51] export and kt in state --- .../hail/driver/AddKeyTable.scala | 22 +++-- .../broadinstitute/hail/driver/ClearKT.scala | 28 ++++++ .../broadinstitute/hail/driver/Command.scala | 4 +- .../hail/driver/ExportKeyTable.scala | 92 +++++++++++++++++++ .../hail/driver/ExportVariants.scala | 1 + .../hail/keytable/KeyTable.scala | 9 ++ .../hail/methods/AddKeyTableSuite.scala | 5 +- 7 files changed, 149 insertions(+), 12 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala create mode 100644 src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala create mode 100644 src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala index 365aba3c981..f2e277265d2 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala @@ -1,6 +1,8 @@ package org.broadinstitute.hail.driver +import org.broadinstitute.hail.annotations.Annotation import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.methods.Aggregators import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} @@ -16,9 +18,9 @@ object AddKeyTable extends Command { usage = "Named aggregation condition", metaVar = "EXPR") var aggCond: String = _ - @Args4jOption(required = true, name = "-o", aliases = Array("--output"), - usage = "output file") - var output: String = _ + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "Name of new key table") + var name: String = _ } def newOptions = new Options @@ -65,6 +67,9 @@ object AddKeyTable extends Command { val (keyNames, keyParseTypes, keyF) = Parser.parseNamedArgs(keyCond, ec) val (aggNames, aggParseTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) + val keySignature = TStruct(keyNames.zip(keyParseTypes): _*) + val aggSignature = TStruct(aggNames.zip(aggParseTypes): _*) + if (keyNames.isEmpty) fatal("this module requires one or more named expr arguments as keys") if (aggNames.isEmpty) @@ -85,19 +90,18 @@ object AddKeyTable extends Command { array } - vds.mapPartitionsWithAll { it => + val kt = KeyTable(vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) - val key: IndexedSeq[String] = keyF() + val key = Annotation.fromSeq(keyF()) (key, (v, va, s, sa, g)) } }.aggregateByKey(zvf())(seqOp, combOp) .map { case (k, agg) => resultOp(agg) - (k ++ aggF()).mkString("\t") - } - .writeTable(options.output, Option((keyNames ++ aggNames).mkString("\t"))) + (k, Annotation.fromSeq(aggF())) + }, keySignature, aggSignature) - state + state.copy(ktEnv = state.ktEnv + (options.name -> kt)) } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala b/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala new file mode 100644 index 00000000000..edf1b34a864 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala @@ -0,0 +1,28 @@ +package org.broadinstitute.hail.driver + +import org.kohsuke.args4j.{Option => Args4jOption} + +object ClearKT extends Command { + + class Options extends BaseOptions { + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "Name of key table to clear") + var name: String = _ + } + + def newOptions = new Options + + def name = "ktclear" + + def description = "Clear key table from environment" + + def supportsMultiallelic = true + + def requiresVDS = false + + def run(state: State, options: Options): State = { + val name = options.name + state.copy( + ktEnv = state.ktEnv - name) + } +} diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index fdc6caa51a3..447059923e6 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -2,6 +2,7 @@ package org.broadinstitute.hail.driver import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.VariantDataset import org.kohsuke.args4j.{Argument, CmdLineException, CmdLineParser, Option => Args4jOption} @@ -13,7 +14,8 @@ case class State(sc: SparkContext, sqlContext: SQLContext, // FIXME make option vds: VariantDataset = null, - env: Map[String, VariantDataset] = Map.empty) { + env: Map[String, VariantDataset] = Map.empty, + ktEnv: Map[String, KeyTable] = Map.empty) { def hadoopConf = sc.hadoopConfiguration } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala new file mode 100644 index 00000000000..75637b92b85 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala @@ -0,0 +1,92 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.expr.{EvalContext, _} +import org.broadinstitute.hail.io.TextExporter +import org.kohsuke.args4j.{Option => Args4jOption} + +object ExportKeyTable extends Command with TextExporter { + + class Options extends BaseOptions { + + @Args4jOption(required = true, name = "-o", aliases = Array("--output"), + usage = "path of output tsv") + var output: String = _ + + @Args4jOption(required = true, name = "-c", aliases = Array("--condition"), + usage = ".columns file, or comma-separated list of fields/computations to be printed to tsv") + var condition: String = _ + + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "name of key table to be printed to tsv") + var name: String = _ + + @Args4jOption(required = false, name = "-t", aliases = Array("--types"), + usage = "Write the types of parse expressions to a file at the given path") + var typesFile: String = _ + + } + + def newOptions = new Options + + def name = "ktexport" + + def description = "Export information from key table to tsv" + + def supportsMultiallelic = true + + def requiresVDS = false + + override def hidden = true + + def run(state: State, options: Options): State = { + + val kt = state.ktEnv.get(options.name) match { + case Some(newKT) => + newKT + case None => + fatal("no such key table $name in environment") + } + + val ks = kt.keySignature + val vs = kt.valueSignature + + val cond = options.condition + val output = options.output + + val symTab = Map( + "k" -> (0, kt.keySignature), + "ka" -> (1, kt.valueSignature), + "global" -> (2, state.vds.globalSignature) + ) + + val ec = EvalContext(symTab) + + val (header, types, f) = Parser.parseExportArgs(cond, ec) + + Option(options.typesFile).foreach { file => + val typeInfo = header + .getOrElse(types.indices.map(i => s"_$i").toArray) + .zip(types) + exportTypes(file, state.hadoopConf, typeInfo) + } + + state.hadoopConf.delete(output, recursive = true) + + kt.rdd + .mapPartitions { it => + val sb = new StringBuilder() + it.map { case (k, v) => + sb.clear() + + ec.setAll(k, v) + + f().foreachBetween(x => sb.append(x))(sb += '\t') + sb.result() + } + }.writeTable(output, header.map(_.mkString("\t"))) + + state + } +} + diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala index 0e2aaaa7b20..256397b33c4 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala @@ -48,6 +48,7 @@ object ExportVariants extends Command with TextExporter { "sa" -> (3, vds.saSignature), "g" -> (4, TGenotype), "global" -> (5, vds.globalSignature))) + val symTab = Map( "v" -> (0, TVariant), "va" -> (1, vds.vaSignature), diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala new file mode 100644 index 00000000000..dedee733a47 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -0,0 +1,9 @@ +package org.broadinstitute.hail.keytable + +import org.apache.spark.rdd.RDD +import org.broadinstitute.hail.annotations._ +import org.broadinstitute.hail.expr.Type +import org.broadinstitute.hail.utils._ + + +case class KeyTable (rdd: RDD[(Annotation, Annotation)], keySignature: Type, valueSignature: Type) diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala index a9ed5dfb2a1..a2886f3b17f 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala @@ -65,8 +65,9 @@ class AddKeyTableSuite extends SparkSuite { (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() s = AddKeyTable.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), - "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", - "-o", outputFile)) + "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-n", "foo")) + + s = ExportKeyTable.run(s, Array("-o", outputFile, "-c", "k.*, ka.*", "-n", "foo")) val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => line.trim.split("\\s+") From 3ed00b3c338e6f3c2b1027a281b879ed872a6e3f Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 25 Oct 2016 17:22:25 -0400 Subject: [PATCH 18/51] made command for clear kt hidden --- src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala b/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala index edf1b34a864..f5ea193f71a 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala @@ -20,6 +20,8 @@ object ClearKT extends Command { def requiresVDS = false + override def hidden = true + def run(state: State, options: Options): State = { val name = options.name state.copy( From 60378da1e3c97a682efc53746bb4541222682bca Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 1 Nov 2016 15:49:14 -0400 Subject: [PATCH 19/51] started filterExpr --- ...AddKeyTable.scala => AggregateByKey.scala} | 12 +--- .../broadinstitute/hail/driver/Command.scala | 2 +- .../hail/driver/ExportKeyTable.scala | 21 ++----- .../hail/driver/FilterKeyTableExpr.scala | 60 +++++++++++++++++++ .../hail/keytable/KeyTable.scala | 32 +++++++++- ...eyTableSuite.scala => KeyTableSuite.scala} | 6 +- 6 files changed, 101 insertions(+), 32 deletions(-) rename src/main/scala/org/broadinstitute/hail/driver/{AddKeyTable.scala => AggregateByKey.scala} (88%) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala rename src/test/scala/org/broadinstitute/hail/methods/{AddKeyTableSuite.scala => KeyTableSuite.scala} (93%) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala similarity index 88% rename from src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala rename to src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala index f2e277265d2..d59dc993702 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AddKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala @@ -4,12 +4,11 @@ import org.broadinstitute.hail.annotations.Annotation import org.broadinstitute.hail.expr._ import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.methods.Aggregators -import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} -object AddKeyTable extends Command { +object AggregateByKey extends Command { - class Options extends BaseOptions with TextTableOptions { + class Options extends BaseOptions { @Args4jOption(required = true, name = "-k", aliases = Array("--key-cond"), usage = "Named key condition", metaVar = "EXPR") var keyCond: String = _ @@ -25,7 +24,7 @@ object AddKeyTable extends Command { def newOptions = new Options - def name = "addkeytable" + def name = "aggregatebykey" def description = "Creates a new key table with key(s) determined by named expressions and additional columns determined by named aggregator expressions" @@ -70,11 +69,6 @@ object AddKeyTable extends Command { val keySignature = TStruct(keyNames.zip(keyParseTypes): _*) val aggSignature = TStruct(aggNames.zip(aggParseTypes): _*) - if (keyNames.isEmpty) - fatal("this module requires one or more named expr arguments as keys") - if (aggNames.isEmpty) - fatal("this module requires one or more named expr arguments to aggregate by key") - val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) val zvf = () => zVals.indices.map(zVals).toArray diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index 447059923e6..40b1efe2097 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -47,7 +47,7 @@ object ToplevelCommands { + cmd.description)) } - register(AddKeyTable) + register(AggregateByKey) register(AggregateIntervals) register(AnnotateSamples) register(AnnotateVariants) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala index 75637b92b85..61a5abaf03c 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala @@ -13,10 +13,6 @@ object ExportKeyTable extends Command with TextExporter { usage = "path of output tsv") var output: String = _ - @Args4jOption(required = true, name = "-c", aliases = Array("--condition"), - usage = ".columns file, or comma-separated list of fields/computations to be printed to tsv") - var condition: String = _ - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), usage = "name of key table to be printed to tsv") var name: String = _ @@ -29,9 +25,9 @@ object ExportKeyTable extends Command with TextExporter { def newOptions = new Options - def name = "ktexport" + def name = "exportkeytable" - def description = "Export information from key table to tsv" + def description = "Export key table to tsv" def supportsMultiallelic = true @@ -48,21 +44,14 @@ object ExportKeyTable extends Command with TextExporter { fatal("no such key table $name in environment") } - val ks = kt.keySignature - val vs = kt.valueSignature - - val cond = options.condition val output = options.output - val symTab = Map( - "k" -> (0, kt.keySignature), - "ka" -> (1, kt.valueSignature), - "global" -> (2, state.vds.globalSignature) - ) + val symTab = Map("k" -> (0, kt.keySignature), + "v" -> (1, kt.valueSignature)) val ec = EvalContext(symTab) - val (header, types, f) = Parser.parseExportArgs(cond, ec) + val (header, types, f) = Parser.parseExportArgs("k.*, v.*", ec) Option(options.typesFile).foreach { file => val typeInfo = header diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala new file mode 100644 index 00000000000..64344c0139c --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala @@ -0,0 +1,60 @@ +package org.broadinstitute.hail.driver + + +import org.broadinstitute.hail.expr.EvalContext +import org.broadinstitute.hail.utils._ +import org.kohsuke.args4j.{Option => Args4jOption} + +object FilterKeyTableExpr extends Command { + class Options extends BaseOptions { + @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), + usage = "Boolean expression for filtering", metaVar = "EXPR") + var condition: String = _ + + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "Name of source key table") + var name: String = _ + + @Args4jOption(required = true, name = "-d", aliases = Array("--dest"), + usage = "Name of destination key table (can be same as source)") + var dest: String = _ + + @Args4jOption(required = false, name = "--keep", usage = "Keep variants matching condition") + var keep: Boolean = false + + @Args4jOption(required = false, name = "--remove", usage = "Remove variants matching condition") + var remove: Boolean = false + } + + def newOptions = new Options + + def name = "filterkeytable expr" + + def description = "Filter key table using a boolean expression" + + def supportsMultiallelic = true + + def requiresVDS = true + + override def hidden = true + + def run(state: State, options: Options): State = { + val kt = state.ktEnv.get(options.name) match { + case Some(newKT) => + newKT + case None => + fatal("no such key table $name in environment") + } + + if (!(options.keep ^ options.remove)) + fatal("either `--keep' or `--remove' required, but not both") + + val cond = options.condition + val keep = options.keep + val dest = options.dest + + state.copy(ktEnv = state.ktEnv + ( dest -> kt.filterRowsExpr(cond, keep))) + + state + } +} diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index dedee733a47..4f7dbc707e1 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -2,8 +2,34 @@ package org.broadinstitute.hail.keytable import org.apache.spark.rdd.RDD import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr.Type -import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct} +import org.broadinstitute.hail.methods.Filter +case class KeyTable (rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { -case class KeyTable (rdd: RDD[(Annotation, Annotation)], keySignature: Type, valueSignature: Type) + val fieldNames = (keySignature.fields ++ valueSignature.fields).map(_.name) + + assert(fieldNames.distinct.length == fieldNames.length) + + def queryKey(code: String): (BaseType, Querier) = ??? + + def queryValue(code: String): (BaseType, Querier) = ??? + + def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter{ case (k, v) => p(k, v)}) + + def filterRowsExpr(cond: String, keep: Boolean): KeyTable = { + val symTab = (keySignature.fields ++ valueSignature.fields) + .zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap + + val ec = EvalContext(symTab) + + val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) + + val p = (k: Annotation, v: Annotation) => { + ec.setAll(k, v) + Filter.keepThis(f(), keep) + } + + filter(p) + } +} diff --git a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala similarity index 93% rename from src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala rename to src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index a2886f3b17f..98fe127a7fb 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AddKeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -10,7 +10,7 @@ import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantDataset, VariantSampleMatrix} import org.testng.annotations.Test -class AddKeyTableSuite extends SparkSuite { +class KeyTableSuite extends SparkSuite { def createKey(nItems: Int, nCategories: Int) = Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) @@ -64,10 +64,10 @@ class AddKeyTableSuite extends SparkSuite { (comb, gt) => (comb._1 + gt.isHet.toInt.toInt, comb._2 + gt.isCalled.toInt.toInt, comb._3 + 1), (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() - s = AddKeyTable.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), + s = AggregateByKey.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-n", "foo")) - s = ExportKeyTable.run(s, Array("-o", outputFile, "-c", "k.*, ka.*", "-n", "foo")) + s = ExportKeyTable.run(s, Array("-o", outputFile, "-n", "foo")) val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => line.trim.split("\\s+") From 488a2a61a1da4029000e727764c1f5472539e644 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 1 Nov 2016 16:14:51 -0400 Subject: [PATCH 20/51] filter works with key table being a pair rdd --- .../hail/driver/FilterKeyTableExpr.scala | 6 ++---- .../org/broadinstitute/hail/keytable/KeyTable.scala | 12 +++++------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala index 64344c0139c..4659c284997 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala @@ -1,7 +1,5 @@ package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.expr.EvalContext import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} @@ -34,7 +32,7 @@ object FilterKeyTableExpr extends Command { def supportsMultiallelic = true - def requiresVDS = true + def requiresVDS = false override def hidden = true @@ -53,7 +51,7 @@ object FilterKeyTableExpr extends Command { val keep = options.keep val dest = options.dest - state.copy(ktEnv = state.ktEnv + ( dest -> kt.filterRowsExpr(cond, keep))) + state.copy(ktEnv = state.ktEnv + ( dest -> kt.filterExpr(cond, keep))) state } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 4f7dbc707e1..8a72614b41e 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -2,22 +2,20 @@ package org.broadinstitute.hail.keytable import org.apache.spark.rdd.RDD import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct} +import org.broadinstitute.hail.expr.{EvalContext, Parser, TBoolean, TStruct} import org.broadinstitute.hail.methods.Filter case class KeyTable (rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { val fieldNames = (keySignature.fields ++ valueSignature.fields).map(_.name) - assert(fieldNames.distinct.length == fieldNames.length) + require(fieldNames.distinct.length == fieldNames.length) - def queryKey(code: String): (BaseType, Querier) = ??? - - def queryValue(code: String): (BaseType, Querier) = ??? + def length = rdd.count() def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter{ case (k, v) => p(k, v)}) - def filterRowsExpr(cond: String, keep: Boolean): KeyTable = { + def filterExpr(cond: String, keep: Boolean): KeyTable = { val symTab = (keySignature.fields ++ valueSignature.fields) .zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap @@ -26,7 +24,7 @@ case class KeyTable (rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) val p = (k: Annotation, v: Annotation) => { - ec.setAll(k, v) + ec.setAll(Seq(k, v): _*) Filter.keepThis(f(), keep) } From 1fb740846af1d873fa1a92aef13e08db03620fb5 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 1 Nov 2016 23:06:29 -0400 Subject: [PATCH 21/51] added import, export, filter --- .../hail/driver/AggregateByKey.scala | 9 ++- .../broadinstitute/hail/driver/Command.scala | 3 + .../hail/driver/ExportKeyTable.scala | 15 +++-- .../hail/driver/FilterKeyTable.scala | 9 +++ .../hail/driver/FilterKeyTableExpr.scala | 6 +- .../hail/driver/ImportKeyTable.scala | 67 +++++++++++++++++++ .../hail/keytable/KeyTable.scala | 23 ++++--- .../hail/methods/KeyTableSuite.scala | 50 ++++++++++++-- 8 files changed, 154 insertions(+), 28 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala create mode 100644 src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala index d59dc993702..19d8a155bef 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala @@ -66,8 +66,7 @@ object AggregateByKey extends Command { val (keyNames, keyParseTypes, keyF) = Parser.parseNamedArgs(keyCond, ec) val (aggNames, aggParseTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) - val keySignature = TStruct(keyNames.zip(keyParseTypes): _*) - val aggSignature = TStruct(aggNames.zip(aggParseTypes): _*) + val signature = TStruct((keyNames ++ aggNames).zip(keyParseTypes ++ aggParseTypes): _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) val zvf = () => zVals.indices.map(zVals).toArray @@ -87,14 +86,14 @@ object AggregateByKey extends Command { val kt = KeyTable(vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) - val key = Annotation.fromSeq(keyF()) + val key = keyF().toIndexedSeq (key, (v, va, s, sa, g)) } }.aggregateByKey(zvf())(seqOp, combOp) .map { case (k, agg) => resultOp(agg) - (k, Annotation.fromSeq(aggF())) - }, keySignature, aggSignature) + Annotation.fromSeq(k ++ aggF()) + }, signature, keyNames) state.copy(ktEnv = state.ktEnv + (options.name -> kt)) } diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index 40b1efe2097..6b0e3545fa2 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -63,6 +63,7 @@ object ToplevelCommands { register(CountBytes) register(Deduplicate) register(DownsampleVariants) + register(ExportKeyTable) register(ExportPlink) register(ExportGEN) register(ExportGenotypes) @@ -73,6 +74,7 @@ object ToplevelCommands { register(ExportVCF) register(FilterAlleles) register(FilterGenotypes) + register(FilterKeyTable) register(Filtermulti) register(FilterSamples) register(FilterVariants) @@ -87,6 +89,7 @@ object ToplevelCommands { register(ImportAnnotations) register(ImportBGEN) register(ImportGEN) + register(ImportKeyTable) register(ImportPlink) register(ImportVCF) register(ImputeSex) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala index 61a5abaf03c..70b07c503bc 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala @@ -1,5 +1,6 @@ package org.broadinstitute.hail.driver +import org.apache.spark.sql.Row import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.expr.{EvalContext, _} import org.broadinstitute.hail.io.TextExporter @@ -46,12 +47,11 @@ object ExportKeyTable extends Command with TextExporter { val output = options.output - val symTab = Map("k" -> (0, kt.keySignature), - "v" -> (1, kt.valueSignature)) + val symTab = kt.signature.fields.zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap val ec = EvalContext(symTab) - val (header, types, f) = Parser.parseExportArgs("k.*, v.*", ec) + val (header, types, f) = Parser.parseExportArgs(kt.fieldNames.map(n => n + " = " + n).mkString(","), ec) Option(options.typesFile).foreach { file => val typeInfo = header @@ -62,13 +62,18 @@ object ExportKeyTable extends Command with TextExporter { state.hadoopConf.delete(output, recursive = true) + val signature = kt.signature + kt.rdd .mapPartitions { it => val sb = new StringBuilder() - it.map { case (k, v) => + it.map { a => sb.clear() - ec.setAll(k, v) + Option(a).map(_.asInstanceOf[Row]) match { + case Some(r) => ec.setAll(r.toSeq: _*) + case None => ec.setAll(Seq.fill(signature.size)(null)) + } f().foreachBetween(x => sb.append(x))(sb += '\t') sb.result() diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala new file mode 100644 index 00000000000..81ccba1a1d8 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala @@ -0,0 +1,9 @@ +package org.broadinstitute.hail.driver + +object FilterKeyTable extends SuperCommand { + def name = "filterkeytable" + + def description = "Filter key tables" + + register(FilterKeyTableExpr) +} diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala index 4659c284997..4df9ad68011 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala @@ -13,7 +13,7 @@ object FilterKeyTableExpr extends Command { usage = "Name of source key table") var name: String = _ - @Args4jOption(required = true, name = "-d", aliases = Array("--dest"), + @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), usage = "Name of destination key table (can be same as source)") var dest: String = _ @@ -49,10 +49,8 @@ object FilterKeyTableExpr extends Command { val cond = options.condition val keep = options.keep - val dest = options.dest + val dest = if (options.dest != null) options.dest else options.name state.copy(ktEnv = state.ktEnv + ( dest -> kt.filterExpr(cond, keep))) - - state } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala new file mode 100644 index 00000000000..148b63be6ab --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala @@ -0,0 +1,67 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.expr.Parser +import org.broadinstitute.hail.keytable.KeyTable +import org.broadinstitute.hail.utils._ +import org.kohsuke.args4j.{Argument, Option => Args4jOption} + +import scala.collection.JavaConverters._ + +object ImportKeyTable extends Command { + + class Options extends BaseOptions with TextTableOptions { + @Argument(usage = "") + var arguments: java.util.ArrayList[String] = new java.util.ArrayList[String]() + + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "name of key table") + var name: String = _ + + @Args4jOption(required = true, name = "-k", aliases = Array("--key-names"), + usage = "comma-separated list of columns to be considered as keys") + var keyNames: String = _ + + @Args4jOption(name = "--npartition", usage = "Number of partitions") + var nPartitions: java.lang.Integer = _ + } + + def newOptions = new Options + + def name = "importkeytable" + + def description = "import key table from tsv" + + def supportsMultiallelic = true + + def requiresVDS = false + + override def hidden = true + + def run(state: State, options: Options): State = { + val files = state.hadoopConf.globAll(options.arguments.asScala) + if (files.isEmpty) + fatal("Arguments referred to no files") + + val keyNames = Parser.parseIdentifierList(options.keyNames) + + val (struct, rdd) = + if (options.nPartitions != null) { + if (options.nPartitions < 1) + fatal("requested number of partitions in -n/--npartitions must be positive") + TextTableReader.read(state.sc)(files, options.config, options.nPartitions) + } else + TextTableReader.read(state.sc)(files, options.config) + + val keyNamesValid = keyNames.forall{ k => + val res = struct.selfField(k).isDefined + if (!res) + println("Key `$k' is not present in input table") + res + } + if (!keyNamesValid) + fatal("Invalid key names given") + + state.copy(ktEnv = state.ktEnv + (options.name -> KeyTable(rdd.map(_.value), struct, keyNames))) + } +} + diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 8a72614b41e..6483b2a2716 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -1,30 +1,33 @@ package org.broadinstitute.hail.keytable import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.broadinstitute.hail.annotations._ import org.broadinstitute.hail.expr.{EvalContext, Parser, TBoolean, TStruct} import org.broadinstitute.hail.methods.Filter -case class KeyTable (rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { +case class KeyTable(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]) { - val fieldNames = (keySignature.fields ++ valueSignature.fields).map(_.name) + val fieldNames = signature.fields.map(_.name) require(fieldNames.distinct.length == fieldNames.length) + require(keyNames.forall(k => signature.selfField(k).isDefined)) - def length = rdd.count() + def nRows = rdd.count() - def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter{ case (k, v) => p(k, v)}) + def filter(p: (Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { a => p(a) }) def filterExpr(cond: String, keep: Boolean): KeyTable = { - val symTab = (keySignature.fields ++ valueSignature.fields) - .zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap - - val ec = EvalContext(symTab) + val ec = EvalContext(signature.fields.map(f => (f.name, f.`type`)): _*) val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) - val p = (k: Annotation, v: Annotation) => { - ec.setAll(Seq(k, v): _*) + val p = (a: Annotation) => { + Option(a).map(_.asInstanceOf[Row]) match { + case Some(r) => ec.setAll(r.toSeq: _*) + case None => ec.setAll(Seq.fill(signature.size)(null)) + } + Filter.keepThis(f(), keep) } diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 98fe127a7fb..e6b3d0e16b8 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -6,12 +6,53 @@ import org.broadinstitute.hail.check.Prop._ import org.broadinstitute.hail.check.{Gen, Properties} import org.broadinstitute.hail.driver._ import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantDataset, VariantSampleMatrix} import org.testng.annotations.Test class KeyTableSuite extends SparkSuite { + @Test def testImportExport() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + val outputFile = tmpDir.createTempFile("ktImpExp", "tsv") + var s = State(sc, sqlContext) + s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status", inputFile)) + s = ExportKeyTable.run(s, Array("-n", "kt1", "-o", outputFile)) + + val importedData = sc.hadoopConfiguration.readLines(inputFile)(_.map(_.value).toIndexedSeq) + val exportedData = sc.hadoopConfiguration.readLines(outputFile)(_.map(_.value).toIndexedSeq) + + intercept[FatalException] { + s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status, BadKeyName", inputFile)) + } + + assert(importedData == exportedData) + } + + @Test def testFilter() = { + val data = Array(Array(5, 9, 0), Array(2, 3, 4), Array(1, 2, 3)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("field1", TInt), ("field2", TInt), ("field3", TInt)) + val keyNames = Array("field1") + val kt = KeyTable(rdd, signature, keyNames) + + var s = State(sc, sqlContext, ktEnv = Map("kt1" -> kt)) + + s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < 3", "-d", "kt2", "--keep")) + assert(s.ktEnv.contains("kt2") && s.ktEnv("kt2").nRows == 2) + + s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < 3 && field3 == 4", "-d", "kt3", "--keep")) + assert(s.ktEnv.contains("kt3") && s.ktEnv("kt3").nRows == 1) + + s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 == 5 && field2 == 9 && field3 == 0", "-d", "kt3", "--remove")) + assert(s.ktEnv.contains("kt3") && s.ktEnv("kt3").nRows == 2) + + s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < -5 && field3 == 100", "--keep")) + assert(s.ktEnv.contains("kt1") && s.ktEnv("kt1").nRows == 0) + + } + def createKey(nItems: Int, nCategories: Int) = Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) @@ -68,6 +109,7 @@ class KeyTableSuite extends SparkSuite { "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-n", "foo")) s = ExportKeyTable.run(s, Array("-o", outputFile, "-n", "foo")) + println(outputFile) val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => line.trim.split("\\s+") @@ -86,9 +128,9 @@ class KeyTableSuite extends SparkSuite { } } } - - @Test def testAddKeyTable() { - Spec.check() - } +// +// @Test def testAddKeyTable() { +// Spec.check() +// } } From 3dc5ffb6bd320b7c73afe2f958a1b3dcb14a9b6f Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Wed, 2 Nov 2016 12:37:34 -0400 Subject: [PATCH 22/51] started writing join --- .../hail/driver/AggregateByKey.scala | 7 +- .../hail/driver/JoinKeyTable.scala | 72 ++++++++ .../hail/expr/JoinAnnotator.scala | 2 +- .../hail/keytable/KeyTable.scala | 26 ++- .../hail/driver/AggregateByKeySuite.scala | 170 ++++++++++++++++++ .../hail/methods/KeyTableSuite.scala | 84 --------- 6 files changed, 274 insertions(+), 87 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala create mode 100644 src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala index 19d8a155bef..9ea18ca9a22 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala @@ -66,6 +66,11 @@ object AggregateByKey extends Command { val (keyNames, keyParseTypes, keyF) = Parser.parseNamedArgs(keyCond, ec) val (aggNames, aggParseTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) + val (testPT, testF) = Parser.parseAnnotationArgs(aggCond, ec) + + println(testPT.mkString("\n")) + + val signature = TStruct((keyNames ++ aggNames).zip(keyParseTypes ++ aggParseTypes): _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) @@ -86,7 +91,7 @@ object AggregateByKey extends Command { val kt = KeyTable(vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) - val key = keyF().toIndexedSeq + val key = keyF(): IndexedSeq[String] (key, (v, va, s, sa, g)) } }.aggregateByKey(zvf())(seqOp, combOp) diff --git a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala new file mode 100644 index 00000000000..22dc74c4619 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala @@ -0,0 +1,72 @@ +package org.broadinstitute.hail.driver + +import org.apache.spark.sql.Row +import org.broadinstitute.hail.expr.{EvalContext, _} +import org.broadinstitute.hail.io.TextExporter +import org.broadinstitute.hail.utils._ +import org.kohsuke.args4j.{Option => Args4jOption} + +object JoinKeyTable extends Command { + + class Options extends BaseOptions { + + @Args4jOption(required = true, name = "-d", aliases = Array("--dest"), + usage = "name of joined key-table") + var destName: String = _ + + @Args4jOption(required = true, name = "-l", aliases = Array("--left-name"), + usage = "name of key-table on left") + var leftName: String = _ + + @Args4jOption(required = true, name = "-r", aliases = Array("--right-name"), + usage = "name of key-table on right") + var rightName: String = _ + + @Args4jOption(required = false, name = "-t", aliases = Array("--join-type"), + usage = "type of join") + var joinType: String = "left" + + @Args4jOption(required = true, name = "-t", aliases = Array("--join-keys"), + usage = "name of columns to join on") + var joinKeys: String = _ + } + + def newOptions = new Options + + def name = "joinkeytable" + + def description = "Join two key tables together to produce new key table" + + def supportsMultiallelic = true + + def requiresVDS = false + + override def hidden = true + + def run(state: State, options: Options): State = { + val ktEnv = state.ktEnv + + val ktLeft = ktEnv.get(options.leftName) match { + case Some(kt) => + kt + case None => + fatal("no such key table $name in environment") + } + + val ktRight = ktEnv.get(options.rightName) match { + case Some(kt) => + kt + case None => + fatal("no such key table $name in environment") + } + + if (ktEnv.contains(options.destName)) + warn("destination name already exists -- overwriting previous key-table") + + + + + state + } +} + diff --git a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala index 2291d823750..c81568e7cfd 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala @@ -27,7 +27,7 @@ trait JoinAnnotator { } def buildInserter(code: String, t: Type, ec: EvalContext, expectedHead: String): (Type, Inserter) = { - val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, expectedHead) + val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Option(expectedHead)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finaltype = parseTypes.foldLeft(t) { case (t, (ids, signature)) => diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 6483b2a2716..80da322761c 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -3,7 +3,7 @@ package org.broadinstitute.hail.keytable import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr.{EvalContext, Parser, TBoolean, TStruct} +import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct} import org.broadinstitute.hail.methods.Filter case class KeyTable(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]) { @@ -15,6 +15,30 @@ case class KeyTable(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[St def nRows = rdd.count() + def leftJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? + def rightJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? + def outerJoin(other: KeyTable): KeyTable = ??? + def innerJoin(other: KeyTable): KeyTable = ??? +// require(keyNames.toSet == other.keyNames.toSet) + // function to make key order the same + + + def query(code: String): (BaseType, Querier) = { + val ec = EvalContext(signature.fields.map(f => (f.name, f.`type`)): _*) + + val (t, f) = Parser.parse(code, ec) + + val f2: Annotation => Option[Any] = { annotation => + Option(annotation).map(_.asInstanceOf[Row]) match { + case Some(r) => ec.setAll(r.toSeq: _*) + case None => ec.setAll(Seq.fill(signature.size)(null)) + } + f() + } + + (t, f2) + } + def filter(p: (Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { a => p(a) }) def filterExpr(cond: String, keep: Boolean): KeyTable = { diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala new file mode 100644 index 00000000000..4fa694e7f3e --- /dev/null +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -0,0 +1,170 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.SparkSuite +import org.broadinstitute.hail.annotations.Annotation +import org.broadinstitute.hail.check.Prop._ +import org.broadinstitute.hail.check.{Gen, Properties} +import org.broadinstitute.hail.expr.{TString, TStruct} +import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantSampleMatrix, _} +import org.testng.annotations.Test + +class AggregateByKeySuite extends SparkSuite { + @Test def replicateSampleAggregation() = { + val inputVCF = "src/test/resources/sample.vcf" + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array("-i", inputVCF)) + s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()")) + s = AggregateByKey.run(s, Array("-k", "Sample = s", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) + + val kt = s.ktEnv("kt") + val vds = s.vds + println(kt.signature.schema) + val (_, ktHetQuery) = kt.query("nHet") + val (_, ktSampleQuery) = kt.query("Sample") + val (_, saHetQuery) = vds.querySA("sa.nHet") + + val ktSampleResults = kt.rdd.map { a => + (ktSampleQuery(a).get.asInstanceOf[String], ktHetQuery(a)) + }.collectAsMap() + ktSampleResults.head._2.foreach(x => println(x.getClass.getName)) + + assert( vds.sampleIdsAndAnnotations.forall{ case (sid, sa) => + (saHetQuery(sa), ktSampleResults(sid)) match { + case (None, None) => true + case (Some(x), Some(y)) => x.asInstanceOf[Long] == y.asInstanceOf[String].toLong + case _ => false + } + }) + } + + @Test def replicateVariantAggregation() = { + val inputVCF = "src/test/resources/sample.vcf" + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array("-i", inputVCF)) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) + s = AggregateByKey.run(s, Array("-k", "Variant = v", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) + + val kt = s.ktEnv("kt") + val vds = s.vds + + val (_, ktHetQuery) = kt.query("nHet") + val (_, ktVariantQuery) = kt.query("Variant") + val (_, vaHetQuery) = vds.queryVA("va.nHet") + + val ktVariantResults = kt.rdd.map { a => + (ktVariantQuery(a).get.asInstanceOf[String], ktHetQuery(a)) + }.collectAsMap() + + assert( vds.variantsAndAnnotations.map{ case (v, va) => + (vaHetQuery(va), ktVariantResults(v.toString)) match { + case (None, None) => true + case (Some(x), Some(y)) => x.asInstanceOf[Long] == y.asInstanceOf[String].toLong + case _ => false + } + }.fold(true)(_ && _)) + } + +// @Test def replicateGlobalAggregation() = { +// val inputVCF = "src/test/resources/sample.vcf" +// var s = State(sc, sqlContext) +// s = ImportVCF.run(s, Array("-i", inputVCF)) +// s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) +// s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum()")) +// s = AggregateByKey.run(s, Array("-k", " ", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) +// +// val kt = s.ktEnv("kt") +// val vds = s.vds +// +// val (_, ktHetQuery) = kt.query("nHet") +// val (_, globalHet) = vds.queryGlobal("global.nHet") +// +// val ktGlobalResults = kt.rdd.map{ a => ktHetQuery(a).asInstanceOf[String].toLong}.collect() +// assert(ktGlobalResults.length == 1 && globalHet.get.asInstanceOf[Long] == ktGlobalResults(0)) +// } + + +// def createKey(nItems: Int, nCategories: Int) = +// Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) +// +// def createKeys(nKeys: Int, nItems: Int) = +// Gen.buildableOfN[Array, Array[Option[String]]](nKeys, createKey(nItems, Gen.choose(1, 10).sample())) +// +// object Spec extends Properties("CreateKeyTable") { +// val compGen = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); +// nKeys <- Gen.choose(1, 5); +// nSampleKeys <- Gen.choose(0, nKeys); +// nVariantKeys <- Gen.const(nKeys - nSampleKeys); +// sampleGroups <- createKeys(nSampleKeys, vds.nSamples); +// variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) +// ) yield (vds, sampleGroups, variantGroups) +// +// property("aggregate by sample and variants same") = forAll(compGen) { case (vds, sampleGroups, variantGroups) => +// val outputFile = tmpDir.createTempFile("keyTableTest", "tsv") +// +// val nKeys = sampleGroups.length + variantGroups.length +// val keyNames = (1 to nKeys).map("key" + _) +// val sampleKeyNames = (1 to sampleGroups.length).map("key" + _) +// val variantKeyNames = (sampleGroups.length + 1 to nKeys).map("key" + _) +// +// var sampleSignature = TStruct() +// sampleKeyNames.foreach(k => sampleSignature = sampleSignature.appendKey(k, TString)) +// +// var variantSignature = TStruct() +// variantKeyNames.foreach(k => variantSignature = variantSignature.appendKey(k, TString)) +// +// val sampleMap = vds.sampleIds.zipWithIndex.map { case (sid, i) => +// (sid, Annotation(sampleGroups.map(_ (i).orNull).toSeq: _*)) +// }.toMap +// +// val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => +// (v, Annotation(variantGroups.map(_ (i).orNull).toSeq: _*)) +// }).toOrderedRDD +// +// var s = State(sc, sqlContext, vds.annotateSamples(sampleMap, sampleSignature, "sa.keys") +// .annotateVariants(variantAnnotations, variantSignature, "va.keys")) +// +// val (_, sampleKeyQuery) = s.vds.querySA("sa.keys.*") +// val (_, variantKeyQuery) = s.vds.queryVA("va.keys.*") +// +// val keyGenotypeRDD = s.vds.mapWithAll { case (v, va, sid, sa, g) => +// val key = sampleKeyQuery(sa).get.asInstanceOf[IndexedSeq[String]] ++ variantKeyQuery(va).get.asInstanceOf[IndexedSeq[String]] +// (key, g) +// } +// +// val result = keyGenotypeRDD.aggregateByKey((0L, 0L, 0L))( +// (comb, gt) => (comb._1 + gt.isHet.toInt.toInt, comb._2 + gt.isCalled.toInt.toInt, comb._3 + 1), +// (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() +// +// s = AggregateByKey.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), +// "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-n", "foo")) +// +// s = ExportKeyTable.run(s, Array("-o", outputFile, "-n", "foo")) +// println(outputFile) +// +// val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => +// line.trim.split("\\s+") +// }.value).toIndexedSeq) +// +// val header = ktr.take(1) +// +// val keyTableResults = ktr.drop(1).map(r => (header(0), r).zipped.toMap) +// .map { x => (keyNames.map { k => x(k) }, (x("nHet"), x("nCalled"), x("nTotal"))) }.toMap +// +// result.forall { case (keys, (nHet, nCalled, nTotal)) => +// val (ktHet, ktCalled, ktTotal) = keyTableResults(keys.map(k => if (k != null) k else "NA").toIndexedSeq) +// ktHet.toLong == nHet && +// ktCalled.toLong == nCalled && +// ktTotal.toLong == nTotal +// } +// } +// } +// // +// // @Test def testAddKeyTable() { +// // Spec.check() +// // } +// +// +// @Test def test() { +// +// } +} diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index e6b3d0e16b8..46099a6c8e8 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -2,13 +2,10 @@ package org.broadinstitute.hail.methods import org.broadinstitute.hail.SparkSuite import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.check.Prop._ -import org.broadinstitute.hail.check.{Gen, Properties} import org.broadinstitute.hail.driver._ import org.broadinstitute.hail.expr._ import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils._ -import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantDataset, VariantSampleMatrix} import org.testng.annotations.Test class KeyTableSuite extends SparkSuite { @@ -52,85 +49,4 @@ class KeyTableSuite extends SparkSuite { assert(s.ktEnv.contains("kt1") && s.ktEnv("kt1").nRows == 0) } - - def createKey(nItems: Int, nCategories: Int) = - Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) - - def createKeys(nKeys: Int, nItems: Int) = - Gen.buildableOfN[Array, Array[Option[String]]](nKeys, createKey(nItems, Gen.choose(1, 10).sample())) - - object Spec extends Properties("CreateKeyTable") { - val compGen = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); - nKeys <- Gen.choose(1, 5); - nSampleKeys <- Gen.choose(0, nKeys); - nVariantKeys <- Gen.const(nKeys - nSampleKeys); - sampleGroups <- createKeys(nSampleKeys, vds.nSamples); - variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) - ) yield (vds, sampleGroups, variantGroups) - - property("aggregate by sample and variants same") = forAll(compGen) { case (vds, sampleGroups, variantGroups) => - val outputFile = tmpDir.createTempFile("keyTableTest", "tsv") - - val nKeys = sampleGroups.length + variantGroups.length - val keyNames = (1 to nKeys).map("key" + _) - val sampleKeyNames = (1 to sampleGroups.length).map("key" + _) - val variantKeyNames = (sampleGroups.length + 1 to nKeys).map("key" + _) - - var sampleSignature = TStruct() - sampleKeyNames.foreach(k => sampleSignature = sampleSignature.appendKey(k, TString)) - - var variantSignature = TStruct() - variantKeyNames.foreach(k => variantSignature = variantSignature.appendKey(k, TString)) - - val sampleMap = vds.sampleIds.zipWithIndex.map { case (sid, i) => - (sid, Annotation(sampleGroups.map(_ (i).orNull).toSeq: _*)) - }.toMap - - val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => - (v, Annotation(variantGroups.map(_ (i).orNull).toSeq: _*)) - }).toOrderedRDD - - var s = State(sc, sqlContext, vds.annotateSamples(sampleMap, sampleSignature, "sa.keys") - .annotateVariants(variantAnnotations, variantSignature, "va.keys")) - - val (_, sampleKeyQuery) = s.vds.querySA("sa.keys.*") - val (_, variantKeyQuery) = s.vds.queryVA("va.keys.*") - - val keyGenotypeRDD = s.vds.mapWithAll { case (v, va, sid, sa, g) => - val key = sampleKeyQuery(sa).get.asInstanceOf[IndexedSeq[String]] ++ variantKeyQuery(va).get.asInstanceOf[IndexedSeq[String]] - (key, g) - } - - val result = keyGenotypeRDD.aggregateByKey((0L, 0L, 0L))( - (comb, gt) => (comb._1 + gt.isHet.toInt.toInt, comb._2 + gt.isCalled.toInt.toInt, comb._3 + 1), - (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() - - s = AggregateByKey.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), - "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-n", "foo")) - - s = ExportKeyTable.run(s, Array("-o", outputFile, "-n", "foo")) - println(outputFile) - - val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => - line.trim.split("\\s+") - }.value).toIndexedSeq) - - val header = ktr.take(1) - - val keyTableResults = ktr.drop(1).map(r => (header(0), r).zipped.toMap) - .map { x => (keyNames.map { k => x(k) }, (x("nHet"), x("nCalled"), x("nTotal"))) }.toMap - - result.forall { case (keys, (nHet, nCalled, nTotal)) => - val (ktHet, ktCalled, ktTotal) = keyTableResults(keys.map(k => if (k != null) k else "NA").toIndexedSeq) - ktHet.toLong == nHet && - ktCalled.toLong == nCalled && - ktTotal.toLong == nTotal - } - } - } -// -// @Test def testAddKeyTable() { -// Spec.check() -// } - } From 2e256c8131e4bf5886167c96c1917102910c2e88 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Wed, 2 Nov 2016 16:07:31 -0400 Subject: [PATCH 23/51] implemented interface --- .../hail/driver/AggregateByKey.scala | 14 +-- .../hail/driver/ExportKeyTable.scala | 17 +-- .../hail/driver/ImportKeyTable.scala | 4 +- .../org/broadinstitute/hail/expr/AST.scala | 2 + .../hail/expr/JoinAnnotator.scala | 2 +- .../hail/keytable/KeyTable.scala | 103 ++++++++++++++---- .../hail/driver/AggregateByKeySuite.scala | 8 +- 7 files changed, 101 insertions(+), 49 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala index 9ea18ca9a22..de8ad4a7983 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala @@ -66,12 +66,8 @@ object AggregateByKey extends Command { val (keyNames, keyParseTypes, keyF) = Parser.parseNamedArgs(keyCond, ec) val (aggNames, aggParseTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) - val (testPT, testF) = Parser.parseAnnotationArgs(aggCond, ec) - - println(testPT.mkString("\n")) - - - val signature = TStruct((keyNames ++ aggNames).zip(keyParseTypes ++ aggParseTypes): _*) + val keySignature = TStruct(keyNames.zip(keyParseTypes): _*) + val valueSignature = TStruct(aggNames.zip(aggParseTypes): _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) val zvf = () => zVals.indices.map(zVals).toArray @@ -91,14 +87,14 @@ object AggregateByKey extends Command { val kt = KeyTable(vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) - val key = keyF(): IndexedSeq[String] + val key = Annotation.fromSeq(keyF()) (key, (v, va, s, sa, g)) } }.aggregateByKey(zvf())(seqOp, combOp) .map { case (k, agg) => resultOp(agg) - Annotation.fromSeq(k ++ aggF()) - }, signature, keyNames) + (k, Annotation.fromSeq(aggF())) + }, keySignature, valueSignature) state.copy(ktEnv = state.ktEnv + (options.name -> kt)) } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala index 70b07c503bc..b152dd4fad8 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala @@ -1,9 +1,9 @@ package org.broadinstitute.hail.driver -import org.apache.spark.sql.Row -import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.expr.{EvalContext, _} import org.broadinstitute.hail.io.TextExporter +import org.broadinstitute.hail.keytable.KeyTable +import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} object ExportKeyTable extends Command with TextExporter { @@ -47,7 +47,7 @@ object ExportKeyTable extends Command with TextExporter { val output = options.output - val symTab = kt.signature.fields.zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap + val symTab = kt.fields.zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap val ec = EvalContext(symTab) @@ -62,19 +62,12 @@ object ExportKeyTable extends Command with TextExporter { state.hadoopConf.delete(output, recursive = true) - val signature = kt.signature - kt.rdd .mapPartitions { it => val sb = new StringBuilder() - it.map { a => + it.map { case (k, v) => sb.clear() - - Option(a).map(_.asInstanceOf[Row]) match { - case Some(r) => ec.setAll(r.toSeq: _*) - case None => ec.setAll(Seq.fill(signature.size)(null)) - } - + KeyTable.setEvalContext(ec, k, v) f().foreachBetween(x => sb.append(x))(sb += '\t') sb.result() } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala index 148b63be6ab..cfe13fa755f 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala @@ -1,6 +1,6 @@ package org.broadinstitute.hail.driver -import org.broadinstitute.hail.expr.Parser +import org.broadinstitute.hail.expr.{Parser, TStruct} import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Argument, Option => Args4jOption} @@ -55,7 +55,7 @@ object ImportKeyTable extends Command { val keyNamesValid = keyNames.forall{ k => val res = struct.selfField(k).isDefined if (!res) - println("Key `$k' is not present in input table") + println(s"Key `$k' is not present in input table") res } if (!keyNamesValid) diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 08b199ff3bc..537e57eb9cd 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -26,6 +26,8 @@ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunction def set(index: Int, arg: Any) { a(index) = arg } + + def clear() = a.indices.foreach { i => a(i) = null } } object EvalContext { diff --git a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala index c81568e7cfd..2291d823750 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala @@ -27,7 +27,7 @@ trait JoinAnnotator { } def buildInserter(code: String, t: Type, ec: EvalContext, expectedHead: String): (Type, Inserter) = { - val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Option(expectedHead)) + val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, expectedHead) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finaltype = parseTypes.foldLeft(t) { case (t, (ids, signature)) => diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 80da322761c..39289eeeef9 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -6,52 +6,113 @@ import org.broadinstitute.hail.annotations._ import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct} import org.broadinstitute.hail.methods.Filter -case class KeyTable(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]) { +object KeyTable extends Serializable { + def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation) = { + (Option(k).map(_.asInstanceOf[Row]), Option(v).map(_.asInstanceOf[Row])) match { + case (Some(kr), Some(vr)) => ec.setAll(kr.toSeq ++ vr.toSeq: _*) + case _ => ec.clear() + } + } + + def setEvalContext(ec: EvalContext, a: Annotation) = { + Option(a).map(_.asInstanceOf[Row]) match { + case Some(r) => ec.setAll(r.toSeq: _*) + case _ => ec.clear() + } + } + + def pairSignature(signature: TStruct, keyNames: Array[String]): (TStruct, TStruct) = { + val keyNameSet = keyNames.toSet + (TStruct(signature.fields.filter(fd => keyNameSet.contains(fd.name))), + TStruct(signature.fields.filterNot(fd => keyNameSet.contains(fd.name)))) + } + + def singleSignature(keySignature: TStruct, valueSignature: TStruct): (TStruct, Array[String]) = + (TStruct(keySignature.fields ++ valueSignature.fields), keySignature.fields.map(_.name).toArray) + + def toSingleRDD(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct): RDD[Annotation] = + rdd.map{ case (k, v) => Annotation(Option(k).map(_.asInstanceOf[Row]).toSeq ++ Option(v).map(_.asInstanceOf[Row]).toSeq: _*) } + + def toPairRDD(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): RDD[(Annotation, Annotation)] = { + val keyNameSet = keyNames.toSet + val keyIndices = signature.fields.filter(fd => keyNames.contains(fd.name)).map(_.index).toSet + val valueIndices = signature.fields.filterNot(fd => keyNames.contains(fd.name)).map(_.index).toSet - val fieldNames = signature.fields.map(_.name) + rdd.map { a => + val r = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill(signature.size)(null)).zipWithIndex + val keyRow = r.filter{ case (ann, i) => keyIndices.contains(i) }.map(_._1) + val valueRow = r.filter{ case (ann, i) => valueIndices.contains(i) }.map(_._1) + (Annotation.fromSeq(keyRow), Annotation.fromSeq(valueRow)) + } + } + + def apply(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): KeyTable = { + val (keySignature, valueSignature) = pairSignature(signature, keyNames) + KeyTable(toPairRDD(rdd, signature, keyNames), keySignature, valueSignature) + } +} - require(fieldNames.distinct.length == fieldNames.length) - require(keyNames.forall(k => signature.selfField(k).isDefined)) +case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { + + require(fieldNames.toSet.size == fieldNames.length) + + def signature = KeyTable.singleSignature(keySignature, valueSignature)._1 + def fields = signature.fields + + def keySchema = keySignature.schema + def valueSchema = valueSignature.schema + def schema = signature.schema + + def keyNames = keySignature.fields.map(_.name) + def valueNames = valueSignature.fields.map(_.name) + def fieldNames = keyNames ++ valueNames def nRows = rdd.count() + def nFields = fields.length - def leftJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? - def rightJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? + def leftJoin(other: KeyTable): KeyTable = ??? + def rightJoin(other: KeyTable): KeyTable = ??? def outerJoin(other: KeyTable): KeyTable = ??? def innerJoin(other: KeyTable): KeyTable = ??? // require(keyNames.toSet == other.keyNames.toSet) // function to make key order the same - def query(code: String): (BaseType, Querier) = { - val ec = EvalContext(signature.fields.map(f => (f.name, f.`type`)): _*) + def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) val (t, f) = Parser.parse(code, ec) - val f2: Annotation => Option[Any] = { annotation => - Option(annotation).map(_.asInstanceOf[Row]) match { - case Some(r) => ec.setAll(r.toSeq: _*) - case None => ec.setAll(Seq.fill(signature.size)(null)) - } + val f2: (Annotation, Annotation) => Option[Any] = { case (k, v) => + KeyTable.setEvalContext(ec, k, v) f() } (t, f2) } - def filter(p: (Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { a => p(a) }) +// def query(code: String): (BaseType, Querier) = { +// val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) +// +// val (t, f) = Parser.parse(code, ec) +// +// val f2: (Annotation) => Option[Any] = { a => +// KeyTable.setEvalContext(ec, a) +// f() +// } +// +// (t, f2) +// } + + def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { case (k, v) => p(k, v) }) def filterExpr(cond: String, keep: Boolean): KeyTable = { - val ec = EvalContext(signature.fields.map(f => (f.name, f.`type`)): _*) + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) - val p = (a: Annotation) => { - Option(a).map(_.asInstanceOf[Row]) match { - case Some(r) => ec.setAll(r.toSeq: _*) - case None => ec.setAll(Seq.fill(signature.size)(null)) - } - + val p = (k: Annotation, v: Annotation) => { + KeyTable.setEvalContext(ec, k, v) Filter.keepThis(f(), keep) } diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala index 4fa694e7f3e..ec7878a4d91 100644 --- a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -23,8 +23,8 @@ class AggregateByKeySuite extends SparkSuite { val (_, ktSampleQuery) = kt.query("Sample") val (_, saHetQuery) = vds.querySA("sa.nHet") - val ktSampleResults = kt.rdd.map { a => - (ktSampleQuery(a).get.asInstanceOf[String], ktHetQuery(a)) + val ktSampleResults = kt.rdd.map { case (k, v) => + (ktSampleQuery(k, v).get.asInstanceOf[String], ktHetQuery(k, v)) }.collectAsMap() ktSampleResults.head._2.foreach(x => println(x.getClass.getName)) @@ -51,8 +51,8 @@ class AggregateByKeySuite extends SparkSuite { val (_, ktVariantQuery) = kt.query("Variant") val (_, vaHetQuery) = vds.queryVA("va.nHet") - val ktVariantResults = kt.rdd.map { a => - (ktVariantQuery(a).get.asInstanceOf[String], ktHetQuery(a)) + val ktVariantResults = kt.rdd.map { case (k, v) => + (ktVariantQuery(k, v).get.asInstanceOf[String], ktHetQuery(k, v)) }.collectAsMap() assert( vds.variantsAndAnnotations.map{ case (v, va) => From ee4f2db38bc0df6d227e03862af6998341e51749 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 3 Nov 2016 11:48:15 -0400 Subject: [PATCH 24/51] got aggregateByKey working --- .../hail/driver/AggregateByKey.scala | 28 +++++--- .../hail/driver/AggregateIntervals.scala | 2 +- .../hail/driver/AnnotateGlobalExpr.scala | 4 +- .../hail/driver/AnnotateSamplesExpr.scala | 4 +- .../hail/driver/AnnotateVariantsExpr.scala | 4 +- .../hail/driver/ExportGenotypes.scala | 2 +- .../hail/driver/ExportKeyTable.scala | 6 +- .../hail/driver/ExportSamples.scala | 2 +- .../hail/driver/ExportVariants.scala | 2 +- .../hail/driver/ExportVariantsCass.scala | 4 +- .../hail/driver/FilterAlleles.scala | 4 +- .../hail/expr/JoinAnnotator.scala | 4 +- .../org/broadinstitute/hail/expr/Parser.scala | 33 ++++++---- .../hail/keytable/KeyTable.scala | 16 +++-- .../hail/driver/AggregateByKeySuite.scala | 66 ++++++++----------- 15 files changed, 98 insertions(+), 83 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala index de8ad4a7983..537db5ca6ee 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala @@ -9,11 +9,11 @@ import org.kohsuke.args4j.{Option => Args4jOption} object AggregateByKey extends Command { class Options extends BaseOptions { - @Args4jOption(required = true, name = "-k", aliases = Array("--key-cond"), + @Args4jOption(required = false, name = "-k", aliases = Array("--key-cond"), usage = "Named key condition", metaVar = "EXPR") var keyCond: String = _ - @Args4jOption(required = true, name = "-a", aliases = Array("--agg-cond"), + @Args4jOption(required = false, name = "-a", aliases = Array("--agg-cond"), usage = "Named aggregation condition", metaVar = "EXPR") var aggCond: String = _ @@ -63,11 +63,23 @@ object AggregateByKey extends Command { ec.set(4, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (keyNames, keyParseTypes, keyF) = Parser.parseNamedArgs(keyCond, ec) - val (aggNames, aggParseTypes, aggF) = Parser.parseNamedArgs(aggCond, ec) + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) - val keySignature = TStruct(keyNames.zip(keyParseTypes): _*) - val valueSignature = TStruct(aggNames.zip(aggParseTypes): _*) + val (aggNameParseTypes, aggF) = + if (aggCond != null) + Parser.parseAnnotationArgs(aggCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) val zvf = () => zVals.indices.map(zVals).toArray @@ -87,13 +99,13 @@ object AggregateByKey extends Command { val kt = KeyTable(vds.mapPartitionsWithAll { it => it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) - val key = Annotation.fromSeq(keyF()) + val key = Annotation.fromSeq(keyF.map(_ ())) (key, (v, va, s, sa, g)) } }.aggregateByKey(zvf())(seqOp, combOp) .map { case (k, agg) => resultOp(agg) - (k, Annotation.fromSeq(aggF())) + (k, Annotation.fromSeq(aggF.map(_ ()))) }, keySignature, valueSignature) state.copy(ktEnv = state.ktEnv + (options.name -> kt)) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala index dee21c588f1..e94a6689865 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala @@ -58,7 +58,7 @@ object AggregateIntervals extends Command { ec.set(1, vds.globalAnnotation) aggregationEC.set(1, vds.globalAnnotation) - val (header, _, f) = Parser.parseNamedArgs(cond, ec) + val (header, _, f) = Parser.parseExportArgs(cond, ec) if (header.isEmpty) fatal("this module requires one or more named expr arguments") diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala index f0b9e62a9e4..802d1400f35 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala @@ -48,7 +48,7 @@ object AnnotateGlobalExpr extends Command { aggECS.set(1, vds.globalAnnotation) aggECV.set(1, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.GLOBAL_HEAD) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Option(Annotation.GLOBAL_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] @@ -82,7 +82,7 @@ object AnnotateGlobalExpr extends Command { val ga = inserters .zip(fns.map(_ ())) .foldLeft(vds.globalAnnotation) { case (a, (ins, res)) => - ins(a, res) + ins(a, Option(res)) } state.copy(vds = vds.copy( diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala index 4bbfa9663a4..ce2a40af47d 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala @@ -47,7 +47,7 @@ object AnnotateSamplesExpr extends Command { ec.set(2, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.SAMPLE_HEAD) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Option(Annotation.SAMPLE_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = parseTypes.foldLeft(vds.saSignature) { case (sas, (ids, signature)) => @@ -70,7 +70,7 @@ object AnnotateSamplesExpr extends Command { fns.zip(inserters) .foldLeft(sa) { case (sa, (fn, inserter)) => - inserter(sa, fn()) + inserter(sa, Option(fn())) } } state.copy(vds = vds.copy( diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala index d10b20810f7..dd82abe473f 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala @@ -49,7 +49,7 @@ object AnnotateVariantsExpr extends Command { ec.set(2, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.VARIANT_HEAD) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Option(Annotation.VARIANT_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = parseTypes.foldLeft(vds.vaSignature) { case (vas, (ids, signature)) => @@ -67,7 +67,7 @@ object AnnotateVariantsExpr extends Command { aggregateOption.foreach(f => f(v, va, gs)) fns.zip(inserters) .foldLeft(va) { case (va, (fn, inserter)) => - inserter(va, fn()) + inserter(va, Option(fn())) } }.copy(vaSignature = finalType) state.copy(vds = annotated) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala index b51ef90d540..f9ef1f99e75 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala @@ -59,7 +59,7 @@ object ExportGenotypes extends Command with TextExporter { val ec = EvalContext(symTab) ec.set(5, vds.globalAnnotation) - val (header, ts, f) = Parser.parseExportArgs(cond, ec) + val (header, ts, f) = Parser.parseNamedArgs(cond, ec) Option(options.typesFile).foreach { file => val typeInfo = header diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala index b152dd4fad8..0fa1e9f79d9 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala @@ -51,7 +51,7 @@ object ExportKeyTable extends Command with TextExporter { val ec = EvalContext(symTab) - val (header, types, f) = Parser.parseExportArgs(kt.fieldNames.map(n => n + " = " + n).mkString(","), ec) + val (header, types, f) = Parser.parseNamedArgs(kt.fieldNames.map(n => n + " = " + n).mkString(","), ec) Option(options.typesFile).foreach { file => val typeInfo = header @@ -62,12 +62,14 @@ object ExportKeyTable extends Command with TextExporter { state.hadoopConf.delete(output, recursive = true) + val nKeys = kt.nKeys + kt.rdd .mapPartitions { it => val sb = new StringBuilder() it.map { case (k, v) => sb.clear() - KeyTable.setEvalContext(ec, k, v) + KeyTable.setEvalContext(ec, k, v, nKeys) f().foreachBetween(x => sb.append(x))(sb += '\t') sb.result() } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala index ea5d75f974e..333c2f7a7ba 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala @@ -58,7 +58,7 @@ object ExportSamples extends Command with TextExporter { ec.set(2, vds.globalAnnotation) aggregationEC.set(5, vds.globalAnnotation) - val (header, types, f) = Parser.parseExportArgs(cond, ec) + val (header, types, f) = Parser.parseNamedArgs(cond, ec) Option(options.typesFile).foreach { file => val typeInfo = header .getOrElse(types.indices.map(i => s"_$i").toArray) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala index 256397b33c4..46389d657c4 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala @@ -60,7 +60,7 @@ object ExportVariants extends Command with TextExporter { ec.set(2, vds.globalAnnotation) aggregationEC.set(5, vds.globalAnnotation) - val (header, types, f) = Parser.parseExportArgs(cond, ec) + val (header, types, f) = Parser.parseNamedArgs(cond, ec) Option(options.typesFile).foreach { file => val typeInfo = header diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala index 482b7ab9bf0..9c039890401 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala @@ -153,7 +153,7 @@ object ExportVariantsCass extends Command { val vEC = EvalContext(vSymTab) val vA = vEC.a - val (vHeader, vTypes, vf) = Parser.parseNamedArgs(vCond, vEC) + val (vHeader, vTypes, vf) = Parser.parseExportArgs(vCond, vEC) val gSymTab = Map( "v" -> (0, TVariant), @@ -164,7 +164,7 @@ object ExportVariantsCass extends Command { val gEC = EvalContext(gSymTab) val gA = gEC.a - val (gHeader, gTypes, gf) = Parser.parseNamedArgs(gCond, gEC) + val (gHeader, gTypes, gf) = Parser.parseExportArgs(gCond, gEC) val symTab = Map( "v" -> (0, TVariant), diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala index 6939be79ad3..62d00cb7ca5 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala @@ -72,7 +72,7 @@ object FilterAlleles extends Command { "v" -> (0, TVariant), "va" -> (1, state.vds.vaSignature), "aIndices" -> (2, TArray(TInt)))) - val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Annotation.VARIANT_HEAD) + val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Option(Annotation.VARIANT_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = types.foldLeft(state.vds.vaSignature) { case (vas, (path, signature)) => val (newVas, i) = vas.insert(signature, path) @@ -120,7 +120,7 @@ object FilterAlleles extends Command { def updateAnnotation(v: Variant, va: Annotation, newToOld: IndexedSeq[Int]): Annotation = { annotationEC.setAll(v, va, newToOld) - generators.zip(inserters).foldLeft(va) { case (va, (fn, inserter)) => inserter(va, fn()) } + generators.zip(inserters).foldLeft(va) { case (va, (fn, inserter)) => inserter(va, Option(fn())) } } def updateGenotypes(gs: Iterable[Genotype], oldToNew: Array[Int], newCount: Int): Iterable[Genotype] = { diff --git a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala index 2291d823750..3bc56290624 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala @@ -27,7 +27,7 @@ trait JoinAnnotator { } def buildInserter(code: String, t: Type, ec: EvalContext, expectedHead: String): (Type, Inserter) = { - val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, expectedHead) + val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Option(expectedHead)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finaltype = parseTypes.foldLeft(t) { case (t, (ids, signature)) => @@ -45,7 +45,7 @@ trait JoinAnnotator { val queries = fns.map(_ ()) var newAnnotation = left queries.indices.foreach { i => - newAnnotation = inserters(i)(newAnnotation, queries(i)) + newAnnotation = inserters(i)(newAnnotation, Option(queries(i))) } newAnnotation } diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index bd1d3186fb9..2ce6839c490 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -93,8 +93,8 @@ object Parser extends JavaTokenParsers { case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } } - - def parseExportArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = { + + def parseNamedArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = { val result = parseAll(export_args, code) match { case Success(r, _) => r case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) @@ -144,8 +144,8 @@ object Parser extends JavaTokenParsers { (someIf(names.nonEmpty, names), tb.result(), () => computations.flatMap(_ ())) } - def parseNamedArgs(code: String, ec: EvalContext): (Array[String], Array[Type], () => Array[String]) = { - val (headerOption, ts, f) = parseExportArgs(code, ec) + def parseExportArgs(code: String, ec: EvalContext): (Array[String], Array[Type], () => Array[String]) = { + val (headerOption, ts, f) = parseNamedArgs(code, ec) val header = headerOption match { case Some(h) => h case None => fatal( @@ -155,22 +155,24 @@ object Parser extends JavaTokenParsers { (header, ts, f) } - def parseAnnotationArgs(code: String, ec: EvalContext, expectedHead: String): (Array[(List[String], Type)], Array[() => Option[Any]]) = { + def parseAnnotationArgs(code: String, ec: EvalContext, expectedHead: Option[String]): (Array[(List[String], Type)], Array[() => Any]) = { val arr = parseAll(annotationExpressions, code) match { case Success(result, _) => result.asInstanceOf[Array[(List[String], AST)]] case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } def checkType(l: List[String], t: BaseType): Type = { - if (l.head == expectedHead) - t match { - case t: Type => t - case bt => fatal( - s"""Got invalid type `$t' from the result of `${ l.mkString(".") }'""".stripMargin) - } else fatal( - s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }' - | Path should begin with `$expectedHead' + if (expectedHead.isDefined && l.head != expectedHead.get) + fatal( + s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }' + | Path should begin with `$expectedHead' """.stripMargin) + + t match { + case t: Type => t + case bt => fatal( + s"""Got invalid type `$t' from the result of `${ l.mkString(".") }'""".stripMargin) + } } val all = arr.map { @@ -178,8 +180,11 @@ object Parser extends JavaTokenParsers { ast.typecheck(ec) val t = checkType(path, ast.`type`) val f = ast.eval(ec) - ((path.tail, t), () => Option(f())) + val name = if (expectedHead.isDefined) path.tail else path +// ((name, t), () => Option(f())) + ((name, t), () => f()) } + (all.map(_._1), all.map(_._2)) } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 39289eeeef9..104ae18e42f 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -7,10 +7,16 @@ import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TS import org.broadinstitute.hail.methods.Filter object KeyTable extends Serializable { - def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation) = { + def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int) = { (Option(k).map(_.asInstanceOf[Row]), Option(v).map(_.asInstanceOf[Row])) match { case (Some(kr), Some(vr)) => ec.setAll(kr.toSeq ++ vr.toSeq: _*) - case _ => ec.clear() + case (Some(kr), None) => + ec.clear() + ec.setAll(kr.toSeq: _*) + case (None, Some(vr)) => + ec.clear() + vr.toSeq.zipWithIndex.foreach{ case (a, i) => ec.set(i + nKeys, a)} + case (None, None) => ec.clear() } } @@ -69,6 +75,8 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def nRows = rdd.count() def nFields = fields.length + def nKeys = keySignature.size + def nValues = valueSignature.size def leftJoin(other: KeyTable): KeyTable = ??? def rightJoin(other: KeyTable): KeyTable = ??? @@ -84,7 +92,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val (t, f) = Parser.parse(code, ec) val f2: (Annotation, Annotation) => Option[Any] = { case (k, v) => - KeyTable.setEvalContext(ec, k, v) + KeyTable.setEvalContext(ec, k, v, nKeys) f() } @@ -112,7 +120,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) val p = (k: Annotation, v: Annotation) => { - KeyTable.setEvalContext(ec, k, v) + KeyTable.setEvalContext(ec, k, v, nKeys) Filter.keepThis(f(), keep) } diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala index ec7878a4d91..3232e20c2ef 100644 --- a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -1,7 +1,7 @@ package org.broadinstitute.hail.driver import org.broadinstitute.hail.SparkSuite -import org.broadinstitute.hail.annotations.Annotation +import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.check.Prop._ import org.broadinstitute.hail.check.{Gen, Properties} import org.broadinstitute.hail.expr.{TString, TStruct} @@ -12,35 +12,28 @@ class AggregateByKeySuite extends SparkSuite { @Test def replicateSampleAggregation() = { val inputVCF = "src/test/resources/sample.vcf" var s = State(sc, sqlContext) - s = ImportVCF.run(s, Array("-i", inputVCF)) + s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()")) s = AggregateByKey.run(s, Array("-k", "Sample = s", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) val kt = s.ktEnv("kt") val vds = s.vds - println(kt.signature.schema) + val (_, ktHetQuery) = kt.query("nHet") val (_, ktSampleQuery) = kt.query("Sample") val (_, saHetQuery) = vds.querySA("sa.nHet") val ktSampleResults = kt.rdd.map { case (k, v) => - (ktSampleQuery(k, v).get.asInstanceOf[String], ktHetQuery(k, v)) + (ktSampleQuery(k, v).map(_.asInstanceOf[String]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) }.collectAsMap() - ktSampleResults.head._2.foreach(x => println(x.getClass.getName)) - - assert( vds.sampleIdsAndAnnotations.forall{ case (sid, sa) => - (saHetQuery(sa), ktSampleResults(sid)) match { - case (None, None) => true - case (Some(x), Some(y)) => x.asInstanceOf[Long] == y.asInstanceOf[String].toLong - case _ => false - } - }) + + assert( vds.sampleIdsAndAnnotations.forall{ case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid))}) } @Test def replicateVariantAggregation() = { val inputVCF = "src/test/resources/sample.vcf" var s = State(sc, sqlContext) - s = ImportVCF.run(s, Array("-i", inputVCF)) + s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) s = AggregateByKey.run(s, Array("-k", "Variant = v", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) @@ -52,36 +45,31 @@ class AggregateByKeySuite extends SparkSuite { val (_, vaHetQuery) = vds.queryVA("va.nHet") val ktVariantResults = kt.rdd.map { case (k, v) => - (ktVariantQuery(k, v).get.asInstanceOf[String], ktHetQuery(k, v)) + (ktVariantQuery(k, v).map(_.asInstanceOf[Variant]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) }.collectAsMap() - assert( vds.variantsAndAnnotations.map{ case (v, va) => - (vaHetQuery(va), ktVariantResults(v.toString)) match { - case (None, None) => true - case (Some(x), Some(y)) => x.asInstanceOf[Long] == y.asInstanceOf[String].toLong - case _ => false - } - }.fold(true)(_ && _)) + assert( vds.variantsAndAnnotations.forall{ case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v))}) } -// @Test def replicateGlobalAggregation() = { -// val inputVCF = "src/test/resources/sample.vcf" -// var s = State(sc, sqlContext) -// s = ImportVCF.run(s, Array("-i", inputVCF)) -// s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) -// s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum()")) -// s = AggregateByKey.run(s, Array("-k", " ", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) -// -// val kt = s.ktEnv("kt") -// val vds = s.vds -// -// val (_, ktHetQuery) = kt.query("nHet") -// val (_, globalHet) = vds.queryGlobal("global.nHet") -// -// val ktGlobalResults = kt.rdd.map{ a => ktHetQuery(a).asInstanceOf[String].toLong}.collect() -// assert(ktGlobalResults.length == 1 && globalHet.get.asInstanceOf[Long] == ktGlobalResults(0)) -// } + @Test def replicateGlobalAggregation() = { + val inputVCF = "src/test/resources/sample.vcf" + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array(inputVCF)) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) + s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum().toLong")) + s = AggregateByKey.run(s, Array("-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) + + val kt = s.ktEnv("kt") + val vds = s.vds + val (_, ktHetQuery) = kt.query("nHet") + val (_, globalHetResult) = vds.queryGlobal("global.nHet") + + val ktGlobalResult = kt.rdd.map{ case (k, v) => ktHetQuery(k, v).map(_.asInstanceOf[Long])}.collect().head + val vdsGlobalResult = globalHetResult.map(_.asInstanceOf[Long]) + + assert( ktGlobalResult == vdsGlobalResult ) + } // def createKey(nItems: Int, nCategories: Int) = // Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) From 40403d08405b50a87c09983c6baead8439c0c011 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 3 Nov 2016 13:53:18 -0400 Subject: [PATCH 25/51] added to funct registry and annotate insert values working --- .../hail/driver/AnnotateKeyTableExpr.scala | 76 ++++++++++++++++ .../org/broadinstitute/hail/expr/Parser.scala | 1 - .../hail/keytable/KeyTable.scala | 4 + .../hail/driver/AggregateByKeySuite.scala | 90 +------------------ .../hail/methods/KeyTableSuite.scala | 35 ++++++++ 5 files changed, 116 insertions(+), 90 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala new file mode 100644 index 00000000000..98f8ef6d2c5 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala @@ -0,0 +1,76 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.annotations._ +import org.broadinstitute.hail.expr.{EvalContext, Parser, TStruct} +import org.broadinstitute.hail.utils._ +import org.kohsuke.args4j.{Option => Args4jOption} +import org.broadinstitute.hail.keytable.KeyTable + +import scala.collection.mutable + +object AnnotateKeyTableExpr extends Command { + class Options extends BaseOptions { + @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), + usage = "Boolean expression for annotating", metaVar = "EXPR") + var condition: String = _ + + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "Name of source key table") + var name: String = _ + + @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), + usage = "Name of destination key table (can be same as source)") + var dest: String = _ + } + + def newOptions = new Options + + def name = "annotatekeytable expr" + + def description = "Annotate key table using an expression" + + def supportsMultiallelic = true + + def requiresVDS = false + + override def hidden = true + + def run(state: State, options: Options): State = { + val cond = options.condition + val dest = options.dest + + val kt = state.ktEnv.get(options.name) match { + case Some(newKT) => + newKT + case None => + fatal("no such key table $name in environment") + } + + val symTab = kt.fields.zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap + val ec = EvalContext(symTab) + + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, None) + + val inserterBuilder = mutable.ArrayBuilder.make[Inserter] + + val finalValueSignature = parseTypes.foldLeft(kt.valueSignature) { case (vs, (ids, signature)) => + val (s: TStruct, i) = vs.insert(signature, ids) + inserterBuilder += i + s + } + + val inserters = inserterBuilder.result() + val nKeys = kt.nKeys + + val annotated = kt.mapAnnotations{ case (k, v) => + KeyTable.setEvalContext(ec, k, v, nKeys) + + fns.zip(inserters) + .foldLeft(v) { case (va, (fn, inserter)) => + inserter(va, Option(fn())) + } + }.copy(valueSignature = finalValueSignature) + + state.copy(ktEnv = state.ktEnv + ( dest -> annotated)) + } +} diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index 2ce6839c490..dbe3d1f1e23 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -181,7 +181,6 @@ object Parser extends JavaTokenParsers { val t = checkType(path, ast.`type`) val f = ast.eval(ec) val name = if (expectedHead.isDefined) path.tail else path -// ((name, t), () => Option(f())) ((name, t), () => f()) } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 104ae18e42f..6a42eb72ff1 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -5,6 +5,7 @@ import org.apache.spark.sql.Row import org.broadinstitute.hail.annotations._ import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct} import org.broadinstitute.hail.methods.Filter +import org.broadinstitute.hail.utils._ object KeyTable extends Serializable { def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int) = { @@ -82,9 +83,12 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def rightJoin(other: KeyTable): KeyTable = ??? def outerJoin(other: KeyTable): KeyTable = ??? def innerJoin(other: KeyTable): KeyTable = ??? + // require(keyNames.toSet == other.keyNames.toSet) // function to make key order the same + def mapAnnotations(f: (Annotation, Annotation) => Annotation): KeyTable = + copy(rdd = rdd.mapValuesWithKey{ case (k, v) => f(k, v) }) def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala index 3232e20c2ef..d5e38e6d939 100644 --- a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -2,10 +2,7 @@ package org.broadinstitute.hail.driver import org.broadinstitute.hail.SparkSuite import org.broadinstitute.hail.utils._ -import org.broadinstitute.hail.check.Prop._ -import org.broadinstitute.hail.check.{Gen, Properties} -import org.broadinstitute.hail.expr.{TString, TStruct} -import org.broadinstitute.hail.variant.{Genotype, VSMSubgen, VariantSampleMatrix, _} +import org.broadinstitute.hail.variant._ import org.testng.annotations.Test class AggregateByKeySuite extends SparkSuite { @@ -70,89 +67,4 @@ class AggregateByKeySuite extends SparkSuite { assert( ktGlobalResult == vdsGlobalResult ) } - -// def createKey(nItems: Int, nCategories: Int) = -// Gen.buildableOfN[Array, Option[String]](nItems, Gen.option(Gen.oneOfSeq((0 until nCategories).map("group" + _)), 0.95)) -// -// def createKeys(nKeys: Int, nItems: Int) = -// Gen.buildableOfN[Array, Array[Option[String]]](nKeys, createKey(nItems, Gen.choose(1, 10).sample())) -// -// object Spec extends Properties("CreateKeyTable") { -// val compGen = for (vds: VariantDataset <- VariantSampleMatrix.gen[Genotype](sc, VSMSubgen.random).filter(vds => vds.nVariants > 0 && vds.nSamples > 0); -// nKeys <- Gen.choose(1, 5); -// nSampleKeys <- Gen.choose(0, nKeys); -// nVariantKeys <- Gen.const(nKeys - nSampleKeys); -// sampleGroups <- createKeys(nSampleKeys, vds.nSamples); -// variantGroups <- createKeys(nVariantKeys, vds.nVariants.toInt) -// ) yield (vds, sampleGroups, variantGroups) -// -// property("aggregate by sample and variants same") = forAll(compGen) { case (vds, sampleGroups, variantGroups) => -// val outputFile = tmpDir.createTempFile("keyTableTest", "tsv") -// -// val nKeys = sampleGroups.length + variantGroups.length -// val keyNames = (1 to nKeys).map("key" + _) -// val sampleKeyNames = (1 to sampleGroups.length).map("key" + _) -// val variantKeyNames = (sampleGroups.length + 1 to nKeys).map("key" + _) -// -// var sampleSignature = TStruct() -// sampleKeyNames.foreach(k => sampleSignature = sampleSignature.appendKey(k, TString)) -// -// var variantSignature = TStruct() -// variantKeyNames.foreach(k => variantSignature = variantSignature.appendKey(k, TString)) -// -// val sampleMap = vds.sampleIds.zipWithIndex.map { case (sid, i) => -// (sid, Annotation(sampleGroups.map(_ (i).orNull).toSeq: _*)) -// }.toMap -// -// val variantAnnotations = sc.parallelize(vds.variants.collect().zipWithIndex.map { case (v, i) => -// (v, Annotation(variantGroups.map(_ (i).orNull).toSeq: _*)) -// }).toOrderedRDD -// -// var s = State(sc, sqlContext, vds.annotateSamples(sampleMap, sampleSignature, "sa.keys") -// .annotateVariants(variantAnnotations, variantSignature, "va.keys")) -// -// val (_, sampleKeyQuery) = s.vds.querySA("sa.keys.*") -// val (_, variantKeyQuery) = s.vds.queryVA("va.keys.*") -// -// val keyGenotypeRDD = s.vds.mapWithAll { case (v, va, sid, sa, g) => -// val key = sampleKeyQuery(sa).get.asInstanceOf[IndexedSeq[String]] ++ variantKeyQuery(va).get.asInstanceOf[IndexedSeq[String]] -// (key, g) -// } -// -// val result = keyGenotypeRDD.aggregateByKey((0L, 0L, 0L))( -// (comb, gt) => (comb._1 + gt.isHet.toInt.toInt, comb._2 + gt.isCalled.toInt.toInt, comb._3 + 1), -// (comb1, comb2) => (comb1._1 + comb2._1, comb1._2 + comb2._2, comb1._3 + comb2._3)).collectAsMap() -// -// s = AggregateByKey.run(s, Array("-k", (sampleKeyNames.map(k => k + " = " + "sa.keys." + k) ++ variantKeyNames.map(k => k + " = " + "va.keys." + k)).mkString(","), -// "-a", "nHet = gs.filter(g => g.isHet).count(), nCalled = gs.filter(g => g.isCalled).count(), nTotal = gs.count()", "-n", "foo")) -// -// s = ExportKeyTable.run(s, Array("-o", outputFile, "-n", "foo")) -// println(outputFile) -// -// val ktr = hadoopConf.readLines(outputFile)(_.map(_.map { line => -// line.trim.split("\\s+") -// }.value).toIndexedSeq) -// -// val header = ktr.take(1) -// -// val keyTableResults = ktr.drop(1).map(r => (header(0), r).zipped.toMap) -// .map { x => (keyNames.map { k => x(k) }, (x("nHet"), x("nCalled"), x("nTotal"))) }.toMap -// -// result.forall { case (keys, (nHet, nCalled, nTotal)) => -// val (ktHet, ktCalled, ktTotal) = keyTableResults(keys.map(k => if (k != null) k else "NA").toIndexedSeq) -// ktHet.toLong == nHet && -// ktCalled.toLong == nCalled && -// ktTotal.toLong == nTotal -// } -// } -// } -// // -// // @Test def testAddKeyTable() { -// // Spec.check() -// // } -// -// -// @Test def test() { -// -// } } diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 46099a6c8e8..51f5bffcc5a 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -27,6 +27,41 @@ class KeyTableSuite extends SparkSuite { assert(importedData == exportedData) } + @Test def testAnnotate() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + var s = State(sc, sqlContext) + s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample", inputFile)) + s = AnnotateKeyTableExpr.run(s, Array("-n", "kt1", "-d", "kt2", "-c", "RandomBool = pcoin(0.4), RandomQP = rnorm(0, 1), RandomNum = runif(0, 1)")) + + val kt1 = s.ktEnv("kt1") + val kt2 = s.ktEnv("kt2") + + val kt1ValueNames = kt1.valueNames.toSet + val kt2ValueNames = kt2.valueNames.toSet + + assert(kt1.nKeys == kt2.nKeys && + kt1.nValues == 2 && kt2.nValues == 5 && + kt1.keySignature == kt2.keySignature && + kt1ValueNames ++ Set("RandomBool", "RandomQP", "RandomNum") == kt2ValueNames + ) + } + + @Test def testLeftJoin() = { + + } + + @Test def testRightJoin() = { + + } + + @Test def testInnerJoin() = { + + } + + @Test def testOuterJoin() = { + + } + @Test def testFilter() = { val data = Array(Array(5, 9, 0), Array(2, 3, 4), Array(1, 2, 3)) val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) From 219f36d523c8439659e71cacf932dd935a891612 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 3 Nov 2016 14:23:54 -0400 Subject: [PATCH 26/51] annotate working with only adding values --- .../hail/driver/AnnotateKeyTableExpr.scala | 14 ++++++--- .../hail/keytable/KeyTable.scala | 3 ++ .../hail/methods/KeyTableSuite.scala | 31 +++++++++---------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala index 98f8ef6d2c5..2c84e355be3 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala @@ -10,10 +10,6 @@ import scala.collection.mutable object AnnotateKeyTableExpr extends Command { class Options extends BaseOptions { - @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), - usage = "Boolean expression for annotating", metaVar = "EXPR") - var condition: String = _ - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), usage = "Name of source key table") var name: String = _ @@ -21,6 +17,14 @@ object AnnotateKeyTableExpr extends Command { @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), usage = "Name of destination key table (can be same as source)") var dest: String = _ + + @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), + usage = "Expression for annotating", metaVar = "EXPR") + var condition: String = _ + + @Args4jOption(required = false, name = "-k", aliases = Array("--key-names"), + usage = "Names of key in new table", metaVar = "EXPR") + var keyNames: String = _ } def newOptions = new Options @@ -37,7 +41,7 @@ object AnnotateKeyTableExpr extends Command { def run(state: State, options: Options): State = { val cond = options.condition - val dest = options.dest + val dest = if (options.dest != null) options.dest else options.name val kt = state.ktEnv.get(options.name) match { case Some(newKT) => diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 6a42eb72ff1..0228edc04eb 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -87,6 +87,9 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v // require(keyNames.toSet == other.keyNames.toSet) // function to make key order the same +// def mapAnnotations(f: (Annotation) => Annotation): KeyTable = +// copy(rdd = KeyTable.toSingleRDD(rdd).map{ a => f(a)}) + def mapAnnotations(f: (Annotation, Annotation) => Annotation): KeyTable = copy(rdd = rdd.mapValuesWithKey{ case (k, v) => f(k, v) }) diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 51f5bffcc5a..73e1b673ca3 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -46,22 +46,6 @@ class KeyTableSuite extends SparkSuite { ) } - @Test def testLeftJoin() = { - - } - - @Test def testRightJoin() = { - - } - - @Test def testInnerJoin() = { - - } - - @Test def testOuterJoin() = { - - } - @Test def testFilter() = { val data = Array(Array(5, 9, 0), Array(2, 3, 4), Array(1, 2, 3)) val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) @@ -82,6 +66,21 @@ class KeyTableSuite extends SparkSuite { s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < -5 && field3 == 100", "--keep")) assert(s.ktEnv.contains("kt1") && s.ktEnv("kt1").nRows == 0) + } + + @Test def testLeftJoin() = { + + } + + @Test def testRightJoin() = { + + } + + @Test def testInnerJoin() = { + + } + + @Test def testOuterJoin() = { } } From 85b75f6ab72b61ff110e27fe83450e80f219748b Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 3 Nov 2016 18:34:54 -0400 Subject: [PATCH 27/51] import, export, aggByKey, filter, annotate done --- .../hail/driver/AnnotateKeyTable.scala | 9 ++ .../hail/driver/AnnotateKeyTableExpr.scala | 42 +++--- .../broadinstitute/hail/driver/Command.scala | 1 + .../hail/driver/ExportKeyTable.scala | 14 +- .../hail/driver/FilterKeyTableExpr.scala | 16 ++- .../hail/driver/ImportKeyTable.scala | 2 +- .../org/broadinstitute/hail/expr/AST.scala | 2 - .../org/broadinstitute/hail/expr/Type.scala | 2 +- .../hail/keytable/KeyTable.scala | 123 +++++++++++------- .../hail/methods/KeyTableSuite.scala | 43 +++++- 10 files changed, 168 insertions(+), 86 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala new file mode 100644 index 00000000000..899b0f12ded --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala @@ -0,0 +1,9 @@ +package org.broadinstitute.hail.driver + +object AnnotateKeyTable extends SuperCommand { + def name = "annotatekeytable" + + def description = "Annotate key tables" + + register(AnnotateKeyTableExpr) +} diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala index 2c84e355be3..f1655397849 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala @@ -1,7 +1,7 @@ package org.broadinstitute.hail.driver import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr.{EvalContext, Parser, TStruct} +import org.broadinstitute.hail.expr.{EvalContext, Parser, TStruct, Type} import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} import org.broadinstitute.hail.keytable.KeyTable @@ -9,6 +9,7 @@ import org.broadinstitute.hail.keytable.KeyTable import scala.collection.mutable object AnnotateKeyTableExpr extends Command { + class Options extends BaseOptions { @Args4jOption(required = true, name = "-n", aliases = Array("--name"), usage = "Name of source key table") @@ -18,12 +19,12 @@ object AnnotateKeyTableExpr extends Command { usage = "Name of destination key table (can be same as source)") var dest: String = _ - @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), - usage = "Expression for annotating", metaVar = "EXPR") + @Args4jOption(required = false, name = "-c", aliases = Array("--cond"), + usage = "Named expression for adding fields to the table", metaVar = "EXPR") var condition: String = _ @Args4jOption(required = false, name = "-k", aliases = Array("--key-names"), - usage = "Names of key in new table", metaVar = "EXPR") + usage = "Names of key in new table (default is existing key names)", metaVar = "EXPR") var keyNames: String = _ } @@ -41,40 +42,47 @@ object AnnotateKeyTableExpr extends Command { def run(state: State, options: Options): State = { val cond = options.condition - val dest = if (options.dest != null) options.dest else options.name + val name = options.name + val dest = if (options.dest != null) options.dest else name - val kt = state.ktEnv.get(options.name) match { + val kt = state.ktEnv.get(name) match { case Some(newKT) => newKT case None => fatal("no such key table $name in environment") } - val symTab = kt.fields.zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap - val ec = EvalContext(symTab) + val ec = EvalContext(kt.fields.map(fd => (fd.name, fd.`type`)): _*) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, None) + val (parseTypes, fns) = + if (cond != null) + Parser.parseAnnotationArgs(cond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] - val finalValueSignature = parseTypes.foldLeft(kt.valueSignature) { case (vs, (ids, signature)) => + val finalSignature = parseTypes.foldLeft(kt.signature) { case (vs, (ids, signature)) => val (s: TStruct, i) = vs.insert(signature, ids) inserterBuilder += i s } val inserters = inserterBuilder.result() - val nKeys = kt.nKeys - val annotated = kt.mapAnnotations{ case (k, v) => - KeyTable.setEvalContext(ec, k, v, nKeys) + val keyNames = if (options.keyNames != null) Parser.parseIdentifierList(options.keyNames) else kt.keyNames.toArray + + val nFields = kt.nFields + + val f: Annotation => Annotation = { a => + KeyTable.setEvalContext(ec, a, nFields) fns.zip(inserters) - .foldLeft(v) { case (va, (fn, inserter)) => - inserter(va, Option(fn())) + .foldLeft(a) { case (a1, (fn, inserter)) => + inserter(a1, Option(fn())) } - }.copy(valueSignature = finalValueSignature) + } - state.copy(ktEnv = state.ktEnv + ( dest -> annotated)) + state.copy(ktEnv = state.ktEnv + (dest -> kt.mapAnnotations(f, finalSignature, keyNames))) } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index 6b0e3545fa2..a462e67a113 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -49,6 +49,7 @@ object ToplevelCommands { register(AggregateByKey) register(AggregateIntervals) + register(AnnotateKeyTable) register(AnnotateSamples) register(AnnotateVariants) register(AnnotateGlobal) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala index 0fa1e9f79d9..4bb665aad35 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala @@ -38,18 +38,17 @@ object ExportKeyTable extends Command with TextExporter { def run(state: State, options: Options): State = { - val kt = state.ktEnv.get(options.name) match { + val name = options.name + val output = options.output + + val kt = state.ktEnv.get(name) match { case Some(newKT) => newKT case None => fatal("no such key table $name in environment") } - val output = options.output - - val symTab = kt.fields.zipWithIndex.map{case (fd, i) => (fd.name, (i, fd.`type`))}.toMap - - val ec = EvalContext(symTab) + val ec = EvalContext(kt.fields.map(fd => (fd.name, fd.`type`)): _*) val (header, types, f) = Parser.parseNamedArgs(kt.fieldNames.map(n => n + " = " + n).mkString(","), ec) @@ -63,13 +62,14 @@ object ExportKeyTable extends Command with TextExporter { state.hadoopConf.delete(output, recursive = true) val nKeys = kt.nKeys + val nValues = kt.nValues kt.rdd .mapPartitions { it => val sb = new StringBuilder() it.map { case (k, v) => sb.clear() - KeyTable.setEvalContext(ec, k, v, nKeys) + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) f().foreachBetween(x => sb.append(x))(sb += '\t') sb.result() } diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala index 4df9ad68011..1965b2229a5 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala @@ -4,9 +4,10 @@ import org.broadinstitute.hail.utils._ import org.kohsuke.args4j.{Option => Args4jOption} object FilterKeyTableExpr extends Command { + class Options extends BaseOptions { @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), - usage = "Boolean expression for filtering", metaVar = "EXPR") + usage = "Boolean expression for filtering", metaVar = "EXPR") var condition: String = _ @Args4jOption(required = true, name = "-n", aliases = Array("--name"), @@ -37,7 +38,12 @@ object FilterKeyTableExpr extends Command { override def hidden = true def run(state: State, options: Options): State = { - val kt = state.ktEnv.get(options.name) match { + val name = options.name + val cond = options.condition + val keep = options.keep + val dest = if (options.dest != null) options.dest else name + + val kt = state.ktEnv.get(name) match { case Some(newKT) => newKT case None => @@ -47,10 +53,6 @@ object FilterKeyTableExpr extends Command { if (!(options.keep ^ options.remove)) fatal("either `--keep' or `--remove' required, but not both") - val cond = options.condition - val keep = options.keep - val dest = if (options.dest != null) options.dest else options.name - - state.copy(ktEnv = state.ktEnv + ( dest -> kt.filterExpr(cond, keep))) + state.copy(ktEnv = state.ktEnv + (dest -> kt.filterExpr(cond, keep))) } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala index cfe13fa755f..a1a88d77e34 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala @@ -52,7 +52,7 @@ object ImportKeyTable extends Command { } else TextTableReader.read(state.sc)(files, options.config) - val keyNamesValid = keyNames.forall{ k => + val keyNamesValid = keyNames.forall { k => val res = struct.selfField(k).isDefined if (!res) println(s"Key `$k' is not present in input table") diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 537e57eb9cd..08b199ff3bc 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -26,8 +26,6 @@ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunction def set(index: Int, arg: Any) { a(index) = arg } - - def clear() = a.indices.foreach { i => a(i) = null } } object EvalContext { diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala index e495bb0c429..6abd5d0260f 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala @@ -310,7 +310,7 @@ case class TArray(elementType: Type) extends TIterable { override def str(a: Annotation): String = JsonMethods.compact(toJSON(a)) - override def genValue: Gen[Annotation] = Gen.buildableOf[IndexedSeq, Annotation](elementType.genValue) + override def genValue: Gen[Annotation] = Gen.buildableOf[Array, Annotation](elementType.genValue).map(x => x: IndexedSeq[Annotation]) } case class TSet(elementType: Type) extends TIterable { diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 0228edc04eb..a55f1ff4a5d 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -1,32 +1,22 @@ package org.broadinstitute.hail.keytable +import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct} +import org.broadinstitute.hail.check.Gen +import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct, Type} import org.broadinstitute.hail.methods.Filter import org.broadinstitute.hail.utils._ object KeyTable extends Serializable { - def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int) = { - (Option(k).map(_.asInstanceOf[Row]), Option(v).map(_.asInstanceOf[Row])) match { - case (Some(kr), Some(vr)) => ec.setAll(kr.toSeq ++ vr.toSeq: _*) - case (Some(kr), None) => - ec.clear() - ec.setAll(kr.toSeq: _*) - case (None, Some(vr)) => - ec.clear() - vr.toSeq.zipWithIndex.foreach{ case (a, i) => ec.set(i + nKeys, a)} - case (None, None) => ec.clear() - } - } + def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) - def setEvalContext(ec: EvalContext, a: Annotation) = { - Option(a).map(_.asInstanceOf[Row]) match { - case Some(r) => ec.setAll(r.toSeq: _*) - case _ => ec.clear() - } - } + def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int, nValues: Int) = + ec.setAll(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues): _*) + + def setEvalContext(ec: EvalContext, a: Annotation, nFields: Int) = + ec.setAll(annotationToSeq(a, nFields): _*) def pairSignature(signature: TStruct, keyNames: Array[String]): (TStruct, TStruct) = { val keyNameSet = keyNames.toSet @@ -37,16 +27,20 @@ object KeyTable extends Serializable { def singleSignature(keySignature: TStruct, valueSignature: TStruct): (TStruct, Array[String]) = (TStruct(keySignature.fields ++ valueSignature.fields), keySignature.fields.map(_.name).toArray) - def toSingleRDD(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct): RDD[Annotation] = - rdd.map{ case (k, v) => Annotation(Option(k).map(_.asInstanceOf[Row]).toSeq ++ Option(v).map(_.asInstanceOf[Row]).toSeq: _*) } + def toSingleRDD(rdd: RDD[(Annotation, Annotation)], nKeys: Int, nValues: Int): RDD[Annotation] = + rdd.map{ case (k, v) => + val x = Annotation.fromSeq(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues)) + x + } def toPairRDD(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): RDD[(Annotation, Annotation)] = { val keyNameSet = keyNames.toSet val keyIndices = signature.fields.filter(fd => keyNames.contains(fd.name)).map(_.index).toSet val valueIndices = signature.fields.filterNot(fd => keyNames.contains(fd.name)).map(_.index).toSet + val nFields = signature.size rdd.map { a => - val r = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill(signature.size)(null)).zipWithIndex + val r = annotationToSeq(a, nFields).zipWithIndex val keyRow = r.filter{ case (ann, i) => keyIndices.contains(i) }.map(_._1) val valueRow = r.filter{ case (ann, i) => valueIndices.contains(i) }.map(_._1) (Annotation.fromSeq(keyRow), Annotation.fromSeq(valueRow)) @@ -59,11 +53,14 @@ object KeyTable extends Serializable { } } + + case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { require(fieldNames.toSet.size == fieldNames.length) def signature = KeyTable.singleSignature(keySignature, valueSignature)._1 + def fields = signature.fields def keySchema = keySignature.schema @@ -79,19 +76,40 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def nKeys = keySignature.size def nValues = valueSignature.size - def leftJoin(other: KeyTable): KeyTable = ??? - def rightJoin(other: KeyTable): KeyTable = ??? - def outerJoin(other: KeyTable): KeyTable = ??? - def innerJoin(other: KeyTable): KeyTable = ??? - -// require(keyNames.toSet == other.keyNames.toSet) - // function to make key order the same + def same(other: KeyTable): Boolean = { + if (fields.toSet != other.fields.toSet) { + println(s"signature: this=${ schema } other=${ other.schema }") + false + } else if (keyNames.toSet != other.keyNames.toSet) { + println(s"keyNames: this=${ keyNames.mkString(",") } other=${ other.keyNames.mkString(",")}") + false + } else { + val thisFieldNames = valueNames + val otherFieldNames = other.valueNames + + rdd.groupByKey().fullOuterJoin(other.rdd.groupByKey()).forall { case (k, (v1, v2)) => + (v1, v2) match { + case (None, None) => true + case (Some(x), Some(y)) => + val r1 = x.map(r => thisFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet + val r2 = y.map(r => otherFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet + val res = r1 == r2 + if (!res) + println(s"k=$k r1=${r1.mkString(",")} r2=${r2.mkString(",")}") + res + case _ => + println(s"k=$k v1=$v1 v2=$v2") + false + } + } + } + } -// def mapAnnotations(f: (Annotation) => Annotation): KeyTable = -// copy(rdd = KeyTable.toSingleRDD(rdd).map{ a => f(a)}) + def mapAnnotations(f: (Annotation) => Annotation, newSignature: TStruct, newKeyNames: Array[String]): KeyTable = + KeyTable(KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a)), newSignature, newKeyNames) - def mapAnnotations(f: (Annotation, Annotation) => Annotation): KeyTable = - copy(rdd = rdd.mapValuesWithKey{ case (k, v) => f(k, v) }) + def mapAnnotations(f: (Annotation, Annotation) => Annotation, newValueSignature: TStruct): KeyTable = + copy(rdd = rdd.mapValuesWithKey{ case (k, v) => f(k, v) }, valueSignature = newValueSignature) def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) @@ -99,25 +117,25 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val (t, f) = Parser.parse(code, ec) val f2: (Annotation, Annotation) => Option[Any] = { case (k, v) => - KeyTable.setEvalContext(ec, k, v, nKeys) + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) f() } (t, f2) } -// def query(code: String): (BaseType, Querier) = { -// val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) -// -// val (t, f) = Parser.parse(code, ec) -// -// val f2: (Annotation) => Option[Any] = { a => -// KeyTable.setEvalContext(ec, a) -// f() -// } -// -// (t, f2) -// } + def querySingle(code: String): (BaseType, Querier) = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + + val (t, f) = Parser.parse(code, ec) + + val f2: (Annotation) => Option[Any] = { a => + KeyTable.setEvalContext(ec, a, nFields) + f() + } + + (t, f2) + } def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { case (k, v) => p(k, v) }) @@ -127,10 +145,19 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) val p = (k: Annotation, v: Annotation) => { - KeyTable.setEvalContext(ec, k, v, nKeys) + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) Filter.keepThis(f(), keep) } filter(p) } -} + + + def leftJoin(other: KeyTable): KeyTable = ??? + def rightJoin(other: KeyTable): KeyTable = ??? + def outerJoin(other: KeyTable): KeyTable = ??? + def innerJoin(other: KeyTable): KeyTable = ??? + + // require(keyNames.toSet == other.keyNames.toSet) + // function to make key order the same +} \ No newline at end of file diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 73e1b673ca3..404cf7135c1 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -10,6 +10,21 @@ import org.testng.annotations.Test class KeyTableSuite extends SparkSuite { + @Test def testSingleToPairRDD() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + var s = State(sc, sqlContext) + s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status", inputFile)) + val kt = s.ktEnv("kt1") + val kt2 = KeyTable.toPairRDD(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames.toArray) + + assert(kt.rdd.fullOuterJoin(kt2).forall { case (k, (v1, v2)) => + val res = v1 == v2 + if (!res) + println(s"k=$k v1=$v1 v2=$v2 res=${ v1 == v2 }") + res + }) + } + @Test def testImportExport() = { val inputFile = "src/test/resources/sampleAnnotations.tsv" val outputFile = tmpDir.createTempFile("ktImpExp", "tsv") @@ -30,11 +45,16 @@ class KeyTableSuite extends SparkSuite { @Test def testAnnotate() = { val inputFile = "src/test/resources/sampleAnnotations.tsv" var s = State(sc, sqlContext) - s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample", inputFile)) - s = AnnotateKeyTableExpr.run(s, Array("-n", "kt1", "-d", "kt2", "-c", "RandomBool = pcoin(0.4), RandomQP = rnorm(0, 1), RandomNum = runif(0, 1)")) + + s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample", "--impute", inputFile)) + s = AnnotateKeyTableExpr.run(s, Array("-n", "kt1", "-d", "kt2", "-c", """qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""")) + s = AnnotateKeyTableExpr.run(s, Array("-n", "kt2", "-d", "kt3")) + s = AnnotateKeyTableExpr.run(s, Array("-n", "kt3", "-d", "kt4", "-k", "qPhen, NotStatus")) val kt1 = s.ktEnv("kt1") val kt2 = s.ktEnv("kt2") + val kt3 = s.ktEnv("kt3") + val kt4 = s.ktEnv("kt4") val kt1ValueNames = kt1.valueNames.toSet val kt2ValueNames = kt2.valueNames.toSet @@ -42,7 +62,24 @@ class KeyTableSuite extends SparkSuite { assert(kt1.nKeys == kt2.nKeys && kt1.nValues == 2 && kt2.nValues == 5 && kt1.keySignature == kt2.keySignature && - kt1ValueNames ++ Set("RandomBool", "RandomQP", "RandomNum") == kt2ValueNames + kt1ValueNames ++ Set("qPhen2", "NotStatus", "X") == kt2ValueNames + ) + + assert(kt2 same kt3) + + def getDataAsMap(kt: KeyTable) = { + val fieldNames = kt.fieldNames + val nFields = kt.nFields + KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues) + .map { a => fieldNames.zip(KeyTable.annotationToSeq(a, nFields)).toMap }.collect().toSet + } + + val kt3data = getDataAsMap(kt3) + val kt4data = getDataAsMap(kt4) + + assert(kt4.keyNames.toSet == Set("qPhen", "NotStatus") && + kt4.valueNames.toSet == Set("qPhen2", "X", "Sample", "Status") && + kt3data == kt4data ) } From 5e8cb00e7527eac70110d13593c242901b2cbc84 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 3 Nov 2016 20:53:51 -0400 Subject: [PATCH 28/51] started join --- .../hail/driver/JoinKeyTable.scala | 48 +++++++++++++------ .../hail/keytable/KeyTable.scala | 11 +++-- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala index 22dc74c4619..d476508c421 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala @@ -45,28 +45,48 @@ object JoinKeyTable extends Command { def run(state: State, options: Options): State = { val ktEnv = state.ktEnv + val leftName = options.leftName + val rightName = options.rightName + val dest = options.destName - val ktLeft = ktEnv.get(options.leftName) match { - case Some(kt) => - kt - case None => - fatal("no such key table $name in environment") + if (options.joinKeys == null) + fatal("Must specify at least one join key name (eg: `Phenotype, ...'") + + val joinKeys = Parser.parseIdentifierList(options.joinKeys) + + val ktLeft = ktEnv.get(leftName) match { + case Some(kt) => kt + case None => fatal("no such key table $leftName in environment") } - val ktRight = ktEnv.get(options.rightName) match { - case Some(kt) => - kt - case None => - fatal("no such key table $name in environment") + val ktRight = ktEnv.get(rightName) match { + case Some(kt) => kt + case None => fatal("no such key table $rightName in environment") } - if (ktEnv.contains(options.destName)) + if (ktEnv.contains(dest)) warn("destination name already exists -- overwriting previous key-table") + val ktLeftFieldSet = ktLeft.fieldNames.toSet + val ktRightFieldSet = ktRight.fieldNames.toSet + + if (!joinKeys.forall(k => ktLeftFieldSet.contains(k)) || !joinKeys.forall(k => ktRightFieldSet.contains(k))) + fatal( + s"""Join keys not present in both key-tables. + |Keys found: ${ joinKeys.mkString(",") } + |Left KeyTable Schema: ${ ktLeft.schema } + |Right KeyTable Schema: ${ ktRight.schema } + """.stripMargin) + + val joinedKT = options.joinType match { + case "left" => ktLeft.leftJoin(ktRight, joinKeys) + case "right" => ktLeft.rightJoin(ktRight, joinKeys) + case "inner" => ktLeft.innerJoin(ktRight, joinKeys) + case "outer" => ktLeft.outerJoin(ktRight, joinKeys) + case _ => fatal("Did not recognize join type. Pick one of [left, right, inner, outer].") + } - - - state + state.copy(ktEnv = state.ktEnv + (dest -> joinedKT)) } } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index a55f1ff4a5d..cd90b15d7b7 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -153,10 +153,13 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v } - def leftJoin(other: KeyTable): KeyTable = ??? - def rightJoin(other: KeyTable): KeyTable = ??? - def outerJoin(other: KeyTable): KeyTable = ??? - def innerJoin(other: KeyTable): KeyTable = ??? + def leftJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { + keySignature.merge(valueSignature) + } + + def rightJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? + def outerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? + def innerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? // require(keyNames.toSet == other.keyNames.toSet) // function to make key order the same From 1e8fbd3f610c1ab87dc37162aa6e7828b4136ed6 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Fri, 4 Nov 2016 14:46:14 -0400 Subject: [PATCH 29/51] 50% done with join --- .../hail/driver/JoinKeyTable.scala | 9 +- .../hail/keytable/KeyTable.scala | 92 +++++++++++++------ .../hail/methods/KeyTableSuite.scala | 34 +++++-- 3 files changed, 91 insertions(+), 44 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala index d476508c421..2807534e0b7 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala @@ -26,7 +26,7 @@ object JoinKeyTable extends Command { usage = "type of join") var joinType: String = "left" - @Args4jOption(required = true, name = "-t", aliases = Array("--join-keys"), + @Args4jOption(required = true, name = "-k", aliases = Array("--join-keys"), usage = "name of columns to join on") var joinKeys: String = _ } @@ -49,11 +49,6 @@ object JoinKeyTable extends Command { val rightName = options.rightName val dest = options.destName - if (options.joinKeys == null) - fatal("Must specify at least one join key name (eg: `Phenotype, ...'") - - val joinKeys = Parser.parseIdentifierList(options.joinKeys) - val ktLeft = ktEnv.get(leftName) match { case Some(kt) => kt case None => fatal("no such key table $leftName in environment") @@ -70,6 +65,8 @@ object JoinKeyTable extends Command { val ktLeftFieldSet = ktLeft.fieldNames.toSet val ktRightFieldSet = ktRight.fieldNames.toSet + val joinKeys = if (options.joinKeys == null) ktLeft.keyNames.toArray else Parser.parseIdentifierList(options.joinKeys) + if (!joinKeys.forall(k => ktLeftFieldSet.contains(k)) || !joinKeys.forall(k => ktRightFieldSet.contains(k))) fatal( s"""Join keys not present in both key-tables. diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index cd90b15d7b7..9ce4b88c08d 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -9,6 +9,8 @@ import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TS import org.broadinstitute.hail.methods.Filter import org.broadinstitute.hail.utils._ +case class Table(rdd: RDD[Annotation], signature: TStruct) + object KeyTable extends Serializable { def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) @@ -18,38 +20,34 @@ object KeyTable extends Serializable { def setEvalContext(ec: EvalContext, a: Annotation, nFields: Int) = ec.setAll(annotationToSeq(a, nFields): _*) - def pairSignature(signature: TStruct, keyNames: Array[String]): (TStruct, TStruct) = { - val keyNameSet = keyNames.toSet - (TStruct(signature.fields.filter(fd => keyNameSet.contains(fd.name))), - TStruct(signature.fields.filterNot(fd => keyNameSet.contains(fd.name)))) - } - - def singleSignature(keySignature: TStruct, valueSignature: TStruct): (TStruct, Array[String]) = - (TStruct(keySignature.fields ++ valueSignature.fields), keySignature.fields.map(_.name).toArray) - def toSingleRDD(rdd: RDD[(Annotation, Annotation)], nKeys: Int, nValues: Int): RDD[Annotation] = rdd.map{ case (k, v) => val x = Annotation.fromSeq(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues)) x } - def toPairRDD(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): RDD[(Annotation, Annotation)] = { - val keyNameSet = keyNames.toSet - val keyIndices = signature.fields.filter(fd => keyNames.contains(fd.name)).map(_.index).toSet - val valueIndices = signature.fields.filterNot(fd => keyNames.contains(fd.name)).map(_.index).toSet + def apply(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): KeyTable = { + val keyFields = signature.fields.filter(fd => keyNames.contains(fd.name)) + val keyIndices = keyFields.map(_.index) + + val valueFields = signature.fields.filterNot(fd => keyNames.contains(fd.name)) + val valueIndices = valueFields.map(_.index) + + assert(keyIndices.toSet.intersect(valueIndices.toSet).isEmpty) + val nFields = signature.size - rdd.map { a => + val newKeySignature = TStruct(keyFields.map(fd => (fd.name, fd.`type`)): _*) + val newValueSignature = TStruct(valueFields.map(fd => (fd.name, fd.`type`)): _*) + + val newRDD = rdd.map { a => val r = annotationToSeq(a, nFields).zipWithIndex - val keyRow = r.filter{ case (ann, i) => keyIndices.contains(i) }.map(_._1) - val valueRow = r.filter{ case (ann, i) => valueIndices.contains(i) }.map(_._1) + val keyRow = keyIndices.map( i => r(i)._1) + val valueRow = valueIndices.map( i => r(i)._1) (Annotation.fromSeq(keyRow), Annotation.fromSeq(valueRow)) } - } - def apply(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): KeyTable = { - val (keySignature, valueSignature) = pairSignature(signature, keyNames) - KeyTable(toPairRDD(rdd, signature, keyNames), keySignature, valueSignature) + KeyTable(newRDD, newKeySignature, newValueSignature) } } @@ -59,8 +57,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v require(fieldNames.toSet.size == fieldNames.length) - def signature = KeyTable.singleSignature(keySignature, valueSignature)._1 - + def signature = keySignature.merge(valueSignature)._1 def fields = signature.fields def keySchema = keySignature.schema @@ -152,15 +149,54 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v filter(p) } + def changeKey(newKeyNames: Array[String]): KeyTable = KeyTable.apply(KeyTable.toSingleRDD(rdd, nKeys, nValues), signature, newKeyNames) def leftJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { - keySignature.merge(valueSignature) + val ktL = changeKey(joinKeys) + val ktR = other.changeKey(joinKeys) + + require(ktL.keySignature == ktR.keySignature) + + val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) + val newRDD = ktL.rdd.leftOuterJoin(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) } + + KeyTable(newRDD, ktL.keySignature, newValueSignature) } - def rightJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? - def outerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? - def innerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = ??? + def rightJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { + val ktL = changeKey(joinKeys) + val ktR = other.changeKey(joinKeys) + + require(ktL.keySignature == ktR.keySignature) + + val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) + val newRDD = ktL.rdd.rightOuterJoin(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) } + + KeyTable(newRDD, ktL.keySignature, newValueSignature) + } + + def outerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { + val ktL = changeKey(joinKeys) + val ktR = other.changeKey(joinKeys) + + require(ktL.keySignature == ktR.keySignature) + + val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) + val newRDD = ktL.rdd.fullOuterJoin(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) } + + KeyTable(newRDD, ktL.keySignature, newValueSignature) + } + + def innerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { + val ktL = changeKey(joinKeys) + val ktR = other.changeKey(joinKeys) + + require(ktL.keySignature == ktR.keySignature) + + val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) + val newRDD = ktL.rdd.join(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr)) } + + KeyTable(newRDD, ktL.keySignature, newValueSignature) + } - // require(keyNames.toSet == other.keyNames.toSet) - // function to make key order the same } \ No newline at end of file diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 404cf7135c1..f93b07dfe4f 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -15,9 +15,9 @@ class KeyTableSuite extends SparkSuite { var s = State(sc, sqlContext) s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status", inputFile)) val kt = s.ktEnv("kt1") - val kt2 = KeyTable.toPairRDD(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames.toArray) + val kt2 = kt.changeKey(kt.keyNames.toArray) - assert(kt.rdd.fullOuterJoin(kt2).forall { case (k, (v1, v2)) => + assert(kt.rdd.fullOuterJoin(kt2.rdd).forall { case (k, (v1, v2)) => val res = v1 == v2 if (!res) println(s"k=$k v1=$v1 v2=$v2 res=${ v1 == v2 }") @@ -105,19 +105,33 @@ class KeyTableSuite extends SparkSuite { assert(s.ktEnv.contains("kt1") && s.ktEnv("kt1").nRows == 0) } - @Test def testLeftJoin() = { - - } + @Test def testJoin() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + val outputFile = tmpDir.createTempFile("ktImpExp", "tsv") + var s = State(sc, sqlContext) + s = ImportKeyTable.run(s, Array("-n", "ktLeft", "-k", "Sample", "--impute", inputFile)) + s = FilterKeyTableExpr.run(s, Array("-n", "ktLeft", "-c", """Status == "CASE"""", "--keep")) + s = AnnotateKeyTableExpr.run(s, Array("-n", "ktLeft", "-d", "ktRight", "-c", "FakeValue = qPhen * 3, FooVar = qPhen / 5")) - @Test def testRightJoin() = { + s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktLeftJoin", "-t", "left")) + s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktRightJoin", "-t", "right")) + s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktInnerJoin", "-t", "inner")) + s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktOuterJoin", "-t", "outer")) - } + val ktLeft = s.ktEnv("ktLeft") + val ktRight = s.ktEnv("ktRight") - @Test def testInnerJoin() = { + assert(!(ktLeft same ktRight)) - } + val ktLeftJoin = s.ktEnv("ktLeftJoin") + val ktRightJoin = s.ktEnv("ktRightJoin") + val ktInnerJoin = s.ktEnv("ktInnerJoin") + val ktOuterJoin = s.ktEnv("ktOuterJoin") - @Test def testOuterJoin() = { + assert(ktRightJoin same ktRight) + assert(ktLeftJoin.nRows == ktLeft.nRows && + ktLeftJoin.nKeys == ktLeft.nKeys && + ktLeftJoin.nFields == ktLeft.nFields) } } From ece4fd001d9937890406b937085cd3bd3f9e09bd Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Fri, 4 Nov 2016 16:26:42 -0400 Subject: [PATCH 30/51] join almost done --- .../hail/driver/JoinKeyTable.scala | 28 ++++----- .../hail/keytable/KeyTable.scala | 59 +++++++------------ src/test/resources/sampleAnnotations2.tsv | 42 +++++++++++++ .../hail/driver/AggregateByKeySuite.scala | 8 +-- .../hail/methods/KeyTableSuite.scala | 42 +++++++++---- 5 files changed, 113 insertions(+), 66 deletions(-) create mode 100644 src/test/resources/sampleAnnotations2.tsv diff --git a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala index 2807534e0b7..398caaf7585 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala @@ -25,10 +25,6 @@ object JoinKeyTable extends Command { @Args4jOption(required = false, name = "-t", aliases = Array("--join-type"), usage = "type of join") var joinType: String = "left" - - @Args4jOption(required = true, name = "-k", aliases = Array("--join-keys"), - usage = "name of columns to join on") - var joinKeys: String = _ } def newOptions = new Options @@ -65,21 +61,25 @@ object JoinKeyTable extends Command { val ktLeftFieldSet = ktLeft.fieldNames.toSet val ktRightFieldSet = ktRight.fieldNames.toSet - val joinKeys = if (options.joinKeys == null) ktLeft.keyNames.toArray else Parser.parseIdentifierList(options.joinKeys) + if (ktLeft.keySignature != ktRight.keySignature) + fatal( + s"""Key schemas are not Identical. + |Left KeyTable Schema: ${ ktLeft.keySchema } + |Right KeyTable Schema: ${ ktRight.keySchema } + """.stripMargin) - if (!joinKeys.forall(k => ktLeftFieldSet.contains(k)) || !joinKeys.forall(k => ktRightFieldSet.contains(k))) + val valueDuplicates = ktLeft.valueNames.intersect(ktRight.valueNames) + if (valueDuplicates.nonEmpty) fatal( - s"""Join keys not present in both key-tables. - |Keys found: ${ joinKeys.mkString(",") } - |Left KeyTable Schema: ${ ktLeft.schema } - |Right KeyTable Schema: ${ ktRight.schema } + s"""Invalid join operation: cannot merge key-tables with same-name fields. + |Found these fields in both tables: [ ${ valueDuplicates.mkString(", ") } ] """.stripMargin) val joinedKT = options.joinType match { - case "left" => ktLeft.leftJoin(ktRight, joinKeys) - case "right" => ktLeft.rightJoin(ktRight, joinKeys) - case "inner" => ktLeft.innerJoin(ktRight, joinKeys) - case "outer" => ktLeft.outerJoin(ktRight, joinKeys) + case "left" => ktLeft.leftJoin(ktRight) + case "right" => ktLeft.rightJoin(ktRight) + case "inner" => ktLeft.innerJoin(ktRight) + case "outer" => ktLeft.outerJoin(ktRight) case _ => fatal("Did not recognize join type. Pick one of [left, right, inner, outer].") } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 9ce4b88c08d..575dd905610 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -9,7 +9,6 @@ import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TS import org.broadinstitute.hail.methods.Filter import org.broadinstitute.hail.utils._ -case class Table(rdd: RDD[Annotation], signature: TStruct) object KeyTable extends Serializable { def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) @@ -105,8 +104,8 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def mapAnnotations(f: (Annotation) => Annotation, newSignature: TStruct, newKeyNames: Array[String]): KeyTable = KeyTable(KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a)), newSignature, newKeyNames) - def mapAnnotations(f: (Annotation, Annotation) => Annotation, newValueSignature: TStruct): KeyTable = - copy(rdd = rdd.mapValuesWithKey{ case (k, v) => f(k, v) }, valueSignature = newValueSignature) +// def mapAnnotations(f: (Annotation, Annotation) => Annotation): RDD[(Annotation, Annotation)] = +// rdd.mapValuesWithKey{ case (k, v) => f(k, v) } def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) @@ -149,54 +148,40 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v filter(p) } - def changeKey(newKeyNames: Array[String]): KeyTable = KeyTable.apply(KeyTable.toSingleRDD(rdd, nKeys, nValues), signature, newKeyNames) + def leftJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) - def leftJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { - val ktL = changeKey(joinKeys) - val ktR = other.changeKey(joinKeys) + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.leftOuterJoin(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) } - require(ktL.keySignature == ktR.keySignature) - - val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) - val newRDD = ktL.rdd.leftOuterJoin(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) } - - KeyTable(newRDD, ktL.keySignature, newValueSignature) + KeyTable(newRDD, keySignature, newValueSignature) } - def rightJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { - val ktL = changeKey(joinKeys) - val ktR = other.changeKey(joinKeys) - - require(ktL.keySignature == ktR.keySignature) + def rightJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) - val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) - val newRDD = ktL.rdd.rightOuterJoin(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) } + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.rightOuterJoin(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) } - KeyTable(newRDD, ktL.keySignature, newValueSignature) + KeyTable(newRDD, keySignature, newValueSignature) } - def outerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { - val ktL = changeKey(joinKeys) - val ktR = other.changeKey(joinKeys) + def outerJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) - require(ktL.keySignature == ktR.keySignature) + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.fullOuterJoin(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) } - val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) - val newRDD = ktL.rdd.fullOuterJoin(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) } - - KeyTable(newRDD, ktL.keySignature, newValueSignature) + KeyTable(newRDD, keySignature, newValueSignature) } - def innerJoin(other: KeyTable, joinKeys: Array[String]): KeyTable = { - val ktL = changeKey(joinKeys) - val ktR = other.changeKey(joinKeys) - - require(ktL.keySignature == ktR.keySignature) + def innerJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) - val (newValueSignature, merger) = ktL.valueSignature.merge(ktR.valueSignature) - val newRDD = ktL.rdd.join(ktR.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr)) } + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.join(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr)) } - KeyTable(newRDD, ktL.keySignature, newValueSignature) + KeyTable(newRDD, keySignature, newValueSignature) } } \ No newline at end of file diff --git a/src/test/resources/sampleAnnotations2.tsv b/src/test/resources/sampleAnnotations2.tsv new file mode 100644 index 00000000000..16452d99b0c --- /dev/null +++ b/src/test/resources/sampleAnnotations2.tsv @@ -0,0 +1,42 @@ +Sample qPhen2 qPhen3 +HG00096 5540.8 27694 +HG00097 3327.2 16626 +HG00099 1451.2 7246 +HG00100 5714.8 28564 +HG00101 2417.6 12078 +HG00102 3948 19730 +HG00103 372.2 1851 +HG00105 4455.6 22268 +HG00106 5296.8 26474 +HG00107 5945.2 29716 +HG00108 3295 16465 +HG00109 6519 32585 +HG00110 4163.2 20806 +HG00111 6013 30055 +HG00112 4918 24580 +HG00113 1769 8835 +HG00114 6251 31245 +HG00115 5638 28180 +HG00116 2548.4 12732 +HG00117 4724.4 23612 +HG00118 3573.4 17857 +HG00119 6177.2 30876 +HG00120 3919.8 19589 +HG00121 966.4 4822 +HG00122 0 -10 +HG00123 5662.2 28301 +HG00124 538.2 2681 +HG00125 2893.2 14456 +HG00126 5506 27520 +HG00127 2044.8 10214 +HG00128 561.4 2797 +HG00129 1630.2 8141 +HG00130 5212 26050 +HG00131 4312.4 21552 +HG00132 2222.4 11102 +HG00133 4943.2 24706 +HG00136 2469.6 12338 +HG00137 3757.2 18776 +HG00138 1799 8985 +HG00139 385.6 1918 +HG00140 0 -10 diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala index d5e38e6d939..f78aebcb311 100644 --- a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -24,7 +24,7 @@ class AggregateByKeySuite extends SparkSuite { (ktSampleQuery(k, v).map(_.asInstanceOf[String]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) }.collectAsMap() - assert( vds.sampleIdsAndAnnotations.forall{ case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid))}) + assert(vds.sampleIdsAndAnnotations.forall { case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid)) }) } @Test def replicateVariantAggregation() = { @@ -45,7 +45,7 @@ class AggregateByKeySuite extends SparkSuite { (ktVariantQuery(k, v).map(_.asInstanceOf[Variant]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) }.collectAsMap() - assert( vds.variantsAndAnnotations.forall{ case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v))}) + assert(vds.variantsAndAnnotations.forall { case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v)) }) } @Test def replicateGlobalAggregation() = { @@ -62,9 +62,9 @@ class AggregateByKeySuite extends SparkSuite { val (_, ktHetQuery) = kt.query("nHet") val (_, globalHetResult) = vds.queryGlobal("global.nHet") - val ktGlobalResult = kt.rdd.map{ case (k, v) => ktHetQuery(k, v).map(_.asInstanceOf[Long])}.collect().head + val ktGlobalResult = kt.rdd.map { case (k, v) => ktHetQuery(k, v).map(_.asInstanceOf[Long]) }.collect().head val vdsGlobalResult = globalHetResult.map(_.asInstanceOf[Long]) - assert( ktGlobalResult == vdsGlobalResult ) + assert(ktGlobalResult == vdsGlobalResult) } } diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index f93b07dfe4f..cfc7f56cfab 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -15,7 +15,7 @@ class KeyTableSuite extends SparkSuite { var s = State(sc, sqlContext) s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status", inputFile)) val kt = s.ktEnv("kt1") - val kt2 = kt.changeKey(kt.keyNames.toArray) + val kt2 = KeyTable(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames.toArray) assert(kt.rdd.fullOuterJoin(kt2.rdd).forall { case (k, (v1, v2)) => val res = v1 == v2 @@ -106,12 +106,12 @@ class KeyTableSuite extends SparkSuite { } @Test def testJoin() = { - val inputFile = "src/test/resources/sampleAnnotations.tsv" - val outputFile = tmpDir.createTempFile("ktImpExp", "tsv") + val inputFile1 = "src/test/resources/sampleAnnotations.tsv" + val inputFile2 = "src/test/resources/sampleAnnotations2.tsv" + var s = State(sc, sqlContext) - s = ImportKeyTable.run(s, Array("-n", "ktLeft", "-k", "Sample", "--impute", inputFile)) - s = FilterKeyTableExpr.run(s, Array("-n", "ktLeft", "-c", """Status == "CASE"""", "--keep")) - s = AnnotateKeyTableExpr.run(s, Array("-n", "ktLeft", "-d", "ktRight", "-c", "FakeValue = qPhen * 3, FooVar = qPhen / 5")) + s = ImportKeyTable.run(s, Array("-n", "ktLeft", "-k", "Sample", "--impute", inputFile1)) + s = ImportKeyTable.run(s, Array("-n", "ktRight", "-k", "Sample", "--impute", inputFile2)) s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktLeftJoin", "-t", "left")) s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktRightJoin", "-t", "right")) @@ -121,17 +121,37 @@ class KeyTableSuite extends SparkSuite { val ktLeft = s.ktEnv("ktLeft") val ktRight = s.ktEnv("ktRight") - assert(!(ktLeft same ktRight)) - val ktLeftJoin = s.ktEnv("ktLeftJoin") val ktRightJoin = s.ktEnv("ktRightJoin") val ktInnerJoin = s.ktEnv("ktInnerJoin") val ktOuterJoin = s.ktEnv("ktOuterJoin") - assert(ktRightJoin same ktRight) + val nExpectedValues = ktLeft.nValues + ktRight.nValues + + val (_, leftKeyQuery) = ktLeft.query("Sample") + val (_, rightKeyQuery) = ktRight.query("Sample") + + val leftKeys = ktLeft.rdd.map { case (k, v) => leftKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet + val rightKeys = ktRight.rdd.map { case (k, v) => rightKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet + + val nIntersectRows = leftKeys.intersect(rightKeys).size + val nUnionRows = rightKeys.union(leftKeys).size + val nExpectedKeys = ktLeft.nKeys assert(ktLeftJoin.nRows == ktLeft.nRows && - ktLeftJoin.nKeys == ktLeft.nKeys && - ktLeftJoin.nFields == ktLeft.nFields) + ktLeftJoin.nKeys == nExpectedKeys && + ktLeftJoin.nValues == nExpectedValues) + + assert(ktRightJoin.nRows == ktRight.nRows && + ktRightJoin.nKeys == nExpectedKeys && + ktRightJoin.nValues == nExpectedValues) + + assert(ktOuterJoin.nRows == nUnionRows && + ktOuterJoin.nKeys == ktLeft.nKeys && + ktOuterJoin.nValues == nExpectedValues) + + assert(ktInnerJoin.nRows == nIntersectRows && + ktInnerJoin.nKeys == nExpectedKeys && + ktInnerJoin.nValues == nExpectedValues) } } From 3ae0e330f4b0548b8d3d298beec8c4975ef2a28f Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Fri, 4 Nov 2016 17:12:44 -0400 Subject: [PATCH 31/51] started working on aggregate --- .../hail/driver/AggregateByKey.scala | 3 +- .../hail/driver/AggregateKeyTable.scala | 103 ++++++++++++++++++ .../hail/driver/AnnotateKeyTableExpr.scala | 2 +- .../hail/keytable/KeyTable.scala | 10 +- .../hail/methods/KeyTableSuite.scala | 6 +- 5 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala index 537db5ca6ee..d3731225561 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala @@ -82,7 +82,6 @@ object AggregateByKey extends Command { val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) - val zvf = () => zVals.indices.map(zVals).toArray val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { val (v, va, s, sa, aggT) = b @@ -102,7 +101,7 @@ object AggregateByKey extends Command { val key = Annotation.fromSeq(keyF.map(_ ())) (key, (v, va, s, sa, g)) } - }.aggregateByKey(zvf())(seqOp, combOp) + }.aggregateByKey(zVals)(seqOp, combOp) .map { case (k, agg) => resultOp(agg) (k, Annotation.fromSeq(aggF.map(_ ()))) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala new file mode 100644 index 00000000000..eae10ea9852 --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala @@ -0,0 +1,103 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.annotations.Annotation +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.keytable.KeyTable +import org.broadinstitute.hail.methods.Aggregators +import org.broadinstitute.hail.utils._ +import org.kohsuke.args4j.{Option => Args4jOption} + +object AggregateKeyTable extends Command { + + class Options extends BaseOptions { + + @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), + usage = "name of joined key-table") + var dest: String = _ + + @Args4jOption(required = true, name = "-n", aliases = Array("--name"), + usage = "name of key-table to aggregate") + var name: String = _ + + @Args4jOption(required = false, name = "-k", aliases = Array("--key-cond"), + usage = "Named key condition") + var keyCond: String = _ + + @Args4jOption(required = false, name = "-a", aliases = Array("--agg-cond"), + usage = "Named aggregation condition") + var aggCond: String = "left" + } + + def newOptions = new Options + + def name = "aggregatekeytable" + + def description = "Aggregate over fields of key-table to produce new key table" + + def supportsMultiallelic = true + + def requiresVDS = false + + override def hidden = true + + def run(state: State, options: Options): State = { + val ktEnv = state.ktEnv + val name = options.name + val dest = if (options.dest != null) options.dest else name + + val aggCond = options.aggCond + val keyCond = options.keyCond + + val kt = ktEnv.get(name) match { + case Some(x) => x + case None => fatal("no such key table $name in environment") + } + + if (ktEnv.contains(dest)) + warn("destination name already exists -- overwriting previous key-table") + + val ec = EvalContext(kt.fields.map(fd => (fd.name, fd.`type`)): _*) + + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val (aggNameParseTypes, aggF) = + if (aggCond != null) + Parser.parseAnnotationArgs(aggCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + + val nKeys = kt.nKeys + val nValues = kt.nValues + +// val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(ec.copy()) +// +// val seqOp = (array: Array[Aggregator], b: (Any, Any, Any)) => { +// val (k, v, aggT) = b +// KeyTable.setEvalContext(ec, k, v, nKeys, nValues) +// for (i <- array.indices) { +// array(i).seqOp(aggT) +// } +// array +// } +// +// kt.mapAnnotations { (k, v) => +// KeyTable.setEvalContext(ec, k, v, nKeys, nValues) +// val key = Annotation.fromSeq(keyF.map(_ ())) +// (key, (k, v)) +// }.aggregateByKey(zVals)(seqOp, combOp) // FIXME: need to aggregate .aggregateByKey() + + val ktAgg = kt // FIXME: place holder for now + state.copy(ktEnv = state.ktEnv + (dest -> ktAgg)) + } +} + diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala index f1655397849..05e6283279d 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala @@ -83,6 +83,6 @@ object AnnotateKeyTableExpr extends Command { } } - state.copy(ktEnv = state.ktEnv + (dest -> kt.mapAnnotations(f, finalSignature, keyNames))) + state.copy(ktEnv = state.ktEnv + (dest -> KeyTable(kt.mapAnnotations(f), finalSignature, keyNames))) } } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 575dd905610..87872131452 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -9,6 +9,8 @@ import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TS import org.broadinstitute.hail.methods.Filter import org.broadinstitute.hail.utils._ +import scala.reflect.ClassTag + object KeyTable extends Serializable { def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) @@ -101,11 +103,11 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v } } - def mapAnnotations(f: (Annotation) => Annotation, newSignature: TStruct, newKeyNames: Array[String]): KeyTable = - KeyTable(KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a)), newSignature, newKeyNames) + def mapAnnotations[T](f: (Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] = + KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a)) -// def mapAnnotations(f: (Annotation, Annotation) => Annotation): RDD[(Annotation, Annotation)] = -// rdd.mapValuesWithKey{ case (k, v) => f(k, v) } + def mapAnnotations[T](f: (Annotation, Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] = + rdd.map{ case (k, v) => f(k, v)} def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index cfc7f56cfab..c38d8015f08 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -130,6 +130,7 @@ class KeyTableSuite extends SparkSuite { val (_, leftKeyQuery) = ktLeft.query("Sample") val (_, rightKeyQuery) = ktRight.query("Sample") + val (_, leftJoinKeyQuery) = ktLeftJoin.query("Sample") val leftKeys = ktLeft.rdd.map { case (k, v) => leftKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet val rightKeys = ktRight.rdd.map { case (k, v) => rightKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet @@ -140,7 +141,8 @@ class KeyTableSuite extends SparkSuite { assert(ktLeftJoin.nRows == ktLeft.nRows && ktLeftJoin.nKeys == nExpectedKeys && - ktLeftJoin.nValues == nExpectedValues) + ktLeftJoin.nValues == nExpectedValues + ) assert(ktRightJoin.nRows == ktRight.nRows && ktRightJoin.nKeys == nExpectedKeys && @@ -154,4 +156,6 @@ class KeyTableSuite extends SparkSuite { ktInnerJoin.nKeys == nExpectedKeys && ktInnerJoin.nValues == nExpectedValues) } + + @Test def testAggregate() {} } From 79b1e026e110ba56d6891178ae326ef63a0158fd Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 8 Nov 2016 00:42:29 -0500 Subject: [PATCH 32/51] Added most key table operations to pyhail --- python/pyhail/TextTableConfig.py | 30 ++++ python/pyhail/__init__.py | 3 +- python/pyhail/context.py | 42 ++++- python/pyhail/keytable.py | 125 +++++++++++++ .../hail/driver/FilterKeyTableExpr.scala | 2 +- .../hail/keytable/KeyTable.scala | 167 ++++++++++++++++-- .../hail/utils/TextTableReader.scala | 5 + 7 files changed, 351 insertions(+), 23 deletions(-) create mode 100644 python/pyhail/TextTableConfig.py create mode 100644 python/pyhail/keytable.py diff --git a/python/pyhail/TextTableConfig.py b/python/pyhail/TextTableConfig.py new file mode 100644 index 00000000000..473fb44d17d --- /dev/null +++ b/python/pyhail/TextTableConfig.py @@ -0,0 +1,30 @@ + +from pyhail.java import scala_object + +class TextTableConfig: + def __init__(self, noheader = False, impute = False, + comment = None, delimiter = "\t", missing = "NA", types = None): + self.noheader = noheader + self.impute = impute + self.comment = comment + self.delimiter = delimiter + self.missing = missing + self.types = types + + def asString(self): + res = ["--comment", self.comment, "--delimiter", self.delimiter, + "--missing", self.missing] + + if self.noheader: + res.append("--no-header") + + if self.impute: + res.append("--impute") + + return " ".join(res) + + def asJavaObject(self, hc): + return hc.jvm.org.broadinstitute.hail.utils.TextTableConfiguration.apply(self.types, self.comment, + self.delimiter, self.missing, + self.noheader, self.impute) + diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py index fc8f8d53054..555e30c509a 100644 --- a/python/pyhail/__init__.py +++ b/python/pyhail/__init__.py @@ -1,4 +1,5 @@ from pyhail.context import HailContext from pyhail.dataset import VariantDataset +from pyhail.keytable import KeyTable -__all__ = ["HailContext", "VariantDataset"] +__all__ = ["HailContext", "VariantDataset", "KeyTable"] diff --git a/python/pyhail/context.py b/python/pyhail/context.py index b1d3c03ff77..1fc1ddd32dd 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -2,9 +2,10 @@ from pyhail.dataset import VariantDataset from pyhail.java import jarray, scala_object, scala_package_object +from pyhail.keytable import KeyTable +from pyhail.TextTableConfig import TextTableConfig from py4j.protocol import Py4JJavaError - class FatalError(Exception): """:class:`.FatalError` is an error thrown by Hail method failures""" @@ -16,7 +17,6 @@ def __init__(self, message, java_exception): def __str__(self): return self.msg - class HailContext(object): """:class:`.HailContext` is the main entrypoint for PyHail functionality. @@ -67,6 +67,7 @@ def __init__(self, sc=None, log='hail.log', quiet=False, append=False, def _jstate(self, jvds): return self.jvm.org.broadinstitute.hail.driver.State( +<<<<<<< 3ae0e330f4b0548b8d3d298beec8c4975ef2a28f self.jsc, self.jsql_context, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty()) def _raise_py4j_exception(self, e): @@ -74,6 +75,12 @@ def _raise_py4j_exception(self, e): raise FatalError(msg, e.java_exception) def run_command(self, vds, pargs): +======= + self.jsc, self.jsqlContext, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty(), + scala_object(self.jvm.scala.collection.immutable, 'Map').empty()) + + def _run_command(self, vds, pargs): +>>>>>>> Added most key table operations to pyhail jargs = jarray(self.gateway, self.jvm.java.lang.String, pargs) t = self.jvm.org.broadinstitute.hail.driver.ToplevelCommands.lookup(jargs) cmd = t._1() @@ -187,9 +194,35 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No if impute: pargs.append('--impute') +<<<<<<< 488ab167ba42a286d23df1a1d6ed47c0be48d831 return self.run_command(None, pargs) def import_bgen(self, path, tolerance=0.2, sample_file=None, npartitions=None): +======= + def import_keytable(self, path, key_names, npartition = None, config = None): + pathArgs = [] + if isinstance(path, str): + pathArgs.append(path) + else: + for p in path: + pathArgs.append(p) + + if not isinstance(key_names, str): + key_names = ",".join(key_names) + + if not npartition: + npartition = 1 + + if not config: + config = TextTableConfig().asJavaObject(self) + elif isinstance(key_names, TextTableConfig): + config = config.asJavaObject(self) + + return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), + key_names, npartition, config)) + + def import_bgen(self, path, tolerance = 0.2, sample_file = None, npartition = None): +>>>>>>> Added most key table operations to pyhail """Import .bgen files as VariantDataset :param path: .bgen files to import. @@ -226,6 +259,11 @@ def import_bgen(self, path, tolerance=0.2, sample_file=None, npartitions=None): pargs.append('--tolerance') pargs.append(str(tolerance)) +<<<<<<< 488ab167ba42a286d23df1a1d6ed47c0be48d831 +======= + + return self._run_command(None, pargs) +>>>>>>> Added most key table operations to pyhail return self.run_command(None, pargs) diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py new file mode 100644 index 00000000000..73d256838ab --- /dev/null +++ b/python/pyhail/keytable.py @@ -0,0 +1,125 @@ +from pyhail.java import scala_object + +class KeyTable: + """:class:`.KeyTable` ... + + :param SparkContext sc: The pyspark context. + :param JavaKeyTable jkt: The java key table object. + """ + + def __init__(self, hc, jkt): + self.hc = hc + self.jkt = jkt + + # FIXME schema stuff... + def nKeys(self): + return self.jkt.nKeys() + + def nValues(self): + return self.jkt.nValues() + + def nFields(self): + return self.jkt.nFields() + + def schema(self): + return self.jkt.schema() + + def keyNames(self): + return self.jkt.keyNames() + + def valueNames(self): + return self.jkt.valueNames() + + def nRows(self): + """Number of rows in the key-table + + :return: long + """ + return self.jkt.nRows() + + def same(self, other): + """Compares two key-tables + + :param KeyTable other: KeyTable to compare to + + :return: bool + """ + return self.jkt.same(other.jkt) + + def export(self, output, types_file = None): + """Export key-table to a tsv file. + + :param str output: Output file path + + :param str types_file: Output path of types file + + :return: Nothing. + """ + self.jkt.export(self.hc.jsc, output, types_file) + + def filter(self, code, keep = True): + """Filter rows from key-table. + + :param str code: Annotation expression. + + :param bool keep: Keep rows where annotation expression evaluates to True + + :return: KeyTable + """ + return KeyTable(self.hc, self.jkt.filter(code, keep)) + + def annotate(self, code, key_names = None): + """Add fields to key-table. + + :param str code: Annotation expression. + + :param bool keep: Keep rows where annotation expression evaluates to True + + :return: KeyTable + """ + return KeyTable(self.hc, self.jkt.annotate(code, key_names)) + + def join(self, right, how = 'inner'): + """Join two key-tables together. Both key-tables must have identical key schemas + and non-overlapping fields in order to be joined. + + :param KeyTable right: key-table to join + + :param str how: Method for joining two tables together. One of "inner", "outer", "left", "right". + + :return: KeyTable + """ + ## Check keys are same + + ## Check fields do not overlap + + if how == "inner": + return KeyTable(self.hc, self.jkt.innerJoin(right.jkt)) + elif how == "outer": + return KeyTable(self.hc, self.jkt.outerJoin(right.jkt)) + elif how == "left": + return KeyTable(self.hc, self.jkt.leftJoin(right.jkt)) + elif how == "right": + return KeyTable(self.hc, self.jkt.rightJoin(right.jkt)) + else: + pass + + + +# def import_fam(hc, path, ...): +# pass + + + # kt.select(star().except('a'), expr('sum', 'x + b'), expr('a', 'a.b.c = 9')) + # kt.select(star().except('a'), {'sum': 'x + b', 'a': 'update(a.b.c, 9)'}) + + + # def for_all(self, condition): + # pass + + # FIXME returns TypedValue + # def aggregate(value expressions...): + # pass + # + # def aggregate_by_key(self, value expressions...): + # pass diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala index 1965b2229a5..4c97889a76e 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala @@ -53,6 +53,6 @@ object FilterKeyTableExpr extends Command { if (!(options.keep ^ options.remove)) fatal("either `--keep' or `--remove' required, but not both") - state.copy(ktEnv = state.ktEnv + (dest -> kt.filterExpr(cond, keep))) + state.copy(ktEnv = state.ktEnv + (dest -> kt.filter(cond, keep))) } } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 87872131452..233e67f7a0d 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -8,11 +8,40 @@ import org.broadinstitute.hail.check.Gen import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct, Type} import org.broadinstitute.hail.methods.Filter import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.io.TextExporter +import scala.collection.mutable import scala.reflect.ClassTag -object KeyTable extends Serializable { +object KeyTable extends Serializable with TextExporter { + + def importTextTable(sc: SparkContext, path: Array[String], keyNames: String, nPartitions: Int, config: TextTableConfiguration) = { + val files = sc.hadoopConfiguration.globAll(path) + if (files.isEmpty) + fatal("Arguments referred to no files") + + val keyNameArray = Parser.parseIdentifierList(keyNames) + + val (struct, rdd) = + if (nPartitions < 1) + fatal("requested number of partitions in -n/--npartitions must be positive") + else + TextTableReader.read(sc)(files, config, nPartitions) + + + val keyNamesValid = keyNameArray.forall { k => + val res = struct.selfField(k).isDefined + if (!res) + println(s"Key `$k' is not present in input table") + res + } + if (!keyNamesValid) + fatal("Invalid key names given") + + KeyTable(rdd.map(_.value), struct, keyNameArray) + } + def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int, nValues: Int) = @@ -22,7 +51,7 @@ object KeyTable extends Serializable { ec.setAll(annotationToSeq(a, nFields): _*) def toSingleRDD(rdd: RDD[(Annotation, Annotation)], nKeys: Int, nValues: Int): RDD[Annotation] = - rdd.map{ case (k, v) => + rdd.map { case (k, v) => val x = Annotation.fromSeq(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues)) x } @@ -43,8 +72,8 @@ object KeyTable extends Serializable { val newRDD = rdd.map { a => val r = annotationToSeq(a, nFields).zipWithIndex - val keyRow = keyIndices.map( i => r(i)._1) - val valueRow = valueIndices.map( i => r(i)._1) + val keyRow = keyIndices.map(i => r(i)._1) + val valueRow = valueIndices.map(i => r(i)._1) (Annotation.fromSeq(keyRow), Annotation.fromSeq(valueRow)) } @@ -52,26 +81,32 @@ object KeyTable extends Serializable { } } - - case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { require(fieldNames.toSet.size == fieldNames.length) def signature = keySignature.merge(valueSignature)._1 + def fields = signature.fields def keySchema = keySignature.schema + def valueSchema = valueSignature.schema + def schema = signature.schema - def keyNames = keySignature.fields.map(_.name) - def valueNames = valueSignature.fields.map(_.name) + def keyNames = keySignature.fields.map(_.name).toArray + + def valueNames = valueSignature.fields.map(_.name).toArray + def fieldNames = keyNames ++ valueNames def nRows = rdd.count() + def nFields = fields.length + def nKeys = keySignature.size + def nValues = valueSignature.size def same(other: KeyTable): Boolean = { @@ -79,7 +114,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v println(s"signature: this=${ schema } other=${ other.schema }") false } else if (keyNames.toSet != other.keyNames.toSet) { - println(s"keyNames: this=${ keyNames.mkString(",") } other=${ other.keyNames.mkString(",")}") + println(s"keyNames: this=${ keyNames.mkString(",") } other=${ other.keyNames.mkString(",") }") false } else { val thisFieldNames = valueNames @@ -93,7 +128,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val r2 = y.map(r => otherFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet val res = r1 == r2 if (!res) - println(s"k=$k r1=${r1.mkString(",")} r2=${r2.mkString(",")}") + println(s"k=$k r1=${ r1.mkString(",") } r2=${ r2.mkString(",") }") res case _ => println(s"k=$k v1=$v1 v2=$v2") @@ -107,16 +142,17 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a)) def mapAnnotations[T](f: (Annotation, Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] = - rdd.map{ case (k, v) => f(k, v)} + rdd.map { case (k, v) => f(k, v) } def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) val (t, f) = Parser.parse(code, ec) - val f2: (Annotation, Annotation) => Option[Any] = { case (k, v) => - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) - f() + val f2: (Annotation, Annotation) => Option[Any] = { + case (k, v) => + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + f() } (t, f2) @@ -135,9 +171,44 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v (t, f2) } + def annotate(cond: String, keyNameString: String): KeyTable = { + val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + + val (parseTypes, fns) = + if (cond != null) + Parser.parseAnnotationArgs(cond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val inserterBuilder = mutable.ArrayBuilder.make[Inserter] + + val finalSignature = parseTypes.foldLeft(signature) { case (vs, (ids, signature)) => + val (s: TStruct, i) = vs.insert(signature, ids) + inserterBuilder += i + s + } + + val inserters = inserterBuilder.result() + + val keyNameArray = if (keyNameString != null) Parser.parseIdentifierList(keyNameString) else keyNames.toArray + + // val nFields = nFields + + val f: Annotation => Annotation = { a => + KeyTable.setEvalContext(ec, a, nFields) + + fns.zip(inserters) + .foldLeft(a) { case (a1, (fn, inserter)) => + inserter(a1, Option(fn())) + } + } + + KeyTable(mapAnnotations(f), finalSignature, keyNameArray) + } + def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { case (k, v) => p(k, v) }) - def filterExpr(cond: String, keep: Boolean): KeyTable = { + def filter(cond: String, keep: Boolean): KeyTable = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) @@ -154,7 +225,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v require(keySignature == other.keySignature) val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) - val newRDD = rdd.leftOuterJoin(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) } + val newRDD = rdd.leftOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) } KeyTable(newRDD, keySignature, newValueSignature) } @@ -163,7 +234,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v require(keySignature == other.keySignature) val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) - val newRDD = rdd.rightOuterJoin(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) } + val newRDD = rdd.rightOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) } KeyTable(newRDD, keySignature, newValueSignature) } @@ -172,7 +243,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v require(keySignature == other.keySignature) val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) - val newRDD = rdd.fullOuterJoin(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) } + val newRDD = rdd.fullOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) } KeyTable(newRDD, keySignature, newValueSignature) } @@ -181,9 +252,67 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v require(keySignature == other.keySignature) val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) - val newRDD = rdd.join(other.rdd).map{ case (k, (vl, vr)) => (k, merger(vl, vr)) } + val newRDD = rdd.join(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl, vr)) } KeyTable(newRDD, keySignature, newValueSignature) } + def forall(code: String): Boolean = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + + val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) + + val p = (k: Annotation, v: Annotation) => { + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + f().getOrElse(false) + } + + rdd.forall { case (k, v) => p(k, v) } + } + + def exists(code: String): Boolean = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + + val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) + + val p = (k: Annotation, v: Annotation) => { + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + f().getOrElse(false) + } + + rdd.exists { case (k, v) => p(k, v) } + } + + def export(sc: SparkContext, output: String, typesFile: String) = { + val hConf = sc.hadoopConfiguration + + val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + + val (header, types, f) = Parser.parseNamedArgs(fieldNames.map(n => n + " = " + n).mkString(","), ec) + + Option(typesFile).foreach { file => + val typeInfo = header + .getOrElse(types.indices.map(i => s"_$i").toArray) + .zip(types) + + KeyTable.exportTypes(file, hConf, typeInfo) + } + + hConf.delete(output, recursive = true) + // + // val nKeys = nKeys + // val nValues = nValues + + rdd + .mapPartitions { it => + val sb = new StringBuilder() + it.map { case (k, v) => + sb.clear() + KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + f().foreachBetween(x => sb.append(x))(sb += '\t') + sb.result() + } + }.writeTable(output, header.map(_.mkString("\t"))) + } + } \ No newline at end of file diff --git a/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala b/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala index 93e35866f2a..19589520f0f 100644 --- a/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala +++ b/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala @@ -46,6 +46,11 @@ trait TextTableOptions { ) } +object TextTableConfiguration { + def apply(types: String, commentChar: String, separator: String, missing: String, noHeader: Boolean, impute: Boolean): TextTableConfiguration = + TextTableConfiguration(Parser.parseAnnotationTypes(Option(types).getOrElse("")), Option(commentChar), separator, missing, noHeader, impute) +} + case class TextTableConfiguration( types: Map[String, Type] = Map.empty[String, Type], commentChar: Option[String] = None, From 9239956d40343e3017f88cea129502401403bb39 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 8 Nov 2016 13:01:15 -0500 Subject: [PATCH 33/51] removed commands from tests --- .../hail/keytable/KeyTable.scala | 2 +- .../hail/methods/KeyTableSuite.scala | 68 ++++++------------- 2 files changed, 22 insertions(+), 48 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 233e67f7a0d..83bcd013ece 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -190,7 +190,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val inserters = inserterBuilder.result() - val keyNameArray = if (keyNameString != null) Parser.parseIdentifierList(keyNameString) else keyNames.toArray + val keyNameArray = if (keyNameString != null) Parser.parseIdentifierList(keyNameString) else keyNames // val nFields = nFields diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index c38d8015f08..ef96b4e115c 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -12,10 +12,8 @@ class KeyTableSuite extends SparkSuite { @Test def testSingleToPairRDD() = { val inputFile = "src/test/resources/sampleAnnotations.tsv" - var s = State(sc, sqlContext) - s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status", inputFile)) - val kt = s.ktEnv("kt1") - val kt2 = KeyTable(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames.toArray) + val kt = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status", sc.defaultMinPartitions, TextTableConfiguration()) + val kt2 = KeyTable(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames) assert(kt.rdd.fullOuterJoin(kt2.rdd).forall { case (k, (v1, v2)) => val res = v1 == v2 @@ -28,15 +26,14 @@ class KeyTableSuite extends SparkSuite { @Test def testImportExport() = { val inputFile = "src/test/resources/sampleAnnotations.tsv" val outputFile = tmpDir.createTempFile("ktImpExp", "tsv") - var s = State(sc, sqlContext) - s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status", inputFile)) - s = ExportKeyTable.run(s, Array("-n", "kt1", "-o", outputFile)) + val kt = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status", sc.defaultMinPartitions, TextTableConfiguration()) + kt.export(sc, outputFile, null) val importedData = sc.hadoopConfiguration.readLines(inputFile)(_.map(_.value).toIndexedSeq) val exportedData = sc.hadoopConfiguration.readLines(outputFile)(_.map(_.value).toIndexedSeq) intercept[FatalException] { - s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample, Status, BadKeyName", inputFile)) + val kt2 = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status, BadKeyName", sc.defaultMinPartitions, TextTableConfiguration()) } assert(importedData == exportedData) @@ -44,17 +41,10 @@ class KeyTableSuite extends SparkSuite { @Test def testAnnotate() = { val inputFile = "src/test/resources/sampleAnnotations.tsv" - var s = State(sc, sqlContext) - - s = ImportKeyTable.run(s, Array("-n", "kt1", "-k", "Sample", "--impute", inputFile)) - s = AnnotateKeyTableExpr.run(s, Array("-n", "kt1", "-d", "kt2", "-c", """qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""")) - s = AnnotateKeyTableExpr.run(s, Array("-n", "kt2", "-d", "kt3")) - s = AnnotateKeyTableExpr.run(s, Array("-n", "kt3", "-d", "kt4", "-k", "qPhen, NotStatus")) - - val kt1 = s.ktEnv("kt1") - val kt2 = s.ktEnv("kt2") - val kt3 = s.ktEnv("kt3") - val kt4 = s.ktEnv("kt4") + val kt1 = KeyTable.importTextTable(sc, Array(inputFile), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) + val kt2 = kt1.annotate("""qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""", null) + val kt3 = kt2.annotate(null, null) + val kt4 = kt3.annotate(null, "qPhen, NotStatus") val kt1ValueNames = kt1.valueNames.toSet val kt2ValueNames = kt2.valueNames.toSet @@ -88,43 +78,27 @@ class KeyTableSuite extends SparkSuite { val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) val signature = TStruct(("field1", TInt), ("field2", TInt), ("field3", TInt)) val keyNames = Array("field1") - val kt = KeyTable(rdd, signature, keyNames) - - var s = State(sc, sqlContext, ktEnv = Map("kt1" -> kt)) - - s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < 3", "-d", "kt2", "--keep")) - assert(s.ktEnv.contains("kt2") && s.ktEnv("kt2").nRows == 2) - s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < 3 && field3 == 4", "-d", "kt3", "--keep")) - assert(s.ktEnv.contains("kt3") && s.ktEnv("kt3").nRows == 1) + val kt1 = KeyTable(rdd, signature, keyNames) + val kt2 = kt1.filter("field1 < 3", keep = true) + val kt3 = kt1.filter("field1 < 3 && field3 == 4", keep = true) + val kt4 = kt1.filter("field1 == 5 && field2 == 9 && field3 == 0", keep = false) + val kt5 = kt1.filter("field1 < -5 && field3 == 100", keep = true) - s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 == 5 && field2 == 9 && field3 == 0", "-d", "kt3", "--remove")) - assert(s.ktEnv.contains("kt3") && s.ktEnv("kt3").nRows == 2) - - s = FilterKeyTableExpr.run(s, Array("-n", "kt1", "-c", "field1 < -5 && field3 == 100", "--keep")) - assert(s.ktEnv.contains("kt1") && s.ktEnv("kt1").nRows == 0) + assert(kt1.nRows == 3 && kt2.nRows == 2 && kt3.nRows == 1 && kt4.nRows == 2 && kt5.nRows == 0) } @Test def testJoin() = { val inputFile1 = "src/test/resources/sampleAnnotations.tsv" val inputFile2 = "src/test/resources/sampleAnnotations2.tsv" - var s = State(sc, sqlContext) - s = ImportKeyTable.run(s, Array("-n", "ktLeft", "-k", "Sample", "--impute", inputFile1)) - s = ImportKeyTable.run(s, Array("-n", "ktRight", "-k", "Sample", "--impute", inputFile2)) - - s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktLeftJoin", "-t", "left")) - s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktRightJoin", "-t", "right")) - s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktInnerJoin", "-t", "inner")) - s = JoinKeyTable.run(s, Array("-l", "ktLeft", "-r", "ktRight", "-d", "ktOuterJoin", "-t", "outer")) - - val ktLeft = s.ktEnv("ktLeft") - val ktRight = s.ktEnv("ktRight") + val ktLeft = KeyTable.importTextTable(sc, Array(inputFile1), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) + val ktRight = KeyTable.importTextTable(sc, Array(inputFile2), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) - val ktLeftJoin = s.ktEnv("ktLeftJoin") - val ktRightJoin = s.ktEnv("ktRightJoin") - val ktInnerJoin = s.ktEnv("ktInnerJoin") - val ktOuterJoin = s.ktEnv("ktOuterJoin") + val ktLeftJoin = ktLeft.leftJoin(ktRight) + val ktRightJoin = ktLeft.rightJoin(ktRight) + val ktInnerJoin = ktLeft.innerJoin(ktRight) + val ktOuterJoin = ktLeft.outerJoin(ktRight) val nExpectedValues = ktLeft.nValues + ktRight.nValues From 4832394395cf3b2a669285c4029a7c39309ab30e Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Tue, 8 Nov 2016 14:26:03 -0500 Subject: [PATCH 34/51] removed key table commands from scala --- python/pyhail/context.py | 7 -- python/pyhail/dataset.py | 15 ++- .../hail/driver/AggregateByKey.scala | 112 ------------------ .../hail/driver/AggregateKeyTable.scala | 103 ---------------- .../hail/driver/AnnotateKeyTable.scala | 9 -- .../hail/driver/AnnotateKeyTableExpr.scala | 88 -------------- .../broadinstitute/hail/driver/ClearKT.scala | 30 ----- .../broadinstitute/hail/driver/Command.scala | 8 +- .../hail/driver/ExportKeyTable.scala | 81 ------------- .../hail/driver/FilterKeyTable.scala | 9 -- .../hail/driver/FilterKeyTableExpr.scala | 58 --------- .../hail/driver/ImportKeyTable.scala | 67 ----------- .../hail/driver/JoinKeyTable.scala | 89 -------------- .../hail/variant/VariantSampleMatrix.scala | 69 +++++++++++ .../hail/driver/AggregateByKeySuite.scala | 25 ++-- 15 files changed, 92 insertions(+), 678 deletions(-) delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala delete mode 100644 src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 1fc1ddd32dd..440d647bf9f 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -67,7 +67,6 @@ def __init__(self, sc=None, log='hail.log', quiet=False, append=False, def _jstate(self, jvds): return self.jvm.org.broadinstitute.hail.driver.State( -<<<<<<< 3ae0e330f4b0548b8d3d298beec8c4975ef2a28f self.jsc, self.jsql_context, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty()) def _raise_py4j_exception(self, e): @@ -75,12 +74,6 @@ def _raise_py4j_exception(self, e): raise FatalError(msg, e.java_exception) def run_command(self, vds, pargs): -======= - self.jsc, self.jsqlContext, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty(), - scala_object(self.jvm.scala.collection.immutable, 'Map').empty()) - - def _run_command(self, vds, pargs): ->>>>>>> Added most key table operations to pyhail jargs = jarray(self.gateway, self.jvm.java.lang.String, pargs) t = self.jvm.org.broadinstitute.hail.driver.ToplevelCommands.lookup(jargs) cmd = t._1() diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index 8e88e2ac6f5..635ae7d1737 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -1,17 +1,30 @@ from pyhail.java import scala_package_object +from pyhail.keytable import KeyTable import pyspark from py4j.protocol import Py4JJavaError - class VariantDataset(object): def __init__(self, hc, jvds): self.hc = hc self.jvds = jvds + def _raise_py4j_exception(self, e): self.hc._raise_py4j_exception(e) + def aggregate_by_key(self, key_condition = None, agg_condition = None): + """Aggregate by user-defined key and aggregation expressions + + :param str key_condition: Named expression for how key + + :param str agg_condition: Named aggregation expression. + + :return: KeyTable + + """ + return KeyTable(self.hc, self.jvds.aggregateByKey(key_cond, agg_condition)) + def aggregate_intervals(self, input, condition, output): """Aggregate over intervals and export. diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala deleted file mode 100644 index d3731225561..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateByKey.scala +++ /dev/null @@ -1,112 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.annotations.Annotation -import org.broadinstitute.hail.expr._ -import org.broadinstitute.hail.keytable.KeyTable -import org.broadinstitute.hail.methods.Aggregators -import org.kohsuke.args4j.{Option => Args4jOption} - -object AggregateByKey extends Command { - - class Options extends BaseOptions { - @Args4jOption(required = false, name = "-k", aliases = Array("--key-cond"), - usage = "Named key condition", metaVar = "EXPR") - var keyCond: String = _ - - @Args4jOption(required = false, name = "-a", aliases = Array("--agg-cond"), - usage = "Named aggregation condition", metaVar = "EXPR") - var aggCond: String = _ - - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "Name of new key table") - var name: String = _ - } - - def newOptions = new Options - - def name = "aggregatebykey" - - def description = "Creates a new key table with key(s) determined by named expressions and additional columns determined by named aggregator expressions" - - def supportsMultiallelic = true - - def requiresVDS = true - - override def hidden = true - - def run(state: State, options: Options): State = { - - val vds = state.vds - val sc = state.sc - - val aggCond = options.aggCond - val keyCond = options.keyCond - - val aggregationEC = EvalContext(Map( - "v" -> (0, TVariant), - "va" -> (1, vds.vaSignature), - "s" -> (2, TSample), - "sa" -> (3, vds.saSignature), - "global" -> (4, vds.globalSignature))) - - val symTab = Map( - "v" -> (0, TVariant), - "va" -> (1, vds.vaSignature), - "s" -> (2, TSample), - "sa" -> (3, vds.saSignature), - "global" -> (4, vds.globalSignature), - "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype))) - - val ec = EvalContext(symTab) - val a = ec.a - - ec.set(4, vds.globalAnnotation) - aggregationEC.set(4, vds.globalAnnotation) - - val (keyNameParseTypes, keyF) = - if (keyCond != null) - Parser.parseAnnotationArgs(keyCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val (aggNameParseTypes, aggF) = - if (aggCond != null) - Parser.parseAnnotationArgs(aggCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val keyNames = keyNameParseTypes.map(_._1.head) - val aggNames = aggNameParseTypes.map(_._1.head) - - val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - - val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) - - val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { - val (v, va, s, sa, aggT) = b - ec.set(0, v) - ec.set(1, va) - ec.set(2, s) - ec.set(3, sa) - for (i <- array.indices) { - array(i).seqOp(aggT) - } - array - } - - val kt = KeyTable(vds.mapPartitionsWithAll { it => - it.map { case (v, va, s, sa, g) => - ec.setAll(v, va, s, sa, g) - val key = Annotation.fromSeq(keyF.map(_ ())) - (key, (v, va, s, sa, g)) - } - }.aggregateByKey(zVals)(seqOp, combOp) - .map { case (k, agg) => - resultOp(agg) - (k, Annotation.fromSeq(aggF.map(_ ()))) - }, keySignature, valueSignature) - - state.copy(ktEnv = state.ktEnv + (options.name -> kt)) - } -} diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala deleted file mode 100644 index eae10ea9852..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateKeyTable.scala +++ /dev/null @@ -1,103 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.annotations.Annotation -import org.broadinstitute.hail.expr._ -import org.broadinstitute.hail.keytable.KeyTable -import org.broadinstitute.hail.methods.Aggregators -import org.broadinstitute.hail.utils._ -import org.kohsuke.args4j.{Option => Args4jOption} - -object AggregateKeyTable extends Command { - - class Options extends BaseOptions { - - @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), - usage = "name of joined key-table") - var dest: String = _ - - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "name of key-table to aggregate") - var name: String = _ - - @Args4jOption(required = false, name = "-k", aliases = Array("--key-cond"), - usage = "Named key condition") - var keyCond: String = _ - - @Args4jOption(required = false, name = "-a", aliases = Array("--agg-cond"), - usage = "Named aggregation condition") - var aggCond: String = "left" - } - - def newOptions = new Options - - def name = "aggregatekeytable" - - def description = "Aggregate over fields of key-table to produce new key table" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - val ktEnv = state.ktEnv - val name = options.name - val dest = if (options.dest != null) options.dest else name - - val aggCond = options.aggCond - val keyCond = options.keyCond - - val kt = ktEnv.get(name) match { - case Some(x) => x - case None => fatal("no such key table $name in environment") - } - - if (ktEnv.contains(dest)) - warn("destination name already exists -- overwriting previous key-table") - - val ec = EvalContext(kt.fields.map(fd => (fd.name, fd.`type`)): _*) - - val (keyNameParseTypes, keyF) = - if (keyCond != null) - Parser.parseAnnotationArgs(keyCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val (aggNameParseTypes, aggF) = - if (aggCond != null) - Parser.parseAnnotationArgs(aggCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val keyNames = keyNameParseTypes.map(_._1.head) - val aggNames = aggNameParseTypes.map(_._1.head) - - val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - - val nKeys = kt.nKeys - val nValues = kt.nValues - -// val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(ec.copy()) -// -// val seqOp = (array: Array[Aggregator], b: (Any, Any, Any)) => { -// val (k, v, aggT) = b -// KeyTable.setEvalContext(ec, k, v, nKeys, nValues) -// for (i <- array.indices) { -// array(i).seqOp(aggT) -// } -// array -// } -// -// kt.mapAnnotations { (k, v) => -// KeyTable.setEvalContext(ec, k, v, nKeys, nValues) -// val key = Annotation.fromSeq(keyF.map(_ ())) -// (key, (k, v)) -// }.aggregateByKey(zVals)(seqOp, combOp) // FIXME: need to aggregate .aggregateByKey() - - val ktAgg = kt // FIXME: place holder for now - state.copy(ktEnv = state.ktEnv + (dest -> ktAgg)) - } -} - diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala deleted file mode 100644 index 899b0f12ded..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTable.scala +++ /dev/null @@ -1,9 +0,0 @@ -package org.broadinstitute.hail.driver - -object AnnotateKeyTable extends SuperCommand { - def name = "annotatekeytable" - - def description = "Annotate key tables" - - register(AnnotateKeyTableExpr) -} diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala deleted file mode 100644 index 05e6283279d..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateKeyTableExpr.scala +++ /dev/null @@ -1,88 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.annotations._ -import org.broadinstitute.hail.expr.{EvalContext, Parser, TStruct, Type} -import org.broadinstitute.hail.utils._ -import org.kohsuke.args4j.{Option => Args4jOption} -import org.broadinstitute.hail.keytable.KeyTable - -import scala.collection.mutable - -object AnnotateKeyTableExpr extends Command { - - class Options extends BaseOptions { - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "Name of source key table") - var name: String = _ - - @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), - usage = "Name of destination key table (can be same as source)") - var dest: String = _ - - @Args4jOption(required = false, name = "-c", aliases = Array("--cond"), - usage = "Named expression for adding fields to the table", metaVar = "EXPR") - var condition: String = _ - - @Args4jOption(required = false, name = "-k", aliases = Array("--key-names"), - usage = "Names of key in new table (default is existing key names)", metaVar = "EXPR") - var keyNames: String = _ - } - - def newOptions = new Options - - def name = "annotatekeytable expr" - - def description = "Annotate key table using an expression" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - val cond = options.condition - val name = options.name - val dest = if (options.dest != null) options.dest else name - - val kt = state.ktEnv.get(name) match { - case Some(newKT) => - newKT - case None => - fatal("no such key table $name in environment") - } - - val ec = EvalContext(kt.fields.map(fd => (fd.name, fd.`type`)): _*) - - val (parseTypes, fns) = - if (cond != null) - Parser.parseAnnotationArgs(cond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val inserterBuilder = mutable.ArrayBuilder.make[Inserter] - - val finalSignature = parseTypes.foldLeft(kt.signature) { case (vs, (ids, signature)) => - val (s: TStruct, i) = vs.insert(signature, ids) - inserterBuilder += i - s - } - - val inserters = inserterBuilder.result() - - val keyNames = if (options.keyNames != null) Parser.parseIdentifierList(options.keyNames) else kt.keyNames.toArray - - val nFields = kt.nFields - - val f: Annotation => Annotation = { a => - KeyTable.setEvalContext(ec, a, nFields) - - fns.zip(inserters) - .foldLeft(a) { case (a1, (fn, inserter)) => - inserter(a1, Option(fn())) - } - } - - state.copy(ktEnv = state.ktEnv + (dest -> KeyTable(kt.mapAnnotations(f), finalSignature, keyNames))) - } -} diff --git a/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala b/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala deleted file mode 100644 index f5ea193f71a..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/ClearKT.scala +++ /dev/null @@ -1,30 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.kohsuke.args4j.{Option => Args4jOption} - -object ClearKT extends Command { - - class Options extends BaseOptions { - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "Name of key table to clear") - var name: String = _ - } - - def newOptions = new Options - - def name = "ktclear" - - def description = "Clear key table from environment" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - val name = options.name - state.copy( - ktEnv = state.ktEnv - name) - } -} diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index a462e67a113..c46c690bb66 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -14,8 +14,7 @@ case class State(sc: SparkContext, sqlContext: SQLContext, // FIXME make option vds: VariantDataset = null, - env: Map[String, VariantDataset] = Map.empty, - ktEnv: Map[String, KeyTable] = Map.empty) { + env: Map[String, VariantDataset] = Map.empty) { def hadoopConf = sc.hadoopConfiguration } @@ -47,9 +46,7 @@ object ToplevelCommands { + cmd.description)) } - register(AggregateByKey) register(AggregateIntervals) - register(AnnotateKeyTable) register(AnnotateSamples) register(AnnotateVariants) register(AnnotateGlobal) @@ -64,7 +61,6 @@ object ToplevelCommands { register(CountBytes) register(Deduplicate) register(DownsampleVariants) - register(ExportKeyTable) register(ExportPlink) register(ExportGEN) register(ExportGenotypes) @@ -75,7 +71,6 @@ object ToplevelCommands { register(ExportVCF) register(FilterAlleles) register(FilterGenotypes) - register(FilterKeyTable) register(Filtermulti) register(FilterSamples) register(FilterVariants) @@ -90,7 +85,6 @@ object ToplevelCommands { register(ImportAnnotations) register(ImportBGEN) register(ImportGEN) - register(ImportKeyTable) register(ImportPlink) register(ImportVCF) register(ImputeSex) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala deleted file mode 100644 index 4bb665aad35..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportKeyTable.scala +++ /dev/null @@ -1,81 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.expr.{EvalContext, _} -import org.broadinstitute.hail.io.TextExporter -import org.broadinstitute.hail.keytable.KeyTable -import org.broadinstitute.hail.utils._ -import org.kohsuke.args4j.{Option => Args4jOption} - -object ExportKeyTable extends Command with TextExporter { - - class Options extends BaseOptions { - - @Args4jOption(required = true, name = "-o", aliases = Array("--output"), - usage = "path of output tsv") - var output: String = _ - - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "name of key table to be printed to tsv") - var name: String = _ - - @Args4jOption(required = false, name = "-t", aliases = Array("--types"), - usage = "Write the types of parse expressions to a file at the given path") - var typesFile: String = _ - - } - - def newOptions = new Options - - def name = "exportkeytable" - - def description = "Export key table to tsv" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - - val name = options.name - val output = options.output - - val kt = state.ktEnv.get(name) match { - case Some(newKT) => - newKT - case None => - fatal("no such key table $name in environment") - } - - val ec = EvalContext(kt.fields.map(fd => (fd.name, fd.`type`)): _*) - - val (header, types, f) = Parser.parseNamedArgs(kt.fieldNames.map(n => n + " = " + n).mkString(","), ec) - - Option(options.typesFile).foreach { file => - val typeInfo = header - .getOrElse(types.indices.map(i => s"_$i").toArray) - .zip(types) - exportTypes(file, state.hadoopConf, typeInfo) - } - - state.hadoopConf.delete(output, recursive = true) - - val nKeys = kt.nKeys - val nValues = kt.nValues - - kt.rdd - .mapPartitions { it => - val sb = new StringBuilder() - it.map { case (k, v) => - sb.clear() - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) - f().foreachBetween(x => sb.append(x))(sb += '\t') - sb.result() - } - }.writeTable(output, header.map(_.mkString("\t"))) - - state - } -} - diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala deleted file mode 100644 index 81ccba1a1d8..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTable.scala +++ /dev/null @@ -1,9 +0,0 @@ -package org.broadinstitute.hail.driver - -object FilterKeyTable extends SuperCommand { - def name = "filterkeytable" - - def description = "Filter key tables" - - register(FilterKeyTableExpr) -} diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala deleted file mode 100644 index 4c97889a76e..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterKeyTableExpr.scala +++ /dev/null @@ -1,58 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.utils._ -import org.kohsuke.args4j.{Option => Args4jOption} - -object FilterKeyTableExpr extends Command { - - class Options extends BaseOptions { - @Args4jOption(required = true, name = "-c", aliases = Array("--cond"), - usage = "Boolean expression for filtering", metaVar = "EXPR") - var condition: String = _ - - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "Name of source key table") - var name: String = _ - - @Args4jOption(required = false, name = "-d", aliases = Array("--dest"), - usage = "Name of destination key table (can be same as source)") - var dest: String = _ - - @Args4jOption(required = false, name = "--keep", usage = "Keep variants matching condition") - var keep: Boolean = false - - @Args4jOption(required = false, name = "--remove", usage = "Remove variants matching condition") - var remove: Boolean = false - } - - def newOptions = new Options - - def name = "filterkeytable expr" - - def description = "Filter key table using a boolean expression" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - val name = options.name - val cond = options.condition - val keep = options.keep - val dest = if (options.dest != null) options.dest else name - - val kt = state.ktEnv.get(name) match { - case Some(newKT) => - newKT - case None => - fatal("no such key table $name in environment") - } - - if (!(options.keep ^ options.remove)) - fatal("either `--keep' or `--remove' required, but not both") - - state.copy(ktEnv = state.ktEnv + (dest -> kt.filter(cond, keep))) - } -} diff --git a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala deleted file mode 100644 index a1a88d77e34..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/ImportKeyTable.scala +++ /dev/null @@ -1,67 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.broadinstitute.hail.expr.{Parser, TStruct} -import org.broadinstitute.hail.keytable.KeyTable -import org.broadinstitute.hail.utils._ -import org.kohsuke.args4j.{Argument, Option => Args4jOption} - -import scala.collection.JavaConverters._ - -object ImportKeyTable extends Command { - - class Options extends BaseOptions with TextTableOptions { - @Argument(usage = "") - var arguments: java.util.ArrayList[String] = new java.util.ArrayList[String]() - - @Args4jOption(required = true, name = "-n", aliases = Array("--name"), - usage = "name of key table") - var name: String = _ - - @Args4jOption(required = true, name = "-k", aliases = Array("--key-names"), - usage = "comma-separated list of columns to be considered as keys") - var keyNames: String = _ - - @Args4jOption(name = "--npartition", usage = "Number of partitions") - var nPartitions: java.lang.Integer = _ - } - - def newOptions = new Options - - def name = "importkeytable" - - def description = "import key table from tsv" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - val files = state.hadoopConf.globAll(options.arguments.asScala) - if (files.isEmpty) - fatal("Arguments referred to no files") - - val keyNames = Parser.parseIdentifierList(options.keyNames) - - val (struct, rdd) = - if (options.nPartitions != null) { - if (options.nPartitions < 1) - fatal("requested number of partitions in -n/--npartitions must be positive") - TextTableReader.read(state.sc)(files, options.config, options.nPartitions) - } else - TextTableReader.read(state.sc)(files, options.config) - - val keyNamesValid = keyNames.forall { k => - val res = struct.selfField(k).isDefined - if (!res) - println(s"Key `$k' is not present in input table") - res - } - if (!keyNamesValid) - fatal("Invalid key names given") - - state.copy(ktEnv = state.ktEnv + (options.name -> KeyTable(rdd.map(_.value), struct, keyNames))) - } -} - diff --git a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala b/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala deleted file mode 100644 index 398caaf7585..00000000000 --- a/src/main/scala/org/broadinstitute/hail/driver/JoinKeyTable.scala +++ /dev/null @@ -1,89 +0,0 @@ -package org.broadinstitute.hail.driver - -import org.apache.spark.sql.Row -import org.broadinstitute.hail.expr.{EvalContext, _} -import org.broadinstitute.hail.io.TextExporter -import org.broadinstitute.hail.utils._ -import org.kohsuke.args4j.{Option => Args4jOption} - -object JoinKeyTable extends Command { - - class Options extends BaseOptions { - - @Args4jOption(required = true, name = "-d", aliases = Array("--dest"), - usage = "name of joined key-table") - var destName: String = _ - - @Args4jOption(required = true, name = "-l", aliases = Array("--left-name"), - usage = "name of key-table on left") - var leftName: String = _ - - @Args4jOption(required = true, name = "-r", aliases = Array("--right-name"), - usage = "name of key-table on right") - var rightName: String = _ - - @Args4jOption(required = false, name = "-t", aliases = Array("--join-type"), - usage = "type of join") - var joinType: String = "left" - } - - def newOptions = new Options - - def name = "joinkeytable" - - def description = "Join two key tables together to produce new key table" - - def supportsMultiallelic = true - - def requiresVDS = false - - override def hidden = true - - def run(state: State, options: Options): State = { - val ktEnv = state.ktEnv - val leftName = options.leftName - val rightName = options.rightName - val dest = options.destName - - val ktLeft = ktEnv.get(leftName) match { - case Some(kt) => kt - case None => fatal("no such key table $leftName in environment") - } - - val ktRight = ktEnv.get(rightName) match { - case Some(kt) => kt - case None => fatal("no such key table $rightName in environment") - } - - if (ktEnv.contains(dest)) - warn("destination name already exists -- overwriting previous key-table") - - val ktLeftFieldSet = ktLeft.fieldNames.toSet - val ktRightFieldSet = ktRight.fieldNames.toSet - - if (ktLeft.keySignature != ktRight.keySignature) - fatal( - s"""Key schemas are not Identical. - |Left KeyTable Schema: ${ ktLeft.keySchema } - |Right KeyTable Schema: ${ ktRight.keySchema } - """.stripMargin) - - val valueDuplicates = ktLeft.valueNames.intersect(ktRight.valueNames) - if (valueDuplicates.nonEmpty) - fatal( - s"""Invalid join operation: cannot merge key-tables with same-name fields. - |Found these fields in both tables: [ ${ valueDuplicates.mkString(", ") } ] - """.stripMargin) - - val joinedKT = options.joinType match { - case "left" => ktLeft.leftJoin(ktRight) - case "right" => ktLeft.rightJoin(ktRight) - case "inner" => ktLeft.innerJoin(ktRight) - case "outer" => ktLeft.outerJoin(ktRight) - case _ => fatal("Did not recognize join type. Pick one of [left, right, inner, outer].") - } - - state.copy(ktEnv = state.ktEnv + (dest -> joinedKT)) - } -} - diff --git a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala index 3e7ada75e2b..f3d08f01e77 100644 --- a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala +++ b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala @@ -23,6 +23,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.kududb.spark.kudu.{KuduContext, _} import Variant.orderedKey +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.methods.{Aggregators, Filter} import org.broadinstitute.hail.utils @@ -598,6 +599,74 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, */ } + def aggregateByKey(keyCond: String, aggCond: String): KeyTable = { + val aggregationEC = EvalContext(Map( + "v" -> (0, TVariant), + "va" -> (1, vaSignature), + "s" -> (2, TSample), + "sa" -> (3, saSignature), + "global" -> (4, globalSignature))) + + val symTab = Map( + "v" -> (0, TVariant), + "va" -> (1, vaSignature), + "s" -> (2, TSample), + "sa" -> (3, saSignature), + "global" -> (4, globalSignature), + "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype))) + + val ec = EvalContext(symTab) + + ec.set(4, globalAnnotation) + aggregationEC.set(4, globalAnnotation) + + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val (aggNameParseTypes, aggF) = + if (aggCond != null) + Parser.parseAnnotationArgs(aggCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + + val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + + val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { + val (v, va, s, sa, aggT) = b + ec.set(0, v) + ec.set(1, va) + ec.set(2, s) + ec.set(3, sa) + for (i <- array.indices) { + array(i).seqOp(aggT) + } + array + } + + val ktRDD = mapPartitionsWithAll { it => + it.map { case (v, va, s, sa, g) => + ec.setAll(v, va, s, sa, g) + val key = Annotation.fromSeq(keyF.map(_ ())) + (key, (v, va, s, sa, g)) + } + }.aggregateByKey(zVals)(seqOp, combOp) + .map { case (k, agg) => + resultOp(agg) + (k, Annotation.fromSeq(aggF.map(_ ()))) + } + + KeyTable(ktRDD, keySignature, valueSignature) + } + def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(String, T)] = { val localtct = tct diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala index f78aebcb311..79269efe1e1 100644 --- a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -11,20 +11,17 @@ class AggregateByKeySuite extends SparkSuite { var s = State(sc, sqlContext) s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()")) - s = AggregateByKey.run(s, Array("-k", "Sample = s", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) - - val kt = s.ktEnv("kt") - val vds = s.vds + val kt = s.vds.aggregateByKey("Sample = s", "nHet = gs.filter(g => g.isHet).count()") val (_, ktHetQuery) = kt.query("nHet") val (_, ktSampleQuery) = kt.query("Sample") - val (_, saHetQuery) = vds.querySA("sa.nHet") + val (_, saHetQuery) = s.vds.querySA("sa.nHet") val ktSampleResults = kt.rdd.map { case (k, v) => (ktSampleQuery(k, v).map(_.asInstanceOf[String]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) }.collectAsMap() - assert(vds.sampleIdsAndAnnotations.forall { case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid)) }) + assert(s.vds.sampleIdsAndAnnotations.forall { case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid)) }) } @Test def replicateVariantAggregation() = { @@ -32,20 +29,17 @@ class AggregateByKeySuite extends SparkSuite { var s = State(sc, sqlContext) s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) - s = AggregateByKey.run(s, Array("-k", "Variant = v", "-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) - - val kt = s.ktEnv("kt") - val vds = s.vds + val kt = s.vds.aggregateByKey("Variant = v", "nHet = gs.filter(g => g.isHet).count()") val (_, ktHetQuery) = kt.query("nHet") val (_, ktVariantQuery) = kt.query("Variant") - val (_, vaHetQuery) = vds.queryVA("va.nHet") + val (_, vaHetQuery) = s.vds.queryVA("va.nHet") val ktVariantResults = kt.rdd.map { case (k, v) => (ktVariantQuery(k, v).map(_.asInstanceOf[Variant]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) }.collectAsMap() - assert(vds.variantsAndAnnotations.forall { case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v)) }) + assert(s.vds.variantsAndAnnotations.forall { case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v)) }) } @Test def replicateGlobalAggregation() = { @@ -54,13 +48,10 @@ class AggregateByKeySuite extends SparkSuite { s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum().toLong")) - s = AggregateByKey.run(s, Array("-a", "nHet = gs.filter(g => g.isHet).count()", "-n", "kt")) - - val kt = s.ktEnv("kt") - val vds = s.vds + val kt = s.vds.aggregateByKey(null, "nHet = gs.filter(g => g.isHet).count()") val (_, ktHetQuery) = kt.query("nHet") - val (_, globalHetResult) = vds.queryGlobal("global.nHet") + val (_, globalHetResult) = s.vds.queryGlobal("global.nHet") val ktGlobalResult = kt.rdd.map { case (k, v) => ktHetQuery(k, v).map(_.asInstanceOf[Long]) }.collect().head val vdsGlobalResult = globalHetResult.map(_.asInstanceOf[Long]) From 3bcf900a1c8162544db120c0359aa1835e480df6 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 10 Nov 2016 15:43:46 -0500 Subject: [PATCH 35/51] lots of debug statements for aggregate; finished join/forall/exists tests --- python/pyhail/context.py | 21 +++- python/pyhail/dataset.py | 10 +- python/pyhail/keytable.py | 96 ++++++++++--------- .../pyhail/{TextTableConfig.py => utils.py} | 23 ++++- .../org/broadinstitute/hail/expr/AST.scala | 31 +++++- .../org/broadinstitute/hail/expr/Type.scala | 15 ++- .../hail/keytable/KeyTable.scala | 86 ++++++++++++++++- .../hail/methods/Aggregators.scala | 4 +- src/test/resources/sampleAnnotations2.tsv | 82 ++++++++++++++++ .../org/broadinstitute/hail/SparkSuite.scala | 2 +- .../hail/methods/KeyTableSuite.scala | 50 +++++++++- 11 files changed, 353 insertions(+), 67 deletions(-) rename python/pyhail/{TextTableConfig.py => utils.py} (56%) diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 440d647bf9f..a0a4a5e75b8 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -193,6 +193,21 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No def import_bgen(self, path, tolerance=0.2, sample_file=None, npartitions=None): ======= def import_keytable(self, path, key_names, npartition = None, config = None): + """Import tabular file as KeyTable + + :param path: .tsv files to import. + :type path: str or list of str + + :param key_names: The name(s) of fields to be considered keys + :type key_names: str or list[str] + + :param npartition: Number of partitions. + :type npartition: int or None + + :param :class:`.TextTableConfig` config: Configuration options for importing text files + + :rtype: :class:`.KeyTable` + """ pathArgs = [] if isinstance(path, str): pathArgs.append(path) @@ -204,12 +219,12 @@ def import_keytable(self, path, key_names, npartition = None, config = None): key_names = ",".join(key_names) if not npartition: - npartition = 1 + npartition = self.sc.defaultMinPartitions if not config: - config = TextTableConfig().asJavaObject(self) + config = TextTableConfig().toJavaObject(self) elif isinstance(key_names, TextTableConfig): - config = config.asJavaObject(self) + config = config.toJavaObject(self) return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), key_names, npartition, config)) diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index 635ae7d1737..c8a09a82362 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -9,21 +9,19 @@ def __init__(self, hc, jvds): self.hc = hc self.jvds = jvds - def _raise_py4j_exception(self, e): self.hc._raise_py4j_exception(e) - def aggregate_by_key(self, key_condition = None, agg_condition = None): + def aggregate_by_key(self, key_condition=None, agg_condition=None): """Aggregate by user-defined key and aggregation expressions - :param str key_condition: Named expression for how key + :param str key_condition: Named expression for which fields are keys :param str agg_condition: Named aggregation expression. - :return: KeyTable - + :rtype: :class`.KeyTable` """ - return KeyTable(self.hc, self.jvds.aggregateByKey(key_cond, agg_condition)) + return KeyTable(self.hc, self.jvds.aggregateByKey(key_condition, agg_condition)) def aggregate_intervals(self, input, condition, output): """Aggregate over intervals and export. diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index 73d256838ab..656b35aa071 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -1,34 +1,44 @@ from pyhail.java import scala_object class KeyTable: - """:class:`.KeyTable` ... + """:class:`.KeyTable` is Hail's version of a SQL + table where fields can be designated as keys. - :param SparkContext sc: The pyspark context. - :param JavaKeyTable jkt: The java key table object. + :param :class:`.HailContext` hc: Hail spark context. + :param JavaKeyTable jkt: Java KeyTable object. """ def __init__(self, hc, jkt): self.hc = hc self.jkt = jkt - # FIXME schema stuff... - def nKeys(self): - return self.jkt.nKeys() - - def nValues(self): - return self.jkt.nValues() - def nFields(self): + """Number of fields in the key-table + + :return: int + """ return self.jkt.nFields() def schema(self): + """Key-table schema + + :return: ??? + """ return self.jkt.schema() def keyNames(self): + """Field names that are keys + + :return: list[str] + """ return self.jkt.keyNames() - def valueNames(self): - return self.jkt.valueNames() + def fieldNames(self): + """Field names + + :return: list[str] + """ + return self.jkt.fieldNames() def nRows(self): """Number of rows in the key-table @@ -38,22 +48,22 @@ def nRows(self): return self.jkt.nRows() def same(self, other): - """Compares two key-tables + """Test whether two key-tables are identical - :param KeyTable other: KeyTable to compare to + :param :class:`.KeyTable` other: KeyTable to compare to :return: bool """ return self.jkt.same(other.jkt) def export(self, output, types_file = None): - """Export key-table to a tsv file. + """Export key-table to a TSV file. :param str output: Output file path :param str types_file: Output path of types file - :return: Nothing. + :rtype: Nothing. """ self.jkt.export(self.hc.jsc, output, types_file) @@ -64,7 +74,7 @@ def filter(self, code, keep = True): :param bool keep: Keep rows where annotation expression evaluates to True - :return: KeyTable + :return: :class:`.KeyTable` """ return KeyTable(self.hc, self.jkt.filter(code, keep)) @@ -73,9 +83,9 @@ def annotate(self, code, key_names = None): :param str code: Annotation expression. - :param bool keep: Keep rows where annotation expression evaluates to True + :param str key_names: Comma separated list of field names to be treated as a key - :return: KeyTable + :return: :class:`.KeyTable` """ return KeyTable(self.hc, self.jkt.annotate(code, key_names)) @@ -83,43 +93,39 @@ def join(self, right, how = 'inner'): """Join two key-tables together. Both key-tables must have identical key schemas and non-overlapping fields in order to be joined. - :param KeyTable right: key-table to join + :param :class:`.KeyTable` right: Key-table to join :param str how: Method for joining two tables together. One of "inner", "outer", "left", "right". - :return: KeyTable + :return: :class:`.KeyTable` """ - ## Check keys are same + return KeyTable(self.hc, self.jkt.join(right.jkt, how)) - ## Check fields do not overlap + def aggregate(self, key_cond, agg_cond): + """Group by key condition and aggregate results - if how == "inner": - return KeyTable(self.hc, self.jkt.innerJoin(right.jkt)) - elif how == "outer": - return KeyTable(self.hc, self.jkt.outerJoin(right.jkt)) - elif how == "left": - return KeyTable(self.hc, self.jkt.leftJoin(right.jkt)) - elif how == "right": - return KeyTable(self.hc, self.jkt.rightJoin(right.jkt)) - else: - pass + :param str key_cond: Named expression defining keys in the new key-table + :param str agg_cond: Named expression specifying how new fields are computed + :return: :class:`.KeyTable` + """ + return KeyTable(self.hc, self.jkt.aggregate(key_cond, agg_cond)) -# def import_fam(hc, path, ...): -# pass + def forall(self, code): + """Tests whether a condition is true for all rows + :param str code: Boolean expression - # kt.select(star().except('a'), expr('sum', 'x + b'), expr('a', 'a.b.c = 9')) - # kt.select(star().except('a'), {'sum': 'x + b', 'a': 'update(a.b.c, 9)'}) + :return: bool + """ + return self.jkt.forall(cond) + def exists(self, code): + """Tests whether a condition is true for any row - # def for_all(self, condition): - # pass + :param str code: Boolean expression - # FIXME returns TypedValue - # def aggregate(value expressions...): - # pass - # - # def aggregate_by_key(self, value expressions...): - # pass + :return: bool + """ + return self.jkt.exists(cond) \ No newline at end of file diff --git a/python/pyhail/TextTableConfig.py b/python/pyhail/utils.py similarity index 56% rename from python/pyhail/TextTableConfig.py rename to python/pyhail/utils.py index 473fb44d17d..c4a9da4256a 100644 --- a/python/pyhail/TextTableConfig.py +++ b/python/pyhail/utils.py @@ -1,7 +1,20 @@ -from pyhail.java import scala_object class TextTableConfig: + """:class:`.TextTableConfig` specifies additional options for importing TSV files. + + :param bool noheader: File has no header and columns should be indicated by `_1, _2, ... _N' (0-indexed) + + :param bool impute: Impute column types from the file + + :param str comment: Skip lines beginning with the given pattern + + :param str delimiter: Field delimiter regex + + :param str missing: Specify identifier to be treated as missing + + :param str types: Define types of fields in annotations files + """ def __init__(self, noheader = False, impute = False, comment = None, delimiter = "\t", missing = "NA", types = None): self.noheader = noheader @@ -11,7 +24,7 @@ def __init__(self, noheader = False, impute = False, self.missing = missing self.types = types - def asString(self): + def __str__(self): res = ["--comment", self.comment, "--delimiter", self.delimiter, "--missing", self.missing] @@ -23,7 +36,11 @@ def asString(self): return " ".join(res) - def asJavaObject(self, hc): + def toJavaObject(self, hc): + """Convert to java TextTableConfiguration object + + :param :class:`.HailContext` hc: Hail spark context. + """ return hc.jvm.org.broadinstitute.hail.utils.TextTableConfiguration.apply(self.types, self.comment, self.delimiter, self.missing, self.noheader, self.impute) diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 08b199ff3bc..7a7312b54fa 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -20,7 +20,15 @@ import org.broadinstitute.hail.utils.EitherIsAMonad._ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunctions: ArrayBuffer[Aggregator]) { def setAll(args: Any*) { - args.zipWithIndex.foreach { case (arg, i) => a(i) = arg } + try { + args.zipWithIndex.foreach { case (arg, i) => +// println(s"$arg, $i a=$a") + a(i) = arg +// println(s"$arg, $i a=$a st=$st") + } + } catch { + case _: IndexOutOfBoundsException => println("error") + } } def set(index: Int, arg: Any) { @@ -661,11 +669,25 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST body.typecheck(agg.ec.copy(st = st)) `type` = body.`type` match { case t: Type => + println(s"maptypecheck bodyType: ${body.getClass}") + println(s"maptypecheck param: ${param}") + println(s"maptypecheck localIdx: $localIdx") + println(s"maptypecheck elementtype: ${agg.elementType}") + println(s"maptypecheck localA: ${localA}") + println(s"maptypecheck localA identityCode: ${System.identityHashCode(localA)}") + println(s"maptypecheck ec: ${agg.ec}") + println(s"maptypecheck ec identityCode: ${System.identityHashCode(agg.ec)}") val fn = body.eval(agg.ec.copy(st = st)) val mapF = (a: Any) => { localA(localIdx) = a + println(s"mapF ec: ${agg.ec}") + println(s"mapF LocalA: ${localA}") + println(s"mapF ec identity code: ${System.identityHashCode(agg.ec)}") + println(s"mapF LocalA identity code: ${System.identityHashCode(localA)}") + println(s"mapF result eval fn(): ${fn()}") fn() } + MappedAggregable(agg, t, mapF) case error => parseError(s"method `$method' expects a lambda function (param => Any), got invalid mapping (param => $error)") @@ -1696,8 +1718,15 @@ case class SliceArray(posn: Position, f: AST, idx1: Option[AST], idx2: Option[AS case class SymRef(posn: Position, symbol: String) extends AST(posn) { def eval(ec: EvalContext): () => Any = { + println(s"symref posn: $posn") + println(s"symref symbol: $symbol") val localI = ec.st(symbol)._1 val localA = ec.a + println(s"symref ec: ${ec}") + println(s"symref localI: ${ec.st(symbol)._1}") + println(s"symref localA: ${ec.a}") + println(s"symref ec identityhash: ${System.identityHashCode(ec)}") + println(s"symref ec.a identityhash: ${System.identityHashCode(ec.a)}") if (localI < 0) () => 0 // FIXME placeholder else diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala index 6abd5d0260f..6b2a73a2941 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala @@ -6,6 +6,7 @@ import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.annotations.{Annotation, AnnotationPathException, _} import org.broadinstitute.hail.check.Arbitrary._ import org.broadinstitute.hail.check.{Gen, _} +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils import org.broadinstitute.hail.utils.{Interval, StringEscapeUtils} import org.broadinstitute.hail.variant.{AltAllele, Genotype, Locus, Variant} @@ -249,6 +250,14 @@ abstract class TAggregable extends BaseType { def f: (Any) => Any } +case class KeyTableAggregable(ec: EvalContext, elementType: Type, idx: Int) extends TAggregable { + def f: (Any) => Any = { + (a: Any) => { + KeyTable.annotationToSeq(a, ec.st.size)(idx) + } + } +} + case class BaseAggregable(ec: EvalContext, elementType: Type) extends TAggregable { def f: (Any) => Any = identity } @@ -274,9 +283,13 @@ case class MappedAggregable(parent: TAggregable, elementType: Type, mapF: (Any) def f: (Any) => Any = { val parentF = parent.f (a: Any) => { + println(s"mapagg a: $a") val prev = parentF(a) - if (prev != null) + println(s"mapagg prev: $prev") + if (prev != null) { + println(s"mapagg result: ${mapF(prev)}") mapF(prev) + } else null } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 83bcd013ece..cd5032a6adf 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -5,8 +5,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.broadinstitute.hail.annotations._ import org.broadinstitute.hail.check.Gen -import org.broadinstitute.hail.expr.{BaseType, EvalContext, Parser, TBoolean, TStruct, Type} -import org.broadinstitute.hail.methods.Filter +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.methods.{Aggregators, Filter} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.io.TextExporter @@ -21,6 +21,8 @@ object KeyTable extends Serializable with TextExporter { if (files.isEmpty) fatal("Arguments referred to no files") + sc.defaultMinPartitions + val keyNameArray = Parser.parseIdentifierList(keyNames) val (struct, rdd) = @@ -206,7 +208,8 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v KeyTable(mapAnnotations(f), finalSignature, keyNameArray) } - def filter(p: (Annotation, Annotation) => Boolean): KeyTable = copy(rdd = rdd.filter { case (k, v) => p(k, v) }) + def filter(p: (Annotation, Annotation) => Boolean): KeyTable = + copy(rdd = rdd.filter { case (k, v) => p(k, v) }) def filter(cond: String, keep: Boolean): KeyTable = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) @@ -221,6 +224,26 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v filter(p) } + def join(other: KeyTable, joinType: String): KeyTable = { + if (keySignature != other.keySignature) + fatal(s"""Key signatures must be identical. + |Left signature: ${keySignature} + |Right signature: ${other.keySignature}""".stripMargin) + + val overlappingFields = valueNames.toSet.intersect(other.valueNames.toSet) + if (overlappingFields.nonEmpty) + fatal(s"""Fields that are not keys cannot be present in both key-tables. + |Overlapping fields: ${overlappingFields.mkString(", ")}""".stripMargin) + + joinType match { + case "left" => leftJoin(other) + case "right" => rightJoin(other) + case "inner" => innerJoin(other) + case "outer" => outerJoin(other) + case _ => fatal("Invalid join type specified. Choose one of `left', `right', `inner', `outer'") + } + } + def leftJoin(other: KeyTable): KeyTable = { require(keySignature == other.keySignature) @@ -315,4 +338,61 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v }.writeTable(output, header.map(_.mkString("\t"))) } + def aggregate(keyCond: String, aggCond: String): KeyTable = { + + val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + val ec = EvalContext(fields.zipWithIndex.map{ case (fd, i) => (fd.name, (-1, KeyTableAggregable(aggregationEC, fd.`type`, i)))}.toMap) + + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, aggregationEC, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val (aggNameParseTypes, aggF) = + if (aggCond != null) + Parser.parseAnnotationArgs(aggCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + + val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + + val seqOp = (array: Array[Aggregator], b: Any) => { + println(s"values inside b = " + KeyTable.annotationToSeq(b, nFields)) + println(s"keytable seqop pre-setec ec.a: ${aggregationEC.a}") + println(s"keytable seqop pre-setec ec: ${aggregationEC}") + println(s"keytable seqop pre-setec ec.a pointer: ${System.identityHashCode(aggregationEC.a)}") + println(s"keytable seqop pre-setec ec pointer: ${System.identityHashCode(aggregationEC)}") + KeyTable.setEvalContext(aggregationEC, b, nFields) + println(s"keytable seqop post-setec ec.a: ${aggregationEC.a}") + println(s"keytable seqop post-setec ec: ${aggregationEC}") + println(s"keytable seqop post-setec ec.a pointer: ${System.identityHashCode(aggregationEC.a)}") + println(s"keytable seqop post-setec ec pointer: ${System.identityHashCode(aggregationEC)}") + for (i <- array.indices) { + println(s"keytable seqop array($i): ${array(i)}") + array(i).seqOp(b) + } + array + } + + val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions{ it => + it.map { a => + KeyTable.setEvalContext(aggregationEC, a, nFields) + val key = Annotation.fromSeq(keyF.map(_ ())) + (key, a) + } + }.aggregateByKey(zVals)(seqOp, combOp) + .map { case (k, agg) => + resultOp(agg) + (k, Annotation.fromSeq(aggF.map(_ ()))) + } + + KeyTable(newRDD, keySignature, valueSignature) + } } \ No newline at end of file diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index 00be90c835f..1e57876a3ad 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -303,8 +303,10 @@ class SumAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Doubl def result = _state - def seqOp(x: Any) { + override def seqOp(x: Any) { + println(s"sumagg seqop input: $x") val r = f(x) + println(s"sumagg seqop result: $r") if (r != null) _state += DoubleNumericConversion.to(r) } diff --git a/src/test/resources/sampleAnnotations2.tsv b/src/test/resources/sampleAnnotations2.tsv index 16452d99b0c..c4ded00778e 100644 --- a/src/test/resources/sampleAnnotations2.tsv +++ b/src/test/resources/sampleAnnotations2.tsv @@ -40,3 +40,85 @@ HG00137 3757.2 18776 HG00138 1799 8985 HG00139 385.6 1918 HG00140 0 -10 +HG00096_B 5540.8 27694 +HG00097_B 3327.2 16626 +HG00099_B 1451.2 7246 +HG00100_B 5714.8 28564 +HG00101_B 2417.6 12078 +HG00102_B 3948 19730 +HG00103_B 372.2 1851 +HG00105_B 4455.6 22268 +HG00106_B 5296.8 26474 +HG00107_B 5945.2 29716 +HG00108_B 3295 16465 +HG00109_B 6519 32585 +HG00110_B 4163.2 20806 +HG00111_B 6013 30055 +HG00112_B 4918 24580 +HG00113_B 1769 8835 +HG00114_B 6251 31245 +HG00115_B 5638 28180 +HG00116_B 2548.4 12732 +HG00117_B 4724.4 23612 +HG00118_B 3573.4 17857 +HG00119_B 6177.2 30876 +HG00120_B 3919.8 19589 +HG00121_B 966.4 4822 +HG00122_B 0 -10 +HG00123_B 5662.2 28301 +HG00124_B 538.2 2681 +HG00125_B 2893.2 14456 +HG00126_B 5506 27520 +HG00127_B 2044.8 10214 +HG00128_B 561.4 2797 +HG00129_B 1630.2 8141 +HG00130_B 5212 26050 +HG00131_B 4312.4 21552 +HG00132_B 2222.4 11102 +HG00133_B 4943.2 24706 +HG00136_B 2469.6 12338 +HG00137_B 3757.2 18776 +HG00138_B 1799 8985 +HG00139_B 385.6 1918 +HG00140_B 0 -10 +HG00096_B_B 5540.8 27694 +HG00097_B_B 3327.2 16626 +HG00099_B_B 1451.2 7246 +HG00100_B_B 5714.8 28564 +HG00101_B_B 2417.6 12078 +HG00102_B_B 3948 19730 +HG00103_B_B 372.2 1851 +HG00105_B_B 4455.6 22268 +HG00106_B_B 5296.8 26474 +HG00107_B_B 5945.2 29716 +HG00108_B_B 3295 16465 +HG00109_B_B 6519 32585 +HG00110_B_B 4163.2 20806 +HG00111_B_B 6013 30055 +HG00112_B_B 4918 24580 +HG00113_B_B 1769 8835 +HG00114_B_B 6251 31245 +HG00115_B_B 5638 28180 +HG00116_B_B 2548.4 12732 +HG00117_B_B 4724.4 23612 +HG00118_B_B 3573.4 17857 +HG00119_B_B 6177.2 30876 +HG00120_B_B 3919.8 19589 +HG00121_B_B 966.4 4822 +HG00122_B_B 0 -10 +HG00123_B_B 5662.2 28301 +HG00124_B_B 538.2 2681 +HG00125_B_B 2893.2 14456 +HG00126_B_B 5506 27520 +HG00127_B_B 2044.8 10214 +HG00128_B_B 561.4 2797 +HG00129_B_B 1630.2 8141 +HG00130_B_B 5212 26050 +HG00131_B_B 4312.4 21552 +HG00132_B_B 2222.4 11102 +HG00133_B_B 4943.2 24706 +HG00136_B_B 2469.6 12338 +HG00137_B_B 3757.2 18776 +HG00138_B_B 1799 8985 +HG00139_B_B 385.6 1918 +HG00140_B_B 0 -10 diff --git a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala index bbe2cddac4a..3f35a119643 100644 --- a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala @@ -24,7 +24,7 @@ class SparkSuite extends TestNGSuite { @BeforeClass def startSpark() { val master = System.getProperty("hail.master") - sc = SparkManager.createSparkContext("Hail.TestNG", Option(master), "local[2]") + sc = SparkManager.createSparkContext("Hail.TestNG", Option(master), "local[1]") sqlContext = SparkManager.createSQLContext() diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index ef96b4e115c..3412c8d8d9f 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -105,6 +105,7 @@ class KeyTableSuite extends SparkSuite { val (_, leftKeyQuery) = ktLeft.query("Sample") val (_, rightKeyQuery) = ktRight.query("Sample") val (_, leftJoinKeyQuery) = ktLeftJoin.query("Sample") + val (_, rightJoinKeyQuery) = ktRightJoin.query("Sample") val leftKeys = ktLeft.rdd.map { case (k, v) => leftKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet val rightKeys = ktRight.rdd.map { case (k, v) => rightKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet @@ -115,12 +116,18 @@ class KeyTableSuite extends SparkSuite { assert(ktLeftJoin.nRows == ktLeft.nRows && ktLeftJoin.nKeys == nExpectedKeys && - ktLeftJoin.nValues == nExpectedValues + ktLeftJoin.nValues == nExpectedValues && + ktLeftJoin.filter{ case (k, v) => + !rightKeys.contains(leftJoinKeyQuery(k, v).map(_.asInstanceOf[String])) + }.forall("isMissing(qPhen2) && isMissing(qPhen3)") ) assert(ktRightJoin.nRows == ktRight.nRows && ktRightJoin.nKeys == nExpectedKeys && - ktRightJoin.nValues == nExpectedValues) + ktRightJoin.nValues == nExpectedValues && + ktRightJoin.filter{ case (k, v) => + !leftKeys.contains(rightJoinKeyQuery(k, v).map(_.asInstanceOf[String])) + }.forall("isMissing(Status) && isMissing(qPhen)")) assert(ktOuterJoin.nRows == nUnionRows && ktOuterJoin.nKeys == ktLeft.nKeys && @@ -131,5 +138,42 @@ class KeyTableSuite extends SparkSuite { ktInnerJoin.nValues == nExpectedValues) } - @Test def testAggregate() {} + @Test def testAggregate() { + val data = Array(Array("Case", 9, 0), Array("Case", 3, 4), Array("Control", 2, 3), Array("Control", 1, 5)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("field1", TString), ("field2", TInt), ("field3", TInt)) + val keyNames = Array("field1") + + val kt1 = KeyTable(rdd, signature, keyNames) +// val kt2 = kt1.aggregate("Status = field1", "field4 = field2.sum(), field5 = field2.map(f => field2 + field3).sum()") + //val kt2 = kt1.aggregate("Status = field1", "field5 = field2.map(f => field2 + field3).sum()") +// val kt2 = kt1.aggregate("Status = field1", "X = field2.map(f => field2).sum(), Y = field2.sum(), Z = field2.map(f => f).sum()") +// val result = Array(Array("Case", 12.0, 12.0, 12.0), Array("Control", 3.0, 3.0, 3.0)) +// val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) +// val resSignature = TStruct(("Status", TString), ("X", TDouble), ("Y", TDouble), ("Z", TDouble)) +// val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) + + val kt2 = kt1.aggregate("Status = field1", "X = field2.map(f => field2).sum()") + val result = Array(Array("Case", 12.0), Array("Control", 3.0)) + val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) + val resSignature = TStruct(("Status", TString), ("X", TDouble)) + val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) + + + assert(kt2 same ktResult) + } + + @Test def testForallExists() { + val data = Array(Array("Sample1", 9, 5), Array("Sample2", 3, 5), Array("Sample3", 2, 5), Array("Sample4", 1, 5)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("Sample", TString), ("field1", TInt), ("field2", TInt)) + val keyNames = Array("Sample") + + val kt = KeyTable(rdd, signature, keyNames) + assert(kt.forall("field2 == 5 && field1 != 0")) + assert(!kt.forall("field2 == 0 && field1 == 5")) + assert(kt.exists("""Sample == "Sample1" && field1 == 9 && field2 == 5""")) + assert(!kt.exists("""Sample == "Sample1" && field1 == 13 && field2 == 2""")) + } + } From fb9cadf95c4d0270a80197a481f74cf2e8733c9e Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Sun, 13 Nov 2016 14:01:23 -0500 Subject: [PATCH 36/51] started getting docs to work --- python/pyhail/__init__.py | 3 +- python/pyhail/context.py | 85 ++++++++++++++++-------------------- python/pyhail/docs/index.rst | 6 +++ python/pyhail/keytable.py | 46 ++++++++++--------- python/pyhail/utils.py | 4 +- 5 files changed, 73 insertions(+), 71 deletions(-) diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py index 555e30c509a..fd18b118917 100644 --- a/python/pyhail/__init__.py +++ b/python/pyhail/__init__.py @@ -1,5 +1,6 @@ from pyhail.context import HailContext from pyhail.dataset import VariantDataset from pyhail.keytable import KeyTable +from pyhail.utils import TextTableConfig -__all__ = ["HailContext", "VariantDataset", "KeyTable"] +__all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig"] diff --git a/python/pyhail/context.py b/python/pyhail/context.py index a0a4a5e75b8..36bc4d234fb 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -187,50 +187,9 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No if impute: pargs.append('--impute') -<<<<<<< 488ab167ba42a286d23df1a1d6ed47c0be48d831 return self.run_command(None, pargs) - def import_bgen(self, path, tolerance=0.2, sample_file=None, npartitions=None): -======= - def import_keytable(self, path, key_names, npartition = None, config = None): - """Import tabular file as KeyTable - - :param path: .tsv files to import. - :type path: str or list of str - - :param key_names: The name(s) of fields to be considered keys - :type key_names: str or list[str] - - :param npartition: Number of partitions. - :type npartition: int or None - - :param :class:`.TextTableConfig` config: Configuration options for importing text files - - :rtype: :class:`.KeyTable` - """ - pathArgs = [] - if isinstance(path, str): - pathArgs.append(path) - else: - for p in path: - pathArgs.append(p) - - if not isinstance(key_names, str): - key_names = ",".join(key_names) - - if not npartition: - npartition = self.sc.defaultMinPartitions - - if not config: - config = TextTableConfig().toJavaObject(self) - elif isinstance(key_names, TextTableConfig): - config = config.toJavaObject(self) - - return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), - key_names, npartition, config)) - def import_bgen(self, path, tolerance = 0.2, sample_file = None, npartition = None): ->>>>>>> Added most key table operations to pyhail """Import .bgen files as VariantDataset :param path: .bgen files to import. @@ -267,13 +226,8 @@ def import_bgen(self, path, tolerance = 0.2, sample_file = None, npartition = No pargs.append('--tolerance') pargs.append(str(tolerance)) -<<<<<<< 488ab167ba42a286d23df1a1d6ed47c0be48d831 -======= - - return self._run_command(None, pargs) ->>>>>>> Added most key table operations to pyhail - return self.run_command(None, pargs) + return self._run_command(None, pargs) def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, chromosome=None): """Import .bgen files as VariantDataset @@ -323,6 +277,43 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch return self.run_command(None, pargs) + def import_keytable(self, path, key_names, npartition = None, config = None): + """Import tabular file as KeyTable + + :param path: .tsv files to import. + :type path: str or list of str + + :param key_names: The name(s) of fields to be considered keys + :type key_names: str or list[str] + + :param npartition: Number of partitions. + :type npartition: int or None + + :param :class:`.TextTableConfig` config: Configuration options for importing text files + + :rtype: :class:`.KeyTable` + """ + pathArgs = [] + if isinstance(path, str): + pathArgs.append(path) + else: + for p in path: + pathArgs.append(p) + + if not isinstance(key_names, str): + key_names = ",".join(key_names) + + if not npartition: + npartition = self.sc.defaultMinPartitions + + if not config: + config = TextTableConfig()._toJavaObject(self) + elif isinstance(key_names, TextTableConfig): + config = config._toJavaObject(self) + + return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), + key_names, npartition, config)) + def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing="NA", quantpheno=False): """ Import PLINK binary file (.bed, .bim, .fam) as VariantDataset diff --git a/python/pyhail/docs/index.rst b/python/pyhail/docs/index.rst index c743f23687e..228c112559c 100644 --- a/python/pyhail/docs/index.rst +++ b/python/pyhail/docs/index.rst @@ -17,6 +17,12 @@ Contents: .. autoclass:: pyhail.VariantDataset :members: +.. autoclass:: pyhail.KeyTable + :members: + +.. autoclass:: pyhail.TextTableConfig + :members: + Indices and tables ================== diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index 656b35aa071..dbe6efdf85e 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -1,58 +1,61 @@ -from pyhail.java import scala_object class KeyTable: """:class:`.KeyTable` is Hail's version of a SQL - table where fields can be designated as keys. - - :param :class:`.HailContext` hc: Hail spark context. - :param JavaKeyTable jkt: Java KeyTable object. + table where fields can be designated as keys. """ def __init__(self, hc, jkt): + """ + :param hc: Hail spark context. + :type hc: :class:`.HailContext` + + :param JavaKeyTable jkt: Java KeyTable object. + """ self.hc = hc self.jkt = jkt def nFields(self): """Number of fields in the key-table - :return: int + :rtype: int """ return self.jkt.nFields() def schema(self): """Key-table schema - :return: ??? + :rtype: ??? """ return self.jkt.schema() def keyNames(self): """Field names that are keys - :return: list[str] + :rtype: list[str] """ return self.jkt.keyNames() def fieldNames(self): - """Field names + """Names of all fields in the key-table - :return: list[str] + :rtype: list[str] """ return self.jkt.fieldNames() def nRows(self): """Number of rows in the key-table - :return: long + :rtype: long """ return self.jkt.nRows() def same(self, other): """Test whether two key-tables are identical - :param :class:`.KeyTable` other: KeyTable to compare to + :param other: KeyTable to compare to + :type other: :class:`.KeyTable` - :return: bool + :rtype: bool """ return self.jkt.same(other.jkt) @@ -74,7 +77,7 @@ def filter(self, code, keep = True): :param bool keep: Keep rows where annotation expression evaluates to True - :return: :class:`.KeyTable` + :rtype: :class:`.KeyTable` """ return KeyTable(self.hc, self.jkt.filter(code, keep)) @@ -85,19 +88,20 @@ def annotate(self, code, key_names = None): :param str key_names: Comma separated list of field names to be treated as a key - :return: :class:`.KeyTable` + :rtype: :class:`.KeyTable` """ return KeyTable(self.hc, self.jkt.annotate(code, key_names)) def join(self, right, how = 'inner'): """Join two key-tables together. Both key-tables must have identical key schemas - and non-overlapping fields in order to be joined. + and non-overlapping field names. - :param :class:`.KeyTable` right: Key-table to join + :param right: Key-table to join + :type right: :class:`.KeyTable` :param str how: Method for joining two tables together. One of "inner", "outer", "left", "right". - :return: :class:`.KeyTable` + :rtype: :class:`.KeyTable` """ return KeyTable(self.hc, self.jkt.join(right.jkt, how)) @@ -108,7 +112,7 @@ def aggregate(self, key_cond, agg_cond): :param str agg_cond: Named expression specifying how new fields are computed - :return: :class:`.KeyTable` + :rtype: :class:`.KeyTable` """ return KeyTable(self.hc, self.jkt.aggregate(key_cond, agg_cond)) @@ -117,7 +121,7 @@ def forall(self, code): :param str code: Boolean expression - :return: bool + :rtype: bool """ return self.jkt.forall(cond) @@ -126,6 +130,6 @@ def exists(self, code): :param str code: Boolean expression - :return: bool + :rtype: bool """ return self.jkt.exists(cond) \ No newline at end of file diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py index c4a9da4256a..19a04ff3d45 100644 --- a/python/pyhail/utils.py +++ b/python/pyhail/utils.py @@ -16,7 +16,7 @@ class TextTableConfig: :param str types: Define types of fields in annotations files """ def __init__(self, noheader = False, impute = False, - comment = None, delimiter = "\t", missing = "NA", types = None): + comment = None, delimiter = "\\\\t", missing = "NA", types = None): self.noheader = noheader self.impute = impute self.comment = comment @@ -36,7 +36,7 @@ def __str__(self): return " ".join(res) - def toJavaObject(self, hc): + def _toJavaObject(self, hc): """Convert to java TextTableConfiguration object :param :class:`.HailContext` hc: Hail spark context. From 2980da394a0b7740a9a2fdf881776ba2b9940364 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Mon, 14 Nov 2016 12:47:39 -0500 Subject: [PATCH 37/51] done with python bindings --- python/pyhail/context.py | 7 ++++--- python/pyhail/dataset.py | 22 ++++++++++++++++------ python/pyhail/keytable.py | 32 ++++++++++++++++++++++---------- python/pyhail/utils.py | 13 +++++++++++-- 4 files changed, 53 insertions(+), 21 deletions(-) diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 36bc4d234fb..d342b212a56 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -280,16 +280,17 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch def import_keytable(self, path, key_names, npartition = None, config = None): """Import tabular file as KeyTable - :param path: .tsv files to import. + :param path: files to import. :type path: str or list of str :param key_names: The name(s) of fields to be considered keys - :type key_names: str or list[str] + :type key_names: str or list of str :param npartition: Number of partitions. :type npartition: int or None - :param :class:`.TextTableConfig` config: Configuration options for importing text files + :param config: Configuration options for importing text files + :type config: :class:`.TextTableConfig` :rtype: :class:`.KeyTable` """ diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index c8a09a82362..453d2064fe0 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -12,15 +12,26 @@ def __init__(self, hc, jvds): def _raise_py4j_exception(self, e): self.hc._raise_py4j_exception(e) - def aggregate_by_key(self, key_condition=None, agg_condition=None): - """Aggregate by user-defined key and aggregation expressions + def aggregate_by_key(self, key_code=None, agg_code=None): + """Aggregate by user-defined key and aggregation expressions. + Equivalent of a group-by operation in SQL. - :param str key_condition: Named expression for which fields are keys + Example: + >>> hc = HailContext(sc) + >>> vds = hc.import_vcf("/path/to/file.vcf") + >>> vds.aggregate_by_key("pheno = sa.pheno, gene = va.gene", "nHet = gs.filter(g => g.isHet).count(), nAlleles = gs.filter(g => g.isCalled).count() * 2") - :param str agg_condition: Named aggregation expression. + The resulting key-table will have four fields [pheno, gene, nHet, nAlleles] where pheno and gene are the keys. - :rtype: :class`.KeyTable` + :param key_code: Named expression(s) for which fields are keys. + :type key_code: str or list of str + + :param agg_code: Named aggregation expression(s). + :type agg_code: str or list of str + + :rtype: :class:`.KeyTable` """ + return KeyTable(self.hc, self.jvds.aggregateByKey(key_condition, agg_condition)) def aggregate_intervals(self, input, condition, output): @@ -57,7 +68,6 @@ def annotate_global_list(self, input, root, as_set=False): :param bool as_set: If True, load text file as Set[String], otherwise, load as Array[String]. - """ pargs = ['annotateglobal', 'list', '-i', input, '-r', root] diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index dbe6efdf85e..5ba23eb988d 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -1,5 +1,6 @@ +from pyhail.utils import Type -class KeyTable: +class KeyTable(object): """:class:`.KeyTable` is Hail's version of a SQL table where fields can be designated as keys. """ @@ -14,6 +15,9 @@ def __init__(self, hc, jkt): self.hc = hc self.jkt = jkt + def __repr__(self): + return self.jkt.toString() + def nFields(self): """Number of fields in the key-table @@ -24,21 +28,21 @@ def nFields(self): def schema(self): """Key-table schema - :rtype: ??? + :rtype: :class:`.Type` """ - return self.jkt.schema() + return Type(self.jkt.signature()) def keyNames(self): """Field names that are keys - :rtype: list[str] + :rtype: list of str """ return self.jkt.keyNames() def fieldNames(self): """Names of all fields in the key-table - :rtype: list[str] + :rtype: list of str """ return self.jkt.fieldNames() @@ -105,15 +109,23 @@ def join(self, right, how = 'inner'): """ return KeyTable(self.hc, self.jkt.join(right.jkt, how)) - def aggregate(self, key_cond, agg_cond): + def _aggregate(self, key_code, agg_code): """Group by key condition and aggregate results - :param str key_cond: Named expression defining keys in the new key-table + :param key_code: Named expression(s) for which fields are keys. + :type key_code: str or list of str - :param str agg_cond: Named expression specifying how new fields are computed + :param agg_code: Named aggregation expression(s). + :type agg_code: str or list of str :rtype: :class:`.KeyTable` """ + if isinstance(key_code, list): + key_code = ", ".join([str(l) for l in list]) + + if isinstance(agg_code, list): + agg_code = ", ".join([str(l) for l in list]) + return KeyTable(self.hc, self.jkt.aggregate(key_cond, agg_cond)) def forall(self, code): @@ -123,7 +135,7 @@ def forall(self, code): :rtype: bool """ - return self.jkt.forall(cond) + return self.jkt.forall(code) def exists(self, code): """Tests whether a condition is true for any row @@ -132,4 +144,4 @@ def exists(self, code): :rtype: bool """ - return self.jkt.exists(cond) \ No newline at end of file + return self.jkt.exists(code) \ No newline at end of file diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py index 19a04ff3d45..682f29593b4 100644 --- a/python/pyhail/utils.py +++ b/python/pyhail/utils.py @@ -1,6 +1,15 @@ +class Type(object): + def __init__(self, jtype): + self.jtype = jtype -class TextTableConfig: + def __repr__(self): + return self.jtype.toString() + + def __str__(self): + return self.jtype.toPrettyString(False, False) + +class TextTableConfig(object): """:class:`.TextTableConfig` specifies additional options for importing TSV files. :param bool noheader: File has no header and columns should be indicated by `_1, _2, ... _N' (0-indexed) @@ -16,7 +25,7 @@ class TextTableConfig: :param str types: Define types of fields in annotations files """ def __init__(self, noheader = False, impute = False, - comment = None, delimiter = "\\\\t", missing = "NA", types = None): + comment = None, delimiter = "\t", missing = "NA", types = None): self.noheader = noheader self.impute = impute self.comment = comment From 8bdcb24aa39e471d108c68ca95a987763cc34449 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Mon, 14 Nov 2016 15:25:07 -0500 Subject: [PATCH 38/51] tried aggregator with rows --- python/pyhail/__init__.py | 2 +- .../hail/keytable/KeyTable.scala | 58 ++++++++++++++++++- .../hail/methods/KeyTableSuite.scala | 26 ++++++--- 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py index fd18b118917..2543b1978f2 100644 --- a/python/pyhail/__init__.py +++ b/python/pyhail/__init__.py @@ -3,4 +3,4 @@ from pyhail.keytable import KeyTable from pyhail.utils import TextTableConfig -__all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig"] +__all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig", "Type"] diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index cd5032a6adf..97cf84b51a6 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -227,7 +227,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def join(other: KeyTable, joinType: String): KeyTable = { if (keySignature != other.keySignature) fatal(s"""Key signatures must be identical. - |Left signature: ${keySignature} + |Left signature: $keySignature |Right signature: ${other.keySignature}""".stripMargin) val overlappingFields = valueNames.toSet.intersect(other.valueNames.toSet) @@ -395,4 +395,60 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v KeyTable(newRDD, keySignature, valueSignature) } + + def aggregateRows(keyCond: String, aggCond: String): KeyTable = { + + val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + val st = fields.zipWithIndex.map{ case (fd, i) => (fd.name, (i, fd.`type`)) }.toMap ++ Map("rows" -> (-1, BaseAggregable(aggregationEC, TStruct(fields.map(fd => (fd.name, fd.`type`)): _*)))) + val ec = EvalContext(st) + + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val (aggNameParseTypes, aggF) = + if (aggCond != null) + Parser.parseAnnotationArgs(aggCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + + val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + + val seqOp = (array: Array[Aggregator], b: Any) => { + KeyTable.setEvalContext(ec, b, nFields) + val row = Option(b).map(_.asInstanceOf[Row]).orNull + + println(s"keytable aggregateRow seqop ec: $ec") + println(s"keytable aggregateRow seqop ec identity: ${System.identityHashCode(ec)}") + println(s"keytable aggregateRow seqop row: $row") + + for (i <- array.indices) { + array(i).seqOp(row) + } + array + } + + val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions{ it => + it.map { a => + KeyTable.setEvalContext(ec, a, nFields) + val key = Annotation.fromSeq(keyF.map(_ ())) + println(s"keytable aggregateRow keymap key: $key") + (key, a) + } + }.aggregateByKey(zVals)(seqOp, combOp) + .map { case (k, agg) => + resultOp(agg) + (k, Annotation.fromSeq(aggF.map(_ ()))) + } + + KeyTable(newRDD, keySignature, valueSignature) + } } \ No newline at end of file diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 3412c8d8d9f..e07a855992a 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -145,15 +145,8 @@ class KeyTableSuite extends SparkSuite { val keyNames = Array("field1") val kt1 = KeyTable(rdd, signature, keyNames) -// val kt2 = kt1.aggregate("Status = field1", "field4 = field2.sum(), field5 = field2.map(f => field2 + field3).sum()") - //val kt2 = kt1.aggregate("Status = field1", "field5 = field2.map(f => field2 + field3).sum()") -// val kt2 = kt1.aggregate("Status = field1", "X = field2.map(f => field2).sum(), Y = field2.sum(), Z = field2.map(f => f).sum()") -// val result = Array(Array("Case", 12.0, 12.0, 12.0), Array("Control", 3.0, 3.0, 3.0)) -// val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) -// val resSignature = TStruct(("Status", TString), ("X", TDouble), ("Y", TDouble), ("Z", TDouble)) -// val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) - val kt2 = kt1.aggregate("Status = field1", "X = field2.map(f => field2).sum()") + val result = Array(Array("Case", 12.0), Array("Control", 3.0)) val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) val resSignature = TStruct(("Status", TString), ("X", TDouble)) @@ -163,6 +156,23 @@ class KeyTableSuite extends SparkSuite { assert(kt2 same ktResult) } + @Test def testAggregateRows() { + val data = Array(Array("Case", 9, 0), Array("Case", 3, 4), Array("Control", 2, 3), Array("Control", 1, 5)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("field1", TString), ("field2", TInt), ("field3", TInt)) + val keyNames = Array("field1") + + val kt1 = KeyTable(rdd, signature, keyNames) + val kt2 = kt1.aggregateRows("Status = field1", "X = rows.map(r => field2).sum()") + + val result = Array(Array("Case", 12.0), Array("Control", 3.0)) + val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) + val resSignature = TStruct(("Status", TString), ("X", TDouble)) + val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) + + assert(kt2 same ktResult) + } + @Test def testForallExists() { val data = Array(Array("Sample1", 9, 5), Array("Sample2", 3, 5), Array("Sample3", 2, 5), Array("Sample4", 1, 5)) val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) From d0b45bc75bfa84fc0e868b025fd89c6fce5828fb Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Mon, 14 Nov 2016 15:43:27 -0500 Subject: [PATCH 39/51] fixed formatting of arguments --- python/pyhail/context.py | 16 ++++++++-------- python/pyhail/dataset.py | 5 ----- python/pyhail/keytable.py | 10 +++++----- .../org/broadinstitute/hail/driver/Command.scala | 1 - 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/python/pyhail/context.py b/python/pyhail/context.py index d342b212a56..66823314659 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -189,7 +189,7 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No return self.run_command(None, pargs) - def import_bgen(self, path, tolerance = 0.2, sample_file = None, npartition = None): + def import_bgen(self, path, tolerance=0.2, sample_file=None, npartitions=None): """Import .bgen files as VariantDataset :param path: .bgen files to import. @@ -227,7 +227,7 @@ def import_bgen(self, path, tolerance = 0.2, sample_file = None, npartition = No pargs.append('--tolerance') pargs.append(str(tolerance)) - return self._run_command(None, pargs) + return self.run_command(None, pargs) def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, chromosome=None): """Import .bgen files as VariantDataset @@ -277,7 +277,7 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch return self.run_command(None, pargs) - def import_keytable(self, path, key_names, npartition = None, config = None): + def import_keytable(self, path, key_names, npartitions=None, config=None): """Import tabular file as KeyTable :param path: files to import. @@ -286,8 +286,8 @@ def import_keytable(self, path, key_names, npartition = None, config = None): :param key_names: The name(s) of fields to be considered keys :type key_names: str or list of str - :param npartition: Number of partitions. - :type npartition: int or None + :param npartitions: Number of partitions. + :type npartitions: int or None :param config: Configuration options for importing text files :type config: :class:`.TextTableConfig` @@ -304,8 +304,8 @@ def import_keytable(self, path, key_names, npartition = None, config = None): if not isinstance(key_names, str): key_names = ",".join(key_names) - if not npartition: - npartition = self.sc.defaultMinPartitions + if not npartitions: + npartitions = self.sc.defaultMinPartitions if not config: config = TextTableConfig()._toJavaObject(self) @@ -313,7 +313,7 @@ def import_keytable(self, path, key_names, npartition = None, config = None): config = config._toJavaObject(self) return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), - key_names, npartition, config)) + key_names, npartitions, config)) def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing="NA", quantpheno=False): """ diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index 453d2064fe0..279ce0ea441 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -16,11 +16,6 @@ def aggregate_by_key(self, key_code=None, agg_code=None): """Aggregate by user-defined key and aggregation expressions. Equivalent of a group-by operation in SQL. - Example: - >>> hc = HailContext(sc) - >>> vds = hc.import_vcf("/path/to/file.vcf") - >>> vds.aggregate_by_key("pheno = sa.pheno, gene = va.gene", "nHet = gs.filter(g => g.isHet).count(), nAlleles = gs.filter(g => g.isCalled).count() * 2") - The resulting key-table will have four fields [pheno, gene, nHet, nAlleles] where pheno and gene are the keys. :param key_code: Named expression(s) for which fields are keys. diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index 5ba23eb988d..e0657bc4d47 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -63,7 +63,7 @@ def same(self, other): """ return self.jkt.same(other.jkt) - def export(self, output, types_file = None): + def export(self, output, types_file=None): """Export key-table to a TSV file. :param str output: Output file path @@ -74,7 +74,7 @@ def export(self, output, types_file = None): """ self.jkt.export(self.hc.jsc, output, types_file) - def filter(self, code, keep = True): + def filter(self, code, keep=True): """Filter rows from key-table. :param str code: Annotation expression. @@ -85,7 +85,7 @@ def filter(self, code, keep = True): """ return KeyTable(self.hc, self.jkt.filter(code, keep)) - def annotate(self, code, key_names = None): + def annotate(self, code, key_names=None): """Add fields to key-table. :param str code: Annotation expression. @@ -96,7 +96,7 @@ def annotate(self, code, key_names = None): """ return KeyTable(self.hc, self.jkt.annotate(code, key_names)) - def join(self, right, how = 'inner'): + def join(self, right, how='inner'): """Join two key-tables together. Both key-tables must have identical key schemas and non-overlapping field names. @@ -126,7 +126,7 @@ def _aggregate(self, key_code, agg_code): if isinstance(agg_code, list): agg_code = ", ".join([str(l) for l in list]) - return KeyTable(self.hc, self.jkt.aggregate(key_cond, agg_cond)) + return KeyTable(self.hc, self.jkt.aggregate(key_code, agg_code)) def forall(self, code): """Tests whether a condition is true for all rows diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index c46c690bb66..e40ef4554d5 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -2,7 +2,6 @@ package org.broadinstitute.hail.driver import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.VariantDataset import org.kohsuke.args4j.{Argument, CmdLineException, CmdLineParser, Option => Args4jOption} From 2927c5aed9960260f0e79f2ccdb5ffe8d1d60ac3 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Mon, 14 Nov 2016 15:49:02 -0500 Subject: [PATCH 40/51] reformatted code --- .../org/broadinstitute/hail/expr/AST.scala | 44 +++++++---------- .../hail/keytable/KeyTable.scala | 48 ++++++++++--------- .../hail/methods/KeyTableSuite.scala | 4 +- 3 files changed, 45 insertions(+), 51 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 7a7312b54fa..c0e60af126f 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -20,15 +20,7 @@ import org.broadinstitute.hail.utils.EitherIsAMonad._ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunctions: ArrayBuffer[Aggregator]) { def setAll(args: Any*) { - try { - args.zipWithIndex.foreach { case (arg, i) => -// println(s"$arg, $i a=$a") - a(i) = arg -// println(s"$arg, $i a=$a st=$st") - } - } catch { - case _: IndexOutOfBoundsException => println("error") - } + args.zipWithIndex.foreach { case (arg, i) => a(i) = arg } } def set(index: Int, arg: Any) { @@ -669,22 +661,22 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST body.typecheck(agg.ec.copy(st = st)) `type` = body.`type` match { case t: Type => - println(s"maptypecheck bodyType: ${body.getClass}") - println(s"maptypecheck param: ${param}") + println(s"maptypecheck bodyType: ${ body.getClass }") + println(s"maptypecheck param: ${ param }") println(s"maptypecheck localIdx: $localIdx") - println(s"maptypecheck elementtype: ${agg.elementType}") - println(s"maptypecheck localA: ${localA}") - println(s"maptypecheck localA identityCode: ${System.identityHashCode(localA)}") - println(s"maptypecheck ec: ${agg.ec}") - println(s"maptypecheck ec identityCode: ${System.identityHashCode(agg.ec)}") + println(s"maptypecheck elementtype: ${ agg.elementType }") + println(s"maptypecheck localA: ${ localA }") + println(s"maptypecheck localA identityCode: ${ System.identityHashCode(localA) }") + println(s"maptypecheck ec: ${ agg.ec }") + println(s"maptypecheck ec identityCode: ${ System.identityHashCode(agg.ec) }") val fn = body.eval(agg.ec.copy(st = st)) val mapF = (a: Any) => { localA(localIdx) = a - println(s"mapF ec: ${agg.ec}") - println(s"mapF LocalA: ${localA}") - println(s"mapF ec identity code: ${System.identityHashCode(agg.ec)}") - println(s"mapF LocalA identity code: ${System.identityHashCode(localA)}") - println(s"mapF result eval fn(): ${fn()}") + println(s"mapF ec: ${ agg.ec }") + println(s"mapF LocalA: ${ localA }") + println(s"mapF ec identity code: ${ System.identityHashCode(agg.ec) }") + println(s"mapF LocalA identity code: ${ System.identityHashCode(localA) }") + println(s"mapF result eval fn(): ${ fn() }") fn() } @@ -1722,11 +1714,11 @@ case class SymRef(posn: Position, symbol: String) extends AST(posn) { println(s"symref symbol: $symbol") val localI = ec.st(symbol)._1 val localA = ec.a - println(s"symref ec: ${ec}") - println(s"symref localI: ${ec.st(symbol)._1}") - println(s"symref localA: ${ec.a}") - println(s"symref ec identityhash: ${System.identityHashCode(ec)}") - println(s"symref ec.a identityhash: ${System.identityHashCode(ec.a)}") + println(s"symref ec: ${ ec }") + println(s"symref localI: ${ ec.st(symbol)._1 }") + println(s"symref localA: ${ ec.a }") + println(s"symref ec identityhash: ${ System.identityHashCode(ec) }") + println(s"symref ec.a identityhash: ${ System.identityHashCode(ec.a) }") if (localI < 0) () => 0 // FIXME placeholder else diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 97cf84b51a6..a927dc660d9 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -226,14 +226,16 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def join(other: KeyTable, joinType: String): KeyTable = { if (keySignature != other.keySignature) - fatal(s"""Key signatures must be identical. - |Left signature: $keySignature - |Right signature: ${other.keySignature}""".stripMargin) + fatal( + s"""Key signatures must be identical. + |Left signature: $keySignature + |Right signature: ${ other.keySignature }""".stripMargin) val overlappingFields = valueNames.toSet.intersect(other.valueNames.toSet) if (overlappingFields.nonEmpty) - fatal(s"""Fields that are not keys cannot be present in both key-tables. - |Overlapping fields: ${overlappingFields.mkString(", ")}""".stripMargin) + fatal( + s"""Fields that are not keys cannot be present in both key-tables. + |Overlapping fields: ${ overlappingFields.mkString(", ") }""".stripMargin) joinType match { case "left" => leftJoin(other) @@ -341,7 +343,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def aggregate(keyCond: String, aggCond: String): KeyTable = { val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) - val ec = EvalContext(fields.zipWithIndex.map{ case (fd, i) => (fd.name, (-1, KeyTableAggregable(aggregationEC, fd.`type`, i)))}.toMap) + val ec = EvalContext(fields.zipWithIndex.map { case (fd, i) => (fd.name, (-1, KeyTableAggregable(aggregationEC, fd.`type`, i))) }.toMap) val (keyNameParseTypes, keyF) = if (keyCond != null) @@ -358,30 +360,30 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val keyNames = keyNameParseTypes.map(_._1.head) val aggNames = aggNameParseTypes.map(_._1.head) - val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) val seqOp = (array: Array[Aggregator], b: Any) => { println(s"values inside b = " + KeyTable.annotationToSeq(b, nFields)) - println(s"keytable seqop pre-setec ec.a: ${aggregationEC.a}") - println(s"keytable seqop pre-setec ec: ${aggregationEC}") - println(s"keytable seqop pre-setec ec.a pointer: ${System.identityHashCode(aggregationEC.a)}") - println(s"keytable seqop pre-setec ec pointer: ${System.identityHashCode(aggregationEC)}") + println(s"keytable seqop pre-setec ec.a: ${ aggregationEC.a }") + println(s"keytable seqop pre-setec ec: ${ aggregationEC }") + println(s"keytable seqop pre-setec ec.a pointer: ${ System.identityHashCode(aggregationEC.a) }") + println(s"keytable seqop pre-setec ec pointer: ${ System.identityHashCode(aggregationEC) }") KeyTable.setEvalContext(aggregationEC, b, nFields) - println(s"keytable seqop post-setec ec.a: ${aggregationEC.a}") - println(s"keytable seqop post-setec ec: ${aggregationEC}") - println(s"keytable seqop post-setec ec.a pointer: ${System.identityHashCode(aggregationEC.a)}") - println(s"keytable seqop post-setec ec pointer: ${System.identityHashCode(aggregationEC)}") + println(s"keytable seqop post-setec ec.a: ${ aggregationEC.a }") + println(s"keytable seqop post-setec ec: ${ aggregationEC }") + println(s"keytable seqop post-setec ec.a pointer: ${ System.identityHashCode(aggregationEC.a) }") + println(s"keytable seqop post-setec ec pointer: ${ System.identityHashCode(aggregationEC) }") for (i <- array.indices) { - println(s"keytable seqop array($i): ${array(i)}") + println(s"keytable seqop array($i): ${ array(i) }") array(i).seqOp(b) } array } - val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions{ it => + val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions { it => it.map { a => KeyTable.setEvalContext(aggregationEC, a, nFields) val key = Annotation.fromSeq(keyF.map(_ ())) @@ -399,7 +401,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def aggregateRows(keyCond: String, aggCond: String): KeyTable = { val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) - val st = fields.zipWithIndex.map{ case (fd, i) => (fd.name, (i, fd.`type`)) }.toMap ++ Map("rows" -> (-1, BaseAggregable(aggregationEC, TStruct(fields.map(fd => (fd.name, fd.`type`)): _*)))) + val st = fields.zipWithIndex.map { case (fd, i) => (fd.name, (i, fd.`type`)) }.toMap ++ Map("rows" -> (-1, BaseAggregable(aggregationEC, TStruct(fields.map(fd => (fd.name, fd.`type`)): _*)))) val ec = EvalContext(st) val (keyNameParseTypes, keyF) = @@ -417,8 +419,8 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val keyNames = keyNameParseTypes.map(_._1.head) val aggNames = aggNameParseTypes.map(_._1.head) - val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) @@ -427,7 +429,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val row = Option(b).map(_.asInstanceOf[Row]).orNull println(s"keytable aggregateRow seqop ec: $ec") - println(s"keytable aggregateRow seqop ec identity: ${System.identityHashCode(ec)}") + println(s"keytable aggregateRow seqop ec identity: ${ System.identityHashCode(ec) }") println(s"keytable aggregateRow seqop row: $row") for (i <- array.indices) { @@ -436,7 +438,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v array } - val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions{ it => + val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions { it => it.map { a => KeyTable.setEvalContext(ec, a, nFields) val key = Annotation.fromSeq(keyF.map(_ ())) diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index e07a855992a..6347e3e0c0d 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -117,7 +117,7 @@ class KeyTableSuite extends SparkSuite { assert(ktLeftJoin.nRows == ktLeft.nRows && ktLeftJoin.nKeys == nExpectedKeys && ktLeftJoin.nValues == nExpectedValues && - ktLeftJoin.filter{ case (k, v) => + ktLeftJoin.filter { case (k, v) => !rightKeys.contains(leftJoinKeyQuery(k, v).map(_.asInstanceOf[String])) }.forall("isMissing(qPhen2) && isMissing(qPhen3)") ) @@ -125,7 +125,7 @@ class KeyTableSuite extends SparkSuite { assert(ktRightJoin.nRows == ktRight.nRows && ktRightJoin.nKeys == nExpectedKeys && ktRightJoin.nValues == nExpectedValues && - ktRightJoin.filter{ case (k, v) => + ktRightJoin.filter { case (k, v) => !leftKeys.contains(rightJoinKeyQuery(k, v).map(_.asInstanceOf[String])) }.forall("isMissing(Status) && isMissing(qPhen)")) From 026631ab10dea4a6fc2d0a5eccbcedde2894905d Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Wed, 16 Nov 2016 17:46:23 -0500 Subject: [PATCH 41/51] Fixed aggregation serialization sharing bug. Now both aggregate and aggregateRows KT tests pass. --- .../org/broadinstitute/hail/expr/AST.scala | 28 ++-- .../hail/keytable/KeyTable.scala | 26 +-- .../hail/methods/Aggregators.scala | 150 +++++++++--------- 3 files changed, 93 insertions(+), 111 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index c0e60af126f..3677371bd1e 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -17,7 +17,9 @@ import scala.language.existentials import scala.reflect.ClassTag import org.broadinstitute.hail.utils.EitherIsAMonad._ -case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunctions: ArrayBuffer[Aggregator]) { +case class EvalContext(st: SymbolTable, + a: ArrayBuffer[Any], + aggregationFunctions: ArrayBuffer[((Any) => Any, Aggregator)]) { def setAll(args: Any*) { args.zipWithIndex.foreach { case (arg, i) => a(i) = arg } @@ -31,7 +33,7 @@ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunction object EvalContext { def apply(symTab: SymbolTable): EvalContext = { val a = new ArrayBuffer[Any]() - val af = new ArrayBuffer[Aggregator]() + val af = new ArrayBuffer[((Any) => Any, Aggregator)]() for ((i, t) <- symTab.values) { if (i >= 0) a += null @@ -1157,7 +1159,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new CountAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new CountAggregator(localIdx))) () => localA(localIdx) case (agg: TAggregable, "fraction", Array(Lambda(_, param, body))) => @@ -1171,7 +1173,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new FractionAggregator(aggF, localIdx, localA, bodyFn, lambdaIdx) + agg.ec.aggregationFunctions += ((aggF, new FractionAggregator(localIdx, localA, bodyFn, lambdaIdx))) () => localA(localIdx) case (agg: TAggregable, "stats", Array()) => @@ -1182,7 +1184,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val t = agg.elementType val aggF = agg.f - agg.ec.aggregationFunctions += new StatAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new StatAggregator(localIdx))) val getOp = (a: Any) => { val sc = a.asInstanceOf[StatCounter] @@ -1225,7 +1227,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val vf = vAST.eval(ec) - agg.ec.aggregationFunctions += new CallStatsAggregator(aggF, localIdx, vf) + agg.ec.aggregationFunctions += ((aggF, new CallStatsAggregator(localIdx, vf))) () => { val cs = localA(localIdx).asInstanceOf[CallStats] @@ -1279,7 +1281,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new HistAggregator(aggF, localIdx, indices) + agg.ec.aggregationFunctions += ((aggF, new HistAggregator(localIdx, indices))) () => localA(localIdx).asInstanceOf[HistogramCombiner].toAnnotation @@ -1290,7 +1292,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new CollectAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new CollectAggregator(localIdx))) () => localA(localIdx).asInstanceOf[ArrayBuffer[Any]].toIndexedSeq case (agg: TAggregable, "infoScore", Array()) => @@ -1301,7 +1303,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val localPos = posn val aggF = agg.f - agg.ec.aggregationFunctions += new InfoScoreAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new InfoScoreAggregator(localIdx))) () => localA(localIdx).asInstanceOf[InfoScoreCombiner].asAnnotation case (agg: TAggregable, "inbreeding", Array(mafAST)) => @@ -1313,7 +1315,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f val maf = mafAST.eval(agg.ec) - agg.ec.aggregationFunctions += new InbreedingAggregator(aggF, localIdx, maf) + agg.ec.aggregationFunctions += ((aggF, new InbreedingAggregator(localIdx, maf))) () => localA(localIdx).asInstanceOf[InbreedingCombiner].asAnnotation case (agg: TAggregable, "hardyWeinberg", Array()) => @@ -1324,7 +1326,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val localPos = posn val aggF = agg.f - agg.ec.aggregationFunctions += new HWEAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new HWEAggregator(localIdx))) () => localA(localIdx).asInstanceOf[HWECombiner].asAnnotation case (agg: TAggregable, "sum", Array()) => @@ -1336,8 +1338,8 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f (`type`: @unchecked) match { - case TDouble => agg.ec.aggregationFunctions += new SumAggregator(aggF, localIdx) - case TArray(TDouble) => agg.ec.aggregationFunctions += new SumArrayAggregator(aggF, localIdx, localPos) + case TDouble => agg.ec.aggregationFunctions += ((aggF, new SumAggregator(localIdx))) + case TArray(TDouble) => agg.ec.aggregationFunctions += ((aggF, new SumArrayAggregator(localIdx, localPos))) } () => localA(localIdx) diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index a927dc660d9..82eb5ae4227 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -364,21 +364,14 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + val aggFunctions = aggregationEC.aggregationFunctions.map(_._1) + + assert(zVals.length == aggFunctions.length) val seqOp = (array: Array[Aggregator], b: Any) => { - println(s"values inside b = " + KeyTable.annotationToSeq(b, nFields)) - println(s"keytable seqop pre-setec ec.a: ${ aggregationEC.a }") - println(s"keytable seqop pre-setec ec: ${ aggregationEC }") - println(s"keytable seqop pre-setec ec.a pointer: ${ System.identityHashCode(aggregationEC.a) }") - println(s"keytable seqop pre-setec ec pointer: ${ System.identityHashCode(aggregationEC) }") KeyTable.setEvalContext(aggregationEC, b, nFields) - println(s"keytable seqop post-setec ec.a: ${ aggregationEC.a }") - println(s"keytable seqop post-setec ec: ${ aggregationEC }") - println(s"keytable seqop post-setec ec.a pointer: ${ System.identityHashCode(aggregationEC.a) }") - println(s"keytable seqop post-setec ec pointer: ${ System.identityHashCode(aggregationEC) }") for (i <- array.indices) { - println(s"keytable seqop array($i): ${ array(i) }") - array(i).seqOp(b) + array(i).seqOp(aggFunctions(i)(b)) } array } @@ -423,17 +416,12 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + val aggFunctions = aggregationEC.aggregationFunctions.map(_._1) val seqOp = (array: Array[Aggregator], b: Any) => { - KeyTable.setEvalContext(ec, b, nFields) - val row = Option(b).map(_.asInstanceOf[Row]).orNull - - println(s"keytable aggregateRow seqop ec: $ec") - println(s"keytable aggregateRow seqop ec identity: ${ System.identityHashCode(ec) }") - println(s"keytable aggregateRow seqop row: $row") - + KeyTable.setEvalContext(aggregationEC, b, nFields) for (i <- array.indices) { - array(i).seqOp(row) + array(i).seqOp(aggFunctions(i)(b)) } array } diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index 1e57876a3ad..19a241a01de 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -15,7 +15,10 @@ import scala.util.parsing.input.Position object Aggregators { def buildVariantAggregations(vds: VariantDataset, ec: EvalContext): Option[(Variant, Annotation, Iterable[Genotype]) => Unit] = { - val aggregators = ec.aggregationFunctions.toArray + + val aggFunctions = ec.aggregationFunctions.map(_._1) + val aggregators = ec.aggregationFunctions.map(_._2) + val aggregatorA = ec.a if (aggregators.nonEmpty) { @@ -28,29 +31,29 @@ object Aggregators { aggregatorA(0) = v aggregatorA(1) = va (gs, localSamplesBc.value, localAnnotationsBc.value).zipped - .foreach { - case (g, s, sa) => - aggregatorA(2) = s - aggregatorA(3) = sa - baseArray.foreach { - _.seqOp(g) - } + .foreach { case (g, s, sa) => + aggregatorA(2) = s + aggregatorA(3) = sa + (baseArray, aggFunctions).zipped.foreach { case (agg, aggF) => + agg.seqOp(aggF(g)) + } } baseArray.foreach { agg => aggregatorA(agg.idx) = agg.result } } Some(f) - } else None + } else + None } def buildSampleAggregations(vds: VariantDataset, ec: EvalContext): Option[(String) => Unit] = { - val aggregators = ec.aggregationFunctions.toArray + val aggFunctions = ec.aggregationFunctions.map(_._1) + val aggregators = ec.aggregationFunctions.map(_._2) val aggregatorA = ec.a if (aggregators.isEmpty) None else { - val localSamplesBc = vds.sampleIdsBc val localAnnotationsBc = vds.sampleAnnotationsBc @@ -73,7 +76,7 @@ object Aggregators { var j = 0 while (j < nAggregations) { - arr(i, j).seqOp(g) + arr(i, j).seqOp(aggFunctions(j)(g)) j += 1 } i += 1 @@ -100,7 +103,8 @@ object Aggregators { def makeFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any)) => Array[Aggregator], (Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = { - val aggregators = ec.aggregationFunctions.toArray + val aggFunctions = ec.aggregationFunctions.map(_._1) + val aggregators = ec.aggregationFunctions.map(_._2) val arr = ec.a @@ -116,7 +120,7 @@ object Aggregators { val (aggT, annotation) = b ec.set(0, annotation) for (i <- array.indices) { - array(i).seqOp(aggT) + array(i).seqOp(aggFunctions(i)(aggT)) } array } @@ -135,15 +139,14 @@ object Aggregators { } } -class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Long] { +class CountAggregator(val idx: Int) extends TypedAggregator[Long] { var _state = 0L def result = _state - def seqOp(x: Any) { - val v = f(x) - if (f(x) != null) + override def seqOp(x: Any) { + if (x != null) _state += 1 } @@ -151,10 +154,10 @@ class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Lon _state += agg2._state } - def copy() = new CountAggregator(f, idx) + override def copy() = new CountAggregator(idx) } -class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int) +class FractionAggregator(val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int) extends TypedAggregator[java.lang.Double] { var _num = 0L @@ -166,11 +169,10 @@ class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any] else _num.toDouble / _denom - def seqOp(x: Any) { - val r = f(x) - if (r != null) { + override def seqOp(x: Any) { + if (x != null) { _denom += 1 - localA(lambdaIdx) = r + localA(lambdaIdx) = x if (bodyFn().asInstanceOf[Boolean]) _num += 1 } @@ -181,29 +183,28 @@ class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any] _denom += agg2._denom } - def copy() = new FractionAggregator(f, idx, localA, bodyFn, lambdaIdx) + override def copy() = new FractionAggregator(idx, localA, bodyFn, lambdaIdx) } -class StatAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[StatCounter] { +class StatAggregator(val idx: Int) extends TypedAggregator[StatCounter] { var _state = new StatCounter() def result = _state - def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(DoubleNumericConversion.to(r)) + override def seqOp(x: Any) { + if (x != null) + _state.merge(DoubleNumericConversion.to(x)) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new StatAggregator(f, idx) + override def copy() = new StatAggregator(idx) } -class CounterAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] { +class CounterAggregator(val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] { var m = new mutable.HashMap[Any, Long] def result = m @@ -220,111 +221,104 @@ class CounterAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[m } } - def copy() = new CounterAggregator(f, idx) + override def copy() = new StatAggregator(idx) } -class HistAggregator(f: (Any) => Any, val idx: Int, indices: Array[Double]) +class HistAggregator(val idx: Int, indices: Array[Double]) extends TypedAggregator[HistogramCombiner] { var _state = new HistogramCombiner(indices) def result = _state - def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(DoubleNumericConversion.to(r)) + override def seqOp(x: Any) { + if (x != null) + _state.merge(DoubleNumericConversion.to(x)) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new HistAggregator(f, idx, indices) + override def copy() = new HistAggregator(idx, indices) } -class CollectAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] { +class CollectAggregator(val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] { var _state = new ArrayBuffer[Any] def result = _state - def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state += f(x) + override def seqOp(x: Any) { + if (x != null) + _state += x } def combOp(agg2: this.type) = _state ++= agg2._state - def copy() = new CollectAggregator(f, idx) + override def copy() = new CollectAggregator(idx) } -class InfoScoreAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[InfoScoreCombiner] { +class InfoScoreAggregator(val idx: Int) extends TypedAggregator[InfoScoreCombiner] { var _state = new InfoScoreCombiner() def result = _state - def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(r.asInstanceOf[Genotype]) + override def seqOp(x: Any) { + if (x != null) + _state.merge(x.asInstanceOf[Genotype]) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new InfoScoreAggregator(f, idx) + override def copy() = new InfoScoreAggregator(idx) } -class HWEAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[HWECombiner] { +class HWEAggregator(val idx: Int) extends TypedAggregator[HWECombiner] { var _state = new HWECombiner() def result = _state - def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(r.asInstanceOf[Genotype]) + override def seqOp(x: Any) { + if (x != null) + _state.merge(x.asInstanceOf[Genotype]) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new HWEAggregator(f, idx) + override def copy() = new HWEAggregator(idx) } -class SumAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Double] { +class SumAggregator(val idx: Int) extends TypedAggregator[Double] { var _state = 0d def result = _state override def seqOp(x: Any) { - println(s"sumagg seqop input: $x") - val r = f(x) - println(s"sumagg seqop result: $r") - if (r != null) - _state += DoubleNumericConversion.to(r) + if (x != null) + _state += DoubleNumericConversion.to(x) } def combOp(agg2: this.type) = _state += agg2._state - def copy() = new SumAggregator(f, idx) + override def copy() = new SumAggregator(idx) } -class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position) +class SumArrayAggregator(val idx: Int, localPos: Position) extends TypedAggregator[IndexedSeq[Double]] { var _state: Array[Double] = _ def result = _state - def seqOp(x: Any) { - val r = f(x).asInstanceOf[IndexedSeq[Any]] + override def seqOp(x: Any) { + val r = x.asInstanceOf[IndexedSeq[Any]] if (r != null) { if (_state == null) _state = r.map(x => if (x == null) 0d else DoubleNumericConversion.to(x)).toArray @@ -355,10 +349,10 @@ class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position) _state(i) += agg2state(i) } - def copy() = new SumArrayAggregator(f, idx, localPos) + override def copy() = new SumArrayAggregator(idx, localPos) } -class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any) +class CallStatsAggregator(val idx: Int, variantF: () => Any) extends TypedAggregator[CallStats] { var first = true @@ -380,9 +374,8 @@ class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any) } if (combiner != null) { - val r = f(x) - if (r != null) - combiner.merge(r.asInstanceOf[Genotype]) + if (x != null) + combiner.merge(x.asInstanceOf[Genotype]) } } @@ -394,26 +387,25 @@ class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any) combiner.merge(agg2.combiner) } - def copy(): TypedAggregator[CallStats] = new CallStatsAggregator(f, idx, variantF) + def copy(): TypedAggregator[CallStats] = new CallStatsAggregator(idx, variantF) } -class InbreedingAggregator(f: (Any) => Any, localIdx: Int, getAF: () => Any) extends TypedAggregator[InbreedingCombiner] { +class InbreedingAggregator(localIdx: Int, getAF: () => Any) extends TypedAggregator[InbreedingCombiner] { var _state = new InbreedingCombiner() def result = _state - def seqOp(x: Any) = { - val r = f(x) + override def seqOp(x: Any) = { val af = getAF() - if (r != null && af != null) - _state.merge(r.asInstanceOf[Genotype], DoubleNumericConversion.to(af)) + if (x != null && af != null) + _state.merge(x.asInstanceOf[Genotype], DoubleNumericConversion.to(af)) } def combOp(agg2: this.type) = _state.merge(agg2.asInstanceOf[InbreedingAggregator]._state) - def copy() = new InbreedingAggregator(f, localIdx, getAF) + override def copy() = new InbreedingAggregator(localIdx, getAF) def idx = localIdx } From 2b3b6d2247f8ee6bc34abaf528691ac6ca7f56c1 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 17 Nov 2016 14:10:22 -0500 Subject: [PATCH 42/51] removed debugging statements --- .../org/broadinstitute/hail/expr/AST.scala | 21 +---- .../org/broadinstitute/hail/expr/Type.scala | 3 - .../hail/keytable/KeyTable.scala | 84 +++++-------------- .../org/broadinstitute/hail/SparkSuite.scala | 2 +- .../hail/methods/KeyTableSuite.scala | 30 ++----- 5 files changed, 33 insertions(+), 107 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 3677371bd1e..4b237aa4d5e 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -663,22 +663,9 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST body.typecheck(agg.ec.copy(st = st)) `type` = body.`type` match { case t: Type => - println(s"maptypecheck bodyType: ${ body.getClass }") - println(s"maptypecheck param: ${ param }") - println(s"maptypecheck localIdx: $localIdx") - println(s"maptypecheck elementtype: ${ agg.elementType }") - println(s"maptypecheck localA: ${ localA }") - println(s"maptypecheck localA identityCode: ${ System.identityHashCode(localA) }") - println(s"maptypecheck ec: ${ agg.ec }") - println(s"maptypecheck ec identityCode: ${ System.identityHashCode(agg.ec) }") val fn = body.eval(agg.ec.copy(st = st)) val mapF = (a: Any) => { localA(localIdx) = a - println(s"mapF ec: ${ agg.ec }") - println(s"mapF LocalA: ${ localA }") - println(s"mapF ec identity code: ${ System.identityHashCode(agg.ec) }") - println(s"mapF LocalA identity code: ${ System.identityHashCode(localA) }") - println(s"mapF result eval fn(): ${ fn() }") fn() } @@ -1712,15 +1699,9 @@ case class SliceArray(posn: Position, f: AST, idx1: Option[AST], idx2: Option[AS case class SymRef(posn: Position, symbol: String) extends AST(posn) { def eval(ec: EvalContext): () => Any = { - println(s"symref posn: $posn") - println(s"symref symbol: $symbol") val localI = ec.st(symbol)._1 val localA = ec.a - println(s"symref ec: ${ ec }") - println(s"symref localI: ${ ec.st(symbol)._1 }") - println(s"symref localA: ${ ec.a }") - println(s"symref ec identityhash: ${ System.identityHashCode(ec) }") - println(s"symref ec.a identityhash: ${ System.identityHashCode(ec.a) }") + if (localI < 0) () => 0 // FIXME placeholder else diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala index 6b2a73a2941..67b2032345f 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala @@ -283,11 +283,8 @@ case class MappedAggregable(parent: TAggregable, elementType: Type, mapF: (Any) def f: (Any) => Any = { val parentF = parent.f (a: Any) => { - println(s"mapagg a: $a") val prev = parentF(a) - println(s"mapagg prev: $prev") if (prev != null) { - println(s"mapagg result: ${mapF(prev)}") mapF(prev) } else diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 82eb5ae4227..5dba21859c2 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -148,12 +148,14 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues val (t, f) = Parser.parse(code, ec) val f2: (Annotation, Annotation) => Option[Any] = { case (k, v) => - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) f() } @@ -162,11 +164,12 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def querySingle(code: String): (BaseType, Querier) = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nFieldsLocal = nFields val (t, f) = Parser.parse(code, ec) val f2: (Annotation) => Option[Any] = { a => - KeyTable.setEvalContext(ec, a, nFields) + KeyTable.setEvalContext(ec, a, nFieldsLocal) f() } @@ -194,10 +197,10 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val keyNameArray = if (keyNameString != null) Parser.parseIdentifierList(keyNameString) else keyNames - // val nFields = nFields + val nFieldsLocal = nFields val f: Annotation => Annotation = { a => - KeyTable.setEvalContext(ec, a, nFields) + KeyTable.setEvalContext(ec, a, nFieldsLocal) fns.zip(inserters) .foldLeft(a) { case (a1, (fn, inserter)) => @@ -213,11 +216,13 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def filter(cond: String, keep: Boolean): KeyTable = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) val p = (k: Annotation, v: Annotation) => { - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) Filter.keepThis(f(), keep) } @@ -284,11 +289,13 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def forall(code: String): Boolean = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) val p = (k: Annotation, v: Annotation) => { - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) f().getOrElse(false) } @@ -297,11 +304,13 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def exists(code: String): Boolean = { val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) val p = (k: Annotation, v: Annotation) => { - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) f().getOrElse(false) } @@ -324,16 +333,16 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v } hConf.delete(output, recursive = true) - // - // val nKeys = nKeys - // val nValues = nValues + + val nKeysLocal = nKeys + val nValuesLocal = nValues rdd .mapPartitions { it => val sb = new StringBuilder() it.map { case (k, v) => sb.clear() - KeyTable.setEvalContext(ec, k, v, nKeys, nValues) + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) f().foreachBetween(x => sb.append(x))(sb += '\t') sb.result() } @@ -376,61 +385,12 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v array } - val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions { it => - it.map { a => - KeyTable.setEvalContext(aggregationEC, a, nFields) - val key = Annotation.fromSeq(keyF.map(_ ())) - (key, a) - } - }.aggregateByKey(zVals)(seqOp, combOp) - .map { case (k, agg) => - resultOp(agg) - (k, Annotation.fromSeq(aggF.map(_ ()))) - } - - KeyTable(newRDD, keySignature, valueSignature) - } - - def aggregateRows(keyCond: String, aggCond: String): KeyTable = { - - val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) - val st = fields.zipWithIndex.map { case (fd, i) => (fd.name, (i, fd.`type`)) }.toMap ++ Map("rows" -> (-1, BaseAggregable(aggregationEC, TStruct(fields.map(fd => (fd.name, fd.`type`)): _*)))) - val ec = EvalContext(st) - - val (keyNameParseTypes, keyF) = - if (keyCond != null) - Parser.parseAnnotationArgs(keyCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val (aggNameParseTypes, aggF) = - if (aggCond != null) - Parser.parseAnnotationArgs(aggCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) - - val keyNames = keyNameParseTypes.map(_._1.head) - val aggNames = aggNameParseTypes.map(_._1.head) - - val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*) - val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) - - val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) - val aggFunctions = aggregationEC.aggregationFunctions.map(_._1) - - val seqOp = (array: Array[Aggregator], b: Any) => { - KeyTable.setEvalContext(aggregationEC, b, nFields) - for (i <- array.indices) { - array(i).seqOp(aggFunctions(i)(b)) - } - array - } + val nFieldsLocal = nFields val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions { it => it.map { a => - KeyTable.setEvalContext(ec, a, nFields) + KeyTable.setEvalContext(aggregationEC, a, nFieldsLocal) val key = Annotation.fromSeq(keyF.map(_ ())) - println(s"keytable aggregateRow keymap key: $key") (key, a) } }.aggregateByKey(zVals)(seqOp, combOp) diff --git a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala index 3f35a119643..bbe2cddac4a 100644 --- a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala @@ -24,7 +24,7 @@ class SparkSuite extends TestNGSuite { @BeforeClass def startSpark() { val master = System.getProperty("hail.master") - sc = SparkManager.createSparkContext("Hail.TestNG", Option(master), "local[1]") + sc = SparkManager.createSparkContext("Hail.TestNG", Option(master), "local[2]") sqlContext = SparkManager.createSQLContext() diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 6347e3e0c0d..4e4a2676c7b 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -145,29 +145,17 @@ class KeyTableSuite extends SparkSuite { val keyNames = Array("field1") val kt1 = KeyTable(rdd, signature, keyNames) - val kt2 = kt1.aggregate("Status = field1", "X = field2.map(f => field2).sum()") - - val result = Array(Array("Case", 12.0), Array("Control", 3.0)) - val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) - val resSignature = TStruct(("Status", TString), ("X", TDouble)) - val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) - - - assert(kt2 same ktResult) - } - - @Test def testAggregateRows() { - val data = Array(Array("Case", 9, 0), Array("Case", 3, 4), Array("Control", 2, 3), Array("Control", 1, 5)) - val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) - val signature = TStruct(("field1", TString), ("field2", TInt), ("field3", TInt)) - val keyNames = Array("field1") - - val kt1 = KeyTable(rdd, signature, keyNames) - val kt2 = kt1.aggregateRows("Status = field1", "X = rows.map(r => field2).sum()") + val kt2 = kt1.aggregate("Status = field1", + "A = field2.sum(), " + + "B = field2.map(f => field2).sum(), " + + "C = field2.map(f => field2 + field3).sum(), " + + "D = field2.count(), " + + "E = field2.filter(f => field2 == 3).count()" + ) - val result = Array(Array("Case", 12.0), Array("Control", 3.0)) + val result = Array(Array("Case", 12.0, 12.0, 16.0, 2L, 1L), Array("Control", 3.0, 3.0, 11.0, 2L, 0L)) val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) - val resSignature = TStruct(("Status", TString), ("X", TDouble)) + val resSignature = TStruct(("Status", TString), ("A", TDouble), ("B", TDouble), ("C", TDouble), ("D", TLong), ("E", TLong)) val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) assert(kt2 same ktResult) From 6acfd206c021d0c553e12c12204eb4abf9185813 Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 17 Nov 2016 14:36:55 -0500 Subject: [PATCH 43/51] rebased with master --- .../org/broadinstitute/hail/expr/AST.scala | 2 +- .../hail/keytable/KeyTable.scala | 4 +- .../hail/methods/Aggregators.scala | 47 +++++++++---------- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 4b237aa4d5e..9018dd3de82 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -1196,7 +1196,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new CounterAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new CounterAggregator(localIdx))) () => { val m = localA(localIdx).asInstanceOf[mutable.HashMap[Any, Long]] diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 5dba21859c2..48d025dd1ec 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -20,9 +20,7 @@ object KeyTable extends Serializable with TextExporter { val files = sc.hadoopConfiguration.globAll(path) if (files.isEmpty) fatal("Arguments referred to no files") - - sc.defaultMinPartitions - + val keyNameArray = Parser.parseIdentifierList(keyNames) val (struct, rdd) = diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index 19a241a01de..b07b3044e1c 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -145,7 +145,7 @@ class CountAggregator(val idx: Int) extends TypedAggregator[Long] { def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state += 1 } @@ -154,7 +154,7 @@ class CountAggregator(val idx: Int) extends TypedAggregator[Long] { _state += agg2._state } - override def copy() = new CountAggregator(idx) + def copy() = new CountAggregator(idx) } class FractionAggregator(val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int) @@ -169,7 +169,7 @@ class FractionAggregator(val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => A else _num.toDouble / _denom - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) { _denom += 1 localA(lambdaIdx) = x @@ -183,7 +183,7 @@ class FractionAggregator(val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => A _denom += agg2._denom } - override def copy() = new FractionAggregator(idx, localA, bodyFn, lambdaIdx) + def copy() = new FractionAggregator(idx, localA, bodyFn, lambdaIdx) } class StatAggregator(val idx: Int) extends TypedAggregator[StatCounter] { @@ -192,7 +192,7 @@ class StatAggregator(val idx: Int) extends TypedAggregator[StatCounter] { def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state.merge(DoubleNumericConversion.to(x)) } @@ -201,7 +201,7 @@ class StatAggregator(val idx: Int) extends TypedAggregator[StatCounter] { _state.merge(agg2._state) } - override def copy() = new StatAggregator(idx) + def copy() = new StatAggregator(idx) } class CounterAggregator(val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] { @@ -210,9 +210,8 @@ class CounterAggregator(val idx: Int) extends TypedAggregator[mutable.HashMap[An def result = m def seqOp(x: Any) { - val r = f(x) - if (r != null) - m.updateValue(r, 0L, _ + 1) + if (x != null) + m.updateValue(x, 0L, _ + 1) } def combOp(agg2: this.type) { @@ -221,7 +220,7 @@ class CounterAggregator(val idx: Int) extends TypedAggregator[mutable.HashMap[An } } - override def copy() = new StatAggregator(idx) + def copy() = new CounterAggregator(idx) } class HistAggregator(val idx: Int, indices: Array[Double]) @@ -231,7 +230,7 @@ class HistAggregator(val idx: Int, indices: Array[Double]) def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state.merge(DoubleNumericConversion.to(x)) } @@ -240,7 +239,7 @@ class HistAggregator(val idx: Int, indices: Array[Double]) _state.merge(agg2._state) } - override def copy() = new HistAggregator(idx, indices) + def copy() = new HistAggregator(idx, indices) } class CollectAggregator(val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] { @@ -249,14 +248,14 @@ class CollectAggregator(val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state += x } def combOp(agg2: this.type) = _state ++= agg2._state - override def copy() = new CollectAggregator(idx) + def copy() = new CollectAggregator(idx) } class InfoScoreAggregator(val idx: Int) extends TypedAggregator[InfoScoreCombiner] { @@ -265,7 +264,7 @@ class InfoScoreAggregator(val idx: Int) extends TypedAggregator[InfoScoreCombine def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state.merge(x.asInstanceOf[Genotype]) } @@ -274,7 +273,7 @@ class InfoScoreAggregator(val idx: Int) extends TypedAggregator[InfoScoreCombine _state.merge(agg2._state) } - override def copy() = new InfoScoreAggregator(idx) + def copy() = new InfoScoreAggregator(idx) } class HWEAggregator(val idx: Int) extends TypedAggregator[HWECombiner] { @@ -283,7 +282,7 @@ class HWEAggregator(val idx: Int) extends TypedAggregator[HWECombiner] { def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state.merge(x.asInstanceOf[Genotype]) } @@ -292,7 +291,7 @@ class HWEAggregator(val idx: Int) extends TypedAggregator[HWECombiner] { _state.merge(agg2._state) } - override def copy() = new HWEAggregator(idx) + def copy() = new HWEAggregator(idx) } class SumAggregator(val idx: Int) extends TypedAggregator[Double] { @@ -300,14 +299,14 @@ class SumAggregator(val idx: Int) extends TypedAggregator[Double] { def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { if (x != null) _state += DoubleNumericConversion.to(x) } def combOp(agg2: this.type) = _state += agg2._state - override def copy() = new SumAggregator(idx) + def copy() = new SumAggregator(idx) } class SumArrayAggregator(val idx: Int, localPos: Position) @@ -317,7 +316,7 @@ class SumArrayAggregator(val idx: Int, localPos: Position) def result = _state - override def seqOp(x: Any) { + def seqOp(x: Any) { val r = x.asInstanceOf[IndexedSeq[Any]] if (r != null) { if (_state == null) @@ -349,7 +348,7 @@ class SumArrayAggregator(val idx: Int, localPos: Position) _state(i) += agg2state(i) } - override def copy() = new SumArrayAggregator(idx, localPos) + def copy() = new SumArrayAggregator(idx, localPos) } class CallStatsAggregator(val idx: Int, variantF: () => Any) @@ -396,7 +395,7 @@ class InbreedingAggregator(localIdx: Int, getAF: () => Any) extends TypedAggrega def result = _state - override def seqOp(x: Any) = { + def seqOp(x: Any) = { val af = getAF() if (x != null && af != null) @@ -405,7 +404,7 @@ class InbreedingAggregator(localIdx: Int, getAF: () => Any) extends TypedAggrega def combOp(agg2: this.type) = _state.merge(agg2.asInstanceOf[InbreedingAggregator]._state) - override def copy() = new InbreedingAggregator(localIdx, getAF) + def copy() = new InbreedingAggregator(localIdx, getAF) def idx = localIdx } From 8b2f0b8f8ef81028f0406052e6f7b8740f8551eb Mon Sep 17 00:00:00 2001 From: Jackie Goldstein Date: Thu, 17 Nov 2016 16:03:14 -0500 Subject: [PATCH 44/51] added python tests --- python/pyhail/__init__.py | 2 +- python/pyhail/context.py | 6 ++--- python/pyhail/keytable.py | 10 ++++----- python/pyhail/tests.py | 46 ++++++++++++++++++++++++++++++++++++++- python/pyhail/utils.py | 2 +- 5 files changed, 55 insertions(+), 11 deletions(-) diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py index 2543b1978f2..e4c87402c4e 100644 --- a/python/pyhail/__init__.py +++ b/python/pyhail/__init__.py @@ -1,6 +1,6 @@ from pyhail.context import HailContext from pyhail.dataset import VariantDataset from pyhail.keytable import KeyTable -from pyhail.utils import TextTableConfig +from pyhail.utils import TextTableConfig, Type __all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig", "Type"] diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 66823314659..348e8025e2e 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -308,9 +308,9 @@ def import_keytable(self, path, key_names, npartitions=None, config=None): npartitions = self.sc.defaultMinPartitions if not config: - config = TextTableConfig()._toJavaObject(self) - elif isinstance(key_names, TextTableConfig): - config = config._toJavaObject(self) + config = TextTableConfig()._jobj(self) + elif isinstance(config, TextTableConfig): + config = config._jobj(self) return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), key_names, npartitions, config)) diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index e0657bc4d47..eb810672e45 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -18,7 +18,7 @@ def __init__(self, hc, jkt): def __repr__(self): return self.jkt.toString() - def nFields(self): + def nfields(self): """Number of fields in the key-table :rtype: int @@ -32,21 +32,21 @@ def schema(self): """ return Type(self.jkt.signature()) - def keyNames(self): + def key_names(self): """Field names that are keys :rtype: list of str """ return self.jkt.keyNames() - def fieldNames(self): + def field_names(self): """Names of all fields in the key-table :rtype: list of str """ return self.jkt.fieldNames() - def nRows(self): + def nrows(self): """Number of rows in the key-table :rtype: long @@ -109,7 +109,7 @@ def join(self, right, how='inner'): """ return KeyTable(self.hc, self.jkt.join(right.jkt, how)) - def _aggregate(self, key_code, agg_code): + def aggregate_by_key(self, key_code, agg_code): """Group by key condition and aggregate results :param key_code: Named expression(s) for which fields are keys. diff --git a/python/pyhail/tests.py b/python/pyhail/tests.py index 158d77220e7..68668a4bb6f 100644 --- a/python/pyhail/tests.py +++ b/python/pyhail/tests.py @@ -6,7 +6,7 @@ import unittest from pyspark import SparkContext -from pyhail import HailContext +from pyhail import HailContext, TextTableConfig class ContextTests(unittest.TestCase): @@ -208,6 +208,50 @@ def test_dataset(self): sample2_split.variant_qc().print_schema() sample2.variants_to_pandas() + + sample_split.annotate_variants_expr("va.nHet = gs.filter(g => g.isHet).count()") + kt = sample_split.aggregate_by_key("Variant = v", "nHet = gs.filter(g => g.isHet).count()") + + def test_keytable(self): + # Import + kt = self.hc.import_keytable(self.test_resources + '/sampleAnnotations.tsv', 'Sample', config = TextTableConfig(impute = True)) + kt2 = self.hc.import_keytable(self.test_resources + '/sampleAnnotations2.tsv', 'Sample', config = TextTableConfig(impute = True)) + + # Variables + self.assertEqual(kt.nfields(), 3) + self.assertEqual(kt.key_names()[0], "Sample") + self.assertEqual(kt.field_names()[2], "qPhen") + self.assertEqual(kt.nrows(), 100) + kt.schema() + + # Export + kt.export('/tmp/testExportKT.tsv') + + # Filter, Same + ktcase = kt.filter('Status == "CASE"', True) + ktcase2 = kt.filter('Status == "CTRL"', False) + self.assertTrue(ktcase.same(ktcase2)) + + # Annotate + kt4 = kt.annotate('X = Status', 'Sample, Status') + + # Join + kt5 = kt.join(kt2, 'left') + + # AggregateByKey + kt6 = kt.aggregate_by_key("Status = Status", "Sum = qPhen.sum()") + + # Forall, Exists + self.assertFalse(kt.forall('Status == "CASE"')) + self.assertTrue(kt.exists('Status == "CASE"')) + + + + + + + + def tearDown(self): self.sc.stop() diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py index 682f29593b4..165c980fe9b 100644 --- a/python/pyhail/utils.py +++ b/python/pyhail/utils.py @@ -45,7 +45,7 @@ def __str__(self): return " ".join(res) - def _toJavaObject(self, hc): + def _jobj(self, hc): """Convert to java TextTableConfiguration object :param :class:`.HailContext` hc: Hail spark context. From 3cfb11be4f75e44abaa21177f1d8559951317a7f Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Sun, 20 Nov 2016 00:49:10 -0500 Subject: [PATCH 45/51] Fixed additional failures. --- python/pyhail/tests.py | 11 ++----- .../hail/variant/VariantSampleMatrix.scala | 32 ++++++++++--------- .../hail/driver/AggregateByKeySuite.scala | 6 ++-- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/python/pyhail/tests.py b/python/pyhail/tests.py index 68668a4bb6f..9d58237ed07 100644 --- a/python/pyhail/tests.py +++ b/python/pyhail/tests.py @@ -210,7 +210,8 @@ def test_dataset(self): sample2.variants_to_pandas() sample_split.annotate_variants_expr("va.nHet = gs.filter(g => g.isHet).count()") - kt = sample_split.aggregate_by_key("Variant = v", "nHet = gs.filter(g => g.isHet).count()") + + kt = sample_split.aggregate_by_key("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong") def test_keytable(self): # Import @@ -245,13 +246,5 @@ def test_keytable(self): self.assertFalse(kt.forall('Status == "CASE"')) self.assertTrue(kt.exists('Status == "CASE"')) - - - - - - - - def tearDown(self): self.sc.stop() diff --git a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala index f3d08f01e77..d1e3bdbd98f 100644 --- a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala +++ b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala @@ -605,17 +605,20 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, "va" -> (1, vaSignature), "s" -> (2, TSample), "sa" -> (3, saSignature), - "global" -> (4, globalSignature))) + "global" -> (4, globalSignature), + "g" -> (5, TGenotype))) - val symTab = Map( + val ec = EvalContext(Map( "v" -> (0, TVariant), "va" -> (1, vaSignature), "s" -> (2, TSample), "sa" -> (3, saSignature), "global" -> (4, globalSignature), - "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype))) + "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype)))) - val ec = EvalContext(symTab) + val ktEC = EvalContext( + aggregationEC.st.map { case (name, (i, t)) => name -> (-1, KeyTableAggregable(aggregationEC, t.asInstanceOf[Type], i)) } + ) ec.set(4, globalAnnotation) aggregationEC.set(4, globalAnnotation) @@ -628,26 +631,25 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, val (aggNameParseTypes, aggF) = if (aggCond != null) - Parser.parseAnnotationArgs(aggCond, ec, None) + Parser.parseAnnotationArgs(aggCond, ktEC, None) else (Array.empty[(List[String], Type)], Array.empty[() => Any]) val keyNames = keyNameParseTypes.map(_._1.head) val aggNames = aggNameParseTypes.map(_._1.head) - val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) - val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*) + val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + val aggFunctions = aggregationEC.aggregationFunctions.map(_._1) + + val localGlobalAnnotation = globalAnnotation - val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => { - val (v, va, s, sa, aggT) = b - ec.set(0, v) - ec.set(1, va) - ec.set(2, s) - ec.set(3, sa) + val seqOp = (array: Array[Aggregator], r: Annotation) => { + KeyTable.setEvalContext(aggregationEC, r, 6) for (i <- array.indices) { - array(i).seqOp(aggT) + array(i).seqOp(aggFunctions(i)(r)) } array } @@ -656,7 +658,7 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, it.map { case (v, va, s, sa, g) => ec.setAll(v, va, s, sa, g) val key = Annotation.fromSeq(keyF.map(_ ())) - (key, (v, va, s, sa, g)) + (key, Annotation(v, va, s, sa, localGlobalAnnotation, g)) } }.aggregateByKey(zVals)(seqOp, combOp) .map { case (k, agg) => diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala index 79269efe1e1..c72b145e678 100644 --- a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -11,7 +11,7 @@ class AggregateByKeySuite extends SparkSuite { var s = State(sc, sqlContext) s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()")) - val kt = s.vds.aggregateByKey("Sample = s", "nHet = gs.filter(g => g.isHet).count()") + val kt = s.vds.aggregateByKey("Sample = s", "nHet = g.map(g => g.isHet.toInt).sum().toLong") val (_, ktHetQuery) = kt.query("nHet") val (_, ktSampleQuery) = kt.query("Sample") @@ -29,7 +29,7 @@ class AggregateByKeySuite extends SparkSuite { var s = State(sc, sqlContext) s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) - val kt = s.vds.aggregateByKey("Variant = v", "nHet = gs.filter(g => g.isHet).count()") + val kt = s.vds.aggregateByKey("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong") val (_, ktHetQuery) = kt.query("nHet") val (_, ktVariantQuery) = kt.query("Variant") @@ -48,7 +48,7 @@ class AggregateByKeySuite extends SparkSuite { s = ImportVCF.run(s, Array(inputVCF)) s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum().toLong")) - val kt = s.vds.aggregateByKey(null, "nHet = gs.filter(g => g.isHet).count()") + val kt = s.vds.aggregateByKey(null, "nHet = g.map(g => g.isHet.toInt).sum().toLong") val (_, ktHetQuery) = kt.query("nHet") val (_, globalHetResult) = s.vds.queryGlobal("global.nHet") From 75594faf35a20ccff5d4d554068f70034561adba Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Sun, 20 Nov 2016 13:41:09 -0500 Subject: [PATCH 46/51] Minor edits. --- python/pyhail/context.py | 16 ++++++------ python/pyhail/dataset.py | 2 +- python/pyhail/keytable.py | 9 ++++--- python/pyhail/type.py | 10 ++++++++ python/pyhail/utils.py | 25 ++++++------------- .../org/broadinstitute/hail/expr/Type.scala | 5 ++-- .../hail/keytable/KeyTable.scala | 11 +++----- 7 files changed, 36 insertions(+), 42 deletions(-) create mode 100644 python/pyhail/type.py diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 348e8025e2e..8b5837c9a34 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -278,7 +278,7 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch return self.run_command(None, pargs) def import_keytable(self, path, key_names, npartitions=None, config=None): - """Import tabular file as KeyTable + """Import delimited text file (text table) as KeyTable. :param path: files to import. :type path: str or list of str @@ -290,7 +290,7 @@ def import_keytable(self, path, key_names, npartitions=None, config=None): :type npartitions: int or None :param config: Configuration options for importing text files - :type config: :class:`.TextTableConfig` + :type config: :class:`.TextTableConfig` or None :rtype: :class:`.KeyTable` """ @@ -302,20 +302,18 @@ def import_keytable(self, path, key_names, npartitions=None, config=None): pathArgs.append(p) if not isinstance(key_names, str): - key_names = ",".join(key_names) + key_names = ','.join(key_names) if not npartitions: npartitions = self.sc.defaultMinPartitions if not config: - config = TextTableConfig()._jobj(self) - elif isinstance(config, TextTableConfig): - config = config._jobj(self) + config = TextTableConfig() - return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), - key_names, npartitions, config)) + return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable( + self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), key_names, npartitions, config.to_java(self))) - def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing="NA", quantpheno=False): + def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing='NA', quantpheno=False): """ Import PLINK binary file (.bed, .bim, .fam) as VariantDataset diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index 279ce0ea441..c4e51e1b226 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -12,7 +12,7 @@ def __init__(self, hc, jvds): def _raise_py4j_exception(self, e): self.hc._raise_py4j_exception(e) - def aggregate_by_key(self, key_code=None, agg_code=None): + def aggregate_by_key(self, key_code, agg_code): """Aggregate by user-defined key and aggregation expressions. Equivalent of a group-by operation in SQL. diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index eb810672e45..5556755db13 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -1,8 +1,9 @@ from pyhail.utils import Type class KeyTable(object): - """:class:`.KeyTable` is Hail's version of a SQL - table where fields can be designated as keys. + """:class:`.KeyTable` is Hail's version of a SQL table where fields + can be designated as keys. + """ def __init__(self, hc, jkt): @@ -85,7 +86,7 @@ def filter(self, code, keep=True): """ return KeyTable(self.hc, self.jkt.filter(code, keep)) - def annotate(self, code, key_names=None): + def annotate(self, code, key_names=''): """Add fields to key-table. :param str code: Annotation expression. @@ -144,4 +145,4 @@ def exists(self, code): :rtype: bool """ - return self.jkt.exists(code) \ No newline at end of file + return self.jkt.exists(code) diff --git a/python/pyhail/type.py b/python/pyhail/type.py new file mode 100644 index 00000000000..b7bbe373292 --- /dev/null +++ b/python/pyhail/type.py @@ -0,0 +1,10 @@ + +class Type(object): + def __init__(self, jtype): + self.jtype = jtype + + def __repr__(self): + return self.jtype.toString() + + def __str__(self): + return self.jtype.toPrettyString(False, False) diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py index 165c980fe9b..2a6906dc31b 100644 --- a/python/pyhail/utils.py +++ b/python/pyhail/utils.py @@ -1,28 +1,20 @@ -class Type(object): - def __init__(self, jtype): - self.jtype = jtype - - def __repr__(self): - return self.jtype.toString() - - def __str__(self): - return self.jtype.toPrettyString(False, False) - class TextTableConfig(object): - """:class:`.TextTableConfig` specifies additional options for importing TSV files. + """Configuration for delimited (text table) files. :param bool noheader: File has no header and columns should be indicated by `_1, _2, ... _N' (0-indexed) :param bool impute: Impute column types from the file - :param str comment: Skip lines beginning with the given pattern + :param comment: Skip lines beginning with the given pattern + :type comment: str or None :param str delimiter: Field delimiter regex :param str missing: Specify identifier to be treated as missing - :param str types: Define types of fields in annotations files + :param types: Define types of fields in annotations files + :type types: str or None """ def __init__(self, noheader = False, impute = False, comment = None, delimiter = "\t", missing = "NA", types = None): @@ -45,12 +37,11 @@ def __str__(self): return " ".join(res) - def _jobj(self, hc): - """Convert to java TextTableConfiguration object + def to_java(self, hc): + """Convert to Java TextTableConfiguration object. - :param :class:`.HailContext` hc: Hail spark context. + :param :class:`.HailContext` The Hail context. """ return hc.jvm.org.broadinstitute.hail.utils.TextTableConfiguration.apply(self.types, self.comment, self.delimiter, self.missing, self.noheader, self.impute) - diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala index 67b2032345f..baeef97045a 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala @@ -284,9 +284,8 @@ case class MappedAggregable(parent: TAggregable, elementType: Type, mapF: (Any) val parentF = parent.f (a: Any) => { val prev = parentF(a) - if (prev != null) { + if (prev != null) mapF(prev) - } else null } @@ -320,7 +319,7 @@ case class TArray(elementType: Type) extends TIterable { override def str(a: Annotation): String = JsonMethods.compact(toJSON(a)) - override def genValue: Gen[Annotation] = Gen.buildableOf[Array, Annotation](elementType.genValue).map(x => x: IndexedSeq[Annotation]) + override def genValue: Gen[Annotation] = Gen.buildableOf[IndexedSeq, Annotation](elementType.genValue) } case class TSet(elementType: Type) extends TIterable { diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 48d025dd1ec..761ddd2f547 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -29,7 +29,6 @@ object KeyTable extends Serializable with TextExporter { else TextTableReader.read(sc)(files, config, nPartitions) - val keyNamesValid = keyNameArray.forall { k => val res = struct.selfField(k).isDefined if (!res) @@ -83,7 +82,7 @@ object KeyTable extends Serializable with TextExporter { case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { - require(fieldNames.toSet.size == fieldNames.length) + require(fieldNames.areDistinct()) def signature = keySignature.merge(valueSignature)._1 @@ -292,12 +291,10 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) - val p = (k: Annotation, v: Annotation) => { + rdd.forall { case (k, v) => KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) f().getOrElse(false) } - - rdd.forall { case (k, v) => p(k, v) } } def exists(code: String): Boolean = { @@ -307,12 +304,10 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) - val p = (k: Annotation, v: Annotation) => { + rdd.exists { case (k, v) => KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) f().getOrElse(false) } - - rdd.exists { case (k, v) => p(k, v) } } def export(sc: SparkContext, output: String, typesFile: String) = { From 3211b7a5b553f1184ff8b594d239f40ce83a112d Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Mon, 21 Nov 2016 14:12:10 -0500 Subject: [PATCH 47/51] Fixed python imports. --- python/pyhail/__init__.py | 3 ++- python/pyhail/keytable.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py index e4c87402c4e..a27e3d6f47d 100644 --- a/python/pyhail/__init__.py +++ b/python/pyhail/__init__.py @@ -1,6 +1,7 @@ from pyhail.context import HailContext from pyhail.dataset import VariantDataset from pyhail.keytable import KeyTable -from pyhail.utils import TextTableConfig, Type +from pyhail.utils import TextTableConfig +from pyhail.type import Type __all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig", "Type"] diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index 5556755db13..ddacad436af 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -1,4 +1,4 @@ -from pyhail.utils import Type +from pyhail.type import Type class KeyTable(object): """:class:`.KeyTable` is Hail's version of a SQL table where fields From 260a8560e2161e96f331b03fb0fc811119c05c25 Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Mon, 21 Nov 2016 16:42:29 -0500 Subject: [PATCH 48/51] Finalize rebase. --- python/pyhail/context.py | 2 +- python/pyhail/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 8b5837c9a34..55e315c0a5b 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -3,7 +3,7 @@ from pyhail.dataset import VariantDataset from pyhail.java import jarray, scala_object, scala_package_object from pyhail.keytable import KeyTable -from pyhail.TextTableConfig import TextTableConfig +from pyhail.utils import TextTableConfig from py4j.protocol import Py4JJavaError class FatalError(Exception): diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index c4e51e1b226..b0dadbf177c 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -27,7 +27,7 @@ def aggregate_by_key(self, key_code, agg_code): :rtype: :class:`.KeyTable` """ - return KeyTable(self.hc, self.jvds.aggregateByKey(key_condition, agg_condition)) + return KeyTable(self.hc, self.jvds.aggregateByKey(key_code, agg_code)) def aggregate_intervals(self, input, condition, output): """Aggregate over intervals and export. From 9417222b090beeb6e0c1d975f98fd7b79f7b2ca9 Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Mon, 21 Nov 2016 18:14:54 -0500 Subject: [PATCH 49/51] Addressed comments. --- python/pyhail/dataset.py | 3 +- python/pyhail/keytable.py | 77 +++++++++++++++++++++++++++++++-------- 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index b0dadbf177c..8fd18373f1e 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -16,8 +16,6 @@ def aggregate_by_key(self, key_code, agg_code): """Aggregate by user-defined key and aggregation expressions. Equivalent of a group-by operation in SQL. - The resulting key-table will have four fields [pheno, gene, nHet, nAlleles] where pheno and gene are the keys. - :param key_code: Named expression(s) for which fields are keys. :type key_code: str or list of str @@ -25,6 +23,7 @@ def aggregate_by_key(self, key_code, agg_code): :type agg_code: str or list of str :rtype: :class:`.KeyTable` + """ return KeyTable(self.hc, self.jvds.aggregateByKey(key_code, agg_code)) diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py index ddacad436af..ec294f4c556 100644 --- a/python/pyhail/keytable.py +++ b/python/pyhail/keytable.py @@ -8,51 +8,72 @@ class KeyTable(object): def __init__(self, hc, jkt): """ - :param hc: Hail spark context. - :type hc: :class:`.HailContext` + :param HailContext hc: Hail spark context. :param JavaKeyTable jkt: Java KeyTable object. """ self.hc = hc self.jkt = jkt + def _raise_py4j_exception(self, e): + self.hc._raise_py4j_exception(e) + def __repr__(self): - return self.jkt.toString() + try: + return self.jkt.toString() + except Py4JJavaError as e: + self._raise_py4j_exception(e) def nfields(self): """Number of fields in the key-table :rtype: int """ - return self.jkt.nFields() + try: + return self.jkt.nFields() + except Py4JJavaError as e: + self._raise_py4j_exception(e) def schema(self): """Key-table schema :rtype: :class:`.Type` """ - return Type(self.jkt.signature()) + try: + return Type(self.jkt.signature()) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def key_names(self): """Field names that are keys :rtype: list of str """ - return self.jkt.keyNames() + try: + return self.jkt.keyNames() + except Py4JJavaError as e: + self._raise_py4j_exception(e) def field_names(self): """Names of all fields in the key-table :rtype: list of str """ - return self.jkt.fieldNames() + try: + return self.jkt.fieldNames() + except Py4JJavaError as e: + self._raise_py4j_exception(e) def nrows(self): """Number of rows in the key-table :rtype: long """ - return self.jkt.nRows() + try: + return self.jkt.nRows() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + def same(self, other): """Test whether two key-tables are identical @@ -62,7 +83,10 @@ def same(self, other): :rtype: bool """ - return self.jkt.same(other.jkt) + try: + return self.jkt.same(other.jkt) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def export(self, output, types_file=None): """Export key-table to a TSV file. @@ -73,7 +97,10 @@ def export(self, output, types_file=None): :rtype: Nothing. """ - self.jkt.export(self.hc.jsc, output, types_file) + try: + self.jkt.export(self.hc.jsc, output, types_file) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def filter(self, code, keep=True): """Filter rows from key-table. @@ -84,7 +111,10 @@ def filter(self, code, keep=True): :rtype: :class:`.KeyTable` """ - return KeyTable(self.hc, self.jkt.filter(code, keep)) + try: + return KeyTable(self.hc, self.jkt.filter(code, keep)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def annotate(self, code, key_names=''): """Add fields to key-table. @@ -95,7 +125,10 @@ def annotate(self, code, key_names=''): :rtype: :class:`.KeyTable` """ - return KeyTable(self.hc, self.jkt.annotate(code, key_names)) + try: + return KeyTable(self.hc, self.jkt.annotate(code, key_names)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def join(self, right, how='inner'): """Join two key-tables together. Both key-tables must have identical key schemas @@ -108,7 +141,10 @@ def join(self, right, how='inner'): :rtype: :class:`.KeyTable` """ - return KeyTable(self.hc, self.jkt.join(right.jkt, how)) + try: + return KeyTable(self.hc, self.jkt.join(right.jkt, how)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def aggregate_by_key(self, key_code, agg_code): """Group by key condition and aggregate results @@ -127,7 +163,10 @@ def aggregate_by_key(self, key_code, agg_code): if isinstance(agg_code, list): agg_code = ", ".join([str(l) for l in list]) - return KeyTable(self.hc, self.jkt.aggregate(key_code, agg_code)) + try: + return KeyTable(self.hc, self.jkt.aggregate(key_code, agg_code)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def forall(self, code): """Tests whether a condition is true for all rows @@ -136,7 +175,10 @@ def forall(self, code): :rtype: bool """ - return self.jkt.forall(code) + try: + return self.jkt.forall(code) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def exists(self, code): """Tests whether a condition is true for any row @@ -145,4 +187,7 @@ def exists(self, code): :rtype: bool """ - return self.jkt.exists(code) + try: + return self.jkt.exists(code) + except Py4JJavaError as e: + self._raise_py4j_exception(e) From f520b46de5543a0a00df0cefa8d5ab3a5102a227 Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Tue, 22 Nov 2016 01:37:53 -0500 Subject: [PATCH 50/51] Addressed comments. --- python/pyhail/context.py | 8 ++-- python/pyhail/tests.py | 8 ++-- python/pyhail/type.py | 2 + python/pyhail/utils.py | 4 +- .../hail/driver/AnnotateGlobalExpr.scala | 2 +- .../hail/driver/AnnotateSamplesExpr.scala | 2 +- .../hail/driver/AnnotateVariantsExpr.scala | 2 +- .../hail/driver/FilterAlleles.scala | 2 +- .../hail/expr/JoinAnnotator.scala | 2 +- .../org/broadinstitute/hail/expr/Parser.scala | 2 +- .../org/broadinstitute/hail/expr/Type.scala | 2 + .../hail/keytable/KeyTable.scala | 43 ++++++------------- 12 files changed, 35 insertions(+), 44 deletions(-) diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 55e315c0a5b..87070f3b10f 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -294,12 +294,12 @@ def import_keytable(self, path, key_names, npartitions=None, config=None): :rtype: :class:`.KeyTable` """ - pathArgs = [] + path_args = [] if isinstance(path, str): - pathArgs.append(path) + path_args.append(path) else: for p in path: - pathArgs.append(p) + path_args.append(p) if not isinstance(key_names, str): key_names = ','.join(key_names) @@ -311,7 +311,7 @@ def import_keytable(self, path, key_names, npartitions=None, config=None): config = TextTableConfig() return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable( - self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), key_names, npartitions, config.to_java(self))) + self.jsc, jarray(self.gateway, self.jvm.java.lang.String, path_args), key_names, npartitions, config.to_java(self))) def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing='NA', quantpheno=False): """ diff --git a/python/pyhail/tests.py b/python/pyhail/tests.py index 9d58237ed07..e83627ddff1 100644 --- a/python/pyhail/tests.py +++ b/python/pyhail/tests.py @@ -234,13 +234,15 @@ def test_keytable(self): self.assertTrue(ktcase.same(ktcase2)) # Annotate - kt4 = kt.annotate('X = Status', 'Sample, Status') + (kt.annotate('X = Status', 'Sample, Status') + .nrows()) # Join - kt5 = kt.join(kt2, 'left') + kt.join(kt2, 'left').nrows() # AggregateByKey - kt6 = kt.aggregate_by_key("Status = Status", "Sum = qPhen.sum()") + (kt.aggregate_by_key("Status = Status", "Sum = qPhen.sum()") + .nrows()) # Forall, Exists self.assertFalse(kt.forall('Status == "CASE"')) diff --git a/python/pyhail/type.py b/python/pyhail/type.py index b7bbe373292..ea85d1f4c84 100644 --- a/python/pyhail/type.py +++ b/python/pyhail/type.py @@ -1,5 +1,7 @@ class Type(object): + """Type of values.""" + def __init__(self, jtype): self.jtype = jtype diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py index 2a6906dc31b..85ee9ff36a9 100644 --- a/python/pyhail/utils.py +++ b/python/pyhail/utils.py @@ -16,8 +16,8 @@ class TextTableConfig(object): :param types: Define types of fields in annotations files :type types: str or None """ - def __init__(self, noheader = False, impute = False, - comment = None, delimiter = "\t", missing = "NA", types = None): + def __init__(self, noheader=False, impute=False, + comment=None, delimiter="\t", missing="NA", types=None): self.noheader = noheader self.impute = impute self.comment = comment diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala index 802d1400f35..6359054501f 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala @@ -48,7 +48,7 @@ object AnnotateGlobalExpr extends Command { aggECS.set(1, vds.globalAnnotation) aggECV.set(1, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Option(Annotation.GLOBAL_HEAD)) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.GLOBAL_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala index ce2a40af47d..c2ea5b0eb59 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala @@ -47,7 +47,7 @@ object AnnotateSamplesExpr extends Command { ec.set(2, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Option(Annotation.SAMPLE_HEAD)) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.SAMPLE_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = parseTypes.foldLeft(vds.saSignature) { case (sas, (ids, signature)) => diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala index dd82abe473f..dcfce9e75f9 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala @@ -49,7 +49,7 @@ object AnnotateVariantsExpr extends Command { ec.set(2, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Option(Annotation.VARIANT_HEAD)) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.VARIANT_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = parseTypes.foldLeft(vds.vaSignature) { case (vas, (ids, signature)) => diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala index 62d00cb7ca5..37f7f38332c 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala @@ -72,7 +72,7 @@ object FilterAlleles extends Command { "v" -> (0, TVariant), "va" -> (1, state.vds.vaSignature), "aIndices" -> (2, TArray(TInt)))) - val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Option(Annotation.VARIANT_HEAD)) + val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Some(Annotation.VARIANT_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = types.foldLeft(state.vds.vaSignature) { case (vas, (path, signature)) => val (newVas, i) = vas.insert(signature, path) diff --git a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala index 3bc56290624..fed27b945b7 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala @@ -27,7 +27,7 @@ trait JoinAnnotator { } def buildInserter(code: String, t: Type, ec: EvalContext, expectedHead: String): (Type, Inserter) = { - val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Option(expectedHead)) + val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Some(expectedHead)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finaltype = parseTypes.foldLeft(t) { case (t, (ids, signature)) => diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index dbe3d1f1e23..59655294c48 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -162,7 +162,7 @@ object Parser extends JavaTokenParsers { } def checkType(l: List[String], t: BaseType): Type = { - if (expectedHead.isDefined && l.head != expectedHead.get) + if (expectedHead.exists(l.head != _)) fatal( s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }' | Path should begin with `$expectedHead' diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala index baeef97045a..a0351d29d88 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala @@ -464,6 +464,8 @@ case class TStruct(fields: IndexedSeq[Field]) extends Type { def selfField(name: String): Option[Field] = fieldIdx.get(name).map(i => fields(i)) + def hasField(name: String): Boolean = fieldIdx.contains(name) + def size: Int = fields.length override def getOption(path: List[String]): Option[Type] = diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index 761ddd2f547..abf4d41dbd8 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -16,29 +16,23 @@ import scala.reflect.ClassTag object KeyTable extends Serializable with TextExporter { - def importTextTable(sc: SparkContext, path: Array[String], keyNames: String, nPartitions: Int, config: TextTableConfiguration) = { + def importTextTable(sc: SparkContext, path: Array[String], keysStr: String, nPartitions: Int, config: TextTableConfiguration) = { + require(nPartitions > 1) + val files = sc.hadoopConfiguration.globAll(path) if (files.isEmpty) fatal("Arguments referred to no files") - - val keyNameArray = Parser.parseIdentifierList(keyNames) + + val keys = Parser.parseIdentifierList(keysStr) val (struct, rdd) = - if (nPartitions < 1) - fatal("requested number of partitions in -n/--npartitions must be positive") - else - TextTableReader.read(sc)(files, config, nPartitions) + TextTableReader.read(sc)(files, config, nPartitions) - val keyNamesValid = keyNameArray.forall { k => - val res = struct.selfField(k).isDefined - if (!res) - println(s"Key `$k' is not present in input table") - res - } - if (!keyNamesValid) - fatal("Invalid key names given") + val invalidKeys = keys.filter(!struct.hasField(_)) + if (invalidKeys.nonEmpty) + fatal(s"invalid keys: ${ invalidKeys.mkString(", ") }") - KeyTable(rdd.map(_.value), struct, keyNameArray) + KeyTable(rdd.map(_.value), struct, keys) } def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) @@ -81,7 +75,6 @@ object KeyTable extends Serializable with TextExporter { } case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { - require(fieldNames.areDistinct()) def signature = keySignature.merge(valueSignature)._1 @@ -176,11 +169,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v def annotate(cond: String, keyNameString: String): KeyTable = { val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) - val (parseTypes, fns) = - if (cond != null) - Parser.parseAnnotationArgs(cond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, None) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] @@ -230,8 +219,8 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v if (keySignature != other.keySignature) fatal( s"""Key signatures must be identical. - |Left signature: $keySignature - |Right signature: ${ other.keySignature }""".stripMargin) + |Left signature: ${ keySignature.toPrettyString(compact = true) } + |Right signature: ${ other.keySignature.toPrettyString(compact = true) }""".stripMargin) val overlappingFields = valueNames.toSet.intersect(other.valueNames.toSet) if (overlappingFields.nonEmpty) @@ -353,11 +342,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v else (Array.empty[(List[String], Type)], Array.empty[() => Any]) - val (aggNameParseTypes, aggF) = - if (aggCond != null) - Parser.parseAnnotationArgs(aggCond, ec, None) - else - (Array.empty[(List[String], Type)], Array.empty[() => Any]) + val (aggNameParseTypes, aggF) = Parser.parseAnnotationArgs(aggCond, ec, None) val keyNames = keyNameParseTypes.map(_._1.head) val aggNames = aggNameParseTypes.map(_._1.head) From 4007b895743c2ab3655b7960aa7d680e28464d84 Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Tue, 22 Nov 2016 11:51:11 -0500 Subject: [PATCH 51/51] Fixed test failures. --- .../org/broadinstitute/hail/expr/Parser.scala | 17 +++++++---------- .../broadinstitute/hail/keytable/KeyTable.scala | 6 +++--- .../hail/methods/KeyTableSuite.scala | 17 ++++++++--------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index 59655294c48..fbe15682002 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -75,13 +75,10 @@ object Parser extends JavaTokenParsers { } def parseIdentifierList(code: String): Array[String] = { - if (code.matches("""\s*""")) - Array.empty[String] - else - parseAll(identifierList, code) match { - case Success(result, _) => result - case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) - } + parseAll(identifierList, code) match { + case Success(result, _) => result + case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) + } } def withPos[T](p: => Parser[T]): Parser[Positioned[T]] = @@ -93,7 +90,7 @@ object Parser extends JavaTokenParsers { case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } } - + def parseNamedArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = { val result = parseAll(export_args, code) match { case Success(r, _) => r @@ -301,7 +298,7 @@ object Parser extends JavaTokenParsers { tsvIdentifier ~ "=" ~ expr ^^ { case id ~ _ ~ expr => (id, expr) } def annotationExpressions: Parser[Array[(List[String], AST)]] = - rep1sep(annotationExpression, ",") ^^ { + repsep(annotationExpression, ",") ^^ { _.toArray } @@ -323,7 +320,7 @@ object Parser extends JavaTokenParsers { def identifier = backtickLiteral | ident - def identifierList: Parser[Array[String]] = rep1sep(identifier, ",") ^^ { + def identifierList: Parser[Array[String]] = repsep(identifier, ",") ^^ { _.toArray } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala index abf4d41dbd8..1f096289cee 100644 --- a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -166,7 +166,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v (t, f2) } - def annotate(cond: String, keyNameString: String): KeyTable = { + def annotate(cond: String, keysStr: String): KeyTable = { val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, None) @@ -181,7 +181,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v val inserters = inserterBuilder.result() - val keyNameArray = if (keyNameString != null) Parser.parseIdentifierList(keyNameString) else keyNames + val keys = Parser.parseIdentifierList(keysStr) val nFieldsLocal = nFields @@ -194,7 +194,7 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v } } - KeyTable(mapAnnotations(f), finalSignature, keyNameArray) + KeyTable(mapAnnotations(f), finalSignature, keys) } def filter(p: (Annotation, Annotation) => Boolean): KeyTable = diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala index 4e4a2676c7b..18aefd093ee 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -42,19 +42,18 @@ class KeyTableSuite extends SparkSuite { @Test def testAnnotate() = { val inputFile = "src/test/resources/sampleAnnotations.tsv" val kt1 = KeyTable.importTextTable(sc, Array(inputFile), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) - val kt2 = kt1.annotate("""qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""", null) - val kt3 = kt2.annotate(null, null) - val kt4 = kt3.annotate(null, "qPhen, NotStatus") + val kt2 = kt1.annotate("""qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""", kt1.keyNames.mkString(",")) + val kt3 = kt2.annotate("", kt2.keyNames.mkString(",")) + val kt4 = kt3.annotate("", "qPhen, NotStatus") val kt1ValueNames = kt1.valueNames.toSet val kt2ValueNames = kt2.valueNames.toSet - assert(kt1.nKeys == kt2.nKeys && - kt1.nValues == 2 && kt2.nValues == 5 && - kt1.keySignature == kt2.keySignature && - kt1ValueNames ++ Set("qPhen2", "NotStatus", "X") == kt2ValueNames - ) - + assert(kt1.nKeys == 1) + assert(kt2.nKeys == 1) + assert(kt1.nValues == 2 && kt2.nValues == 5) + assert(kt1.keySignature == kt2.keySignature) + assert(kt1ValueNames ++ Set("qPhen2", "NotStatus", "X") == kt2ValueNames) assert(kt2 same kt3) def getDataAsMap(kt: KeyTable) = {