diff --git a/acro/acro.py b/acro/acro.py
index 2dd145a..677a02f 100644
--- a/acro/acro.py
+++ b/acro/acro.py
@@ -60,12 +60,14 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None:
Tables.__init__(self, suppress)
Regression.__init__(self, config)
self.config: dict[str, Any] = {}
- self.results: Records = Records()
self.suppress: bool = suppress
path: pathlib.Path = pathlib.Path(__file__).with_name(config + ".yaml")
logger.debug("path: %s", path)
with open(path, encoding="utf-8") as handle:
self.config = yaml.load(handle, Loader=yaml.loader.SafeLoader)
+ self.results: Records = Records(
+ blocked_extensions=self.config.get("blocked_extensions", [])
+ )
logger.info("version: %s", __version__)
logger.info("config: %s", self.config)
logger.info("automatic suppression: %s", self.suppress)
@@ -78,6 +80,10 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None:
acro_tables.ZEROS_ARE_DISCLOSIVE = self.config["zeros_are_disclosive"]
# set globals for survival analysis
acro_tables.SURVIVAL_THRESHOLD = self.config["survival_safe_threshold"]
+ # set globals for blocked file extensions
+ acro_tables.BLOCKED_EXTENSIONS = [
+ ext.lower() for ext in self.config.get("blocked_extensions", [])
+ ]
def finalise(
self, path: str = "outputs", ext: str = "json", interactive: bool = False
@@ -138,7 +144,7 @@ def print_outputs(self) -> str:
"""
return self.results.print()
- def custom_output(self, filename: str, comment: str = "") -> None:
+ def custom_output(self, filename: str, comment: str = "") -> bool:
"""Add an unsupported output to the results dictionary.
Parameters
@@ -147,8 +153,13 @@ def custom_output(self, filename: str, comment: str = "") -> None:
The name of the file that will be added to the list of the outputs.
comment : str
An optional comment.
+
+ Returns
+ -------
+ bool
+ False if the file extension is blocked, True otherwise.
"""
- self.results.add_custom(filename, comment)
+ return self.results.add_custom(filename, comment)
def rename_output(self, old: str, new: str) -> None:
"""Rename an output.
diff --git a/acro/acro_stata_parser.py b/acro/acro_stata_parser.py
index 0e1a423..ebcb7f5 100644
--- a/acro/acro_stata_parser.py
+++ b/acro/acro_stata_parser.py
@@ -376,12 +376,10 @@ def add_custom_output(varlist: list[str]) -> str:
except IndexError:
return "syntax error: please pass the name of the output to be added"
- # .gph extension contain data
- _, file_extension = os.path.splitext(the_output)
- if file_extension == ".gph":
- return "Warning: .gph files may not be exported as they contain data."
comment_str = " ".join(varlist)
- stata_config.stata_acro.custom_output(the_output, comment_str)
+ if not stata_config.stata_acro.custom_output(the_output, comment_str):
+ _, ext = os.path.splitext(the_output)
+ return f"Warning: {ext} files are not allowed and cannot be exported."
outcome = f"file {the_output} with comment {comment_str} added to session."
return outcome
diff --git a/acro/acro_tables.py b/acro/acro_tables.py
index 72ecbe2..bc41748 100644
--- a/acro/acro_tables.py
+++ b/acro/acro_tables.py
@@ -58,6 +58,9 @@ def mode_aggfunc(values: Series) -> Series:
# survival analysis parameters
SURVIVAL_THRESHOLD: int = 10
+# blocked file extensions for outputs
+BLOCKED_EXTENSIONS: list[str] = []
+
class Tables:
"""Creates tabular data.
@@ -577,6 +580,14 @@ def survival_plot(
summary: str,
) -> tuple[Any, str] | None:
"""Create the survival plot according to the status of suppressing."""
+ _, extension = os.path.splitext(filename)
+ if extension.lower() in BLOCKED_EXTENSIONS:
+ logger.warning(
+ "Blocked file extension %s. Files with extension %s are not allowed.",
+ extension,
+ extension,
+ )
+ return None
if self.suppress:
survival_table = _rounded_survival_table(survival_table)
plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0)
@@ -702,6 +713,14 @@ def hist(
The name of the file where the histogram is saved.
"""
logger.debug("hist()")
+ _, extension = os.path.splitext(filename)
+ if extension.lower() in BLOCKED_EXTENSIONS:
+ logger.warning(
+ "Blocked file extension %s. Files with extension %s are not allowed.",
+ extension,
+ extension,
+ )
+ return None
command: str = utils.get_command("hist()", stack())
if isinstance(data, list): # pragma: no cover
@@ -847,6 +866,14 @@ def pie(
The path to the saved pie chart file.
"""
logger.debug("pie()")
+ _, extension = os.path.splitext(filename)
+ if extension.lower() in BLOCKED_EXTENSIONS:
+ logger.warning(
+ "Blocked file extension %s. Files with extension %s are not allowed.",
+ extension,
+ extension,
+ )
+ return None
command: str = utils.get_command("pie()", stack())
# COMPUTE PRE-CATEGORY COUNTS
diff --git a/acro/default.yaml b/acro/default.yaml
index d0b3d92..c178030 100644
--- a/acro/default.yaml
+++ b/acro/default.yaml
@@ -29,4 +29,13 @@ survival_safe_threshold: 10
# consider zeros to be disclosive
zeros_are_disclosive: True
+
+################################################################################
+# Blocked file extensions #
+################################################################################
+# File extensions that are not allowed in custom outputs or plots.
+# Extensions are case-insensitive and must include the leading dot.
+blocked_extensions:
+ - .svg
+ - .gph
...
diff --git a/acro/record.py b/acro/record.py
index 1868cfe..e04a0d4 100644
--- a/acro/record.py
+++ b/acro/record.py
@@ -209,10 +209,19 @@ def __str__(self) -> str:
class Records:
"""Stores data related to a collection of output records."""
- def __init__(self) -> None:
- """Construct a new object for storing multiple records."""
+ def __init__(self, blocked_extensions: list[str] | None = None) -> None:
+ """Construct a new object for storing multiple records.
+
+ Parameters
+ ----------
+ blocked_extensions : list[str] | None, default None
+ File extensions that are not allowed in custom outputs.
+ """
self.results: dict[str, Record] = {}
self.output_id: int = 0
+ self.blocked_extensions: list[str] = [
+ ext.lower() for ext in (blocked_extensions or [])
+ ]
def add(
self,
@@ -322,7 +331,7 @@ def get_index(self, index: int) -> Record:
key = list(self.results.keys())[index]
return self.results[key]
- def add_custom(self, filename: str, comment: str | None = None) -> None:
+ def add_custom(self, filename: str, comment: str | None = None) -> bool:
"""Add an unsupported output to the results dictionary.
Parameters
@@ -331,7 +340,20 @@ def add_custom(self, filename: str, comment: str | None = None) -> None:
The name of the file that will be added to the list of the outputs.
comment : str | None, default None
An optional comment.
+
+ Returns
+ -------
+ bool
+ False if the file extension is blocked, True otherwise.
"""
+ _, ext = os.path.splitext(filename)
+ if ext.lower() in self.blocked_extensions:
+ logger.warning(
+ "Blocked file extension %s. Files with extension %s are not allowed.",
+ filename,
+ ext,
+ )
+ return False
if os.path.exists(filename):
output = Record(
uid=f"output_{self.output_id}",
@@ -352,6 +374,7 @@ def add_custom(self, filename: str, comment: str | None = None) -> None:
logger.info(
"WARNING: Unable to add %s because the file does not exist", filename
) # pragma: no cover
+ return True
def rename(self, old: str, new: str) -> None:
"""Rename an output.
diff --git a/test/test_initial.py b/test/test_initial.py
index a9af6cb..2d4b2d7 100644
--- a/test/test_initial.py
+++ b/test/test_initial.py
@@ -505,6 +505,61 @@ def test_custom_output(acro):
shutil.rmtree(PATH)
+def test_blocked_extension(acro, tmp_path):
+ """Test that blocked file extensions are rejected in custom output."""
+ # create temporary files with blocked extensions
+ svg_file = tmp_path / "test.svg"
+ svg_file.write_text("")
+ gph_file = tmp_path / "test.gph"
+ gph_file.write_text("data")
+
+ # blocked extensions should be rejected
+ acro.custom_output(str(svg_file))
+ acro.custom_output(str(gph_file))
+ assert len(acro.results.results) == 0
+
+ # allowed extensions should be accepted
+ txt_file = tmp_path / "test.txt"
+ txt_file.write_text("hello")
+ acro.custom_output(str(txt_file))
+ assert len(acro.results.results) == 1
+
+ # case-insensitive check
+ svg_upper = tmp_path / "test.SVG"
+ svg_upper.write_text("")
+ acro.custom_output(str(svg_upper))
+ assert len(acro.results.results) == 1
+
+
+def test_blocked_extension_hist(data, acro):
+ """Test that blocked file extensions are rejected for histograms."""
+ result = acro.hist(data, "inc_grants", bins=1, filename="hist.svg")
+ assert result is None
+ assert len(acro.results.results) == 0
+
+
+def test_blocked_extension_pie(data, acro):
+ """Test that blocked file extensions are rejected for pie charts."""
+ result = acro.pie(data, "grant_type", filename="pie.svg")
+ assert result is None
+ assert len(acro.results.results) == 0
+
+
+def test_blocked_extension_survival(acro):
+ """Test that blocked file extensions are rejected for survival plots."""
+ result = acro.survival_plot(
+ survival_table=pd.DataFrame(),
+ survival_func=None,
+ filename="surv.svg",
+ status="pass",
+ sdc={},
+ command="test",
+ summary="test",
+ )
+ assert result is None
+ assert len(acro.results.results) == 0
+
+
def test_missing(data, acro, monkeypatch):
"""Pivot table and Crosstab with negative values."""
acro_tables.CHECK_MISSING_VALUES = True
diff --git a/test/test_stata17_interface.py b/test/test_stata17_interface.py
index 16f6001..b0f7067 100644
--- a/test/test_stata17_interface.py
+++ b/test/test_stata17_interface.py
@@ -472,7 +472,7 @@ def test_stata_custom_output_invalid():
options="nototals",
stata_version="17",
)
- correct = "Warning: .gph files may not be exported as they contain data."
+ correct = "Warning: .gph files are not allowed and cannot be exported."
assert ret == correct, f" we got : {ret}\nexpected:{correct}"
newres = stata_config.stata_acro.results.__dict__
assert newres == previous, (
diff --git a/test/test_stata_interface.py b/test/test_stata_interface.py
index 449467f..78c3d94 100644
--- a/test/test_stata_interface.py
+++ b/test/test_stata_interface.py
@@ -562,7 +562,7 @@ def test_stata_custom_output_invalid():
options="nototals",
stata_version="17",
)
- correct = "Warning: .gph files may not be exported as they contain data."
+ correct = "Warning: .gph files are not allowed and cannot be exported."
assert ret == correct, f" we got : {ret}\nexpected:{correct}"
newres = stata_config.stata_acro.results.__dict__
assert newres == previous, (