Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions shelfmark/config/security_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from collections.abc import Callable

_OIDC_LOCKOUT_MESSAGE = "A local admin account with a password is required before enabling OIDC. Use the 'Go to Users' button above to create one. This ensures you can still sign in if your identity provider is unavailable."
_OIDC_REQUIRED_FIELDS = (
("OIDC_DISCOVERY_URL", "Discovery URL"),
("OIDC_CLIENT_ID", "Client ID"),
("OIDC_CLIENT_SECRET", "Client Secret"),
)


def _has_local_password_admin() -> bool:
Expand All @@ -23,6 +28,30 @@ def _has_local_password_admin() -> bool:
)


def _load_effective_security_values(values: dict[str, Any]) -> dict[str, Any]:
"""Merge the current save payload onto the persisted security config."""
from shelfmark.core.settings_registry import load_config_file

effective_values = load_config_file("security")
effective_values.update(values)
return effective_values


def _get_missing_oidc_required_fields(effective_values: dict[str, Any]) -> list[str]:
"""Return missing required OIDC field labels from the effective config."""
missing_fields: list[str] = []

for key, label in _OIDC_REQUIRED_FIELDS:
value = effective_values.get(key)
if value is None:
missing_fields.append(label)
continue
if isinstance(value, str) and not value.strip():
missing_fields.append(label)

return missing_fields


def on_save_security(
values: dict[str, Any],
) -> dict[str, Any]:
Expand All @@ -45,8 +74,21 @@ def on_save_security(
strip_trailing_slash=False,
)

if normalized_values.get("AUTH_METHOD") == "oidc" and not _has_local_password_admin():
return {"error": True, "message": _OIDC_LOCKOUT_MESSAGE, "values": normalized_values}
effective_values = _load_effective_security_values(normalized_values)
auth_method = str(effective_values.get("AUTH_METHOD", "") or "").strip().lower()

if auth_method == "oidc":
if not _has_local_password_admin():
return {"error": True, "message": _OIDC_LOCKOUT_MESSAGE, "values": normalized_values}

missing_fields = _get_missing_oidc_required_fields(effective_values)
if missing_fields:
missing_fields_text = ", ".join(missing_fields)
return {
"error": True,
"message": f"OIDC configuration is incomplete: missing {missing_fields_text}.",
"values": normalized_values,
}

return {"error": False, "values": normalized_values}

Expand Down
88 changes: 81 additions & 7 deletions tests/config/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from shelfmark.core.user_db import UserDB


def _set_config_dir(monkeypatch, config_dir: Path) -> None:
"""Point config helpers at a test-local config directory."""
monkeypatch.setenv("CONFIG_DIR", str(config_dir))
monkeypatch.setattr("shelfmark.config.env.CONFIG_DIR", config_dir)


