From f6b7c9464bf94a136562d3306ea002201685cd71 Mon Sep 17 00:00:00 2001 From: PavelMakarchuk Date: Tue, 10 Mar 2026 01:13:01 +0100 Subject: [PATCH] Add Stata (.dta) file format support for all CLI subcommands Auto-detects format by file extension. Introduces a thin I/O helper (core/io.py) that dispatches to pd.read_stata/to_stata for .dta files and falls back to CSV for everything else. The stdin/stdout default mode remains CSV-only since Stata is a binary format. Closes #758 Co-Authored-By: Claude Opus 4.6 --- changelog.d/add-stata-support.added.md | 1 + policyengine_taxsim/cli.py | 16 +++-- policyengine_taxsim/core/io.py | 36 ++++++++++ policyengine_taxsim/exe.py | 4 +- policyengine_taxsim/runners/base_runner.py | 11 +++- tests/test_stata_io.py | 76 ++++++++++++++++++++++ 6 files changed, 133 insertions(+), 11 deletions(-) create mode 100644 changelog.d/add-stata-support.added.md create mode 100644 policyengine_taxsim/core/io.py create mode 100644 tests/test_stata_io.py diff --git a/changelog.d/add-stata-support.added.md b/changelog.d/add-stata-support.added.md new file mode 100644 index 00000000..ac669d1e --- /dev/null +++ b/changelog.d/add-stata-support.added.md @@ -0,0 +1 @@ +Added support for Stata (.dta) file format. Input and output files are auto-detected by extension; all CLI subcommands now accept .dta files alongside CSV. diff --git a/policyengine_taxsim/cli.py b/policyengine_taxsim/cli.py index ada17a40..816231df 100644 --- a/policyengine_taxsim/cli.py +++ b/policyengine_taxsim/cli.py @@ -12,6 +12,7 @@ from .core.yaml_generator import generate_pe_tests_yaml from .core.input_mapper import form_household_situation from .core.utils import get_state_code, convert_taxsim32_dependents + from .core.io import read_input, write_output except ImportError: from policyengine_taxsim.runners.policyengine_runner import PolicyEngineRunner from policyengine_taxsim.runners.taxsim_runner import TaxsimRunner @@ -27,6 +28,7 @@ get_state_code, convert_taxsim32_dependents, ) + from policyengine_taxsim.core.io import read_input, write_output def _generate_yaml_files(input_df: pd.DataFrame, results_df: pd.DataFrame): @@ -169,7 +171,7 @@ def policyengine(input_file, output, logs, disable_salt, assume_w2_wages, sample """ try: # Read input file - df = pd.read_csv(input_file) + df = read_input(input_file) # Apply sampling if requested if sample and sample < len(df): @@ -192,7 +194,7 @@ def policyengine(input_file, output, logs, disable_salt, assume_w2_wages, sample click.echo(f"Generated {len(df_with_ids)} YAML test files") # Save results to output file - results_df.to_csv(output, index=False) + write_output(results_df, output) click.echo(f"Results saved to {output}") except Exception as e: @@ -213,7 +215,7 @@ def taxsim(input_file, output, sample, taxsim_path): """Run TAXSIM-35 tax calculations""" try: # Load and optionally sample data - df = pd.read_csv(input_file) + df = read_input(input_file) if sample and sample < len(df): click.echo(f"Sampling {sample} records from {len(df)} total records") @@ -224,7 +226,7 @@ def taxsim(input_file, output, sample, taxsim_path): results = runner.run() # Save results - results.to_csv(output, index=False) + write_output(results, output) click.echo(f"TAXSIM results saved to: {output}") except Exception as e: @@ -258,7 +260,7 @@ def compare(input_file, sample, output_dir, year, disable_salt, logs, assume_w2_ """Compare PolicyEngine and TAXSIM results""" try: # Load and optionally sample data - df = pd.read_csv(input_file) + df = read_input(input_file) # Override year column if specified if year is not None and "year" in df.columns: @@ -344,7 +346,7 @@ def compare(input_file, sample, output_dir, year, disable_salt, logs, assume_w2_ def sample_data(input_file, sample, output): """Sample records from a large dataset""" try: - df = pd.read_csv(input_file) + df = read_input(input_file) if not sample: click.echo( @@ -368,7 +370,7 @@ def sample_data(input_file, sample, output): ) # Save sampled data - sampled_df.to_csv(output, index=False) + write_output(sampled_df, output) click.echo(f"Sampled {sample} records from {len(df)} total records") click.echo(f"Sampled data saved to: {output}") diff --git a/policyengine_taxsim/core/io.py b/policyengine_taxsim/core/io.py new file mode 100644 index 00000000..5bea0c95 --- /dev/null +++ b/policyengine_taxsim/core/io.py @@ -0,0 +1,36 @@ +"""File I/O helpers supporting CSV and Stata (.dta) formats. + +Format is auto-detected from file extensions: + .dta → Stata + anything else → CSV (the default) +""" + +import pandas as pd +from pathlib import Path +from typing import Union + + +STATA_EXTENSIONS = {".dta"} + + +def _is_stata(path: Union[str, Path]) -> bool: + return Path(path).suffix.lower() in STATA_EXTENSIONS + + +def read_input(path: Union[str, Path]) -> pd.DataFrame: + """Read a TAXSIM-format input file (CSV or Stata).""" + if _is_stata(path): + return pd.read_stata(path) + return pd.read_csv(path) + + +def write_output( + df: pd.DataFrame, + path: Union[str, Path], + index: bool = False, +) -> None: + """Write a DataFrame to CSV or Stata based on file extension.""" + if _is_stata(path): + df.to_stata(path, write_index=index) + else: + df.to_csv(path, index=index) diff --git a/policyengine_taxsim/exe.py b/policyengine_taxsim/exe.py index 0c19771c..b28e904e 100644 --- a/policyengine_taxsim/exe.py +++ b/policyengine_taxsim/exe.py @@ -49,7 +49,9 @@ def main(input_file, output, logs, disable_salt): generate_household, export_household = get_mappers() # Read input file - df = pd.read_csv(input_file) + from policyengine_taxsim.core.io import read_input + + df = read_input(input_file) # Process each row idtl_0_results = [] diff --git a/policyengine_taxsim/runners/base_runner.py b/policyengine_taxsim/runners/base_runner.py index bd36e963..33889fe8 100644 --- a/policyengine_taxsim/runners/base_runner.py +++ b/policyengine_taxsim/runners/base_runner.py @@ -3,6 +3,11 @@ from typing import Optional, Union from pathlib import Path +try: + from ..core.io import write_output +except ImportError: + from policyengine_taxsim.core.io import write_output + class BaseTaxRunner(ABC): """Abstract base class for tax calculation runners @@ -87,14 +92,14 @@ def save_input(self, output_path: Union[str, Path]): output_path: Path where to save the input data """ output_path = Path(output_path) - self.input_df.to_csv(output_path, index=False) + write_output(self.input_df, output_path) print(f"Input data saved to: {output_path}") def save_results( self, output_path: Union[str, Path], results_df: Optional[pd.DataFrame] = None ): """ - Save results to CSV file + Save results to file (CSV or Stata based on extension) Args: output_path: Path where to save the results @@ -108,7 +113,7 @@ def save_results( results_df = self.results output_path = Path(output_path) - results_df.to_csv(output_path, index=False) + write_output(results_df, output_path) print(f"Results saved to: {output_path}") def get_record_count(self) -> int: diff --git a/tests/test_stata_io.py b/tests/test_stata_io.py new file mode 100644 index 00000000..d25d6562 --- /dev/null +++ b/tests/test_stata_io.py @@ -0,0 +1,76 @@ +"""Tests for Stata (.dta) file format support.""" + +import pandas as pd +import pytest +from pathlib import Path + +from policyengine_taxsim.core.io import read_input, write_output + + +@pytest.fixture +def sample_df(): + return pd.DataFrame( + { + "taxsimid": [1, 2], + "year": [2024, 2024], + "state": [5, 6], + "mstat": [1, 2], + "pwages": [50000.0, 80000.0], + } + ) + + +class TestReadInput: + def test_read_csv(self, tmp_path, sample_df): + csv_path = tmp_path / "input.csv" + sample_df.to_csv(csv_path, index=False) + + result = read_input(csv_path) + assert len(result) == 2 + assert list(result.columns) == list(sample_df.columns) + + def test_read_stata(self, tmp_path, sample_df): + dta_path = tmp_path / "input.dta" + sample_df.to_stata(dta_path, write_index=False) + + result = read_input(dta_path) + assert len(result) == 2 + assert "taxsimid" in result.columns + assert "pwages" in result.columns + + def test_read_unknown_extension_defaults_to_csv(self, tmp_path, sample_df): + txt_path = tmp_path / "input.txt" + sample_df.to_csv(txt_path, index=False) + + result = read_input(txt_path) + assert len(result) == 2 + + +class TestWriteOutput: + def test_write_csv(self, tmp_path, sample_df): + csv_path = tmp_path / "output.csv" + write_output(sample_df, csv_path) + + result = pd.read_csv(csv_path) + assert len(result) == 2 + assert list(result.columns) == list(sample_df.columns) + + def test_write_stata(self, tmp_path, sample_df): + dta_path = tmp_path / "output.dta" + write_output(sample_df, dta_path) + + result = pd.read_stata(dta_path) + assert len(result) == 2 + assert "taxsimid" in result.columns + + def test_roundtrip_stata(self, tmp_path, sample_df): + """Write then read a .dta file and verify data is preserved.""" + dta_path = tmp_path / "roundtrip.dta" + write_output(sample_df, dta_path) + result = read_input(dta_path) + + pd.testing.assert_frame_equal( + result[sample_df.columns].reset_index(drop=True), + sample_df.reset_index(drop=True), + check_dtype=False, + )