Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/add-stata-support.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for Stata (.dta) file format. Input and output files are auto-detected by extension; all CLI subcommands now accept .dta files alongside CSV.
16 changes: 9 additions & 7 deletions policyengine_taxsim/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .core.yaml_generator import generate_pe_tests_yaml
from .core.input_mapper import form_household_situation
from .core.utils import get_state_code, convert_taxsim32_dependents
from .core.io import read_input, write_output
except ImportError:
from policyengine_taxsim.runners.policyengine_runner import PolicyEngineRunner
from policyengine_taxsim.runners.taxsim_runner import TaxsimRunner
Expand All @@ -27,6 +28,7 @@
get_state_code,
convert_taxsim32_dependents,
)
from policyengine_taxsim.core.io import read_input, write_output


def _generate_yaml_files(input_df: pd.DataFrame, results_df: pd.DataFrame):
Expand Down Expand Up @@ -169,7 +171,7 @@ def policyengine(input_file, output, logs, disable_salt, assume_w2_wages, sample
"""
try:
# Read input file
df = pd.read_csv(input_file)
df = read_input(input_file)

# Apply sampling if requested
if sample and sample < len(df):
Expand All @@ -192,7 +194,7 @@ def policyengine(input_file, output, logs, disable_salt, assume_w2_wages, sample
click.echo(f"Generated {len(df_with_ids)} YAML test files")

# Save results to output file
results_df.to_csv(output, index=False)
write_output(results_df, output)
click.echo(f"Results saved to {output}")

except Exception as e:
Expand All @@ -213,7 +215,7 @@ def taxsim(input_file, output, sample, taxsim_path):
"""Run TAXSIM-35 tax calculations"""
try:
# Load and optionally sample data
df = pd.read_csv(input_file)
df = read_input(input_file)

if sample and sample < len(df):
click.echo(f"Sampling {sample} records from {len(df)} total records")
Expand All @@ -224,7 +226,7 @@ def taxsim(input_file, output, sample, taxsim_path):
results = runner.run()

# Save results
results.to_csv(output, index=False)
write_output(results, output)
click.echo(f"TAXSIM results saved to: {output}")

except Exception as e:
Expand Down Expand Up @@ -258,7 +260,7 @@ def compare(input_file, sample, output_dir, year, disable_salt, logs, assume_w2_
"""Compare PolicyEngine and TAXSIM results"""
try:
# Load and optionally sample data
df = pd.read_csv(input_file)
df = read_input(input_file)

# Override year column if specified
if year is not None and "year" in df.columns:
Expand Down Expand Up @@ -344,7 +346,7 @@ def compare(input_file, sample, output_dir, year, disable_salt, logs, assume_w2_
def sample_data(input_file, sample, output):
"""Sample records from a large dataset"""
try:
df = pd.read_csv(input_file)
df = read_input(input_file)

if not sample:
click.echo(
Expand All @@ -368,7 +370,7 @@ def sample_data(input_file, sample, output):
)

# Save sampled data
sampled_df.to_csv(output, index=False)
write_output(sampled_df, output)
click.echo(f"Sampled {sample} records from {len(df)} total records")
click.echo(f"Sampled data saved to: {output}")

Expand Down
36 changes: 36 additions & 0 deletions policyengine_taxsim/core/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""File I/O helpers supporting CSV and Stata (.dta) formats.

Format is auto-detected from file extensions:
.dta → Stata
anything else → CSV (the default)
"""

import pandas as pd
from pathlib import Path
from typing import Union


STATA_EXTENSIONS = {".dta"}


def _is_stata(path: Union[str, Path]) -> bool:
return Path(path).suffix.lower() in STATA_EXTENSIONS


def read_input(path: Union[str, Path]) -> pd.DataFrame:
"""Read a TAXSIM-format input file (CSV or Stata)."""
if _is_stata(path):
return pd.read_stata(path)
return pd.read_csv(path)


def write_output(
df: pd.DataFrame,
path: Union[str, Path],
index: bool = False,
) -> None:
"""Write a DataFrame to CSV or Stata based on file extension."""
if _is_stata(path):
df.to_stata(path, write_index=index)
else:
df.to_csv(path, index=index)
4 changes: 3 additions & 1 deletion policyengine_taxsim/exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def main(input_file, output, logs, disable_salt):
generate_household, export_household = get_mappers()

# Read input file
df = pd.read_csv(input_file)
from policyengine_taxsim.core.io import read_input

df = read_input(input_file)

# Process each row
idtl_0_results = []
Expand Down
11 changes: 8 additions & 3 deletions policyengine_taxsim/runners/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from typing import Optional, Union
from pathlib import Path

try:
from ..core.io import write_output
except ImportError:
from policyengine_taxsim.core.io import write_output


class BaseTaxRunner(ABC):
"""Abstract base class for tax calculation runners
Expand Down Expand Up @@ -87,14 +92,14 @@ def save_input(self, output_path: Union[str, Path]):
output_path: Path where to save the input data
"""
output_path = Path(output_path)
self.input_df.to_csv(output_path, index=False)
write_output(self.input_df, output_path)
print(f"Input data saved to: {output_path}")

def save_results(
self, output_path: Union[str, Path], results_df: Optional[pd.DataFrame] = None
):
"""
Save results to CSV file
Save results to file (CSV or Stata based on extension)

Args:
output_path: Path where to save the results
Expand All @@ -108,7 +113,7 @@ def save_results(
results_df = self.results

output_path = Path(output_path)
results_df.to_csv(output_path, index=False)
write_output(results_df, output_path)
print(f"Results saved to: {output_path}")

def get_record_count(self) -> int:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_stata_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Tests for Stata (.dta) file format support."""

import pandas as pd
import pytest
from pathlib import Path

from policyengine_taxsim.core.io import read_input, write_output


@pytest.fixture
def sample_df():
return pd.DataFrame(
{
"taxsimid": [1, 2],
"year": [2024, 2024],
"state": [5, 6],
"mstat": [1, 2],
"pwages": [50000.0, 80000.0],
}
)


class TestReadInput:
def test_read_csv(self, tmp_path, sample_df):
csv_path = tmp_path / "input.csv"
sample_df.to_csv(csv_path, index=False)

result = read_input(csv_path)
assert len(result) == 2
assert list(result.columns) == list(sample_df.columns)

def test_read_stata(self, tmp_path, sample_df):
dta_path = tmp_path / "input.dta"
sample_df.to_stata(dta_path, write_index=False)

result = read_input(dta_path)
assert len(result) == 2
assert "taxsimid" in result.columns
assert "pwages" in result.columns

def test_read_unknown_extension_defaults_to_csv(self, tmp_path, sample_df):
txt_path = tmp_path / "input.txt"
sample_df.to_csv(txt_path, index=False)

result = read_input(txt_path)
assert len(result) == 2


class TestWriteOutput:
def test_write_csv(self, tmp_path, sample_df):
csv_path = tmp_path / "output.csv"
write_output(sample_df, csv_path)

result = pd.read_csv(csv_path)
assert len(result) == 2
assert list(result.columns) == list(sample_df.columns)

def test_write_stata(self, tmp_path, sample_df):
dta_path = tmp_path / "output.dta"
write_output(sample_df, dta_path)

result = pd.read_stata(dta_path)
assert len(result) == 2
assert "taxsimid" in result.columns

def test_roundtrip_stata(self, tmp_path, sample_df):
"""Write then read a .dta file and verify data is preserved."""
dta_path = tmp_path / "roundtrip.dta"
write_output(sample_df, dta_path)
result = read_input(dta_path)

pd.testing.assert_frame_equal(
result[sample_df.columns].reset_index(drop=True),
sample_df.reset_index(drop=True),
check_dtype=False,
)