diff --git a/Cargo.lock b/Cargo.lock index c13d79adeb..3d8423ccd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2391,6 +2391,7 @@ dependencies = [ "figment", "futures", "humantime", + "inotify", "jsonschema", "local-ip-address", "log", @@ -3992,6 +3993,28 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.9.4", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "insta" version = "1.43.2" diff --git a/components/src/dynamo/frontend/main.py b/components/src/dynamo/frontend/main.py index eaff60aa50..f64587ee8c 100644 --- a/components/src/dynamo/frontend/main.py +++ b/components/src/dynamo/frontend/main.py @@ -225,6 +225,12 @@ def parse_args(): ), help=f"Interval in seconds for polling custom backend metrics. Set to > 0 to enable polling (default: 0=disabled, suggested: 9.2s which is less than typical Prometheus scrape interval). Can be set via {CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR} env var.", ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) flags = parser.parse_args() @@ -252,8 +258,7 @@ async def async_main(): os.environ["DYN_METRICS_PREFIX"] = flags.metrics_prefix loop = asyncio.get_running_loop() - - runtime = DistributedRuntime(loop, is_static) + runtime = DistributedRuntime(loop, flags.store_kv, is_static) def signal_handler(): asyncio.create_task(graceful_shutdown(runtime)) diff --git a/components/src/dynamo/mocker/args.py b/components/src/dynamo/mocker/args.py index 8f3631458e..7a10ba02d9 100644 --- a/components/src/dynamo/mocker/args.py +++ b/components/src/dynamo/mocker/args.py @@ -204,6 +204,12 @@ def parse_args(): default=False, help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)", ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) args = parser.parse_args() validate_worker_type_args(args) diff --git a/components/src/dynamo/mocker/main.py b/components/src/dynamo/mocker/main.py index 220f2995ac..6dbd3d9fc3 100644 --- a/components/src/dynamo/mocker/main.py +++ b/components/src/dynamo/mocker/main.py @@ -72,7 +72,7 @@ async def launch_workers(args, extra_engine_args_path): logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}") # Create a separate DistributedRuntime for this worker (on same event loop) - runtime = DistributedRuntime(loop, False) + runtime = DistributedRuntime(loop, args.store_kv, False) runtimes.append(runtime) # Create EntrypointArgs for this worker diff --git a/components/src/dynamo/sglang/args.py b/components/src/dynamo/sglang/args.py index 2417f0fc28..a08cb025de 100644 --- a/components/src/dynamo/sglang/args.py +++ b/components/src/dynamo/sglang/args.py @@ -93,6 +93,12 @@ "default": None, "help": "Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.", }, + "store-kv": { + "flags": ["--store-kv"], + "type": str, + "default": os.environ.get("DYN_STORE_KV", "etcd"), + "help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + }, } @@ -102,6 +108,7 @@ class DynamoArgs: component: str endpoint: str migration_limit: int + store_kv: str # tool and reasoning parser options tool_call_parser: Optional[str] = None @@ -329,6 +336,7 @@ async def parse_args(args: list[str]) -> Config: component=parsed_component_name, endpoint=parsed_endpoint_name, migration_limit=parsed_args.migration_limit, + store_kv=parsed_args.store_kv, tool_call_parser=tool_call_parser, reasoning_parser=reasoning_parser, custom_jinja_template=expanded_template_path, diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index 1dc20099a8..2d5e92bf49 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -11,7 +11,7 @@ from dynamo.common.config_dump import dump_config from dynamo.llm import ModelInput, ModelType -from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.health_check import ( @@ -33,9 +33,12 @@ configure_dynamo_logging() -@dynamo_worker(static=False) -async def worker(runtime: DistributedRuntime): +async def worker(): + config = await parse_args(sys.argv[1:]) + dump_config(config.dynamo_args.dump_config_to, config) + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, config.dynamo_args.store_kv, False) def signal_handler(): asyncio.create_task(graceful_shutdown(runtime)) @@ -45,9 +48,6 @@ def signal_handler(): logging.info("Signal handlers will trigger a graceful shutdown of the runtime") - config = await parse_args(sys.argv[1:]) - dump_config(config.dynamo_args.dump_config_to, config) - if config.dynamo_args.embedding_worker: await init_embedding(runtime, config) elif config.dynamo_args.multimodal_processor: diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 270c8ce58a..55d7723bc1 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -39,7 +39,7 @@ from dynamo.common.config_dump import dump_config from dynamo.common.utils.prometheus import register_engine_metrics_callback from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm -from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.health_check import TrtllmHealthCheckPayload @@ -102,11 +102,13 @@ async def get_engine_runtime_config( return runtime_config -@dynamo_worker(static=False) -async def worker(runtime: DistributedRuntime): - # Set up signal handler for graceful shutdown +async def worker(): + config = cmd_line_args() + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, config.store_kv, False) + # Set up signal handler for graceful shutdown def signal_handler(): # Schedule the shutdown coroutine instead of calling it directly asyncio.create_task(graceful_shutdown(runtime)) @@ -116,7 +118,6 @@ def signal_handler(): logging.info("Signal handlers set up for graceful shutdown") - config = cmd_line_args() await init(runtime, config) diff --git a/components/src/dynamo/trtllm/utils/trtllm_utils.py b/components/src/dynamo/trtllm/utils/trtllm_utils.py index b7ec219f02..3a3da53d1f 100644 --- a/components/src/dynamo/trtllm/utils/trtllm_utils.py +++ b/components/src/dynamo/trtllm/utils/trtllm_utils.py @@ -58,6 +58,7 @@ def __init__(self) -> None: self.tool_call_parser: Optional[str] = None self.dump_config_to: Optional[str] = None self.custom_jinja_template: Optional[str] = None + self.store_kv: str = "" def __str__(self) -> str: return ( @@ -87,8 +88,9 @@ def __str__(self) -> str: f"max_file_size_mb={self.max_file_size_mb}, " f"reasoning_parser={self.reasoning_parser}, " f"tool_call_parser={self.tool_call_parser}, " - f"dump_config_to={self.dump_config_to}," - f"custom_jinja_template={self.custom_jinja_template}" + f"dump_config_to={self.dump_config_to}, " + f"custom_jinja_template={self.custom_jinja_template}, " + f"store_kv={self.store_kv}" ) @@ -278,6 +280,12 @@ def cmd_line_args(): default=None, help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.", ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) args = parser.parse_args() @@ -337,6 +345,7 @@ def cmd_line_args(): config.reasoning_parser = args.dyn_reasoning_parser config.tool_call_parser = args.dyn_tool_call_parser config.dump_config_to = args.dump_config_to + config.store_kv = args.store_kv # Handle custom jinja template path expansion (environment variables and home directory) if args.custom_jinja_template: diff --git a/components/src/dynamo/vllm/args.py b/components/src/dynamo/vllm/args.py index dc7e73ed88..ace113a32f 100644 --- a/components/src/dynamo/vllm/args.py +++ b/components/src/dynamo/vllm/args.py @@ -38,6 +38,7 @@ class Config: migration_limit: int = 0 kv_port: Optional[int] = None custom_jinja_template: Optional[str] = None + store_kv: str # mirror vLLM model: str @@ -164,6 +165,12 @@ def parse_args() -> Config: "'USER: please describe the image ASSISTANT:'." ), ) + parser.add_argument( + "--store-kv", + type=str, + default=os.environ.get("DYN_STORE_KV", "etcd"), + help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.", + ) add_config_dump_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) @@ -233,6 +240,7 @@ def parse_args() -> Config: config.multimodal_worker = args.multimodal_worker config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker config.mm_prompt_template = args.mm_prompt_template + config.store_kv = args.store_kv # Validate custom Jinja template file exists if provided if config.custom_jinja_template is not None: diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index 060390d01b..5f4eb4ddc8 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -25,7 +25,7 @@ fetch_llm, register_llm, ) -from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.vllm.multimodal_handlers import ( EncodeWorkerHandler, @@ -70,16 +70,16 @@ async def graceful_shutdown(runtime): logging.info("DistributedRuntime shutdown complete") -@dynamo_worker(static=False) -async def worker(runtime: DistributedRuntime): +async def worker(): config = parse_args() + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, config.store_kv, False) + await configure_ports(config) overwrite_args(config) # Set up signal handler for graceful shutdown - loop = asyncio.get_running_loop() - def signal_handler(): asyncio.create_task(graceful_shutdown(runtime)) diff --git a/examples/custom_backend/cancellation/client.py b/examples/custom_backend/cancellation/client.py index fbcbead315..91d982f296 100644 --- a/examples/custom_backend/cancellation/client.py +++ b/examples/custom_backend/cancellation/client.py @@ -50,7 +50,7 @@ async def main(): return loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "mem", True) # Connect to middle server or direct server based on argument if use_middle_server: diff --git a/examples/custom_backend/cancellation/middle_server.py b/examples/custom_backend/cancellation/middle_server.py index 968cee014b..e200cffe9c 100644 --- a/examples/custom_backend/cancellation/middle_server.py +++ b/examples/custom_backend/cancellation/middle_server.py @@ -50,7 +50,7 @@ async def generate(self, request, context): async def main(): """Start the middle server""" loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "mem", True) # Create middle server handler handler = MiddleServer(runtime) diff --git a/examples/custom_backend/cancellation/server.py b/examples/custom_backend/cancellation/server.py index 63a1c70938..dfc9938507 100644 --- a/examples/custom_backend/cancellation/server.py +++ b/examples/custom_backend/cancellation/server.py @@ -31,7 +31,7 @@ async def generate(self, request, context): async def main(): """Start the demo server""" loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "mem", True) # Create server component component = runtime.namespace("demo").component("server") diff --git a/examples/custom_backend/nim/mock_nim_frontend.py b/examples/custom_backend/nim/mock_nim_frontend.py index c79b5c03f7..defebe7980 100755 --- a/examples/custom_backend/nim/mock_nim_frontend.py +++ b/examples/custom_backend/nim/mock_nim_frontend.py @@ -123,7 +123,7 @@ async def async_main(): # Create DistributedRuntime - similar to frontend/main.py line 246 is_static = True # Use static mode (no etcd) - runtime = DistributedRuntime(loop, is_static) # type: ignore[call-arg] + runtime = DistributedRuntime(loop, "mem", is_static) # type: ignore[call-arg] # Setup signal handlers for graceful shutdown def signal_handler(): diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index 66603cfdd6..09de1ad6cb 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -127,6 +127,12 @@ pub struct Flags { #[arg(long, default_value = "false")] pub static_worker: bool, + /// Which key-value backend to use: etcd, mem, file. + /// Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. + /// File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv. + #[arg(long, default_value = "etcd")] + pub store_kv: String, + /// Everything after a `--`. /// These are the command line arguments to the python engine when using `pystr` or `pytok`. #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] diff --git a/launch/dynamo-run/src/lib.rs b/launch/dynamo-run/src/lib.rs index d0a63c0792..a82a80eb00 100644 --- a/launch/dynamo-run/src/lib.rs +++ b/launch/dynamo-run/src/lib.rs @@ -6,10 +6,11 @@ use dynamo_llm::entrypoint::EngineConfig; use dynamo_llm::entrypoint::input::Input; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_runtime::distributed::DistributedConfig; +use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect; +use dynamo_runtime::transports::nats; use dynamo_runtime::{DistributedRuntime, Runtime}; mod flags; -use either::Either; pub use flags::Flags; mod opt; pub use dynamo_llm::request_template::RequestTemplate; @@ -73,14 +74,16 @@ pub async fn run( // TODO: old, address this later: // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If not, then the endpoint isn't exposed so we let LocalModel invent one. - let mut rt = Either::Left(runtime.clone()); if let Input::Endpoint(path) = &in_opt { builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); - - let dst_config = DistributedConfig::from_settings(flags.static_worker); - let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; - rt = Either::Right(distributed_runtime); + } + let selected_store: KeyValueStoreSelect = flags.store_kv.parse()?; + let dst_config = DistributedConfig { + store_backend: selected_store, + nats_config: nats::ClientOptions::default(), + is_static: flags.static_worker, }; + let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; if let Some(Output::Static(path)) = &out_opt { builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?)); } @@ -98,10 +101,16 @@ pub async fn run( flags.validate(&in_opt, &out_opt)?; // Make an engine from the local_model, flags and output. - let engine_config = engine_for(out_opt, flags.clone(), local_model, rt.clone()).await?; + let engine_config = engine_for( + out_opt, + flags.clone(), + local_model, + distributed_runtime.clone(), + ) + .await?; // Run it from an input - dynamo_llm::entrypoint::input::run_input(rt, in_opt, engine_config).await?; + dynamo_llm::entrypoint::input::run_input(distributed_runtime, in_opt, engine_config).await?; Ok(()) } @@ -112,7 +121,7 @@ async fn engine_for( out_opt: Output, flags: Flags, local_model: LocalModel, - rt: Either, + drt: DistributedRuntime, ) -> anyhow::Result { match out_opt { Output::Auto => { @@ -135,10 +144,6 @@ async fn engine_for( is_static: flags.static_worker, }), Output::Mocker => { - let Either::Right(drt) = rt else { - panic!("Mocker requires a distributed runtime to run."); - }; - let args = flags.mocker_config(); let endpoint = local_model.endpoint_id().clone(); diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index c9468ec900..3878f6e2a6 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1606,6 +1606,7 @@ dependencies = [ "figment", "futures", "humantime", + "inotify", "local-ip-address", "log", "nid", @@ -2857,6 +2858,28 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.9.3", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "instant" version = "0.1.13" diff --git a/lib/bindings/python/examples/cli/cli.py b/lib/bindings/python/examples/cli/cli.py index 6d2c6ded78..f1722a172d 100644 --- a/lib/bindings/python/examples/cli/cli.py +++ b/lib/bindings/python/examples/cli/cli.py @@ -115,7 +115,7 @@ def parse_args(): async def run(): loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, False) + runtime = DistributedRuntime(loop, "etcd", False) args = parse_args() diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index fc905b45d1..f1faf54670 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -2,6 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use dynamo_llm::local_model::LocalModel; +use dynamo_runtime::distributed::DistributedConfig; +use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect; use futures::StreamExt; use once_cell::sync::OnceCell; use pyo3::IntoPyObjectExt; @@ -426,7 +428,9 @@ enum ModelInput { #[pymethods] impl DistributedRuntime { #[new] - fn new(event_loop: PyObject, is_static: bool) -> PyResult { + fn new(event_loop: PyObject, store_kv: String, is_static: bool) -> PyResult { + let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?; + // Try to get existing runtime first, create new Worker only if needed // This allows multiple DistributedRuntime instances to share the same tokio runtime let runtime = rs::Worker::runtime_from_existing() @@ -464,9 +468,14 @@ impl DistributedRuntime { rs::DistributedRuntime::from_settings_without_discovery(runtime), ) } else { + let config = DistributedConfig { + store_backend: selected_kv_store, + is_static: false, + nats_config: dynamo_runtime::transports::nats::ClientOptions::default(), + }; runtime .secondary() - .block_on(rs::DistributedRuntime::from_settings(runtime)) + .block_on(rs::DistributedRuntime::new(runtime, config)) }; let inner = inner.map_err(to_pyerr)?; @@ -628,7 +637,7 @@ impl DistributedRuntime { } fn shutdown(&self) { - self.inner.runtime().shutdown(); + self.inner.shutdown(); } fn event_loop(&self) -> PyObject { diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index ba28abca21..6980fbe622 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -299,7 +299,7 @@ pub fn run_input<'p>( let input_enum: Input = input.parse().map_err(to_pyerr)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { dynamo_llm::entrypoint::input::run_input( - either::Either::Right(distributed_runtime.inner.clone()), + distributed_runtime.inner.clone(), input_enum, engine_config.inner, ) diff --git a/lib/bindings/python/src/dynamo/runtime/__init__.py b/lib/bindings/python/src/dynamo/runtime/__init__.py index 7e9195c304..c46205d72a 100644 --- a/lib/bindings/python/src/dynamo/runtime/__init__.py +++ b/lib/bindings/python/src/dynamo/runtime/__init__.py @@ -25,7 +25,7 @@ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, static) + runtime = DistributedRuntime(loop, "etcd", static) await func(runtime, *args, **kwargs) diff --git a/lib/bindings/python/tests/cancellation/test_cancellation.py b/lib/bindings/python/tests/cancellation/test_cancellation.py index 15e0619711..1050688453 100644 --- a/lib/bindings/python/tests/cancellation/test_cancellation.py +++ b/lib/bindings/python/tests/cancellation/test_cancellation.py @@ -256,7 +256,7 @@ async def test_server_context_cancel(server, client): except ValueError as e: # Verify the expected cancellation exception is received # TODO: Should this be a asyncio.CancelledError? - assert str(e) == "Stream ended before generation completed" + assert str(e).startswith("Stream ended before generation completed") # Verify server context cancellation status assert handler.context_is_stopped diff --git a/lib/bindings/python/tests/cancellation/test_example.py b/lib/bindings/python/tests/cancellation/test_example.py index 0a253d5928..2363e8a5f0 100644 --- a/lib/bindings/python/tests/cancellation/test_example.py +++ b/lib/bindings/python/tests/cancellation/test_example.py @@ -82,20 +82,17 @@ def run_client(example_dir, use_middle=False): ) # Wait for client to complete - stdout, _ = client_proc.communicate(timeout=1) - - if client_proc.returncode != 0: - pytest.fail( - f"Client failed with return code {client_proc.returncode}. Output: {stdout}" - ) + stdout, _ = client_proc.communicate(timeout=2) + print(f"Client stdout: {stdout}") return stdout -def stop_process(process): +def stop_process(name, process): """Stop a running process and capture its output""" process.terminate() stdout, _ = process.communicate(timeout=1) + print(f"{name}: {stdout}") return stdout @@ -109,7 +106,7 @@ async def test_direct_connection_cancellation(example_dir, server_process): await asyncio.sleep(1) # Capture server output - server_output = stop_process(server_process) + server_output = stop_process("server_process", server_process) # Assert expected messages assert ( @@ -132,8 +129,8 @@ async def test_middle_server_cancellation( await asyncio.sleep(1) # Capture output from all processes - server_output = stop_process(server_process) - middle_output = stop_process(middle_server_process) + server_output = stop_process("server_process", server_process) + middle_output = stop_process("middle_server_process", middle_server_process) # Assert expected messages assert ( diff --git a/lib/bindings/python/tests/conftest.py b/lib/bindings/python/tests/conftest.py index f34abbc79f..35a2835d03 100644 --- a/lib/bindings/python/tests/conftest.py +++ b/lib/bindings/python/tests/conftest.py @@ -153,7 +153,7 @@ def start_nats_and_etcd_default_ports(): print(f"Using ETCD on default client port {etcd_client_port}") # Start services with default ports - nats_server = subprocess.Popen(["nats-server", "-js"]) + nats_server = subprocess.Popen(["nats-server", "-js", "--trace"]) etcd = subprocess.Popen(["etcd"]) return nats_server, etcd, nats_port, etcd_client_port, nats_data_dir, etcd_data_dir @@ -181,6 +181,8 @@ def start_nats_and_etcd_random_ports(): etcd = subprocess.Popen( [ "etcd", + "--logger", + "zap", "--data-dir", str(etcd_data_dir), "--listen-client-urls", @@ -221,7 +223,11 @@ def start_nats_and_etcd_random_ports(): msg = log.get("msg", "") # Look for the client port - if "serving client traffic" in msg or "serving client" in msg: + if ( + "serving client traffic" in msg + or "serving client" in msg + or "serving insecure client" in msg + ): address = log.get("address", "") match = re.search(r":(\d+)$", address) if match: @@ -430,6 +436,6 @@ async def test_my_test(runtime): ) loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, True) + runtime = DistributedRuntime(loop, "mem", True) yield runtime runtime.shutdown() diff --git a/lib/bindings/python/tests/test_kv_bindings.py b/lib/bindings/python/tests/test_kv_bindings.py index f2a5256477..7dc95deb91 100644 --- a/lib/bindings/python/tests/test_kv_bindings.py +++ b/lib/bindings/python/tests/test_kv_bindings.py @@ -34,7 +34,7 @@ async def distributed_runtime(): Each test gets its own runtime in a forked process to avoid singleton conflicts. """ loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, False) + runtime = DistributedRuntime(loop, "etcd", False) yield runtime runtime.shutdown() diff --git a/lib/llm/src/audit/sink.rs b/lib/llm/src/audit/sink.rs index 0d4f4088bb..1f0628d7fb 100644 --- a/lib/llm/src/audit/sink.rs +++ b/lib/llm/src/audit/sink.rs @@ -89,8 +89,8 @@ fn parse_sinks_from_env( } /// spawn one worker per sink; each subscribes to the bus (off hot path) -pub fn spawn_workers_from_env(drt: Option<&dynamo_runtime::DistributedRuntime>) { - let nats_client = drt.and_then(|d| d.nats_client()); +pub fn spawn_workers_from_env(drt: &dynamo_runtime::DistributedRuntime) { + let nats_client = drt.nats_client(); let sinks = parse_sinks_from_env(nats_client); for sink in sinks { let name = sink.name(); diff --git a/lib/llm/src/discovery/watcher.rs b/lib/llm/src/discovery/watcher.rs index 00412422ad..2a3df5e961 100644 --- a/lib/llm/src/discovery/watcher.rs +++ b/lib/llm/src/discovery/watcher.rs @@ -183,8 +183,8 @@ impl ModelWatcher { } } } - WatchEvent::Delete(kv) => { - let deleted_key = kv.key_str(); + WatchEvent::Delete(key) => { + let deleted_key = key.as_ref(); match self .handle_delete(deleted_key, target_namespace, global_namespace) .await diff --git a/lib/llm/src/entrypoint/input.rs b/lib/llm/src/entrypoint/input.rs index 02967019af..09c6dfb781 100644 --- a/lib/llm/src/entrypoint/input.rs +++ b/lib/llm/src/entrypoint/input.rs @@ -23,7 +23,6 @@ pub mod http; pub mod text; use dynamo_runtime::protocols::ENDPOINT_SCHEME; -use either::Either; const BATCH_PREFIX: &str = "batch:"; @@ -107,15 +106,10 @@ impl Default for Input { /// For Input::Endpoint pass a DistributedRuntime. For everything else pass either a Runtime or a /// DistributedRuntime. pub async fn run_input( - rt: Either, + drt: dynamo_runtime::DistributedRuntime, in_opt: Input, engine_config: super::EngineConfig, ) -> anyhow::Result<()> { - let runtime = match &rt { - Either::Left(rt) => rt.clone(), - Either::Right(drt) => drt.runtime().clone(), - }; - // Initialize audit bus + sink workers (off hot path; fan-out supported) if crate::audit::config::policy().enabled { let cap: usize = std::env::var("DYN_AUDIT_CAPACITY") @@ -123,38 +117,30 @@ pub async fn run_input( .and_then(|v| v.parse().ok()) .unwrap_or(1024); crate::audit::bus::init(cap); - // Pass DistributedRuntime if available for shared NATS client - let drt_ref = match &rt { - Either::Right(drt) => Some(drt), - Either::Left(_) => None, - }; - crate::audit::sink::spawn_workers_from_env(drt_ref); - tracing::info!("Audit initialized: bus cap={}", cap); + crate::audit::sink::spawn_workers_from_env(&drt); + tracing::info!(cap, "Audit initialized"); } match in_opt { Input::Http => { - http::run(runtime, engine_config).await?; + http::run(drt, engine_config).await?; } Input::Grpc => { - grpc::run(runtime, engine_config).await?; + grpc::run(drt, engine_config).await?; } Input::Text => { - text::run(runtime, None, engine_config).await?; + text::run(drt, None, engine_config).await?; } Input::Stdin => { let mut prompt = String::new(); std::io::stdin().read_to_string(&mut prompt).unwrap(); - text::run(runtime, Some(prompt), engine_config).await?; + text::run(drt, Some(prompt), engine_config).await?; } Input::Batch(path) => { - batch::run(runtime, path, engine_config).await?; + batch::run(drt, path, engine_config).await?; } Input::Endpoint(path) => { - let Either::Right(distributed_runtime) = rt else { - anyhow::bail!("Input::Endpoint requires passing a DistributedRuntime"); - }; - endpoint::run(distributed_runtime, path, engine_config).await?; + endpoint::run(drt, path, engine_config).await?; } } Ok(()) diff --git a/lib/llm/src/entrypoint/input/batch.rs b/lib/llm/src/entrypoint/input/batch.rs index 0498f08889..f379676133 100644 --- a/lib/llm/src/entrypoint/input/batch.rs +++ b/lib/llm/src/entrypoint/input/batch.rs @@ -8,7 +8,7 @@ use crate::types::openai::chat_completions::{ }; use anyhow::Context as _; use dynamo_async_openai::types::FinishReason; -use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken}; +use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken}; use futures::StreamExt; use serde::{Deserialize, Serialize}; use std::cmp; @@ -51,11 +51,11 @@ struct Entry { } pub async fn run( - runtime: Runtime, + distributed_runtime: DistributedRuntime, input_jsonl: PathBuf, engine_config: EngineConfig, ) -> anyhow::Result<()> { - let cancel_token = runtime.primary_token(); + let cancel_token = distributed_runtime.primary_token(); // Check if the path exists and is a directory if !input_jsonl.exists() || !input_jsonl.is_file() { anyhow::bail!( @@ -64,7 +64,7 @@ pub async fn run( ); } - let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?; + let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?; let pre_processor = if prepared_engine.has_tokenizer() { Some(OpenAIPreprocessor::new( diff --git a/lib/llm/src/entrypoint/input/common.rs b/lib/llm/src/entrypoint/input/common.rs index df382b3b62..0e5f52da41 100644 --- a/lib/llm/src/entrypoint/input/common.rs +++ b/lib/llm/src/entrypoint/input/common.rs @@ -24,9 +24,8 @@ use crate::{ }; use dynamo_runtime::{ - DistributedRuntime, Runtime, + DistributedRuntime, component::Client, - distributed::DistributedConfig, engine::{AsyncEngineStream, Data}, pipeline::{ Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend, @@ -55,23 +54,25 @@ impl PreparedEngine { /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. pub async fn prepare_engine( - runtime: Runtime, + distributed_runtime: DistributedRuntime, engine_config: EngineConfig, ) -> anyhow::Result { match engine_config { EngineConfig::Dynamic(local_model) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; - let store = Arc::new(distributed_runtime.store().clone()); let model_manager = Arc::new(ModelManager::new()); let watch_obj = Arc::new(ModelWatcher::new( - distributed_runtime, + distributed_runtime.clone(), model_manager.clone(), dynamo_runtime::pipeline::RouterMode::RoundRobin, None, None, )); - let (_, receiver) = store.watch(model_card::ROOT_PATH, None, runtime.primary_token()); + let (_, receiver) = store.watch( + model_card::ROOT_PATH, + None, + distributed_runtime.primary_token(), + ); let inner_watch_obj = watch_obj.clone(); let _watcher_task = tokio::spawn(async move { inner_watch_obj.watch(receiver, None).await; @@ -98,9 +99,6 @@ pub async fn prepare_engine( let card = local_model.card(); let router_mode = local_model.router_config().router_mode; - let dst_config = DistributedConfig::from_settings(true); - let distributed_runtime = DistributedRuntime::new(runtime, dst_config).await?; - let endpoint_id = local_model.endpoint_id(); let component = distributed_runtime .namespace(&endpoint_id.namespace)? diff --git a/lib/llm/src/entrypoint/input/grpc.rs b/lib/llm/src/entrypoint/input/grpc.rs index 8693c4d1d1..e39799d830 100644 --- a/lib/llm/src/entrypoint/input/grpc.rs +++ b/lib/llm/src/entrypoint/input/grpc.rs @@ -16,18 +16,20 @@ use crate::{ completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, }, }; -use dynamo_runtime::{DistributedRuntime, Runtime, storage::key_value_store::KeyValueStoreManager}; -use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; +use dynamo_runtime::pipeline::RouterMode; +use dynamo_runtime::{DistributedRuntime, storage::key_value_store::KeyValueStoreManager}; /// Build and run an KServe gRPC service -pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { +pub async fn run( + distributed_runtime: DistributedRuntime, + engine_config: EngineConfig, +) -> anyhow::Result<()> { let grpc_service_builder = kserve::KserveService::builder() .port(engine_config.local_model().http_port()) // [WIP] generalize port.. .with_request_template(engine_config.local_model().request_template()); let grpc_service = match engine_config { EngineConfig::Dynamic(_) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let store = Arc::new(distributed_runtime.store().clone()); let grpc_service = grpc_service_builder.build()?; let router_config = engine_config.local_model().router_config(); @@ -39,7 +41,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul Some(namespace.to_string()) }; run_watcher( - distributed_runtime, + distributed_runtime.clone(), grpc_service.state().manager_clone(), store, router_config.router_mode, @@ -55,8 +57,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul let checksum = card.mdcsum(); let router_mode = local_model.router_config().router_mode; - let dst_config = DistributedConfig::from_settings(true); // true means static - let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let grpc_service = grpc_service_builder.build()?; let manager = grpc_service.model_manager(); @@ -157,8 +157,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul grpc_service } }; - grpc_service.run(runtime.primary_token()).await?; - runtime.shutdown(); // Cancel primary token + grpc_service + .run(distributed_runtime.primary_token()) + .await?; + distributed_runtime.shutdown(); // Cancel primary token Ok(()) } diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 88b4e3e979..e5566141f2 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -17,12 +17,15 @@ use crate::{ completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, }, }; +use dynamo_runtime::DistributedRuntime; +use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::storage::key_value_store::KeyValueStoreManager; -use dynamo_runtime::{DistributedRuntime, Runtime}; -use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; /// Build and run an HTTP service -pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { +pub async fn run( + distributed_runtime: DistributedRuntime, + engine_config: EngineConfig, +) -> anyhow::Result<()> { let local_model = engine_config.local_model(); let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) { (Some(tls_cert_path), Some(tls_key_path)) => { @@ -63,7 +66,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul let http_service = match engine_config { EngineConfig::Dynamic(_) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; // This allows the /health endpoint to query store for active instances http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); let http_service = http_service_builder.build()?; @@ -80,7 +82,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul Some(namespace.to_string()) }; run_watcher( - distributed_runtime, + distributed_runtime.clone(), http_service.state().manager_clone(), store, router_config.router_mode, @@ -96,11 +98,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul EngineConfig::StaticRemote(local_model) => { let card = local_model.card(); let checksum = card.mdcsum(); - let router_mode = local_model.router_config().router_mode; - - let dst_config = DistributedConfig::from_settings(true); // true means static - let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let http_service = http_service_builder.build()?; let manager = http_service.model_manager(); @@ -233,8 +231,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul http_service.custom_backend_metrics_polling_interval, http_service.custom_backend_registry.as_ref(), ) { - // Create DistributedRuntime for polling, matching the engine's mode - let drt = DistributedRuntime::from_settings(runtime.clone()).await?; tracing::info!( namespace_component_endpoint=%namespace_component_endpoint, polling_interval_secs=polling_interval, @@ -246,7 +242,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul // shutdown phase. Some( crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task( - drt, + distributed_runtime.clone(), namespace_component_endpoint.clone(), polling_interval, registry.clone(), @@ -256,14 +252,16 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul None }; - http_service.run(runtime.primary_token()).await?; + http_service + .run(distributed_runtime.primary_token()) + .await?; // Abort the polling task if it was started if let Some(task) = polling_task { task.abort(); } - runtime.shutdown(); // Cancel primary token + distributed_runtime.shutdown(); // Cancel primary token Ok(()) } diff --git a/lib/llm/src/entrypoint/input/text.rs b/lib/llm/src/entrypoint/input/text.rs index 9659650a67..aced80e119 100644 --- a/lib/llm/src/entrypoint/input/text.rs +++ b/lib/llm/src/entrypoint/input/text.rs @@ -5,7 +5,8 @@ use crate::request_template::RequestTemplate; use crate::types::openai::chat_completions::{ NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, }; -use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken}; +use dynamo_runtime::DistributedRuntime; +use dynamo_runtime::pipeline::Context; use futures::StreamExt; use std::io::{ErrorKind, Write}; @@ -17,15 +18,15 @@ use crate::entrypoint::input::common; const MAX_TOKENS: u32 = 8192; pub async fn run( - runtime: Runtime, + distributed_runtime: DistributedRuntime, single_prompt: Option, engine_config: EngineConfig, ) -> anyhow::Result<()> { - let cancel_token = runtime.primary_token(); - let prepared_engine = common::prepare_engine(runtime, engine_config).await?; + let prepared_engine = + common::prepare_engine(distributed_runtime.clone(), engine_config).await?; // TODO: Pass prepared_engine directly main_loop( - cancel_token, + distributed_runtime, &prepared_engine.service_name, prepared_engine.engine, single_prompt, @@ -36,13 +37,14 @@ pub async fn run( } async fn main_loop( - cancel_token: CancellationToken, + distributed_runtime: DistributedRuntime, service_name: &str, engine: OpenAIChatCompletionsStreamingEngine, mut initial_prompt: Option, _inspect_template: bool, template: Option, ) -> anyhow::Result<()> { + let cancel_token = distributed_runtime.primary_token(); if initial_prompt.is_none() { tracing::info!("Ctrl-c to exit"); } @@ -179,7 +181,11 @@ async fn main_loop( break; } } - cancel_token.cancel(); // stop everything else println!(); + + // Stop the runtime and wait for it to stop + distributed_runtime.shutdown(); + cancel_token.cancelled().await; + Ok(()) } diff --git a/lib/llm/tests/audit_nats_integration.rs b/lib/llm/tests/audit_nats_integration.rs index f860adaf29..1819e6bcc1 100644 --- a/lib/llm/tests/audit_nats_integration.rs +++ b/lib/llm/tests/audit_nats_integration.rs @@ -167,7 +167,7 @@ mod tests { bus::init(100); let drt = create_test_drt().await; - sink::spawn_workers_from_env(Some(&drt)); + sink::spawn_workers_from_env(&drt); time::sleep(Duration::from_millis(100)).await; // Emit audit record @@ -224,7 +224,7 @@ mod tests { bus::init(100); let drt = create_test_drt().await; - sink::spawn_workers_from_env(Some(&drt)); + sink::spawn_workers_from_env(&drt); time::sleep(Duration::from_millis(100)).await; // Request with store=true (should be audited) diff --git a/lib/runtime/Cargo.toml b/lib/runtime/Cargo.toml index cd774ba16e..343ee01f34 100644 --- a/lib/runtime/Cargo.toml +++ b/lib/runtime/Cargo.toml @@ -63,6 +63,7 @@ bincode = { version = "1" } console-subscriber = { version = "0.4", optional = true } educe = { version = "0.6.0" } figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] } +inotify = { version = "0.11" } local-ip-address = { version = "0.6.3" } log = { version = "0.4" } nid = { version = "3.0.0", features = ["serde"] } diff --git a/lib/runtime/examples/Cargo.lock b/lib/runtime/examples/Cargo.lock index d7074c7e28..12dc68568c 100644 --- a/lib/runtime/examples/Cargo.lock +++ b/lib/runtime/examples/Cargo.lock @@ -679,6 +679,7 @@ dependencies = [ "figment", "futures", "humantime", + "inotify", "local-ip-address", "log", "nid", @@ -1354,6 +1355,28 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.9.0", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "iovec" version = "0.1.4" diff --git a/lib/runtime/src/component/client.rs b/lib/runtime/src/component/client.rs index 411194de09..8d888b7b21 100644 --- a/lib/runtime/src/component/client.rs +++ b/lib/runtime/src/component/client.rs @@ -1,19 +1,19 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::pipeline::{ - AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode, - SingleIn, +use crate::{ + pipeline::{ + AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode, + SingleIn, + }, + storage::key_value_store::{KeyValueStoreManager, WatchEvent}, }; use arc_swap::ArcSwap; use std::collections::HashMap; use std::sync::Arc; use tokio::net::unix::pipe::Receiver; -use crate::{ - pipeline::async_trait, - transports::etcd::{Client as EtcdClient, WatchEvent}, -}; +use crate::{pipeline::async_trait, transports::etcd::Client as EtcdClient}; use super::*; @@ -70,12 +70,7 @@ impl Client { const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1); // create live endpoint watcher - let Some(etcd_client) = &endpoint.component.drt.etcd_client else { - anyhow::bail!("Attempt to create a dynamic client on a static endpoint"); - }; - - let instance_source = - Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?; + let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?; let client = Client { endpoint, @@ -194,7 +189,6 @@ impl Client { } async fn get_or_create_dynamic_instance_source( - etcd_client: &EtcdClient, endpoint: &Endpoint, ) -> Result> { let drt = endpoint.drt(); @@ -209,12 +203,10 @@ impl Client { } } - let prefix_watcher = etcd_client - .kv_get_and_watch_prefix(endpoint.etcd_root()) - .await?; - - let (prefix, mut kv_event_rx) = prefix_watcher.dissolve(); - + let prefix = endpoint.etcd_root(); + let store = Arc::new(drt.store().clone()); + let (_, mut kv_event_rx) = + store.watch(super::INSTANCE_ROOT_PATH, None, drt.primary_token()); let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]); let secondary = endpoint.component.drt.runtime.secondary().clone(); @@ -223,7 +215,7 @@ impl Client { // currently this is created once per client, but this object/task should only be instantiated // once per worker/instance secondary.spawn(async move { - tracing::debug!("Starting endpoint watcher for prefix: {}", prefix); + tracing::debug!("Starting endpoint watcher for prefix: {prefix}"); let mut map = HashMap::new(); loop { @@ -245,23 +237,40 @@ impl Client { match kv_event { WatchEvent::Put(kv) => { - let key = String::from_utf8(kv.key().to_vec()); - let val = serde_json::from_slice::(kv.value()); - if let (Ok(key), Ok(val)) = (key, val) { - map.insert(key.clone(), val); - } else { - tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {prefix}"); - break; + let key = kv.key_str(); + if !key.starts_with(&prefix) { + continue; } - } - WatchEvent::Delete(kv) => { - match String::from_utf8(kv.key().to_vec()) { - Ok(key) => { map.remove(&key); } - Err(_) => { - tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix); + let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else { + tracing::error!("WatchEvent::Put Key not in INSTANCE_ROOT_PATH. Should be impossible."); + continue; + }; + if key.starts_with("/") { + key = &key[1..]; + } + + match serde_json::from_slice::(kv.value()) { + Ok(val) => map.insert(key.to_string(), val), + Err(err) => { + tracing::error!(error = %err, prefix, + "Unable to parse put endpoint event; shutting down endpoint watcher"); break; } + }; + } + WatchEvent::Delete(key) => { + let key = key.as_ref(); + if !key.starts_with(&prefix) { + continue; + } + let Some(mut key) = key.strip_prefix(super::INSTANCE_ROOT_PATH) else { + tracing::error!("WatchEvent::Delete Key not in INSTANCE_ROOT_PATH. Should be impossible."); + continue; + }; + if key.starts_with("/") { + key = &key[1..]; } + map.remove(key); } } diff --git a/lib/runtime/src/component/endpoint.rs b/lib/runtime/src/component/endpoint.rs index baeb46683f..395e54f9c3 100644 --- a/lib/runtime/src/component/endpoint.rs +++ b/lib/runtime/src/component/endpoint.rs @@ -4,6 +4,8 @@ use derive_getters::Dissolve; use tokio_util::sync::CancellationToken; +use crate::storage::key_value_store; + use super::*; pub use async_nats::service::endpoint::Stats as EndpointStats; @@ -118,8 +120,6 @@ impl EndpointConfigBuilder { let endpoint_name = endpoint.name.clone(); let system_health = endpoint.drt().system_health.clone(); let subject = endpoint.subject_to(connection_id); - let etcd_path = endpoint.etcd_path_with_lease_id(connection_id); - let etcd_client = endpoint.component.drt.etcd_client.clone(); // Register health check target in SystemHealth if provided if let Some(health_check_payload) = &health_check_payload { @@ -193,9 +193,6 @@ impl EndpointConfigBuilder { result }); - // make the components service endpoint discovery in etcd - - // client.register_service() let info = Instance { component: component_name.clone(), endpoint: endpoint_name.clone(), @@ -206,15 +203,16 @@ impl EndpointConfigBuilder { let info = serde_json::to_vec_pretty(&info)?; - if let Some(etcd_client) = &etcd_client - && let Err(e) = etcd_client - .kv_create(&etcd_path, info, Some(connection_id)) - .await - { + let store = endpoint.drt().store(); + let instances_bucket = store + .get_or_create_bucket(super::INSTANCE_ROOT_PATH, None) + .await?; + let key = key_value_store::Key::from_raw(endpoint.unique_path(connection_id)); + if let Err(err) = instances_bucket.insert(&key, info.into(), 0).await { tracing::error!( component_name, endpoint_name, - error = %e, + error = %err, "Unable to register service for discovery" ); endpoint_shutdown_token.cancel(); diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index fd2846a0b9..86f587a96a 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -3,7 +3,8 @@ pub use crate::component::Component; use crate::storage::key_value_store::{ - EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, MemoryStore, + EtcdStore, KeyValueStore, KeyValueStoreEnum, KeyValueStoreManager, KeyValueStoreSelect, + MemoryStore, }; use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::{ @@ -48,23 +49,22 @@ impl std::fmt::Debug for DistributedRuntime { impl DistributedRuntime { pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result { - let (etcd_config, nats_config, is_static) = config.dissolve(); + let (selected_kv_store, nats_config, is_static) = config.dissolve(); let runtime_clone = runtime.clone(); - // TODO: Here is where we will later select the KeyValueStore impl - let (etcd_client, store) = if is_static { - (None, KeyValueStoreManager::memory()) - } else { - match etcd::Client::new(etcd_config.clone(), runtime_clone).await { - Ok(etcd_client) => { - let store = KeyValueStoreManager::etcd(etcd_client.clone()); - (Some(etcd_client), store) - } - Err(err) => { - tracing::info!(%err, "Did not connect to etcd. Using memory storage."); - (None, KeyValueStoreManager::memory()) - } + let (etcd_client, store) = match (is_static, selected_kv_store) { + (false, KeyValueStoreSelect::Etcd(etcd_config)) => { + let etcd_client = etcd::Client::new(*etcd_config, runtime_clone).await.inspect_err(|err| + // The returned error doesn't show because of a dropped runtime error, so + // log it first. + tracing::error!(%err, "Could not connect to etcd. Pass `--store-kv ..` to use a different backend or start etcd."))?; + let store = KeyValueStoreManager::etcd(etcd_client.clone()); + (Some(etcd_client), store) + } + (false, KeyValueStoreSelect::File(root)) => (None, KeyValueStoreManager::file(root)), + (true, _) | (false, KeyValueStoreSelect::Memory) => { + (None, KeyValueStoreManager::memory()) } }; @@ -234,6 +234,7 @@ impl DistributedRuntime { pub fn shutdown(&self) { self.runtime.shutdown(); + self.store.shutdown(); } /// Create a [`Namespace`] @@ -302,7 +303,7 @@ impl DistributedRuntime { #[derive(Dissolve)] pub struct DistributedConfig { - pub etcd_config: etcd::ClientOptions, + pub store_backend: KeyValueStoreSelect, pub nats_config: nats::ClientOptions, pub is_static: bool, } @@ -310,22 +311,22 @@ pub struct DistributedConfig { impl DistributedConfig { pub fn from_settings(is_static: bool) -> DistributedConfig { DistributedConfig { - etcd_config: etcd::ClientOptions::default(), + store_backend: KeyValueStoreSelect::Etcd(Box::default()), nats_config: nats::ClientOptions::default(), is_static, } } pub fn for_cli() -> DistributedConfig { - let mut config = DistributedConfig { - etcd_config: etcd::ClientOptions::default(), + let etcd_config = etcd::ClientOptions { + attach_lease: false, + ..Default::default() + }; + DistributedConfig { + store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)), nats_config: nats::ClientOptions::default(), is_static: false, - }; - - config.etcd_config.attach_lease = false; - - config + } } } diff --git a/lib/runtime/src/storage/key_value_store.rs b/lib/runtime/src/storage/key_value_store.rs index 7fc122ec40..24dbb4eb39 100644 --- a/lib/runtime/src/storage/key_value_store.rs +++ b/lib/runtime/src/storage/key_value_store.rs @@ -4,14 +4,16 @@ //! Interface to a traditional key-value store such as etcd. //! "key_value_store" spelt out because in AI land "KV" means something else. -use std::collections::HashMap; -use std::fmt; use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; use std::time::Duration; +use std::{collections::HashMap, path::PathBuf}; +use std::{env, fmt}; use crate::CancellationToken; use crate::slug::Slug; +use crate::transports::etcd as etcd_transport; use async_trait::async_trait; use futures::StreamExt; use serde::{Deserialize, Serialize}; @@ -22,10 +24,15 @@ mod nats; pub use nats::NATSStore; mod etcd; pub use etcd::EtcdStore; +mod file; +pub use file::FileStore; const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100); /// A key that is safe to use directly in the KV store. +/// +/// TODO: Need to re-think this. etcd uses slash separators, so we often use from_raw +/// to avoid the slug. But other impl's, particularly file, need a real slug. #[derive(Debug, Clone, PartialEq)] pub struct Key(String); @@ -95,7 +102,7 @@ impl KeyValue { #[derive(Debug, Clone, PartialEq)] pub enum WatchEvent { Put(KeyValue), - Delete(KeyValue), + Delete(Key), } #[async_trait] @@ -112,6 +119,57 @@ pub trait KeyValueStore: Send + Sync { async fn get_bucket(&self, bucket_name: &str) -> Result, StoreError>; fn connection_id(&self) -> u64; + + fn shutdown(&self); +} + +#[derive(Clone, Debug, Default)] +pub enum KeyValueStoreSelect { + // Box it because it is significantly bigger than the other variants + Etcd(Box), + File(PathBuf), + #[default] + Memory, + // Nats not listed because likely we want to remove that impl. It is not currently used and not well tested. +} + +impl fmt::Display for KeyValueStoreSelect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + KeyValueStoreSelect::Etcd(opts) => { + let urls = opts.etcd_url.join(","); + write!(f, "Etcd({urls})") + } + KeyValueStoreSelect::File(path) => write!(f, "File({})", path.display()), + KeyValueStoreSelect::Memory => write!(f, "Memory"), + } + } +} + +impl FromStr for KeyValueStoreSelect { + type Err = anyhow::Error; + + fn from_str(s: &str) -> anyhow::Result { + match s { + "etcd" => Ok(Self::Etcd(Box::default())), + "file" => { + let root = env::var("DYN_FILE_KV") + .map(PathBuf::from) + .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv")); + Ok(Self::File(root)) + } + "mem" => Ok(Self::Memory), + x => anyhow::bail!("Unknown key-value store type '{x}'"), + } + } +} + +impl TryFrom for KeyValueStoreSelect { + type Error = anyhow::Error; + + fn try_from(s: String) -> anyhow::Result { + s.parse() + } } #[allow(clippy::large_enum_variant)] @@ -119,6 +177,7 @@ pub enum KeyValueStoreEnum { Memory(MemoryStore), Nats(NATSStore), Etcd(EtcdStore), + File(FileStore), } impl KeyValueStoreEnum { @@ -133,6 +192,7 @@ impl KeyValueStoreEnum { Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), + File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?), }) } @@ -154,6 +214,10 @@ impl KeyValueStoreEnum { .get_bucket(bucket_name) .await? .map(|b| Box::new(b) as Box), + File(x) => x + .get_bucket(bucket_name) + .await? + .map(|b| Box::new(b) as Box), }; Ok(maybe_bucket) } @@ -164,12 +228,23 @@ impl KeyValueStoreEnum { Memory(x) => x.connection_id(), Etcd(x) => x.connection_id(), Nats(x) => x.connection_id(), + File(x) => x.connection_id(), + } + } + + fn shutdown(&self) { + use KeyValueStoreEnum::*; + match self { + Memory(x) => x.shutdown(), + Etcd(x) => x.shutdown(), + Nats(x) => x.shutdown(), + File(x) => x.shutdown(), } } } #[derive(Clone)] -pub struct KeyValueStoreManager(Arc); +pub struct KeyValueStoreManager(pub Arc); impl Default for KeyValueStoreManager { fn default() -> Self { @@ -187,6 +262,10 @@ impl KeyValueStoreManager { Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client))) } + pub fn file>(root: P) -> Self { + Self::new(KeyValueStoreEnum::File(FileStore::new(root))) + } + fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager { KeyValueStoreManager(Arc::new(s)) } @@ -302,6 +381,12 @@ impl KeyValueStoreManager { } Ok(outcome) } + + /// Cleanup any temporary state. + /// TODO: Should this be async? Take &mut self? + pub fn shutdown(&self) { + self.0.shutdown() + } } /// An online storage for key-value config values. @@ -366,6 +451,9 @@ pub enum StoreError { #[error("Internal etcd error: {0}")] EtcdError(String), + #[error("Internal filesystem error: {0}")] + FilesystemError(String), + #[error("Key Value Error: {0} for bucket '{1}'")] KeyValueError(String, String), diff --git a/lib/runtime/src/storage/key_value_store/etcd.rs b/lib/runtime/src/storage/key_value_store/etcd.rs index bd3934af96..5e8c6bf5db 100644 --- a/lib/runtime/src/storage/key_value_store/etcd.rs +++ b/lib/runtime/src/storage/key_value_store/etcd.rs @@ -54,6 +54,10 @@ impl KeyValueStore for EtcdStore { fn connection_id(&self) -> u64 { self.client.lease_id() } + + fn shutdown(&self) { + // Revoke the lease? etcd will do it for us on disconnect. + } } pub struct EtcdBucket { @@ -132,13 +136,13 @@ impl KeyValueBucket for EtcdBucket { continue; } }; - let item = KeyValue::new(key, v_bytes.into()); match e.event_type() { EventType::Put => { + let item = KeyValue::new(key, v_bytes.into()); yield WatchEvent::Put(item); } EventType::Delete => { - yield WatchEvent::Delete(item); + yield WatchEvent::Delete(Key::from_raw(key)); } } } diff --git a/lib/runtime/src/storage/key_value_store/file.rs b/lib/runtime/src/storage/key_value_store/file.rs new file mode 100644 index 0000000000..392d403a2a --- /dev/null +++ b/lib/runtime/src/storage/key_value_store/file.rs @@ -0,0 +1,307 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashSet; +use std::ffi::OsString; +use std::fmt; +use std::fs; +use std::os::unix::ffi::OsStrExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; +use std::{collections::HashMap, pin::Pin}; + +use anyhow::Context as _; +use async_trait::async_trait; +use futures::StreamExt; +use inotify::{Event, EventMask, EventStream, Inotify, WatchMask}; +use parking_lot::Mutex; + +use crate::storage::key_value_store::KeyValue; + +use super::{Key, KeyValueBucket, KeyValueStore, StoreError, StoreOutcome, WatchEvent}; + +/// Treat as a singleton +#[derive(Clone)] +pub struct FileStore { + root: PathBuf, + connection_id: u64, + /// Directories we may have created files in, for shutdown cleanup + /// Arc so that we only ever have one map here after clone + active_dirs: Arc>>, +} + +impl FileStore { + pub(super) fn new>(root_dir: P) -> Self { + FileStore { + root: root_dir.into(), + connection_id: rand::random::(), + active_dirs: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[async_trait] +impl KeyValueStore for FileStore { + type Bucket = Directory; + + /// A "bucket" is a directory + async fn get_or_create_bucket( + &self, + bucket_name: &str, + _ttl: Option, // TODO ttl not used yet + ) -> Result { + let p = self.root.join(bucket_name); + if let Some(dir) = self.active_dirs.lock().get(&p) { + return Ok(dir.clone()); + }; + + if p.exists() { + // Get + if !p.is_dir() { + return Err(StoreError::FilesystemError( + "Bucket name is not a directory".to_string(), + )); + } + } else { + // Create + fs::create_dir_all(&p).map_err(to_fs_err)?; + } + let dir = Directory::new(self.root.clone(), p.clone()); + self.active_dirs.lock().insert(p, dir.clone()); + Ok(dir) + } + + /// A "bucket" is a directory + async fn get_bucket(&self, bucket_name: &str) -> Result, StoreError> { + let p = self.root.join(bucket_name); + if let Some(dir) = self.active_dirs.lock().get(&p) { + return Ok(Some(dir.clone())); + }; + + if !p.exists() { + return Ok(None); + } + if !p.is_dir() { + return Err(StoreError::FilesystemError( + "Bucket name is not a directory".to_string(), + )); + } + let dir = Directory::new(self.root.clone(), p.clone()); + self.active_dirs.lock().insert(p, dir.clone()); + Ok(Some(dir)) + } + + fn connection_id(&self) -> u64 { + self.connection_id + } + + // This cannot be a Drop imp because DistributedRuntime is cloned various places including + // Python. Drop doesn't get called. + fn shutdown(&self) { + for (_, mut dir) in self.active_dirs.lock().drain() { + if let Err(err) = dir.delete_owned_files() { + tracing::error!(error = %err, %dir, "Failed shutdown delete of owned files"); + } + } + } +} + +#[derive(Clone)] +pub struct Directory { + root: PathBuf, + p: PathBuf, + /// These are the files we created and hence must delete on shutdown + owned_files: Arc>>, +} + +impl Directory { + fn new(root: PathBuf, p: PathBuf) -> Self { + Directory { + root, + p, + owned_files: Arc::new(Mutex::new(HashSet::new())), + } + } + + fn delete_owned_files(&mut self) -> anyhow::Result<()> { + let mut errs = Vec::new(); + for p in self.owned_files.lock().drain() { + if let Err(err) = fs::remove_file(&p) { + errs.push(format!("{}: {err}", p.display())); + } + } + if !errs.is_empty() { + anyhow::bail!(errs.join(", ")); + } + Ok(()) + } +} + +impl fmt::Display for Directory { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.p.display()) + } +} + +#[async_trait] +impl KeyValueBucket for Directory { + /// Write a file to the directory + async fn insert( + &self, + key: &Key, + value: bytes::Bytes, + _revision: u64, // Not used. Maybe put in file name? + ) -> Result { + let safe_key = Key::new(key.as_ref()); // because of from_raw + let full_path = self.p.join(safe_key.as_ref()); + self.owned_files.lock().insert(full_path.clone()); + let str_path = full_path.display().to_string(); + fs::write(&full_path, &value) + .context(str_path) + .map_err(a_to_fs_err)?; + Ok(StoreOutcome::Created(0)) + } + + /// Read a file from the directory + async fn get(&self, key: &Key) -> Result, StoreError> { + let safe_key = Key::new(key.as_ref()); // because of from_raw + let full_path = self.p.join(safe_key.as_ref()); + if !full_path.exists() { + return Ok(None); + } + let str_path = full_path.display().to_string(); + let data: bytes::Bytes = fs::read(&full_path) + .context(str_path) + .map_err(a_to_fs_err)? + .into(); + Ok(Some(data)) + } + + /// Delete a file from the directory + async fn delete(&self, key: &Key) -> Result<(), StoreError> { + let safe_key = Key::new(key.as_ref()); // because of from_raw + let full_path = self.p.join(safe_key.as_ref()); + let str_path = full_path.display().to_string(); + if !full_path.exists() { + return Err(StoreError::MissingKey(str_path)); + } + + self.owned_files.lock().remove(&full_path); + + fs::remove_file(&full_path) + .context(str_path) + .map_err(a_to_fs_err) + } + + async fn watch( + &self, + ) -> Result + Send + 'life0>>, StoreError> { + let inotify = Inotify::init().map_err(to_fs_err)?; + inotify + .watches() + .add( + &self.p, + WatchMask::MODIFY | WatchMask::CREATE | WatchMask::DELETE, + ) + .map_err(to_fs_err)?; + + let dir = self.p.clone(); + Ok(Box::pin(async_stream::stream! { + let mut buffer = [0; 1024]; + let mut events = match inotify.into_event_stream(&mut buffer) { + Ok(events) => events, + Err(err) => { + tracing::error!(error = %err, "Failed getting event stream from inotify"); + return; + } + }; + while let Some(Ok(event)) = events.next().await { + let Some(name) = event.name else { + tracing::warn!("Unexpected event on the directory itself"); + continue; + }; + let item_path = dir.join(name); + let key = match item_path.strip_prefix(&self.root) { + Ok(stripped) => stripped.display().to_string().replace("_", "/"), + Err(err) => { + // Possibly this should be a panic. + // A key cannot be outside the file store root. + tracing::error!( + error = %err, + item_path = %item_path.display(), + root = %self.root.display(), + "Item in file store is not prefixed with file store root. Should be impossible. Ignoring invalid key."); + continue; + } + }; + + match event.mask { + EventMask::MODIFY | EventMask::CREATE => { + let data: bytes::Bytes = match fs::read(&item_path) { + Ok(data) => data.into(), + Err(err) => { + tracing::warn!(error = %err, item = %item_path.display(), "Failed reading event item. Skipping."); + continue; + } + }; + let item = KeyValue::new(key, data); + yield WatchEvent::Put(item); + } + EventMask::DELETE => { + yield WatchEvent::Delete(Key::from_raw(key)); + } + event_type => { + tracing::warn!(?event_type, dir = %dir.display(), "Unexpected event type"); + continue; + } + } + } + })) + } + + async fn entries(&self) -> Result, StoreError> { + let contents = fs::read_dir(&self.p) + .with_context(|| self.p.display().to_string()) + .map_err(a_to_fs_err)?; + let mut out = HashMap::new(); + for entry in contents { + let entry = entry.map_err(to_fs_err)?; + if !entry.path().is_file() { + tracing::warn!( + path = %entry.path().display(), + "Unexpected entry, directory should only contain files." + ); + continue; + } + + let key = match entry.path().strip_prefix(&self.root) { + Ok(p) => p.to_string_lossy().to_string().replace("_", "/"), + Err(err) => { + tracing::error!( + error = %err, + path = %entry.path().display(), + root = %self.root.display(), + "FileStore path not in root. Should be impossible. Skipping entry." + ); + continue; + } + }; + let data: bytes::Bytes = fs::read(entry.path()) + .with_context(|| self.p.display().to_string()) + .map_err(a_to_fs_err)? + .into(); + out.insert(key, data); + } + Ok(out) + } +} + +// For anyhow preserve the context +fn a_to_fs_err(err: anyhow::Error) -> StoreError { + StoreError::FilesystemError(format!("{err:#}")) +} + +fn to_fs_err(err: E) -> StoreError { + StoreError::FilesystemError(err.to_string()) +} diff --git a/lib/runtime/src/storage/key_value_store/mem.rs b/lib/runtime/src/storage/key_value_store/mem.rs index 287a870693..a7a9037b28 100644 --- a/lib/runtime/src/storage/key_value_store/mem.rs +++ b/lib/runtime/src/storage/key_value_store/mem.rs @@ -57,7 +57,7 @@ impl MemoryBucket { } impl MemoryStore { - pub fn new() -> Self { + pub(super) fn new() -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); MemoryStore { inner: Arc::new(MemoryStoreInner { @@ -107,6 +107,8 @@ impl KeyValueStore for MemoryStore { fn connection_id(&self) -> u64 { self.connection_id } + + fn shutdown(&self) {} } #[async_trait] @@ -205,8 +207,7 @@ impl KeyValueBucket for MemoryBucketRef { yield WatchEvent::Put(item); }, Some(MemoryEvent::Delete { key }) => { - let item = KeyValue::new(key, bytes::Bytes::new()); - yield WatchEvent::Delete(item); + yield WatchEvent::Delete(Key::from_raw(key)); } } } diff --git a/lib/runtime/src/storage/key_value_store/nats.rs b/lib/runtime/src/storage/key_value_store/nats.rs index d30e779214..b6f5802efd 100644 --- a/lib/runtime/src/storage/key_value_store/nats.rs +++ b/lib/runtime/src/storage/key_value_store/nats.rs @@ -52,6 +52,11 @@ impl KeyValueStore for NATSStore { fn connection_id(&self) -> u64 { self.client.client().server_info().client_id } + + fn shutdown(&self) { + // TODO: Track and delete any owned keys + // The TTL should ensure NATS does it, but best we do it immediately + } } impl NATSStore { @@ -160,12 +165,14 @@ impl KeyValueBucket for NATSBucket { >| async move { match maybe_entry { Ok(entry) => { - let item = KeyValue::new(entry.key, entry.value); Some(match entry.operation { - Operation::Put => WatchEvent::Put(item), - Operation::Delete => WatchEvent::Delete(item), + Operation::Put => { + let item = KeyValue::new(entry.key, entry.value); + WatchEvent::Put(item) + } + Operation::Delete => WatchEvent::Delete(Key::from_raw(entry.key)), // TODO: What is Purge? Not urgent, NATS impl not used - Operation::Purge => WatchEvent::Delete(item), + Operation::Purge => WatchEvent::Delete(Key::from_raw(entry.key)), }) } Err(e) => { diff --git a/tests/planner/unit/test_virtual_connector.py b/tests/planner/unit/test_virtual_connector.py index 98b17f8bf0..367a3446a5 100644 --- a/tests/planner/unit/test_virtual_connector.py +++ b/tests/planner/unit/test_virtual_connector.py @@ -31,7 +31,7 @@ def get_runtime(): except Exception: # If no existing runtime, create a new one loop = asyncio.get_running_loop() - _runtime_instance = DistributedRuntime(loop, False) + _runtime_instance = DistributedRuntime(loop, "etcd", False) return _runtime_instance diff --git a/tests/router/test_router_e2e_with_mockers.py b/tests/router/test_router_e2e_with_mockers.py index 9423e8461c..6cf5e97fdf 100644 --- a/tests/router/test_router_e2e_with_mockers.py +++ b/tests/router/test_router_e2e_with_mockers.py @@ -226,7 +226,7 @@ def get_runtime(): # No running loop, create a new one (sync context) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - _runtime_instance = DistributedRuntime(loop, False) + _runtime_instance = DistributedRuntime(loop, "etcd", False) return _runtime_instance