diff --git a/fastdeploy/__init__.py b/fastdeploy/__init__.py index edb2aa43a2..b386676da1 100644 --- a/fastdeploy/__init__.py +++ b/fastdeploy/__init__.py @@ -44,6 +44,9 @@ # TODO(tangbinhan): remove this code +__version__ = "2.3.0-dev" + + def _patch_fastsafetensors(): try: file_path = ( diff --git a/fastdeploy/entrypoints/cli/main.py b/fastdeploy/entrypoints/cli/main.py index a4ba74afed..1966072603 100644 --- a/fastdeploy/entrypoints/cli/main.py +++ b/fastdeploy/entrypoints/cli/main.py @@ -17,15 +17,17 @@ # This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/main.py from __future__ import annotations -import importlib.metadata +from fastdeploy import __version__ def main(): import fastdeploy.entrypoints.cli.openai + import fastdeploy.entrypoints.cli.serve from fastdeploy.utils import FlexibleArgumentParser CMD_MODULES = [ fastdeploy.entrypoints.cli.openai, + fastdeploy.entrypoints.cli.serve, ] parser = FlexibleArgumentParser(description="FastDeploy CLI") @@ -33,7 +35,7 @@ def main(): "-v", "--version", action="version", - version=importlib.metadata.version("fastdeploy"), + version=__version__, ) subparsers = parser.add_subparsers(required=False, dest="subparser") cmds = {} diff --git a/fastdeploy/entrypoints/cli/serve.py b/fastdeploy/entrypoints/cli/serve.py new file mode 100755 index 0000000000..a42fa690a1 --- /dev/null +++ b/fastdeploy/entrypoints/cli/serve.py @@ -0,0 +1,51 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/serve.py + +import argparse + +import fastdeploy.entrypoints.openai.api_server as api_server +from fastdeploy.entrypoints.cli.types import CLISubcommand +from fastdeploy.entrypoints.openai.api_server import make_arg_parser +from fastdeploy.utils import FlexibleArgumentParser, YamlInputAction + + +class ServeSubcommand(CLISubcommand): + """The `serve` subcommand for the fastdeploy CLI.""" + + name = "serve" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + api_server.main(args) + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + name=self.name, + help="Start the FastDeploy OpenAI Compatible API server.", + description="Start the FastDeploy OpenAI Compatible API server.", + usage="fastdeploy serve [model_tag] [options]", + ) + serve_parser = make_arg_parser(serve_parser) + serve_parser.add_argument( + "--config", action=YamlInputAction, help="Read CLI options from a config file. Must be a YAML file" + ) + return serve_parser + + +def cmd_init() -> list[CLISubcommand]: + return [ServeSubcommand()] diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py old mode 100644 new mode 100755 index 40227d0a00..f74be5d166 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -14,16 +14,26 @@ # limitations under the License. """ +import argparse import asyncio +import multiprocessing import os +import signal +import socket +import sys import threading import time import traceback +import weakref from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from multiprocessing import current_process +from multiprocessing import connection, current_process +from multiprocessing.process import BaseProcess +from typing import Any, Callable, Optional +import setproctitle import uvicorn +import uvloop import zmq from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError @@ -64,46 +74,51 @@ api_server_logger, console_logger, is_port_available, + is_valid_ipv6_address, + kill_process_tree, retrive_model_from_server, ) -parser = FlexibleArgumentParser() -parser.add_argument("--port", default=8000, type=int, help="port to the http server") -parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") -parser.add_argument("--workers", default=1, type=int, help="number of workers") -parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server") -parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") -parser.add_argument( - "--max-waiting-time", - default=-1, - type=int, - help="max waiting time for connection, if set value -1 means no waiting time limit", -) -parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") +llm_engine = None -parser.add_argument( - "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " -) -parser.add_argument( - "--timeout-graceful-shutdown", - default=0, - type=int, - help="timeout for graceful shutdown in seconds (used by uvicorn)", -) -parser = EngineArgs.add_cli_args(parser) -args = parser.parse_args() +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--port", default=8000, type=int, help="port to the http server") + parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") + parser.add_argument("--workers", default=1, type=int, help="number of workers") + parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server") + parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") + parser.add_argument( + "--max-waiting-time", + default=-1, + type=int, + help="max waiting time for connection, if set value -1 means no waiting time limit", + ) + parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") -console_logger.info(f"Number of api-server workers: {args.workers}.") + parser.add_argument( + "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " + ) + parser.add_argument( + "--timeout-graceful-shutdown", + default=0, + type=int, + help="timeout for graceful shutdown in seconds (used by uvicorn)", + ) + parser = EngineArgs.add_cli_args(parser) + return parser -args.model = retrive_model_from_server(args.model, args.revision) -chat_template = load_chat_template(args.chat_template, args.model) -if args.tool_parser_plugin: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) -llm_engine = None +def rewrite_args(args: argparse.Namespace) -> argparse.Namespace: + console_logger.info(f"Number of api-server workers: {args.workers}.") -def load_engine(): + args.model = retrive_model_from_server(args.model, args.revision) + if args.tool_parser_plugin: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + return args + + +def load_engine(args: argparse.Namespace): """ load engine """ @@ -122,13 +137,7 @@ def load_engine(): return engine -app = FastAPI() - -MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers -connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) - - -def load_data_service(): +def load_data_service(args: argparse.Namespace) -> ExpertService: """ load data service """ @@ -147,434 +156,706 @@ def load_data_service(): return expert_service -@asynccontextmanager -async def lifespan(app: FastAPI): - """ - async context manager for FastAPI lifespan - """ - - if args.tokenizer is None: - args.tokenizer = args.model - if current_process().name != "MainProcess": - pid = os.getppid() - else: - pid = os.getpid() - api_server_logger.info(f"{pid}") - - if args.served_model_name is not None: - served_model_names = args.served_model_name - verification = True - else: - served_model_names = args.model - verification = False - model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)] - - engine_client = EngineClient( - model_name_or_path=args.model, - tokenizer=args.tokenizer, - max_model_len=args.max_model_len, - tensor_parallel_size=args.tensor_parallel_size, - pid=pid, - port=int(args.engine_worker_queue_port[args.local_data_parallel_id]), - limit_mm_per_prompt=args.limit_mm_per_prompt, - mm_processor_kwargs=args.mm_processor_kwargs, - # args.enable_mm, - reasoning_parser=args.reasoning_parser, - data_parallel_size=args.data_parallel_size, - enable_logprob=args.enable_logprob, - workers=args.workers, - tool_parser=args.tool_call_parser, - ) - await engine_client.connection_manager.initialize() - app.state.dynamic_load_weight = args.dynamic_load_weight - model_handler = OpenAIServingModels( - model_paths, - args.max_model_len, - args.ips, - ) - app.state.model_handler = model_handler - chat_handler = OpenAIServingChat( - engine_client, - app.state.model_handler, - pid, - args.ips, - args.max_waiting_time, - chat_template, - args.enable_mm_output, - args.tokenizer_base_url, - ) - completion_handler = OpenAIServingCompletion( - engine_client, - app.state.model_handler, - pid, - args.ips, - args.max_waiting_time, - ) - engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) - engine_client.pid = pid - app.state.engine_client = engine_client - app.state.chat_handler = chat_handler - app.state.completion_handler = completion_handler - global llm_engine - if llm_engine is not None: - llm_engine.engine.data_processor = engine_client.data_processor - yield - # close zmq - try: - await engine_client.connection_manager.close() - engine_client.zmq_client.close() - from prometheus_client import multiprocess - - multiprocess.mark_process_dead(os.getpid()) - api_server_logger.info(f"Closing metrics client pid: {pid}") - except Exception as e: - api_server_logger.warning(f"exit error: {e}, {str(traceback.format_exc())}") - - -app = FastAPI(lifespan=lifespan) -app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception) -app.add_exception_handler(Exception, ExceptionHandler.handle_exception) -instrument(app) - - -@asynccontextmanager -async def connection_manager(): - """ - async context manager for connection manager - """ - try: - await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001) - yield - except asyncio.TimeoutError: - api_server_logger.info(f"Reach max request concurrency, semaphore status: {connection_semaphore.status()}") - raise HTTPException( - status_code=429, detail=f"Too many requests,current max concurrency is {args.max_concurrency}" - ) - - -# TODO 传递真实引擎值 通过pid 获取状态 -@app.get("/health") -def health(request: Request) -> Response: - """Health check.""" - - status, msg = app.state.engine_client.check_health() - if not status: - return Response(content=msg, status_code=404) - status, msg = app.state.engine_client.is_workers_alive() - if not status: - return Response(content=msg, status_code=304) - return Response(status_code=200) - - -@app.get("/load") -async def list_all_routes(): - """ - 列出所有以/v1开头的路由信息 - - Args: - 无参数 - - Returns: - dict: 包含所有符合条件的路由信息的字典,格式如下: - { - "routes": [ - { - "path": str, # 路由路径 - "methods": list, # 支持的HTTP方法列表,已排序 - "tags": list # 路由标签列表,默认为空列表 - }, - ... - ] - } - - """ - routes_info = [] +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 - for route in app.routes: - # 直接检查路径是否以/v1开头 - if route.path.startswith("/v1"): - methods = sorted(route.methods) - tags = getattr(route, "tags", []) or [] - routes_info.append({"path": route.path, "methods": methods, "tags": tags}) - return {"routes": routes_info} +def set_ulimit(target_soft_limit=65535): + if sys.platform.startswith("win"): + api_server_logger.info("Windows detected, skipping ulimit adjustment.") + return -@app.api_route("/ping", methods=["GET", "POST"]) -def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return health(raw_request) - + import resource -def wrap_streaming_generator(original_generator: AsyncGenerator): - """ - Wrap an async generator to release the connection semaphore when the generator is finished. - """ + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) - async def wrapped_generator(): + if current_soft < target_soft_limit: try: - async for chunk in original_generator: - yield chunk - finally: - api_server_logger.debug(f"release: {connection_semaphore.status()}") - connection_semaphore.release() + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + api_server_logger.warning( + "Found ulimit of %s and failed to automatically increase " + "with error %s. This can cause fd limit errors like " + "`OSError: [Errno 24] Too many open files`. Consider " + "increasing with ulimit -n", + current_soft, + e, + ) - return wrapped_generator +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[BaseProcess] = [] + + for i in range(num_servers): + client_config = {"client_count": num_servers, "client_index": i} + + proc = spawn_context.Process( + target=target_server_fn, name=f"ApiServer_{i}", args=(args, listen_address, sock, client_config) + ) + self.processes.append(proc) + proc.start() + + api_server_logger.info("Started %d API server processes", len(self.processes)) + + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, self.shutdown, self.processes) + + def close(self) -> None: + self._finalizer() + + # Note(rob): shutdown function cannot be a bound method, + # else the gc cannot collect the object. + def shutdown(self, procs: list[BaseProcess]): + # Shutdown the process. + for proc in procs: + if proc.is_alive(): + proc.terminate() + + # Allow 5 seconds for remaining procs to terminate. + deadline = time.monotonic() + 5 + for proc in procs: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + if proc.is_alive(): + proc.join(remaining) + + for proc in procs: + if proc.is_alive() and (pid := proc.pid) is not None: + kill_process_tree(pid) + + +def create_server_socket(addr: tuple[str, int]) -> socket.socket: + family = socket.AF_INET + if is_valid_ipv6_address(addr[0]): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind(addr) + return sock + + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + addr, port = sock_addr + # is_ssl = args.ssl_keyfile and args.ssl_certfile + is_ssl = False + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + return listen_address, sock + + +async def serve_http(app: FastAPI, sock: Optional[socket.socket], **uvicorn_kwargs: Any): + config = uvicorn.Config(app, **uvicorn_kwargs) + config.workers = 1 + config.load() + server = uvicorn.Server(config=config) + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.serve(sockets=[sock] if sock else None)) + + async def dummy_shutdown() -> None: + pass -@app.post("/v1/chat/completions") -async def create_chat_completion(request: ChatCompletionRequest): - """ - Create a chat completion for the provided prompt and parameters. - """ - api_server_logger.info(f"Chat Received request: {request.model_dump_json()}") - if app.state.dynamic_load_weight: - status, msg = app.state.engine_client.is_workers_alive() - if not status: - return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) try: - async with connection_manager(): - inject_to_metadata(request) - generator = await app.state.chat_handler.create_chat_completion(request) - if isinstance(generator, ErrorResponse): - api_server_logger.debug(f"release: {connection_semaphore.status()}") - connection_semaphore.release() - return JSONResponse(content=generator.model_dump(), status_code=500) - elif isinstance(generator, ChatCompletionResponse): - api_server_logger.debug(f"release: {connection_semaphore.status()}") - connection_semaphore.release() - return JSONResponse(content=generator.model_dump()) - else: - wrapped_generator = wrap_streaming_generator(generator) - return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") - - except HTTPException as e: - api_server_logger.error(f"Error in chat completion: {str(e)}") - return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + return server.shutdown() -@app.post("/v1/completions") -async def create_completion(request: CompletionRequest): - """ - Create a completion for the provided prompt and parameters. - """ - api_server_logger.info(f"Completion Received request: {request.model_dump_json()}") - if app.state.dynamic_load_weight: - status, msg = app.state.engine_client.is_workers_alive() - if not status: - return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) - try: - async with connection_manager(): - generator = await app.state.completion_handler.create_completion(request) - if isinstance(generator, ErrorResponse): - connection_semaphore.release() - return JSONResponse(content=generator.model_dump(), status_code=500) - elif isinstance(generator, CompletionResponse): - connection_semaphore.release() - return JSONResponse(content=generator.model_dump()) - else: - wrapped_generator = wrap_streaming_generator(generator) - return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") - except HTTPException as e: - return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) +def run_api_server_worker_proc(args, listen_address, sock, client_config=None, **uvicorn_kwargs) -> None: + """Entrypoint for individual API server worker processes.""" + # 设置进程标题,并为标准输出和标准错误添加特定于进程的前缀 + # Set process title and add process-specific prefix to stdout and stderr. + server_index = client_config.get("client_index", 0) if client_config else 0 + setproctitle.setproctitle(f"APIServer::{server_index}") -@app.get("/v1/models") -async def list_models() -> Response: - """ - List all available models. - """ - if app.state.dynamic_load_weight: - status, msg = app.state.engine_client.is_workers_alive() - if not status: - return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) + uvloop.run(run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)) - models = await app.state.model_handler.list_models() - if isinstance(models, ErrorResponse): - return JSONResponse(content=models.model_dump()) - elif isinstance(models, ModelList): - return JSONResponse(content=models.model_dump()) +async def run_server_worker(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None: + api_server_app = ApiServerApp(args) + # 异步启动HTTP服务 + shutdown_task = await serve_http( + app=api_server_app.build_app(), + sock=sock, + host=args.host, + port=args.port, + log_config=UVICORN_CONFIG, + log_level="info", + **uvicorn_kwargs, + ) + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +def run_multi_api_server(args: argparse.Namespace): + listen_address, sock = setup_server(args) # Construct common args for the APIServerProcessManager up-front. + api_server_manager_kwargs = dict( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=args.workers, + ) -@app.get("/update_model_weight") -def update_model_weight(request: Request) -> Response: - """ - update model weight - """ - if app.state.dynamic_load_weight: - status, msg = app.state.engine_client.update_model_weight() - if not status: - return Response(content=msg, status_code=404) - return Response(status_code=200) - else: - return Response(content="Dynamic Load Weight Disabled.", status_code=404) + api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) + wait_for_completion_or_failure(api_server_manager) -@app.get("/clear_load_weight") -def clear_load_weight(request: Request) -> Response: - """ - clear model weight - """ - if app.state.dynamic_load_weight: - status, msg = app.state.engine_client.clear_load_weight() - if not status: - return Response(content=msg, status_code=404) - return Response(status_code=200) - else: - return Response(content="Dynamic Load Weight Disabled.", status_code=404) +def wait_for_completion_or_failure(api_server_manager: APIServerProcessManager) -> None: + """Wait for all processes to complete or detect if any fail. + Raises an exception if any process exits with a non-zero status. -def launch_api_server() -> None: - """ - 启动http服务 + Args: + api_server_manager: The manager for API servers. + engine_manager: The manager for engine processes. + If CoreEngineProcManager, it manages local engines; + if CoreEngineActorManager, it manages all engines. + coordinator: The coordinator for data parallel. """ - if not is_port_available(args.host, args.port): - raise Exception(f"The parameter `port`:{args.port} is already in use.") - - api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}") - api_server_logger.info(f"args: {args.__dict__}") - fd_start_span("FD_START") try: - uvicorn.run( - app="fastdeploy.entrypoints.openai.api_server:app", - host=args.host, - port=args.port, - workers=args.workers, - log_config=UVICORN_CONFIG, - log_level="info", - timeout_graceful_shutdown=args.timeout_graceful_shutdown, - ) # set log level to error to avoid log + api_server_logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = {proc.sentinel: proc for proc in api_server_manager.processes} + + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " f"died with exit code {proc.exitcode}" + ) + + except KeyboardInterrupt: + api_server_logger.info("Received KeyboardInterrupt, shutting down API servers...") except Exception as e: - api_server_logger.error(f"launch sync http server error, {e}, {str(traceback.format_exc())}") - - -metrics_app = FastAPI() - - -@metrics_app.get("/metrics") -async def metrics(): - """ - metrics - """ - metrics_text = get_filtered_metrics( - EXCLUDE_LABELS, - extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=args.workers), - ) - return Response(metrics_text, media_type=CONTENT_TYPE_LATEST) - - -def run_metrics_server(): - """ - run metrics server - """ - - uvicorn.run(metrics_app, host="0.0.0.0", port=args.metrics_port, log_config=UVICORN_CONFIG, log_level="error") - - -def launch_metrics_server(): - """Metrics server running the sub thread""" - if not is_port_available(args.host, args.metrics_port): - raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.") - - prom_dir = cleanup_prometheus_files(True) - os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir - metrics_server_thread = threading.Thread(target=run_metrics_server, daemon=True) - metrics_server_thread.start() - time.sleep(1) - - -controller_app = FastAPI() - - -@controller_app.post("/controller/reset_scheduler") -def reset_scheduler(): - """ - reset scheduler - """ - global llm_engine + api_server_logger.exception("Exception occurred while running API servers: %s", str(e)) + raise + finally: + api_server_logger.info("Terminating remaining processes ...") + api_server_manager.close() + + +class ApiServerApp(FastAPI): + + def __init__(self, args: argparse.Namespace): + self.args = args + + def build_app(self) -> FastAPI: + args = self.args + MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers + connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) + chat_template = load_chat_template(args.chat_template, args.model) + + @asynccontextmanager + async def lifespan(app: FastAPI): + """ + async context manager for FastAPI lifespan + """ + + if args.tokenizer is None: + args.tokenizer = args.model + if current_process().name != "MainProcess": + pid = os.getppid() + else: + pid = os.getpid() + api_server_logger.info(f"{pid}") - if llm_engine is None: - return Response("Engine not loaded", status_code=500) - llm_engine.engine.scheduler.reset() - return Response("Scheduler Reset Successfully", status_code=200) + if args.served_model_name is not None: + served_model_names = args.served_model_name + verification = True + else: + served_model_names = args.model + verification = False + model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)] + + engine_client = EngineClient( + model_name_or_path=args.model, + tokenizer=args.tokenizer, + max_model_len=args.max_model_len, + tensor_parallel_size=args.tensor_parallel_size, + pid=pid, + port=int(args.engine_worker_queue_port[args.local_data_parallel_id]), + limit_mm_per_prompt=args.limit_mm_per_prompt, + mm_processor_kwargs=args.mm_processor_kwargs, + # args.enable_mm, + reasoning_parser=args.reasoning_parser, + data_parallel_size=args.data_parallel_size, + enable_logprob=args.enable_logprob, + workers=args.workers, + tool_parser=args.tool_call_parser, + ) + await engine_client.connection_manager.initialize() + app.state.dynamic_load_weight = args.dynamic_load_weight + model_handler = OpenAIServingModels( + model_paths, + args.max_model_len, + args.ips, + ) + app.state.model_handler = model_handler + chat_handler = OpenAIServingChat( + engine_client, + app.state.model_handler, + pid, + args.ips, + args.max_waiting_time, + chat_template, + args.enable_mm_output, + args.tokenizer_base_url, + ) + completion_handler = OpenAIServingCompletion( + engine_client, + app.state.model_handler, + pid, + args.ips, + args.max_waiting_time, + ) + engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) + engine_client.pid = pid + app.state.engine_client = engine_client + app.state.chat_handler = chat_handler + app.state.completion_handler = completion_handler + global llm_engine + if llm_engine is not None: + llm_engine.engine.data_processor = engine_client.data_processor + yield + # close zmq + try: + await engine_client.connection_manager.close() + engine_client.zmq_client.close() + from prometheus_client import multiprocess + + multiprocess.mark_process_dead(os.getpid()) + api_server_logger.info(f"Closing metrics client pid: {pid}") + except Exception as e: + api_server_logger.warning(f"exit error: {e}, {str(traceback.format_exc())}") + + app = FastAPI(lifespan=lifespan) + instrument(app) + app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception) + app.add_exception_handler(Exception, ExceptionHandler.handle_exception) + + @asynccontextmanager + async def connection_manager(): + """ + async context manager for connection manager + """ + try: + await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001) + yield + except asyncio.TimeoutError: + api_server_logger.info( + f"Reach max request concurrency, semaphore status: {connection_semaphore.status()}" + ) + raise HTTPException( + status_code=429, detail=f"Too many requests,current max concurrency is {args.max_concurrency}" + ) + + # TODO 传递真实引擎值 通过pid 获取状态 + @app.get("/health") + def health(request: Request) -> Response: + """Health check.""" + + status, msg = app.state.engine_client.check_health() + if not status: + return Response(content=msg, status_code=404) + status, msg = app.state.engine_client.is_workers_alive() + if not status: + return Response(content=msg, status_code=304) + return Response(status_code=200) + + @app.get("/load") + async def list_all_routes(): + """ + 列出所有以/v1开头的路由信息 + + Args: + 无参数 + + Returns: + dict: 包含所有符合条件的路由信息的字典,格式如下: + { + "routes": [ + { + "path": str, # 路由路径 + "methods": list, # 支持的HTTP方法列表,已排序 + "tags": list # 路由标签列表,默认为空列表 + }, + ... + ] + } + + """ + routes_info = [] + + for route in app.routes: + # 直接检查路径是否以/v1开头 + if route.path.startswith("/v1"): + methods = sorted(route.methods) + tags = getattr(route, "tags", []) or [] + routes_info.append({"path": route.path, "methods": methods, "tags": tags}) + return {"routes": routes_info} + + @app.api_route("/ping", methods=["GET", "POST"]) + def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return health(raw_request) + + def wrap_streaming_generator(original_generator: AsyncGenerator): + """ + Wrap an async generator to release the connection semaphore when the generator is finished. + """ + + async def wrapped_generator(): + try: + async for chunk in original_generator: + yield chunk + finally: + api_server_logger.debug(f"release: {connection_semaphore.status()}") + connection_semaphore.release() + + return wrapped_generator + + @app.post("/v1/chat/completions") + async def create_chat_completion(request: ChatCompletionRequest): + """ + Create a chat completion for the provided prompt and parameters. + """ + api_server_logger.info(f"Chat Received request: {request.model_dump_json()}") + if app.state.dynamic_load_weight: + status, msg = app.state.engine_client.is_workers_alive() + if not status: + return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) + try: + async with connection_manager(): + inject_to_metadata(request) + generator = await app.state.chat_handler.create_chat_completion(request) + if isinstance(generator, ErrorResponse): + api_server_logger.debug(f"release: {connection_semaphore.status()}") + connection_semaphore.release() + return JSONResponse(content=generator.model_dump(), status_code=500) + elif isinstance(generator, ChatCompletionResponse): + api_server_logger.debug(f"release: {connection_semaphore.status()}") + connection_semaphore.release() + return JSONResponse(content=generator.model_dump()) + else: + wrapped_generator = wrap_streaming_generator(generator) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") + + except HTTPException as e: + api_server_logger.error(f"Error in chat completion: {str(e)}") + + @app.post("/v1/completions") + async def create_completion(request: CompletionRequest): + """ + Create a completion for the provided prompt and parameters. + """ + api_server_logger.info(f"Completion Received request: {request.model_dump_json()}") + if app.state.dynamic_load_weight: + status, msg = app.state.engine_client.is_workers_alive() + if not status: + return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) + try: + async with connection_manager(): + generator = await app.state.completion_handler.create_completion(request) + if isinstance(generator, ErrorResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump(), status_code=500) + elif isinstance(generator, CompletionResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump()) + else: + wrapped_generator = wrap_streaming_generator(generator) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) + + @app.get("/v1/models") + async def list_models() -> Response: + """ + List all available models. + """ + if app.state.dynamic_load_weight: + status, msg = app.state.engine_client.is_workers_alive() + if not status: + return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) + + models = await app.state.model_handler.list_models() + if isinstance(models, ErrorResponse): + return JSONResponse(content=models.model_dump()) + elif isinstance(models, ModelList): + return JSONResponse(content=models.model_dump()) + + @app.get("/update_model_weight") + def update_model_weight(request: Request) -> Response: + """ + update model weight + """ + if app.state.dynamic_load_weight: + status, msg = app.state.engine_client.update_model_weight() + if not status: + return Response(content=msg, status_code=404) + return Response(status_code=200) + else: + return Response(content="Dynamic Load Weight Disabled.", status_code=404) + + @app.get("/clear_load_weight") + def clear_load_weight(request: Request) -> Response: + """ + clear model weight + """ + if app.state.dynamic_load_weight: + status, msg = app.state.engine_client.clear_load_weight() + if not status: + return Response(content=msg, status_code=404) + return Response(status_code=200) + else: + return Response(content="Dynamic Load Weight Disabled.", status_code=404) + return app -@controller_app.post("/controller/scheduler") -def control_scheduler(request: ControlSchedulerRequest): - """ - Control the scheduler behavior with the given parameters. - """ + def launch_api_server(self) -> None: + """ + 启动http服务 + """ + args = self.args + if not is_port_available(args.host, args.port): + raise Exception(f"The parameter `port`:{args.port} is already in use.") - content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0)) + api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}") + api_server_logger.info(f"args: {args.__dict__}") + fd_start_span("FD_START") - global llm_engine - if llm_engine is None: - content.message = "Engine is not loaded" - content.code = 500 - return JSONResponse(content=content.model_dump(), status_code=500) - - if request.reset: - llm_engine.engine.scheduler.reset() - - if request.load_shards_num or request.reallocate_shard: - if hasattr(llm_engine.engine.scheduler, "update_config") and callable( - llm_engine.engine.scheduler.update_config - ): - llm_engine.engine.scheduler.update_config( - load_shards_num=request.load_shards_num, - reallocate=request.reallocate_shard, + try: + if args.workers > 1: + run_multi_api_server(args) + else: + app = self.build_app() + uvicorn.run( + app=app, + host=args.host, + port=args.port, + workers=args.workers, + log_config=UVICORN_CONFIG, + log_level="info", + timeout_graceful_shutdown=args.timeout_graceful_shutdown, + ) # set log level to error to avoid log + except Exception as e: + api_server_logger.error(f"launch sync http server error, {e}, {str(traceback.format_exc())}") + print("fastdeploy api server stopped") + + +class MetricsServerApp(FastAPI): + + def __init__(self, args): + self.args = args + + def build_app(self) -> FastAPI: + metrics_app = FastAPI() + + @metrics_app.get("/metrics") + async def metrics(): + """ + metrics + """ + metrics_text = get_filtered_metrics( + EXCLUDE_LABELS, + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=args.workers), ) - else: - content.message = "This scheduler doesn't support the `update_config()` method." - content.code = 400 - return JSONResponse(content=content.model_dump(), status_code=400) - - return JSONResponse(content=content.model_dump(), status_code=200) + return Response(metrics_text, media_type=CONTENT_TYPE_LATEST) + return metrics_app -def run_controller_server(): - """ - run controller server - """ - uvicorn.run( - controller_app, - host="0.0.0.0", - port=args.controller_port, - log_config=UVICORN_CONFIG, - log_level="error", - ) + def run_metrics_server(self): + """ + run metrics server + """ + metrics_app = self.build_app() + uvicorn.run( + metrics_app, host="0.0.0.0", port=self.args.metrics_port, log_config=UVICORN_CONFIG, log_level="error" + ) + def launch_metrics_server(self): + """Metrics server running the sub thread""" + args = self.args + if not is_port_available(args.host, args.metrics_port): + raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.") + + prom_dir = cleanup_prometheus_files(True) + os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir + metrics_server_thread = threading.Thread(target=self.run_metrics_server, daemon=True) + metrics_server_thread.start() + time.sleep(1) + + +class ControllerServerApp(FastAPI): + + def __init__(self, args): + self.args = args + + def build_app(self) -> FastAPI: + controller_app = FastAPI() + + @controller_app.post("/controller/reset_scheduler") + def reset_scheduler(): + """ + reset scheduler + """ + global llm_engine + + if llm_engine is None: + return Response("Engine not loaded", status_code=500) + llm_engine.engine.scheduler.reset() + return Response("Scheduler Reset Successfully", status_code=200) + + @controller_app.post("/controller/scheduler") + def control_scheduler(request: ControlSchedulerRequest): + """ + Control the scheduler behavior with the given parameters. + """ + content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0)) + + global llm_engine + if llm_engine is None: + content.message = "Engine is not loaded" + content.code = 500 + return JSONResponse(content=content.model_dump(), status_code=500) + + if request.reset: + llm_engine.engine.scheduler.reset() + + if request.load_shards_num or request.reallocate_shard: + if hasattr(llm_engine.engine.scheduler, "update_config") and callable( + llm_engine.engine.scheduler.update_config + ): + llm_engine.engine.scheduler.update_config( + load_shards_num=request.load_shards_num, + reallocate=request.reallocate_shard, + ) + else: + content.message = "This scheduler doesn't support the `update_config()` method." + content.code = 400 + return JSONResponse(content=content.model_dump(), status_code=400) + + return JSONResponse(content=content.model_dump(), status_code=200) + + return controller_app + + def run_controller_server(self): + """ + run controller server + """ + app = self.build_app() + uvicorn.run( + app, + host="0.0.0.0", + port=self.args.controller_port, + log_config=UVICORN_CONFIG, + log_level="error", + ) -def launch_controller_server(): - """Controller server running the sub thread""" - if args.controller_port < 0: - return + def launch_controller_server(self): + """Controller server running the sub thread""" + args = self.args + if args.controller_port < 0: + return - if not is_port_available(args.host, args.controller_port): - raise Exception(f"The parameter `controller_port`:{args.controller_port} is already in use.") + if not is_port_available(args.host, args.controller_port): + raise Exception(f"The parameter `controller_port`:{args.controller_port} is already in use.") - controller_server_thread = threading.Thread(target=run_controller_server, daemon=True) - controller_server_thread.start() - time.sleep(1) + controller_server_thread = threading.Thread(target=self.run_controller_server, daemon=True) + controller_server_thread.start() + time.sleep(1) -def main(): +def main(args: argparse.Namespace): """main函数""" + args = rewrite_args(args) if args.local_data_parallel_id == 0: - if not load_engine(): + if not load_engine(args): return else: - if not load_data_service(): + if not load_data_service(args): return api_server_logger.info("FastDeploy LLM engine initialized!\n") console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics") console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions") console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions") - - launch_controller_server() - launch_metrics_server() - launch_api_server() + controller_server = ControllerServerApp(args) + controller_server.launch_controller_server() + metrics_server = MetricsServerApp(args) + metrics_server.launch_metrics_server() + api_server = ApiServerApp(args) + api_server.launch_api_server() if __name__ == "__main__": - main() + parser = FlexibleArgumentParser() + parser = make_arg_parser(parser) + args = parser.parse_args() + main(args) diff --git a/fastdeploy/plugins/utils.py b/fastdeploy/plugins/utils.py index e457223acf..572b1a1579 100644 --- a/fastdeploy/plugins/utils.py +++ b/fastdeploy/plugins/utils.py @@ -32,7 +32,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: discovered_plugins = entry_points(group=group) if len(discovered_plugins) == 0: - logger.info("No plugins for group %s found.", group) + logger.debug("No plugins for group %s found.", group) return {} logger.info("Available plugins for group %s:", group) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 1a2dd0c79b..e8324ae1c6 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -17,12 +17,15 @@ import argparse import asyncio import codecs +import contextlib import importlib +import ipaddress import json import logging import os import random import re +import signal import socket import sys import tarfile @@ -37,6 +40,7 @@ import numpy as np import paddle +import psutil import requests import yaml from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download @@ -852,3 +856,64 @@ def get_logger(name, file_name=None, without_formater=False, print_to_console=Fa console_logger = get_logger("console", "console.log", print_to_console=True) spec_logger = get_logger("speculate", "speculate.log") zmq_client_logger = get_logger("zmq_client", "zmq_client.log") + + +class YamlInputAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + # 支持从文件路径或直接传递 YAML 字符串 + if os.path.exists(values): + with open(values, "r") as f: + yaml_content = f.read() + else: + yaml_content = values # 直接处理 YAML 字符串(需用户确保格式正确) + + try: + config = yaml.safe_load(yaml_content) + if not isinstance(config, dict): + raise ValueError("YAML 内容必须为字典格式") + + # 如果目标属性已存在(如通过其他参数设置),则合并或覆盖 + if hasattr(namespace, self.dest): + existing_config = getattr(namespace, self.dest) + if isinstance(existing_config, dict): + existing_config.update(config) # 合并字典 + setattr(namespace, self.dest, existing_config) + else: + setattr(namespace, self.dest, config) # 直接覆盖 + else: + setattr(namespace, self.dest, config) + except yaml.YAMLError as e: + raise argparse.ArgumentError(self, f"YAML 解析错误: {e}") + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def kill_process_tree(pid: int) -> None: + """ + Kills all descendant processes of the given pid by sending SIGKILL. + + Args: + pid (int): Process ID of the parent process + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Send SIGKILL to all children first + for child in children: + with contextlib.suppress(ProcessLookupError): + os.kill(child.pid, signal.SIGKILL) + + # Finally kill the parent + with contextlib.suppress(ProcessLookupError): + os.kill(pid, signal.SIGKILL) diff --git a/requirements.txt b/requirements.txt index ddad9d9b3b..a3561553aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,3 +39,5 @@ opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser +setproctitle +uvloop diff --git a/requirements_dcu.txt b/requirements_dcu.txt index 79bac3a622..64fcc1d47c 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -36,3 +36,5 @@ opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser +setproctitle +uvloop diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index d481e3febb..add6b664c0 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -37,3 +37,5 @@ opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser +setproctitle +uvloop diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index 7aa310fa23..04c5ead61f 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -38,3 +38,5 @@ opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser +setproctitle +uvloop diff --git a/setup.py b/setup.py index 1e98789363..98ae480ab7 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,8 @@ from setuptools.command.install import install from wheel.bdist_wheel import bdist_wheel +from fastdeploy import __version__ + long_description = "FastDeploy: Large Language Model Serving.\n\n" long_description += "GitHub: https://github.com/PaddlePaddle/FastDeploy\n" long_description += "Email: dltp@baidu.com" @@ -185,7 +187,7 @@ def get_name(): cmdclass_dict = {"bdist_wheel": CustomBdistWheel} cmdclass_dict["build_ext"] = CMakeBuild -FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.3.0-dev") +FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", __version__) cmdclass_dict["build_optl"] = PostInstallCommand setup( diff --git a/tests/entrypoints/cli/test_main.py b/tests/entrypoints/cli/test_main.py index dada7f624b..787d3d035d 100644 --- a/tests/entrypoints/cli/test_main.py +++ b/tests/entrypoints/cli/test_main.py @@ -6,10 +6,8 @@ class TestCliMain(unittest.TestCase): @patch("fastdeploy.utils.FlexibleArgumentParser") - @patch("fastdeploy.entrypoints.cli.main.importlib.metadata") - def test_main_basic(self, mock_metadata, mock_parser): + def test_main_basic(self, mock_parser): # Setup mocks - mock_metadata.version.return_value = "1.0.0" mock_args = MagicMock() mock_args.subparser = None mock_parser.return_value.parse_args.return_value = mock_args @@ -18,7 +16,6 @@ def test_main_basic(self, mock_metadata, mock_parser): cli_main() # Verify version check - mock_metadata.version.assert_called_once_with("fastdeploy") mock_args.dispatch_function.assert_called_once() diff --git a/tests/entrypoints/cli/test_serve.py b/tests/entrypoints/cli/test_serve.py new file mode 100755 index 0000000000..c391ba2219 --- /dev/null +++ b/tests/entrypoints/cli/test_serve.py @@ -0,0 +1,45 @@ +import argparse +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.entrypoints.cli.serve import ServeSubcommand, cmd_init + + +class TestServeSubcommand(unittest.TestCase): + """Tests for ServeSubcommand class.""" + + def test_name_property(self): + """Test the name property is correctly set.""" + self.assertEqual(ServeSubcommand.name, "serve") + + @patch("fastdeploy.entrypoints.cli.serve.api_server") + def test_cmd_method(self, mock_api_server): + """Test the cmd method calls the expected API server functions.""" + test_args = argparse.Namespace() + ServeSubcommand.cmd(test_args) + mock_api_server.main.assert_called_once_with(test_args) + + def test_validate_method(self): + """Test the validate method does nothing (no-op).""" + test_args = argparse.Namespace() + instance = ServeSubcommand() + instance.validate(test_args) # Should not raise any exceptions + + @patch("argparse._SubParsersAction.add_parser") + def test_subparser_init(self, mock_add_parser): + """Test the subparser initialization.""" + mock_subparsers = MagicMock() + instance = ServeSubcommand() + result = instance.subparser_init(mock_subparsers) + self.assertIsNotNone(result) + + def test_cmd_init_returns_list(self): + """Test cmd_init returns a list of subcommands.""" + result = cmd_init() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], ServeSubcommand) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py new file mode 100755 index 0000000000..518af7159d --- /dev/null +++ b/tests/entrypoints/openai/test_api_server.py @@ -0,0 +1,196 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient + +from fastdeploy.entrypoints.openai import api_server +from fastdeploy.entrypoints.openai.api_server import ( + ApiServerApp, + ControllerServerApp, + MetricsServerApp, +) + +# 直接从本地模块导入 +make_arg_parser = api_server.make_arg_parser +rewrite_args = api_server.rewrite_args +load_engine = api_server.load_engine +load_data_service = api_server.load_data_service + + +class TestApiServer(unittest.TestCase): + + def setUp(self): + self.test_args = MagicMock() + self.test_args.port = 123123 + self.test_args.host = "0.0.0.0" + self.test_args.workers = 1 + self.test_args.metrics_port = 12334 + self.test_args.controller_port = 12231 + self.test_args.max_waiting_time = -1 + self.test_args.max_concurrency = 512 + self.test_args.enable_mm_output = False + self.test_args.timeout_graceful_shutdown = 0 + self.test_args.model = "test_model" + self.test_args.revision = None + self.test_args.tokenizer = None + self.test_args.max_model_len = 2048 + self.test_args.tensor_parallel_size = 1 + self.test_args.engine_worker_queue_port = [8002] + self.test_args.local_data_parallel_id = 0 + self.test_args.limit_mm_per_prompt = None + self.test_args.mm_processor_kwargs = None + self.test_args.reasoning_parser = None + self.test_args.data_parallel_size = 1 + self.test_args.enable_logprob = False + self.test_args.tool_call_parser = None + self.test_args.dynamic_load_weight = False + self.test_args.served_model_name = None + self.test_args.hidden_size = 11 + self.test_args.num_attention_heads = 11 + self.test_args.chat_template = None + self.test_args.ips = "127.0.0.1" + self.test_args.tokenizer_base_url = None + + @patch("fastdeploy.entrypoints.openai.api_server.LLMEngine") + def test_load_engine(self, mock_engine): + mock_engine_instance = MagicMock() + mock_engine.from_engine_args.return_value = mock_engine_instance + mock_engine_instance.start.return_value = True + + with patch("fastdeploy.entrypoints.openai.api_server.llm_engine", None): + result = load_engine(self.test_args) + self.assertEqual(result, mock_engine_instance) + mock_engine.from_engine_args.assert_called_once() + mock_engine_instance.start.assert_called_once_with(api_server_pid=os.getpid()) + + @patch("fastdeploy.entrypoints.openai.api_server.ExpertService") + @patch("fastdeploy.engine.args_utils.EngineArgs.from_cli_args") + @patch("fastdeploy.engine.args_utils.EngineArgs.create_engine_config") + @patch("fastdeploy.entrypoints.openai.api_server.os.getpid") + @patch("fastdeploy.entrypoints.openai.api_server.api_server_logger.info") + def test_load_data_service(self, mock_logger, mock_getpid, mock_create_config, mock_from_cli, mock_service): + """测试 load_data_service 函数的完整行为""" + # Setup mocks + mock_getpid.return_value = 12345 + + # 创建详细的配置对象 + config = MagicMock() + config.parallel_config.local_data_parallel_id = 0 + config.hidden_size = 768 + config.num_attention_heads = 12 + config.worker_num_per_node = 1 + config.nnode = 1 + config.parallel_config.data_parallel_size = 1 + config.parallel_config.tensor_parallel_size = 1 + config.splitwise_role = "mixed" + config.scheduler_config = MagicMock(name="default") + config.cache_config = MagicMock(rdma_comm_ports=[], pd_comm_port=[8000]) + config.device_ids = "0" + config.engine_worker_queue_port = [8000] + config.host_ip = "127.0.0.1" + config.disaggregate_info = None + config.print = MagicMock() + + engine_args = MagicMock() + engine_args.create_engine_config.return_value = config + mock_from_cli.return_value = engine_args + + mock_service_instance = MagicMock() + mock_service.return_value = mock_service_instance + mock_service_instance.start.return_value = True + + # 调用函数 + result = load_data_service(self.test_args) + # 验证点1: EngineArgs.from_cli_args 被正确调用 + mock_from_cli.assert_called_once_with(self.test_args) + # 验证点2: create_engine_config 被正确调用 + engine_args.create_engine_config.assert_called_once() + # 验证点3: ExpertService 被正确初始化 + mock_service.assert_called_once_with(config, 0) + # 验证点4: start 方法被正确调用 + mock_service_instance.start.assert_called_once_with(12345, 0) + # 验证点5: 函数返回预期的 ExpertService 实例 + self.assertEqual(result, mock_service_instance) + # 验证日志记录 + mock_logger.assert_called() + + def test_make_arg_parser(self): + parser = make_arg_parser(MagicMock()) + self.assertTrue(hasattr(parser, "add_argument")) + + def test_rewrite_args(self): + self.test_args.workers = None + self.test_args.max_num_seqs = 64 + self.test_args.tool_parser_plugin = None + self.test_args.model = "test_model" + with patch("fastdeploy.entrypoints.openai.api_server.retrive_model_from_server", return_value="test_model"): + result = rewrite_args(self.test_args) + self.assertEqual(result.workers, None) # 64 // 32 = 2 + + @patch("multiprocessing.get_context") + @patch("multiprocessing.connection.wait") + @patch("fastdeploy.utils.kill_process_tree") # 将mock提升到方法级别 + def test_run_multi_api_server(self, mock_kill_process_tree, mock_ready_sentinels, mock_spawn_context): + mocked_spawn_context = MagicMock() + mocked_process = MagicMock() + mocked_process.sentinel = "test_sentinel" + mocked_process.exitcode = 1 + mocked_process.name = "test_process" + mocked_process.is_alive.return_value = False + mocked_process.pid = 1 + mocked_spawn_context.Process.return_value = mocked_process + mock_spawn_context.return_value = mocked_spawn_context + mock_sentinels = ["test_sentinel"] + mock_ready_sentinels.return_value = mock_sentinels + with ( + patch("fastdeploy.entrypoints.openai.api_server.set_ulimit"), + patch("fastdeploy.entrypoints.openai.api_server.create_server_socket"), + patch("fastdeploy.entrypoints.openai.api_server.signal.SIGKILL"), + ): + with self.assertRaises(RuntimeError): + api_server.run_multi_api_server(self.test_args) + + @patch("fastdeploy.entrypoints.openai.api_server.LLMEngine") + @patch("uvicorn.run") + @patch("socket.socket") + @patch("fastdeploy.entrypoints.openai.api_server.retrive_model_from_server", return_value="test_model") + def test_api_server_app(self, mock_retrieve_model, mock_socket, mock_run, mock_engine): + mock_engine_instance = MagicMock() + mock_engine.from_engine_args.return_value = mock_engine_instance + mock_engine_instance.start.return_value = True + + app = ApiServerApp(self.test_args) + TestClient(app.build_app()) + app.launch_api_server() + mock_run.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.api_server.LLMEngine") + @patch("uvicorn.run") + @patch("socket.socket") + def test_metrics_server_app(self, mock_socket, mock_run, mock_engine): + app = MetricsServerApp(self.test_args) + TestClient(app.build_app()) + app.launch_metrics_server() + mock_run.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.api_server.LLMEngine") + @patch("uvicorn.run") + @patch("socket.socket") + def test_controller_server_app(self, mock_socket, mock_run, mock_engine): + + app = ControllerServerApp(self.test_args) + TestClient(app.build_app()) + app.launch_controller_server() + mock_run.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.api_server.LLMEngine") + @patch("uvicorn.run") + @patch("socket.socket") + def test_main(self, mock_socket, mock_run, mock_engine): + with patch("fastdeploy.entrypoints.openai.api_server.retrive_model_from_server", return_value="test_model"): + api_server.main(self.test_args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_kill_process_tree.py b/tests/utils/test_kill_process_tree.py new file mode 100644 index 0000000000..01f14848f8 --- /dev/null +++ b/tests/utils/test_kill_process_tree.py @@ -0,0 +1,63 @@ +import signal +import unittest +from unittest.mock import MagicMock, patch + +import psutil + +from fastdeploy.utils import kill_process_tree + + +class TestKillProcessTree(unittest.TestCase): + @patch("psutil.Process") + @patch("os.kill") + def test_kill_process_tree_success(self, mock_os_kill, mock_process): + # Setup mock process tree + parent_process = MagicMock() + child1 = MagicMock() + child1.pid = 1001 + child2 = MagicMock() + child2.pid = 1002 + parent_process.children.return_value = [child1, child2] + mock_process.return_value = parent_process + + # Call function + kill_process_tree(1234) + + # Verify + mock_process.assert_called_once_with(1234) + parent_process.children.assert_called_once_with(recursive=True) + self.assertEqual(mock_os_kill.call_count, 3) # 2 children + parent + mock_os_kill.assert_any_call(1001, signal.SIGKILL) + mock_os_kill.assert_any_call(1002, signal.SIGKILL) + mock_os_kill.assert_any_call(1234, signal.SIGKILL) + + @patch("psutil.Process") + def test_kill_process_tree_no_such_process(self, mock_process): + mock_process.side_effect = psutil.NoSuchProcess(1234) + + # Should not raise exception + kill_process_tree(1234) + + mock_process.assert_called_once_with(1234) + + @patch("psutil.Process") + @patch("os.kill") + def test_kill_process_tree_child_kill_failure(self, mock_os_kill, mock_process): + parent_process = MagicMock() + child = MagicMock() + child.pid = 1001 + parent_process.children.return_value = [child] + mock_process.return_value = parent_process + + # First child kill fails, parent kill succeeds + mock_os_kill.side_effect = [ProcessLookupError, None] + + # Should not raise exception + kill_process_tree(1234) + + mock_os_kill.assert_any_call(1001, signal.SIGKILL) + mock_os_kill.assert_any_call(1234, signal.SIGKILL) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_yaml_input_action.py b/tests/utils/test_yaml_input_action.py new file mode 100644 index 0000000000..1befd68aed --- /dev/null +++ b/tests/utils/test_yaml_input_action.py @@ -0,0 +1,65 @@ +import argparse +import os +import tempfile +import unittest +from unittest.mock import MagicMock + +from fastdeploy.utils import YamlInputAction + + +class TestYamlInputAction(unittest.TestCase): + def setUp(self): + self.parser = MagicMock(spec=argparse.ArgumentParser) + self.namespace = argparse.Namespace() + self.action = YamlInputAction(option_strings=[], dest="config") + + def test_call_with_yaml_file(self): + # Create a temporary YAML file + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml_content = """ + key1: value1 + key2: value2 + """ + f.write(yaml_content) + f.close() + + # Test + self.action(self.parser, self.namespace, f.name) + + # Verify + self.assertEqual(self.namespace.config, {"key1": "value1", "key2": "value2"}) + + # Clean up + os.unlink(f.name) + + def test_call_with_yaml_string(self): + yaml_str = """ + key1: value1 + key2: value2 + """ + + self.action(self.parser, self.namespace, yaml_str) + self.assertEqual(self.namespace.config, {"key1": "value1", "key2": "value2"}) + + def test_call_with_invalid_yaml(self): + with self.assertRaises(ValueError): + self.action(self.parser, self.namespace, "invalid") + + def test_call_with_non_dict_yaml(self): + with self.assertRaises(ValueError): + self.action(self.parser, self.namespace, "- item1\n- item2") + + def test_call_with_existing_config(self): + # Set existing config + self.namespace.config = {"existing": "value"} + + yaml_str = """ + new_key: new_value + """ + + self.action(self.parser, self.namespace, yaml_str) + self.assertEqual(self.namespace.config, {"existing": "value", "new_key": "new_value"}) + + +if __name__ == "__main__": + unittest.main()