Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions bluesky_httpserver/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime, timedelta
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security
from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
Expand Down Expand Up @@ -202,7 +202,6 @@ def get_current_principal(
# otherwise it is None. The original set of API key scopes is used for generating new
# API keys.
roles, scopes, api_key_scopes = {}, {}, None

if api_key is not None:
if authenticators:
# Tiled is in a multi-user configuration with authentication providers.
Expand Down Expand Up @@ -356,6 +355,40 @@ def get_current_principal(
return principal


def get_current_principal_websocket(
websocket: WebSocket,
scopes: str,
):
app = websocket.app
security_scopes = SecurityScopes(scopes=scopes or [])
settings = app.dependency_overrides[get_settings]()
authenticators = app.dependency_overrides[get_authenticators]()
api_access_manager = app.dependency_overrides[get_api_access_manager]()

auth_header = websocket.headers.get("Authorization", "")
access_token, api_key = None, None
if auth_header.startswith("Bearer "):
access_token = auth_header[len("Bearer") :].strip()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, we chose not to support these in Tiled because there is no mechanism for the server to request that they be refreshed, since it cannot send HTTP response codes.

Instead, the client mints a short-lived API key and revokes it after the connection is formed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. It should be possible to implement a refresh scheme when a token is validated by sending a plain HTTP request in case connection to a websocket fails and then refreshed if requested by the server, but it does not look like a standard approach.

if auth_header.startswith("ApiKey "):
api_key = auth_header[len("ApiKey") :].strip()

principal = None
try:
principal = get_current_principal(
request=websocket,
security_scopes=security_scopes,
access_token=access_token,
api_key=api_key,
settings=settings,
authenticators=authenticators,
api_access_manager=api_access_manager,
)
except HTTPException as ex:
print(f"WebSocket connection failed: {ex}")

return principal


def create_session(settings, identity_provider, id, scopes):
with get_sessionmaker(settings.database_settings)() as db:
# Have we seen this Identity before?
Expand Down
3 changes: 3 additions & 0 deletions bluesky_httpserver/authorization/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"write:plan:control",
"write:execute",
"write:history:edit",
"user:apikeys",
}

_DEFAULT_SCOPES_USER = {
Expand All @@ -91,6 +92,7 @@
"write:plan:control",
"write:execute",
"write:history:edit",
"user:apikeys",
}

_DEFAULT_SCOPES_OBSERVER = {
Expand All @@ -103,6 +105,7 @@
"read:console",
"read:lock",
"read:testing",
"user:apikeys",
}

# =============================================================================================
Expand Down
30 changes: 26 additions & 4 deletions bluesky_httpserver/routers/core_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
else:
from pydantic_settings import BaseSettings

from ..authentication import get_current_principal
from ..authentication import get_current_principal, get_current_principal_websocket
from ..console_output import ConsoleOutputEventStream, StreamingResponseFromClass
from ..resources import SERVER_RESOURCES as SR
from ..settings import get_settings
Expand Down Expand Up @@ -1139,7 +1139,12 @@ def is_alive(self):


@router.websocket("/console_output/ws")
async def console_output_ws(websocket: WebSocket):
async def console_output_ws(websocket: WebSocket, scopes=["read:console"]):
principal = get_current_principal_websocket(websocket=websocket, scopes=scopes)
if not principal:
await websocket.close(code=4001, reason="Invalid token")
return

await websocket.accept()
q = SR.console_output_stream.add_queue(websocket)
wsmon = WebSocketMonitor(websocket)
Expand All @@ -1151,33 +1156,48 @@ async def console_output_ws(websocket: WebSocket):
await websocket.send_text(msg)
except asyncio.TimeoutError:
pass
except RuntimeError: # 'send' after the client is disconnected
pass
except WebSocketDisconnect:
pass
finally:
SR.console_output_stream.remove_queue(websocket)


@router.websocket("/status/ws")
async def status_ws(websocket: WebSocket):
async def status_ws(websocket: WebSocket, scopes=["read:monitor"]):
principal = get_current_principal_websocket(websocket=websocket, scopes=scopes)
if not principal:
await websocket.close(code=4001, reason="Invalid token")
return

await websocket.accept()
q = SR.system_info_stream.add_queue_status(websocket)
wsmon = WebSocketMonitor(websocket)
wsmon.start()

try:
while wsmon.is_alive:
try:
msg = await asyncio.wait_for(q.get(), timeout=1)
await websocket.send_text(msg)
except asyncio.TimeoutError:
pass
except RuntimeError: # 'send' after the client is disconnected
pass
except WebSocketDisconnect:
pass
finally:
SR.system_info_stream.remove_queue_status(websocket)


@router.websocket("/info/ws")
async def info_ws(websocket: WebSocket):
async def info_ws(websocket: WebSocket, scopes=["read:monitor"]):
principal = get_current_principal_websocket(websocket=websocket, scopes=scopes)
if not principal:
await websocket.close(code=4001, reason="Invalid token")
return

await websocket.accept()
q = SR.system_info_stream.add_queue_info(websocket)
wsmon = WebSocketMonitor(websocket)
Expand All @@ -1189,6 +1209,8 @@ async def info_ws(websocket: WebSocket):
await websocket.send_text(msg)
except asyncio.TimeoutError:
pass
except RuntimeError: # 'send' after the client is disconnected
pass
except WebSocketDisconnect:
pass
finally:
Expand Down
175 changes: 175 additions & 0 deletions bluesky_httpserver/tests/test_auth_for_websockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import json
import pprint
import threading
import time as ttime

import pytest
from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401
from websockets.sync.client import connect

from .conftest import fastapi_server_fs # noqa: F401
from .conftest import (
SERVER_ADDRESS,
SERVER_PORT,
request_to_json,
setup_server_with_config_file,
wait_for_environment_to_be_closed,
wait_for_environment_to_be_created,
)

config_toy_test = """
authentication:
allow_anonymous_access: True
providers:
- provider: toy
authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator
args:
users_to_passwords:
bob: bob_password
alice: alice_password
cara: cara_password
tom: tom_password
api_access:
policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl
args:
users:
bob:
roles:
- admin
- expert
alice:
roles: advanced
tom:
roles: user
"""


class _ReceiveSystemInfoSocket(threading.Thread):
"""
Catch streaming console output by connecting to /console_output/ws socket and
save messages to the buffer.
"""

def __init__(self, *, endpoint, api_key=None, token=None, **kwargs):
super().__init__(**kwargs)
self.received_data_buffer = []
self._exit = False
self._api_key = api_key
self._token = token
self._endpoint = endpoint

def run(self):
websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}"
if self._token is not None:
additional_headers = {"Authorization": f"Bearer {self._token}"}
elif self._api_key is not None:
additional_headers = {"Authorization": f"ApiKey {self._api_key}"}
else:
additional_headers = {}

