Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions acro/acro.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None:
Tables.__init__(self, suppress)
Regression.__init__(self, config)
self.config: dict[str, Any] = {}
self.results: Records = Records()
self.suppress: bool = suppress
path: pathlib.Path = pathlib.Path(__file__).with_name(config + ".yaml")
logger.debug("path: %s", path)
with open(path, encoding="utf-8") as handle:
self.config = yaml.load(handle, Loader=yaml.loader.SafeLoader)
self.results: Records = Records(
blocked_extensions=self.config.get("blocked_extensions", [])
)
logger.info("version: %s", __version__)
logger.info("config: %s", self.config)
logger.info("automatic suppression: %s", self.suppress)
Expand All @@ -78,6 +80,10 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None:
acro_tables.ZEROS_ARE_DISCLOSIVE = self.config["zeros_are_disclosive"]
# set globals for survival analysis
acro_tables.SURVIVAL_THRESHOLD = self.config["survival_safe_threshold"]
# set globals for blocked file extensions
acro_tables.BLOCKED_EXTENSIONS = [
ext.lower() for ext in self.config.get("blocked_extensions", [])
]

def finalise(
self, path: str = "outputs", ext: str = "json", interactive: bool = False
Expand Down Expand Up @@ -138,7 +144,7 @@ def print_outputs(self) -> str:
"""
return self.results.print()

def custom_output(self, filename: str, comment: str = "") -> None:
def custom_output(self, filename: str, comment: str = "") -> bool:
"""Add an unsupported output to the results dictionary.

Parameters
Expand All @@ -147,8 +153,13 @@ def custom_output(self, filename: str, comment: str = "") -> None:
The name of the file that will be added to the list of the outputs.
comment : str
An optional comment.

Returns
-------
bool
False if the file extension is blocked, True otherwise.
"""
self.results.add_custom(filename, comment)
return self.results.add_custom(filename, comment)

def rename_output(self, old: str, new: str) -> None:
"""Rename an output.
Expand Down
8 changes: 3 additions & 5 deletions acro/acro_stata_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,10 @@ def add_custom_output(varlist: list[str]) -> str:
except IndexError:
return "syntax error: please pass the name of the output to be added"

# .gph extension contain data
_, file_extension = os.path.splitext(the_output)
if file_extension == ".gph":
return "Warning: .gph files may not be exported as they contain data."
comment_str = " ".join(varlist)
stata_config.stata_acro.custom_output(the_output, comment_str)
if not stata_config.stata_acro.custom_output(the_output, comment_str):
_, ext = os.path.splitext(the_output)
return f"Warning: {ext} files are not allowed and cannot be exported."
outcome = f"file {the_output} with comment {comment_str} added to session."
return outcome

Expand Down
27 changes: 27 additions & 0 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def mode_aggfunc(values: Series) -> Series:
# survival analysis parameters
SURVIVAL_THRESHOLD: int = 10

# blocked file extensions for outputs
BLOCKED_EXTENSIONS: list[str] = []


class Tables:
"""Creates tabular data.
Expand Down Expand Up @@ -577,6 +580,14 @@ def survival_plot(
summary: str,
) -> tuple[Any, str] | None:
"""Create the survival plot according to the status of suppressing."""
_, extension = os.path.splitext(filename)
if extension.lower() in BLOCKED_EXTENSIONS:
logger.warning(
"Blocked file extension %s. Files with extension %s are not allowed.",
extension,
extension,
)
return None
if self.suppress:
survival_table = _rounded_survival_table(survival_table)
plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0)
Expand Down Expand Up @@ -702,6 +713,14 @@ def hist(
The name of the file where the histogram is saved.
"""
logger.debug("hist()")
_, extension = os.path.splitext(filename)
if extension.lower() in BLOCKED_EXTENSIONS:
logger.warning(
"Blocked file extension %s. Files with extension %s are not allowed.",
extension,
extension,
)
return None
command: str = utils.get_command("hist()", stack())

if isinstance(data, list): # pragma: no cover
Expand Down Expand Up @@ -847,6 +866,14 @@ def pie(
The path to the saved pie chart file.
"""
logger.debug("pie()")
_, extension = os.path.splitext(filename)
if extension.lower() in BLOCKED_EXTENSIONS:
logger.warning(
"Blocked file extension %s. Files with extension %s are not allowed.",
extension,
extension,
)
return None
command: str = utils.get_command("pie()", stack())

# COMPUTE PRE-CATEGORY COUNTS
Expand Down
9 changes: 9 additions & 0 deletions acro/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,13 @@ survival_safe_threshold: 10

# consider zeros to be disclosive
zeros_are_disclosive: True

################################################################################
# Blocked file extensions #
################################################################################
# File extensions that are not allowed in custom outputs or plots.
# Extensions are case-insensitive and must include the leading dot.
blocked_extensions:
- .svg
- .gph
...
29 changes: 26 additions & 3 deletions acro/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,19 @@ def __str__(self) -> str:
class Records:
"""Stores data related to a collection of output records."""

def __init__(self) -> None:
"""Construct a new object for storing multiple records."""
def __init__(self, blocked_extensions: list[str] | None = None) -> None:
"""Construct a new object for storing multiple records.

Parameters
----------
blocked_extensions : list[str] | None, default None
File extensions that are not allowed in custom outputs.
"""
self.results: dict[str, Record] = {}
self.output_id: int = 0
self.blocked_extensions: list[str] = [
ext.lower() for ext in (blocked_extensions or [])
]

def add(
self,
Expand Down Expand Up @@ -322,7 +331,7 @@ def get_index(self, index: int) -> Record:
key = list(self.results.keys())[index]
return self.results[key]

def add_custom(self, filename: str, comment: str | None = None) -> None:
def add_custom(self, filename: str, comment: str | None = None) -> bool:
"""Add an unsupported output to the results dictionary.

Parameters
Expand All @@ -331,7 +340,20 @@ def add_custom(self, filename: str, comment: str | None = None) -> None:
The name of the file that will be added to the list of the outputs.
comment : str | None, default None
An optional comment.

Returns
-------
bool
False if the file extension is blocked, True otherwise.
"""
_, ext = os.path.splitext(filename)
if ext.lower() in self.blocked_extensions:
logger.warning(
"Blocked file extension %s. Files with extension %s are not allowed.",
filename,
ext,
)
return False
if os.path.exists(filename):
output = Record(
uid=f"output_{self.output_id}",
Expand All @@ -352,6 +374,7 @@ def add_custom(self, filename: str, comment: str | None = None) -> None:
logger.info(
"WARNING: Unable to add %s because the file does not exist", filename
) # pragma: no cover
return True

def rename(self, old: str, new: str) -> None:
"""Rename an output.
Expand Down
55 changes: 55 additions & 0 deletions test/test_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,61 @@ def test_custom_output(acro):
shutil.rmtree(PATH)


def test_blocked_extension(acro, tmp_path):
"""Test that blocked file extensions are rejected in custom output."""
# create temporary files with blocked extensions
svg_file = tmp_path / "test.svg"
svg_file.write_text("<svg></svg>")
gph_file = tmp_path / "test.gph"
gph_file.write_text("data")

# blocked extensions should be rejected
acro.custom_output(str(svg_file))
acro.custom_output(str(gph_file))
assert len(acro.results.results) == 0

# allowed extensions should be accepted
txt_file = tmp_path / "test.txt"
txt_file.write_text("hello")
acro.custom_output(str(txt_file))
assert len(acro.results.results) == 1

# case-insensitive check
svg_upper = tmp_path / "test.SVG"
svg_upper.write_text("<svg></svg>")
acro.custom_output(str(svg_upper))
assert len(acro.results.results) == 1


def test_blocked_extension_hist(data, acro):
"""Test that blocked file extensions are rejected for histograms."""
result = acro.hist(data, "inc_grants", bins=1, filename="hist.svg")
assert result is None
assert len(acro.results.results) == 0


def test_blocked_extension_pie(data, acro):
"""Test that blocked file extensions are rejected for pie charts."""
result = acro.pie(data, "grant_type", filename="pie.svg")
assert result is None
assert len(acro.results.results) == 0


def test_blocked_extension_survival(acro):
"""Test that blocked file extensions are rejected for survival plots."""
result = acro.survival_plot(
survival_table=pd.DataFrame(),
survival_func=None,
filename="surv.svg",
status="pass",
sdc={},
command="test",
summary="test",
)
assert result is None
assert len(acro.results.results) == 0


def test_missing(data, acro, monkeypatch):
"""Pivot table and Crosstab with negative values."""
acro_tables.CHECK_MISSING_VALUES = True
Expand Down
2 changes: 1 addition & 1 deletion test/test_stata17_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_stata_custom_output_invalid():
options="nototals",
stata_version="17",
)
correct = "Warning: .gph files may not be exported as they contain data."
correct = "Warning: .gph files are not allowed and cannot be exported."
assert ret == correct, f" we got : {ret}\nexpected:{correct}"
newres = stata_config.stata_acro.results.__dict__
assert newres == previous, (
Expand Down
2 changes: 1 addition & 1 deletion test/test_stata_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def test_stata_custom_output_invalid():
options="nototals",
stata_version="17",
)
correct = "Warning: .gph files may not be exported as they contain data."
correct = "Warning: .gph files are not allowed and cannot be exported."
assert ret == correct, f" we got : {ret}\nexpected:{correct}"
newres = stata_config.stata_acro.results.__dict__
assert newres == previous, (
Expand Down
Loading