From d3c6a8fd61116eb46c9d5aa0fc2694cdc1b40bc0 Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Thu, 9 Apr 2026 12:58:45 +0100 Subject: [PATCH 1/2] feat: create and enforce blocked file extension list for custom outputs --- acro/acro.py | 17 ++++++++++++++--- acro/acro_stata_parser.py | 8 +++----- acro/acro_tables.py | 24 ++++++++++++++++++++++++ acro/default.yaml | 9 +++++++++ acro/record.py | 29 ++++++++++++++++++++++++++--- test/test_initial.py | 26 ++++++++++++++++++++++++++ test/test_stata17_interface.py | 2 +- test/test_stata_interface.py | 2 +- 8 files changed, 104 insertions(+), 13 deletions(-) diff --git a/acro/acro.py b/acro/acro.py index 2dd145a7..677a02fd 100644 --- a/acro/acro.py +++ b/acro/acro.py @@ -60,12 +60,14 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None: Tables.__init__(self, suppress) Regression.__init__(self, config) self.config: dict[str, Any] = {} - self.results: Records = Records() self.suppress: bool = suppress path: pathlib.Path = pathlib.Path(__file__).with_name(config + ".yaml") logger.debug("path: %s", path) with open(path, encoding="utf-8") as handle: self.config = yaml.load(handle, Loader=yaml.loader.SafeLoader) + self.results: Records = Records( + blocked_extensions=self.config.get("blocked_extensions", []) + ) logger.info("version: %s", __version__) logger.info("config: %s", self.config) logger.info("automatic suppression: %s", self.suppress) @@ -78,6 +80,10 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None: acro_tables.ZEROS_ARE_DISCLOSIVE = self.config["zeros_are_disclosive"] # set globals for survival analysis acro_tables.SURVIVAL_THRESHOLD = self.config["survival_safe_threshold"] + # set globals for blocked file extensions + acro_tables.BLOCKED_EXTENSIONS = [ + ext.lower() for ext in self.config.get("blocked_extensions", []) + ] def finalise( self, path: str = "outputs", ext: str = "json", interactive: bool = False @@ -138,7 +144,7 @@ def print_outputs(self) -> str: """ return self.results.print() - def custom_output(self, filename: str, comment: str = "") -> None: + def custom_output(self, filename: str, comment: str = "") -> bool: """Add an unsupported output to the results dictionary. Parameters @@ -147,8 +153,13 @@ def custom_output(self, filename: str, comment: str = "") -> None: The name of the file that will be added to the list of the outputs. comment : str An optional comment. + + Returns + ------- + bool + False if the file extension is blocked, True otherwise. """ - self.results.add_custom(filename, comment) + return self.results.add_custom(filename, comment) def rename_output(self, old: str, new: str) -> None: """Rename an output. diff --git a/acro/acro_stata_parser.py b/acro/acro_stata_parser.py index 0e1a4238..ebcb7f57 100644 --- a/acro/acro_stata_parser.py +++ b/acro/acro_stata_parser.py @@ -376,12 +376,10 @@ def add_custom_output(varlist: list[str]) -> str: except IndexError: return "syntax error: please pass the name of the output to be added" - # .gph extension contain data - _, file_extension = os.path.splitext(the_output) - if file_extension == ".gph": - return "Warning: .gph files may not be exported as they contain data." comment_str = " ".join(varlist) - stata_config.stata_acro.custom_output(the_output, comment_str) + if not stata_config.stata_acro.custom_output(the_output, comment_str): + _, ext = os.path.splitext(the_output) + return f"Warning: {ext} files are not allowed and cannot be exported." outcome = f"file {the_output} with comment {comment_str} added to session." return outcome diff --git a/acro/acro_tables.py b/acro/acro_tables.py index 72ecbe27..e77c2b6f 100644 --- a/acro/acro_tables.py +++ b/acro/acro_tables.py @@ -58,6 +58,9 @@ def mode_aggfunc(values: Series) -> Series: # survival analysis parameters SURVIVAL_THRESHOLD: int = 10 +# blocked file extensions for outputs +BLOCKED_EXTENSIONS: list[str] = [] + class Tables: """Creates tabular data. @@ -594,6 +597,13 @@ def survival_plot( if not extension: # pragma: no cover logger.info("Please provide a valid file extension") return None # pragma: no cover + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None increment_number = 0 while os.path.exists( f"acro_artifacts/{filename}_{increment_number}{extension}" @@ -788,6 +798,13 @@ def hist( if not extension: # pragma: no cover logger.info("Please provide a valid file extension") return None + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None increment_number = 0 while os.path.exists( f"acro_artifacts/{filename}_{increment_number}{extension}" @@ -888,6 +905,13 @@ def pie( if not extension: # pragma: no cover logger.info("Please provide a valid file extension") return None + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None increment_number = 0 while os.path.exists( diff --git a/acro/default.yaml b/acro/default.yaml index d0b3d92a..c178030d 100644 --- a/acro/default.yaml +++ b/acro/default.yaml @@ -29,4 +29,13 @@ survival_safe_threshold: 10 # consider zeros to be disclosive zeros_are_disclosive: True + +################################################################################ +# Blocked file extensions # +################################################################################ +# File extensions that are not allowed in custom outputs or plots. +# Extensions are case-insensitive and must include the leading dot. +blocked_extensions: + - .svg + - .gph ... diff --git a/acro/record.py b/acro/record.py index 1868cfef..e04a0d4b 100644 --- a/acro/record.py +++ b/acro/record.py @@ -209,10 +209,19 @@ def __str__(self) -> str: class Records: """Stores data related to a collection of output records.""" - def __init__(self) -> None: - """Construct a new object for storing multiple records.""" + def __init__(self, blocked_extensions: list[str] | None = None) -> None: + """Construct a new object for storing multiple records. + + Parameters + ---------- + blocked_extensions : list[str] | None, default None + File extensions that are not allowed in custom outputs. + """ self.results: dict[str, Record] = {} self.output_id: int = 0 + self.blocked_extensions: list[str] = [ + ext.lower() for ext in (blocked_extensions or []) + ] def add( self, @@ -322,7 +331,7 @@ def get_index(self, index: int) -> Record: key = list(self.results.keys())[index] return self.results[key] - def add_custom(self, filename: str, comment: str | None = None) -> None: + def add_custom(self, filename: str, comment: str | None = None) -> bool: """Add an unsupported output to the results dictionary. Parameters @@ -331,7 +340,20 @@ def add_custom(self, filename: str, comment: str | None = None) -> None: The name of the file that will be added to the list of the outputs. comment : str | None, default None An optional comment. + + Returns + ------- + bool + False if the file extension is blocked, True otherwise. """ + _, ext = os.path.splitext(filename) + if ext.lower() in self.blocked_extensions: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + filename, + ext, + ) + return False if os.path.exists(filename): output = Record( uid=f"output_{self.output_id}", @@ -352,6 +374,7 @@ def add_custom(self, filename: str, comment: str | None = None) -> None: logger.info( "WARNING: Unable to add %s because the file does not exist", filename ) # pragma: no cover + return True def rename(self, old: str, new: str) -> None: """Rename an output. diff --git a/test/test_initial.py b/test/test_initial.py index a9af6cbc..6d6e7738 100644 --- a/test/test_initial.py +++ b/test/test_initial.py @@ -505,6 +505,32 @@ def test_custom_output(acro): shutil.rmtree(PATH) +def test_blocked_extension(acro, tmp_path): + """Test that blocked file extensions are rejected in custom output.""" + # create temporary files with blocked extensions + svg_file = tmp_path / "test.svg" + svg_file.write_text("") + gph_file = tmp_path / "test.gph" + gph_file.write_text("data") + + # blocked extensions should be rejected + acro.custom_output(str(svg_file)) + acro.custom_output(str(gph_file)) + assert len(acro.results.results) == 0 + + # allowed extensions should be accepted + txt_file = tmp_path / "test.txt" + txt_file.write_text("hello") + acro.custom_output(str(txt_file)) + assert len(acro.results.results) == 1 + + # case-insensitive check + svg_upper = tmp_path / "test.SVG" + svg_upper.write_text("") + acro.custom_output(str(svg_upper)) + assert len(acro.results.results) == 1 + + def test_missing(data, acro, monkeypatch): """Pivot table and Crosstab with negative values.""" acro_tables.CHECK_MISSING_VALUES = True diff --git a/test/test_stata17_interface.py b/test/test_stata17_interface.py index 16f6001e..b0f70673 100644 --- a/test/test_stata17_interface.py +++ b/test/test_stata17_interface.py @@ -472,7 +472,7 @@ def test_stata_custom_output_invalid(): options="nototals", stata_version="17", ) - correct = "Warning: .gph files may not be exported as they contain data." + correct = "Warning: .gph files are not allowed and cannot be exported." assert ret == correct, f" we got : {ret}\nexpected:{correct}" newres = stata_config.stata_acro.results.__dict__ assert newres == previous, ( diff --git a/test/test_stata_interface.py b/test/test_stata_interface.py index 449467f0..78c3d943 100644 --- a/test/test_stata_interface.py +++ b/test/test_stata_interface.py @@ -562,7 +562,7 @@ def test_stata_custom_output_invalid(): options="nototals", stata_version="17", ) - correct = "Warning: .gph files may not be exported as they contain data." + correct = "Warning: .gph files are not allowed and cannot be exported." assert ret == correct, f" we got : {ret}\nexpected:{correct}" newres = stata_config.stata_acro.results.__dict__ assert newres == previous, ( From 5c2dc104e7f1251ac4b915490ee12e94d22e4b19 Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Thu, 9 Apr 2026 13:07:43 +0100 Subject: [PATCH 2/2] test: add coverage for blocked extension checks in plot outputs --- acro/acro_tables.py | 45 +++++++++++++++++++++++--------------------- test/test_initial.py | 29 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 21 deletions(-) diff --git a/acro/acro_tables.py b/acro/acro_tables.py index e77c2b6f..bc41748e 100644 --- a/acro/acro_tables.py +++ b/acro/acro_tables.py @@ -580,6 +580,14 @@ def survival_plot( summary: str, ) -> tuple[Any, str] | None: """Create the survival plot according to the status of suppressing.""" + _, extension = os.path.splitext(filename) + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None if self.suppress: survival_table = _rounded_survival_table(survival_table) plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0) @@ -597,13 +605,6 @@ def survival_plot( if not extension: # pragma: no cover logger.info("Please provide a valid file extension") return None # pragma: no cover - if extension.lower() in BLOCKED_EXTENSIONS: - logger.warning( - "Blocked file extension %s. Files with extension %s are not allowed.", - extension, - extension, - ) - return None increment_number = 0 while os.path.exists( f"acro_artifacts/{filename}_{increment_number}{extension}" @@ -712,6 +713,14 @@ def hist( The name of the file where the histogram is saved. """ logger.debug("hist()") + _, extension = os.path.splitext(filename) + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None command: str = utils.get_command("hist()", stack()) if isinstance(data, list): # pragma: no cover @@ -798,13 +807,6 @@ def hist( if not extension: # pragma: no cover logger.info("Please provide a valid file extension") return None - if extension.lower() in BLOCKED_EXTENSIONS: - logger.warning( - "Blocked file extension %s. Files with extension %s are not allowed.", - extension, - extension, - ) - return None increment_number = 0 while os.path.exists( f"acro_artifacts/{filename}_{increment_number}{extension}" @@ -864,6 +866,14 @@ def pie( The path to the saved pie chart file. """ logger.debug("pie()") + _, extension = os.path.splitext(filename) + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None command: str = utils.get_command("pie()", stack()) # COMPUTE PRE-CATEGORY COUNTS @@ -905,13 +915,6 @@ def pie( if not extension: # pragma: no cover logger.info("Please provide a valid file extension") return None - if extension.lower() in BLOCKED_EXTENSIONS: - logger.warning( - "Blocked file extension %s. Files with extension %s are not allowed.", - extension, - extension, - ) - return None increment_number = 0 while os.path.exists( diff --git a/test/test_initial.py b/test/test_initial.py index 6d6e7738..2d4b2d78 100644 --- a/test/test_initial.py +++ b/test/test_initial.py @@ -531,6 +531,35 @@ def test_blocked_extension(acro, tmp_path): assert len(acro.results.results) == 1 +def test_blocked_extension_hist(data, acro): + """Test that blocked file extensions are rejected for histograms.""" + result = acro.hist(data, "inc_grants", bins=1, filename="hist.svg") + assert result is None + assert len(acro.results.results) == 0 + + +def test_blocked_extension_pie(data, acro): + """Test that blocked file extensions are rejected for pie charts.""" + result = acro.pie(data, "grant_type", filename="pie.svg") + assert result is None + assert len(acro.results.results) == 0 + + +def test_blocked_extension_survival(acro): + """Test that blocked file extensions are rejected for survival plots.""" + result = acro.survival_plot( + survival_table=pd.DataFrame(), + survival_func=None, + filename="surv.svg", + status="pass", + sdc={}, + command="test", + summary="test", + ) + assert result is None + assert len(acro.results.results) == 0 + + def test_missing(data, acro, monkeypatch): """Pivot table and Crosstab with negative values.""" acro_tables.CHECK_MISSING_VALUES = True