Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion architecture/gateway-single-node.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ When `openshell sandbox create` cannot connect to a gateway (connection refused,
1. `should_attempt_bootstrap()` in `crates/openshell-cli/src/bootstrap.rs` checks the error type. It returns `true` for connectivity errors and missing default TLS materials, but `false` for TLS handshake/auth errors.
2. If running in a terminal, the user is prompted to confirm.
3. `run_bootstrap()` deploys a gateway named `"openshell"`, sets it as active, and returns fresh `TlsOptions` pointing to the newly-written mTLS certs.
4. When `sandbox create` requests GPU explicitly (`--gpu`) or infers it from an image whose final name component contains `gpu` (such as `nvidia-gpu`), the bootstrap path enables gateway GPU support before retrying sandbox creation.
4. When `sandbox create` requests GPU explicitly (`--gpu`), requests `nvidia.com/gpu` through `--resource`, or infers GPU intent from an image whose final name component contains `gpu` (such as `nvidia-gpu`), the bootstrap path enables gateway GPU support before retrying sandbox creation.

## Container Environment Variables

Expand Down
49 changes: 49 additions & 0 deletions crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,21 @@ enum SandboxCommands {
#[arg(long)]
gpu: bool,

/// Request an additional sandbox resource limit.
///
/// Accepts `<NAME>=<COUNT>` and may be repeated. This is primarily
/// useful for device-plugin resources such as `nvidia.com/gpu=2` or
/// `vendor.com/vf=1`.
#[arg(long = "resource", value_name = "NAME=COUNT")]
resources: Vec<String>,

/// Override the sandbox pod runtime class.
///
/// This is forwarded to `template.runtime_class_name` and is useful
/// for device runtimes that require an explicit runtime class.
#[arg(long = "runtime-class")]
runtime_class: Option<String>,

/// SSH destination for remote bootstrap (e.g., user@hostname).
/// Only used when no cluster exists yet; ignored if a cluster is
/// already active.
Expand Down Expand Up @@ -2073,6 +2088,8 @@ async fn main() -> Result<()> {
no_keep,
editor,
gpu,
resources,
runtime_class,
remote,
ssh_key,
providers,
Expand Down Expand Up @@ -2154,6 +2171,8 @@ async fn main() -> Result<()> {
upload_spec.as_ref(),
keep,
gpu,
&resources,
runtime_class.as_deref(),
editor,
remote.as_deref(),
ssh_key.as_deref(),
Expand All @@ -2176,6 +2195,8 @@ async fn main() -> Result<()> {
upload_spec.as_ref(),
keep,
gpu,
&resources,
runtime_class.as_deref(),
editor,
remote.as_deref(),
ssh_key.as_deref(),
Expand Down Expand Up @@ -2872,6 +2893,34 @@ mod tests {
assert_eq!(dest.get_value_hint(), ValueHint::AnyPath);
}

#[test]
fn sandbox_create_accepts_resource_and_runtime_class_flags() {
let cli = Cli::try_parse_from([
"openshell",
"sandbox",
"create",
"--resource",
"nvidia.com/gpu=2",
"--resource",
"vendor.com/vf=1",
"--runtime-class",
"nvidia",
])
.expect("sandbox create should parse generic resource flags");

assert!(matches!(
cli.command,
Some(Commands::Sandbox {
command: Some(SandboxCommands::Create {
resources,
runtime_class,
..
})
}) if resources == vec!["nvidia.com/gpu=2".to_string(), "vendor.com/vf=1".to_string()]
&& runtime_class.as_deref() == Some("nvidia")
));
}

#[test]
fn parse_upload_spec_without_remote() {
let (local, remote) = parse_upload_spec("./src");
Expand Down
207 changes: 196 additions & 11 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ use openshell_providers::{
ProviderRegistry, detect_provider_from_command, normalize_provider_type,
};
use owo_colors::OwoColorize;
use std::collections::{HashMap, HashSet, VecDeque};
use prost_types::{Struct, Value, value::Kind};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::io::{IsTerminal, Write};
use std::path::{Path, PathBuf};
use std::process::Command;
Expand Down Expand Up @@ -1845,6 +1846,8 @@ pub async fn sandbox_create_with_bootstrap(
upload: Option<&(String, Option<String>, bool)>,
keep: bool,
gpu: bool,
resources: &[String],
runtime_class_name: Option<&str>,
editor: Option<Editor>,
remote: Option<&str>,
ssh_key: Option<&str>,
Expand All @@ -1856,14 +1859,26 @@ pub async fn sandbox_create_with_bootstrap(
bootstrap_override: Option<bool>,
auto_providers_override: Option<bool>,
) -> Result<()> {
if gpu && resource_specs_request_named_resource(resources, "nvidia.com/gpu") {
return Err(miette!(
"--gpu conflicts with --resource nvidia.com/gpu=<COUNT>; use only one GPU request path"
));
}
if gpu && runtime_class_name.is_some() {
return Err(miette!(
"--runtime-class cannot be combined with --gpu because GPU sandboxes force the runtime class automatically"
));
}
if !crate::bootstrap::confirm_bootstrap(bootstrap_override)? {
return Err(miette::miette!(
"No active gateway.\n\
Set one with: openshell gateway select <name>\n\
Or deploy a new gateway: openshell gateway start"
));
}
let requested_gpu = gpu || from.is_some_and(source_requests_gpu);
let requested_gpu = gpu
|| from.is_some_and(source_requests_gpu)
|| resource_specs_request_named_resource(resources, "nvidia.com/gpu");
let (tls, server, gateway_name) =
crate::bootstrap::run_bootstrap(remote, ssh_key, requested_gpu).await?;
// Disable bootstrap inside sandbox_create so that a transient connection
Expand All @@ -1876,6 +1891,8 @@ pub async fn sandbox_create_with_bootstrap(
upload,
keep,
gpu,
resources,
runtime_class_name,
editor,
remote,
ssh_key,
Expand Down Expand Up @@ -1931,6 +1948,8 @@ pub async fn sandbox_create(
upload: Option<&(String, Option<String>, bool)>,
keep: bool,
gpu: bool,
resources: &[String],
runtime_class_name: Option<&str>,
editor: Option<Editor>,
remote: Option<&str>,
ssh_key: Option<&str>,
Expand Down Expand Up @@ -1988,7 +2007,9 @@ pub async fn sandbox_create(
eprintln!();
return Err(err);
}
let requested_gpu = gpu || from.is_some_and(source_requests_gpu);
let requested_gpu = gpu
|| from.is_some_and(source_requests_gpu)
|| resource_specs_request_named_resource(resources, "nvidia.com/gpu");
let (new_tls, new_server, _) =
crate::bootstrap::run_bootstrap(remote, ssh_key, requested_gpu).await?;
let c = grpc_client(&new_server, &new_tls)
Expand Down Expand Up @@ -2017,6 +2038,17 @@ pub async fn sandbox_create(
None => None,
};
let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu);
let requested_gpu_resource = resource_specs_request_named_resource(resources, "nvidia.com/gpu");
if gpu && requested_gpu_resource {
return Err(miette!(
"--gpu conflicts with --resource nvidia.com/gpu=<COUNT>; use only one GPU request path"
));
}
if gpu && runtime_class_name.is_some() {
return Err(miette!(
"--runtime-class cannot be combined with --gpu because GPU sandboxes force the runtime class automatically"
));
}

let inferred_types: Vec<String> = inferred_provider_type(command).into_iter().collect();
let configured_providers = ensure_required_providers(
Expand All @@ -2029,10 +2061,7 @@ pub async fn sandbox_create(

let policy = load_sandbox_policy(policy)?;

let template = image.map(|img| SandboxTemplate {
image: img,
..SandboxTemplate::default()
});
let template = build_sandbox_template(image, runtime_class_name, resources)?;

let request = CreateSandboxRequest {
spec: Some(SandboxSpec {
Expand Down Expand Up @@ -2547,6 +2576,87 @@ fn image_requests_gpu(image: &str) -> bool {
image_name.contains("gpu")
}

fn parse_resource_limits(resources: &[String]) -> Result<Option<Struct>> {
if resources.is_empty() {
return Ok(None);
}

let mut limits = BTreeMap::new();
for resource in resources {
let (name, count) = resource.split_once('=').ok_or_else(|| {
miette!("invalid --resource value '{resource}'; expected <NAME>=<COUNT>")
})?;
if name.is_empty() {
return Err(miette!(
"invalid --resource value '{resource}'; resource name is empty"
));
}
let count = count.parse::<u32>().map_err(|_| {
miette!("invalid --resource value '{resource}'; count must be a positive integer")
})?;
if count == 0 {
return Err(miette!(
"invalid --resource value '{resource}'; count must be greater than zero"
));
}
if limits
.insert(name.to_string(), string_value(count.to_string()))
.is_some()
{
return Err(miette!(
"duplicate --resource entry for '{name}'; specify each resource only once"
));
}
}

Ok(Some(Struct {
fields: BTreeMap::from([("limits".to_string(), struct_value(limits))]),
}))
}

fn resource_specs_request_named_resource(resources: &[String], resource_name: &str) -> bool {
resources
.iter()
.filter_map(|resource| resource.split_once('='))
.any(|(name, _)| name == resource_name)
}

fn build_sandbox_template(
image: Option<String>,
runtime_class_name: Option<&str>,
resources: &[String],
) -> Result<Option<SandboxTemplate>> {
let mut template = SandboxTemplate::default();
let mut changed = false;

if let Some(image) = image {
template.image = image;
changed = true;
}
if let Some(runtime_class_name) = runtime_class_name.filter(|value| !value.is_empty()) {
template.runtime_class_name = runtime_class_name.to_string();
changed = true;
}
if let Some(resource_limits) = parse_resource_limits(resources)? {
template.resources = Some(resource_limits);
changed = true;
}

Ok(changed.then_some(template))
}

fn string_value(value: impl Into<String>) -> Value {
Value {
kind: Some(Kind::StringValue(value.into())),
}
}

fn struct_value(fields: BTreeMap<String, Value>) -> Value {
Value {
kind: Some(Kind::StructValue(Struct { fields })),
}
}

/// Build a Dockerfile and push the resulting image into the gateway.
///
/// Returns the image tag that was built so the caller can use it for sandbox
Expand Down Expand Up @@ -4987,12 +5097,14 @@ fn format_timestamp_ms(ms: i64) -> String {
#[cfg(test)]
mod tests {
use super::{
GatewayControlTarget, TlsOptions, format_gateway_select_header,
GatewayControlTarget, TlsOptions, build_sandbox_template, format_gateway_select_header,
format_gateway_select_items, gateway_auth_label, gateway_select_with, gateway_type_label,
git_sync_files, http_health_check, image_requests_gpu, inferred_provider_type,
parse_cli_setting_value, parse_credential_pairs, provisioning_timeout_message,
ready_false_condition_message, resolve_gateway_control_target_from, sandbox_should_persist,
shell_escape, source_requests_gpu, validate_gateway_name, validate_ssh_host,
parse_cli_setting_value, parse_credential_pairs, parse_resource_limits,
provisioning_timeout_message, ready_false_condition_message,
resolve_gateway_control_target_from, resource_specs_request_named_resource,
sandbox_should_persist, shell_escape, source_requests_gpu, validate_gateway_name,
validate_ssh_host,
};
use crate::TEST_ENV_LOCK;
use hyper::StatusCode;
Expand All @@ -5006,6 +5118,7 @@ mod tests {

use openshell_bootstrap::GatewayMetadata;
use openshell_core::proto::{SandboxCondition, SandboxStatus};
use prost_types::value::Kind;

struct EnvVarGuard {
key: &'static str,
Expand Down Expand Up @@ -5206,6 +5319,78 @@ mod tests {
assert!(sandbox_should_persist(false, Some(&spec)));
}

#[test]
fn parse_resource_limits_builds_limits_struct() {
let parsed = parse_resource_limits(&[
"nvidia.com/gpu=2".to_string(),
"vendor.com/vf=1".to_string(),
])
.expect("parse resource limits")
.expect("resource limits should exist");

let limits = parsed
.fields
.get("limits")
.expect("limits field should exist");
let Kind::StructValue(limits) = limits.kind.clone().expect("limits kind should exist")
else {
panic!("limits should be a struct");
};
let Kind::StringValue(gpu) = limits
.fields
.get("nvidia.com/gpu")
.and_then(|value| value.kind.clone())
.expect("gpu limit should exist")
else {
panic!("gpu limit should be a string");
};
let Kind::StringValue(vf) = limits
.fields
.get("vendor.com/vf")
.and_then(|value| value.kind.clone())
.expect("vf limit should exist")
else {
panic!("vf limit should be a string");
};

assert_eq!(gpu, "2");
assert_eq!(vf, "1");
}

#[test]
fn parse_resource_limits_rejects_invalid_values() {
let err = parse_resource_limits(&["nvidia.com/gpu=two".to_string()])
.expect_err("invalid counts should fail");
assert!(err.to_string().contains("count must be a positive integer"));
}

#[test]
fn build_sandbox_template_includes_runtime_class_and_resources() {
let template = build_sandbox_template(
Some("example.com/test:latest".to_string()),
Some("nvidia"),
&["nvidia.com/gpu=2".to_string()],
)
.expect("build template")
.expect("template should exist");

assert_eq!(template.image, "example.com/test:latest");
assert_eq!(template.runtime_class_name, "nvidia");
assert!(template.resources.is_some());
}

#[test]
fn resource_specs_request_named_resource_detects_gpu_requests() {
assert!(resource_specs_request_named_resource(
&["nvidia.com/gpu=2".to_string()],
"nvidia.com/gpu"
));
assert!(!resource_specs_request_named_resource(
&["vendor.com/vf=1".to_string()],
"nvidia.com/gpu"
));
}

#[test]
fn image_requests_gpu_matches_known_gpu_image_names() {
for image in [
Expand Down
Loading
Loading