diff --git a/src/cli.rs b/src/cli.rs index b21883b..35cbf5b 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -315,6 +315,48 @@ pub enum Command { output: Option, }, + /// Forward-selection LD pruning on conditional score-test p-values + #[command(name = "ld-prune")] + LdPrune { + /// Pre-built cohort id (under the store root). + #[arg(long)] + cohort: String, + + /// Phenotype file (TSV with sample_id as first column) + #[arg(long)] + phenotype: PathBuf, + + /// Trait column name in the phenotype file + #[arg(long)] + trait_name: String, + + /// Covariate columns (comma-separated, e.g. age,sex,PC1,PC2) + #[arg(long, value_delimiter = ',')] + covariates: Vec, + + /// Candidate variants file. Tab-delimited or colon-delimited with + /// four fields per row: CHR POS REF ALT. `#`-prefixed lines skipped. + #[arg(long)] + variants: PathBuf, + + /// Minor allele frequency floor for candidates (default 0.01) + #[arg(long, default_value = "0.01")] + maf_cutoff: f64, + + /// Conditional p-value threshold at which forward selection stops + /// (default 1e-4, matches STAARpipeline LD_pruning). + #[arg(long, default_value = "1e-4")] + cond_p_thresh: f64, + + /// Column name mapping for phenotype file (key=value pairs) + #[arg(long, value_delimiter = ',')] + column_map: Vec, + + /// Output TSV path (default: .ld_pruned.tsv) + #[arg(short, long)] + output: Option, + }, + /// Meta-analysis of STAAR across studies (MetaSTAAR) #[command(name = "meta-staar")] MetaStaar { diff --git a/src/commands/ld_prune.rs b/src/commands/ld_prune.rs new file mode 100644 index 0000000..d8a9da2 --- /dev/null +++ b/src/commands/ld_prune.rs @@ -0,0 +1,365 @@ +//! `favor ld-prune` subcommand. +//! +//! Forward-selection LD pruning on conditional score-test p-values. +//! Mirrors STAARpipeline R/LD_pruning.R for the gaussian, unrelated, +//! single-trait path. + +use std::collections::HashMap; +use std::io::Write; +use std::path::PathBuf; + +use serde_json::json; + +use crate::commands; +use crate::error::CohortError; +use crate::output::Output; +use crate::runtime::Engine; +use crate::staar::ld_prune::{self, Candidate, KeptVariant, LdPruneParams}; +use crate::staar::model::load_phenotype; +use crate::store::cohort::CohortId; +use crate::types::Chromosome; + +const GB: u64 = 1024 * 1024 * 1024; + +pub struct LdPruneArgs { + pub cohort: String, + pub phenotype: PathBuf, + pub trait_name: String, + pub covariates: Vec, + pub variants: PathBuf, + pub maf_cutoff: f64, + pub cond_p_thresh: f64, + pub column_map: Vec, + pub output: Option, +} + +pub struct LdPruneConfig { + pub cohort_id: CohortId, + pub phenotype: PathBuf, + pub trait_name: String, + pub covariates: Vec, + pub variants: PathBuf, + pub maf_cutoff: f64, + pub cond_p_thresh: f64, + pub column_map: HashMap, + pub output: PathBuf, +} + +pub fn run( + engine: &Engine, + args: LdPruneArgs, + out: &dyn Output, + dry_run: bool, +) -> Result<(), CohortError> { + let config = build_config(args)?; + if dry_run { + return emit_dry_run(&config, out); + } + run_ld_prune(engine, &config, out) +} + +fn build_config(args: LdPruneArgs) -> Result { + if !args.phenotype.exists() { + return Err(CohortError::Input(format!( + "Phenotype file not found: '{}'", + args.phenotype.display() + ))); + } + if !args.variants.exists() { + return Err(CohortError::Input(format!( + "Variants file not found: '{}'", + args.variants.display() + ))); + } + if args.trait_name.trim().is_empty() { + return Err(CohortError::Input("--trait-name is required".into())); + } + if args.maf_cutoff <= 0.0 || args.maf_cutoff >= 0.5 { + return Err(CohortError::Input(format!( + "MAF cutoff must be in (0, 0.5), got {}", + args.maf_cutoff + ))); + } + if args.cond_p_thresh <= 0.0 || args.cond_p_thresh >= 1.0 { + return Err(CohortError::Input(format!( + "--cond-p-thresh must be in (0, 1), got {}", + args.cond_p_thresh + ))); + } + + let cohort_id = CohortId::new(args.cohort.trim().to_string()); + let output = args.output.unwrap_or_else(|| { + PathBuf::from(format!("{}.ld_pruned.tsv", cohort_id.as_str())) + }); + + let mut column_map = HashMap::new(); + for entry in &args.column_map { + let (k, v) = entry.split_once('=').ok_or_else(|| { + CohortError::Input(format!( + "Invalid --column-map entry '{entry}'. Expected key=value." + )) + })?; + column_map.insert(k.trim().to_string(), v.trim().to_string()); + } + + Ok(LdPruneConfig { + cohort_id, + phenotype: args.phenotype, + trait_name: args.trait_name, + covariates: args.covariates, + variants: args.variants, + maf_cutoff: args.maf_cutoff, + cond_p_thresh: args.cond_p_thresh, + column_map, + output, + }) +} + +fn emit_dry_run(config: &LdPruneConfig, out: &dyn Output) -> Result<(), CohortError> { + let n_candidates = std::fs::read_to_string(&config.variants) + .map(|s| s.lines().filter(|l| !l.is_empty() && !l.starts_with('#')).count()) + .unwrap_or(0); + let plan = commands::DryRunPlan { + command: "ld-prune".into(), + inputs: json!({ + "cohort_id": config.cohort_id.as_str(), + "phenotype": config.phenotype.to_string_lossy(), + "trait": config.trait_name, + "covariates": config.covariates, + "variants": config.variants.to_string_lossy(), + "n_candidates": n_candidates, + "maf_cutoff": config.maf_cutoff, + "cond_p_thresh": config.cond_p_thresh, + }), + memory: commands::MemoryEstimate { + minimum: "4G".into(), + recommended: "8G".into(), + minimum_bytes: 4 * GB, + recommended_bytes: 8 * GB, + }, + runtime: None, + output_path: config.output.to_string_lossy().into(), + }; + commands::emit(&plan, out); + Ok(()) +} + +fn run_ld_prune( + engine: &Engine, + config: &LdPruneConfig, + out: &dyn Output, +) -> Result<(), CohortError> { + let cohort = engine.cohort(&config.cohort_id); + let store = cohort.load()?; + + let pheno = load_phenotype( + engine.df(), + &config.phenotype, + &config.covariates, + &store.geno, + std::slice::from_ref(&config.trait_name), + None, + 5, + 0, + &config.column_map, + out, + )?; + if !matches!(pheno.trait_type, crate::staar::TraitType::Continuous) { + return Err(CohortError::Input( + "ld-prune currently supports continuous traits only".into(), + )); + } + + let candidates_by_chrom = parse_candidates(&config.variants)?; + if candidates_by_chrom.is_empty() { + return Err(CohortError::Input(format!( + "No variants parsed from '{}'", + config.variants.display() + ))); + } + + out.status(&format!( + "ld-prune: {} candidate variants across {} chromosome(s)", + candidates_by_chrom.values().map(|v| v.len()).sum::(), + candidates_by_chrom.len() + )); + + let params = LdPruneParams { + maf_cutoff: config.maf_cutoff, + cond_p_thresh: config.cond_p_thresh, + }; + + let mut kept_all: Vec = Vec::new(); + for (chrom, cands) in &candidates_by_chrom { + let view = match cohort.chromosome(chrom) { + Ok(v) => v, + Err(e) => { + out.warn(&format!(" chr{}: skipped ({e})", chrom.label())); + continue; + } + }; + let kept = ld_prune::ld_prune_chromosome( + &view, + *chrom, + &pheno.y, + &pheno.x, + &pheno.pheno_mask, + cands, + ¶ms, + )?; + out.status(&format!( + " chr{}: {} / {} variants kept", + chrom.label(), + kept.len(), + cands.len(), + )); + kept_all.extend(kept); + } + + write_output(&config.output, &kept_all)?; + out.success(&format!( + "ld-prune: {} variants retained -> {}", + kept_all.len(), + config.output.display() + )); + out.result_json(&json!({ + "status": "ok", + "output_path": config.output.to_string_lossy(), + "n_kept": kept_all.len(), + })); + Ok(()) +} + +fn parse_candidates( + path: &std::path::Path, +) -> Result>, CohortError> { + let content = std::fs::read_to_string(path) + .map_err(|e| CohortError::Resource(format!("read {}: {e}", path.display())))?; + + let mut by_chrom: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + + for (lineno, raw) in content.lines().enumerate() { + let line = raw.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + let parts: Vec<&str> = if line.contains('\t') { + line.split('\t').collect() + } else { + line.split(':').collect() + }; + if parts.len() < 4 { + return Err(CohortError::Input(format!( + "variants file {}:{}: expected CHR{{:\\t}}POS{{:\\t}}REF{{:\\t}}ALT, got '{raw}'", + path.display(), + lineno + 1 + ))); + } + let chrom: Chromosome = parts[0].parse().map_err(|e: String| { + CohortError::Input(format!( + "variants file {}:{}: {e}", + path.display(), + lineno + 1 + )) + })?; + let position: u32 = parts[1].parse().map_err(|e| { + CohortError::Input(format!( + "variants file {}:{}: bad position '{}': {e}", + path.display(), + lineno + 1, + parts[1] + )) + })?; + by_chrom.entry(chrom).or_default().push(Candidate { + position, + ref_allele: parts[2].to_string(), + alt_allele: parts[3].to_string(), + }); + } + + Ok(by_chrom) +} + +fn write_output(path: &std::path::Path, kept: &[KeptVariant]) -> Result<(), CohortError> { + if let Some(parent) = path.parent() { + if !parent.as_os_str().is_empty() { + std::fs::create_dir_all(parent).map_err(|e| { + CohortError::Resource(format!("create {}: {e}", parent.display())) + })?; + } + } + let mut f = std::fs::File::create(path) + .map_err(|e| CohortError::Resource(format!("create {}: {e}", path.display())))?; + writeln!(f, "CHR\tPOS\tREF\tALT\tentry_log10p") + .map_err(|e| CohortError::Resource(format!("write {}: {e}", path.display())))?; + for v in kept { + writeln!( + f, + "{}\t{}\t{}\t{}\t{}", + v.chromosome.label(), + v.position, + v.ref_allele, + v.alt_allele, + format_log10p(v.entry_log10p), + ) + .map_err(|e| CohortError::Resource(format!("write {}: {e}", path.display())))?; + } + Ok(()) +} + +fn format_log10p(lp: f64) -> String { + if lp.is_infinite() { + "Inf".into() + } else { + format!("{:.6}", lp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_candidates_colon_and_tsv() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("candidates.txt"); + std::fs::write( + &path, + "# header\n1:100:A:T\n1\t200\tG\tC\n2:300:C:A\n\n", + ) + .unwrap(); + let map = parse_candidates(&path).unwrap(); + assert_eq!(map.get(&Chromosome::Autosome(1)).unwrap().len(), 2); + assert_eq!(map.get(&Chromosome::Autosome(2)).unwrap().len(), 1); + } + + #[test] + fn parse_candidates_rejects_short_rows() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("bad.txt"); + std::fs::write(&path, "1:100:A\n").unwrap(); + assert!(parse_candidates(&path).is_err()); + } + + #[test] + fn build_config_rejects_bad_maf() { + let dir = tempfile::tempdir().unwrap(); + let pheno = dir.path().join("pheno.tsv"); + let vars = dir.path().join("vars.tsv"); + std::fs::write(&pheno, "id\ttrait\n").unwrap(); + std::fs::write(&vars, "1:100:A:T\n").unwrap(); + let args = LdPruneArgs { + cohort: "c1".into(), + phenotype: pheno.clone(), + trait_name: "trait".into(), + covariates: vec![], + variants: vars.clone(), + maf_cutoff: 0.8, + cond_p_thresh: 1e-4, + column_map: vec![], + output: None, + }; + assert!(matches!(build_config(args), Err(CohortError::Input(_)))); + } +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index e9ca2b7..011921f 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -4,6 +4,7 @@ pub mod enrich; pub mod ingest; pub mod inspect; pub mod interpret; +pub mod ld_prune; pub mod meta_staar; pub mod staar; pub mod store; diff --git a/src/main.rs b/src/main.rs index 5e8d858..ebb430b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -199,6 +199,35 @@ fn run( dry_run, ) } + Command::LdPrune { + cohort, + phenotype, + trait_name, + covariates, + variants, + maf_cutoff, + cond_p_thresh, + column_map, + output: output_path, + } => { + let engine = runtime::Engine::open(store_path)?; + commands::ld_prune::run( + &engine, + commands::ld_prune::LdPruneArgs { + cohort, + phenotype, + trait_name, + covariates, + variants, + maf_cutoff, + cond_p_thresh, + column_map, + output: output_path, + }, + out, + dry_run, + ) + } Command::MetaStaar { studies, masks, diff --git a/src/staar/ld_prune.rs b/src/staar/ld_prune.rs new file mode 100644 index 0000000..34640e5 --- /dev/null +++ b/src/staar/ld_prune.rs @@ -0,0 +1,334 @@ +//! LD pruning via sequential conditional analysis. +//! +//! Forward selection on conditional p-values, not r² correlations. Starts +//! with the most significant marginal variant, then at each step refits +//! the null on `[X, G_known]` and picks the candidate with the smallest +//! conditional p-value. Stops when no remaining candidate beats +//! `cond_p_thresh`. +//! +//! Mirrors STAARpipeline R/LD_pruning.R lines 146–185 on the gaussian, +//! unrelated, single-trait path. + +use faer::Mat; + +use crate::error::CohortError; +use crate::staar::carrier::sparse_score::{carriers_to_dense_compact, individual_score_test}; +use crate::staar::carrier::AnalysisVectors; +use crate::staar::model::{augment_covariates, fit_glm}; +use crate::store::cohort::types::VariantVcf; +use crate::store::cohort::variants::CarrierList; +use crate::store::cohort::ChromosomeView; +use crate::types::Chromosome; + +#[derive(Clone, Debug)] +pub struct Candidate { + pub position: u32, + pub ref_allele: String, + pub alt_allele: String, +} + +#[derive(Clone, Debug)] +pub struct KeptVariant { + pub chromosome: Chromosome, + pub position: u32, + pub ref_allele: String, + pub alt_allele: String, + /// −log10 of the p-value under which this variant entered the pruned + /// set. The first pick carries its marginal p; later picks carry their + /// conditional p at the time of selection. + pub entry_log10p: f64, +} + +pub struct LdPruneParams { + pub maf_cutoff: f64, + pub cond_p_thresh: f64, +} + +/// Resolve candidates to cohort-indexed carriers, MAF-filter, and hand off +/// to `ld_prune_from_carriers`. Keeps disk IO and math isolated so tests +/// can exercise the loop with synthetic carrier lists. +pub fn ld_prune_chromosome( + view: &ChromosomeView<'_>, + chromosome: Chromosome, + y: &Mat, + x_base: &Mat, + pheno_mask: &[bool], + candidates: &[Candidate], + params: &LdPruneParams, +) -> Result, CohortError> { + if candidates.is_empty() { + return Ok(Vec::new()); + } + + let index = view.index()?; + let all_entries = index.all_entries(); + + let mut matched: Vec<(u32, Candidate)> = Vec::with_capacity(candidates.len()); + for c in candidates { + for (i, e) in all_entries.iter().enumerate() { + if e.position == c.position + && e.ref_allele.as_ref() == c.ref_allele.as_str() + && e.alt_allele.as_ref() == c.alt_allele.as_str() + && e.maf > params.maf_cutoff + { + matched.push((i as u32, c.clone())); + break; + } + } + } + if matched.is_empty() { + return Ok(Vec::new()); + } + + matched.sort_by_key(|(v, _)| *v); + let sorted_vcfs: Vec = matched.iter().map(|(v, _)| VariantVcf(*v)).collect(); + let sorted_candidates: Vec = matched.into_iter().map(|(_, c)| c).collect(); + let carriers: Vec = view.carriers_batch(&sorted_vcfs)?.entries; + + ld_prune_from_carriers( + chromosome, + &sorted_candidates, + &carriers, + y, + x_base, + pheno_mask, + params.cond_p_thresh, + ) +} + +/// Core forward-selection loop over pre-loaded carriers. `candidates[i]` +/// identifies the variant that produced `carriers[i]`; vectors must have +/// the same length. +pub fn ld_prune_from_carriers( + chromosome: Chromosome, + candidates: &[Candidate], + carriers: &[CarrierList], + y: &Mat, + x_base: &Mat, + pheno_mask: &[bool], + cond_p_thresh: f64, +) -> Result, CohortError> { + assert_eq!(candidates.len(), carriers.len()); + let m = carriers.len(); + + if m == 0 { + return Ok(Vec::new()); + } + + let null0 = fit_glm(y, x_base); + let analysis0 = AnalysisVectors::from_null_model(&null0, pheno_mask)?; + + if m == 1 { + let s = individual_score_test(&carriers[0], &analysis0); + let c = &candidates[0]; + return Ok(vec![KeptVariant { + chromosome, + position: c.position, + ref_allele: c.ref_allele.clone(), + alt_allele: c.alt_allele.clone(), + entry_log10p: safe_log10p(s.pvalue), + }]); + } + + let mut best = 0usize; + let mut best_log10p = f64::NEG_INFINITY; + for (i, carrier) in carriers.iter().enumerate() { + let s = individual_score_test(carrier, &analysis0); + let lp = safe_log10p(s.pvalue); + if lp > best_log10p { + best_log10p = lp; + best = i; + } + } + + let mut known: Vec = vec![best]; + let mut kept: Vec = Vec::with_capacity(4); + kept.push(kept_variant(chromosome, &candidates[best], best_log10p)); + + let cond_log10_thresh = -cond_p_thresh.log10(); + + loop { + if known.len() == m { + break; + } + + let known_carriers: Vec = + known.iter().map(|&i| carriers[i].clone()).collect(); + let g_known = carriers_to_dense_compact( + &known_carriers, + &analysis0.vcf_to_pheno, + analysis0.n_pheno, + ); + let x_cond = augment_covariates(x_base, &g_known); + let null_cond = fit_glm(y, &x_cond); + let analysis_cond = AnalysisVectors::from_null_model(&null_cond, pheno_mask)?; + + let mut pick: Option = None; + let mut pick_log10p = cond_log10_thresh; + for (i, carrier) in carriers.iter().enumerate() { + if known.contains(&i) { + continue; + } + let s = individual_score_test(carrier, &analysis_cond); + let lp = safe_log10p(s.pvalue); + if lp > pick_log10p { + pick_log10p = lp; + pick = Some(i); + } + } + + let Some(idx) = pick else { break }; + known.push(idx); + kept.push(kept_variant(chromosome, &candidates[idx], pick_log10p)); + } + + Ok(kept) +} + +fn kept_variant(chromosome: Chromosome, c: &Candidate, entry_log10p: f64) -> KeptVariant { + KeptVariant { + chromosome, + position: c.position, + ref_allele: c.ref_allele.clone(), + alt_allele: c.alt_allele.clone(), + entry_log10p, + } +} + +#[inline] +fn safe_log10p(p: f64) -> f64 { + if p > 0.0 { + -p.log10() + } else { + f64::INFINITY + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::store::cohort::variants::{CarrierEntry, CarrierList}; + + fn dense_carrier(n: usize, dosages: impl Fn(usize) -> u8) -> CarrierList { + let entries: Vec = (0..n) + .map(|i| CarrierEntry { + sample_idx: i as u32, + dosage: dosages(i), + }) + .filter(|e| e.dosage != 0) + .collect(); + CarrierList { entries } + } + + fn fake_candidate(pos: u32) -> Candidate { + Candidate { + position: pos, + ref_allele: "A".into(), + alt_allele: "T".into(), + } + } + + #[test] + fn single_variant_returns_one_row() { + let n = 200; + let pheno_mask = vec![true; n]; + let y = Mat::from_fn(n, 1, |i, _| (i as f64) * 0.01); + let x = Mat::from_fn(n, 1, |_, _| 1.0); + let carriers = vec![dense_carrier(n, |i| if i < 10 { 1 } else { 0 })]; + let cands = vec![fake_candidate(100)]; + + let kept = ld_prune_from_carriers( + Chromosome::Autosome(1), + &cands, + &carriers, + &y, + &x, + &pheno_mask, + 1e-4, + ) + .unwrap(); + assert_eq!(kept.len(), 1); + assert_eq!(kept[0].position, 100); + } + + #[test] + fn perfectly_collinear_variant_is_pruned() { + // Two variants carry identical genotype; the partner should drop + // out on the conditional pass because its signal is already + // absorbed by the first pick. + let n = 400; + let pheno_mask = vec![true; n]; + // y correlates with carrier pattern so marginal p is tiny. + let y = Mat::from_fn(n, 1, |i, _| if i < 40 { 2.0 } else { 0.0 }); + let x = Mat::from_fn(n, 1, |_, _| 1.0); + let pattern = |i: usize| if i < 40 { 1u8 } else { 0u8 }; + let carriers = vec![ + dense_carrier(n, pattern), + dense_carrier(n, pattern), + ]; + let cands = vec![fake_candidate(100), fake_candidate(200)]; + + let kept = ld_prune_from_carriers( + Chromosome::Autosome(1), + &cands, + &carriers, + &y, + &x, + &pheno_mask, + 1e-4, + ) + .unwrap(); + assert_eq!(kept.len(), 1, "collinear partner must not be kept"); + } + + #[test] + fn independent_signals_both_kept() { + // Two orthogonal carrier patterns on two disjoint halves of y. + let n = 400; + let pheno_mask = vec![true; n]; + let mut y_vals = vec![0.0_f64; n]; + for v in y_vals.iter_mut().take(20) { + *v = 3.0; + } + for v in y_vals.iter_mut().take(220).skip(200) { + *v = 3.0; + } + let y = Mat::from_fn(n, 1, |i, _| y_vals[i]); + let x = Mat::from_fn(n, 1, |_, _| 1.0); + let carriers = vec![ + dense_carrier(n, |i| if i < 20 { 1 } else { 0 }), + dense_carrier(n, |i| if (200..220).contains(&i) { 1 } else { 0 }), + ]; + let cands = vec![fake_candidate(100), fake_candidate(500)]; + + let kept = ld_prune_from_carriers( + Chromosome::Autosome(1), + &cands, + &carriers, + &y, + &x, + &pheno_mask, + 1e-4, + ) + .unwrap(); + assert_eq!(kept.len(), 2, "orthogonal signals must both survive"); + } + + #[test] + fn empty_input_returns_empty() { + let n = 10; + let y = Mat::::zeros(n, 1); + let x = Mat::from_fn(n, 1, |_, _| 1.0); + let kept = ld_prune_from_carriers( + Chromosome::Autosome(1), + &[], + &[], + &y, + &x, + &vec![true; n], + 1e-4, + ) + .unwrap(); + assert!(kept.is_empty()); + } +} diff --git a/src/staar/mod.rs b/src/staar/mod.rs index 3fef1d4..fc922a7 100644 --- a/src/staar/mod.rs +++ b/src/staar/mod.rs @@ -6,6 +6,7 @@ mod ground_truth_test; #[cfg(test)] mod invariance_test; pub mod kinship; +pub mod ld_prune; pub mod masks; pub mod meta; pub mod model;