diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 21f15834b..d4a19c4bf 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1227,6 +1227,14 @@ enum SandboxCommands { #[arg(long)] memory: Option, + /// Experimental driver-keyed JSON object for driver-specific sandbox settings. + /// Validation behavior is not yet finalized. + /// + /// For Kubernetes, pass a value such as + /// `{"kubernetes":{"pod":{"node_selector":{"pool":"gpu"}}}}`. + #[arg(long, value_name = "JSON")] + driver_config_json: Option, + /// Provider names to attach to this sandbox. #[arg(long = "provider")] providers: Vec, @@ -2541,6 +2549,7 @@ async fn main() -> Result<()> { gpu_device, cpu, memory, + driver_config_json, providers, policy, forward, @@ -2621,6 +2630,7 @@ async fn main() -> Result<()> { gpu_device.as_deref(), cpu.as_deref(), memory.as_deref(), + driver_config_json.as_deref(), editor, &providers, policy.as_deref(), @@ -4335,6 +4345,32 @@ mod tests { } } + #[test] + fn sandbox_create_driver_config_json_flag_parses() { + let json = r#"{"kubernetes":{"pod":{"node_selector":{"pool":"gpu"}}}}"#; + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--driver-config-json", + json, + ]) + .expect("sandbox create driver config JSON flag should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: + Some(SandboxCommands::Create { + driver_config_json, .. + }), + .. + }) => { + assert_eq!(driver_config_json.as_deref(), Some(json)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 3b2adfc59..b29e2f633 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1607,6 +1607,52 @@ fn build_sandbox_resource_limits( Ok(Some(Struct { fields })) } +fn parse_driver_config_json(value: &str) -> Result { + let parsed: serde_json::Value = serde_json::from_str(value) + .into_diagnostic() + .wrap_err("--driver-config-json must be valid JSON")?; + + let serde_json::Value::Object(fields) = parsed else { + return Err(miette!( + "--driver-config-json must be a JSON object keyed by driver name" + )); + }; + + Ok(prost_types::Struct { + fields: fields + .into_iter() + .map(|(key, value)| json_to_protobuf_value(value).map(|value| (key, value))) + .collect::>()?, + }) +} + +fn json_to_protobuf_value(value: serde_json::Value) -> Result { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + let kind = match value { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(value) => Kind::BoolValue(value), + serde_json::Value::Number(value) => Kind::NumberValue(value.as_f64().ok_or_else(|| { + miette!("--driver-config-json contains a number that cannot be represented") + })?), + serde_json::Value::String(value) => Kind::StringValue(value), + serde_json::Value::Array(values) => Kind::ListValue(ListValue { + values: values + .into_iter() + .map(json_to_protobuf_value) + .collect::>()?, + }), + serde_json::Value::Object(fields) => Kind::StructValue(Struct { + fields: fields + .into_iter() + .map(|(key, value)| json_to_protobuf_value(value).map(|value| (key, value))) + .collect::>()?, + }), + }; + + Ok(Value { kind: Some(kind) }) +} + fn validate_cpu_quantity(value: &str) -> Result { let value = value.trim(); if value.is_empty() { @@ -1701,6 +1747,7 @@ pub async fn sandbox_create( gpu_device: Option<&str>, cpu: Option<&str>, memory: Option<&str>, + driver_config_json: Option<&str>, editor: Option, providers: &[String], policy: Option<&str>, @@ -1770,11 +1817,15 @@ pub async fn sandbox_create( let policy = load_sandbox_policy(policy)?; let resource_limits = build_sandbox_resource_limits(cpu, memory)?; + let driver_config = driver_config_json + .map(parse_driver_config_json) + .transpose()?; - let template = if image.is_some() || resource_limits.is_some() { + let template = if image.is_some() || resource_limits.is_some() || driver_config.is_some() { Some(SandboxTemplate { image: image.unwrap_or_default(), resources: resource_limits, + driver_config, ..SandboxTemplate::default() }) } else { @@ -7471,11 +7522,11 @@ mod tests { git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, mtls_certs_exist_for_gateway, package_managed_tls_dirs, parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs, - parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata, - provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, - ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, - sandbox_should_persist, sandbox_upload_plan, service_expose_status_error, - service_url_for_gateway, + parse_credential_pairs, parse_driver_config_json, plaintext_gateway_is_remote, + progress_step_from_metadata, provider_profile_allows_refresh_bootstrap, + provisioning_timeout_message, ready_false_condition_message, refresh_status_header, + refresh_status_row, resolve_from, sandbox_should_persist, sandbox_upload_plan, + service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -7873,6 +7924,56 @@ mod tests { assert!(build_sandbox_resource_limits(None, Some("1.5Gi")).is_err()); } + #[test] + fn parse_driver_config_json_accepts_driver_keyed_object() { + let config = + parse_driver_config_json(r#"{"kubernetes":{"pod":{"node_selector":{"pool":"gpu"}}}}"#) + .expect("driver config should parse"); + + let kubernetes = config + .fields + .get("kubernetes") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StructValue(inner) => Some(inner), + _ => None, + }) + .expect("kubernetes block should be a struct"); + let pod = kubernetes + .fields + .get("pod") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StructValue(inner) => Some(inner), + _ => None, + }) + .expect("pod block should be a struct"); + + assert!(pod.fields.contains_key("node_selector")); + } + + #[test] + fn parse_driver_config_json_rejects_non_object() { + let err = parse_driver_config_json(r#"["kubernetes"]"#) + .expect_err("top-level array should be rejected"); + + assert!( + err.to_string().contains("keyed by driver name"), + "unexpected error: {err}" + ); + } + + #[test] + fn parse_driver_config_json_rejects_invalid_json() { + let err = parse_driver_config_json(r#"{"kubernetes":"#) + .expect_err("invalid JSON should be rejected"); + + assert!( + err.to_string().contains("must be valid JSON"), + "unexpected error: {err}" + ); + } + #[test] fn inferred_provider_type_returns_type_for_known_command() { let result = inferred_provider_type(&["claude".to_string(), "--help".to_string()]); diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 828dbd998..96821655d 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -787,6 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -829,6 +830,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { Some("500m"), Some("2Gi"), None, + None, &[], None, None, @@ -884,6 +886,79 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { assert!(!resources.fields.contains_key("requests")); } +#[tokio::test] +async fn sandbox_create_sends_driver_config_json() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("driver-config"), + None, + "openshell", + &[], + true, + false, + None, + None, + None, + Some(r#"{"kubernetes":{"pod":{"priority_class_name":"batch-low"}}}"#), + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let driver_config = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.template.as_ref()) + .and_then(|template| template.driver_config.as_ref()) + .expect("driver config should be sent"); + let kubernetes = driver_config + .fields + .get("kubernetes") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StructValue(inner) => Some(inner), + _ => None, + }) + .expect("kubernetes block should be a struct"); + let pod = kubernetes + .fields + .get("pod") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StructValue(inner) => Some(inner), + _ => None, + }) + .expect("pod block should be a struct"); + + assert_eq!( + pod.fields + .get("priority_class_name") + .and_then(|value| value.kind.as_ref()) + .and_then(|kind| match kind { + prost_types::value::Kind::StringValue(value) => Some(value.as_str()), + _ => None, + }), + Some("batch-low") + ); +} + #[tokio::test] async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { let server = run_server().await; @@ -906,6 +981,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, + None, &[], None, None, @@ -963,6 +1039,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, + None, &[], None, None, @@ -1016,6 +1093,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, + None, &[], None, None, @@ -1061,6 +1139,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, + None, &[], None, None, @@ -1102,6 +1181,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1147,6 +1227,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1192,6 +1273,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -1237,6 +1319,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index c9b34ff8f..4a902a48b 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -39,8 +39,7 @@ fn test_sandbox() -> DriverSandbox { agent_socket_path: String::new(), labels: HashMap::new(), environment: HashMap::from([("TEMPLATE_ENV".to_string(), "template".to_string())]), - resources: None, - platform_config: None, + ..Default::default() }), gpu: false, gpu_device: String::new(), @@ -391,7 +390,7 @@ fn docker_resource_limits_rejects_requests() { memory_request: String::new(), memory_limit: String::new(), }), - platform_config: None, + ..Default::default() }; let err = docker_resource_limits(&template).unwrap_err(); @@ -411,7 +410,7 @@ fn docker_resource_limits_applies_cpu_and_memory_limits() { memory_limit: "2Gi".to_string(), ..Default::default() }), - platform_config: None, + ..Default::default() }; let limits = docker_resource_limits(&template).unwrap(); diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index f4b98ffdb..0bdcf3748 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -65,3 +65,37 @@ When a sandbox requests GPU support, the driver checks node allocatable capacity for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The sandbox image must provide the user-space libraries needed by the agent workload. + +## Driver Config POC + +The RFC 0005 POC accepts the selected `SandboxTemplate.driver_config.kubernetes` +block as `DriverSandboxTemplate.driver_config`. The Kubernetes driver owns the +nested schema and currently accepts: + +- `pod.node_selector` +- `pod.tolerations` +- `pod.runtime_class_name` +- `pod.priority_class_name` +- `containers.agent.resources.requests` +- `containers.agent.resources.limits` + +Nested keys inside the `kubernetes` block use snake_case. The top-level +`driver_config` envelope is keyed by driver names, so `kubernetes` is not part +of the nested schema. + +Set this through the CLI with the public driver-keyed envelope. The gateway +forwards only the `kubernetes` object to this driver: + +```shell +openshell sandbox create \ + --driver-config-json '{"kubernetes":{"pod":{"runtime_class_name":"kata-containers","node_selector":{"pool":"gpu"}}}}' \ + -- claude +``` + +Resource keys use native Kubernetes resource names and quantity strings. The +POC parser renders the keys listed above and ignores unknown fields. +`pod.runtime_class_name` maps to PodSpec `runtimeClassName` and overrides the +driver's configured `default_runtime_class_name`; the typed public +`SandboxTemplate.runtime_class_name` still takes precedence when set. Use the +public `gpu` flag for the default GPU request and `driver_config` only for +additional driver-owned resource details. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 830b85225..449cee58d 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -28,6 +28,7 @@ use openshell_core::proto::compute::v1::{ GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; +use serde::Deserialize; use std::collections::BTreeMap; use std::pin::Pin; use std::time::Duration; @@ -79,6 +80,47 @@ pub const SANDBOX_KIND: &str = "Sandbox"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; const GPU_RESOURCE_QUANTITY: &str = "1"; +// This POC treats the selected Struct as a driver-local typed schema. Once the +// Kubernetes shape stabilizes, these serde structs may move to driver-local +// protobuf definitions, but the typed decode should stay inside this driver. +// Do not promote Kubernetes config messages into the public API or gateway +// translation layer; the RFC boundary is Struct at the gateway, typed config in +// the selected driver. +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +struct KubernetesSandboxDriverConfig { + pod: KubernetesPodDriverConfig, + containers: KubernetesDriverContainersConfig, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +struct KubernetesPodDriverConfig { + node_selector: BTreeMap, + runtime_class_name: String, + tolerations: Vec, + priority_class_name: String, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +struct KubernetesDriverContainersConfig { + agent: KubernetesContainerDriverConfig, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +struct KubernetesContainerDriverConfig { + resources: KubernetesContainerResourceConfig, +} + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +struct KubernetesContainerResourceConfig { + requests: BTreeMap, + limits: BTreeMap, +} + // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) // --------------------------------------------------------------------------- @@ -1086,6 +1128,21 @@ fn spec_pod_env(spec: Option<&SandboxSpec>) -> std::collections::HashMap KubernetesSandboxDriverConfig { + let Some(config) = template.driver_config.as_ref() else { + return KubernetesSandboxDriverConfig::default(); + }; + + let json = serde_json::Value::Object(proto_struct_to_json_object(config)); + match serde_json::from_value(json) { + Ok(config) => config, + Err(err) => { + warn!(error = %err, "Ignoring invalid Kubernetes driver_config"); + KubernetesSandboxDriverConfig::default() + } + } +} + fn sandbox_to_k8s_spec( spec: Option<&SandboxSpec>, params: &SandboxPodParams<'_>, @@ -1160,6 +1217,8 @@ fn sandbox_template_to_k8s( inject_workspace: bool, params: &SandboxPodParams<'_>, ) -> serde_json::Value { + let driver_config = kubernetes_driver_config(template); + let mut metadata = serde_json::Map::new(); if !template.labels.is_empty() { metadata.insert("labels".to_string(), serde_json::json!(template.labels)); @@ -1190,10 +1249,15 @@ fn sandbox_template_to_k8s( } let mut spec = serde_json::Map::new(); - let runtime_class_name = platform_config_string(template, "runtime_class_name").or_else(|| { - (!params.default_runtime_class_name.is_empty()) - .then(|| params.default_runtime_class_name.to_string()) - }); + let runtime_class_name = platform_config_string(template, "runtime_class_name") + .or_else(|| { + (!driver_config.pod.runtime_class_name.is_empty()) + .then(|| driver_config.pod.runtime_class_name.clone()) + }) + .or_else(|| { + (!params.default_runtime_class_name.is_empty()) + .then(|| params.default_runtime_class_name.to_string()) + }); if let Some(runtime_class) = runtime_class_name { spec.insert( "runtimeClassName".to_string(), @@ -1206,6 +1270,7 @@ fn sandbox_template_to_k8s( if let Some(tolerations) = platform_config_struct(template, "tolerations") { spec.insert("tolerations".to_string(), tolerations); } + apply_pod_driver_config(&mut spec, &driver_config.pod); // Per-sandbox platform_config.host_users overrides the cluster-wide default. let use_user_namespaces = platform_config_bool(template, "host_users") @@ -1317,6 +1382,7 @@ fn sandbox_template_to_k8s( if let Some(resources) = container_resources(template, gpu) { container.insert("resources".to_string(), resources); } + apply_agent_driver_resources(&mut container, &driver_config.containers.agent.resources); spec.insert( "containers".to_string(), serde_json::Value::Array(vec![serde_json::Value::Object(container)]), @@ -1386,6 +1452,83 @@ fn sandbox_template_to_k8s( result } +fn apply_pod_driver_config( + spec: &mut serde_json::Map, + config: &KubernetesPodDriverConfig, +) { + if !config.node_selector.is_empty() { + let node_selector = spec + .entry("nodeSelector".to_string()) + .or_insert_with(|| serde_json::json!({})); + merge_string_map(node_selector, &config.node_selector); + } + + if !config.priority_class_name.is_empty() { + spec.entry("priorityClassName".to_string()) + .or_insert_with(|| serde_json::json!(config.priority_class_name)); + } + + if !config.tolerations.is_empty() { + let tolerations = spec + .entry("tolerations".to_string()) + .or_insert_with(|| serde_json::json!([])); + if let Some(existing) = tolerations.as_array_mut() { + existing.extend(config.tolerations.iter().cloned()); + } else { + *tolerations = serde_json::Value::Array(config.tolerations.clone()); + } + } +} + +fn apply_agent_driver_resources( + container: &mut serde_json::Map, + resources: &KubernetesContainerResourceConfig, +) { + if resources.requests.is_empty() && resources.limits.is_empty() { + return; + } + + let target = container + .entry("resources".to_string()) + .or_insert_with(|| serde_json::json!({})); + apply_resource_quantity_map(target, "requests", &resources.requests); + apply_resource_quantity_map(target, "limits", &resources.limits); +} + +fn merge_string_map(target: &mut serde_json::Value, values: &BTreeMap) { + if !target.is_object() { + *target = serde_json::json!({}); + } + let target = target + .as_object_mut() + .expect("target was converted to object"); + for (key, value) in values { + target + .entry(key.clone()) + .or_insert_with(|| serde_json::json!(value)); + } +} + +fn apply_resource_quantity_map( + target: &mut serde_json::Value, + section: &str, + values: &BTreeMap, +) { + if values.is_empty() { + return; + } + if !target.is_object() { + *target = serde_json::json!({}); + } + let target = target + .as_object_mut() + .expect("target was converted to object"); + let section_value = target + .entry(section.to_string()) + .or_insert_with(|| serde_json::json!({})); + merge_string_map(section_value, values); +} + fn image_pull_secret_refs(secrets: &[String]) -> Vec { secrets .iter() @@ -1607,6 +1750,16 @@ fn platform_config_struct(template: &SandboxTemplate, key: &str) -> Option serde_json::Map { + config + .fields + .iter() + .map(|(key, value)| (key.clone(), proto_value_to_json(value))) + .collect() +} + fn proto_value_to_json(value: &prost_types::Value) -> serde_json::Value { match value.kind.as_ref() { Some(prost_types::value::Kind::NumberValue(num)) => serde_json::Number::from_f64(*num) @@ -1704,6 +1857,57 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn json_struct(value: serde_json::Value) -> Struct { + match json_value(value).kind { + Some(Kind::StructValue(value)) => value, + _ => panic!("expected JSON object"), + } + } + + fn json_value(value: serde_json::Value) -> Value { + match value { + serde_json::Value::Null => Value { kind: None }, + serde_json::Value::Bool(value) => Value { + kind: Some(Kind::BoolValue(value)), + }, + serde_json::Value::Number(value) => Value { + kind: value.as_f64().map(Kind::NumberValue), + }, + serde_json::Value::String(value) => Value { + kind: Some(Kind::StringValue(value)), + }, + serde_json::Value::Array(values) => Value { + kind: Some(Kind::ListValue(prost_types::ListValue { + values: values.into_iter().map(json_value).collect(), + })), + }, + serde_json::Value::Object(values) => Value { + kind: Some(Kind::StructValue(Struct { + fields: values + .into_iter() + .map(|(key, value)| (key, json_value(value))) + .collect(), + })), + }, + } + } + + #[test] + fn driver_config_ignores_invalid_shape() { + let template = SandboxTemplate { + driver_config: Some(json_struct(serde_json::json!({ + "pod": "not-an-object" + }))), + ..SandboxTemplate::default() + }; + + let config = kubernetes_driver_config(&template); + + assert!(config.pod.node_selector.is_empty()); + assert!(config.containers.agent.resources.requests.is_empty()); + assert!(config.containers.agent.resources.limits.is_empty()); + } + #[test] fn kube_pulling_event_adds_image_progress_metadata() { let mut metadata = std::collections::HashMap::new(); @@ -2154,6 +2358,102 @@ mod tests { ); } + #[test] + fn driver_config_runtime_class_name_applies_to_pod_spec() { + let template = SandboxTemplate { + driver_config: Some(json_struct(serde_json::json!({ + "pod": { + "runtime_class_name": "kata-containers" + } + }))), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["runtimeClassName"], + serde_json::json!("kata-containers") + ); + } + + #[test] + fn driver_config_runtime_class_name_overrides_config_default() { + let template = SandboxTemplate { + driver_config: Some(json_struct(serde_json::json!({ + "pod": { + "runtime_class_name": "kata-containers" + } + }))), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams { + default_runtime_class_name: "gvisor", + ..SandboxPodParams::default() + }; + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["runtimeClassName"], + serde_json::json!("kata-containers") + ); + } + + #[test] + fn template_runtime_class_name_overrides_driver_config() { + let template = SandboxTemplate { + platform_config: Some(Struct { + fields: std::iter::once(( + "runtime_class_name".to_string(), + Value { + kind: Some(Kind::StringValue("gvisor".to_string())), + }, + )) + .collect(), + }), + driver_config: Some(json_struct(serde_json::json!({ + "pod": { + "runtime_class_name": "kata-containers" + } + }))), + ..SandboxTemplate::default() + }; + + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["runtimeClassName"], + serde_json::json!("gvisor") + ); + } + #[test] fn runtime_class_name_omitted_when_both_template_and_default_empty() { let template = SandboxTemplate::default(); @@ -2959,6 +3259,72 @@ mod tests { assert_eq!(tolerations[0]["effect"], "NoSchedule"); } + #[test] + fn driver_config_applies_pod_scheduling_and_agent_resources() { + let template = SandboxTemplate { + driver_config: Some(json_struct(serde_json::json!({ + "pod": { + "node_selector": { + "accelerator": "nvidia" + }, + "runtime_class_name": "kata-containers", + "priority_class_name": "gpu-workload", + "tolerations": [{ + "key": "nvidia.com/gpu", + "operator": "Exists", + "effect": "NoSchedule" + }] + }, + "containers": { + "agent": { + "resources": { + "requests": { + "vendor.example/gpu-memory": "8Gi" + }, + "limits": { + "vendor.example/gpu-slices": "1" + } + } + } + } + }))), + ..SandboxTemplate::default() + }; + + let pod_template = sandbox_template_to_k8s( + &template, + false, + &std::collections::HashMap::new(), + false, + &SandboxPodParams::default(), + ); + + assert_eq!( + pod_template["spec"]["nodeSelector"]["accelerator"], + serde_json::json!("nvidia") + ); + assert_eq!( + pod_template["spec"]["priorityClassName"], + serde_json::json!("gpu-workload") + ); + assert_eq!( + pod_template["spec"]["runtimeClassName"], + serde_json::json!("kata-containers") + ); + assert_eq!( + pod_template["spec"]["tolerations"][0]["key"], + serde_json::json!("nvidia.com/gpu") + ); + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["requests"]["vendor.example/gpu-memory"], + serde_json::json!("8Gi") + ); + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"]["vendor.example/gpu-slices"], + serde_json::json!("1") + ); + } + #[test] fn default_workspace_vct_uses_provided_storage_size() { let vct = default_workspace_volume_claim_templates("5Gi"); diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index d6c1966e7..064eb3857 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -425,7 +425,8 @@ impl ComputeRuntime { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { - let driver_sandbox = driver_sandbox_from_public(sandbox); + let driver_sandbox = + driver_sandbox_from_public(sandbox, self.driver_kind).map_err(|status| *status)?; self.driver .validate_sandbox_create(Request::new(ValidateSandboxCreateRequest { sandbox: Some(driver_sandbox), @@ -440,6 +441,8 @@ impl ComputeRuntime { sandbox_token: Option, ) -> Result { let sandbox_id = sandbox.object_id().to_string(); + let mut driver_sandbox = + driver_sandbox_from_public(&sandbox, self.driver_kind).map_err(|status| *status)?; // Create with MustCreate condition to prevent duplicate creation race self.sandbox_index.update_from_sandbox(&sandbox); @@ -469,7 +472,6 @@ impl ComputeRuntime { } })?; - let mut driver_sandbox = driver_sandbox_from_public(&sandbox); if let Some(token) = sandbox_token && let Some(spec) = driver_sandbox.spec.as_mut() { @@ -551,12 +553,11 @@ impl ComputeRuntime { self.sandbox_watch_bus.notify(&id); self.cleanup_sandbox_owned_records(&sandbox).await; - let driver_sandbox = driver_sandbox_from_public(&sandbox); let deleted = self .driver .delete_sandbox(Request::new(DeleteSandboxRequest { - sandbox_id: driver_sandbox.id, - sandbox_name: driver_sandbox.name, + sandbox_id: sandbox.object_id().to_string(), + sandbox_name: sandbox.object_name().to_string(), })) .await .map(|response| response.into_inner().deleted) @@ -1249,38 +1250,75 @@ impl ComputeRuntime { } } -fn driver_sandbox_from_public(sandbox: &Sandbox) -> DriverSandbox { - DriverSandbox { +fn driver_sandbox_from_public( + sandbox: &Sandbox, + driver_kind: Option, +) -> Result> { + Ok(DriverSandbox { id: sandbox.object_id().to_string(), name: sandbox.object_name().to_string(), namespace: String::new(), // Namespace is set by the driver based on its config - spec: sandbox.spec.as_ref().map(driver_sandbox_spec_from_public), + spec: sandbox + .spec + .as_ref() + .map(|spec| driver_sandbox_spec_from_public(spec, driver_kind)) + .transpose()?, status: sandbox.status.as_ref().map(driver_status_from_public), - } + }) } -fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { - DriverSandboxSpec { +fn driver_sandbox_spec_from_public( + spec: &SandboxSpec, + driver_kind: Option, +) -> Result> { + Ok(DriverSandboxSpec { log_level: spec.log_level.clone(), environment: spec.environment.clone(), template: spec .template .as_ref() - .map(driver_sandbox_template_from_public), + .map(|template| driver_sandbox_template_from_public(template, driver_kind)) + .transpose()?, gpu: spec.gpu, gpu_device: spec.gpu_device.clone(), sandbox_token: String::new(), - } + }) } -fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { - DriverSandboxTemplate { +fn driver_sandbox_template_from_public( + template: &SandboxTemplate, + driver_kind: Option, +) -> Result> { + Ok(DriverSandboxTemplate { image: template.image.clone(), agent_socket_path: template.agent_socket.clone(), labels: template.labels.clone(), environment: template.environment.clone(), resources: extract_typed_resources(&template.resources), platform_config: build_platform_config(template), + driver_config: select_driver_config(&template.driver_config, driver_kind)?, + }) +} + +fn select_driver_config( + config: &Option, + driver_kind: Option, +) -> Result, Box> { + let Some(config) = config else { + return Ok(None); + }; + let Some(driver_kind) = driver_kind else { + return Ok(None); + }; + let driver_name = driver_kind.as_str(); + let Some(value) = config.fields.get(driver_name) else { + return Ok(None); + }; + match value.kind.as_ref() { + Some(prost_types::value::Kind::StructValue(inner)) => Ok(Some(inner.clone())), + _ => Err(Box::new(Status::invalid_argument(format!( + "template.driver_config.{driver_name} must be an object" + )))), } } @@ -1819,6 +1857,61 @@ mod tests { } } + #[test] + fn select_driver_config_forwards_only_matching_driver_block() { + let config = prost_types::Struct { + fields: [ + ( + "kubernetes".to_string(), + struct_value([("node", string_value("gpu"))]), + ), + ( + "docker".to_string(), + struct_value([("network_mode", string_value("bridge"))]), + ), + ] + .into_iter() + .collect(), + }; + + let selected = + select_driver_config(&Some(config), Some(ComputeDriverKind::Kubernetes)).unwrap(); + let selected = selected.expect("kubernetes config should be selected"); + + assert!(selected.fields.contains_key("node")); + assert!(!selected.fields.contains_key("network_mode")); + } + + #[test] + fn select_driver_config_ignores_non_matching_driver_blocks() { + let config = prost_types::Struct { + fields: std::iter::once(( + "docker".to_string(), + struct_value([("network_mode", string_value("bridge"))]), + )) + .collect(), + }; + + let selected = + select_driver_config(&Some(config), Some(ComputeDriverKind::Kubernetes)).unwrap(); + + assert!(selected.is_none()); + } + + #[test] + fn select_driver_config_rejects_non_object_matching_driver_block() { + let config = prost_types::Struct { + fields: std::iter::once(("kubernetes".to_string(), string_value("not-an-object"))) + .collect(), + }; + + let err = + select_driver_config(&Some(config), Some(ComputeDriverKind::Kubernetes)).unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("template.driver_config.kubernetes")); + } + #[derive(Debug, Default)] struct TestDriver { listed_sandboxes: Vec, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 83658ef7b..268c143d2 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -200,6 +200,14 @@ fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { ))); } } + if let Some(ref s) = tmpl.driver_config { + let size = s.encoded_len(); + if size > MAX_TEMPLATE_STRUCT_SIZE { + return Err(Status::invalid_argument(format!( + "template.driver_config serialized size exceeds maximum ({size} > {MAX_TEMPLATE_STRUCT_SIZE})" + ))); + } + } Ok(()) } diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index b4bbf68a7..229bb1bdb 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -38,6 +38,28 @@ Docker and Podman apply them as runtime limits. Kubernetes applies them as both container requests and limits. The VM driver accepts the fields but currently ignores them. +Sandbox create also accepts experimental driver-owned config through +`--driver-config-json`. The value is a JSON object keyed by driver name. The +gateway forwards only the block for the active driver, so a Kubernetes gateway +receives the `kubernetes` object from a value such as: + +Nested keys inside each driver block use snake_case. The top-level envelope keys +are driver names, such as `kubernetes`, and are not part of the nested schema. + +```shell +openshell sandbox create \ + --driver-config-json '{"kubernetes":{"pod":{"runtime_class_name":"kata-containers","priority_class_name":"batch-low"}}}' \ + -- claude +``` + +Driver config is for fields without a stable public flag. Prefer `--cpu`, +`--memory`, and `--gpu` for portable resource intent. + +For Kubernetes, `pod.runtime_class_name` maps to PodSpec `runtimeClassName`. +It overrides the gateway's configured default runtime class for that sandbox, +while a typed `SandboxTemplate.runtime_class_name` value from the API still +takes precedence. + ## Docker Driver [Docker](https://www.docker.com/get-started/)-backed sandboxes run as containers on the gateway host. Use Docker for local development, single-machine gateways, and hosts that already use Docker Desktop or Docker Engine. diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 8a154c1b3..0db6d7678 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -43,6 +43,25 @@ value as both the request and the limit so the scheduler reserves the same amount the sandbox can use. The VM driver currently accepts these flags but does not change VM allocation. +### Driver-Specific Configuration + +Pass experimental driver-owned settings with `--driver-config-json`. The value +must be a JSON object keyed by driver name. The gateway forwards only the block +for its configured compute driver: + +Nested keys inside each driver block use snake_case. The top-level envelope keys +are driver names, such as `kubernetes`, and are not part of the nested schema. + +```shell +openshell sandbox create \ + --driver-config-json '{"kubernetes":{"pod":{"runtime_class_name":"kata-containers","node_selector":{"pool":"gpu"}}}}' \ + -- claude +``` + +Use this only for driver-specific fields that do not have a stable CLI flag. +Prefer stable flags such as `--cpu`, `--memory`, and `--gpu` when they cover +the same behavior. + ### GPU Resources To request GPU resources, add `--gpu`: diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 610d491c7..190a04e87 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -121,6 +121,10 @@ message DriverSandboxTemplate { // For the Kubernetes driver this carries fields such as runtimeClassName, // annotations, and volumeClaimTemplates. google.protobuf.Struct platform_config = 11; + // Caller-provided config for the selected driver only. + // This is the inner block selected from public SandboxTemplate.driver_config. + // The selected driver owns nested schema validation. + google.protobuf.Struct driver_config = 12; } // Typed compute-resource requirements. diff --git a/proto/openshell.proto b/proto/openshell.proto index a8ead0d31..bd144cba0 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -355,6 +355,11 @@ message SandboxTemplate { // available (beta through 1.35, GA in 1.36+) and a supporting runtime. // When unset, the cluster-wide default is used. optional bool user_namespaces = 10; + // Driver-keyed opaque config envelope supplied by the caller. + // The gateway selects the block matching the active compute driver and + // forwards only that inner Struct to DriverSandboxTemplate.driver_config. + // The selected driver owns nested schema validation. + google.protobuf.Struct driver_config = 11; } // User-facing sandbox status derived by the gateway from compute-driver observations.