diff --git a/bigtop-manager-ai/pom.xml b/bigtop-manager-ai/pom.xml index 6eb68ca83..6c1fbbd73 100644 --- a/bigtop-manager-ai/pom.xml +++ b/bigtop-manager-ai/pom.xml @@ -67,6 +67,10 @@ org.apache.commons commons-lang3 + + org.springframework.ai + spring-ai-starter-mcp-client-webflux + diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java index 38951c8ec..7912a7e1b 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java @@ -20,6 +20,7 @@ import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig; import org.apache.bigtop.manager.ai.assistant.provider.ChatMemoryStoreProvider; +import org.apache.bigtop.manager.ai.config.McpAsyncClientManager; import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory; import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.PlatformType; @@ -34,11 +35,15 @@ import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + import jakarta.annotation.Resource; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; @Component +@Slf4j public class GeneralAssistantFactory extends AbstractAIAssistantFactory { @Resource @@ -47,6 +52,9 @@ public class GeneralAssistantFactory extends AbstractAIAssistantFactory { @Resource private ChatMemoryStoreProvider chatMemoryStoreProvider; + @Resource + private McpAsyncClientManager mcpAsyncClientManager; + private void configureSystemPrompt(AIAssistant.Builder builder, SystemPrompt systemPrompt, String locale) { List systemPrompts = new ArrayList<>(); if (systemPrompt != null) { @@ -68,7 +76,15 @@ private AIAssistant.Builder initializeBuilder(PlatformType platformType) { } @Override - public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvider, SystemPrompt systemPrompt) { + public AIAssistant createWithPrompt(AIAssistantConfig config, SystemPrompt systemPrompt) { + return createWithPrompt(config, systemPrompt, null); + } + + @Override + public AIAssistant createWithPrompt( + AIAssistantConfig config, + SystemPrompt systemPrompt, + Consumer toolExecutionListener) { GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config; PlatformType platformType = generalAssistantConfig.getPlatformType(); Object id = generalAssistantConfig.getId(); @@ -80,6 +96,15 @@ public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvide builder.id(id) .memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore(id)) .withConfig(generalAssistantConfig); + builder.withToolExecutionListener(toolExecutionListener); + + List mcpAsyncClients = mcpAsyncClientManager.getClients(); + if (!mcpAsyncClients.isEmpty()) { + log.info("MCP clients available for platform {} (chat), count={}", platformType, mcpAsyncClients.size()); + builder.withMcpClients(mcpAsyncClients); + } else { + log.info("MCP client unavailable for platform {} (chat)", platformType); + } configureSystemPrompt(builder, systemPrompt, generalAssistantConfig.getLanguage()); @@ -87,7 +112,7 @@ public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvide } @Override - public AIAssistant createForTest(AIAssistantConfig config, Object toolProvider) { + public AIAssistant createForTest(AIAssistantConfig config) { GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config; PlatformType platformType = generalAssistantConfig.getPlatformType(); AIAssistant.Builder builder = initializeBuilder(platformType); @@ -96,6 +121,34 @@ public AIAssistant createForTest(AIAssistantConfig config, Object toolProvider) .memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore()) .withConfig(generalAssistantConfig); + List mcpAsyncClients = mcpAsyncClientManager.getClients(); + if (!mcpAsyncClients.isEmpty()) { + log.info("MCP clients available for platform {} (test), count={}", platformType, mcpAsyncClients.size()); + builder.withMcpClients(mcpAsyncClients); + } else { + log.info("MCP client unavailable for platform {} (test)", platformType); + } + return builder.build(); } + + @Override + public List getModels(AIAssistantConfig config) { + GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config; + PlatformType platformType = generalAssistantConfig.getPlatformType(); + try { + AIAssistant.Builder builder = initializeBuilder(platformType); + builder.withConfig(generalAssistantConfig); + List models = builder.getModels(); + if (models != null && !models.isEmpty()) { + log.info("Fetched {} dynamic models for platform {}.", models.size(), platformType); + return models; + } + } catch (Exception e) { + log.warn("Failed to fetch dynamic models from platform {}: {}", platformType, e.getMessage()); + } + + log.info("No dynamic models for platform {}, fallback to default models.", platformType); + return java.util.Collections.emptyList(); + } } diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManager.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManager.java new file mode 100644 index 000000000..90ca8661c --- /dev/null +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManager.java @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.ai.config; + +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.client.WebClient; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import lombok.extern.slf4j.Slf4j; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Component +@Slf4j +public class McpAsyncClientManager { + + @Value("${spring.ai.mcp.client.enabled:false}") + private boolean enabled; + + @Value("${spring.ai.mcp.client.request-timeout-seconds:120}") + private long requestTimeoutSeconds; + + @Value("${spring.ai.mcp.client.init-timeout-seconds:10}") + private long initTimeoutSeconds; + + @Value("${spring.ai.mcp.client.connections:}") + private String connections; + + private final ObjectMapper objectMapper = new ObjectMapper(); + private final McpJsonMapper mcpJsonMapper = new JacksonMcpJsonMapper(objectMapper); + + private final ObjectProvider sseClientPropertiesProvider; + private final ObjectProvider streamableHttpClientPropertiesProvider; + private final ObjectProvider stdioClientPropertiesProvider; + + private List clients = Collections.emptyList(); + private boolean initialized = false; + private boolean disabledLogged = false; + + public McpAsyncClientManager( + ObjectProvider sseClientPropertiesProvider, + ObjectProvider streamableHttpClientPropertiesProvider, + ObjectProvider stdioClientPropertiesProvider) { + this.sseClientPropertiesProvider = sseClientPropertiesProvider; + this.streamableHttpClientPropertiesProvider = streamableHttpClientPropertiesProvider; + this.stdioClientPropertiesProvider = stdioClientPropertiesProvider; + } + + public synchronized McpAsyncClient getClient() { + List allClients = getClients(); + if (allClients.isEmpty()) { + return null; + } + return allClients.get(0); + } + + public synchronized List getClients() { + if (!enabled) { + if (!disabledLogged) { + log.info("MCP Async Client is disabled by config (spring.ai.mcp.client.enabled=false)"); + disabledLogged = true; + } + return Collections.emptyList(); + } + + if (!initialized) { + initialized = true; + clients = initializeClients(); + log.info("MCP Async Clients initialized: {}", clients.size()); + } + + return clients; + } + + private List initializeClients() { + List parsedConnections = parseConnections(); + + if (parsedConnections.isEmpty()) { + parsedConnections.addAll(fromSpringMcpProperties()); + } + + if (parsedConnections.isEmpty()) { + log.warn( + "No MCP connections configured. Please use spring.ai.mcp.client.connections or spring.ai.mcp.client..connections."); + return Collections.emptyList(); + } + + List initializedClients = new ArrayList<>(); + for (McpConnection connection : parsedConnections) { + McpAsyncClient client = initializeClient(connection); + if (client != null) { + initializedClients.add(client); + } + } + return initializedClients; + } + + private List fromSpringMcpProperties() { + List conns = new ArrayList<>(); + + McpSseClientProperties sseProperties = sseClientPropertiesProvider.getIfAvailable(); + if (sseProperties != null && sseProperties.getConnections() != null) { + sseProperties.getConnections().forEach((name, params) -> { + if (params != null) { + conns.add(McpConnection.sse(name, params.url(), params.sseEndpoint())); + } + }); + } + + McpStreamableHttpClientProperties streamableHttpProperties = + streamableHttpClientPropertiesProvider.getIfAvailable(); + if (streamableHttpProperties != null && streamableHttpProperties.getConnections() != null) { + streamableHttpProperties.getConnections().forEach((name, params) -> { + if (params != null) { + conns.add(McpConnection.streamableHttp(name, params.url(), params.endpoint())); + } + }); + } + + McpStdioClientProperties stdioProperties = stdioClientPropertiesProvider.getIfAvailable(); + if (stdioProperties != null && stdioProperties.getConnections() != null) { + stdioProperties.getConnections().forEach((name, params) -> { + if (params != null) { + conns.add(McpConnection.local(name, params.command(), params.args(), params.env())); + } + }); + } + + return conns; + } + + private McpAsyncClient initializeClient(McpConnection connection) { + try { + McpClientTransport transport = buildTransport(connection); + if (transport == null) { + log.warn("Skip MCP connection {} due to unsupported type {}", connection.name, connection.type); + return null; + } + + Duration requestTimeout = Duration.ofSeconds(Math.max( + connection.requestTimeoutSeconds > 0 ? connection.requestTimeoutSeconds : requestTimeoutSeconds, + 1)); + Duration initTimeout = Duration.ofSeconds(Math.max( + connection.initTimeoutSeconds > 0 ? connection.initTimeoutSeconds : initTimeoutSeconds, 1)); + + McpAsyncClient client = + McpClient.async(transport).requestTimeout(requestTimeout).build(); + + client.initialize().block(initTimeout); + client.ping().block(initTimeout); + + log.info( + "MCP Async Client [{}] initialized, type={}, baseUrl={}, endpoint={}, requestTimeout={}, initTimeout={}", + connection.name, + connection.type, + connection.baseUrl, + connection.endpoint, + requestTimeout, + initTimeout); + logRegisteredTools(connection.name, client, initTimeout); + return client; + } catch (Exception e) { + log.warn("Failed to initialize MCP client [{}]: {}", connection.name, e.getMessage()); + return null; + } + } + + private McpClientTransport buildTransport(McpConnection connection) { + String type = connection.type.toLowerCase(); + return switch (type) { + case "sse" -> buildWebFluxSseTransport(connection); + case "streamable-http", "http", "http-streamable" -> buildWebFluxStreamableHttpTransport(connection); + case "local", "stdio" -> buildStdioTransport(connection); + default -> null; + }; + } + + private McpClientTransport buildWebFluxSseTransport(McpConnection connection) { + String baseUrl = connection.baseUrl; + if (!StringUtils.hasText(baseUrl)) { + log.warn("MCP SSE connection {} missing baseUrl", connection.name); + return null; + } + String endpoint = StringUtils.hasText(connection.endpoint) ? connection.endpoint : "/mcp/sse"; + WebClient.Builder webClientBuilder = WebClient.builder().baseUrl(baseUrl); + return WebFluxSseClientTransport.builder(webClientBuilder) + .jsonMapper(mcpJsonMapper) + .sseEndpoint(endpoint) + .build(); + } + + private McpClientTransport buildWebFluxStreamableHttpTransport(McpConnection connection) { + String baseUrl = connection.baseUrl; + if (!StringUtils.hasText(baseUrl)) { + log.warn("MCP streamable-http connection {} missing baseUrl", connection.name); + return null; + } + String endpoint = StringUtils.hasText(connection.endpoint) ? connection.endpoint : "/mcp"; + WebClient.Builder webClientBuilder = WebClient.builder().baseUrl(baseUrl); + return WebClientStreamableHttpTransport.builder(webClientBuilder) + .jsonMapper(mcpJsonMapper) + .endpoint(endpoint) + .openConnectionOnStartup(true) + .build(); + } + + private McpClientTransport buildStdioTransport(McpConnection connection) { + if (!StringUtils.hasText(connection.command)) { + log.warn("MCP stdio/local connection {} missing command", connection.name); + return null; + } + + ServerParameters.Builder serverBuilder = ServerParameters.builder(connection.command); + if (!connection.args.isEmpty()) { + serverBuilder.args(connection.args); + } + if (!connection.env.isEmpty()) { + serverBuilder.env(connection.env); + } + + StdioClientTransport stdioTransport = new StdioClientTransport(serverBuilder.build(), mcpJsonMapper); + stdioTransport.setStdErrorHandler(line -> log.warn("MCP stdio [{}] stderr: {}", connection.name, line)); + return stdioTransport; + } + + private void logRegisteredTools(String connectionName, McpAsyncClient client, Duration initTimeout) { + try { + List toolNames = new ArrayList<>(); + String cursor = null; + while (true) { + McpSchema.ListToolsResult listToolsResult = cursor == null || cursor.isBlank() + ? client.listTools().block(initTimeout) + : client.listTools(cursor).block(initTimeout); + if (listToolsResult == null + || listToolsResult.tools() == null + || listToolsResult.tools().isEmpty()) { + break; + } + + for (McpSchema.Tool tool : listToolsResult.tools()) { + if (tool != null && tool.name() != null) { + toolNames.add(tool.name()); + } + } + + String nextCursor = listToolsResult.nextCursor(); + if (!StringUtils.hasText(nextCursor) || nextCursor.equals(cursor)) { + break; + } + cursor = nextCursor; + } + + log.info("MCP tools discovered for [{}]: count={}, names={}", connectionName, toolNames.size(), toolNames); + } catch (Exception e) { + log.warn("Failed to list MCP tools for [{}]: {}", connectionName, e.getMessage()); + } + } + + private List parseConnections() { + if (!StringUtils.hasText(connections)) { + return new ArrayList<>(); + } + + List parsed = new ArrayList<>(); + String[] segments = connections.split(";"); + for (int i = 0; i < segments.length; i++) { + String segment = segments[i].trim(); + if (!StringUtils.hasText(segment)) { + continue; + } + + Map kv = new HashMap<>(); + String[] pairs = segment.split(","); + for (String pair : pairs) { + String[] items = pair.split("=", 2); + if (items.length != 2) { + continue; + } + kv.put(items[0].trim().toLowerCase(), items[1].trim()); + } + + String name = valueOrDefault(kv.get("name"), "conn-" + (i + 1)); + String type = valueOrDefault(kv.get("type"), "sse"); + String baseUrl = kv.get("baseurl"); + String endpoint = kv.get("endpoint"); + String command = kv.get("command"); + List args = parseList(kv.get("args")); + Map env = parseEnv(kv.get("env")); + long connectionRequestTimeout = parseLongOrDefault(kv.get("requesttimeoutseconds"), -1); + long connectionInitTimeout = parseLongOrDefault(kv.get("inittimeoutseconds"), -1); + + parsed.add(new McpConnection( + name, + type, + baseUrl, + endpoint, + command, + args, + env, + connectionRequestTimeout, + connectionInitTimeout)); + } + + return parsed; + } + + private static long parseLongOrDefault(String raw, long defaultValue) { + if (!StringUtils.hasText(raw)) { + return defaultValue; + } + try { + return Long.parseLong(raw); + } catch (NumberFormatException e) { + return defaultValue; + } + } + + private static List parseList(String raw) { + if (!StringUtils.hasText(raw)) { + return Collections.emptyList(); + } + List list = new ArrayList<>(); + for (String item : raw.split("\\|")) { + String trimmed = item.trim(); + if (StringUtils.hasText(trimmed)) { + list.add(trimmed); + } + } + return list; + } + + private static Map parseEnv(String raw) { + if (!StringUtils.hasText(raw)) { + return Collections.emptyMap(); + } + Map env = new HashMap<>(); + for (String item : raw.split("\\|")) { + String[] pair = item.split(":", 2); + if (pair.length == 2 && StringUtils.hasText(pair[0])) { + env.put(pair[0].trim(), pair[1].trim()); + } + } + return env; + } + + private static String valueOrDefault(String value, String defaultValue) { + return StringUtils.hasText(value) ? value : defaultValue; + } + + private static class McpConnection { + private final String name; + private final String type; + private final String baseUrl; + private final String endpoint; + private final String command; + private final List args; + private final Map env; + private final long requestTimeoutSeconds; + private final long initTimeoutSeconds; + + private McpConnection( + String name, + String type, + String baseUrl, + String endpoint, + String command, + List args, + Map env, + long requestTimeoutSeconds, + long initTimeoutSeconds) { + this.name = name; + this.type = type; + this.baseUrl = baseUrl; + this.endpoint = endpoint; + this.command = command; + this.args = args == null ? Collections.emptyList() : args; + this.env = env == null ? Collections.emptyMap() : env; + this.requestTimeoutSeconds = requestTimeoutSeconds; + this.initTimeoutSeconds = initTimeoutSeconds; + } + + private static McpConnection sse(String name, String baseUrl, String endpoint) { + return new McpConnection( + name, "sse", baseUrl, endpoint, null, Collections.emptyList(), Collections.emptyMap(), -1, -1); + } + + private static McpConnection streamableHttp(String name, String baseUrl, String endpoint) { + return new McpConnection( + name, + "streamable-http", + baseUrl, + endpoint, + null, + Collections.emptyList(), + Collections.emptyMap(), + -1, + -1); + } + + private static McpConnection local(String name, String command, List args, Map env) { + return new McpConnection(name, "local", null, null, command, args, env, -1, -1); + } + } +} diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java index 30ec7ab62..39c52b069 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java @@ -24,9 +24,18 @@ import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.web.reactive.function.client.WebClient; +import com.fasterxml.jackson.databind.JsonNode; import reactor.core.publisher.Flux; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + public abstract class AbstractAIAssistant implements AIAssistant { protected final AIAssistant.Service aiServices; protected static final Integer MEMORY_LEN = 10; @@ -67,6 +76,12 @@ public abstract static class Builder implements AIAssistant.Builder { protected String systemPrompt; + protected io.modelcontextprotocol.client.McpAsyncClient mcpAsyncClient; + protected List mcpAsyncClients = new ArrayList<>(); + protected Consumer toolExecutionListener; + + private static final String AUTHORIZATION_HEADER = "Authorization"; + public Builder() {} public Builder withSystemPrompt(String systemPrompt) { @@ -74,6 +89,12 @@ public Builder withSystemPrompt(String systemPrompt) { return this; } + @Override + public Builder withToolExecutionListener(Consumer toolExecutionListener) { + this.toolExecutionListener = toolExecutionListener; + return this; + } + public Builder withConfig(AIAssistantConfig config) { this.config = config; return this; @@ -89,6 +110,44 @@ public Builder memoryStore(ChatMemory chatMemory) { return this; } + public Builder withMcpClient(io.modelcontextprotocol.client.McpAsyncClient mcpAsyncClient) { + this.mcpAsyncClient = mcpAsyncClient; + if (mcpAsyncClient != null) { + this.mcpAsyncClients = List.of(mcpAsyncClient); + } + return this; + } + + @Override + public Builder withMcpClients(List mcpAsyncClients) { + if (mcpAsyncClients == null || mcpAsyncClients.isEmpty()) { + this.mcpAsyncClients = Collections.emptyList(); + this.mcpAsyncClient = null; + return this; + } + + this.mcpAsyncClients = List.copyOf(mcpAsyncClients); + this.mcpAsyncClient = this.mcpAsyncClients.get(0); + return this; + } + + protected List getMcpAsyncClients() { + if (mcpAsyncClients != null && !mcpAsyncClients.isEmpty()) { + return mcpAsyncClients; + } + if (mcpAsyncClient != null) { + return List.of(mcpAsyncClient); + } + return Collections.emptyList(); + } + + protected void emitToolExecutionEvent(String executionId, String toolName, String status, String payload) { + if (toolExecutionListener != null) { + toolExecutionListener.accept( + new AIAssistant.ToolExecutionEvent(executionId, toolName, status, payload)); + } + } + public ChatMemory getChatMemory() { if (chatMemory == null) { chatMemory = MessageWindowChatMemory.builder() @@ -97,5 +156,81 @@ public ChatMemory getChatMemory() { } return chatMemory; } + + protected String resolveModelsBaseUrl() { + return null; + } + + protected String resolveModelsPath() { + return "/v1/models"; + } + + protected String resolveApiKey(Map credentials) { + if (credentials == null) { + return null; + } + String apiKey = credentials.get("apiKey"); + if (apiKey == null) { + return null; + } + apiKey = apiKey.trim(); + if (apiKey.startsWith("Bearer ")) { + apiKey = apiKey.substring("Bearer ".length()).trim(); + } + return apiKey; + } + + protected void applyModelRequestAuth(WebClient.RequestHeadersSpec requestSpec, String apiKey) { + if (apiKey != null && !apiKey.isBlank()) { + requestSpec.header(AUTHORIZATION_HEADER, "Bearer " + apiKey); + } + } + + protected List parseModelsResponse(JsonNode response) { + if (response == null || !response.has("data")) { + return Collections.emptyList(); + } + List models = new ArrayList<>(); + for (JsonNode node : response.get("data")) { + JsonNode idNode = node.get("id"); + if (idNode != null && !idNode.isNull()) { + models.add(idNode.asText()); + } + } + return models; + } + + @Override + public List getModels() { + String baseUrl = resolveModelsBaseUrl(); + if (baseUrl == null || baseUrl.isBlank()) { + return Collections.emptyList(); + } + + String path = resolveModelsPath(); + if (path == null || path.isBlank()) { + path = "/v1/models"; + } + + Map credentials = config == null ? Collections.emptyMap() : config.getCredentials(); + String apiKey = resolveApiKey(credentials); + + try { + WebClient webClient = + WebClient.builder().baseUrl(baseUrl.trim()).build(); + WebClient.RequestHeadersSpec requestSpec = webClient.get().uri(path); + applyModelRequestAuth(requestSpec, apiKey); + + JsonNode response = requestSpec + .retrieve() + .bodyToMono(JsonNode.class) + .timeout(Duration.ofSeconds(10)) + .block(); + + return parseModelsResponse(response); + } catch (Exception ignored) { + return Collections.emptyList(); + } + } } } diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java index ff2b96f5c..0898dcbb0 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java @@ -27,6 +27,9 @@ import reactor.core.publisher.Flux; +import java.util.List; +import java.util.function.Consumer; + public interface AIAssistant { /** @@ -75,8 +78,14 @@ interface Builder { Builder withConfig(AIAssistantConfig configProvider); + Builder withMcpClient(io.modelcontextprotocol.client.McpAsyncClient mcpAsyncClient); + + Builder withMcpClients(List mcpAsyncClients); + Builder withSystemPrompt(String systemPrompt); + Builder withToolExecutionListener(Consumer toolExecutionListener); + AIAssistant build(); ChatModel getChatModel(); @@ -84,5 +93,9 @@ interface Builder { StreamingChatModel getStreamingChatModel(); ChatMemory getChatMemory(); + + List getModels(); } + + record ToolExecutionEvent(String executionId, String toolName, String status, String payload) {} } diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java index 06e1dafef..d97f40081 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java @@ -21,13 +21,30 @@ import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig; import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; +import java.util.List; +import java.util.function.Consumer; + public interface AIAssistantFactory { - AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvider, SystemPrompt systemPrompt); + AIAssistant createWithPrompt(AIAssistantConfig config, SystemPrompt systemPrompt); + + AIAssistant createForTest(AIAssistantConfig config); + + default AIAssistant createAIService(AIAssistantConfig config) { + return createWithPrompt(config, SystemPrompt.DEFAULT_PROMPT); + } - AIAssistant createForTest(AIAssistantConfig config, Object toolProvider); + default AIAssistant createAIService( + AIAssistantConfig config, Consumer toolExecutionListener) { + return createWithPrompt(config, SystemPrompt.DEFAULT_PROMPT, toolExecutionListener); + } - default AIAssistant createAIService(AIAssistantConfig config, Object toolProvider) { - return createWithPrompt(config, toolProvider, SystemPrompt.DEFAULT_PROMPT); + default AIAssistant createWithPrompt( + AIAssistantConfig config, + SystemPrompt systemPrompt, + Consumer toolExecutionListener) { + return createWithPrompt(config, systemPrompt); } + + List getModels(AIAssistantConfig config); } diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java index 6661e6e02..b4053ea4c 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java @@ -29,18 +29,23 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; +import java.util.UUID; public class DashScopeAssistant extends AbstractAIAssistant { + private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_DASHSCOPE_BASE_URL"; private static final String BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode"; public DashScopeAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) { @@ -58,6 +63,19 @@ public static Builder builder() { public static class Builder extends AbstractAIAssistant.Builder { + @Override + protected String resolveModelsBaseUrl() { + return resolveDefaultBaseUrl(); + } + + private String resolveDefaultBaseUrl() { + String envBaseUrl = System.getenv(BASE_URL_ENV_KEY); + if (envBaseUrl != null && !envBaseUrl.isBlank()) { + return envBaseUrl; + } + return BASE_URL; + } + @Override public ChatModel getChatModel() { String model = config.getModel(); @@ -65,9 +83,17 @@ public ChatModel getChatModel() { String apiKey = config.getCredentials().get("apiKey"); Assert.notNull(apiKey, "apiKey must not be null"); - OpenAiApi openAiApi = - OpenAiApi.builder().baseUrl(BASE_URL).apiKey(apiKey).build(); - OpenAiChatOptions options = OpenAiChatOptions.builder().model(model).build(); + OpenAiApi openAiApi = OpenAiApi.builder() + .baseUrl(resolveDefaultBaseUrl()) + .apiKey(apiKey) + .build(); + OpenAiChatOptions.Builder optionsBuilder = + OpenAiChatOptions.builder().model(model); + List mcpClients = getMcpAsyncClients(); + if (!mcpClients.isEmpty()) { + optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients)); + } + OpenAiChatOptions options = optionsBuilder.build(); return OpenAiChatModel.builder() .openAiApi(openAiApi) .defaultOptions(options) @@ -80,6 +106,45 @@ public StreamingChatModel getStreamingChatModel() { return getChatModel(); } + private ToolCallback[] buildObservedToolCallbacks( + List mcpClients) { + ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks(); + ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length]; + for (int i = 0; i < callbacks.length; i++) { + observedCallbacks[i] = wrapToolCallback(callbacks[i]); + } + return observedCallbacks; + } + + private ToolCallback wrapToolCallback(ToolCallback delegate) { + return new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { + String toolName = getToolDefinition().name(); + String executionId = UUID.randomUUID().toString(); + emitToolExecutionEvent(executionId, toolName, "started", toolInput); + try { + String result = delegate.call(toolInput, toolContext); + emitToolExecutionEvent(executionId, toolName, "completed", result); + return result; + } catch (Exception e) { + emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage()); + throw e; + } + } + }; + } + public AIAssistant build() { ChatModel chatModel = getChatModel(); StreamingChatModel streamingChatModel = getStreamingChatModel(); @@ -132,13 +197,18 @@ public Flux streamChat(String userMessage) { StringBuilder responseBuilder = new StringBuilder(); return streamingChatModel.stream(prompt) - .map(chatResponse -> { - String content = - chatResponse.getResult().getOutput().getText(); - if (content != null) { + .concatMap(chatResponse -> { + String content = null; + if (chatResponse.getResult() != null + && chatResponse.getResult().getOutput() != null) { + content = + chatResponse.getResult().getOutput().getText(); + } + if (content != null && !content.isEmpty()) { responseBuilder.append(content); + return Flux.just(content); } - return content; + return Flux.empty(); }) .doOnComplete(() -> { // Save to memory when streaming completes diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java index 962ad8258..17f3286cd 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java @@ -33,15 +33,22 @@ import org.springframework.ai.deepseek.DeepSeekChatModel; import org.springframework.ai.deepseek.DeepSeekChatOptions; import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; +import java.util.UUID; public class DeepSeekAssistant extends AbstractAIAssistant { + private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_DEEPSEEK_BASE_URL"; + private static final String BASE_URL = "https://api.deepseek.com"; + public DeepSeekAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) { super(memoryId, chatMemory, aiServices); } @@ -57,6 +64,19 @@ public static Builder builder() { public static class Builder extends AbstractAIAssistant.Builder { + @Override + protected String resolveModelsBaseUrl() { + return resolveDefaultBaseUrl(); + } + + private String resolveDefaultBaseUrl() { + String envBaseUrl = System.getenv(BASE_URL_ENV_KEY); + if (envBaseUrl != null && !envBaseUrl.isBlank()) { + return envBaseUrl; + } + return BASE_URL; + } + @Override public ChatModel getChatModel() { String model = config.getModel(); @@ -64,9 +84,17 @@ public ChatModel getChatModel() { String apiKey = config.getCredentials().get("apiKey"); Assert.notNull(apiKey, "apiKey must not be null"); - DeepSeekApi deepSeekApi = DeepSeekApi.builder().apiKey(apiKey).build(); - DeepSeekChatOptions options = - DeepSeekChatOptions.builder().model(model).build(); + DeepSeekApi deepSeekApi = DeepSeekApi.builder() + .baseUrl(resolveDefaultBaseUrl()) + .apiKey(apiKey) + .build(); + DeepSeekChatOptions.Builder optionsBuilder = + DeepSeekChatOptions.builder().model(model); + List mcpClients = getMcpAsyncClients(); + if (!mcpClients.isEmpty()) { + optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients)); + } + DeepSeekChatOptions options = optionsBuilder.build(); return DeepSeekChatModel.builder() .deepSeekApi(deepSeekApi) .defaultOptions(options) @@ -79,6 +107,45 @@ public StreamingChatModel getStreamingChatModel() { return getChatModel(); } + private ToolCallback[] buildObservedToolCallbacks( + List mcpClients) { + ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks(); + ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length]; + for (int i = 0; i < callbacks.length; i++) { + observedCallbacks[i] = wrapToolCallback(callbacks[i]); + } + return observedCallbacks; + } + + private ToolCallback wrapToolCallback(ToolCallback delegate) { + return new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { + String toolName = getToolDefinition().name(); + String executionId = UUID.randomUUID().toString(); + emitToolExecutionEvent(executionId, toolName, "started", toolInput); + try { + String result = delegate.call(toolInput, toolContext); + emitToolExecutionEvent(executionId, toolName, "completed", result); + return result; + } catch (Exception e) { + emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage()); + throw e; + } + } + }; + } + public AIAssistant build() { ChatModel chatModel = getChatModel(); StreamingChatModel streamingChatModel = getStreamingChatModel(); @@ -131,13 +198,18 @@ public Flux streamChat(String userMessage) { StringBuilder responseBuilder = new StringBuilder(); return streamingChatModel.stream(prompt) - .map(chatResponse -> { - String content = - chatResponse.getResult().getOutput().getText(); - if (content != null) { + .concatMap(chatResponse -> { + String content = null; + if (chatResponse.getResult() != null + && chatResponse.getResult().getOutput() != null) { + content = + chatResponse.getResult().getOutput().getText(); + } + if (content != null && !content.isEmpty()) { responseBuilder.append(content); + return Flux.just(content); } - return content; + return Flux.empty(); }) .doOnComplete(() -> { // Save to memory when streaming completes diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java index b51fa75ef..c6fa5e2c2 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java @@ -29,18 +29,24 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.UUID; public class OpenAIAssistant extends AbstractAIAssistant { + private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_OPENAI_BASE_URL"; private static final String BASE_URL = "https://api.openai.com"; public OpenAIAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) { @@ -58,6 +64,26 @@ public static Builder builder() { public static class Builder extends AbstractAIAssistant.Builder { + @Override + protected String resolveModelsBaseUrl() { + Map credentials = config == null ? null : config.getCredentials(); + if (credentials != null) { + String baseUrl = credentials.get("baseUrl"); + if (baseUrl != null && !baseUrl.isBlank()) { + return baseUrl; + } + } + return resolveDefaultBaseUrl(); + } + + private String resolveDefaultBaseUrl() { + String envBaseUrl = System.getenv(BASE_URL_ENV_KEY); + if (envBaseUrl != null && !envBaseUrl.isBlank()) { + return envBaseUrl; + } + return BASE_URL; + } + @Override public ChatModel getChatModel() { String model = config.getModel(); @@ -65,9 +91,17 @@ public ChatModel getChatModel() { String apiKey = config.getCredentials().get("apiKey"); Assert.notNull(apiKey, "apiKey must not be null"); - OpenAiApi openAiApi = - OpenAiApi.builder().baseUrl(BASE_URL).apiKey(apiKey).build(); - OpenAiChatOptions options = OpenAiChatOptions.builder().model(model).build(); + OpenAiApi openAiApi = OpenAiApi.builder() + .baseUrl(resolveDefaultBaseUrl()) + .apiKey(apiKey) + .build(); + OpenAiChatOptions.Builder optionsBuilder = + OpenAiChatOptions.builder().model(model); + List mcpClients = getMcpAsyncClients(); + if (!mcpClients.isEmpty()) { + optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients)); + } + OpenAiChatOptions options = optionsBuilder.build(); return OpenAiChatModel.builder() .openAiApi(openAiApi) .defaultOptions(options) @@ -80,6 +114,45 @@ public StreamingChatModel getStreamingChatModel() { return getChatModel(); } + private ToolCallback[] buildObservedToolCallbacks( + List mcpClients) { + ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks(); + ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length]; + for (int i = 0; i < callbacks.length; i++) { + observedCallbacks[i] = wrapToolCallback(callbacks[i]); + } + return observedCallbacks; + } + + private ToolCallback wrapToolCallback(ToolCallback delegate) { + return new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { + String toolName = getToolDefinition().name(); + String executionId = UUID.randomUUID().toString(); + emitToolExecutionEvent(executionId, toolName, "started", toolInput); + try { + String result = delegate.call(toolInput, toolContext); + emitToolExecutionEvent(executionId, toolName, "completed", result); + return result; + } catch (Exception e) { + emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage()); + throw e; + } + } + }; + } + public AIAssistant build() { ChatModel chatModel = getChatModel(); StreamingChatModel streamingChatModel = getStreamingChatModel(); @@ -132,13 +205,18 @@ public Flux streamChat(String userMessage) { StringBuilder responseBuilder = new StringBuilder(); return streamingChatModel.stream(prompt) - .map(chatResponse -> { - String content = - chatResponse.getResult().getOutput().getText(); - if (content != null) { + .concatMap(chatResponse -> { + String content = null; + if (chatResponse.getResult() != null + && chatResponse.getResult().getOutput() != null) { + content = + chatResponse.getResult().getOutput().getText(); + } + if (content != null && !content.isEmpty()) { responseBuilder.append(content); + return Flux.just(content); } - return content; + return Flux.empty(); }) .doOnComplete(() -> { // Save to memory when streaming completes diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java index 468020eb1..65b1782a1 100644 --- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java +++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java @@ -29,18 +29,26 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; +import org.springframework.web.reactive.function.client.WebClient; +import com.fasterxml.jackson.databind.JsonNode; import reactor.core.publisher.Flux; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.UUID; public class QianFanAssistant extends AbstractAIAssistant { + private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_QIANFAN_BASE_URL"; private static final String BASE_URL = "https://qianfan.baidubce.com"; public QianFanAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) { @@ -58,6 +66,60 @@ public static Builder builder() { public static class Builder extends AbstractAIAssistant.Builder { + @Override + protected String resolveModelsBaseUrl() { + return resolveDefaultBaseUrl(); + } + + private String resolveDefaultBaseUrl() { + String envBaseUrl = System.getenv(BASE_URL_ENV_KEY); + if (envBaseUrl != null && !envBaseUrl.isBlank()) { + return envBaseUrl; + } + return BASE_URL; + } + + @Override + protected String resolveModelsPath() { + return "/v2/chat/models"; + } + + @Override + public List getModels() { + String apiKey = resolveApiKey(config == null ? null : config.getCredentials()); + if (apiKey == null || apiKey.isBlank()) { + return Collections.emptyList(); + } + + try { + WebClient webClient = + WebClient.builder().baseUrl(resolveModelsBaseUrl()).build(); + JsonNode response = webClient + .get() + .uri(resolveModelsPath()) + .header("Authorization", "Bearer " + apiKey) + .retrieve() + .bodyToMono(JsonNode.class) + .timeout(java.time.Duration.ofSeconds(10)) + .block(); + + if (response == null || !response.has("result")) { + return Collections.emptyList(); + } + + List models = new ArrayList<>(); + for (JsonNode modelNode : response.get("result")) { + JsonNode modelId = modelNode.get("model"); + if (modelId != null && !modelId.isNull()) { + models.add(modelId.asText()); + } + } + return models; + } catch (Exception ignored) { + return Collections.emptyList(); + } + } + @Override public ChatModel getChatModel() { String model = config.getModel(); @@ -66,11 +128,17 @@ public ChatModel getChatModel() { Assert.notNull(apiKey, "apiKey must not be null"); OpenAiApi openAiApi = OpenAiApi.builder() - .baseUrl(BASE_URL) + .baseUrl(resolveDefaultBaseUrl()) .completionsPath("/v2/chat/completions") .apiKey(apiKey) .build(); - OpenAiChatOptions options = OpenAiChatOptions.builder().model(model).build(); + OpenAiChatOptions.Builder optionsBuilder = + OpenAiChatOptions.builder().model(model); + List mcpClients = getMcpAsyncClients(); + if (!mcpClients.isEmpty()) { + optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients)); + } + OpenAiChatOptions options = optionsBuilder.build(); return OpenAiChatModel.builder() .openAiApi(openAiApi) .defaultOptions(options) @@ -82,6 +150,45 @@ public StreamingChatModel getStreamingChatModel() { return getChatModel(); } + private ToolCallback[] buildObservedToolCallbacks( + List mcpClients) { + ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks(); + ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length]; + for (int i = 0; i < callbacks.length; i++) { + observedCallbacks[i] = wrapToolCallback(callbacks[i]); + } + return observedCallbacks; + } + + private ToolCallback wrapToolCallback(ToolCallback delegate) { + return new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return delegate.getToolDefinition(); + } + + @Override + public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) { + String toolName = getToolDefinition().name(); + String executionId = UUID.randomUUID().toString(); + emitToolExecutionEvent(executionId, toolName, "started", toolInput); + try { + String result = delegate.call(toolInput, toolContext); + emitToolExecutionEvent(executionId, toolName, "completed", result); + return result; + } catch (Exception e) { + emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage()); + throw e; + } + } + }; + } + public AIAssistant build() { ChatModel chatModel = getChatModel(); StreamingChatModel streamingChatModel = getStreamingChatModel(); @@ -134,13 +241,18 @@ public Flux streamChat(String userMessage) { StringBuilder responseBuilder = new StringBuilder(); return streamingChatModel.stream(prompt) - .map(chatResponse -> { - String content = - chatResponse.getResult().getOutput().getText(); - if (content != null) { + .concatMap(chatResponse -> { + String content = null; + if (chatResponse.getResult() != null + && chatResponse.getResult().getOutput() != null) { + content = + chatResponse.getResult().getOutput().getText(); + } + if (content != null && !content.isEmpty()) { responseBuilder.append(content); + return Flux.just(content); } - return content; + return Flux.empty(); }) .doOnComplete(() -> { // Save to memory when streaming completes diff --git a/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java b/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java index e87d2c418..a780a65d5 100644 --- a/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java +++ b/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java @@ -64,8 +64,8 @@ void testCreateAIAssistant() { try (MockedStatic openAIAssistantMockedStatic = mockStatic(OpenAIAssistant.class)) { openAIAssistantMockedStatic.when(OpenAIAssistant::builder).thenReturn(mockBuilder); - generalAssistantFactory.createAIService(assistantConfigProvider, null); - generalAssistantFactory.createForTest(assistantConfigProvider, null); + generalAssistantFactory.createAIService(assistantConfigProvider); + generalAssistantFactory.createForTest(assistantConfigProvider); } } } diff --git a/bigtop-manager-ai/src/test/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManagerTest.java b/bigtop-manager-ai/src/test/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManagerTest.java new file mode 100644 index 000000000..544f89d9c --- /dev/null +++ b/bigtop-manager-ai/src/test/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManagerTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.ai.config; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties; +import org.springframework.beans.factory.ObjectProvider; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class McpAsyncClientManagerTest { + + private static final ObjectProvider SSE_PROVIDER = new EmptyObjectProvider<>(); + private static final ObjectProvider STREAMABLE_HTTP_PROVIDER = + new EmptyObjectProvider<>(); + private static final ObjectProvider STDIO_PROVIDER = new EmptyObjectProvider<>(); + + @Test + void resolveBaseUrlShouldPreferConfiguredValue() throws Exception { + McpAsyncClientManager manager = + new McpAsyncClientManager(SSE_PROVIDER, STREAMABLE_HTTP_PROVIDER, STDIO_PROVIDER); + setField(manager, "connections", "name=sseMain,type=sse,baseUrl=http://127.0.0.1:9000,endpoint=/mcp/sse"); + + Method parseMethod = McpAsyncClientManager.class.getDeclaredMethod("parseConnections"); + parseMethod.setAccessible(true); + @SuppressWarnings("unchecked") + List conns = (List) parseMethod.invoke(manager); + assertEquals("http://127.0.0.1:9000", getStringField(conns.get(0), "baseUrl")); + } + + @Test + void resolveBaseUrlShouldFallbackToServerPort() throws Exception { + McpAsyncClientManager manager = + new McpAsyncClientManager(SSE_PROVIDER, STREAMABLE_HTTP_PROVIDER, STDIO_PROVIDER); + setField(manager, "connections", " "); + + Method parseMethod = McpAsyncClientManager.class.getDeclaredMethod("parseConnections"); + parseMethod.setAccessible(true); + @SuppressWarnings("unchecked") + List conns = (List) parseMethod.invoke(manager); + assertTrue(conns.isEmpty()); + } + + @Test + void parseConnectionsShouldSupportSseHttpAndLocal() throws Exception { + McpAsyncClientManager manager = + new McpAsyncClientManager(SSE_PROVIDER, STREAMABLE_HTTP_PROVIDER, STDIO_PROVIDER); + setField( + manager, + "connections", + "name=sseMain,type=sse,baseUrl=http://127.0.0.1:8080,endpoint=/mcp/sse;" + + "name=httpMain,type=streamable-http,baseUrl=http://127.0.0.1:8081,endpoint=/mcp;" + + "name=localMain,type=local,command=npx,args=-y|@modelcontextprotocol/server-memory,env=FOO:bar|BAZ:qux"); + + Method parseMethod = McpAsyncClientManager.class.getDeclaredMethod("parseConnections"); + parseMethod.setAccessible(true); + @SuppressWarnings("unchecked") + List conns = (List) parseMethod.invoke(manager); + + assertEquals(3, conns.size()); + assertEquals("sse", getStringField(conns.get(0), "type")); + assertEquals("streamable-http", getStringField(conns.get(1), "type")); + assertEquals("local", getStringField(conns.get(2), "type")); + + @SuppressWarnings("unchecked") + List args = (List) getField(conns.get(2), "args"); + assertEquals(2, args.size()); + assertEquals("-y", args.get(0)); + assertEquals("@modelcontextprotocol/server-memory", args.get(1)); + + @SuppressWarnings("unchecked") + Map env = (Map) getField(conns.get(2), "env"); + assertEquals("bar", env.get("FOO")); + assertEquals("qux", env.get("BAZ")); + } + + @Test + void parseConnectionsShouldFallbackEmptyWhenBlank() throws Exception { + McpAsyncClientManager manager = + new McpAsyncClientManager(SSE_PROVIDER, STREAMABLE_HTTP_PROVIDER, STDIO_PROVIDER); + setField(manager, "connections", " "); + + Method parseMethod = McpAsyncClientManager.class.getDeclaredMethod("parseConnections"); + parseMethod.setAccessible(true); + @SuppressWarnings("unchecked") + List conns = (List) parseMethod.invoke(manager); + + assertTrue(conns.isEmpty()); + } + + private void setField(McpAsyncClientManager manager, String fieldName, String value) throws Exception { + var field = McpAsyncClientManager.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(manager, value); + } + + private Object getField(Object object, String fieldName) throws Exception { + Field field = object.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(object); + } + + private String getStringField(Object object, String fieldName) throws Exception { + return (String) getField(object, fieldName); + } + + private static class EmptyObjectProvider implements ObjectProvider { + @Override + public T getObject() { + throw new UnsupportedOperationException("No object available"); + } + + @Override + public T getObject(Object... args) { + throw new UnsupportedOperationException("No object available"); + } + + @Override + public T getIfAvailable() { + return null; + } + + @Override + public T getIfUnique() { + return null; + } + } +} diff --git a/bigtop-manager-bom/pom.xml b/bigtop-manager-bom/pom.xml index 878ae561b..342956776 100644 --- a/bigtop-manager-bom/pom.xml +++ b/bigtop-manager-bom/pom.xml @@ -31,7 +31,7 @@ Bigtop Manager Bom - 1.0.0-RC1 + 1.1.3 3.2.0 2.2.0 2.3.32 @@ -286,6 +286,12 @@ ${spring-ai.version} + + org.springframework.ai + spring-ai-starter-mcp-client-webflux + ${spring-ai.version} + + dev.langchain4j diff --git a/bigtop-manager-server/pom.xml b/bigtop-manager-server/pom.xml index e6d3ff384..17cbc5ab1 100644 --- a/bigtop-manager-server/pom.xml +++ b/bigtop-manager-server/pom.xml @@ -63,10 +63,6 @@ org.apache.bigtop bigtop-manager-ai - - dev.langchain4j - langchain4j - org.jetbrains annotations diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/LLMConfigController.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/LLMConfigController.java index 141d5341a..f2daf7f0d 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/LLMConfigController.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/LLMConfigController.java @@ -41,6 +41,7 @@ import jakarta.annotation.Resource; import java.util.List; +import java.util.Map; @Tag(name = "LLM Config Controller") @RestController @@ -69,6 +70,15 @@ public ResponseEntity> platformsAuthCredential( return ResponseEntity.success(llmConfigService.platformsAuthCredentials(platformId)); } + @Operation(summary = "list platform models", description = "List models from /v1/models") + @PostMapping("/platforms/{platformId}/models") + public ResponseEntity> platformModels( + @PathVariable(name = "platformId") Long platformId, @RequestBody AuthPlatformReq authPlatformReq) { + AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromReq2DTO(authPlatformReq); + Map authCredentials = authPlatformDTO.getAuthCredentials(); + return ResponseEntity.success(llmConfigService.platformModels(platformId, authCredentials)); + } + @Operation(summary = "list auth platforms", description = "List authorized platforms") @GetMapping("/auth-platforms") public ResponseEntity> authorizedPlatforms() { diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/mcp/tool/StackMcpTool.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/mcp/tool/StackMcpTool.java index f27bf40b5..346fdb03b 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/mcp/tool/StackMcpTool.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/mcp/tool/StackMcpTool.java @@ -29,11 +29,14 @@ import org.springframework.ai.tool.annotation.Tool; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + import java.util.ArrayList; import java.util.List; import java.util.Map; @Component +@Slf4j public class StackMcpTool implements McpTool { @Tool( @@ -51,7 +54,7 @@ public List listStacks() { stackVO.setServices(ServiceConverter.INSTANCE.fromDTO2VO(serviceDTOList)); stackVOList.add(stackVO); } - + log.info("ListStacks tool called, total stacks: {}", stackVOList.size()); return stackVOList; } } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/AuthPlatformConverter.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/AuthPlatformConverter.java index 329823f36..5b30d61c3 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/AuthPlatformConverter.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/AuthPlatformConverter.java @@ -54,6 +54,7 @@ default Map mapAuthCredentials(List authCrede return null; } return authCredentials.stream() + .filter(item -> item != null && item.getKey() != null) .collect(Collectors.toMap(AuthCredentialReq::getKey, AuthCredentialReq::getValue)); } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/TalkVO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/TalkVO.java index f52de9798..02269e47f 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/TalkVO.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/TalkVO.java @@ -25,4 +25,14 @@ public class TalkVO { private String content; private String finishReason; + + private String eventType; + + private String executionId; + + private String toolName; + + private String toolStatus; + + private String toolPayload; } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/LLMConfigService.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/LLMConfigService.java index 223b5723c..3a3dd1d33 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/LLMConfigService.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/LLMConfigService.java @@ -24,6 +24,7 @@ import org.apache.bigtop.manager.server.model.vo.PlatformVO; import java.util.List; +import java.util.Map; public interface LLMConfigService { @@ -48,4 +49,6 @@ public interface LLMConfigService { AuthPlatformVO getAuthorizedPlatform(Long authId); PlatformVO getPlatform(Long id); + + List platformModels(Long platformId, Map authCredentials); } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java index 45ae57aec..08eca694f 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java @@ -45,23 +45,28 @@ import org.apache.bigtop.manager.server.model.vo.ChatThreadVO; import org.apache.bigtop.manager.server.model.vo.TalkVO; import org.apache.bigtop.manager.server.service.ChatbotService; -import org.apache.bigtop.manager.server.tools.provider.AIServiceToolsProvider; import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import lombok.extern.slf4j.Slf4j; +import reactor.core.Disposable; import reactor.core.publisher.Flux; import jakarta.annotation.Resource; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; @Service @Slf4j public class ChatbotServiceImpl implements ChatbotService { + private static final long CHAT_SSE_TIMEOUT_MILLIS = 30 * 60 * 1000L; + @Resource private PlatformDao platformDao; @@ -74,9 +79,6 @@ public class ChatbotServiceImpl implements ChatbotService { @Resource private ChatMessageDao chatMessageDao; - @Resource - private AIServiceToolsProvider aiServiceToolsProvider; - @Resource private AIAssistantFactory aiAssistantFactory; @@ -129,19 +131,49 @@ public List getAllChatThreads() { @Override public SseEmitter talk(Long threadId, ChatbotCommand command, String message) { - AIAssistant aiAssistant = prepareTalk(threadId, command); + SseEmitter emitter = new SseEmitter(CHAT_SSE_TIMEOUT_MILLIS); + AtomicBoolean emitterClosed = new AtomicBoolean(false); + AtomicReference subscriptionRef = new AtomicReference<>(); + AIAssistant aiAssistant = prepareTalk(threadId, command, emitter, emitterClosed); Flux stringFlux = (command == null) ? aiAssistant.streamAsk(message) : Flux.just(aiAssistant.ask(message)); - SseEmitter emitter = new SseEmitter(); - - stringFlux.subscribe( - s -> sendTalkVO(emitter, s, null), - throwable -> handleError(emitter, throwable), - () -> completeEmitter(emitter)); + Disposable subscription = stringFlux.subscribe( + s -> { + if (!sendTalkVO(emitter, emitterClosed, s, null)) { + disposeSubscription(subscriptionRef); + } + }, + throwable -> { + disposeSubscription(subscriptionRef); + handleError(emitter, emitterClosed, throwable); + }, + () -> { + disposeSubscription(subscriptionRef); + completeEmitter(emitter, emitterClosed); + }); + + subscriptionRef.set(subscription); + + emitter.onCompletion(() -> { + emitterClosed.set(true); + disposeSubscription(subscriptionRef); + }); + + emitter.onTimeout(() -> { + emitterClosed.set(true); + disposeSubscription(subscriptionRef); + try { + emitter.complete(); + } catch (Exception ignored) { + } + }); - emitter.onTimeout(emitter::complete); + emitter.onError(error -> { + emitterClosed.set(true); + disposeSubscription(subscriptionRef); + }); return emitter; } @@ -233,13 +265,20 @@ private PlatformType getPlatformType(String platformName) { } private AIAssistant buildAIAssistant( - String platformName, String model, Map credentials, Long threadId, ChatbotCommand command) { + String platformName, + String model, + Map credentials, + Long threadId, + ChatbotCommand command, + SseEmitter emitter, + AtomicBoolean emitterClosed) { return aiAssistantFactory.createAIService( getAIAssistantConfig(platformName, model, credentials, threadId), - aiServiceToolsProvider.getToolsProvide(command)); + event -> sendToolExecutionEvent(emitter, emitterClosed, event)); } - private AIAssistant prepareTalk(Long threadId, ChatbotCommand command) { + private AIAssistant prepareTalk( + Long threadId, ChatbotCommand command, SseEmitter emitter, AtomicBoolean emitterClosed) { ChatThreadPO chatThreadPO = validateAndGetChatThread(threadId); AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform(); @@ -255,29 +294,112 @@ private AIAssistant prepareTalk(Long threadId, ChatbotCommand command) { authPlatformDTO.getModel(), authPlatformDTO.getAuthCredentials(), threadId, - command); + command, + emitter, + emitterClosed); + } + + private void sendToolExecutionEvent( + SseEmitter emitter, AtomicBoolean emitterClosed, AIAssistant.ToolExecutionEvent event) { + if (emitterClosed.get()) { + return; + } + + TalkVO talkVO = new TalkVO(); + talkVO.setEventType("tool_execution"); + talkVO.setExecutionId(event.executionId()); + talkVO.setToolName(event.toolName()); + talkVO.setToolStatus(event.status()); + talkVO.setToolPayload(event.payload()); + sendTalkVO(emitter, emitterClosed, talkVO); } - private void sendTalkVO(SseEmitter emitter, String content, String finishReason) { + private boolean sendTalkVO(SseEmitter emitter, AtomicBoolean emitterClosed, String content, String finishReason) { + TalkVO talkVO = new TalkVO(); + talkVO.setContent(content); + talkVO.setFinishReason(finishReason); + return sendTalkVO(emitter, emitterClosed, talkVO); + } + + private boolean sendTalkVO(SseEmitter emitter, AtomicBoolean emitterClosed, TalkVO talkVO) { + if (emitterClosed.get()) { + return false; + } + try { - TalkVO talkVO = new TalkVO(); - talkVO.setContent(content); - talkVO.setFinishReason(finishReason); emitter.send(talkVO); + return true; + } catch (IllegalStateException e) { + if (emitterClosed.compareAndSet(false, true)) { + log.warn("SSE emitter already closed, stop sending stream data: {}", e.getMessage()); + } + return false; } catch (Exception e) { - log.error("Error sending data to SseEmitter", e); - emitter.completeWithError(e); + if (emitterClosed.compareAndSet(false, true)) { + log.error("Error sending data to SseEmitter", e); + } + try { + emitter.complete(); + } catch (Exception ignored) { + } + return false; } } - private void handleError(SseEmitter emitter, Throwable throwable) { + private void handleError(SseEmitter emitter, AtomicBoolean emitterClosed, Throwable throwable) { + if (isStreamCancellation(throwable)) { + log.warn("SSE streaming cancelled: {}", throwable.getMessage()); + if (emitterClosed.compareAndSet(false, true)) { + try { + emitter.complete(); + } catch (Exception ignored) { + } + } + return; + } + log.error("Error during SSE streaming: {}", throwable.getMessage(), throwable); - sendTalkVO(emitter, null, "Error: " + throwable.getMessage()); - emitter.completeWithError(throwable); + + if (!sendTalkVO(emitter, emitterClosed, null, "Error: " + throwable.getMessage())) { + return; + } + + if (emitterClosed.compareAndSet(false, true)) { + try { + emitter.complete(); + } catch (Exception ignored) { + } + } } - private void completeEmitter(SseEmitter emitter) { - sendTalkVO(emitter, null, "completed"); - emitter.complete(); + private void completeEmitter(SseEmitter emitter, AtomicBoolean emitterClosed) { + if (!sendTalkVO(emitter, emitterClosed, null, "completed")) { + return; + } + + if (emitterClosed.compareAndSet(false, true)) { + try { + emitter.complete(); + } catch (Exception ignored) { + } + } + } + + private void disposeSubscription(AtomicReference subscriptionRef) { + Disposable disposable = subscriptionRef.getAndSet(null); + if (disposable != null && !disposable.isDisposed()) { + disposable.dispose(); + } + } + + private boolean isStreamCancellation(Throwable throwable) { + Throwable current = throwable; + while (current != null) { + if (current instanceof InterruptedException || current instanceof CancellationException) { + return true; + } + current = current.getCause(); + } + return false; } } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java index eaf46eae4..6e2ea9944 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java @@ -22,7 +22,6 @@ import org.apache.bigtop.manager.ai.core.enums.PlatformType; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory; -import org.apache.bigtop.manager.common.utils.JsonUtils; import org.apache.bigtop.manager.dao.po.AuthPlatformPO; import org.apache.bigtop.manager.dao.po.ChatMessagePO; import org.apache.bigtop.manager.dao.po.ChatThreadPO; @@ -47,15 +46,11 @@ import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.stereotype.Service; -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.service.tool.ToolExecutor; -import dev.langchain4j.service.tool.ToolProvider; -import dev.langchain4j.service.tool.ToolProviderResult; import lombok.extern.slf4j.Slf4j; import jakarta.annotation.Resource; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -78,15 +73,61 @@ public class LLMConfigServiceImpl implements LLMConfigService { @Resource private AIAssistantFactory aiAssistantFactory; - private static final String TEST_FLAG = "ZmxhZw=="; - private static final String TEST_KEY = "bm"; - @Override public List platforms() { List platformPOs = platformDao.findAll(); return PlatformConverter.INSTANCE.fromPO2VO(platformPOs); } + private List getDynamicModels(PlatformPO platformPO, Map explicitCreds) { + PlatformType platformType = + PlatformType.getPlatformType(platformPO.getName().toLowerCase()); + if (platformType == null) { + return getDefaultModels(platformPO); + } + + Map creds = new HashMap<>(); + + if (explicitCreds != null && !explicitCreds.isEmpty()) { + creds.putAll(explicitCreds); + } else { + List authPlatformPOs = authPlatformDao.findAll(); + for (AuthPlatformPO auth : authPlatformPOs) { + if (auth.getPlatformId().equals(platformPO.getId()) && !auth.getIsDeleted()) { + creds = AuthPlatformConverter.INSTANCE.fromPO2DTO(auth).getAuthCredentials(); + break; + } + } + } + + GeneralAssistantConfig config = GeneralAssistantConfig.builder() + .setPlatformType(platformType) + .setModel("dummy") + .addCredentials(creds) + .build(); + + List models = aiAssistantFactory.getModels(config); + if (models != null && !models.isEmpty()) { + return models; + } + + return getDefaultModels(platformPO); + } + + private List getDefaultModels(PlatformPO platformPO) { + String supportModels = platformPO.getSupportModels(); + if (supportModels == null || supportModels.isBlank()) { + return Collections.emptyList(); + } + return List.of(supportModels.split(",")); + } + + @Override + public List platformModels(Long platformId, Map authCredentials) { + PlatformPO platformPO = validateAndGetPlatform(platformId); + return getDynamicModels(platformPO, authCredentials); + } + @Override public List platformsAuthCredentials(Long platformId) { PlatformPO platformPO = platformDao.findById(platformId); @@ -164,11 +205,6 @@ public boolean testAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { PlatformPO platformPO = validateAndGetPlatform(authPlatformDTO.getPlatformId()); - List supportModels = List.of(platformPO.getSupportModels().split(",")); - if (supportModels.isEmpty() || !supportModels.contains(authPlatformDTO.getModel())) { - throw new ApiException(ApiExceptionEnum.MODEL_NOT_SUPPORTED); - } - if (authPlatformDTO.getId() != null) { AuthPlatformPO authPlatformPO = validateAndGetAuthPlatform(authPlatformDTO.getId()); @@ -180,6 +216,11 @@ public boolean testAuthorizedPlatform(AuthPlatformDTO authPlatformDTO) { Map credentialSet = getStringMap(authPlatformDTO, PlatformConverter.INSTANCE.fromPO2DTO(platformPO)); + List dynamicModels = getDynamicModels(platformPO, credentialSet); + if (dynamicModels.isEmpty() || !dynamicModels.contains(authPlatformDTO.getModel())) { + throw new ApiException(ApiExceptionEnum.MODEL_NOT_SUPPORTED); + } + if (!testAuthorization(platformPO.getName(), authPlatformDTO.getModel(), credentialSet)) { throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT); } @@ -260,7 +301,6 @@ public AuthPlatformVO getAuthorizedPlatform(Long authId) { @Override public PlatformVO getPlatform(Long id) { PlatformPO platformPO = validateAndGetPlatform(id); - return PlatformConverter.INSTANCE.fromPO2VO(platformPO); } @@ -302,10 +342,8 @@ private PlatformType getPlatformType(String platformName) { } private Boolean testAuthorization(String platformName, String model, Map credentials) { - Boolean result = testFuncCalling(platformName, model, credentials); - log.info("Test func calling result: {}", result); GeneralAssistantConfig generalAssistantConfig = getAIAssistantConfig(platformName, model, credentials); - AIAssistant aiAssistant = aiAssistantFactory.createForTest(generalAssistantConfig, null); + AIAssistant aiAssistant = aiAssistantFactory.createForTest(generalAssistantConfig); try { return aiAssistant.test(); } catch (Exception e) { @@ -313,40 +351,6 @@ private Boolean testAuthorization(String platformName, String model, Map credentials) { - ToolProvider toolProvider = (toolProviderRequest) -> { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getFlag") - .description("Get flag based on key") - .parameters(JsonObjectSchema.builder() - .addStringProperty("key") - .description("Lowercase key to get flag") - .build()) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - String key = arguments.get("key").toString(); - if (key.equals(TEST_KEY)) { - return TEST_FLAG; - } - return null; - }; - - return ToolProviderResult.builder() - .add(toolSpecification, toolExecutor) - .build(); - }; - - GeneralAssistantConfig generalAssistantConfig = getAIAssistantConfig(platformName, model, credentials); - AIAssistant aiAssistant = aiAssistantFactory.createForTest(generalAssistantConfig, toolProvider); - try { - return aiAssistant.ask("What is the flag of " + TEST_KEY).contains(TEST_FLAG); - } catch (Exception e) { - log.error("Test function calling failed", e); - return false; - } - } - private void switchActivePlatform(Long id) { List authPlatformPOS = authPlatformDao.findAll(); for (AuthPlatformPO authPlatformPO : authPlatformPOS) { diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/ClusterFunctions.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/ClusterFunctions.java deleted file mode 100644 index 0978164da..000000000 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/ClusterFunctions.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.functions; - -import org.apache.bigtop.manager.common.utils.JsonUtils; -import org.apache.bigtop.manager.server.model.vo.ClusterVO; -import org.apache.bigtop.manager.server.service.ClusterService; - -import org.springframework.stereotype.Component; - -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.service.tool.ToolExecutor; -import lombok.extern.slf4j.Slf4j; - -import jakarta.annotation.Resource; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -@Component -@Slf4j -public class ClusterFunctions { - @Resource - private ClusterService clusterService; - - public Map listCluster() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("listCluster") - .description("Get cluster list") - .build(); - ToolExecutor toolExecutor = - (toolExecutionRequest, memoryId) -> JsonUtils.indentWriteAsString(clusterService.list()); - - return Map.of(toolSpecification, toolExecutor); - } - - public Map getClusterById() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getClusterById") - .description("Get cluster information based on ID") - .parameters(JsonObjectSchema.builder() - .description("Cluster ID") - .addNumberProperty("clusterId") - .build()) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - Long clusterId = Long.valueOf(arguments.get("clusterId").toString()); - ClusterVO clusterVO = clusterService.get(clusterId); - if (clusterVO == null) { - return "Cluster not found"; - } - return JsonUtils.indentWriteAsString(clusterVO); - }; - - return Map.of(toolSpecification, toolExecutor); - } - - public Map getClusterByName() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getClusterByName") - .description("Get cluster information based on cluster name") - .parameters(JsonObjectSchema.builder() - .description("Cluster name") - .addStringProperty("clusterName") - .build()) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - String clusterName = arguments.get("clusterName").toString(); - List clusterVOS = clusterService.list(); - for (ClusterVO clusterVO : clusterVOS) { - if (clusterVO.getName().equals(clusterName)) { - return JsonUtils.indentWriteAsString(clusterVO); - } - } - return "Cluster not found"; - }; - - return Map.of(toolSpecification, toolExecutor); - } - - public Map getAllFunctions() { - Map functions = new HashMap<>(); - functions.putAll(listCluster()); - functions.putAll(getClusterById()); - functions.putAll(getClusterByName()); - return functions; - } -} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/HostFunctions.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/HostFunctions.java deleted file mode 100644 index b14b7327b..000000000 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/HostFunctions.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.functions; - -import org.apache.bigtop.manager.common.utils.JsonUtils; -import org.apache.bigtop.manager.dao.query.HostQuery; -import org.apache.bigtop.manager.server.model.vo.HostVO; -import org.apache.bigtop.manager.server.model.vo.PageVO; -import org.apache.bigtop.manager.server.service.HostService; - -import org.springframework.stereotype.Component; - -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.service.tool.ToolExecutor; -import lombok.extern.slf4j.Slf4j; - -import jakarta.annotation.Resource; -import java.util.HashMap; -import java.util.Map; - -@Component -@Slf4j -public class HostFunctions { - @Resource - private HostService hostService; - - public Map getHostById() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getHostById") - .description("Get host information based on ID") - .parameters(JsonObjectSchema.builder() - .description("Host ID") - .addNumberProperty("hostId") - .build()) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - Long hostId = Long.valueOf(arguments.get("hostId").toString()); - HostVO hostVO = hostService.get(hostId); - if (hostVO == null) { - return "Host not found"; - } - return JsonUtils.indentWriteAsString(hostVO); - }; - - return Map.of(toolSpecification, toolExecutor); - } - - public Map getHostByName() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getHostByName") - .description("Get host information based on cluster name") - .parameters(JsonObjectSchema.builder() - .description("Host name") - .addStringProperty("hostName") - .build()) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - String hostName = arguments.get("hostName").toString(); - HostQuery hostQuery = new HostQuery(); - hostQuery.setHostname(hostName); - PageVO hostVO = hostService.list(hostQuery); - return JsonUtils.indentWriteAsString(hostVO); - }; - - return Map.of(toolSpecification, toolExecutor); - } - - public Map getAllFunctions() { - Map functions = new HashMap<>(); - functions.putAll(getHostById()); - functions.putAll(getHostByName()); - return functions; - } -} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/StackFunctions.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/StackFunctions.java deleted file mode 100644 index 93215f99b..000000000 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/functions/StackFunctions.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.functions; - -import org.apache.bigtop.manager.common.utils.JsonUtils; -import org.apache.bigtop.manager.server.model.vo.PropertyVO; -import org.apache.bigtop.manager.server.model.vo.ServiceConfigVO; -import org.apache.bigtop.manager.server.model.vo.ServiceVO; -import org.apache.bigtop.manager.server.model.vo.StackVO; -import org.apache.bigtop.manager.server.service.StackService; - -import org.springframework.stereotype.Component; - -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.service.tool.ToolExecutor; -import lombok.extern.slf4j.Slf4j; - -import jakarta.annotation.Resource; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -@Component -@Slf4j -public class StackFunctions { - @Resource - private StackService stackService; - - public Map listStackAndService() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("listStackAndService") - .description("Retrieve the list of services in each stack") - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map> stackInfo = new HashMap<>(); - for (StackVO stackVO : stackService.list()) { - List services = new ArrayList<>(); - for (ServiceVO serviceVO : stackVO.getServices()) { - services.add(serviceVO.getName()); - } - stackInfo.put(stackVO.getStackName(), services); - } - return JsonUtils.indentWriteAsString(stackInfo); - }; - return Map.of(toolSpecification, toolExecutor); - } - - public Map getServiceByName() { - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("getServiceByName") - .description("Get service information and configs based on service name") - .parameters(JsonObjectSchema.builder() - .addStringProperty("serviceName") - .description("Service name") - .build()) - .build(); - ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> { - Map arguments = JsonUtils.readFromString(toolExecutionRequest.arguments()); - String serviceName = arguments.get("serviceName").toString(); - for (StackVO stackVO : stackService.list()) { - for (ServiceVO serviceVO : stackVO.getServices()) { - if (serviceVO.getName().equals(serviceName)) { - for (ServiceConfigVO serviceConfigVO : serviceVO.getConfigs()) { - for (PropertyVO propertyVO : serviceConfigVO.getProperties()) { - if (propertyVO.getName().equals("content")) { - propertyVO.setValue(null); - } - if (propertyVO.getAttrs() != null - && propertyVO.getAttrs().getType().equals("longtext")) { - propertyVO.setValue(null); - } - } - } - return JsonUtils.indentWriteAsString(serviceVO); - } - } - } - return "Service not found"; - }; - - return Map.of(toolSpecification, toolExecutor); - } - - public Map getAllFunctions() { - Map functions = new HashMap<>(); - functions.putAll(listStackAndService()); - functions.putAll(getServiceByName()); - return functions; - } -} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/provider/AIServiceToolsProvider.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/provider/AIServiceToolsProvider.java deleted file mode 100644 index 72632ba3d..000000000 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/provider/AIServiceToolsProvider.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.provider; - -import org.apache.bigtop.manager.server.enums.ChatbotCommand; - -import org.springframework.stereotype.Component; - -import dev.langchain4j.service.tool.ToolProvider; - -import jakarta.annotation.Resource; - -@Component -public class AIServiceToolsProvider { - @Resource - private InfoToolsProvider infoToolsProvider; - - public ToolProvider getToolsProvide(ChatbotCommand chatbotCommand) { - if (ChatbotCommand.INFO.equals(chatbotCommand)) { - return infoToolsProvider; - } - return null; - } -} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/provider/InfoToolsProvider.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/provider/InfoToolsProvider.java deleted file mode 100644 index 36ec1f92b..000000000 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/provider/InfoToolsProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.provider; - -import org.apache.bigtop.manager.server.tools.functions.ClusterFunctions; -import org.apache.bigtop.manager.server.tools.functions.HostFunctions; -import org.apache.bigtop.manager.server.tools.functions.StackFunctions; - -import org.springframework.stereotype.Component; - -import dev.langchain4j.service.tool.ToolProvider; -import dev.langchain4j.service.tool.ToolProviderRequest; -import dev.langchain4j.service.tool.ToolProviderResult; -import lombok.extern.slf4j.Slf4j; - -import jakarta.annotation.Resource; - -@Component -@Slf4j -public class InfoToolsProvider implements ToolProvider { - @Resource - private ClusterFunctions clusterFunctions; - - @Resource - private HostFunctions hostFunctions; - - @Resource - private StackFunctions stackFunctions; - - @Override - public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) { - return ToolProviderResult.builder() - .addAll(clusterFunctions.getAllFunctions()) - .addAll(hostFunctions.getAllFunctions()) - .addAll(stackFunctions.getAllFunctions()) - .build(); - } -} diff --git a/bigtop-manager-server/src/main/resources/application.yml b/bigtop-manager-server/src/main/resources/application.yml index 860573b97..5cf093202 100644 --- a/bigtop-manager-server/src/main/resources/application.yml +++ b/bigtop-manager-server/src/main/resources/application.yml @@ -25,6 +25,24 @@ spring: type: ASYNC sse-endpoint: /mcp/sse sse-message-endpoint: /mcp/messages + client: + enabled: true + request-timeout-seconds: 120 + init-timeout-seconds: 20 + # Optional multi-connection definition. + # Format: + # name=,type=,baseUrl=,endpoint=,command=,args=,env=,requestTimeoutSeconds=,initTimeoutSeconds=; + # Notes: + # - sse / streamable-http use baseUrl (+ optional endpoint) + # - local uses command/args/env + # - requestTimeoutSeconds/initTimeoutSeconds are optional per-connection overrides + # Example: + # connections: > + # name=embedded,type=sse,baseUrl=http://localhost:8080,endpoint=/mcp/sse; + # name=remote,type=streamable-http,baseUrl=http://127.0.0.1:9000,endpoint=/mcp,requestTimeoutSeconds=180; + # name=memory,type=local,command=npx,args=-y|@modelcontextprotocol/server-memory,env=NODE_ENV:production + connections: > + name=bm,type=sse,baseUrl=http://localhost:8080,endpoint=/mcp/sse banner: charset: utf-8 application: @@ -67,4 +85,4 @@ springdoc: pagehelper: reasonable: false params: count=countSql - support-methods-arguments: true \ No newline at end of file + support-methods-arguments: true diff --git a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/ClusterFunctionsTest.java b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/ClusterFunctionsTest.java deleted file mode 100644 index 6a522da71..000000000 --- a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/ClusterFunctionsTest.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.functions; - -import org.apache.bigtop.manager.server.model.vo.ClusterVO; -import org.apache.bigtop.manager.server.service.ClusterService; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.service.tool.ToolExecutor; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class ClusterFunctionsTest { - - @Mock - private ClusterService clusterService; - - @InjectMocks - private ClusterFunctions clusterFunctions; - - private ClusterVO testCluster; - - @BeforeEach - void setUp() { - testCluster = new ClusterVO(); - testCluster.setId(1L); - testCluster.setName("test-cluster"); - } - - @Test - void testListCluster() { - // Mock clusterService response - when(clusterService.list()).thenReturn(Collections.singletonList(testCluster)); - - // Get the tool specification and executor - Map tools = clusterFunctions.listCluster(); - assertEquals(1, tools.size()); - - ToolSpecification spec = tools.keySet().iterator().next(); - ToolExecutor executor = tools.get(spec); - - // Execute the tool - String result = executor.execute(ToolExecutionRequest.builder().build(), "memoryId"); - - // Verify results - assertTrue(result.contains("test-cluster")); - verify(clusterService, times(1)).list(); - } - - @Test - void testGetClusterById() { - // Mock clusterService response - when(clusterService.get(1L)).thenReturn(testCluster); - - // Get the tool specification and executor - Map tools = clusterFunctions.getClusterById(); - assertEquals(1, tools.size()); - - ToolSpecification spec = tools.keySet().iterator().next(); - ToolExecutor executor = tools.get(spec); - - // Build request with arguments - String arguments = "{\"clusterId\": 1}"; - ToolExecutionRequest request = - ToolExecutionRequest.builder().arguments(arguments).build(); - - // Execute the tool - String result = executor.execute(request, "memoryId"); - - // Verify results - assertTrue(result.contains("test-cluster")); - verify(clusterService, times(1)).get(1L); - } - - @Test - void testGetClusterByIdWhenNotExists() { - // Mock clusterService response - when(clusterService.get(999L)).thenReturn(null); - - // Get the tool specification and executor - Map tools = clusterFunctions.getClusterById(); - ToolExecutor executor = tools.values().iterator().next(); - - // Build request with arguments - String arguments = "{\"clusterId\": 999}"; - ToolExecutionRequest request = - ToolExecutionRequest.builder().arguments(arguments).build(); - - // Execute the tool - String result = executor.execute(request, "memoryId"); - - // Verify results - assertEquals("Cluster not found", result); - } - - @Test - void testGetClusterByName() { - // Mock clusterService response - when(clusterService.list()).thenReturn(Collections.singletonList(testCluster)); - - // Get the tool specification and executor - Map tools = clusterFunctions.getClusterByName(); - ToolExecutor executor = tools.values().iterator().next(); - - // Build request with arguments - String arguments = "{\"clusterName\": \"test-cluster\"}"; - ToolExecutionRequest request = - ToolExecutionRequest.builder().arguments(arguments).build(); - - // Execute the tool - String result = executor.execute(request, "memoryId"); - - // Verify results - assertTrue(result.contains("test-cluster")); - verify(clusterService, times(1)).list(); - } - - @Test - void testGetClusterByNameWhenNotExists() { - // Mock clusterService response - when(clusterService.list()).thenReturn(Collections.singletonList(testCluster)); - - // Get the tool specification and executor - Map tools = clusterFunctions.getClusterByName(); - ToolExecutor executor = tools.values().iterator().next(); - - // Build request with arguments - String arguments = "{\"clusterName\": \"non-existent\"}"; - ToolExecutionRequest request = - ToolExecutionRequest.builder().arguments(arguments).build(); - - // Execute the tool - String result = executor.execute(request, "memoryId"); - - // Verify results - assertEquals("Cluster not found", result); - } - - @Test - void testGetAllFunctions() { - Map functions = clusterFunctions.getAllFunctions(); - assertEquals(3, functions.size()); - - List expectedToolNames = List.of("listCluster", "getClusterById", "getClusterByName"); - assertTrue(functions.keySet().stream().map(ToolSpecification::name).allMatch(expectedToolNames::contains)); - } -} diff --git a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/HostFunctionsTest.java b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/HostFunctionsTest.java deleted file mode 100644 index b73996203..000000000 --- a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/HostFunctionsTest.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.functions; - -import org.apache.bigtop.manager.dao.query.HostQuery; -import org.apache.bigtop.manager.server.model.vo.HostVO; -import org.apache.bigtop.manager.server.model.vo.PageVO; -import org.apache.bigtop.manager.server.service.HostService; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.model.chat.request.json.JsonSchemaElement; -import dev.langchain4j.service.tool.ToolExecutor; - -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class HostFunctionsTest { - - @Mock - private HostService hostService; - - @InjectMocks - private HostFunctions hostFunctions; - - private HostVO testHost; - private PageVO testPage; - - @BeforeEach - void setUp() { - testHost = new HostVO(); - testHost.setId(1L); - testHost.setHostname("test-host"); - - testPage = new PageVO<>(); - testPage.setContent(List.of(testHost)); - testPage.setTotal(1L); - } - - @Test - void testGetHostByIdToolSpecification() { - Map tools = hostFunctions.getHostById(); - assertEquals(1, tools.size()); - - ToolSpecification spec = tools.keySet().iterator().next(); - Map params = spec.parameters().properties(); - - assertEquals(1, params.size()); - assertTrue(params.containsKey("hostId")); - } - - @Test - void testGetHostByIdExecutorFound() throws Exception { - when(hostService.get(1L)).thenReturn(testHost); - - Map tools = hostFunctions.getHostById(); - ToolExecutor executor = tools.values().iterator().next(); - - String arguments = "{\"hostId\": 1}"; - String result = executor.execute( - ToolExecutionRequest.builder().arguments(arguments).build(), null); - - // Use system-independent newline character regex - String expectedPattern = ".*\"hostname\"\\s*:\\s*\"test-host\".*"; - assertTrue( - result.replaceAll("\\R", System.lineSeparator()).matches("(?s)" + expectedPattern), - "Hostname should match with any line separators"); - } - - @Test - void testGetHostByIdExecutorNotFound() { - when(hostService.get(anyLong())).thenReturn(null); - - Map tools = hostFunctions.getHostById(); - ToolExecutor executor = tools.values().iterator().next(); - - String arguments = "{\"hostId\": 999}"; - String result = executor.execute( - ToolExecutionRequest.builder().arguments(arguments).build(), null); - - assertEquals("Host not found", result); - } - - @Test - void testGetHostByNameToolSpecification() { - Map tools = hostFunctions.getHostByName(); - assertEquals(1, tools.size()); - - ToolSpecification spec = tools.keySet().iterator().next(); - assertEquals("getHostByName", spec.name()); - assertEquals("Get host information based on cluster name", spec.description()); - Map params = spec.parameters().properties(); - assertEquals(1, params.size()); - assertTrue(params.containsKey("hostName")); - } - - @Test - void testGetHostByNameExecutor() { - HostQuery query = new HostQuery(); - query.setHostname("test-host"); - when(hostService.list(query)).thenReturn(testPage); - - Map tools = hostFunctions.getHostByName(); - ToolExecutor executor = tools.values().iterator().next(); - - String arguments = "{\"hostName\":\"test-host\"}"; - String result = executor.execute( - ToolExecutionRequest.builder().arguments(arguments).build(), null); - - // System-independent matching pattern - String totalPattern = "(?s).*\"total\"\\s*:\\s*1.*"; - String hostPattern = "(?s).*\"hostname\"\\s*:\\s*\"test-host\".*"; - assertTrue(result.matches(totalPattern), "Should contain total=1"); - assertTrue(result.matches(hostPattern), "Should contain hostname=test-host"); - } - - @Test - void testGetAllFunctions() { - Map functions = hostFunctions.getAllFunctions(); - assertEquals(2, functions.size()); - assertTrue(functions.keySet().stream().anyMatch(s -> s.name().equals("getHostById"))); - assertTrue(functions.keySet().stream().anyMatch(s -> s.name().equals("getHostByName"))); - } -} diff --git a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/StackFunctionsTest.java b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/StackFunctionsTest.java deleted file mode 100644 index 2324ac309..000000000 --- a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/functions/StackFunctionsTest.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.functions; - -import org.apache.bigtop.manager.server.model.vo.AttrsVO; -import org.apache.bigtop.manager.server.model.vo.PropertyVO; -import org.apache.bigtop.manager.server.model.vo.ServiceConfigVO; -import org.apache.bigtop.manager.server.model.vo.ServiceVO; -import org.apache.bigtop.manager.server.model.vo.StackVO; -import org.apache.bigtop.manager.server.service.StackService; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.model.chat.request.json.JsonSchemaElement; -import dev.langchain4j.service.tool.ToolExecutor; - -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class StackFunctionsTest { - - @Mock - private StackService stackService; - - @InjectMocks - private StackFunctions stackFunctions; - - private StackVO testStack; - - @BeforeEach - void setUp() { - // Initialize test data - testStack = new StackVO(); - testStack.setStackName("test-stack"); - - ServiceVO testService = new ServiceVO(); - testService.setName("test-service"); - - ServiceConfigVO config = new ServiceConfigVO(); - PropertyVO normalProp = new PropertyVO(); - normalProp.setName("normal"); - normalProp.setValue("value"); - PropertyVO contentProp = new PropertyVO(); - contentProp.setName("content"); - contentProp.setValue("secret"); - PropertyVO longtextProp = new PropertyVO(); - longtextProp.setName("longtext"); - longtextProp.setValue("very long text"); - AttrsVO attrs = new AttrsVO(); - attrs.setType("longtext"); - longtextProp.setAttrs(attrs); - - config.setProperties(List.of(normalProp, contentProp, longtextProp)); - testService.setConfigs(List.of(config)); - - testStack.setServices(List.of(testService)); - } - - @Test - void testListStackAndService() { - // Mock service layer return data - when(stackService.list()).thenReturn(List.of(testStack)); - - // Get tool - Map tools = stackFunctions.listStackAndService(); - assertEquals(1, tools.size()); - - // Validate tool specification - ToolSpecification spec = tools.keySet().iterator().next(); - assertEquals("listStackAndService", spec.name()); - assertEquals("Retrieve the list of services in each stack", spec.description()); - - // Execute tool - ToolExecutor executor = tools.values().iterator().next(); - String result = - executor.execute(ToolExecutionRequest.builder().arguments("{}").build(), null); - - // Validate result - String expectedJson = - """ - { - "test-stack": ["test-service"] - }"""; - assertEquals(expectedJson.replaceAll("\\s", ""), result.replaceAll("\\s", "")); - } - - @Test - void testGetServiceByNameFound() { - // Mock service layer return data - when(stackService.list()).thenReturn(List.of(testStack)); - - // Get tool - Map tools = stackFunctions.getServiceByName(); - ToolExecutor executor = tools.values().iterator().next(); - - // Execute query - String arguments = "{\"serviceName\" : \"test-service\"}"; - String result = executor.execute( - ToolExecutionRequest.builder().arguments(arguments).build(), null); - - // Validate result - assertAll( - () -> assertTrue(result.contains("\"name\" : \"test-service\"")), - () -> assertTrue(result.contains("\"name\" : \"normal\"")), - () -> assertTrue(result.contains("\"name\" : \"content\"")), - () -> assertTrue(result.contains("\"name\" : \"longtext\""))); - } - - @Test - void testGetServiceByNameNotFound() { - when(stackService.list()).thenReturn(List.of(testStack)); - - Map tools = stackFunctions.getServiceByName(); - ToolExecutor executor = tools.values().iterator().next(); - - String arguments = "{\"serviceName\":\"non-existent\"}"; - String result = executor.execute( - ToolExecutionRequest.builder().arguments(arguments).build(), null); - - assertEquals("Service not found", result); - } - - @Test - void testGetServiceByNameToolSpecification() { - Map tools = stackFunctions.getServiceByName(); - ToolSpecification spec = tools.keySet().iterator().next(); - Map params = spec.parameters().properties(); - assertAll( - () -> assertEquals("getServiceByName", spec.name()), - () -> assertEquals("Get service information and configs based on service name", spec.description()), - () -> assertEquals(1, params.size()), - () -> assertTrue(params.containsKey("serviceName"))); - } - - @Test - void testGetAllFunctions() { - Map functions = stackFunctions.getAllFunctions(); - assertEquals(2, functions.size()); - assertTrue(functions.keySet().stream().anyMatch(s -> s.name().equals("listStackAndService"))); - assertTrue(functions.keySet().stream().anyMatch(s -> s.name().equals("getServiceByName"))); - } -} diff --git a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/provider/AIServiceToolsProviderTest.java b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/provider/AIServiceToolsProviderTest.java deleted file mode 100644 index 475dfe3f6..000000000 --- a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/tools/provider/AIServiceToolsProviderTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.bigtop.manager.server.tools.provider; - -import org.apache.bigtop.manager.server.enums.ChatbotCommand; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import dev.langchain4j.service.tool.ToolProvider; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; - -@ExtendWith(MockitoExtension.class) -public class AIServiceToolsProviderTest { - - @Mock - private InfoToolsProvider infoToolsProvider; - - @InjectMocks - private AIServiceToolsProvider aiServiceToolsProvider; - - @Test - public void testGetToolsProvideWithInfoCommand() { - ToolProvider toolProvider = aiServiceToolsProvider.getToolsProvide(ChatbotCommand.INFO); - assertNotNull(toolProvider); - assertEquals(infoToolsProvider, toolProvider); - } - - @Test - public void testGetToolsProvideWithOtherCommand() { - ToolProvider toolProvider = aiServiceToolsProvider.getToolsProvide(ChatbotCommand.HELP); - assertNull(toolProvider); - } - - @Test - public void testGetToolsProvideWithNullCommand() { - ToolProvider toolProvider = aiServiceToolsProvider.getToolsProvide(null); - assertNull(toolProvider); - } -} diff --git a/bigtop-manager-ui/src/api/ai-assistant/types.ts b/bigtop-manager-ui/src/api/ai-assistant/types.ts index 602974ddc..bcac129c3 100644 --- a/bigtop-manager-ui/src/api/ai-assistant/types.ts +++ b/bigtop-manager-ui/src/api/ai-assistant/types.ts @@ -47,9 +47,22 @@ export interface ChatMessageItem { sender: SenderType message: string createTime?: string + messageType?: 'message' | 'tool' + executionId?: string + toolName?: string + toolStatus?: 'started' | 'completed' | 'failed' + toolPayload?: string + toolInput?: string + toolOutput?: string + toolError?: string } export interface ReceivedMessageItem { content?: string - finishReason?: 'completed' | null + finishReason?: 'completed' | string | null + eventType?: 'tool_execution' | null + executionId?: string + toolName?: string + toolStatus?: 'started' | 'completed' | 'failed' + toolPayload?: string } diff --git a/bigtop-manager-ui/src/api/llm-config/index.ts b/bigtop-manager-ui/src/api/llm-config/index.ts index 4d0702bb5..234231746 100644 --- a/bigtop-manager-ui/src/api/llm-config/index.ts +++ b/bigtop-manager-ui/src/api/llm-config/index.ts @@ -18,7 +18,13 @@ */ import request from '@/api/request.ts' -import { Platform, PlatformCredential, AuthorizedPlatform, UpdateAuthorizedPlatformConfig } from './types' +import { + Platform, + PlatformCredential, + AuthorizedPlatform, + UpdateAuthorizedPlatformConfig, + PlatformModelsReq +} from './types' export const getPlatforms = (): Promise => { return request({ @@ -41,6 +47,14 @@ export const getPlatformCredentials = (platformId: number): Promise => { + return request({ + method: 'post', + url: `/llm/config/platforms/${platformId}/models`, + data + }) +} + export const addAuthorizedPlatform = (data: UpdateAuthorizedPlatformConfig): Promise => { return request({ method: 'post', diff --git a/bigtop-manager-ui/src/api/llm-config/types.ts b/bigtop-manager-ui/src/api/llm-config/types.ts index 0b458b561..c085734e5 100644 --- a/bigtop-manager-ui/src/api/llm-config/types.ts +++ b/bigtop-manager-ui/src/api/llm-config/types.ts @@ -68,3 +68,7 @@ export interface UpdateAuthorizedPlatformConfig extends AuthorizedPlatform { authCredentials: AuthCredential[] testPassed: boolean } + +export interface PlatformModelsReq { + authCredentials: AuthCredential[] +} diff --git a/bigtop-manager-ui/src/features/ai-assistant/chat-message.vue b/bigtop-manager-ui/src/features/ai-assistant/chat-message.vue index 6bd63bd9a..9dc7ff654 100644 --- a/bigtop-manager-ui/src/features/ai-assistant/chat-message.vue +++ b/bigtop-manager-ui/src/features/ai-assistant/chat-message.vue @@ -25,6 +25,19 @@ const props = defineProps() const isUser = computed(() => props.record.sender === 'USER') + const isTool = computed(() => props.record.messageType === 'tool') + const activeKey = ref([]) + const toolOutput = computed(() => props.record.toolOutput || '-') + const toolInput = computed(() => props.record.toolInput || '-') + const toolError = computed(() => props.record.toolError || '-') + + watch( + () => props.record.executionId, + () => { + activeKey.value = [] + }, + { immediate: true } + )