diff --git a/medcat-service/env/app.env b/medcat-service/env/app.env index cb68c397c..9a7641611 100755 --- a/medcat-service/env/app.env +++ b/medcat-service/env/app.env @@ -36,6 +36,8 @@ SERVER_PORT=5000 SERVER_WORKERS=1 SERVER_WORKER_TIMEOUT=300 SERVER_THREADS=1 +SERVER_GUNICORN_MAX_REQUESTS=1000 +SERVER_GUNICORN_MAX_REQUESTS_JITTER=50 # set the number of torch threads, this should be used ONLY if you are using CPUs and the default image # set to -1 or 0 if you are using GPU @@ -43,4 +45,4 @@ APP_TORCH_THREADS=8 # GPU SETTING # CAUTION, use only if you are using the GPU docker image. -APP_CUDA_DEVICE_COUNT=1 +APP_CUDA_DEVICE_COUNT=-1 \ No newline at end of file diff --git a/medcat-service/env/app_deid.env b/medcat-service/env/app_deid.env index 56607c72b..e59c7ad2f 100755 --- a/medcat-service/env/app_deid.env +++ b/medcat-service/env/app_deid.env @@ -36,6 +36,8 @@ SERVER_PORT=5000 SERVER_WORKERS=1 SERVER_WORKER_TIMEOUT=300 SERVER_THREADS=1 +SERVER_GUNICORN_MAX_REQUESTS=1000 +SERVER_GUNICORN_MAX_REQUESTS_JITTER=50 # set the number of torch threads, this should be used ONLY if you are using CPUs and the default image # set to -1 or 0 if you are using GPU diff --git a/medcat-service/medcat_service/demo/gradio_demo.py b/medcat-service/medcat_service/demo/gradio_demo.py index fbe7db96a..0ef5fe797 100644 --- a/medcat-service/medcat_service/demo/gradio_demo.py +++ b/medcat-service/medcat_service/demo/gradio_demo.py @@ -3,7 +3,7 @@ import gradio as gr from pydantic import BaseModel -from medcat_service.dependencies import get_medcat_processor, get_settings +from medcat_service.dependencies import get_global_processor from medcat_service.types import ProcessAPIInputContent from medcat_service.types_entities import Entity @@ -96,7 +96,7 @@ def convert_display_model_to_list_of_lists(entity_display_model: list[EntityAnno def process_input(input_text: str): - processor = get_medcat_processor(get_settings()) + processor = get_global_processor() input = ProcessAPIInputContent(text=input_text) result = processor.process_content(input.model_dump()) diff --git a/medcat-service/medcat_service/dependencies.py b/medcat-service/medcat_service/dependencies.py index f9cbab061..ce4ba5626 100644 --- a/medcat-service/medcat_service/dependencies.py +++ b/medcat-service/medcat_service/dependencies.py @@ -1,26 +1,52 @@ import logging -from functools import lru_cache -from typing import Annotated +from typing import Annotated, Optional -from fastapi import Depends +from fastapi import Depends, Request from medcat_service.config import Settings from medcat_service.nlp_processor.medcat_processor import MedCatProcessor log = logging.getLogger(__name__) +processor_singleton: Optional[MedCatProcessor] = None +settings_singleton: Optional[Settings] = None -@lru_cache -def get_settings() -> Settings: - settings = Settings() - log.debug("Using settings: %s", settings) - return settings +def get_settings(request: Request) -> Settings: + _settings = request.app.state.settings + log.debug("Using settings: %s", _settings) + return _settings -@lru_cache -def get_medcat_processor(settings: Annotated[Settings, Depends(get_settings)]) -> MedCatProcessor: - log.debug("Creating new Medcat Processsor using settings: %s", settings) - return MedCatProcessor(settings) +def set_global_settings(settings: Settings) -> None: + global settings_singleton + settings_singleton = settings + +def get_global_settings() -> Settings: + if settings_singleton is None: + raise RuntimeError("Settings have not been initialised yet") + return settings_singleton + + +def set_global_processor(proc: MedCatProcessor): + global processor_singleton + processor_singleton = proc + + +def get_medcat_processor(request: Request) -> MedCatProcessor: + proc = getattr(request.app.state, "medcat", None) + log.debug("Getting MedCatProcessor from app.state: %s", proc) + if proc is None: + raise RuntimeError("MedCatProcessor is not initialised on app.state") + return proc + + +def get_global_processor() -> MedCatProcessor: + if processor_singleton is None: + raise RuntimeError("MedCatProcessor has not been initialised yet") + return processor_singleton + + +SettingsDep = Annotated[Settings, Depends(get_settings)] MedCatProcessorDep = Annotated[MedCatProcessor, Depends(get_medcat_processor)] diff --git a/medcat-service/medcat_service/main.py b/medcat-service/medcat_service/main.py index 945d6255b..0f4d5f9c3 100644 --- a/medcat-service/medcat_service/main.py +++ b/medcat-service/medcat_service/main.py @@ -1,28 +1,55 @@ +import logging +from contextlib import asynccontextmanager + import gradio as gr from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from medcat_service.config import Settings from medcat_service.demo.gradio_demo import io -from medcat_service.dependencies import get_settings +from medcat_service.dependencies import set_global_processor, set_global_settings +from medcat_service.nlp_processor.medcat_processor import MedCatProcessor from medcat_service.routers import admin, health, process from medcat_service.types import HealthCheckFailedException -settings = get_settings() -app = FastAPI( - title="MedCAT Service", - summary="MedCAT Service", - contact={ +@asynccontextmanager +async def lifespan(app: FastAPI): + + log = logging.getLogger(__name__) + log.debug("Starting MedCAT Service lifespan setup") + + # allow overriding settings and medcat processor for testing + settings = getattr(app.state, "settings", None) + if settings is None: + settings = Settings() + app.state.settings = settings + + medcat = getattr(app.state, "medcat", None) + if medcat is None: + medcat = MedCatProcessor(settings) + app.state.medcat = medcat + + app.state.title = "MedCAT Service", + app.state.summary = "MedCAT Service", + app.state.contact = { "name": "CogStack Org", "url": "https://cogstack.org/", "email": "contact@cogstack.org", }, - license_info={ + app.state.license_info = { "name": "Apache 2.0", "identifier": "Apache-2.0", }, - root_path=settings.app_root_path, -) + app.state.root_path = settings.app_root_path + + set_global_settings(settings) + set_global_processor(medcat) + log.debug("MedCAT Service lifespan setup complete") + + yield + +app = FastAPI(lifespan=lifespan) app.include_router(admin.router) app.include_router(health.router) @@ -35,9 +62,9 @@ async def healthcheck_failed_exception_handler(request: Request, exc: HealthCheckFailedException): return JSONResponse(status_code=503, content=exc.reason.model_dump()) - if __name__ == "__main__": # Only run this when directly executing `python main.py` for local dev. import os + import uvicorn uvicorn.run("medcat_service.main:app", host="0.0.0.0", port=int(os.environ.get("SERVER_PORT", 8000))) diff --git a/medcat-service/medcat_service/test/common.py b/medcat-service/medcat_service/test/common.py index 4f1747868..2f3efd5a7 100644 --- a/medcat-service/medcat_service/test/common.py +++ b/medcat-service/medcat_service/test/common.py @@ -3,9 +3,20 @@ import logging import os +from medcat_service.config import Settings + log = logging.getLogger(__name__) +def get_settings_override_deid(): + return Settings( + deid_mode=True, + deid_redact=True, + APP_LOG_LEVEL=10, + MEDCAT_LOG_LEVEL=10 + ) # type: ignore + + def get_example_short_document(): """ Returns an example short document to be processed with possibly minimal set of annotations to be validated diff --git a/medcat-service/medcat_service/test/test_admin.py b/medcat-service/medcat_service/test/test_admin.py index b4b82b048..22cbebd8b 100644 --- a/medcat-service/medcat_service/test/test_admin.py +++ b/medcat-service/medcat_service/test/test_admin.py @@ -11,7 +11,8 @@ class TestAdminApi(unittest.TestCase): def setUp(self): setup_medcat_processor() - self.client = TestClient(app) + self._client_ctx = TestClient(app) + self.client = self._client_ctx.__enter__() def testGetInfo(self): response = self.client.get(self.ENDPOINT_INFO_ENDPOINT) diff --git a/medcat-service/medcat_service/test/test_deid.py b/medcat-service/medcat_service/test/test_deid.py index 24fe8063a..b24a6c340 100644 --- a/medcat-service/medcat_service/test/test_deid.py +++ b/medcat-service/medcat_service/test/test_deid.py @@ -4,13 +4,8 @@ from fastapi.testclient import TestClient import medcat_service.test.common as common -from medcat_service.config import Settings -from medcat_service.dependencies import get_settings from medcat_service.main import app - - -def get_settings_override(): - return Settings(deid_mode=True, deid_redact=True) +from medcat_service.nlp_processor.medcat_processor import MedCatProcessor class TestMedcatServiceDeId(unittest.TestCase): @@ -25,14 +20,27 @@ def setUpClass(cls): if "APP_MEDCAT_MODEL_PACK" not in os.environ: os.environ["APP_MEDCAT_MODEL_PACK"] = "./models/examples/example-deid-model-pack.zip" - app.dependency_overrides[get_settings] = get_settings_override - cls.client = TestClient(app) + test_settings = common.get_settings_override_deid() + app.state.settings = test_settings + app.state.medcat = MedCatProcessor(test_settings) + + cls._client_ctx = TestClient(app) + cls.client = cls._client_ctx.__enter__() + + @classmethod + def tearDownClass(cls): + # exit context so shutdown runs + cls._client_ctx.__exit__(None, None, None) + app.dependency_overrides.clear() + + def test_settings_override_applied(self): + assert app.state.settings.deid_mode is True + assert app.state.settings.deid_redact is True def test_deid_process_api(self): payload = common.create_payload_content_from_doc_single( "John had been diagnosed with acute Kidney Failure the week before" ) - app.dependency_overrides[get_settings] = get_settings_override response = self.client.post(self.ENDPOINT_PROCESS_SINGLE, json=payload) self.assertEqual(response.status_code, 200) @@ -54,13 +62,11 @@ def test_deid_process_api(self): self.assertEqual(ann["pretty_name"], expected["pretty_name"]) self.assertEqual(ann["source_value"], expected["source_value"]) self.assertEqual(ann["cui"], expected["cui"]) - app.dependency_overrides = {} def test_deid_process_bulk_api(self): payload = common.create_payload_content_from_doc_bulk([ "John had been diagnosed with acute Kidney Failure the week before" ]) - app.dependency_overrides[get_settings] = get_settings_override response = self.client.post(self.ENDPOINT_PROCESS_BULK, json=payload) self.assertEqual(response.status_code, 200) @@ -87,4 +93,3 @@ def test_deid_process_bulk_api(self): # self.assertEqual(ann["pretty_name"], expected["pretty_name"]) # self.assertEqual(ann["source_value"], expected["source_value"]) # self.assertEqual(ann["cui"], expected["cui"]) - app.dependency_overrides = {} diff --git a/medcat-service/medcat_service/test/test_service.py b/medcat-service/medcat_service/test/test_service.py index 92362fe99..f3160f47c 100644 --- a/medcat-service/medcat_service/test/test_service.py +++ b/medcat-service/medcat_service/test/test_service.py @@ -6,7 +6,9 @@ from fastapi.testclient import TestClient import medcat_service.test.common as common +from medcat_service.config import Settings from medcat_service.main import app +from medcat_service.nlp_processor.medcat_processor import MedCatProcessor class TestMedcatService(unittest.TestCase): @@ -31,7 +33,11 @@ def setUpClass(cls): """ cls._setup_logging(cls) common.setup_medcat_processor() - cls.client = TestClient(app) + test_settings = Settings() + app.state.settings = test_settings + app.state.medcat = MedCatProcessor(test_settings) + cls._client_ctx = TestClient(app) + cls.client = cls._client_ctx.__enter__() @staticmethod def _setup_logging(cls): @@ -39,6 +45,11 @@ def _setup_logging(cls): logging.basicConfig(format=log_format, level=logging.INFO) cls.log = logging.getLogger(__name__) + @classmethod + def tearDownClass(cls): + # exit context so shutdown runs + cls._client_ctx.__exit__(None, None, None) + # unit test helper methods # def _testProcessSingleDoc(self, doc): diff --git a/medcat-service/start_service_production.sh b/medcat-service/start_service_production.sh index 5009c521f..0f5b1437a 100644 --- a/medcat-service/start_service_production.sh +++ b/medcat-service/start_service_production.sh @@ -33,6 +33,15 @@ if [ -z ${SERVER_WORKER_TIMEOUT+x} ]; then echo "SERVER_WORKER_TIMEOUT is unset -- setting to default (sec): $SERVER_WORKER_TIMEOUT"; fi +if [ -z ${SERVER_GUNICORN_MAX_REQUESTS+x} ]; then + SERVER_GUNICORN_MAX_REQUESTS=1000; + echo "SERVER_GUNICORN_MAX_REQUESTS is unset -- setting to default (sec): $SERVER_GUNICORN_MAX_REQUESTS"; +fi + +if [ -z ${SERVER_GUNICORN_MAX_REQUESTS_JITTER+x} ]; then + SERVER_GUNICORN_MAX_REQUESTS_JITTER=50; + echo "SERVER_GUNICORN_MAX_REQUESTS_JITTER is unset -- setting to default (sec): $SERVER_GUNICORN_MAX_REQUESTS_JITTER"; +fi SERVER_ACCESS_LOG_FORMAT="%(t)s [ACCESS] %(h)s \"%(r)s\" %(s)s \"%(f)s\" \"%(a)s\"" @@ -50,5 +59,7 @@ exec gunicorn \ --error-logfile=- \ --log-level info \ --config /cat/config.py \ + --max-requests="$SERVER_GUNICORN_MAX_REQUESTS" \ + --max-requests-jitter="$SERVER_GUNICORN_MAX_REQUESTS_JITTER" \ --worker-class uvicorn.workers.UvicornWorker \ medcat_service.main:app