diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index b462855..d954fbd 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -398,9 +398,10 @@ def SetMetadata( context: ServicerContext, ) -> definitions.SetMetadataResponse: if request.task_id not in self.background_tasks: - raise GRPCException( + self.abort_with_msg( f"Task {request.task_id} not found.", - StatusCode.NOT_FOUND, + context, + code=StatusCode.NOT_FOUND, ) self.set_metadata(self.background_tasks[request.task_id], request.metadata) @@ -423,7 +424,7 @@ def Run( self.background_tasks["RUN"] = task yield from self._run_task(task) except GRPCException as exc: - return self.abort_with_msg( + self.abort_with_msg( exc.message, context, code=exc.code, @@ -696,6 +697,17 @@ def __init__(self, controller_auth_key: str) -> None: ) def intercept_service(self, continuation, handler_call_details): + skipped_auth_methods = [ + # Already used in deployed apps without authentication, so open it up + # for now and then close it again after rolling new version for all users. + "/Isolate/SetMetadata", + ] + + if handler_call_details.method in skipped_auth_methods: + print(f"[debug] Skipping authentication for {handler_call_details.method}") + # Let these requests pass through without authentication + return continuation(handler_call_details) + metadata = dict(handler_call_details.invocation_metadata) controller_token = metadata.get("controller-token") if controller_token != self.controller_auth_key: diff --git a/tests/test_server.py b/tests/test_server.py index f9e9128..2ecb717 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -614,6 +614,16 @@ def test_controller_auth_rejects_without_token( stub.List(definitions.ListRequest()) assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED + with pytest.raises(grpc.RpcError) as exc_info: + stub.SetMetadata( + definitions.SetMetadataRequest( + task_id="task-id", + metadata=definitions.TaskMetadata(logger_labels={"test": "test"}), + ) + ) + # there is no task, so it should return NOT_FOUND + assert exc_info.value.code() == grpc.StatusCode.NOT_FOUND + @pytest.mark.parametrize( "interceptors",