diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6550ec38..b194cdec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,9 +8,11 @@ on: branches: [ "main" ] pull_request: branches: [ "main" ] + types: [opened, synchronize, reopened, ready_for_review] jobs: test: + if: github.event.pull_request.draft == false strategy: matrix: python-version: ["3.9", "3.10", "3.11"] diff --git a/.gitignore b/.gitignore index aaa34162..ae3778db 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ scrap/ .vscode/ .ruff_cache/ .python-version +.cursor/ +scripts/ diff --git a/README.md b/README.md index 6ffc231b..9f9fed06 100644 --- a/README.md +++ b/README.md @@ -10,32 +10,115 @@ -Build simple, portable, and scalable AI and NLP applications in a healthcare context πŸ’« πŸ₯. +Connect your AI models to any healthcare system with a few lines of Python πŸ’« πŸ₯. + +Integrating AI with electronic health records (EHRs) is complex, manual, and time-consuming. Let's try to change that. -Integrating electronic health record systems (EHRs) data is complex, and so is designing reliable, reactive algorithms involving unstructured healthcare data. Let's try to change that. ```bash pip install healthchain ``` First time here? Check out our [Docs](https://dotimplement.github.io/HealthChain/) page! -Came here from NHS RPySOC 2024 ✨? -[CDS sandbox walkthrough](https://dotimplement.github.io/HealthChain/cookbook/cds_sandbox/) -[Slides](https://speakerdeck.com/jenniferjiangkells/building-healthcare-context-aware-applications-with-healthchain) ## Features -- [x] πŸ”₯ Build FHIR-native pipelines or use [pre-built ones](https://dotimplement.github.io/HealthChain/reference/pipeline/pipeline/#prebuilt) for your healthcare NLP and ML tasks -- [x] πŸ”Œ Connect pipelines to any EHR system with built-in [CDA and FHIR Connectors](https://dotimplement.github.io/HealthChain/reference/pipeline/connectors/connectors/) -- [x] πŸ”„ Convert between FHIR, CDA, and HL7v2 with the [InteropEngine](https://dotimplement.github.io/HealthChain/reference/interop/interop/) -- [x] πŸ§ͺ Test your pipelines in full healthcare-context aware [sandbox](https://dotimplement.github.io/HealthChain/reference/sandbox/sandbox/) environments -- [x] πŸ—ƒοΈ Generate [synthetic healthcare data](https://dotimplement.github.io/HealthChain/reference/utilities/data_generator/) for testing and development -- [x] πŸš€ Deploy sandbox servers locally with [FastAPI](https://fastapi.tiangolo.com/) +- [x] πŸ”Œ **Gateway**: Connect to multiple EHR systems with [unified API](https://dotimplement.github.io/HealthChain/reference/gateway/gateway/) supporting FHIR, CDS Hooks, and SOAP/CDA protocols +- [x] πŸ”₯ **Pipelines**: Build FHIR-native ML workflows or use [pre-built ones](https://dotimplement.github.io/HealthChain/reference/pipeline/pipeline/#prebuilt) for your healthcare NLP and AI tasks +- [x] πŸ”„ **InteropEngine**: Convert between FHIR, CDA, and HL7v2 with a [template-based engine](https://dotimplement.github.io/HealthChain/reference/interop/interop/) +- [x] πŸ”’ Type-safe healthcare data with full type hints and Pydantic validation for [FHIR resources](https://dotimplement.github.io/HealthChain/reference/utilities/fhir_helpers/) +- [x] ⚑ Event-driven architecture with real-time event handling and [audit trails](https://dotimplement.github.io/HealthChain/reference/gateway/events/) built-in +- [x] πŸš€ Deploy production-ready applications with [HealthChainAPI](https://dotimplement.github.io/HealthChain/reference/gateway/api/) and FastAPI integration +- [x] πŸ§ͺ Generate [synthetic healthcare data](https://dotimplement.github.io/HealthChain/reference/utilities/data_generator/) and [sandbox testing](https://dotimplement.github.io/HealthChain/reference/sandbox/sandbox/) utilities ## Why use HealthChain? -- **EHR integrations are manual and time-consuming** - HealthChain abstracts away complexities so you can focus on AI development, not EHR configurations. -- **It's difficult to track and evaluate multiple integration instances** - HealthChain provides a framework to test the real-world resilience of your whole system, not just your models. -- [**Most healthcare data is unstructured**](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6372467/) - HealthChain is optimized for real-time AI and NLP applications that deal with realistic healthcare data. -- **Built by health tech developers, for health tech developers** - HealthChain is tech stack agnostic, modular, and easily extensible. +- **EHR integrations are manual and time-consuming** - **HealthChainAPI** abstracts away complexities so you can focus on AI development, not learning FHIR APIs, CDS Hooks, and authentication schemes. +- **Healthcare data is fragmented and complex** - **InteropEngine** handles the conversion between FHIR, CDA, and HL7v2 so you don't have to become an expert in healthcare data standards. +- [**Most healthcare data is unstructured**](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6372467/) - HealthChain **Pipelines** are optimized for real-time AI and NLP applications that deal with realistic healthcare data. +- **Built by health tech developers, for health tech developers** - HealthChain is tech stack agnostic, modular, and easily extensible with built-in compliance and audit features. + +## HealthChainAPI + +The HealthChainAPI provides a secure, asynchronous integration layer that coordinates multiple healthcare systems in a single application. + +### Multi-Protocol Support + +Connect to multiple healthcare data sources and protocols: + +```python +from healthchain.gateway import ( + HealthChainAPI, FHIRGateway, + CDSHooksService, NoteReaderService +) + +# Create your healthcare application +app = HealthChainAPI( + title="My Healthcare AI App", + description="AI-powered patient care platform" +) + +# FHIR for patient data from multiple EHRs +fhir = FHIRGateway() +fhir.add_source("epic", "fhir://fhir.epic.com/r4?client_id=...") +fhir.add_source("medplum", "fhir://api.medplum.com/fhir/R4/?client_id=...") + +# CDS Hooks for real-time clinical decision support +cds = CDSHooksService() + +@cds.hook("patient-view", id="allergy-alerts") +def check_allergies(request): + # Your AI logic here + return {"cards": [...]} + +# SOAP for clinical document processing +notes = NoteReaderService() + +@notes.method("ProcessDocument") +def process_note(request): + # Your NLP pipeline here + return processed_document + +# Register everything +app.register_gateway(fhir) +app.register_service(cds) +app.register_service(notes) + +# Your API now handles: +# /fhir/* - Patient data, observations, etc. +# /cds/* - Real-time clinical alerts +# /soap/* - Clinical document processing +``` + +### FHIR Operations with AI Enhancement + +```python +from healthchain.gateway import FHIRGateway +from fhir.resources.patient import Patient + +gateway = FHIRGateway() +gateway.add_source("epic", "fhir://fhir.epic.com/r4?...") + +# Add AI transformations to FHIR data +@gateway.transform(Patient) +async def enhance_patient(id: str, source: str = None) -> Patient: + async with gateway.modify(Patient, id, source) as patient: + # Get lab results and process with AI + lab_results = await gateway.search( + Observation, + {"patient": id, "category": "laboratory"}, + source + ) + insights = nlp_pipeline.process(patient, lab_results) + + # Add AI summary to patient record + patient.extension = patient.extension or [] + patient.extension.append({ + "url": "http://healthchain.org/fhir/summary", + "valueString": insights.summary + }) + return patient + +# Automatically available at: GET /fhir/transform/Patient/123?source=epic +``` ## Pipeline Pipelines provide a flexible way to build and manage processing pipelines for NLP and ML tasks that can easily integrate with complex healthcare systems. @@ -139,116 +222,40 @@ cda_data = engine.from_fhir(fhir_resources, dest_format=FormatType.CDA) ## Sandbox -Sandboxes provide a staging environment for testing and validating your pipeline in a realistic healthcare context. - -### Clinical Decision Support (CDS) -[CDS Hooks](https://cds-hooks.org/) is an [HL7](https://cds-hooks.hl7.org) published specification for clinical decision support. - -**When is this used?** CDS hooks are triggered at certain events during a clinician's workflow in an electronic health record (EHR), e.g. when a patient record is opened, when an order is elected. - -**What information is sent**: the context of the event and [FHIR](https://hl7.org/fhir/) resources that are requested by your service, for example, the patient ID and information on the encounter and conditions they are being seen for. - -**What information is returned**: β€œcards” displaying text, actionable suggestions, or links to launch a [SMART](https://smarthealthit.org/) app from within the workflow. - +Test your AI applications in realistic healthcare contexts with [CDS Hooks](https://cds-hooks.org/) sandbox environments. ```python import healthchain as hc - -from healthchain.pipeline import SummarizationPipeline from healthchain.sandbox.use_cases import ClinicalDecisionSupport -from healthchain.models import Card, Prefetch, CDSRequest -from healthchain.data_generator import CdsDataGenerator -from typing import List @hc.sandbox class MyCDS(ClinicalDecisionSupport): - def __init__(self) -> None: - self.pipeline = SummarizationPipeline.from_model_id( - "facebook/bart-large-cnn", source="huggingface" - ) - self.data_generator = CdsDataGenerator() + def __init__(self): + self.pipeline = SummarizationPipeline.from_model_id("facebook/bart-large-cnn") - # Sets up an instance of a mock EHR client of the specified workflow @hc.ehr(workflow="encounter-discharge") - def ehr_database_client(self) -> Prefetch: + def ehr_database_client(self): return self.data_generator.generate_prefetch() - # Define your application logic here - @hc.api - def my_service(self, data: CDSRequest) -> CDSRequest: - result = self.pipeline(data) - return result -``` - -### Clinical Documentation - -The `ClinicalDocumentation` use case implements a real-time Clinical Documentation Improvement (CDI) service. It helps convert free-text medical documentation into coded information that can be used for billing, quality reporting, and clinical decision support. - -**When is this used?** Triggered when a clinician opts in to a CDI functionality (e.g. Epic NoteReader) and signs or pends a note after writing it. - -**What information is sent**: A [CDA (Clinical Document Architecture)](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) document which contains continuity of care data and free-text data, e.g. a patient's problem list and the progress note that the clinician has entered in the EHR. - -```python -import healthchain as hc - -from healthchain.pipeline import MedicalCodingPipeline -from healthchain.sandbox.use_cases import ClinicalDocumentation -from healthchain.models import CdaRequest, CdaResponse -from fhir.resources.documentreference import DocumentReference - -@hc.sandbox -class NotereaderSandbox(ClinicalDocumentation): - def __init__(self): - self.pipeline = MedicalCodingPipeline.from_model_id( - "en_core_sci_md", source="spacy" - ) - - # Load an existing CDA file - @hc.ehr(workflow="sign-note-inpatient") - def load_data_in_client(self) -> DocumentReference: - with open("/path/to/cda/data.xml", "r") as file: - xml_string = file.read() - - cda_document_reference = create_document_reference( - data=xml_string, - content_type="text/xml", - description="Original CDA Document loaded from my sandbox", - ) - return cda_document_reference - - @hc.api - def my_service(self, data: CdaRequest) -> CdaResponse: - annotated_ccd = self.pipeline(data) - return annotated_ccd -``` -### Running a sandbox - -Ensure you run the following commands in your `mycds.py` file: - -```python cds = MyCDS() cds.start_sandbox() -``` -This will populate your EHR client with the data generation method you have defined, send requests to your server for processing, and save the data in the `./output` directory. -Then run: -```bash -healthchain run mycds.py +# Run with: healthchain run mycds.py ``` -By default, the server runs at `http://127.0.0.1:8000`, and you can interact with the exposed endpoints at `/docs`. ## Road Map -- [x] πŸ”„ Transform and validate healthcare HL7v2, CDA to FHIR with template-based interop engine -- [ ] πŸ₯ Runtime connection health and EHR integration management - connect to FHIR APIs and legacy systems +- [ ] πŸ”’ Built-in HIPAA compliance validation and PHI detection - [ ] πŸ“Š Track configurations, data provenance, and monitor model performance with MLFlow integration - [ ] πŸš€ Compliance monitoring, auditing at deployment as a sidecar service -- [ ] πŸ”’ Built-in HIPAA compliance validation and PHI detection -- [ ] 🧠 Multi-modal pipelines that that have built-in NLP to utilize unstructured data +- [ ] πŸ”„ HL7v2 parsing and FHIR profile conversion support +- [ ] 🧠 Multi-modal pipelines + ## Contribute We are always eager to hear feedback and suggestions, especially if you are a developer or researcher working with healthcare systems! - πŸ’‘ Let's chat! [Discord](https://discord.gg/UQC6uAepUz) - πŸ› οΈ [Contribution Guidelines](CONTRIBUTING.md) -## Acknowledgement -This repository makes use of [fhir.resources](https://github.com/nazrulworld/fhir.resources), and [CDS Hooks](https://cds-hooks.org/) developed by [HL7](https://www.hl7.org/) and [Boston Children’s Hospital](https://www.childrenshospital.org/). + +## Acknowledgements πŸ€— +This project builds on [fhir.resources](https://github.com/nazrulworld/fhir.resources) and [CDS Hooks](https://cds-hooks.org/) standards developed by [HL7](https://www.hl7.org/) and [Boston Children's Hospital](https://www.childrenshospital.org/). diff --git a/docs/assets/images/openapi_docs.png b/docs/assets/images/openapi_docs.png new file mode 100644 index 00000000..76afa7cd Binary files /dev/null and b/docs/assets/images/openapi_docs.png differ diff --git a/docs/index.md b/docs/index.md index 6d97fba0..53d9e534 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ -# Welcome to HealthChain +# Welcome to HealthChain πŸ’« πŸ₯ -HealthChain πŸ’«πŸ₯ is an open-source Python framework designed to streamline the development, testing, and validation of AI, Natural Language Processing, and Machine Learning applications in a healthcare context. +HealthChain is an open-source Python framework for building real-time AI applications in a healthcare context. [ :fontawesome-brands-discord: Join our Discord](https://discord.gg/UQC6uAepUz){ .md-button .md-button--primary }      @@ -19,19 +19,19 @@ HealthChain πŸ’«πŸ₯ is an open-source Python framework designed to streamline t [:octicons-arrow-right-24: Pipeline](reference/pipeline/pipeline.md) -- :octicons-beaker-24:{ .lg .middle } __Test in a sandbox__ +- :material-connection:{ .lg .middle } __Connect to multiple data sources__ --- - Test your models in a full health-context aware environment from day 1 + Connect to multiple healthcare data sources and protocols with **HealthChainAPI**. - [:octicons-arrow-right-24: Sandbox](reference/sandbox/sandbox.md) + [:octicons-arrow-right-24: Gateway](reference/gateway/gateway.md) - :material-database:{ .lg .middle } __Interoperability__ --- - Configuration-driven InteropEngine to convert between FHIR, CDA, and HL7v2 + Configuration-driven **InteropEngine** to convert between FHIR, CDA, and HL7v2 [:octicons-arrow-right-24: Interoperability](reference/interop/interop.md) @@ -49,16 +49,17 @@ HealthChain πŸ’«πŸ₯ is an open-source Python framework designed to streamline t ## Why HealthChain? -You've probably heard every *AI will revolutionize healthcare* pitch by now, but if you're one of the people who think: wait, can we go beyond just vibe-checking and *actually* build products that are reliable, reactive, and easy to scale in complex healthcare systems? Then HealthChain is probably for you. +Healthcare AI development has a **missing middleware layer**. Traditional enterprise integration engines move data around, EHR platforms serve end users, but there's nothing in between for developers building AI applications that need to talk to multiple healthcare systems. Few solutions are open-source, and even fewer are built in modern Python where most ML/AI libraries thrive. -Specifically, HealthChain addresses two challenges: +HealthChain fills that gap with: -1. **Scaling Electronic Health Record system (EHRs) integrations of real-time AI, NLP, and ML applications is a manual and time-consuming process.** +- **πŸ”₯ FHIR-native ML pipelines** - Pre-built NLP/ML pipelines optimized for structured / unstructured healthcare data, or build your own with familiar Python libraries such as πŸ€— Hugging Face, πŸ€– LangChain, and πŸ“š spaCy +- **πŸ”’ Type-safe healthcare data** - Full type hints and Pydantic validation for FHIR resources with automatic data validation and error handling +- **πŸ”Œ Multi-protocol connectivity** - Handle FHIR, CDS Hooks, and SOAP/CDA in the same codebase with OAuth2 authentication and connection pooling +- **⚑ Event-driven architecture** - Real-time event handling with audit trails and workflow automation built-in +- **πŸ”„ Built-in interoperability** - Convert between FHIR, CDA, and HL7v2 using a template-based engine +- **πŸš€ Production-ready deployment** - FastAPI integration for scalable, real-time applications -2. **Testing and evaluating unstructured data in complex, outcome focused systems is a challenging and labour-intensive task.** - -We believe more efficient end-to-end pipeline and integration testing at an early stage in development will give you back time to focus on what actually matters: developing safer, more effective and more explainable models that scale to real-world *adoption*. Building products for healthcare in a process that is *human*-centric. - -HealthChain is made by a (very) small team with experience in software engineering, machine learning, and healthcare NLP. We understand that good data science is about more than just building models, and that good engineering is about more than just building systems. This rings especially true in healthcare, where people, processes, and technology all play a role in making an impact. +HealthChain is made by a small team with experience in software engineering, machine learning, and healthcare NLP. We understand that good data science is about more than just building models, and that good engineering is about more than just building systems. This rings especially true in healthcare, where people, processes, and technology all play a role in making an impact. For inquiries and collaborations, please get [in touch](mailto:jenniferjiangkells@gmail.com)! diff --git a/docs/quickstart.md b/docs/quickstart.md index 816e621e..660ca885 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -2,8 +2,41 @@ After [installing HealthChain](installation.md), get up to speed quickly with the core components before diving further into the [full documentation](reference/index.md)! +HealthChain provides three core tools for healthcare AI integration: **Gateway** for connecting to multiple healthcare systems, **Pipelines** for FHIR-native AI workflows, and **InteropEngine** for healthcare data format conversion between FHIR, CDA, and HL7v2. + ## Core Components +### HealthChainAPI Gateway πŸ”Œ + +The HealthChainAPI provides a unified interface for connecting your AI models to multiple healthcare systems through a single API. Handle FHIR, CDS Hooks, and SOAP/CDA protocols with OAuth2 authentication and connection pooling. + +[(Full Documentation on Gateway)](./reference/gateway/gateway.md) + +```python +from healthchain.gateway import HealthChainAPI, FHIRGateway + +# Create your healthcare application +app = HealthChainAPI(title="My Healthcare AI App") + +# Connect to multiple FHIR servers +fhir = FHIRGateway() +fhir.add_source("epic", "fhir://fhir.epic.com/r4?client_id=...") +fhir.add_source("medplum", "fhir://api.medplum.com/fhir/R4/?client_id=...") + +# Add AI transformations to FHIR data +@fhir.transform(Patient) +async def enhance_patient(id: str, source: str = None) -> Patient: + async with fhir.modify(Patient, id, source) as patient: + # Your AI logic here + patient.active = True + return patient + +# Register and run +app.register_gateway(fhir) + +# Available at: GET /fhir/transform/Patient/123?source=epic +``` + ### Pipeline πŸ› οΈ HealthChain Pipelines provide a flexible way to build and manage processing pipelines for NLP and ML tasks that can easily integrate with electronic health record (EHR) systems. @@ -149,72 +182,30 @@ The interop module provides a flexible, template-based approach to healthcare fo For more details, see the [conversion examples](cookbook/interop/basic_conversion.md). -### Sandbox πŸ§ͺ -Once you've built your pipeline, you might want to experiment with how it interacts with different healthcare systems. A sandbox helps you stage and test the end-to-end workflow of your pipeline application where real-time EHR integrations are involved. - -Running a sandbox will start a [FastAPI](https://fastapi.tiangolo.com/) server with pre-defined standardized endpoints and create a sandboxed environment for you to interact with your application. +## Utilities βš™οΈ -To create a sandbox, initialize a class that inherits from a type of [UseCase](./reference/sandbox/use_cases/use_cases.md) and decorate it with the `@hc.sandbox` decorator. +### Sandbox Testing -Every sandbox also requires a **client** function marked by `@hc.ehr` and a **service** function marked by `@hc.api`. A **workflow** must be specified when creating an EHR client. +Test your AI applications in realistic healthcare contexts with sandbox environments for CDS Hooks and clinical documentation workflows. -[(Full Documentation on Sandbox and Use Cases)](./reference/sandbox/sandbox.md) +[(Full Documentation on Sandbox)](./reference/sandbox/sandbox.md) ```python import healthchain as hc - -from healthchain.sandbox.use_cases import ClinicalDocumentation -from healthchain.pipeline import MedicalCodingPipeline -from healthchain.models import CdaRequest, CdaResponse -from healthchain.fhir import create_document_reference - -from fhir.resources.documentreference import DocumentReference +from healthchain.sandbox.use_cases import ClinicalDecisionSupport @hc.sandbox -class MyCoolSandbox(ClinicalDocumentation): - def __init__(self) -> None: - # Load your pipeline - self.pipeline = MedicalCodingPipeline.from_local_model( - "./path/to/model", source="spacy" - ) - - @hc.ehr(workflow="sign-note-inpatient") - def load_data_in_client(self) -> DocumentReference: - # Load your data - with open('/path/to/data.xml', "r") as file: - xml_string = file.read() - - cda_document_reference = create_document_reference( - data=xml_string, - content_type="text/xml", - description="Original CDA Document loaded from my sandbox", - ) - - return cda_document_reference - - @hc.api - def my_service(self, request: CdaRequest) -> CdaResponse: - # Run your pipeline - results = self.pipeline(request) - return results - -if __name__ == "__main__": - clindoc = MyCoolSandbox() - clindoc.start_sandbox() -``` - -#### Deploy sandbox locally with FastAPI πŸš€ +class MyCDS(ClinicalDecisionSupport): + def __init__(self): + self.pipeline = SummarizationPipeline.from_model_id("facebook/bart-large-cnn") -To run your sandbox: + @hc.ehr(workflow="encounter-discharge") + def ehr_database_client(self): + return self.data_generator.generate_prefetch() -```bash -healthchain run my_sandbox.py +# Run with: healthchain run mycds.py ``` -This will start a server by default at `http://127.0.0.1:8000`, and you can interact with the exposed endpoints at `/docs`. Data generated from your sandbox runs is saved at `./output/` by default. - -## Utilities βš™οΈ - ### FHIR Helpers The `fhir` module provides a set of helper functions for working with FHIR resources. diff --git a/docs/reference/gateway/api.md b/docs/reference/gateway/api.md new file mode 100644 index 00000000..29a4b3eb --- /dev/null +++ b/docs/reference/gateway/api.md @@ -0,0 +1,181 @@ +# HealthChainAPI πŸ₯ + +The `HealthChainAPI` is your main application that coordinates all the different gateways and services. + +It's a [FastAPI](https://fastapi.tiangolo.com/) app under the hood, so you get all the benefits of FastAPI (automatic docs, type safety, performance) plus healthcare-specific features that makes it easier to work with healthcare data sources, such as FHIR APIs, CDS Hooks, and SOAP/CDA services. + + +## Basic Usage + + +```python +from healthchain.gateway import HealthChainAPI, FHIRGateway +import uvicorn + +# Create your app +app = HealthChainAPI( + title="My Healthcare App", + description="AI-powered patient care", +) + +# Add a FHIR gateway +fhir = FHIRGateway() +app.register_gateway(fhir) + +# Run it (docs automatically available at /docs) +if __name__ == "__main__": + uvicorn.run(app) +``` + +You can also register multiple services of different protocols: + +```python +from healthchain.gateway import ( + HealthChainAPI, FHIRGateway, + CDSHooksService, NoteReaderService +) + +app = HealthChainAPI() + +# Register everything you need +app.register_gateway(FHIRGateway(), path="/fhir") +app.register_service(CDSHooksService(), path="/cds") +app.register_service(NoteReaderService(), path="/soap") + +# Your API now handles: +# /fhir/* - Patient data, observations, etc. +# /cds/* - Real-time clinical alerts +# /soap/* - Clinical document processing +``` + +## Default Endpoints + +![open_api_docs](../../assets/images/openapi_docs.png) + +The HealthChainAPI automatically provides several default endpoints: + +### Root Endpoint: `GET /` + +Returns basic API information and registered components. + +```json +{ + "name": "HealthChain API", + "version": "1.0.0", + "description": "Healthcare Integration Platform", + "gateways": ["FHIRGateway"], + "services": ["CDSHooksService", "NoteReaderService"] +} +``` + +### Health Check: `GET /health` + +Simple health check endpoint for monitoring. + +```json +{ + "status": "healthy" +} +``` + +### Gateway Status: `GET /gateway/status` + +Comprehensive status of all registered gateways and services. + +```json +{ + "gateways": { + "FHIRGateway": { + "status": "active", + "sources": ["epic", "cerner"], + "connection_pool": {...} + } + }, + "services": { + "CDSHooksService": { + "status": "active", + "hooks": ["patient-view", "order-select"] + } + }, + "events": { + "enabled": true, + "dispatcher": "LocalEventDispatcher" + } +} +``` + + +## Event Integration + +The HealthChainAPI coordinates events across all registered components. This is useful for auditing, workflow automation, and other use cases. For more information, see the **[Events](events.md)** page. + + +```python +from healthchain.gateway.events.dispatcher import local_handler + +app = HealthChainAPI() + +# Register global event handler +@local_handler.register(event_name="fhir.patient.read") +async def log_patient_access(event): + event_name, payload = event + print(f"Patient accessed: {payload['resource_id']}") + +# Register handler for all events from specific component +@local_handler.register(event_name="cdshooks.*") +async def log_cds_events(event): + event_name, payload = event + print(f"CDS Hook fired: {event_name}") +``` + +## Dependencies and Injection + +The HealthChainAPI provides dependency injection for accessing registered components. + +### Gateway Dependencies + +```python +from healthchain.gateway.api.dependencies import get_gateway +from fastapi import Depends + +@app.get("/custom/patient/{id}") +async def get_enhanced_patient( + id: str, + fhir: FHIRGateway = Depends(get_gateway("FHIRGateway")) +): + """Custom endpoint using FHIR gateway dependency.""" + patient = await fhir.read(Patient, id) + return patient + +# Or get all gateways +from healthchain.gateway.api.dependencies import get_all_gateways + +@app.get("/admin/gateways") +async def list_gateways( + gateways: Dict[str, Any] = Depends(get_all_gateways) +): + return {"gateways": list(gateways.keys())} +``` + +### Application Dependencies + +```python +from healthchain.gateway.api.dependencies import get_app + +@app.get("/admin/status") +async def admin_status( + app_instance: HealthChainAPI = Depends(get_app) +): + return { + "gateways": len(app_instance.gateways), + "services": len(app_instance.services), + "events_enabled": app_instance.enable_events + } +``` + + +## See Also + +- **[FHIR Gateway](fhir_gateway.md)**: Complete FHIR operations reference +- **[CDS Hooks Service](cdshooks.md)**: Complete CDS Hooks service reference +- **[NoteReader Service](soap_cda.md)**: Complete NoteReader service reference diff --git a/docs/reference/gateway/cdshooks.md b/docs/reference/gateway/cdshooks.md new file mode 100644 index 00000000..6564a4a4 --- /dev/null +++ b/docs/reference/gateway/cdshooks.md @@ -0,0 +1,105 @@ +# CDS Hooks Protocol + +CDS Hooks is an [HL7](https://cds-hooks.hl7.org) published specification for clinical decision support that enables external services to provide real-time recommendations during clinical workflows. + +## Overview + +CDS hooks are triggered at specific events during a clinician's workflow in an electronic health record (EHR), such as when a patient record is opened or when an order is selected. The hooks communicate using [FHIR (Fast Healthcare Interoperability Resources)](https://hl7.org/fhir/). + +CDS Hooks are unique in that they are *real-time* webhooks that are triggered by the EHR, not by external services. This makes them ideal for real-time clinical decision support and alerts, but also trickier to test and debug for a developer. They are also a relatively new standard, so not all EHRs support them yet. + +| When | Where | What you receive | What you send back | Common Use Cases | +| :-------- | :-----| :-------------------------- |----------------------------|-----------------| +| Triggered at certain events during a clinician's workflow | EHR | The context of the event and FHIR resources that are requested by your service | "Cards" displaying text, actionable suggestions, or links to launch a [SMART](https://smarthealthit.org/) app | Allergy alerts, medication reconciliation, clinical decision support | + +## HealthChainAPI Integration + +Use the `CDSHooksService` with HealthChainAPI to handle CDS Hooks workflows: + +```python +from healthchain.gateway import HealthChainAPI, CDSHooksService +from healthchain.models import CDSRequest, CDSResponse + +app = HealthChainAPI() +cds = CDSHooksService() + +@cds.hook("patient-view", id="allergy-alerts") +def check_allergies(request: CDSRequest) -> CDSResponse: + # Your AI logic here + return CDSResponse(cards=[...]) + +app.register_service(cds, path="/cds") +``` + +## Supported Workflows + +| Workflow Name | Description | Trigger | Status | +|-----------|-------------|---------|----------| +| `patient-view` | Triggered when a patient chart is opened | Opening a patient's chart | βœ… | +| `order-select` | Triggered when a new order is selected | Selecting a new order | ⏳ | +| `order-sign` | Triggered when orders are being signed | Signing orders | ⏳ | +| `encounter-discharge` | Triggered when a patient is being discharged | Discharging a patient | βœ… | + +## API Endpoints + +When registered with HealthChainAPI, the following endpoints are automatically created: + +| Endpoint | Method | Function | Description | +|------|--------|----------|-------------| +| `/cds-services` | GET | Service Discovery | Lists all available CDS services | +| `/cds-services/{id}` | POST | Hook Execution | Executes the specified CDS hook | + +## Request/Response Format + +### CDSRequest Example + +```json +{ + "hookInstance": "23f1a303-991f-4118-86c5-11d99a39222e", + "fhirServer": "https://fhir.example.org", + "hook": "patient-view", + "context": { + "patientId": "1288992", + "userId": "Practitioner/example" + }, + "prefetch": { + "patientToGreet": { + "resourceType": "Patient", + "gender": "male", + "birthDate": "1925-12-23", + "id": "1288992", + "active": true + } + } +} +``` + +### CDSResponse Example + +```json +{ + "cards": [{ + "summary": "Bilirubin: Based on the age of this patient consider overlaying bilirubin results", + "indicator": "info", + "detail": "The focus of this app is to reduce the incidence of severe hyperbilirubinemia...", + "source": { + "label": "Intermountain", + "url": null + }, + "links": [{ + "label": "Bilirubin SMART app", + "url": "https://example.com/launch", + "type": "smart" + }] + }] +} +``` + +## Supported FHIR Resources + +- `Patient` +- `Encounter` +- `Procedure` +- `MedicationRequest` + +For more information, see the [official CDS Hooks documentation](https://cds-hooks.org/). diff --git a/docs/reference/gateway/events.md b/docs/reference/gateway/events.md new file mode 100644 index 00000000..e850f9b9 --- /dev/null +++ b/docs/reference/gateway/events.md @@ -0,0 +1,76 @@ +# Events + +The FHIR Gateway emits events for all operations. The events are emitted using the `EventDispatcher`. + +!!! warning "Develoment Use Only" + This is a development feature and may change in future releases. + + + +## Event System + +The FHIR Gateway uses the `EventDispatcher` to emit events. + +## Event Types + +- `ehr.generic` +- `fhir.read` +- `fhir.search` +- `fhir.update` +- `fhir.delete` +- `fhir.create` +- `cds.patient.view` +- `cds.encounter.discharge` +- `cds.order.sign` +- `cds.order.select` +- `notereader.sign.note` +- `notereader.process.note` + +## Automatic Events + +The FHIR Gateway automatically emits events for all operations: + +```python +from healthchain.gateway.events.dispatcher import local_handler + +# Listen for FHIR read events +@local_handler.register(event_name="fhir.read") +async def audit_fhir_access(event): + event_name, payload = event + print(f"FHIR Read: {payload['resource_type']}/{payload['resource_id']} from {payload.get('source', 'unknown')}") + +# Listen for patient-specific events +@local_handler.register(event_name="fhir.patient.*") +async def track_patient_access(event): + event_name, payload = event + operation = event_name.split('.')[-1] # read, create, update, delete + print(f"Patient {operation}: {payload['resource_id']}") +``` + +### Custom Event Creation + +```python +# Configure custom event creation +def custom_event_creator(operation, resource_type, resource_id, resource=None, source=None): + """Create custom events with additional metadata.""" + return EHREvent( + event_type=EHREventType.FHIR_READ, + source_system=source or "unknown", + timestamp=datetime.now(), + payload={ + "operation": operation, + "resource_type": resource_type, + "resource_id": resource_id, + "user_id": get_current_user_id(), # Your auth system + "session_id": get_session_id(), + "ip_address": get_client_ip() + }, + metadata={ + "compliance": "HIPAA", + "audit_required": True + } + ) + +# Apply to gateway +gateway.events.set_event_creator(custom_event_creator) +``` diff --git a/docs/reference/gateway/fhir_gateway.md b/docs/reference/gateway/fhir_gateway.md new file mode 100644 index 00000000..007882b0 --- /dev/null +++ b/docs/reference/gateway/fhir_gateway.md @@ -0,0 +1,293 @@ +# FHIR Gateway + +The `FHIRGateway` provides a unified **asynchronous** interface for connecting to multiple FHIR servers with automatic authentication, connection pooling, error handling, and simplified CRUD operations. It handles the complexity of managing multiple FHIR clients and provides a consistent API across different healthcare systems. + + +## Basic Usage + +```python +from healthchain.gateway import FHIRGateway +from fhir.resources.patient import Patient + +# Create gateway +gateway = FHIRGateway() + +# Connect to FHIR server +gateway.add_source( + "my_fhir_server", + "fhir://fhir.example.com/api/FHIR/R4/?client_id=your_app&client_secret=secret&token_url=https://fhir.example.com/oauth2/token" +) + +async with gateway: + # FHIR operations + patient = await gateway.read(Patient, "123", "my_fhir_server") + print(f"Patient: {patient.name[0].family}") +``` + + +## Adding Sources πŸ₯ + +The gateway currently supports adding sources with OAuth2 authentication flow. + +```python +# Epic Sandbox (JWT assertion) +gateway.add_source( + "epic", + ( + "fhir://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4/" + "?client_id=your_app" + "&client_secret_path=keys/private.pem" + "&token_url=https://fhir.epic.com/interconnect-fhir-oauth/oauth2/token" + "&use_jwt_assertion=true" + ) +) + +# Medplum (Client Credentials) +gateway.add_source( + "medplum", + ( + "fhir://api.medplum.com/fhir/R4/" + "?client_id=your_app" + "&client_secret=secret" + "&token_url=https://api.medplum.com/oauth2/token" + "&scope=openid" + ) +) +``` +!!! info "For more information on configuring specific FHIR servers" + + **Epic FHIR API:** + + - [Epic on FHIR Documentation](https://fhir.epic.com/) + - [Epic OAuth2 Setup](https://fhir.epic.com/Documentation?docId=oauth2) + - [Test Patients in Epic Sandbox](https://fhir.epic.com/Documentation?docId=testpatients) + - [Useful Epic Sandbox Setup Guide](https://docs.interfaceware.com/docs/IguanaX_Documentation_Home/Development/iNTERFACEWARE_Collections/HL7_Collection/Epic_FHIR_Adapter/Set_up_your_Epic_FHIR_Sandbox_2783739933/) + + **Medplum FHIR API:** + + - [Medplum app tutorial](https://www.medplum.com/docs/tutorials) + - [Medplum OAuth2 Client Credentials Setup](https://www.medplum.com/docs/auth/methods/client-credentials) + + **General Resources:** + + - [OAuth2](https://oauth.net/2/) + - [FHIR RESTful API](https://hl7.org/fhir/http.html) + - [FHIR Specification](https://hl7.org/fhir/) + + +### Connection String Format + +Connection strings use the `fhir://` scheme with query parameters: + +``` +fhir://hostname:port/path?param1=value1¶m2=value2 +``` + +**Required Parameters:** + +- `client_id`: OAuth2 client ID +- `token_url`: OAuth2 token endpoint + +**Optional Parameters:** + +- `client_secret`: OAuth2 client secret (for client credentials flow) +- `client_secret_path`: Path to private key file (for JWT assertion) +- `scope`: OAuth2 scope (default: "`system/*.read system/*.write`") +- `use_jwt_assertion`: Use JWT assertion flow (default: false) +- `audience`: Token audience (for some servers) + + +## FHIR Operations πŸ”₯ + +!!! note Prerequisites + These examples assume you have already created and configured your gateway as shown in the [Basic Usage](#basic-usage) section above. + +### Create Resources + +```python +from fhir.resources.patient import Patient +from fhir.resources.humanname import HumanName + +# Create a new patient +patient = Patient( + name=[HumanName(family="Smith", given=["John"])], + gender="male", + birthDate="1990-01-01" +) + +created_patient = await gateway.create(resource=patient, source="medplum") +print(f"Created patient with ID: {created_patient.id}") +``` + +### Read Resources + +```python +from fhir.resources.patient import Patient + +# Read a specific patient (Derrick Lin, Epic Sandbox) +patient = await gateway.read( + resource_type=Patient, + fhir_id="eq081-VQEgP8drUUqCWzHfw3", + source="epic" + ) +``` + +### Update Resources + +```python +from fhir.resources.patient import Patient + +# Read, modify, and update +patient = await gateway.read(Patient, "123", "medplum") +patient.name[0].family = "Johnson" +updated_patient = await gateway.update(patient, "medplum") + +# Using context manager +async with gateway.modify(Patient, "123", "medplum") as patient: + patient.active = True + patient.name[0].given = ["Jane"] + # Automatic save on exit +``` + +### Delete Resources + +```python +from fhir.resources.patient import Patient + +# Delete a patient +success = await gateway.delete(Patient, "123", "medplum") +if success: + print("Patient deleted successfully") +``` + +## Search Operations + +### Basic Search + +```python +from fhir.resources.patient import Patient +from fhir.resources.bundle import Bundle + +# Search by name +search_params = {"family": "Smith", "given": "John"} +results: Bundle = await gateway.search(Patient, search_params, "epic") + +for entry in results.entry: + patient = entry.resource + print(f"Found: {patient.name[0].family}, {patient.name[0].given[0]}") +``` + +### Advanced Search + +```python +from fhir.resources.patient import Patient + +# Complex search with multiple parameters +search_params = { + "birthdate": "1990-01-01", + "gender": "male", + "address-city": "Boston", + "_count": 50, + "_sort": "family" +} + +results = await gateway.search(Patient, search_params, "epic") +print(f"Found {len(results.entry)} patients") +``` + +## Transform Handlers πŸ€– + +Transform handlers allow you to create custom API endpoints that process and enhance FHIR resources with additional logic, AI insights, or data transformations before returning them to clients. These handlers run before the response is sent, enabling real-time data enrichment and processing. + +```python +from fhir.resources.patient import Patient +from fhir.resources.observation import Observation + +@fhir_gateway.transform(Patient) +async def get_enhanced_patient_summary(id: str, source: str = None) -> Patient: + """Create enhanced patient summary with AI insights""" + + async with fhir_gateway.modify(Patient, id, source=source) as patient: + # Get lab results and process with AI + lab_results = await fhir_gateway.search( + resource_type=Observation, + search_params={"patient": id, "category": "laboratory"}, + source=source + ) + insights = nlp_pipeline.process(patient, lab_results) + + # Add AI summary + patient.extension = patient.extension or [] + patient.extension.append({ + "url": "http://healthchain.org/fhir/summary", + "valueString": insights.summary + }) + + return patient + +# The handler is automatically called via HTTP endpoint: +# GET /fhir/transform/Patient/123?source=epic +``` + +## Aggregate Handlers πŸ”— + +Aggregate handlers allow you to combine data from multiple FHIR sources into a single resource. This is useful for creating unified views across different EHR systems or consolidating patient data from various healthcare providers. + + +```python +from fhir.resources.observation import Observation +from fhir.resources.bundle import Bundle + +@gateway.aggregate(Observation) +async def aggregate_vitals(patient_id: str, sources: List[str] = None) -> Bundle: + """Aggregate vital signs from multiple sources.""" + sources = sources or ["epic", "cerner"] + all_observations = [] + + for source in sources: + try: + results = await gateway.search( + Observation, + {"patient": patient_id, "category": "vital-signs"}, + source + ) + processed_observations = process_observations(results) + all_observations.append(processed_observations) + except Exception as e: + print(f"Could not get vitals from {source}: {e}") + + return Bundle(type="searchset", entry=[{"resource": obs} for obs in all_observations]) + +# The handler is automatically called via HTTP endpoint: +# GET /fhir/aggregate/Observation?patient_id=123&sources=epic&sources=cerner +``` + +## Server Capabilities + +- **GET** `/fhir/metadata` - Returns FHIR-style `CapabilityStatement` of transform and aggregate endpoints +- **GET** `/fhir/status` - Returns Gateway status and connection health + + +## Connection Pool Management + +When you add a connection to a FHIR server, the gateway will automatically add it to a connection pool to manage connections to FHIR servers. + + +### Pool Configuration + +```python +# Create gateway with optimized connection settings +gateway = FHIRGateway( + max_connections=100, # Total connections across all sources + max_keepalive_connections=20, # Keep-alive connections per source + keepalive_expiry=30.0, # Keep connections alive for 30 seconds +) + +# Add multiple sources - they share the connection pool +gateway.add_source("epic", "fhir://epic.org/...") +gateway.add_source("cerner", "fhir://cerner.org/...") +gateway.add_source("medplum", "fhir://medplum.com/...") + +stats = gateway.get_pool_status() +print(stats) +``` diff --git a/docs/reference/gateway/gateway.md b/docs/reference/gateway/gateway.md new file mode 100644 index 00000000..6fc16773 --- /dev/null +++ b/docs/reference/gateway/gateway.md @@ -0,0 +1,105 @@ +# Gateway + +The HealthChain Gateway module provides a secure, asynchronous integration layer for connecting your NLP/ML pipelines with multiple healthcare systems. It provides a unified interface for connecting to FHIR servers, CDS Hooks, and SOAP/CDA services and is designed to be used in conjunction with the [HealthChainAPI](api.md) to create a complete healthcare integration platform. + + +## Features πŸš€ + +The Gateway handles the complex parts of healthcare integration: + +- **Multiple Protocols**: Works with [FHIR RESTful APIs](https://hl7.org/fhir/http.html), [CDS Hooks](https://cds-hooks.hl7.org/), and [Epic NoteReader CDI](https://discovery.hgdata.com/product/epic-notereader-cdi) (SOAP/CDA service) out of the box +- **Multi-Source**: Context managers to work with data from multiple EHR systems and FHIR servers safely +- **Smart Connections**: Handles [OAuth2.0 authentication](https://oauth.net/2/), connection pooling, and automatic token refresh +- **Event-Driven**: Native [asyncio](https://docs.python.org/3/library/asyncio.html) support for real-time events, audit trails, and workflow automation +- **Transform & Aggregate**: FastAPI-style declarative patterns to create endpoints for enhancing and combining data +- **Developer-Friendly**: Modern Python typing and validation support via [fhir.resources](https://github.com/nazrulworld/fhir.resources) (powered by [Pydantic](https://docs.pydantic.dev/)), protocol-based interfaces, and informative error messages + +## Key Components + +| Component | Description | Use Case | +|-----------|-------------|----------| +| [**HealthChainAPI**](api.md) | FastAPI app with gateway and service registration | Main app that coordinates everything | +| [**FHIRGateway**](fhir_gateway.md) | FHIR client with connection pooling and authentication| Reading/writing patient data from EHRs (Epic, Cerner, etc.) or application FHIR servers (Medplum, Hapi etc.) | +| [**CDSHooksService**](cdshooks.md) | Clinical Decision Support hooks service | Real-time alerts and recommendations | +| [**NoteReaderService**](soap_cda.md) | SOAP/CDA document processing service | Processing clinical documents and notes | +| [**Event System**](events.md) | Event-driven integration | Audit trails, workflow automation | + + +## Basic Usage + + +```python +from healthchain.gateway import HealthChainAPI, FHIRGateway +from fhir.resources.patient import Patient + +# Create the application +app = HealthChainAPI() + +# Create and configure a FHIR gateway +fhir = FHIRGateway() + +# Connect to your FHIR APIs +fhir.add_source("epic", "fhir://epic.org/api/FHIR/R4?client_id=...") +fhir.add_source("medplum", "fhir://api.medplum.com/fhir/R4/?client_id=...") + +# Add AI enhancements to patient data +@fhir.transform(Patient) +async def enhance_patient(id: str, source: str = None) -> Patient: + async with fhir.modify(Patient, id, source) as patient: + patient.active = True # Your custom logic here + return patient + +# Register and run +app.register_gateway(fhir) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app) + +# Default: http://127.0.0.1:8000/ +``` + +You can also register multiple services of different protocols! + +```python +from healthchain.gateway import ( + HealthChainAPI, FHIRGateway, + CDSHooksService, NoteReaderService +) + +app = HealthChainAPI() + +# FHIR for patient data +fhir = FHIRGateway() +fhir.add_source("epic", "fhir://fhir.epic.com/r4?...") + +# CDS Hooks for real-time alerts +cds = CDSHooksService() + +@cds.hook("patient-view", id="allergy-alerts") +def check_allergies(request): + # Your logic here + return {"cards": [...]} + +# SOAP for clinical documents +notes = NoteReaderService() + +@notes.method("ProcessDocument") +def process_note(request): + # Your NLP pipeline here + return processed_document + +# Register everything +app.register_gateway(fhir) +app.register_service(cds) +app.register_service(notes) +``` + + +## Protocol Support + +| Protocol | Implementation | Features | +|----------|---------------|----------| +| **FHIR API** | `FHIRGateway` | FHIR-instance level CRUD operations - [read](https://hl7.org/fhir/http.html#read), [create](https://hl7.org/fhir/http.html#create), [update](https://hl7.org/fhir/http.html#update), [delete](https://hl7.org/fhir/http.html#delete), [search](https://hl7.org/fhir/http.html#search), register `transform` and `aggregate` handlers, connection pooling and authentication management | +| **CDS Hooks** | `CDSHooksService` | Hook Registration, Service Discovery | +| **SOAP/CDA** | `NoteReaderService` | Method Registration (`ProcessDocument`), SOAP Service Discovery (WSDL)| diff --git a/docs/reference/gateway/soap_cda.md b/docs/reference/gateway/soap_cda.md new file mode 100644 index 00000000..89992e13 --- /dev/null +++ b/docs/reference/gateway/soap_cda.md @@ -0,0 +1,157 @@ +# SOAP/CDA Protocol + +The SOAP/CDA protocol enables real-time Clinical Documentation Improvement (CDI) services. This implementation follows the Epic-integrated NoteReader CDI specification for analyzing clinical notes and extracting structured data. + +## Overview + +Clinical Documentation workflows communicate using [CDA (Clinical Document Architecture)](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/). CDAs are standardized electronic documents for exchanging clinical information between different healthcare systems. They provide a common structure for capturing and sharing patient data like medical history, medications, and care plans between different healthcare systems and providers. Think of it as a collaborative Google Doc that you can add, amend, and remove entries from. + +The Epic NoteReader CDI is a SOAP/CDA-based NLP service that extracts structured data from clinical notes. Like CDS Hooks, it operates in real-time and is triggered when a clinician opts into CDI functionality and signs or pends a note. + +The primary use case for Epic NoteReader is to convert free-text medical documentation into coded information that can be used for billing, quality reporting, continuity of care, and clinical decision support at the point-of-care ([case study](https://www.researchsquare.com/article/rs-4925228/v1)). + +It is a vendor-specific component (Epic), but we plan to add support for other IHE SOAP/CDA services in the future. + +| When | Where | What you receive | What you send back | +| :-------- | :-----| :-------------------------- |----------------------------| +| Triggered when a clinician opts in to CDI functionality and signs or pends a note | EHR documentation modules (e.g. NoteReader in Epic) | A CDA document containing continuity of care data and free-text clinical notes | A CDA document with additional structured data extracted by your CDI service | + +## HealthChainAPI Integration + +Use the `NoteReaderService` with HealthChainAPI to handle SOAP/CDA workflows: + +```python +from healthchain.gateway import HealthChainAPI, NoteReaderService +from healthchain.models import CdaRequest, CdaResponse + +app = HealthChainAPI() +notes = NoteReaderService() + +@notes.method("ProcessDocument") +def process_note(request: CdaRequest) -> CdaResponse: + # Your NLP pipeline here + processed_document = nlp_pipeline.process(request) + return processed_document + +app.register_service(notes, path="/soap") +``` + +## Supported Workflows + +| Workflow Name | Description | Trigger | Status | +|-----------|-------------|---------|----------| +| `sign-note-inpatient` | CDI processing for inpatient clinical notes | Signing or pending a note in Epic inpatient setting | βœ… | +| `sign-note-outpatient` | CDI processing for outpatient clinical notes | Signing or pending a note in Epic outpatient setting | ⏳ | + +Currently supports parsing of problems, medications, and allergies sections. + +## API Endpoints + +When registered with HealthChainAPI, the following endpoints are automatically created: + +| Endpoint | Method | Function | Protocol | +|------|--------|----------|----------| +| `/notereader/` | POST | `process_notereader_document` | SOAP | + +*Note: NoteReader is a vendor-specific component (Epic). Different EHR vendors have varying support for third-party CDI services.* + +## Request/Response Format + +### CDA Request Example + +```xml + + + + + + + CDA Document with Problem List and Progress Note + + + + + + + + +
+ + + Problems + + + Hypertension + + + +
+
+ + + +
+ + + Progress Note + + Patient's blood pressure remains elevated. + Discussed lifestyle modifications and medication adherence. + Started Lisinopril 10 mg daily for hypertension management. + +
+
+
+
+
+``` + +### CDA Response Example + +The response includes additional structured sections extracted from the clinical text: + +```xml + + + + + + + + +
+ + + Medications + + + Lisinopril 10 mg oral tablet, once daily + + + +
+
+
+
+
+``` + +## Supported CDA Sections + +- **Problems/Conditions**: ICD-10/SNOMED CT coded diagnoses +- **Medications**: SNOMED CT/RxNorm coded medications with dosage and frequency +- **Allergies**: Allergen identification and reaction severity +- **Progress Notes**: Free-text clinical documentation + +## Data Flow + +| Stage | Input | Output | +|-------|-------|--------| +| Gateway Receives | `CdaRequest` | Processed by your service | +| Gateway Returns | Your processed result | `CdaResponse` | + +You can use the [CdaConnector](../pipeline/connectors/cdaconnector.md) to handle conversion between CDA documents and HealthChain pipeline data containers. diff --git a/docs/reference/index.md b/docs/reference/index.md index 8d405beb..3aa9afeb 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -2,6 +2,7 @@ ## Core Components +- [Gateway](gateway/gateway.md): Connect to multiple healthcare systems and services. - [Pipeline](pipeline/pipeline.md): Build and manage processing pipelines for healthcare NLP and ML tasks. - [Sandbox](sandbox/sandbox.md): Test your pipelines in a simulated healthcare environment. - [Interoperability](interop/interop.md): Convert between healthcare data formats like FHIR, CDA, and HL7v2. diff --git a/docs/reference/pipeline/components/cdscardcreator.md b/docs/reference/pipeline/components/cdscardcreator.md index e0c35354..81cf46e7 100644 --- a/docs/reference/pipeline/components/cdscardcreator.md +++ b/docs/reference/pipeline/components/cdscardcreator.md @@ -127,4 +127,4 @@ pipeline.add_component(CdsCardCreator( ## Related Documentation - [CDS Hooks Specification](https://cds-hooks.org/) -- [Clinical Decision Support Documentation](../../sandbox/use_cases/cds.md) +- [Clinical Decision Support Documentation](../../gateway/cdshooks.md) diff --git a/docs/reference/pipeline/connectors/cdaconnector.md b/docs/reference/pipeline/connectors/cdaconnector.md index 4a6ddda8..a6338470 100644 --- a/docs/reference/pipeline/connectors/cdaconnector.md +++ b/docs/reference/pipeline/connectors/cdaconnector.md @@ -4,7 +4,7 @@ The `CdaConnector` parses CDA documents, extracting free-text notes and relevant This connector is particularly useful for clinical documentation improvement (CDI) workflows where a document needs to be processed and updated with additional structured data. -[(Full Documentation on Clinical Documentation)](../../sandbox/use_cases/clindoc.md) +[(Full Documentation on Clinical Documentation)](../../gateway/soap_cda.md) ## Input and Output diff --git a/docs/reference/pipeline/connectors/cdsfhirconnector.md b/docs/reference/pipeline/connectors/cdsfhirconnector.md index 2b49a86e..088dc8e4 100644 --- a/docs/reference/pipeline/connectors/cdsfhirconnector.md +++ b/docs/reference/pipeline/connectors/cdsfhirconnector.md @@ -2,7 +2,7 @@ The `CdsFhirConnector` handles FHIR data in the context of Clinical Decision Support (CDS) services, specifically using the [CDS Hooks specification](https://cds-hooks.org/). -[(Full Documentation on Clinical Decision Support)](../../sandbox/use_cases/cds.md) +[(Full Documentation on Clinical Decision Support)](../../gateway/cdshooks.md) ## Input and Output diff --git a/docs/reference/pipeline/connectors/connectors.md b/docs/reference/pipeline/connectors/connectors.md index 0f58651e..9e2f1463 100644 --- a/docs/reference/pipeline/connectors/connectors.md +++ b/docs/reference/pipeline/connectors/connectors.md @@ -23,8 +23,8 @@ Each connector can be mapped to a specific use case in the sandbox module. | Connector | Use Case | |-----------|----------| -| `CdaConnector` | [**Clinical Documentation**](../../sandbox/use_cases/clindoc.md) | -| `CdsFhirConnector` | [**Clinical Decision Support**](../../sandbox/use_cases/cds.md) | +| `CdaConnector` | [**Clinical Documentation**](../../gateway/soap_cda.md) | +| `CdsFhirConnector` | [**Clinical Decision Support**](../../gateway/cdshooks.md) | ## Adding connectors to your pipeline diff --git a/docs/reference/pipeline/pipeline.md b/docs/reference/pipeline/pipeline.md index e5df359c..20c64f42 100644 --- a/docs/reference/pipeline/pipeline.md +++ b/docs/reference/pipeline/pipeline.md @@ -15,7 +15,7 @@ HealthChain comes with a set of prebuilt pipelines that are out-of-the-box imple | **QAPipeline** [TODO] | `Document` | N/A | A Question Answering pipeline suitable for conversational AI applications | Developing a chatbot to answer patient queries about their medical records | | **ClassificationPipeline** [TODO] | `Tabular` | `CdsFhirConnector` | A pipeline for machine learning classification tasks | Predicting patient readmission risk based on historical health data | -Prebuilt pipelines are end-to-end workflows with Connectors built into them. They interact with raw data received from EHR interfaces, usually CDA or FHIR data from specific [use cases](../sandbox/use_cases/use_cases.md). +Prebuilt pipelines are end-to-end workflows with Connectors built into them. They interact with raw data received from EHR interfaces, usually CDA or FHIR data from specific [protocols](../gateway/gateway.md). You can load your models directly as a pipeline object, from local files or from a remote model repository such as Hugging Face. diff --git a/docs/reference/sandbox/client.md b/docs/reference/sandbox/client.md deleted file mode 100644 index 50712925..00000000 --- a/docs/reference/sandbox/client.md +++ /dev/null @@ -1,50 +0,0 @@ -# Client - -A client is a healthcare system object that requests information and processing from an external service. This is typically an EHR system, but we may also support other health objects in the future such as a CPOE (Computerized Physician Order Entry). - -We can mark a client by using the decorator `@hc.ehr`. You must declare a particular **workflow** for the EHR client, which informs the sandbox how your data will be formatted. You can find more information on the [Use Cases](./use_cases/use_cases.md) documentation page. - -Data returned from the client should be wrapped in a [Prefetch](../../../api/data_models.md#healthchain.models.data.prefetch) object, where prefetch is a dictionary of FHIR resources with keys corresponding to the CDS service. - -You can optionally specify the number of requests to generate with the `num` parameter. - -=== "Clinical Documentation" - ```python - import healthchain as hc - - from healthchain.sandbox.use_cases import ClinicalDocumentation - from healthchain.fhir import create_document_reference - - from fhir.resources.documentreference import DocumentReference - - @hc.sandbox - class MyCoolSandbox(ClinicalDocumentation): - def __init__(self) -> None: - pass - - @hc.ehr(workflow="sign-note-inpatient", num=10) - def load_data_in_client(self) -> DocumentReference: - # Do things here to load in your data - return create_document_reference(data="", content_type="text/xml") - ``` - -=== "CDS" - ```python - import healthchain as hc - - from healthchain.sandbox.use_cases import ClinicalDecisionSupport - from healthchain.models import Prefetch - - from fhir.resources.patient import Patient - - @hc.sandbox - class MyCoolSandbox(ClinicalDecisionSupport): - def __init__(self) -> None: - pass - - @hc.ehr(workflow="patient-view", num=10) - def load_data_in_client(self) -> Prefetch: - # Do things here to load in your data - return Prefetch(prefetch={"patient": Patient(id="123")}) - - ``` diff --git a/docs/reference/sandbox/sandbox.md b/docs/reference/sandbox/sandbox.md deleted file mode 100644 index cff13b3d..00000000 --- a/docs/reference/sandbox/sandbox.md +++ /dev/null @@ -1,60 +0,0 @@ -# Sandbox - -Designing your pipeline to integrate well in a healthcare context is an essential step to turning it into an application that -could potentially be adapted for real-world use. As a developer who has years of experience deploying healthcare NLP solutions into hospitals, I know how painful and slow this process can be. - -A sandbox makes this process easier. It provides a staging environment to debug, test, track, and interact with your application in realistic deployment scenarios without having to gain access to such environments, especially ones that are tightly integrated with local EHR configurations. Think of it as integration testing in healthcare systems. - -For a given sandbox run: - -1. Data is generated or loaded into a client (EHR) - -2. Data is wrapped and sent as standardized API requests the designated service - -3. Data is processed by the service (you application) - -4. Processed result is wrapped and sent back to the service as a standardized API response - -5. Data is received by the client which could be rendered in a UI interface - -To create a sandbox, initialize a class that inherits from a type of `UseCase` and decorate it with the `@hc.sandbox` decorator. `UseCase` loads in the blueprint of the API endpoints for the specified use case, and `@hc.sandbox` orchestrates these interactions. - -Every sandbox also requires a [**Client**](./client.md) function marked by `@hc.ehr` and a [**Service**](./service.md) function marked by `@hc.api`. Every client function must specify a **workflow** that informs the sandbox how your data will be formatted. For more information on workflows, see the [Use Cases](./use_cases/use_cases.md) documentation. - -!!! success "For each sandbox you need to specify..." - - - Use case - - service function - - client function - - workflow of client - - -```python -import healthchain as hc - -from healthchain.pipeline import SummarizationPipeline -from healthchain.sandbox.use_cases import ClinicalDecisionSupport -from healthchain.data_generators import CdsDataGenerator -from healthchain.models import CDSRequest, Prefetch, CDSResponse - - -@hc.sandbox -class MyCoolSandbox(ClinicalDecisionSupport): - def __init__(self): - self.data_generator = CdsDataGenerator() - self.pipeline = SummarizationPipeline('gpt-4o') - - @hc.ehr(workflow="encounter-discharge") - def load_data_in_client(self) -> Prefetch: - prefetch = self.data_generator.generate_prefetch() - return prefetch - - @hc.api - def my_service(self, request: CDSRequest) -> CDSResponse: - cds_response = self.pipeline(request) - return cds_response - -if __name__ == "__main__": - cds = MyCoolSandbox() - cds.start_sandbox() -``` diff --git a/docs/reference/sandbox/service.md b/docs/reference/sandbox/service.md deleted file mode 100644 index 417a7117..00000000 --- a/docs/reference/sandbox/service.md +++ /dev/null @@ -1,66 +0,0 @@ -# Service - -A service is typically an API of a third-party system that returns data to the client, the healthcare provider object. This is where you define your application logic. - -When you decorate a function with `@hc.api` in a sandbox, the function is mounted standardized API endpoint an EHR client can make requests to. This can be defined by healthcare interoperability standards, such as HL7, or the EHR provider. HealthChain will start a [FastAPI](https://fastapi.tiangolo.com/) server with these APIs pre-defined for you. - -Your service function receives use case specific request data as input and returns the response data. - -We recommend you initialize your pipeline in the class `__init__` method. - -Here are minimal examples for each use case: - -=== "Clinical Documentation" - ```python - import healthchain as hc - - from healthchain.sandbox.use_cases import ClinicalDocumentation - from healthchain.pipeline import MedicalCodingPipeline - from healthchain.models import CdaRequest, CdaResponse - from healthchain.fhir import create_document_reference - from fhir.resources.documentreference import DocumentReference - - @hc.sandbox - class MyCoolSandbox(ClinicalDocumentation): - def __init__(self): - self.pipeline = MedicalCodingPipeline.load("./path/to/model") - - @hc.ehr(workflow="sign-note-inpatient") - def load_data_in_client(self) -> DocumentReference: - with open('/path/to/data.xml', "r") as file: - xml_string = file.read() - - return create_document_reference(data=xml_string, content_type="text/xml") - - @hc.api - def my_service(self, request: CdaRequest) -> CdaResponse: - response = self.pipeline(request) - return response - ``` - -=== "CDS" - ```python - import healthchain as hc - - from healthchain.sandbox.use_cases import ClinicalDecisionSupport - from healthchain.pipeline import SummarizationPipeline - from healthchain.models import CDSRequest, CDSResponse, Prefetch - from fhir.resources.patient import Patient - - @hc.sandbox - class MyCoolSandbox(ClinicalDecisionSupport): - def __init__(self): - self.pipeline = SummarizationPipeline.load("model-name") - - @hc.ehr(workflow="patient-view") - def load_data_in_client(self) -> Prefetch: - with open('/path/to/data.json', "r") as file: - fhir_json = file.read() - - return Prefetch(prefetch={"patient": Patient(**fhir_json)}) - - @hc.api - def my_service(self, request: CDSRequest) -> CDSResponse: - response = self.pipeline(request) - return response - ``` diff --git a/docs/reference/sandbox/use_cases/cds.md b/docs/reference/sandbox/use_cases/cds.md deleted file mode 100644 index 87b6ab0f..00000000 --- a/docs/reference/sandbox/use_cases/cds.md +++ /dev/null @@ -1,91 +0,0 @@ -# Use Cases - -## Clinical Decision Support (CDS) - -CDS workflows are based on [CDS Hooks](https://cds-hooks.org/). CDS Hooks is an [HL7](https://cds-hooks.hl7.org) published specification for clinical decision support. CDS hooks communicate using [FHIR (Fast Healthcare Interoperability Resources)](https://hl7.org/fhir/). For more information you can consult the [official documentation](https://cds-hooks.org/). - -| When | Where | What you receive | What you send back | -| :-------- | :-----| :-------------------------- |----------------------------| -| Triggered at certain events during a clinician's workflow, e.g. when a patient record is opened. | EHR | The context of the event and FHIR resources that are requested by your service. e.g. patient ID, `Encounter` and `Patient`. | β€œCards” displaying text, actionable suggestions, or links to launch a [SMART](https://smarthealthit.org/) app from within the workflow. | - -## Data Flow - -| Stage | Input | Output | -|-------|-------|--------| -| Client | N/A | `Prefetch` | -| Service | `CDSRequest` | `CDSResponse` | - - -[CdsFhirConnector](../../pipeline/connectors/cdsfhirconnector.md) handles the conversion of `CDSRequests` :material-swap-horizontal: `Document` :material-swap-horizontal: `CDSResponse` in a HealthChain pipeline. - - -## Supported Workflows - -| Workflow Name | Description | Trigger | Maturity | -|-----------|-------------|---------|----------| -| `patient-view` | Triggered when a patient chart is opened | Opening a patient's chart | βœ… | -| `order-select` | Triggered when a new order is selected | Selecting a new order | ⏳ | -| `order-sign` | Triggered when orders are being signed | Signing orders | ⏳ | -| `encounter-discharge` | Triggered when a patient is being discharged | Discharging a patient | βœ… | - - - -## Generated API Endpoints - -| Endpoint | Method | Function Name | API Protocol | -|------|--------|----------|--------------| -| `/cds-services` | GET | `cds_discovery` | REST | -| `/cds-services/{id}` | POST | `cds_service` | REST | - -## What does the data look like? - -### Example `CDSRequest` - -```json -{ - "hookInstance" : "23f1a303-991f-4118-86c5-11d99a39222e", - "fhirServer" : "https://fhir.example.org", - "hook" : "patient-view", - "context" : { - "patientId" : "1288992", - "userId" : "Practitioner/example" - }, - "prefetch" : { - "patientToGreet" : { - "resourceType" : "Patient", - "gender" : "male", - "birthDate" : "1925-12-23", - "id" : "1288992", - "active" : true - } - } -} -``` -### Example `CDSResponse` - -```json -{ - "summary": "Bilirubin: Based on the age of this patient consider overlaying bilirubin [Mass/volume] results over a time-based risk chart", - "indicator": "info", - "detail": "The focus of this app is to reduce the incidence of severe hyperbilirubinemia and bilirubin encephalopathy while minimizing the risks of unintended harm such as maternal anxiety, decreased breastfeeding, and unnecessary costs or treatment.", - "source": { - "label": "Intermountain", - "url": null - }, - "links": [ - { - "label": "Bilirubin SMART app", - "url": "https://example.com/launch", - "type": "smart" - } - ] -} - -``` - -## Implemented FHIR Resources - -- `Patient` -- `Encounter` -- `Procedure` -- `MedicationRequest` diff --git a/docs/reference/sandbox/use_cases/clindoc.md b/docs/reference/sandbox/use_cases/clindoc.md deleted file mode 100644 index b2d2d59a..00000000 --- a/docs/reference/sandbox/use_cases/clindoc.md +++ /dev/null @@ -1,229 +0,0 @@ -# Clinical Documentation -The `ClinicalDocumentation` use case implements a real-time Clinical Documentation Improvement (CDI) service. It currently implements the Epic-integrated NoteReader CDI specification, which communicates with a third-party NLP engine to analyse clinical notes and extract structured data. It helps convert free-text medical documentation into coded information that can be used for billing, quality reporting, continuity of care, and clinical decision support ([case study](https://www.researchsquare.com/article/rs-4925228/v1)). - -`ClinicalDocumentation` communicates using [CDA (Clinical Document Architecture)](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/). CDAs are standardized electronic documents for exchanging clinical information. They provide a common structure for capturing and sharing patient data like medical history, medications, and care plans between different healthcare systems and providers. Think of it as a collaborative Google Doc that you can add, amend, and remove entries from. - -| When | Where | What you receive | What you send back | -| :-------- | :-----| :-------------------------- |----------------------------| -| Triggered when a clinician opts in to a CDI functionality and signs or pends a note after writing it. | Specific modules in EHR where clinical documentation takes place, such as NoteReader in Epic. | A CDA document which contains continuity of care data and free-text data, e.g. a patient's problem list and the progress note that the clinician has entered in the EHR. | A CDA document which contains additional structured data extracted and returned by your CDI service. | - - -## Data Flow - -| Stage | Input | Output | -|-------|-------|--------| -| Client | N/A | `DocumentReference` | -| Service | `CdaRequest` | `CdaResponse` | - - -[CdaConnector](../../pipeline/connectors/cdaconnector.md) handles the conversion of `CdaRequests` :material-swap-horizontal: `DocumentReference` :material-swap-horizontal: `CdaResponse` in a HealthChain pipeline. - - -## Supported Workflows - -| Workflow Name | Description | Trigger | Maturity | -|-----------|-------------|---------|----------| -| `sign-note-inpatient` | Triggered when a clinician opts in to a CDI functionality and signs or pends a note after writing it in an inpatient setting. | Signing or pending a note in Epic | βœ… | -| `sign-note-outpatient` | Triggered when a clinician opts in to a CDI functionality and signs or pends a note after writing it in an outpatient setting. | Signing or pending a note in Epic | ⏳ | - -We support parsing of problems, medications, and allergies sections, though some of the data fields may be limited. We plan to implement additional CDI services and workflows for different vendor specifications. - -## Generated API Endpoints - -| Endpoint | Method | Function | API Protocol | -|------|--------|----------|--------------| -| `/notereader/` | POST | `process_notereader_document` | SOAP | - - -Note that NoteReader is a vendor-specific component (Epic). This particular note-based workflow is one type of CDI service. Different EHR vendors will have different support for third-party CDI services. - -## What does the data look like? -### Example CDA Request - -```xml - - - - - - - CDA Document with Problem List and Progress Note - - - - - - - - -
- - - Problems - - - Hypertension - - - - - - - - - - - - - - - - - Hypertension - - - - - - - - - -
-
- - - -
- - - Progress Note - - Patient's blood pressure remains elevated. Discussed lifestyle modifications and medication adherence. Started Lisinopril 10 mg daily for hypertension management. Will follow up in 3 months to assess response to treatment. - -
-
-
-
-
-``` - -### Example CDA Response - -```xml - - - - - - - CDA Document with Problem List, Medication, and Progress Note - - - - - - - - -
- - - Problems - - - Hypertension - - - - - - - - - - - - - - - - - Hypertension - - - - - - - - - -
-
- - - -
- - - Medications - - - Lisinopril 10 mg oral tablet, once daily - - - - - - - - - - - - - - - - - - - Lisinopril 10 mg oral tablet - - - - - - - - - - - - - -
-
- - - -
- - - Progress Note - - Patient's blood pressure remains elevated. Discussed lifestyle modifications and medication adherence. Started Lisinopril 10 mg daily for hypertension management. Will follow up in 3 months to assess response to treatment. - -
-
-
-
-
-``` - -## Implemented CDA Sections -- Problems -- Medications (including information on dosage, frequency, duration, route) -- Allergies (including information on severity, reaction and type of allergen) -- Progress Note (free-text) diff --git a/docs/reference/sandbox/use_cases/use_cases.md b/docs/reference/sandbox/use_cases/use_cases.md deleted file mode 100644 index 8104764e..00000000 --- a/docs/reference/sandbox/use_cases/use_cases.md +++ /dev/null @@ -1,10 +0,0 @@ -# Use Cases - -Use cases are the core building blocks of sandboxes. They define the API endpoints and the data formats for a given workflow. - -We currently support: - -- [Clinical Decision Support](./cds.md) -- [Clinical Documentation](./clindoc.md) - -More documentation on the pros and cons of each use case will be added soon. For now, you can refer to the source code for more details. diff --git a/docs/reference/utilities/data_generator.md b/docs/reference/utilities/data_generator.md index 8c18b8c6..57c37aee 100644 --- a/docs/reference/utilities/data_generator.md +++ b/docs/reference/utilities/data_generator.md @@ -1,15 +1,10 @@ # Data Generator -Healthcare data is interoperable, but not composable - every deployment site will have different ways of configuring data and terminology. This matters when you develop applications that need to integrate into these systems, especially when you need to reliably extract data for your model to consume. +Healthcare systems use standardized data formats, but each hospital or clinic configures their data differently. This creates challenges when building applications that need to work across multiple healthcare systems. -The aim of the data generator is not to generate realistic data suitable for use cases such as patient population studies, but rather to generate data that is structurally compliant with what is expected of EHR configurations, and to be able to test and handle variations in this. +The data generator creates test data that matches the structure and format expected by Electronic Health Record (EHR) systems. It's designed for testing your applications, not for research studies that need realistic patient populations. -For this reason the data generator is opinionated by specific workflows and use cases. - -!!! note - We're aware we may not cover everyone's use cases, so if you have strong opinions about this, please [reach out](https://discord.gg/UQC6uAepUz)! - -On the synthetic data spectrum defined by [this UK ONS methodology working paper](https://www.ons.gov.uk/methodology/methodologicalpublications/generalmethodology/onsworkingpaperseries/onsmethodologyworkingpaperseriesnumber16syntheticdatapilot#:~:text=Synthetic%20data%20at%20ONS&text=Synthetic%20data%20is%20created%20by,that%20provided%20the%20original%20data.%E2%80%9D), HealthChain generates level 1: synthetic structural data. +According to the [UK ONS synthetic data classification](https://www.ons.gov.uk/methodology/methodologicalpublications/generalmethodology/onsworkingpaperseries/onsmethodologyworkingpaperseriesnumber16syntheticdatapilot#:~:text=Synthetic%20data%20at%20ONS&text=Synthetic%20data%20is%20created%20by,that%20provided%20the%20original%20data.%E2%80%9D), HealthChain generates "level 1: synthetic structural data" - data that follows the correct format but contains fictional information. ![Synthetic data](../../assets/images/synthetic_data_ons.png) @@ -28,7 +23,7 @@ Current implemented workflows: | [order-sign](https://cds-hooks.org/hooks/order-sign/)| :material-check: Partial | Future: `MedicationRequest`, `ProcedureRequest`, `ServiceRequest` | | [order-select](https://cds-hooks.org/hooks/order-select/) | :material-check: Partial | Future: `MedicationRequest`, `ProcedureRequest`, `ServiceRequest` | -For more information on CDS workflows, see the [CDS Use Case](../sandbox/use_cases/cds.md) documentation. +For more information on CDS workflows, see the [CDS Hooks Protocol](../gateway/cdshooks.md) documentation. You can use the data generator within a client function or on its own. diff --git a/docs/reference/utilities/sandbox.md b/docs/reference/utilities/sandbox.md new file mode 100644 index 00000000..cc6c8cba --- /dev/null +++ b/docs/reference/utilities/sandbox.md @@ -0,0 +1,105 @@ +# Sandbox Testing + +Sandbox environments provide testing utilities for validating your HealthChain applications in realistic healthcare contexts. These are primarily used for development and testing rather than production deployment. + +!!! info "For production applications, use [HealthChainAPI](../gateway/api.md) instead" + + Sandbox is a testing utility. For production healthcare AI applications, use the [Gateway](../gateway/gateway.md) with [HealthChainAPI](../gateway/api.md). + +## Quick Example + +Test CDS Hooks workflows with synthetic data: + +```python +import healthchain as hc +from healthchain.sandbox.use_cases import ClinicalDecisionSupport + +@hc.sandbox +class TestCDS(ClinicalDecisionSupport): + def __init__(self): + self.pipeline = SummarizationPipeline.from_model_id("facebook/bart-large-cnn") + + @hc.ehr(workflow="encounter-discharge") + def ehr_database_client(self): + return self.data_generator.generate_prefetch() + +# Run with: healthchain run test_cds.py +``` + +## Available Testing Scenarios + +- **[CDS Hooks](../gateway/cdshooks.md)**: `ClinicalDecisionSupport` - Test clinical decision support workflows +- **[Clinical Documentation](../gateway/soap_cda.md)**: `ClinicalDocumentation` - Test SOAP/CDA document processing workflows + +## EHR Client Simulation + +The `@hc.ehr` decorator simulates EHR client behavior for testing. You must specify a **workflow** that determines how your data will be formatted. + +Data should be wrapped in a [Prefetch](../../../api/data_models.md#healthchain.models.data.prefetch) object for CDS workflows, or return appropriate FHIR resources for document workflows. + +=== "Clinical Decision Support" + ```python + import healthchain as hc + from healthchain.sandbox.use_cases import ClinicalDecisionSupport + from healthchain.models import Prefetch + from fhir.resources.patient import Patient + + @hc.sandbox + class MyCoolSandbox(ClinicalDecisionSupport): + @hc.ehr(workflow="patient-view", num=10) + def load_data_in_client(self) -> Prefetch: + # Load your test data here + return Prefetch(prefetch={"patient": Patient(id="123")}) + ``` + +=== "Clinical Documentation" + ```python + import healthchain as hc + from healthchain.sandbox.use_cases import ClinicalDocumentation + from healthchain.fhir import create_document_reference + from fhir.resources.documentreference import DocumentReference + + @hc.sandbox + class MyCoolSandbox(ClinicalDocumentation): + @hc.ehr(workflow="sign-note-inpatient", num=10) + def load_data_in_client(self) -> DocumentReference: + # Load your test data here + return create_document_reference(data="", content_type="text/xml") + ``` + +**Parameters:** + +- `workflow`: The healthcare workflow to simulate (e.g., "patient-view", "sign-note-inpatient") +- `num`: Optional number of requests to generate for testing + +## Migration to Production + +!!! warning "Sandbox Decorators are Deprecated" + `@hc.api` is deprecated. Use [HealthChainAPI](../gateway/api.md) for production. + +**Quick Migration:** + +```python +# Before (Testing) - Shows deprecation warning +@hc.sandbox +class TestCDS(ClinicalDecisionSupport): + @hc.api # ⚠️ DEPRECATED + def my_service(self, request): ... + +# After (Production) +from healthchain.gateway import HealthChainAPI, CDSHooksService + +app = HealthChainAPI() +cds = CDSHooksService() + +@cds.hook("patient-view") +def my_service(request): ... + +app.register_service(cds) +``` + +**Next Steps:** + +1. **Testing**: Continue using sandbox utilities with deprecation warnings +2. **Production**: Migrate to [HealthChainAPI Gateway](../gateway/gateway.md) +3. **Protocols**: See [CDS Hooks](../gateway/cdshooks.md) and [SOAP/CDA](../gateway/soap_cda.md) diff --git a/healthchain/gateway/README.md b/healthchain/gateway/README.md index 580231c0..ac670c0f 100644 --- a/healthchain/gateway/README.md +++ b/healthchain/gateway/README.md @@ -16,7 +16,7 @@ All protocol implementations extend `BaseGateway` to provide protocol-specific f ```python from healthchain.gateway import ( HealthChainAPI, BaseGateway, - FHIRGateway, CDSHooksGateway, NoteReaderGateway + FHIRGateway, CDSHooksService, NoteReaderService ) # Create the application @@ -24,8 +24,8 @@ app = HealthChainAPI() # Create gateways for different protocols fhir = FHIRGateway(base_url="https://fhir.example.com/r4") -cds = CDSHooksGateway() -soap = NoteReaderGateway() +cds = CDSHooksService() +soap = NoteReaderService() # Register protocol-specific handlers @fhir.read(Patient) @@ -49,12 +49,12 @@ app.register_gateway(soap) ## Core Types - `BaseGateway`: The central abstraction for all protocol gateway implementations -- `EventDispatcherMixin`: A reusable mixin that provides event dispatching +- `EventCapability`: A component that provides event dispatching - `HealthChainAPI`: FastAPI wrapper for healthcare gateway registration - Concrete gateway implementations: - `FHIRGateway`: FHIR REST API protocol - - `CDSHooksGateway`: CDS Hooks protocol - - `NoteReaderGateway`: SOAP/CDA protocol + - `CDSHooksService`: CDS Hooks protocol + - `NoteReaderService`: SOAP/CDA protocol ## Quick Start @@ -87,9 +87,9 @@ The gateway module uses Python's Protocol typing for robust interface definition ```python # Register gateways with explicit types -app.register_gateway(fhir) # Implements FHIRGatewayProtocol -app.register_gateway(cds) # Implements CDSHooksGatewayProtocol -app.register_gateway(soap) # Implements SOAPGatewayProtocol +app.register_gateway(fhir) # Implements FHIRGateway +app.register_gateway(cds) # Implements CDSHooksService +app.register_gateway(soap) # Implements NoteReaderService # Get typed gateway dependencies in API routes @app.get("/api/patient/{id}") @@ -106,3 +106,26 @@ This approach provides: - Clear interface definition for gateway implementations - Runtime type safety with detailed error messages - Better testability through protocol-based mocking + +## Context Managers + +Context managers are a powerful tool for managing resource lifecycles in a safe and predictable way. They are particularly useful for: + +- Standalone CRUD operations +- Creating new resources +- Bulk operations +- Cross-resource transactions +- When you need guaranteed cleanup/connection management + +The decorator pattern is more for processing existing resources, while context managers are for managing resource lifecycles. + +```python +@fhir.read(Patient) +async def read_patient_and_create_note(patient): + # Use context manager to create related resources + async with fhir.resource_context("DiagnosticReport") as report: + report["subject"] = {"reference": f"Patient/{patient.id}"} + report["status"] = "final" + + return patient +``` diff --git a/healthchain/gateway/__init__.py b/healthchain/gateway/__init__.py index 56afba4b..62f2aa57 100644 --- a/healthchain/gateway/__init__.py +++ b/healthchain/gateway/__init__.py @@ -1,51 +1,57 @@ """ HealthChain Gateway Module. -This module provides a secure gateway layer that manages routing, transformation, -and event handling between healthcare systems (FHIR servers, EHRs) with a focus on -maintainable, compliant integration patterns. - -Core components: -- BaseGateway: Abstract base class for all gateway implementations -- Protocol implementations: Concrete gateways for various healthcare protocols -- Event system: Publish-subscribe framework for healthcare events -- API framework: FastAPI-based application for exposing gateway endpoints +This module provides the core gateway functionality for HealthChain, +including API applications, protocol handlers, and healthcare integrations. """ -# Main application exports -from healthchain.gateway.api.app import HealthChainAPI, create_app - -# Core components -from healthchain.gateway.core.base import ( - BaseGateway, - GatewayConfig, +# API Components +from healthchain.gateway.api.app import HealthChainAPI +from healthchain.gateway.api.dependencies import ( + get_app, + get_event_dispatcher, + get_gateway, + get_all_gateways, ) -# Event system +# Core Components +from healthchain.gateway.core.base import BaseGateway, BaseProtocolHandler +from healthchain.gateway.core.fhirgateway import FHIRGateway + +# Protocol Handlers +from healthchain.gateway.protocols.cdshooks import CDSHooksService +from healthchain.gateway.protocols.notereader import NoteReaderService + +# Event System from healthchain.gateway.events.dispatcher import ( EventDispatcher, EHREvent, EHREventType, ) -# Re-export gateway implementations -from healthchain.gateway.protocols import ( - CDSHooksGateway, - NoteReaderGateway, -) +# Client Utilities +from healthchain.gateway.clients.fhir import AsyncFHIRClient +from healthchain.gateway.clients.pool import FHIRClientPool __all__ = [ # API "HealthChainAPI", - "create_app", + "get_app", + "get_event_dispatcher", + "get_gateway", + "get_all_gateways", # Core "BaseGateway", - "GatewayConfig", + "BaseProtocolHandler", + "FHIRGateway", + # Protocols + "CDSHooksService", + "NoteReaderService", # Events "EventDispatcher", "EHREvent", "EHREventType", - # Gateways - "CDSHooksGateway", - "NoteReaderGateway", + # Clients + "AsyncFHIRClient", + "FHIRClientPool", ] diff --git a/healthchain/gateway/api/__init__.py b/healthchain/gateway/api/__init__.py index 8e19de07..d6226a54 100644 --- a/healthchain/gateway/api/__init__.py +++ b/healthchain/gateway/api/__init__.py @@ -1,39 +1,38 @@ """ -HealthChain API module. +API module for HealthChain Gateway. -This module provides API components for the HealthChain gateway. +This module provides the FastAPI application wrapper and dependency injection +for healthcare integrations. """ -from healthchain.gateway.api.app import HealthChainAPI, create_app +from healthchain.gateway.api.app import HealthChainAPI from healthchain.gateway.api.dependencies import ( get_app, get_event_dispatcher, get_gateway, get_all_gateways, - get_typed_gateway, + get_service, + get_all_services, + get_gateway_by_name, + get_service_by_name, ) from healthchain.gateway.api.protocols import ( HealthChainAPIProtocol, - GatewayProtocol, EventDispatcherProtocol, - FHIRGatewayProtocol, - SOAPGatewayProtocol, + FHIRConnectionManagerProtocol, ) __all__ = [ - # Classes "HealthChainAPI", - # Functions - "create_app", "get_app", "get_event_dispatcher", "get_gateway", "get_all_gateways", - "get_typed_gateway", - # Protocols + "get_service", + "get_all_services", + "get_gateway_by_name", + "get_service_by_name", "HealthChainAPIProtocol", - "GatewayProtocol", "EventDispatcherProtocol", - "FHIRGatewayProtocol", - "SOAPGatewayProtocol", + "FHIRConnectionManagerProtocol", ] diff --git a/healthchain/gateway/api/app.py b/healthchain/gateway/api/app.py index 0e73d1a1..4b27583d 100644 --- a/healthchain/gateway/api/app.py +++ b/healthchain/gateway/api/app.py @@ -6,23 +6,19 @@ """ import logging -import importlib -import inspect -import os -import signal +from contextlib import asynccontextmanager from datetime import datetime from fastapi import FastAPI, APIRouter, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from contextlib import asynccontextmanager from termcolor import colored -from typing import Dict, Optional, Type, Union, Set +from typing import Dict, Optional, Type, Union -from healthchain.gateway.core.base import BaseGateway +from healthchain.gateway.core.base import BaseGateway, BaseProtocolHandler from healthchain.gateway.events.dispatcher import EventDispatcher from healthchain.gateway.api.dependencies import get_app @@ -48,13 +44,14 @@ class HealthChainAPI(FastAPI): # Create and register gateways fhir_gateway = FHIRGateway() - cds_gateway = CDSHooksGateway() - note_gateway = NoteReaderGateway() + cds_service = CDSHooksService() + note_service = NoteReaderService() # Register with the API app.register_gateway(fhir_gateway) - app.register_gateway(cds_gateway) - app.register_gateway(note_gateway) + + app.register_service(cds_service) + app.register_service(note_service) # Run the app with uvicorn uvicorn.run(app) @@ -75,35 +72,44 @@ def __init__( Initialize the HealthChainAPI application. Args: - title: API title for documentation - description: API description for documentation + title: API title + description: API description version: API version - enable_cors: Whether to enable CORS middleware - enable_events: Whether to enable event dispatching functionality - event_dispatcher: Optional event dispatcher to use (for testing/DI) - **kwargs: Additional keyword arguments to pass to FastAPI + enable_cors: Enable CORS middleware + enable_events: Enable event dispatching + event_dispatcher: Optional custom event dispatcher + **kwargs: Additional FastAPI configuration """ - # Set up the lifespan - if "lifespan" not in kwargs: - kwargs["lifespan"] = self.lifespan - super().__init__( - title=title, description=description, version=version, **kwargs + title=title, + description=description, + version=version, + lifespan=self._lifespan, + **kwargs, ) - self.gateways: Dict[str, BaseGateway] = {} - self.gateway_endpoints: Dict[str, Set[str]] = {} + # Gateway and service registries + self.gateways = {} + self.services = {} + self.gateway_endpoints = {} + self.service_endpoints = {} + + # Event system setup self.enable_events = enable_events + self.event_dispatcher = None - # Initialize event dispatcher if events are enabled - if self.enable_events: - self.event_dispatcher = event_dispatcher or EventDispatcher() - if not event_dispatcher: # Only initialize if we created it - self.event_dispatcher.init_app(self) - else: - self.event_dispatcher = None + if enable_events: + if event_dispatcher: + self.event_dispatcher = event_dispatcher + else: + from healthchain.gateway.events.dispatcher import EventDispatcher - # Add default middleware + self.event_dispatcher = EventDispatcher() + + # Initialize the event dispatcher + self.event_dispatcher.init_app(self) + + # Setup middleware if enable_cors: self.add_middleware( CORSMiddleware, @@ -113,12 +119,8 @@ def __init__( allow_headers=["*"], ) - # Add exception handlers - self.add_exception_handler( - RequestValidationError, self._validation_exception_handler - ) - self.add_exception_handler(HTTPException, self._http_exception_handler) - self.add_exception_handler(Exception, self._general_exception_handler) + # Add global exception handler + self.add_exception_handler(Exception, self._exception_handler) # Add default routes self._add_default_routes() @@ -126,12 +128,14 @@ def __init__( # Register self as a dependency for get_app self.dependency_overrides[get_app] = lambda: self - # Add a shutdown route - shutdown_router = APIRouter() - shutdown_router.add_api_route( - "/shutdown", self._shutdown, methods=["GET"], include_in_schema=False - ) - self.include_router(shutdown_router) + @asynccontextmanager + async def _lifespan(self, app: FastAPI): + """ + Handle application lifespan events (startup and shutdown). + """ + await self._startup() + yield + await self._shutdown() def get_event_dispatcher(self) -> Optional[EventDispatcher]: """Get the event dispatcher instance. @@ -143,206 +147,179 @@ def get_event_dispatcher(self) -> Optional[EventDispatcher]: """ return self.event_dispatcher - def get_gateway(self, gateway_name: str) -> Optional[BaseGateway]: - """Get a specific gateway by name. + def _register_component( + self, + component: Union[Type, object], + component_type: str, + path: Optional[str] = None, + use_events: Optional[bool] = None, + **options, + ) -> None: + """Register a healthcare component (gateway or service).""" - Args: - gateway_name: The name of the gateway to retrieve + use_events = use_events if use_events is not None else self.enable_events + registry, endpoints_registry, base_class = self._get_component_config( + component_type + ) - Returns: - The gateway instance or None if not found - """ - return self.gateways.get(gateway_name) + component_instance = self._get_component_instance( + component, base_class, use_events, **options + ) - def get_all_gateways(self) -> Dict[str, BaseGateway]: - """Get all registered gateways. + registry[component_instance.__class__.__name__] = component_instance - Returns: - Dictionary of all registered gateways - """ - return self.gateways + self._get_component_events(component_instance, use_events) + self._add_component_routes( + component_instance, component_type, endpoints_registry, path + ) - def register_gateway( + def _get_component_config(self, component_type: str) -> tuple: + """Get the appropriate registries and base class for a component type.""" + if component_type == "gateway": + return self.gateways, self.gateway_endpoints, BaseGateway + else: # service + return self.services, self.service_endpoints, BaseProtocolHandler + + def _get_component_instance( self, - gateway: Union[Type[BaseGateway], BaseGateway], - path: Optional[str] = None, - use_events: Optional[bool] = None, + component: Union[Type, object], + base_class: Type, + use_events: bool, **options, + ) -> object: + """Create or validate a component instance and return it with its name.""" + if isinstance(component, base_class): + # Already an instance + component_instance = component + else: + # Create a new instance from the class + if "use_events" not in options: + options["use_events"] = use_events + component_instance = component(**options) + + return component_instance + + def _get_component_events( + self, component_instance: object, use_events: bool ) -> None: - """ - Register a gateway with the API and mount its endpoints. + """Connect the event dispatcher to a component if events are enabled.""" + if ( + use_events + and self.event_dispatcher + and hasattr(component_instance, "events") + and hasattr(component_instance.events, "set_dispatcher") + ): + component_instance.events.set_dispatcher(self.event_dispatcher) - Args: - gateway: The gateway class or instance to register - path: Optional override for the gateway's mount path - use_events: Whether to enable events for this gateway (defaults to app setting) - **options: Options to pass to the constructor - """ - try: - # Determine if events should be used for this gateway - gateway_use_events = ( - self.enable_events if use_events is None else use_events - ) + def _add_component_routes( + self, + component: Union[BaseGateway, BaseProtocolHandler], + component_type: str, + endpoints_registry: Dict[str, set], + path: Optional[str] = None, + ) -> None: + """Add routes for a component.""" - gateway_name = gateway.__class__.__name__ + component_name = component.__class__.__name__ + endpoints_registry[component_name] = set() - # Create a new instance - if isinstance(gateway, BaseGateway): - gateway_instance = gateway - else: - if "use_events" not in options: - options["use_events"] = gateway_use_events - gateway_instance = gateway(**options) - - # Add to internal gateway registry - self.gateways[gateway_name] = gateway_instance - - # Provide event dispatcher to gateway if events are enabled - if ( - gateway_use_events - and self.event_dispatcher - and hasattr(gateway_instance, "set_event_dispatcher") - and callable(gateway_instance.set_event_dispatcher) - ): - gateway_instance.set_event_dispatcher(self.event_dispatcher) - - # Add gateway routes to FastAPI app - self._add_gateway_routes(gateway_instance, path) - - except Exception as e: - logger.error( - f"Failed to register gateway {gateway.__name__ if hasattr(gateway, '__name__') else gateway.__class__.__name__}: {str(e)}" + # Case 1: APIRouter-based components (gateways and CDSHooksService) + if isinstance(component, APIRouter): + self._register_api_router( + component, component_name, endpoints_registry, path ) - raise + return - def _add_gateway_routes( - self, gateway: BaseGateway, path: Optional[str] = None - ) -> None: - """ - Add gateway routes to the FastAPI app. + # Case 2: WSGI services (like NoteReaderService) - only for services + if ( + component_type == "service" + and hasattr(component, "create_wsgi_app") + and callable(component.create_wsgi_app) + ): + self._register_wsgi_service( + component, component_name, endpoints_registry, path + ) + return - Args: - gateway: The gateway to add routes for - path: Optional override for the mount path - """ - gateway_name = gateway.__class__.__name__ - self.gateway_endpoints[gateway_name] = set() - - # Case 1: Gateways with get_routes implementation - if hasattr(gateway, "get_routes") and callable(gateway.get_routes): - routes = gateway.get_routes(path) - if routes: - for route_path, methods, handler, kwargs in routes: - for method in methods: - self.add_api_route( - path=route_path, - endpoint=handler, - methods=[method], - **kwargs, - ) - self.gateway_endpoints[gateway_name].add( - f"{method}:{route_path}" - ) - logger.debug( - f"Registered {method} route {route_path} for {gateway_name}" - ) - - # Case 2: WSGI gateways (like SOAP) - elif hasattr(gateway, "create_wsgi_app") and callable(gateway.create_wsgi_app): - # For SOAP/WSGI gateways - wsgi_app = gateway.create_wsgi_app() - - # Determine mount path - mount_path = path - if mount_path is None and hasattr(gateway, "config"): - # Try to get the default path from the gateway config - mount_path = getattr(gateway.config, "default_mount_path", None) - if not mount_path: - mount_path = getattr(gateway.config, "base_path", None) - - if not mount_path: - # Fallback path based on gateway name - mount_path = f"/{gateway_name.lower().replace('gateway', '')}" - - # Mount the WSGI app - self.mount(mount_path, WSGIMiddleware(wsgi_app)) - self.gateway_endpoints[gateway_name].add(f"WSGI:{mount_path}") - logger.debug(f"Registered WSGI gateway {gateway_name} at {mount_path}") - - # Case 3: Gateway instances that are also APIRouters (like FHIRGateway) - elif isinstance(gateway, APIRouter): - # Include the router - self.include_router(gateway) - if hasattr(gateway, "routes"): - for route in gateway.routes: - for method in route.methods: - self.gateway_endpoints[gateway_name].add( - f"{method}:{route.path}" - ) - logger.debug( - f"Registered {method} route {route.path} from {gateway_name} router" - ) - else: - logger.debug(f"Registered {gateway_name} as router (routes unknown)") + # Case 3: Unsupported patterns + if component_type == "gateway": + logger.warning( + f"Gateway {component_name} is not an APIRouter and cannot be registered" + ) + else: + logger.warning( + f"Service {component_name} does not implement APIRouter or WSGI patterns. " + f"Services must either inherit from APIRouter or implement create_wsgi_app()." + ) - elif not ( - hasattr(gateway, "get_routes") - and callable(gateway.get_routes) - and gateway.get_routes(path) - ): - logger.warning(f"Gateway {gateway_name} does not provide any routes") + def _register_api_router( + self, + router: APIRouter, + component_name: str, + endpoints_registry: Dict[str, set], + path: Optional[str] = None, + ) -> None: + """Register an APIRouter component.""" + mount_path = path or router.prefix + if path: + router.prefix = mount_path + + self.include_router(router) + + if hasattr(router, "routes"): + for route in router.routes: + for method in route.methods: + endpoint = f"{method}:{route.path}" + endpoints_registry[component_name].add(endpoint) + logger.debug( + f"Registered {method} route {route.path} from {component_name} router" + ) + else: + logger.debug(f"Registered {component_name} as router (routes unknown)") - def register_router( - self, router: Union[APIRouter, Type, str, list], **options + def _register_wsgi_service( + self, + service: BaseProtocolHandler, + service_name: str, + endpoints_registry: Dict[str, set], + path: Optional[str] = None, ) -> None: - """ - Register one or more routers with the API. + """Register a WSGI service.""" + # Create WSGI app + wsgi_app = service.create_wsgi_app() + + # Determine mount path with fallback chain + mount_path = ( + path + or getattr(service.config, "default_mount_path", None) + or getattr(service.config, "base_path", None) + or f"/{service_name.lower().replace('service', '').replace('gateway', '')}" + ) - Args: - router: The router(s) to register (can be an instance, class, import path, or list of any of these) - **options: Options to pass to the router constructor or include_router - """ - try: - # Handle list of routers - if isinstance(router, list): - for r in router: - self.register_router(r, **options) - return - - # Case 1: Direct APIRouter instance - if isinstance(router, APIRouter): - self.include_router(router, **options) - return - - # Case 2: Router class that needs instantiation - if inspect.isclass(router): - instance = router(**options) - if not isinstance(instance, APIRouter): - raise TypeError( - f"Expected APIRouter instance, got {type(instance)}" - ) - self.include_router(instance) - return - - # Case 3: Import path as string - if isinstance(router, str): - module_path, class_name = router.rsplit(".", 1) - module = importlib.import_module(module_path) - router_class = getattr(module, class_name) - instance = router_class(**options) - if not isinstance(instance, APIRouter): - raise TypeError( - f"Expected APIRouter instance, got {type(instance)}" - ) - self.include_router(instance) - return + # Mount the WSGI app + self.mount(mount_path, WSGIMiddleware(wsgi_app)) + endpoints_registry[service_name].add(f"WSGI:{mount_path}") + logger.debug(f"Registered WSGI service {service_name} at {mount_path}") - raise TypeError(f"Unsupported router type: {type(router)}") + def register_gateway( + self, + gateway: Union[Type[BaseGateway], BaseGateway], + path: Optional[str] = None, + use_events: Optional[bool] = None, + **options, + ) -> None: + """Register a gateway with the API and mount its endpoints.""" + self._register_component(gateway, "gateway", path, use_events, **options) - except Exception as e: - router_name = getattr(router, "__name__", str(router)) - logger.error(f"Failed to register router {router_name}: {str(e)}") - raise + def register_service( + self, + service: Union[Type[BaseProtocolHandler], BaseProtocolHandler], + path: Optional[str] = None, + use_events: Optional[bool] = None, + **options, + ) -> None: + """Register a service with the API and mount its endpoints.""" + self._register_component(service, "service", path, use_events, **options) def _add_default_routes(self) -> None: """Add default routes for the API.""" @@ -355,6 +332,7 @@ async def root(): "version": self.version, "description": self.description, "gateways": list(self.gateways.keys()), + "services": list(self.services.keys()), } @self.get("/health") @@ -365,16 +343,24 @@ async def health_check(): @self.get("/metadata") async def metadata(): """Provide capability statement for the API.""" - gateway_info = {} - for name, gateway in self.gateways.items(): - # Try to get metadata if available - if hasattr(gateway, "get_metadata") and callable(gateway.get_metadata): - gateway_info[name] = gateway.get_metadata() - else: - gateway_info[name] = { - "type": name, - "endpoints": list(self.gateway_endpoints.get(name, set())), - } + + def get_component_info(components, endpoints_registry): + """Helper function to get metadata for components.""" + info = {} + for name, component in components.items(): + if hasattr(component, "get_metadata") and callable( + component.get_metadata + ): + info[name] = component.get_metadata() + else: + info[name] = { + "type": name, + "endpoints": list(endpoints_registry.get(name, set())), + } + return info + + gateway_info = get_component_info(self.gateways, self.gateway_endpoints) + service_info = get_component_info(self.services, self.service_endpoints) return { "resourceType": "CapabilityStatement", @@ -390,115 +376,72 @@ async def metadata(): "url": "/", }, "gateways": gateway_info, + "services": service_info, } - async def _validation_exception_handler( - self, request: Request, exc: RequestValidationError - ) -> JSONResponse: - """Handle validation exceptions.""" - return JSONResponse( - status_code=422, - content={"detail": exc.errors(), "body": exc.body}, - ) - - async def _http_exception_handler( - self, request: Request, exc: HTTPException - ) -> JSONResponse: - """Handle HTTP exceptions.""" - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail}, - headers=exc.headers, - ) - - async def _general_exception_handler( + async def _exception_handler( self, request: Request, exc: Exception ) -> JSONResponse: - """Handle general exceptions.""" - logger.exception("Unhandled exception", exc_info=exc) - return JSONResponse( - status_code=500, - content={"detail": "Internal server error"}, - ) - - @asynccontextmanager - async def lifespan(self, app: FastAPI): - """Lifecycle manager for the application.""" - self._startup() - yield - self._shutdown() - - def _startup(self) -> None: - """Display startup information and log registered endpoints.""" - healthchain_ascii = r""" + """Unified exception handler for all types of exceptions.""" + if isinstance(exc, RequestValidationError): + return JSONResponse( + status_code=422, + content={"detail": exc.errors(), "body": exc.body}, + ) + elif isinstance(exc, HTTPException): + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + headers=exc.headers, + ) + else: + logger.exception("Unhandled exception", exc_info=exc) + return JSONResponse( + status_code=500, + content={"detail": "Internal server error"}, + ) + async def _startup(self) -> None: + """Display startup information and initialize components.""" + # Display banner + banner = r""" __ __ ____ __ ________ _ / / / /__ ____ _/ / /_/ /_ / ____/ /_ ____ _(_)___ / /_/ / _ \/ __ `/ / __/ __ \/ / / __ \/ __ `/ / __ \ / __ / __/ /_/ / / /_/ / / / /___/ / / / /_/ / / / / / /_/ /_/\___/\__,_/_/\__/_/ /_/\____/_/ /_/\__,_/_/_/ /_/ - -""" # noqa: E501 - +""" colors = ["red", "yellow", "green", "cyan", "blue", "magenta"] - for i, line in enumerate(healthchain_ascii.split("\n")): - color = colors[i % len(colors)] - print(colored(line, color)) - - # Log registered gateways and endpoints - for name, gateway in self.gateways.items(): - endpoints = self.gateway_endpoints.get(name, set()) - for endpoint in endpoints: - print(f"{colored('HEALTHCHAIN', 'green')}: {endpoint}") - - print( - f"{colored('HEALTHCHAIN', 'green')}: See more details at {colored(self.docs_url, 'magenta')}" - ) - - def _shutdown(self): - """ - Shuts down server by sending a SIGTERM signal. - """ - os.kill(os.getpid(), signal.SIGTERM) - return JSONResponse(content={"message": "Server is shutting down..."}) - - -def create_app( - config: Optional[Dict] = None, - enable_events: bool = True, - event_dispatcher: Optional[EventDispatcher] = None, -) -> HealthChainAPI: - """ - Factory function to create a new HealthChainAPI application. - - This function provides a simple way to create a HealthChainAPI application - with standard middleware and basic configuration. It's useful for quickly - bootstrapping an application with sensible defaults. - - Args: - config: Optional configuration dictionary - enable_events: Whether to enable event dispatching functionality - event_dispatcher: Optional event dispatcher to use (for testing/DI) - - Returns: - Configured HealthChainAPI instance - """ - # Setup basic application config - app_config = { - "title": "HealthChain API", - "description": "Healthcare Integration API", - "version": "0.1.0", - "docs_url": "/docs", - "redoc_url": "/redoc", - "enable_events": enable_events, - "event_dispatcher": event_dispatcher, - } - - # Override with user config if provided - if config: - app_config.update(config) - - # Create application - app = HealthChainAPI(**app_config) - - return app + for i, line in enumerate(banner.split("\n")): + print(colored(line, colors[i % len(colors)])) + + # Log startup info + logger.info(f"πŸš€ Starting {self.title} v{self.version}") + logger.info(f"Gateways: {list(self.gateways.keys())}") + logger.info(f"Services: {list(self.services.keys())}") + + # Initialize components + for name, component in {**self.gateways, **self.services}.items(): + if hasattr(component, "startup") and callable(component.startup): + try: + await component.startup() + logger.debug(f"Initialized: {name}") + except Exception as e: + logger.warning(f"Failed to initialize {name}: {e}") + + logger.info(f"πŸ“– Docs: {self.docs_url}") + + async def _shutdown(self) -> None: + """Handle graceful shutdown.""" + logger.info("πŸ›‘ Shutting down...") + + # Shutdown all components + for name, component in {**self.services, **self.gateways}.items(): + if hasattr(component, "shutdown") and callable(component.shutdown): + try: + await component.shutdown() + logger.debug(f"Shutdown: {name}") + except Exception as e: + logger.warning(f"Failed to shutdown {name}: {e}") + + logger.info("βœ… Shutdown completed") diff --git a/healthchain/gateway/api/dependencies.py b/healthchain/gateway/api/dependencies.py index a123bf4f..b742c934 100644 --- a/healthchain/gateway/api/dependencies.py +++ b/healthchain/gateway/api/dependencies.py @@ -1,47 +1,35 @@ """ Dependency providers for HealthChainAPI. -This module contains FastAPI dependency injection providers that can be +This module contains dependency functions that can be used in route handlers to access HealthChainAPI components. """ -from typing import Dict, Optional, TypeVar, cast, Callable -from fastapi import Depends +from fastapi import Depends, HTTPException +from typing import Dict, Optional, Any from healthchain.gateway.api.protocols import ( HealthChainAPIProtocol, - GatewayProtocol, EventDispatcherProtocol, ) -# Type variable for type hinting -T = TypeVar("T", bound=GatewayProtocol) - -# Application instance dependency def get_app() -> HealthChainAPIProtocol: """Get the current HealthChainAPI application instance. - This is a dependency that returns the current application instance. - It should be overridden during application startup. + This is a placeholder that should be overridden by the actual + HealthChainAPI instance through dependency_overrides. Returns: The HealthChainAPI instance """ - raise RuntimeError( - "get_app dependency has not been overridden. " - "This usually happens when you try to use the dependency outside " - "of a request context or before the application has been initialized." - ) + raise RuntimeError("HealthChainAPI instance not available") def get_event_dispatcher( app: HealthChainAPIProtocol = Depends(get_app), ) -> Optional[EventDispatcherProtocol]: - """Get the event dispatcher from the app. - - This is a dependency that can be used in route handlers to access - the event dispatcher. + """Get the event dispatcher from the current application. Args: app: The HealthChainAPI instance @@ -54,29 +42,23 @@ def get_event_dispatcher( def get_gateway( gateway_name: str, app: HealthChainAPIProtocol = Depends(get_app) -) -> Optional[GatewayProtocol]: - """Get a specific gateway from the app. - - This is a dependency that can be used in route handlers to access - a specific gateway. +) -> Optional[Any]: + """Get a specific gateway by name. Args: gateway_name: The name of the gateway to retrieve app: The HealthChainAPI instance Returns: - The gateway or None if not found + The gateway instance or None if not found """ - return app.get_gateway(gateway_name) + return app.gateways.get(gateway_name) def get_all_gateways( app: HealthChainAPIProtocol = Depends(get_app), -) -> Dict[str, GatewayProtocol]: - """Get all registered gateways from the app. - - This is a dependency that can be used in route handlers to access - all gateways. +) -> Dict[str, Any]: + """Get all registered gateways. Args: app: The HealthChainAPI instance @@ -84,31 +66,79 @@ def get_all_gateways( Returns: Dictionary of all registered gateways """ - return app.get_all_gateways() + return app.gateways + + +def get_service( + service_name: str, app: HealthChainAPIProtocol = Depends(get_app) +) -> Optional[Any]: + """Get a specific service by name. + + Args: + service_name: The name of the service to retrieve + app: The HealthChainAPI instance + + Returns: + The service instance or None if not found + """ + return app.services.get(service_name) + + +def get_all_services( + app: HealthChainAPIProtocol = Depends(get_app), +) -> Dict[str, Any]: + """Get all registered services. + Args: + app: The HealthChainAPI instance + + Returns: + Dictionary of all registered services + """ + return app.services -def get_typed_gateway( - gateway_name: str, gateway_type: type[T] -) -> Callable[[], Optional[T]]: - """Create a dependency that returns a gateway of a specific type. - This creates a dependency that returns a gateway cast to a specific type, - which is useful when you need a specific gateway protocol. +def get_gateway_by_name(gateway_name: str): + """Dependency factory for getting a specific gateway by name. Args: - gateway_name: Name of the gateway to retrieve - gateway_type: The expected gateway type/protocol + gateway_name: The name of the gateway to retrieve Returns: - A dependency function that returns the typed gateway + A dependency function that returns the gateway """ - def _get_typed_gateway( + def _get_gateway_dependency( app: HealthChainAPIProtocol = Depends(get_app), - ) -> Optional[T]: # type: ignore - gateway = app.get_gateway(gateway_name) + ) -> Any: + gateway = app.gateways.get(gateway_name) if gateway is None: - return None - return cast(T, gateway) + raise HTTPException( + status_code=404, detail=f"Gateway '{gateway_name}' not found" + ) + return gateway + + return _get_gateway_dependency + - return _get_typed_gateway +def get_service_by_name(service_name: str): + """Dependency factory for getting a specific service by name. + + Args: + service_name: The name of the service to retrieve + + Returns: + A dependency function that returns the service + """ + + def _get_service_dependency( + app: HealthChainAPIProtocol = Depends(get_app), + ) -> Any: + service = app.services.get(service_name) + if service is None: + raise HTTPException( + status_code=404, detail=f"Service '{service_name}' not found" + ) + return service + + return _get_service_dependency diff --git a/healthchain/gateway/api/protocols.py b/healthchain/gateway/api/protocols.py index 7ac44017..fdbfcd0c 100644 --- a/healthchain/gateway/api/protocols.py +++ b/healthchain/gateway/api/protocols.py @@ -6,9 +6,22 @@ typing and better type checking. """ -from typing import Dict, Optional, Set, Any, Protocol, Callable, Union +from typing import ( + Dict, + Optional, + Set, + Any, + Protocol, + Callable, + Union, + Type, + TYPE_CHECKING, +) -from healthchain.gateway.events.dispatcher import EHREvent +from healthchain.gateway.events.dispatcher import EHREvent, EHREventType + +if TYPE_CHECKING: + from fastapi import FastAPI, APIRouter class EventDispatcherProtocol(Protocol): @@ -16,103 +29,39 @@ class EventDispatcherProtocol(Protocol): async def publish( self, event: EHREvent, middleware_id: Optional[int] = None - ) -> bool: + ) -> None: """Dispatch an event to registered handlers. Args: event: The event to publish middleware_id: Optional middleware ID - - Returns: - True if the event was successfully dispatched - """ - ... - - def init_app(self, app: Any) -> None: - """Initialize the dispatcher with an application. - - Args: - app: Application instance to initialize with - """ - ... - - def register_handler(self, event_name: str, handler: Callable) -> None: - """Register a handler for a specific event. - - Args: - event_name: The name of the event to handle - handler: The handler function - """ - ... - - -class GatewayProtocol(Protocol): - """Protocol defining the interface for gateways.""" - - def get_metadata(self) -> Dict[str, Any]: - """Get metadata about the gateway. - - Returns: - Dictionary with gateway metadata - """ - ... - - def set_event_dispatcher(self, dispatcher: EventDispatcherProtocol) -> None: - """Set the event dispatcher for this gateway. - - Args: - dispatcher: The event dispatcher to use """ ... - -class FHIRGatewayProtocol(GatewayProtocol, Protocol): - """Protocol defining the interface for FHIR gateways.""" - - async def search( - self, resource_type: str, params: Dict[str, Any] - ) -> Dict[str, Any]: - """Search for FHIR resources. + def init_app(self, app: "FastAPI") -> None: + """Initialize the dispatcher with a FastAPI application. Args: - resource_type: The FHIR resource type - params: Search parameters - - Returns: - FHIR Bundle containing search results + app: FastAPI application instance to initialize with """ ... - async def read(self, resource_type: str, resource_id: str) -> Dict[str, Any]: - """Read a FHIR resource. + def register_handler(self, event_type: EHREventType) -> Callable: + """Register a handler for a specific event type. Args: - resource_type: The FHIR resource type - resource_id: The resource ID + event_type: The EHR event type to handle Returns: - FHIR resource + Decorator function for registering handlers """ ... - -class SOAPGatewayProtocol(GatewayProtocol, Protocol): - """Protocol defining the interface for SOAP gateways.""" - - def create_wsgi_app(self) -> Any: - """Create a WSGI application for the SOAP service. + def register_default_handler(self) -> Callable: + """Register a handler for all events. Returns: - WSGI application - """ - ... - - def register_method(self, method_name: str, handler: Callable) -> None: - """Register a method handler for the SOAP service. - - Args: - method_name: The SOAP method name - handler: The handler function + Decorator function for registering handlers """ ... @@ -120,8 +69,10 @@ def register_method(self, method_name: str, handler: Callable) -> None: class HealthChainAPIProtocol(Protocol): """Protocol defining the interface for the HealthChainAPI.""" - gateways: Dict[str, GatewayProtocol] + gateways: Dict[str, Any] + services: Dict[str, Any] gateway_endpoints: Dict[str, Set[str]] + service_endpoints: Dict[str, Set[str]] enable_events: bool event_dispatcher: Optional[EventDispatcherProtocol] @@ -133,47 +84,107 @@ def get_event_dispatcher(self) -> Optional[EventDispatcherProtocol]: """ ... - def get_gateway(self, gateway_name: str) -> Optional[GatewayProtocol]: - """Get a gateway by name. + def register_gateway( + self, + gateway: Union[Type[Any], Any], + path: Optional[str] = None, + use_events: Optional[bool] = None, + **options, + ) -> None: + """Register a gateway. Args: - gateway_name: The name of the gateway - - Returns: - The gateway or None if not found - """ - ... - - def get_all_gateways(self) -> Dict[str, GatewayProtocol]: - """Get all registered gateways. - - Returns: - Dictionary of all registered gateways + gateway: The gateway to register (class or instance) + path: Optional mount path + use_events: Whether to use events + **options: Additional options """ ... - def register_gateway( + def register_service( self, - gateway: Union[GatewayProtocol, Any], + service: Union[Type[Any], Any], path: Optional[str] = None, use_events: Optional[bool] = None, **options, ) -> None: - """Register a gateway. + """Register a service. Args: - gateway: The gateway to register + service: The service to register (class or instance) path: Optional mount path use_events: Whether to use events **options: Additional options """ ... - def register_router(self, router: Any, **options) -> None: + def register_router(self, router: "APIRouter", **options) -> None: """Register a router. Args: - router: The router to register + router: The APIRouter instance to register **options: Additional options """ ... + + +# Protocols below are primarily used for testing + + +class FHIRConnectionManagerProtocol(Protocol): + """Protocol for FHIR connection management.""" + + def add_source(self, name: str, connection_string: str) -> None: + """Add a FHIR data source.""" + ... + + async def get_client(self, source: str = None) -> "FHIRServerInterfaceProtocol": + """Get a FHIR client for the specified source.""" + ... + + def get_pool_status(self) -> Dict[str, Any]: + """Get connection pool status.""" + ... + + async def close(self) -> None: + """Close all connections.""" + ... + + @property + def sources(self) -> Dict[str, Any]: + """Get registered sources.""" + ... + + +class FHIRServerInterfaceProtocol(Protocol): + """Protocol for FHIR server interface.""" + + async def read(self, resource_type: Type[Any], resource_id: str) -> Any: + """Read a FHIR resource.""" + ... + + async def search( + self, resource_type: Type[Any], params: Dict[str, Any] = None + ) -> Any: + """Search for FHIR resources.""" + ... + + async def create(self, resource: Any) -> Any: + """Create a FHIR resource.""" + ... + + async def update(self, resource: Any) -> Any: + """Update a FHIR resource.""" + ... + + async def delete(self, resource_type: Type[Any], resource_id: str) -> bool: + """Delete a FHIR resource.""" + ... + + async def transaction(self, bundle: Any) -> Any: + """Execute a transaction bundle.""" + ... + + async def capabilities(self) -> Any: + """Get server capabilities.""" + ... diff --git a/healthchain/gateway/clients/__init__.py b/healthchain/gateway/clients/__init__.py new file mode 100644 index 00000000..f1d7ace3 --- /dev/null +++ b/healthchain/gateway/clients/__init__.py @@ -0,0 +1,13 @@ +from .fhir import FHIRServerInterface, AsyncFHIRClient, create_fhir_client +from .auth import OAuth2TokenManager, FHIRAuthConfig, parse_fhir_auth_connection_string +from .pool import FHIRClientPool + +__all__ = [ + "FHIRServerInterface", + "AsyncFHIRClient", + "create_fhir_client", + "OAuth2TokenManager", + "FHIRAuthConfig", + "parse_fhir_auth_connection_string", + "FHIRClientPool", +] diff --git a/healthchain/gateway/clients/auth.py b/healthchain/gateway/clients/auth.py new file mode 100644 index 00000000..ba0eae5f --- /dev/null +++ b/healthchain/gateway/clients/auth.py @@ -0,0 +1,465 @@ +""" +OAuth2 authentication manager for FHIR clients. + +This module provides OAuth2 client credentials flow for automatic token +management and refresh. +""" + +import logging +import os +import uuid +import asyncio +import httpx + +from typing import Dict, Optional, Any +from datetime import datetime, timedelta, timezone +from pydantic import BaseModel + + +logger = logging.getLogger(__name__) + + +class OAuth2Config(BaseModel): + """OAuth2 configuration for client credentials flow.""" + + client_id: str + token_url: str + client_secret: Optional[str] = None # Client secret string for standard flow + client_secret_path: Optional[str] = ( + None # Path to private key file for JWT assertion + ) + scope: Optional[str] = None + audience: Optional[str] = None # For Epic and other systems that require audience + use_jwt_assertion: bool = False # Use JWT client assertion instead of client secret + + def model_post_init(self, __context) -> None: + """Validate that exactly one of client_secret or client_secret_path is provided.""" + if not self.client_secret and not self.client_secret_path: + raise ValueError( + "Either client_secret or client_secret_path must be provided" + ) + + if self.client_secret and self.client_secret_path: + raise ValueError("Cannot provide both client_secret and client_secret_path") + + if self.use_jwt_assertion and not self.client_secret_path: + raise ValueError( + "use_jwt_assertion=True requires client_secret_path to be set" + ) + + if not self.use_jwt_assertion and self.client_secret_path: + raise ValueError( + "client_secret_path can only be used with use_jwt_assertion=True" + ) + + @property + def secret_value(self) -> str: + """Get the secret value, reading from file if necessary.""" + if self.client_secret_path: + try: + with open(self.client_secret_path, "rb") as f: + return f.read().decode("utf-8") + except Exception as e: + raise ValueError( + f"Failed to read secret from {self.client_secret_path}: {e}" + ) + return self.client_secret + + +class TokenInfo(BaseModel): + """Token information with expiry tracking.""" + + access_token: str + token_type: str = "Bearer" + expires_in: int + scope: Optional[str] = None + expires_at: datetime + + @classmethod + def from_response(cls, response_data: Dict[str, Any]) -> "TokenInfo": + """Create TokenInfo from OAuth2.0 token response.""" + expires_at = datetime.now() + timedelta( + seconds=response_data.get("expires_in", 3600) + ) + + return cls( + access_token=response_data["access_token"], + token_type=response_data.get("token_type", "Bearer"), + expires_in=response_data.get("expires_in", 3600), + scope=response_data.get("scope"), + expires_at=expires_at, + ) + + def is_expired(self, buffer_seconds: int = 300) -> bool: + """Check if token is expired or will expire within buffer time.""" + return datetime.now() + timedelta(seconds=buffer_seconds) >= self.expires_at + + +class OAuth2TokenManager: + """ + Manages OAuth2.0 tokens with automatic refresh for FHIR clients. + + Supports client credentials flow commonly used in healthcare integrations. + """ + + def __init__(self, config: OAuth2Config, refresh_buffer_seconds: int = 300): + """ + Initialize OAuth2 token manager. + + Args: + config: OAuth2 configuration + refresh_buffer_seconds: Refresh token this many seconds before expiry + """ + self.config = config + self.refresh_buffer_seconds = refresh_buffer_seconds + self._token: Optional[TokenInfo] = None + self._refresh_lock: Optional[asyncio.Lock] = None + + def _get_refresh_lock(self) -> asyncio.Lock: + """Get or create the refresh lock when an event loop is running.""" + if self._refresh_lock is None: + # Only create the lock when we have a running event loop + # This ensures Python 3.9 compatibility + self._refresh_lock = asyncio.Lock() + return self._refresh_lock + + async def get_access_token(self) -> str: + """ + Get a valid access token, refreshing if necessary. + + Returns: + Valid Bearer access token + """ + async with self._get_refresh_lock(): + if self._token is None or self._token.is_expired( + self.refresh_buffer_seconds + ): + await self._refresh_token() + + return self._token.access_token + + async def _refresh_token(self): + """Refresh the access token using client credentials flow.""" + logger.debug(f"Refreshing token from {self.config.token_url}") + + # Check if client_secret is a private key path or JWT assertion is enabled + if self.config.use_jwt_assertion or self.config.client_secret_path: + # Use JWT client assertion flow (Epic/SMART on FHIR style) + jwt_assertion = self._create_jwt_assertion() + token_data = { + "grant_type": "client_credentials", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": jwt_assertion, + } + else: + # Standard client credentials flow + token_data = { + "grant_type": "client_credentials", + "client_id": self.config.client_id, + "client_secret": self.config.secret_value, + } + + if self.config.scope: + token_data["scope"] = self.config.scope + + if self.config.audience: + token_data["audience"] = self.config.audience + + # Make token request + async with httpx.AsyncClient() as client: + try: + response = await client.post( + self.config.token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30, + ) + response.raise_for_status() + + response_data = response.json() + self._token = TokenInfo.from_response(response_data) + + logger.debug( + f"Token refreshed successfully, expires at {self._token.expires_at}" + ) + + except httpx.HTTPStatusError as e: + logger.error( + f"Token refresh failed: {e.response.status_code} {e.response.text}" + ) + raise Exception(f"Failed to refresh token: {e.response.status_code}") + except Exception as e: + logger.error(f"Token refresh error: {str(e)}") + raise + + def invalidate_token(self): + """Invalidate the current token to force refresh on next request.""" + self._token = None + + def _create_jwt_assertion(self) -> str: + """Create JWT client assertion for SMART on FHIR authentication.""" + from jwt import JWT, jwk_from_pem + + # Generate unique JTI + jti = str(uuid.uuid4()) + + # Load private key (client_secret should be path to private key for JWT assertion) + try: + with open(self.config.client_secret_path, "rb") as f: + private_key_data = f.read() + key = jwk_from_pem(private_key_data) + except Exception as e: + raise Exception( + f"Failed to load private key from {os.path.basename(self.config.client_secret_path)}: {e}" + ) + + # Create JWT claims matching the script + now = datetime.now(timezone.utc) + claims = { + "iss": self.config.client_id, # Issuer (client ID) + "sub": self.config.client_id, # Subject (client ID) + "aud": self.config.token_url, # Audience (token endpoint) + "jti": jti, # Unique token identifier + "iat": int(now.timestamp()), # Issued at + "exp": int( + (now + timedelta(minutes=5)).timestamp() + ), # Expires in 5 minutes + } + + # Create and sign JWT + signed_jwt = JWT().encode(claims, key, alg="RS384") + + return signed_jwt + + +class FHIRAuthConfig(BaseModel): + """Configuration for FHIR server authentication.""" + + # OAuth2 settings + client_id: str + client_secret: Optional[str] = None # Client secret string for standard flow + client_secret_path: Optional[str] = ( + None # Path to private key file for JWT assertion + ) + token_url: str + scope: Optional[str] = "system/*.read system/*.write" + audience: Optional[str] = None + use_jwt_assertion: bool = False # Use JWT client assertion (Epic/SMART style) + + # Connection settings + base_url: str + timeout: int = 30 + verify_ssl: bool = True + + def model_post_init(self, __context) -> None: + """Validate that exactly one of client_secret or client_secret_path is provided.""" + if not self.client_secret and not self.client_secret_path: + raise ValueError( + "Either client_secret or client_secret_path must be provided" + ) + + if self.client_secret and self.client_secret_path: + raise ValueError("Cannot provide both client_secret and client_secret_path") + + if self.use_jwt_assertion and not self.client_secret_path: + raise ValueError( + "use_jwt_assertion=True requires client_secret_path to be set" + ) + + if not self.use_jwt_assertion and self.client_secret_path: + raise ValueError( + "client_secret_path can only be used with use_jwt_assertion=True" + ) + + def to_oauth2_config(self) -> OAuth2Config: + """Convert to OAuth2Config for token manager.""" + return OAuth2Config( + client_id=self.client_id, + client_secret=self.client_secret, + client_secret_path=self.client_secret_path, + token_url=self.token_url, + scope=self.scope, + audience=self.audience, + use_jwt_assertion=self.use_jwt_assertion, + ) + + @classmethod + def from_env(cls, env_prefix: str) -> "FHIRAuthConfig": + """ + Create FHIRAuthConfig from environment variables. + + Args: + env_prefix: Environment variable prefix (e.g., "EPIC") + + Expected environment variables: + {env_prefix}_CLIENT_ID + {env_prefix}_CLIENT_SECRET (or {env_prefix}_CLIENT_SECRET_PATH) + {env_prefix}_TOKEN_URL + {env_prefix}_BASE_URL + {env_prefix}_SCOPE (optional) + {env_prefix}_AUDIENCE (optional) + {env_prefix}_TIMEOUT (optional, default: 30) + {env_prefix}_VERIFY_SSL (optional, default: true) + {env_prefix}_USE_JWT_ASSERTION (optional, default: false) + + Returns: + FHIRAuthConfig instance + + Example: + # Set environment variables: + # EPIC_CLIENT_ID=app123 + # EPIC_CLIENT_SECRET=secret456 + # EPIC_TOKEN_URL=https://epic.com/oauth2/token + # EPIC_BASE_URL=https://epic.com/api/FHIR/R4 + + config = FHIRAuthConfig.from_env("EPIC") + """ + import os + + # Read required environment variables + client_id = os.getenv(f"{env_prefix}_CLIENT_ID") + client_secret = os.getenv(f"{env_prefix}_CLIENT_SECRET") + client_secret_path = os.getenv(f"{env_prefix}_CLIENT_SECRET_PATH") + token_url = os.getenv(f"{env_prefix}_TOKEN_URL") + base_url = os.getenv(f"{env_prefix}_BASE_URL") + + if not all([client_id, token_url, base_url]): + missing = [ + var + for var, val in [ + (f"{env_prefix}_CLIENT_ID", client_id), + (f"{env_prefix}_TOKEN_URL", token_url), + (f"{env_prefix}_BASE_URL", base_url), + ] + if not val + ] + raise ValueError(f"Missing required environment variables: {missing}") + + # Read optional environment variables + scope = os.getenv(f"{env_prefix}_SCOPE", "system/*.read system/*.write") + audience = os.getenv(f"{env_prefix}_AUDIENCE") + timeout = int(os.getenv(f"{env_prefix}_TIMEOUT", "30")) + verify_ssl = os.getenv(f"{env_prefix}_VERIFY_SSL", "true").lower() == "true" + use_jwt_assertion = ( + os.getenv(f"{env_prefix}_USE_JWT_ASSERTION", "false").lower() == "true" + ) + + return cls( + client_id=client_id, + client_secret=client_secret, + client_secret_path=client_secret_path, + token_url=token_url, + base_url=base_url, + scope=scope, + audience=audience, + timeout=timeout, + verify_ssl=verify_ssl, + use_jwt_assertion=use_jwt_assertion, + ) + + def to_connection_string(self) -> str: + """ + Convert FHIRAuthConfig to connection string format. + + Returns: + Connection string in fhir:// format + + Example: + config = FHIRAuthConfig(...) + connection_string = config.to_connection_string() + # Returns: "fhir://hostname/path?client_id=...&token_url=..." + """ + # Extract hostname and path from base_url + import urllib.parse + + parsed_base = urllib.parse.urlparse(self.base_url) + + # Build query parameters + params = { + "client_id": self.client_id, + "token_url": self.token_url, + } + + # Add secret (either client_secret or client_secret_path) + if self.client_secret: + params["client_secret"] = self.client_secret + elif self.client_secret_path: + params["client_secret_path"] = self.client_secret_path + + # Add optional parameters + if self.scope: + params["scope"] = self.scope + if self.audience: + params["audience"] = self.audience + if self.timeout != 30: + params["timeout"] = str(self.timeout) + if not self.verify_ssl: + params["verify_ssl"] = "false" + if self.use_jwt_assertion: + params["use_jwt_assertion"] = "true" + + # Build connection string + query_string = urllib.parse.urlencode(params) + return f"fhir://{parsed_base.netloc}{parsed_base.path}?{query_string}" + + +def parse_fhir_auth_connection_string(connection_string: str) -> FHIRAuthConfig: + """ + Parse a FHIR connection string into authentication configuration. + + Format: fhir://hostname:port/path?client_id=xxx&client_secret=xxx&token_url=xxx&scope=xxx + Or for JWT: fhir://hostname:port/path?client_id=xxx&client_secret_path=xxx&token_url=xxx&use_jwt_assertion=true + + Args: + connection_string: FHIR connection string with OAuth2 credentials + + Returns: + FHIRAuthConfig with parsed settings + + Raises: + ValueError: If connection string is invalid or missing required parameters + """ + import urllib.parse + + if not connection_string.startswith("fhir://"): + raise ValueError("Connection string must start with fhir://") + + parsed = urllib.parse.urlparse(connection_string) + params = dict(urllib.parse.parse_qsl(parsed.query)) + + # Validate required parameters + required_params = ["client_id", "token_url"] + missing_params = [param for param in required_params if param not in params] + + if missing_params: + raise ValueError(f"Missing required parameters: {missing_params}") + + # Check that exactly one of client_secret or client_secret_path is provided + has_secret = "client_secret" in params + has_secret_path = "client_secret_path" in params + + if not has_secret and not has_secret_path: + raise ValueError( + "Either 'client_secret' or 'client_secret_path' parameter must be provided" + ) + + if has_secret and has_secret_path: + raise ValueError( + "Cannot provide both 'client_secret' and 'client_secret_path' parameters" + ) + + # Build base URL + base_url = f"https://{parsed.netloc}{parsed.path}" + + return FHIRAuthConfig( + client_id=params["client_id"], + client_secret=params.get("client_secret"), + client_secret_path=params.get("client_secret_path"), + token_url=params["token_url"], + scope=params.get("scope", "system/*.read system/*.write"), + audience=params.get("audience"), + base_url=base_url, + timeout=int(params.get("timeout", 30)), + verify_ssl=params.get("verify_ssl", "true").lower() == "true", + use_jwt_assertion=params.get("use_jwt_assertion", "false").lower() == "true", + ) diff --git a/healthchain/gateway/clients/fhir.py b/healthchain/gateway/clients/fhir.py new file mode 100644 index 00000000..2ca6513d --- /dev/null +++ b/healthchain/gateway/clients/fhir.py @@ -0,0 +1,389 @@ +""" +FHIR client interfaces and implementations. + +This module provides standardized interfaces for different FHIR client libraries. +""" + +import logging +import json +import httpx + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, Union, Type +from urllib.parse import urljoin, urlencode +from functools import lru_cache + +from fhir.resources.resource import Resource +from fhir.resources.bundle import Bundle +from fhir.resources.capabilitystatement import CapabilityStatement + +from healthchain.gateway.clients.auth import OAuth2TokenManager, FHIRAuthConfig + + +logger = logging.getLogger(__name__) + + +def create_fhir_client( + auth_config: FHIRAuthConfig, + limits: httpx.Limits = None, + **additional_params, +) -> "FHIRServerInterface": + """ + Factory function to create and configure a FHIR server interface using OAuth2.0 + + Args: + auth_config: OAuth2.0 authentication configuration + limits: httpx connection limits for pooling + **additional_params: Additional parameters for the client + + Returns: + A configured FHIRServerInterface implementation + """ + logger.debug(f"Creating FHIR server with OAuth2.0 for {auth_config.base_url}") + return AsyncFHIRClient(auth_config=auth_config, limits=limits, **additional_params) + + +class FHIRClientError(Exception): + """Base exception for FHIR client errors.""" + + def __init__( + self, message: str, status_code: int = None, response_data: dict = None + ): + self.status_code = status_code + self.response_data = response_data + super().__init__(message) + + +class FHIRServerInterface(ABC): + """ + Interface for FHIR servers. + + Provides a standardized interface for interacting with FHIR servers + using different client libraries. + """ + + @abstractmethod + async def read( + self, resource_type: Union[str, Type[Resource]], resource_id: str + ) -> Resource: + """Read a specific resource by ID.""" + pass + + @abstractmethod + async def create(self, resource: Resource) -> Resource: + """Create a new resource.""" + pass + + @abstractmethod + async def update(self, resource: Resource) -> Resource: + """Update an existing resource.""" + pass + + @abstractmethod + async def delete( + self, resource_type: Union[str, Type[Resource]], resource_id: str + ) -> bool: + """Delete a resource.""" + pass + + @abstractmethod + async def search( + self, + resource_type: Union[str, Type[Resource]], + params: Optional[Dict[str, Any]] = None, + ) -> Bundle: + """Search for resources.""" + pass + + @abstractmethod + async def transaction(self, bundle: Bundle) -> Bundle: + """Execute a transaction bundle.""" + pass + + @abstractmethod + async def capabilities(self) -> CapabilityStatement: + """Get the capabilities of the FHIR server.""" + pass + + +class AsyncFHIRClient(FHIRServerInterface): + """ + Async FHIR client optimized for HealthChain gateway use cases. + + - Uses fhir.resources for validation + - Supports JWT Bearer token authentication + - Async-first with httpx + """ + + def __init__( + self, + auth_config: FHIRAuthConfig, + limits: httpx.Limits = None, + **kwargs, + ): + """ + Initialize the FHIR client with OAuth2.0 authentication. + + Args: + auth_config: OAuth2.0 authentication configuration + limits: httpx connection limits for pooling + **kwargs: Additional parameters passed to httpx.AsyncClient + """ + self.base_url = auth_config.base_url.rstrip("/") + "/" + self.timeout = auth_config.timeout + self.verify_ssl = auth_config.verify_ssl + self.token_manager = OAuth2TokenManager(auth_config.to_oauth2_config()) + + # Setup base headers + self.base_headers = { + "Accept": "application/fhir+json", + "Content-Type": "application/fhir+json", + } + + # Create httpx client with connection pooling and additional kwargs + client_kwargs = {"timeout": self.timeout, "verify": self.verify_ssl} + if limits is not None: + client_kwargs["limits"] = limits + + # Pass through additional kwargs to httpx.AsyncClient + client_kwargs.update(kwargs) + + self.client = httpx.AsyncClient(**client_kwargs) + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + + async def close(self): + """Close the HTTP client.""" + await self.client.aclose() + + async def _get_headers(self) -> Dict[str, str]: + """Get headers with fresh OAuth2.0 token.""" + headers = self.base_headers.copy() + token = await self.token_manager.get_access_token() + headers["Authorization"] = f"Bearer {token}" + return headers + + def _build_url(self, path: str, params: Dict[str, Any] = None) -> str: + """Build a complete URL with optional query parameters.""" + url = urljoin(self.base_url, path) + if params: + # Filter out None values and convert to strings + clean_params = {k: str(v) for k, v in params.items() if v is not None} + if clean_params: + url += "?" + urlencode(clean_params) + return url + + def _handle_response(self, response: httpx.Response) -> dict: + """Handle HTTP response and convert to dict.""" + try: + data = response.json() + except json.JSONDecodeError: + raise FHIRClientError( + f"Invalid JSON response: {response.text}", + status_code=response.status_code, + ) + + if not response.is_success: + error_msg = f"FHIR request failed: {response.status_code}" + if isinstance(data, dict) and "issue" in data: + # FHIR OperationOutcome format + issues = data.get("issue", []) + if issues: + error_msg += f" - {issues[0].get('diagnostics', 'Unknown error')}" + + raise FHIRClientError( + error_msg, status_code=response.status_code, response_data=data + ) + + return data + + @lru_cache(maxsize=128) + def _resolve_resource_type( + self, resource_type: Union[str, Type[Resource]] + ) -> tuple[str, Type[Resource]]: + """ + Resolve FHIR resource type to string name and class. Cached with LRU. + + Args: + resource_type: FHIR resource type or class + + Returns: + Tuple of (type_name: str, resource_class: Type[Resource]) + """ + if hasattr(resource_type, "__name__"): + # It's already a class + type_name = resource_type.__name__ + resource_class = resource_type + else: + # It's a string, need to dynamically import + type_name = str(resource_type) + module_name = f"fhir.resources.{type_name.lower()}" + module = __import__(module_name, fromlist=[type_name]) + resource_class = getattr(module, type_name) + + return type_name, resource_class + + async def capabilities(self) -> CapabilityStatement: + """ + Fetch the server's CapabilityStatement. + + Returns: + CapabilityStatement resource + """ + headers = await self._get_headers() + response = await self.client.get(self._build_url("metadata"), headers=headers) + data = self._handle_response(response) + return CapabilityStatement(**data) + + async def read( + self, resource_type: Union[str, Type[Resource]], resource_id: str + ) -> Resource: + """ + Read a specific resource by ID. + + Args: + resource_type: FHIR resource type or class + resource_id: Resource ID + + Returns: + Resource instance + """ + type_name, resource_class = self._resolve_resource_type(resource_type) + url = self._build_url(f"{type_name}/{resource_id}") + logger.debug(f"Sending GET request to {url}") + + headers = await self._get_headers() + response = await self.client.get(url, headers=headers) + data = self._handle_response(response) + + return resource_class(**data) + + async def search( + self, resource_type: Union[str, Type[Resource]], params: Dict[str, Any] = None + ) -> Bundle: + """ + Search for resources. + + Args: + resource_type: FHIR resource type or class + params: Search parameters + + Returns: + Bundle containing search results + """ + type_name, _ = self._resolve_resource_type(resource_type) + url = self._build_url(type_name, params) + logger.debug(f"Sending GET request to {url}") + + headers = await self._get_headers() + response = await self.client.get(url, headers=headers) + data = self._handle_response(response) + + return Bundle(**data) + + async def create(self, resource: Resource) -> Resource: + """ + Create a new resource. + + Args: + resource: Resource to create + + Returns: + Created resource with server-assigned ID + """ + type_name, resource_class = self._resolve_resource_type( + resource.__resource_type__ + ) + url = self._build_url(type_name) + logger.debug(f"Sending POST request to {url}") + + headers = await self._get_headers() + response = await self.client.post( + url, content=resource.model_dump_json(), headers=headers + ) + data = self._handle_response(response) + + # Return the same resource type + return resource_class(**data) + + async def update(self, resource: Resource) -> Resource: + """ + Update an existing resource. + + Args: + resource: Resource to update (must have ID) + + Returns: + Updated resource + """ + if not resource.id: + raise ValueError("Resource must have an ID for update") + + type_name, resource_class = self._resolve_resource_type( + resource.__resource_type__ + ) + url = self._build_url(f"{type_name}/{resource.id}") + logger.debug(f"Sending PUT request to {url}") + + headers = await self._get_headers() + response = await self.client.put( + url, content=resource.model_dump_json(), headers=headers + ) + data = self._handle_response(response) + + # Return the same resource type + return resource_class(**data) + + async def delete( + self, resource_type: Union[str, Type[Resource]], resource_id: str + ) -> bool: + """ + Delete a resource. + + Args: + resource_type: FHIR resource type or class + resource_id: Resource ID to delete + + Returns: + True if successful + """ + type_name, _ = self._resolve_resource_type(resource_type) + url = self._build_url(f"{type_name}/{resource_id}") + logger.debug(f"Sending DELETE request to {url}") + + headers = await self._get_headers() + response = await self.client.delete(url, headers=headers) + + # Delete operations typically return 204 No Content + if response.status_code in (200, 204): + return True + + self._handle_response(response) # This will raise an error + return False + + async def transaction(self, bundle: Bundle) -> Bundle: + """ + Execute a transaction bundle. + + Args: + bundle: Transaction bundle + + Returns: + Response bundle + """ + url = self._build_url("") # Base URL for transaction + logger.debug(f"Sending POST request to {url}") + + headers = await self._get_headers() + response = await self.client.post( + url, content=bundle.model_dump_json(), headers=headers + ) + data = self._handle_response(response) + + return Bundle(**data) diff --git a/healthchain/gateway/clients/pool.py b/healthchain/gateway/clients/pool.py new file mode 100644 index 00000000..ae9da57d --- /dev/null +++ b/healthchain/gateway/clients/pool.py @@ -0,0 +1,89 @@ +import httpx + +from typing import Any, Callable, Dict +from healthchain.gateway.clients import FHIRServerInterface + + +class FHIRClientPool: + """ + Manages FHIR client instances with connection pooling using httpx. + Handles connection lifecycle, timeouts, and resource cleanup. + """ + + def __init__( + self, + max_connections: int = 100, + max_keepalive_connections: int = 20, + keepalive_expiry: float = 5.0, + ): + """ + Initialize the FHIR client pool. + + Args: + max_connections: Maximum number of total connections + max_keepalive_connections: Maximum number of keep-alive connections + keepalive_expiry: How long to keep connections alive (seconds) + """ + self._clients: Dict[str, FHIRServerInterface] = {} + self._client_limits = httpx.Limits( + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + ) + + async def get_client( + self, connection_string: str, client_factory: Callable + ) -> FHIRServerInterface: + """ + Get a FHIR client for the given connection string. + + Args: + connection_string: FHIR connection string + client_factory: Factory function to create new clients + + Returns: + FHIRServerInterface: A FHIR client with pooled connections + """ + if connection_string not in self._clients: + # Create new client with connection pooling + self._clients[connection_string] = client_factory( + connection_string, limits=self._client_limits + ) + + return self._clients[connection_string] + + async def close_all(self): + """Close all client connections.""" + for client in self._clients.values(): + if hasattr(client, "close"): + await client.close() + self._clients.clear() + + def get_pool_stats(self) -> Dict[str, Any]: + """Get connection pool statistics.""" + stats = { + "total_clients": len(self._clients), + "limits": { + "max_connections": self._client_limits.max_connections, + "max_keepalive_connections": self._client_limits.max_keepalive_connections, + "keepalive_expiry": self._client_limits.keepalive_expiry, + }, + "clients": {}, + } + + for conn_str, client in self._clients.items(): + # Try to get httpx client stats if available + client_stats = {} + if hasattr(client, "client") and hasattr(client.client, "_pool"): + pool = client.client._pool + client_stats.update( + { + "active_connections": len(pool._pool), + "available_connections": len( + [c for c in pool._pool if c.is_available()] + ), + } + ) + stats["clients"][conn_str] = client_stats + + return stats diff --git a/healthchain/gateway/core/__init__.py b/healthchain/gateway/core/__init__.py index 90e5d606..e8dab522 100644 --- a/healthchain/gateway/core/__init__.py +++ b/healthchain/gateway/core/__init__.py @@ -5,16 +5,34 @@ that define the gateway architecture. """ -from .base import BaseGateway, GatewayConfig +from .base import BaseGateway, GatewayConfig, EventCapability +from .connection import FHIRConnectionManager +from .errors import FHIRErrorHandler, FHIRConnectionError +from .fhirgateway import FHIRGateway # Import these if available, but don't error if they're not try: __all__ = [ "BaseGateway", "GatewayConfig", + "EventCapability", + "FHIRConnectionManager", + "FHIRErrorHandler", + "FHIRConnectionError", + "FHIRGateway", + "EHREvent", + "SOAPEvent", + "EHREventType", + "RequestModel", + "ResponseModel", ] except ImportError: __all__ = [ "BaseGateway", "GatewayConfig", + "EventCapability", + "FHIRConnectionManager", + "FHIRErrorHandler", + "FHIRConnectionError", + "FHIRGateway", ] diff --git a/healthchain/gateway/core/base.py b/healthchain/gateway/core/base.py index e1e0ff41..be77b1e8 100644 --- a/healthchain/gateway/core/base.py +++ b/healthchain/gateway/core/base.py @@ -11,11 +11,15 @@ from abc import ABC from typing import Any, Callable, Dict, List, TypeVar, Generic, Optional, Union from pydantic import BaseModel +from fastapi import APIRouter + +from healthchain.gateway.api.protocols import EventDispatcherProtocol logger = logging.getLogger(__name__) # Type variables for self-referencing return types and generic gateways G = TypeVar("G", bound="BaseGateway") +P = TypeVar("P", bound="BaseProtocolHandler") T = TypeVar("T") # For generic request types R = TypeVar("R") # For generic response types @@ -27,54 +31,35 @@ class GatewayConfig(BaseModel): system_type: str = "GENERIC" -class EventDispatcherMixin: +class EventCapability: """ - Mixin class that provides event dispatching capabilities. + Encapsulates event dispatching functionality. - This mixin encapsulates all event-related functionality to allow for cleaner separation - of concerns and optional event support in gateways. """ def __init__(self): - """ - Initialize event dispatching capabilities. - """ - self.event_dispatcher = None - self._event_creator = None + """Initialize event dispatching capabilities.""" + self.dispatcher: Optional[EventDispatcherProtocol] = ( + None # EventDispatcherProtocol + ) + self._event_creator: Optional[Callable] = None - def _run_async_publish(self, event): + def publish(self, event): """ - Safely run the async publish method in a way that works in both sync and async contexts. + Publish an event using the configured dispatcher. Args: event: The event to publish """ - if not self.event_dispatcher: + if not self.dispatcher: return - try: - # Try to get the running loop (only works in async context) - try: - loop = asyncio.get_running_loop() - # We're in an async context, so create_task works - asyncio.create_task(self.event_dispatcher.publish(event)) - except RuntimeError: - # We're not in an async context, create a new loop - loop = asyncio.new_event_loop() - try: - # Run the coroutine to completion in the new loop - loop.run_until_complete(self.event_dispatcher.publish(event)) - finally: - # Clean up the loop - loop.close() - except Exception as e: - logger.error(f"Failed to publish event: {str(e)}", exc_info=True) - - def set_event_dispatcher(self, dispatcher): - """ - Set the event dispatcher for this gateway. + # Delegate to dispatcher's sync-friendly publish method + self.dispatcher.emit(event) - This allows the gateway to publish events and register handlers. + def set_dispatcher(self, dispatcher) -> "EventCapability": + """ + Set the event dispatcher. Args: dispatcher: The event dispatcher instance @@ -82,20 +67,13 @@ def set_event_dispatcher(self, dispatcher): Returns: Self, to allow for method chaining """ - self.event_dispatcher = dispatcher - - # Register default handlers - self._register_default_handlers() - + self.dispatcher = dispatcher return self - def set_event_creator(self, creator_function: Callable): + def set_event_creator(self, creator_function: Callable) -> "EventCapability": """ Set a custom function to map gateway-specific events to EHREvents. - The creator function will be called instead of any default event creation logic, - allowing users to define custom event creation without subclassing. - Args: creator_function: Function that accepts gateway-specific arguments and returns an EHREvent or None @@ -106,18 +84,7 @@ def set_event_creator(self, creator_function: Callable): self._event_creator = creator_function return self - def _register_default_handlers(self): - """ - Register default event handlers for this gateway. - - Override this method in subclasses to register default handlers - for specific event types relevant to the gateway. - """ - # Base implementation does nothing - # Subclasses should override this method to register their default handlers - pass - - def register_event_handler(self, event_type, handler=None): + def register_handler(self, event_type, handler=None): """ Register a custom event handler for a specific event type. @@ -128,40 +95,75 @@ def register_event_handler(self, event_type, handler=None): handler: The handler function (optional if used as decorator) Returns: - Decorator function if handler is None, self otherwise + Decorator function if handler is None, the capability object otherwise """ - if not self.event_dispatcher: - raise ValueError("Event dispatcher not set for this gateway") + if not self.dispatcher: + raise ValueError("Event dispatcher not set") # If used as a decorator (no handler provided) if handler is None: - return self.event_dispatcher.register_handler(event_type) + return self.dispatcher.register_handler(event_type) # If called directly with a handler - self.event_dispatcher.register_handler(event_type)(handler) + self.dispatcher.register_handler(event_type)(handler) return self + def emit_event( + self, creator_function: Callable, *args, use_events: bool = True, **kwargs + ) -> None: + """ + Emit an event using the standard custom/fallback pattern. -class BaseGateway(ABC, Generic[T, R], EventDispatcherMixin): - """ - Base class for healthcare standard gateways that handle communication with external systems. + This method implements the common event emission pattern used across + all protocol handlers: try custom event creator first, then fallback + to standard event creator. + + Args: + creator_function: Standard event creator function to use as fallback + *args: Positional arguments to pass to the event creator + use_events: Whether events are enabled for this operation + **kwargs: Keyword arguments to pass to the event creator + + Example: + # In a protocol handler + self.events.emit_event( + create_fhir_event, + operation, resource_type, resource_id, resource + ) + """ + # Skip if events are disabled or no dispatcher + if not self.dispatcher or not use_events: + return + + # Use custom event creator if provided + if self._event_creator: + event = self._event_creator(*args) + if event: + self.publish(event) + return - Gateways provide a consistent interface for interacting with healthcare standards - and protocols through the decorator pattern for handler registration. + # Create a standard event using the provided creator function + event = creator_function(*args, **kwargs) + if event: + self.publish(event) - Type Parameters: - T: The request type this gateway handles - R: The response type this gateway returns + +class BaseProtocolHandler(ABC, Generic[T, R]): + """ + Base class for protocol handlers that process specific request/response types. + + This is designed for CDS Hooks, SOAP, and other protocol-specific handlers. + Register handlers with the register_handler method. """ def __init__( self, config: Optional[GatewayConfig] = None, use_events: bool = True, **options ): """ - Initialize a new gateway. + Initialize a new protocol handler. Args: - config: Configuration options for the gateway + config: Configuration options for the handler use_events: Whether to enable event dispatching **options: Additional configuration options """ @@ -173,11 +175,9 @@ def __init__( self.return_errors = self.config.return_errors or options.get( "return_errors", False ) + self.events = EventCapability() - # Initialize event dispatcher mixin - EventDispatcherMixin.__init__(self) - - def register_handler(self, operation: str, handler: Callable) -> G: + def register_handler(self, operation: str, handler: Callable) -> P: """ Register a handler function for a specific operation. @@ -280,64 +280,85 @@ async def _default_handler( def get_capabilities(self) -> List[str]: """ - Get list of operations this gateway supports. + Get list of operations this handler supports. Returns: List of supported operation names """ return list(self._handlers.keys()) - def get_routes(self, path: Optional[str] = None) -> List[tuple]: + @classmethod + def create(cls, **options) -> G: """ - Get routes that this gateway wants to register with the FastAPI app. - - This method returns a list of tuples with the following structure: - (path, methods, handler, kwargs) where: - - path is the URL path for the endpoint - - methods is a list of HTTP methods this endpoint supports - - handler is the function to be called when the endpoint is accessed - - kwargs are additional arguments to pass to the add_api_route method + Factory method to create a new gateway with default configuration. Args: - path: Optional base path to prefix all routes + **options: Options to pass to the constructor Returns: - List of route tuples (path, methods, handler, kwargs) + New gateway instance """ - # Default implementation returns empty list - # Specific gateway classes should override this - return [] + return cls(**options) + - def get_metadata(self) -> Dict[str, Any]: +class BaseGateway(ABC, APIRouter): + """ + Base class for healthcare integration gateways. + + Combines FastAPI routing capabilities with event dispatching using composition. + """ + + def __init__( + self, + config: Optional[GatewayConfig] = None, + use_events: bool = True, + prefix: str = "/api", + tags: Optional[List[str]] = None, + **options, + ): + """ + Initialize a new gateway. + + Args: + config: Configuration options for the gateway + use_events: Whether to enable event dispatching + prefix: URL prefix for API routes + tags: OpenAPI tags + **options: Additional configuration options + """ + # Initialize APIRouter + APIRouter.__init__(self, prefix=prefix, tags=tags or []) + + self.options = options + self.config = config or GatewayConfig() + self.use_events = use_events + # Default to raising exceptions unless configured otherwise + self.return_errors = self.config.return_errors or options.get( + "return_errors", False + ) + self.events = EventCapability() + + def get_gateway_status(self) -> Dict[str, Any]: """ - Get metadata for this gateway, including capabilities and configuration. + Get operational status and metadata for this gateway. Returns: - Dictionary of gateway metadata + Dictionary of gateway operational status and metadata """ # Default implementation returns basic info # Specific gateway classes should override this - metadata = { + status = { "gateway_type": self.__class__.__name__, - "operations": self.get_capabilities(), "system_type": self.config.system_type, + "status": "active", + "return_errors": self.return_errors, } # Add event-related metadata if events are enabled - if self.event_dispatcher: - metadata["event_enabled"] = True - - return metadata - - @classmethod - def create(cls, **options) -> G: - """ - Factory method to create a new gateway with default configuration. - - Args: - **options: Options to pass to the constructor + if self.use_events: + status["events"] = { + "enabled": True, + "dispatcher_configured": self.events.dispatcher is not None, + } - Returns: - New gateway instance - """ - return cls(**options) + return status diff --git a/healthchain/gateway/core/connection.py b/healthchain/gateway/core/connection.py new file mode 100644 index 00000000..009f18c1 --- /dev/null +++ b/healthchain/gateway/core/connection.py @@ -0,0 +1,172 @@ +""" +FHIR Connection Management for HealthChain Gateway. + +This module provides centralized connection management for FHIR sources, +including connection string parsing, client pooling, and source configuration. +""" + +import logging +import urllib.parse +from typing import Dict + +import httpx + +from healthchain.gateway.clients.fhir import FHIRServerInterface +from healthchain.gateway.clients.pool import FHIRClientPool +from healthchain.gateway.core.errors import FHIRConnectionError + + +logger = logging.getLogger(__name__) + + +class FHIRConnectionManager: + """ + Manages FHIR connections and client pooling. + + Handles connection strings, source configuration, and provides + pooled FHIR clients for efficient resource management. + """ + + def __init__( + self, + max_connections: int = 100, + max_keepalive_connections: int = 20, + keepalive_expiry: float = 5.0, + ): + """ + Initialize the connection manager. + + Args: + max_connections: Maximum total HTTP connections across all sources + max_keepalive_connections: Maximum keep-alive connections per source + keepalive_expiry: How long to keep connections alive (seconds) + """ + # Create httpx-based client pool + self.client_pool = FHIRClientPool( + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + ) + + # Store configuration + self.sources = {} + self._connection_strings = {} + + def add_source(self, name: str, connection_string: str): + """ + Add a FHIR data source using connection string. + + Format: fhir://hostname:port/path?param1=value1¶m2=value2 + + Examples: + fhir://epic.org/api/FHIR/R4?client_id=my_app&client_secret=secret&token_url=https://epic.org/oauth2/token&use_jwt_assertion=true + fhir://cerner.org/r4?client_id=app_id&client_secret=app_secret&token_url=https://cerner.org/token&scope=openid + + Args: + name: Source name identifier + connection_string: FHIR connection string + + Raises: + FHIRConnectionError: If connection string is invalid + """ + # Store connection string for pooling + self._connection_strings[name] = connection_string + + # Parse the connection string for validation only + try: + if not connection_string.startswith("fhir://"): + raise ValueError("Connection string must start with fhir://") + + # Parse URL for validation + parsed = urllib.parse.urlparse(connection_string) + + # Validate that we have a valid hostname + if not parsed.netloc: + raise ValueError("Invalid connection string: missing hostname") + + # Store the source name + self.sources[name] = None # Placeholder - store metadata here + + logger.info(f"Added FHIR source '{name}'") + + except Exception as e: + raise FHIRConnectionError( + message=f"Failed to parse connection string: {str(e)}", + code="Invalid connection string", + state="500", + ) + + def _create_server_from_connection_string( + self, connection_string: str, limits: httpx.Limits = None + ) -> FHIRServerInterface: + """ + Create a FHIR server instance from a connection string with connection pooling. + + This is used by the client pool to create new server instances. + + Args: + connection_string: FHIR connection string + limits: httpx connection limits for pooling + + Returns: + FHIRServerInterface: A new FHIR server instance with pooled connections + """ + from healthchain.gateway.clients import create_fhir_client + from healthchain.gateway.clients.auth import parse_fhir_auth_connection_string + + # Parse connection string as OAuth2.0 configuration + auth_config = parse_fhir_auth_connection_string(connection_string) + + # Pass httpx limits for connection pooling + return create_fhir_client(auth_config=auth_config, limits=limits) + + async def get_client(self, source: str = None) -> FHIRServerInterface: + """ + Get a FHIR client for the specified source. + + Args: + source: Source name to get client for (uses first available if None) + + Returns: + FHIRServerInterface: A FHIR client with pooled connections + + Raises: + ValueError: If source is unknown or no connection string found + """ + source_name = source or next(iter(self.sources.keys())) + if source_name not in self.sources: + raise ValueError(f"Unknown source: {source_name}") + + if source_name not in self._connection_strings: + raise ValueError(f"No connection string found for source: {source_name}") + + connection_string = self._connection_strings[source_name] + + return await self.client_pool.get_client( + connection_string, self._create_server_from_connection_string + ) + + def get_pool_status(self) -> Dict[str, any]: + """ + Get the current status of the connection pool. + + Returns: + Dict containing pool status information including: + - max_connections: Maximum connections across all sources + - sources: Dict of source names and their connection info + - client_stats: Detailed httpx connection pool statistics + """ + return self.client_pool.get_pool_stats() + + def get_sources(self) -> Dict[str, any]: + """ + Get all configured sources. + + Returns: + Dict of source names and their configurations + """ + return self.sources.copy() + + async def close(self): + """Close all connections and clean up resources.""" + await self.client_pool.close_all() diff --git a/healthchain/gateway/core/errors.py b/healthchain/gateway/core/errors.py new file mode 100644 index 00000000..f1368bb3 --- /dev/null +++ b/healthchain/gateway/core/errors.py @@ -0,0 +1,195 @@ +""" +FHIR Error Handling for HealthChain Gateway. + +This module provides standardized error handling for FHIR operations, +including status code mapping, error formatting, and exception types. +""" + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class FHIRConnectionError(Exception): + """Standardized FHIR connection error with state codes.""" + + def __init__( + self, + message: str, + code: str, + state: Optional[str] = None, + show_state: bool = True, + ): + """ + Initialize a FHIR connection error. + + Args: + message: Human-readable error message e.g. Server does not allow client defined ids + code: Error code or technical details e.g. METHOD_NOT_ALLOWED + state: HTTP status code e.g. 405 + show_state: Whether to include the state in the error message + """ + self.message = message + self.code = code + self.state = state + if show_state: + super().__init__(f"[{state} {code}] {message}") + else: + super().__init__(f"[{code}] {message}") + + +class FHIRErrorHandler: + """ + Handles FHIR operation errors consistently across the gateway. + + Provides standardized error mapping, formatting, and exception handling + for FHIR-specific operations and status codes. + """ + + # Map HTTP status codes to FHIR error types and messages + # Based on: https://build.fhir.org/http.html + ERROR_MAP = { + 400: "Resource could not be parsed or failed basic FHIR validation rules (or multiple matches were found for conditional criteria)", + 401: "Authorization is required for the interaction that was attempted", + 403: "You may not have permission to perform this operation", + 404: "The resource you are looking for does not exist, is not a resource type, or is not a FHIR end point", + 405: "The server does not allow client defined ids for resources", + 409: "Version conflict - update cannot be done", + 410: "The resource you are looking for is no longer available", + 412: "Version conflict - version id does not match", + 422: "Proposed resource violated applicable FHIR profiles or server business rules", + } + + @classmethod + def handle_fhir_error( + cls, + e: Exception, + resource_type: str, + fhir_id: Optional[str] = None, + operation: str = "operation", + ) -> None: + """ + Handle FHIR operation errors consistently. + + Args: + e: The original exception + resource_type: The FHIR resource type being operated on + fhir_id: The resource ID (if applicable) + operation: The operation being performed + + Raises: + FHIRConnectionError: Standardized FHIR error with proper formatting + """ + error_msg = str(e) + resource_ref = f"{resource_type}{'' if fhir_id is None else f'/{fhir_id}'}" + + # Try status code first + status_code = getattr(e, "status_code", None) + if status_code in cls.ERROR_MAP: + msg = cls.ERROR_MAP[status_code] + raise FHIRConnectionError( + message=f"{operation} {resource_ref} failed: {msg}", + code=error_msg, + state=str(status_code), + show_state=False, + ) + + # Fall back to message parsing + error_msg_lower = error_msg.lower() + for code, msg in cls.ERROR_MAP.items(): + if str(code) in error_msg_lower: + raise FHIRConnectionError( + message=f"{operation} {resource_ref} failed: {msg}", + code=error_msg, + state=str(code), + show_state=False, + ) + + # Default fallback error + raise FHIRConnectionError( + message=f"{operation} {resource_ref} failed: HTTP error", + code=error_msg, + state=str(status_code) if status_code else "UNKNOWN", + show_state=False, + ) + + @classmethod + def create_validation_error( + cls, message: str, resource_type: str = None, field_name: str = None + ) -> FHIRConnectionError: + """ + Create a standardized validation error. + + Args: + message: The validation error message + resource_type: The resource type being validated (optional) + field_name: The specific field that failed validation (optional) + + Returns: + FHIRConnectionError: Formatted validation error + """ + if resource_type and field_name: + formatted_message = ( + f"Validation failed for {resource_type}.{field_name}: {message}" + ) + elif resource_type: + formatted_message = f"Validation failed for {resource_type}: {message}" + else: + formatted_message = f"Validation failed: {message}" + + return FHIRConnectionError( + message=formatted_message, + code="VALIDATION_ERROR", + state="422", # Unprocessable Entity + ) + + @classmethod + def create_connection_error( + cls, message: str, source: str = None + ) -> FHIRConnectionError: + """ + Create a standardized connection error. + + Args: + message: The connection error message + source: The source name that failed to connect (optional) + + Returns: + FHIRConnectionError: Formatted connection error + """ + if source: + formatted_message = f"Connection to source '{source}' failed: {message}" + else: + formatted_message = f"Connection failed: {message}" + + return FHIRConnectionError( + message=formatted_message, + code="CONNECTION_ERROR", + state="503", # Service Unavailable + ) + + @classmethod + def create_authentication_error( + cls, message: str, source: str = None + ) -> FHIRConnectionError: + """ + Create a standardized authentication error. + + Args: + message: The authentication error message + source: The source name that failed authentication (optional) + + Returns: + FHIRConnectionError: Formatted authentication error + """ + if source: + formatted_message = f"Authentication to source '{source}' failed: {message}" + else: + formatted_message = f"Authentication failed: {message}" + + return FHIRConnectionError( + message=formatted_message, + code="AUTHENTICATION_ERROR", + state="401", # Unauthorized + ) diff --git a/healthchain/gateway/core/fhirgateway.py b/healthchain/gateway/core/fhirgateway.py new file mode 100644 index 00000000..e951040b --- /dev/null +++ b/healthchain/gateway/core/fhirgateway.py @@ -0,0 +1,913 @@ +""" +FHIR Gateway for HealthChain. + +This module provides a specialized FHIR integration hub for data aggregation, +transformation, and routing. +""" + +import logging +import inspect +import warnings + +from contextlib import asynccontextmanager +from typing import ( + Dict, + List, + Any, + Callable, + Optional, + TypeVar, + Type, +) +from fastapi import Depends, HTTPException, Query, Path +from fastapi.responses import JSONResponse +from datetime import datetime + +from fhir.resources.resource import Resource +from fhir.resources.bundle import Bundle +from fhir.resources.capabilitystatement import CapabilityStatement + +from healthchain.gateway.core.base import BaseGateway +from healthchain.gateway.core.connection import FHIRConnectionManager +from healthchain.gateway.core.errors import FHIRErrorHandler +from healthchain.gateway.events.fhir import create_fhir_event +from healthchain.gateway.clients.fhir import FHIRServerInterface + + +logger = logging.getLogger(__name__) + +# Type variable for FHIR Resource +T = TypeVar("T", bound=Resource) + + +class FHIRResponse(JSONResponse): + """ + Custom response class for FHIR resources. + + This sets the correct content-type header for FHIR resources. + """ + + media_type = "application/fhir+json" + + +class FHIRGateway(BaseGateway): + """ + FHIR Gateway for HealthChain. + + A specialized gateway for FHIR resource operations including: + - Connection pooling and management + - Resource transformation and aggregation + - Event-driven processing + - OAuth2 authentication support + + Example: + ```python + # Initialize with connection pooling + async with FHIRGateway(max_connections=50) as gateway: + # Add FHIR source + gateway.add_source("epic", "fhir://epic.org/api/FHIR/R4?...") + + # Register transformation handler + @gateway.transform(Patient) + async def enhance_patient(id: str, source: str = None) -> Patient: + async with gateway.modify(Patient, id, source) as patient: + patient.active = True + return patient + + # Use the gateway + patient = await gateway.read(Patient, "123", "epic") + ``` + """ + + def __init__( + self, + sources: Dict[str, FHIRServerInterface] = None, + prefix: str = "/fhir", + tags: List[str] = ["FHIR"], + use_events: bool = True, + max_connections: int = 100, + max_keepalive_connections: int = 20, + keepalive_expiry: float = 5.0, + **options, + ): + """ + Initialize the FHIR Gateway. + + Args: + sources: Dictionary of named FHIR servers or connection strings + prefix: URL prefix for API routes + tags: OpenAPI tags + use_events: Enable event-based processing + max_connections: Maximum total HTTP connections across all sources + max_keepalive_connections: Maximum keep-alive connections per source + keepalive_expiry: How long to keep connections alive (seconds) + **options: Additional options + """ + # Initialize as BaseGateway (which includes APIRouter) + super().__init__(use_events=use_events, prefix=prefix, tags=tags, **options) + + self.use_events = use_events + + # Create connection manager + self.connection_manager = FHIRConnectionManager( + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + ) + + # Add sources if provided + if sources: + for name, source in sources.items(): + if isinstance(source, str): + self.connection_manager.add_source(name, source) + else: + self.connection_manager.sources[name] = source + + # Handlers for resource operations + self._resource_handlers: Dict[str, Dict[str, Callable]] = {} + + # Register base routes only (metadata endpoint) + self._register_base_routes() + + def _get_gateway_dependency(self): + """Create a dependency function that returns this gateway instance.""" + + def get_self_gateway(): + return self + + return get_self_gateway + + def _get_resource_name(self, resource_type: Type[Resource]) -> str: + """Extract resource name from resource type.""" + return resource_type.__resource_type__ + + def _register_base_routes(self): + """Register basic endpoints""" + get_self_gateway = self._get_gateway_dependency() + + # FHIR Metadata endpoint - returns CapabilityStatement + @self.get("/metadata", response_class=FHIRResponse) + def capability_statement( + fhir: "FHIRGateway" = Depends(get_self_gateway), + ): + """Return the FHIR capability statement for this gateway's services.""" + return fhir.build_capability_statement().model_dump() + + # Gateway status endpoint - returns operational metadata + @self.get("/status", response_class=JSONResponse) + def gateway_status( + fhir: "FHIRGateway" = Depends(get_self_gateway), + ): + """Return operational status and metadata for this gateway.""" + return fhir.get_gateway_status() + + def build_capability_statement(self) -> CapabilityStatement: + """ + Build a FHIR CapabilityStatement for this gateway's value-add services. + + Only includes resources and operations that this gateway provides through + its transform/aggregate endpoints, not the underlying FHIR sources. + + Returns: + CapabilityStatement: FHIR-compliant capability statement + """ + # Build resource entries based on registered handlers + resources = [] + for resource_type, operations in self._resource_handlers.items(): + interactions = [] + + # Add supported interactions based on registered handlers + for operation in operations: + if operation == "transform": + interactions.append( + {"code": "read"} + ) # Transform requires read access + elif operation == "aggregate": + interactions.append( + {"code": "search-type"} + ) # Aggregate is like search + + if interactions: + # Extract the resource name from the resource type class + resource_name = self._get_resource_name(resource_type) + resources.append( + { + "type": resource_name, + "interaction": interactions, + "documentation": f"Gateway provides {', '.join(operations)} operations for {resource_name}", + } + ) + + capability_data = { + "resourceType": "CapabilityStatement", + "status": "active", + "date": datetime.now().strftime("%Y-%m-%d"), + "publisher": "HealthChain", + "kind": "instance", + "software": { + "name": "HealthChain FHIR Gateway", + "version": "1.0.0", # TODO: Extract from package + }, + "fhirVersion": "4.0.1", + "format": ["application/fhir+json"], + "rest": [ + { + "mode": "server", + "documentation": "HealthChain FHIR Gateway provides transformation and aggregation services", + "resource": resources, + } + ] + if resources + else [], + } + + return CapabilityStatement(**capability_data) + + @property + def supported_resources(self) -> List[str]: + """Get list of supported FHIR resource types.""" + return [ + self._get_resource_name(resource_type) + for resource_type in self._resource_handlers.keys() + ] + + def get_capabilities(self) -> List[str]: + """ + Get list of supported FHIR operations and resources. + + Returns: + List of capabilities this gateway supports + """ + capabilities = [] + for resource_type, operations in self._resource_handlers.items(): + resource_name = self._get_resource_name(resource_type) + for operation in operations: + capabilities.append(f"{operation}:{resource_name}") + return capabilities + + def get_gateway_status(self) -> Dict[str, Any]: + """ + Get operational status and metadata for this gateway. + + This provides gateway-specific operational information. + + Returns: + Dict containing gateway operational status and metadata + """ + status = { + "gateway_type": "FHIRGateway", + "version": "1.0.0", # TODO: Extract from package + "status": "active", + "timestamp": datetime.now().isoformat() + "Z", + "sources": { + "count": len(self.connection_manager.sources), + "names": list(self.connection_manager.sources.keys()), + }, + "connection_pool": self.get_pool_status(), + "supported_operations": { + "resources": self.supported_resources, + "operations": self.get_capabilities(), + "endpoints": { + "transform": len( + [ + r + for r, ops in self._resource_handlers.items() + if "transform" in ops + ] + ), + "aggregate": len( + [ + r + for r, ops in self._resource_handlers.items() + if "aggregate" in ops + ] + ), + }, + }, + "events": { + "enabled": self.use_events, + "dispatcher_configured": self.events.dispatcher is not None, + }, + } + + return status + + def _register_resource_handler( + self, + resource_type: Type[Resource], + operation: str, + handler: Callable, + ) -> None: + """Register a custom handler for a resource operation.""" + self._validate_handler_annotations(resource_type, operation, handler) + + if resource_type not in self._resource_handlers: + self._resource_handlers[resource_type] = {} + self._resource_handlers[resource_type][operation] = handler + + resource_name = self._get_resource_name(resource_type) + logger.debug( + f"Registered {operation} handler for {resource_name}: {handler.__name__}" + ) + + self._register_operation_route(resource_type, operation) + + def _validate_handler_annotations( + self, + resource_type: Type[Resource], + operation: str, + handler: Callable, + ) -> None: + """Validate that handler annotations match the decorator resource type.""" + if operation != "transform": + return + + try: + sig = inspect.signature(handler) + return_annotation = sig.return_annotation + + if return_annotation == inspect.Parameter.empty: + warnings.warn( + f"Handler {handler.__name__} missing return type annotation for {resource_type.__name__}" + ) + return + + if return_annotation != resource_type: + raise TypeError( + f"Handler {handler.__name__} return type ({return_annotation}) " + f"doesn't match decorator resource type ({resource_type})" + ) + + except Exception as e: + if isinstance(e, TypeError): + raise + logger.warning(f"Could not validate handler annotations: {str(e)}") + + def _register_operation_route( + self, resource_type: Type[Resource], operation: str + ) -> None: + """Register a route for a specific resource type and operation.""" + resource_name = self._get_resource_name(resource_type) + + if operation == "transform": + path = f"/transform/{resource_name}/{{id}}" + summary = f"Transform {resource_name}" + description = ( + f"Transform a {resource_name} resource with registered handler" + ) + elif operation == "aggregate": + path = f"/aggregate/{resource_name}" + summary = f"Aggregate {resource_name}" + description = f"Aggregate {resource_name} resources from multiple sources" + else: + raise ValueError(f"Unsupported operation: {operation}") + + handler = self._create_route_handler(resource_type, operation) + + self.add_api_route( + path=path, + endpoint=handler, + methods=["GET"], + summary=summary, + description=description, + response_model_exclude_none=True, + response_class=FHIRResponse, + tags=self.tags, + include_in_schema=True, + ) + logger.debug(f"Registered {operation} endpoint: {self.prefix}{path}") + + def _create_route_handler( + self, resource_type: Type[Resource], operation: str + ) -> Callable: + """Create a route handler for the given resource type and operation.""" + get_self_gateway = self._get_gateway_dependency() + + def _execute_handler(fhir: "FHIRGateway", *args) -> Any: + """Common handler execution logic with error handling.""" + try: + handler_func = fhir._resource_handlers[resource_type][operation] + result = handler_func(*args) + return result + except Exception as e: + logger.error(f"Error in {operation} handler: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + if operation == "transform": + + async def handler( + id: str = Path(..., description="Resource ID to transform"), + source: Optional[str] = Query( + None, description="Source system to retrieve the resource from" + ), + fhir: "FHIRGateway" = Depends(get_self_gateway), + ): + """Transform a resource with registered handler.""" + return _execute_handler(fhir, id, source) + + elif operation == "aggregate": + + async def handler( + id: Optional[str] = Query(None, description="ID to aggregate data for"), + sources: Optional[List[str]] = Query( + None, description="List of source names to query" + ), + fhir: "FHIRGateway" = Depends(get_self_gateway), + ): + """Aggregate resources with registered handler.""" + return _execute_handler(fhir, id, sources) + + else: + raise ValueError(f"Unsupported operation: {operation}") + + return handler + + def add_source(self, name: str, connection_string: str) -> None: + """ + Add a FHIR data source using connection string with OAuth2.0 flow. + + Format: fhir://hostname:port/path?param1=value1¶m2=value2 + + Examples: + fhir://epic.org/api/FHIR/R4?client_id=my_app&client_secret=secret&token_url=https://epic.org/oauth2/token&scope=system/*.read + fhir://cerner.org/r4?client_id=app_id&client_secret=app_secret&token_url=https://cerner.org/token&audience=https://cerner.org/fhir + """ + return self.connection_manager.add_source(name, connection_string) + + async def get_client(self, source: str = None) -> FHIRServerInterface: + """ + Get a FHIR client for the specified source. + + Args: + source: Source name to get client for (uses first available if None) + + Returns: + FHIRServerInterface: A FHIR client with pooled connections + + Raises: + ValueError: If source is unknown or no connection string found + """ + return await self.connection_manager.get_client(source) + + async def capabilities(self, source: str = None) -> CapabilityStatement: + """ + Get the capabilities of the FHIR server. + + Args: + source: Source name to get capabilities for (uses first available if None) + + Returns: + CapabilityStatement: The capabilities of the FHIR server + + Raises: + FHIRConnectionError: If connection fails + """ + capabilities = await self._execute_with_client( + "capabilities", + source=source, + resource_type=CapabilityStatement, + ) + + # Emit capabilities event + self._emit_fhir_event("capabilities", "CapabilityStatement", None, capabilities) + logger.debug("Retrieved server capabilities") + + return capabilities + + async def read( + self, + resource_type: Type[Resource], + fhir_id: str, + source: str = None, + ) -> Resource: + """ + Read a FHIR resource. + + Args: + resource_type: The FHIR resource type class + fhir_id: Resource ID to fetch + source: Source name to fetch from (uses first available if None) + + Returns: + The FHIR resource object + + Raises: + ValueError: If resource not found or source invalid + FHIRConnectionError: If connection fails + + Example: + # Simple read-only access + document = await fhir_gateway.read(DocumentReference, "123", "epic") + summary = extract_summary(document.text) + """ + resource = await self._execute_with_client( + "read", + source=source, + resource_type=resource_type, + resource_id=fhir_id, + client_args=(resource_type, fhir_id), + ) + + if not resource: + type_name = resource_type.__resource_type__ + raise ValueError(f"Resource {type_name}/{fhir_id} not found") + + # Emit read event + type_name = resource.__resource_type__ + self._emit_fhir_event("read", type_name, fhir_id, resource) + logger.debug(f"Retrieved {type_name}/{fhir_id} for read-only access") + + return resource + + async def search( + self, + resource_type: Type[Resource], + params: Dict[str, Any] = None, + source: str = None, + ) -> Bundle: + """ + Search for FHIR resources. + + Args: + resource_type: The FHIR resource type class + params: Search parameters (e.g., {"name": "Smith", "active": "true"}) + source: Source name to search in (uses first available if None) + + Returns: + Bundle containing search results + + Raises: + ValueError: If source is invalid + FHIRConnectionError: If connection fails + + Example: + # Search for patients by name + bundle = await fhir_gateway.search(Patient, {"name": "Smith"}, "epic") + for entry in bundle.entry or []: + patient = entry.resource + print(f"Found patient: {patient.name[0].family}") + """ + bundle = await self._execute_with_client( + "search", + source=source, + resource_type=resource_type, + client_args=(resource_type,), + client_kwargs={"params": params}, + ) + + # Emit search event with result count + type_name = resource_type.__resource_type__ + event_data = { + "params": params, + "result_count": len(bundle.entry) if bundle.entry else 0, + } + self._emit_fhir_event("search", type_name, None, event_data) + logger.debug( + f"Searched {type_name} with params {params}, found {len(bundle.entry) if bundle.entry else 0} results" + ) + + return bundle + + async def create(self, resource: Resource, source: str = None) -> Resource: + """ + Create a new FHIR resource. + + Args: + resource: The FHIR resource to create + source: Source name to create in (uses first available if None) + + Returns: + The created FHIR resource with server-assigned ID + + Raises: + ValueError: If source is invalid + FHIRConnectionError: If connection fails + + Example: + # Create a new patient + patient = Patient(name=[HumanName(family="Smith", given=["John"])]) + created = await fhir_gateway.create(patient, "epic") + print(f"Created patient with ID: {created.id}") + """ + created = await self._execute_with_client( + "create", + source=source, + resource_type=resource.__class__, + client_args=(resource,), + ) + + # Emit create event + type_name = resource.__resource_type__ + self._emit_fhir_event("create", type_name, created.id, created) + logger.debug(f"Created {type_name} resource with ID {created.id}") + + return created + + async def update(self, resource: Resource, source: str = None) -> Resource: + """ + Update an existing FHIR resource. + + Args: + resource: The FHIR resource to update (must have ID) + source: Source name to update in (uses first available if None) + + Returns: + The updated FHIR resource + + Raises: + ValueError: If resource has no ID or source is invalid + FHIRConnectionError: If connection fails + + Example: + # Update a patient's name + patient = await fhir_gateway.read(Patient, "123", "epic") + patient.name[0].family = "Jones" + updated = await fhir_gateway.update(patient, "epic") + """ + if not resource.id: + raise ValueError("Resource must have an ID for update") + + updated = await self._execute_with_client( + "update", + source=source, + resource_type=resource.__class__, + resource_id=resource.id, + client_args=(resource,), + ) + + # Emit update event + type_name = resource.__resource_type__ + self._emit_fhir_event("update", type_name, resource.id, updated) + logger.debug(f"Updated {type_name} resource with ID {resource.id}") + + return updated + + async def delete( + self, resource_type: Type[Resource], fhir_id: str, source: str = None + ) -> bool: + """ + Delete a FHIR resource. + + Args: + resource_type: The FHIR resource type class + fhir_id: Resource ID to delete + source: Source name to delete from (uses first available if None) + + Returns: + True if deletion was successful + + Raises: + ValueError: If source is invalid + FHIRConnectionError: If connection fails + + Example: + # Delete a patient + success = await fhir_gateway.delete(Patient, "123", "epic") + if success: + print("Patient deleted successfully") + """ + success = await self._execute_with_client( + "delete", + source=source, + resource_type=resource_type, + resource_id=fhir_id, + client_args=(resource_type, fhir_id), + ) + + if success: + # Emit delete event + type_name = resource_type.__resource_type__ + self._emit_fhir_event("delete", type_name, fhir_id, None) + logger.debug(f"Deleted {type_name} resource with ID {fhir_id}") + + return success + + async def transaction(self, bundle: Bundle, source: str = None) -> Bundle: + """ + Execute a FHIR transaction bundle. + + Args: + bundle: The transaction bundle to execute + source: Source name to execute in (uses first available if None) + + Returns: + The response bundle with results + + Raises: + ValueError: If source is invalid + FHIRConnectionError: If connection fails + + Example: + # Create a transaction bundle + bundle = Bundle(type="transaction", entry=[ + BundleEntry(resource=patient1, request=BundleRequest(method="POST")), + BundleEntry(resource=patient2, request=BundleRequest(method="POST")) + ]) + result = await fhir_gateway.transaction(bundle, "epic") + """ + result = await self._execute_with_client( + "transaction", + source=source, + resource_type=Bundle, + client_args=(bundle,), + ) + + # Emit transaction event with entry counts + event_data = { + "entry_count": len(bundle.entry) if bundle.entry else 0, + "result_count": len(result.entry) if result.entry else 0, + } + self._emit_fhir_event("transaction", "Bundle", None, event_data) + logger.debug( + f"Executed transaction bundle with {len(bundle.entry) if bundle.entry else 0} entries" + ) + + return result + + @asynccontextmanager + async def modify( + self, resource_type: Type[Resource], fhir_id: str = None, source: str = None + ): + """ + Context manager for working with FHIR resources. + + Automatically handles fetching, updating, and error handling using connection pooling. + + Args: + resource_type: The FHIR resource type class (e.g. Patient) + fhir_id: Resource ID (if None, creates a new resource) + source: Source name to use (uses first available if None) + + Yields: + Resource: The FHIR resource object + + Raises: + FHIRConnectionError: If connection fails + ValueError: If resource type is invalid + """ + client = await self.get_client(source) + resource = None + is_new = fhir_id is None + + # Get type name for error messages + type_name = resource_type.__resource_type__ + + try: + if is_new: + resource = resource_type() + else: + resource = await client.read(resource_type, fhir_id) + logger.debug(f"Retrieved {type_name}/{fhir_id} in modify context") + self._emit_fhir_event("read", type_name, fhir_id, resource) + + yield resource + + if is_new: + updated_resource = await client.create(resource) + else: + updated_resource = await client.update(resource) + + resource.id = updated_resource.id + for field_name, field_value in updated_resource.model_dump().items(): + if hasattr(resource, field_name): + setattr(resource, field_name, field_value) + + operation = "create" if is_new else "update" + self._emit_fhir_event(operation, type_name, resource.id, updated_resource) + logger.debug( + f"{'Created' if is_new else 'Updated'} {type_name} resource in modify context" + ) + + except Exception as e: + operation = ( + "read" + if not is_new and resource is None + else "create" + if is_new + else "update" + ) + FHIRErrorHandler.handle_fhir_error(e, type_name, fhir_id, operation) + + def aggregate(self, resource_type: Type[Resource]): + """ + Decorator for custom aggregation functions. + + Args: + resource_type: The FHIR resource type class that this handler aggregates + + Example: + @fhir_gateway.aggregate(Patient) + def aggregate_patients(id: str = None, sources: List[str] = None) -> List[Patient]: + # Handler implementation + pass + """ + + def decorator(handler: Callable): + self._register_resource_handler(resource_type, "aggregate", handler) + return handler + + return decorator + + def transform(self, resource_type: Type[Resource]): + """ + Decorator for custom transformation functions. + + Args: + resource_type: The FHIR resource type class that this handler transforms + + Example: + @fhir_gateway.transform(DocumentReference) + def transform_document(id: str, source: str = None) -> DocumentReference: + # Handler implementation + pass + """ + + def decorator(handler: Callable): + self._register_resource_handler(resource_type, "transform", handler) + return handler + + return decorator + + def _emit_fhir_event( + self, operation: str, resource_type: str, resource_id: str, resource: Any = None + ): + """ + Emit an event for FHIR operations. + + Args: + operation: The FHIR operation (read, search, create, update, delete) + resource_type: The FHIR resource type + resource_id: The resource ID + resource: The resource object or data + """ + self.events.emit_event( + create_fhir_event, + operation, + resource_type, + resource_id, + resource, + use_events=self.use_events, + ) + + def get_pool_status(self) -> Dict[str, Any]: + """ + Get the current status of the connection pool. + + Returns: + Dict containing pool status information including: + - max_connections: Maximum connections across all sources + - sources: Dict of source names and their connection info + - client_stats: Detailed httpx connection pool statistics + """ + return self.connection_manager.get_pool_status() + + async def close(self): + """Close all connections and clean up resources.""" + await self.connection_manager.close() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + + async def _execute_with_client( + self, + operation: str, + *, # Force keyword-only arguments + source: str = None, + resource_type: Type[Resource] = None, + resource_id: str = None, + client_args: tuple = (), + client_kwargs: dict = None, + ): + """ + Execute a client operation with consistent error handling. + + Args: + operation: Operation name (read, create, update, delete, etc.) + source: Source name to use + resource_type: Resource type for error handling + resource_id: Resource ID for error handling (if applicable) + client_args: Positional arguments to pass to the client method + client_kwargs: Keyword arguments to pass to the client method + """ + client = await self.get_client(source) + client_kwargs = client_kwargs or {} + + try: + result = await getattr(client, operation)(*client_args, **client_kwargs) + return result + + except Exception as e: + # Use existing error handler + error_resource_type = resource_type or ( + client_args[0].__class__ + if client_args and hasattr(client_args[0], "__class__") + else None + ) + FHIRErrorHandler.handle_fhir_error( + e, error_resource_type, resource_id, operation + ) diff --git a/healthchain/gateway/events/__init__.py b/healthchain/gateway/events/__init__.py index 9e1f5857..ba674dc0 100644 --- a/healthchain/gateway/events/__init__.py +++ b/healthchain/gateway/events/__init__.py @@ -11,5 +11,4 @@ "EventDispatcher", "EHREvent", "EHREventType", - "EHREventPublisher", ] diff --git a/healthchain/gateway/events/cdshooks.py b/healthchain/gateway/events/cdshooks.py new file mode 100644 index 00000000..7fb02cab --- /dev/null +++ b/healthchain/gateway/events/cdshooks.py @@ -0,0 +1,71 @@ +""" +CDS Hooks specific event handling utilities. + +This module provides constants and helper functions for creating +and managing CDS Hooks operation events. +""" + +from datetime import datetime +from typing import Any, Dict, Optional + +from healthchain.gateway.events.dispatcher import EHREvent, EHREventType +from healthchain.models.requests.cdsrequest import CDSRequest +from healthchain.models.responses.cdsresponse import CDSResponse + + +# Mapping of CDS Hook types to event types +HOOK_TO_EVENT = { + "patient-view": EHREventType.CDS_PATIENT_VIEW, + "encounter-discharge": EHREventType.CDS_ENCOUNTER_DISCHARGE, + "order-sign": EHREventType.CDS_ORDER_SIGN, + "order-select": EHREventType.CDS_ORDER_SELECT, +} + + +def create_cds_hook_event( + hook_type: str, + request: CDSRequest, + response: CDSResponse, + extra_payload: Optional[Dict[str, Any]] = None, +) -> Optional[EHREvent]: + """ + Create a standardized CDS Hook event. + + Args: + hook_type: The hook type being invoked (e.g., "patient-view") + request: The CDSRequest object + response: The CDSResponse object + extra_payload: Additional payload data + + Returns: + EHREvent or None if hook type is not mapped + + Example: + event = create_cds_hook_event( + "patient-view", request, response + ) + """ + # Get the event type from the mapping + event_type = HOOK_TO_EVENT.get(hook_type, EHREventType.EHR_GENERIC) + + # Build the base payload + payload = { + "hook": hook_type, + "hook_instance": request.hookInstance, + "context": dict(request.context), + } + + # Add any extra payload data + if extra_payload: + payload.update(extra_payload) + + # Create and return the event + return EHREvent( + event_type=event_type, + source_system="CDS-Hooks", + timestamp=datetime.now(), + payload=payload, + metadata={ + "cards_count": len(response.cards) if response.cards else 0, + }, + ) diff --git a/healthchain/gateway/events/dispatcher.py b/healthchain/gateway/events/dispatcher.py index 4ddfe052..5d2ef09f 100644 --- a/healthchain/gateway/events/dispatcher.py +++ b/healthchain/gateway/events/dispatcher.py @@ -1,4 +1,5 @@ import logging +import asyncio from enum import Enum from pydantic import BaseModel from typing import Dict, Optional @@ -42,38 +43,21 @@ def get_name(self) -> str: class EventDispatcher: """Event dispatcher for handling EHR system events using fastapi-events. - This class provides a simple way to work with fastapi-events for dispatching - healthcare-related events in a FastAPI application. + Provides a simple interface for dispatching healthcare-related events in FastAPI applications. + Supports both request-scoped and application-scoped event handling. Example: ```python - from fastapi import FastAPI - from fastapi_events.handlers.local import local_handler - from fastapi_events.middleware import EventHandlerASGIMiddleware - app = FastAPI() dispatcher = EventDispatcher() - - # Register with the app dispatcher.init_app(app) - # Register a handler for a specific event type - @local_handler.register(event_name="patient.admission") - async def handle_admission(event): - # Process admission event - event_name, payload = event - print(f"Processing admission for {payload}") - pass - - # Register a default handler for all events - @local_handler.register(event_name="*") - async def log_all_events(event): - # Log all events + @dispatcher.register_handler(EHREventType.FHIR_READ) + async def handle_fhir_read(event): event_name, payload = event - print(f"Event logged: {event_name}") - pass + print(f"Processing FHIR read: {payload}") - # Publish an event (from anywhere in your application) + event = create_fhir_event(EHREventType.FHIR_READ, "test-system", {"resource_id": "123"}) await dispatcher.publish(event) ``` """ @@ -149,3 +133,32 @@ async def publish(self, event: EHREvent, middleware_id: Optional[int] = None): result = dispatch(event_name, event_data, middleware_id=mid) if result is not None: await result + + def emit(self, event: EHREvent, middleware_id: Optional[int] = None): + """Publish an event from synchronous code by handling async context automatically. + + This method handles the complexity of managing event loops when called from + synchronous contexts, while delegating to the async publish method when + already in an async context. + + Args: + event (EHREvent): The event to publish + middleware_id (Optional[int]): Custom middleware ID, defaults to self.middleware_id + """ + try: + # Try to get the running loop (only works in async context) + try: + loop = asyncio.get_running_loop() + # We're in an async context, so create_task works + asyncio.create_task(self.publish(event, middleware_id)) + except RuntimeError: + # We're not in an async context, create a new loop + loop = asyncio.new_event_loop() + try: + # Run the coroutine to completion in the new loop + loop.run_until_complete(self.publish(event, middleware_id)) + finally: + # Clean up the loop + loop.close() + except Exception as e: + logger.error(f"Failed to publish event: {str(e)}", exc_info=True) diff --git a/healthchain/gateway/events/fhir.py b/healthchain/gateway/events/fhir.py new file mode 100644 index 00000000..a1902f02 --- /dev/null +++ b/healthchain/gateway/events/fhir.py @@ -0,0 +1,79 @@ +""" +FHIR-specific event handling utilities. + +This module provides constants and helper functions for creating +and managing FHIR operation events. +""" + +from datetime import datetime +from typing import Any, Dict, Optional + +from healthchain.gateway.events.dispatcher import EHREvent, EHREventType + + +# Mapping of FHIR operations to event types +OPERATION_TO_EVENT = { + "read": EHREventType.FHIR_READ, + "search": EHREventType.FHIR_SEARCH, + "create": EHREventType.FHIR_CREATE, + "update": EHREventType.FHIR_UPDATE, + "delete": EHREventType.FHIR_DELETE, +} + + +def create_fhir_event( + operation: str, + resource_type: str, + resource_id: Optional[str], + resource: Any = None, + extra_payload: Optional[Dict[str, Any]] = None, +) -> Optional[EHREvent]: + """ + Create a standardized FHIR event. + + Args: + operation: The FHIR operation (read, search, create, update, delete) + resource_type: The FHIR resource type + resource_id: The resource ID (can be None for operations like search) + resource: The resource object or data + extra_payload: Additional payload data + + Returns: + EHREvent or None if operation is not mapped + + Example: + event = create_fhir_event( + "read", "Patient", "123", patient_resource + ) + """ + # Get the event type from the mapping + event_type = OPERATION_TO_EVENT.get(operation) + if not event_type: + return None + + # Build the base payload + payload = { + "resource_type": resource_type, + "resource_id": resource_id, + "operation": operation, + } + + # Add the resource data if available + if resource: + payload["resource"] = resource + + # Add any extra payload data + if extra_payload: + payload.update(extra_payload) + + # Create and return the event + return EHREvent( + event_type=event_type, + source_system="FHIR", + timestamp=datetime.now(), + payload=payload, + metadata={ + "operation": operation, + "resource_type": resource_type, + }, + ) diff --git a/healthchain/gateway/events/notereader.py b/healthchain/gateway/events/notereader.py new file mode 100644 index 00000000..f03fdae3 --- /dev/null +++ b/healthchain/gateway/events/notereader.py @@ -0,0 +1,63 @@ +""" +NoteReader specific event handling utilities. + +This module provides constants and helper functions for creating +and managing NoteReader SOAP operation events. +""" + +from datetime import datetime +from typing import Any, Dict, Optional + +from healthchain.gateway.events.dispatcher import EHREvent, EHREventType +from healthchain.models.requests import CdaRequest +from healthchain.models.responses.cdaresponse import CdaResponse + + +def create_notereader_event( + operation: str, + request: CdaRequest, + response: CdaResponse, + system_type: str = "EHR_CDA", + extra_payload: Optional[Dict[str, Any]] = None, +) -> EHREvent: + """ + Create a standardized NoteReader event. + + Args: + operation: The SOAP method name (e.g., "ProcessDocument") + request: The CdaRequest object + response: The CdaResponse object + system_type: The system type identifier + extra_payload: Additional payload data + + Returns: + EHREvent for the NoteReader operation + + Example: + event = create_notereader_event( + "ProcessDocument", request, response + ) + """ + # Build the base payload + payload = { + "operation": operation, + "work_type": request.work_type, + "session_id": request.session_id, + "has_error": response.error is not None, + } + + # Add any extra payload data + if extra_payload: + payload.update(extra_payload) + + # Create and return the event + return EHREvent( + event_type=EHREventType.NOTEREADER_PROCESS_NOTE, + source_system="NoteReader", + timestamp=datetime.now(), + payload=payload, + metadata={ + "service": "NoteReaderService", + "system_type": system_type, + }, + ) diff --git a/healthchain/gateway/protocols/__init__.py b/healthchain/gateway/protocols/__init__.py index 89ac147e..b3e2c699 100644 --- a/healthchain/gateway/protocols/__init__.py +++ b/healthchain/gateway/protocols/__init__.py @@ -8,12 +8,12 @@ interface for registration, event handling, and endpoint management. """ -from .cdshooks import CDSHooksGateway -from .notereader import NoteReaderGateway +from .cdshooks import CDSHooksService +from .notereader import NoteReaderService from .apiprotocol import ApiProtocol __all__ = [ - "CDSHooksGateway", - "NoteReaderGateway", + "CDSHooksService", + "NoteReaderService", "ApiProtocol", ] diff --git a/healthchain/gateway/protocols/cdshooks.py b/healthchain/gateway/protocols/cdshooks.py index 24b6cedd..328d608f 100644 --- a/healthchain/gateway/protocols/cdshooks.py +++ b/healthchain/gateway/protocols/cdshooks.py @@ -6,20 +6,14 @@ """ import logging -from datetime import datetime -from typing import Dict, List, Optional, Any, Callable, Union, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from fastapi import APIRouter, Body, Depends from pydantic import BaseModel -from fastapi import Depends, Body - -from healthchain.gateway.core.base import BaseGateway -from healthchain.gateway.events.dispatcher import ( - EventDispatcher, - EHREvent, - EHREventType, -) -from healthchain.gateway.api.protocols import GatewayProtocol +from healthchain.gateway.core.base import BaseProtocolHandler +from healthchain.gateway.events.cdshooks import create_cds_hook_event +from healthchain.gateway.events.dispatcher import EventDispatcher from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdsdiscovery import CDSService, CDSServiceInformation from healthchain.models.responses.cdsresponse import CDSResponse @@ -29,20 +23,12 @@ # Type variable for self-referencing return types -T = TypeVar("T", bound="CDSHooksGateway") - - -HOOK_TO_EVENT = { - "patient-view": EHREventType.CDS_PATIENT_VIEW, - "encounter-discharge": EHREventType.CDS_ENCOUNTER_DISCHARGE, - "order-sign": EHREventType.CDS_ORDER_SIGN, - "order-select": EHREventType.CDS_ORDER_SELECT, -} +T = TypeVar("T", bound="CDSHooksService") -# Configuration options for CDS Hooks gateway +# Configuration options for CDS Hooks service class CDSHooksConfig(BaseModel): - """Configuration options for CDS Hooks gateway""" + """Configuration options for CDS Hooks service""" system_type: str = "CDS-HOOKS" base_path: str = "/cds" @@ -51,21 +37,21 @@ class CDSHooksConfig(BaseModel): allowed_hooks: List[str] = UseCaseMapping.ClinicalDecisionSupport.allowed_workflows -class CDSHooksGateway(BaseGateway[CDSRequest, CDSResponse], GatewayProtocol): +class CDSHooksService(BaseProtocolHandler[CDSRequest, CDSResponse], APIRouter): """ - Gateway for CDS Hooks protocol integration. + Service for CDS Hooks protocol integration. - This gateway implements the CDS Hooks standard for integrating clinical decision + This service implements the CDS Hooks standard for integrating clinical decision support with EHR systems. It provides discovery and hook execution endpoints that conform to the CDS Hooks specification. Example: ```python - # Create a CDS Hooks gateway - cds_gateway = CDSHooksGateway() + # Create a CDS Hooks service + cds_service = CDSHooksService() # Register a hook handler - @cds_gateway.hook("patient-view", id="patient-summary") + @cds_service.hook("patient-view", id="patient-summary") def handle_patient_view(request: CDSRequest) -> CDSResponse: # Create cards based on the patient context return CDSResponse( @@ -78,8 +64,8 @@ def handle_patient_view(request: CDSRequest) -> CDSResponse: ] ) - # Register the gateway with the API - app.register_gateway(cds_gateway) + # Register the service with the API + app.register_service(cds_service) ``` """ @@ -91,40 +77,74 @@ def __init__( **options, ): """ - Initialize a new CDS Hooks gateway. + Initialize a new CDS Hooks service. Args: - config: Configuration options for the gateway + config: Configuration options for the service event_dispatcher: Optional event dispatcher for publishing events use_events: Whether to enable event dispatching functionality - **options: Additional options for the gateway + **options: Additional options for the service """ - # Initialize the base gateway - super().__init__(use_events=use_events, **options) + # Initialize the base protocol handler + BaseProtocolHandler.__init__(self, use_events=use_events, **options) # Initialize specific configuration self.config = config or CDSHooksConfig() + + # Initialize APIRouter with configuration + APIRouter.__init__(self, prefix=self.config.base_path, tags=["CDS Hooks"]) + self._handler_metadata = {} # Set event dispatcher if provided if event_dispatcher and use_events: - self.set_event_dispatcher(event_dispatcher) + self.events.set_dispatcher(event_dispatcher) - def set_event_dispatcher(self, event_dispatcher: Optional[EventDispatcher] = None): - """ - Set the event dispatcher for this gateway. + self._register_base_routes() - Args: - event_dispatcher: The event dispatcher to use + def _get_service_dependency(self): + """Create a dependency function that returns this service instance.""" - Returns: - Self, for method chaining - """ - # TODO: This is a hack to avoid inheritance issues. Should find a solution to this. - self.event_dispatcher = event_dispatcher - # Register default handlers if needed - self._register_default_handlers() - return self + def get_self_service(): + return self + + return get_self_service + + def _register_base_routes(self): + """Register base routes for CDS Hooks service.""" + get_self_service = self._get_service_dependency() + + # Discovery endpoint + discovery_path = self.config.discovery_path.lstrip("/") + + @self.get(f"/{discovery_path}", response_model_exclude_none=True) + async def discovery_handler(cds: "CDSHooksService" = Depends(get_self_service)): + """CDS Hooks discovery endpoint.""" + return cds.handle_discovery() + + def _register_hook_route(self, hook_id: str): + """Register a route for a specific hook ID.""" + get_self_service = self._get_service_dependency() + service_path = self.config.service_path.lstrip("/") + endpoint = f"/{service_path}/{hook_id}" + + async def service_handler( + request: CDSRequest = Body(...), + cds: "CDSHooksService" = Depends(get_self_service), + ): + """CDS Hook service endpoint.""" + return cds.handle_request(request) + + self.add_api_route( + path=endpoint, + endpoint=service_handler, + methods=["POST"], + response_model_exclude_none=True, + summary=f"CDS Hook: {hook_id}", + description=f"Execute CDS Hook service: {hook_id}", + ) + + logger.debug(f"Registered CDS Hook endpoint: {self.prefix}{endpoint}") def hook( self, @@ -165,6 +185,9 @@ def decorator(handler): "usage_requirements": usage_requirements, } + # Register the route for this hook + self._register_hook_route(id) + return handler return decorator @@ -208,7 +231,7 @@ def handle_request(self, request: CDSRequest) -> CDSResponse: response = self.handle(hook_type, request=request) # If we have an event dispatcher, emit an event for the hook execution - if self.event_dispatcher and self.use_events: + if self.events.dispatcher and self.use_events: try: self._emit_hook_event(hook_type, request, response) except Exception as e: @@ -336,38 +359,14 @@ def _emit_hook_event( request: The CDSRequest object response: The CDSResponse object """ - # Skip if events are disabled or no dispatcher - if not self.event_dispatcher or not self.use_events: - return - - # Use custom event creator if provided - if self._event_creator: - event = self._event_creator(hook_type, request, response) - if event: - self._run_async_publish(event) - return - - # Get the event type from the mapping - event_type = HOOK_TO_EVENT.get(hook_type, EHREventType.EHR_GENERIC) - - # Create a standard event - event = EHREvent( - event_type=event_type, - source_system="CDS-Hooks", - timestamp=datetime.now(), - payload={ - "hook": hook_type, - "hook_instance": request.hookInstance, - "context": dict(request.context), - }, - metadata={ - "cards_count": len(response.cards) if response.cards else 0, - }, + self.events.emit_event( + create_cds_hook_event, + hook_type, + request, + response, + use_events=self.use_events, ) - # Publish the event - self._run_async_publish(event) - def get_metadata(self) -> List[Dict[str, Any]]: """ Get metadata for all registered hooks. @@ -390,74 +389,3 @@ def get_metadata(self) -> List[Dict[str, Any]]: ) return metadata - - def get_routes(self, path: Optional[str] = None) -> List[tuple]: - """ - Get routes for the CDS Hooks gateway. - - Args: - path: Optional path to add the gateway at (uses config if None) - - Returns: - List of route tuples (path, methods, handler, kwargs) - """ - routes = [] - - # Create a dependency for this specific gateway instance - def get_self_cds(): - return self - - base_path = path or self.config.base_path - if base_path: - base_path = base_path.rstrip("/") - - # Register the discovery endpoint - discovery_path = self.config.discovery_path.lstrip("/") - discovery_endpoint = ( - f"{base_path}/{discovery_path}" if base_path else f"/{discovery_path}" - ) - - # Create handlers with dependency injection - async def discovery_handler(cds: GatewayProtocol = Depends(get_self_cds)): - return cds.handle_discovery() - - routes.append( - ( - discovery_endpoint, - ["GET"], - discovery_handler, - {"response_model_exclude_none": True}, - ) - ) - - # Register service endpoints for each hook - service_path = self.config.service_path.lstrip("/") - for metadata in self.get_metadata(): - hook_id = metadata.get("id") - if hook_id: - service_endpoint = ( - f"{base_path}/{service_path}/{hook_id}" - if base_path - else f"/{service_path}/{hook_id}" - ) - - # Create a handler factory to properly capture hook_id in closure - def create_handler_for_hook(): - async def service_handler( - request: CDSRequest = Body(...), - cds: GatewayProtocol = Depends(get_self_cds), - ): - return cds.handle_request(request) - - return service_handler - - routes.append( - ( - service_endpoint, - ["POST"], - create_handler_for_hook(), - {"response_model_exclude_none": True}, - ) - ) - - return routes diff --git a/healthchain/gateway/protocols/notereader.py b/healthchain/gateway/protocols/notereader.py index 6a7d4b58..bdfcceda 100644 --- a/healthchain/gateway/protocols/notereader.py +++ b/healthchain/gateway/protocols/notereader.py @@ -6,33 +6,32 @@ """ import logging -from typing import Optional, Dict, Any, Callable, TypeVar, Union +from typing import Any, Callable, Dict, Optional, TypeVar, Union + +from pydantic import BaseModel from spyne import Application from spyne.protocol.soap import Soap11 from spyne.server.wsgi import WsgiApplication -from pydantic import BaseModel -from datetime import datetime -from healthchain.gateway.events.dispatcher import EHREvent, EHREventType -from healthchain.gateway.core.base import BaseGateway +from healthchain.gateway.core.base import BaseProtocolHandler from healthchain.gateway.events.dispatcher import EventDispatcher +from healthchain.gateway.events.notereader import create_notereader_event from healthchain.gateway.soap.epiccdsservice import CDSServices -from healthchain.models.requests import CdaRequest -from healthchain.models.responses.cdaresponse import CdaResponse from healthchain.gateway.soap.model.epicclientfault import ClientFault from healthchain.gateway.soap.model.epicserverfault import ServerFault -from healthchain.gateway.api.protocols import SOAPGatewayProtocol +from healthchain.models.requests.cdarequest import CdaRequest +from healthchain.models.responses.cdaresponse import CdaResponse logger = logging.getLogger(__name__) # Type variable for self-referencing return types -T = TypeVar("T", bound="NoteReaderGateway") +T = TypeVar("T", bound="NoteReaderService") class NoteReaderConfig(BaseModel): - """Configuration options for NoteReader gateway""" + """Configuration options for NoteReader service""" service_name: str = "ICDSServices" namespace: str = "urn:epic-com:Common.2013.Services" @@ -40,9 +39,9 @@ class NoteReaderConfig(BaseModel): default_mount_path: str = "/notereader" -class NoteReaderGateway(BaseGateway[CdaRequest, CdaResponse], SOAPGatewayProtocol): +class NoteReaderService(BaseProtocolHandler[CdaRequest, CdaResponse]): """ - Gateway for Epic NoteReader SOAP protocol integration. + Service for Epic NoteReader SOAP protocol integration. Provides SOAP integration with healthcare systems, particularly Epic's NoteReader CDA document processing and other SOAP-based @@ -50,11 +49,11 @@ class NoteReaderGateway(BaseGateway[CdaRequest, CdaResponse], SOAPGatewayProtoco Example: ```python - # Create NoteReader gateway with default configuration - gateway = NoteReaderGateway() + # Create NoteReader service with default configuration + service = NoteReaderService() # Register method handler with decorator - @gateway.method("ProcessDocument") + @service.method("ProcessDocument") def process_document(request: CdaRequest) -> CdaResponse: # Process the document return CdaResponse( @@ -62,8 +61,8 @@ def process_document(request: CdaRequest) -> CdaResponse: error=None ) - # Register the gateway with the API - app.register_gateway(gateway) + # Register the service with the API + app.register_service(service) ``` """ @@ -75,15 +74,15 @@ def __init__( **options, ): """ - Initialize a new NoteReader gateway. + Initialize a new NoteReader service. Args: - config: Configuration options for the gateway + config: Configuration options for the service event_dispatcher: Optional event dispatcher for publishing events use_events: Whether to enable event dispatching functionality - **options: Additional options for the gateway + **options: Additional options for the service """ - # Initialize the base gateway + # Initialize the base protocol handler super().__init__(use_events=use_events, **options) # Initialize specific configuration @@ -92,23 +91,7 @@ def __init__( # Set event dispatcher if provided if event_dispatcher and use_events: - self.set_event_dispatcher(event_dispatcher) - - def set_event_dispatcher(self, event_dispatcher: Optional[EventDispatcher] = None): - """ - Set the event dispatcher for this gateway. - - Args: - event_dispatcher: The event dispatcher to use - - Returns: - Self, for method chaining - """ - # TODO: This is a hack to avoid inheritance issues. Should find a solution to this. - self.event_dispatcher = event_dispatcher - # Register default handlers if needed - self._register_default_handlers() - return self + self.events.set_dispatcher(event_dispatcher) def method(self, method_name: str) -> Callable: """ @@ -251,7 +234,7 @@ def create_wsgi_app(self) -> WsgiApplication: raise ValueError( "No ProcessDocument handler registered. " "You must register a handler before creating the WSGI app. " - "Use @gateway.method('ProcessDocument') to register a handler." + "Use @service.method('ProcessDocument') to register a handler." ) # Create adapter for SOAP service integration @@ -264,7 +247,7 @@ def service_adapter(cda_request: CdaRequest) -> CdaResponse: processed_result = self._process_result(result) # Emit event if we have an event dispatcher - if self.event_dispatcher and self.use_events: + if self.events.dispatcher and self.use_events: self._emit_document_event( "ProcessDocument", cda_request, processed_result ) @@ -300,48 +283,27 @@ def _emit_document_event( request: The CdaRequest object response: The CdaResponse object """ - # Skip if events are disabled or no dispatcher - if not self.event_dispatcher or not self.use_events: - return - - # Use custom event creator if provided - if self._event_creator: - event = self._event_creator(operation, request, response) - if event: - self._run_async_publish(event) - return - - # Create a standard event - event = EHREvent( - event_type=EHREventType.NOTEREADER_PROCESS_NOTE, - source_system="NoteReader", - timestamp=datetime.now(), - payload={ - "operation": operation, - "work_type": request.work_type, - "session_id": request.session_id, - "has_error": response.error is not None, - }, - metadata={ - "service": "NoteReaderService", - "system_type": self.config.system_type, - }, + self.events.emit_event( + create_notereader_event, + operation, + request, + response, + use_events=self.use_events, + system_type=self.config.system_type, ) - # Publish the event - self._run_async_publish(event) - def get_metadata(self) -> Dict[str, Any]: """ - Get metadata for this gateway. + Get metadata for this service. Returns: - Dictionary of gateway metadata + Dictionary of service metadata """ return { - "gateway_type": self.__class__.__name__, + "service_type": self.__class__.__name__, "operations": self.get_capabilities(), "system_type": self.config.system_type, "soap_service": self.config.service_name, + "namespace": self.config.namespace, "mount_path": self.config.default_mount_path, } diff --git a/healthchain/sandbox/clients/ehr.py b/healthchain/sandbox/clients/ehr.py index 419aac32..30c2cfe9 100644 --- a/healthchain/sandbox/clients/ehr.py +++ b/healthchain/sandbox/clients/ehr.py @@ -3,7 +3,7 @@ import httpx -from healthchain.models import CDSRequest +from healthchain.models import CDSRequest, CDSResponse from healthchain.models.responses.cdaresponse import CdaResponse from healthchain.sandbox.base import BaseClient, BaseRequestConstructor from healthchain.sandbox.workflows import Workflow @@ -92,7 +92,13 @@ async def send_request(self, url: str) -> List[Dict]: timeout=timeout, ) response.raise_for_status() - responses.append(response.json()) + response_data = response.json() + try: + cds_response = CDSResponse(**response_data) + responses.append(cds_response.model_dump(exclude_none=True)) + except Exception: + # Fallback to raw response if parsing fails + responses.append(response_data) except httpx.HTTPStatusError as exc: log.error( f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.json()}" diff --git a/healthchain/sandbox/environment.py b/healthchain/sandbox/environment.py index 244ff096..2a852eb3 100644 --- a/healthchain/sandbox/environment.py +++ b/healthchain/sandbox/environment.py @@ -2,7 +2,6 @@ import logging import uuid import httpx -import requests from pathlib import Path from typing import Dict, Optional @@ -127,9 +126,3 @@ def start_sandbox( extension, ) log.info(f"Saved response data at {response_path}/") - - # TODO: may not be relevant anymore - def stop_sandbox(self) -> None: - """Shuts down sandbox instance""" - log.info("Shutting down server...") - requests.get(str(self.api.join("/shutdown"))) diff --git a/mkdocs.yml b/mkdocs.yml index d4ea5412..c5c4ed2b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -18,6 +18,14 @@ nav: - NoteReader Sandbox: cookbook/notereader_sandbox.md - Docs: - Welcome: reference/index.md + - Gateway: + - Overview: reference/gateway/gateway.md + - HealthChainAPI: reference/gateway/api.md + - FHIR Gateway: reference/gateway/fhir_gateway.md + - Events: reference/gateway/events.md + - Protocols: + - CDS Hooks: reference/gateway/cdshooks.md + - SOAP/CDA: reference/gateway/soap_cda.md - Pipeline: - Overview: reference/pipeline/pipeline.md - Data Container: reference/pipeline/data_container.md @@ -42,16 +50,9 @@ nav: - Parsers: reference/interop/parsers.md - Generators: reference/interop/generators.md - Working with xmltodict: reference/interop/xmltodict.md - - Sandbox: - - Overview: reference/sandbox/sandbox.md - - Client: reference/sandbox/client.md - - Service: reference/sandbox/service.md - - Use Cases: - - Overview: reference/sandbox/use_cases/use_cases.md - - Clinical Decision Support: reference/sandbox/use_cases/cds.md - - Clinical Documentation: reference/sandbox/use_cases/clindoc.md - Utilities: - FHIR Helpers: reference/utilities/fhir_helpers.md + - Sandbox: reference/sandbox/sandbox.md - Data Generator: reference/utilities/data_generator.md - API Reference: - api/index.md diff --git a/poetry.lock b/poetry.lock index b6720a89..49e06b35 100644 --- a/poetry.lock +++ b/poetry.lock @@ -59,25 +59,6 @@ files = [ astroid = ["astroid (>=2,<4)"] test = ["astroid (>=2,<4)", "pytest", "pytest-cov", "pytest-xdist"] -[[package]] -name = "attrs" -version = "25.3.0" -description = "Classes Without Boilerplate" -optional = false -python-versions = ">=3.8" -files = [ - {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, - {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, -] - -[package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] - [[package]] name = "babel" version = "2.17.0" @@ -440,6 +421,55 @@ files = [ pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<3.0.0" srsly = ">=2.4.0,<3.0.0" +[[package]] +name = "cryptography" +version = "43.0.3" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +optional = false +python-versions = ">=3.7" +files = [ + {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e1ce50266f4f70bf41a2c6dc4358afadae90e2a1e5342d3c08883df1675374f"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:443c4a81bb10daed9a8f334365fe52542771f25aedaf889fd323a853ce7377d6"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:74f57f24754fe349223792466a709f8e0c093205ff0dca557af51072ff47ab18"}, + {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9762ea51a8fc2a88b70cf2995e5675b38d93bf36bd67d91721c309df184f49bd"}, + {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:81ef806b1fef6b06dcebad789f988d3b37ccaee225695cf3e07648eee0fc6b73"}, + {file = "cryptography-43.0.3-cp37-abi3-win32.whl", hash = "sha256:cbeb489927bd7af4aa98d4b261af9a5bc025bd87f0e3547e11584be9e9427be2"}, + {file = "cryptography-43.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:f46304d6f0c6ab8e52770addfa2fc41e6629495548862279641972b6215451cd"}, + {file = "cryptography-43.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8ac43ae87929a5982f5948ceda07001ee5e83227fd69cf55b109144938d96984"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:846da004a5804145a5f441b8530b4bf35afbf7da70f82409f151695b127213d5"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f7b178f11ed3664fd0e995a47ed2b5ff0a12d893e41dd0494f406d1cf555cab7"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c2e6fc39c4ab499049df3bdf567f768a723a5e8464816e8f009f121a5a9f4405"}, + {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e1be4655c7ef6e1bbe6b5d0403526601323420bcf414598955968c9ef3eb7d16"}, + {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df6b6c6d742395dd77a23ea3728ab62f98379eff8fb61be2744d4679ab678f73"}, + {file = "cryptography-43.0.3-cp39-abi3-win32.whl", hash = "sha256:d56e96520b1020449bbace2b78b603442e7e378a9b3bd68de65c782db1507995"}, + {file = "cryptography-43.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d03b5621a135bffecad2c73e9f4deb1a0f977b9a8ffe6f8e002bf6c9d07b918c"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a2a431ee15799d6db9fe80c82b055bae5a752bef645bba795e8e52687c69efe3"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:281c945d0e28c92ca5e5930664c1cefd85efe80e5c0d2bc58dd63383fda29f83"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f18c716be16bc1fea8e95def49edf46b82fccaa88587a45f8dc0ff6ab5d8e0a7"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a02ded6cd4f0a5562a8887df8b3bd14e822a90f97ac5e544c162899bc467664"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:53a583b6637ab4c4e3591a15bc9db855b8d9dee9a669b550f311480acab6eb08"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1ec0bcf7e17c0c5669d881b1cd38c4972fade441b27bda1051665faaa89bdcaa"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"}, + {file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"}, +] + +[package.dependencies] +cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} + +[package.extras] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"] +docstest = ["pyenchant (>=1.6.11)", "readme-renderer", "sphinxcontrib-spelling (>=4.0.1)"] +nox = ["nox"] +pep8test = ["check-sdist", "click", "mypy", "ruff"] +sdist = ["build"] +ssh = ["bcrypt (>=3.1.5)"] +test = ["certifi", "cryptography-vectors (==43.0.3)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test-randomorder = ["pytest-randomly"] + [[package]] name = "cymem" version = "2.0.11" @@ -1013,6 +1043,19 @@ traitlets = ">=5.3" docs = ["intersphinx-registry", "myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest (<9)", "pytest-cov", "pytest-timeout"] +[[package]] +name = "jwt" +version = "1.3.1" +description = "JSON Web Token library for Python 3." +optional = false +python-versions = ">= 3.6" +files = [ + {file = "jwt-1.3.1-py3-none-any.whl", hash = "sha256:61c9170f92e736b530655e75374681d4fcca9cfa8763ab42be57353b2b203494"}, +] + +[package.dependencies] +cryptography = ">=3.1,<3.4.0 || >3.4.0" + [[package]] name = "langcodes" version = "3.5.0" @@ -1693,20 +1736,6 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] -[[package]] -name = "outcome" -version = "1.3.0.post0" -description = "Capture the outcome of Python function calls." -optional = false -python-versions = ">=3.7" -files = [ - {file = "outcome-1.3.0.post0-py2.py3-none-any.whl", hash = "sha256:e771c5ce06d1415e356078d3bdd68523f284b4ce5419828922b6871e65eda82b"}, - {file = "outcome-1.3.0.post0.tar.gz", hash = "sha256:9dcf02e65f2971b80047b377468e72a268e15c0af3cf1238e6ff14f7f91143b8"}, -] - -[package.dependencies] -attrs = ">=19.2.0" - [[package]] name = "packaging" version = "25.0" @@ -2201,19 +2230,22 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] [[package]] -name = "pytest-anyio" -version = "0.0.0" -description = "The pytest anyio plugin is built into anyio. You don't need this package." +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3"}, - {file = "pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746"}, + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, ] [package.dependencies] -anyio = "*" -pytest = "*" +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] name = "python-dateutil" @@ -2570,18 +2602,18 @@ files = [ [[package]] name = "requests" -version = "2.32.3" +version = "2.32.4" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, - {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, + {file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"}, + {file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"}, ] [package.dependencies] certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" +charset_normalizer = ">=2,<4" idna = ">=2.5,<4" urllib3 = ">=1.21.1,<3" @@ -2712,17 +2744,6 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] -[[package]] -name = "sortedcontainers" -version = "2.4.0" -description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" -optional = false -python-versions = "*" -files = [ - {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, - {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, -] - [[package]] name = "spacy" version = "3.8.7" @@ -3118,26 +3139,6 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] -[[package]] -name = "trio" -version = "0.25.1" -description = "A friendly Python library for async concurrency and I/O" -optional = false -python-versions = ">=3.8" -files = [ - {file = "trio-0.25.1-py3-none-any.whl", hash = "sha256:e42617ba091e7b2e50c899052e83a3c403101841de925187f61e7b7eaebdf3fb"}, - {file = "trio-0.25.1.tar.gz", hash = "sha256:9f5314f014ea3af489e77b001861c535005c3858d38ec46b6b071ebfa339d7fb"}, -] - -[package.dependencies] -attrs = ">=23.2.0" -cffi = {version = ">=1.14", markers = "os_name == \"nt\" and implementation_name != \"pypy\""} -exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} -idna = "*" -outcome = "*" -sniffio = ">=1.3.0" -sortedcontainers = "*" - [[package]] name = "typer" version = "0.16.0" @@ -3423,13 +3424,13 @@ files = [ [[package]] name = "zipp" -version = "3.22.0" +version = "3.23.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" files = [ - {file = "zipp-3.22.0-py3-none-any.whl", hash = "sha256:fe208f65f2aca48b81f9e6fd8cf7b8b32c26375266b009b413d45306b6148343"}, - {file = "zipp-3.22.0.tar.gz", hash = "sha256:dd2f28c3ce4bc67507bfd3781d21b7bb2be31103b51a4553ad7d90b84e57ace5"}, + {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, + {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, ] [package.extras] @@ -3437,10 +3438,10 @@ check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "importlib_resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "da53bb58ad4735ea5fb701ffa281813c23ae66363f9456da5b2fc6da1573b771" +content-hash = "9393b4f9f835e103b8fc351d583cbbe4fdf8592ae600b5d7f82e75dedd77a256" diff --git a/pyproject.toml b/pyproject.toml index 2ded2331..c2a679f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ jinja2 = "^3.1.3" fastapi = "^0.115.3" starlette = ">=0.40.0,<0.42.0" uvicorn = "^0.24.0" -requests = "^2.31.0" httpx = "^0.27.0" spyne = "^2.14.0" lxml = "^5.2.2" @@ -45,13 +44,13 @@ fhir-resources = "^8.0.0" python-liquid = "^1.13.0" regex = "!=2019.12.17" fastapi-events = "^0.12.2" +jwt = "^1.3.1" [tool.poetry.group.dev.dependencies] ruff = "^0.4.2" pytest = "^8.2.0" pre-commit = "^3.5.0" -pytest-anyio = "^0.0.0" -trio = "^0.25.0" +pytest-asyncio = "^0.24.0" ipykernel = "^6.29.5" [tool.poetry.group.docs.dependencies] diff --git a/tests/gateway/test_api_app.py b/tests/gateway/test_api_app.py index f93c8fbc..06556f13 100644 --- a/tests/gateway/test_api_app.py +++ b/tests/gateway/test_api_app.py @@ -1,19 +1,13 @@ -""" -Tests for the HealthChainAPI class with dependency injection. - -This module contains tests for the HealthChainAPI class, focusing on -testing with dependency injection. -""" +"""Tests for the HealthChainAPI class.""" import pytest from unittest.mock import AsyncMock -from fastapi import Depends, APIRouter, HTTPException +from fastapi import Depends, HTTPException from fastapi.testclient import TestClient -from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError -from healthchain.gateway.api.app import create_app, HealthChainAPI +from healthchain.gateway.api.app import HealthChainAPI from healthchain.gateway.api.dependencies import ( - get_app, get_event_dispatcher, get_gateway, get_all_gateways, @@ -22,49 +16,23 @@ from healthchain.gateway.core.base import BaseGateway -# Custom create_app function for testing -def create_app_for_testing(enable_events=True, event_dispatcher=None, app_class=None): - """Create a test app with optional custom app class.""" - if app_class is None: - # Use the default HealthChainAPI class - return create_app( - enable_events=enable_events, event_dispatcher=event_dispatcher - ) - - # Use a custom app class - app_config = { - "title": "Test HealthChain API", - "description": "Test API", - "version": "0.1.0", - "docs_url": "/docs", - "redoc_url": "/redoc", - "enable_events": enable_events, - "event_dispatcher": event_dispatcher, - } - return app_class(**app_config) - - class MockGateway(BaseGateway): """Mock gateway for testing.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "MockGateway" - self.event_dispatcher = None - - def get_metadata(self): - return {"type": "mock", "version": "1.0.0"} + self.startup_called = False + self.shutdown_called = False - def set_event_dispatcher(self, dispatcher): - self.event_dispatcher = dispatcher + async def startup(self): + self.startup_called = True + async def shutdown(self): + self.shutdown_called = True -class AnotherMockGateway(BaseGateway): - """Another mock gateway for testing.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.name = "AnotherMockGateway" + def get_metadata(self): + return {"type": "mock", "version": "1.0.0"} class MockEventDispatcher(EventDispatcher): @@ -79,216 +47,132 @@ def init_app(self, app): @pytest.fixture -def mock_event_dispatcher(): - """Create a mock event dispatcher.""" +def mock_dispatcher(): return MockEventDispatcher() @pytest.fixture def mock_gateway(): - """Create a mock gateway.""" return MockGateway() @pytest.fixture -def test_app(mock_event_dispatcher, mock_gateway): - """Create a test app with mocked dependencies.""" - - # Create a test subclass that overrides _shutdown to avoid termination - class SafeHealthChainAPI(HealthChainAPI): - def _shutdown(self): - # Override to avoid termination - return JSONResponse(content={"message": "Server is shutting down..."}) - - # Create the app with the safe implementation - app = create_app_for_testing( +def app(mock_dispatcher, mock_gateway): + """Create test app with mocked dependencies.""" + app = HealthChainAPI( + title="Test API", + version="0.1.0", enable_events=True, - event_dispatcher=mock_event_dispatcher, - app_class=SafeHealthChainAPI, + event_dispatcher=mock_dispatcher, ) app.register_gateway(mock_gateway) return app @pytest.fixture -def client(test_app): - """Create a test client.""" - return TestClient(test_app) +def client(app): + return TestClient(app) -def test_app_creation(): - """Test that the app can be created with custom dependencies.""" - - # Create a test subclass that overrides _shutdown to avoid termination - class SafeHealthChainAPI(HealthChainAPI): - def _shutdown(self): - # Override to avoid termination - return JSONResponse(content={"message": "Server is shutting down..."}) - - mock_dispatcher = MockEventDispatcher() - app = create_app_for_testing( - enable_events=True, - event_dispatcher=mock_dispatcher, - app_class=SafeHealthChainAPI, - ) - +def test_app_creation(mock_dispatcher): + """Test app creation with custom dependencies.""" + app = HealthChainAPI(enable_events=True, event_dispatcher=mock_dispatcher) assert app.get_event_dispatcher() is mock_dispatcher assert app.enable_events is True -def test_dependency_injection_get_app(test_app): - """Test that get_app dependency returns the app.""" - # Override dependency to return our test app - test_app.dependency_overrides[get_app] = lambda: test_app - - with TestClient(test_app) as client: - response = client.get("/health") - assert response.status_code == 200 - +def test_lifespan_startup_shutdown(): + """Test lifespan events call startup and shutdown.""" + gateway = MockGateway() + app = HealthChainAPI() + app.register_gateway(gateway) -def test_dependency_injection_event_dispatcher(test_app, mock_event_dispatcher): - """Test that get_event_dispatcher dependency returns the event dispatcher.""" + with TestClient(app) as client: + assert gateway.startup_called + assert client.get("/health").status_code == 200 - # Create a test route that uses the dependency - @test_app.get("/test-event-dispatcher") - def test_route(dispatcher=Depends(get_event_dispatcher)): - assert dispatcher is mock_event_dispatcher - return {"success": True} + assert gateway.shutdown_called - with TestClient(test_app) as client: - response = client.get("/test-event-dispatcher") - assert response.status_code == 200 - assert response.json() == {"success": True} +def test_dependency_injection(app, mock_dispatcher, mock_gateway): + """Test dependency injection works correctly.""" -def test_dependency_injection_gateway(test_app, mock_gateway): - """Test that get_gateway dependency returns the gateway.""" + @app.get("/test-dispatcher") + def test_dispatcher(dispatcher=Depends(get_event_dispatcher)): + assert dispatcher is mock_dispatcher + return {"success": True} - # Create a test route that uses the dependency - @test_app.get("/test-gateway/{gateway_name}") - def test_route(gateway_name: str, gateway=Depends(get_gateway)): + @app.get("/test-gateway") + def test_gateway(gateway=Depends(get_gateway)): assert gateway is mock_gateway return {"success": True} - with TestClient(test_app) as client: - response = client.get("/test-gateway/MockGateway") - assert response.status_code == 200 - assert response.json() == {"success": True} - - -def test_dependency_injection_all_gateways(test_app, mock_gateway): - """Test that get_all_gateways dependency returns all gateways.""" - - # Create a test route that uses the dependency - @test_app.get("/test-all-gateways") - def test_route(gateways=Depends(get_all_gateways)): + @app.get("/test-all-gateways") + def test_all_gateways(gateways=Depends(get_all_gateways)): assert "MockGateway" in gateways assert gateways["MockGateway"] is mock_gateway return {"success": True} - with TestClient(test_app) as client: - response = client.get("/test-all-gateways") - assert response.status_code == 200 - assert response.json() == {"success": True} - - -def test_root_endpoint(client): - """Test the root endpoint returns gateway information.""" - response = client.get("/") - assert response.status_code == 200 - assert "MockGateway" in response.json()["gateways"] - - -def test_metadata_endpoint(client): - """Test the metadata endpoint returns gateway information.""" - response = client.get("/metadata") - assert response.status_code == 200 + with TestClient(app) as client: + assert client.get("/test-dispatcher").json() == {"success": True} + assert client.get("/test-gateway?gateway_name=MockGateway").json() == { + "success": True + } + assert client.get("/test-all-gateways").json() == {"success": True} - data = response.json() - assert data["resourceType"] == "CapabilityStatement" - assert "MockGateway" in data["gateways"] - assert data["gateways"]["MockGateway"]["type"] == "mock" +def test_endpoints(client): + """Test default API endpoints.""" + # Root endpoint + root = client.get("/").json() + assert "MockGateway" in root["gateways"] -def test_register_gateway(test_app): - """Test registering a gateway.""" - # Create a gateway instance - another_gateway = AnotherMockGateway() + # Health endpoint + assert client.get("/health").json() == {"status": "healthy"} - # Register it with the app - test_app.register_gateway(another_gateway) + # Metadata endpoint + metadata = client.get("/metadata").json() + assert metadata["resourceType"] == "CapabilityStatement" + assert "MockGateway" in metadata["gateways"] - # Verify it was registered - assert "AnotherMockGateway" in test_app.gateways - assert test_app.gateways["AnotherMockGateway"] is another_gateway +def test_register_gateway(app): + """Test gateway registration.""" -def test_register_router(test_app): - """Test registering a router.""" - # Create a router - router = APIRouter(prefix="/test-router", tags=["test"]) - - @router.get("/test") - def test_route(): - return {"message": "Router test"} - - # Register the router - test_app.register_router(router) + class TestGateway(BaseGateway): + pass - # Test the route - with TestClient(test_app) as client: - response = client.get("/test-router/test") - assert response.status_code == 200 - assert response.json() == {"message": "Router test"} + gateway = TestGateway() + app.register_gateway(gateway) + assert "TestGateway" in app.gateways -def test_exception_handling(test_app): - """Test the exception handling middleware.""" +def test_exception_handling(app): + """Test unified exception handling.""" - # Add a route that raises an exception - @test_app.get("/test-error") - def error_route(): + @app.get("/http-error") + def http_error(): raise HTTPException(status_code=400, detail="Test error") - # Add a route that raises an unexpected exception - @test_app.get("/test-unexpected-error") - def unexpected_error_route(): - raise ValueError("Unexpected test error") + @app.get("/validation-error") + def validation_error(): + raise RequestValidationError([{"msg": "test validation error"}]) - with TestClient(test_app) as client: - # Test HTTP exception handling - response = client.get("/test-error") + with TestClient(app) as client: + # HTTP exception + response = client.get("/http-error") assert response.status_code == 400 assert response.json() == {"detail": "Test error"} - # Test unexpected exception handling - with pytest.raises(ValueError): - response = client.get("/test-unexpected-error") - assert response.status_code == 500 - assert response.json() == {"detail": "Internal server error"} - - -def test_gateway_event_dispatcher_integration(mock_event_dispatcher): - """Test that gateways receive the event dispatcher when registered.""" + # Validation exception + response = client.get("/validation-error") + assert response.status_code == 422 + assert "detail" in response.json() - # Create a test subclass that overrides _shutdown to avoid termination - class SafeHealthChainAPI(HealthChainAPI): - def _shutdown(self): - # Override to avoid termination - return JSONResponse(content={"message": "Server is shutting down..."}) - # Create a gateway +def test_event_dispatcher_integration(mock_dispatcher): + """Test gateway receives event dispatcher.""" gateway = MockGateway() - - # Create app with events enabled - app = create_app_for_testing( - enable_events=True, - event_dispatcher=mock_event_dispatcher, - app_class=SafeHealthChainAPI, - ) - - # Register gateway + app = HealthChainAPI(enable_events=True, event_dispatcher=mock_dispatcher) app.register_gateway(gateway) - - # Check that gateway received the event dispatcher - assert gateway.event_dispatcher is mock_event_dispatcher + assert gateway.events.dispatcher is mock_dispatcher diff --git a/tests/gateway/test_auth.py b/tests/gateway/test_auth.py new file mode 100644 index 00000000..3532b646 --- /dev/null +++ b/tests/gateway/test_auth.py @@ -0,0 +1,424 @@ +""" +Tests for the OAuth2 authentication module in the HealthChain gateway system. + +This module tests OAuth2 token management, configuration, and connection string parsing. +""" + +import pytest +import tempfile +import os +from unittest.mock import patch, Mock +from datetime import datetime, timedelta + +from healthchain.gateway.clients.auth import ( + OAuth2Config, + TokenInfo, + OAuth2TokenManager, + FHIRAuthConfig, + parse_fhir_auth_connection_string, +) + +# Configure pytest-asyncio for async tests only (sync tests don't need the mark) + + +@pytest.fixture +def oauth2_config(): + """Create a basic OAuth2 configuration for testing.""" + return OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/oauth/token", + scope="system/*.read", + audience="https://example.com/fhir", + ) + + +@pytest.fixture +def oauth2_config_jwt(): + """Create an OAuth2 configuration for JWT assertion testing.""" + return OAuth2Config( + client_id="test_client", + client_secret_path="/path/to/private.pem", + token_url="https://example.com/oauth/token", + scope="system/*.read", + audience="https://example.com/fhir", + use_jwt_assertion=True, + ) + + +@pytest.fixture +def token_manager(oauth2_config): + """Create an OAuth2TokenManager for testing.""" + return OAuth2TokenManager(oauth2_config) + + +@pytest.fixture +def token_manager_jwt(oauth2_config_jwt): + """Create an OAuth2TokenManager for JWT testing.""" + return OAuth2TokenManager(oauth2_config_jwt) + + +@pytest.fixture +def mock_token_response(): + """Create a mock token response.""" + return { + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "system/*.read", + } + + +@pytest.fixture +def temp_key_file(): + """Create a temporary private key file for testing.""" + key_content = """-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4f6a8v... +-----END PRIVATE KEY-----""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as f: + f.write(key_content) + temp_path = f.name + + yield temp_path + + # Cleanup + os.unlink(temp_path) + + +# Core Validation Tests +@pytest.mark.parametrize( + "config_args,expected_error", + [ + # Missing both secrets + ( + {"client_id": "test", "token_url": "https://example.com/token"}, + "Either client_secret or client_secret_path must be provided", + ), + # Both secrets provided + ( + { + "client_id": "test", + "client_secret": "secret", + "client_secret_path": "/path", + "token_url": "https://example.com/token", + }, + "Cannot provide both client_secret and client_secret_path", + ), + # JWT without path + ( + { + "client_id": "test", + "client_secret": "secret", + "token_url": "https://example.com/token", + "use_jwt_assertion": True, + }, + "use_jwt_assertion=True requires client_secret_path to be set", + ), + # Path without JWT + ( + { + "client_id": "test", + "client_secret_path": "/path", + "token_url": "https://example.com/token", + "use_jwt_assertion": False, + }, + "client_secret_path can only be used with use_jwt_assertion=True", + ), + ], +) +def test_oauth2_config_validation_rules(config_args, expected_error): + """OAuth2Config enforces validation rules for secret configuration.""" + with pytest.raises(ValueError, match=expected_error): + OAuth2Config(**config_args) + + +def test_oauth2_config_secret_value_reads_from_file(temp_key_file): + """OAuth2Config reads secret from file when client_secret_path is provided.""" + config = OAuth2Config( + client_id="test_client", + client_secret_path=temp_key_file, + token_url="https://example.com/token", + use_jwt_assertion=True, + ) + secret_value = config.secret_value + assert "BEGIN PRIVATE KEY" in secret_value + assert "END PRIVATE KEY" in secret_value + + +def test_oauth2_config_secret_value_handles_file_errors(): + """OAuth2Config raises clear error when file cannot be read.""" + config = OAuth2Config( + client_id="test_client", + client_secret_path="/nonexistent/file.pem", + token_url="https://example.com/token", + use_jwt_assertion=True, + ) + with pytest.raises(ValueError, match="Failed to read secret from"): + _ = config.secret_value + + +# Token Management Core Tests +def test_token_info_expiration_logic(): + """TokenInfo correctly calculates expiration with buffer.""" + # Test near-expiry with buffer + near_expiry_token = TokenInfo( + access_token="test_token", + expires_in=240, + expires_at=datetime.now() + timedelta(minutes=4), + ) + assert near_expiry_token.is_expired( + buffer_seconds=300 + ) # 5 min buffer, expires in 4 + assert not near_expiry_token.is_expired( + buffer_seconds=120 + ) # 2 min buffer, expires in 4 + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient.post") +async def test_oauth2_token_manager_standard_flow( + mock_post, token_manager, mock_token_response +): + """OAuth2TokenManager performs standard client credentials flow correctly.""" + # Mock successful response + mock_response = Mock() + mock_response.json.return_value = mock_token_response + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + token = await token_manager.get_access_token() + + # Verify token returned + assert token == "test_access_token" + + # Verify correct request data for standard flow + call_args = mock_post.call_args + request_data = call_args[1]["data"] + assert request_data["grant_type"] == "client_credentials" + assert request_data["client_id"] == "test_client" + assert request_data["client_secret"] == "test_secret" + assert "client_assertion" not in request_data + + +@pytest.mark.asyncio +@patch("healthchain.gateway.clients.auth.OAuth2TokenManager._create_jwt_assertion") +@patch("httpx.AsyncClient.post") +async def test_oauth2_token_manager_jwt_flow( + mock_post, mock_create_jwt, token_manager_jwt, mock_token_response +): + """OAuth2TokenManager performs JWT assertion flow correctly.""" + mock_create_jwt.return_value = "mock_jwt_assertion" + + # Mock successful response + mock_response = Mock() + mock_response.json.return_value = mock_token_response + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + token = await token_manager_jwt.get_access_token() + assert token == "test_access_token" + + # Verify JWT-specific request data + call_args = mock_post.call_args + request_data = call_args[1]["data"] + assert request_data["grant_type"] == "client_credentials" + assert ( + request_data["client_assertion_type"] + == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + ) + assert request_data["client_assertion"] == "mock_jwt_assertion" + assert "client_secret" not in request_data + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient.post") +async def test_oauth2_token_manager_caching_and_refresh( + mock_post, token_manager, mock_token_response +): + """OAuth2TokenManager caches valid tokens and refreshes expired ones.""" + # Set up valid cached token + token_manager._token = TokenInfo( + access_token="cached_token", + expires_in=3600, + expires_at=datetime.now() + timedelta(hours=1), + ) + + # Should use cached token + token = await token_manager.get_access_token() + assert token == "cached_token" + mock_post.assert_not_called() + + # Set expired token + token_manager._token = TokenInfo( + access_token="expired_token", + expires_in=3600, + expires_at=datetime.now() - timedelta(minutes=10), + ) + + # Mock refresh response + mock_response = Mock() + mock_response.json.return_value = mock_token_response + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Should refresh token + token = await token_manager.get_access_token() + assert token == "test_access_token" + mock_post.assert_called_once() + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient.post") +async def test_oauth2_token_manager_error_handling(mock_post, token_manager): + """OAuth2TokenManager handles HTTP errors gracefully.""" + from httpx import HTTPStatusError, Request + + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + + mock_post.side_effect = HTTPStatusError( + "401 Unauthorized", request=Mock(spec=Request), response=mock_response + ) + + with pytest.raises(Exception, match="Failed to refresh token: 401"): + await token_manager.get_access_token() + + +@patch("jwt.JWT.encode") +@patch("jwt.jwk_from_pem") +def test_oauth2_token_manager_jwt_assertion_creation( + mock_jwk_from_pem, mock_jwt_encode, token_manager_jwt, temp_key_file +): + """OAuth2TokenManager creates valid JWT assertions with correct claims.""" + token_manager_jwt.config.client_secret_path = temp_key_file + + mock_key = Mock() + mock_jwk_from_pem.return_value = mock_key + mock_jwt_encode.return_value = "signed_jwt_token" + + jwt_assertion = token_manager_jwt._create_jwt_assertion() + + assert jwt_assertion == "signed_jwt_token" + + # Verify JWT claims structure + call_args = mock_jwt_encode.call_args[0] + claims = call_args[0] + assert claims["iss"] == "test_client" + assert claims["sub"] == "test_client" + assert claims["aud"] == "https://example.com/oauth/token" + assert "jti" in claims + assert "iat" in claims + assert "exp" in claims + + +# FHIR Config Tests (Core Validation Only) +def test_fhir_auth_config_validation_mirrors_oauth2_config(): + """FHIRAuthConfig enforces same validation rules as OAuth2Config.""" + # Should fail with same validation error + with pytest.raises( + ValueError, match="Either client_secret or client_secret_path must be provided" + ): + FHIRAuthConfig( + client_id="test_client", + token_url="https://example.com/token", + base_url="https://example.com/fhir/R4", + ) + + +def test_fhir_auth_config_to_oauth2_config_conversion(): + """FHIRAuthConfig correctly converts to OAuth2Config preserving all auth settings.""" + fhir_config = FHIRAuthConfig( + client_id="test_client", + client_secret_path="/path/to/private.pem", + token_url="https://example.com/token", + base_url="https://example.com/fhir/R4", + use_jwt_assertion=True, + scope="custom_scope", + audience="custom_audience", + ) + + oauth2_config = fhir_config.to_oauth2_config() + + # Verify auth-related fields are preserved + assert oauth2_config.client_id == fhir_config.client_id + assert oauth2_config.client_secret_path == fhir_config.client_secret_path + assert oauth2_config.token_url == fhir_config.token_url + assert oauth2_config.use_jwt_assertion == fhir_config.use_jwt_assertion + assert oauth2_config.scope == fhir_config.scope + assert oauth2_config.audience == fhir_config.audience + + +# Connection String Parsing Tests (Core Functionality) +@pytest.mark.parametrize( + "connection_string,expected_error", + [ + # Invalid scheme + ("invalid://not-fhir", "Connection string must start with fhir://"), + # Missing required params + ( + "fhir://example.com/fhir/R4?client_id=test_client", + "Missing required parameters", + ), + # Missing secrets + ( + "fhir://example.com/fhir/R4?client_id=test&token_url=https://example.com/token", + "Either 'client_secret' or 'client_secret_path' parameter must be provided", + ), + # Both secrets + ( + "fhir://example.com/fhir/R4?client_id=test&client_secret=secret&client_secret_path=/path&token_url=https://example.com/token", + "Cannot provide both 'client_secret' and 'client_secret_path' parameters", + ), + ], +) +def test_connection_string_parsing_validation(connection_string, expected_error): + """Connection string parsing enforces validation rules.""" + with pytest.raises(ValueError, match=expected_error): + parse_fhir_auth_connection_string(connection_string) + + +def test_connection_string_parsing_handles_both_auth_types(): + """Connection string parsing correctly handles both standard and JWT authentication.""" + # Standard auth + standard_string = "fhir://example.com/fhir/R4?client_id=test&client_secret=secret&token_url=https://example.com/token" + standard_config = parse_fhir_auth_connection_string(standard_string) + assert standard_config.client_secret == "secret" + assert standard_config.client_secret_path is None + assert not standard_config.use_jwt_assertion + + # JWT auth + jwt_string = ( + "fhir://example.com/fhir/R4?client_id=test&client_secret_path=/path/key.pem&" + "token_url=https://example.com/token&use_jwt_assertion=true" + ) + jwt_config = parse_fhir_auth_connection_string(jwt_string) + assert jwt_config.client_secret is None + assert jwt_config.client_secret_path == "/path/key.pem" + assert jwt_config.use_jwt_assertion + + +def test_connection_string_parsing_handles_complex_parameters(): + """Connection string parsing correctly handles all parameters and URL encoding.""" + connection_string = ( + "fhir://example.com:8080/fhir/R4?" + "client_id=test%20client&" + "client_secret=test%20secret&" + "token_url=https%3A//example.com/token&" + "scope=system%2F*.read&" + "audience=https://example.com/fhir&" + "timeout=60&" + "verify_ssl=false" + ) + + config = parse_fhir_auth_connection_string(connection_string) + + assert config.client_id == "test client" # URL decoded + assert config.client_secret == "test secret" # URL decoded + assert config.token_url == "https://example.com/token" # URL decoded + assert config.scope == "system/*.read" # URL decoded + assert config.base_url == "https://example.com:8080/fhir/R4" + assert config.audience == "https://example.com/fhir" + assert config.timeout == 60 + assert not config.verify_ssl diff --git a/tests/gateway/test_cdshooks.py b/tests/gateway/test_cdshooks.py index a1c6cf20..4ef07fc4 100644 --- a/tests/gateway/test_cdshooks.py +++ b/tests/gateway/test_cdshooks.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from healthchain.gateway.protocols.cdshooks import ( - CDSHooksGateway, + CDSHooksService, CDSHooksConfig, ) from healthchain.gateway.events.dispatcher import EventDispatcher @@ -11,70 +11,90 @@ from healthchain.models.responses.cdsdiscovery import CDSServiceInformation -def test_cdshooks_gateway_initialization(): - """Test CDSHooksGateway initialization with default config""" - gateway = CDSHooksGateway() +@pytest.mark.parametrize( + "config_args,expected_paths", + [ + # Default config + ( + {}, + { + "base_path": "/cds", + "discovery_path": "/cds-discovery", + "service_path": "/cds-services", + }, + ), + # Custom config + ( + { + "base_path": "/custom-cds", + "discovery_path": "/custom-discovery", + "service_path": "/custom-services", + }, + { + "base_path": "/custom-cds", + "discovery_path": "/custom-discovery", + "service_path": "/custom-services", + }, + ), + ], +) +def test_cdshooks_service_configuration(config_args, expected_paths): + """CDSHooksService supports both default and custom path configurations.""" + if config_args: + config = CDSHooksConfig(**config_args) + gateway = CDSHooksService(config=config) + else: + gateway = CDSHooksService.create() + + assert isinstance(gateway, CDSHooksService) assert isinstance(gateway.config, CDSHooksConfig) assert gateway.config.system_type == "CDS-HOOKS" - assert gateway.config.base_path == "/cds" - assert gateway.config.discovery_path == "/cds-discovery" - assert gateway.config.service_path == "/cds-services" - -def test_cdshooks_gateway_create(): - """Test CDSHooksGateway.create factory method""" - gateway = CDSHooksGateway.create() - assert isinstance(gateway, CDSHooksGateway) - assert isinstance(gateway.config, CDSHooksConfig) + for path_name, expected_value in expected_paths.items(): + assert getattr(gateway.config, path_name) == expected_value -def test_cdshooks_gateway_hook_decorator(): - """Test hook decorator for registering handlers""" - gateway = CDSHooksGateway() +def test_cdshooks_hook_decorator_with_metadata_variants(): + """Hook decorator supports default and custom metadata configurations.""" + gateway = CDSHooksService() + # Default metadata @gateway.hook("patient-view", id="test-patient-view") - def handle_patient_view(request): + def handle_patient_view_default(request): return CDSResponse(cards=[]) - # Verify handler is registered - assert "patient-view" in gateway._handlers - assert "patient-view" in gateway._handler_metadata - assert gateway._handler_metadata["patient-view"]["id"] == "test-patient-view" - assert gateway._handler_metadata["patient-view"]["title"] == "Patient View" - assert ( - gateway._handler_metadata["patient-view"]["description"] - == "CDS Hook service created by HealthChain" - ) - - -def test_cdshooks_gateway_hook_with_custom_metadata(): - """Test hook decorator with custom metadata""" - gateway = CDSHooksGateway() - + # Custom metadata @gateway.hook( - "patient-view", + "order-select", id="custom-id", title="Custom Title", description="Custom description", usage_requirements="Requires patient context", ) - def handle_patient_view(request): + def handle_order_select_custom(request): return CDSResponse(cards=[]) - assert gateway._handler_metadata["patient-view"]["id"] == "custom-id" - assert gateway._handler_metadata["patient-view"]["title"] == "Custom Title" - assert ( - gateway._handler_metadata["patient-view"]["description"] == "Custom description" - ) - assert ( - gateway._handler_metadata["patient-view"]["usage_requirements"] - == "Requires patient context" - ) + # Verify both handlers registered correctly + assert "patient-view" in gateway._handlers + assert "order-select" in gateway._handlers + + # Check default metadata + default_meta = gateway._handler_metadata["patient-view"] + assert default_meta["id"] == "test-patient-view" + assert default_meta["title"] == "Patient View" + assert default_meta["description"] == "CDS Hook service created by HealthChain" + + # Check custom metadata + custom_meta = gateway._handler_metadata["order-select"] + assert custom_meta["id"] == "custom-id" + assert custom_meta["title"] == "Custom Title" + assert custom_meta["description"] == "Custom description" + assert custom_meta["usage_requirements"] == "Requires patient context" def test_cdshooks_gateway_handle_request(test_cds_request): """Test request handler endpoint""" - gateway = CDSHooksGateway() + gateway = CDSHooksService() # Register a handler with the hook decorator @gateway.hook("patient-view", id="test-patient-view") @@ -96,7 +116,7 @@ def handle_patient_view(request): def test_cdshooks_gateway_handle_discovery(): """Test discovery endpoint handler""" - gateway = CDSHooksGateway() + gateway = CDSHooksService() # Register sample hooks @gateway.hook("patient-view", id="test-patient-view", title="Patient View") @@ -123,56 +143,58 @@ def handle_order_select(request): assert hooks["order-select"].title == "Order Select" -def test_cdshooks_gateway_get_routes(): - """Test that CDSHooksGateway correctly returns routes with get_routes method""" - gateway = CDSHooksGateway() +def test_cdshooks_gateway_routing_and_custom_paths(): + """CDSHooksService works as APIRouter with correct route registration.""" + # Test default paths + gateway = CDSHooksService() - # Register sample hooks @gateway.hook("patient-view", id="test-patient-view") def handle_patient_view(request): return CDSResponse(cards=[]) - # Get routes from gateway - routes = gateway.get_routes() + # Verify gateway is now an APIRouter + from fastapi import APIRouter - # Should return at least 2 routes (discovery endpoint and hook endpoint) - assert len(routes) >= 2 + assert isinstance(gateway, APIRouter) - # Verify discovery endpoint - discovery_routes = [r for r in routes if "GET" in r[1]] - assert len(discovery_routes) >= 1 - discovery_route = discovery_routes[0] - assert discovery_route[1] == ["GET"] # HTTP method is GET + # Verify routes are registered directly in the router + assert hasattr(gateway, "routes") + assert len(gateway.routes) >= 2 - # Verify hook endpoint - hook_routes = [r for r in routes if "POST" in r[1]] - assert len(hook_routes) >= 1 - hook_route = hook_routes[0] - assert hook_route[1] == ["POST"] # HTTP method is POST - assert "test-patient-view" in hook_route[0] # Route path contains hook ID + # Check that routes have been registered + route_paths = [route.path for route in gateway.routes] + route_methods = [list(route.methods)[0] for route in gateway.routes] + # Should have discovery endpoint + assert any("cds-discovery" in path for path in route_paths) + assert "GET" in route_methods -def test_cdshooks_gateway_custom_base_path(): - """Test CDSHooksGateway with custom base path""" - config = CDSHooksConfig( + # Should have hook service endpoint + assert any("test-patient-view" in path for path in route_paths) + assert "POST" in route_methods + + # Test custom paths + custom_config = CDSHooksConfig( base_path="/custom-cds", discovery_path="/custom-discovery", service_path="/custom-services", ) - gateway = CDSHooksGateway(config=config) + custom_gateway = CDSHooksService(config=custom_config) - @gateway.hook("patient-view", id="test-service") - def handle_patient_view(request): + @custom_gateway.hook("patient-view", id="test-service") + def handle_custom_patient_view(request): return CDSResponse(cards=[]) - routes = gateway.get_routes() + # Verify custom gateway has correct prefix + assert custom_gateway.prefix == "/custom-cds" - # Check that custom paths are used in routes - discovery_route = [r for r in routes if "GET" in r[1]][0] - assert discovery_route[0] == "/custom-cds/custom-discovery" + # Verify routes exist + custom_route_paths = [route.path for route in custom_gateway.routes] + assert any("custom-discovery" in path for path in custom_route_paths) + assert any("test-service" in path for path in custom_route_paths) - service_route = [r for r in routes if "POST" in r[1]][0] - assert "/custom-cds/custom-services/test-service" in service_route[0] + # Verify get_routes() method no longer exists + assert not hasattr(gateway, "get_routes") def test_cdshooks_gateway_event_emission(): @@ -181,7 +203,7 @@ def test_cdshooks_gateway_event_emission(): mock_dispatcher = MagicMock(spec=EventDispatcher) # Create gateway with event dispatcher - gateway = CDSHooksGateway(event_dispatcher=mock_dispatcher) + gateway = CDSHooksService(event_dispatcher=mock_dispatcher) # Register a handler @gateway.hook("patient-view", id="test-service") @@ -202,13 +224,12 @@ def handle_patient_view(request): # Handle the request gateway.handle_request(request) - # Verify event was dispatched - assert mock_dispatcher.publish.called or mock_dispatcher.publish_async.called + assert mock_dispatcher.emit.called def test_cdshooks_gateway_hook_invalid_hook_type(): """Test hook decorator with invalid hook type""" - gateway = CDSHooksGateway() + gateway = CDSHooksService() # Try to register an invalid hook type with pytest.raises(ValueError): @@ -216,63 +237,3 @@ def test_cdshooks_gateway_hook_invalid_hook_type(): @gateway.hook("invalid-hook-type", id="test") def handle_invalid(request): return CDSResponse(cards=[]) - - -def test_cdshooks_gateway_handle_with_direct_request(): - """Test handling a CDSRequest directly with the handle method""" - gateway = CDSHooksGateway() - - # Register a handler - @gateway.hook("patient-view", id="test-service") - def handle_patient_view(request): - return CDSResponse( - cards=[ - Card(summary="Direct test", indicator="info", source={"label": "Test"}) - ] - ) - - # Create a test request - request = CDSRequest( - hook="patient-view", - hookInstance="test-instance", - context={"patientId": "123", "userId": "456"}, - ) - - # Handle the request directly with the handle method - result = gateway.handle("patient-view", request=request) - - # Verify response - assert isinstance(result, CDSResponse) - assert len(result.cards) == 1 - assert result.cards[0].summary == "Direct test" - - -def test_cdshooks_gateway_get_metadata(): - """Test retrieving metadata for registered hooks""" - gateway = CDSHooksGateway() - - # Register handlers with different metadata - @gateway.hook("patient-view", id="patient-service", title="Patient Service") - def handle_patient_view(request): - return CDSResponse(cards=[]) - - @gateway.hook("order-select", id="order-service", description="Custom description") - def handle_order_select(request): - return CDSResponse(cards=[]) - - # Get metadata - metadata = gateway.get_metadata() - - # Verify metadata contains both services - assert len(metadata) == 2 - - # Find each service by hook type - patient_metadata = next(item for item in metadata if item["hook"] == "patient-view") - order_metadata = next(item for item in metadata if item["hook"] == "order-select") - - # Verify metadata values - assert patient_metadata["id"] == "patient-service" - assert patient_metadata["title"] == "Patient Service" - - assert order_metadata["id"] == "order-service" - assert order_metadata["description"] == "Custom description" diff --git a/tests/gateway/test_client_pool.py b/tests/gateway/test_client_pool.py new file mode 100644 index 00000000..55b16081 --- /dev/null +++ b/tests/gateway/test_client_pool.py @@ -0,0 +1,182 @@ +"""Tests for FHIR client connection pooling functionality.""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from healthchain.gateway.clients.pool import FHIRClientPool +from healthchain.gateway.api.protocols import FHIRServerInterfaceProtocol + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_client_factory(): + """Create a mock client factory function.""" + + def factory(connection_string, limits=None): + client = Mock(spec=FHIRServerInterfaceProtocol) + client.close = AsyncMock() + + # Add httpx client attributes for pool stats + client.client = Mock() + client.client._pool = Mock() + available_conn = Mock() + available_conn.is_available.return_value = True + unavailable_conn = Mock() + unavailable_conn.is_available.return_value = False + client.client._pool._pool = [available_conn, unavailable_conn] + + client._limits = limits + return client + + return factory + + +@pytest.fixture +def client_pool(): + """Create a FHIRClientPool for testing.""" + return FHIRClientPool( + max_connections=50, max_keepalive_connections=10, keepalive_expiry=3.0 + ) + + +@pytest.mark.parametrize( + "max_conn,keepalive_conn,expiry", + [ + (200, 50, 10.0), + (100, 20, 5.0), # defaults + ], +) +def test_client_pool_initialization(max_conn, keepalive_conn, expiry): + """FHIRClientPool initializes with custom or default limits.""" + if max_conn == 100: # test defaults + pool = FHIRClientPool() + else: + pool = FHIRClientPool( + max_connections=max_conn, + max_keepalive_connections=keepalive_conn, + keepalive_expiry=expiry, + ) + + assert pool._client_limits.max_connections == max_conn + assert pool._client_limits.max_keepalive_connections == keepalive_conn + assert pool._client_limits.keepalive_expiry == expiry + assert pool._clients == {} + + +async def test_client_creation_and_reuse(client_pool, mock_client_factory): + """FHIRClientPool creates new clients and reuses existing ones.""" + conn1 = "fhir://server1.example.com/R4" + conn2 = "fhir://server2.example.com/R4" + + # Create first client + client1a = await client_pool.get_client(conn1, mock_client_factory) + assert client1a is not None + assert conn1 in client_pool._clients + assert client1a._limits is client_pool._client_limits + + # Reuse same client + client1b = await client_pool.get_client(conn1, mock_client_factory) + assert client1a is client1b + + # Create different client for different connection + client2 = await client_pool.get_client(conn2, mock_client_factory) + assert client1a is not client2 + assert len(client_pool._clients) == 2 + + +async def test_close_all_clients(client_pool, mock_client_factory): + """FHIRClientPool closes all clients and handles missing close methods.""" + conn1 = "fhir://server1.example.com/R4" + conn2 = "fhir://server2.example.com/R4" + + # Create clients + client1 = await client_pool.get_client(conn1, mock_client_factory) + client2 = await client_pool.get_client(conn2, mock_client_factory) + + # Add client without close method + client_without_close = Mock(spec=[]) + client_pool._clients["no_close"] = client_without_close + + # Close all clients + await client_pool.close_all() + + # Verify all clients were closed + client1.close.assert_called_once() + client2.close.assert_called_once() + assert client_pool._clients == {} + + +async def test_pool_stats(client_pool, mock_client_factory): + """FHIRClientPool provides accurate statistics.""" + # Empty pool stats + stats = client_pool.get_pool_stats() + assert stats["total_clients"] == 0 + assert stats["limits"]["max_connections"] == 50 + assert stats["limits"]["max_keepalive_connections"] == 10 + assert stats["limits"]["keepalive_expiry"] == 3.0 + assert stats["clients"] == {} + + # Add clients and check stats + conn1 = "fhir://server1.example.com/R4" + conn2 = "fhir://server2.example.com/R4" + + await client_pool.get_client(conn1, mock_client_factory) + await client_pool.get_client(conn2, mock_client_factory) + + stats = client_pool.get_pool_stats() + assert stats["total_clients"] == 2 + assert conn1 in stats["clients"] + assert conn2 in stats["clients"] + + # Check connection details + client_stats = stats["clients"][conn1] + assert client_stats["active_connections"] == 2 + assert client_stats["available_connections"] == 1 + + +async def test_pool_stats_without_pool_info(client_pool): + """FHIRClientPool handles clients without connection pool info.""" + simple_client = Mock(spec=[]) + client_pool._clients["simple"] = simple_client + + stats = client_pool.get_pool_stats() + assert stats["total_clients"] == 1 + assert stats["clients"]["simple"] == {} + + +async def test_client_factory_exceptions(client_pool): + """FHIRClientPool propagates exceptions from client factory.""" + + def failing_factory(connection_string, limits=None): + raise ValueError("Factory failed") + + with pytest.raises(ValueError, match="Factory failed"): + await client_pool.get_client("fhir://test.com/R4", failing_factory) + + +async def test_concurrent_client_creation(client_pool): + """FHIRClientPool handles concurrent requests for same connection.""" + connection_string = "fhir://test.example.com/R4" + call_count = 0 + + def counting_factory(conn_str, limits=None): + nonlocal call_count + call_count += 1 + client = Mock() + client.close = AsyncMock() + return client + + import asyncio + + async def get_client(): + return await client_pool.get_client(connection_string, counting_factory) + + # Create concurrent tasks + tasks = [get_client() for _ in range(3)] + results = await asyncio.gather(*tasks) + + # All clients should be the same instance + assert all(client is results[0] for client in results) + # Factory should only be called once due to caching + assert call_count == 1 diff --git a/tests/gateway/test_connection_manager.py b/tests/gateway/test_connection_manager.py new file mode 100644 index 00000000..ae951380 --- /dev/null +++ b/tests/gateway/test_connection_manager.py @@ -0,0 +1,102 @@ +""" +Tests for the FHIR connection manager in the HealthChain gateway system. + +This module tests centralized connection management for FHIR sources: +- Connection string parsing and validation +- Source lifecycle management +- Client pooling and retrieval +""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from healthchain.gateway.core.connection import FHIRConnectionManager +from healthchain.gateway.core.errors import FHIRConnectionError +from healthchain.gateway.api.protocols import FHIRServerInterfaceProtocol + +# Configure pytest-asyncio for async tests +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def connection_manager(): + """Create a connection manager for testing.""" + return FHIRConnectionManager( + max_connections=50, max_keepalive_connections=10, keepalive_expiry=30.0 + ) + + +@pytest.fixture +def mock_fhir_client(): + """Create a mock FHIR client using protocol.""" + client = Mock(spec=FHIRServerInterfaceProtocol) + client.base_url = "https://test.fhir.com/R4" + return client + + +@pytest.mark.parametrize( + "connection_string,should_succeed", + [ + # Valid connection strings + ( + "fhir://epic.org/api/FHIR/R4?client_id=test&client_secret=secret&token_url=https://epic.org/token", + True, + ), + ( + "fhir://localhost:8080/fhir?client_id=local&client_secret=pass&token_url=http://localhost/token", + True, + ), + # Invalid connection strings + ("http://not-fhir.com/api", False), # Wrong scheme + ("fhir://", False), # Missing hostname + ("invalid-string", False), # Not a URL + ], +) +def test_connection_manager_source_validation_and_parsing( + connection_manager, connection_string, should_succeed +): + """FHIRConnectionManager validates connection strings and parses hostnames correctly.""" + if should_succeed: + connection_manager.add_source("test_source", connection_string) + assert "test_source" in connection_manager.sources + assert "test_source" in connection_manager._connection_strings + assert ( + connection_manager._connection_strings["test_source"] == connection_string + ) + else: + with pytest.raises( + FHIRConnectionError, match="Failed to parse connection string" + ): + connection_manager.add_source("test_source", connection_string) + + +async def test_connection_manager_client_retrieval_and_default_selection( + connection_manager, mock_fhir_client +): + """FHIRConnectionManager retrieves clients through pooling and selects defaults correctly.""" + # Add multiple sources + connection_manager.add_source( + "first", + "fhir://first.com/fhir?client_id=test&client_secret=secret&token_url=https://first.com/token", + ) + connection_manager.add_source( + "second", + "fhir://second.com/fhir?client_id=test&client_secret=secret&token_url=https://second.com/token", + ) + + connection_manager.client_pool.get_client = AsyncMock(return_value=mock_fhir_client) + + # Test specific source retrieval + client = await connection_manager.get_client("first") + assert client == mock_fhir_client + + # Test default source selection (should use first available) + client_default = await connection_manager.get_client() + assert client_default == mock_fhir_client + call_args = connection_manager.client_pool.get_client.call_args + from urllib.parse import urlparse + + parsed_url = urlparse(call_args[0][0]) + assert ( + parsed_url.hostname == "first.com" + ) # Should use first source's connection string diff --git a/tests/gateway/test_core_base.py b/tests/gateway/test_core_base.py new file mode 100644 index 00000000..aca6dd88 --- /dev/null +++ b/tests/gateway/test_core_base.py @@ -0,0 +1,221 @@ +""" +Tests for the core base classes in the HealthChain gateway system. + +This module tests the fundamental base classes that define the gateway architecture: +- BaseGateway +- BaseProtocolHandler +- EventCapability +- GatewayConfig +""" + +import pytest +from unittest.mock import Mock, AsyncMock +from typing import Dict, Any + +from healthchain.gateway.core.base import ( + BaseGateway, + BaseProtocolHandler, + EventCapability, + GatewayConfig, +) +from healthchain.gateway.events.dispatcher import EventDispatcher + +# Configure pytest-asyncio for async tests +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_event_dispatcher(): + """Create a mock event dispatcher for testing.""" + dispatcher = Mock(spec=EventDispatcher) + dispatcher.publish = AsyncMock() + dispatcher.register_handler = Mock(return_value=lambda f: f) + return dispatcher + + +class ConcreteProtocolHandler(BaseProtocolHandler[Dict[str, Any], Dict[str, Any]]): + """Concrete implementation of BaseProtocolHandler for testing.""" + + def _process_result(self, result: Any) -> Dict[str, Any]: + """Process results into expected dict format.""" + if isinstance(result, dict): + return result + return {"processed": str(result)} + + +class ConcreteGateway(BaseGateway): + """Concrete implementation of BaseGateway for testing.""" + + def get_metadata(self) -> Dict[str, Any]: + metadata = super().get_metadata() + metadata["test_specific"] = True + return metadata + + +def test_event_capability_configuration_and_chaining(mock_event_dispatcher): + """EventCapability supports configuration and method chaining.""" + capability = EventCapability() + mock_creator = Mock(return_value={"event": "test"}) + + # Test method chaining and configuration + result = capability.set_dispatcher(mock_event_dispatcher).set_event_creator( + mock_creator + ) + + assert capability.dispatcher == mock_event_dispatcher + assert capability._event_creator == mock_creator + assert result == capability # Method chaining + + +def test_event_capability_delegated_publishing(mock_event_dispatcher): + """EventCapability delegates to dispatcher's emit method.""" + capability = EventCapability() + capability.set_dispatcher(mock_event_dispatcher) + + test_event = {"type": "test_event"} + capability.publish(test_event) + + mock_event_dispatcher.emit.assert_called_once_with(test_event) + + +async def test_protocol_handler_supports_sync_and_async_handlers(): + """BaseProtocolHandler supports both synchronous and asynchronous handlers.""" + handler = ConcreteProtocolHandler() + + # Register handlers + handler.register_handler("sync_op", lambda data: {"sync_result": data}) + handler.register_handler( + "async_op", AsyncMock(return_value={"async_result": "test"}) + ) + + # Test both handler types + sync_result = await handler.handle("sync_op", data="test_sync") + async_result = await handler.handle("async_op", data="test_async") + + assert sync_result == {"sync_result": "test_sync"} + assert async_result == {"async_result": "test"} + + +@pytest.mark.parametrize( + "return_errors,operation_exists,expected_behavior", + [ + # Handler exists - should succeed + (False, True, {"success": True, "raises": False}), + (True, True, {"success": True, "raises": False}), + # Handler missing, return_errors=False - should raise + (False, False, {"success": False, "raises": True}), + # Handler missing, return_errors=True - should return error dict + (True, False, {"success": False, "raises": False, "error_in_response": True}), + ], +) +async def test_protocol_handler_error_handling_behavior( + return_errors, operation_exists, expected_behavior +): + """BaseProtocolHandler handles missing operations and errors according to configuration.""" + config = GatewayConfig(return_errors=return_errors) + handler = ConcreteProtocolHandler(config=config) + + if operation_exists: + handler.register_handler("test_op", lambda data: {"result": data}) + + if expected_behavior["raises"]: + with pytest.raises(ValueError, match="Unsupported operation"): + await handler.handle( + "test_op" if operation_exists else "missing_op", data="test" + ) + else: + result = await handler.handle( + "test_op" if operation_exists else "missing_op", data="test" + ) + + if expected_behavior.get("error_in_response"): + assert "error" in result + assert "Unsupported operation" in result["error"] + else: + assert result == {"result": "test"} + + +async def test_protocol_handler_exception_handling_in_handlers(): + """BaseProtocolHandler handles exceptions in registered handlers appropriately.""" + # Test with return_errors=False (should raise) + handler_raise = ConcreteProtocolHandler(config=GatewayConfig(return_errors=False)) + handler_raise.register_handler("failing_op", lambda: 1 / 0) + + with pytest.raises(ValueError, match="Error during operation execution"): + await handler_raise.handle("failing_op") + + # Test with return_errors=True (should return error dict) + handler_return = ConcreteProtocolHandler(config=GatewayConfig(return_errors=True)) + handler_return.register_handler("failing_op", lambda: 1 / 0) + + result = await handler_return.handle("failing_op") + assert "error" in result + assert "Error during operation execution" in result["error"] + + +def test_base_gateway_initialization_and_metadata_generation(): + """BaseGateway initializes correctly and generates metadata including event capabilities.""" + # Test default initialization + gateway = ConcreteGateway() + assert gateway.prefix == "/api" + assert gateway.tags == [] + + # Test custom initialization and metadata + custom_gateway = ConcreteGateway( + prefix="/custom", tags=["test"], config=GatewayConfig(system_type="TEST_SYSTEM") + ) + + assert custom_gateway.prefix == "/custom" + assert custom_gateway.tags == ["test"] + + # Test metadata generation + metadata = custom_gateway.get_gateway_status() + assert metadata["gateway_type"] == "ConcreteGateway" + assert metadata["system_type"] == "TEST_SYSTEM" + assert metadata["status"] == "active" + + # Test with event dispatcher + custom_gateway.events.set_dispatcher(Mock(spec=EventDispatcher)) + metadata_with_events = custom_gateway.get_gateway_status() + assert metadata_with_events["events"]["enabled"] is True + + +def test_base_gateway_event_handler_registration(mock_event_dispatcher): + """BaseGateway supports event handler registration via events capability.""" + gateway = ConcreteGateway() + gateway.events.set_dispatcher(mock_event_dispatcher) + + # Test decorator usage and direct registration + decorator = gateway.events.register_handler("test_event") + assert callable(decorator) + + def test_handler(event): + return "handled" + + result = gateway.events.register_handler("direct_event", test_handler) + assert result == gateway.events # Method chaining returns EventCapability + + # Test error when no dispatcher set + no_dispatcher_gateway = ConcreteGateway() + with pytest.raises(ValueError, match="Event dispatcher not set"): + no_dispatcher_gateway.events.register_handler("event", test_handler) + + +def test_protocol_handler_capabilities_and_factory_method(): + """BaseProtocolHandler provides capabilities introspection and factory method.""" + # Test capabilities + handler = ConcreteProtocolHandler() + handler.register_handler("op1", lambda: "result1") + handler.register_handler("op2", lambda: "result2") + + capabilities = handler.get_capabilities() + assert set(capabilities) == {"op1", "op2"} + + # Test factory method + factory_handler = ConcreteProtocolHandler.create( + config=GatewayConfig(system_type="FACTORY_TEST"), return_errors=True + ) + + assert isinstance(factory_handler, ConcreteProtocolHandler) + assert factory_handler.config.system_type == "FACTORY_TEST" + assert factory_handler.return_errors is True diff --git a/tests/gateway/test_error_handling.py b/tests/gateway/test_error_handling.py new file mode 100644 index 00000000..d3107df6 --- /dev/null +++ b/tests/gateway/test_error_handling.py @@ -0,0 +1,210 @@ +"""Tests for FHIR error handling functionality.""" + +import pytest + +from healthchain.gateway.core.errors import ( + FHIRConnectionError, + FHIRErrorHandler, +) + + +@pytest.mark.parametrize( + "show_state,expected", + [ + (True, "[404 NOT_FOUND] Resource not found"), + (False, "[VALIDATION_ERROR] Validation failed"), + (None, "[None GENERIC_ERROR] Generic error"), # no state provided + ], +) +def test_fhir_connection_error_formatting(show_state, expected): + """FHIRConnectionError formats messages correctly based on show_state.""" + if show_state is None: + error = FHIRConnectionError(message="Generic error", code="GENERIC_ERROR") + elif show_state: + error = FHIRConnectionError( + message="Resource not found", code="NOT_FOUND", state="404", show_state=True + ) + else: + error = FHIRConnectionError( + message="Validation failed", + code="VALIDATION_ERROR", + state="422", + show_state=False, + ) + + assert str(error) == expected + + +@pytest.mark.parametrize( + "status_code,expected_fragment", + [ + (400, "Resource could not be parsed"), + (401, "Authorization is required"), + (403, "You may not have permission"), + (404, "resource you are looking for does not exist"), + (405, "server does not allow client defined ids"), + (409, "Version conflict - update cannot be done"), + (410, "resource you are looking for is no longer available"), + (412, "Version conflict - version id does not match"), + (422, "Proposed resource violated applicable FHIR profiles"), + ], +) +def test_error_mapping_by_status_code(status_code, expected_fragment): + """FHIRErrorHandler maps HTTP status codes to appropriate FHIR error messages.""" + mock_exception = Exception("HTTP error") + mock_exception.status_code = status_code + + with pytest.raises(FHIRConnectionError) as exc_info: + FHIRErrorHandler.handle_fhir_error( + mock_exception, resource_type="Patient", fhir_id="123", operation="read" + ) + + error = exc_info.value + assert expected_fragment.lower() in error.message.lower() + assert error.state == str(status_code) + assert "read Patient/123 failed" in error.message + + +def test_error_mapping_by_message_content(): + """FHIRErrorHandler maps errors by parsing status code from error message.""" + mock_exception = Exception("Request failed with status 404 - not found") + + with pytest.raises(FHIRConnectionError) as exc_info: + FHIRErrorHandler.handle_fhir_error( + mock_exception, resource_type="Observation", operation="search" + ) + + error = exc_info.value + assert "resource you are looking for does not exist" in error.message + assert error.state == "404" + assert "search Observation failed" in error.message + + +@pytest.mark.parametrize( + "has_status_code,expected_state", + [ + (True, "599"), + (False, "UNKNOWN"), + ], +) +def test_error_handling_edge_cases(has_status_code, expected_state): + """FHIRErrorHandler handles unknown status codes and missing attributes.""" + mock_exception = Exception("Server error") + if has_status_code: + mock_exception.status_code = 599 # Unknown status code + + with pytest.raises(FHIRConnectionError) as exc_info: + FHIRErrorHandler.handle_fhir_error( + mock_exception, resource_type="Patient", fhir_id="123", operation="update" + ) + + error = exc_info.value + assert error.state == expected_state + assert "update Patient/123 failed: HTTP error" in error.message + + +@pytest.mark.parametrize( + "fhir_id,expected_format", + [ + ("patient-123", "read Patient/patient-123 failed"), + (None, "create Patient failed"), + ], +) +def test_resource_reference_formatting(fhir_id, expected_format): + """FHIRErrorHandler formats resource references correctly with or without ID.""" + mock_exception = Exception("Error") + mock_exception.status_code = 404 if fhir_id else 400 + + with pytest.raises(FHIRConnectionError) as exc_info: + FHIRErrorHandler.handle_fhir_error( + mock_exception, + resource_type="Patient", + fhir_id=fhir_id, + operation="read" if fhir_id else "create", + ) + + assert expected_format in str(exc_info.value) + + +@pytest.mark.parametrize( + "resource_type,field_name,expected_format", + [ + ( + "Patient", + "identifier", + "Validation failed for Patient.identifier: Invalid format", + ), + ( + "Observation", + None, + "Validation failed for Observation: Missing required field", + ), + (None, None, "Validation failed: General validation error"), + ], +) +def test_validation_error_creation(resource_type, field_name, expected_format): + """FHIRErrorHandler creates validation errors with appropriate formatting.""" + message = ( + "Invalid format" + if field_name + else "Missing required field" + if resource_type + else "General validation error" + ) + + error = FHIRErrorHandler.create_validation_error( + message=message, resource_type=resource_type, field_name=field_name + ) + + assert error.message == expected_format + assert error.code == "VALIDATION_ERROR" + assert error.state == "422" + + +@pytest.mark.parametrize( + "source,error_type,expected_code,expected_state", + [ + ("epic_prod", "connection", "CONNECTION_ERROR", "503"), + ("cerner_dev", "authentication", "AUTHENTICATION_ERROR", "401"), + (None, "connection", "CONNECTION_ERROR", "503"), + (None, "authentication", "AUTHENTICATION_ERROR", "401"), + ], +) +def test_specialized_error_creation(source, error_type, expected_code, expected_state): + """FHIRErrorHandler creates connection and authentication errors correctly.""" + message = "Network timeout" if error_type == "connection" else "Invalid token" + + if error_type == "connection": + error = FHIRErrorHandler.create_connection_error(message=message, source=source) + expected_prefix = f"Connection to source '{source}'" if source else "Connection" + else: + error = FHIRErrorHandler.create_authentication_error( + message=message, source=source + ) + expected_prefix = ( + f"Authentication to source '{source}'" if source else "Authentication" + ) + + expected_message = ( + f"{expected_prefix} failed: {message}" + if source + else f"{expected_prefix} failed: {message}" + ) + + assert error.message == expected_message + assert error.code == expected_code + assert error.state == expected_state + + +def test_error_chaining_preserves_original_message(): + """FHIRErrorHandler preserves original exception message in error code.""" + original_message = "Detailed server error: Resource validation failed on field X" + mock_exception = Exception(original_message) + mock_exception.status_code = 422 + + with pytest.raises(FHIRConnectionError) as exc_info: + FHIRErrorHandler.handle_fhir_error( + mock_exception, resource_type="Patient", operation="create" + ) + + assert exc_info.value.code == original_message diff --git a/tests/gateway/test_event_dispatcher.py b/tests/gateway/test_event_dispatcher.py index a7090a58..45fafd36 100644 --- a/tests/gateway/test_event_dispatcher.py +++ b/tests/gateway/test_event_dispatcher.py @@ -1,86 +1,228 @@ """ -Tests for the EventDispatcher in the HealthChain gateway system. +Tests for the event dispatcher core functionality. -This module tests the functionality of the EventDispatcher class -for handling EHR events in the system. +Focuses on pub/sub behavior, handler registration, and event publishing patterns. """ import pytest -from datetime import datetime +from unittest.mock import Mock, patch, AsyncMock from fastapi import FastAPI +from datetime import datetime from healthchain.gateway.events.dispatcher import ( EventDispatcher, - EHREventType, EHREvent, + EHREventType, ) - -@pytest.fixture -def app(): - """Create a FastAPI app for testing.""" - return FastAPI() +pytestmark = pytest.mark.asyncio @pytest.fixture -def dispatcher(): - """Create an EventDispatcher for testing.""" - return EventDispatcher() +def mock_fastapi_app(): + """Create a mock FastAPI app for testing.""" + return Mock(spec=FastAPI) @pytest.fixture -def initialized_dispatcher(app, dispatcher): - """Create an EventDispatcher initialized with a FastAPI app.""" - dispatcher.init_app(app) - return dispatcher +def event_dispatcher(): + """Create an event dispatcher for testing.""" + return EventDispatcher() @pytest.fixture -def sample_event(): +def sample_ehr_event(): """Create a sample EHR event for testing.""" return EHREvent( - event_type=EHREventType.EHR_GENERIC, + event_type=EHREventType.FHIR_READ, source_system="test_system", timestamp=datetime.now(), - payload={"data": "test data"}, - metadata={"test": "metadata"}, + payload={"resource_id": "123", "resource_type": "Patient"}, + metadata={"user": "test_user"}, ) -def test_event_dispatcher_initialization(dispatcher): - """Test that EventDispatcher initializes correctly.""" +def test_event_dispatcher_conforms_to_protocol(): + """EventDispatcher implements the required protocol methods.""" + dispatcher = EventDispatcher() + + # Check that dispatcher has all required protocol methods + assert hasattr(dispatcher, "publish") + assert hasattr(dispatcher, "init_app") + assert hasattr(dispatcher, "register_handler") + assert hasattr(dispatcher, "register_default_handler") + assert callable(getattr(dispatcher, "publish")) + assert callable(getattr(dispatcher, "init_app")) + + +def test_event_dispatcher_initialization(): + """EventDispatcher initializes with empty registry and unique middleware ID.""" + dispatcher = EventDispatcher() + + assert dispatcher.handlers_registry == {} assert dispatcher.app is None - assert dispatcher.middleware_id is not None + assert isinstance(dispatcher.middleware_id, int) + + # Each instance should have unique middleware ID + dispatcher2 = EventDispatcher() + assert dispatcher.middleware_id != dispatcher2.middleware_id + + +@patch("healthchain.gateway.events.dispatcher.EventHandlerASGIMiddleware") +def test_event_dispatcher_app_initialization( + mock_middleware, event_dispatcher, mock_fastapi_app +): + """EventDispatcher correctly initializes with FastAPI app and registers middleware.""" + event_dispatcher.init_app(mock_fastapi_app) + + assert event_dispatcher.app is mock_fastapi_app + mock_fastapi_app.add_middleware.assert_called_once() + + # Verify middleware was called with correct parameters + call_args = mock_fastapi_app.add_middleware.call_args + assert call_args[0][0] == mock_middleware + assert "handlers" in call_args[1] + assert call_args[1]["middleware_id"] == event_dispatcher.middleware_id + + +@pytest.mark.parametrize( + "event_type,expected_name", + [ + (EHREventType.FHIR_READ, "fhir.read"), + (EHREventType.CDS_PATIENT_VIEW, "cds.patient.view"), + (EHREventType.NOTEREADER_SIGN_NOTE, "notereader.sign.note"), + ], +) +def test_ehr_event_name_mapping(event_type, expected_name): + """EHREvent correctly maps event types to string names.""" + event = EHREvent( + event_type=event_type, + source_system="test", + timestamp=datetime.now(), + payload={}, + metadata={}, + ) + + assert event.get_name() == expected_name + assert event.event_type.value == expected_name + + +@patch("healthchain.gateway.events.dispatcher.local_handler") +def test_event_handler_registration_returns_decorator( + mock_local_handler, event_dispatcher +): + """Event handler registration returns correct fastapi-events decorator.""" + mock_decorator = Mock() + mock_local_handler.register.return_value = mock_decorator + + result = event_dispatcher.register_handler(EHREventType.FHIR_READ) + + assert result is mock_decorator + mock_local_handler.register.assert_called_once_with(event_name="fhir.read") + + +@patch("healthchain.gateway.events.dispatcher.local_handler") +def test_default_handler_registration(mock_local_handler, event_dispatcher): + """Default handler registration uses wildcard pattern.""" + mock_decorator = Mock() + mock_local_handler.register.return_value = mock_decorator + + result = event_dispatcher.register_default_handler() + assert result is mock_decorator + mock_local_handler.register.assert_called_once_with(event_name="*") -def test_event_dispatcher_init_app(app, dispatcher): - """Test that EventDispatcher can be initialized with a FastAPI app.""" - dispatcher.init_app(app) - assert dispatcher.app == app - assert len(app.user_middleware) == 1 +@patch("healthchain.gateway.events.dispatcher.dispatch") +async def test_event_publishing_with_default_middleware_id( + mock_dispatch, event_dispatcher, sample_ehr_event +): + """Event publishing uses dispatcher's middleware ID when none provided.""" + mock_dispatch.return_value = None # dispatch may return None -def test_register_handler(initialized_dispatcher): - """Test that register_handler returns a decorator.""" - decorator = initialized_dispatcher.register_handler(EHREventType.EHR_GENERIC) - assert callable(decorator) + await event_dispatcher.publish(sample_ehr_event) + mock_dispatch.assert_called_once_with( + "fhir.read", + sample_ehr_event.model_dump(), + middleware_id=event_dispatcher.middleware_id, + ) + + +@patch("healthchain.gateway.events.dispatcher.dispatch") +async def test_event_publishing_with_custom_middleware_id( + mock_dispatch, event_dispatcher, sample_ehr_event +): + """Event publishing uses provided middleware ID when specified.""" + custom_middleware_id = 12345 + mock_dispatch.return_value = None + + await event_dispatcher.publish(sample_ehr_event, middleware_id=custom_middleware_id) + + mock_dispatch.assert_called_once_with( + "fhir.read", sample_ehr_event.model_dump(), middleware_id=custom_middleware_id + ) + + +@patch("healthchain.gateway.events.dispatcher.dispatch") +async def test_event_publishing_awaits_dispatch_result( + mock_dispatch, event_dispatcher, sample_ehr_event +): + """Event publishing awaits dispatch result when it returns an awaitable.""" -# TODO: test async -# @patch("healthchain.gateway.events.dispatcher.dispatch") -# async def test_publish_event(mock_dispatch, initialized_dispatcher, sample_event): -# """Test that publish correctly dispatches an event.""" -# mock_dispatch.return_value = None -# await initialized_dispatcher.publish(sample_event) -# mock_dispatch.assert_called_once() + # Create a proper coroutine that can be awaited + async def mock_coroutine(): + return "dispatched" + mock_dispatch.return_value = mock_coroutine() -def test_ehr_event_get_name(sample_event): - """Test that EHREvent.get_name returns the correct event name.""" - assert sample_event.get_name() == "ehr.generic" + await event_dispatcher.publish(sample_ehr_event) + + # Verify dispatch was called with correct parameters + mock_dispatch.assert_called_once_with( + "fhir.read", + sample_ehr_event.model_dump(), + middleware_id=event_dispatcher.middleware_id, + ) -def test_basic_event_types(): - """Test a few basic event types.""" - assert EHREventType.EHR_GENERIC.value == "ehr.generic" - assert EHREventType.FHIR_READ.value == "fhir.read" +def test_emit_method_handles_sync_context(event_dispatcher, sample_ehr_event): + """EventDispatcher.emit creates a new loop when not in async context.""" + # Mock all the asyncio components + with patch.object( + event_dispatcher, "publish", new_callable=AsyncMock + ) as mock_publish: + with patch( + "asyncio.get_running_loop", side_effect=RuntimeError("No running loop") + ): + with patch("asyncio.new_event_loop") as mock_new_loop: + mock_loop = Mock() + mock_new_loop.return_value = mock_loop + + # Call emit from sync context + event_dispatcher.emit(sample_ehr_event, middleware_id=42) + + # Verify behavior + mock_new_loop.assert_called_once() + mock_loop.run_until_complete.assert_called_once() + mock_loop.close.assert_called_once() + mock_publish.assert_called_once_with(sample_ehr_event, 42) + + +def test_emit_method_handles_async_context(event_dispatcher, sample_ehr_event): + """EventDispatcher.emit correctly handles existing async context.""" + # Mock the async publish method + with patch.object( + event_dispatcher, "publish", new_callable=AsyncMock + ) as mock_publish: + # Test async context - should use create_task + with patch("asyncio.get_running_loop") as mock_get_loop: + with patch("asyncio.create_task") as mock_create_task: + mock_loop = Mock() + mock_get_loop.return_value = mock_loop + + event_dispatcher.emit(sample_ehr_event) + + # Verify create_task was used (async context) + mock_create_task.assert_called_once() + mock_publish.assert_called_once_with(sample_ehr_event, None) diff --git a/tests/gateway/test_fhir_client.py b/tests/gateway/test_fhir_client.py new file mode 100644 index 00000000..79424d7c --- /dev/null +++ b/tests/gateway/test_fhir_client.py @@ -0,0 +1,346 @@ +""" +Tests for FHIR client external API integration functionality. + +Focuses on HTTP operations, authentication, error handling, and response processing. +""" + +import pytest +import json +import httpx +from unittest.mock import Mock, AsyncMock, patch +from fhir.resources.patient import Patient +from fhir.resources.bundle import Bundle +from fhir.resources.capabilitystatement import CapabilityStatement + +from healthchain.gateway.clients.fhir import ( + AsyncFHIRClient, + FHIRClientError, +) +from healthchain.gateway.clients.auth import FHIRAuthConfig + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_auth_config(): + """Create a mock FHIR auth configuration.""" + return FHIRAuthConfig( + base_url="https://test.fhir.org/R4", + client_id="test_client", + client_secret="test_secret", + token_url="https://test.fhir.org/oauth/token", + scope="system/*.read", + timeout=30.0, + verify_ssl=True, + ) + + +@pytest.fixture +def fhir_client(mock_auth_config): + """Create a FHIR client for testing.""" + with patch( + "healthchain.gateway.clients.fhir.OAuth2TokenManager" + ) as mock_manager_class: + mock_manager = AsyncMock() + mock_manager.get_access_token.return_value = "test_token" + mock_manager_class.return_value = mock_manager + + client = AsyncFHIRClient(auth_config=mock_auth_config) + client.token_manager = mock_manager + return client + + +@pytest.fixture +def fhir_client_with_limits(mock_auth_config): + """Create an AsyncFHIRClient with connection limits for testing.""" + limits = httpx.Limits( + max_connections=50, + max_keepalive_connections=10, + keepalive_expiry=30.0, + ) + with patch( + "healthchain.gateway.clients.fhir.OAuth2TokenManager" + ) as mock_manager_class: + mock_manager = AsyncMock() + mock_manager.get_access_token.return_value = "test_token" + mock_manager_class.return_value = mock_manager + + client = AsyncFHIRClient(auth_config=mock_auth_config, limits=limits) + client.token_manager = mock_manager + return client + + +@pytest.fixture +def mock_httpx_response(): + """Create a mock httpx response.""" + response = Mock(spec=httpx.Response) + response.is_success = True + response.status_code = 200 + response.json.return_value = {"resourceType": "Patient", "id": "123"} + return response + + +def test_fhir_client_initialization_and_configuration(mock_auth_config): + """AsyncFHIRClient initializes with correct configuration and headers.""" + with patch("healthchain.gateway.clients.fhir.OAuth2TokenManager"): + client = AsyncFHIRClient(auth_config=mock_auth_config) + + # Test configuration + assert client.base_url == "https://test.fhir.org/R4/" + assert client.timeout == 30.0 + assert client.verify_ssl is True + + # Test headers + assert client.base_headers["Accept"] == "application/fhir+json" + assert client.base_headers["Content-Type"] == "application/fhir+json" + + +def test_async_fhir_client_conforms_to_protocol(fhir_client): + """AsyncFHIRClient implements the required protocol methods.""" + # Check that client has all required protocol methods + assert hasattr(fhir_client, "read") + assert hasattr(fhir_client, "search") + assert hasattr(fhir_client, "create") + assert hasattr(fhir_client, "update") + assert hasattr(fhir_client, "delete") + assert hasattr(fhir_client, "transaction") + assert hasattr(fhir_client, "capabilities") + + # Check that methods are callable + assert callable(getattr(fhir_client, "read")) + assert callable(getattr(fhir_client, "search")) + + +async def test_fhir_client_authentication_and_headers(fhir_client): + """AsyncFHIRClient manages OAuth tokens and includes proper headers.""" + # Test first call includes token and headers + headers = await fhir_client._get_headers() + assert headers["Authorization"] == "Bearer test_token" + assert headers["Accept"] == "application/fhir+json" + assert headers["Content-Type"] == "application/fhir+json" + + # Test token refresh on subsequent calls + await fhir_client._get_headers() + assert fhir_client.token_manager.get_access_token.call_count == 2 + + +def test_fhir_client_url_building(fhir_client): + """AsyncFHIRClient builds URLs correctly with and without parameters.""" + # Without parameters + url = fhir_client._build_url("Patient/123") + assert url == "https://test.fhir.org/R4/Patient/123" + + # With parameters (None values filtered) + params = {"name": "John", "active": True, "limit": None} + url = fhir_client._build_url("Patient", params) + assert "https://test.fhir.org/R4/Patient?" in url + assert "name=John" in url + assert "active=True" in url + assert "limit" not in url + + +@pytest.mark.parametrize( + "status_code,is_success,should_raise", + [ + (200, True, False), + (201, True, False), + (400, False, True), + (404, False, True), + (500, False, True), + ], +) +def test_fhir_client_response_handling( + fhir_client, status_code, is_success, should_raise +): + """AsyncFHIRClient handles HTTP status codes and error responses appropriately.""" + mock_response = Mock(spec=httpx.Response) + mock_response.is_success = is_success + mock_response.status_code = status_code + mock_response.json.return_value = {"resourceType": "OperationOutcome"} + + if should_raise: + with pytest.raises(FHIRClientError) as exc_info: + fhir_client._handle_response(mock_response) + assert exc_info.value.status_code == status_code + else: + result = fhir_client._handle_response(mock_response) + assert result == {"resourceType": "OperationOutcome"} + + +def test_fhir_client_error_extraction_and_invalid_json(fhir_client): + """AsyncFHIRClient extracts error diagnostics and handles invalid JSON.""" + # Test error extraction from OperationOutcome + mock_response = Mock(spec=httpx.Response) + mock_response.is_success = False + mock_response.status_code = 422 + mock_response.json.return_value = { + "resourceType": "OperationOutcome", + "issue": [{"diagnostics": "Validation failed on field X"}], + } + + with pytest.raises(FHIRClientError) as exc_info: + fhir_client._handle_response(mock_response) + assert "Validation failed on field X" in str(exc_info.value) + assert exc_info.value.status_code == 422 + + # Test invalid JSON handling + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "doc", 0) + mock_response.text = "Invalid response text" + mock_response.status_code = 500 + + with pytest.raises(FHIRClientError) as exc_info: + fhir_client._handle_response(mock_response) + assert "Invalid JSON response" in str(exc_info.value) + + +async def test_fhir_client_crud_operations(fhir_client, mock_httpx_response): + """AsyncFHIRClient performs CRUD operations correctly.""" + # Test READ operation + with patch.object( + fhir_client.client, "get", return_value=mock_httpx_response + ) as mock_get: + with patch.object( + fhir_client, "_get_headers", return_value={"Authorization": "Bearer token"} + ): + result = await fhir_client.read(Patient, "123") + mock_get.assert_called_once_with( + "https://test.fhir.org/R4/Patient/123", + headers={"Authorization": "Bearer token"}, + ) + assert isinstance(result, Patient) + assert result.id == "123" + + # Test CREATE operation + patient = Patient(id="123", active=True) + mock_httpx_response.json.return_value = { + "resourceType": "Patient", + "id": "new-123", + "active": True, + } + + with patch.object( + fhir_client.client, "post", return_value=mock_httpx_response + ) as mock_post: + with patch.object( + fhir_client, "_get_headers", return_value={"Authorization": "Bearer token"} + ): + result = await fhir_client.create(patient) + call_args = mock_post.call_args + assert call_args[0][0] == "https://test.fhir.org/R4/Patient" + assert "content" in call_args[1] + assert isinstance(result, Patient) + assert result.id == "new-123" + + # Test DELETE operation + mock_delete_response = Mock(spec=httpx.Response) + mock_delete_response.is_success = True + mock_delete_response.status_code = 204 + + with patch.object( + fhir_client.client, "delete", return_value=mock_delete_response + ) as mock_delete: + with patch.object(fhir_client, "_get_headers", return_value={}): + result = await fhir_client.delete(Patient, "123") + mock_delete.assert_called_once_with( + "https://test.fhir.org/R4/Patient/123", headers={} + ) + assert result is True + + +async def test_fhir_client_search_and_capabilities(fhir_client): + """AsyncFHIRClient handles search operations and server capabilities.""" + # Test SEARCH operation + bundle_response = { + "resourceType": "Bundle", + "type": "searchset", + "entry": [{"resource": {"resourceType": "Patient", "id": "123"}}], + } + mock_response = Mock(spec=httpx.Response) + mock_response.is_success = True + mock_response.json.return_value = bundle_response + + with patch.object( + fhir_client.client, "get", return_value=mock_response + ) as mock_get: + with patch.object(fhir_client, "_get_headers", return_value={}): + params = {"name": "John", "active": True} + result = await fhir_client.search(Patient, params) + + call_url = mock_get.call_args[0][0] + assert "Patient?" in call_url + assert "name=John" in call_url + assert "active=True" in call_url + assert isinstance(result, Bundle) + assert result.type == "searchset" + + # Test CAPABILITIES operation + capabilities_response = { + "resourceType": "CapabilityStatement", + "status": "active", + "kind": "instance", + "fhirVersion": "4.0.1", + "date": "2023-01-01T00:00:00Z", + "format": ["json"], + } + mock_response.json.return_value = capabilities_response + + with patch.object( + fhir_client.client, "get", return_value=mock_response + ) as mock_get: + with patch.object(fhir_client, "_get_headers", return_value={}): + result = await fhir_client.capabilities() + mock_get.assert_called_once_with( + "https://test.fhir.org/R4/metadata", headers={} + ) + assert isinstance(result, CapabilityStatement) + assert result.status == "active" + + +def test_fhir_client_resource_type_resolution(fhir_client): + """AsyncFHIRClient resolves resource types from classes, strings, and handles errors.""" + # Test with FHIR resource class + type_name, resource_class = fhir_client._resolve_resource_type(Patient) + assert type_name == "Patient" + assert resource_class == Patient + + # Test with string name + with patch("builtins.__import__") as mock_import: + mock_module = Mock() + mock_module.Patient = Patient + mock_import.return_value = mock_module + + type_name, resource_class = fhir_client._resolve_resource_type("Patient") + assert type_name == "Patient" + assert resource_class == Patient + mock_import.assert_called_once_with( + "fhir.resources.patient", fromlist=["Patient"] + ) + + # Test invalid resource type + with pytest.raises(ModuleNotFoundError, match="No module named"): + fhir_client._resolve_resource_type("InvalidResource") + + +async def test_fhir_client_authentication_failure(fhir_client): + """AsyncFHIRClient handles authentication failures.""" + fhir_client.token_manager.get_access_token.side_effect = Exception("Auth failed") + with pytest.raises(Exception, match="Auth failed"): + await fhir_client._get_headers() + + +async def test_fhir_client_http_timeout(fhir_client): + """AsyncFHIRClient handles HTTP timeout errors.""" + with patch.object(fhir_client.client, "get") as mock_get: + mock_get.side_effect = httpx.TimeoutException("Request timed out") + with pytest.raises(httpx.TimeoutException): + await fhir_client.read(Patient, "123") + + +def test_fhir_client_error_class(): + """FHIRClientError preserves response data for debugging.""" + response_data = {"resourceType": "OperationOutcome", "issue": []} + error = FHIRClientError("Test error", status_code=400, response_data=response_data) + + assert error.status_code == 400 + assert error.response_data == response_data + assert str(error) == "Test error" diff --git a/tests/gateway/test_fhir_gateway.py b/tests/gateway/test_fhir_gateway.py new file mode 100644 index 00000000..3bda8baf --- /dev/null +++ b/tests/gateway/test_fhir_gateway.py @@ -0,0 +1,279 @@ +import pytest +from unittest.mock import AsyncMock, Mock, patch +from typing import Dict, Any, List + +from fhir.resources.patient import Patient +from fhir.resources.bundle import Bundle + +from healthchain.gateway.core.fhirgateway import FHIRGateway + +pytestmark = pytest.mark.asyncio + + +class MockConnectionManager: + """Mock FHIR connection manager for testing.""" + + def __init__(self): + self.sources = {"test_source": Mock()} + + def add_source(self, name: str, connection_string: str) -> None: + self.sources[name] = Mock() + + async def get_client(self, source: str = None): + return AsyncMock() + + def get_pool_status(self) -> Dict[str, Any]: + return { + "max_connections": 100, + "sources": {"test_source": "connected"}, + } + + async def close(self) -> None: + pass + + +@pytest.fixture +def mock_connection_manager(): + """Fixture providing a mock connection manager.""" + return MockConnectionManager() + + +@pytest.fixture +def fhir_gateway(mock_connection_manager): + """Fixture providing a FHIRGateway with mocked dependencies.""" + with patch( + "healthchain.gateway.core.fhirgateway.FHIRConnectionManager", + return_value=mock_connection_manager, + ): + return FHIRGateway(use_events=False) + + +@pytest.fixture +def test_patient(): + """Fixture providing a test Patient resource.""" + return Patient(id="123", active=True) + + +def test_transform_handler_registration_with_correct_annotation(fhir_gateway): + """Transform handlers with correct return type annotations register successfully.""" + + @fhir_gateway.transform(Patient) + def transform_patient(id: str, source: str = None) -> Patient: + return Patient(id=id) + + assert fhir_gateway._resource_handlers[Patient]["transform"] == transform_patient + + +def test_transform_handler_validation_enforces_return_type_match(fhir_gateway): + """Transform handler registration validates return type matches decorator resource type.""" + from fhir.resources.observation import Observation + + with pytest.raises( + TypeError, match="return type .* doesn't match decorator resource type" + ): + + @fhir_gateway.transform(Patient) + def invalid_handler(id: str) -> Observation: # Wrong return type + return Observation() + + +def test_aggregate_handler_registration_without_validation(fhir_gateway): + """Aggregate handlers register without return type validation.""" + + @fhir_gateway.aggregate(Patient) + def aggregate_patients(id: str = None, sources: List[str] = None): + return [] + + assert fhir_gateway._resource_handlers[Patient]["aggregate"] == aggregate_patients + + +def test_handler_registration_creates_routes(fhir_gateway): + """Handler registration automatically creates corresponding API routes.""" + initial_routes = len(fhir_gateway.routes) + + @fhir_gateway.transform(Patient) + def transform_patient(id: str) -> Patient: + return Patient(id=id) + + assert len(fhir_gateway.routes) == initial_routes + 1 + + +def test_empty_capability_statement_with_no_handlers(fhir_gateway): + """Gateway with no handlers generates minimal CapabilityStatement.""" + capability = fhir_gateway.build_capability_statement() + + assert capability.model_dump()["resourceType"] == "CapabilityStatement" + assert capability.status == "active" + assert capability.kind == "instance" + assert capability.fhirVersion == "4.0.1" + + +def test_capability_statement_includes_registered_resources(fhir_gateway): + """CapabilityStatement includes resources with registered handlers.""" + from fhir.resources.observation import Observation + + @fhir_gateway.transform(Patient) + def transform_patient(id: str) -> Patient: + return Patient(id=id) + + @fhir_gateway.aggregate(Observation) + def aggregate_observations(id: str = None) -> List[Observation]: + return [] + + capability = fhir_gateway.build_capability_statement() + resources = capability.rest[0].resource + resource_types = [r.type for r in resources] + + assert "Patient" in resource_types + assert "Observation" in resource_types + + +def test_gateway_status_structure(fhir_gateway): + """Gateway status contains required fields with correct structure.""" + status = fhir_gateway.get_gateway_status() + + assert status["gateway_type"] == "FHIRGateway" + assert status["status"] == "active" + assert isinstance(status["timestamp"], str) + assert isinstance(status["version"], str) + + +def test_supported_operations_tracking(fhir_gateway): + """Gateway accurately tracks registered operations.""" + initial_ops = fhir_gateway.get_gateway_status()["supported_operations"][ + "endpoints" + ]["transform"] + + @fhir_gateway.transform(Patient) + def transform_patient(id: str) -> Patient: + return Patient(id=id) + + updated_status = fhir_gateway.get_gateway_status() + + assert ( + updated_status["supported_operations"]["endpoints"]["transform"] + == initial_ops + 1 + ) + assert "Patient" in updated_status["supported_operations"]["resources"] + + +async def test_read_operation_with_client_delegation(fhir_gateway, test_patient): + """Read operation delegates to client and handles results correctly.""" + with patch.object( + fhir_gateway, "_execute_with_client", return_value=test_patient + ) as mock_execute: + result = await fhir_gateway.read(Patient, "123", "test_source") + + mock_execute.assert_called_once_with( + "read", + source="test_source", + resource_type=Patient, + resource_id="123", + client_args=(Patient, "123"), + ) + assert result == test_patient + + +async def test_read_operation_raises_on_not_found(fhir_gateway): + """Read operation raises ValueError when resource not found.""" + with patch.object(fhir_gateway, "_execute_with_client", return_value=None): + with pytest.raises(ValueError, match="Resource Patient/123 not found"): + await fhir_gateway.read(Patient, "123") + + +async def test_create_operation_with_validation(fhir_gateway, test_patient): + """Create operation validates input and returns created resource.""" + created_patient = Patient(id="456", active=True) + with patch.object( + fhir_gateway, "_execute_with_client", return_value=created_patient + ) as mock_execute: + result = await fhir_gateway.create(test_patient) + + mock_execute.assert_called_once_with( + "create", + source=None, + resource_type=Patient, + client_args=(test_patient,), + ) + assert result == created_patient + + +async def test_update_operation_requires_resource_id(fhir_gateway): + """Update operation validates that resource has ID.""" + patient_without_id = Patient(active=True) # No ID + + with pytest.raises(ValueError, match="Resource must have an ID for update"): + await fhir_gateway.update(patient_without_id) + + +async def test_search_operation_with_parameters(fhir_gateway): + """Search operation passes parameters correctly to client.""" + mock_bundle = Bundle(type="searchset", total=1) + params = {"name": "Smith", "active": "true"} + + with patch.object( + fhir_gateway, "_execute_with_client", return_value=mock_bundle + ) as mock_execute: + result = await fhir_gateway.search(Patient, params, "test_source") + + mock_execute.assert_called_once_with( + "search", + source="test_source", + resource_type=Patient, + client_args=(Patient,), + client_kwargs={"params": params}, + ) + assert result == mock_bundle + + +async def test_modify_context_for_existing_resource(fhir_gateway, test_patient): + """Modify context manager fetches, yields, and updates existing resources.""" + mock_client = AsyncMock() + mock_client.read.return_value = test_patient + mock_client.update.return_value = Patient(id="123", active=False) + + with patch.object(fhir_gateway, "get_client", return_value=mock_client): + async with fhir_gateway.modify(Patient, "123") as patient: + assert patient == test_patient + patient.active = False + + mock_client.read.assert_called_once_with(Patient, "123") + mock_client.update.assert_called_once_with(test_patient) + + +async def test_modify_context_for_new_resource(fhir_gateway): + """Modify context manager creates new resources when no ID provided.""" + created_patient = Patient(id="456", active=True) + mock_client = AsyncMock() + mock_client.create.return_value = created_patient + + with patch.object(fhir_gateway, "get_client", return_value=mock_client): + async with fhir_gateway.modify(Patient) as patient: + assert patient.id is None # New resource + patient.active = True + + mock_client.create.assert_called_once() + # Verify the created resource was updated with returned values + assert patient.id == "456" + + +async def test_execute_with_client_handles_client_errors(fhir_gateway): + """_execute_with_client properly handles and re-raises client errors.""" + mock_client = AsyncMock() + mock_client.read.side_effect = Exception("Client error") + + with patch.object(fhir_gateway, "get_client", return_value=mock_client): + with patch( + "healthchain.gateway.core.fhirgateway.FHIRErrorHandler.handle_fhir_error" + ) as mock_handler: + mock_handler.side_effect = Exception("Handled error") + + with pytest.raises(Exception, match="Handled error"): + await fhir_gateway._execute_with_client( + "read", + resource_type=Patient, + resource_id="123", + client_args=(Patient, "123"), + ) + + mock_handler.assert_called_once() diff --git a/tests/gateway/test_notereader.py b/tests/gateway/test_notereader.py index 510e61be..865c884b 100644 --- a/tests/gateway/test_notereader.py +++ b/tests/gateway/test_notereader.py @@ -2,93 +2,79 @@ from unittest.mock import patch, MagicMock from healthchain.gateway.protocols.notereader import ( - NoteReaderGateway, + NoteReaderService, NoteReaderConfig, ) from healthchain.models.requests import CdaRequest from healthchain.models.responses.cdaresponse import CdaResponse -from healthchain.gateway.events.dispatcher import EventDispatcher -def test_notereader_gateway_initialization(): - """Test NoteReaderGateway initialization with default config""" - gateway = NoteReaderGateway() +@pytest.mark.parametrize( + "config_args,expected_values", + [ + # Default config via create() + ( + {}, + { + "service_name": "ICDSServices", + "namespace": "urn:epic-com:Common.2013.Services", + "system_type": "EHR_CDA", + "default_mount_path": "/notereader", + }, + ), + # Custom config + ( + { + "service_name": "CustomService", + "namespace": "urn:custom:namespace", + "system_type": "CUSTOM_SYSTEM", + "default_mount_path": "/custom-path", + }, + { + "service_name": "CustomService", + "namespace": "urn:custom:namespace", + "system_type": "CUSTOM_SYSTEM", + "default_mount_path": "/custom-path", + }, + ), + ], +) +def test_notereader_service_configuration(config_args, expected_values): + """NoteReaderService supports both default and custom configurations.""" + if config_args: + config = NoteReaderConfig(**config_args) + gateway = NoteReaderService(config=config) + else: + gateway = NoteReaderService.create() + + assert isinstance(gateway, NoteReaderService) assert isinstance(gateway.config, NoteReaderConfig) - assert gateway.config.service_name == "ICDSServices" - assert gateway.config.namespace == "urn:epic-com:Common.2013.Services" - assert gateway.config.system_type == "EHR_CDA" + for attr_name, expected_value in expected_values.items(): + assert getattr(gateway.config, attr_name) == expected_value -def test_notereader_gateway_create(): - """Test NoteReaderGateway.create factory method""" - gateway = NoteReaderGateway.create() - assert isinstance(gateway, NoteReaderGateway) - assert isinstance(gateway.config, NoteReaderConfig) +def test_notereader_handler_registration_methods(): + """NoteReaderService supports both direct registration and decorator-based registration.""" + gateway = NoteReaderService() -def test_notereader_gateway_register_handler(): - """Test handler registration with gateway""" - gateway = NoteReaderGateway() + # Test direct registration mock_handler = MagicMock(return_value=CdaResponse(document="test", error=None)) - - # Register handler gateway.register_handler("ProcessDocument", mock_handler) - - # Verify handler is registered assert "ProcessDocument" in gateway._handlers assert gateway._handlers["ProcessDocument"] == mock_handler - -def test_notereader_gateway_method_decorator(): - """Test method decorator for registering handlers""" - gateway = NoteReaderGateway() - - @gateway.method("ProcessDocument") - def process_document(request): - return CdaResponse(document="processed", error=None) - - # Verify handler is registered - assert "ProcessDocument" in gateway._handlers - - -def test_notereader_gateway_handle(): - """Test request handling logic directly (bypassing async methods)""" - gateway = NoteReaderGateway() - - # Register a handler - @gateway.method("ProcessDocument") - def process_document(request): + # Test decorator registration + @gateway.method("ProcessNotes") + def process_notes(request): return CdaResponse(document="processed", error=None) - # Create a request - request = CdaRequest(document="test") - - # Instead of testing the async handle method, let's test the core logic directly - # Extract the request - extracted_request = gateway._extract_request( - "ProcessDocument", {"request": request} - ) - assert extracted_request == request - - # Verify handler is properly registered - assert "ProcessDocument" in gateway._handlers - handler = gateway._handlers["ProcessDocument"] - - # Call the handler directly - handler_result = handler(request) - assert isinstance(handler_result, CdaResponse) - assert handler_result.document == "processed" - - # Verify process_result works correctly - processed_result = gateway._process_result(handler_result) - assert isinstance(processed_result, CdaResponse) - assert processed_result.document == "processed" - assert processed_result.error is None + assert "ProcessNotes" in gateway._handlers def test_notereader_gateway_extract_request(): """Test request extraction from parameters""" - gateway = NoteReaderGateway() + gateway = NoteReaderService() # Case 1: CdaRequest passed directly request = CdaRequest(document="test") @@ -110,7 +96,7 @@ def test_notereader_gateway_extract_request(): def test_notereader_gateway_process_result(): """Test processing results from handlers""" - gateway = NoteReaderGateway() + gateway = NoteReaderService() # Test with CdaResponse object response = CdaResponse(document="test", error=None) @@ -123,22 +109,15 @@ def test_notereader_gateway_process_result(): assert isinstance(result, CdaResponse) assert result.document == "test_dict" - # Test with unexpected type - result = gateway._process_result("just a string") - assert isinstance(result, CdaResponse) - assert result.document == "just a string" - assert result.error is None - @patch("healthchain.gateway.protocols.notereader.Application") @patch("healthchain.gateway.protocols.notereader.WsgiApplication") def test_notereader_gateway_create_wsgi_app(mock_wsgi, mock_application): """Test WSGI app creation for SOAP service""" - # Set up the mock to return a simple mock object instead of trying to create a real WsgiApplication mock_wsgi_instance = MagicMock() mock_wsgi.return_value = mock_wsgi_instance - gateway = NoteReaderGateway() + gateway = NoteReaderService() # Register required ProcessDocument handler @gateway.method("ProcessDocument") @@ -161,7 +140,7 @@ def process_document(request): def test_notereader_gateway_create_wsgi_app_no_handler(): """Test WSGI app creation fails without ProcessDocument handler""" - gateway = NoteReaderGateway() + gateway = NoteReaderService() # No handler registered - should raise ValueError with pytest.raises(ValueError): @@ -170,7 +149,7 @@ def test_notereader_gateway_create_wsgi_app_no_handler(): def test_notereader_gateway_get_metadata(): """Test retrieving gateway metadata""" - gateway = NoteReaderGateway() + gateway = NoteReaderService() # Register a handler to have some capabilities @gateway.method("ProcessDocument") @@ -181,68 +160,11 @@ def process_document(request): metadata = gateway.get_metadata() # Verify metadata contains expected keys - assert "gateway_type" in metadata - assert metadata["gateway_type"] == "NoteReaderGateway" + assert "service_type" in metadata + assert metadata["service_type"] in "NoteReaderService" assert "operations" in metadata assert "ProcessDocument" in metadata["operations"] assert "system_type" in metadata assert metadata["system_type"] == "EHR_CDA" assert "mount_path" in metadata assert metadata["mount_path"] == "/notereader" - - -def test_notereader_gateway_custom_config(): - """Test NoteReaderGateway with custom configuration""" - custom_config = NoteReaderConfig( - service_name="CustomService", - namespace="urn:custom:namespace", - system_type="CUSTOM_SYSTEM", - default_mount_path="/custom-path", - ) - - gateway = NoteReaderGateway(config=custom_config) - - assert gateway.config.service_name == "CustomService" - assert gateway.config.namespace == "urn:custom:namespace" - assert gateway.config.system_type == "CUSTOM_SYSTEM" - assert gateway.config.default_mount_path == "/custom-path" - - -@patch("healthchain.gateway.protocols.notereader.CDSServices") -def test_notereader_gateway_event_emission(mock_cds_services): - """Test that events are emitted when handling requests""" - # Create mock event dispatcher - mock_dispatcher = MagicMock(spec=EventDispatcher) - - # Create gateway with event dispatcher - gateway = NoteReaderGateway(event_dispatcher=mock_dispatcher) - - # Mock the service adapter directly - mock_service_adapter = MagicMock() - mock_cds_services._service = mock_service_adapter - - # Register a handler - @gateway.method("ProcessDocument") - def process_document(request): - return CdaResponse(document="processed", error=None) - - # Create WSGI app to install handler - with patch("healthchain.gateway.protocols.notereader.WsgiApplication"): - with patch("healthchain.gateway.protocols.notereader.Application"): - gateway.create_wsgi_app() - - # Get the adapter function from the CDSServices class (this would be set by create_wsgi_app) - mock_cds_services._service - - # Create a request and manually call the adapter function - # just to verify it would call our event dispatcher - with patch.object(gateway, "_emit_document_event") as mock_emit: - request = CdaRequest(document="test") - mock_handler = gateway._handlers["ProcessDocument"] - - # Simulate what would happen in service_adapter - result = mock_handler(request) - gateway._emit_document_event("ProcessDocument", request, result) - - # Verify event emission was called - mock_emit.assert_called_once() diff --git a/tests/gateway/test_protocols.py b/tests/gateway/test_protocols.py deleted file mode 100644 index 9ff02d86..00000000 --- a/tests/gateway/test_protocols.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Tests for Protocol conformance in the HealthChain gateway system. - -This module tests whether the implementations of various components -correctly conform to their defined Protocol interfaces. -""" - -from typing import cast - -from healthchain.gateway.api.protocols import ( - HealthChainAPIProtocol, - GatewayProtocol, - EventDispatcherProtocol, -) -from healthchain.gateway.api.app import create_app -from healthchain.gateway.events.dispatcher import EventDispatcher -from tests.gateway.test_api_app import MockGateway - - -def test_healthchainapi_conforms_to_protocol(): - """Test that HealthChainAPI conforms to HealthChainAPIProtocol.""" - # Create an instance of HealthChainAPI - app = create_app() - - # Cast to the protocol type - this will fail at runtime if not compatible - protocol_app = cast(HealthChainAPIProtocol, app) - - # Basic assertions to check that it functions as expected - assert hasattr(protocol_app, "get_event_dispatcher") - assert hasattr(protocol_app, "get_gateway") - assert hasattr(protocol_app, "get_all_gateways") - assert hasattr(protocol_app, "register_gateway") - assert hasattr(protocol_app, "register_router") - - -def test_eventdispatcher_conforms_to_protocol(): - """Test that EventDispatcher conforms to EventDispatcherProtocol.""" - # Create an instance of EventDispatcher - dispatcher = EventDispatcher() - - # Cast to the protocol type - this will fail at runtime if not compatible - protocol_dispatcher = cast(EventDispatcherProtocol, dispatcher) - - # Basic assertions to check that it functions as expected - assert hasattr(protocol_dispatcher, "publish") - assert hasattr(protocol_dispatcher, "init_app") - assert hasattr(protocol_dispatcher, "register_handler") - - -def test_gateway_conforms_to_protocol(): - """Test that MockGateway conforms to GatewayProtocol.""" - # Create an instance of MockGateway - gateway = MockGateway() - - # Cast to the protocol type - this will fail at runtime if not compatible - protocol_gateway = cast(GatewayProtocol, gateway) - - # Basic assertions to check that it functions as expected - assert hasattr(protocol_gateway, "get_metadata") - assert hasattr(protocol_gateway, "set_event_dispatcher") - - -def test_typed_gateway_access(): - """Test accessing a gateway with a specific protocol type.""" - # Create app and gateway - app = create_app() - gateway = MockGateway() - app.register_gateway(gateway) - - # Test getting the gateway as a general GatewayProtocol - retrieved_gateway = app.get_gateway("MockGateway") - assert retrieved_gateway is not None - - # Cast to protocol type - will fail if not compatible - protocol_gateway = cast(GatewayProtocol, retrieved_gateway) - assert protocol_gateway.get_metadata() == gateway.get_metadata() diff --git a/tests/sandbox/test_cds_sandbox.py b/tests/sandbox/test_cds_sandbox.py index 82663ae0..bf51ec06 100644 --- a/tests/sandbox/test_cds_sandbox.py +++ b/tests/sandbox/test_cds_sandbox.py @@ -1,7 +1,7 @@ from unittest.mock import patch, MagicMock import healthchain as hc -from healthchain.gateway.protocols.cdshooks import CDSHooksGateway +from healthchain.gateway.protocols.cdshooks import CDSHooksService from healthchain.gateway.api import HealthChainAPI from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.models.responses.cdsresponse import CDSResponse, Card @@ -14,7 +14,7 @@ def test_cdshooks_sandbox_integration(): """Test CDSHooks service integration with sandbox decorator""" # Create HealthChainAPI instead of FastAPI app = HealthChainAPI() - cds_service = CDSHooksGateway() + cds_service = CDSHooksService() # Register a hook handler for the service @cds_service.hook("patient-view", id="test-patient-view") @@ -26,7 +26,7 @@ async def handle_patient_view(request: CDSRequest) -> CDSResponse: ) # Register the service with the HealthChainAPI - app.register_gateway(cds_service, "/cds") + app.register_service(cds_service, "/cds") # Define a sandbox class using the CDSHooks service @hc.sandbox("http://localhost:8000/") diff --git a/tests/sandbox/test_clients.py b/tests/sandbox/test_clients.py index 320c2cb5..694653ac 100644 --- a/tests/sandbox/test_clients.py +++ b/tests/sandbox/test_clients.py @@ -26,7 +26,7 @@ def test_generate_request(ehr_client, mock_strategy): assert len(ehr_client.request_data) == 1 -@pytest.mark.anyio +@pytest.mark.asyncio @patch.object( httpx.AsyncClient, "post", @@ -38,7 +38,7 @@ async def test_send_request(ehr_client): assert all(response["status"] == "success" for response in responses) -@pytest.mark.anyio +@pytest.mark.asyncio async def test_logging_on_send_request_error(caplog, ehr_client): with patch.object(httpx.AsyncClient, "post") as mock_post: mock_post.return_value = Mock() diff --git a/tests/sandbox/test_clindoc_sandbox.py b/tests/sandbox/test_clindoc_sandbox.py index 99ebd93f..b071b778 100644 --- a/tests/sandbox/test_clindoc_sandbox.py +++ b/tests/sandbox/test_clindoc_sandbox.py @@ -1,7 +1,7 @@ from unittest.mock import patch, MagicMock import healthchain as hc -from healthchain.gateway.protocols.notereader import NoteReaderGateway +from healthchain.gateway.protocols.notereader import NoteReaderService from healthchain.gateway.api import HealthChainAPI from healthchain.models.requests import CdaRequest from healthchain.models.responses.cdaresponse import CdaResponse @@ -13,7 +13,7 @@ def test_notereader_sandbox_integration(): """Test NoteReaderService integration with sandbox decorator""" # Use HealthChainAPI instead of FastAPI app = HealthChainAPI() - note_service = NoteReaderGateway() + note_service = NoteReaderService() # Register a method handler for the service @note_service.method("ProcessDocument") @@ -21,7 +21,7 @@ def process_document(cda_request: CdaRequest) -> CdaResponse: return CdaResponse(document="document", error=None) # Register service with HealthChainAPI - app.register_gateway(note_service, "/notereader") + app.register_service(note_service, "/notereader") # Define a sandbox class that uses the NoteReader service @hc.sandbox("http://localhost:8000/")