Skip to content

Commit 545df4f

Browse files
eenblamBen ElamMark90
authored
Make auth callbacks async (#1239)
* Make auth callbacks async * Fix workflow tests * Update docs * Reorganize broadcast_invalidate_status_counts * Run DB calls in separate thread * Update documentation for async callback --------- Co-authored-by: Ben Elam <baelam@es.net> Co-authored-by: Mark90 <mark_moes@live.nl>
1 parent c5025e3 commit 545df4f

File tree

9 files changed

+78
-37
lines changed

9 files changed

+78
-37
lines changed

docs/reference-docs/auth-backend-and-frontend.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ app.register_graphql_authorization(graphql_authorization_instance)
275275
Role-based access control for workflows is currently in beta.
276276
Initial support has been added to the backend, but the feature is not fully communicated through the UI yet.
277277

278-
Certain `orchestrator-core` decorators accept authorization callbacks of type `type Authorizer = Callable[OIDCUserModel, bool]`, which return True when the input user is authorized, otherwise False.
278+
Certain `orchestrator-core` decorators accept authorization callbacks of type `type Authorizer = Callable[[OIDCUserModel | None], Awaitable[bool]]`, which return True when the input user is authorized, otherwise False.
279+
In other words, authorization callbacks are async, take a nullable OIDCUserModel (or subclass) as argument, and return a bool.
279280

280281
A table (below) is available for comparing possible configuration states with the policy that will be enforced.
281282

@@ -528,8 +529,11 @@ are prioritized in different workflow and inputstep configurations.
528529
Assume we have the following function that can be used to create callbacks:
529530

530531
```python
531-
def allow_roles(*roles) -> Callable[OIDCUserModel|None, bool]:
532-
def f(user: OIDCUserModel) -> bool:
532+
from oauth2_lib.fastapi import OIDCUserModel
533+
from orchestrator.workflows.utils import Authorizer
534+
535+
def allow_roles(*roles) -> Authorizer:
536+
async def f(user: OIDCUserModel) -> bool:
533537
if is_admin(user): # Relative to your authorization provider
534538
return True
535539
for role in roles:

orchestrator/api/api_v1/endpoints/processes.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
"""Module that implements process related API endpoints."""
1515

16+
import asyncio
1617
import struct
1718
import zlib
1819
from http import HTTPStatus
@@ -62,10 +63,12 @@
6263
from orchestrator.websocket import (
6364
WS_CHANNELS,
6465
broadcast_invalidate_status_counts,
66+
broadcast_invalidate_status_counts_async,
6567
broadcast_process_update_to_websocket,
6668
websocket_manager,
6769
)
6870
from orchestrator.workflow import ProcessStat, ProcessStatus, StepList, Workflow
71+
from orchestrator.workflows import get_workflow
6972
from pydantic_forms.types import JSON, State
7073

7174
router = APIRouter()
@@ -175,16 +178,29 @@ def delete(process_id: UUID) -> None:
175178
status_code=HTTPStatus.CREATED,
176179
dependencies=[Depends(check_global_lock, use_cache=False)],
177180
)
178-
def new_process(
181+
async def new_process(
179182
workflow_key: str,
180183
request: Request,
181184
json_data: list[dict[str, Any]] | None = Body(...),
182185
user: str = Depends(user_name),
183186
user_model: OIDCUserModel | None = Depends(authenticate),
184187
) -> dict[str, UUID]:
185188
broadcast_func = api_broadcast_process_data(request)
186-
process_id = start_process(
187-
workflow_key, user_inputs=json_data, user_model=user_model, user=user, broadcast_func=broadcast_func
189+
190+
workflow = get_workflow(workflow_key)
191+
if not workflow:
192+
raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
193+
194+
if not await workflow.authorize_callback(user_model):
195+
raise_status(HTTPStatus.FORBIDDEN, f"User is not authorized to execute '{workflow_key}' workflow")
196+
197+
process_id = await asyncio.to_thread(
198+
start_process,
199+
workflow_key,
200+
user_inputs=json_data,
201+
user_model=user_model,
202+
user=user,
203+
broadcast_func=broadcast_func,
188204
)
189205

190206
return {"id": process_id}
@@ -196,31 +212,31 @@ def new_process(
196212
status_code=HTTPStatus.NO_CONTENT,
197213
dependencies=[Depends(check_global_lock, use_cache=False)],
198214
)
199-
def resume_process_endpoint(
215+
async def resume_process_endpoint(
200216
process_id: UUID,
201217
request: Request,
202218
json_data: JSON = Body(...),
203219
user: str = Depends(user_name),
204220
user_model: OIDCUserModel | None = Depends(authenticate),
205221
) -> None:
206-
process = _get_process(process_id)
222+
process = await asyncio.to_thread(_get_process, process_id)
207223

208224
if not can_be_resumed(process.last_status):
209225
raise_status(HTTPStatus.CONFLICT, f"Resuming a {process.last_status.lower()} workflow is not possible")
210226

211227
pstat = load_process(process)
212228
auth_resume, auth_retry = get_auth_callbacks(get_steps_to_evaluate_for_rbac(pstat), pstat.workflow)
213229
if process.last_status == ProcessStatus.SUSPENDED:
214-
if auth_resume is not None and not auth_resume(user_model):
230+
if auth_resume is not None and not (await auth_resume(user_model)):
215231
raise_status(HTTPStatus.FORBIDDEN, "User is not authorized to resume step")
216232
elif process.last_status in (ProcessStatus.FAILED, ProcessStatus.WAITING):
217-
if auth_retry is not None and not auth_retry(user_model):
233+
if auth_retry is not None and not (await auth_retry(user_model)):
218234
raise_status(HTTPStatus.FORBIDDEN, "User is not authorized to retry step")
219235

220-
broadcast_invalidate_status_counts()
236+
await broadcast_invalidate_status_counts_async()
221237
broadcast_func = api_broadcast_process_data(request)
222238

223-
resume_process(process, user=user, user_inputs=json_data, broadcast_func=broadcast_func)
239+
await asyncio.to_thread(resume_process, process, user=user, user_inputs=json_data, broadcast_func=broadcast_func)
224240

225241

226242
@router.post(

orchestrator/graphql/schemas/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ async def is_allowed(self, info: OrchestratorInfo) -> bool:
3838
workflow_table = get_original_model(self, WorkflowTable)
3939
workflow = get_workflow(workflow_table.name)
4040

41-
return workflow.authorize_callback(oidc_user) # type: ignore
41+
return await workflow.authorize_callback(oidc_user) # type: ignore

orchestrator/services/processes.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,6 @@ def run() -> WFProcess:
421421
return process_id
422422

423423

424-
def error_message_unauthorized(workflow_key: str) -> str:
425-
return f"User is not authorized to execute '{workflow_key}' workflow"
426-
427-
428424
def create_process(
429425
workflow_key: str,
430426
user_inputs: list[State] | None = None,
@@ -442,9 +438,6 @@ def create_process(
442438
if not workflow:
443439
raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
444440

445-
if not workflow.authorize_callback(user_model):
446-
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
447-
448441
initial_state = {
449442
"process_id": process_id,
450443
"reporter": user,

orchestrator/utils/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable
1+
from collections.abc import Awaitable, Callable
22
from typing import TypeAlias, TypeVar
33

44
from oauth2_lib.fastapi import OIDCUserModel
@@ -7,4 +7,4 @@
77

88
# Can instead use "type Authorizer = ..." in later Python versions.
99
T = TypeVar("T", bound=OIDCUserModel)
10-
Authorizer: TypeAlias = Callable[[T | None], bool]
10+
Authorizer: TypeAlias = Callable[[T | None], Awaitable[bool]]

orchestrator/websocket/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ async def invalidate_subscription_cache(subscription_id: UUID | UUIDstr, invalid
105105
await broadcast_invalidate_cache({"type": "subscriptions", "id": str(subscription_id)})
106106

107107

108+
async def broadcast_invalidate_status_counts_async() -> None:
109+
"""Broadcast message to invalidate the status counts of the connected websocket clients.
110+
111+
This breaks the pattern of `sync_` prefixes to maintain backwards compatibility of
112+
broadcast_invalidate_status_counts, a sync function.
113+
"""
114+
if not websocket_manager.enabled:
115+
logger.debug("WebSocketManager is not enabled. Skip broadcasting through websocket.")
116+
return
117+
118+
await broadcast_invalidate_cache({"type": "processStatusCounts"})
119+
120+
108121
def broadcast_invalidate_status_counts() -> None:
109122
"""Broadcast message to invalidate the status counts of the connected websocket clients."""
110123
if not websocket_manager.enabled:
@@ -148,4 +161,5 @@ async def broadcast_process_update_to_websocket_async(
148161
"broadcast_process_update_to_websocket_async",
149162
"WS_CHANNELS",
150163
"broadcast_invalidate_status_counts",
164+
"broadcast_invalidate_status_counts_async",
151165
]

orchestrator/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def form_generator(state: State) -> FormGenerator:
193193
return form_generator
194194

195195

196-
def allow(_: OIDCUserModel | None) -> bool:
196+
async def allow(_: OIDCUserModel | None) -> bool:
197197
"""Default function to return True in absence of user-defined authorize function."""
198198
return True
199199

test/unit_tests/api/test_processes.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -544,12 +544,26 @@ def test_resume_all_processes_value_error(test_client, mocked_processes_resumeal
544544
)
545545
def test_create_process_reporter(test_client, fastapi_app, oidc_user, reporter, expected_user):
546546
# given
547+
async def allow(_: object) -> bool:
548+
return True
549+
550+
fake_workflow = make_workflow(
551+
f=lambda _: None,
552+
description="fake",
553+
initial_input_form=None,
554+
target=Target.CREATE,
555+
steps=StepList([]),
556+
authorize_callback=allow,
557+
retry_auth_callback=allow,
558+
)
547559
url_params = {"reporter": reporter} if reporter is not None else {}
548560
fastapi_depends = {authenticate: lambda: oidc_user}
549561
with (
562+
mock.patch("orchestrator.api.api_v1.endpoints.processes.get_workflow") as mock_get_workflow,
550563
mock.patch("orchestrator.api.api_v1.endpoints.processes.start_process") as mock_start_process,
551564
mock.patch.dict(fastapi_app.dependency_overrides, fastapi_depends),
552565
):
566+
mock_get_workflow.return_value = fake_workflow
553567
mock_start_process.return_value = uuid.uuid4()
554568
# when
555569
response = test_client.post("/api/processes/fake_workflow", json=[], params=url_params)
@@ -609,7 +623,7 @@ def test_new_process_higher_version_invalid(test_client, generic_subscription_1)
609623

610624

611625
def test_unauthorized_to_run_process(test_client):
612-
def disallow(_: OIDCUserModel | None = None) -> bool:
626+
async def disallow(_: OIDCUserModel | None = None) -> bool:
613627
return False
614628

615629
@workflow("unauthorized_workflow", target=Target.CREATE, authorize_callback=disallow)
@@ -623,10 +637,10 @@ def unauthorized_workflow():
623637

624638
@pytest.fixture
625639
def authorize_resume_workflow():
626-
def disallow(_: OIDCUserModel | None = None) -> bool:
640+
async def disallow(_: OIDCUserModel | None = None) -> bool:
627641
return False
628642

629-
def allow(_: OIDCUserModel | None = None) -> bool:
643+
async def allow(_: OIDCUserModel | None = None) -> bool:
630644
return True
631645

632646
class ConfirmForm(FormPage):
@@ -719,19 +733,19 @@ def test_unauthorized_resume_input_step(test_client, process_on_unauthorized_res
719733
assert HTTPStatus.FORBIDDEN == response.status_code
720734

721735

722-
def _A(_: OIDCUserModel) -> bool:
736+
async def _A(_: OIDCUserModel) -> bool:
723737
return True
724738

725739

726-
def _B(_: OIDCUserModel) -> bool:
740+
async def _B(_: OIDCUserModel) -> bool:
727741
return True
728742

729743

730-
def _C(_: OIDCUserModel) -> bool:
744+
async def _C(_: OIDCUserModel) -> bool:
731745
return True
732746

733747

734-
def _D(_: OIDCUserModel) -> bool:
748+
async def _D(_: OIDCUserModel) -> bool:
735749
return True
736750

737751

@@ -844,10 +858,10 @@ def test_continue_awaiting_process_endpoint_wrong_process_status(test_client, pr
844858

845859
@pytest.fixture
846860
def authorize_step_group_retry_workflow():
847-
def disallow(_: OIDCUserModel | None = None) -> bool:
861+
async def disallow(_: OIDCUserModel | None = None) -> bool:
848862
return False
849863

850-
def allow(_: OIDCUserModel | None = None) -> bool:
864+
async def allow(_: OIDCUserModel | None = None) -> bool:
851865
return True
852866

853867
steps = StepList([])
@@ -919,10 +933,10 @@ def test_unauthorized_step_group_retry(test_client, process_on_unretriable_step_
919933

920934
@pytest.fixture
921935
def authorize_step_retry_workflow():
922-
def disallow(_: OIDCUserModel | None = None) -> bool:
936+
async def disallow(_: OIDCUserModel | None = None) -> bool:
923937
return False
924938

925-
def allow(_: OIDCUserModel | None = None) -> bool:
939+
async def allow(_: OIDCUserModel | None = None) -> bool:
926940
return True
927941

928942
@step("authorized_retry", retry_auth_callback=allow)
@@ -998,10 +1012,10 @@ def test_unauthorized_step_retry(test_client, process_on_unretriable_step):
9981012

9991013
@pytest.fixture
10001014
def authorize_retrystep_retry_workflow():
1001-
def disallow(_: OIDCUserModel | None = None) -> bool:
1015+
async def disallow(_: OIDCUserModel | None = None) -> bool:
10021016
return False
10031017

1004-
def allow(_: OIDCUserModel | None = None) -> bool:
1018+
async def allow(_: OIDCUserModel | None = None) -> bool:
10051019
return True
10061020

10071021
@retrystep("authorized_retry", retry_auth_callback=allow)

test/unit_tests/graphql/test_workflows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_workflows_sort_by_resource_type_desc(test_client):
246246
def test_workflows_not_allowed(test_client):
247247
forbidden_workflow_name = "unauthorized_workflow"
248248

249-
def disallow(_: OIDCUserModel | None = None) -> bool:
249+
async def disallow(_: OIDCUserModel | None = None) -> bool:
250250
return False
251251

252252
@workflow(forbidden_workflow_name, target=Target.CREATE, authorize_callback=disallow)

0 commit comments

Comments
 (0)