@pytest.fixture
def temp_config_dir():
"""Create a temporary config directory for tests."""
Expand Down Expand Up @@ -404,7 +410,7 @@ class TestSecurityOnSave:
def test_on_save_passthrough_for_non_oidc(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security

monkeypatch.setenv("CONFIG_DIR", str(tmp_path))
_set_config_dir(monkeypatch, tmp_path)
values = {"AUTH_METHOD": "builtin", "PROXY_AUTH_USER_HEADER": "X-Auth-User"}

result = _on_save_security(values.copy())
Expand All @@ -415,32 +421,100 @@ def test_on_save_passthrough_for_non_oidc(self, tmp_path, monkeypatch):
def test_on_save_blocks_oidc_without_local_admin(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security

monkeypatch.setenv("CONFIG_DIR", str(tmp_path))
_set_config_dir(monkeypatch, tmp_path)
UserDB(str(tmp_path / "users.db")).initialize()

result = _on_save_security({"AUTH_METHOD": "oidc"})

assert result["error"] is True
assert "local admin" in result["message"].lower()

def test_on_save_allows_oidc_with_local_password_admin(self, tmp_path, monkeypatch):
def test_on_save_blocks_oidc_when_client_id_is_missing(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security

monkeypatch.setenv("CONFIG_DIR", str(tmp_path))
_set_config_dir(monkeypatch, tmp_path)
user_db = UserDB(str(tmp_path / "users.db"))
user_db.initialize()
user_db.create_user(username="admin", password_hash="hash", role="admin")

result = _on_save_security({"AUTH_METHOD": "oidc"})
result = _on_save_security(
{
"AUTH_METHOD": "oidc",
"OIDC_DISCOVERY_URL": "https://auth.example.com/.well-known/openid-configuration",
"OIDC_CLIENT_SECRET": "secret123",
}
)

assert result["error"] is True
assert "client id" in result["message"].lower()

def test_on_save_blocks_oidc_when_discovery_url_is_missing(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security

_set_config_dir(monkeypatch, tmp_path)
user_db = UserDB(str(tmp_path / "users.db"))
user_db.initialize()
user_db.create_user(username="admin", password_hash="hash", role="admin")

result = _on_save_security(
{
"AUTH_METHOD": "oidc",
"OIDC_CLIENT_ID": "shelfmark",
"OIDC_CLIENT_SECRET": "secret123",
}
)

assert result["error"] is True
assert "discovery url" in result["message"].lower()

def test_on_save_blocks_oidc_when_secret_is_missing(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security

_set_config_dir(monkeypatch, tmp_path)
user_db = UserDB(str(tmp_path / "users.db"))
user_db.initialize()
user_db.create_user(username="admin", password_hash="hash", role="admin")

result = _on_save_security(
{
"AUTH_METHOD": "oidc",
"OIDC_DISCOVERY_URL": "https://auth.example.com/.well-known/openid-configuration",
"OIDC_CLIENT_ID": "shelfmark",
}
)

assert result["error"] is True
assert "client secret" in result["message"].lower()

def test_on_save_allows_oidc_with_existing_saved_secret(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security
from shelfmark.core.settings_registry import save_config_file

_set_config_dir(monkeypatch, tmp_path)
user_db = UserDB(str(tmp_path / "users.db"))
user_db.initialize()
user_db.create_user(username="admin", password_hash="hash", role="admin")
save_config_file(
"security",
{
"AUTH_METHOD": "oidc",
"OIDC_DISCOVERY_URL": "https://auth.example.com/.well-known/openid-configuration",
"OIDC_CLIENT_ID": "existing-client",
"OIDC_CLIENT_SECRET": "saved-secret",
},
)

result = _on_save_security({"AUTH_METHOD": "oidc", "OIDC_CLIENT_ID": "updated-client"})

assert result["error"] is False
assert result["values"]["OIDC_CLIENT_ID"] == "updated-client"

def test_on_save_normalizes_oidc_discovery_url_without_stripping_trailing_slash(
self, tmp_path, monkeypatch
):
from shelfmark.config.security import _on_save_security

monkeypatch.setenv("CONFIG_DIR", str(tmp_path))
_set_config_dir(monkeypatch, tmp_path)
values = {"OIDC_DISCOVERY_URL": " 'auth.example.com/.well-known/openid-configuration/' "}

result = _on_save_security(values)
Expand All @@ -454,7 +528,7 @@ def test_on_save_normalizes_oidc_discovery_url_without_stripping_trailing_slash(
def test_on_save_normalizes_proxy_logout_url(self, tmp_path, monkeypatch):
from shelfmark.config.security import _on_save_security

monkeypatch.setenv("CONFIG_DIR", str(tmp_path))
_set_config_dir(monkeypatch, tmp_path)
values = {"PROXY_AUTH_LOGOUT_URL": "auth.example.com/logout"}

result = _on_save_security(values)
Expand Down
10 changes: 9 additions & 1 deletion tests/core/test_admin_users_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,7 @@ def setup_config_dir(self, db_path, tmp_path, monkeypatch):
"""Point CONFIG_DIR to a temp dir so _on_save_security can find users.db."""
config_dir = str(tmp_path)
monkeypatch.setenv("CONFIG_DIR", config_dir)
monkeypatch.setattr("shelfmark.config.env.CONFIG_DIR", tmp_path)
# Create user_db at the path _on_save_security will look for
self._user_db = UserDB(os.path.join(config_dir, "users.db"))
self._user_db.initialize()
Expand Down Expand Up @@ -1768,7 +1769,14 @@ def test_oidc_allowed_with_local_admin(self):
password_hash="hashed_pw",
role="admin",
)
result = self._call_on_save({"AUTH_METHOD": "oidc"})
result = self._call_on_save(
{
"AUTH_METHOD": "oidc",
"OIDC_DISCOVERY_URL": "https://auth.example.com/.well-known/openid-configuration",
"OIDC_CLIENT_ID": "shelfmark",
"OIDC_CLIENT_SECRET": "secret123",
}
)
assert result["error"] is False

def test_non_oidc_methods_not_blocked(self):
Expand Down
Loading