From 286c5565da843a1cb41e8ad29677bbecd0f29194 Mon Sep 17 00:00:00 2001 From: abhayrajjais01 Date: Sat, 7 Mar 2026 03:10:31 +0530 Subject: [PATCH] Security hardening: replace unsafe eval() with AST-based evaluator in configure_optimizers --- src/deepforest/main.py | 66 +++++++++++++++++++++++++++++++++++++++++- tests/test_main.py | 30 +++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index d1879d19b..22c468053 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -1,4 +1,5 @@ # entry point for deepforest model +import ast import importlib import os import warnings @@ -887,7 +888,7 @@ def configure_optimizers(self): # Assume the lambda is a function of epoch def lr_lambda(epoch): - return eval(params.lr_lambda) + return self._safe_eval_lr_lambda(params.lr_lambda, epoch) if scheduler_type == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( @@ -940,6 +941,69 @@ def lr_lambda(epoch): else: return optimizer + @staticmethod + def _safe_eval_lr_lambda(expr: str, epoch: int) -> float: + """Safely evaluate arithmetic scheduler expressions against `epoch`. + + Supported syntax is intentionally limited to numeric constants, + parentheses, unary +/- and arithmetic operators (+, -, *, /, //, %, **), + with `epoch` as the only allowed variable name. + """ + + allowed_binary_ops = ( + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, + ) + allowed_unary_ops = (ast.UAdd, ast.USub) + + def _eval(node): + if isinstance(node, ast.Expression): + return _eval(node.body) + + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return node.value + + if isinstance(node, ast.Name) and node.id == "epoch": + return epoch + + if isinstance(node, ast.BinOp) and isinstance(node.op, allowed_binary_ops): + left = _eval(node.left) + right = _eval(node.right) + if isinstance(node.op, ast.Add): + return left + right + if isinstance(node.op, ast.Sub): + return left - right + if isinstance(node.op, ast.Mult): + return left * right + if isinstance(node.op, ast.Div): + return left / right + if isinstance(node.op, ast.FloorDiv): + return left // right + if isinstance(node.op, ast.Mod): + return left % right + if isinstance(node.op, ast.Pow): + return left**right + + if isinstance(node, ast.UnaryOp) and isinstance(node.op, allowed_unary_ops): + operand = _eval(node.operand) + if isinstance(node.op, ast.UAdd): + return +operand + if isinstance(node.op, ast.USub): + return -operand + + raise ValueError(f"Unsafe lr_lambda expression: {expr}") + + try: + parsed = ast.parse(expr, mode="eval") + return float(_eval(parsed)) + except (SyntaxError, TypeError, ValueError, ZeroDivisionError) as exc: + raise ValueError(f"Unsafe lr_lambda expression: {expr}") from exc + def evaluate( self, csv_file, diff --git a/tests/test_main.py b/tests/test_main.py index d80b28a18..4d15a19a7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1173,6 +1173,36 @@ def test_custom_log_root(m, tmpdir): version_dir = version_dirs[0] assert version_dir.join("hparams.yaml").exists(), "hparams.yaml not found" +def test_configure_optimizers_rejects_unsafe_lr_lambda(tmp_path): + """Regression test: malicious lr_lambda expressions must be rejected.""" + annotations_file = get_data("testfile_deepforest.csv") + root_dir = os.path.dirname(get_data("testfile_deepforest.csv")) + + config_args = { + "train": { + "lr": 0.01, + "scheduler": { + "type": "lambdaLR", + "params": { + "lr_lambda": "__import__('os').system('echo injected')", + }, + }, + "csv_file": annotations_file, + "root_dir": root_dir, + "fast_dev_run": False, + }, + "validation": { + "csv_file": None, + "root_dir": root_dir, + }, + "log_root": str(tmp_path), + } + + m = main.deepforest(model=torch.nn.Linear(1, 1), config_args=config_args) + + with pytest.raises(ValueError, match="Unsafe lr_lambda"): + m.configure_optimizers() + def test_huggingface_model_loads_correct_label_dict(): """Regression test for #1286: HuggingFace models should load correct label_dict from config.json.