From 770e6d0365deca2ea71e3b2da4b5d1e8ddc523fa Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Sat, 25 Oct 2025 23:16:29 +0200 Subject: [PATCH 1/2] Add regex selectors and membership operators --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/common.rs | 189 ++++++++++++++++++++++++++++++++- src/cut.rs | 31 ++++-- src/expression.rs | 263 ++++++++++++++++++++++++++++++++++++++++++---- src/summarize.rs | 34 ++++-- 6 files changed, 479 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6d80c91..3ac7c72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,7 +603,7 @@ dependencies = [ [[package]] name = "tsvkit" -version = "0.9.5" +version = "0.9.6" dependencies = [ "anyhow", "calamine", diff --git a/Cargo.toml b/Cargo.toml index 9be5e12..6be690c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tsvkit" -version = "0.9.5" +version = "0.9.6" edition = "2024" [dependencies] diff --git a/src/common.rs b/src/common.rs index ffbed4c..3af34f8 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::fs::File; use std::io::{self, BufReader}; use std::path::Path; @@ -5,6 +6,7 @@ use std::path::Path; use anyhow::{Context, Result, anyhow, bail}; use csv::ReaderBuilder; use flate2::read::MultiGzDecoder; +use regex::Regex; use xz2::read::XzDecoder; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -27,6 +29,7 @@ pub enum ColumnSelector { Index(usize), FromEnd(usize), Name(String), + Regex(String), Range(Option>, Option>), Special(SpecialColumn), } @@ -99,7 +102,24 @@ pub fn resolve_selectors( selectors: &[ColumnSelector], no_header: bool, ) -> Result> { - let mut indices = Vec::with_capacity(selectors.len()); + resolve_selectors_with_options(headers, selectors, no_header, false) +} + +pub fn resolve_selectors_allow_duplicates( + headers: &[String], + selectors: &[ColumnSelector], + no_header: bool, +) -> Result> { + resolve_selectors_with_options(headers, selectors, no_header, true) +} + +fn resolve_selectors_with_options( + headers: &[String], + selectors: &[ColumnSelector], + no_header: bool, + allow_duplicates: bool, +) -> Result> { + let mut indices = Vec::new(); for selector in selectors { match selector { ColumnSelector::Special(special) => { @@ -108,9 +128,13 @@ pub fn resolve_selectors( special.default_header() ); } - ColumnSelector::Index(_) | ColumnSelector::FromEnd(_) | ColumnSelector::Name(_) => { - let index = resolve_selector_index(headers, selector, no_header)?; - indices.push(index); + ColumnSelector::Index(_) + | ColumnSelector::FromEnd(_) + | ColumnSelector::Name(_) + | ColumnSelector::Regex(_) => { + let mut resolved = + resolve_selector_indices(headers, selector, no_header, allow_duplicates)?; + indices.append(&mut resolved); } ColumnSelector::Range(start, end) => { if headers.is_empty() { @@ -282,6 +306,9 @@ fn parse_simple_selector(token: &str) -> Result { if token.is_empty() { return Err(anyhow!("empty column selector")); } + if let Some(regex) = parse_regex_literal(token)? { + return Ok(ColumnSelector::Regex(regex)); + } if let Some(literal) = parse_backtick_literal(token)? { return Ok(ColumnSelector::Name(literal)); } @@ -314,6 +341,46 @@ fn parse_simple_selector(token: &str) -> Result { Ok(ColumnSelector::Name(token.to_string())) } +fn parse_regex_literal(token: &str) -> Result> { + let trimmed = token.trim(); + if !trimmed.starts_with('~') { + return Ok(None); + } + let remainder = trimmed[1..].trim_start(); + let mut chars = remainder.chars(); + match chars.next() { + Some('"') => { + let mut value = String::new(); + let mut escaped = false; + while let Some(ch) = chars.next() { + if escaped { + value.push(ch); + escaped = false; + continue; + } + match ch { + '\\' => { + escaped = true; + } + '"' => { + if !chars.as_str().is_empty() { + bail!("unexpected trailing characters after regex selector literal"); + } + return Ok(Some(value)); + } + other => value.push(other), + } + } + bail!("unterminated regex selector literal"); + } + Some(other) => bail!( + "regex column selector must use double quotes (e.g. ~\"pattern\"), got '{}'", + other + ), + None => bail!("regex column selector requires a quoted pattern"), + } +} + fn parse_backtick_literal(token: &str) -> Result> { let trimmed = token.trim(); if !trimmed.starts_with('`') { @@ -521,6 +588,89 @@ fn tokenize_selector_spec(spec: &str) -> Result> { Ok(tokens) } +fn resolve_selector_indices( + headers: &[String], + selector: &ColumnSelector, + no_header: bool, + allow_duplicates: bool, +) -> Result> { + match selector { + ColumnSelector::Index(idx) => { + let index = *idx; + if index >= headers.len() { + bail!( + "column index {} out of range ({} columns)", + index + 1, + headers.len() + ); + } + Ok(vec![index]) + } + ColumnSelector::FromEnd(offset) => { + let offset = *offset; + if offset == 0 { + bail!("column selector '-0' is not allowed"); + } + if offset > headers.len() { + bail!( + "column selector '-{}' out of range ({} columns)", + offset, + headers.len() + ); + } + Ok(vec![headers.len() - offset]) + } + ColumnSelector::Name(name) => { + if no_header { + bail!("column names cannot be used when input lacks a header row"); + } + if allow_duplicates { + let mut matches = Vec::new(); + for (idx, header) in headers.iter().enumerate() { + if header == name { + matches.push(idx); + } + } + if matches.is_empty() { + bail!("column '{}' not found", name); + } + Ok(matches) + } else { + let index = headers + .iter() + .position(|h| h == name) + .with_context(|| format!("column '{}' not found", name))?; + Ok(vec![index]) + } + } + ColumnSelector::Regex(pattern) => { + if no_header { + bail!("regex column selectors require headers"); + } + let regex = Regex::new(pattern) + .with_context(|| format!("invalid regex pattern '{}'", pattern))?; + let mut seen = HashSet::new(); + let mut matches = Vec::new(); + for (idx, header) in headers.iter().enumerate() { + if regex.is_match(header) { + if allow_duplicates || seen.insert(header.clone()) { + matches.push(idx); + } + } + } + if matches.is_empty() { + bail!("regex pattern '{}' did not match any columns", pattern); + } + Ok(matches) + } + ColumnSelector::Range(_, _) => unreachable!("range selectors handled separately"), + ColumnSelector::Special(special) => bail!( + "special column '{}' not supported without column injection", + special.default_header() + ), + } +} + fn resolve_selector_index( headers: &[String], selector: &ColumnSelector, @@ -562,6 +712,9 @@ fn resolve_selector_index( .with_context(|| format!("column '{}' not found", name))?; Ok(index) } + ColumnSelector::Regex(_) => { + bail!("regex column selectors cannot be used in range endpoints") + } ColumnSelector::Special(special) => bail!( "special column '{}' not supported without column injection", special.default_header() @@ -576,7 +729,7 @@ fn resolve_selector_index( mod tests { use super::{ ColumnSelector, SpecialColumn, parse_selector_list, parse_single_selector, - resolve_selectors, + resolve_selectors, resolve_selectors_allow_duplicates, }; #[test] @@ -702,4 +855,30 @@ mod tests { assert!(matches!(selectors[1], ColumnSelector::Name(ref name) if name == "__file__")); assert!(matches!(selectors[2], ColumnSelector::Name(ref name) if name == "__base__")); } + + #[test] + fn regex_selector_matches_columns() { + let headers = vec![ + "sample_a".to_string(), + "other".to_string(), + "sample_b".to_string(), + ]; + let selectors = parse_selector_list("~\"^sample_\"").unwrap(); + let indices = resolve_selectors(&headers, &selectors, false).unwrap(); + assert_eq!(indices, vec![0, 2]); + } + + #[test] + fn allow_duplicates_includes_repeated_headers() { + let headers = vec![ + "value".to_string(), + "value".to_string(), + "other".to_string(), + ]; + let selectors = parse_selector_list("value").unwrap(); + let indices = resolve_selectors(&headers, &selectors, false).unwrap(); + assert_eq!(indices, vec![0]); + let indices = resolve_selectors_allow_duplicates(&headers, &selectors, false).unwrap(); + assert_eq!(indices, vec![0, 1]); + } } diff --git a/src/cut.rs b/src/cut.rs index f162696..65a6c66 100644 --- a/src/cut.rs +++ b/src/cut.rs @@ -6,20 +6,20 @@ use clap::Args; use crate::common::{ ColumnSelector, InputOptions, SpecialColumn, default_headers, parse_selector_list, - reader_for_path, resolve_selectors, should_skip_record, + reader_for_path, resolve_selectors, resolve_selectors_allow_duplicates, should_skip_record, }; #[derive(Args, Debug)] #[command( about = "Select and reorder TSV columns", - long_about = "Pick columns by name or 1-based index. Combine comma-separated selectors with ranges (colA:colD or 2:6) and single fields in one spec. Defaults to header-aware mode; add -H for headerless input.\n\nExamples:\n tsvkit cut -f id,sample3,sample1 examples/profiles.tsv\n tsvkit cut -f 'Purity,sample:FN,F1' examples/profiles.tsv\n tsvkit cut -H -f 3,1 data.tsv" + long_about = "Pick columns by name or 1-based index. Combine comma-separated selectors with ranges (colA:colD or 2:6) and single fields in one spec. Use ~\"regex\" to match columns by pattern. Defaults to header-aware mode; add -H for headerless input.\n\nExamples:\n tsvkit cut -f id,sample3,sample1 examples/profiles.tsv\n tsvkit cut -f 'Purity,sample:FN,F1' examples/profiles.tsv\n tsvkit cut -H -f 3,1 data.tsv" )] pub struct CutArgs { /// Input TSV file(s) (use '-' for stdin; supports gz/xz) #[arg(value_name = "FILES", num_args = 0.., default_values = ["-"])] pub files: Vec, - /// Fields to select, using names, 1-based indices, ranges (`colA:colD`, `2:5`), or mixes. Comma-separated list. + /// Fields to select, using names, 1-based indices, ranges (`colA:colD`, `2:5`), regex (`~"^sample"`), or mixes. Comma-separated list. #[arg(short = 'f', long = "fields", value_name = "COLS", required = true)] pub fields: String, @@ -47,6 +47,10 @@ pub struct CutArgs { /// Ignore rows whose column count differs from the header/first row #[arg(short = 'I', long = "ignore-illegal-row")] pub ignore_illegal_row: bool, + + /// Allow duplicate column matches when resolving names or regex selectors + #[arg(short = 'D', long = "allow-dups")] + pub allow_dups: bool, } pub fn run(args: CutArgs) -> Result<()> { @@ -73,6 +77,7 @@ pub fn run(args: CutArgs) -> Result<()> { &file_info, &input_opts, &mut writer, + args.allow_dups, )?; } else { process_header_file( @@ -84,6 +89,7 @@ pub fn run(args: CutArgs) -> Result<()> { &input_opts, &mut writer, &mut header_emitted, + args.allow_dups, )?; } } @@ -99,6 +105,7 @@ fn process_no_header_file( file_info: &FileInfo, input_opts: &InputOptions, writer: &mut BufWriter>, + allow_duplicates: bool, ) -> Result<()> { let mut records = reader.records(); let first_record = loop { @@ -115,7 +122,7 @@ fn process_no_header_file( }; let expected_width = first_record.len(); let headers = default_headers(expected_width); - let columns = build_cut_columns(&headers, selectors, true)?; + let columns = build_cut_columns(&headers, selectors, true, allow_duplicates)?; emit_record(&first_record, &columns, file_info, writer)?; for record in records { let record = record.with_context(|| format!("failed reading from {:?}", path))?; @@ -136,6 +143,7 @@ fn process_header_file( input_opts: &InputOptions, writer: &mut BufWriter>, header_emitted: &mut bool, + allow_duplicates: bool, ) -> Result<()> { let headers = reader .headers() @@ -143,7 +151,7 @@ fn process_header_file( .iter() .map(|s| s.to_string()) .collect::>(); - let columns = build_cut_columns(&headers, selectors, false)?; + let columns = build_cut_columns(&headers, selectors, false, allow_duplicates)?; let expected_width = headers.len(); if !*header_emitted { @@ -199,6 +207,7 @@ fn build_cut_columns( headers: &[String], selectors: &[ColumnSelector], no_header: bool, + allow_duplicates: bool, ) -> Result> { let mut columns = Vec::new(); for selector in selectors { @@ -214,11 +223,19 @@ fn build_cut_columns( { bail!("special columns cannot be used within a range selector"); } - let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + let indices = if allow_duplicates { + resolve_selectors_allow_duplicates(headers, &[selector.clone()], no_header)? + } else { + resolve_selectors(headers, &[selector.clone()], no_header)? + }; columns.extend(indices.into_iter().map(CutColumn::Index)); } _ => { - let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + let indices = if allow_duplicates { + resolve_selectors_allow_duplicates(headers, &[selector.clone()], no_header)? + } else { + resolve_selectors(headers, &[selector.clone()], no_header)? + }; columns.extend(indices.into_iter().map(CutColumn::Index)); } } diff --git a/src/expression.rs b/src/expression.rs index 347f18c..98662c5 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -35,6 +35,7 @@ pub enum ValueExpr { default: Option>, }, RegexCall(Box, Box), + List(Vec), } #[derive(Debug, Clone)] @@ -100,6 +101,8 @@ pub enum CompareOp { Le, RegexMatch, RegexNotMatch, + In, + NotIn, } #[derive(Debug, Clone)] @@ -177,6 +180,13 @@ pub fn bind_expression(expr: Expr, headers: &[String], no_header: bool) -> Resul invert: matches!(op, CompareOp::RegexNotMatch), }) } + CompareOp::In | CompareOp::NotIn => { + let left = bind_value(lhs, headers, no_header)?; + ensure_scalar_bound_value(&left, "left-hand side of 'in'")?; + let right = bind_value(rhs, headers, no_header)?; + ensure_membership_target(&right)?; + Ok(BoundExpr::Compare(left, op, right)) + } _ => Ok(BoundExpr::Compare( bind_value(lhs, headers, no_header)?, op, @@ -240,6 +250,7 @@ pub enum BoundValue { value: Box, pattern: RegexPattern, }, + List(Vec), } #[derive(Debug, Clone)] @@ -570,6 +581,25 @@ where bool_eval(false) } } + BoundValue::List(items) => { + if items.is_empty() { + return empty_eval(); + } + let mut combined = String::new(); + for (idx, item) in items.iter().enumerate() { + let saved = ctx.take_captures(); + let value = eval_value_with_context(item, ctx); + ctx.restore_captures(saved); + if idx > 0 { + combined.push(','); + } + combined.push_str(value.text.as_ref()); + } + EvalValue { + text: Cow::Owned(combined), + numeric: None, + } + } } } @@ -616,30 +646,102 @@ fn evaluate_compare<'a, R>( where R: RowAccessor + ?Sized, { - let left = eval_value_with_context(lhs, ctx); - let right = eval_value_with_context(rhs, ctx); - match op { - CompareOp::Eq => { - if let (Some(a), Some(b)) = (left.numeric, right.numeric) { - a == b - } else { - left.text == right.text + CompareOp::In => evaluate_membership(lhs, rhs, ctx, false), + CompareOp::NotIn => evaluate_membership(lhs, rhs, ctx, true), + CompareOp::RegexMatch | CompareOp::RegexNotMatch => false, + _ => { + let left = eval_value_with_context(lhs, ctx); + let right = eval_value_with_context(rhs, ctx); + match op { + CompareOp::Eq => { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a == b + } else { + left.text == right.text + } + } + CompareOp::Ne => { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a != b + } else { + left.text != right.text + } + } + CompareOp::Gt => compare_numeric(&left, &right, |a, b| a > b), + CompareOp::Ge => compare_numeric(&left, &right, |a, b| a >= b), + CompareOp::Lt => compare_numeric(&left, &right, |a, b| a < b), + CompareOp::Le => compare_numeric(&left, &right, |a, b| a <= b), + _ => false, } } - CompareOp::Ne => { - if let (Some(a), Some(b)) = (left.numeric, right.numeric) { - a != b - } else { - left.text != right.text + } +} + +fn evaluate_membership<'a, R>( + lhs: &'a BoundValue, + rhs: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, + invert: bool, +) -> bool +where + R: RowAccessor + ?Sized, +{ + let left_eval = eval_value_with_context(lhs, ctx); + let found = match rhs { + BoundValue::List(values) => { + let mut matched = false; + for value in values { + let saved = ctx.take_captures(); + let candidate = eval_value_with_context(value, ctx); + ctx.restore_captures(saved); + if eval_values_equal(&left_eval, &candidate) { + matched = true; + break; + } + } + matched + } + BoundValue::Column(idx) => { + let text = ctx.row().get(*idx).unwrap_or(""); + eval_matches_text(&left_eval, text) + } + BoundValue::Columns(indices) => { + let mut matched = false; + for idx in indices { + let text = ctx.row().get(*idx).unwrap_or(""); + if eval_matches_text(&left_eval, text) { + matched = true; + break; + } + } + matched + } + other => { + let candidate = eval_value_with_context(other, ctx); + eval_values_equal(&left_eval, &candidate) + } + }; + if invert { !found } else { found } +} + +fn eval_values_equal(left: &EvalValue<'_>, right: &EvalValue<'_>) -> bool { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a == b + } else { + left.text == right.text + } +} + +fn eval_matches_text(left: &EvalValue<'_>, candidate: &str) -> bool { + if let Some(left_num) = left.numeric { + if let Some(candidate_num) = parse_float(candidate) { + if left_num == candidate_num { + return true; } } - CompareOp::Gt => compare_numeric(&left, &right, |a, b| a > b), - CompareOp::Ge => compare_numeric(&left, &right, |a, b| a >= b), - CompareOp::Lt => compare_numeric(&left, &right, |a, b| a < b), - CompareOp::Le => compare_numeric(&left, &right, |a, b| a <= b), - CompareOp::RegexMatch | CompareOp::RegexNotMatch => false, } + left.text == candidate } fn evaluate_regex<'a, R>( @@ -794,6 +896,29 @@ fn bind_value(value: ValueExpr, headers: &[String], no_header: bool) -> Result { + let mut bound_items = Vec::with_capacity(items.len()); + for item in items { + let bound = bind_value(item, headers, no_header)?; + ensure_scalar_bound_value(&bound, "list element")?; + bound_items.push(bound); + } + Ok(BoundValue::List(bound_items)) + } + } +} + +fn ensure_scalar_bound_value(value: &BoundValue, context: &str) -> Result<()> { + if matches!(value, BoundValue::Columns(_) | BoundValue::List(_)) { + bail!("{} must resolve to a single value", context); + } + Ok(()) +} + +fn ensure_membership_target(value: &BoundValue) -> Result<()> { + match value { + BoundValue::List(_) | BoundValue::Column(_) | BoundValue::Columns(_) => Ok(()), + _ => bail!("right-hand side of 'in'/'!in' must be a list or column selector"), } } @@ -1094,6 +1219,55 @@ mod tests { let result = eval_value(&bound, &row); assert_eq!(result.numeric, Some(5.0)); } + + #[test] + fn membership_operator_accepts_literal_list() { + let expr = parse_expression("$group in [\"case\",\"control\"]").unwrap(); + let headers = vec!["group".to_string()]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let row = csv::StringRecord::from(vec!["case"]); + assert!(evaluate(&bound, &row)); + let row = csv::StringRecord::from(vec!["test"]); + assert!(!evaluate(&bound, &row)); + } + + #[test] + fn membership_operator_supports_negation() { + let expr = parse_expression("$status !in [\"fail\",\"missing\"]").unwrap(); + let headers = vec!["status".to_string()]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let row = csv::StringRecord::from(vec!["ok"]); + assert!(evaluate(&bound, &row)); + let row = csv::StringRecord::from(vec!["fail"]); + assert!(!evaluate(&bound, &row)); + } + + #[test] + fn membership_operator_accepts_column_range() { + let expr = parse_expression("$value in $a:$c").unwrap(); + let headers = vec![ + "value".to_string(), + "a".to_string(), + "b".to_string(), + "c".to_string(), + ]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let row = csv::StringRecord::from(vec!["x", "w", "x", "z"]); + assert!(evaluate(&bound, &row)); + let row = csv::StringRecord::from(vec!["q", "w", "x", "z"]); + assert!(!evaluate(&bound, &row)); + } + + #[test] + fn membership_requires_list_on_rhs() { + let expr = parse_expression("$value in 5").unwrap(); + let headers = vec!["value".to_string()]; + let err = bind_expression(expr, &headers, false).unwrap_err(); + assert!( + err.to_string() + .contains("right-hand side of 'in'/'!in' must be a list") + ); + } } impl<'a> Lexer<'a> { @@ -1134,6 +1308,9 @@ impl<'a> Lexer<'a> { } else if self.peek_char(1) == Some(b'~') { self.pos += 2; Ok(Some(Token::Compare(CompareOp::RegexNotMatch))) + } else if self.match_keyword(1, "in") { + self.pos += 3; + Ok(Some(Token::Compare(CompareOp::NotIn))) } else { self.pos += 1; Ok(Some(Token::Not)) @@ -1217,7 +1394,11 @@ impl<'a> Lexer<'a> { c if c.is_ascii_digit() || c == b'.' => self.lex_number(), c if c.is_ascii_alphabetic() || c == b'_' => { let ident = self.lex_identifier(); - Ok(Some(Token::Ident(ident))) + if ident.eq_ignore_ascii_case("in") { + Ok(Some(Token::Compare(CompareOp::In))) + } else { + Ok(Some(Token::Ident(ident))) + } } _ => bail!("unexpected character '{}' in expression", ch as char), } @@ -1404,6 +1585,26 @@ impl<'a> Lexer<'a> { self.pos < self.chars.len() && self.chars[self.pos] == expected } + fn match_keyword(&self, offset: usize, keyword: &str) -> bool { + let start = self.pos + offset; + let end = start + keyword.len(); + if end > self.chars.len() { + return false; + } + let slice = &self.chars[start..end]; + let Ok(text) = std::str::from_utf8(slice) else { + return false; + }; + if !text.eq_ignore_ascii_case(keyword) { + return false; + } + if end == self.chars.len() { + return true; + } + let next = self.chars[end]; + !next.is_ascii_alphanumeric() && next != b'_' + } + fn peek_char(&self, offset: usize) -> Option { self.chars.get(self.pos + offset).copied() } @@ -1755,6 +1956,11 @@ impl Parser { unreachable!() } } + Some(Token::LBracket) => { + self.pos += 1; + let items = self.parse_list_literal_values()?; + Ok(ValueExpr::List(items)) + } Some(Token::String(_)) => { if let Some(Token::String(value)) = self.advance().cloned() { Ok(ValueExpr::String(value)) @@ -2041,6 +2247,25 @@ impl Parser { } } + fn parse_list_literal_values(&mut self) -> Result> { + let mut values = Vec::new(); + if self.consume_token(TokenKind::RBracket) { + return Ok(values); + } + loop { + let value = self.parse_arith()?; + values.push(value); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RBracket) { + break; + } else { + bail!("expected ',' or ']' in list literal"); + } + } + Ok(values) + } + fn consume_compare(&mut self) -> Option { if let Some(Token::Compare(op)) = self.peek_token().cloned() { self.pos += 1; diff --git a/src/summarize.rs b/src/summarize.rs index 70823df..022b141 100644 --- a/src/summarize.rs +++ b/src/summarize.rs @@ -10,7 +10,8 @@ use indexmap::IndexMap; use crate::common::{ InputOptions, default_headers, parse_selector_list, parse_single_selector, reader_for_path, - resolve_selectors, resolve_single_selector, should_skip_record, + resolve_selectors, resolve_selectors_allow_duplicates, resolve_single_selector, + should_skip_record, }; #[derive(Args, Debug)] @@ -27,7 +28,7 @@ pub struct SummarizeArgs { #[arg(short = 'g', long = "group", value_name = "COLS")] pub group_cols: Option, - /// Statistics to compute (`COLUMN=ops`). Columns accept names, indices, ranges (e.g. `IL6:IL10`); operations are comma-separated (sum, mean, median, sd, var, min, max, mode, distinct, q*/p* aliases). Repeatable. + /// Statistics to compute (`COLUMN=ops`). Columns accept names, indices, ranges (e.g. `IL6:IL10`), and regex (`~"^sample"`); operations are comma-separated (sum, mean, median, sd, var, min, max, mode, distinct, q*/p* aliases). Repeatable. #[arg(short = 's', long = "stat", value_name = "COLUMN=OPS", required = true)] pub stats: Vec, @@ -51,6 +52,10 @@ pub struct SummarizeArgs { /// Ignore rows whose column count differs from the header/first row #[arg(short = 'I', long = "ignore-illegal-row")] pub ignore_illegal_row: bool, + + /// Allow duplicate column matches when resolving names or regex selectors + #[arg(short = 'D', long = "allow-dups")] + pub allow_dups: bool, } #[derive(Debug, Clone, PartialEq)] @@ -665,7 +670,7 @@ pub fn run(args: SummarizeArgs) -> Result<()> { }; headers = default_headers(first_record.len()); group_indices = parse_group_indices(args.group_cols.as_deref(), &headers, true)?; - stat_requests = parse_stat_requests(&args.stats, &headers, true)?; + stat_requests = parse_stat_requests(&args.stats, &headers, true, args.allow_dups)?; process_record(&mut groups, &group_indices, &stat_requests, &first_record); for record in records { @@ -683,7 +688,7 @@ pub fn run(args: SummarizeArgs) -> Result<()> { .map(|s| s.to_string()) .collect::>(); group_indices = parse_group_indices(args.group_cols.as_deref(), &headers, false)?; - stat_requests = parse_stat_requests(&args.stats, &headers, false)?; + stat_requests = parse_stat_requests(&args.stats, &headers, false, args.allow_dups)?; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; @@ -716,7 +721,12 @@ fn parse_group_indices( } } -fn expand_target_columns(spec: &str, headers: &[String], no_header: bool) -> Result> { +fn expand_target_columns( + spec: &str, + headers: &[String], + no_header: bool, + allow_duplicates: bool, +) -> Result> { let mut indices = Vec::new(); for token in spec.split(',') { let token = token.trim(); @@ -726,7 +736,11 @@ fn expand_target_columns(spec: &str, headers: &[String], no_header: bool) -> Res let parts: Vec<&str> = token.split(':').collect(); if parts.len() == 1 { let selectors = parse_selector_list(token)?; - let resolved = resolve_selectors(headers, &selectors, no_header)?; + let resolved = if allow_duplicates { + resolve_selectors_allow_duplicates(headers, &selectors, no_header)? + } else { + resolve_selectors(headers, &selectors, no_header)? + }; indices.extend(resolved); } else if parts.len() == 2 { let start_part = parts[0].trim(); @@ -768,6 +782,7 @@ fn parse_stat_requests( specs: &[String], headers: &[String], no_header: bool, + allow_duplicates: bool, ) -> Result> { if specs.is_empty() { bail!("at least one --stat specification is required"); @@ -779,7 +794,8 @@ fn parse_stat_requests( .split_once('=') .or_else(|| spec.split_once(':')) .with_context(|| format!("expected COLUMN=ops in '{}'", spec))?; - let column_indices = expand_target_columns(column_part.trim(), headers, no_header)?; + let column_indices = + expand_target_columns(column_part.trim(), headers, no_header, allow_duplicates)?; let ops = ops_part .split(',') @@ -1064,9 +1080,9 @@ mod tests { "c".to_string(), "d".to_string(), ]; - let end_open = expand_target_columns("2:", &headers, false).unwrap(); + let end_open = expand_target_columns("2:", &headers, false, false).unwrap(); assert_eq!(end_open, vec![1, 2, 3]); - let start_open = expand_target_columns(":3", &headers, false).unwrap(); + let start_open = expand_target_columns(":3", &headers, false, false).unwrap(); assert_eq!(start_open, vec![0, 1, 2]); } From 76229f19001bf5290d936b322e09616e9fdcdafd Mon Sep 17 00:00:00 2001 From: "Z.-L. Deng" Date: Sat, 25 Oct 2025 23:22:51 +0200 Subject: [PATCH 2/2] Document regex selectors and membership operators --- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 36 ++++++- src/common.rs | 189 ++++++++++++++++++++++++++++++++- src/cut.rs | 31 ++++-- src/expression.rs | 263 ++++++++++++++++++++++++++++++++++++++++++---- src/summarize.rs | 34 ++++-- 7 files changed, 510 insertions(+), 47 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6d80c91..3ac7c72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,7 +603,7 @@ dependencies = [ [[package]] name = "tsvkit" -version = "0.9.5" +version = "0.9.6" dependencies = [ "anyhow", "calamine", diff --git a/Cargo.toml b/Cargo.toml index 9be5e12..6be690c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tsvkit" -version = "0.9.5" +version = "0.9.6" edition = "2024" [dependencies] diff --git a/README.md b/README.md index 4962d49..3625c63 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,8 @@ ### Key features - Stream-friendly processing; every command reads from files or standard input and writes to standard output. -- Column selectors that accept names, 1-based indices, ranges, and multi-file specifications. -- Expression language with arithmetic, comparisons, logical operators, regex matching, and numeric helper functions. +- Column selectors that accept names, 1-based indices, ranges, regexes, and multi-file specifications. +- Expression language with arithmetic, comparisons, logical operators, list membership, regex matching, and numeric helper functions. - Aggregations for grouped summaries (`summarize`) and row-wise calculations (`mutate`). - Excel tooling to inspect, preview, export, and assemble `.xlsx` workbooks. @@ -120,14 +120,17 @@ Selectors are reused in `cut`, `filter`, `join`, `mutate`, `summarize`, and othe | `index` | 1-based column index. | `1,4,9` | | `-index` | Column counted from the end (1 = last). | `-1,-2` | | `start:end` | Inclusive range by name or index. Supports open ends. | `IL6:IL10`, `2:5`, `:IL10`, `IL6:` | +| `~"regex"` | Columns whose names match the regular expression. Requires headers. | `~"^sample_"` | | `:` | Select every column in order. | `-f ':'` | -| `mixed` | Combine names, indices, and ranges. | `sample_id,3:5,tech` | +| `mixed` | Combine names, indices, ranges, and regexes. | `sample_id,3:5,~"_pct$"` | | `multi-file` | Separate selectors for each input with semicolons (primarily `join`). | `sample_id;subject_id` | | `range in expressions` | Prefixed with `$` to access a slice of values. | `$IL6:$IL10` | > Wrap selectors in backticks or braces to treat punctuation literally. For example, ``-f '`IL6:IL10`,`total,reads`'`` or `-f '{IL6:IL10},{total,reads}'` selects columns named `IL6:IL10` and `total,reads` instead of expanding a range or splitting on the comma. -Negative indices are also valid inside ranges: `:-2` selects every column except the final two, while `-3:` keeps the last three columns. +Negative indices are also valid inside ranges: `:-2` selects every column except the final two, while `-3:` keeps the last three columns. Regex selectors deduplicate by first match; add `--allow-dups` (or `-D`) on `cut`/`summarize` when you need repeated columns. + +> Regex selectors require a header row. When `-H/--no-header` is active, using `~"..."` results in an error with guidance to remove the regex or restore headers. Anywhere you access column *values* inside an expression, prefix the selector with `$` (`$purity`, `$1`, `$IL6:$IL10`). @@ -152,9 +155,13 @@ The same expression language powers `filter -e`, `mutate -e name=EXPR`, and rege | `!` / `not` | Logical negation. | Booleans | | `~` | Regex match. Right-hand side can be literal text or a `$range`. | Strings | | `!~` | Regex does *not* match. | Strings | +| `in` | Membership test against a list literal or numeric range. | `$group in ["case","control"]` | +| `!in` | Negated membership test. | `$status !in ["fail","missing"]` | > Reference columns whose names contain operators or punctuation with `${column-name}` inside expressions (e.g. `${dna-} - $rna_ug`). This prevents the parser from treating the characters as arithmetic. +List literals use square brackets: `[1,2,3]`, `["case","control"]`, `[IL6:IL10]`. Combine them with `in`/`!in` to test membership, or pass them to helper functions that accept lists. + **Numeric helper functions** | Function | Description | @@ -228,13 +235,25 @@ Ranges expand consecutive columns automatically: tsvkit cut -f 'sample_id,IL6:IL10' examples/cytokines.tsv ``` +Regex selectors pick up columns whose headers match a pattern. Combine them with names, indices, and ranges in any order: + +```bash +tsvkit cut -f '1,group,~"^IL",~"_pct$"' examples/qc.tsv +``` + +Matches deduplicate by default; add `-D/--allow-dups` to keep every occurrence when multiple selectors target the same column. + ### `filter` -Filter rows with boolean logic, arithmetic, column ranges, and regexes. +Filter rows with boolean logic, arithmetic, column ranges, regexes, and list membership tests. ```bash tsvkit filter -e '$group == "case" & $purity >= 0.94' examples/samples.tsv ``` +```bash +tsvkit filter -e '$status !in ["fail","missing","error"] & $tech ~ "sRNA"' examples/samples.tsv +``` + **Expression building blocks for `filter`** | Building block | Examples | Notes | @@ -248,6 +267,7 @@ tsvkit filter -e '$group == "case" & $purity >= 0.94' examples/samples.tsv | Row-wise aggregators | `sum($dna_ug:$rna_ug)`, `mode($1,$3)`, `countunique($gene:)` | Same catalog as [`summarize`](#summarize): totals, quantiles (`q*` / `p*`), variance/SD, products, entropy, argmin/argmax, membership stats. Works with ranges, lists, and open selectors. | | Regex match | `$tech ~ "sRNA"`, `$notes !~ "(?i)fail"` | Patterns follow Rust `regex` syntax. `(?i)` enables case-insensitive matching. | | Regex across ranges | `$gene:$notes ~ "kinase"`, `~ "control"` | When the left-hand side is omitted, `~` scans all columns. | +| Membership | `$group in ["case","control"]`, `$rank in [1:3]` | Right-hand side must be a list literal or numeric range. | **Regex usage at a glance** @@ -329,6 +349,12 @@ tsvkit summarize \ examples/samples.tsv ``` +Regex selectors work here as well, so you can summarize whole families of columns in one shot: + +```bash +tsvkit summarize -g patient -s '~"^sample_"=mean,sd' -D cohort.tsv +``` + **Aggregators supported by `summarize`** _Counts & membership_ diff --git a/src/common.rs b/src/common.rs index ffbed4c..3af34f8 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::fs::File; use std::io::{self, BufReader}; use std::path::Path; @@ -5,6 +6,7 @@ use std::path::Path; use anyhow::{Context, Result, anyhow, bail}; use csv::ReaderBuilder; use flate2::read::MultiGzDecoder; +use regex::Regex; use xz2::read::XzDecoder; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -27,6 +29,7 @@ pub enum ColumnSelector { Index(usize), FromEnd(usize), Name(String), + Regex(String), Range(Option>, Option>), Special(SpecialColumn), } @@ -99,7 +102,24 @@ pub fn resolve_selectors( selectors: &[ColumnSelector], no_header: bool, ) -> Result> { - let mut indices = Vec::with_capacity(selectors.len()); + resolve_selectors_with_options(headers, selectors, no_header, false) +} + +pub fn resolve_selectors_allow_duplicates( + headers: &[String], + selectors: &[ColumnSelector], + no_header: bool, +) -> Result> { + resolve_selectors_with_options(headers, selectors, no_header, true) +} + +fn resolve_selectors_with_options( + headers: &[String], + selectors: &[ColumnSelector], + no_header: bool, + allow_duplicates: bool, +) -> Result> { + let mut indices = Vec::new(); for selector in selectors { match selector { ColumnSelector::Special(special) => { @@ -108,9 +128,13 @@ pub fn resolve_selectors( special.default_header() ); } - ColumnSelector::Index(_) | ColumnSelector::FromEnd(_) | ColumnSelector::Name(_) => { - let index = resolve_selector_index(headers, selector, no_header)?; - indices.push(index); + ColumnSelector::Index(_) + | ColumnSelector::FromEnd(_) + | ColumnSelector::Name(_) + | ColumnSelector::Regex(_) => { + let mut resolved = + resolve_selector_indices(headers, selector, no_header, allow_duplicates)?; + indices.append(&mut resolved); } ColumnSelector::Range(start, end) => { if headers.is_empty() { @@ -282,6 +306,9 @@ fn parse_simple_selector(token: &str) -> Result { if token.is_empty() { return Err(anyhow!("empty column selector")); } + if let Some(regex) = parse_regex_literal(token)? { + return Ok(ColumnSelector::Regex(regex)); + } if let Some(literal) = parse_backtick_literal(token)? { return Ok(ColumnSelector::Name(literal)); } @@ -314,6 +341,46 @@ fn parse_simple_selector(token: &str) -> Result { Ok(ColumnSelector::Name(token.to_string())) } +fn parse_regex_literal(token: &str) -> Result> { + let trimmed = token.trim(); + if !trimmed.starts_with('~') { + return Ok(None); + } + let remainder = trimmed[1..].trim_start(); + let mut chars = remainder.chars(); + match chars.next() { + Some('"') => { + let mut value = String::new(); + let mut escaped = false; + while let Some(ch) = chars.next() { + if escaped { + value.push(ch); + escaped = false; + continue; + } + match ch { + '\\' => { + escaped = true; + } + '"' => { + if !chars.as_str().is_empty() { + bail!("unexpected trailing characters after regex selector literal"); + } + return Ok(Some(value)); + } + other => value.push(other), + } + } + bail!("unterminated regex selector literal"); + } + Some(other) => bail!( + "regex column selector must use double quotes (e.g. ~\"pattern\"), got '{}'", + other + ), + None => bail!("regex column selector requires a quoted pattern"), + } +} + fn parse_backtick_literal(token: &str) -> Result> { let trimmed = token.trim(); if !trimmed.starts_with('`') { @@ -521,6 +588,89 @@ fn tokenize_selector_spec(spec: &str) -> Result> { Ok(tokens) } +fn resolve_selector_indices( + headers: &[String], + selector: &ColumnSelector, + no_header: bool, + allow_duplicates: bool, +) -> Result> { + match selector { + ColumnSelector::Index(idx) => { + let index = *idx; + if index >= headers.len() { + bail!( + "column index {} out of range ({} columns)", + index + 1, + headers.len() + ); + } + Ok(vec![index]) + } + ColumnSelector::FromEnd(offset) => { + let offset = *offset; + if offset == 0 { + bail!("column selector '-0' is not allowed"); + } + if offset > headers.len() { + bail!( + "column selector '-{}' out of range ({} columns)", + offset, + headers.len() + ); + } + Ok(vec![headers.len() - offset]) + } + ColumnSelector::Name(name) => { + if no_header { + bail!("column names cannot be used when input lacks a header row"); + } + if allow_duplicates { + let mut matches = Vec::new(); + for (idx, header) in headers.iter().enumerate() { + if header == name { + matches.push(idx); + } + } + if matches.is_empty() { + bail!("column '{}' not found", name); + } + Ok(matches) + } else { + let index = headers + .iter() + .position(|h| h == name) + .with_context(|| format!("column '{}' not found", name))?; + Ok(vec![index]) + } + } + ColumnSelector::Regex(pattern) => { + if no_header { + bail!("regex column selectors require headers"); + } + let regex = Regex::new(pattern) + .with_context(|| format!("invalid regex pattern '{}'", pattern))?; + let mut seen = HashSet::new(); + let mut matches = Vec::new(); + for (idx, header) in headers.iter().enumerate() { + if regex.is_match(header) { + if allow_duplicates || seen.insert(header.clone()) { + matches.push(idx); + } + } + } + if matches.is_empty() { + bail!("regex pattern '{}' did not match any columns", pattern); + } + Ok(matches) + } + ColumnSelector::Range(_, _) => unreachable!("range selectors handled separately"), + ColumnSelector::Special(special) => bail!( + "special column '{}' not supported without column injection", + special.default_header() + ), + } +} + fn resolve_selector_index( headers: &[String], selector: &ColumnSelector, @@ -562,6 +712,9 @@ fn resolve_selector_index( .with_context(|| format!("column '{}' not found", name))?; Ok(index) } + ColumnSelector::Regex(_) => { + bail!("regex column selectors cannot be used in range endpoints") + } ColumnSelector::Special(special) => bail!( "special column '{}' not supported without column injection", special.default_header() @@ -576,7 +729,7 @@ fn resolve_selector_index( mod tests { use super::{ ColumnSelector, SpecialColumn, parse_selector_list, parse_single_selector, - resolve_selectors, + resolve_selectors, resolve_selectors_allow_duplicates, }; #[test] @@ -702,4 +855,30 @@ mod tests { assert!(matches!(selectors[1], ColumnSelector::Name(ref name) if name == "__file__")); assert!(matches!(selectors[2], ColumnSelector::Name(ref name) if name == "__base__")); } + + #[test] + fn regex_selector_matches_columns() { + let headers = vec![ + "sample_a".to_string(), + "other".to_string(), + "sample_b".to_string(), + ]; + let selectors = parse_selector_list("~\"^sample_\"").unwrap(); + let indices = resolve_selectors(&headers, &selectors, false).unwrap(); + assert_eq!(indices, vec![0, 2]); + } + + #[test] + fn allow_duplicates_includes_repeated_headers() { + let headers = vec![ + "value".to_string(), + "value".to_string(), + "other".to_string(), + ]; + let selectors = parse_selector_list("value").unwrap(); + let indices = resolve_selectors(&headers, &selectors, false).unwrap(); + assert_eq!(indices, vec![0]); + let indices = resolve_selectors_allow_duplicates(&headers, &selectors, false).unwrap(); + assert_eq!(indices, vec![0, 1]); + } } diff --git a/src/cut.rs b/src/cut.rs index f162696..65a6c66 100644 --- a/src/cut.rs +++ b/src/cut.rs @@ -6,20 +6,20 @@ use clap::Args; use crate::common::{ ColumnSelector, InputOptions, SpecialColumn, default_headers, parse_selector_list, - reader_for_path, resolve_selectors, should_skip_record, + reader_for_path, resolve_selectors, resolve_selectors_allow_duplicates, should_skip_record, }; #[derive(Args, Debug)] #[command( about = "Select and reorder TSV columns", - long_about = "Pick columns by name or 1-based index. Combine comma-separated selectors with ranges (colA:colD or 2:6) and single fields in one spec. Defaults to header-aware mode; add -H for headerless input.\n\nExamples:\n tsvkit cut -f id,sample3,sample1 examples/profiles.tsv\n tsvkit cut -f 'Purity,sample:FN,F1' examples/profiles.tsv\n tsvkit cut -H -f 3,1 data.tsv" + long_about = "Pick columns by name or 1-based index. Combine comma-separated selectors with ranges (colA:colD or 2:6) and single fields in one spec. Use ~\"regex\" to match columns by pattern. Defaults to header-aware mode; add -H for headerless input.\n\nExamples:\n tsvkit cut -f id,sample3,sample1 examples/profiles.tsv\n tsvkit cut -f 'Purity,sample:FN,F1' examples/profiles.tsv\n tsvkit cut -H -f 3,1 data.tsv" )] pub struct CutArgs { /// Input TSV file(s) (use '-' for stdin; supports gz/xz) #[arg(value_name = "FILES", num_args = 0.., default_values = ["-"])] pub files: Vec, - /// Fields to select, using names, 1-based indices, ranges (`colA:colD`, `2:5`), or mixes. Comma-separated list. + /// Fields to select, using names, 1-based indices, ranges (`colA:colD`, `2:5`), regex (`~"^sample"`), or mixes. Comma-separated list. #[arg(short = 'f', long = "fields", value_name = "COLS", required = true)] pub fields: String, @@ -47,6 +47,10 @@ pub struct CutArgs { /// Ignore rows whose column count differs from the header/first row #[arg(short = 'I', long = "ignore-illegal-row")] pub ignore_illegal_row: bool, + + /// Allow duplicate column matches when resolving names or regex selectors + #[arg(short = 'D', long = "allow-dups")] + pub allow_dups: bool, } pub fn run(args: CutArgs) -> Result<()> { @@ -73,6 +77,7 @@ pub fn run(args: CutArgs) -> Result<()> { &file_info, &input_opts, &mut writer, + args.allow_dups, )?; } else { process_header_file( @@ -84,6 +89,7 @@ pub fn run(args: CutArgs) -> Result<()> { &input_opts, &mut writer, &mut header_emitted, + args.allow_dups, )?; } } @@ -99,6 +105,7 @@ fn process_no_header_file( file_info: &FileInfo, input_opts: &InputOptions, writer: &mut BufWriter>, + allow_duplicates: bool, ) -> Result<()> { let mut records = reader.records(); let first_record = loop { @@ -115,7 +122,7 @@ fn process_no_header_file( }; let expected_width = first_record.len(); let headers = default_headers(expected_width); - let columns = build_cut_columns(&headers, selectors, true)?; + let columns = build_cut_columns(&headers, selectors, true, allow_duplicates)?; emit_record(&first_record, &columns, file_info, writer)?; for record in records { let record = record.with_context(|| format!("failed reading from {:?}", path))?; @@ -136,6 +143,7 @@ fn process_header_file( input_opts: &InputOptions, writer: &mut BufWriter>, header_emitted: &mut bool, + allow_duplicates: bool, ) -> Result<()> { let headers = reader .headers() @@ -143,7 +151,7 @@ fn process_header_file( .iter() .map(|s| s.to_string()) .collect::>(); - let columns = build_cut_columns(&headers, selectors, false)?; + let columns = build_cut_columns(&headers, selectors, false, allow_duplicates)?; let expected_width = headers.len(); if !*header_emitted { @@ -199,6 +207,7 @@ fn build_cut_columns( headers: &[String], selectors: &[ColumnSelector], no_header: bool, + allow_duplicates: bool, ) -> Result> { let mut columns = Vec::new(); for selector in selectors { @@ -214,11 +223,19 @@ fn build_cut_columns( { bail!("special columns cannot be used within a range selector"); } - let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + let indices = if allow_duplicates { + resolve_selectors_allow_duplicates(headers, &[selector.clone()], no_header)? + } else { + resolve_selectors(headers, &[selector.clone()], no_header)? + }; columns.extend(indices.into_iter().map(CutColumn::Index)); } _ => { - let indices = resolve_selectors(headers, &[selector.clone()], no_header)?; + let indices = if allow_duplicates { + resolve_selectors_allow_duplicates(headers, &[selector.clone()], no_header)? + } else { + resolve_selectors(headers, &[selector.clone()], no_header)? + }; columns.extend(indices.into_iter().map(CutColumn::Index)); } } diff --git a/src/expression.rs b/src/expression.rs index 347f18c..98662c5 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -35,6 +35,7 @@ pub enum ValueExpr { default: Option>, }, RegexCall(Box, Box), + List(Vec), } #[derive(Debug, Clone)] @@ -100,6 +101,8 @@ pub enum CompareOp { Le, RegexMatch, RegexNotMatch, + In, + NotIn, } #[derive(Debug, Clone)] @@ -177,6 +180,13 @@ pub fn bind_expression(expr: Expr, headers: &[String], no_header: bool) -> Resul invert: matches!(op, CompareOp::RegexNotMatch), }) } + CompareOp::In | CompareOp::NotIn => { + let left = bind_value(lhs, headers, no_header)?; + ensure_scalar_bound_value(&left, "left-hand side of 'in'")?; + let right = bind_value(rhs, headers, no_header)?; + ensure_membership_target(&right)?; + Ok(BoundExpr::Compare(left, op, right)) + } _ => Ok(BoundExpr::Compare( bind_value(lhs, headers, no_header)?, op, @@ -240,6 +250,7 @@ pub enum BoundValue { value: Box, pattern: RegexPattern, }, + List(Vec), } #[derive(Debug, Clone)] @@ -570,6 +581,25 @@ where bool_eval(false) } } + BoundValue::List(items) => { + if items.is_empty() { + return empty_eval(); + } + let mut combined = String::new(); + for (idx, item) in items.iter().enumerate() { + let saved = ctx.take_captures(); + let value = eval_value_with_context(item, ctx); + ctx.restore_captures(saved); + if idx > 0 { + combined.push(','); + } + combined.push_str(value.text.as_ref()); + } + EvalValue { + text: Cow::Owned(combined), + numeric: None, + } + } } } @@ -616,30 +646,102 @@ fn evaluate_compare<'a, R>( where R: RowAccessor + ?Sized, { - let left = eval_value_with_context(lhs, ctx); - let right = eval_value_with_context(rhs, ctx); - match op { - CompareOp::Eq => { - if let (Some(a), Some(b)) = (left.numeric, right.numeric) { - a == b - } else { - left.text == right.text + CompareOp::In => evaluate_membership(lhs, rhs, ctx, false), + CompareOp::NotIn => evaluate_membership(lhs, rhs, ctx, true), + CompareOp::RegexMatch | CompareOp::RegexNotMatch => false, + _ => { + let left = eval_value_with_context(lhs, ctx); + let right = eval_value_with_context(rhs, ctx); + match op { + CompareOp::Eq => { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a == b + } else { + left.text == right.text + } + } + CompareOp::Ne => { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a != b + } else { + left.text != right.text + } + } + CompareOp::Gt => compare_numeric(&left, &right, |a, b| a > b), + CompareOp::Ge => compare_numeric(&left, &right, |a, b| a >= b), + CompareOp::Lt => compare_numeric(&left, &right, |a, b| a < b), + CompareOp::Le => compare_numeric(&left, &right, |a, b| a <= b), + _ => false, } } - CompareOp::Ne => { - if let (Some(a), Some(b)) = (left.numeric, right.numeric) { - a != b - } else { - left.text != right.text + } +} + +fn evaluate_membership<'a, R>( + lhs: &'a BoundValue, + rhs: &'a BoundValue, + ctx: &mut EvalContext<'a, R>, + invert: bool, +) -> bool +where + R: RowAccessor + ?Sized, +{ + let left_eval = eval_value_with_context(lhs, ctx); + let found = match rhs { + BoundValue::List(values) => { + let mut matched = false; + for value in values { + let saved = ctx.take_captures(); + let candidate = eval_value_with_context(value, ctx); + ctx.restore_captures(saved); + if eval_values_equal(&left_eval, &candidate) { + matched = true; + break; + } + } + matched + } + BoundValue::Column(idx) => { + let text = ctx.row().get(*idx).unwrap_or(""); + eval_matches_text(&left_eval, text) + } + BoundValue::Columns(indices) => { + let mut matched = false; + for idx in indices { + let text = ctx.row().get(*idx).unwrap_or(""); + if eval_matches_text(&left_eval, text) { + matched = true; + break; + } + } + matched + } + other => { + let candidate = eval_value_with_context(other, ctx); + eval_values_equal(&left_eval, &candidate) + } + }; + if invert { !found } else { found } +} + +fn eval_values_equal(left: &EvalValue<'_>, right: &EvalValue<'_>) -> bool { + if let (Some(a), Some(b)) = (left.numeric, right.numeric) { + a == b + } else { + left.text == right.text + } +} + +fn eval_matches_text(left: &EvalValue<'_>, candidate: &str) -> bool { + if let Some(left_num) = left.numeric { + if let Some(candidate_num) = parse_float(candidate) { + if left_num == candidate_num { + return true; } } - CompareOp::Gt => compare_numeric(&left, &right, |a, b| a > b), - CompareOp::Ge => compare_numeric(&left, &right, |a, b| a >= b), - CompareOp::Lt => compare_numeric(&left, &right, |a, b| a < b), - CompareOp::Le => compare_numeric(&left, &right, |a, b| a <= b), - CompareOp::RegexMatch | CompareOp::RegexNotMatch => false, } + left.text == candidate } fn evaluate_regex<'a, R>( @@ -794,6 +896,29 @@ fn bind_value(value: ValueExpr, headers: &[String], no_header: bool) -> Result { + let mut bound_items = Vec::with_capacity(items.len()); + for item in items { + let bound = bind_value(item, headers, no_header)?; + ensure_scalar_bound_value(&bound, "list element")?; + bound_items.push(bound); + } + Ok(BoundValue::List(bound_items)) + } + } +} + +fn ensure_scalar_bound_value(value: &BoundValue, context: &str) -> Result<()> { + if matches!(value, BoundValue::Columns(_) | BoundValue::List(_)) { + bail!("{} must resolve to a single value", context); + } + Ok(()) +} + +fn ensure_membership_target(value: &BoundValue) -> Result<()> { + match value { + BoundValue::List(_) | BoundValue::Column(_) | BoundValue::Columns(_) => Ok(()), + _ => bail!("right-hand side of 'in'/'!in' must be a list or column selector"), } } @@ -1094,6 +1219,55 @@ mod tests { let result = eval_value(&bound, &row); assert_eq!(result.numeric, Some(5.0)); } + + #[test] + fn membership_operator_accepts_literal_list() { + let expr = parse_expression("$group in [\"case\",\"control\"]").unwrap(); + let headers = vec!["group".to_string()]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let row = csv::StringRecord::from(vec!["case"]); + assert!(evaluate(&bound, &row)); + let row = csv::StringRecord::from(vec!["test"]); + assert!(!evaluate(&bound, &row)); + } + + #[test] + fn membership_operator_supports_negation() { + let expr = parse_expression("$status !in [\"fail\",\"missing\"]").unwrap(); + let headers = vec!["status".to_string()]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let row = csv::StringRecord::from(vec!["ok"]); + assert!(evaluate(&bound, &row)); + let row = csv::StringRecord::from(vec!["fail"]); + assert!(!evaluate(&bound, &row)); + } + + #[test] + fn membership_operator_accepts_column_range() { + let expr = parse_expression("$value in $a:$c").unwrap(); + let headers = vec![ + "value".to_string(), + "a".to_string(), + "b".to_string(), + "c".to_string(), + ]; + let bound = bind_expression(expr, &headers, false).unwrap(); + let row = csv::StringRecord::from(vec!["x", "w", "x", "z"]); + assert!(evaluate(&bound, &row)); + let row = csv::StringRecord::from(vec!["q", "w", "x", "z"]); + assert!(!evaluate(&bound, &row)); + } + + #[test] + fn membership_requires_list_on_rhs() { + let expr = parse_expression("$value in 5").unwrap(); + let headers = vec!["value".to_string()]; + let err = bind_expression(expr, &headers, false).unwrap_err(); + assert!( + err.to_string() + .contains("right-hand side of 'in'/'!in' must be a list") + ); + } } impl<'a> Lexer<'a> { @@ -1134,6 +1308,9 @@ impl<'a> Lexer<'a> { } else if self.peek_char(1) == Some(b'~') { self.pos += 2; Ok(Some(Token::Compare(CompareOp::RegexNotMatch))) + } else if self.match_keyword(1, "in") { + self.pos += 3; + Ok(Some(Token::Compare(CompareOp::NotIn))) } else { self.pos += 1; Ok(Some(Token::Not)) @@ -1217,7 +1394,11 @@ impl<'a> Lexer<'a> { c if c.is_ascii_digit() || c == b'.' => self.lex_number(), c if c.is_ascii_alphabetic() || c == b'_' => { let ident = self.lex_identifier(); - Ok(Some(Token::Ident(ident))) + if ident.eq_ignore_ascii_case("in") { + Ok(Some(Token::Compare(CompareOp::In))) + } else { + Ok(Some(Token::Ident(ident))) + } } _ => bail!("unexpected character '{}' in expression", ch as char), } @@ -1404,6 +1585,26 @@ impl<'a> Lexer<'a> { self.pos < self.chars.len() && self.chars[self.pos] == expected } + fn match_keyword(&self, offset: usize, keyword: &str) -> bool { + let start = self.pos + offset; + let end = start + keyword.len(); + if end > self.chars.len() { + return false; + } + let slice = &self.chars[start..end]; + let Ok(text) = std::str::from_utf8(slice) else { + return false; + }; + if !text.eq_ignore_ascii_case(keyword) { + return false; + } + if end == self.chars.len() { + return true; + } + let next = self.chars[end]; + !next.is_ascii_alphanumeric() && next != b'_' + } + fn peek_char(&self, offset: usize) -> Option { self.chars.get(self.pos + offset).copied() } @@ -1755,6 +1956,11 @@ impl Parser { unreachable!() } } + Some(Token::LBracket) => { + self.pos += 1; + let items = self.parse_list_literal_values()?; + Ok(ValueExpr::List(items)) + } Some(Token::String(_)) => { if let Some(Token::String(value)) = self.advance().cloned() { Ok(ValueExpr::String(value)) @@ -2041,6 +2247,25 @@ impl Parser { } } + fn parse_list_literal_values(&mut self) -> Result> { + let mut values = Vec::new(); + if self.consume_token(TokenKind::RBracket) { + return Ok(values); + } + loop { + let value = self.parse_arith()?; + values.push(value); + if self.consume_token(TokenKind::Comma) { + continue; + } else if self.consume_token(TokenKind::RBracket) { + break; + } else { + bail!("expected ',' or ']' in list literal"); + } + } + Ok(values) + } + fn consume_compare(&mut self) -> Option { if let Some(Token::Compare(op)) = self.peek_token().cloned() { self.pos += 1; diff --git a/src/summarize.rs b/src/summarize.rs index 70823df..022b141 100644 --- a/src/summarize.rs +++ b/src/summarize.rs @@ -10,7 +10,8 @@ use indexmap::IndexMap; use crate::common::{ InputOptions, default_headers, parse_selector_list, parse_single_selector, reader_for_path, - resolve_selectors, resolve_single_selector, should_skip_record, + resolve_selectors, resolve_selectors_allow_duplicates, resolve_single_selector, + should_skip_record, }; #[derive(Args, Debug)] @@ -27,7 +28,7 @@ pub struct SummarizeArgs { #[arg(short = 'g', long = "group", value_name = "COLS")] pub group_cols: Option, - /// Statistics to compute (`COLUMN=ops`). Columns accept names, indices, ranges (e.g. `IL6:IL10`); operations are comma-separated (sum, mean, median, sd, var, min, max, mode, distinct, q*/p* aliases). Repeatable. + /// Statistics to compute (`COLUMN=ops`). Columns accept names, indices, ranges (e.g. `IL6:IL10`), and regex (`~"^sample"`); operations are comma-separated (sum, mean, median, sd, var, min, max, mode, distinct, q*/p* aliases). Repeatable. #[arg(short = 's', long = "stat", value_name = "COLUMN=OPS", required = true)] pub stats: Vec, @@ -51,6 +52,10 @@ pub struct SummarizeArgs { /// Ignore rows whose column count differs from the header/first row #[arg(short = 'I', long = "ignore-illegal-row")] pub ignore_illegal_row: bool, + + /// Allow duplicate column matches when resolving names or regex selectors + #[arg(short = 'D', long = "allow-dups")] + pub allow_dups: bool, } #[derive(Debug, Clone, PartialEq)] @@ -665,7 +670,7 @@ pub fn run(args: SummarizeArgs) -> Result<()> { }; headers = default_headers(first_record.len()); group_indices = parse_group_indices(args.group_cols.as_deref(), &headers, true)?; - stat_requests = parse_stat_requests(&args.stats, &headers, true)?; + stat_requests = parse_stat_requests(&args.stats, &headers, true, args.allow_dups)?; process_record(&mut groups, &group_indices, &stat_requests, &first_record); for record in records { @@ -683,7 +688,7 @@ pub fn run(args: SummarizeArgs) -> Result<()> { .map(|s| s.to_string()) .collect::>(); group_indices = parse_group_indices(args.group_cols.as_deref(), &headers, false)?; - stat_requests = parse_stat_requests(&args.stats, &headers, false)?; + stat_requests = parse_stat_requests(&args.stats, &headers, false, args.allow_dups)?; for record in reader.records() { let record = record.with_context(|| format!("failed reading from {:?}", args.file))?; @@ -716,7 +721,12 @@ fn parse_group_indices( } } -fn expand_target_columns(spec: &str, headers: &[String], no_header: bool) -> Result> { +fn expand_target_columns( + spec: &str, + headers: &[String], + no_header: bool, + allow_duplicates: bool, +) -> Result> { let mut indices = Vec::new(); for token in spec.split(',') { let token = token.trim(); @@ -726,7 +736,11 @@ fn expand_target_columns(spec: &str, headers: &[String], no_header: bool) -> Res let parts: Vec<&str> = token.split(':').collect(); if parts.len() == 1 { let selectors = parse_selector_list(token)?; - let resolved = resolve_selectors(headers, &selectors, no_header)?; + let resolved = if allow_duplicates { + resolve_selectors_allow_duplicates(headers, &selectors, no_header)? + } else { + resolve_selectors(headers, &selectors, no_header)? + }; indices.extend(resolved); } else if parts.len() == 2 { let start_part = parts[0].trim(); @@ -768,6 +782,7 @@ fn parse_stat_requests( specs: &[String], headers: &[String], no_header: bool, + allow_duplicates: bool, ) -> Result> { if specs.is_empty() { bail!("at least one --stat specification is required"); @@ -779,7 +794,8 @@ fn parse_stat_requests( .split_once('=') .or_else(|| spec.split_once(':')) .with_context(|| format!("expected COLUMN=ops in '{}'", spec))?; - let column_indices = expand_target_columns(column_part.trim(), headers, no_header)?; + let column_indices = + expand_target_columns(column_part.trim(), headers, no_header, allow_duplicates)?; let ops = ops_part .split(',') @@ -1064,9 +1080,9 @@ mod tests { "c".to_string(), "d".to_string(), ]; - let end_open = expand_target_columns("2:", &headers, false).unwrap(); + let end_open = expand_target_columns("2:", &headers, false, false).unwrap(); assert_eq!(end_open, vec![1, 2, 3]); - let start_open = expand_target_columns(":3", &headers, false).unwrap(); + let start_open = expand_target_columns(":3", &headers, false, false).unwrap(); assert_eq!(start_open, vec![0, 1, 2]); }