diff --git a/docs/faq/ExpressionLanguage.md b/docs/faq/ExpressionLanguage.md
index 3a009907f25..bada94b4cd4 100644
--- a/docs/faq/ExpressionLanguage.md
+++ b/docs/faq/ExpressionLanguage.md
@@ -1 +1,4 @@
## Expression Language
+
+
+
diff --git a/docs/reference/HailExpressionLanguage.md b/docs/reference/HailExpressionLanguage.md
index ba3d6876a47..b20a49982e9 100644
--- a/docs/reference/HailExpressionLanguage.md
+++ b/docs/reference/HailExpressionLanguage.md
@@ -58,6 +58,12 @@ Several Hail commands provide the ability to perform a broad array of computatio
- pcoin(p) -- returns `true` with probability `p`. `p` should be between 0.0 and 1.0
- runif(min, max) -- returns a random draw from a uniform distribution on \[`min`, `max`). `min` should be less than or equal to `max`
- rnorm(mean, sd) -- returns a random draw from a normal distribution with mean `mean` and standard deviation `sd`. `sd` should be non-negative
+
+ - Statistics
+ - pnorm(x) -- Returns left-tail probability p for which p = Prob($Z$ < x) with $Z$ a standard normal random variable
+ - qnorm(p) -- Returns left-quantile x for which p = Prob($Z$ < x) with $Z$ a standard normal random variable. `p` must satisfy `0 < p < 1`. Inverse of `pnorm`
+ - pchisq1tail(x) -- Returns right-tail probability p for which p = Prob($Z^2$ > x) with $Z^2$ a chi-squared random variable with one degree of freedom. `x` must be positive
+ - qchisq1tail(p) -- Returns right-quantile x for which p = Prob($Z^2$ > x) with $Z^2$ a chi-squared RV with one degree of freedom. `p` must satisfy `0 < p <= 1`. Inverse of `pchisq1tail`
- Array Operations:
- constructor: `[element1, element2, ...]` -- Create a new array from elements of the same type.
@@ -139,6 +145,10 @@ Several Hail commands provide the ability to perform a broad array of computatio
- range: `range(end)` or `range(start, end)`. This function will produce an `Array[Int]`. `range(3)` produces `[0, 1, 2]`. `range(-2, 2)` produces `[-2, -1, 0, 1]`.
+ - `gtj(i)` and `gtk(i)`. Convert from genotype index (triangular numbers) to `j/k` pairs.
+
+ - `gtIndex(j, k)`. Convert from `j/k` pair to genotype index (triangular numbers).
+
**Note:**
- All variables and values are case sensitive
@@ -355,7 +365,7 @@ The resulting array is sorted by count in descending order (the most common elem
.hist( start, end, bins )
```
-This aggregator is used to compute density distributions of numeric parameters. The start, end, and bins params are no-scope parameters, which means that while computations like `100 / 4` are acceptable, variable references like `global.nBins` are not.
+This aggregator is used to compute frequency distributions of numeric parameters. The start, end, and bins params are no-scope parameters, which means that while computations like `100 / 4` are acceptable, variable references like `global.nBins` are not.
The result of a `hist` invocation is a struct:
@@ -363,7 +373,7 @@ The result of a `hist` invocation is a struct:
Struct {
binEdges: Array[Double],
binFrequencies: Array[Long],
- nSmaller: Long,
+ nLess: Long,
nGreater: Long
}
```
@@ -374,7 +384,7 @@ Important properties:
- (bins + 1) breakpoints are generated from the range `(start to end by binsize)`
- `binEdges` stores an array of bin cutoffs. Each bin is left-inclusive, right-exclusive except the last bin, which includes the maximum value. This means that if there are N total bins, there will be N + 1 elements in binEdges. For the invocation `hist(0, 3, 3)`, `binEdges` would be `[0, 1, 2, 3]` where the bins are `[0, 1)`, `[1, 2)`, `[2, 3]`.
- `binFrequencies` stores the number of elements in the aggregable that fall in each bin. It contains one element for each bin.
- - Elements greater than the max bin or smaller than the min bin will be tracked separately by `nSmaller` and `nGreater`
+ - Elements greater than the max bin or less than the min bin will be tracked separately by `nLess` and `nGreater`
**Examples:**
@@ -388,7 +398,7 @@ Or, extend the above to compute a global gq histogram:
```
annotatevariants expr -c 'va.gqHist = gs.map(g => g.gq).hist(0, 100, 20)'
-annotateglobal expr -c 'global.gqDensity = variants.map(v => va.gqHist.densities).sum()'
+annotateglobal expr -c 'global.gqHist = variants.map(v => va.gqHist.binFrequencies).sum()'
```
### Collect
diff --git a/python/pyhail/__init__.py b/python/pyhail/__init__.py
index fc8f8d53054..a27e3d6f47d 100644
--- a/python/pyhail/__init__.py
+++ b/python/pyhail/__init__.py
@@ -1,4 +1,7 @@
from pyhail.context import HailContext
from pyhail.dataset import VariantDataset
+from pyhail.keytable import KeyTable
+from pyhail.utils import TextTableConfig
+from pyhail.type import Type
-__all__ = ["HailContext", "VariantDataset"]
+__all__ = ["HailContext", "VariantDataset", "KeyTable", "TextTableConfig", "Type"]
diff --git a/python/pyhail/context.py b/python/pyhail/context.py
index 2cb6b6a307a..87070f3b10f 100644
--- a/python/pyhail/context.py
+++ b/python/pyhail/context.py
@@ -1,16 +1,47 @@
import pyspark
from pyhail.dataset import VariantDataset
-from pyhail.java import jarray, scala_object
+from pyhail.java import jarray, scala_object, scala_package_object
+from pyhail.keytable import KeyTable
+from pyhail.utils import TextTableConfig
+from py4j.protocol import Py4JJavaError
+
+class FatalError(Exception):
+ """:class:`.FatalError` is an error thrown by Hail method failures"""
+
+ def __init__(self, message, java_exception):
+ self.msg = message
+ self.java_exception = java_exception
+ super(FatalError)
+
+ def __str__(self):
+ return self.msg
class HailContext(object):
""":class:`.HailContext` is the main entrypoint for PyHail
functionality.
:param SparkContext sc: The pyspark context.
+
+ :param str log: Log file.
+
+ :param bool quiet: Don't write log file.
+
+ :param bool append: Append to existing log file.
+
+ :param long block_size: Minimum size of file splits in MB.
+
+ :param str parquet_compression: Parquet compression codec.
+
+ :param int branching_factor: Branching factor to use in tree aggregate.
+
+ :param str tmp_dir: Temporary directory for file merging.
"""
- def __init__(self, sc):
+ def __init__(self, sc=None, log='hail.log', quiet=False, append=False,
+ block_size=1, parquet_compression='uncompressed',
+ branching_factor=50, tmp_dir='/tmp'):
+
self.sc = sc
self.gateway = sc._gateway
@@ -23,26 +54,37 @@ def __init__(self, sc):
self.sql_context = pyspark.sql.SQLContext(sc, self.jsql_context)
- self.jsc.hadoopConfiguration().set(
- 'io.compression.codecs',
- 'org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec')
+ scala_package_object(self.jvm.org.broadinstitute.hail.driver).configure(
+ self.jsc,
+ log,
+ quiet,
+ append,
+ parquet_compression,
+ block_size,
+ branching_factor,
+ tmp_dir)
- logger = sc._jvm.org.apache.log4j
- logger.LogManager.getLogger("org"). setLevel(logger.Level.ERROR)
- logger.LogManager.getLogger("akka").setLevel(logger.Level.ERROR)
def _jstate(self, jvds):
return self.jvm.org.broadinstitute.hail.driver.State(
self.jsc, self.jsql_context, jvds, scala_object(self.jvm.scala.collection.immutable, 'Map').empty())
+ def _raise_py4j_exception(self, e):
+ msg = scala_package_object(self.jvm.org.broadinstitute.hail.utils).getMinimalMessage(e.java_exception)
+ raise FatalError(msg, e.java_exception)
+
def run_command(self, vds, pargs):
jargs = jarray(self.gateway, self.jvm.java.lang.String, pargs)
t = self.jvm.org.broadinstitute.hail.driver.ToplevelCommands.lookup(jargs)
cmd = t._1()
cmd_args = t._2()
jstate = self._jstate(vds.jvds if vds != None else None)
- result = cmd.run(jstate,
- cmd_args)
+
+ try:
+ result = cmd.run(jstate, cmd_args)
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
return VariantDataset(self, result.vds())
def grep(self, regex, path, max_count=100):
@@ -74,7 +116,7 @@ def import_annotations_table(self, path, variant_expr, code=None, npartitions=No
# text table options
types=None, missing="NA", delimiter="\\t", comment=None,
header=True, impute=False):
- """Import variants and variant annotaitons from a delimited text file
+ """Import variants and variant annotations from a delimited text file
(text table) as a sites-only VariantDataset.
:param path: The files to import.
@@ -235,7 +277,43 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch
return self.run_command(None, pargs)
- def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing="NA", quantpheno=False):
+ def import_keytable(self, path, key_names, npartitions=None, config=None):
+ """Import delimited text file (text table) as KeyTable.
+
+ :param path: files to import.
+ :type path: str or list of str
+
+ :param key_names: The name(s) of fields to be considered keys
+ :type key_names: str or list of str
+
+ :param npartitions: Number of partitions.
+ :type npartitions: int or None
+
+ :param config: Configuration options for importing text files
+ :type config: :class:`.TextTableConfig` or None
+
+ :rtype: :class:`.KeyTable`
+ """
+ path_args = []
+ if isinstance(path, str):
+ path_args.append(path)
+ else:
+ for p in path:
+ path_args.append(p)
+
+ if not isinstance(key_names, str):
+ key_names = ','.join(key_names)
+
+ if not npartitions:
+ npartitions = self.sc.defaultMinPartitions
+
+ if not config:
+ config = TextTableConfig()
+
+ return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(
+ self.jsc, jarray(self.gateway, self.jvm.java.lang.String, path_args), key_names, npartitions, config.to_java(self)))
+
+ def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing='NA', quantpheno=False):
"""
Import PLINK binary file (.bed, .bim, .fam) as VariantDataset
@@ -427,7 +505,8 @@ def balding_nichols_model(self, populations, samples, variants, npartitions,
:rtype: :class:`.VariantDataset`
"""
- pargs = ['baldingnichols', '-k', str(populations), '-n', str(samples), '-m', str(variants), '--npartitions', str(npartitions),
+ pargs = ['baldingnichols', '-k', str(populations), '-n', str(samples), '-m', str(variants), '--npartitions',
+ str(npartitions),
'--root', root]
if population_dist:
pargs.append('-d')
diff --git a/python/pyhail/dataset.py b/python/pyhail/dataset.py
index e921bb6dcc8..8fd18373f1e 100644
--- a/python/pyhail/dataset.py
+++ b/python/pyhail/dataset.py
@@ -1,12 +1,33 @@
from pyhail.java import scala_package_object
+from pyhail.keytable import KeyTable
import pyspark
+from py4j.protocol import Py4JJavaError
class VariantDataset(object):
def __init__(self, hc, jvds):
self.hc = hc
self.jvds = jvds
+ def _raise_py4j_exception(self, e):
+ self.hc._raise_py4j_exception(e)
+
+ def aggregate_by_key(self, key_code, agg_code):
+ """Aggregate by user-defined key and aggregation expressions.
+ Equivalent of a group-by operation in SQL.
+
+ :param key_code: Named expression(s) for which fields are keys.
+ :type key_code: str or list of str
+
+ :param agg_code: Named aggregation expression(s).
+ :type agg_code: str or list of str
+
+ :rtype: :class:`.KeyTable`
+
+ """
+
+ return KeyTable(self.hc, self.jvds.aggregateByKey(key_code, agg_code))
+
def aggregate_intervals(self, input, condition, output):
"""Aggregate over intervals and export.
@@ -41,7 +62,6 @@ def annotate_global_list(self, input, root, as_set=False):
:param bool as_set: If True, load text file as Set[String],
otherwise, load as Array[String].
-
"""
pargs = ['annotateglobal', 'list', '-i', input, '-r', root]
@@ -324,9 +344,12 @@ def count(self, genotypes=False):
"""
- return (scala_package_object(self.hc.jvm.org.broadinstitute.hail.driver)
- .count(self.jvds, genotypes)
- .toJavaMap())
+ try:
+ return (scala_package_object(self.hc.jvm.org.broadinstitute.hail.driver)
+ .count(self.jvds, genotypes)
+ .toJavaMap())
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
def deduplicate(self):
"""Remove duplicate variants."""
@@ -381,14 +404,14 @@ def export_genotypes(self, output, condition, types=None, export_ref=False, expo
pargs.append('--print-missing')
return self.hc.run_command(self, pargs)
- def export_plink(self, output):
+ def export_plink(self, output, fam_expr = 'id = s.id'):
"""Export as PLINK .bed/.bim/.fam
:param str output: Output file base. Will write .bed, .bim and .fam files.
"""
- pargs = ['exportplink', '--output', output]
+ pargs = ['exportplink', '--output', output, '--fam-expr', fam_expr]
return self.hc.run_command(self, pargs)
def export_samples(self, output, condition, types=None):
@@ -505,7 +528,7 @@ def write(self, output, overwrite=False):
:param str output: Path of .vds file to write.
:param bool overwrite: If True, overwrite any existing .vds file.
-
+
"""
pargs = ['write', '-o', output]
@@ -513,14 +536,16 @@ def write(self, output, overwrite=False):
pargs.append('--overwrite')
return self.hc.run_command(self, pargs)
- def filter_genotypes(self, condition):
+ def filter_genotypes(self, condition, keep=True):
"""Filter variants based on expression.
:param str condition: Expression for filter condition.
"""
- pargs = ['filtergenotypes', '--keep', '-c', condition]
+ pargs = ['filtergenotypes',
+ '--keep' if keep else '--remove',
+ '-c', condition]
return self.hc.run_command(self, pargs)
def filter_multi(self):
@@ -539,24 +564,28 @@ def filter_samples_all(self):
pargs = ['filtersamples', 'all']
return self.hc.run_command(self, pargs)
- def filter_samples_expr(self, condition):
+ def filter_samples_expr(self, condition, keep=True):
"""Filter samples based on expression.
:param str condition: Expression for filter condition.
"""
- pargs = ['filtersamples', 'expr', '--keep', '-c', condition]
+ pargs = ['filtersamples', 'expr',
+ '--keep' if keep else '--remove',
+ '-c', condition]
return self.hc.run_command(self, pargs)
- def filter_samples_list(self, input):
+ def filter_samples_list(self, input, keep=True):
"""Filter samples with a sample list file.
:param str input: Path to sample list file.
"""
- pargs = ['filtersamples', 'list', '--keep', '-i', input]
+ pargs = ['filtersamples', 'list',
+ '--keep' if keep else '--remove',
+ '-i', input]
return self.hc.run_command(self, pargs)
def filter_variants_all(self):
@@ -565,34 +594,40 @@ def filter_variants_all(self):
pargs = ['filtervariants', 'all']
return self.hc.run_command(self, pargs)
- def filter_variants_expr(self, condition):
+ def filter_variants_expr(self, condition, keep=True):
"""Filter variants based on expression.
:param str condition: Expression for filter condition.
"""
- pargs = ['filtervariants', 'expr', '--keep', '-c', condition]
+ pargs = ['filtervariants', 'expr',
+ '--keep' if keep else '--remove',
+ '-c', condition]
return self.hc.run_command(self, pargs)
- def filter_variants_intervals(self, input):
+ def filter_variants_intervals(self, input, keep=True):
"""Filter variants with an .interval_list file.
:param str input: Path to .interval_list file.
"""
- pargs = ['filtervariants', 'intervals', '--keep', '-i', input]
+ pargs = ['filtervariants', 'intervals',
+ '--keep' if keep else '--remove',
+ '-i', input]
return self.hc.run_command(self, pargs)
- def filter_variants_list(self, input):
+ def filter_variants_list(self, input, keep=True):
"""Filter variants with a list of variants.
:param str input: Path to variant list file.
"""
- pargs = ['filtervariants', 'list', '--keep', '-i', input]
+ pargs = ['filtervariants', 'list',
+ '--keep' if keep else '--remove',
+ '-i', input]
return self.hc.run_command(self, pargs)
def grm(self, format, output, id_file=None, n_file=None):
@@ -698,11 +733,10 @@ def join(self, right):
and global annotations from self.
"""
-
- return VariantDataset(
- self.hc,
- self.hc.jvm.org.broadinstitute.hail.driver.Join.join(self.jvds,
- right.jvds))
+ try:
+ return VariantDataset(self.hc, self.hc.jvm.org.broadinstitute.hail.driver.Join.join(self.jvds, right.jvds))
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
def linreg(self, y, covariates='', root='va.linreg', minac=1, minaf=None):
"""Test each variant for association using the linear regression
@@ -865,8 +899,10 @@ def same(self, other):
:rtype: bool
"""
-
- return self.jvds.same(other.jvds, 1e-6)
+ try:
+ return self.jvds.same(other.jvds, 1e-6)
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
def sample_qc(self, branching_factor=None):
"""Compute per-sample QC metrics.
@@ -914,7 +950,7 @@ def split_multi(self, propagate_gq=False):
pargs.append('--propagate-gq')
return self.hc.run_command(self, pargs)
- def tdt(self, fam, root = 'va.tdt'):
+ def tdt(self, fam, root='va.tdt'):
"""Find transmitted and untransmitted variants; count per variant and
nuclear family.
@@ -975,5 +1011,16 @@ def vep(self, config, block_size=None, root=None, force=False, csq=False):
def variants_to_pandas(self):
"""Convert variants and variant annotations to Pandas dataframe."""
- return pyspark.sql.DataFrame(self.jvds.variantsDF(self.hc.jsql_context),
- self.hc.sql_context).toPandas()
+ try:
+ return pyspark.sql.DataFrame(self.jvds.variantsDF(self.hc.jsql_context),
+ self.hc.sql_context).toPandas()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def samples_to_pandas(self):
+ """Convert samples and sample annotations to Pandas dataframe."""
+ try:
+ return pyspark.sql.DataFrame(self.jvds.samplesDF(self.hc.jsql_context),
+ self.hc.sql_context).toPandas()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
diff --git a/python/pyhail/docs/index.rst b/python/pyhail/docs/index.rst
index c743f23687e..228c112559c 100644
--- a/python/pyhail/docs/index.rst
+++ b/python/pyhail/docs/index.rst
@@ -17,6 +17,12 @@ Contents:
.. autoclass:: pyhail.VariantDataset
:members:
+.. autoclass:: pyhail.KeyTable
+ :members:
+
+.. autoclass:: pyhail.TextTableConfig
+ :members:
+
Indices and tables
==================
diff --git a/python/pyhail/keytable.py b/python/pyhail/keytable.py
new file mode 100644
index 00000000000..ec294f4c556
--- /dev/null
+++ b/python/pyhail/keytable.py
@@ -0,0 +1,193 @@
+from pyhail.type import Type
+
+class KeyTable(object):
+ """:class:`.KeyTable` is Hail's version of a SQL table where fields
+ can be designated as keys.
+
+ """
+
+ def __init__(self, hc, jkt):
+ """
+ :param HailContext hc: Hail spark context.
+
+ :param JavaKeyTable jkt: Java KeyTable object.
+ """
+ self.hc = hc
+ self.jkt = jkt
+
+ def _raise_py4j_exception(self, e):
+ self.hc._raise_py4j_exception(e)
+
+ def __repr__(self):
+ try:
+ return self.jkt.toString()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def nfields(self):
+ """Number of fields in the key-table
+
+ :rtype: int
+ """
+ try:
+ return self.jkt.nFields()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def schema(self):
+ """Key-table schema
+
+ :rtype: :class:`.Type`
+ """
+ try:
+ return Type(self.jkt.signature())
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def key_names(self):
+ """Field names that are keys
+
+ :rtype: list of str
+ """
+ try:
+ return self.jkt.keyNames()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def field_names(self):
+ """Names of all fields in the key-table
+
+ :rtype: list of str
+ """
+ try:
+ return self.jkt.fieldNames()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def nrows(self):
+ """Number of rows in the key-table
+
+ :rtype: long
+ """
+ try:
+ return self.jkt.nRows()
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+
+ def same(self, other):
+ """Test whether two key-tables are identical
+
+ :param other: KeyTable to compare to
+ :type other: :class:`.KeyTable`
+
+ :rtype: bool
+ """
+ try:
+ return self.jkt.same(other.jkt)
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def export(self, output, types_file=None):
+ """Export key-table to a TSV file.
+
+ :param str output: Output file path
+
+ :param str types_file: Output path of types file
+
+ :rtype: Nothing.
+ """
+ try:
+ self.jkt.export(self.hc.jsc, output, types_file)
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def filter(self, code, keep=True):
+ """Filter rows from key-table.
+
+ :param str code: Annotation expression.
+
+ :param bool keep: Keep rows where annotation expression evaluates to True
+
+ :rtype: :class:`.KeyTable`
+ """
+ try:
+ return KeyTable(self.hc, self.jkt.filter(code, keep))
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def annotate(self, code, key_names=''):
+ """Add fields to key-table.
+
+ :param str code: Annotation expression.
+
+ :param str key_names: Comma separated list of field names to be treated as a key
+
+ :rtype: :class:`.KeyTable`
+ """
+ try:
+ return KeyTable(self.hc, self.jkt.annotate(code, key_names))
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def join(self, right, how='inner'):
+ """Join two key-tables together. Both key-tables must have identical key schemas
+ and non-overlapping field names.
+
+ :param right: Key-table to join
+ :type right: :class:`.KeyTable`
+
+ :param str how: Method for joining two tables together. One of "inner", "outer", "left", "right".
+
+ :rtype: :class:`.KeyTable`
+ """
+ try:
+ return KeyTable(self.hc, self.jkt.join(right.jkt, how))
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def aggregate_by_key(self, key_code, agg_code):
+ """Group by key condition and aggregate results
+
+ :param key_code: Named expression(s) for which fields are keys.
+ :type key_code: str or list of str
+
+ :param agg_code: Named aggregation expression(s).
+ :type agg_code: str or list of str
+
+ :rtype: :class:`.KeyTable`
+ """
+ if isinstance(key_code, list):
+ key_code = ", ".join([str(l) for l in list])
+
+ if isinstance(agg_code, list):
+ agg_code = ", ".join([str(l) for l in list])
+
+ try:
+ return KeyTable(self.hc, self.jkt.aggregate(key_code, agg_code))
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def forall(self, code):
+ """Tests whether a condition is true for all rows
+
+ :param str code: Boolean expression
+
+ :rtype: bool
+ """
+ try:
+ return self.jkt.forall(code)
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
+
+ def exists(self, code):
+ """Tests whether a condition is true for any row
+
+ :param str code: Boolean expression
+
+ :rtype: bool
+ """
+ try:
+ return self.jkt.exists(code)
+ except Py4JJavaError as e:
+ self._raise_py4j_exception(e)
diff --git a/python/pyhail/tests.py b/python/pyhail/tests.py
index 158d77220e7..e83627ddff1 100644
--- a/python/pyhail/tests.py
+++ b/python/pyhail/tests.py
@@ -6,7 +6,7 @@
import unittest
from pyspark import SparkContext
-from pyhail import HailContext
+from pyhail import HailContext, TextTableConfig
class ContextTests(unittest.TestCase):
@@ -208,6 +208,45 @@ def test_dataset(self):
sample2_split.variant_qc().print_schema()
sample2.variants_to_pandas()
-
+
+ sample_split.annotate_variants_expr("va.nHet = gs.filter(g => g.isHet).count()")
+
+ kt = sample_split.aggregate_by_key("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong")
+
+ def test_keytable(self):
+ # Import
+ kt = self.hc.import_keytable(self.test_resources + '/sampleAnnotations.tsv', 'Sample', config = TextTableConfig(impute = True))
+ kt2 = self.hc.import_keytable(self.test_resources + '/sampleAnnotations2.tsv', 'Sample', config = TextTableConfig(impute = True))
+
+ # Variables
+ self.assertEqual(kt.nfields(), 3)
+ self.assertEqual(kt.key_names()[0], "Sample")
+ self.assertEqual(kt.field_names()[2], "qPhen")
+ self.assertEqual(kt.nrows(), 100)
+ kt.schema()
+
+ # Export
+ kt.export('/tmp/testExportKT.tsv')
+
+ # Filter, Same
+ ktcase = kt.filter('Status == "CASE"', True)
+ ktcase2 = kt.filter('Status == "CTRL"', False)
+ self.assertTrue(ktcase.same(ktcase2))
+
+ # Annotate
+ (kt.annotate('X = Status', 'Sample, Status')
+ .nrows())
+
+ # Join
+ kt.join(kt2, 'left').nrows()
+
+ # AggregateByKey
+ (kt.aggregate_by_key("Status = Status", "Sum = qPhen.sum()")
+ .nrows())
+
+ # Forall, Exists
+ self.assertFalse(kt.forall('Status == "CASE"'))
+ self.assertTrue(kt.exists('Status == "CASE"'))
+
def tearDown(self):
self.sc.stop()
diff --git a/python/pyhail/type.py b/python/pyhail/type.py
new file mode 100644
index 00000000000..ea85d1f4c84
--- /dev/null
+++ b/python/pyhail/type.py
@@ -0,0 +1,12 @@
+
+class Type(object):
+ """Type of values."""
+
+ def __init__(self, jtype):
+ self.jtype = jtype
+
+ def __repr__(self):
+ return self.jtype.toString()
+
+ def __str__(self):
+ return self.jtype.toPrettyString(False, False)
diff --git a/python/pyhail/utils.py b/python/pyhail/utils.py
new file mode 100644
index 00000000000..85ee9ff36a9
--- /dev/null
+++ b/python/pyhail/utils.py
@@ -0,0 +1,47 @@
+
+class TextTableConfig(object):
+ """Configuration for delimited (text table) files.
+
+ :param bool noheader: File has no header and columns should be indicated by `_1, _2, ... _N' (0-indexed)
+
+ :param bool impute: Impute column types from the file
+
+ :param comment: Skip lines beginning with the given pattern
+ :type comment: str or None
+
+ :param str delimiter: Field delimiter regex
+
+ :param str missing: Specify identifier to be treated as missing
+
+ :param types: Define types of fields in annotations files
+ :type types: str or None
+ """
+ def __init__(self, noheader=False, impute=False,
+ comment=None, delimiter="\t", missing="NA", types=None):
+ self.noheader = noheader
+ self.impute = impute
+ self.comment = comment
+ self.delimiter = delimiter
+ self.missing = missing
+ self.types = types
+
+ def __str__(self):
+ res = ["--comment", self.comment, "--delimiter", self.delimiter,
+ "--missing", self.missing]
+
+ if self.noheader:
+ res.append("--no-header")
+
+ if self.impute:
+ res.append("--impute")
+
+ return " ".join(res)
+
+ def to_java(self, hc):
+ """Convert to Java TextTableConfiguration object.
+
+ :param :class:`.HailContext` The Hail context.
+ """
+ return hc.jvm.org.broadinstitute.hail.utils.TextTableConfiguration.apply(self.types, self.comment,
+ self.delimiter, self.missing,
+ self.noheader, self.impute)
diff --git a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala
index af89ccb2842..e94a6689865 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/AggregateIntervals.scala
@@ -58,7 +58,7 @@ object AggregateIntervals extends Command {
ec.set(1, vds.globalAnnotation)
aggregationEC.set(1, vds.globalAnnotation)
- val (header, _, f) = Parser.parseNamedArgs(cond, ec)
+ val (header, _, f) = Parser.parseExportArgs(cond, ec)
if (header.isEmpty)
fatal("this module requires one or more named expr arguments")
@@ -67,8 +67,6 @@ object AggregateIntervals extends Command {
val zvf = () => zVals.indices.map(zVals).toArray
- val variantAggregations = Aggregators.buildVariantAggregations(vds, aggregationEC)
-
val iList = IntervalListAnnotator.read(options.input, sc.hadoopConfiguration)
val iListBc = sc.broadcast(iList)
diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala
index f0b9e62a9e4..6359054501f 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateGlobalExpr.scala
@@ -48,7 +48,7 @@ object AnnotateGlobalExpr extends Command {
aggECS.set(1, vds.globalAnnotation)
aggECV.set(1, vds.globalAnnotation)
- val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.GLOBAL_HEAD)
+ val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.GLOBAL_HEAD))
val inserterBuilder = mutable.ArrayBuilder.make[Inserter]
@@ -82,7 +82,7 @@ object AnnotateGlobalExpr extends Command {
val ga = inserters
.zip(fns.map(_ ()))
.foldLeft(vds.globalAnnotation) { case (a, (ins, res)) =>
- ins(a, res)
+ ins(a, Option(res))
}
state.copy(vds = vds.copy(
diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala
index 4bbfa9663a4..c2ea5b0eb59 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateSamplesExpr.scala
@@ -47,7 +47,7 @@ object AnnotateSamplesExpr extends Command {
ec.set(2, vds.globalAnnotation)
aggregationEC.set(4, vds.globalAnnotation)
- val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.SAMPLE_HEAD)
+ val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.SAMPLE_HEAD))
val inserterBuilder = mutable.ArrayBuilder.make[Inserter]
val finalType = parseTypes.foldLeft(vds.saSignature) { case (sas, (ids, signature)) =>
@@ -70,7 +70,7 @@ object AnnotateSamplesExpr extends Command {
fns.zip(inserters)
.foldLeft(sa) { case (sa, (fn, inserter)) =>
- inserter(sa, fn())
+ inserter(sa, Option(fn()))
}
}
state.copy(vds = vds.copy(
diff --git a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala
index d10b20810f7..dcfce9e75f9 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/AnnotateVariantsExpr.scala
@@ -49,7 +49,7 @@ object AnnotateVariantsExpr extends Command {
ec.set(2, vds.globalAnnotation)
aggregationEC.set(4, vds.globalAnnotation)
- val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Annotation.VARIANT_HEAD)
+ val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, Some(Annotation.VARIANT_HEAD))
val inserterBuilder = mutable.ArrayBuilder.make[Inserter]
val finalType = parseTypes.foldLeft(vds.vaSignature) { case (vas, (ids, signature)) =>
@@ -67,7 +67,7 @@ object AnnotateVariantsExpr extends Command {
aggregateOption.foreach(f => f(v, va, gs))
fns.zip(inserters)
.foldLeft(va) { case (va, (fn, inserter)) =>
- inserter(va, fn())
+ inserter(va, Option(fn()))
}
}.copy(vaSignature = finalType)
state.copy(vds = annotated)
diff --git a/src/main/scala/org/broadinstitute/hail/driver/Command.scala b/src/main/scala/org/broadinstitute/hail/driver/Command.scala
index 436f6420500..e40ef4554d5 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/Command.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/Command.scala
@@ -254,7 +254,7 @@ abstract class Command {
fatal("this module does not support multiallelic variants.\n Please run `splitmulti' first.")
else {
if (requiresVDS)
- log.info(s"sparkinfo: $name, ${state.vds.nPartitions} partitions, ${state.vds.rdd.getStorageLevel.toReadableString()}")
+ log.info(s"sparkinfo: $name, ${ state.vds.nPartitions } partitions, ${ state.vds.rdd.getStorageLevel.toReadableString() }")
run(state, options)
}
}
diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala
index b51ef90d540..f9ef1f99e75 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/ExportGenotypes.scala
@@ -59,7 +59,7 @@ object ExportGenotypes extends Command with TextExporter {
val ec = EvalContext(symTab)
ec.set(5, vds.globalAnnotation)
- val (header, ts, f) = Parser.parseExportArgs(cond, ec)
+ val (header, ts, f) = Parser.parseNamedArgs(cond, ec)
Option(options.typesFile).foreach { file =>
val typeInfo = header
diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala
index 104f3db7724..f7b898f77a6 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/ExportPlink.scala
@@ -1,6 +1,7 @@
package org.broadinstitute.hail.driver
import org.apache.spark.storage.StorageLevel
+import org.broadinstitute.hail.expr.{BaseAggregable, EvalContext, Parser, TBoolean, TDouble, TGenotype, TSample, TString, Type}
import org.broadinstitute.hail.utils._
import org.broadinstitute.hail.io.plink.ExportBedBimFam
import org.kohsuke.args4j.{Option => Args4jOption}
@@ -11,6 +12,10 @@ object ExportPlink extends Command {
@Args4jOption(required = true, name = "-o", aliases = Array("--output"),
usage = "Output file base (will generate .bed, .bim, .fam)")
var output: String = _
+
+ @Args4jOption(name = "-f", aliases = Array("--fam-expr"),
+ usage = "Expression for .fam file values, in sample context only (global, s, sa in scope), assignable fields: famID, id, matID, patID (String), isFemale (Boolean), isCase (Boolean) or qPheno (Double)")
+ var famExpr: String = "id = s.id"
}
def newOptions = new Options
@@ -26,12 +31,64 @@ object ExportPlink extends Command {
def run(state: State, options: Options): State = {
val vds = state.vds
+ val symTab = Map(
+ "s" -> (0, TSample),
+ "sa" -> (1, vds.saSignature),
+ "global" -> (2, vds.globalSignature))
+
+ val ec = EvalContext(symTab)
+ ec.set(2, vds.globalAnnotation)
+
+ type Formatter = (() => Option[Any]) => () => String
+
+ val formatID: Formatter = f => () => f().map(_.asInstanceOf[String]).getOrElse("0")
+ val formatIsFemale: Formatter = f => () => f().map {
+ _.asInstanceOf[Boolean] match {
+ case true => "2"
+ case false => "1"
+ }
+ }.getOrElse("0")
+ val formatIsCase: Formatter = f => () => f().map {
+ _.asInstanceOf[Boolean] match {
+ case true => "2"
+ case false => "1"
+ }
+ }.getOrElse("-9")
+ val formatQPheno: Formatter = f => () => f().map(_.toString).getOrElse("-9")
+
+ val famColumns: Map[String, (Type, Int, Formatter)] = Map(
+ "famID" -> (TString, 0, formatID),
+ "id" -> (TString, 1, formatID),
+ "patID" -> (TString, 2, formatID),
+ "matID" -> (TString, 3, formatID),
+ "isFemale" -> (TBoolean, 4, formatIsFemale),
+ "qPheno" -> (TDouble, 5, formatQPheno),
+ "isCase" -> (TBoolean, 5, formatIsCase))
+
+ val exprs = Parser.parseNamedExprs(options.famExpr, ec)
+
+ val famFns: Array[() => String] = Array(
+ () => "0", () => "0", () => "0", () => "0", () => "-9", () => "-9")
+
+ exprs.foreach { case (name, t, f) =>
+ famColumns.get(name) match {
+ case Some((colt, i, formatter)) =>
+ if (colt != t)
+ fatal("invalid type for .fam file column $h: expected $colt, got $t")
+ famFns(i) = formatter(f)
+
+ case None =>
+ fatal(s"no .fam file column $name")
+ }
+ }
+
val spaceRegex = """\s+""".r
val badSampleIds = vds.sampleIds.filter(id => spaceRegex.findFirstIn(id).isDefined)
if (badSampleIds.nonEmpty) {
- fatal(s"""Found ${ badSampleIds.length } sample IDs with whitespace
+ fatal(
+ s"""Found ${ badSampleIds.length } sample IDs with whitespace
| Please run `renamesamples' to fix this problem before exporting to plink format
- | Bad sample IDs: @1 """.stripMargin, badSampleIds)
+ | Bad sample IDs: @1 """.stripMargin, badSampleIds)
}
val bedHeader = Array[Byte](108, 27, 1)
@@ -49,8 +106,11 @@ object ExportPlink extends Command {
plinkRDD.unpersist()
val famRows = vds
- .sampleIds
- .map(ExportBedBimFam.makeFamRow)
+ .sampleIdsAndAnnotations
+ .map { case (s, sa) =>
+ ec.setAll(s, sa)
+ famFns.map(_()).mkString("\t")
+ }
state.hadoopConf.writeTextFile(options.output + ".fam")(out =>
famRows.foreach(line => {
diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala
index ea5d75f974e..333c2f7a7ba 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/ExportSamples.scala
@@ -58,7 +58,7 @@ object ExportSamples extends Command with TextExporter {
ec.set(2, vds.globalAnnotation)
aggregationEC.set(5, vds.globalAnnotation)
- val (header, types, f) = Parser.parseExportArgs(cond, ec)
+ val (header, types, f) = Parser.parseNamedArgs(cond, ec)
Option(options.typesFile).foreach { file =>
val typeInfo = header
.getOrElse(types.indices.map(i => s"_$i").toArray)
diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala
index 0e2aaaa7b20..46389d657c4 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariants.scala
@@ -48,6 +48,7 @@ object ExportVariants extends Command with TextExporter {
"sa" -> (3, vds.saSignature),
"g" -> (4, TGenotype),
"global" -> (5, vds.globalSignature)))
+
val symTab = Map(
"v" -> (0, TVariant),
"va" -> (1, vds.vaSignature),
@@ -59,7 +60,7 @@ object ExportVariants extends Command with TextExporter {
ec.set(2, vds.globalAnnotation)
aggregationEC.set(5, vds.globalAnnotation)
- val (header, types, f) = Parser.parseExportArgs(cond, ec)
+ val (header, types, f) = Parser.parseNamedArgs(cond, ec)
Option(options.typesFile).foreach { file =>
val typeInfo = header
diff --git a/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala b/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala
index 482b7ab9bf0..9c039890401 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/ExportVariantsCass.scala
@@ -153,7 +153,7 @@ object ExportVariantsCass extends Command {
val vEC = EvalContext(vSymTab)
val vA = vEC.a
- val (vHeader, vTypes, vf) = Parser.parseNamedArgs(vCond, vEC)
+ val (vHeader, vTypes, vf) = Parser.parseExportArgs(vCond, vEC)
val gSymTab = Map(
"v" -> (0, TVariant),
@@ -164,7 +164,7 @@ object ExportVariantsCass extends Command {
val gEC = EvalContext(gSymTab)
val gA = gEC.a
- val (gHeader, gTypes, gf) = Parser.parseNamedArgs(gCond, gEC)
+ val (gHeader, gTypes, gf) = Parser.parseExportArgs(gCond, gEC)
val symTab = Map(
"v" -> (0, TVariant),
diff --git a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala
index 6939be79ad3..37f7f38332c 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/FilterAlleles.scala
@@ -72,7 +72,7 @@ object FilterAlleles extends Command {
"v" -> (0, TVariant),
"va" -> (1, state.vds.vaSignature),
"aIndices" -> (2, TArray(TInt))))
- val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Annotation.VARIANT_HEAD)
+ val (types, generators) = Parser.parseAnnotationArgs(options.annotation, annotationEC, Some(Annotation.VARIANT_HEAD))
val inserterBuilder = mutable.ArrayBuilder.make[Inserter]
val finalType = types.foldLeft(state.vds.vaSignature) { case (vas, (path, signature)) =>
val (newVas, i) = vas.insert(signature, path)
@@ -120,7 +120,7 @@ object FilterAlleles extends Command {
def updateAnnotation(v: Variant, va: Annotation, newToOld: IndexedSeq[Int]): Annotation = {
annotationEC.setAll(v, va, newToOld)
- generators.zip(inserters).foldLeft(va) { case (va, (fn, inserter)) => inserter(va, fn()) }
+ generators.zip(inserters).foldLeft(va) { case (va, (fn, inserter)) => inserter(va, Option(fn())) }
}
def updateGenotypes(gs: Iterable[Genotype], oldToNew: Array[Int], newCount: Int): Iterable[Genotype] = {
diff --git a/src/main/scala/org/broadinstitute/hail/driver/Main.scala b/src/main/scala/org/broadinstitute/hail/driver/Main.scala
index a9aa78c025b..6d5d2d9a317 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/Main.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/Main.scala
@@ -34,10 +34,7 @@ object SparkManager {
conf.setMaster(local)
}
- conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
_sc = new SparkContext(conf)
- _sc.hadoopConfiguration.set("io.compression.codecs",
- "org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec")
}
_sc
@@ -241,29 +238,6 @@ object Main {
sys.exit(1)
}
- val logProps = new Properties()
- if (options.logQuiet) {
- logProps.put("log4j.rootLogger", "OFF, stderr")
-
- logProps.put("log4j.appender.stderr", "org.apache.log4j.ConsoleAppender")
- logProps.put("log4j.appender.stderr.Target", "System.err")
- logProps.put("log4j.appender.stderr.threshold", "OFF")
- logProps.put("log4j.appender.stderr.layout", "org.apache.log4j.PatternLayout")
- logProps.put("log4j.appender.stderr.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n")
- } else {
- logProps.put("log4j.rootLogger", "INFO, logfile")
-
- logProps.put("log4j.appender.logfile", "org.apache.log4j.FileAppender")
- logProps.put("log4j.appender.logfile.append", options.logAppend.toString)
- logProps.put("log4j.appender.logfile.file", options.logFile)
- logProps.put("log4j.appender.logfile.threshold", "INFO")
- logProps.put("log4j.appender.logfile.layout", "org.apache.log4j.PatternLayout")
- logProps.put("log4j.appender.logfile.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n")
- }
-
- LogManager.resetConfiguration()
- PropertyConfigurator.configure(logProps)
-
if (splitArgs.length == 1)
fail(s"hail: fatal: no commands given")
@@ -288,23 +262,9 @@ object Main {
val sc = SparkManager.createSparkContext("Hail", Option(options.master), "local[*]")
- val conf = sc.getConf
- conf.set("spark.ui.showConsoleProgress", "false")
- val progressBar = ProgressBarBuilder.build(sc)
-
- conf.set("spark.sql.parquet.compression.codec", options.parquetCompression)
-
- sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", options.blockSize * 1024L * 1024L)
-
- /* `DataFrame.write` writes one file per partition. Without this, read will split files larger than the default
- * parquet block size into multiple partitions. This causes `OrderedRDD` to fail since the per-partition range
- * no longer line up with the RDD partitions.
- *
- * For reasons we don't understand, the DataFrame code uses `SparkHadoopUtil.get.conf` instead of the Hadoop
- * configuration in the SparkContext. Set both for consistency.
- */
- SparkHadoopUtil.get.conf.setLong("parquet.block.size", 1099511627776L)
- sc.hadoopConfiguration.setLong("parquet.block.size", 1099511627776L)
+ configure(sc, logFile = options.logFile, quiet = options.logQuiet, append = options.logAppend,
+ parquetCompression = options.parquetCompression, blockSize = options.blockSize,
+ branchingFactor = options.branchingFactor, tmpDir = options.tmpDir)
val sqlContext = SparkManager.createSQLContext()
@@ -313,12 +273,9 @@ object Main {
sc.addJar(jar)
HailConfiguration.installDir = new File(jar).getParent + "/.."
- HailConfiguration.tmpDir = options.tmpDir
- HailConfiguration.branchingFactor = options.branchingFactor
runCommands(sc, sqlContext, invocations)
sc.stop()
- progressBar.stop()
}
}
diff --git a/src/main/scala/org/broadinstitute/hail/driver/VEP.scala b/src/main/scala/org/broadinstitute/hail/driver/VEP.scala
index 74e6c9a664a..c7598b61ed6 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/VEP.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/VEP.scala
@@ -211,7 +211,7 @@ object VEP extends Command {
val rootType =
vds.vaSignature.getOption(root)
.filter { t =>
- val r = t == vepSignature
+ val r = t == (if(csq) TString else vepSignature)
if (!r) {
if (options.force)
warn(s"type for ${ options.root } does not match vep signature, overwriting.")
diff --git a/src/main/scala/org/broadinstitute/hail/driver/package.scala b/src/main/scala/org/broadinstitute/hail/driver/package.scala
index d564e84199b..a2dced1a88c 100644
--- a/src/main/scala/org/broadinstitute/hail/driver/package.scala
+++ b/src/main/scala/org/broadinstitute/hail/driver/package.scala
@@ -1,8 +1,15 @@
package org.broadinstitute.hail
+import java.io.File
import java.util
+import java.util.Properties
+
+import org.apache.log4j.{Level, LogManager, PropertyConfigurator}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.{ProgressBarBuilder, SparkContext}
import org.broadinstitute.hail.utils._
import org.broadinstitute.hail.variant.VariantDataset
+
import scala.collection.JavaConverters._
package object driver {
@@ -39,4 +46,65 @@ package object driver {
CountResult(vds.nSamples, nVariants, nCalled)
}
+
+ def configure(sc: SparkContext, logFile: String, quiet: Boolean, append: Boolean,
+ parquetCompression: String, blockSize: Long, branchingFactor: Int, tmpDir: String) {
+ require(blockSize > 0)
+ require(branchingFactor > 0)
+
+ val logProps = new Properties()
+ if (quiet) {
+ logProps.put("log4j.rootLogger", "OFF, stderr")
+ logProps.put("log4j.appender.stderr", "org.apache.log4j.ConsoleAppender")
+ logProps.put("log4j.appender.stderr.Target", "System.err")
+ logProps.put("log4j.appender.stderr.threshold", "OFF")
+ logProps.put("log4j.appender.stderr.layout", "org.apache.log4j.PatternLayout")
+ logProps.put("log4j.appender.stderr.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n")
+ } else {
+ logProps.put("log4j.rootLogger", "INFO, logfile")
+ logProps.put("log4j.appender.logfile", "org.apache.log4j.FileAppender")
+ logProps.put("log4j.appender.logfile.append", append.toString)
+ logProps.put("log4j.appender.logfile.file", logFile)
+ logProps.put("log4j.appender.logfile.threshold", "INFO")
+ logProps.put("log4j.appender.logfile.layout", "org.apache.log4j.PatternLayout")
+ logProps.put("log4j.appender.logfile.layout.ConversionPattern", "%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n")
+ }
+
+ LogManager.resetConfiguration()
+ PropertyConfigurator.configure(logProps)
+
+ val conf = sc.getConf
+
+ conf.set("spark.ui.showConsoleProgress", "false")
+ val progressBar = ProgressBarBuilder.build(sc)
+
+ sc.hadoopConfiguration.set(
+ "io.compression.codecs",
+ "org.apache.hadoop.io.compress.DefaultCodec," +
+ "org.broadinstitute.hail.io.compress.BGzipCodec," +
+ "org.apache.hadoop.io.compress.GzipCodec")
+
+ conf.set("spark.sql.parquet.compression.codec", parquetCompression)
+ conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+
+ sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", blockSize * 1024L * 1024L)
+
+ /* `DataFrame.write` writes one file per partition. Without this, read will split files larger than the default
+ * parquet block size into multiple partitions. This causes `OrderedRDD` to fail since the per-partition range
+ * no longer line up with the RDD partitions.
+ *
+ * For reasons we don't understand, the DataFrame code uses `SparkHadoopUtil.get.conf` instead of the Hadoop
+ * configuration in the SparkContext. Set both for consistency.
+ */
+ SparkHadoopUtil.get.conf.setLong("parquet.block.size", 1099511627776L)
+ sc.hadoopConfiguration.setLong("parquet.block.size", 1099511627776L)
+
+
+ val jar = getClass.getProtectionDomain.getCodeSource.getLocation.toURI.getPath
+ sc.addJar(jar)
+
+ HailConfiguration.installDir = new File(jar).getParent + "/.."
+ HailConfiguration.tmpDir = tmpDir
+ HailConfiguration.branchingFactor = branchingFactor
+ }
}
diff --git a/src/main/scala/org/broadinstitute/hail/expr/AST.scala b/src/main/scala/org/broadinstitute/hail/expr/AST.scala
index 9967d896468..9018dd3de82 100644
--- a/src/main/scala/org/broadinstitute/hail/expr/AST.scala
+++ b/src/main/scala/org/broadinstitute/hail/expr/AST.scala
@@ -17,7 +17,9 @@ import scala.language.existentials
import scala.reflect.ClassTag
import org.broadinstitute.hail.utils.EitherIsAMonad._
-case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunctions: ArrayBuffer[Aggregator]) {
+case class EvalContext(st: SymbolTable,
+ a: ArrayBuffer[Any],
+ aggregationFunctions: ArrayBuffer[((Any) => Any, Aggregator)]) {
def setAll(args: Any*) {
args.zipWithIndex.foreach { case (arg, i) => a(i) = arg }
@@ -31,7 +33,7 @@ case class EvalContext(st: SymbolTable, a: ArrayBuffer[Any], aggregationFunction
object EvalContext {
def apply(symTab: SymbolTable): EvalContext = {
val a = new ArrayBuffer[Any]()
- val af = new ArrayBuffer[Aggregator]()
+ val af = new ArrayBuffer[((Any) => Any, Aggregator)]()
for ((i, t) <- symTab.values) {
if (i >= 0)
a += null
@@ -666,6 +668,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
localA(localIdx) = a
fn()
}
+
MappedAggregable(agg, t, mapF)
case error =>
parseError(s"method `$method' expects a lambda function (param => Any), got invalid mapping (param => $error)")
@@ -1143,7 +1146,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
- agg.ec.aggregationFunctions += new CountAggregator(aggF, localIdx)
+ agg.ec.aggregationFunctions += ((aggF, new CountAggregator(localIdx)))
() => localA(localIdx)
case (agg: TAggregable, "fraction", Array(Lambda(_, param, body))) =>
@@ -1157,7 +1160,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
- agg.ec.aggregationFunctions += new FractionAggregator(aggF, localIdx, localA, bodyFn, lambdaIdx)
+ agg.ec.aggregationFunctions += ((aggF, new FractionAggregator(localIdx, localA, bodyFn, lambdaIdx)))
() => localA(localIdx)
case (agg: TAggregable, "stats", Array()) =>
@@ -1168,7 +1171,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val t = agg.elementType
val aggF = agg.f
- agg.ec.aggregationFunctions += new StatAggregator(aggF, localIdx)
+ agg.ec.aggregationFunctions += ((aggF, new StatAggregator(localIdx)))
val getOp = (a: Any) => {
val sc = a.asInstanceOf[StatCounter]
@@ -1193,7 +1196,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
- agg.ec.aggregationFunctions += new CounterAggregator(aggF, localIdx)
+ agg.ec.aggregationFunctions += ((aggF, new CounterAggregator(localIdx)))
() => {
val m = localA(localIdx).asInstanceOf[mutable.HashMap[Any, Long]]
@@ -1211,7 +1214,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val vf = vAST.eval(ec)
- agg.ec.aggregationFunctions += new CallStatsAggregator(aggF, localIdx, vf)
+ agg.ec.aggregationFunctions += ((aggF, new CallStatsAggregator(localIdx, vf)))
() => {
val cs = localA(localIdx).asInstanceOf[CallStats]
@@ -1265,7 +1268,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
- agg.ec.aggregationFunctions += new HistAggregator(aggF, localIdx, indices)
+ agg.ec.aggregationFunctions += ((aggF, new HistAggregator(localIdx, indices)))
() => localA(localIdx).asInstanceOf[HistogramCombiner].toAnnotation
@@ -1276,7 +1279,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
- agg.ec.aggregationFunctions += new CollectAggregator(aggF, localIdx)
+ agg.ec.aggregationFunctions += ((aggF, new CollectAggregator(localIdx)))
() => localA(localIdx).asInstanceOf[ArrayBuffer[Any]].toIndexedSeq
case (agg: TAggregable, "infoScore", Array()) =>
@@ -1287,7 +1290,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val localPos = posn
val aggF = agg.f
- agg.ec.aggregationFunctions += new InfoScoreAggregator(aggF, localIdx)
+ agg.ec.aggregationFunctions += ((aggF, new InfoScoreAggregator(localIdx)))
() => localA(localIdx).asInstanceOf[InfoScoreCombiner].asAnnotation
case (agg: TAggregable, "inbreeding", Array(mafAST)) =>
@@ -1299,7 +1302,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
val maf = mafAST.eval(agg.ec)
- agg.ec.aggregationFunctions += new InbreedingAggregator(aggF, localIdx, maf)
+ agg.ec.aggregationFunctions += ((aggF, new InbreedingAggregator(localIdx, maf)))
() => localA(localIdx).asInstanceOf[InbreedingCombiner].asAnnotation
case (agg: TAggregable, "hardyWeinberg", Array()) =>
@@ -1310,7 +1313,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val localPos = posn
val aggF = agg.f
- agg.ec.aggregationFunctions += new HWEAggregator(aggF, localIdx)
+ agg.ec.aggregationFunctions += ((aggF, new HWEAggregator(localIdx)))
() => localA(localIdx).asInstanceOf[HWECombiner].asAnnotation
case (agg: TAggregable, "sum", Array()) =>
@@ -1322,8 +1325,8 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
val aggF = agg.f
(`type`: @unchecked) match {
- case TDouble => agg.ec.aggregationFunctions += new SumAggregator(aggF, localIdx)
- case TArray(TDouble) => agg.ec.aggregationFunctions += new SumArrayAggregator(aggF, localIdx, localPos)
+ case TDouble => agg.ec.aggregationFunctions += ((aggF, new SumAggregator(localIdx)))
+ case TArray(TDouble) => agg.ec.aggregationFunctions += ((aggF, new SumArrayAggregator(localIdx, localPos)))
}
() => localA(localIdx)
@@ -1649,9 +1652,10 @@ case class IndexOp(posn: Position, f: AST, idx: AST) extends AST(posn, Array(f,
} catch {
case e: java.lang.IndexOutOfBoundsException =>
ParserUtils.error(localPos,
- s"""Tried to access index [$i] on array ${ JsonMethods.compact(localT.toJSON(a)) } of length ${ a.length }
+ s"""Invalid array index: tried to access index [$i] on array `@1' of length ${ a.length }
| Hint: All arrays in Hail are zero-indexed (`array[0]' is the first element)
- | Hint: For accessing `A'-numbered info fields in split variants, `va.info.field[va.aIndex - 1]' is correct""".stripMargin)
+ | Hint: For accessing `A'-numbered info fields in split variants, `va.info.field[va.aIndex - 1]' is correct""".stripMargin,
+ JsonMethods.compact(localT.toJSON(a)))
case e: Throwable => throw e
})
@@ -1697,6 +1701,7 @@ case class SymRef(posn: Position, symbol: String) extends AST(posn) {
def eval(ec: EvalContext): () => Any = {
val localI = ec.st(symbol)._1
val localA = ec.a
+
if (localI < 0)
() => 0 // FIXME placeholder
else
diff --git a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala
index cbb6675d631..c7bc99ea53f 100644
--- a/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala
+++ b/src/main/scala/org/broadinstitute/hail/expr/FunctionRegistry.scala
@@ -356,8 +356,18 @@ object FunctionRegistry {
register("runif", { (min: Double, max: Double) => min + (max - min) * math.random })
register("rnorm", { (mean: Double, sd: Double) => mean + sd * scala.util.Random.nextGaussian() })
+ register("pnorm", { (x: Double) => pnorm(x) })
+ register("qnorm", { (p: Double) => qnorm(p) })
+
+ register("pchisq1tail", { (x: Double) => chiSquaredTail(1.0, x) })
+ register("qchisq1tail", { (p: Double) => inverseChiSquaredTailOneDF(p) })
+
registerConversion((x: Int) => x.toDouble, priority = 2)
registerConversion { (x: Long) => x.toDouble }
registerConversion { (x: Int) => x.toLong }
registerConversion { (x: Float) => x.toDouble }
+
+ register("gtj", (i: Int) => Genotype.gtPair(i).j)
+ register("gtk", (i: Int) => Genotype.gtPair(i).k)
+ register("gtIndex", (j: Int, k: Int) => Genotype.gtIndex(j, k))
}
diff --git a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala
index 2291d823750..fed27b945b7 100644
--- a/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala
+++ b/src/main/scala/org/broadinstitute/hail/expr/JoinAnnotator.scala
@@ -27,7 +27,7 @@ trait JoinAnnotator {
}
def buildInserter(code: String, t: Type, ec: EvalContext, expectedHead: String): (Type, Inserter) = {
- val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, expectedHead)
+ val (parseTypes, fns) = Parser.parseAnnotationArgs(code, ec, Some(expectedHead))
val inserterBuilder = mutable.ArrayBuilder.make[Inserter]
val finaltype = parseTypes.foldLeft(t) { case (t, (ids, signature)) =>
@@ -45,7 +45,7 @@ trait JoinAnnotator {
val queries = fns.map(_ ())
var newAnnotation = left
queries.indices.foreach { i =>
- newAnnotation = inserters(i)(newAnnotation, queries(i))
+ newAnnotation = inserters(i)(newAnnotation, Option(queries(i)))
}
newAnnotation
}
diff --git a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala
index 3b3713fe8dc..fbe15682002 100644
--- a/src/main/scala/org/broadinstitute/hail/expr/Parser.scala
+++ b/src/main/scala/org/broadinstitute/hail/expr/Parser.scala
@@ -20,6 +20,17 @@ object ParserUtils {
lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' }
}^""".stripMargin)
}
+
+ def error(pos: Position, msg: String, tr: Truncatable): Nothing = {
+ val lineContents = pos.longString.split("\n").head
+ val prefix = s":${ pos.line }:"
+ fatal(
+ s"""$msg
+ |$prefix$lineContents
+ |${ " " * prefix.length }${
+ lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' }
+ }^""".stripMargin, tr)
+ }
}
object Parser extends JavaTokenParsers {
@@ -64,13 +75,10 @@ object Parser extends JavaTokenParsers {
}
def parseIdentifierList(code: String): Array[String] = {
- if (code.matches("""\s*"""))
- Array.empty[String]
- else
- parseAll(identifierList, code) match {
- case Success(result, _) => result
- case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg)
- }
+ parseAll(identifierList, code) match {
+ case Success(result, _) => result
+ case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg)
+ }
}
def withPos[T](p: => Parser[T]): Parser[Positioned[T]] =
@@ -83,7 +91,7 @@ object Parser extends JavaTokenParsers {
}
}
- def parseExportArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = {
+ def parseNamedArgs(code: String, ec: EvalContext): (Option[Array[String]], Array[Type], () => Array[String]) = {
val result = parseAll(export_args, code) match {
case Success(r, _) => r
case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg)
@@ -133,8 +141,8 @@ object Parser extends JavaTokenParsers {
(someIf(names.nonEmpty, names), tb.result(), () => computations.flatMap(_ ()))
}
- def parseNamedArgs(code: String, ec: EvalContext): (Array[String], Array[Type], () => Array[String]) = {
- val (headerOption, ts, f) = parseExportArgs(code, ec)
+ def parseExportArgs(code: String, ec: EvalContext): (Array[String], Array[Type], () => Array[String]) = {
+ val (headerOption, ts, f) = parseNamedArgs(code, ec)
val header = headerOption match {
case Some(h) => h
case None => fatal(
@@ -144,22 +152,24 @@ object Parser extends JavaTokenParsers {
(header, ts, f)
}
- def parseAnnotationArgs(code: String, ec: EvalContext, expectedHead: String): (Array[(List[String], Type)], Array[() => Option[Any]]) = {
+ def parseAnnotationArgs(code: String, ec: EvalContext, expectedHead: Option[String]): (Array[(List[String], Type)], Array[() => Any]) = {
val arr = parseAll(annotationExpressions, code) match {
case Success(result, _) => result.asInstanceOf[Array[(List[String], AST)]]
case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg)
}
def checkType(l: List[String], t: BaseType): Type = {
- if (l.head == expectedHead)
- t match {
- case t: Type => t
- case bt => fatal(
- s"""Got invalid type `$t' from the result of `${ l.mkString(".") }'""".stripMargin)
- } else fatal(
- s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }'
- | Path should begin with `$expectedHead'
+ if (expectedHead.exists(l.head != _))
+ fatal(
+ s"""invalid annotation path `${ l.map(prettyIdentifier).mkString(".") }'
+ | Path should begin with `$expectedHead'
""".stripMargin)
+
+ t match {
+ case t: Type => t
+ case bt => fatal(
+ s"""Got invalid type `$t' from the result of `${ l.mkString(".") }'""".stripMargin)
+ }
}
val all = arr.map {
@@ -167,8 +177,10 @@ object Parser extends JavaTokenParsers {
ast.typecheck(ec)
val t = checkType(path, ast.`type`)
val f = ast.eval(ec)
- ((path.tail, t), () => Option(f()))
+ val name = if (expectedHead.isDefined) path.tail else path
+ ((name, t), () => f())
}
+
(all.map(_._1), all.map(_._2))
}
@@ -186,6 +198,19 @@ object Parser extends JavaTokenParsers {
path.tail
}
+ def parseNamedExprs(code: String, ec: EvalContext): Array[(String, BaseType, () => Option[Any])] = {
+ val parsed = parseAll(named_args, code) match {
+ case Success(result, _) => result.asInstanceOf[Array[(String, AST)]]
+ case NoSuccess(msg, _) => fatal(msg)
+ }
+
+ parsed.map { case (name, ast) =>
+ ast.typecheck(ec)
+ val f = ast.eval(ec)
+ (name, ast.`type`, () => Option(f()))
+ }
+ }
+
def parseExprs(code: String, ec: EvalContext): (Array[(BaseType, () => Option[Any])]) = {
if (code.matches("""\s*"""))
@@ -273,7 +298,7 @@ object Parser extends JavaTokenParsers {
tsvIdentifier ~ "=" ~ expr ^^ { case id ~ _ ~ expr => (id, expr) }
def annotationExpressions: Parser[Array[(List[String], AST)]] =
- rep1sep(annotationExpression, ",") ^^ {
+ repsep(annotationExpression, ",") ^^ {
_.toArray
}
@@ -295,7 +320,7 @@ object Parser extends JavaTokenParsers {
def identifier = backtickLiteral | ident
- def identifierList: Parser[Array[String]] = rep1sep(identifier, ",") ^^ {
+ def identifierList: Parser[Array[String]] = repsep(identifier, ",") ^^ {
_.toArray
}
diff --git a/src/main/scala/org/broadinstitute/hail/expr/Type.scala b/src/main/scala/org/broadinstitute/hail/expr/Type.scala
index e495bb0c429..a0351d29d88 100644
--- a/src/main/scala/org/broadinstitute/hail/expr/Type.scala
+++ b/src/main/scala/org/broadinstitute/hail/expr/Type.scala
@@ -6,6 +6,7 @@ import org.broadinstitute.hail.utils._
import org.broadinstitute.hail.annotations.{Annotation, AnnotationPathException, _}
import org.broadinstitute.hail.check.Arbitrary._
import org.broadinstitute.hail.check.{Gen, _}
+import org.broadinstitute.hail.keytable.KeyTable
import org.broadinstitute.hail.utils
import org.broadinstitute.hail.utils.{Interval, StringEscapeUtils}
import org.broadinstitute.hail.variant.{AltAllele, Genotype, Locus, Variant}
@@ -249,6 +250,14 @@ abstract class TAggregable extends BaseType {
def f: (Any) => Any
}
+case class KeyTableAggregable(ec: EvalContext, elementType: Type, idx: Int) extends TAggregable {
+ def f: (Any) => Any = {
+ (a: Any) => {
+ KeyTable.annotationToSeq(a, ec.st.size)(idx)
+ }
+ }
+}
+
case class BaseAggregable(ec: EvalContext, elementType: Type) extends TAggregable {
def f: (Any) => Any = identity
}
@@ -455,6 +464,8 @@ case class TStruct(fields: IndexedSeq[Field]) extends Type {
def selfField(name: String): Option[Field] = fieldIdx.get(name).map(i => fields(i))
+ def hasField(name: String): Boolean = fieldIdx.contains(name)
+
def size: Int = fields.length
override def getOption(path: List[String]): Option[Type] =
diff --git a/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala b/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala
index 005f7cea76f..3dd37176696 100644
--- a/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala
+++ b/src/main/scala/org/broadinstitute/hail/io/plink/ExportBedBimFam.scala
@@ -35,8 +35,4 @@ object ExportBedBimFam {
val id = s"${v.contig}:${v.start}:${v.ref}:${v.alt}"
s"""${v.contig}\t$id\t0\t${v.start}\t${v.alt}\t${v.ref}"""
}
-
- def makeFamRow(s: String): String = {
- s"0\t$s\t0\t0\t0\t-9"
- }
}
diff --git a/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala
new file mode 100644
index 00000000000..1f096289cee
--- /dev/null
+++ b/src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala
@@ -0,0 +1,382 @@
+package org.broadinstitute.hail.keytable
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.broadinstitute.hail.annotations._
+import org.broadinstitute.hail.check.Gen
+import org.broadinstitute.hail.expr._
+import org.broadinstitute.hail.methods.{Aggregators, Filter}
+import org.broadinstitute.hail.utils._
+import org.broadinstitute.hail.io.TextExporter
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+
+object KeyTable extends Serializable with TextExporter {
+
+ def importTextTable(sc: SparkContext, path: Array[String], keysStr: String, nPartitions: Int, config: TextTableConfiguration) = {
+ require(nPartitions > 1)
+
+ val files = sc.hadoopConfiguration.globAll(path)
+ if (files.isEmpty)
+ fatal("Arguments referred to no files")
+
+ val keys = Parser.parseIdentifierList(keysStr)
+
+ val (struct, rdd) =
+ TextTableReader.read(sc)(files, config, nPartitions)
+
+ val invalidKeys = keys.filter(!struct.hasField(_))
+ if (invalidKeys.nonEmpty)
+ fatal(s"invalid keys: ${ invalidKeys.mkString(", ") }")
+
+ KeyTable(rdd.map(_.value), struct, keys)
+ }
+
+ def annotationToSeq(a: Annotation, nFields: Int) = Option(a).map(_.asInstanceOf[Row].toSeq).getOrElse(Seq.fill[Any](nFields)(null))
+
+ def setEvalContext(ec: EvalContext, k: Annotation, v: Annotation, nKeys: Int, nValues: Int) =
+ ec.setAll(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues): _*)
+
+ def setEvalContext(ec: EvalContext, a: Annotation, nFields: Int) =
+ ec.setAll(annotationToSeq(a, nFields): _*)
+
+ def toSingleRDD(rdd: RDD[(Annotation, Annotation)], nKeys: Int, nValues: Int): RDD[Annotation] =
+ rdd.map { case (k, v) =>
+ val x = Annotation.fromSeq(annotationToSeq(k, nKeys) ++ annotationToSeq(v, nValues))
+ x
+ }
+
+ def apply(rdd: RDD[Annotation], signature: TStruct, keyNames: Array[String]): KeyTable = {
+ val keyFields = signature.fields.filter(fd => keyNames.contains(fd.name))
+ val keyIndices = keyFields.map(_.index)
+
+ val valueFields = signature.fields.filterNot(fd => keyNames.contains(fd.name))
+ val valueIndices = valueFields.map(_.index)
+
+ assert(keyIndices.toSet.intersect(valueIndices.toSet).isEmpty)
+
+ val nFields = signature.size
+
+ val newKeySignature = TStruct(keyFields.map(fd => (fd.name, fd.`type`)): _*)
+ val newValueSignature = TStruct(valueFields.map(fd => (fd.name, fd.`type`)): _*)
+
+ val newRDD = rdd.map { a =>
+ val r = annotationToSeq(a, nFields).zipWithIndex
+ val keyRow = keyIndices.map(i => r(i)._1)
+ val valueRow = valueIndices.map(i => r(i)._1)
+ (Annotation.fromSeq(keyRow), Annotation.fromSeq(valueRow))
+ }
+
+ KeyTable(newRDD, newKeySignature, newValueSignature)
+ }
+}
+
+case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) {
+ require(fieldNames.areDistinct())
+
+ def signature = keySignature.merge(valueSignature)._1
+
+ def fields = signature.fields
+
+ def keySchema = keySignature.schema
+
+ def valueSchema = valueSignature.schema
+
+ def schema = signature.schema
+
+ def keyNames = keySignature.fields.map(_.name).toArray
+
+ def valueNames = valueSignature.fields.map(_.name).toArray
+
+ def fieldNames = keyNames ++ valueNames
+
+ def nRows = rdd.count()
+
+ def nFields = fields.length
+
+ def nKeys = keySignature.size
+
+ def nValues = valueSignature.size
+
+ def same(other: KeyTable): Boolean = {
+ if (fields.toSet != other.fields.toSet) {
+ println(s"signature: this=${ schema } other=${ other.schema }")
+ false
+ } else if (keyNames.toSet != other.keyNames.toSet) {
+ println(s"keyNames: this=${ keyNames.mkString(",") } other=${ other.keyNames.mkString(",") }")
+ false
+ } else {
+ val thisFieldNames = valueNames
+ val otherFieldNames = other.valueNames
+
+ rdd.groupByKey().fullOuterJoin(other.rdd.groupByKey()).forall { case (k, (v1, v2)) =>
+ (v1, v2) match {
+ case (None, None) => true
+ case (Some(x), Some(y)) =>
+ val r1 = x.map(r => thisFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet
+ val r2 = y.map(r => otherFieldNames.zip(r.asInstanceOf[Row].toSeq).toMap).toSet
+ val res = r1 == r2
+ if (!res)
+ println(s"k=$k r1=${ r1.mkString(",") } r2=${ r2.mkString(",") }")
+ res
+ case _ =>
+ println(s"k=$k v1=$v1 v2=$v2")
+ false
+ }
+ }
+ }
+ }
+
+ def mapAnnotations[T](f: (Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] =
+ KeyTable.toSingleRDD(rdd, nKeys, nValues).map(a => f(a))
+
+ def mapAnnotations[T](f: (Annotation, Annotation) => T)(implicit tct: ClassTag[T]): RDD[T] =
+ rdd.map { case (k, v) => f(k, v) }
+
+ def query(code: String): (BaseType, (Annotation, Annotation) => Option[Any]) = {
+ val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*)
+ val nKeysLocal = nKeys
+ val nValuesLocal = nValues
+
+ val (t, f) = Parser.parse(code, ec)
+
+ val f2: (Annotation, Annotation) => Option[Any] = {
+ case (k, v) =>
+ KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
+ f()
+ }
+
+ (t, f2)
+ }
+
+ def querySingle(code: String): (BaseType, Querier) = {
+ val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*)
+ val nFieldsLocal = nFields
+
+ val (t, f) = Parser.parse(code, ec)
+
+ val f2: (Annotation) => Option[Any] = { a =>
+ KeyTable.setEvalContext(ec, a, nFieldsLocal)
+ f()
+ }
+
+ (t, f2)
+ }
+
+ def annotate(cond: String, keysStr: String): KeyTable = {
+ val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*)
+
+ val (parseTypes, fns) = Parser.parseAnnotationArgs(cond, ec, None)
+
+ val inserterBuilder = mutable.ArrayBuilder.make[Inserter]
+
+ val finalSignature = parseTypes.foldLeft(signature) { case (vs, (ids, signature)) =>
+ val (s: TStruct, i) = vs.insert(signature, ids)
+ inserterBuilder += i
+ s
+ }
+
+ val inserters = inserterBuilder.result()
+
+ val keys = Parser.parseIdentifierList(keysStr)
+
+ val nFieldsLocal = nFields
+
+ val f: Annotation => Annotation = { a =>
+ KeyTable.setEvalContext(ec, a, nFieldsLocal)
+
+ fns.zip(inserters)
+ .foldLeft(a) { case (a1, (fn, inserter)) =>
+ inserter(a1, Option(fn()))
+ }
+ }
+
+ KeyTable(mapAnnotations(f), finalSignature, keys)
+ }
+
+ def filter(p: (Annotation, Annotation) => Boolean): KeyTable =
+ copy(rdd = rdd.filter { case (k, v) => p(k, v) })
+
+ def filter(cond: String, keep: Boolean): KeyTable = {
+ val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*)
+ val nKeysLocal = nKeys
+ val nValuesLocal = nValues
+
+ val f: () => Option[Boolean] = Parser.parse[Boolean](cond, ec, TBoolean)
+
+ val p = (k: Annotation, v: Annotation) => {
+ KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
+ Filter.keepThis(f(), keep)
+ }
+
+ filter(p)
+ }
+
+ def join(other: KeyTable, joinType: String): KeyTable = {
+ if (keySignature != other.keySignature)
+ fatal(
+ s"""Key signatures must be identical.
+ |Left signature: ${ keySignature.toPrettyString(compact = true) }
+ |Right signature: ${ other.keySignature.toPrettyString(compact = true) }""".stripMargin)
+
+ val overlappingFields = valueNames.toSet.intersect(other.valueNames.toSet)
+ if (overlappingFields.nonEmpty)
+ fatal(
+ s"""Fields that are not keys cannot be present in both key-tables.
+ |Overlapping fields: ${ overlappingFields.mkString(", ") }""".stripMargin)
+
+ joinType match {
+ case "left" => leftJoin(other)
+ case "right" => rightJoin(other)
+ case "inner" => innerJoin(other)
+ case "outer" => outerJoin(other)
+ case _ => fatal("Invalid join type specified. Choose one of `left', `right', `inner', `outer'")
+ }
+ }
+
+ def leftJoin(other: KeyTable): KeyTable = {
+ require(keySignature == other.keySignature)
+
+ val (newValueSignature, merger) = valueSignature.merge(other.valueSignature)
+ val newRDD = rdd.leftOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl, vr.orNull)) }
+
+ KeyTable(newRDD, keySignature, newValueSignature)
+ }
+
+ def rightJoin(other: KeyTable): KeyTable = {
+ require(keySignature == other.keySignature)
+
+ val (newValueSignature, merger) = valueSignature.merge(other.valueSignature)
+ val newRDD = rdd.rightOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl.orNull, vr)) }
+
+ KeyTable(newRDD, keySignature, newValueSignature)
+ }
+
+ def outerJoin(other: KeyTable): KeyTable = {
+ require(keySignature == other.keySignature)
+
+ val (newValueSignature, merger) = valueSignature.merge(other.valueSignature)
+ val newRDD = rdd.fullOuterJoin(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl.orNull, vr.orNull)) }
+
+ KeyTable(newRDD, keySignature, newValueSignature)
+ }
+
+ def innerJoin(other: KeyTable): KeyTable = {
+ require(keySignature == other.keySignature)
+
+ val (newValueSignature, merger) = valueSignature.merge(other.valueSignature)
+ val newRDD = rdd.join(other.rdd).map { case (k, (vl, vr)) => (k, merger(vl, vr)) }
+
+ KeyTable(newRDD, keySignature, newValueSignature)
+ }
+
+ def forall(code: String): Boolean = {
+ val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*)
+ val nKeysLocal = nKeys
+ val nValuesLocal = nValues
+
+ val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean)
+
+ rdd.forall { case (k, v) =>
+ KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
+ f().getOrElse(false)
+ }
+ }
+
+ def exists(code: String): Boolean = {
+ val ec = EvalContext(fields.map(f => (f.name, f.`type`)): _*)
+ val nKeysLocal = nKeys
+ val nValuesLocal = nValues
+
+ val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean)
+
+ rdd.exists { case (k, v) =>
+ KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
+ f().getOrElse(false)
+ }
+ }
+
+ def export(sc: SparkContext, output: String, typesFile: String) = {
+ val hConf = sc.hadoopConfiguration
+
+ val ec = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*)
+
+ val (header, types, f) = Parser.parseNamedArgs(fieldNames.map(n => n + " = " + n).mkString(","), ec)
+
+ Option(typesFile).foreach { file =>
+ val typeInfo = header
+ .getOrElse(types.indices.map(i => s"_$i").toArray)
+ .zip(types)
+
+ KeyTable.exportTypes(file, hConf, typeInfo)
+ }
+
+ hConf.delete(output, recursive = true)
+
+ val nKeysLocal = nKeys
+ val nValuesLocal = nValues
+
+ rdd
+ .mapPartitions { it =>
+ val sb = new StringBuilder()
+ it.map { case (k, v) =>
+ sb.clear()
+ KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
+ f().foreachBetween(x => sb.append(x))(sb += '\t')
+ sb.result()
+ }
+ }.writeTable(output, header.map(_.mkString("\t")))
+ }
+
+ def aggregate(keyCond: String, aggCond: String): KeyTable = {
+
+ val aggregationEC = EvalContext(fields.map(fd => (fd.name, fd.`type`)): _*)
+ val ec = EvalContext(fields.zipWithIndex.map { case (fd, i) => (fd.name, (-1, KeyTableAggregable(aggregationEC, fd.`type`, i))) }.toMap)
+
+ val (keyNameParseTypes, keyF) =
+ if (keyCond != null)
+ Parser.parseAnnotationArgs(keyCond, aggregationEC, None)
+ else
+ (Array.empty[(List[String], Type)], Array.empty[() => Any])
+
+ val (aggNameParseTypes, aggF) = Parser.parseAnnotationArgs(aggCond, ec, None)
+
+ val keyNames = keyNameParseTypes.map(_._1.head)
+ val aggNames = aggNameParseTypes.map(_._1.head)
+
+ val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*)
+ val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*)
+
+ val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC)
+ val aggFunctions = aggregationEC.aggregationFunctions.map(_._1)
+
+ assert(zVals.length == aggFunctions.length)
+
+ val seqOp = (array: Array[Aggregator], b: Any) => {
+ KeyTable.setEvalContext(aggregationEC, b, nFields)
+ for (i <- array.indices) {
+ array(i).seqOp(aggFunctions(i)(b))
+ }
+ array
+ }
+
+ val nFieldsLocal = nFields
+
+ val newRDD = KeyTable.toSingleRDD(rdd, nKeys, nValues).mapPartitions { it =>
+ it.map { a =>
+ KeyTable.setEvalContext(aggregationEC, a, nFieldsLocal)
+ val key = Annotation.fromSeq(keyF.map(_ ()))
+ (key, a)
+ }
+ }.aggregateByKey(zVals)(seqOp, combOp)
+ .map { case (k, agg) =>
+ resultOp(agg)
+ (k, Annotation.fromSeq(aggF.map(_ ())))
+ }
+
+ KeyTable(newRDD, keySignature, valueSignature)
+ }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala
index 00be90c835f..b07b3044e1c 100644
--- a/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala
+++ b/src/main/scala/org/broadinstitute/hail/methods/Aggregators.scala
@@ -15,7 +15,10 @@ import scala.util.parsing.input.Position
object Aggregators {
def buildVariantAggregations(vds: VariantDataset, ec: EvalContext): Option[(Variant, Annotation, Iterable[Genotype]) => Unit] = {
- val aggregators = ec.aggregationFunctions.toArray
+
+ val aggFunctions = ec.aggregationFunctions.map(_._1)
+ val aggregators = ec.aggregationFunctions.map(_._2)
+
val aggregatorA = ec.a
if (aggregators.nonEmpty) {
@@ -28,29 +31,29 @@ object Aggregators {
aggregatorA(0) = v
aggregatorA(1) = va
(gs, localSamplesBc.value, localAnnotationsBc.value).zipped
- .foreach {
- case (g, s, sa) =>
- aggregatorA(2) = s
- aggregatorA(3) = sa
- baseArray.foreach {
- _.seqOp(g)
- }
+ .foreach { case (g, s, sa) =>
+ aggregatorA(2) = s
+ aggregatorA(3) = sa
+ (baseArray, aggFunctions).zipped.foreach { case (agg, aggF) =>
+ agg.seqOp(aggF(g))
+ }
}
baseArray.foreach { agg => aggregatorA(agg.idx) = agg.result }
}
Some(f)
- } else None
+ } else
+ None
}
def buildSampleAggregations(vds: VariantDataset, ec: EvalContext): Option[(String) => Unit] = {
- val aggregators = ec.aggregationFunctions.toArray
+ val aggFunctions = ec.aggregationFunctions.map(_._1)
+ val aggregators = ec.aggregationFunctions.map(_._2)
val aggregatorA = ec.a
if (aggregators.isEmpty)
None
else {
-
val localSamplesBc = vds.sampleIdsBc
val localAnnotationsBc = vds.sampleAnnotationsBc
@@ -73,7 +76,7 @@ object Aggregators {
var j = 0
while (j < nAggregations) {
- arr(i, j).seqOp(g)
+ arr(i, j).seqOp(aggFunctions(j)(g))
j += 1
}
i += 1
@@ -100,7 +103,8 @@ object Aggregators {
def makeFunctions(ec: EvalContext): (Array[Aggregator], (Array[Aggregator], (Any, Any)) => Array[Aggregator],
(Array[Aggregator], Array[Aggregator]) => Array[Aggregator], (Array[Aggregator]) => Unit) = {
- val aggregators = ec.aggregationFunctions.toArray
+ val aggFunctions = ec.aggregationFunctions.map(_._1)
+ val aggregators = ec.aggregationFunctions.map(_._2)
val arr = ec.a
@@ -116,7 +120,7 @@ object Aggregators {
val (aggT, annotation) = b
ec.set(0, annotation)
for (i <- array.indices) {
- array(i).seqOp(aggT)
+ array(i).seqOp(aggFunctions(i)(aggT))
}
array
}
@@ -135,15 +139,14 @@ object Aggregators {
}
}
-class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Long] {
+class CountAggregator(val idx: Int) extends TypedAggregator[Long] {
var _state = 0L
def result = _state
def seqOp(x: Any) {
- val v = f(x)
- if (f(x) != null)
+ if (x != null)
_state += 1
}
@@ -151,10 +154,10 @@ class CountAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Lon
_state += agg2._state
}
- def copy() = new CountAggregator(f, idx)
+ def copy() = new CountAggregator(idx)
}
-class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int)
+class FractionAggregator(val idx: Int, localA: ArrayBuffer[Any], bodyFn: () => Any, lambdaIdx: Int)
extends TypedAggregator[java.lang.Double] {
var _num = 0L
@@ -167,10 +170,9 @@ class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any]
_num.toDouble / _denom
def seqOp(x: Any) {
- val r = f(x)
- if (r != null) {
+ if (x != null) {
_denom += 1
- localA(lambdaIdx) = r
+ localA(lambdaIdx) = x
if (bodyFn().asInstanceOf[Boolean])
_num += 1
}
@@ -181,37 +183,35 @@ class FractionAggregator(f: (Any) => Any, val idx: Int, localA: ArrayBuffer[Any]
_denom += agg2._denom
}
- def copy() = new FractionAggregator(f, idx, localA, bodyFn, lambdaIdx)
+ def copy() = new FractionAggregator(idx, localA, bodyFn, lambdaIdx)
}
-class StatAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[StatCounter] {
+class StatAggregator(val idx: Int) extends TypedAggregator[StatCounter] {
var _state = new StatCounter()
def result = _state
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- _state.merge(DoubleNumericConversion.to(r))
+ if (x != null)
+ _state.merge(DoubleNumericConversion.to(x))
}
def combOp(agg2: this.type) {
_state.merge(agg2._state)
}
- def copy() = new StatAggregator(f, idx)
+ def copy() = new StatAggregator(idx)
}
-class CounterAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] {
+class CounterAggregator(val idx: Int) extends TypedAggregator[mutable.HashMap[Any, Long]] {
var m = new mutable.HashMap[Any, Long]
def result = m
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- m.updateValue(r, 0L, _ + 1)
+ if (x != null)
+ m.updateValue(x, 0L, _ + 1)
}
def combOp(agg2: this.type) {
@@ -220,10 +220,10 @@ class CounterAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[m
}
}
- def copy() = new CounterAggregator(f, idx)
+ def copy() = new CounterAggregator(idx)
}
-class HistAggregator(f: (Any) => Any, val idx: Int, indices: Array[Double])
+class HistAggregator(val idx: Int, indices: Array[Double])
extends TypedAggregator[HistogramCombiner] {
var _state = new HistogramCombiner(indices)
@@ -231,90 +231,85 @@ class HistAggregator(f: (Any) => Any, val idx: Int, indices: Array[Double])
def result = _state
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- _state.merge(DoubleNumericConversion.to(r))
+ if (x != null)
+ _state.merge(DoubleNumericConversion.to(x))
}
def combOp(agg2: this.type) {
_state.merge(agg2._state)
}
- def copy() = new HistAggregator(f, idx, indices)
+ def copy() = new HistAggregator(idx, indices)
}
-class CollectAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] {
+class CollectAggregator(val idx: Int) extends TypedAggregator[ArrayBuffer[Any]] {
var _state = new ArrayBuffer[Any]
def result = _state
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- _state += f(x)
+ if (x != null)
+ _state += x
}
def combOp(agg2: this.type) = _state ++= agg2._state
- def copy() = new CollectAggregator(f, idx)
+ def copy() = new CollectAggregator(idx)
}
-class InfoScoreAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[InfoScoreCombiner] {
+class InfoScoreAggregator(val idx: Int) extends TypedAggregator[InfoScoreCombiner] {
var _state = new InfoScoreCombiner()
def result = _state
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- _state.merge(r.asInstanceOf[Genotype])
+ if (x != null)
+ _state.merge(x.asInstanceOf[Genotype])
}
def combOp(agg2: this.type) {
_state.merge(agg2._state)
}
- def copy() = new InfoScoreAggregator(f, idx)
+ def copy() = new InfoScoreAggregator(idx)
}
-class HWEAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[HWECombiner] {
+class HWEAggregator(val idx: Int) extends TypedAggregator[HWECombiner] {
var _state = new HWECombiner()
def result = _state
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- _state.merge(r.asInstanceOf[Genotype])
+ if (x != null)
+ _state.merge(x.asInstanceOf[Genotype])
}
def combOp(agg2: this.type) {
_state.merge(agg2._state)
}
- def copy() = new HWEAggregator(f, idx)
+ def copy() = new HWEAggregator(idx)
}
-class SumAggregator(f: (Any) => Any, val idx: Int) extends TypedAggregator[Double] {
+class SumAggregator(val idx: Int) extends TypedAggregator[Double] {
var _state = 0d
def result = _state
def seqOp(x: Any) {
- val r = f(x)
- if (r != null)
- _state += DoubleNumericConversion.to(r)
+ if (x != null)
+ _state += DoubleNumericConversion.to(x)
}
def combOp(agg2: this.type) = _state += agg2._state
- def copy() = new SumAggregator(f, idx)
+ def copy() = new SumAggregator(idx)
}
-class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position)
+class SumArrayAggregator(val idx: Int, localPos: Position)
extends TypedAggregator[IndexedSeq[Double]] {
var _state: Array[Double] = _
@@ -322,7 +317,7 @@ class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position)
def result = _state
def seqOp(x: Any) {
- val r = f(x).asInstanceOf[IndexedSeq[Any]]
+ val r = x.asInstanceOf[IndexedSeq[Any]]
if (r != null) {
if (_state == null)
_state = r.map(x => if (x == null) 0d else DoubleNumericConversion.to(x)).toArray
@@ -353,10 +348,10 @@ class SumArrayAggregator(f: (Any) => Any, val idx: Int, localPos: Position)
_state(i) += agg2state(i)
}
- def copy() = new SumArrayAggregator(f, idx, localPos)
+ def copy() = new SumArrayAggregator(idx, localPos)
}
-class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any)
+class CallStatsAggregator(val idx: Int, variantF: () => Any)
extends TypedAggregator[CallStats] {
var first = true
@@ -378,9 +373,8 @@ class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any)
}
if (combiner != null) {
- val r = f(x)
- if (r != null)
- combiner.merge(r.asInstanceOf[Genotype])
+ if (x != null)
+ combiner.merge(x.asInstanceOf[Genotype])
}
}
@@ -392,26 +386,25 @@ class CallStatsAggregator(f: (Any) => Any, val idx: Int, variantF: () => Any)
combiner.merge(agg2.combiner)
}
- def copy(): TypedAggregator[CallStats] = new CallStatsAggregator(f, idx, variantF)
+ def copy(): TypedAggregator[CallStats] = new CallStatsAggregator(idx, variantF)
}
-class InbreedingAggregator(f: (Any) => Any, localIdx: Int, getAF: () => Any) extends TypedAggregator[InbreedingCombiner] {
+class InbreedingAggregator(localIdx: Int, getAF: () => Any) extends TypedAggregator[InbreedingCombiner] {
var _state = new InbreedingCombiner()
def result = _state
def seqOp(x: Any) = {
- val r = f(x)
val af = getAF()
- if (r != null && af != null)
- _state.merge(r.asInstanceOf[Genotype], DoubleNumericConversion.to(af))
+ if (x != null && af != null)
+ _state.merge(x.asInstanceOf[Genotype], DoubleNumericConversion.to(af))
}
def combOp(agg2: this.type) = _state.merge(agg2.asInstanceOf[InbreedingAggregator]._state)
- def copy() = new InbreedingAggregator(f, localIdx, getAF)
+ def copy() = new InbreedingAggregator(localIdx, getAF)
def idx = localIdx
}
diff --git a/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala b/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala
index 55d926153a0..5d52748adef 100644
--- a/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala
+++ b/src/main/scala/org/broadinstitute/hail/stats/HistogramCombiner.scala
@@ -9,7 +9,7 @@ object HistogramCombiner {
def schema: Type = TStruct(
"binEdges" -> TArray(TDouble),
"binFrequencies" -> TArray(TLong),
- "nSmaller" -> TLong,
+ "nLess" -> TLong,
"nGreater" -> TLong)
}
@@ -18,13 +18,13 @@ class HistogramCombiner(indices: Array[Double]) extends Serializable {
val min = indices.head
val max = indices(indices.length - 1)
- var nSmaller = 0L
+ var nLess = 0L
var nGreater = 0L
- val density = Array.fill(indices.length - 1)(0L)
+ val frequency = Array.fill(indices.length - 1)(0L)
def merge(d: Double): HistogramCombiner = {
if (d < min)
- nSmaller += 1
+ nLess += 1
else if (d > max)
nGreater += 1
else if (!d.isNaN) {
@@ -32,27 +32,28 @@ class HistogramCombiner(indices: Array[Double]) extends Serializable {
val ind = if (bs < 0)
-bs - 2
else
- math.min(bs, density.length - 1)
- assert(ind < density.length && ind >= 0, s"""found out of bounds index $ind
- | Resulted from trying to merge $d
- | Indices are [${indices.mkString(", ")}]
- | Binary search index was $bs""".stripMargin)
- density(ind) += 1
+ math.min(bs, frequency.length - 1)
+ assert(ind < frequency.length && ind >= 0,
+ s"""found out of bounds index $ind
+ | Resulted from trying to merge $d
+ | Indices are [${ indices.mkString(", ") }]
+ | Binary search index was $bs""".stripMargin)
+ frequency(ind) += 1
}
this
}
def merge(that: HistogramCombiner): HistogramCombiner = {
- require(density.length == that.density.length)
+ require(frequency.length == that.frequency.length)
- nSmaller += that.nSmaller
+ nLess += that.nLess
nGreater += that.nGreater
- for (i <- density.indices)
- density(i) += that.density(i)
+ for (i <- frequency.indices)
+ frequency(i) += that.frequency(i)
this
}
- def toAnnotation: Annotation = Annotation(indices: IndexedSeq[Double], density: IndexedSeq[Long], nSmaller, nGreater)
+ def toAnnotation: Annotation = Annotation(indices: IndexedSeq[Double], frequency: IndexedSeq[Long], nLess, nGreater)
}
diff --git a/src/main/scala/org/broadinstitute/hail/stats/package.scala b/src/main/scala/org/broadinstitute/hail/stats/package.scala
index 6abffa56048..2e43d01fb4b 100644
--- a/src/main/scala/org/broadinstitute/hail/stats/package.scala
+++ b/src/main/scala/org/broadinstitute/hail/stats/package.scala
@@ -2,10 +2,9 @@ package org.broadinstitute.hail
import breeze.linalg.Matrix
import org.apache.commons.math3.distribution.HypergeometricDistribution
-import org.apache.commons.math3.special.Gamma
+import org.apache.commons.math3.special.{Erf, Gamma}
import org.apache.spark.SparkContext
import org.broadinstitute.hail.annotations.Annotation
-import org.broadinstitute.hail.expr.{TDouble, TInt, TStruct}
import org.broadinstitute.hail.utils._
import org.broadinstitute.hail.variant.{Genotype, Variant, VariantDataset, VariantMetadata, VariantSampleMatrix}
@@ -271,11 +270,26 @@ package object stats {
Array(Option(pvalue), oddsRatioEstimate, confInterval._1, confInterval._2)
}
+ val sqrt2 = math.sqrt(2)
+
+ // Returns the p for which p = Prob(Z < x) with Z a standard normal RV
+ def pnorm(x: Double) = 0.5 * (1 + Erf.erf(x / sqrt2))
+
+ // Returns the x for which p = Prob(Z < x) with Z a standard normal RV
+ def qnorm(p: Double) = sqrt2 * Erf.erfInv(2 * p - 1)
+
+ // Returns the p for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with one degree of freedom
// This implementation avoids the round-off error truncation issue in
// org.apache.commons.math3.distribution.ChiSquaredDistribution,
// which computes the CDF with regularizedGammaP and p = 1 - CDF.
def chiSquaredTail(df: Double, x: Double) = Gamma.regularizedGammaQ(df / 2, x / 2)
+ // Returns the x for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with one degree of freedom
+ def inverseChiSquaredTailOneDF(p: Double) = {
+ val q = qnorm(0.5 * p)
+ q * q
+ }
+
def uninitialized[T]: T = {
class A {
var x: T = _
diff --git a/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala b/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala
index 93e35866f2a..19589520f0f 100644
--- a/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala
+++ b/src/main/scala/org/broadinstitute/hail/utils/TextTableReader.scala
@@ -46,6 +46,11 @@ trait TextTableOptions {
)
}
+object TextTableConfiguration {
+ def apply(types: String, commentChar: String, separator: String, missing: String, noHeader: Boolean, impute: Boolean): TextTableConfiguration =
+ TextTableConfiguration(Parser.parseAnnotationTypes(Option(types).getOrElse("")), Option(commentChar), separator, missing, noHeader, impute)
+}
+
case class TextTableConfiguration(
types: Map[String, Type] = Map.empty[String, Type],
commentChar: Option[String] = None,
diff --git a/src/main/scala/org/broadinstitute/hail/utils/package.scala b/src/main/scala/org/broadinstitute/hail/utils/package.scala
index 3c476e45829..4c41172fe5c 100644
--- a/src/main/scala/org/broadinstitute/hail/utils/package.scala
+++ b/src/main/scala/org/broadinstitute/hail/utils/package.scala
@@ -15,7 +15,44 @@ package object utils extends Logging
with richUtils.Implicits
with utils.NumericImplicits {
- class FatalException(msg: String, logMsg: Option[String] = None) extends RuntimeException(msg)
+ class FatalException(val msg: String, val logMsg: Option[String] = None) extends RuntimeException(msg)
+
+ def digForFatal(e: Throwable): Option[String] = {
+ val r = e match {
+ case f: FatalException =>
+ println(s"found fatal $f")
+ Some(s"${ e.getMessage }")
+ case _ =>
+ Option(e.getCause).flatMap(c => digForFatal(c))
+ }
+ r
+ }
+
+ def deepestMessage(e: Throwable): String = {
+ var iterE = e
+ while (iterE.getCause != null)
+ iterE = iterE.getCause
+
+ s"${ e.getClass.getSimpleName }: ${ e.getLocalizedMessage }"
+ }
+
+ def expandException(e: Throwable): String = {
+ val msg = e match {
+ case f: FatalException => f.logMsg.getOrElse(f.msg)
+ case _ => e.getLocalizedMessage
+ }
+ s"${ e.getClass.getName }: $msg\n\tat ${ e.getStackTrace.mkString("\n\tat ") }${
+ Option(e.getCause).map(exception => expandException(exception)).getOrElse("")
+ }"
+ }
+
+ def getMinimalMessage(e: Exception): String = {
+ val fatalOption = digForFatal(e)
+ val prefix = if (fatalOption.isDefined) "fatal" else "caught exception"
+ val msg = fatalOption.getOrElse(deepestMessage(e))
+ log.error(s"hail: $prefix: $msg\nFrom ${ expandException(e) }")
+ msg
+ }
trait Truncatable {
def truncate: String
diff --git a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala
index 97d1acc8591..d1e3bdbd98f 100644
--- a/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala
+++ b/src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala
@@ -23,6 +23,7 @@ import org.json4s._
import org.json4s.jackson.JsonMethods
import org.kududb.spark.kudu.{KuduContext, _}
import Variant.orderedKey
+import org.broadinstitute.hail.keytable.KeyTable
import org.broadinstitute.hail.methods.{Aggregators, Filter}
import org.broadinstitute.hail.utils
@@ -598,6 +599,76 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata,
*/
}
+ def aggregateByKey(keyCond: String, aggCond: String): KeyTable = {
+ val aggregationEC = EvalContext(Map(
+ "v" -> (0, TVariant),
+ "va" -> (1, vaSignature),
+ "s" -> (2, TSample),
+ "sa" -> (3, saSignature),
+ "global" -> (4, globalSignature),
+ "g" -> (5, TGenotype)))
+
+ val ec = EvalContext(Map(
+ "v" -> (0, TVariant),
+ "va" -> (1, vaSignature),
+ "s" -> (2, TSample),
+ "sa" -> (3, saSignature),
+ "global" -> (4, globalSignature),
+ "gs" -> (-1, BaseAggregable(aggregationEC, TGenotype))))
+
+ val ktEC = EvalContext(
+ aggregationEC.st.map { case (name, (i, t)) => name -> (-1, KeyTableAggregable(aggregationEC, t.asInstanceOf[Type], i)) }
+ )
+
+ ec.set(4, globalAnnotation)
+ aggregationEC.set(4, globalAnnotation)
+
+ val (keyNameParseTypes, keyF) =
+ if (keyCond != null)
+ Parser.parseAnnotationArgs(keyCond, ec, None)
+ else
+ (Array.empty[(List[String], Type)], Array.empty[() => Any])
+
+ val (aggNameParseTypes, aggF) =
+ if (aggCond != null)
+ Parser.parseAnnotationArgs(aggCond, ktEC, None)
+ else
+ (Array.empty[(List[String], Type)], Array.empty[() => Any])
+
+ val keyNames = keyNameParseTypes.map(_._1.head)
+ val aggNames = aggNameParseTypes.map(_._1.head)
+
+ val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*)
+ val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*)
+
+ val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC)
+ val aggFunctions = aggregationEC.aggregationFunctions.map(_._1)
+
+ val localGlobalAnnotation = globalAnnotation
+
+ val seqOp = (array: Array[Aggregator], r: Annotation) => {
+ KeyTable.setEvalContext(aggregationEC, r, 6)
+ for (i <- array.indices) {
+ array(i).seqOp(aggFunctions(i)(r))
+ }
+ array
+ }
+
+ val ktRDD = mapPartitionsWithAll { it =>
+ it.map { case (v, va, s, sa, g) =>
+ ec.setAll(v, va, s, sa, g)
+ val key = Annotation.fromSeq(keyF.map(_ ()))
+ (key, Annotation(v, va, s, sa, localGlobalAnnotation, g))
+ }
+ }.aggregateByKey(zVals)(seqOp, combOp)
+ .map { case (k, agg) =>
+ resultOp(agg)
+ (k, Annotation.fromSeq(aggF.map(_ ())))
+ }
+
+ KeyTable(ktRDD, keySignature, valueSignature)
+ }
+
def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(String, T)] = {
val localtct = tct
@@ -923,6 +994,18 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata,
sqlContext.createDataFrame(rowRDD, schema)
}
+ def samplesDF(sqlContext: SQLContext): DataFrame = {
+ val rowRDD = sparkContext.parallelize(
+ sampleIdsAndAnnotations.map { case (s, sa) =>
+ Row(s, SparkAnnotationImpex.exportAnnotation(sa, saSignature))
+ })
+ val schema = StructType(Array(
+ StructField("sample", StringType, nullable = false),
+ StructField("sa", saSignature.schema, nullable = true)
+ ))
+
+ sqlContext.createDataFrame(rowRDD, schema)
+ }
}
// FIXME AnyVal Scala 2.11
diff --git a/src/test/resources/sampleAnnotations2.tsv b/src/test/resources/sampleAnnotations2.tsv
new file mode 100644
index 00000000000..c4ded00778e
--- /dev/null
+++ b/src/test/resources/sampleAnnotations2.tsv
@@ -0,0 +1,124 @@
+Sample qPhen2 qPhen3
+HG00096 5540.8 27694
+HG00097 3327.2 16626
+HG00099 1451.2 7246
+HG00100 5714.8 28564
+HG00101 2417.6 12078
+HG00102 3948 19730
+HG00103 372.2 1851
+HG00105 4455.6 22268
+HG00106 5296.8 26474
+HG00107 5945.2 29716
+HG00108 3295 16465
+HG00109 6519 32585
+HG00110 4163.2 20806
+HG00111 6013 30055
+HG00112 4918 24580
+HG00113 1769 8835
+HG00114 6251 31245
+HG00115 5638 28180
+HG00116 2548.4 12732
+HG00117 4724.4 23612
+HG00118 3573.4 17857
+HG00119 6177.2 30876
+HG00120 3919.8 19589
+HG00121 966.4 4822
+HG00122 0 -10
+HG00123 5662.2 28301
+HG00124 538.2 2681
+HG00125 2893.2 14456
+HG00126 5506 27520
+HG00127 2044.8 10214
+HG00128 561.4 2797
+HG00129 1630.2 8141
+HG00130 5212 26050
+HG00131 4312.4 21552
+HG00132 2222.4 11102
+HG00133 4943.2 24706
+HG00136 2469.6 12338
+HG00137 3757.2 18776
+HG00138 1799 8985
+HG00139 385.6 1918
+HG00140 0 -10
+HG00096_B 5540.8 27694
+HG00097_B 3327.2 16626
+HG00099_B 1451.2 7246
+HG00100_B 5714.8 28564
+HG00101_B 2417.6 12078
+HG00102_B 3948 19730
+HG00103_B 372.2 1851
+HG00105_B 4455.6 22268
+HG00106_B 5296.8 26474
+HG00107_B 5945.2 29716
+HG00108_B 3295 16465
+HG00109_B 6519 32585
+HG00110_B 4163.2 20806
+HG00111_B 6013 30055
+HG00112_B 4918 24580
+HG00113_B 1769 8835
+HG00114_B 6251 31245
+HG00115_B 5638 28180
+HG00116_B 2548.4 12732
+HG00117_B 4724.4 23612
+HG00118_B 3573.4 17857
+HG00119_B 6177.2 30876
+HG00120_B 3919.8 19589
+HG00121_B 966.4 4822
+HG00122_B 0 -10
+HG00123_B 5662.2 28301
+HG00124_B 538.2 2681
+HG00125_B 2893.2 14456
+HG00126_B 5506 27520
+HG00127_B 2044.8 10214
+HG00128_B 561.4 2797
+HG00129_B 1630.2 8141
+HG00130_B 5212 26050
+HG00131_B 4312.4 21552
+HG00132_B 2222.4 11102
+HG00133_B 4943.2 24706
+HG00136_B 2469.6 12338
+HG00137_B 3757.2 18776
+HG00138_B 1799 8985
+HG00139_B 385.6 1918
+HG00140_B 0 -10
+HG00096_B_B 5540.8 27694
+HG00097_B_B 3327.2 16626
+HG00099_B_B 1451.2 7246
+HG00100_B_B 5714.8 28564
+HG00101_B_B 2417.6 12078
+HG00102_B_B 3948 19730
+HG00103_B_B 372.2 1851
+HG00105_B_B 4455.6 22268
+HG00106_B_B 5296.8 26474
+HG00107_B_B 5945.2 29716
+HG00108_B_B 3295 16465
+HG00109_B_B 6519 32585
+HG00110_B_B 4163.2 20806
+HG00111_B_B 6013 30055
+HG00112_B_B 4918 24580
+HG00113_B_B 1769 8835
+HG00114_B_B 6251 31245
+HG00115_B_B 5638 28180
+HG00116_B_B 2548.4 12732
+HG00117_B_B 4724.4 23612
+HG00118_B_B 3573.4 17857
+HG00119_B_B 6177.2 30876
+HG00120_B_B 3919.8 19589
+HG00121_B_B 966.4 4822
+HG00122_B_B 0 -10
+HG00123_B_B 5662.2 28301
+HG00124_B_B 538.2 2681
+HG00125_B_B 2893.2 14456
+HG00126_B_B 5506 27520
+HG00127_B_B 2044.8 10214
+HG00128_B_B 561.4 2797
+HG00129_B_B 1630.2 8141
+HG00130_B_B 5212 26050
+HG00131_B_B 4312.4 21552
+HG00132_B_B 2222.4 11102
+HG00133_B_B 4943.2 24706
+HG00136_B_B 2469.6 12338
+HG00137_B_B 3757.2 18776
+HG00138_B_B 1799 8985
+HG00139_B_B 385.6 1918
+HG00140_B_B 0 -10
diff --git a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala
index 94bfd70cfdf..bbe2cddac4a 100644
--- a/src/test/scala/org/broadinstitute/hail/SparkSuite.scala
+++ b/src/test/scala/org/broadinstitute/hail/SparkSuite.scala
@@ -34,6 +34,9 @@ class SparkSuite extends TestNGSuite {
val jar = getClass.getProtectionDomain.getCodeSource.getLocation.toURI.getPath
HailConfiguration.installDir = new File(jar).getParent + "/.."
HailConfiguration.tmpDir = "/tmp"
+
+ driver.configure(sc, logFile = "hail.log", quiet = true, append = false,
+ parquetCompression = "uncompressed", blockSize = 1L, branchingFactor = 50, tmpDir = "/tmp")
}
@AfterClass(alwaysRun = true)
diff --git a/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala
new file mode 100644
index 00000000000..c72b145e678
--- /dev/null
+++ b/src/test/scala/org/broadinstitute/hail/driver/AggregateByKeySuite.scala
@@ -0,0 +1,61 @@
+package org.broadinstitute.hail.driver
+
+import org.broadinstitute.hail.SparkSuite
+import org.broadinstitute.hail.utils._
+import org.broadinstitute.hail.variant._
+import org.testng.annotations.Test
+
+class AggregateByKeySuite extends SparkSuite {
+ @Test def replicateSampleAggregation() = {
+ val inputVCF = "src/test/resources/sample.vcf"
+ var s = State(sc, sqlContext)
+ s = ImportVCF.run(s, Array(inputVCF))
+ s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()"))
+ val kt = s.vds.aggregateByKey("Sample = s", "nHet = g.map(g => g.isHet.toInt).sum().toLong")
+
+ val (_, ktHetQuery) = kt.query("nHet")
+ val (_, ktSampleQuery) = kt.query("Sample")
+ val (_, saHetQuery) = s.vds.querySA("sa.nHet")
+
+ val ktSampleResults = kt.rdd.map { case (k, v) =>
+ (ktSampleQuery(k, v).map(_.asInstanceOf[String]), ktHetQuery(k, v).map(_.asInstanceOf[Long]))
+ }.collectAsMap()
+
+ assert(s.vds.sampleIdsAndAnnotations.forall { case (sid, sa) => saHetQuery(sa) == ktSampleResults(Option(sid)) })
+ }
+
+ @Test def replicateVariantAggregation() = {
+ val inputVCF = "src/test/resources/sample.vcf"
+ var s = State(sc, sqlContext)
+ s = ImportVCF.run(s, Array(inputVCF))
+ s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()"))
+ val kt = s.vds.aggregateByKey("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong")
+
+ val (_, ktHetQuery) = kt.query("nHet")
+ val (_, ktVariantQuery) = kt.query("Variant")
+ val (_, vaHetQuery) = s.vds.queryVA("va.nHet")
+
+ val ktVariantResults = kt.rdd.map { case (k, v) =>
+ (ktVariantQuery(k, v).map(_.asInstanceOf[Variant]), ktHetQuery(k, v).map(_.asInstanceOf[Long]))
+ }.collectAsMap()
+
+ assert(s.vds.variantsAndAnnotations.forall { case (v, va) => vaHetQuery(va) == ktVariantResults(Option(v)) })
+ }
+
+ @Test def replicateGlobalAggregation() = {
+ val inputVCF = "src/test/resources/sample.vcf"
+ var s = State(sc, sqlContext)
+ s = ImportVCF.run(s, Array(inputVCF))
+ s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()"))
+ s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum().toLong"))
+ val kt = s.vds.aggregateByKey(null, "nHet = g.map(g => g.isHet.toInt).sum().toLong")
+
+ val (_, ktHetQuery) = kt.query("nHet")
+ val (_, globalHetResult) = s.vds.queryGlobal("global.nHet")
+
+ val ktGlobalResult = kt.rdd.map { case (k, v) => ktHetQuery(k, v).map(_.asInstanceOf[Long]) }.collect().head
+ val vdsGlobalResult = globalHetResult.map(_.asInstanceOf[Long])
+
+ assert(ktGlobalResult == vdsGlobalResult)
+ }
+}
diff --git a/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala b/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala
index 815972b617e..bb4046583f3 100644
--- a/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala
+++ b/src/test/scala/org/broadinstitute/hail/io/ExportPlinkSuite.scala
@@ -66,4 +66,24 @@ class ExportPlinkSuite extends SparkSuite {
}
)
}
+
+ @Test def testFamExport() {
+ val plink = tmpDir.createTempFile("mendel")
+
+ var s = State(sc, sqlContext)
+ s = ImportVCF.run(s, Array("src/test/resources/mendel.vcf"))
+ s = SplitMulti.run(s)
+ s = HardCalls.run(s)
+ s = AnnotateSamplesFam.run(s, Array("-i", "src/test/resources/mendel.fam", "-d", "\\\\s+"))
+ s = AnnotateSamplesExpr.run(s, Array("-c", "sa = sa.fam"))
+ s = AnnotateVariantsExpr.run(s, Array("-c", "va.rsid = str(v)"))
+ s = AnnotateVariantsExpr.run(s, Array("-c", "va = select(va, rsid)"))
+
+ s = ExportPlink.run(s, Array("-o", plink, "-f",
+ "famID = sa.famID, id = s.id, matID = sa.matID, patID = sa.patID, isFemale = sa.isFemale, isCase = sa.isCase"))
+
+ var s2 = ImportPlink.run(s, Array("--bfile", plink))
+
+ assert(s.vds.same(s2.vds))
+ }
}
diff --git a/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala b/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala
index 5e797a4ffc7..86a848508c9 100644
--- a/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala
+++ b/src/test/scala/org/broadinstitute/hail/io/compress/BGzipCodecSuite.scala
@@ -44,7 +44,7 @@ class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat {
class BGzipCodecSuite extends SparkSuite {
@Test def test() {
- sc.hadoopConfiguration.set("io.compression.codecs", "org.apache.hadoop.io.compress.DefaultCodec,org.broadinstitute.hail.io.compress.BGzipCodec,org.apache.hadoop.io.compress.GzipCodec")
+ sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", 1L)
val uncompPath = "src/test/resources/sample.vcf"
diff --git a/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala
index 674357ce3ca..7f18f5205d5 100644
--- a/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala
+++ b/src/test/scala/org/broadinstitute/hail/methods/AggregatorSuite.scala
@@ -140,24 +140,24 @@ class AggregatorSuite extends SparkSuite {
s2.vds.rdd.collect.foreach { case (v, (va, gs)) =>
val r = va.asInstanceOf[Row]
- val densities = r.getAs[IndexedSeq[Long]](1)
+ val frequencies = r.getAs[IndexedSeq[Long]](1)
val definedGq = gs.flatMap(_.gq)
- assert(densities(0) == definedGq.count(gq => gq < 5))
- assert(densities(1) == definedGq.count(gq => gq >= 5 && gq < 10))
- assert(densities.last == definedGq.count(gq => gq >= 95))
+ assert(frequencies(0) == definedGq.count(gq => gq < 5))
+ assert(frequencies(1) == definedGq.count(gq => gq >= 5 && gq < 10))
+ assert(frequencies.last == definedGq.count(gq => gq >= 95))
}
val s3 = AnnotateVariantsExpr.run(s, Array("-c", "va = gs.map(g => g.gq).hist(22, 80, 5)"))
s3.vds.rdd.collect.foreach { case (v, (va, gs)) =>
val r = va.asInstanceOf[Row]
- val nSmaller = r.getAs[Long](2)
+ val nLess = r.getAs[Long](2)
val nGreater = r.getAs[Long](3)
val definedGq = gs.flatMap(_.gq)
- assert(nSmaller == definedGq.count(_ < 22))
+ assert(nLess == definedGq.count(_ < 22))
assert(nGreater == definedGq.count(_ > 80))
}
diff --git a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala
index c81a992d21d..2140bc830aa 100644
--- a/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala
+++ b/src/test/scala/org/broadinstitute/hail/methods/ExprSuite.scala
@@ -507,7 +507,17 @@ class ExprSuite extends SparkSuite {
assert(eval[Boolean]("rnorm(2.0, 4.0).abs > -1.0").contains(true))
+ assert(D_==(eval[Double]("pnorm(qnorm(0.5))").get, 0.5))
+ assert(D_==(eval[Double]("qnorm(pnorm(0.5))").get, 0.5))
+
+ assert(D_==(eval[Double]("qchisq1tail(pchisq1tail(0.5))").get, 0.5))
+ assert(D_==(eval[Double]("pchisq1tail(qchisq1tail(0.5))").get, 0.5))
+
assert(eval[Any]("if (true) NA: Double else 0.0").isEmpty)
+
+ assert(eval[Int]("gtIndex(3, 5)").contains(18))
+ assert(eval[Int]("gtj(18)").contains(3))
+ assert(eval[Int]("gtk(18)").contains(5))
}
@Test def testParseTypes() {
diff --git a/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala
new file mode 100644
index 00000000000..18aefd093ee
--- /dev/null
+++ b/src/test/scala/org/broadinstitute/hail/methods/KeyTableSuite.scala
@@ -0,0 +1,176 @@
+package org.broadinstitute.hail.methods
+
+import org.broadinstitute.hail.SparkSuite
+import org.broadinstitute.hail.annotations._
+import org.broadinstitute.hail.driver._
+import org.broadinstitute.hail.expr._
+import org.broadinstitute.hail.keytable.KeyTable
+import org.broadinstitute.hail.utils._
+import org.testng.annotations.Test
+
+class KeyTableSuite extends SparkSuite {
+
+ @Test def testSingleToPairRDD() = {
+ val inputFile = "src/test/resources/sampleAnnotations.tsv"
+ val kt = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status", sc.defaultMinPartitions, TextTableConfiguration())
+ val kt2 = KeyTable(KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues), kt.signature, kt.keyNames)
+
+ assert(kt.rdd.fullOuterJoin(kt2.rdd).forall { case (k, (v1, v2)) =>
+ val res = v1 == v2
+ if (!res)
+ println(s"k=$k v1=$v1 v2=$v2 res=${ v1 == v2 }")
+ res
+ })
+ }
+
+ @Test def testImportExport() = {
+ val inputFile = "src/test/resources/sampleAnnotations.tsv"
+ val outputFile = tmpDir.createTempFile("ktImpExp", "tsv")
+ val kt = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status", sc.defaultMinPartitions, TextTableConfiguration())
+ kt.export(sc, outputFile, null)
+
+ val importedData = sc.hadoopConfiguration.readLines(inputFile)(_.map(_.value).toIndexedSeq)
+ val exportedData = sc.hadoopConfiguration.readLines(outputFile)(_.map(_.value).toIndexedSeq)
+
+ intercept[FatalException] {
+ val kt2 = KeyTable.importTextTable(sc, Array(inputFile), "Sample, Status, BadKeyName", sc.defaultMinPartitions, TextTableConfiguration())
+ }
+
+ assert(importedData == exportedData)
+ }
+
+ @Test def testAnnotate() = {
+ val inputFile = "src/test/resources/sampleAnnotations.tsv"
+ val kt1 = KeyTable.importTextTable(sc, Array(inputFile), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true))
+ val kt2 = kt1.annotate("""qPhen2 = pow(qPhen, 2), NotStatus = Status == "CASE", X = qPhen == 5""", kt1.keyNames.mkString(","))
+ val kt3 = kt2.annotate("", kt2.keyNames.mkString(","))
+ val kt4 = kt3.annotate("", "qPhen, NotStatus")
+
+ val kt1ValueNames = kt1.valueNames.toSet
+ val kt2ValueNames = kt2.valueNames.toSet
+
+ assert(kt1.nKeys == 1)
+ assert(kt2.nKeys == 1)
+ assert(kt1.nValues == 2 && kt2.nValues == 5)
+ assert(kt1.keySignature == kt2.keySignature)
+ assert(kt1ValueNames ++ Set("qPhen2", "NotStatus", "X") == kt2ValueNames)
+ assert(kt2 same kt3)
+
+ def getDataAsMap(kt: KeyTable) = {
+ val fieldNames = kt.fieldNames
+ val nFields = kt.nFields
+ KeyTable.toSingleRDD(kt.rdd, kt.nKeys, kt.nValues)
+ .map { a => fieldNames.zip(KeyTable.annotationToSeq(a, nFields)).toMap }.collect().toSet
+ }
+
+ val kt3data = getDataAsMap(kt3)
+ val kt4data = getDataAsMap(kt4)
+
+ assert(kt4.keyNames.toSet == Set("qPhen", "NotStatus") &&
+ kt4.valueNames.toSet == Set("qPhen2", "X", "Sample", "Status") &&
+ kt3data == kt4data
+ )
+ }
+
+ @Test def testFilter() = {
+ val data = Array(Array(5, 9, 0), Array(2, 3, 4), Array(1, 2, 3))
+ val rdd = sc.parallelize(data.map(Annotation.fromSeq(_)))
+ val signature = TStruct(("field1", TInt), ("field2", TInt), ("field3", TInt))
+ val keyNames = Array("field1")
+
+ val kt1 = KeyTable(rdd, signature, keyNames)
+ val kt2 = kt1.filter("field1 < 3", keep = true)
+ val kt3 = kt1.filter("field1 < 3 && field3 == 4", keep = true)
+ val kt4 = kt1.filter("field1 == 5 && field2 == 9 && field3 == 0", keep = false)
+ val kt5 = kt1.filter("field1 < -5 && field3 == 100", keep = true)
+
+ assert(kt1.nRows == 3 && kt2.nRows == 2 && kt3.nRows == 1 && kt4.nRows == 2 && kt5.nRows == 0)
+ }
+
+ @Test def testJoin() = {
+ val inputFile1 = "src/test/resources/sampleAnnotations.tsv"
+ val inputFile2 = "src/test/resources/sampleAnnotations2.tsv"
+
+ val ktLeft = KeyTable.importTextTable(sc, Array(inputFile1), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true))
+ val ktRight = KeyTable.importTextTable(sc, Array(inputFile2), "Sample", sc.defaultMinPartitions, TextTableConfiguration(impute = true))
+
+ val ktLeftJoin = ktLeft.leftJoin(ktRight)
+ val ktRightJoin = ktLeft.rightJoin(ktRight)
+ val ktInnerJoin = ktLeft.innerJoin(ktRight)
+ val ktOuterJoin = ktLeft.outerJoin(ktRight)
+
+ val nExpectedValues = ktLeft.nValues + ktRight.nValues
+
+ val (_, leftKeyQuery) = ktLeft.query("Sample")
+ val (_, rightKeyQuery) = ktRight.query("Sample")
+ val (_, leftJoinKeyQuery) = ktLeftJoin.query("Sample")
+ val (_, rightJoinKeyQuery) = ktRightJoin.query("Sample")
+
+ val leftKeys = ktLeft.rdd.map { case (k, v) => leftKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet
+ val rightKeys = ktRight.rdd.map { case (k, v) => rightKeyQuery(k, v).map(_.asInstanceOf[String]) }.collect().toSet
+
+ val nIntersectRows = leftKeys.intersect(rightKeys).size
+ val nUnionRows = rightKeys.union(leftKeys).size
+ val nExpectedKeys = ktLeft.nKeys
+
+ assert(ktLeftJoin.nRows == ktLeft.nRows &&
+ ktLeftJoin.nKeys == nExpectedKeys &&
+ ktLeftJoin.nValues == nExpectedValues &&
+ ktLeftJoin.filter { case (k, v) =>
+ !rightKeys.contains(leftJoinKeyQuery(k, v).map(_.asInstanceOf[String]))
+ }.forall("isMissing(qPhen2) && isMissing(qPhen3)")
+ )
+
+ assert(ktRightJoin.nRows == ktRight.nRows &&
+ ktRightJoin.nKeys == nExpectedKeys &&
+ ktRightJoin.nValues == nExpectedValues &&
+ ktRightJoin.filter { case (k, v) =>
+ !leftKeys.contains(rightJoinKeyQuery(k, v).map(_.asInstanceOf[String]))
+ }.forall("isMissing(Status) && isMissing(qPhen)"))
+
+ assert(ktOuterJoin.nRows == nUnionRows &&
+ ktOuterJoin.nKeys == ktLeft.nKeys &&
+ ktOuterJoin.nValues == nExpectedValues)
+
+ assert(ktInnerJoin.nRows == nIntersectRows &&
+ ktInnerJoin.nKeys == nExpectedKeys &&
+ ktInnerJoin.nValues == nExpectedValues)
+ }
+
+ @Test def testAggregate() {
+ val data = Array(Array("Case", 9, 0), Array("Case", 3, 4), Array("Control", 2, 3), Array("Control", 1, 5))
+ val rdd = sc.parallelize(data.map(Annotation.fromSeq(_)))
+ val signature = TStruct(("field1", TString), ("field2", TInt), ("field3", TInt))
+ val keyNames = Array("field1")
+
+ val kt1 = KeyTable(rdd, signature, keyNames)
+ val kt2 = kt1.aggregate("Status = field1",
+ "A = field2.sum(), " +
+ "B = field2.map(f => field2).sum(), " +
+ "C = field2.map(f => field2 + field3).sum(), " +
+ "D = field2.count(), " +
+ "E = field2.filter(f => field2 == 3).count()"
+ )
+
+ val result = Array(Array("Case", 12.0, 12.0, 16.0, 2L, 1L), Array("Control", 3.0, 3.0, 11.0, 2L, 0L))
+ val resRDD = sc.parallelize(result.map(Annotation.fromSeq(_)))
+ val resSignature = TStruct(("Status", TString), ("A", TDouble), ("B", TDouble), ("C", TDouble), ("D", TLong), ("E", TLong))
+ val ktResult = KeyTable(resRDD, resSignature, keyNames = Array("Status"))
+
+ assert(kt2 same ktResult)
+ }
+
+ @Test def testForallExists() {
+ val data = Array(Array("Sample1", 9, 5), Array("Sample2", 3, 5), Array("Sample3", 2, 5), Array("Sample4", 1, 5))
+ val rdd = sc.parallelize(data.map(Annotation.fromSeq(_)))
+ val signature = TStruct(("Sample", TString), ("field1", TInt), ("field2", TInt))
+ val keyNames = Array("Sample")
+
+ val kt = KeyTable(rdd, signature, keyNames)
+ assert(kt.forall("field2 == 5 && field1 != 0"))
+ assert(!kt.forall("field2 == 0 && field1 == 5"))
+ assert(kt.exists("""Sample == "Sample1" && field1 == 9 && field2 == 5"""))
+ assert(!kt.exists("""Sample == "Sample1" && field1 == 13 && field2 == 2"""))
+ }
+
+}
diff --git a/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala b/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala
index 59a3fa40ac9..eeca50781e9 100644
--- a/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala
+++ b/src/test/scala/org/broadinstitute/hail/stats/StatsSuite.scala
@@ -1,7 +1,7 @@
package org.broadinstitute.hail.stats
import breeze.linalg.DenseMatrix
-import org.apache.commons.math3.distribution.ChiSquaredDistribution
+import org.apache.commons.math3.distribution.{ChiSquaredDistribution, NormalDistribution}
import org.broadinstitute.hail.utils._
import org.broadinstitute.hail.variant.Variant
import org.broadinstitute.hail.SparkSuite
@@ -21,6 +21,24 @@ class StatsSuite extends SparkSuite {
val chiSq5 = new ChiSquaredDistribution(5.2)
assert(D_==(chiSquaredTail(5.2, 1), 1 - chiSq5.cumulativeProbability(1)))
assert(D_==(chiSquaredTail(5.2, 5.52341), 1 - chiSq5.cumulativeProbability(5.52341)))
+
+ assert(D_==(inverseChiSquaredTailOneDF(.1), chiSq1.inverseCumulativeProbability(1 - .1)))
+ assert(D_==(inverseChiSquaredTailOneDF(.0001), chiSq1.inverseCumulativeProbability(1 - .0001)))
+
+ val a = List(.0000000001, .5, .9999999999, 1.0)
+ a.foreach(p => println(p, inverseChiSquaredTailOneDF(p)))
+ a.foreach(p => assert(D_==(chiSquaredTail(1.0, inverseChiSquaredTailOneDF(p)), p)))
+ }
+
+ @Test def normTest() = {
+ val normalDist = new NormalDistribution()
+ assert(D_==(pnorm(1), normalDist.cumulativeProbability(1)))
+ assert(D_==(pnorm(-10), normalDist.cumulativeProbability(-10)))
+ assert(D_==(qnorm(.6), normalDist.inverseCumulativeProbability(.6)))
+ assert(D_==(qnorm(.0001), normalDist.inverseCumulativeProbability(.0001)))
+
+ val a = List(0.0, .0000000001, .5, .9999999999, 1.0)
+ assert(a.forall(p => D_==(qnorm(pnorm(qnorm(p))), qnorm(p))))
}
@Test def vdsFromMatrixTest() {