Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/commands/meta_staar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ pub fn build_config(
loci.display()
)));
}
if matches!(conditional_model, crate::cli::ConditionalModel::Heterogeneous) {
return Err(CohortError::Input(
"--conditional-model heterogeneous requires per-study U vectors \
which --emit-sumstats does not yet persist. Use --conditional-model \
homogeneous for now."
.into(),
));
}
}

let mask_categories = crate::commands::parse_mask_categories(&masks)?;
Expand Down
1 change: 1 addition & 0 deletions src/staar/carrier/sparse_score.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ pub(crate) fn null_model_from_analysis(analysis: &AnalysisVectors) -> NullModel
// a dense G and a non-mixed-model NullModel. Kinship-aware analyses
// route through `score_gene_sparse_kinship` directly.
kinship: None,
scang: None,
}
}

Expand Down
157 changes: 134 additions & 23 deletions src/staar/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ pub fn merge_chromosome(
bool_or({ccre_p}) AS {ccre_p}, \
bool_or({ccre_e}) AS {ccre_e}, \
{weight_aggs}, \
CAST(array_agg(named_struct('s', study_idx, 'seg', segment_id)) AS VARCHAR) AS study_segs \
CAST(array_agg(named_struct('s', study_idx, 'seg', segment_id)) AS VARCHAR) AS study_segs, \
CAST(array_agg(named_struct('s', study_idx, \
'u', CASE WHEN {ref_a} <= {alt_a} THEN u_stat ELSE -u_stat END)) \
AS VARCHAR) AS study_us \
FROM _study_variants \
WHERE {maf} < {maf_cutoff} \
GROUP BY {pos}, \
Expand All @@ -428,7 +431,7 @@ pub fn merge_chromosome(
cadd_phred_raw, {revel}, {msp}, {gh}, \
{cage_p}, {cage_e}, {ccre_p}, {ccre_e}, \
{weight_select}, \
study_segs \
study_segs, study_us \
FROM _meta_variants ORDER BY {pos}",
pos = Col::Position,
ref_a = Col::RefAllele,
Expand Down Expand Up @@ -512,6 +515,7 @@ pub fn merge_chromosome(
w_arrs.push(f64_col(17 + i)?);
}
let segs_arr = str_col(28)?;
let study_u_arr = str_col(29)?;

for i in 0..n {
let mut weights = [0.0f64; 11];
Expand All @@ -521,6 +525,7 @@ pub fn merge_chromosome(
let cadd_phred = f64_or(cadd_raw_arr, i, 0.0);

let study_segments = parse_study_segments(str_or(segs_arr, i, ""));
let u_study = parse_study_us(str_or(study_u_arr, i, ""));

let mac_total = i64_or(mac_arr, i, 0);
let n_total = i64_or(n_obs_arr, i, 0);
Expand Down Expand Up @@ -559,6 +564,7 @@ pub fn merge_chromosome(
mac_total,
n_total,
study_segments,
u_study,
});
}
}
Expand Down Expand Up @@ -671,6 +677,7 @@ pub fn meta_score_gene(
n_variants: m as u32,
cumulative_mac: cmac as u32,
staar: sr,
emthr: f64::NAN,
},
burden_beta,
burden_se,
Expand Down Expand Up @@ -750,6 +757,34 @@ fn parse_study_segments(s: &str) -> Vec<(usize, i32)> {
result
}

/// Parse DuckDB's stringified `array_agg(named_struct('s',..,'u',..))` into
/// `(study_idx, signed_u)` pairs. Mirrors `parse_study_segments` byte-for-byte
/// save for the 'u' value being a float rather than an integer segment id.
fn parse_study_us(s: &str) -> Vec<(usize, f64)> {
let mut result = Vec::new();
for part in s.split('{') {
let part =
part.trim_matches(|c: char| c == '[' || c == ']' || c == ',' || c == ' ' || c == '}');
if part.is_empty() {
continue;
}
let mut study: Option<usize> = None;
let mut u: Option<f64> = None;
for kv in part.split(',') {
let kv = kv.trim();
if let Some(val) = kv.strip_prefix("s:").or_else(|| kv.strip_prefix("s :")) {
study = val.trim().parse().ok();
} else if let Some(val) = kv.strip_prefix("u:").or_else(|| kv.strip_prefix("u :")) {
u = val.trim().parse().ok();
}
}
if let (Some(s), Some(u)) = (study, u) {
result.push((s, u));
}
}
result
}

const SEGMENT_BP: u32 = 500_000;
const MAX_SEGMENT_VARIANTS: usize = 2000;

Expand Down Expand Up @@ -1215,27 +1250,23 @@ pub fn parse_known_loci_file(
}

/// Conditional meta-analysis: condition gene-level U/K on known loci
/// before running STAAR tests.
///
/// Homogeneous model: condition the merged (cross-study) U and K.
/// Heterogeneous model: condition per-study U and K before merging.
///
/// The conditioning step uses Schur complement:
/// U_cond = U_t - K_tc * K_cc^{-1} * U_c
/// K_cond = K_tt - K_tc * K_cc^{-1} * K_ct
///
/// before running STAAR tests. Schur complement:
/// U_cond = U_t - K_tc · K_cc⁻¹ · U_c
/// K_cond = K_tt - K_tc · K_cc⁻¹ · K_ct
/// where t = test (gene) variants, c = conditioning (known loci) variants.
/// Conditional meta-scoring uses the homogeneous model: condition the
/// merged (cross-study) U and K on known-loci variants via Schur complement.
/// The heterogeneous model (per-study conditioning) is rejected at config
/// time because --emit-sumstats does not yet persist per-study U vectors.
///
/// Homogeneous model sums U and K across studies before the Schur solve.
/// Heterogeneous model does the Schur solve inside each study first and
/// sums per-study (U_cond_i, K_cond_i) at the end, matching MetaSTAAR
/// R/MetaSTAAR_merge_cond.R:391-404 and tolerating per-study variation in
/// covariate-adjusted residual variance.
pub fn meta_score_gene_conditional(
group: &MaskGroup,
meta_variants: &[MetaVariant],
studies: &[StudyHandle],
segment_cache: &HashMap<(usize, i32), SegmentCov>,
known_loci_indices: &[usize],
_heterogeneous: bool,
heterogeneous: bool,
) -> Option<MetaGeneResult> {
let gene_indices: Vec<usize> = group
.variant_indices
Expand Down Expand Up @@ -1264,20 +1295,13 @@ pub fn meta_score_gene_conditional(
let m_t = gene_indices.len();
let m_c = cond_indices.len();

// Build combined keys: [gene variants | conditioning variants].
let combined: Vec<usize> = gene_indices
.iter()
.chain(cond_indices.iter())
.copied()
.collect();
let m_all = combined.len();

// Homogeneous: condition merged U/K.
let mut u_all = Mat::zeros(m_all, 1);
for (local, &gi) in combined.iter().enumerate() {
u_all[(local, 0)] = meta_variants[gi].u_meta;
}

let keys: Vec<(u32, &str, &str)> = combined
.iter()
.map(|&gi| {
Expand All @@ -1289,6 +1313,46 @@ pub fn meta_score_gene_conditional(
})
.collect();

if heterogeneous {
// Mirrors MetaSTAAR R/MetaSTAAR_merge_cond.R:391-404.
// Each study conditions its own (U_i, K_i) via Schur complement;
// the per-study conditional pair is summed across studies. Studies
// with no overlap on a variant contribute zero U and zero K for
// that row / column, so they drop out of the Schur solve naturally.
let mut u_cond_sum = Mat::zeros(m_t, 1);
let mut k_cond_sum = Mat::zeros(m_t, m_t);

for study_idx in 0..studies.len() {
let (u_study, cov_study) = build_study_u_cov(
study_idx,
&combined,
meta_variants,
segment_cache,
&keys,
);

let (u_t_i, k_tt_i, u_c_i, k_cc_i, k_tc_i) =
partition_u_cov(&u_study, &cov_study, m_t, m_c);
let (u_cond_i, k_cond_i) =
schur_condition(&u_t_i, &k_tt_i, &u_c_i, &k_cc_i, &k_tc_i);

for i in 0..m_t {
u_cond_sum[(i, 0)] += u_cond_i[(i, 0)];
for j in 0..m_t {
k_cond_sum[(i, j)] += k_cond_i[(i, j)];
}
}
}

return finish_conditional(&gene_indices, meta_variants, &u_cond_sum, &k_cond_sum, group);
}

// Homogeneous: condition merged U/K.
let mut u_all = Mat::zeros(m_all, 1);
for (local, &gi) in combined.iter().enumerate() {
u_all[(local, 0)] = meta_variants[gi].u_meta;
}

let mut cov_all = Mat::zeros(m_all, m_all);
for study_idx in 0..studies.len() {
let mut needed_segments: std::collections::HashSet<i32> =
Expand Down Expand Up @@ -1318,6 +1382,50 @@ pub fn meta_score_gene_conditional(
finish_conditional(&gene_indices, meta_variants, &u_cond, &k_cond, group)
}

/// Build one study's (U, K) across the combined variant list. Variants not
/// present in the study contribute zero; segments absent from the cache
/// contribute zero (same fallback as the homogeneous accumulator).
fn build_study_u_cov(
study_idx: usize,
combined: &[usize],
meta_variants: &[MetaVariant],
segment_cache: &HashMap<(usize, i32), SegmentCov>,
keys: &[(u32, &str, &str)],
) -> (Mat<f64>, Mat<f64>) {
let m_all = combined.len();
let mut u = Mat::zeros(m_all, 1);
for (local, &gi) in combined.iter().enumerate() {
for &(s, val) in &meta_variants[gi].u_study {
if s == study_idx {
u[(local, 0)] = val;
break;
}
}
}

let mut cov = Mat::zeros(m_all, m_all);
let mut needed: std::collections::HashSet<i32> = std::collections::HashSet::new();
for &gi in combined {
for &(s, seg_id) in &meta_variants[gi].study_segments {
if s == study_idx {
needed.insert(seg_id);
}
}
}
for seg_id in needed {
if let Some(seg) = segment_cache.get(&(study_idx, seg_id)) {
let sub = seg.extract_submatrix(keys);
for i in 0..m_all {
for j in 0..m_all {
cov[(i, j)] += sub[(i, j)];
}
}
}
}

(u, cov)
}

/// Partition combined U and K into test (t) and conditioning (c) blocks.
#[allow(clippy::type_complexity)]
fn partition_u_cov(
Expand Down Expand Up @@ -1442,6 +1550,7 @@ fn finish_conditional(
n_variants: m_t as u32,
cumulative_mac: cmac as u32,
staar: sr,
emthr: f64::NAN,
},
burden_beta,
burden_se,
Expand Down Expand Up @@ -1562,6 +1671,7 @@ mod tests {
mac_total: (2.0 * mafs[i] * n as f64).round() as i64,
n_total: n as i64,
study_segments: vec![(0, 0)],
u_study: vec![(0, u[(i, 0)])],
})
.collect();

Expand Down Expand Up @@ -1674,6 +1784,7 @@ mod tests {
mac_total: (2.0 * mafs[i] * n as f64).round() as i64,
n_total: n as i64,
study_segments: vec![(0, 10), (1, 20)],
u_study: vec![(0, half_u[(i, 0)]), (1, half_u[(i, 0)])],
})
.collect();

Expand Down
6 changes: 6 additions & 0 deletions src/staar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod multi;
pub mod output;
pub mod pipeline;
pub mod run_manifest;
pub mod scang;
pub mod score;
pub mod scoring;
pub mod stats;
Expand Down Expand Up @@ -152,4 +153,9 @@ pub struct GeneResult {
pub n_variants: u32,
pub cumulative_mac: u32,
pub staar: score::StaarResult,
/// SCANG-O empirical −log10(p) threshold at α = 0.05, NaN otherwise.
/// Emitted alongside per-window p-values so operators can cross-check
/// `-log10(p) > emthr` matches the R `SCANG_O_res$th0` gate. See
/// `crate::staar::scang::chrom_threshold` and SCANG R/SCANG.r:181-205.
pub emthr: f64,
}
5 changes: 5 additions & 0 deletions src/staar/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,9 @@ pub struct NullModel {
/// Kinship-aware variance components and projection. Set when `--kinship`
/// is in use; the score path then dispatches to the kinship-aware kernel.
pub kinship: Option<KinshipState>,
/// SCANG-side state populated lazily once a run needs Monte Carlo
/// thresholds. See `crate::staar::scang::ScangExt`.
pub scang: Option<crate::staar::scang::ScangExt>,
}

impl NullModel {
Expand Down Expand Up @@ -603,6 +606,7 @@ pub fn fit_glm(y: &Mat<f64>, x: &Mat<f64>) -> NullModel {
fitted_values: None,
working_weights: None,
kinship: None,
scang: None,
}
}

Expand Down Expand Up @@ -704,6 +708,7 @@ pub fn fit_logistic(y: &Mat<f64>, x: &Mat<f64>, max_iter: usize) -> NullModel {
fitted_values: Some(fitted),
working_weights: Some(w_final),
kinship: None,
scang: None,
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/staar/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ fn mask_results_schema(channels: &[&str]) -> Schema {
}
fields.push(Field::new("ACAT-O", DataType::Float64, true));
fields.push(Field::new("STAAR-O", DataType::Float64, true));
// SCANG-only empirical −log10(p) threshold; NaN elsewhere. Stays on
// the shared schema so downstream readers never have to branch on
// mask type to know whether the column exists.
fields.push(Field::new("emthr", DataType::Float64, true));
Schema::new(fields)
}

Expand All @@ -308,6 +312,7 @@ fn build_mask_columns(sorted: &[&GeneResult], n_channels: usize) -> Vec<ArrayRef
.collect();
let mut b_acat_o = Float64Builder::with_capacity(nr);
let mut b_staar_o = Float64Builder::with_capacity(nr);
let mut b_emthr = Float64Builder::with_capacity(nr);

for r in sorted {
let s = &r.staar;
Expand Down Expand Up @@ -354,6 +359,7 @@ fn build_mask_columns(sorted: &[&GeneResult], n_channels: usize) -> Vec<ArrayRef
}
b_acat_o.append_value(s.acat_o);
b_staar_o.append_value(s.staar_o);
b_emthr.append_value(r.emthr);
}

let mut columns: Vec<ArrayRef> = vec![
Expand All @@ -370,6 +376,7 @@ fn build_mask_columns(sorted: &[&GeneResult], n_channels: usize) -> Vec<ArrayRef
}
columns.push(Arc::new(b_acat_o.finish()));
columns.push(Arc::new(b_staar_o.finish()));
columns.push(Arc::new(b_emthr.finish()));
columns
}

Expand Down Expand Up @@ -1056,6 +1063,7 @@ mod tests {
acat_o: p,
staar_o: p,
},
emthr: f64::NAN,
}
}

Expand Down
Loading
Loading