Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,14 @@ enum SandboxCommands {
#[arg(long)]
memory: Option<String>,

/// Experimental driver-keyed JSON object for driver-specific sandbox settings.
/// Validation behavior is not yet finalized.
///
/// For Kubernetes, pass a value such as
/// `{"kubernetes":{"pod":{"node_selector":{"pool":"gpu"}}}}`.
#[arg(long, value_name = "JSON")]
driver_config_json: Option<String>,

/// Provider names to attach to this sandbox.
#[arg(long = "provider")]
providers: Vec<String>,
Expand Down Expand Up @@ -2541,6 +2549,7 @@ async fn main() -> Result<()> {
gpu_device,
cpu,
memory,
driver_config_json,
providers,
policy,
forward,
Expand Down Expand Up @@ -2621,6 +2630,7 @@ async fn main() -> Result<()> {
gpu_device.as_deref(),
cpu.as_deref(),
memory.as_deref(),
driver_config_json.as_deref(),
editor,
&providers,
policy.as_deref(),
Expand Down Expand Up @@ -4335,6 +4345,32 @@ mod tests {
}
}

#[test]
fn sandbox_create_driver_config_json_flag_parses() {
let json = r#"{"kubernetes":{"pod":{"node_selector":{"pool":"gpu"}}}}"#;
let cli = Cli::try_parse_from([
"openshell",
"sandbox",
"create",
"--driver-config-json",
json,
])
.expect("sandbox create driver config JSON flag should parse");

match cli.command {
Some(Commands::Sandbox {
command:
Some(SandboxCommands::Create {
driver_config_json, ..
}),
..
}) => {
assert_eq!(driver_config_json.as_deref(), Some(json));
}
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
}
}

#[test]
fn service_expose_accepts_positional_target_port_and_service() {
let cli = Cli::try_parse_from([
Expand Down
113 changes: 107 additions & 6 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,52 @@ fn build_sandbox_resource_limits(
Ok(Some(Struct { fields }))
}

fn parse_driver_config_json(value: &str) -> Result<prost_types::Struct> {
let parsed: serde_json::Value = serde_json::from_str(value)
.into_diagnostic()
.wrap_err("--driver-config-json must be valid JSON")?;

let serde_json::Value::Object(fields) = parsed else {
return Err(miette!(
"--driver-config-json must be a JSON object keyed by driver name"
));
};

Ok(prost_types::Struct {
fields: fields
.into_iter()
.map(|(key, value)| json_to_protobuf_value(value).map(|value| (key, value)))
.collect::<Result<_>>()?,
})
}

fn json_to_protobuf_value(value: serde_json::Value) -> Result<prost_types::Value> {
use prost_types::{ListValue, Struct, Value, value::Kind};

let kind = match value {
serde_json::Value::Null => Kind::NullValue(0),
serde_json::Value::Bool(value) => Kind::BoolValue(value),
serde_json::Value::Number(value) => Kind::NumberValue(value.as_f64().ok_or_else(|| {
miette!("--driver-config-json contains a number that cannot be represented")
})?),
serde_json::Value::String(value) => Kind::StringValue(value),
serde_json::Value::Array(values) => Kind::ListValue(ListValue {
values: values
.into_iter()
.map(json_to_protobuf_value)
.collect::<Result<_>>()?,
}),
serde_json::Value::Object(fields) => Kind::StructValue(Struct {
fields: fields
.into_iter()
.map(|(key, value)| json_to_protobuf_value(value).map(|value| (key, value)))
.collect::<Result<_>>()?,
}),
};

Ok(Value { kind: Some(kind) })
}

fn validate_cpu_quantity(value: &str) -> Result<String> {
let value = value.trim();
if value.is_empty() {
Expand Down Expand Up @@ -1701,6 +1747,7 @@ pub async fn sandbox_create(
gpu_device: Option<&str>,
cpu: Option<&str>,
memory: Option<&str>,
driver_config_json: Option<&str>,
editor: Option<Editor>,
providers: &[String],
policy: Option<&str>,
Expand Down Expand Up @@ -1770,11 +1817,15 @@ pub async fn sandbox_create(

let policy = load_sandbox_policy(policy)?;
let resource_limits = build_sandbox_resource_limits(cpu, memory)?;
let driver_config = driver_config_json
.map(parse_driver_config_json)
.transpose()?;

let template = if image.is_some() || resource_limits.is_some() {
let template = if image.is_some() || resource_limits.is_some() || driver_config.is_some() {
Some(SandboxTemplate {
image: image.unwrap_or_default(),
resources: resource_limits,
driver_config,
..SandboxTemplate::default()
})
} else {
Expand Down Expand Up @@ -7471,11 +7522,11 @@ mod tests {
git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle,
inferred_provider_type, mtls_certs_exist_for_gateway, package_managed_tls_dirs,
parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs,
parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata,
provider_profile_allows_refresh_bootstrap, provisioning_timeout_message,
ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from,
sandbox_should_persist, sandbox_upload_plan, service_expose_status_error,
service_url_for_gateway,
parse_credential_pairs, parse_driver_config_json, plaintext_gateway_is_remote,
progress_step_from_metadata, provider_profile_allows_refresh_bootstrap,
provisioning_timeout_message, ready_false_condition_message, refresh_status_header,
refresh_status_row, resolve_from, sandbox_should_persist, sandbox_upload_plan,
service_expose_status_error, service_url_for_gateway,
};
use crate::TEST_ENV_LOCK;
use hyper::StatusCode;
Expand Down Expand Up @@ -7873,6 +7924,56 @@ mod tests {
assert!(build_sandbox_resource_limits(None, Some("1.5Gi")).is_err());
}

#[test]
fn parse_driver_config_json_accepts_driver_keyed_object() {
let config =
parse_driver_config_json(r#"{"kubernetes":{"pod":{"node_selector":{"pool":"gpu"}}}}"#)
.expect("driver config should parse");

let kubernetes = config
.fields
.get("kubernetes")
.and_then(|value| value.kind.as_ref())
.and_then(|kind| match kind {
prost_types::value::Kind::StructValue(inner) => Some(inner),
_ => None,
})
.expect("kubernetes block should be a struct");
let pod = kubernetes
.fields
.get("pod")
.and_then(|value| value.kind.as_ref())
.and_then(|kind| match kind {
prost_types::value::Kind::StructValue(inner) => Some(inner),
_ => None,
})
.expect("pod block should be a struct");

assert!(pod.fields.contains_key("node_selector"));
}

#[test]
fn parse_driver_config_json_rejects_non_object() {
let err = parse_driver_config_json(r#"["kubernetes"]"#)
.expect_err("top-level array should be rejected");

assert!(
err.to_string().contains("keyed by driver name"),
"unexpected error: {err}"
);
}

#[test]
fn parse_driver_config_json_rejects_invalid_json() {
let err = parse_driver_config_json(r#"{"kubernetes":"#)
.expect_err("invalid JSON should be rejected");

assert!(
err.to_string().contains("must be valid JSON"),
"unexpected error: {err}"
);
}

#[test]
fn inferred_provider_type_returns_type_for_known_command() {
let result = inferred_provider_type(&["claude".to_string(), "--help".to_string()]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -829,6 +830,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() {
Some("500m"),
Some("2Gi"),
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -884,6 +886,79 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() {
assert!(!resources.fields.contains_key("requests"));
}

#[tokio::test]
async fn sandbox_create_sends_driver_config_json() {
let server = run_server().await;
let fake_ssh_dir = tempfile::tempdir().unwrap();
let xdg_dir = tempfile::tempdir().unwrap();
let _env = test_env(&fake_ssh_dir, &xdg_dir);
let tls = test_tls(&server);
install_fake_ssh(&fake_ssh_dir);

run::sandbox_create(
&server.endpoint,
Some("driver-config"),
None,
"openshell",
&[],
true,
false,
None,
None,
None,
Some(r#"{"kubernetes":{"pod":{"priority_class_name":"batch-low"}}}"#),
None,
&[],
None,
None,
&["echo".to_string(), "OK".to_string()],
Some(false),
Some(false),
&HashMap::new(),
"manual",
&tls,
)
.await
.expect("sandbox create should succeed");

let requests = create_requests(&server).await;
let driver_config = requests[0]
.spec
.as_ref()
.and_then(|spec| spec.template.as_ref())
.and_then(|template| template.driver_config.as_ref())
.expect("driver config should be sent");
let kubernetes = driver_config
.fields
.get("kubernetes")
.and_then(|value| value.kind.as_ref())
.and_then(|kind| match kind {
prost_types::value::Kind::StructValue(inner) => Some(inner),
_ => None,
})
.expect("kubernetes block should be a struct");
let pod = kubernetes
.fields
.get("pod")
.and_then(|value| value.kind.as_ref())
.and_then(|kind| match kind {
prost_types::value::Kind::StructValue(inner) => Some(inner),
_ => None,
})
.expect("pod block should be a struct");

assert_eq!(
pod.fields
.get("priority_class_name")
.and_then(|value| value.kind.as_ref())
.and_then(|kind| match kind {
prost_types::value::Kind::StringValue(value) => Some(value.as_str()),
_ => None,
}),
Some("batch-low")
);
}

#[tokio::test]
async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() {
let server = run_server().await;
Expand All @@ -906,6 +981,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -963,6 +1039,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -1016,6 +1093,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -1061,6 +1139,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -1102,6 +1181,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -1147,6 +1227,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -1192,6 +1273,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() {
None,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -1237,6 +1319,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() {
None,
None,
None,
None,
&[],
None,
Some(openshell_core::forward::ForwardSpec::new(forward_port)),
Expand Down
Loading
Loading