Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
86b2eb2
feat: add summary generation and export functionality for results
JessUWE Feb 19, 2026
ea0a63e
test: add coverage tests for table info extraction
JessUWE Feb 19, 2026
f3c6d96
fix coverage
JessUWE Feb 19, 2026
f9ef74e
test: add coverage for empty summary edge case (line 595)
JessUWE Feb 19, 2026
47359b2
test: remove unused variable idx_type
JessUWE Feb 19, 2026
f817df2
test: fix unused variable warnings by using underscore
JessUWE Feb 19, 2026
e8bc926
feat: implement session summary with differencing risk detection
JessUWE Mar 5, 2026
e3665ff
feat(record): update warning message
JessUWE Mar 5, 2026
5f14a04
Merge branch 'main' into feature/224-session-summary
JessUWE Mar 5, 2026
404bba3
Add generate_summary() to provide high-level output overview for chec…
JessUWE Mar 5, 2026
fcbe9fa
- Add multi-layered 'DO NOT RELEASE' warnings (filename, comment)
JessUWE Mar 5, 2026
4c16dcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
e1271ae
Remove tests for index and columns name extraction
JessUWE Mar 6, 2026
6a96fd2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
c2cd7b6
refactor: remove unreachable elif branches and unnecessary tests
JessUWE Mar 6, 2026
ee46528
Fixes issue where tables with identical variables but different
JessUWE Mar 10, 2026
0202e39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
61ab307
Fixes issue where tables with identical variables but different
JessUWE Mar 10, 2026
f7c8937
Merge branch 'feature/224-session-summary' of https://github.com/AI-S…
JessUWE Mar 10, 2026
603db30
fix: correct differencing risk detection for suppression settings
JessUWE Mar 10, 2026
f72f049
fix: correct differencing risk detection for suppression settings
JessUWE Mar 10, 2026
55c3320
fix: correct differencing risk detection for suppression settings
JessUWE Mar 10, 2026
6799fab
fix: correct differencing risk detection for suppression settings
JessUWE Mar 10, 2026
c6524a5
refactor: simplify test_extract_table_info_with_numeric_data and remo…
JessUWE Mar 10, 2026
7234259
Merge branch 'feature/224-session-summary' of https://github.com/AI-S…
JessUWE Mar 10, 2026
a0116b8
refactor: simplify docstring for generate_variable_matrix_table and r…
JessUWE Mar 10, 2026
1238fdd
Add per-file ignores for acro_stata_parser.py
jim-smith Mar 11, 2026
5f18e0c
fix: resolve code review issues in session summary implementation
JessUWE Mar 12, 2026
10b8ccd
Merge main branch into feature/224-session-summary
JessUWE Mar 26, 2026
8c30d0d
Refactor Record class and improve variable extraction
JessUWE Mar 26, 2026
b4a3871
Enhance assertions for summary DataFrame variables
JessUWE Mar 26, 2026
4e28a64
refactor: add regression variable extraction functions and tests
JessUWE Mar 26, 2026
9002c7f
Refactor test_initial.py and improve documentation
JessUWE Mar 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 129 additions & 9 deletions acro/acro_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import re
import warnings
from inspect import stack
from typing import Any
Expand All @@ -22,6 +23,119 @@
logger = logging.getLogger("acro")


def _get_endog_exog_variables(endog: ArrayLike, exog: ArrayLike) -> list[str]:
"""Extract variable names from endog and exog arguments.

Parameters
----------
endog : array_like
The dependent variable (Series or array).
exog : array_like
The independent variables (DataFrame, Series, or array).

Returns
-------
list[str]
List of variable names: [dependent, independent1, independent2, ...].
"""
variables: list[str] = []

if hasattr(endog, "name") and endog.name is not None:
variables.append(str(endog.name))
if hasattr(exog, "columns"):
for col in exog.columns:
if str(col) != "const":
variables.append(str(col))
elif hasattr(exog, "name") and exog.name is not None:
variables.append(str(exog.name))
return variables


def _split_formula_terms(text: str, delimiters: str = "+") -> list[str]:
"""Split a formula string on delimiters, but only outside parentheses.

Parameters
----------
text : str
The string to split.
delimiters : str
Characters to split on (e.g., '+' or ':*').

Returns
-------
list[str]
The split terms.
"""
terms: list[str] = []
depth = 0
current: list[str] = []
for char in text:
if char == "(":
depth += 1
current.append(char)
elif char == ")":
depth -= 1
current.append(char)
elif char in delimiters and depth == 0:
terms.append("".join(current))
current = []
else:
current.append(char)
terms.append("".join(current))
return terms