try:
with connect(websocket_uri, additional_headers=additional_headers) as websocket:
while not self._exit:
try:
msg_json = websocket.recv(timeout=0.1, decode=False)
try:
msg = json.loads(msg_json)
self.received_data_buffer.append(msg)
except json.JSONDecodeError:
pass
except TimeoutError:
pass
except Exception as ex:
print(f"Failed to connect to server: {ex}")

def stop(self):
"""
Call this method to stop the thread. Then send a request to the server so that some output
is printed in ``stdout``.
"""
self._exit = True

def __del__(self):
self.stop()


# fmt: off
@pytest.mark.parametrize("ws_auth_type", ["apikey", "token", "apikey_invalid", "token_invalid", "none"])
# fmt: on
def test_websocket_auth_01(
tmpdir,
monkeypatch,
re_manager_cmd, # noqa: F811
fastapi_server_fs, # noqa: F811
ws_auth_type,
):
"""
Test authentication for websockets. The test is run only on ``/status/ws`` websocket.
The other websockets are expected to use the same authentication scheme.
"""

# Start RE Manager
params = ["--zmq-publish-console", "ON"]
re_manager_cmd(params)

setup_server_with_config_file(config_file_str=config_toy_test, tmpdir=tmpdir, monkeypatch=monkeypatch)
fastapi_server_fs()

