Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 56 additions & 49 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,41 @@ def mode_aggfunc(values) -> Series:
SURVIVAL_THRESHOLD: int = 10


def _save_plot(filename: str) -> str | None:
"""Save the current plot to the acro_artifacts directory with a unique name.

Parameters
----------
filename : str
The base name of the file where the plot will be saved.

Returns
-------
str | None
The unique filename where the plot was saved, or None if invalid extension.
"""
try:
os.makedirs("acro_artifacts")
logger.debug("Directory acro_artifacts created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")

filename, extension = os.path.splitext(filename)
if not extension: # pragma: no cover
logger.info("Please provide a valid file extension")
return None

increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"

plt.savefig(unique_filename)
return unique_filename


class Tables:
"""Creates tabular data.

Expand Down Expand Up @@ -545,32 +580,20 @@ def survival_plot( # pylint: disable=too-many-arguments
self, survival_table, survival_func, filename, status, sdc, command, summary
):
"""Create the survival plot according to the status of suppressing."""
if self.suppress:
survival_table = _rounded_survival_table(survival_table)
plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0)
else: # pragma: no cover
plot = survival_func.plot()

try:
os.makedirs("acro_artifacts")
logger.debug("Directory acro_artifacts created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")

# create a unique filename with number to avoid overwrite
filename, extension = os.path.splitext(filename)
if not extension: # pragma: no cover
logger.info("Please provide a valid file extension")
return None # pragma: no cover
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"

# save the plot to the acro artifacts directory
plt.savefig(unique_filename)
if self.suppress and status == "fail":
logger.warning("Survival plot will not be shown as it is disclosive.")
unique_filename = None
plot = None
else:
if self.suppress:
survival_table = _rounded_survival_table(survival_table)
plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0)
else: # pragma: no cover
plot = survival_func.plot()

unique_filename = _save_plot(filename)

output_list = [os.path.normpath(unique_filename)] if unique_filename else []

# record output
self.results.add(
Expand All @@ -581,7 +604,7 @@ def survival_plot( # pylint: disable=too-many-arguments
command=command,
summary=summary,
outcome=pd.DataFrame(),
output=[os.path.normpath(unique_filename)],
output=output_list,
)
return (plot, unique_filename)

Expand Down Expand Up @@ -694,6 +717,7 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
"Histogram will not be shown as the %s column is disclosive.",
column,
)
unique_filename = None
else: # pragma: no cover
data.hist(
column=column,
Expand All @@ -713,6 +737,7 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
legend=legend,
**kwargs,
)
unique_filename = _save_plot(filename)
else:
status = "review"
data.hist(
Expand All @@ -733,6 +758,8 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
legend=legend,
**kwargs,
)
unique_filename = _save_plot(filename)

logger.info("status: %s", status)

# create the summary
Expand All @@ -744,27 +771,7 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
f"The maximum value of the {column} column is: {max_value}"
)

# create the acro_artifacts directory to save the plot in it
try:
os.makedirs("acro_artifacts")
logger.debug("Directory acro_artifacts created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")

# create a unique filename with number to avoid overwrite
filename, extension = os.path.splitext(filename)
if not extension: # pragma: no cover
logger.info("Please provide a valid file extension")
return None
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"

# save the plot to the acro artifacts directory
plt.savefig(unique_filename)
output_list = [os.path.normpath(unique_filename)] if unique_filename else []

# record output
self.results.add(
Expand All @@ -775,7 +782,7 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
command=command,
summary=summary,
outcome=pd.DataFrame(),
output=[os.path.normpath(unique_filename)],
output=output_list,
)
return unique_filename

Expand Down
4 changes: 2 additions & 2 deletions test/test_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,11 @@ def test_histogram_disclosive(data, acro, caplog):
"""Test a discolsive histogram."""
filename = os.path.normpath("acro_artifacts/histogram_0.png")
_ = acro.hist(data, "inc_grants")
assert os.path.exists(filename)
assert not os.path.exists(filename)
acro.add_exception("output_0", "Let me have it")
results: Records = acro.finalise(path=PATH)
output_0 = results.get_index(0)
assert output_0.output == [filename]
assert output_0.output == []
assert (
"Histogram will not be shown as the inc_grants column is disclosive."
in caplog.text
Expand Down