def _get_formula_variables(formula: str) -> list[str]: # noqa: C901
"""Extract variable names from an formula string.

Parses formulas like 'y ~ x1 + x2 + x3' to extract variable names.
Handles interaction terms (x1:x2), polynomial terms I(x^2), and
categorical terms C(x), respecting parentheses nesting.

Parameters
----------
formula : str
An R-style formula string, e.g., 'y ~ x1 + x2'.

Returns
-------
list[str]
List of variable names: [dependent, independent1, independent2, ...].
"""
variables: list[str] = []
parts = formula.split("~")
if len(parts) != 2:
return variables
dep_var = parts[0].strip()
if dep_var:
variables.append(dep_var)
rhs = parts[1].strip()
terms = _split_formula_terms(rhs, "+")
for term in terms:
term = term.strip()
if not term or term == "1":
continue
sub_terms = _split_formula_terms(term, ":*")
for sub in sub_terms:
sub = sub.strip()
if not sub or sub == "1":
continue
sub = re.sub(r"^[IC]\(", "", sub)
sub = re.sub(r"\)$", "", sub)
sub = re.sub(r"\^\d+$", "", sub)
while sub.startswith("(") and sub.endswith(")"):
sub = sub[1:-1]
sub = sub.strip()
if "+" in sub:
for inner in _split_formula_terms(sub, "+"):
inner = inner.strip()
if inner and inner not in variables:
variables.append(inner)
elif sub and sub not in variables:
variables.append(sub)
return variables


class Regression:
"""Creates regression models."""