resp1 = request_to_json("post", "/auth/provider/toy/token", login=("bob", "bob_password"))
assert "access_token" in pprint.pformat(resp1)
token = resp1["access_token"]

resp3 = request_to_json(
"post", "/auth/apikey", json={"expires_in": 900, "note": "API key for testing"}, token=token
)
assert "secret" in resp3, pprint.pformat(resp3)
assert "note" in resp3, pprint.pformat(resp3)
assert resp3["note"] == "API key for testing"
assert resp3["scopes"] == ["inherit"]
api_key = resp3["secret"]

endpoint = "/status/ws"
if ws_auth_type == "none":
ws_params = {}
elif ws_auth_type == "apikey":
ws_params = {"api_key": api_key}
elif ws_auth_type == "apikey_invalid":
ws_params = {"api_key": "InvalidApiKey"}
elif ws_auth_type == "token":
ws_params = {"token": token}
elif ws_auth_type == "token_invalid":
ws_params = {"token": "InvalidToken"}
else:
assert False, f"Unknown authentication type: {ws_auth_type!r}"

rsc = _ReceiveSystemInfoSocket(endpoint=endpoint, **ws_params)
rsc.start()
ttime.sleep(1) # Wait until the client connects to the socket

resp1 = request_to_json("post", "/environment/open", api_key=api_key)
assert resp1["success"] is True, pprint.pformat(resp1)

assert wait_for_environment_to_be_created(timeout=10, api_key=api_key)

resp2b = request_to_json("post", "/environment/close", api_key=api_key)
assert resp2b["success"] is True, pprint.pformat(resp2b)

assert wait_for_environment_to_be_closed(timeout=10, api_key=api_key)

# Wait until capture is complete
ttime.sleep(2)
rsc.stop()
rsc.join()

buffer = rsc.received_data_buffer
if ws_auth_type in ("none", "apikey_invalid", "token_invalid"):
assert len(buffer) == 0
elif ws_auth_type in ("apikey", "token"):
assert len(buffer) > 0
for msg in buffer:
assert "time" in msg, msg
assert isinstance(msg["time"], float), msg
assert "msg" in msg
assert isinstance(msg["msg"], dict)
else:
assert False, f"Unknown authentication type: {ws_auth_type!r}"
3 changes: 2 additions & 1 deletion bluesky_httpserver/tests/test_console_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs):

def run(self):
websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api/console_output/ws"
with connect(websocket_uri) as websocket:
additional_headers = {"Authorization": f"ApiKey {self._api_key}"}
with connect(websocket_uri, additional_headers=additional_headers) as websocket:
while not self._exit:
try:
msg_json = websocket.recv(timeout=0.1, decode=False)
Expand Down
22 changes: 13 additions & 9 deletions bluesky_httpserver/tests/test_system_info_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,21 @@ def __init__(self, *, endpoint, api_key=API_KEY_FOR_TESTS, **kwargs):

def run(self):
websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}"
with connect(websocket_uri) as websocket:
while not self._exit:
try:
msg_json = websocket.recv(timeout=0.1, decode=False)
additional_headers = {"Authorization": f"ApiKey {self._api_key}"}
try:
with connect(websocket_uri, additional_headers=additional_headers) as websocket:
while not self._exit:
try:
msg = json.loads(msg_json)
self.received_data_buffer.append(msg)
except json.JSONDecodeError:
msg_json = websocket.recv(timeout=0.1, decode=False)
try:
msg = json.loads(msg_json)
self.received_data_buffer.append(msg)
except json.JSONDecodeError:
pass
except TimeoutError:
pass
except TimeoutError:
pass
except Exception as ex:
print(f"Failed to connect to server: {ex}")

def stop(self):
"""
Expand Down
Loading