diff --git a/crates/ai/src/acp/bridge_init.rs b/crates/ai/src/acp/bridge_init.rs index 921ce4f43..00f4c1dda 100644 --- a/crates/ai/src/acp/bridge_init.rs +++ b/crates/ai/src/acp/bridge_init.rs @@ -326,7 +326,7 @@ async fn bootstrap_session( load_session(connection.clone(), cwd.clone(), existing_session_id.clone()).await; if let Ok(Err(err)) = &load_result - && matches!(err.code, acp::ErrorCode::AuthRequired) + && is_auth_required(err) { if let Err(e) = authenticate(connection.clone()).await { ctx.io_handle.abort(); @@ -342,16 +342,13 @@ async fn bootstrap_session( Ok(Ok(load_response)) => { log::info!("ACP session loaded: {}", existing_session_id); client.set_session_id(existing_session_id.clone()).await; - return Ok(SessionBootstrap { - session_id: Some(acp::SessionId::new(existing_session_id)), - initial_modes: load_response.modes.map(map_mode_state), - initial_config_options: load_response.config_options.map(&ctx.map_config_options), - }); + return Ok(bootstrap_from_loaded_session( + existing_session_id, + load_response, + &ctx.map_config_options, + )); } - Ok(Err(err)) - if matches!(err.code, acp::ErrorCode::MethodNotFound) - && ctx.supports_session_resume => - { + Ok(Err(err)) if should_try_resume_after_load(&err, ctx.supports_session_resume) => { log::warn!( "ACP session/load unavailable ({}), trying session/resume", err @@ -360,7 +357,7 @@ async fn bootstrap_session( resume_session(connection.clone(), cwd.clone(), existing_session_id.clone()).await; if let Ok(Err(err)) = &resume_result - && matches!(err.code, acp::ErrorCode::AuthRequired) + && is_auth_required(err) { if let Err(e) = authenticate(connection.clone()).await { ctx.io_handle.abort(); @@ -377,20 +374,13 @@ async fn bootstrap_session( Ok(Ok(resume_response)) => { log::info!("ACP session resumed: {}", existing_session_id); client.set_session_id(existing_session_id.clone()).await; - return Ok(SessionBootstrap { - session_id: Some(acp::SessionId::new(existing_session_id)), - initial_modes: resume_response.modes.map(map_mode_state), - initial_config_options: resume_response - .config_options - .map(&ctx.map_config_options), - }); + return Ok(bootstrap_from_resumed_session( + existing_session_id, + resume_response, + &ctx.map_config_options, + )); } - Ok(Err(err)) - if matches!( - err.code, - acp::ErrorCode::MethodNotFound | acp::ErrorCode::ResourceNotFound - ) => - { + Ok(Err(err)) if should_fallback_to_new_after_resume(&err) => { log::warn!( "ACP session/resume unavailable or session missing ({}), falling back to \ session/new", @@ -415,12 +405,7 @@ async fn bootstrap_session( } } } - Ok(Err(err)) - if matches!( - err.code, - acp::ErrorCode::MethodNotFound | acp::ErrorCode::ResourceNotFound - ) => - { + Ok(Err(err)) if should_fallback_to_new_after_load(&err) => { log::warn!( "ACP session/load unavailable or session missing ({}), falling back to session/new", err @@ -447,7 +432,7 @@ async fn bootstrap_session( let mut session_result = create_session(connection.clone(), cwd.clone()).await; if let Ok(Err(err)) = &session_result - && matches!(err.code, acp::ErrorCode::AuthRequired) + && is_auth_required(err) { if let Err(e) = authenticate(connection.clone()).await { ctx.io_handle.abort(); @@ -480,11 +465,64 @@ async fn bootstrap_session( log::info!("ACP session created: {}", session.session_id); client.set_session_id(session.session_id.to_string()).await; - Ok(SessionBootstrap { + Ok(bootstrap_from_new_session(session, &ctx.map_config_options)) +} + +fn is_auth_required(err: &acp::Error) -> bool { + matches!(err.code, acp::ErrorCode::AuthRequired) +} + +fn should_try_resume_after_load(err: &acp::Error, supports_session_resume: bool) -> bool { + supports_session_resume && matches!(err.code, acp::ErrorCode::MethodNotFound) +} + +fn should_fallback_to_new_after_load(err: &acp::Error) -> bool { + matches!( + err.code, + acp::ErrorCode::MethodNotFound | acp::ErrorCode::ResourceNotFound + ) +} + +fn should_fallback_to_new_after_resume(err: &acp::Error) -> bool { + matches!( + err.code, + acp::ErrorCode::MethodNotFound | acp::ErrorCode::ResourceNotFound + ) +} + +fn bootstrap_from_loaded_session( + existing_session_id: String, + load_response: acp::LoadSessionResponse, + map_config_options: &impl Fn(Vec) -> Vec, +) -> SessionBootstrap { + SessionBootstrap { + session_id: Some(acp::SessionId::new(existing_session_id)), + initial_modes: load_response.modes.map(map_mode_state), + initial_config_options: load_response.config_options.map(map_config_options), + } +} + +fn bootstrap_from_resumed_session( + existing_session_id: String, + resume_response: acp::ResumeSessionResponse, + map_config_options: &impl Fn(Vec) -> Vec, +) -> SessionBootstrap { + SessionBootstrap { + session_id: Some(acp::SessionId::new(existing_session_id)), + initial_modes: resume_response.modes.map(map_mode_state), + initial_config_options: resume_response.config_options.map(map_config_options), + } +} + +fn bootstrap_from_new_session( + session: acp::NewSessionResponse, + map_config_options: &impl Fn(Vec) -> Vec, +) -> SessionBootstrap { + SessionBootstrap { session_id: Some(session.session_id), initial_modes: session.modes.map(map_mode_state), - initial_config_options: session.config_options.map(ctx.map_config_options), - }) + initial_config_options: session.config_options.map(map_config_options), + } } async fn create_session( @@ -570,3 +608,89 @@ fn emit_initial_session_state( log::warn!("Failed to emit initial session config options: {}", e); } } + +#[cfg(test)] +mod tests { + use super::*; + + fn no_config_options(_: Vec) -> Vec { + Vec::new() + } + + #[test] + fn loaded_session_bootstrap_preserves_requested_session_id() { + let bootstrap = bootstrap_from_loaded_session( + "existing-session".to_string(), + acp::LoadSessionResponse::new(), + &no_config_options, + ); + + assert_eq!( + bootstrap.session_id.map(|id| id.to_string()), + Some("existing-session".to_string()) + ); + assert!(bootstrap.initial_modes.is_none()); + assert!(bootstrap.initial_config_options.is_none()); + } + + #[test] + fn method_not_found_load_uses_resume_only_when_supported() { + let err = acp::Error::method_not_found(); + + assert!(should_try_resume_after_load(&err, true)); + assert!(!should_try_resume_after_load(&err, false)); + } + + #[test] + fn missing_or_unsupported_load_falls_back_to_new_session() { + assert!(should_fallback_to_new_after_load( + &acp::Error::method_not_found() + )); + assert!(should_fallback_to_new_after_load( + &acp::Error::resource_not_found(None) + )); + assert!(!should_fallback_to_new_after_load( + &acp::Error::invalid_params() + )); + } + + #[test] + fn missing_or_unsupported_resume_falls_back_to_new_session() { + assert!(should_fallback_to_new_after_resume( + &acp::Error::method_not_found() + )); + assert!(should_fallback_to_new_after_resume( + &acp::Error::resource_not_found(None) + )); + assert!(!should_fallback_to_new_after_resume( + &acp::Error::internal_error() + )); + } + + #[test] + fn auth_required_errors_are_retriable_before_session_fallbacks() { + assert!(is_auth_required(&acp::Error::auth_required())); + assert!(!is_auth_required(&acp::Error::method_not_found())); + assert!(!should_fallback_to_new_after_load( + &acp::Error::auth_required() + )); + assert!(!should_fallback_to_new_after_resume( + &acp::Error::auth_required() + )); + } + + #[test] + fn new_session_bootstrap_uses_agent_created_session_id() { + let bootstrap = bootstrap_from_new_session( + acp::NewSessionResponse::new("created-session"), + &no_config_options, + ); + + assert_eq!( + bootstrap.session_id.map(|id| id.to_string()), + Some("created-session".to_string()) + ); + assert!(bootstrap.initial_modes.is_none()); + assert!(bootstrap.initial_config_options.is_none()); + } +} diff --git a/crates/ai/src/acp/client.rs b/crates/ai/src/acp/client.rs index 4b3ac6739..6695ea712 100644 --- a/crates/ai/src/acp/client.rs +++ b/crates/ai/src/acp/client.rs @@ -61,6 +61,51 @@ impl AthasAcpClient { self.permission_tx.clone() } + fn map_permission_option_kind( + kind: acp::PermissionOptionKind, + ) -> super::types::AcpPermissionOptionKind { + match kind { + acp::PermissionOptionKind::AllowOnce => super::types::AcpPermissionOptionKind::AllowOnce, + acp::PermissionOptionKind::AllowAlways => { + super::types::AcpPermissionOptionKind::AllowAlways + } + acp::PermissionOptionKind::RejectOnce => super::types::AcpPermissionOptionKind::RejectOnce, + acp::PermissionOptionKind::RejectAlways => { + super::types::AcpPermissionOptionKind::RejectAlways + } + _ => super::types::AcpPermissionOptionKind::RejectOnce, + } + } + + fn permission_response_for_choice( + options: &[acp::PermissionOption], + approved: bool, + ) -> acp::RequestPermissionResponse { + let selected_option = options + .iter() + .find(|opt| { + if approved { + matches!( + opt.kind, + acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways + ) + } else { + matches!( + opt.kind, + acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways + ) + } + }) + .or_else(|| options.first()) + .map(|opt| acp::SelectedPermissionOutcome::new(opt.option_id.clone())); + + if let Some(selected) = selected_option { + acp::RequestPermissionResponse::new(acp::RequestPermissionOutcome::Selected(selected)) + } else { + acp::RequestPermissionResponse::new(acp::RequestPermissionOutcome::Cancelled) + } + } + pub async fn set_session_id(&self, session_id: String) { let mut current = self.current_session_id.lock().await; *current = Some(session_id); @@ -209,23 +254,7 @@ impl AthasAcpClient { fn fallback_permission_response( args: &acp::RequestPermissionRequest, ) -> acp::RequestPermissionResponse { - let selected_option = args - .options - .iter() - .find(|opt| { - matches!( - opt.kind, - acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways - ) - }) - .or_else(|| args.options.first()) - .map(|opt| acp::SelectedPermissionOutcome::new(opt.option_id.clone())); - - if let Some(selected) = selected_option { - acp::RequestPermissionResponse::new(acp::RequestPermissionOutcome::Selected(selected)) - } else { - acp::RequestPermissionResponse::new(acp::RequestPermissionOutcome::Cancelled) - } + Self::permission_response_for_choice(&args.options, false) } fn map_plan_priority(priority: acp::PlanEntryPriority) -> AcpPlanEntryPriority { @@ -348,6 +377,14 @@ impl AthasAcpClient { .collect() } + fn map_usage_update(session_id: String, update: acp::UsageUpdate) -> AcpEvent { + AcpEvent::UsageUpdate { + session_id, + used: update.used, + size: update.size, + } + } + pub(crate) fn map_session_config_option( option: acp::SessionConfigOption, ) -> Option { @@ -418,21 +455,7 @@ impl acp::Client for AthasAcpClient { .map(|option| super::types::AcpPermissionOption { id: option.option_id.to_string(), name: option.name.clone(), - kind: match option.kind { - acp::PermissionOptionKind::AllowOnce => { - super::types::AcpPermissionOptionKind::AllowOnce - } - acp::PermissionOptionKind::AllowAlways => { - super::types::AcpPermissionOptionKind::AllowAlways - } - acp::PermissionOptionKind::RejectOnce => { - super::types::AcpPermissionOptionKind::RejectOnce - } - acp::PermissionOptionKind::RejectAlways => { - super::types::AcpPermissionOptionKind::RejectAlways - } - _ => super::types::AcpPermissionOptionKind::RejectOnce, - }, + kind: Self::map_permission_option_kind(option.kind), }) .collect(), }); @@ -486,53 +509,9 @@ impl acp::Client for AthasAcpClient { return Ok(Self::fallback_permission_response(&args)); } - // Prefer allow-once/allow-always options if available - let selected_option = args - .options - .iter() - .find(|opt| { - matches!( - opt.kind, - acp::PermissionOptionKind::AllowOnce - | acp::PermissionOptionKind::AllowAlways - ) - }) - .or_else(|| args.options.first()) - .map(|opt| acp::SelectedPermissionOutcome::new(opt.option_id.clone())); - - if let Some(selected) = selected_option { - Ok(acp::RequestPermissionResponse::new( - acp::RequestPermissionOutcome::Selected(selected), - )) - } else { - Ok(acp::RequestPermissionResponse::new( - acp::RequestPermissionOutcome::Cancelled, - )) - } + Ok(Self::permission_response_for_choice(&args.options, true)) } else { - // Prefer reject-once/reject-always options if available - let selected_option = args - .options - .iter() - .find(|opt| { - matches!( - opt.kind, - acp::PermissionOptionKind::RejectOnce - | acp::PermissionOptionKind::RejectAlways - ) - }) - .or_else(|| args.options.first()) - .map(|opt| acp::SelectedPermissionOutcome::new(opt.option_id.clone())); - - if let Some(selected) = selected_option { - Ok(acp::RequestPermissionResponse::new( - acp::RequestPermissionOutcome::Selected(selected), - )) - } else { - Ok(acp::RequestPermissionResponse::new( - acp::RequestPermissionOutcome::Cancelled, - )) - } + Ok(Self::permission_response_for_choice(&args.options, false)) } } _ => Ok(acp::RequestPermissionResponse::new( @@ -695,6 +674,9 @@ impl acp::Client for AthasAcpClient { updated_at: update.updated_at.take(), }); } + acp::SessionUpdate::UsageUpdate(update) => { + self.emit_event(Self::map_usage_update(session_id, update)); + } acp::SessionUpdate::AvailableCommandsUpdate(commands_update) => { self.emit_event(AcpEvent::SlashCommandsUpdate { session_id, @@ -1161,3 +1143,101 @@ impl acp::Client for AthasAcpClient { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn permission_option( + id: &'static str, + kind: acp::PermissionOptionKind, + ) -> acp::PermissionOption { + acp::PermissionOption::new(id, id, kind) + } + + fn selected_option_id(response: acp::RequestPermissionResponse) -> Option { + match response.outcome { + acp::RequestPermissionOutcome::Selected(selected) => Some(selected.option_id.to_string()), + acp::RequestPermissionOutcome::Cancelled => None, + _ => None, + } + } + + #[test] + fn usage_update_maps_to_frontend_event() { + let event = AthasAcpClient::map_usage_update( + "session-1".to_string(), + acp::UsageUpdate::new(1234, 200000), + ); + + match event { + AcpEvent::UsageUpdate { + session_id, + used, + size, + } => { + assert_eq!(session_id, "session-1"); + assert_eq!(used, 1234); + assert_eq!(size, 200000); + } + other => panic!("expected usage update event, got {other:?}"), + } + } + + #[test] + fn permission_option_kinds_map_to_frontend_kinds() { + assert!(matches!( + AthasAcpClient::map_permission_option_kind(acp::PermissionOptionKind::AllowOnce), + super::super::types::AcpPermissionOptionKind::AllowOnce + )); + assert!(matches!( + AthasAcpClient::map_permission_option_kind(acp::PermissionOptionKind::AllowAlways), + super::super::types::AcpPermissionOptionKind::AllowAlways + )); + assert!(matches!( + AthasAcpClient::map_permission_option_kind(acp::PermissionOptionKind::RejectOnce), + super::super::types::AcpPermissionOptionKind::RejectOnce + )); + assert!(matches!( + AthasAcpClient::map_permission_option_kind(acp::PermissionOptionKind::RejectAlways), + super::super::types::AcpPermissionOptionKind::RejectAlways + )); + } + + #[test] + fn approved_permission_prefers_allow_options() { + let options = vec![ + permission_option("reject-once", acp::PermissionOptionKind::RejectOnce), + permission_option("allow-always", acp::PermissionOptionKind::AllowAlways), + ]; + + let response = AthasAcpClient::permission_response_for_choice(&options, true); + + assert_eq!( + selected_option_id(response).as_deref(), + Some("allow-always") + ); + } + + #[test] + fn rejected_permission_prefers_reject_options() { + let options = vec![ + permission_option("allow-once", acp::PermissionOptionKind::AllowOnce), + permission_option("reject-always", acp::PermissionOptionKind::RejectAlways), + ]; + + let response = AthasAcpClient::permission_response_for_choice(&options, false); + + assert_eq!( + selected_option_id(response).as_deref(), + Some("reject-always") + ); + } + + #[test] + fn permission_choice_cancels_when_options_are_empty() { + let response = AthasAcpClient::permission_response_for_choice(&[], true); + + assert_eq!(selected_option_id(response), None); + } +} diff --git a/crates/ai/src/acp/types.rs b/crates/ai/src/acp/types.rs index c6e9af22a..8754c99fe 100644 --- a/crates/ai/src/acp/types.rs +++ b/crates/ai/src/acp/types.rs @@ -486,6 +486,13 @@ pub enum AcpEvent { title: Option, updated_at: Option, }, + /// Session token/context usage updated + #[serde(rename_all = "camelCase")] + UsageUpdate { + session_id: String, + used: u64, + size: u64, + }, /// Prompt turn completed with a stop reason #[serde(rename_all = "camelCase")] PromptComplete { diff --git a/src-tauri/src/commands/ai/acp.rs b/src-tauri/src/commands/ai/acp.rs index 7f0021d9b..2c1d10332 100644 --- a/src-tauri/src/commands/ai/acp.rs +++ b/src-tauri/src/commands/ai/acp.rs @@ -15,6 +15,8 @@ use tokio::sync::Mutex; pub type AcpBridgeState = Arc>; const EXTENSIONS_CDN_BASE_URL: &str = "https://athas.dev/extensions"; +const ACP_REGISTRY_URL: &str = + "https://cdn.agentclientprotocol.com/registry/v1/latest/registry.json"; const AGENT_CATALOG_CACHE_SECONDS: u64 = 300; #[derive(Deserialize)] @@ -30,15 +32,17 @@ pub struct PermissionResponseArgs { #[tauri::command] pub async fn get_available_agents( + app_handle: AppHandle, bridge: State<'_, AcpBridgeState>, ) -> Result, String> { let mut bridge = bridge.lock().await; - refresh_registered_agents(&mut bridge).await; + refresh_registered_agents(&app_handle, &mut bridge).await; Ok(bridge.detect_agents()) } #[tauri::command] pub async fn start_acp_agent( + app_handle: AppHandle, bridge: State<'_, AcpBridgeState>, agent_id: String, workspace_path: Option, @@ -46,7 +50,7 @@ pub async fn start_acp_agent( ) -> Result { let bridge = { let mut bridge = bridge.lock().await; - refresh_registered_agents(&mut bridge).await; + refresh_registered_agents(&app_handle, &mut bridge).await; bridge.detect_agents(); bridge.clone() }; @@ -64,7 +68,7 @@ pub async fn install_acp_agent( ) -> Result { let agent = { let mut bridge = bridge.lock().await; - refresh_registered_agents(&mut bridge).await; + refresh_registered_agents(&app_handle, &mut bridge).await; let agents = bridge.detect_agents(); agents .into_iter() @@ -97,7 +101,7 @@ pub async fn uninstall_acp_agent( ) -> Result { let agent = { let mut bridge = bridge.lock().await; - refresh_registered_agents(&mut bridge).await; + refresh_registered_agents(&app_handle, &mut bridge).await; bridge.invalidate_agent_detection_cache(); let agents = bridge.detect_agents(); agents @@ -163,12 +167,57 @@ struct MarketplaceExtensionManifest { agents: Vec, } +#[derive(Deserialize)] +struct AcpRegistryIndex { + #[serde(default)] + agents: Vec, +} + +#[derive(Deserialize)] +struct AcpRegistryAgent { + id: String, + name: String, + description: String, + icon: Option, + distribution: AcpRegistryDistribution, +} + +#[derive(Deserialize)] +struct AcpRegistryDistribution { + binary: Option>, + npx: Option, + uvx: Option, +} + +#[derive(Deserialize)] +struct AcpRegistryBinaryTarget { + archive: String, + cmd: String, + #[serde(default)] + args: Vec, + #[serde(default)] + env: HashMap, +} + +#[derive(Deserialize)] +struct AcpRegistryPackageTarget { + package: String, + #[serde(default)] + args: Vec, + #[serde(default)] + env: HashMap, +} + fn extensions_manifest_url() -> String { let base_url = std::env::var("ATHAS_EXTENSIONS_CDN_URL") .unwrap_or_else(|_| EXTENSIONS_CDN_BASE_URL.to_string()); format!("{}/manifests.json", base_url.trim_end_matches('/')) } +fn acp_registry_url() -> String { + std::env::var("ATHAS_ACP_REGISTRY_URL").unwrap_or_else(|_| ACP_REGISTRY_URL.to_string()) +} + fn current_platform_arch() -> Option<&'static str> { match (std::env::consts::OS, std::env::consts::ARCH) { ("macos", "aarch64") => Some("darwin-arm64"), @@ -181,6 +230,33 @@ fn current_platform_arch() -> Option<&'static str> { } } +fn current_acp_registry_platform() -> Option<&'static str> { + match (std::env::consts::OS, std::env::consts::ARCH) { + ("macos", "aarch64") => Some("darwin-aarch64"), + ("macos", "x86_64") => Some("darwin-x86_64"), + ("linux", "aarch64") => Some("linux-aarch64"), + ("linux", "x86_64") => Some("linux-x86_64"), + ("windows", "aarch64") => Some("windows-aarch64"), + ("windows", "x86_64") => Some("windows-x86_64"), + _ => None, + } +} + +fn registry_command_name(cmd: &str, fallback: &str) -> String { + Path::new(cmd) + .file_name() + .and_then(|name| name.to_str()) + .map(|name| { + if cfg!(windows) { + name.strip_suffix(".exe").unwrap_or(name).to_string() + } else { + name.to_string() + } + }) + .filter(|name| !name.is_empty()) + .unwrap_or_else(|| fallback.to_string()) +} + fn to_agent_config(contribution: MarketplaceAgentContribution) -> AgentConfig { let mut agent = AgentConfig { id: contribution.id, @@ -217,7 +293,184 @@ fn to_agent_config(contribution: MarketplaceAgentContribution) -> AgentConfig { agent } -async fn load_marketplace_agents() -> Result, String> { +fn acp_registry_agent_to_config(agent: AcpRegistryAgent) -> Option { + let AcpRegistryAgent { + id, + name, + description, + icon, + distribution, + } = agent; + + if let Some(target) = current_acp_registry_platform() + .and_then(|platform| distribution.binary.as_ref()?.get(platform)) + { + let binary_name = registry_command_name(&target.cmd, &id); + return Some(AgentConfig { + id, + name, + binary_name, + binary_path: None, + args: target.args.clone(), + env_vars: target.env.clone(), + icon, + description: Some(description), + installed: false, + install_runtime: Some(AgentRuntime::Binary), + install_package: Some(target.cmd.clone()), + install_download_url: Some(target.archive.clone()), + install_command: Some(registry_command_name(&target.cmd, "")), + can_install: true, + }); + } + + if let Some(target) = distribution.npx { + let mut args = vec!["-y".to_string(), target.package.clone()]; + args.extend(target.args.clone()); + return Some(AgentConfig { + id, + name, + binary_name: "npx".to_string(), + binary_path: None, + args, + env_vars: target.env, + icon, + description: Some(description), + installed: false, + install_runtime: None, + install_package: None, + install_download_url: None, + install_command: None, + can_install: false, + }); + } + + if let Some(target) = distribution.uvx { + let mut args = vec![target.package]; + args.extend(target.args); + return Some(AgentConfig { + id, + name, + binary_name: "uvx".to_string(), + binary_path: None, + args, + env_vars: target.env, + icon, + description: Some(description), + installed: false, + install_runtime: None, + install_package: None, + install_download_url: None, + install_command: None, + can_install: false, + }); + } + + None +} + +fn acp_registry_agents_from_index(index: AcpRegistryIndex) -> Vec { + let mut agents = index + .agents + .into_iter() + .filter_map(acp_registry_agent_to_config) + .collect::>(); + agents.sort_by_key(|agent| agent.name.clone()); + agents +} + +fn acp_registry_cache_path(app_handle: &AppHandle) -> Result { + app_handle + .path() + .app_data_dir() + .map(|dir| dir.join("acp-registry").join("registry.json")) + .map_err(|error| format!("Failed to resolve ACP registry cache path: {}", error)) +} + +fn acp_registry_agents_from_json(json: &str) -> Result, String> { + let registry = serde_json::from_str::(json) + .map_err(|error| format!("Invalid ACP registry: {}", error))?; + Ok(acp_registry_agents_from_index(registry)) +} + +fn load_cached_acp_registry_agents(app_handle: &AppHandle) -> Result, String> { + let cache_path = acp_registry_cache_path(app_handle)?; + let json = fs::read_to_string(&cache_path) + .map_err(|error| format!("Failed to read cached ACP registry: {}", error))?; + acp_registry_agents_from_json(&json) + .map_err(|error| format!("Invalid cached ACP registry: {}", error)) +} + +fn write_acp_registry_cache(app_handle: &AppHandle, json: &str) -> Result<(), String> { + let cache_path = acp_registry_cache_path(app_handle)?; + if let Some(parent) = cache_path.parent() { + fs::create_dir_all(parent) + .map_err(|error| format!("Failed to create ACP registry cache directory: {}", error))?; + } + fs::write(&cache_path, json) + .map_err(|error| format!("Failed to write ACP registry cache: {}", error)) +} + +async fn load_acp_registry_agents(app_handle: &AppHandle) -> Result, String> { + let response = reqwest::Client::new() + .get(acp_registry_url()) + .timeout(Duration::from_secs(5)) + .send() + .await + .map_err(|error| format!("Failed to load ACP registry: {}", error))?; + + if !response.status().is_success() { + return Err(format!( + "Failed to load ACP registry: HTTP {}", + response.status() + )); + } + + let json = response + .text() + .await + .map_err(|error| format!("Failed to read ACP registry response: {}", error))?; + + let agents = acp_registry_agents_from_json(&json)?; + if let Err(error) = write_acp_registry_cache(app_handle, &json) { + log::warn!("{}", error); + } + + Ok(agents) +} + +async fn load_preferred_registry_agents( + app_handle: &AppHandle, +) -> Result, String> { + match load_acp_registry_agents(app_handle).await { + Ok(agents) => Ok(agents), + Err(registry_error) => { + log::warn!("{}", registry_error); + load_cached_acp_registry_agents(app_handle).map_err(|cache_error| { + log::warn!("{}", cache_error); + registry_error + }) + } + } +} + +fn merge_agent_catalogs( + mut preferred_agents: Vec, + fallback_agents: Vec, +) -> Vec { + for agent in fallback_agents { + if !preferred_agents + .iter() + .any(|preferred| preferred.id == agent.id) + { + preferred_agents.push(agent); + } + } + preferred_agents.sort_by_key(|agent| agent.name.clone()); + preferred_agents +} + +async fn load_marketplace_agents(app_handle: &AppHandle) -> Result, String> { let cache = AGENT_CATALOG_CACHE.get_or_init(|| std::sync::Mutex::new(None)); { let cached = cache @@ -230,6 +483,40 @@ async fn load_marketplace_agents() -> Result, String> { } } + let registry_agents = load_preferred_registry_agents(app_handle).await; + let legacy_agents = load_legacy_marketplace_agents().await; + let agents = match (registry_agents, legacy_agents) { + (Ok(registry_agents), Ok(legacy_agents)) => { + merge_agent_catalogs(registry_agents, legacy_agents) + } + (Ok(registry_agents), Err(legacy_error)) => { + log::warn!("{}", legacy_error); + registry_agents + } + (Err(registry_error), Ok(legacy_agents)) => { + log::warn!("{}", registry_error); + legacy_agents + } + (Err(registry_error), Err(legacy_error)) => { + return Err(format!( + "{}; legacy agent catalog also failed: {}", + registry_error, legacy_error + )); + } + }; + + let mut cached = cache + .lock() + .map_err(|_| "Agent catalog cache poisoned".to_string())?; + *cached = Some(CachedAgentCatalog { + loaded_at: Instant::now(), + agents: agents.clone(), + }); + + Ok(agents) +} + +async fn load_legacy_marketplace_agents() -> Result, String> { let response = reqwest::Client::new() .get(extensions_manifest_url()) .timeout(Duration::from_secs(5)) @@ -256,19 +543,11 @@ async fn load_marketplace_agents() -> Result, String> { .collect::>(); agents.sort_by_key(|agent| agent.name.clone()); - let mut cached = cache - .lock() - .map_err(|_| "Agent catalog cache poisoned".to_string())?; - *cached = Some(CachedAgentCatalog { - loaded_at: Instant::now(), - agents: agents.clone(), - }); - Ok(agents) } -async fn refresh_registered_agents(bridge: &mut AcpAgentBridge) { - match load_marketplace_agents().await { +async fn refresh_registered_agents(app_handle: &AppHandle, bridge: &mut AcpAgentBridge) { + match load_marketplace_agents(app_handle).await { Ok(agents) => bridge.replace_registered_agents(agents), Err(error) => { log::warn!("{}", error); @@ -534,3 +813,224 @@ fn make_wrapper_executable(path: &PathBuf) -> Result<(), String> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + fn parse_registry(json: &str) -> Vec { + acp_registry_agents_from_json(json).expect("registry fixture") + } + + fn test_agent(id: &str, name: &str) -> AgentConfig { + AgentConfig { + id: id.to_string(), + name: name.to_string(), + binary_name: id.to_string(), + binary_path: None, + args: vec![], + env_vars: HashMap::new(), + icon: None, + description: None, + installed: false, + install_runtime: None, + install_package: None, + install_download_url: None, + install_command: None, + can_install: false, + } + } + + #[test] + fn acp_registry_maps_binary_agent_for_current_platform() { + let agents = parse_registry( + r#"{ + "agents": [ + { + "id": "codex-acp", + "name": "Codex CLI", + "version": "0.14.0", + "description": "ACP adapter for OpenAI's coding assistant", + "distribution": { + "binary": { + "darwin-aarch64": { + "archive": "https://example.com/codex-aarch64.tar.gz", + "cmd": "./codex-acp", + "args": ["--acp"], + "env": { "CODEX_HOME": "/tmp/codex" } + }, + "darwin-x86_64": { + "archive": "https://example.com/codex-x64.tar.gz", + "cmd": "./codex-acp" + }, + "linux-x86_64": { + "archive": "https://example.com/codex-linux.tar.gz", + "cmd": "./codex-acp" + }, + "windows-x86_64": { + "archive": "https://example.com/codex.zip", + "cmd": "./codex-acp.exe" + } + } + }, + "icon": "https://example.com/codex.svg" + } + ] + }"#, + ); + + let agent = agents + .into_iter() + .find(|agent| agent.id == "codex-acp") + .expect("codex agent"); + + assert_eq!(agent.name, "Codex CLI"); + assert_eq!(agent.install_runtime, Some(AgentRuntime::Binary)); + assert_eq!(agent.binary_name, "codex-acp"); + assert!(agent.can_install); + assert!(agent.install_download_url.is_some()); + assert_eq!(agent.icon.as_deref(), Some("https://example.com/codex.svg")); + } + + #[test] + fn acp_registry_maps_npx_agent_to_npx_launch() { + let agents = parse_registry( + r#"{ + "agents": [ + { + "id": "claude-acp", + "name": "Claude Agent", + "version": "0.33.1", + "description": "ACP wrapper for Claude", + "distribution": { + "npx": { + "package": "@agentclientprotocol/claude-agent-acp@0.33.1", + "args": ["--verbose"], + "env": { "ANTHROPIC_HOME": "/tmp/claude" } + } + } + } + ] + }"#, + ); + + let agent = agents + .into_iter() + .find(|agent| agent.id == "claude-acp") + .expect("claude agent"); + + assert_eq!(agent.binary_name, "npx"); + assert_eq!(agent.install_runtime, None); + assert_eq!( + agent.args, + vec![ + "-y".to_string(), + "@agentclientprotocol/claude-agent-acp@0.33.1".to_string(), + "--verbose".to_string() + ] + ); + assert_eq!(agent.install_package.as_deref(), None); + assert!(!agent.can_install); + } + + #[test] + fn acp_registry_includes_uvx_agents_without_managed_install_claim() { + let agents = parse_registry( + r#"{ + "agents": [ + { + "id": "fast-agent", + "name": "fast-agent", + "version": "0.7.0", + "description": "Code and build agents", + "distribution": { + "uvx": { + "package": "fast-agent-acp==0.7.0", + "args": ["-x"] + } + } + } + ] + }"#, + ); + + let agent = agents + .into_iter() + .find(|agent| agent.id == "fast-agent") + .expect("uvx agent"); + + assert_eq!(agent.binary_name, "uvx"); + assert_eq!( + agent.args, + vec!["fast-agent-acp==0.7.0".to_string(), "-x".to_string()] + ); + assert_eq!(agent.install_runtime, None); + assert!(!agent.can_install); + } + + #[test] + fn acp_registry_prefers_current_platform_binary_over_npx() { + let agents = parse_registry( + r#"{ + "agents": [ + { + "id": "kilo", + "name": "Kilo", + "version": "7.2.40", + "description": "Kilo ACP", + "distribution": { + "binary": { + "darwin-aarch64": { + "archive": "https://example.com/kilo-aarch64.tar.gz", + "cmd": "./kilo" + }, + "darwin-x86_64": { + "archive": "https://example.com/kilo-x64.tar.gz", + "cmd": "./kilo" + }, + "linux-x86_64": { + "archive": "https://example.com/kilo-linux.tar.gz", + "cmd": "./kilo" + }, + "windows-x86_64": { + "archive": "https://example.com/kilo.zip", + "cmd": "kilo.exe" + } + }, + "npx": { + "package": "kilo-code@7.2.40" + } + } + } + ] + }"#, + ); + + let agent = agents + .into_iter() + .find(|agent| agent.id == "kilo") + .expect("kilo agent"); + + assert_eq!(agent.install_runtime, Some(AgentRuntime::Binary)); + assert_eq!(agent.binary_name, "kilo"); + } + + #[test] + fn merge_agent_catalogs_preserves_legacy_agents() { + let agents = merge_agent_catalogs( + vec![test_agent("codex-acp", "Codex Registry")], + vec![ + test_agent("codex-acp", "Codex Legacy"), + test_agent("athas-legacy", "Athas Legacy"), + ], + ); + + assert_eq!(agents.len(), 2); + assert!(agents.iter().any(|agent| agent.id == "athas-legacy")); + let codex = agents + .iter() + .find(|agent| agent.id == "codex-acp") + .expect("codex entry"); + assert_eq!(codex.name, "Codex Registry"); + } +} diff --git a/src/features/ai/services/acp-stream-handler.ts b/src/features/ai/services/acp-stream-handler.ts index 177e830c3..48e7eced1 100644 --- a/src/features/ai/services/acp-stream-handler.ts +++ b/src/features/ai/services/acp-stream-handler.ts @@ -311,6 +311,9 @@ export class AcpStreamHandler { case "session_info_update": break; + case "usage_update": + break; + case "prompt_complete": this.handlePromptComplete(event); break; @@ -327,6 +330,7 @@ export class AcpStreamHandler { // The stop reason can be used to determine how to handle the completion if (event.stopReason === "cancelled") { // User cancelled the prompt + this.finalizeActiveToolsAsCancelled(); this.cleanup(); this.handlers.onComplete(); return; @@ -542,10 +546,19 @@ export class AcpStreamHandler { } } + private finalizeActiveToolsAsCancelled(): void { + if (!this.handlers.onToolComplete || this.activeTools.size === 0) return; + + for (const [toolId, toolName] of this.activeTools) { + this.handlers.onToolComplete(toolName, toolId, undefined, "Cancelled"); + } + } + private forceStop(): void { if (this.sessionComplete || this.cancelled) return; this.cancelled = true; this.pendingNewMessage = false; + this.finalizeActiveToolsAsCancelled(); this.cleanup(); this.handlers.onComplete(); } diff --git a/src/features/ai/tests/acp-cancellation.test.ts b/src/features/ai/tests/acp-cancellation.test.ts new file mode 100644 index 000000000..50d700174 --- /dev/null +++ b/src/features/ai/tests/acp-cancellation.test.ts @@ -0,0 +1,157 @@ +import { afterEach, describe, expect, it, vi } from "vite-plus/test"; +import { invoke } from "@tauri-apps/api/core"; +import { AcpStreamHandler } from "../services/acp-stream-handler"; +import type { AcpEvent } from "../types/acp"; + +vi.mock("@tauri-apps/api/core", () => ({ + invoke: vi.fn(), +})); + +vi.mock("@tauri-apps/api/event", () => ({ + listen: vi.fn(), +})); + +vi.mock("@/features/ai/store/store", () => ({ + useAIChatStore: { + getState: () => ({ + acpStatus: null, + getChatById: () => null, + getCurrentChat: () => null, + setAcpStatus: vi.fn(), + setChatAcpSessionId: vi.fn(), + setAvailableSlashCommands: vi.fn(), + setSessionConfigOptions: vi.fn(), + setSessionModeState: vi.fn(), + setCurrentModeId: vi.fn(), + }), + }, +})); + +vi.mock("@/features/editor/stores/buffer-store", () => ({ + useBufferStore: { + getState: () => ({ + actions: { + openWebViewerBuffer: vi.fn(), + openTerminalBuffer: vi.fn(), + }, + }), + }, +})); + +vi.mock("@/features/window/stores/project-store", () => ({ + useProjectStore: { + getState: () => ({ + rootFolderPath: "/repo", + }), + }, +})); + +vi.mock("@/features/ai/lib/acp-session-info", () => ({ + getChatTitleFromSessionInfo: (_currentTitle: string | undefined, nextTitle: string) => nextTitle, +})); + +vi.mock("../utils/ai-context-builder", () => ({ + buildContextPrompt: () => "", +})); + +const mockedInvoke = vi.mocked(invoke); + +type TestableAcpStreamHandler = { + handleAcpEvent: (event: unknown) => void; +}; + +type AcpStreamHandlerStatic = { + activeHandler: AcpStreamHandler | null; +}; + +const setActiveHandler = (handler: AcpStreamHandler | null) => { + (AcpStreamHandler as unknown as AcpStreamHandlerStatic).activeHandler = handler; +}; + +const handleAcpEvent = (handler: AcpStreamHandler, event: AcpEvent) => { + (handler as unknown as TestableAcpStreamHandler).handleAcpEvent(event); +}; + +describe("AcpStreamHandler cancellation", () => { + afterEach(() => { + mockedInvoke.mockReset(); + setActiveHandler(null); + }); + + it("finalizes active tools before sending backend cancellation", async () => { + const onComplete = vi.fn(); + const onToolComplete = vi.fn(); + const handler = new AcpStreamHandler( + "codex", + { + onChunk: vi.fn(), + onComplete, + onError: vi.fn(), + onToolComplete, + }, + "chat-1", + ); + + handleAcpEvent(handler, { + type: "tool_start", + sessionId: "session-1", + toolName: "read_text_file", + toolId: "tool-1", + input: { path: "src/main.ts" }, + kind: "read", + status: "in_progress", + locations: [], + }); + setActiveHandler(handler); + + await AcpStreamHandler.cancelPrompt(); + + expect(onToolComplete).toHaveBeenCalledWith("read_text_file", "tool-1", undefined, "Cancelled"); + expect(onComplete).toHaveBeenCalledOnce(); + expect(mockedInvoke).toHaveBeenCalledWith("cancel_acp_prompt"); + }); + + it("ignores late events after a cancelled turn is force-stopped", async () => { + const onComplete = vi.fn(); + const onToolComplete = vi.fn(); + const onToolUse = vi.fn(); + const handler = new AcpStreamHandler( + "codex", + { + onChunk: vi.fn(), + onComplete, + onError: vi.fn(), + onToolUse, + onToolComplete, + }, + "chat-1", + ); + + handleAcpEvent(handler, { + type: "tool_start", + sessionId: "session-1", + toolName: "read_text_file", + toolId: "tool-1", + input: { path: "src/main.ts" }, + kind: "read", + status: "in_progress", + locations: [], + }); + setActiveHandler(handler); + + await AcpStreamHandler.cancelPrompt(); + handleAcpEvent(handler, { + type: "tool_start", + sessionId: "session-1", + toolName: "write_text_file", + toolId: "tool-2", + input: { path: "src/main.ts" }, + kind: "edit", + status: "in_progress", + locations: [], + }); + + expect(onToolUse).toHaveBeenCalledOnce(); + expect(onToolComplete).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/features/ai/tests/acp-permission.test.ts b/src/features/ai/tests/acp-permission.test.ts new file mode 100644 index 000000000..84cdcf1af --- /dev/null +++ b/src/features/ai/tests/acp-permission.test.ts @@ -0,0 +1,139 @@ +import { afterEach, describe, expect, it, vi } from "vite-plus/test"; +import { invoke } from "@tauri-apps/api/core"; +import { AcpStreamHandler } from "../services/acp-stream-handler"; +import type { AcpEvent } from "../types/acp"; + +vi.mock("@tauri-apps/api/core", () => ({ + invoke: vi.fn(), +})); + +vi.mock("@tauri-apps/api/event", () => ({ + listen: vi.fn(), +})); + +vi.mock("@/features/ai/store/store", () => ({ + useAIChatStore: { + getState: () => ({ + acpStatus: null, + getChatById: () => null, + getCurrentChat: () => null, + setAcpStatus: vi.fn(), + setChatAcpSessionId: vi.fn(), + setAvailableSlashCommands: vi.fn(), + setSessionConfigOptions: vi.fn(), + setSessionModeState: vi.fn(), + setCurrentModeId: vi.fn(), + }), + }, +})); + +vi.mock("@/features/editor/stores/buffer-store", () => ({ + useBufferStore: { + getState: () => ({ + actions: { + openWebViewerBuffer: vi.fn(), + openTerminalBuffer: vi.fn(), + }, + }), + }, +})); + +vi.mock("@/features/window/stores/project-store", () => ({ + useProjectStore: { + getState: () => ({ + rootFolderPath: "/repo", + }), + }, +})); + +vi.mock("@/features/ai/lib/acp-session-info", () => ({ + getChatTitleFromSessionInfo: (_currentTitle: string | undefined, nextTitle: string) => nextTitle, +})); + +vi.mock("../utils/ai-context-builder", () => ({ + buildContextPrompt: () => "", +})); + +type TestableAcpStreamHandler = { + handleAcpEvent: (event: unknown) => void; +}; + +const mockedInvoke = vi.mocked(invoke); + +const handleAcpEvent = (handler: AcpStreamHandler, event: AcpEvent) => { + (handler as unknown as TestableAcpStreamHandler).handleAcpEvent(event); +}; + +const permissionEvent: AcpEvent = { + type: "permission_request", + requestId: "request-1", + permissionType: "tool_call", + resource: "tool-1", + description: "Run command (tool-1)", + options: [ + { + id: "allow-once", + name: "Allow once", + kind: "allow_once", + }, + { + id: "reject-once", + name: "Reject once", + kind: "reject_once", + }, + ], +}; + +describe("AcpStreamHandler permission requests", () => { + afterEach(() => { + mockedInvoke.mockReset(); + vi.clearAllMocks(); + }); + + it("routes permission requests to the permission handler", () => { + const onEvent = vi.fn(); + const onPermissionRequest = vi.fn(); + const handler = new AcpStreamHandler( + "codex", + { + onChunk: vi.fn(), + onComplete: vi.fn(), + onError: vi.fn(), + onEvent, + onPermissionRequest, + }, + "chat-1", + ); + + handleAcpEvent(handler, permissionEvent); + + expect(onEvent).toHaveBeenCalledWith(permissionEvent); + expect(onPermissionRequest).toHaveBeenCalledWith(permissionEvent); + expect(mockedInvoke).not.toHaveBeenCalled(); + }); + + it("auto-rejects permission requests when no permission handler is registered", async () => { + mockedInvoke.mockResolvedValue(undefined); + const handler = new AcpStreamHandler( + "codex", + { + onChunk: vi.fn(), + onComplete: vi.fn(), + onError: vi.fn(), + }, + "chat-1", + ); + + handleAcpEvent(handler, permissionEvent); + await vi.waitFor(() => { + expect(mockedInvoke).toHaveBeenCalledWith("respond_acp_permission", { + args: { + requestId: "request-1", + approved: false, + cancelled: false, + optionId: undefined, + }, + }); + }); + }); +}); diff --git a/src/features/ai/tests/acp-usage.test.ts b/src/features/ai/tests/acp-usage.test.ts new file mode 100644 index 000000000..9d8c6a5ab --- /dev/null +++ b/src/features/ai/tests/acp-usage.test.ts @@ -0,0 +1,104 @@ +import { afterEach, describe, expect, it, vi } from "vite-plus/test"; +import { AcpStreamHandler } from "../services/acp-stream-handler"; +import type { AcpEvent } from "../types/acp"; + +vi.mock("@tauri-apps/api/core", () => ({ + invoke: vi.fn(), +})); + +vi.mock("@tauri-apps/api/event", () => ({ + listen: vi.fn(), +})); + +vi.mock("@/features/ai/store/store", () => ({ + useAIChatStore: { + getState: () => ({ + acpStatus: null, + getChatById: () => null, + getCurrentChat: () => null, + setAcpStatus: vi.fn(), + setChatAcpSessionId: vi.fn(), + setAvailableSlashCommands: vi.fn(), + setSessionConfigOptions: vi.fn(), + setSessionModeState: vi.fn(), + setCurrentModeId: vi.fn(), + }), + }, +})); + +vi.mock("@/features/editor/stores/buffer-store", () => ({ + useBufferStore: { + getState: () => ({ + actions: { + openWebViewerBuffer: vi.fn(), + openTerminalBuffer: vi.fn(), + }, + }), + }, +})); + +vi.mock("@/features/window/stores/project-store", () => ({ + useProjectStore: { + getState: () => ({ + rootFolderPath: "/repo", + }), + }, +})); + +vi.mock("@/features/ai/lib/acp-session-info", () => ({ + getChatTitleFromSessionInfo: (_currentTitle: string | undefined, nextTitle: string) => nextTitle, +})); + +vi.mock("../utils/ai-context-builder", () => ({ + buildContextPrompt: () => "", +})); + +type TestableAcpStreamHandler = { + handleAcpEvent: (event: unknown) => void; +}; + +const handleAcpEvent = (handler: AcpStreamHandler, event: AcpEvent) => { + (handler as unknown as TestableAcpStreamHandler).handleAcpEvent(event); +}; + +describe("AcpStreamHandler usage updates", () => { + afterEach(() => { + vi.clearAllMocks(); + }); + + it("passes usage updates through the generic ACP event stream without mutating chat output", () => { + const onEvent = vi.fn(); + const onChunk = vi.fn(); + const onToolUse = vi.fn(); + const onToolComplete = vi.fn(); + const onComplete = vi.fn(); + const onError = vi.fn(); + const handler = new AcpStreamHandler( + "codex", + { + onChunk, + onComplete, + onError, + onEvent, + onToolUse, + onToolComplete, + }, + "chat-1", + ); + const event: AcpEvent = { + type: "usage_update", + sessionId: "session-1", + used: 1234, + size: 200000, + }; + + handleAcpEvent(handler, event); + + expect(onEvent).toHaveBeenCalledWith(event); + expect(onChunk).not.toHaveBeenCalled(); + expect(onToolUse).not.toHaveBeenCalled(); + expect(onToolComplete).not.toHaveBeenCalled(); + expect(onComplete).not.toHaveBeenCalled(); + expect(onError).not.toHaveBeenCalled(); + }); +}); diff --git a/src/features/ai/types/acp.ts b/src/features/ai/types/acp.ts index aea4140de..53774e765 100644 --- a/src/features/ai/types/acp.ts +++ b/src/features/ai/types/acp.ts @@ -268,6 +268,12 @@ export type AcpEvent = title: string | null; updatedAt: string | null; } + | { + type: "usage_update"; + sessionId: string; + used: number; + size: number; + } | { type: "prompt_complete"; sessionId: string;