diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index a1a0f75b..074b971b 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -177,15 +177,47 @@ pub async fn sync_policy(endpoint: &str, sandbox: &str, policy: &ProtoSandboxPol sync_policy_with_client(&mut client, sandbox, policy).await } +/// Provider environment fetched from the server, indexed by provider type. +pub struct ProviderEnvironment { + /// Env vars indexed by provider type (e.g. `"anthropic"` -> `{"ANTHROPIC_API_KEY": "sk-..."}`). + pub by_type: HashMap>, +} + +impl ProviderEnvironment { + /// Flatten all provider env vars into a single map for injection into the + /// child process. When two different provider types set the same env var, + /// one value wins arbitrarily (iteration order over `HashMap` keys is + /// nondeterministic). + pub fn flatten(self) -> HashMap { + let mut flat = HashMap::new(); + for (_provider_type, env) in self.by_type { + for (key, value) in env { + flat.entry(key).or_insert(value); + } + } + flat + } + + /// Returns the set of provider types present. + pub fn provider_types(&self) -> Vec { + self.by_type.keys().cloned().collect() + } + + /// Check if a specific provider type is present. + pub fn has_provider_type(&self, provider_type: &str) -> bool { + self.by_type.contains_key(provider_type) + } +} + /// Fetch provider environment variables for a sandbox from OpenShell server via gRPC. /// -/// Returns a map of environment variable names to values derived from provider -/// credentials configured on the sandbox. Returns an empty map if the sandbox -/// has no providers or the call fails. +/// Returns provider credentials indexed by provider type. Use +/// [`ProviderEnvironment::flatten`] to merge into a single env var map for +/// injection into the child process. pub async fn fetch_provider_environment( endpoint: &str, sandbox_id: &str, -) -> Result> { +) -> Result { debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Fetching provider environment"); let mut client = connect(endpoint).await?; @@ -197,7 +229,13 @@ pub async fn fetch_provider_environment( .await .into_diagnostic()?; - Ok(response.into_inner().environment) + let inner = response.into_inner(); + let by_type = inner + .providers + .into_iter() + .map(|(provider_type, entry)| (provider_type, entry.environment)) + .collect(); + Ok(ProviderEnvironment { by_type }) } /// A reusable gRPC client for the OpenShell service. diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 8246555b..a03840cb 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -186,22 +186,32 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { + let provider_result = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { match grpc_client::fetch_provider_environment(endpoint, id).await { - Ok(env) => { - info!(env_count = env.len(), "Fetched provider environment"); - env + Ok(result) => { + info!( + provider_types = ?result.provider_types(), + "Fetched provider environment" + ); + result } Err(e) => { warn!(error = %e, "Failed to fetch provider environment, continuing without"); - std::collections::HashMap::new() + grpc_client::ProviderEnvironment { + by_type: std::collections::HashMap::new(), + } } } } else { - std::collections::HashMap::new() + grpc_client::ProviderEnvironment { + by_type: std::collections::HashMap::new(), + } }; - let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); + let has_anthropic = provider_result.has_provider_type("anthropic") + || provider_result.has_provider_type("claude"); + let (provider_env, secret_resolver) = + SecretResolver::from_provider_env(provider_result.flatten()); let secret_resolver = secret_resolver.map(Arc::new); // Create identity cache for SHA256 TOFU when OPA is active @@ -502,6 +512,12 @@ pub async fn run_sandbox( } } + // Write provider-specific config files (e.g., Claude Code onboarding bypass + // for the anthropic provider). Non-fatal: sandbox still starts on failure. + if let Err(e) = write_provider_configs(has_anthropic, &policy) { + warn!(error = %e, "Failed to write provider config files, continuing without"); + } + #[cfg(target_os = "linux")] let mut handle = ProcessHandle::spawn( program, @@ -1179,6 +1195,72 @@ fn prepare_filesystem(_policy: &SandboxPolicy) -> Result<()> { Ok(()) } +/// Write provider-specific configuration files to the sandbox user's home directory. +/// +/// Currently handles the `anthropic` provider type by writing a `.claude.json` +/// file that marks onboarding as complete, allowing Claude Code to start +/// without interactive setup when `ANTHROPIC_API_KEY` is present in the +/// environment. +#[cfg(unix)] +fn write_provider_configs(has_anthropic: bool, policy: &SandboxPolicy) -> Result<()> { + use nix::unistd::{User, chown}; + + if !has_anthropic { + return Ok(()); + } + + // Resolve sandbox user and home directory (same logic as session_user_and_home in ssh.rs). + let user_name = policy.process.run_as_user.as_deref().unwrap_or("sandbox"); + let (uid, gid, home) = { + let user = User::from_name(user_name) + .into_diagnostic()? + .ok_or_else(|| miette::miette!("sandbox user '{user_name}' not found"))?; + let gid = user.gid; + let home = user.dir.to_string_lossy().into_owned(); + (user.uid, gid, home) + }; + + let claude_json_path = std::path::Path::new(&home).join(".claude.json"); + + // Merge into existing .claude.json if present, so we don't clobber + // user-supplied or BYOC-baked configuration. + let mut config: serde_json::Value = if claude_json_path.exists() { + let existing = std::fs::read_to_string(&claude_json_path).into_diagnostic()?; + serde_json::from_str(&existing).unwrap_or_else(|_| serde_json::json!({})) + } else { + serde_json::json!({}) + }; + + if let Some(obj) = config.as_object_mut() { + obj.entry("hasCompletedOnboarding") + .or_insert(serde_json::Value::Bool(true)); + } + + if let Some(parent) = claude_json_path.parent() { + std::fs::create_dir_all(parent).into_diagnostic()?; + } + + std::fs::write( + &claude_json_path, + serde_json::to_string_pretty(&config).unwrap(), + ) + .into_diagnostic()?; + + chown(&claude_json_path, Some(uid), Some(gid)).into_diagnostic()?; + + info!( + path = %claude_json_path.display(), + "Wrote Claude Code config for anthropic provider" + ); + + Ok(()) +} + +#[cfg(not(unix))] +fn write_provider_configs(_has_anthropic: bool, _policy: &SandboxPolicy) -> Result<()> { + Ok(()) +} + /// Background loop that polls the server for policy updates. /// /// When a new version is detected, attempts to reload the OPA engine via diff --git a/crates/openshell-server/src/grpc.rs b/crates/openshell-server/src/grpc.rs index 422b6463..375e890a 100644 --- a/crates/openshell-server/src/grpc.rs +++ b/crates/openshell-server/src/grpc.rs @@ -22,8 +22,8 @@ use openshell_core::proto::{ GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxesRequest, - ListSandboxesResponse, PolicyChunk, PolicyStatus, Provider, ProviderResponse, - PushSandboxLogsRequest, PushSandboxLogsResponse, RejectDraftChunkRequest, + ListSandboxesResponse, PolicyChunk, PolicyStatus, Provider, ProviderEnvironmentEntry, + ProviderResponse, PushSandboxLogsRequest, PushSandboxLogsResponse, RejectDraftChunkRequest, RejectDraftChunkResponse, ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxLogLine, SandboxPolicyRevision, SandboxResponse, SandboxStreamEvent, ServiceStatus, SshSession, SubmitPolicyAnalysisRequest, @@ -810,18 +810,26 @@ impl OpenShell for OpenShellService { .spec .ok_or_else(|| Status::internal("sandbox has no spec"))?; - let environment = + let by_type = resolve_provider_environment(self.state.store.as_ref(), &spec.providers).await?; + let env_count: usize = by_type.values().map(|e| e.len()).sum(); info!( sandbox_id = %sandbox_id, provider_count = spec.providers.len(), - env_count = environment.len(), + env_count = env_count, "GetSandboxProviderEnvironment request completed successfully" ); + let providers = by_type + .into_iter() + .map(|(provider_type, environment)| { + (provider_type, ProviderEnvironmentEntry { environment }) + }) + .collect(); + Ok(Response::new(GetSandboxProviderEnvironmentResponse { - environment, + providers, })) } @@ -2741,21 +2749,25 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> String { } } -/// Resolve provider credentials into environment variables. +/// Resolved provider environment: credential env vars indexed by provider type. +type ProviderEnvByType = + std::collections::HashMap>; + +/// Resolve provider credentials into environment variables, indexed by provider type. /// /// For each provider name in the list, fetches the provider from the store and -/// collects credential key-value pairs. Returns a map of environment variables -/// to inject into the sandbox. When duplicate keys appear across providers, the -/// first provider's value wins. +/// collects credential key-value pairs grouped by provider type. When multiple +/// providers share the same type, their credentials merge under one entry +/// (first value wins on duplicate keys within the same type). async fn resolve_provider_environment( store: &crate::persistence::Store, provider_names: &[String], -) -> Result, Status> { +) -> Result { if provider_names.is_empty() { - return Ok(std::collections::HashMap::new()); + return Ok(ProviderEnvByType::new()); } - let mut env = std::collections::HashMap::new(); + let mut by_type = ProviderEnvByType::new(); for name in provider_names { let provider = store @@ -2764,9 +2776,17 @@ async fn resolve_provider_environment( .map_err(|e| Status::internal(format!("failed to fetch provider '{name}': {e}")))? .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; + let provider_type = if provider.r#type.is_empty() { + "unknown".to_string() + } else { + provider.r#type.clone() + }; + + let type_env = by_type.entry(provider_type).or_default(); + for (key, value) in &provider.credentials { if is_valid_env_key(key) { - env.entry(key.clone()).or_insert_with(|| value.clone()); + type_env.entry(key.clone()).or_insert_with(|| value.clone()); } else { warn!( provider_name = %name, @@ -2777,7 +2797,7 @@ async fn resolve_provider_environment( } } - Ok(env) + Ok(by_type) } fn is_valid_env_key(key: &str) -> bool { @@ -3691,10 +3711,17 @@ mod tests { let result = resolve_provider_environment(&store, &["claude-local".to_string()]) .await .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("CLAUDE_API_KEY"), Some(&"sk-abc".to_string())); + let claude_env = result.get("claude").expect("claude type should be present"); + assert_eq!( + claude_env.get("ANTHROPIC_API_KEY"), + Some(&"sk-abc".to_string()) + ); + assert_eq!( + claude_env.get("CLAUDE_API_KEY"), + Some(&"sk-abc".to_string()) + ); // Config values should NOT be injected. - assert!(!result.contains_key("endpoint")); + assert!(!claude_env.contains_key("endpoint")); } #[tokio::test] @@ -3728,9 +3755,10 @@ mod tests { let result = resolve_provider_environment(&store, &["test-provider".to_string()]) .await .unwrap(); - assert_eq!(result.get("VALID_KEY"), Some(&"value".to_string())); - assert!(!result.contains_key("nested.api_key")); - assert!(!result.contains_key("bad-key")); + let test_env = result.get("test").expect("test type should be present"); + assert_eq!(test_env.get("VALID_KEY"), Some(&"value".to_string())); + assert!(!test_env.contains_key("nested.api_key")); + assert!(!test_env.contains_key("bad-key")); } #[tokio::test] @@ -3772,8 +3800,16 @@ mod tests { ) .await .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("GITLAB_TOKEN"), Some(&"glpat-xyz".to_string())); + let claude_env = result.get("claude").expect("claude type"); + assert_eq!( + claude_env.get("ANTHROPIC_API_KEY"), + Some(&"sk-abc".to_string()) + ); + let gitlab_env = result.get("gitlab").expect("gitlab type"); + assert_eq!( + gitlab_env.get("GITLAB_TOKEN"), + Some(&"glpat-xyz".to_string()) + ); } #[tokio::test] @@ -3784,7 +3820,7 @@ mod tests { Provider { id: String::new(), name: "provider-a".to_string(), - r#type: "claude".to_string(), + r#type: "shared-type".to_string(), credentials: std::iter::once(("SHARED_KEY".to_string(), "first-value".to_string())) .collect(), config: HashMap::new(), @@ -3797,7 +3833,7 @@ mod tests { Provider { id: String::new(), name: "provider-b".to_string(), - r#type: "gitlab".to_string(), + r#type: "shared-type".to_string(), credentials: std::iter::once(( "SHARED_KEY".to_string(), "second-value".to_string(), @@ -3815,7 +3851,57 @@ mod tests { ) .await .unwrap(); - assert_eq!(result.get("SHARED_KEY"), Some(&"first-value".to_string())); + let env = result.get("shared-type").expect("shared-type should exist"); + assert_eq!(env.get("SHARED_KEY"), Some(&"first-value".to_string())); + } + + #[tokio::test] + async fn resolve_provider_env_same_type_merges() { + let store = Store::connect("sqlite::memory:").await.unwrap(); + create_provider_record( + &store, + Provider { + id: String::new(), + name: "anthropic-1".to_string(), + r#type: "anthropic".to_string(), + credentials: std::iter::once(("ANTHROPIC_API_KEY".to_string(), "sk-1".to_string())) + .collect(), + config: HashMap::new(), + }, + ) + .await + .unwrap(); + create_provider_record( + &store, + Provider { + id: String::new(), + name: "anthropic-2".to_string(), + r#type: "anthropic".to_string(), + credentials: std::iter::once(("ANOTHER_KEY".to_string(), "val".to_string())) + .collect(), + config: HashMap::new(), + }, + ) + .await + .unwrap(); + + let result = resolve_provider_environment( + &store, + &["anthropic-1".to_string(), "anthropic-2".to_string()], + ) + .await + .unwrap(); + + assert_eq!( + result.len(), + 1, + "both providers should merge under one type" + ); + let env = result + .get("anthropic") + .expect("anthropic type should exist"); + assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-1".to_string())); + assert_eq!(env.get("ANOTHER_KEY"), Some(&"val".to_string())); } /// Simulates the handler flow: persist a sandbox with providers, then resolve @@ -3866,11 +3952,15 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) + let by_type = resolve_provider_environment(&store, &spec.providers) .await .unwrap(); - assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-test".to_string())); + let claude_env = by_type.get("claude").expect("claude type"); + assert_eq!( + claude_env.get("ANTHROPIC_API_KEY"), + Some(&"sk-test".to_string()) + ); } /// Handler flow returns empty map when sandbox has no providers. diff --git a/proto/openshell.proto b/proto/openshell.proto index ad93848d..5bee9b83 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -425,10 +425,18 @@ message GetSandboxProviderEnvironmentRequest { string sandbox_id = 1; } +// Environment variables for a single provider type. +message ProviderEnvironmentEntry { + // Credential environment variables for this provider type. + map environment = 1; +} + // Get sandbox provider environment response. message GetSandboxProviderEnvironmentResponse { - // Provider credential environment variables. - map environment = 1; + // Provider credential environment variables, indexed by provider type. + // When multiple providers share the same type, their env vars are merged + // under a single entry (first value wins on duplicate keys). + map providers = 1; } // ---------------------------------------------------------------------------