diff --git a/acro/acro.py b/acro/acro.py index 2dd145a..677a02f 100644 --- a/acro/acro.py +++ b/acro/acro.py @@ -60,12 +60,14 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None: Tables.__init__(self, suppress) Regression.__init__(self, config) self.config: dict[str, Any] = {} - self.results: Records = Records() self.suppress: bool = suppress path: pathlib.Path = pathlib.Path(__file__).with_name(config + ".yaml") logger.debug("path: %s", path) with open(path, encoding="utf-8") as handle: self.config = yaml.load(handle, Loader=yaml.loader.SafeLoader) + self.results: Records = Records( + blocked_extensions=self.config.get("blocked_extensions", []) + ) logger.info("version: %s", __version__) logger.info("config: %s", self.config) logger.info("automatic suppression: %s", self.suppress) @@ -78,6 +80,10 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None: acro_tables.ZEROS_ARE_DISCLOSIVE = self.config["zeros_are_disclosive"] # set globals for survival analysis acro_tables.SURVIVAL_THRESHOLD = self.config["survival_safe_threshold"] + # set globals for blocked file extensions + acro_tables.BLOCKED_EXTENSIONS = [ + ext.lower() for ext in self.config.get("blocked_extensions", []) + ] def finalise( self, path: str = "outputs", ext: str = "json", interactive: bool = False @@ -138,7 +144,7 @@ def print_outputs(self) -> str: """ return self.results.print() - def custom_output(self, filename: str, comment: str = "") -> None: + def custom_output(self, filename: str, comment: str = "") -> bool: """Add an unsupported output to the results dictionary. Parameters @@ -147,8 +153,13 @@ def custom_output(self, filename: str, comment: str = "") -> None: The name of the file that will be added to the list of the outputs. comment : str An optional comment. + + Returns + ------- + bool + False if the file extension is blocked, True otherwise. """ - self.results.add_custom(filename, comment) + return self.results.add_custom(filename, comment) def rename_output(self, old: str, new: str) -> None: """Rename an output. diff --git a/acro/acro_stata_parser.py b/acro/acro_stata_parser.py index 0e1a423..ebcb7f5 100644 --- a/acro/acro_stata_parser.py +++ b/acro/acro_stata_parser.py @@ -376,12 +376,10 @@ def add_custom_output(varlist: list[str]) -> str: except IndexError: return "syntax error: please pass the name of the output to be added" - # .gph extension contain data - _, file_extension = os.path.splitext(the_output) - if file_extension == ".gph": - return "Warning: .gph files may not be exported as they contain data." comment_str = " ".join(varlist) - stata_config.stata_acro.custom_output(the_output, comment_str) + if not stata_config.stata_acro.custom_output(the_output, comment_str): + _, ext = os.path.splitext(the_output) + return f"Warning: {ext} files are not allowed and cannot be exported." outcome = f"file {the_output} with comment {comment_str} added to session." return outcome diff --git a/acro/acro_tables.py b/acro/acro_tables.py index 72ecbe2..bc41748 100644 --- a/acro/acro_tables.py +++ b/acro/acro_tables.py @@ -58,6 +58,9 @@ def mode_aggfunc(values: Series) -> Series: # survival analysis parameters SURVIVAL_THRESHOLD: int = 10 +# blocked file extensions for outputs +BLOCKED_EXTENSIONS: list[str] = [] + class Tables: """Creates tabular data. @@ -577,6 +580,14 @@ def survival_plot( summary: str, ) -> tuple[Any, str] | None: """Create the survival plot according to the status of suppressing.""" + _, extension = os.path.splitext(filename) + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None if self.suppress: survival_table = _rounded_survival_table(survival_table) plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0) @@ -702,6 +713,14 @@ def hist( The name of the file where the histogram is saved. """ logger.debug("hist()") + _, extension = os.path.splitext(filename) + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None command: str = utils.get_command("hist()", stack()) if isinstance(data, list): # pragma: no cover @@ -847,6 +866,14 @@ def pie( The path to the saved pie chart file. """ logger.debug("pie()") + _, extension = os.path.splitext(filename) + if extension.lower() in BLOCKED_EXTENSIONS: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + extension, + extension, + ) + return None command: str = utils.get_command("pie()", stack()) # COMPUTE PRE-CATEGORY COUNTS diff --git a/acro/default.yaml b/acro/default.yaml index d0b3d92..c178030 100644 --- a/acro/default.yaml +++ b/acro/default.yaml @@ -29,4 +29,13 @@ survival_safe_threshold: 10 # consider zeros to be disclosive zeros_are_disclosive: True + +################################################################################ +# Blocked file extensions # +################################################################################ +# File extensions that are not allowed in custom outputs or plots. +# Extensions are case-insensitive and must include the leading dot. +blocked_extensions: + - .svg + - .gph ... diff --git a/acro/record.py b/acro/record.py index 1868cfe..e04a0d4 100644 --- a/acro/record.py +++ b/acro/record.py @@ -209,10 +209,19 @@ def __str__(self) -> str: class Records: """Stores data related to a collection of output records.""" - def __init__(self) -> None: - """Construct a new object for storing multiple records.""" + def __init__(self, blocked_extensions: list[str] | None = None) -> None: + """Construct a new object for storing multiple records. + + Parameters + ---------- + blocked_extensions : list[str] | None, default None + File extensions that are not allowed in custom outputs. + """ self.results: dict[str, Record] = {} self.output_id: int = 0 + self.blocked_extensions: list[str] = [ + ext.lower() for ext in (blocked_extensions or []) + ] def add( self, @@ -322,7 +331,7 @@ def get_index(self, index: int) -> Record: key = list(self.results.keys())[index] return self.results[key] - def add_custom(self, filename: str, comment: str | None = None) -> None: + def add_custom(self, filename: str, comment: str | None = None) -> bool: """Add an unsupported output to the results dictionary. Parameters @@ -331,7 +340,20 @@ def add_custom(self, filename: str, comment: str | None = None) -> None: The name of the file that will be added to the list of the outputs. comment : str | None, default None An optional comment. + + Returns + ------- + bool + False if the file extension is blocked, True otherwise. """ + _, ext = os.path.splitext(filename) + if ext.lower() in self.blocked_extensions: + logger.warning( + "Blocked file extension %s. Files with extension %s are not allowed.", + filename, + ext, + ) + return False if os.path.exists(filename): output = Record( uid=f"output_{self.output_id}", @@ -352,6 +374,7 @@ def add_custom(self, filename: str, comment: str | None = None) -> None: logger.info( "WARNING: Unable to add %s because the file does not exist", filename ) # pragma: no cover + return True def rename(self, old: str, new: str) -> None: """Rename an output. diff --git a/test/test_initial.py b/test/test_initial.py index a9af6cb..2d4b2d7 100644 --- a/test/test_initial.py +++ b/test/test_initial.py @@ -505,6 +505,61 @@ def test_custom_output(acro): shutil.rmtree(PATH) +def test_blocked_extension(acro, tmp_path): + """Test that blocked file extensions are rejected in custom output.""" + # create temporary files with blocked extensions + svg_file = tmp_path / "test.svg" + svg_file.write_text("") + gph_file = tmp_path / "test.gph" + gph_file.write_text("data") + + # blocked extensions should be rejected + acro.custom_output(str(svg_file)) + acro.custom_output(str(gph_file)) + assert len(acro.results.results) == 0 + + # allowed extensions should be accepted + txt_file = tmp_path / "test.txt" + txt_file.write_text("hello") + acro.custom_output(str(txt_file)) + assert len(acro.results.results) == 1 + + # case-insensitive check + svg_upper = tmp_path / "test.SVG" + svg_upper.write_text("") + acro.custom_output(str(svg_upper)) + assert len(acro.results.results) == 1 + + +def test_blocked_extension_hist(data, acro): + """Test that blocked file extensions are rejected for histograms.""" + result = acro.hist(data, "inc_grants", bins=1, filename="hist.svg") + assert result is None + assert len(acro.results.results) == 0 + + +def test_blocked_extension_pie(data, acro): + """Test that blocked file extensions are rejected for pie charts.""" + result = acro.pie(data, "grant_type", filename="pie.svg") + assert result is None + assert len(acro.results.results) == 0 + + +def test_blocked_extension_survival(acro): + """Test that blocked file extensions are rejected for survival plots.""" + result = acro.survival_plot( + survival_table=pd.DataFrame(), + survival_func=None, + filename="surv.svg", + status="pass", + sdc={}, + command="test", + summary="test", + ) + assert result is None + assert len(acro.results.results) == 0 + + def test_missing(data, acro, monkeypatch): """Pivot table and Crosstab with negative values.""" acro_tables.CHECK_MISSING_VALUES = True diff --git a/test/test_stata17_interface.py b/test/test_stata17_interface.py index 16f6001..b0f7067 100644 --- a/test/test_stata17_interface.py +++ b/test/test_stata17_interface.py @@ -472,7 +472,7 @@ def test_stata_custom_output_invalid(): options="nototals", stata_version="17", ) - correct = "Warning: .gph files may not be exported as they contain data." + correct = "Warning: .gph files are not allowed and cannot be exported." assert ret == correct, f" we got : {ret}\nexpected:{correct}" newres = stata_config.stata_acro.results.__dict__ assert newres == previous, ( diff --git a/test/test_stata_interface.py b/test/test_stata_interface.py index 449467f..78c3d94 100644 --- a/test/test_stata_interface.py +++ b/test/test_stata_interface.py @@ -562,7 +562,7 @@ def test_stata_custom_output_invalid(): options="nototals", stata_version="17", ) - correct = "Warning: .gph files may not be exported as they contain data." + correct = "Warning: .gph files are not allowed and cannot be exported." assert ret == correct, f" we got : {ret}\nexpected:{correct}" newres = stata_config.stata_acro.results.__dict__ assert newres == previous, (