diff --git a/README.md b/README.md
index 6fc80e13..e8dc57fb 100644
--- a/README.md
+++ b/README.md
@@ -89,6 +89,12 @@ It will listen on the IP address `0.0.0.0`, which means all the available IP add
In most cases you would (and should) have a "termination proxy" handling HTTPS for you on top, this will depend on how you deploy your application, your provider might do this for you, or you might need to set it up yourself. You can learn more about it in the FastAPI Deployment documentation.
+## `fastapi schema`
+
+When you run `fastapi schema`, it will generate a swagger/openapi document.
+
+This document will be output to stderr by default, however `--output ` option can be used to write output into file. You can control the format of the JSON file by specifying indent level with `--indent #`. If set to 0, JSON will be in the minimal/compress form. Default is 2 spaces.
+
## License
This project is licensed under the terms of the MIT license.
diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py
index 28afa297..50fdd67e 100644
--- a/src/fastapi_cli/cli.py
+++ b/src/fastapi_cli/cli.py
@@ -1,4 +1,6 @@
+import json
import logging
+import sys
from pathlib import Path
from typing import Any, List, Union
@@ -7,7 +9,11 @@
from rich.tree import Tree
from typing_extensions import Annotated
-from fastapi_cli.discover import get_import_data, get_import_data_from_import_string
+from fastapi_cli.discover import (
+ get_app,
+ get_import_data,
+ get_import_data_from_import_string,
+)
from fastapi_cli.exceptions import FastAPICLIException
from . import __version__
@@ -397,5 +403,42 @@ def run(
)
+@app.command()
+def schema(
+ path: Annotated[
+ Union[Path, None],
+ typer.Argument(
+ help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried."
+ ),
+ ] = None,
+ *,
+ app: Annotated[
+ Union[str, None],
+ typer.Option(
+ help="The name of the variable that contains the [bold]FastAPI[/bold] app in the imported module or package. If not provided, it is detected automatically."
+ ),
+ ] = None,
+ output: Annotated[
+ Union[str, None],
+ typer.Option(
+ help="The filename to write schema to. If not provided, write to stderr."
+ ),
+ ] = None,
+ indent: Annotated[
+ int,
+ typer.Option(help="JSON format indent. If 0, disable pretty printing"),
+ ] = 2,
+) -> Any:
+ """Generate schema"""
+ fastapi_app = get_app(path=path, app_name=app)
+ schema = fastapi_app.openapi()
+
+ stream = open(output, "w") if output else sys.stderr
+ json.dump(schema, stream, indent=indent if indent > 0 else None)
+ if output:
+ stream.close()
+ return 0
+
+
def main() -> None:
app()
diff --git a/src/fastapi_cli/discover.py b/src/fastapi_cli/discover.py
index b174f8fb..559c2d73 100644
--- a/src/fastapi_cli/discover.py
+++ b/src/fastapi_cli/discover.py
@@ -1,9 +1,10 @@
import importlib
import sys
+from contextlib import contextmanager
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
-from typing import List, Union
+from typing import Iterator, List, Union
from fastapi_cli.exceptions import FastAPICLIException
@@ -41,6 +42,18 @@ class ModuleData:
extra_sys_path: Path
module_paths: List[Path]
+ @contextmanager
+ def sys_path(self) -> Iterator[str]:
+ """Context manager to temporarily alter sys.path"""
+ extra_sys_path = str(self.extra_sys_path) if self.extra_sys_path else ""
+ if extra_sys_path:
+ logger.debug("Adding %s to sys.path...", extra_sys_path)
+ sys.path.insert(0, extra_sys_path)
+ yield extra_sys_path
+ if extra_sys_path and sys.path and sys.path[0] == extra_sys_path:
+ logger.debug("Removing %s from sys.path...", extra_sys_path)
+ sys.path.pop(0)
+
def get_module_data_from_path(path: Path) -> ModuleData:
use_path = path.resolve()
@@ -153,3 +166,16 @@ def get_import_data_from_import_string(import_string: str) -> ImportData:
),
import_string=import_string,
)
+
+
+def get_app(
+ *, path: Union[Path, None] = None, app_name: Union[str, None] = None
+) -> FastAPI:
+ """Get the FastAPI app instance from the given path and app name."""
+ import_data: ImportData = get_import_data(path=path, app_name=app_name)
+ mod_data, use_app_name = import_data.module_data, import_data.app_name
+ with mod_data.sys_path():
+ mod = importlib.import_module(mod_data.module_import_str)
+ app = getattr(mod, use_app_name)
+ ## get_import_string_parts guarantees app is FastAPI object
+ return app # type: ignore[no-any-return]
diff --git a/tests/assets/openapi.json b/tests/assets/openapi.json
new file mode 100644
index 00000000..61506d7c
--- /dev/null
+++ b/tests/assets/openapi.json
@@ -0,0 +1,25 @@
+{
+ "openapi": "3.1.0",
+ "info": {
+ "title": "FastAPI",
+ "version": "0.1.0"
+ },
+ "paths": {
+ "/": {
+ "get": {
+ "summary": "App Root",
+ "operationId": "app_root__get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/json": {
+ "schema": {}
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/tests/test_cli.py b/tests/test_cli.py
index b87a811a..abc18cb4 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,3 +1,4 @@
+import os
import subprocess
import sys
from pathlib import Path
@@ -7,6 +8,7 @@
from typer.testing import CliRunner
from fastapi_cli.cli import app
+from fastapi_cli.exceptions import FastAPICLIException
from fastapi_cli.utils.cli import get_uvicorn_log_config
from tests.utils import changing_dir
@@ -15,6 +17,13 @@
assets_path = Path(__file__).parent / "assets"
+def read_file(filename: str, strip: bool = True) -> str:
+ """Read file and return content as string"""
+ with open("openapi.json") as stream:
+ data = stream.read()
+ return data.strip() if data and strip else data
+
+
def test_dev() -> None:
with changing_dir(assets_path):
with patch.object(uvicorn, "run") as mock_run:
@@ -377,6 +386,51 @@ def test_dev_help() -> None:
assert "Use multiple worker processes." not in result.output
+def test_schema() -> None:
+ with changing_dir(assets_path):
+ with open("openapi.json") as stream:
+ expected = stream.read().strip()
+ assert expected != "", "Failed to read expected result"
+ result = runner.invoke(app, ["schema", "single_file_app.py"])
+ assert result.exit_code == 0, result.output
+ assert expected in result.output, result.output
+
+
+def test_schema_file() -> None:
+ with changing_dir(assets_path):
+ filename = "unit-test.json"
+ expected = read_file("openapi.json", strip=True)
+ assert expected != "", "Failed to read expected result"
+ result = runner.invoke(
+ app, ["schema", "single_file_app.py", "--output", filename]
+ )
+ assert os.path.isfile(filename)
+ actual = read_file(filename, strip=True)
+ os.remove(filename)
+ assert result.exit_code == 0, result.output
+ assert expected == actual
+
+
+def test_schema_invalid_path() -> None:
+ with changing_dir(assets_path):
+ result = runner.invoke(app, ["schema", "invalid/single_file_app.py"])
+ assert result.exit_code == 1, result.output
+ assert isinstance(result.exception, FastAPICLIException)
+ assert "Path does not exist invalid/single_file_app.py" in str(result.exception)
+
+
+#
+#
+# def test_schema_invalid_package() -> None:
+# with changing_dir(assets_path):
+# result = runner.invoke(
+# app, ["schema", "broken_package/mod/app.py"]
+# )
+# assert result.exit_code == 1, result.output
+# assert isinstance(result.exception, ImportError)
+# assert "attempted relative import beyond top-level package" in str(result.exception)
+
+
def test_run_help() -> None:
result = runner.invoke(app, ["run", "--help"])
assert result.exit_code == 0, result.output
diff --git a/tests/test_utils_package.py b/tests/test_utils_package.py
index 407561da..c43c9b98 100644
--- a/tests/test_utils_package.py
+++ b/tests/test_utils_package.py
@@ -3,7 +3,7 @@
import pytest
from pytest import CaptureFixture
-from fastapi_cli.discover import get_import_data
+from fastapi_cli.discover import get_app, get_import_data
from fastapi_cli.exceptions import FastAPICLIException
from tests.utils import changing_dir
@@ -169,6 +169,17 @@ def test_broken_package_dir(capsys: CaptureFixture[str]) -> None:
assert "Ensure all the package directories have an __init__.py file" in captured.out
+def test_get_app_broken_package_dir(capsys: CaptureFixture[str]) -> None:
+ with changing_dir(assets_path):
+ # TODO (when deprecating Python 3.8): remove ValueError
+ with pytest.raises((ImportError, ValueError)):
+ get_app(path=Path("broken_package/mod/app.py"))
+
+ captured = capsys.readouterr()
+ assert "Import error:" in captured.out
+ assert "Ensure all the package directories have an __init__.py file" in captured.out
+
+
def test_package_dir_no_app() -> None:
with changing_dir(assets_path):
with pytest.raises(FastAPICLIException) as e: