diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py index 28afa29..849379b 100644 --- a/src/fastapi_cli/cli.py +++ b/src/fastapi_cli/cli.py @@ -100,6 +100,7 @@ def _run( entrypoint: Union[str, None] = None, proxy_headers: bool = False, forwarded_allow_ips: Union[str, None] = None, + is_factory: bool = False, ) -> None: with get_rich_toolkit() as toolkit: server_type = "development" if command == "dev" else "production" @@ -136,13 +137,14 @@ def _run( toolkit.print(root_tree, tag="module") toolkit.print_line() + imported_object_type = "factory" if is_factory else "app" toolkit.print( - "Importing the FastAPI app object from the module with the following code:", + f"Importing the FastAPI {imported_object_type} object from the module with the following code:", tag="code", ) toolkit.print_line() toolkit.print( - f"[underline]from [bold]{module_data.module_import_str}[/bold] import [bold]{import_data.app_name}[/bold]" + f"[underline]from [bold]{module_data.module_import_str}[/bold] import [bold]{import_data.candidate_name}[/bold]" ) toolkit.print_line() @@ -187,6 +189,7 @@ def _run( proxy_headers=proxy_headers, forwarded_allow_ips=forwarded_allow_ips, log_config=get_uvicorn_log_config(), + factory=is_factory, ) @@ -195,7 +198,7 @@ def dev( path: Annotated[ Union[Path, None], typer.Argument( - help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app. If not provided, a default set of paths will be tried." + help="A path to a Python file or package directory (with [blue]__init__.py[/blue] files) containing a [bold]FastAPI[/bold] app or app factory. If not provided, a default set of paths will be tried." ), ] = None, *, @@ -250,6 +253,12 @@ def dev( help="Comma separated list of IP Addresses to trust with proxy headers. The literal '*' means trust everything." ), ] = None, + factory: Annotated[ + bool, + typer.Option( + help="Treat [bold]path[bold] as an application factory, i.e. a () -> callable." + ), + ] = False, ) -> Any: """ Run a [bold]FastAPI[/bold] app in [yellow]development[/yellow] mode. ๐Ÿงช @@ -287,6 +296,7 @@ def dev( command="dev", proxy_headers=proxy_headers, forwarded_allow_ips=forwarded_allow_ips, + is_factory=factory, ) @@ -356,6 +366,12 @@ def run( help="Comma separated list of IP Addresses to trust with proxy headers. The literal '*' means trust everything." ), ] = None, + factory: Annotated[ + bool, + typer.Option( + help="Treat [bold]path[bold] as an application factory, i.e. a () -> callable." + ), + ] = False, ) -> Any: """ Run a [bold]FastAPI[/bold] app in [green]production[/green] mode. ๐Ÿš€ @@ -394,6 +410,7 @@ def run( command="run", proxy_headers=proxy_headers, forwarded_allow_ips=forwarded_allow_ips, + is_factory=factory, ) diff --git a/src/fastapi_cli/discover.py b/src/fastapi_cli/discover.py index b174f8f..294eaea 100644 --- a/src/fastapi_cli/discover.py +++ b/src/fastapi_cli/discover.py @@ -100,18 +100,23 @@ def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> obj = getattr(mod, name) if isinstance(obj, FastAPI): return name - raise FastAPICLIException("Could not find FastAPI app in module, try using --app") + raise FastAPICLIException( + "Could not find FastAPI app or app factory in module, try using --app" + ) @dataclass class ImportData: - app_name: str + # candidate is an app or a factory + candidate_name: str module_data: ModuleData import_string: str def get_import_data( - *, path: Union[Path, None] = None, app_name: Union[str, None] = None + *, + path: Union[Path, None] = None, + app_name: Union[str, None] = None, ) -> ImportData: if not path: path = get_default_path() @@ -128,7 +133,7 @@ def get_import_data( import_string = f"{mod_data.module_import_str}:{use_app_name}" return ImportData( - app_name=use_app_name, module_data=mod_data, import_string=import_string + candidate_name=use_app_name, module_data=mod_data, import_string=import_string ) @@ -145,7 +150,7 @@ def get_import_data_from_import_string(import_string: str) -> ImportData: sys.path.insert(0, str(here)) return ImportData( - app_name=app_name, + candidate_name=app_name, module_data=ModuleData( module_import_str=module_str, extra_sys_path=here, diff --git a/tests/test_cli.py b/tests/test_cli.py index b87a811..084fc58 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -32,6 +32,7 @@ def test_dev() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting development server ๐Ÿš€" in result.output @@ -62,6 +63,7 @@ def test_dev_package() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: nested_package.package:app" in result.output assert "Starting development server ๐Ÿš€" in result.output @@ -111,6 +113,7 @@ def test_dev_args() -> None: "proxy_headers": False, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:api" in result.output assert "Starting development server ๐Ÿš€" in result.output @@ -141,6 +144,7 @@ def test_dev_env_vars() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting development server ๐Ÿš€" in result.output @@ -178,6 +182,7 @@ def test_dev_env_vars_and_args() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting development server ๐Ÿš€" in result.output @@ -206,6 +211,7 @@ def test_run() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting production server ๐Ÿš€" in result.output @@ -232,6 +238,7 @@ def test_run_trust_proxy() -> None: "proxy_headers": True, "forwarded_allow_ips": "*", "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting production server ๐Ÿš€" in result.output @@ -278,6 +285,7 @@ def test_run_args() -> None: "proxy_headers": False, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:api" in result.output @@ -309,6 +317,7 @@ def test_run_env_vars() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting production server ๐Ÿš€" in result.output @@ -342,6 +351,7 @@ def test_run_env_vars_and_args() -> None: "proxy_headers": True, "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output assert "Starting production server ๐Ÿš€" in result.output @@ -428,6 +438,7 @@ def test_dev_with_import_string() -> None: "root_path": "", "proxy_headers": True, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:api" in result.output @@ -449,6 +460,7 @@ def test_run_with_import_string() -> None: "root_path": "", "proxy_headers": True, "log_config": get_uvicorn_log_config(), + "factory": False, } assert "Using import string: single_file_app:app" in result.output diff --git a/tests/test_discover.py b/tests/test_discover.py index b105205..2c42ec6 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -15,7 +15,7 @@ def test_get_import_data_from_import_string_valid() -> None: result = get_import_data_from_import_string("module.submodule:app") assert isinstance(result, ImportData) - assert result.app_name == "app" + assert result.candidate_name == "app" assert result.import_string == "module.submodule:app" assert result.module_data.module_import_str == "module.submodule" assert result.module_data.extra_sys_path == Path(".").resolve() diff --git a/tests/test_utils_package.py b/tests/test_utils_package.py index 407561d..595c422 100644 --- a/tests/test_utils_package.py +++ b/tests/test_utils_package.py @@ -174,7 +174,8 @@ def test_package_dir_no_app() -> None: with pytest.raises(FastAPICLIException) as e: get_import_data(path=Path("package/core/utils.py")) assert ( - "Could not find FastAPI app in module, try using --app" in e.value.args[0] + "Could not find FastAPI app or app factory in module, try using --app" + in e.value.args[0] )