Expand Down Expand Up @@ -73,10 +187,11 @@ def ols(
results = model.fit()
status, summary, dof = self.__check_model_dof("ols", model)
tables: list[SimpleTable] = results.summary().tables
vars_used = _get_endog_exog_variables(endog, exog)
self.results.add(
status=status,
output_type="regression",
properties={"method": "ols", "dof": dof},
properties={"method": "ols", "dof": dof, "variables": vars_used},
sdc={},
command=command,
summary=summary,
Expand All @@ -85,7 +200,7 @@ def ols(
)
return results

def olsr(
def olsr( # pylint: disable=keyword-arg-before-vararg
self,
formula: str,
data: Any,
Expand Down Expand Up @@ -144,10 +259,11 @@ def olsr(
results = model.fit()
status, summary, dof = self.__check_model_dof("olsr", model)
tables: list[SimpleTable] = results.summary().tables
vars_used = _get_formula_variables(formula)
self.results.add(
status=status,
output_type="regression",
properties={"method": "olsr", "dof": dof},
properties={"method": "olsr", "dof": dof, "variables": vars_used},
sdc={},
command=command,
summary=summary,
Expand Down Expand Up @@ -193,10 +309,11 @@ def logit(
results = model.fit()
status, summary, dof = self.__check_model_dof("logit", model)
tables: list[SimpleTable] = results.summary().tables
vars_used = _get_endog_exog_variables(endog, exog)
self.results.add(
status=status,
output_type="regression",
properties={"method": "logit", "dof": dof},
properties={"method": "logit", "dof": dof, "variables": vars_used},
sdc={},
command=command,
summary=summary,
Expand All @@ -205,7 +322,7 @@ def logit(
)
return results

def logitr(
def logitr( # pylint: disable=keyword-arg-before-vararg
self,
formula: str,
data: Any,
Expand Down Expand Up @@ -264,10 +381,11 @@ def logitr(
results = model.fit()
status, summary, dof = self.__check_model_dof("logitr", model)
tables: list[SimpleTable] = results.summary().tables
vars_used = _get_formula_variables(formula)
self.results.add(
status=status,
output_type="regression",
properties={"method": "logitr", "dof": dof},
properties={"method": "logitr", "dof": dof, "variables": vars_used},
sdc={},
command=command,
summary=summary,
Expand Down Expand Up @@ -313,10 +431,11 @@ def probit(
results = model.fit()
status, summary, dof = self.__check_model_dof("probit", model)
tables: list[SimpleTable] = results.summary().tables
vars_used = _get_endog_exog_variables(endog, exog)
self.results.add(
status=status,
output_type="regression",
properties={"method": "probit", "dof": dof},
properties={"method": "probit", "dof": dof, "variables": vars_used},
sdc={},
command=command,
summary=summary,
Expand All @@ -325,7 +444,7 @@ def probit(
)
return results

def probitr(
def probitr( # pylint: disable=keyword-arg-before-vararg
self,
formula: str,
data: Any,
Expand Down Expand Up @@ -384,10 +503,11 @@ def probitr(
results = model.fit()
status, summary, dof = self.__check_model_dof("probitr", model)
tables: list[SimpleTable] = results.summary().tables
vars_used = _get_formula_variables(formula)
self.results.add(
status=status,
output_type="regression",
properties={"method": "probitr", "dof": dof},
properties={"method": "probitr", "dof": dof, "variables": vars_used},
sdc={},
command=command,
summary=summary,
Expand Down
59 changes: 45 additions & 14 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""ACRO: Tables functions."""

# pylint: disable=too-many-lines
from __future__ import annotations

import logging
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(self, suppress: bool) -> None:
self.suppress: bool = suppress
self.results: Records = Records()

def crosstab(
def crosstab( # pylint: disable=too-many-arguments,too-many-locals # noqa: C901
self,
index: Any,
columns: Any,
Expand Down Expand Up @@ -212,14 +213,29 @@ def crosstab(
colnames=colnames,
normalize=normalize,
)
sdc = get_table_sdc(masks, self.suppress, table)

vars_used: list[str] = []
if isinstance(index, pd.Series):
vars_used.append(index.name)
elif isinstance(index, list):
for var in index:
if isinstance(var, pd.Series):
vars_used.append(var.name)
if isinstance(columns, pd.Series):
vars_used.append(columns.name)
elif isinstance(columns, list):
for var in columns:
if isinstance(var, pd.Series):
vars_used.append(var.name)
if values is not None and isinstance(values, pd.Series):
vars_used.append(values.name)

# record output

self.results.add(
status=status,
output_type="table",
properties={"method": "crosstab"},
properties={"method": "crosstab", "variables": vars_used},
sdc=sdc,
command=command,
summary=summary,
Expand All @@ -234,7 +250,7 @@ def crosstab(
)
return table

def pivot_table(
def pivot_table( # pylint: disable=too-many-arguments,too-many-locals # noqa: C901
self,
data: DataFrame,
values: Any = None,
Expand Down Expand Up @@ -422,12 +438,27 @@ def pivot_table(
observed=observed,
sort=sort,
)
sdc = get_table_sdc(masks, self.suppress, table)

vars_used: list[str] = []
if isinstance(index, list):
vars_used.extend(index)
elif index is not None:
vars_used.append(index)
if isinstance(columns, list):
vars_used.extend(columns)
elif columns is not None:
vars_used.append(columns)
if isinstance(values, list):
vars_used.extend(values)
elif values is not None:
vars_used.append(values)
vars_used = [str(v) for v in vars_used]

# record output
self.results.add(
status=status,
output_type="table",
properties={"method": "pivot_table"},
properties={"method": "pivot_table", "variables": vars_used},
sdc=sdc,
command=command,
summary=summary,
Expand All @@ -442,7 +473,7 @@ def pivot_table(
)
return table

def surv_func(
def surv_func( # pylint: disable=too-many-arguments,too-many-locals
self,
time: Any,
status: Any,
Expand Down Expand Up @@ -541,7 +572,7 @@ def surv_func(
return (plot, output_filename)
return None

def survival_table(
def survival_table( # pylint: disable=too-many-arguments
self,
survival_table: DataFrame,
safe_table: DataFrame,
Expand All @@ -566,7 +597,7 @@ def survival_table(
)
return survival_table

def survival_plot(
def survival_plot( # pylint: disable=too-many-arguments
self,
survival_table: DataFrame,
survival_func: Any,
Expand Down Expand Up @@ -617,7 +648,7 @@ def survival_plot(
)
return (plot, unique_filename)

def hist(
def hist( # pylint: disable=too-many-arguments,too-many-locals
self,
data: DataFrame,
column: str,
Expand Down Expand Up @@ -914,7 +945,7 @@ def pie(
return unique_filename


def create_crosstab_masks(
def create_crosstab_masks( # pylint: disable=too-many-arguments,too-many-locals
index: Any,
columns: Any,
values: Any,
Expand Down Expand Up @@ -1365,7 +1396,7 @@ def _align_mask_columns(m: DataFrame, table: DataFrame) -> DataFrame:
if table_nlevels == 2 and mask_nlevels == 2:
table_top = table.columns.get_level_values(0).unique().tolist()
mask_top = m.columns.get_level_values(0).unique().tolist()
if len(mask_top) == 1 and len(table_top) > 1:
if mask_top != table_top:
n_base = len(table.columns.get_level_values(1).unique())
base_mask = m.iloc[:, :n_base]
flat_cols = base_mask.columns.get_level_values(1)
Expand Down Expand Up @@ -1771,7 +1802,7 @@ def get_index_columns(
return index_new, columns_new


def crosstab_with_totals(
def crosstab_with_totals( # pylint: disable=too-many-arguments,too-many-locals
masks: dict[str, DataFrame],
aggfunc: Any,
index: Any,
Expand Down Expand Up @@ -1907,7 +1938,7 @@ def crosstab_with_totals(
return table


def manual_crossstab_with_totals(
def manual_crossstab_with_totals( # pylint: disable=too-many-arguments
table: DataFrame,
aggfunc: str | list[str] | None,
index: Any,
Expand Down
Loading
Loading