Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion src/deepforest/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# entry point for deepforest model
import ast
import importlib
import os
import warnings
Expand Down Expand Up @@ -887,7 +888,7 @@ def configure_optimizers(self):

# Assume the lambda is a function of epoch
def lr_lambda(epoch):
return eval(params.lr_lambda)
return self._safe_eval_lr_lambda(params.lr_lambda, epoch)

if scheduler_type == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
Expand Down Expand Up @@ -940,6 +941,69 @@ def lr_lambda(epoch):
else:
return optimizer

@staticmethod
def _safe_eval_lr_lambda(expr: str, epoch: int) -> float:
"""Safely evaluate arithmetic scheduler expressions against `epoch`.

Supported syntax is intentionally limited to numeric constants,
parentheses, unary +/- and arithmetic operators (+, -, *, /, //, %, **),
with `epoch` as the only allowed variable name.
"""

allowed_binary_ops = (
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.FloorDiv,
ast.Mod,
ast.Pow,
)
allowed_unary_ops = (ast.UAdd, ast.USub)

def _eval(node):
if isinstance(node, ast.Expression):
return _eval(node.body)

if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return node.value

if isinstance(node, ast.Name) and node.id == "epoch":
return epoch

if isinstance(node, ast.BinOp) and isinstance(node.op, allowed_binary_ops):
left = _eval(node.left)
right = _eval(node.right)
if isinstance(node.op, ast.Add):
return left + right
if isinstance(node.op, ast.Sub):
return left - right
if isinstance(node.op, ast.Mult):
return left * right
if isinstance(node.op, ast.Div):
return left / right
if isinstance(node.op, ast.FloorDiv):
return left // right
if isinstance(node.op, ast.Mod):
return left % right
if isinstance(node.op, ast.Pow):
return left**right

if isinstance(node, ast.UnaryOp) and isinstance(node.op, allowed_unary_ops):
operand = _eval(node.operand)
if isinstance(node.op, ast.UAdd):
return +operand
if isinstance(node.op, ast.USub):
return -operand

raise ValueError(f"Unsafe lr_lambda expression: {expr}")

try:
parsed = ast.parse(expr, mode="eval")
return float(_eval(parsed))
except (SyntaxError, TypeError, ValueError, ZeroDivisionError) as exc:
raise ValueError(f"Unsafe lr_lambda expression: {expr}") from exc

def evaluate(
self,
csv_file,
Expand Down
30 changes: 30 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,36 @@ def test_custom_log_root(m, tmpdir):
version_dir = version_dirs[0]
assert version_dir.join("hparams.yaml").exists(), "hparams.yaml not found"

def test_configure_optimizers_rejects_unsafe_lr_lambda(tmp_path):
"""Regression test: malicious lr_lambda expressions must be rejected."""
annotations_file = get_data("testfile_deepforest.csv")
root_dir = os.path.dirname(get_data("testfile_deepforest.csv"))

config_args = {
"train": {
"lr": 0.01,
"scheduler": {
"type": "lambdaLR",
"params": {
"lr_lambda": "__import__('os').system('echo injected')",
},
},
"csv_file": annotations_file,
"root_dir": root_dir,
"fast_dev_run": False,
},
"validation": {
"csv_file": None,
"root_dir": root_dir,
},
"log_root": str(tmp_path),
}

m = main.deepforest(model=torch.nn.Linear(1, 1), config_args=config_args)

with pytest.raises(ValueError, match="Unsafe lr_lambda"):
m.configure_optimizers()

def test_huggingface_model_loads_correct_label_dict():
"""Regression test for #1286:
HuggingFace models should load correct label_dict from config.json.
Expand Down
Loading