diff --git a/docs/faq/ExpressionLanguage.md b/docs/faq/ExpressionLanguage.md index 3a009907f25..bada94b4cd4 100644 --- a/docs/faq/ExpressionLanguage.md +++ b/docs/faq/ExpressionLanguage.md @@ -1 +1,4 @@ ## Expression Language + + + diff --git a/docs/reference/HailExpressionLanguage.md b/docs/reference/HailExpressionLanguage.md index ba3d6876a47..b20a49982e9 100644 --- a/docs/reference/HailExpressionLanguage.md +++ b/docs/reference/HailExpressionLanguage.md @@ -58,6 +58,12 @@ Several Hail commands provide the ability to perform a broad array of computatio - pcoin(p) -- returns `true` with probability `p`. `p` should be between 0.0 and 1.0 - runif(min, max) -- returns a random draw from a uniform distribution on \[`min`, `max`). `min` should be less than or equal to `max` - rnorm(mean, sd) -- returns a random draw from a normal distribution with mean `mean` and standard deviation `sd`. `sd` should be non-negative + + - Statistics + - pnorm(x) -- Returns left-tail probability p for which p = Prob($Z$ < x) with $Z$ a standard normal random variable + - qnorm(p) -- Returns left-quantile x for which p = Prob($Z$ < x) with $Z$ a standard normal random variable. `p` must satisfy `0 < p < 1`. Inverse of `pnorm` + - pchisq1tail(x) -- Returns right-tail probability p for which p = Prob($Z^2$ > x) with $Z^2$ a chi-squared random variable with one degree of freedom. `x` must be positive + - qchisq1tail(p) -- Returns right-quantile x for which p = Prob($Z^2$ > x) with $Z^2$ a chi-squared RV with one degree of freedom. `p` must satisfy `0 < p <= 1`. Inverse of `pchisq1tail` - Array Operations: - constructor: `[element1, element2, ...]` -- Create a new array from elements of the same type. @@ -139,6 +145,10 @@ Several Hail commands provide the ability to perform a broad array of computatio - range: `range(end)` or `range(start, end)`. This function will produce an `Array[Int]`. `range(3)` produces `[0, 1, 2]`. `range(-2, 2)` produces `[-2, -1, 0, 1]`. + - `gtj(i)` and `gtk(i)`. Convert from genotype index (triangular numbers) to `j/k` pairs. + + - `gtIndex(j, k)`. Convert from `j/k` pair to genotype index (triangular numbers). + **Note:** - All variables and values are case sensitive @@ -355,7 +365,7 @@ The resulting array is sorted by count in descending order (the most common elem .hist( start, end, bins ) ``` -This aggregator is used to compute density distributions of numeric parameters. The start, end, and bins params are no-scope parameters, which means that while computations like `100 / 4` are acceptable, variable references like `global.nBins` are not. +This aggregator is used to compute frequency distributions of numeric parameters. The start, end, and bins params are no-scope parameters, which means that while computations like `100 / 4` are acceptable, variable references like `global.nBins` are not. The result of a `hist` invocation is a struct: @@ -363,7 +373,7 @@ The result of a `hist` invocation is a struct: Struct { binEdges: Array[Double], binFrequencies: Array[Long], - nSmaller: Long, + nLess: Long, nGreater: Long } ``` @@ -374,7 +384,7 @@ Important properties: - (bins + 1) breakpoints are generated from the range `(start to end by binsize)` - `binEdges` stores an array of bin cutoffs. Each bin is left-inclusive, right-exclusive except the last bin, which includes the maximum value. This means that if there are N total bins, there will be N + 1 elements in binEdges. For the invocation `hist(0, 3, 3)`, `binEdges` would be `[0, 1, 2, 3]` where the bins are `[0, 1)`, `[1, 2)`, `[2, 3]`. - `binFrequencies` stores the number of elements in the aggregable that fall in each bin. It contains one element for each bin. - - Elements greater than the max bin or smaller than the min bin will be tracked separately by `nSmaller` and `nGreater` + - Elements greater than the max bin or less than the min bin will be tracked separately by `nLess` and `nGreater` **Examples:** @@ -388,7 +398,7 @@ Or, extend the above to compute a global gq histogram: ``` annotatevariants expr -c 'va.gqHist = gs.map(g => g.gq).hist(0, 100, 20)' -annotateglobal expr -c 'global.gqDensity = variants.map(v => va.gqHist.densities).sum()' +annotateglobal expr -c 'global.gqHist = variants.map(v => va.gqHist.binFrequencies).sum()' ``` ### Collect diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py index fc8f8d53054..a27e3d6f47d 100644 --- a/python/pyhail/__init__.py +++ b/python/pyhail/__init__.py @@ -1,4 +1,7 @@ from pyhail.context import HailContext from pyhail.dataset import VariantDataset +from pyhail.keytable import KeyTable +from pyhail.utils import TextTableConfig +from pyhail.type import Type -__all__ = ["HailContext", "VariantDataset"] +__all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig", "Type"] diff --git a/python/pyhail/context.py b/python/pyhail/context.py index 2cb6b6a307a..87070f3b10f 100644 --- a/python/pyhail/context.py +++ b/python/pyhail/context.py @@ -1,16 +1,47 @@ import pyspark from pyhail.dataset import VariantDataset -from pyhail.java import jarray, scala_object +from pyhail.java import jarray, scala_object, scala_package_object +from pyhail.keytable import KeyTable +from pyhail.utils import TextTableConfig +from py4j.protocol import Py4JJavaError + +class FatalError(Exception): + """:class:`.FatalError` is an error thrown by Hail method failures""" + + def __init__(self, message, java_exception): + self.msg = message + self.java_exception = java_exception + super(FatalError) + + def __str__(self): + return self.msg class HailContext(object): """:class:`.HailContext` is the main entrypoint for PyHail functionality. :param SparkContext sc: The pyspark context. + + :param str log: Log file. + + :param bool quiet: Don't write log file. + + :param bool append: Append to existing log file. + + :param long block_size: Minimum size of file splits in MB. + + :param str parquet_compression: Parquet compression codec. + + :param int branching_factor: Branching factor to use in tree aggregate. + + :param str tmp_dir: Temporary directory for file merging. """ - def __init__(self, sc): + def __init__(self, sc=None, log='hail.log', quiet=False, append=False, + block_size=1, parquet_compression='uncompressed', + branching_factor=50, tmp_dir='/tmp'): + self.sc = sc self.gateway = sc._gateway @@ -23,26 +54,37 @@ def __init__(self, sc): self.sql_context = pyspark.sql.SQLContext(sc, self.jsql_context) - self.jsc.hadoopConfiguration().set( - 'io.compression.codecs', - 'org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec') + scala_package_object(self.jvm.org.broadinstitute.hail.driver).configure( + self.jsc, + log, + quiet, + append, + parquet_compression, + block_size, + branching_factor, + tmp_dir) - logger = sc._jvm.org.apache.log4j - logger.LogManager.getLogger("org"). setLevel(logger.Level.ERROR) - logger.LogManager.getLogger("akka").setLevel(logger.Level.ERROR) def _jstate(self, jvds): return self.jvm.org.broadinstitute.hail.driver.State( self.jsc, self.jsql_context, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty()) + def _raise_py4j_exception(self, e): + msg = scala_package_object(self.jvm.org.broadinstitute.hail.utils).getMinimalMessage(e.java_exception) + raise FatalError(msg, e.java_exception) + def run_command(self, vds, pargs): jargs = jarray(self.gateway, self.jvm.java.lang.String, pargs) t = self.jvm.org.broadinstitute.hail.driver.ToplevelCommands.lookup(jargs) cmd = t._1() cmd_args = t._2() jstate = self._jstate(vds.jvds if vds != None else None) - result = cmd.run(jstate, - cmd_args) + + try: + result = cmd.run(jstate, cmd_args) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + return VariantDataset(self, result.vds()) def grep(self, regex, path, max_count=100): @@ -74,7 +116,7 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No # text table options types=None, missing="NA", delimiter="\\t", comment=None, header=True, impute=False): - """Import variants and variant annotaitons from a delimited text file + """Import variants and variant annotations from a delimited text file (text table) as a sites-only VariantDataset. :param path: The files to import. @@ -235,7 +277,43 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch return self.run_command(None, pargs) - def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing="NA", quantpheno=False): + def import_keytable(self, path, key_names, npartitions=None, config=None): + """Import delimited text file (text table) as KeyTable. + + :param path: files to import. + :type path: str or list of str + + :param key_names: The name(s) of fields to be considered keys + :type key_names: str or list of str + + :param npartitions: Number of partitions. + :type npartitions: int or None + + :param config: Configuration options for importing text files + :type config: :class:`.TextTableConfig` or None + + :rtype: :class:`.KeyTable` + """ + path_args = [] + if isinstance(path, str): + path_args.append(path) + else: + for p in path: + path_args.append(p) + + if not isinstance(key_names, str): + key_names = ','.join(key_names) + + if not npartitions: + npartitions = self.sc.defaultMinPartitions + + if not config: + config = TextTableConfig() + + return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable( + self.jsc, jarray(self.gateway, self.jvm.java.lang.String, path_args), key_names, npartitions, config.to_java(self))) + + def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing='NA', quantpheno=False): """ Import PLINK binary file (.bed, .bim, .fam) as VariantDataset @@ -427,7 +505,8 @@ def balding_nichols_model(self, populations, samples, variants, npartitions, :rtype: :class:`.VariantDataset` """ - pargs = ['baldingnichols', '-k', str(populations), '-n', str(samples), '-m', str(variants), '--npartitions', str(npartitions), + pargs = ['baldingnichols', '-k', str(populations), '-n', str(samples), '-m', str(variants), '--npartitions', + str(npartitions), '--root', root] if population_dist: pargs.append('-d') diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py index e921bb6dcc8..8fd18373f1e 100644 --- a/python/pyhail/dataset.py +++ b/python/pyhail/dataset.py @@ -1,12 +1,33 @@ from pyhail.java import scala_package_object +from pyhail.keytable import KeyTable import pyspark +from py4j.protocol import Py4JJavaError class VariantDataset(object): def __init__(self, hc, jvds): self.hc = hc self.jvds = jvds + def _raise_py4j_exception(self, e): + self.hc._raise_py4j_exception(e) + + def aggregate_by_key(self, key_code, agg_code): + """Aggregate by user-defined key and aggregation expressions. + Equivalent of a group-by operation in SQL. + + :param key_code: Named expression(s) for which fields are keys. + :type key_code: str or list of str + + :param agg_code: Named aggregation expression(s). + :type agg_code: str or list of str + + :rtype: :class:`.KeyTable` + + """ + + return KeyTable(self.hc, self.jvds.aggregateByKey(key_code, agg_code)) + def aggregate_intervals(self, input, condition, output): """Aggregate over intervals and export. @@ -41,7 +62,6 @@ def annotate_global_list(self, input, root, as_set=False): :param bool as_set: If True, load text file as Set[String], otherwise, load as Array[String]. - """ pargs = ['annotateglobal', 'list', '-i', input, '-r', root] @@ -324,9 +344,12 @@ def count(self, genotypes=False): """ - return (scala_package_object(self.hc.jvm.org.broadinstitute.hail.driver) - .count(self.jvds, genotypes) - .toJavaMap()) + try: + return (scala_package_object(self.hc.jvm.org.broadinstitute.hail.driver) + .count(self.jvds, genotypes) + .toJavaMap()) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def deduplicate(self): """Remove duplicate variants.""" @@ -381,14 +404,14 @@ def export_genotypes(self, output, condition, types=None, export_ref=False, expo pargs.append('--print-missing') return self.hc.run_command(self, pargs) - def export_plink(self, output): + def export_plink(self, output, fam_expr = 'id = s.id'): """Export as PLINK .bed/.bim/.fam :param str output: Output file base. Will write .bed, .bim and .fam files. """ - pargs = ['exportplink', '--output', output] + pargs = ['exportplink', '--output', output, '--fam-expr', fam_expr] return self.hc.run_command(self, pargs) def export_samples(self, output, condition, types=None): @@ -505,7 +528,7 @@ def write(self, output, overwrite=False): :param str output: Path of .vds file to write. :param bool overwrite: If True, overwrite any existing .vds file. - + """ pargs = ['write', '-o', output] @@ -513,14 +536,16 @@ def write(self, output, overwrite=False): pargs.append('--overwrite') return self.hc.run_command(self, pargs) - def filter_genotypes(self, condition): + def filter_genotypes(self, condition, keep=True): """Filter variants based on expression. :param str condition: Expression for filter condition. """ - pargs = ['filtergenotypes', '--keep', '-c', condition] + pargs = ['filtergenotypes', + '--keep' if keep else '--remove', + '-c', condition] return self.hc.run_command(self, pargs) def filter_multi(self): @@ -539,24 +564,28 @@ def filter_samples_all(self): pargs = ['filtersamples', 'all'] return self.hc.run_command(self, pargs) - def filter_samples_expr(self, condition): + def filter_samples_expr(self, condition, keep=True): """Filter samples based on expression. :param str condition: Expression for filter condition. """ - pargs = ['filtersamples', 'expr', '--keep', '-c', condition] + pargs = ['filtersamples', 'expr', + '--keep' if keep else '--remove', + '-c', condition] return self.hc.run_command(self, pargs) - def filter_samples_list(self, input): + def filter_samples_list(self, input, keep=True): """Filter samples with a sample list file. :param str input: Path to sample list file. """ - pargs = ['filtersamples', 'list', '--keep', '-i', input] + pargs = ['filtersamples', 'list', + '--keep' if keep else '--remove', + '-i', input] return self.hc.run_command(self, pargs) def filter_variants_all(self): @@ -565,34 +594,40 @@ def filter_variants_all(self): pargs = ['filtervariants', 'all'] return self.hc.run_command(self, pargs) - def filter_variants_expr(self, condition): + def filter_variants_expr(self, condition, keep=True): """Filter variants based on expression. :param str condition: Expression for filter condition. """ - pargs = ['filtervariants', 'expr', '--keep', '-c', condition] + pargs = ['filtervariants', 'expr', + '--keep' if keep else '--remove', + '-c', condition] return self.hc.run_command(self, pargs) - def filter_variants_intervals(self, input): + def filter_variants_intervals(self, input, keep=True): """Filter variants with an .interval_list file. :param str input: Path to .interval_list file. """ - pargs = ['filtervariants', 'intervals', '--keep', '-i', input] + pargs = ['filtervariants', 'intervals', + '--keep' if keep else '--remove', + '-i', input] return self.hc.run_command(self, pargs) - def filter_variants_list(self, input): + def filter_variants_list(self, input, keep=True): """Filter variants with a list of variants. :param str input: Path to variant list file. """ - pargs = ['filtervariants', 'list', '--keep', '-i', input] + pargs = ['filtervariants', 'list', + '--keep' if keep else '--remove', + '-i', input] return self.hc.run_command(self, pargs) def grm(self, format, output, id_file=None, n_file=None): @@ -698,11 +733,10 @@ def join(self, right): and global annotations from self. """ - - return VariantDataset( - self.hc, - self.hc.jvm.org.broadinstitute.hail.driver.Join.join(self.jvds, - right.jvds)) + try: + return VariantDataset(self.hc, self.hc.jvm.org.broadinstitute.hail.driver.Join.join(self.jvds, right.jvds)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def linreg(self, y, covariates='', root='va.linreg', minac=1, minaf=None): """Test each variant for association using the linear regression @@ -865,8 +899,10 @@ def same(self, other): :rtype: bool """ - - return self.jvds.same(other.jvds, 1e-6) + try: + return self.jvds.same(other.jvds, 1e-6) + except Py4JJavaError as e: + self._raise_py4j_exception(e) def sample_qc(self, branching_factor=None): """Compute per-sample QC metrics. @@ -914,7 +950,7 @@ def split_multi(self, propagate_gq=False): pargs.append('--propagate-gq') return self.hc.run_command(self, pargs) - def tdt(self, fam, root = 'va.tdt'): + def tdt(self, fam, root='va.tdt'): """Find transmitted and untransmitted variants; count per variant and nuclear family. @@ -975,5 +1011,16 @@ def vep(self, config, block_size=None, root=None, force=False, csq=False): def variants_to_pandas(self): """Convert variants and variant annotations to Pandas dataframe.""" - return pyspark.sql.DataFrame(self.jvds.variantsDF(self.hc.jsql_context), - self.hc.sql_context).toPandas() + try: + return pyspark.sql.DataFrame(self.jvds.variantsDF(self.hc.jsql_context), + self.hc.sql_context).toPandas() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def samples_to_pandas(self): + """Convert samples and sample annotations to Pandas dataframe.""" + try: + return pyspark.sql.DataFrame(self.jvds.samplesDF(self.hc.jsql_context), + self.hc.sql_context).toPandas() + except Py4JJavaError as e: + self._raise_py4j_exception(e) diff --git a/python/pyhail/docs/index.rst b/python/pyhail/docs/index.rst index c743f23687e..228c112559c 100644 --- a/python/pyhail/docs/index.rst +++ b/python/pyhail/docs/index.rst @@ -17,6 +17,12 @@ Contents: .. autoclass:: pyhail.VariantDataset :members: +.. autoclass:: pyhail.KeyTable + :members: + +.. autoclass:: pyhail.TextTableConfig + :members: + Indices and tables ================== diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py new file mode 100644 index 00000000000..ec294f4c556 --- /dev/null +++ b/python/pyhail/keytable.py @@ -0,0 +1,193 @@ +from pyhail.type import Type + +class KeyTable(object): + """:class:`.KeyTable` is Hail's version of a SQL table where fields + can be designated as keys. + + """ + + def __init__(self, hc, jkt): + """ + :param HailContext hc: Hail spark context. + + :param JavaKeyTable jkt: Java KeyTable object. + """ + self.hc = hc + self.jkt = jkt + + def _raise_py4j_exception(self, e): + self.hc._raise_py4j_exception(e) + + def __repr__(self): + try: + return self.jkt.toString() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def nfields(self): + """Number of fields in the key-table + + :rtype: int + """ + try: + return self.jkt.nFields() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def schema(self): + """Key-table schema + + :rtype: :class:`.Type` + """ + try: + return Type(self.jkt.signature()) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def key_names(self): + """Field names that are keys + + :rtype: list of str + """ + try: + return self.jkt.keyNames() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def field_names(self): + """Names of all fields in the key-table + + :rtype: list of str + """ + try: + return self.jkt.fieldNames() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def nrows(self): + """Number of rows in the key-table + + :rtype: long + """ + try: + return self.jkt.nRows() + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + + def same(self, other): + """Test whether two key-tables are identical + + :param other: KeyTable to compare to + :type other: :class:`.KeyTable` + + :rtype: bool + """ + try: + return self.jkt.same(other.jkt) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def export(self, output, types_file=None): + """Export key-table to a TSV file. + + :param str output: Output file path + + :param str types_file: Output path of types file + + :rtype: Nothing. + """ + try: + self.jkt.export(self.hc.jsc, output, types_file) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def filter(self, code, keep=True): + """Filter rows from key-table. + + :param str code: Annotation expression. + + :param bool keep: Keep rows where annotation expression evaluates to True + + :rtype: :class:`.KeyTable` + """ + try: + return KeyTable(self.hc, self.jkt.filter(code, keep)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def annotate(self, code, key_names=''): + """Add fields to key-table. + + :param str code: Annotation expression. + + :param str key_names: Comma separated list of field names to be treated as a key + + :rtype: :class:`.KeyTable` + """ + try: + return KeyTable(self.hc, self.jkt.annotate(code, key_names)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def join(self, right, how='inner'): + """Join two key-tables together. Both key-tables must have identical key schemas + and non-overlapping field names. + + :param right: Key-table to join + :type right: :class:`.KeyTable` + + :param str how: Method for joining two tables together. One of "inner", "outer", "left", "right". + + :rtype: :class:`.KeyTable` + """ + try: + return KeyTable(self.hc, self.jkt.join(right.jkt, how)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def aggregate_by_key(self, key_code, agg_code): + """Group by key condition and aggregate results + + :param key_code: Named expression(s) for which fields are keys. + :type key_code: str or list of str + + :param agg_code: Named aggregation expression(s). + :type agg_code: str or list of str + + :rtype: :class:`.KeyTable` + """ + if isinstance(key_code, list): + key_code = ", ".join([str(l) for l in list]) + + if isinstance(agg_code, list): + agg_code = ", ".join([str(l) for l in list]) + + try: + return KeyTable(self.hc, self.jkt.aggregate(key_code, agg_code)) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def forall(self, code): + """Tests whether a condition is true for all rows + + :param str code: Boolean expression + + :rtype: bool + """ + try: + return self.jkt.forall(code) + except Py4JJavaError as e: + self._raise_py4j_exception(e) + + def exists(self, code): + """Tests whether a condition is true for any row + + :param str code: Boolean expression + + :rtype: bool + """ + try: + return self.jkt.exists(code) + except Py4JJavaError as e: + self._raise_py4j_exception(e) diff --git a/python/pyhail/tests.py b/python/pyhail/tests.py index 158d77220e7..e83627ddff1 100644 --- a/python/pyhail/tests.py +++ b/python/pyhail/tests.py @@ -6,7 +6,7 @@ import unittest from pyspark import SparkContext -from pyhail import HailContext +from pyhail import HailContext, TextTableConfig class ContextTests(unittest.TestCase): @@ -208,6 +208,45 @@ def test_dataset(self): sample2_split.variant_qc().print_schema() sample2.variants_to_pandas() - + + sample_split.annotate_variants_expr("va.nHet = gs.filter(g => g.isHet).count()") + + kt = sample_split.aggregate_by_key("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong") + + def test_keytable(self): + # Import + kt = self.hc.import_keytable(self.test_resources + '/sampleAnnotations.tsv', 'Sample', config = TextTableConfig(impute = True)) + kt2 = self.hc.import_keytable(self.test_resources + '/sampleAnnotations2.tsv', 'Sample', config = TextTableConfig(impute = True)) + + # Variables + self.assertEqual(kt.nfields(), 3) + self.assertEqual(kt.key_names()[0], "Sample") + self.assertEqual(kt.field_names()[2], "qPhen") + self.assertEqual(kt.nrows(), 100) + kt.schema() + + # Export + kt.export('/tmp/testExportKT.tsv') + + # Filter, Same + ktcase = kt.filter('Status == "CASE"', True) + ktcase2 = kt.filter('Status == "CTRL"', False) + self.assertTrue(ktcase.same(ktcase2)) + + # Annotate + (kt.annotate('X = Status', 'Sample, Status') + .nrows()) + + # Join + kt.join(kt2, 'left').nrows() + + # AggregateByKey + (kt.aggregate_by_key("Status = Status", "Sum = qPhen.sum()") + .nrows()) + + # Forall, Exists + self.assertFalse(kt.forall('Status == "CASE"')) + self.assertTrue(kt.exists('Status == "CASE"')) + def tearDown(self): self.sc.stop() diff --git a/python/pyhail/type.py b/python/pyhail/type.py new file mode 100644 index 00000000000..ea85d1f4c84 --- /dev/null +++ b/python/pyhail/type.py @@ -0,0 +1,12 @@ + +class Type(object): + """Type of values.""" + + def __init__(self, jtype): + self.jtype = jtype + + def __repr__(self): + return self.jtype.toString() + + def __str__(self): + return self.jtype.toPrettyString(False, False) diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py new file mode 100644 index 00000000000..85ee9ff36a9 --- /dev/null +++ b/python/pyhail/utils.py @@ -0,0 +1,47 @@ + +class TextTableConfig(object): + """Configuration for delimited (text table) files. + + :param bool noheader: File has no header and columns should be indicated by `_1, _2, ... _N' (0-indexed) + + :param bool impute: Impute column types from the file + + :param comment: Skip lines beginning with the given pattern + :type comment: str or None + + :param str delimiter: Field delimiter regex + + :param str missing: Specify identifier to be treated as missing + + :param types: Define types of fields in annotations files + :type types: str or None + """ + def __init__(self, noheader=False, impute=False, + comment=None, delimiter="\t", missing="NA", types=None): + self.noheader = noheader + self.impute = impute + self.comment = comment + self.delimiter = delimiter + self.missing = missing + self.types = types + + def __str__(self): + res = ["--comment", self.comment, "--delimiter", self.delimiter, + "--missing", self.missing] + + if self.noheader: + res.append("--no-header") + + if self.impute: + res.append("--impute") + + return " ".join(res) + + def to_java(self, hc): + """Convert to Java TextTableConfiguration object. + + :param :class:`.HailContext` The Hail context. + """ + return hc.jvm.org.broadinstitute.hail.utils.TextTableConfiguration.apply(self.types, self.comment, + self.delimiter, self.missing, + self.noheader, self.impute) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala index af89ccb2842..e94a6689865 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala @@ -58,7 +58,7 @@ object AggregateIntervals extends Command { ec.set(1, vds.globalAnnotation) aggregationEC.set(1, vds.globalAnnotation) - val (header, _, f) = Parser.parseNamedArgs(cond, ec) + val (header, _, f) = Parser.parseExportArgs(cond, ec) if (header.isEmpty) fatal("this module requires one or more named expr arguments") @@ -67,8 +67,6 @@ object AggregateIntervals extends Command { val zvf = () => zVals.indices.map(zVals).toArray - val variantAggregations = Aggregators.buildVariantAggregations(vds, aggregationEC) - val iList = IntervalListAnnotator.read(options.input, sc.hadoopConfiguration) val iListBc = sc.broadcast(iList) diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala index f0b9e62a9e4..6359054501f 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala @@ -48,7 +48,7 @@ object AnnotateGlobalExpr extends Command { aggECS.set(1, vds.globalAnnotation) aggECV.set(1, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.GLOBAL_HEAD) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.GLOBAL_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] @@ -82,7 +82,7 @@ object AnnotateGlobalExpr extends Command { val ga = inserters .zip(fns.map(_ ())) .foldLeft(vds.globalAnnotation) { case (a, (ins, res)) => - ins(a, res) + ins(a, Option(res)) } state.copy(vds = vds.copy( diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala index 4bbfa9663a4..c2ea5b0eb59 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala @@ -47,7 +47,7 @@ object AnnotateSamplesExpr extends Command { ec.set(2, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.SAMPLE_HEAD) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.SAMPLE_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = parseTypes.foldLeft(vds.saSignature) { case (sas, (ids, signature)) => @@ -70,7 +70,7 @@ object AnnotateSamplesExpr extends Command { fns.zip(inserters) .foldLeft(sa) { case (sa, (fn, inserter)) => - inserter(sa, fn()) + inserter(sa, Option(fn())) } } state.copy(vds = vds.copy( diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala index d10b20810f7..dcfce9e75f9 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala @@ -49,7 +49,7 @@ object AnnotateVariantsExpr extends Command { ec.set(2, vds.globalAnnotation) aggregationEC.set(4, vds.globalAnnotation) - val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.VARIANT_HEAD) + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.VARIANT_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = parseTypes.foldLeft(vds.vaSignature) { case (vas, (ids, signature)) => @@ -67,7 +67,7 @@ object AnnotateVariantsExpr extends Command { aggregateOption.foreach(f => f(v, va, gs)) fns.zip(inserters) .foldLeft(va) { case (va, (fn, inserter)) => - inserter(va, fn()) + inserter(va, Option(fn())) } }.copy(vaSignature = finalType) state.copy(vds = annotated) diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala index 436f6420500..e40ef4554d5 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala @@ -254,7 +254,7 @@ abstract class Command { fatal("this module does not support multiallelic variants.\n Please run `splitmulti' first.") else { if (requiresVDS) - log.info(s"sparkinfo: $name, ${state.vds.nPartitions} partitions, ${state.vds.rdd.getStorageLevel.toReadableString()}") + log.info(s"sparkinfo: $name, ${ state.vds.nPartitions } partitions, ${ state.vds.rdd.getStorageLevel.toReadableString() }") run(state, options) } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala index b51ef90d540..f9ef1f99e75 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala @@ -59,7 +59,7 @@ object ExportGenotypes extends Command with TextExporter { val ec = EvalContext(symTab) ec.set(5, vds.globalAnnotation) - val (header, ts, f) = Parser.parseExportArgs(cond, ec) + val (header, ts, f) = Parser.parseNamedArgs(cond, ec) Option(options.typesFile).foreach { file => val typeInfo = header diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala index 104f3db7724..f7b898f77a6 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala @@ -1,6 +1,7 @@ package org.broadinstitute.hail.driver import org.apache.spark.storage.StorageLevel +import org.broadinstitute.hail.expr.{BaseAggregable, EvalContext, Parser, TBoolean, TDouble, TGenotype, TSample, TString, Type} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.io.plink.ExportBedBimFam import org.kohsuke.args4j.{Option => Args4jOption} @@ -11,6 +12,10 @@ object ExportPlink extends Command { @Args4jOption(required = true, name = "-o", aliases = Array("--output"), usage = "Output file base (will generate .bed, .bim, .fam)") var output: String = _ + + @Args4jOption(name = "-f", aliases = Array("--fam-expr"), + usage = "Expression for .fam file values, in sample context only (global, s, sa in scope), assignable fields: famID, id, matID, patID (String), isFemale (Boolean), isCase (Boolean) or qPheno (Double)") + var famExpr: String = "id = s.id" } def newOptions = new Options @@ -26,12 +31,64 @@ object ExportPlink extends Command { def run(state: State, options: Options): State = { val vds = state.vds + val symTab = Map( + "s" -> (0, TSample), + "sa" -> (1, vds.saSignature), + "global" -> (2, vds.globalSignature)) + + val ec = EvalContext(symTab) + ec.set(2, vds.globalAnnotation) + + type Formatter = (() => Option[Any]) => () => String + + val formatID: Formatter = f => () => f().map(_.asInstanceOf[String]).getOrElse("0") + val formatIsFemale: Formatter = f => () => f().map { + _.asInstanceOf[Boolean] match { + case true => "2" + case false => "1" + } + }.getOrElse("0") + val formatIsCase: Formatter = f => () => f().map { + _.asInstanceOf[Boolean] match { + case true => "2" + case false => "1" + } + }.getOrElse("-9") + val formatQPheno: Formatter = f => () => f().map(_.toString).getOrElse("-9") + + val famColumns: Map[String, (Type, Int, Formatter)] = Map( + "famID" -> (TString, 0, formatID), + "id" -> (TString, 1, formatID), + "patID" -> (TString, 2, formatID), + "matID" -> (TString, 3, formatID), + "isFemale" -> (TBoolean, 4, formatIsFemale), + "qPheno" -> (TDouble, 5, formatQPheno), + "isCase" -> (TBoolean, 5, formatIsCase)) + + val exprs = Parser.parseNamedExprs(options.famExpr, ec) + + val famFns: Array[() => String] = Array( + () => "0", () => "0", () => "0", () => "0", () => "-9", () => "-9") + + exprs.foreach { case (name, t, f) => + famColumns.get(name) match { + case Some((colt, i, formatter)) => + if (colt != t) + fatal("invalid type for .fam file column $h: expected $colt, got $t") + famFns(i) = formatter(f) + + case None => + fatal(s"no .fam file column $name") + } + } + val spaceRegex = """\s+""".r val badSampleIds = vds.sampleIds.filter(id => spaceRegex.findFirstIn(id).isDefined) if (badSampleIds.nonEmpty) { - fatal(s"""Found ${ badSampleIds.length } sample IDs with whitespace + fatal( + s"""Found ${ badSampleIds.length } sample IDs with whitespace | Please run `renamesamples' to fix this problem before exporting to plink format - | Bad sample IDs: @1 """.stripMargin, badSampleIds) + | Bad sample IDs: @1 """.stripMargin, badSampleIds) } val bedHeader = Array[Byte](108, 27, 1) @@ -49,8 +106,11 @@ object ExportPlink extends Command { plinkRDD.unpersist() val famRows = vds - .sampleIds - .map(ExportBedBimFam.makeFamRow) + .sampleIdsAndAnnotations + .map { case (s, sa) => + ec.setAll(s, sa) + famFns.map(_()).mkString("\t") + } state.hadoopConf.writeTextFile(options.output + ".fam")(out => famRows.foreach(line => { diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala index ea5d75f974e..333c2f7a7ba 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala @@ -58,7 +58,7 @@ object ExportSamples extends Command with TextExporter { ec.set(2, vds.globalAnnotation) aggregationEC.set(5, vds.globalAnnotation) - val (header, types, f) = Parser.parseExportArgs(cond, ec) + val (header, types, f) = Parser.parseNamedArgs(cond, ec) Option(options.typesFile).foreach { file => val typeInfo = header .getOrElse(types.indices.map(i => s"_$i").toArray) diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala index 0e2aaaa7b20..46389d657c4 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala @@ -48,6 +48,7 @@ object ExportVariants extends Command with TextExporter { "sa" -> (3, vds.saSignature), "g" -> (4, TGenotype), "global" -> (5, vds.globalSignature))) + val symTab = Map( "v" -> (0, TVariant), "va" -> (1, vds.vaSignature), @@ -59,7 +60,7 @@ object ExportVariants extends Command with TextExporter { ec.set(2, vds.globalAnnotation) aggregationEC.set(5, vds.globalAnnotation) - val (header, types, f) = Parser.parseExportArgs(cond, ec) + val (header, types, f) = Parser.parseNamedArgs(cond, ec) Option(options.typesFile).foreach { file => val typeInfo = header diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala index 482b7ab9bf0..9c039890401 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala @@ -153,7 +153,7 @@ object ExportVariantsCass extends Command { val vEC = EvalContext(vSymTab) val vA = vEC.a - val (vHeader, vTypes, vf) = Parser.parseNamedArgs(vCond, vEC) + val (vHeader, vTypes, vf) = Parser.parseExportArgs(vCond, vEC) val gSymTab = Map( "v" -> (0, TVariant), @@ -164,7 +164,7 @@ object ExportVariantsCass extends Command { val gEC = EvalContext(gSymTab) val gA = gEC.a - val (gHeader, gTypes, gf) = Parser.parseNamedArgs(gCond, gEC) + val (gHeader, gTypes, gf) = Parser.parseExportArgs(gCond, gEC) val symTab = Map( "v" -> (0, TVariant), diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala index 6939be79ad3..37f7f38332c 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala @@ -72,7 +72,7 @@ object FilterAlleles extends Command { "v" -> (0, TVariant), "va" -> (1, state.vds.vaSignature), "aIndices" -> (2, TArray(TInt)))) - val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Annotation.VARIANT_HEAD) + val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Some(Annotation.VARIANT_HEAD)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finalType = types.foldLeft(state.vds.vaSignature) { case (vas, (path, signature)) => val (newVas, i) = vas.insert(signature, path) @@ -120,7 +120,7 @@ object FilterAlleles extends Command { def updateAnnotation(v: Variant, va: Annotation, newToOld: IndexedSeq[Int]): Annotation = { annotationEC.setAll(v, va, newToOld) - generators.zip(inserters).foldLeft(va) { case (va, (fn, inserter)) => inserter(va, fn()) } + generators.zip(inserters).foldLeft(va) { case (va, (fn, inserter)) => inserter(va, Option(fn())) } } def updateGenotypes(gs: Iterable[Genotype], oldToNew: Array[Int], newCount: Int): Iterable[Genotype] = { diff --git a/src/main/scala/org/broadinstitute/hail/driver/Main.scala b/src/main/scala/org/broadinstitute/hail/driver/Main.scala index a9aa78c025b..6d5d2d9a317 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/Main.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/Main.scala @@ -34,10 +34,7 @@ object SparkManager { conf.setMaster(local) } - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") _sc = new SparkContext(conf) - _sc.hadoopConfiguration.set("io.compression.codecs", - "org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec") } _sc @@ -241,29 +238,6 @@ object Main { sys.exit(1) } - val logProps = new Properties() - if (options.logQuiet) { - logProps.put("log4j.rootLogger", "OFF, stderr") - - logProps.put("log4j.appender.stderr", "org.apache.log4j.ConsoleAppender") - logProps.put("log4j.appender.stderr.Target", "System.err") - logProps.put("log4j.appender.stderr.threshold", "OFF") - logProps.put("log4j.appender.stderr.layout", "org.apache.log4j.PatternLayout") - logProps.put("log4j.appender.stderr.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") - } else { - logProps.put("log4j.rootLogger", "INFO, logfile") - - logProps.put("log4j.appender.logfile", "org.apache.log4j.FileAppender") - logProps.put("log4j.appender.logfile.append", options.logAppend.toString) - logProps.put("log4j.appender.logfile.file", options.logFile) - logProps.put("log4j.appender.logfile.threshold", "INFO") - logProps.put("log4j.appender.logfile.layout", "org.apache.log4j.PatternLayout") - logProps.put("log4j.appender.logfile.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") - } - - LogManager.resetConfiguration() - PropertyConfigurator.configure(logProps) - if (splitArgs.length == 1) fail(s"hail: fatal: no commands given") @@ -288,23 +262,9 @@ object Main { val sc = SparkManager.createSparkContext("Hail", Option(options.master), "local[*]") - val conf = sc.getConf - conf.set("spark.ui.showConsoleProgress", "false") - val progressBar = ProgressBarBuilder.build(sc) - - conf.set("spark.sql.parquet.compression.codec", options.parquetCompression) - - sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", options.blockSize * 1024L * 1024L) - - /* `DataFrame.write` writes one file per partition. Without this, read will split files larger than the default - * parquet block size into multiple partitions. This causes `OrderedRDD` to fail since the per-partition range - * no longer line up with the RDD partitions. - * - * For reasons we don't understand, the DataFrame code uses `SparkHadoopUtil.get.conf` instead of the Hadoop - * configuration in the SparkContext. Set both for consistency. - */ - SparkHadoopUtil.get.conf.setLong("parquet.block.size", 1099511627776L) - sc.hadoopConfiguration.setLong("parquet.block.size", 1099511627776L) + configure(sc, logFile = options.logFile, quiet = options.logQuiet, append = options.logAppend, + parquetCompression = options.parquetCompression, blockSize = options.blockSize, + branchingFactor = options.branchingFactor, tmpDir = options.tmpDir) val sqlContext = SparkManager.createSQLContext() @@ -313,12 +273,9 @@ object Main { sc.addJar(jar) HailConfiguration.installDir = new File(jar).getParent + "/.." - HailConfiguration.tmpDir = options.tmpDir - HailConfiguration.branchingFactor = options.branchingFactor runCommands(sc, sqlContext, invocations) sc.stop() - progressBar.stop() } } diff --git a/src/main/scala/org/broadinstitute/hail/driver/VEP.scala b/src/main/scala/org/broadinstitute/hail/driver/VEP.scala index 74e6c9a664a..c7598b61ed6 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/VEP.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/VEP.scala @@ -211,7 +211,7 @@ object VEP extends Command { val rootType = vds.vaSignature.getOption(root) .filter { t => - val r = t == vepSignature + val r = t == (if(csq) TString else vepSignature) if (!r) { if (options.force) warn(s"type for ${ options.root } does not match vep signature, overwriting.") diff --git a/src/main/scala/org/broadinstitute/hail/driver/package.scala b/src/main/scala/org/broadinstitute/hail/driver/package.scala index d564e84199b..a2dced1a88c 100644 --- a/src/main/scala/org/broadinstitute/hail/driver/package.scala +++ b/src/main/scala/org/broadinstitute/hail/driver/package.scala @@ -1,8 +1,15 @@ package org.broadinstitute.hail +import java.io.File import java.util +import java.util.Properties + +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{ProgressBarBuilder, SparkContext} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.VariantDataset + import scala.collection.JavaConverters._ package object driver { @@ -39,4 +46,65 @@ package object driver { CountResult(vds.nSamples, nVariants, nCalled) } + + def configure(sc: SparkContext, logFile: String, quiet: Boolean, append: Boolean, + parquetCompression: String, blockSize: Long, branchingFactor: Int, tmpDir: String) { + require(blockSize > 0) + require(branchingFactor > 0) + + val logProps = new Properties() + if (quiet) { + logProps.put("log4j.rootLogger", "OFF, stderr") + logProps.put("log4j.appender.stderr", "org.apache.log4j.ConsoleAppender") + logProps.put("log4j.appender.stderr.Target", "System.err") + logProps.put("log4j.appender.stderr.threshold", "OFF") + logProps.put("log4j.appender.stderr.layout", "org.apache.log4j.PatternLayout") + logProps.put("log4j.appender.stderr.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") + } else { + logProps.put("log4j.rootLogger", "INFO, logfile") + logProps.put("log4j.appender.logfile", "org.apache.log4j.FileAppender") + logProps.put("log4j.appender.logfile.append", append.toString) + logProps.put("log4j.appender.logfile.file", logFile) + logProps.put("log4j.appender.logfile.threshold", "INFO") + logProps.put("log4j.appender.logfile.layout", "org.apache.log4j.PatternLayout") + logProps.put("log4j.appender.logfile.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n") + } + + LogManager.resetConfiguration() + PropertyConfigurator.configure(logProps) + + val conf = sc.getConf + + conf.set("spark.ui.showConsoleProgress", "false") + val progressBar = ProgressBarBuilder.build(sc) + + sc.hadoopConfiguration.set( + "io.compression.codecs", + "org.apache.hadoop.io.compress.DefaultCodec," + + "org.broadinstitute.hail.io.compress.BGzipCodec," + + "org.apache.hadoop.io.compress.GzipCodec") + + conf.set("spark.sql.parquet.compression.codec", parquetCompression) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", blockSize * 1024L * 1024L) + + /* `DataFrame.write` writes one file per partition. Without this, read will split files larger than the default + * parquet block size into multiple partitions. This causes `OrderedRDD` to fail since the per-partition range + * no longer line up with the RDD partitions. + * + * For reasons we don't understand, the DataFrame code uses `SparkHadoopUtil.get.conf` instead of the Hadoop + * configuration in the SparkContext. Set both for consistency. + */ + SparkHadoopUtil.get.conf.setLong("parquet.block.size", 1099511627776L) + sc.hadoopConfiguration.setLong("parquet.block.size", 1099511627776L) + + + val jar = getClass.getProtectionDomain.getCodeSource.getLocation.toURI.getPath + sc.addJar(jar) + + HailConfiguration.installDir = new File(jar).getParent + "/.." + HailConfiguration.tmpDir = tmpDir + HailConfiguration.branchingFactor = branchingFactor + } } diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala index 9967d896468..9018dd3de82 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala @@ -17,7 +17,9 @@ import scala.language.existentials import scala.reflect.ClassTag import org.broadinstitute.hail.utils.EitherIsAMonad._ -case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunctions: ArrayBuffer[Aggregator]) { +case class EvalContext(st: SymbolTable, + a: ArrayBuffer[Any], + aggregationFunctions: ArrayBuffer[((Any) => Any, Aggregator)]) { def setAll(args: Any*) { args.zipWithIndex.foreach { case (arg, i) => a(i) = arg } @@ -31,7 +33,7 @@ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunction object EvalContext { def apply(symTab: SymbolTable): EvalContext = { val a = new ArrayBuffer[Any]() - val af = new ArrayBuffer[Aggregator]() + val af = new ArrayBuffer[((Any) => Any, Aggregator)]() for ((i, t) <- symTab.values) { if (i >= 0) a += null @@ -666,6 +668,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST localA(localIdx) = a fn() } + MappedAggregable(agg, t, mapF) case error => parseError(s"method `$method' expects a lambda function (param => Any), got invalid mapping (param => $error)") @@ -1143,7 +1146,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new CountAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new CountAggregator(localIdx))) () => localA(localIdx) case (agg: TAggregable, "fraction", Array(Lambda(_, param, body))) => @@ -1157,7 +1160,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new FractionAggregator(aggF, localIdx, localA, bodyFn, lambdaIdx) + agg.ec.aggregationFunctions += ((aggF, new FractionAggregator(localIdx, localA, bodyFn, lambdaIdx))) () => localA(localIdx) case (agg: TAggregable, "stats", Array()) => @@ -1168,7 +1171,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val t = agg.elementType val aggF = agg.f - agg.ec.aggregationFunctions += new StatAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new StatAggregator(localIdx))) val getOp = (a: Any) => { val sc = a.asInstanceOf[StatCounter] @@ -1193,7 +1196,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new CounterAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new CounterAggregator(localIdx))) () => { val m = localA(localIdx).asInstanceOf[mutable.HashMap[Any, Long]] @@ -1211,7 +1214,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val vf = vAST.eval(ec) - agg.ec.aggregationFunctions += new CallStatsAggregator(aggF, localIdx, vf) + agg.ec.aggregationFunctions += ((aggF, new CallStatsAggregator(localIdx, vf))) () => { val cs = localA(localIdx).asInstanceOf[CallStats] @@ -1265,7 +1268,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new HistAggregator(aggF, localIdx, indices) + agg.ec.aggregationFunctions += ((aggF, new HistAggregator(localIdx, indices))) () => localA(localIdx).asInstanceOf[HistogramCombiner].toAnnotation @@ -1276,7 +1279,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f - agg.ec.aggregationFunctions += new CollectAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new CollectAggregator(localIdx))) () => localA(localIdx).asInstanceOf[ArrayBuffer[Any]].toIndexedSeq case (agg: TAggregable, "infoScore", Array()) => @@ -1287,7 +1290,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val localPos = posn val aggF = agg.f - agg.ec.aggregationFunctions += new InfoScoreAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new InfoScoreAggregator(localIdx))) () => localA(localIdx).asInstanceOf[InfoScoreCombiner].asAnnotation case (agg: TAggregable, "inbreeding", Array(mafAST)) => @@ -1299,7 +1302,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f val maf = mafAST.eval(agg.ec) - agg.ec.aggregationFunctions += new InbreedingAggregator(aggF, localIdx, maf) + agg.ec.aggregationFunctions += ((aggF, new InbreedingAggregator(localIdx, maf))) () => localA(localIdx).asInstanceOf[InbreedingCombiner].asAnnotation case (agg: TAggregable, "hardyWeinberg", Array()) => @@ -1310,7 +1313,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val localPos = posn val aggF = agg.f - agg.ec.aggregationFunctions += new HWEAggregator(aggF, localIdx) + agg.ec.aggregationFunctions += ((aggF, new HWEAggregator(localIdx))) () => localA(localIdx).asInstanceOf[HWECombiner].asAnnotation case (agg: TAggregable, "sum", Array()) => @@ -1322,8 +1325,8 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST val aggF = agg.f (`type`: @unchecked) match { - case TDouble => agg.ec.aggregationFunctions += new SumAggregator(aggF, localIdx) - case TArray(TDouble) => agg.ec.aggregationFunctions += new SumArrayAggregator(aggF, localIdx, localPos) + case TDouble => agg.ec.aggregationFunctions += ((aggF, new SumAggregator(localIdx))) + case TArray(TDouble) => agg.ec.aggregationFunctions += ((aggF, new SumArrayAggregator(localIdx, localPos))) } () => localA(localIdx) @@ -1649,9 +1652,10 @@ case class IndexOp(posn: Position, f: AST, idx: AST) extends AST(posn, Array(f, } catch { case e: java.lang.IndexOutOfBoundsException => ParserUtils.error(localPos, - s"""Tried to access index [$i] on array ${ JsonMethods.compact(localT.toJSON(a)) } of length ${ a.length } + s"""Invalid array index: tried to access index [$i] on array `@1' of length ${ a.length } | Hint: All arrays in Hail are zero-indexed (`array[0]' is the first element) - | Hint: For accessing `A'-numbered info fields in split variants, `va.info.field[va.aIndex - 1]' is correct""".stripMargin) + | Hint: For accessing `A'-numbered info fields in split variants, `va.info.field[va.aIndex - 1]' is correct""".stripMargin, + JsonMethods.compact(localT.toJSON(a))) case e: Throwable => throw e }) @@ -1697,6 +1701,7 @@ case class SymRef(posn: Position, symbol: String) extends AST(posn) { def eval(ec: EvalContext): () => Any = { val localI = ec.st(symbol)._1 val localA = ec.a + if (localI < 0) () => 0 // FIXME placeholder else diff --git a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala index cbb6675d631..c7bc99ea53f 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala @@ -356,8 +356,18 @@ object FunctionRegistry { register("runif", { (min: Double, max: Double) => min + (max - min) * math.random }) register("rnorm", { (mean: Double, sd: Double) => mean + sd * scala.util.Random.nextGaussian() }) + register("pnorm", { (x: Double) => pnorm(x) }) + register("qnorm", { (p: Double) => qnorm(p) }) + + register("pchisq1tail", { (x: Double) => chiSquaredTail(1.0, x) }) + register("qchisq1tail", { (p: Double) => inverseChiSquaredTailOneDF(p) }) + registerConversion((x: Int) => x.toDouble, priority = 2) registerConversion { (x: Long) => x.toDouble } registerConversion { (x: Int) => x.toLong } registerConversion { (x: Float) => x.toDouble } + + register("gtj", (i: Int) => Genotype.gtPair(i).j) + register("gtk", (i: Int) => Genotype.gtPair(i).k) + register("gtIndex", (j: Int, k: Int) => Genotype.gtIndex(j, k)) } diff --git a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala index 2291d823750..fed27b945b7 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala @@ -27,7 +27,7 @@ trait JoinAnnotator { } def buildInserter(code: String, t: Type, ec: EvalContext, expectedHead: String): (Type, Inserter) = { - val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, expectedHead) + val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Some(expectedHead)) val inserterBuilder = mutable.ArrayBuilder.make[Inserter] val finaltype = parseTypes.foldLeft(t) { case (t, (ids, signature)) => @@ -45,7 +45,7 @@ trait JoinAnnotator { val queries = fns.map(_ ()) var newAnnotation = left queries.indices.foreach { i => - newAnnotation = inserters(i)(newAnnotation, queries(i)) + newAnnotation = inserters(i)(newAnnotation, Option(queries(i))) } newAnnotation } diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala index 3b3713fe8dc..fbe15682002 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala @@ -20,6 +20,17 @@ object ParserUtils { lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' } }^""".stripMargin) } + + def error(pos: Position, msg: String, tr: Truncatable): Nothing = { + val lineContents = pos.longString.split("\n").head + val prefix = s":${ pos.line }:" + fatal( + s"""$msg + |$prefix$lineContents + |${ " " * prefix.length }${ + lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' } + }^""".stripMargin, tr) + } } object Parser extends JavaTokenParsers { @@ -64,13 +75,10 @@ object Parser extends JavaTokenParsers { } def parseIdentifierList(code: String): Array[String] = { - if (code.matches("""\s*""")) - Array.empty[String] - else - parseAll(identifierList, code) match { - case Success(result, _) => result - case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) - } + parseAll(identifierList, code) match { + case Success(result, _) => result + case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) + } } def withPos[T](p: => Parser[T]): Parser[Positioned[T]] = @@ -83,7 +91,7 @@ object Parser extends JavaTokenParsers { } } - def parseExportArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = { + def parseNamedArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = { val result = parseAll(export_args, code) match { case Success(r, _) => r case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) @@ -133,8 +141,8 @@ object Parser extends JavaTokenParsers { (someIf(names.nonEmpty, names), tb.result(), () => computations.flatMap(_ ())) } - def parseNamedArgs(code: String, ec: EvalContext): (Array[String], Array[Type], () => Array[String]) = { - val (headerOption, ts, f) = parseExportArgs(code, ec) + def parseExportArgs(code: String, ec: EvalContext): (Array[String], Array[Type], () => Array[String]) = { + val (headerOption, ts, f) = parseNamedArgs(code, ec) val header = headerOption match { case Some(h) => h case None => fatal( @@ -144,22 +152,24 @@ object Parser extends JavaTokenParsers { (header, ts, f) } - def parseAnnotationArgs(code: String, ec: EvalContext, expectedHead: String): (Array[(List[String], Type)], Array[() => Option[Any]]) = { + def parseAnnotationArgs(code: String, ec: EvalContext, expectedHead: Option[String]): (Array[(List[String], Type)], Array[() => Any]) = { val arr = parseAll(annotationExpressions, code) match { case Success(result, _) => result.asInstanceOf[Array[(List[String], AST)]] case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } def checkType(l: List[String], t: BaseType): Type = { - if (l.head == expectedHead) - t match { - case t: Type => t - case bt => fatal( - s"""Got invalid type `$t' from the result of `${ l.mkString(".") }'""".stripMargin) - } else fatal( - s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }' - | Path should begin with `$expectedHead' + if (expectedHead.exists(l.head != _)) + fatal( + s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }' + | Path should begin with `$expectedHead' """.stripMargin) + + t match { + case t: Type => t + case bt => fatal( + s"""Got invalid type `$t' from the result of `${ l.mkString(".") }'""".stripMargin) + } } val all = arr.map { @@ -167,8 +177,10 @@ object Parser extends JavaTokenParsers { ast.typecheck(ec) val t = checkType(path, ast.`type`) val f = ast.eval(ec) - ((path.tail, t), () => Option(f())) + val name = if (expectedHead.isDefined) path.tail else path + ((name, t), () => f()) } + (all.map(_._1), all.map(_._2)) } @@ -186,6 +198,19 @@ object Parser extends JavaTokenParsers { path.tail } + def parseNamedExprs(code: String, ec: EvalContext): Array[(String, BaseType, () => Option[Any])] = { + val parsed = parseAll(named_args, code) match { + case Success(result, _) => result.asInstanceOf[Array[(String, AST)]] + case NoSuccess(msg, _) => fatal(msg) + } + + parsed.map { case (name, ast) => + ast.typecheck(ec) + val f = ast.eval(ec) + (name, ast.`type`, () => Option(f())) + } + } + def parseExprs(code: String, ec: EvalContext): (Array[(BaseType, () => Option[Any])]) = { if (code.matches("""\s*""")) @@ -273,7 +298,7 @@ object Parser extends JavaTokenParsers { tsvIdentifier ~ "=" ~ expr ^^ { case id ~ _ ~ expr => (id, expr) } def annotationExpressions: Parser[Array[(List[String], AST)]] = - rep1sep(annotationExpression, ",") ^^ { + repsep(annotationExpression, ",") ^^ { _.toArray } @@ -295,7 +320,7 @@ object Parser extends JavaTokenParsers { def identifier = backtickLiteral | ident - def identifierList: Parser[Array[String]] = rep1sep(identifier, ",") ^^ { + def identifierList: Parser[Array[String]] = repsep(identifier, ",") ^^ { _.toArray } diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala index e495bb0c429..a0351d29d88 100644 --- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala +++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala @@ -6,6 +6,7 @@ import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.annotations.{Annotation, AnnotationPathException, _} import org.broadinstitute.hail.check.Arbitrary._ import org.broadinstitute.hail.check.{Gen, _} +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.utils import org.broadinstitute.hail.utils.{Interval, StringEscapeUtils} import org.broadinstitute.hail.variant.{AltAllele, Genotype, Locus, Variant} @@ -249,6 +250,14 @@ abstract class TAggregable extends BaseType { def f: (Any) => Any } +case class KeyTableAggregable(ec: EvalContext, elementType: Type, idx: Int) extends TAggregable { + def f: (Any) => Any = { + (a: Any) => { + KeyTable.annotationToSeq(a, ec.st.size)(idx) + } + } +} + case class BaseAggregable(ec: EvalContext, elementType: Type) extends TAggregable { def f: (Any) => Any = identity } @@ -455,6 +464,8 @@ case class TStruct(fields: IndexedSeq[Field]) extends Type { def selfField(name: String): Option[Field] = fieldIdx.get(name).map(i => fields(i)) + def hasField(name: String): Boolean = fieldIdx.contains(name) + def size: Int = fields.length override def getOption(path: List[String]): Option[Type] = diff --git a/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala b/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala index 005f7cea76f..3dd37176696 100644 --- a/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala +++ b/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala @@ -35,8 +35,4 @@ object ExportBedBimFam { val id = s"${v.contig}:${v.start}:${v.ref}:${v.alt}" s"""${v.contig}\t$id\t0\t${v.start}\t${v.alt}\t${v.ref}""" } - - def makeFamRow(s: String): String = { - s"0\t$s\t0\t0\t0\t-9" - } } diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala new file mode 100644 index 00000000000..1f096289cee --- /dev/null +++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala @@ -0,0 +1,382 @@ +package org.broadinstitute.hail.keytable + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.broadinstitute.hail.annotations._ +import org.broadinstitute.hail.check.Gen +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.methods.{Aggregators, Filter} +import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.io.TextExporter + +import scala.collection.mutable +import scala.reflect.ClassTag + + +object KeyTable extends Serializable with TextExporter { + + def importTextTable(sc: SparkContext, path: Array[String], keysStr: String, nPartitions: Int, config: TextTableConfiguration) = { + require(nPartitions > 1) + + val files = sc.hadoopConfiguration.globAll(path) + if (files.isEmpty) + fatal("Arguments referred to no files") + + val keys = Parser.parseIdentifierList(keysStr) + + val (struct, rdd) = + TextTableReader.read(sc)(files, config, nPartitions) + + val invalidKeys = keys.filter(!struct.hasField(_)) + if (invalidKeys.nonEmpty) + fatal(s"invalid keys: ${ invalidKeys.mkString(", ") }") + + KeyTable(rdd.map(_.value), struct, keys) + } + + def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null)) + + def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int, nValues: Int) = + ec.setAll(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues): _*) + + def setEvalContext(ec: EvalContext, a: Annotation, nFields: Int) = + ec.setAll(annotationToSeq(a, nFields): _*) + + def toSingleRDD(rdd: RDD[(Annotation, Annotation)], nKeys: Int, nValues: Int): RDD[Annotation] = + rdd.map { case (k, v) => + val x = Annotation.fromSeq(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues)) + x + } + + def apply(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): KeyTable = { + val keyFields = signature.fields.filter(fd => keyNames.contains(fd.name)) + val keyIndices = keyFields.map(_.index) + + val valueFields = signature.fields.filterNot(fd => keyNames.contains(fd.name)) + val valueIndices = valueFields.map(_.index) + + assert(keyIndices.toSet.intersect(valueIndices.toSet).isEmpty) + + val nFields = signature.size + + val newKeySignature = TStruct(keyFields.map(fd => (fd.name, fd.`type`)): _*) + val newValueSignature = TStruct(valueFields.map(fd => (fd.name, fd.`type`)): _*) + + val newRDD = rdd.map { a => + val r = annotationToSeq(a, nFields).zipWithIndex + val keyRow = keyIndices.map(i => r(i)._1) + val valueRow = valueIndices.map(i => r(i)._1) + (Annotation.fromSeq(keyRow), Annotation.fromSeq(valueRow)) + } + + KeyTable(newRDD, newKeySignature, newValueSignature) + } +} + +case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) { + require(fieldNames.areDistinct()) + + def signature = keySignature.merge(valueSignature)._1 + + def fields = signature.fields + + def keySchema = keySignature.schema + + def valueSchema = valueSignature.schema + + def schema = signature.schema + + def keyNames = keySignature.fields.map(_.name).toArray + + def valueNames = valueSignature.fields.map(_.name).toArray + + def fieldNames = keyNames ++ valueNames + + def nRows = rdd.count() + + def nFields = fields.length + + def nKeys = keySignature.size + + def nValues = valueSignature.size + + def same(other: KeyTable): Boolean = { + if (fields.toSet != other.fields.toSet) { + println(s"signature: this=${ schema } other=${ other.schema }") + false + } else if (keyNames.toSet != other.keyNames.toSet) { + println(s"keyNames: this=${ keyNames.mkString(",") } other=${ other.keyNames.mkString(",") }") + false + } else { + val thisFieldNames = valueNames + val otherFieldNames = other.valueNames + + rdd.groupByKey().fullOuterJoin(other.rdd.groupByKey()).forall { case (k, (v1, v2)) => + (v1, v2) match { + case (None, None) => true + case (Some(x), Some(y)) => + val r1 = x.map(r => thisFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet + val r2 = y.map(r => otherFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet + val res = r1 == r2 + if (!res) + println(s"k=$k r1=${ r1.mkString(",") } r2=${ r2.mkString(",") }") + res + case _ => + println(s"k=$k v1=$v1 v2=$v2") + false + } + } + } + } + + def mapAnnotations[T](f: (Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] = + KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a)) + + def mapAnnotations[T](f: (Annotation, Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] = + rdd.map { case (k, v) => f(k, v) } + + def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues + + val (t, f) = Parser.parse(code, ec) + + val f2: (Annotation, Annotation) => Option[Any] = { + case (k, v) => + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) + f() + } + + (t, f2) + } + + def querySingle(code: String): (BaseType, Querier) = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nFieldsLocal = nFields + + val (t, f) = Parser.parse(code, ec) + + val f2: (Annotation) => Option[Any] = { a => + KeyTable.setEvalContext(ec, a, nFieldsLocal) + f() + } + + (t, f2) + } + + def annotate(cond: String, keysStr: String): KeyTable = { + val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + + val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, None) + + val inserterBuilder = mutable.ArrayBuilder.make[Inserter] + + val finalSignature = parseTypes.foldLeft(signature) { case (vs, (ids, signature)) => + val (s: TStruct, i) = vs.insert(signature, ids) + inserterBuilder += i + s + } + + val inserters = inserterBuilder.result() + + val keys = Parser.parseIdentifierList(keysStr) + + val nFieldsLocal = nFields + + val f: Annotation => Annotation = { a => + KeyTable.setEvalContext(ec, a, nFieldsLocal) + + fns.zip(inserters) + .foldLeft(a) { case (a1, (fn, inserter)) => + inserter(a1, Option(fn())) + } + } + + KeyTable(mapAnnotations(f), finalSignature, keys) + } + + def filter(p: (Annotation, Annotation) => Boolean): KeyTable = + copy(rdd = rdd.filter { case (k, v) => p(k, v) }) + + def filter(cond: String, keep: Boolean): KeyTable = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues + + val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean) + + val p = (k: Annotation, v: Annotation) => { + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) + Filter.keepThis(f(), keep) + } + + filter(p) + } + + def join(other: KeyTable, joinType: String): KeyTable = { + if (keySignature != other.keySignature) + fatal( + s"""Key signatures must be identical. + |Left signature: ${ keySignature.toPrettyString(compact = true) } + |Right signature: ${ other.keySignature.toPrettyString(compact = true) }""".stripMargin) + + val overlappingFields = valueNames.toSet.intersect(other.valueNames.toSet) + if (overlappingFields.nonEmpty) + fatal( + s"""Fields that are not keys cannot be present in both key-tables. + |Overlapping fields: ${ overlappingFields.mkString(", ") }""".stripMargin) + + joinType match { + case "left" => leftJoin(other) + case "right" => rightJoin(other) + case "inner" => innerJoin(other) + case "outer" => outerJoin(other) + case _ => fatal("Invalid join type specified. Choose one of `left', `right', `inner', `outer'") + } + } + + def leftJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) + + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.leftOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) } + + KeyTable(newRDD, keySignature, newValueSignature) + } + + def rightJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) + + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.rightOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) } + + KeyTable(newRDD, keySignature, newValueSignature) + } + + def outerJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) + + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.fullOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) } + + KeyTable(newRDD, keySignature, newValueSignature) + } + + def innerJoin(other: KeyTable): KeyTable = { + require(keySignature == other.keySignature) + + val (newValueSignature, merger) = valueSignature.merge(other.valueSignature) + val newRDD = rdd.join(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl, vr)) } + + KeyTable(newRDD, keySignature, newValueSignature) + } + + def forall(code: String): Boolean = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues + + val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) + + rdd.forall { case (k, v) => + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) + f().getOrElse(false) + } + } + + def exists(code: String): Boolean = { + val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*) + val nKeysLocal = nKeys + val nValuesLocal = nValues + + val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean) + + rdd.exists { case (k, v) => + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) + f().getOrElse(false) + } + } + + def export(sc: SparkContext, output: String, typesFile: String) = { + val hConf = sc.hadoopConfiguration + + val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + + val (header, types, f) = Parser.parseNamedArgs(fieldNames.map(n => n + " = " + n).mkString(","), ec) + + Option(typesFile).foreach { file => + val typeInfo = header + .getOrElse(types.indices.map(i => s"_$i").toArray) + .zip(types) + + KeyTable.exportTypes(file, hConf, typeInfo) + } + + hConf.delete(output, recursive = true) + + val nKeysLocal = nKeys + val nValuesLocal = nValues + + rdd + .mapPartitions { it => + val sb = new StringBuilder() + it.map { case (k, v) => + sb.clear() + KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal) + f().foreachBetween(x => sb.append(x))(sb += '\t') + sb.result() + } + }.writeTable(output, header.map(_.mkString("\t"))) + } + + def aggregate(keyCond: String, aggCond: String): KeyTable = { + + val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*) + val ec = EvalContext(fields.zipWithIndex.map { case (fd, i) => (fd.name, (-1, KeyTableAggregable(aggregationEC, fd.`type`, i))) }.toMap) + + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, aggregationEC, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val (aggNameParseTypes, aggF) = Parser.parseAnnotationArgs(aggCond, ec, None) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + + val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + val aggFunctions = aggregationEC.aggregationFunctions.map(_._1) + + assert(zVals.length == aggFunctions.length) + + val seqOp = (array: Array[Aggregator], b: Any) => { + KeyTable.setEvalContext(aggregationEC, b, nFields) + for (i <- array.indices) { + array(i).seqOp(aggFunctions(i)(b)) + } + array + } + + val nFieldsLocal = nFields + + val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions { it => + it.map { a => + KeyTable.setEvalContext(aggregationEC, a, nFieldsLocal) + val key = Annotation.fromSeq(keyF.map(_ ())) + (key, a) + } + }.aggregateByKey(zVals)(seqOp, combOp) + .map { case (k, agg) => + resultOp(agg) + (k, Annotation.fromSeq(aggF.map(_ ()))) + } + + KeyTable(newRDD, keySignature, valueSignature) + } +} \ No newline at end of file diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala index 00be90c835f..b07b3044e1c 100644 --- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala +++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala @@ -15,7 +15,10 @@ import scala.util.parsing.input.Position object Aggregators { def buildVariantAggregations(vds: VariantDataset, ec: EvalContext): Option[(Variant, Annotation, Iterable[Genotype]) => Unit] = { - val aggregators = ec.aggregationFunctions.toArray + + val aggFunctions = ec.aggregationFunctions.map(_._1) + val aggregators = ec.aggregationFunctions.map(_._2) + val aggregatorA = ec.a if (aggregators.nonEmpty) { @@ -28,29 +31,29 @@ object Aggregators { aggregatorA(0) = v aggregatorA(1) = va (gs, localSamplesBc.value, localAnnotationsBc.value).zipped - .foreach { - case (g, s, sa) => - aggregatorA(2) = s - aggregatorA(3) = sa - baseArray.foreach { - _.seqOp(g) - } + .foreach { case (g, s, sa) => + aggregatorA(2) = s + aggregatorA(3) = sa + (baseArray, aggFunctions).zipped.foreach { case (agg, aggF) => + agg.seqOp(aggF(g)) + } } baseArray.foreach { agg => aggregatorA(agg.idx) = agg.result } } Some(f) - } else None + } else + None } def buildSampleAggregations(vds: VariantDataset, ec: EvalContext): Option[(String) => Unit] = { - val aggregators = ec.aggregationFunctions.toArray + val aggFunctions = ec.aggregationFunctions.map(_._1) + val aggregators = ec.aggregationFunctions.map(_._2) val aggregatorA = ec.a if (aggregators.isEmpty) None else { - val localSamplesBc = vds.sampleIdsBc val localAnnotationsBc = vds.sampleAnnotationsBc @@ -73,7 +76,7 @@ object Aggregators { var j = 0 while (j < nAggregations) { - arr(i, j).seqOp(g) + arr(i, j).seqOp(aggFunctions(j)(g)) j += 1 } i += 1 @@ -100,7 +103,8 @@ object Aggregators { def makeFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any)) => Array[Aggregator], (Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = { - val aggregators = ec.aggregationFunctions.toArray + val aggFunctions = ec.aggregationFunctions.map(_._1) + val aggregators = ec.aggregationFunctions.map(_._2) val arr = ec.a @@ -116,7 +120,7 @@ object Aggregators { val (aggT, annotation) = b ec.set(0, annotation) for (i <- array.indices) { - array(i).seqOp(aggT) + array(i).seqOp(aggFunctions(i)(aggT)) } array } @@ -135,15 +139,14 @@ object Aggregators { } } -class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Long] { +class CountAggregator(val idx: Int) extends TypedAggregator[Long] { var _state = 0L def result = _state def seqOp(x: Any) { - val v = f(x) - if (f(x) != null) + if (x != null) _state += 1 } @@ -151,10 +154,10 @@ class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Lon _state += agg2._state } - def copy() = new CountAggregator(f, idx) + def copy() = new CountAggregator(idx) } -class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int) +class FractionAggregator(val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int) extends TypedAggregator[java.lang.Double] { var _num = 0L @@ -167,10 +170,9 @@ class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any] _num.toDouble / _denom def seqOp(x: Any) { - val r = f(x) - if (r != null) { + if (x != null) { _denom += 1 - localA(lambdaIdx) = r + localA(lambdaIdx) = x if (bodyFn().asInstanceOf[Boolean]) _num += 1 } @@ -181,37 +183,35 @@ class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any] _denom += agg2._denom } - def copy() = new FractionAggregator(f, idx, localA, bodyFn, lambdaIdx) + def copy() = new FractionAggregator(idx, localA, bodyFn, lambdaIdx) } -class StatAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[StatCounter] { +class StatAggregator(val idx: Int) extends TypedAggregator[StatCounter] { var _state = new StatCounter() def result = _state def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(DoubleNumericConversion.to(r)) + if (x != null) + _state.merge(DoubleNumericConversion.to(x)) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new StatAggregator(f, idx) + def copy() = new StatAggregator(idx) } -class CounterAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] { +class CounterAggregator(val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] { var m = new mutable.HashMap[Any, Long] def result = m def seqOp(x: Any) { - val r = f(x) - if (r != null) - m.updateValue(r, 0L, _ + 1) + if (x != null) + m.updateValue(x, 0L, _ + 1) } def combOp(agg2: this.type) { @@ -220,10 +220,10 @@ class CounterAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[m } } - def copy() = new CounterAggregator(f, idx) + def copy() = new CounterAggregator(idx) } -class HistAggregator(f: (Any) => Any, val idx: Int, indices: Array[Double]) +class HistAggregator(val idx: Int, indices: Array[Double]) extends TypedAggregator[HistogramCombiner] { var _state = new HistogramCombiner(indices) @@ -231,90 +231,85 @@ class HistAggregator(f: (Any) => Any, val idx: Int, indices: Array[Double]) def result = _state def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(DoubleNumericConversion.to(r)) + if (x != null) + _state.merge(DoubleNumericConversion.to(x)) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new HistAggregator(f, idx, indices) + def copy() = new HistAggregator(idx, indices) } -class CollectAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] { +class CollectAggregator(val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] { var _state = new ArrayBuffer[Any] def result = _state def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state += f(x) + if (x != null) + _state += x } def combOp(agg2: this.type) = _state ++= agg2._state - def copy() = new CollectAggregator(f, idx) + def copy() = new CollectAggregator(idx) } -class InfoScoreAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[InfoScoreCombiner] { +class InfoScoreAggregator(val idx: Int) extends TypedAggregator[InfoScoreCombiner] { var _state = new InfoScoreCombiner() def result = _state def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(r.asInstanceOf[Genotype]) + if (x != null) + _state.merge(x.asInstanceOf[Genotype]) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new InfoScoreAggregator(f, idx) + def copy() = new InfoScoreAggregator(idx) } -class HWEAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[HWECombiner] { +class HWEAggregator(val idx: Int) extends TypedAggregator[HWECombiner] { var _state = new HWECombiner() def result = _state def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state.merge(r.asInstanceOf[Genotype]) + if (x != null) + _state.merge(x.asInstanceOf[Genotype]) } def combOp(agg2: this.type) { _state.merge(agg2._state) } - def copy() = new HWEAggregator(f, idx) + def copy() = new HWEAggregator(idx) } -class SumAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Double] { +class SumAggregator(val idx: Int) extends TypedAggregator[Double] { var _state = 0d def result = _state def seqOp(x: Any) { - val r = f(x) - if (r != null) - _state += DoubleNumericConversion.to(r) + if (x != null) + _state += DoubleNumericConversion.to(x) } def combOp(agg2: this.type) = _state += agg2._state - def copy() = new SumAggregator(f, idx) + def copy() = new SumAggregator(idx) } -class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position) +class SumArrayAggregator(val idx: Int, localPos: Position) extends TypedAggregator[IndexedSeq[Double]] { var _state: Array[Double] = _ @@ -322,7 +317,7 @@ class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position) def result = _state def seqOp(x: Any) { - val r = f(x).asInstanceOf[IndexedSeq[Any]] + val r = x.asInstanceOf[IndexedSeq[Any]] if (r != null) { if (_state == null) _state = r.map(x => if (x == null) 0d else DoubleNumericConversion.to(x)).toArray @@ -353,10 +348,10 @@ class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position) _state(i) += agg2state(i) } - def copy() = new SumArrayAggregator(f, idx, localPos) + def copy() = new SumArrayAggregator(idx, localPos) } -class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any) +class CallStatsAggregator(val idx: Int, variantF: () => Any) extends TypedAggregator[CallStats] { var first = true @@ -378,9 +373,8 @@ class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any) } if (combiner != null) { - val r = f(x) - if (r != null) - combiner.merge(r.asInstanceOf[Genotype]) + if (x != null) + combiner.merge(x.asInstanceOf[Genotype]) } } @@ -392,26 +386,25 @@ class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any) combiner.merge(agg2.combiner) } - def copy(): TypedAggregator[CallStats] = new CallStatsAggregator(f, idx, variantF) + def copy(): TypedAggregator[CallStats] = new CallStatsAggregator(idx, variantF) } -class InbreedingAggregator(f: (Any) => Any, localIdx: Int, getAF: () => Any) extends TypedAggregator[InbreedingCombiner] { +class InbreedingAggregator(localIdx: Int, getAF: () => Any) extends TypedAggregator[InbreedingCombiner] { var _state = new InbreedingCombiner() def result = _state def seqOp(x: Any) = { - val r = f(x) val af = getAF() - if (r != null && af != null) - _state.merge(r.asInstanceOf[Genotype], DoubleNumericConversion.to(af)) + if (x != null && af != null) + _state.merge(x.asInstanceOf[Genotype], DoubleNumericConversion.to(af)) } def combOp(agg2: this.type) = _state.merge(agg2.asInstanceOf[InbreedingAggregator]._state) - def copy() = new InbreedingAggregator(f, localIdx, getAF) + def copy() = new InbreedingAggregator(localIdx, getAF) def idx = localIdx } diff --git a/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala b/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala index 55d926153a0..5d52748adef 100644 --- a/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala +++ b/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala @@ -9,7 +9,7 @@ object HistogramCombiner { def schema: Type = TStruct( "binEdges" -> TArray(TDouble), "binFrequencies" -> TArray(TLong), - "nSmaller" -> TLong, + "nLess" -> TLong, "nGreater" -> TLong) } @@ -18,13 +18,13 @@ class HistogramCombiner(indices: Array[Double]) extends Serializable { val min = indices.head val max = indices(indices.length - 1) - var nSmaller = 0L + var nLess = 0L var nGreater = 0L - val density = Array.fill(indices.length - 1)(0L) + val frequency = Array.fill(indices.length - 1)(0L) def merge(d: Double): HistogramCombiner = { if (d < min) - nSmaller += 1 + nLess += 1 else if (d > max) nGreater += 1 else if (!d.isNaN) { @@ -32,27 +32,28 @@ class HistogramCombiner(indices: Array[Double]) extends Serializable { val ind = if (bs < 0) -bs - 2 else - math.min(bs, density.length - 1) - assert(ind < density.length && ind >= 0, s"""found out of bounds index $ind - | Resulted from trying to merge $d - | Indices are [${indices.mkString(", ")}] - | Binary search index was $bs""".stripMargin) - density(ind) += 1 + math.min(bs, frequency.length - 1) + assert(ind < frequency.length && ind >= 0, + s"""found out of bounds index $ind + | Resulted from trying to merge $d + | Indices are [${ indices.mkString(", ") }] + | Binary search index was $bs""".stripMargin) + frequency(ind) += 1 } this } def merge(that: HistogramCombiner): HistogramCombiner = { - require(density.length == that.density.length) + require(frequency.length == that.frequency.length) - nSmaller += that.nSmaller + nLess += that.nLess nGreater += that.nGreater - for (i <- density.indices) - density(i) += that.density(i) + for (i <- frequency.indices) + frequency(i) += that.frequency(i) this } - def toAnnotation: Annotation = Annotation(indices: IndexedSeq[Double], density: IndexedSeq[Long], nSmaller, nGreater) + def toAnnotation: Annotation = Annotation(indices: IndexedSeq[Double], frequency: IndexedSeq[Long], nLess, nGreater) } diff --git a/src/main/scala/org/broadinstitute/hail/stats/package.scala b/src/main/scala/org/broadinstitute/hail/stats/package.scala index 6abffa56048..2e43d01fb4b 100644 --- a/src/main/scala/org/broadinstitute/hail/stats/package.scala +++ b/src/main/scala/org/broadinstitute/hail/stats/package.scala @@ -2,10 +2,9 @@ package org.broadinstitute.hail import breeze.linalg.Matrix import org.apache.commons.math3.distribution.HypergeometricDistribution -import org.apache.commons.math3.special.Gamma +import org.apache.commons.math3.special.{Erf, Gamma} import org.apache.spark.SparkContext import org.broadinstitute.hail.annotations.Annotation -import org.broadinstitute.hail.expr.{TDouble, TInt, TStruct} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.{Genotype, Variant, VariantDataset, VariantMetadata, VariantSampleMatrix} @@ -271,11 +270,26 @@ package object stats { Array(Option(pvalue), oddsRatioEstimate, confInterval._1, confInterval._2) } + val sqrt2 = math.sqrt(2) + + // Returns the p for which p = Prob(Z < x) with Z a standard normal RV + def pnorm(x: Double) = 0.5 * (1 + Erf.erf(x / sqrt2)) + + // Returns the x for which p = Prob(Z < x) with Z a standard normal RV + def qnorm(p: Double) = sqrt2 * Erf.erfInv(2 * p - 1) + + // Returns the p for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with one degree of freedom // This implementation avoids the round-off error truncation issue in // org.apache.commons.math3.distribution.ChiSquaredDistribution, // which computes the CDF with regularizedGammaP and p = 1 - CDF. def chiSquaredTail(df: Double, x: Double) = Gamma.regularizedGammaQ(df / 2, x / 2) + // Returns the x for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with one degree of freedom + def inverseChiSquaredTailOneDF(p: Double) = { + val q = qnorm(0.5 * p) + q * q + } + def uninitialized[T]: T = { class A { var x: T = _ diff --git a/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala b/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala index 93e35866f2a..19589520f0f 100644 --- a/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala +++ b/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala @@ -46,6 +46,11 @@ trait TextTableOptions { ) } +object TextTableConfiguration { + def apply(types: String, commentChar: String, separator: String, missing: String, noHeader: Boolean, impute: Boolean): TextTableConfiguration = + TextTableConfiguration(Parser.parseAnnotationTypes(Option(types).getOrElse("")), Option(commentChar), separator, missing, noHeader, impute) +} + case class TextTableConfiguration( types: Map[String, Type] = Map.empty[String, Type], commentChar: Option[String] = None, diff --git a/src/main/scala/org/broadinstitute/hail/utils/package.scala b/src/main/scala/org/broadinstitute/hail/utils/package.scala index 3c476e45829..4c41172fe5c 100644 --- a/src/main/scala/org/broadinstitute/hail/utils/package.scala +++ b/src/main/scala/org/broadinstitute/hail/utils/package.scala @@ -15,7 +15,44 @@ package object utils extends Logging with richUtils.Implicits with utils.NumericImplicits { - class FatalException(msg: String, logMsg: Option[String] = None) extends RuntimeException(msg) + class FatalException(val msg: String, val logMsg: Option[String] = None) extends RuntimeException(msg) + + def digForFatal(e: Throwable): Option[String] = { + val r = e match { + case f: FatalException => + println(s"found fatal $f") + Some(s"${ e.getMessage }") + case _ => + Option(e.getCause).flatMap(c => digForFatal(c)) + } + r + } + + def deepestMessage(e: Throwable): String = { + var iterE = e + while (iterE.getCause != null) + iterE = iterE.getCause + + s"${ e.getClass.getSimpleName }: ${ e.getLocalizedMessage }" + } + + def expandException(e: Throwable): String = { + val msg = e match { + case f: FatalException => f.logMsg.getOrElse(f.msg) + case _ => e.getLocalizedMessage + } + s"${ e.getClass.getName }: $msg\n\tat ${ e.getStackTrace.mkString("\n\tat ") }${ + Option(e.getCause).map(exception => expandException(exception)).getOrElse("") + }" + } + + def getMinimalMessage(e: Exception): String = { + val fatalOption = digForFatal(e) + val prefix = if (fatalOption.isDefined) "fatal" else "caught exception" + val msg = fatalOption.getOrElse(deepestMessage(e)) + log.error(s"hail: $prefix: $msg\nFrom ${ expandException(e) }") + msg + } trait Truncatable { def truncate: String diff --git a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala index 97d1acc8591..d1e3bdbd98f 100644 --- a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala +++ b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala @@ -23,6 +23,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.kududb.spark.kudu.{KuduContext, _} import Variant.orderedKey +import org.broadinstitute.hail.keytable.KeyTable import org.broadinstitute.hail.methods.{Aggregators, Filter} import org.broadinstitute.hail.utils @@ -598,6 +599,76 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, */ } + def aggregateByKey(keyCond: String, aggCond: String): KeyTable = { + val aggregationEC = EvalContext(Map( + "v" -> (0, TVariant), + "va" -> (1, vaSignature), + "s" -> (2, TSample), + "sa" -> (3, saSignature), + "global" -> (4, globalSignature), + "g" -> (5, TGenotype))) + + val ec = EvalContext(Map( + "v" -> (0, TVariant), + "va" -> (1, vaSignature), + "s" -> (2, TSample), + "sa" -> (3, saSignature), + "global" -> (4, globalSignature), + "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype)))) + + val ktEC = EvalContext( + aggregationEC.st.map { case (name, (i, t)) => name -> (-1, KeyTableAggregable(aggregationEC, t.asInstanceOf[Type], i)) } + ) + + ec.set(4, globalAnnotation) + aggregationEC.set(4, globalAnnotation) + + val (keyNameParseTypes, keyF) = + if (keyCond != null) + Parser.parseAnnotationArgs(keyCond, ec, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val (aggNameParseTypes, aggF) = + if (aggCond != null) + Parser.parseAnnotationArgs(aggCond, ktEC, None) + else + (Array.empty[(List[String], Type)], Array.empty[() => Any]) + + val keyNames = keyNameParseTypes.map(_._1.head) + val aggNames = aggNameParseTypes.map(_._1.head) + + val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*) + + val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC) + val aggFunctions = aggregationEC.aggregationFunctions.map(_._1) + + val localGlobalAnnotation = globalAnnotation + + val seqOp = (array: Array[Aggregator], r: Annotation) => { + KeyTable.setEvalContext(aggregationEC, r, 6) + for (i <- array.indices) { + array(i).seqOp(aggFunctions(i)(r)) + } + array + } + + val ktRDD = mapPartitionsWithAll { it => + it.map { case (v, va, s, sa, g) => + ec.setAll(v, va, s, sa, g) + val key = Annotation.fromSeq(keyF.map(_ ())) + (key, Annotation(v, va, s, sa, localGlobalAnnotation, g)) + } + }.aggregateByKey(zVals)(seqOp, combOp) + .map { case (k, agg) => + resultOp(agg) + (k, Annotation.fromSeq(aggF.map(_ ()))) + } + + KeyTable(ktRDD, keySignature, valueSignature) + } + def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(String, T)] = { val localtct = tct @@ -923,6 +994,18 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata, sqlContext.createDataFrame(rowRDD, schema) } + def samplesDF(sqlContext: SQLContext): DataFrame = { + val rowRDD = sparkContext.parallelize( + sampleIdsAndAnnotations.map { case (s, sa) => + Row(s, SparkAnnotationImpex.exportAnnotation(sa, saSignature)) + }) + val schema = StructType(Array( + StructField("sample", StringType, nullable = false), + StructField("sa", saSignature.schema, nullable = true) + )) + + sqlContext.createDataFrame(rowRDD, schema) + } } // FIXME AnyVal Scala 2.11 diff --git a/src/test/resources/sampleAnnotations2.tsv b/src/test/resources/sampleAnnotations2.tsv new file mode 100644 index 00000000000..c4ded00778e --- /dev/null +++ b/src/test/resources/sampleAnnotations2.tsv @@ -0,0 +1,124 @@ +Sample qPhen2 qPhen3 +HG00096 5540.8 27694 +HG00097 3327.2 16626 +HG00099 1451.2 7246 +HG00100 5714.8 28564 +HG00101 2417.6 12078 +HG00102 3948 19730 +HG00103 372.2 1851 +HG00105 4455.6 22268 +HG00106 5296.8 26474 +HG00107 5945.2 29716 +HG00108 3295 16465 +HG00109 6519 32585 +HG00110 4163.2 20806 +HG00111 6013 30055 +HG00112 4918 24580 +HG00113 1769 8835 +HG00114 6251 31245 +HG00115 5638 28180 +HG00116 2548.4 12732 +HG00117 4724.4 23612 +HG00118 3573.4 17857 +HG00119 6177.2 30876 +HG00120 3919.8 19589 +HG00121 966.4 4822 +HG00122 0 -10 +HG00123 5662.2 28301 +HG00124 538.2 2681 +HG00125 2893.2 14456 +HG00126 5506 27520 +HG00127 2044.8 10214 +HG00128 561.4 2797 +HG00129 1630.2 8141 +HG00130 5212 26050 +HG00131 4312.4 21552 +HG00132 2222.4 11102 +HG00133 4943.2 24706 +HG00136 2469.6 12338 +HG00137 3757.2 18776 +HG00138 1799 8985 +HG00139 385.6 1918 +HG00140 0 -10 +HG00096_B 5540.8 27694 +HG00097_B 3327.2 16626 +HG00099_B 1451.2 7246 +HG00100_B 5714.8 28564 +HG00101_B 2417.6 12078 +HG00102_B 3948 19730 +HG00103_B 372.2 1851 +HG00105_B 4455.6 22268 +HG00106_B 5296.8 26474 +HG00107_B 5945.2 29716 +HG00108_B 3295 16465 +HG00109_B 6519 32585 +HG00110_B 4163.2 20806 +HG00111_B 6013 30055 +HG00112_B 4918 24580 +HG00113_B 1769 8835 +HG00114_B 6251 31245 +HG00115_B 5638 28180 +HG00116_B 2548.4 12732 +HG00117_B 4724.4 23612 +HG00118_B 3573.4 17857 +HG00119_B 6177.2 30876 +HG00120_B 3919.8 19589 +HG00121_B 966.4 4822 +HG00122_B 0 -10 +HG00123_B 5662.2 28301 +HG00124_B 538.2 2681 +HG00125_B 2893.2 14456 +HG00126_B 5506 27520 +HG00127_B 2044.8 10214 +HG00128_B 561.4 2797 +HG00129_B 1630.2 8141 +HG00130_B 5212 26050 +HG00131_B 4312.4 21552 +HG00132_B 2222.4 11102 +HG00133_B 4943.2 24706 +HG00136_B 2469.6 12338 +HG00137_B 3757.2 18776 +HG00138_B 1799 8985 +HG00139_B 385.6 1918 +HG00140_B 0 -10 +HG00096_B_B 5540.8 27694 +HG00097_B_B 3327.2 16626 +HG00099_B_B 1451.2 7246 +HG00100_B_B 5714.8 28564 +HG00101_B_B 2417.6 12078 +HG00102_B_B 3948 19730 +HG00103_B_B 372.2 1851 +HG00105_B_B 4455.6 22268 +HG00106_B_B 5296.8 26474 +HG00107_B_B 5945.2 29716 +HG00108_B_B 3295 16465 +HG00109_B_B 6519 32585 +HG00110_B_B 4163.2 20806 +HG00111_B_B 6013 30055 +HG00112_B_B 4918 24580 +HG00113_B_B 1769 8835 +HG00114_B_B 6251 31245 +HG00115_B_B 5638 28180 +HG00116_B_B 2548.4 12732 +HG00117_B_B 4724.4 23612 +HG00118_B_B 3573.4 17857 +HG00119_B_B 6177.2 30876 +HG00120_B_B 3919.8 19589 +HG00121_B_B 966.4 4822 +HG00122_B_B 0 -10 +HG00123_B_B 5662.2 28301 +HG00124_B_B 538.2 2681 +HG00125_B_B 2893.2 14456 +HG00126_B_B 5506 27520 +HG00127_B_B 2044.8 10214 +HG00128_B_B 561.4 2797 +HG00129_B_B 1630.2 8141 +HG00130_B_B 5212 26050 +HG00131_B_B 4312.4 21552 +HG00132_B_B 2222.4 11102 +HG00133_B_B 4943.2 24706 +HG00136_B_B 2469.6 12338 +HG00137_B_B 3757.2 18776 +HG00138_B_B 1799 8985 +HG00139_B_B 385.6 1918 +HG00140_B_B 0 -10 diff --git a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala index 94bfd70cfdf..bbe2cddac4a 100644 --- a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala @@ -34,6 +34,9 @@ class SparkSuite extends TestNGSuite { val jar = getClass.getProtectionDomain.getCodeSource.getLocation.toURI.getPath HailConfiguration.installDir = new File(jar).getParent + "/.." HailConfiguration.tmpDir = "/tmp" + + driver.configure(sc, logFile = "hail.log", quiet = true, append = false, + parquetCompression = "uncompressed", blockSize = 1L, branchingFactor = 50, tmpDir = "/tmp") } @AfterClass(alwaysRun = true) diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala new file mode 100644 index 00000000000..c72b145e678 --- /dev/null +++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala @@ -0,0 +1,61 @@ +package org.broadinstitute.hail.driver + +import org.broadinstitute.hail.SparkSuite +import org.broadinstitute.hail.utils._ +import org.broadinstitute.hail.variant._ +import org.testng.annotations.Test + +class AggregateByKeySuite extends SparkSuite { + @Test def replicateSampleAggregation() = { + val inputVCF = "src/test/resources/sample.vcf" + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array(inputVCF)) + s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()")) + val kt = s.vds.aggregateByKey("Sample = s", "nHet = g.map(g => g.isHet.toInt).sum().toLong") + + val (_, ktHetQuery) = kt.query("nHet") + val (_, ktSampleQuery) = kt.query("Sample") + val (_, saHetQuery) = s.vds.querySA("sa.nHet") + + val ktSampleResults = kt.rdd.map { case (k, v) => + (ktSampleQuery(k, v).map(_.asInstanceOf[String]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) + }.collectAsMap() + + assert(s.vds.sampleIdsAndAnnotations.forall { case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid)) }) + } + + @Test def replicateVariantAggregation() = { + val inputVCF = "src/test/resources/sample.vcf" + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array(inputVCF)) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) + val kt = s.vds.aggregateByKey("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong") + + val (_, ktHetQuery) = kt.query("nHet") + val (_, ktVariantQuery) = kt.query("Variant") + val (_, vaHetQuery) = s.vds.queryVA("va.nHet") + + val ktVariantResults = kt.rdd.map { case (k, v) => + (ktVariantQuery(k, v).map(_.asInstanceOf[Variant]), ktHetQuery(k, v).map(_.asInstanceOf[Long])) + }.collectAsMap() + + assert(s.vds.variantsAndAnnotations.forall { case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v)) }) + } + + @Test def replicateGlobalAggregation() = { + val inputVCF = "src/test/resources/sample.vcf" + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array(inputVCF)) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()")) + s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum().toLong")) + val kt = s.vds.aggregateByKey(null, "nHet = g.map(g => g.isHet.toInt).sum().toLong") + + val (_, ktHetQuery) = kt.query("nHet") + val (_, globalHetResult) = s.vds.queryGlobal("global.nHet") + + val ktGlobalResult = kt.rdd.map { case (k, v) => ktHetQuery(k, v).map(_.asInstanceOf[Long]) }.collect().head + val vdsGlobalResult = globalHetResult.map(_.asInstanceOf[Long]) + + assert(ktGlobalResult == vdsGlobalResult) + } +} diff --git a/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala b/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala index 815972b617e..bb4046583f3 100644 --- a/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala @@ -66,4 +66,24 @@ class ExportPlinkSuite extends SparkSuite { } ) } + + @Test def testFamExport() { + val plink = tmpDir.createTempFile("mendel") + + var s = State(sc, sqlContext) + s = ImportVCF.run(s, Array("src/test/resources/mendel.vcf")) + s = SplitMulti.run(s) + s = HardCalls.run(s) + s = AnnotateSamplesFam.run(s, Array("-i", "src/test/resources/mendel.fam", "-d", "\\\\s+")) + s = AnnotateSamplesExpr.run(s, Array("-c", "sa = sa.fam")) + s = AnnotateVariantsExpr.run(s, Array("-c", "va.rsid = str(v)")) + s = AnnotateVariantsExpr.run(s, Array("-c", "va = select(va, rsid)")) + + s = ExportPlink.run(s, Array("-o", plink, "-f", + "famID = sa.famID, id = s.id, matID = sa.matID, patID = sa.patID, isFemale = sa.isFemale, isCase = sa.isCase")) + + var s2 = ImportPlink.run(s, Array("--bfile", plink)) + + assert(s.vds.same(s2.vds)) + } } diff --git a/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala b/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala index 5e797a4ffc7..86a848508c9 100644 --- a/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala @@ -44,7 +44,7 @@ class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat { class BGzipCodecSuite extends SparkSuite { @Test def test() { - sc.hadoopConfiguration.set("io.compression.codecs", "org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec") + sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", 1L) val uncompPath = "src/test/resources/sample.vcf" diff --git a/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala index 674357ce3ca..7f18f5205d5 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala @@ -140,24 +140,24 @@ class AggregatorSuite extends SparkSuite { s2.vds.rdd.collect.foreach { case (v, (va, gs)) => val r = va.asInstanceOf[Row] - val densities = r.getAs[IndexedSeq[Long]](1) + val frequencies = r.getAs[IndexedSeq[Long]](1) val definedGq = gs.flatMap(_.gq) - assert(densities(0) == definedGq.count(gq => gq < 5)) - assert(densities(1) == definedGq.count(gq => gq >= 5 && gq < 10)) - assert(densities.last == definedGq.count(gq => gq >= 95)) + assert(frequencies(0) == definedGq.count(gq => gq < 5)) + assert(frequencies(1) == definedGq.count(gq => gq >= 5 && gq < 10)) + assert(frequencies.last == definedGq.count(gq => gq >= 95)) } val s3 = AnnotateVariantsExpr.run(s, Array("-c", "va = gs.map(g => g.gq).hist(22, 80, 5)")) s3.vds.rdd.collect.foreach { case (v, (va, gs)) => val r = va.asInstanceOf[Row] - val nSmaller = r.getAs[Long](2) + val nLess = r.getAs[Long](2) val nGreater = r.getAs[Long](3) val definedGq = gs.flatMap(_.gq) - assert(nSmaller == definedGq.count(_ < 22)) + assert(nLess == definedGq.count(_ < 22)) assert(nGreater == definedGq.count(_ > 80)) } diff --git a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala index c81a992d21d..2140bc830aa 100644 --- a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala @@ -507,7 +507,17 @@ class ExprSuite extends SparkSuite { assert(eval[Boolean]("rnorm(2.0, 4.0).abs > -1.0").contains(true)) + assert(D_==(eval[Double]("pnorm(qnorm(0.5))").get, 0.5)) + assert(D_==(eval[Double]("qnorm(pnorm(0.5))").get, 0.5)) + + assert(D_==(eval[Double]("qchisq1tail(pchisq1tail(0.5))").get, 0.5)) + assert(D_==(eval[Double]("pchisq1tail(qchisq1tail(0.5))").get, 0.5)) + assert(eval[Any]("if (true) NA: Double else 0.0").isEmpty) + + assert(eval[Int]("gtIndex(3, 5)").contains(18)) + assert(eval[Int]("gtj(18)").contains(3)) + assert(eval[Int]("gtk(18)").contains(5)) } @Test def testParseTypes() { diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala new file mode 100644 index 00000000000..18aefd093ee --- /dev/null +++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala @@ -0,0 +1,176 @@ +package org.broadinstitute.hail.methods + +import org.broadinstitute.hail.SparkSuite +import org.broadinstitute.hail.annotations._ +import org.broadinstitute.hail.driver._ +import org.broadinstitute.hail.expr._ +import org.broadinstitute.hail.keytable.KeyTable +import org.broadinstitute.hail.utils._ +import org.testng.annotations.Test + +class KeyTableSuite extends SparkSuite { + + @Test def testSingleToPairRDD() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + val kt = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status", sc.defaultMinPartitions, TextTableConfiguration()) + val kt2 = KeyTable(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames) + + assert(kt.rdd.fullOuterJoin(kt2.rdd).forall { case (k, (v1, v2)) => + val res = v1 == v2 + if (!res) + println(s"k=$k v1=$v1 v2=$v2 res=${ v1 == v2 }") + res + }) + } + + @Test def testImportExport() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + val outputFile = tmpDir.createTempFile("ktImpExp", "tsv") + val kt = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status", sc.defaultMinPartitions, TextTableConfiguration()) + kt.export(sc, outputFile, null) + + val importedData = sc.hadoopConfiguration.readLines(inputFile)(_.map(_.value).toIndexedSeq) + val exportedData = sc.hadoopConfiguration.readLines(outputFile)(_.map(_.value).toIndexedSeq) + + intercept[FatalException] { + val kt2 = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status, BadKeyName", sc.defaultMinPartitions, TextTableConfiguration()) + } + + assert(importedData == exportedData) + } + + @Test def testAnnotate() = { + val inputFile = "src/test/resources/sampleAnnotations.tsv" + val kt1 = KeyTable.importTextTable(sc, Array(inputFile), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) + val kt2 = kt1.annotate("""qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""", kt1.keyNames.mkString(",")) + val kt3 = kt2.annotate("", kt2.keyNames.mkString(",")) + val kt4 = kt3.annotate("", "qPhen, NotStatus") + + val kt1ValueNames = kt1.valueNames.toSet + val kt2ValueNames = kt2.valueNames.toSet + + assert(kt1.nKeys == 1) + assert(kt2.nKeys == 1) + assert(kt1.nValues == 2 && kt2.nValues == 5) + assert(kt1.keySignature == kt2.keySignature) + assert(kt1ValueNames ++ Set("qPhen2", "NotStatus", "X") == kt2ValueNames) + assert(kt2 same kt3) + + def getDataAsMap(kt: KeyTable) = { + val fieldNames = kt.fieldNames + val nFields = kt.nFields + KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues) + .map { a => fieldNames.zip(KeyTable.annotationToSeq(a, nFields)).toMap }.collect().toSet + } + + val kt3data = getDataAsMap(kt3) + val kt4data = getDataAsMap(kt4) + + assert(kt4.keyNames.toSet == Set("qPhen", "NotStatus") && + kt4.valueNames.toSet == Set("qPhen2", "X", "Sample", "Status") && + kt3data == kt4data + ) + } + + @Test def testFilter() = { + val data = Array(Array(5, 9, 0), Array(2, 3, 4), Array(1, 2, 3)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("field1", TInt), ("field2", TInt), ("field3", TInt)) + val keyNames = Array("field1") + + val kt1 = KeyTable(rdd, signature, keyNames) + val kt2 = kt1.filter("field1 < 3", keep = true) + val kt3 = kt1.filter("field1 < 3 && field3 == 4", keep = true) + val kt4 = kt1.filter("field1 == 5 && field2 == 9 && field3 == 0", keep = false) + val kt5 = kt1.filter("field1 < -5 && field3 == 100", keep = true) + + assert(kt1.nRows == 3 && kt2.nRows == 2 && kt3.nRows == 1 && kt4.nRows == 2 && kt5.nRows == 0) + } + + @Test def testJoin() = { + val inputFile1 = "src/test/resources/sampleAnnotations.tsv" + val inputFile2 = "src/test/resources/sampleAnnotations2.tsv" + + val ktLeft = KeyTable.importTextTable(sc, Array(inputFile1), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) + val ktRight = KeyTable.importTextTable(sc, Array(inputFile2), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true)) + + val ktLeftJoin = ktLeft.leftJoin(ktRight) + val ktRightJoin = ktLeft.rightJoin(ktRight) + val ktInnerJoin = ktLeft.innerJoin(ktRight) + val ktOuterJoin = ktLeft.outerJoin(ktRight) + + val nExpectedValues = ktLeft.nValues + ktRight.nValues + + val (_, leftKeyQuery) = ktLeft.query("Sample") + val (_, rightKeyQuery) = ktRight.query("Sample") + val (_, leftJoinKeyQuery) = ktLeftJoin.query("Sample") + val (_, rightJoinKeyQuery) = ktRightJoin.query("Sample") + + val leftKeys = ktLeft.rdd.map { case (k, v) => leftKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet + val rightKeys = ktRight.rdd.map { case (k, v) => rightKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet + + val nIntersectRows = leftKeys.intersect(rightKeys).size + val nUnionRows = rightKeys.union(leftKeys).size + val nExpectedKeys = ktLeft.nKeys + + assert(ktLeftJoin.nRows == ktLeft.nRows && + ktLeftJoin.nKeys == nExpectedKeys && + ktLeftJoin.nValues == nExpectedValues && + ktLeftJoin.filter { case (k, v) => + !rightKeys.contains(leftJoinKeyQuery(k, v).map(_.asInstanceOf[String])) + }.forall("isMissing(qPhen2) && isMissing(qPhen3)") + ) + + assert(ktRightJoin.nRows == ktRight.nRows && + ktRightJoin.nKeys == nExpectedKeys && + ktRightJoin.nValues == nExpectedValues && + ktRightJoin.filter { case (k, v) => + !leftKeys.contains(rightJoinKeyQuery(k, v).map(_.asInstanceOf[String])) + }.forall("isMissing(Status) && isMissing(qPhen)")) + + assert(ktOuterJoin.nRows == nUnionRows && + ktOuterJoin.nKeys == ktLeft.nKeys && + ktOuterJoin.nValues == nExpectedValues) + + assert(ktInnerJoin.nRows == nIntersectRows && + ktInnerJoin.nKeys == nExpectedKeys && + ktInnerJoin.nValues == nExpectedValues) + } + + @Test def testAggregate() { + val data = Array(Array("Case", 9, 0), Array("Case", 3, 4), Array("Control", 2, 3), Array("Control", 1, 5)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("field1", TString), ("field2", TInt), ("field3", TInt)) + val keyNames = Array("field1") + + val kt1 = KeyTable(rdd, signature, keyNames) + val kt2 = kt1.aggregate("Status = field1", + "A = field2.sum(), " + + "B = field2.map(f => field2).sum(), " + + "C = field2.map(f => field2 + field3).sum(), " + + "D = field2.count(), " + + "E = field2.filter(f => field2 == 3).count()" + ) + + val result = Array(Array("Case", 12.0, 12.0, 16.0, 2L, 1L), Array("Control", 3.0, 3.0, 11.0, 2L, 0L)) + val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_))) + val resSignature = TStruct(("Status", TString), ("A", TDouble), ("B", TDouble), ("C", TDouble), ("D", TLong), ("E", TLong)) + val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status")) + + assert(kt2 same ktResult) + } + + @Test def testForallExists() { + val data = Array(Array("Sample1", 9, 5), Array("Sample2", 3, 5), Array("Sample3", 2, 5), Array("Sample4", 1, 5)) + val rdd = sc.parallelize(data.map(Annotation.fromSeq(_))) + val signature = TStruct(("Sample", TString), ("field1", TInt), ("field2", TInt)) + val keyNames = Array("Sample") + + val kt = KeyTable(rdd, signature, keyNames) + assert(kt.forall("field2 == 5 && field1 != 0")) + assert(!kt.forall("field2 == 0 && field1 == 5")) + assert(kt.exists("""Sample == "Sample1" && field1 == 9 && field2 == 5""")) + assert(!kt.exists("""Sample == "Sample1" && field1 == 13 && field2 == 2""")) + } + +} diff --git a/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala b/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala index 59a3fa40ac9..eeca50781e9 100644 --- a/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala +++ b/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala @@ -1,7 +1,7 @@ package org.broadinstitute.hail.stats import breeze.linalg.DenseMatrix -import org.apache.commons.math3.distribution.ChiSquaredDistribution +import org.apache.commons.math3.distribution.{ChiSquaredDistribution, NormalDistribution} import org.broadinstitute.hail.utils._ import org.broadinstitute.hail.variant.Variant import org.broadinstitute.hail.SparkSuite @@ -21,6 +21,24 @@ class StatsSuite extends SparkSuite { val chiSq5 = new ChiSquaredDistribution(5.2) assert(D_==(chiSquaredTail(5.2, 1), 1 - chiSq5.cumulativeProbability(1))) assert(D_==(chiSquaredTail(5.2, 5.52341), 1 - chiSq5.cumulativeProbability(5.52341))) + + assert(D_==(inverseChiSquaredTailOneDF(.1), chiSq1.inverseCumulativeProbability(1 - .1))) + assert(D_==(inverseChiSquaredTailOneDF(.0001), chiSq1.inverseCumulativeProbability(1 - .0001))) + + val a = List(.0000000001, .5, .9999999999, 1.0) + a.foreach(p => println(p, inverseChiSquaredTailOneDF(p))) + a.foreach(p => assert(D_==(chiSquaredTail(1.0, inverseChiSquaredTailOneDF(p)), p))) + } + + @Test def normTest() = { + val normalDist = new NormalDistribution() + assert(D_==(pnorm(1), normalDist.cumulativeProbability(1))) + assert(D_==(pnorm(-10), normalDist.cumulativeProbability(-10))) + assert(D_==(qnorm(.6), normalDist.inverseCumulativeProbability(.6))) + assert(D_==(qnorm(.0001), normalDist.inverseCumulativeProbability(.0001))) + + val a = List(0.0, .0000000001, .5, .9999999999, 1.0) + assert(a.forall(p => D_==(qnorm(pnorm(qnorm(p))), qnorm(p)))) } @Test def vdsFromMatrixTest() {