Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion fia_api/routers/instrument.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import os
from typing import Annotated

from fastapi import APIRouter, Depends
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session

from fia_api.core.auth.tokens import JWTAPIBearer, get_user_from_token
from fia_api.core.cache import cache_get_json, cache_set_json
from fia_api.core.exceptions import AuthError
from fia_api.core.services.instrument import get_latest_run_by_instrument_name, update_latest_run_for_instrument
from fia_api.core.session import get_db_session

InstrumentRouter = APIRouter(prefix="/instrument")
jwt_api_security = JWTAPIBearer()
INSTRUMENT_LATEST_RUN_CACHE_TTL_SECONDS = int(os.environ.get("INSTRUMENT_LATEST_RUN_CACHE_TTL_SECONDS", "15"))


def _latest_run_cache_key(instrument: str) -> str:
return f"fia_api:instrument:latest_run:{instrument.upper()}"


@InstrumentRouter.get("/{instrument}/latest-run", tags=["instrument"])
Expand All @@ -30,8 +37,19 @@ async def get_instrument_latest_run(
if user.role != "staff":
# If not staff this is not allowed
raise AuthError("User not authorised for this action")

if INSTRUMENT_LATEST_RUN_CACHE_TTL_SECONDS > 0:
cached = cache_get_json(_latest_run_cache_key(instrument))
if isinstance(cached, dict):
return cached

latest_run = get_latest_run_by_instrument_name(instrument.upper(), session)
return {"latest_run": latest_run}
payload = {"latest_run": latest_run}

if INSTRUMENT_LATEST_RUN_CACHE_TTL_SECONDS > 0:
cache_set_json(_latest_run_cache_key(instrument), payload, INSTRUMENT_LATEST_RUN_CACHE_TTL_SECONDS)

return payload


@InstrumentRouter.put("/{instrument}/latest-run", tags=["instrument"])
Expand All @@ -55,4 +73,5 @@ async def update_instrument_latest_run(
# If not staff this is not allowed
raise AuthError("User not authorised for this action")
update_latest_run_for_instrument(instrument.upper(), latest_run["latest_run"], session)
cache_set_json(_latest_run_cache_key(instrument), None, 1)
return {"latest_run": latest_run["latest_run"]}
23 changes: 22 additions & 1 deletion fia_api/routers/instrument_specs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Annotated, Any

from fastapi import APIRouter, Depends
Expand All @@ -6,12 +7,18 @@
from sqlalchemy.orm import Session

from fia_api.core.auth.tokens import JWTAPIBearer, get_user_from_token
from fia_api.core.cache import cache_get_json, cache_set_json
from fia_api.core.exceptions import AuthError
from fia_api.core.services.instrument import get_specification_by_instrument_name, update_specification_for_instrument
from fia_api.core.session import get_db_session

InstrumentSpecRouter = APIRouter()
jwt_api_security = JWTAPIBearer()
INSTRUMENT_SPEC_CACHE_TTL_SECONDS = int(os.environ.get("INSTRUMENT_SPEC_CACHE_TTL_SECONDS", "120"))


def _spec_cache_key(instrument_name: str) -> str:
return f"fia_api:instrument:spec:{instrument_name.upper()}"


@InstrumentSpecRouter.get(
Expand All @@ -32,7 +39,20 @@ async def get_instrument_specification(
if user.role != "staff":
# If not staff this is not allowed
raise AuthError("User not authorised for this action")
return get_specification_by_instrument_name(instrument_name.upper(), session)

if INSTRUMENT_SPEC_CACHE_TTL_SECONDS > 0:
cached = cache_get_json(_spec_cache_key(instrument_name))
if isinstance(cached, dict):
return cached.get("specification")

specification = get_specification_by_instrument_name(instrument_name.upper(), session)

if INSTRUMENT_SPEC_CACHE_TTL_SECONDS > 0:
cache_set_json(
_spec_cache_key(instrument_name), {"specification": specification}, INSTRUMENT_SPEC_CACHE_TTL_SECONDS
)

return specification


@InstrumentSpecRouter.put("/instrument/{instrument_name}/specification", tags=["instrument specifications"])
Expand All @@ -54,4 +74,5 @@ async def update_instrument_specification(
# If not staff this is not allowed
raise AuthError("User not authorised for this action")
update_specification_for_instrument(instrument_name.upper(), specification, session)
cache_set_json(_spec_cache_key(instrument_name), None, 1)
return specification
37 changes: 33 additions & 4 deletions fia_api/routers/live_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Live Data Script router
import os
from typing import Annotated, Literal

from fastapi import APIRouter, Depends
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session

from fia_api.core.auth.tokens import JWTAPIBearer, get_user_from_token
from fia_api.core.cache import cache_get, cache_set_json
from fia_api.core.cache import cache_get, cache_get_json, cache_set_json
from fia_api.core.exceptions import AuthError
from fia_api.core.request_models import LiveDataScriptUpdateRequest
from fia_api.core.services.instrument import (
Expand All @@ -18,6 +19,8 @@

LiveDataRouter = APIRouter(tags=["live-data"])
jwt_api_security = JWTAPIBearer()
LIVE_DATA_INSTRUMENTS_CACHE_TTL_SECONDS = int(os.environ.get("LIVE_DATA_INSTRUMENTS_CACHE_TTL_SECONDS", "120"))
LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS = int(os.environ.get("LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS", "60"))


@LiveDataRouter.get("/live-data/instruments")
Expand All @@ -28,7 +31,18 @@ async def get_live_data_instruments(session: Annotated[Session, Depends(get_db_s
:param session: The current session of the request
:return: List of instrument names with live data support enabled
"""
return get_instruments_with_live_data_support(session)
cache_key = "fia_api:live_data:instruments"
if LIVE_DATA_INSTRUMENTS_CACHE_TTL_SECONDS > 0:
cached = cache_get_json(cache_key)
if isinstance(cached, list):
return cached

instruments = get_instruments_with_live_data_support(session)

if LIVE_DATA_INSTRUMENTS_CACHE_TTL_SECONDS > 0:
cache_set_json(cache_key, instruments, LIVE_DATA_INSTRUMENTS_CACHE_TTL_SECONDS)

return instruments


def _get_traceback_key(instrument: str) -> str:
Expand All @@ -46,6 +60,10 @@ async def get_instrument_traceback(instrument: str) -> str | None:
return cache_get(_get_traceback_key(instrument.lower()))


def _get_script_cache_key(instrument: str) -> str:
return f"fia_api:live_data:script:{instrument.upper()}"


@LiveDataRouter.get("/live-data/{instrument}/script")
async def get_instrument_script(instrument: str, session: Annotated[Session, Depends(get_db_session)]) -> str | None:
"""
Expand All @@ -55,7 +73,17 @@ async def get_instrument_script(instrument: str, session: Annotated[Session, Dep
:param session: The current session of the request
:return: The live data script or None
"""
return get_live_data_script_by_instrument_name(instrument.upper(), session)
if LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS > 0:
cached = cache_get_json(_get_script_cache_key(instrument))
if isinstance(cached, dict):
return cached.get("script")

script = get_live_data_script_by_instrument_name(instrument.upper(), session)

if LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS > 0:
cache_set_json(_get_script_cache_key(instrument), {"script": script}, LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS)

return script


@LiveDataRouter.put("/live-data/{instrument}/script")
Expand All @@ -79,6 +107,7 @@ async def update_instrument_script(
raise AuthError("Only Staff can update Live Data Scripts")

update_live_data_script_for_instrument(instrument.upper(), script_request.value, session)
# Clear traceback when script is updated
# Clear traceback and script cache when script is updated
cache_set_json(_get_traceback_key(instrument), None, 1)
cache_set_json(_get_script_cache_key(instrument), None, 1)
return "ok"
107 changes: 107 additions & 0 deletions test/e2e/test_endpoint_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Cache behavior tests for live-data and instrument endpoints."""

from http import HTTPStatus
from unittest.mock import patch

from starlette.testclient import TestClient

from fia_api.fia_api import app

from .constants import STAFF_HEADER

client = TestClient(app)


# --- GET /live-data/instruments ---


@patch("fia_api.routers.live_data.LIVE_DATA_INSTRUMENTS_CACHE_TTL_SECONDS", 120)
@patch("fia_api.routers.live_data.get_instruments_with_live_data_support")
@patch("fia_api.routers.live_data.cache_set_json")
@patch("fia_api.routers.live_data.cache_get_json")
def test_live_data_instruments_cache_hit(mock_cache_get, mock_cache_set, mock_get_instruments):
cached_payload = ["INSTRUMENT_1", "INSTRUMENT_2"]
mock_cache_get.return_value = cached_payload

response = client.get("/live-data/instruments")

assert response.status_code == HTTPStatus.OK
assert response.json() == cached_payload
mock_get_instruments.assert_not_called()
mock_cache_set.assert_not_called()


# --- GET /live-data/{instrument}/script ---


@patch("fia_api.routers.live_data.LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS", 60)
@patch("fia_api.routers.live_data.get_live_data_script_by_instrument_name")
@patch("fia_api.routers.live_data.cache_set_json")
@patch("fia_api.routers.live_data.cache_get_json")
def test_live_data_script_cache_hit(mock_cache_get, mock_cache_set, mock_get_script):
mock_cache_get.return_value = {"script": "print('hello')"}

response = client.get("/live-data/TEST/script")

assert response.status_code == HTTPStatus.OK
assert response.json() == "print('hello')"
mock_get_script.assert_not_called()
mock_cache_set.assert_not_called()


@patch("fia_api.routers.live_data.LIVE_DATA_SCRIPT_CACHE_TTL_SECONDS", 60)
@patch("fia_api.routers.live_data.get_live_data_script_by_instrument_name")
@patch("fia_api.routers.live_data.cache_set_json")
@patch("fia_api.routers.live_data.cache_get_json")
def test_live_data_script_cache_hit_none_script(mock_cache_get, mock_cache_set, mock_get_script):
"""A cached None script (instrument has no script) should still be a cache hit."""
mock_cache_get.return_value = {"script": None}

response = client.get("/live-data/TEST/script")

assert response.status_code == HTTPStatus.OK
assert response.json() is None
mock_get_script.assert_not_called()
mock_cache_set.assert_not_called()


# --- GET /instrument/{instrument_name}/specification ---


@patch("fia_api.routers.instrument_specs.INSTRUMENT_SPEC_CACHE_TTL_SECONDS", 120)
@patch("fia_api.core.auth.tokens.requests.post")
@patch("fia_api.routers.instrument_specs.get_specification_by_instrument_name")
@patch("fia_api.routers.instrument_specs.cache_set_json")
@patch("fia_api.routers.instrument_specs.cache_get_json")
def test_instrument_spec_cache_hit(mock_cache_get, mock_cache_set, mock_get_spec, mock_post):
cached_spec = {"foo": "bar", "baz": 42}
mock_cache_get.return_value = {"specification": cached_spec}
mock_post.return_value.status_code = HTTPStatus.OK

response = client.get("/instrument/TEST/specification", headers=STAFF_HEADER)

assert response.status_code == HTTPStatus.OK
assert response.json() == cached_spec
mock_get_spec.assert_not_called()
mock_cache_set.assert_not_called()


# --- GET /instrument/{instrument}/latest-run ---


@patch("fia_api.routers.instrument.INSTRUMENT_LATEST_RUN_CACHE_TTL_SECONDS", 15)
@patch("fia_api.core.auth.tokens.requests.post")
@patch("fia_api.routers.instrument.get_latest_run_by_instrument_name")
@patch("fia_api.routers.instrument.cache_set_json")
@patch("fia_api.routers.instrument.cache_get_json")
def test_instrument_latest_run_cache_hit(mock_cache_get, mock_cache_set, mock_get_latest, mock_post):
cached_payload = {"latest_run": "12345"}
mock_cache_get.return_value = cached_payload
mock_post.return_value.status_code = HTTPStatus.OK

response = client.get("/instrument/TEST/latest-run", headers=STAFF_HEADER)

assert response.status_code == HTTPStatus.OK
assert response.json() == cached_payload
mock_get_latest.assert_not_called()
mock_cache_set.assert_not_called()
Loading