diff --git a/bigtop-manager-ai/pom.xml b/bigtop-manager-ai/pom.xml
index 6eb68ca83..6c1fbbd73 100644
--- a/bigtop-manager-ai/pom.xml
+++ b/bigtop-manager-ai/pom.xml
@@ -67,6 +67,10 @@
org.apache.commons
commons-lang3
+
+ org.springframework.ai
+ spring-ai-starter-mcp-client-webflux
+
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java
index 38951c8ec..7912a7e1b 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java
@@ -20,6 +20,7 @@
import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig;
import org.apache.bigtop.manager.ai.assistant.provider.ChatMemoryStoreProvider;
+import org.apache.bigtop.manager.ai.config.McpAsyncClientManager;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
@@ -34,11 +35,15 @@
import org.springframework.stereotype.Component;
+import lombok.extern.slf4j.Slf4j;
+
import jakarta.annotation.Resource;
import java.util.ArrayList;
import java.util.List;
+import java.util.function.Consumer;
@Component
+@Slf4j
public class GeneralAssistantFactory extends AbstractAIAssistantFactory {
@Resource
@@ -47,6 +52,9 @@ public class GeneralAssistantFactory extends AbstractAIAssistantFactory {
@Resource
private ChatMemoryStoreProvider chatMemoryStoreProvider;
+ @Resource
+ private McpAsyncClientManager mcpAsyncClientManager;
+
private void configureSystemPrompt(AIAssistant.Builder builder, SystemPrompt systemPrompt, String locale) {
List systemPrompts = new ArrayList<>();
if (systemPrompt != null) {
@@ -68,7 +76,15 @@ private AIAssistant.Builder initializeBuilder(PlatformType platformType) {
}
@Override
- public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvider, SystemPrompt systemPrompt) {
+ public AIAssistant createWithPrompt(AIAssistantConfig config, SystemPrompt systemPrompt) {
+ return createWithPrompt(config, systemPrompt, null);
+ }
+
+ @Override
+ public AIAssistant createWithPrompt(
+ AIAssistantConfig config,
+ SystemPrompt systemPrompt,
+ Consumer toolExecutionListener) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
Object id = generalAssistantConfig.getId();
@@ -80,6 +96,15 @@ public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvide
builder.id(id)
.memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore(id))
.withConfig(generalAssistantConfig);
+ builder.withToolExecutionListener(toolExecutionListener);
+
+ List mcpAsyncClients = mcpAsyncClientManager.getClients();
+ if (!mcpAsyncClients.isEmpty()) {
+ log.info("MCP clients available for platform {} (chat), count={}", platformType, mcpAsyncClients.size());
+ builder.withMcpClients(mcpAsyncClients);
+ } else {
+ log.info("MCP client unavailable for platform {} (chat)", platformType);
+ }
configureSystemPrompt(builder, systemPrompt, generalAssistantConfig.getLanguage());
@@ -87,7 +112,7 @@ public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvide
}
@Override
- public AIAssistant createForTest(AIAssistantConfig config, Object toolProvider) {
+ public AIAssistant createForTest(AIAssistantConfig config) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
AIAssistant.Builder builder = initializeBuilder(platformType);
@@ -96,6 +121,34 @@ public AIAssistant createForTest(AIAssistantConfig config, Object toolProvider)
.memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore())
.withConfig(generalAssistantConfig);
+ List mcpAsyncClients = mcpAsyncClientManager.getClients();
+ if (!mcpAsyncClients.isEmpty()) {
+ log.info("MCP clients available for platform {} (test), count={}", platformType, mcpAsyncClients.size());
+ builder.withMcpClients(mcpAsyncClients);
+ } else {
+ log.info("MCP client unavailable for platform {} (test)", platformType);
+ }
+
return builder.build();
}
+
+ @Override
+ public List getModels(AIAssistantConfig config) {
+ GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
+ PlatformType platformType = generalAssistantConfig.getPlatformType();
+ try {
+ AIAssistant.Builder builder = initializeBuilder(platformType);
+ builder.withConfig(generalAssistantConfig);
+ List models = builder.getModels();
+ if (models != null && !models.isEmpty()) {
+ log.info("Fetched {} dynamic models for platform {}.", models.size(), platformType);
+ return models;
+ }
+ } catch (Exception e) {
+ log.warn("Failed to fetch dynamic models from platform {}: {}", platformType, e.getMessage());
+ }
+
+ log.info("No dynamic models for platform {}, fallback to default models.", platformType);
+ return java.util.Collections.emptyList();
+ }
}
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManager.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManager.java
new file mode 100644
index 000000000..90ca8661c
--- /dev/null
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManager.java
@@ -0,0 +1,441 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.bigtop.manager.ai.config;
+
+import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties;
+import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties;
+import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties;
+import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+import org.springframework.util.StringUtils;
+import org.springframework.web.reactive.function.client.WebClient;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.client.McpAsyncClient;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.transport.ServerParameters;
+import io.modelcontextprotocol.client.transport.StdioClientTransport;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
+import io.modelcontextprotocol.json.McpJsonMapper;
+import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper;
+import io.modelcontextprotocol.spec.McpClientTransport;
+import io.modelcontextprotocol.spec.McpSchema;
+import lombok.extern.slf4j.Slf4j;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@Component
+@Slf4j
+public class McpAsyncClientManager {
+
+ @Value("${spring.ai.mcp.client.enabled:false}")
+ private boolean enabled;
+
+ @Value("${spring.ai.mcp.client.request-timeout-seconds:120}")
+ private long requestTimeoutSeconds;
+
+ @Value("${spring.ai.mcp.client.init-timeout-seconds:10}")
+ private long initTimeoutSeconds;
+
+ @Value("${spring.ai.mcp.client.connections:}")
+ private String connections;
+
+ private final ObjectMapper objectMapper = new ObjectMapper();
+ private final McpJsonMapper mcpJsonMapper = new JacksonMcpJsonMapper(objectMapper);
+
+ private final ObjectProvider sseClientPropertiesProvider;
+ private final ObjectProvider streamableHttpClientPropertiesProvider;
+ private final ObjectProvider stdioClientPropertiesProvider;
+
+ private List clients = Collections.emptyList();
+ private boolean initialized = false;
+ private boolean disabledLogged = false;
+
+ public McpAsyncClientManager(
+ ObjectProvider sseClientPropertiesProvider,
+ ObjectProvider streamableHttpClientPropertiesProvider,
+ ObjectProvider stdioClientPropertiesProvider) {
+ this.sseClientPropertiesProvider = sseClientPropertiesProvider;
+ this.streamableHttpClientPropertiesProvider = streamableHttpClientPropertiesProvider;
+ this.stdioClientPropertiesProvider = stdioClientPropertiesProvider;
+ }
+
+ public synchronized McpAsyncClient getClient() {
+ List allClients = getClients();
+ if (allClients.isEmpty()) {
+ return null;
+ }
+ return allClients.get(0);
+ }
+
+ public synchronized List getClients() {
+ if (!enabled) {
+ if (!disabledLogged) {
+ log.info("MCP Async Client is disabled by config (spring.ai.mcp.client.enabled=false)");
+ disabledLogged = true;
+ }
+ return Collections.emptyList();
+ }
+
+ if (!initialized) {
+ initialized = true;
+ clients = initializeClients();
+ log.info("MCP Async Clients initialized: {}", clients.size());
+ }
+
+ return clients;
+ }
+
+ private List initializeClients() {
+ List parsedConnections = parseConnections();
+
+ if (parsedConnections.isEmpty()) {
+ parsedConnections.addAll(fromSpringMcpProperties());
+ }
+
+ if (parsedConnections.isEmpty()) {
+ log.warn(
+ "No MCP connections configured. Please use spring.ai.mcp.client.connections or spring.ai.mcp.client..connections.");
+ return Collections.emptyList();
+ }
+
+ List initializedClients = new ArrayList<>();
+ for (McpConnection connection : parsedConnections) {
+ McpAsyncClient client = initializeClient(connection);
+ if (client != null) {
+ initializedClients.add(client);
+ }
+ }
+ return initializedClients;
+ }
+
+ private List fromSpringMcpProperties() {
+ List conns = new ArrayList<>();
+
+ McpSseClientProperties sseProperties = sseClientPropertiesProvider.getIfAvailable();
+ if (sseProperties != null && sseProperties.getConnections() != null) {
+ sseProperties.getConnections().forEach((name, params) -> {
+ if (params != null) {
+ conns.add(McpConnection.sse(name, params.url(), params.sseEndpoint()));
+ }
+ });
+ }
+
+ McpStreamableHttpClientProperties streamableHttpProperties =
+ streamableHttpClientPropertiesProvider.getIfAvailable();
+ if (streamableHttpProperties != null && streamableHttpProperties.getConnections() != null) {
+ streamableHttpProperties.getConnections().forEach((name, params) -> {
+ if (params != null) {
+ conns.add(McpConnection.streamableHttp(name, params.url(), params.endpoint()));
+ }
+ });
+ }
+
+ McpStdioClientProperties stdioProperties = stdioClientPropertiesProvider.getIfAvailable();
+ if (stdioProperties != null && stdioProperties.getConnections() != null) {
+ stdioProperties.getConnections().forEach((name, params) -> {
+ if (params != null) {
+ conns.add(McpConnection.local(name, params.command(), params.args(), params.env()));
+ }
+ });
+ }
+
+ return conns;
+ }
+
+ private McpAsyncClient initializeClient(McpConnection connection) {
+ try {
+ McpClientTransport transport = buildTransport(connection);
+ if (transport == null) {
+ log.warn("Skip MCP connection {} due to unsupported type {}", connection.name, connection.type);
+ return null;
+ }
+
+ Duration requestTimeout = Duration.ofSeconds(Math.max(
+ connection.requestTimeoutSeconds > 0 ? connection.requestTimeoutSeconds : requestTimeoutSeconds,
+ 1));
+ Duration initTimeout = Duration.ofSeconds(Math.max(
+ connection.initTimeoutSeconds > 0 ? connection.initTimeoutSeconds : initTimeoutSeconds, 1));
+
+ McpAsyncClient client =
+ McpClient.async(transport).requestTimeout(requestTimeout).build();
+
+ client.initialize().block(initTimeout);
+ client.ping().block(initTimeout);
+
+ log.info(
+ "MCP Async Client [{}] initialized, type={}, baseUrl={}, endpoint={}, requestTimeout={}, initTimeout={}",
+ connection.name,
+ connection.type,
+ connection.baseUrl,
+ connection.endpoint,
+ requestTimeout,
+ initTimeout);
+ logRegisteredTools(connection.name, client, initTimeout);
+ return client;
+ } catch (Exception e) {
+ log.warn("Failed to initialize MCP client [{}]: {}", connection.name, e.getMessage());
+ return null;
+ }
+ }
+
+ private McpClientTransport buildTransport(McpConnection connection) {
+ String type = connection.type.toLowerCase();
+ return switch (type) {
+ case "sse" -> buildWebFluxSseTransport(connection);
+ case "streamable-http", "http", "http-streamable" -> buildWebFluxStreamableHttpTransport(connection);
+ case "local", "stdio" -> buildStdioTransport(connection);
+ default -> null;
+ };
+ }
+
+ private McpClientTransport buildWebFluxSseTransport(McpConnection connection) {
+ String baseUrl = connection.baseUrl;
+ if (!StringUtils.hasText(baseUrl)) {
+ log.warn("MCP SSE connection {} missing baseUrl", connection.name);
+ return null;
+ }
+ String endpoint = StringUtils.hasText(connection.endpoint) ? connection.endpoint : "/mcp/sse";
+ WebClient.Builder webClientBuilder = WebClient.builder().baseUrl(baseUrl);
+ return WebFluxSseClientTransport.builder(webClientBuilder)
+ .jsonMapper(mcpJsonMapper)
+ .sseEndpoint(endpoint)
+ .build();
+ }
+
+ private McpClientTransport buildWebFluxStreamableHttpTransport(McpConnection connection) {
+ String baseUrl = connection.baseUrl;
+ if (!StringUtils.hasText(baseUrl)) {
+ log.warn("MCP streamable-http connection {} missing baseUrl", connection.name);
+ return null;
+ }
+ String endpoint = StringUtils.hasText(connection.endpoint) ? connection.endpoint : "/mcp";
+ WebClient.Builder webClientBuilder = WebClient.builder().baseUrl(baseUrl);
+ return WebClientStreamableHttpTransport.builder(webClientBuilder)
+ .jsonMapper(mcpJsonMapper)
+ .endpoint(endpoint)
+ .openConnectionOnStartup(true)
+ .build();
+ }
+
+ private McpClientTransport buildStdioTransport(McpConnection connection) {
+ if (!StringUtils.hasText(connection.command)) {
+ log.warn("MCP stdio/local connection {} missing command", connection.name);
+ return null;
+ }
+
+ ServerParameters.Builder serverBuilder = ServerParameters.builder(connection.command);
+ if (!connection.args.isEmpty()) {
+ serverBuilder.args(connection.args);
+ }
+ if (!connection.env.isEmpty()) {
+ serverBuilder.env(connection.env);
+ }
+
+ StdioClientTransport stdioTransport = new StdioClientTransport(serverBuilder.build(), mcpJsonMapper);
+ stdioTransport.setStdErrorHandler(line -> log.warn("MCP stdio [{}] stderr: {}", connection.name, line));
+ return stdioTransport;
+ }
+
+ private void logRegisteredTools(String connectionName, McpAsyncClient client, Duration initTimeout) {
+ try {
+ List toolNames = new ArrayList<>();
+ String cursor = null;
+ while (true) {
+ McpSchema.ListToolsResult listToolsResult = cursor == null || cursor.isBlank()
+ ? client.listTools().block(initTimeout)
+ : client.listTools(cursor).block(initTimeout);
+ if (listToolsResult == null
+ || listToolsResult.tools() == null
+ || listToolsResult.tools().isEmpty()) {
+ break;
+ }
+
+ for (McpSchema.Tool tool : listToolsResult.tools()) {
+ if (tool != null && tool.name() != null) {
+ toolNames.add(tool.name());
+ }
+ }
+
+ String nextCursor = listToolsResult.nextCursor();
+ if (!StringUtils.hasText(nextCursor) || nextCursor.equals(cursor)) {
+ break;
+ }
+ cursor = nextCursor;
+ }
+
+ log.info("MCP tools discovered for [{}]: count={}, names={}", connectionName, toolNames.size(), toolNames);
+ } catch (Exception e) {
+ log.warn("Failed to list MCP tools for [{}]: {}", connectionName, e.getMessage());
+ }
+ }
+
+ private List parseConnections() {
+ if (!StringUtils.hasText(connections)) {
+ return new ArrayList<>();
+ }
+
+ List parsed = new ArrayList<>();
+ String[] segments = connections.split(";");
+ for (int i = 0; i < segments.length; i++) {
+ String segment = segments[i].trim();
+ if (!StringUtils.hasText(segment)) {
+ continue;
+ }
+
+ Map kv = new HashMap<>();
+ String[] pairs = segment.split(",");
+ for (String pair : pairs) {
+ String[] items = pair.split("=", 2);
+ if (items.length != 2) {
+ continue;
+ }
+ kv.put(items[0].trim().toLowerCase(), items[1].trim());
+ }
+
+ String name = valueOrDefault(kv.get("name"), "conn-" + (i + 1));
+ String type = valueOrDefault(kv.get("type"), "sse");
+ String baseUrl = kv.get("baseurl");
+ String endpoint = kv.get("endpoint");
+ String command = kv.get("command");
+ List args = parseList(kv.get("args"));
+ Map env = parseEnv(kv.get("env"));
+ long connectionRequestTimeout = parseLongOrDefault(kv.get("requesttimeoutseconds"), -1);
+ long connectionInitTimeout = parseLongOrDefault(kv.get("inittimeoutseconds"), -1);
+
+ parsed.add(new McpConnection(
+ name,
+ type,
+ baseUrl,
+ endpoint,
+ command,
+ args,
+ env,
+ connectionRequestTimeout,
+ connectionInitTimeout));
+ }
+
+ return parsed;
+ }
+
+ private static long parseLongOrDefault(String raw, long defaultValue) {
+ if (!StringUtils.hasText(raw)) {
+ return defaultValue;
+ }
+ try {
+ return Long.parseLong(raw);
+ } catch (NumberFormatException e) {
+ return defaultValue;
+ }
+ }
+
+ private static List parseList(String raw) {
+ if (!StringUtils.hasText(raw)) {
+ return Collections.emptyList();
+ }
+ List list = new ArrayList<>();
+ for (String item : raw.split("\\|")) {
+ String trimmed = item.trim();
+ if (StringUtils.hasText(trimmed)) {
+ list.add(trimmed);
+ }
+ }
+ return list;
+ }
+
+ private static Map parseEnv(String raw) {
+ if (!StringUtils.hasText(raw)) {
+ return Collections.emptyMap();
+ }
+ Map env = new HashMap<>();
+ for (String item : raw.split("\\|")) {
+ String[] pair = item.split(":", 2);
+ if (pair.length == 2 && StringUtils.hasText(pair[0])) {
+ env.put(pair[0].trim(), pair[1].trim());
+ }
+ }
+ return env;
+ }
+
+ private static String valueOrDefault(String value, String defaultValue) {
+ return StringUtils.hasText(value) ? value : defaultValue;
+ }
+
+ private static class McpConnection {
+ private final String name;
+ private final String type;
+ private final String baseUrl;
+ private final String endpoint;
+ private final String command;
+ private final List args;
+ private final Map env;
+ private final long requestTimeoutSeconds;
+ private final long initTimeoutSeconds;
+
+ private McpConnection(
+ String name,
+ String type,
+ String baseUrl,
+ String endpoint,
+ String command,
+ List args,
+ Map env,
+ long requestTimeoutSeconds,
+ long initTimeoutSeconds) {
+ this.name = name;
+ this.type = type;
+ this.baseUrl = baseUrl;
+ this.endpoint = endpoint;
+ this.command = command;
+ this.args = args == null ? Collections.emptyList() : args;
+ this.env = env == null ? Collections.emptyMap() : env;
+ this.requestTimeoutSeconds = requestTimeoutSeconds;
+ this.initTimeoutSeconds = initTimeoutSeconds;
+ }
+
+ private static McpConnection sse(String name, String baseUrl, String endpoint) {
+ return new McpConnection(
+ name, "sse", baseUrl, endpoint, null, Collections.emptyList(), Collections.emptyMap(), -1, -1);
+ }
+
+ private static McpConnection streamableHttp(String name, String baseUrl, String endpoint) {
+ return new McpConnection(
+ name,
+ "streamable-http",
+ baseUrl,
+ endpoint,
+ null,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ -1,
+ -1);
+ }
+
+ private static McpConnection local(String name, String command, List args, Map env) {
+ return new McpConnection(name, "local", null, null, command, args, env, -1, -1);
+ }
+ }
+}
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java
index 30ec7ab62..39c52b069 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java
@@ -24,9 +24,18 @@
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
+import org.springframework.web.reactive.function.client.WebClient;
+import com.fasterxml.jackson.databind.JsonNode;
import reactor.core.publisher.Flux;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
public abstract class AbstractAIAssistant implements AIAssistant {
protected final AIAssistant.Service aiServices;
protected static final Integer MEMORY_LEN = 10;
@@ -67,6 +76,12 @@ public abstract static class Builder implements AIAssistant.Builder {
protected String systemPrompt;
+ protected io.modelcontextprotocol.client.McpAsyncClient mcpAsyncClient;
+ protected List mcpAsyncClients = new ArrayList<>();
+ protected Consumer toolExecutionListener;
+
+ private static final String AUTHORIZATION_HEADER = "Authorization";
+
public Builder() {}
public Builder withSystemPrompt(String systemPrompt) {
@@ -74,6 +89,12 @@ public Builder withSystemPrompt(String systemPrompt) {
return this;
}
+ @Override
+ public Builder withToolExecutionListener(Consumer toolExecutionListener) {
+ this.toolExecutionListener = toolExecutionListener;
+ return this;
+ }
+
public Builder withConfig(AIAssistantConfig config) {
this.config = config;
return this;
@@ -89,6 +110,44 @@ public Builder memoryStore(ChatMemory chatMemory) {
return this;
}
+ public Builder withMcpClient(io.modelcontextprotocol.client.McpAsyncClient mcpAsyncClient) {
+ this.mcpAsyncClient = mcpAsyncClient;
+ if (mcpAsyncClient != null) {
+ this.mcpAsyncClients = List.of(mcpAsyncClient);
+ }
+ return this;
+ }
+
+ @Override
+ public Builder withMcpClients(List mcpAsyncClients) {
+ if (mcpAsyncClients == null || mcpAsyncClients.isEmpty()) {
+ this.mcpAsyncClients = Collections.emptyList();
+ this.mcpAsyncClient = null;
+ return this;
+ }
+
+ this.mcpAsyncClients = List.copyOf(mcpAsyncClients);
+ this.mcpAsyncClient = this.mcpAsyncClients.get(0);
+ return this;
+ }
+
+ protected List getMcpAsyncClients() {
+ if (mcpAsyncClients != null && !mcpAsyncClients.isEmpty()) {
+ return mcpAsyncClients;
+ }
+ if (mcpAsyncClient != null) {
+ return List.of(mcpAsyncClient);
+ }
+ return Collections.emptyList();
+ }
+
+ protected void emitToolExecutionEvent(String executionId, String toolName, String status, String payload) {
+ if (toolExecutionListener != null) {
+ toolExecutionListener.accept(
+ new AIAssistant.ToolExecutionEvent(executionId, toolName, status, payload));
+ }
+ }
+
public ChatMemory getChatMemory() {
if (chatMemory == null) {
chatMemory = MessageWindowChatMemory.builder()
@@ -97,5 +156,81 @@ public ChatMemory getChatMemory() {
}
return chatMemory;
}
+
+ protected String resolveModelsBaseUrl() {
+ return null;
+ }
+
+ protected String resolveModelsPath() {
+ return "/v1/models";
+ }
+
+ protected String resolveApiKey(Map credentials) {
+ if (credentials == null) {
+ return null;
+ }
+ String apiKey = credentials.get("apiKey");
+ if (apiKey == null) {
+ return null;
+ }
+ apiKey = apiKey.trim();
+ if (apiKey.startsWith("Bearer ")) {
+ apiKey = apiKey.substring("Bearer ".length()).trim();
+ }
+ return apiKey;
+ }
+
+ protected void applyModelRequestAuth(WebClient.RequestHeadersSpec> requestSpec, String apiKey) {
+ if (apiKey != null && !apiKey.isBlank()) {
+ requestSpec.header(AUTHORIZATION_HEADER, "Bearer " + apiKey);
+ }
+ }
+
+ protected List parseModelsResponse(JsonNode response) {
+ if (response == null || !response.has("data")) {
+ return Collections.emptyList();
+ }
+ List models = new ArrayList<>();
+ for (JsonNode node : response.get("data")) {
+ JsonNode idNode = node.get("id");
+ if (idNode != null && !idNode.isNull()) {
+ models.add(idNode.asText());
+ }
+ }
+ return models;
+ }
+
+ @Override
+ public List getModels() {
+ String baseUrl = resolveModelsBaseUrl();
+ if (baseUrl == null || baseUrl.isBlank()) {
+ return Collections.emptyList();
+ }
+
+ String path = resolveModelsPath();
+ if (path == null || path.isBlank()) {
+ path = "/v1/models";
+ }
+
+ Map credentials = config == null ? Collections.emptyMap() : config.getCredentials();
+ String apiKey = resolveApiKey(credentials);
+
+ try {
+ WebClient webClient =
+ WebClient.builder().baseUrl(baseUrl.trim()).build();
+ WebClient.RequestHeadersSpec> requestSpec = webClient.get().uri(path);
+ applyModelRequestAuth(requestSpec, apiKey);
+
+ JsonNode response = requestSpec
+ .retrieve()
+ .bodyToMono(JsonNode.class)
+ .timeout(Duration.ofSeconds(10))
+ .block();
+
+ return parseModelsResponse(response);
+ } catch (Exception ignored) {
+ return Collections.emptyList();
+ }
+ }
}
}
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java
index ff2b96f5c..0898dcbb0 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java
@@ -27,6 +27,9 @@
import reactor.core.publisher.Flux;
+import java.util.List;
+import java.util.function.Consumer;
+
public interface AIAssistant {
/**
@@ -75,8 +78,14 @@ interface Builder {
Builder withConfig(AIAssistantConfig configProvider);
+ Builder withMcpClient(io.modelcontextprotocol.client.McpAsyncClient mcpAsyncClient);
+
+ Builder withMcpClients(List mcpAsyncClients);
+
Builder withSystemPrompt(String systemPrompt);
+ Builder withToolExecutionListener(Consumer toolExecutionListener);
+
AIAssistant build();
ChatModel getChatModel();
@@ -84,5 +93,9 @@ interface Builder {
StreamingChatModel getStreamingChatModel();
ChatMemory getChatMemory();
+
+ List getModels();
}
+
+ record ToolExecutionEvent(String executionId, String toolName, String status, String payload) {}
}
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java
index 06e1dafef..d97f40081 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java
@@ -21,13 +21,30 @@
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
+import java.util.List;
+import java.util.function.Consumer;
+
public interface AIAssistantFactory {
- AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvider, SystemPrompt systemPrompt);
+ AIAssistant createWithPrompt(AIAssistantConfig config, SystemPrompt systemPrompt);
+
+ AIAssistant createForTest(AIAssistantConfig config);
+
+ default AIAssistant createAIService(AIAssistantConfig config) {
+ return createWithPrompt(config, SystemPrompt.DEFAULT_PROMPT);
+ }
- AIAssistant createForTest(AIAssistantConfig config, Object toolProvider);
+ default AIAssistant createAIService(
+ AIAssistantConfig config, Consumer toolExecutionListener) {
+ return createWithPrompt(config, SystemPrompt.DEFAULT_PROMPT, toolExecutionListener);
+ }
- default AIAssistant createAIService(AIAssistantConfig config, Object toolProvider) {
- return createWithPrompt(config, toolProvider, SystemPrompt.DEFAULT_PROMPT);
+ default AIAssistant createWithPrompt(
+ AIAssistantConfig config,
+ SystemPrompt systemPrompt,
+ Consumer toolExecutionListener) {
+ return createWithPrompt(config, systemPrompt);
}
+
+ List getModels(AIAssistantConfig config);
}
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java
index 6661e6e02..b4053ea4c 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DashScopeAssistant.java
@@ -29,18 +29,23 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
+import java.util.UUID;
public class DashScopeAssistant extends AbstractAIAssistant {
+ private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_DASHSCOPE_BASE_URL";
private static final String BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode";
public DashScopeAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) {
@@ -58,6 +63,19 @@ public static Builder builder() {
public static class Builder extends AbstractAIAssistant.Builder {
+ @Override
+ protected String resolveModelsBaseUrl() {
+ return resolveDefaultBaseUrl();
+ }
+
+ private String resolveDefaultBaseUrl() {
+ String envBaseUrl = System.getenv(BASE_URL_ENV_KEY);
+ if (envBaseUrl != null && !envBaseUrl.isBlank()) {
+ return envBaseUrl;
+ }
+ return BASE_URL;
+ }
+
@Override
public ChatModel getChatModel() {
String model = config.getModel();
@@ -65,9 +83,17 @@ public ChatModel getChatModel() {
String apiKey = config.getCredentials().get("apiKey");
Assert.notNull(apiKey, "apiKey must not be null");
- OpenAiApi openAiApi =
- OpenAiApi.builder().baseUrl(BASE_URL).apiKey(apiKey).build();
- OpenAiChatOptions options = OpenAiChatOptions.builder().model(model).build();
+ OpenAiApi openAiApi = OpenAiApi.builder()
+ .baseUrl(resolveDefaultBaseUrl())
+ .apiKey(apiKey)
+ .build();
+ OpenAiChatOptions.Builder optionsBuilder =
+ OpenAiChatOptions.builder().model(model);
+ List mcpClients = getMcpAsyncClients();
+ if (!mcpClients.isEmpty()) {
+ optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients));
+ }
+ OpenAiChatOptions options = optionsBuilder.build();
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(options)
@@ -80,6 +106,45 @@ public StreamingChatModel getStreamingChatModel() {
return getChatModel();
}
+ private ToolCallback[] buildObservedToolCallbacks(
+ List mcpClients) {
+ ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks();
+ ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length];
+ for (int i = 0; i < callbacks.length; i++) {
+ observedCallbacks[i] = wrapToolCallback(callbacks[i]);
+ }
+ return observedCallbacks;
+ }
+
+ private ToolCallback wrapToolCallback(ToolCallback delegate) {
+ return new ToolCallback() {
+ @Override
+ public ToolDefinition getToolDefinition() {
+ return delegate.getToolDefinition();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return call(toolInput, null);
+ }
+
+ @Override
+ public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
+ String toolName = getToolDefinition().name();
+ String executionId = UUID.randomUUID().toString();
+ emitToolExecutionEvent(executionId, toolName, "started", toolInput);
+ try {
+ String result = delegate.call(toolInput, toolContext);
+ emitToolExecutionEvent(executionId, toolName, "completed", result);
+ return result;
+ } catch (Exception e) {
+ emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage());
+ throw e;
+ }
+ }
+ };
+ }
+
public AIAssistant build() {
ChatModel chatModel = getChatModel();
StreamingChatModel streamingChatModel = getStreamingChatModel();
@@ -132,13 +197,18 @@ public Flux streamChat(String userMessage) {
StringBuilder responseBuilder = new StringBuilder();
return streamingChatModel.stream(prompt)
- .map(chatResponse -> {
- String content =
- chatResponse.getResult().getOutput().getText();
- if (content != null) {
+ .concatMap(chatResponse -> {
+ String content = null;
+ if (chatResponse.getResult() != null
+ && chatResponse.getResult().getOutput() != null) {
+ content =
+ chatResponse.getResult().getOutput().getText();
+ }
+ if (content != null && !content.isEmpty()) {
responseBuilder.append(content);
+ return Flux.just(content);
}
- return content;
+ return Flux.empty();
})
.doOnComplete(() -> {
// Save to memory when streaming completes
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java
index 962ad8258..17f3286cd 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/DeepSeekAssistant.java
@@ -33,15 +33,22 @@
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.DeepSeekApi;
+import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
+import java.util.UUID;
public class DeepSeekAssistant extends AbstractAIAssistant {
+ private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_DEEPSEEK_BASE_URL";
+ private static final String BASE_URL = "https://api.deepseek.com";
+
public DeepSeekAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) {
super(memoryId, chatMemory, aiServices);
}
@@ -57,6 +64,19 @@ public static Builder builder() {
public static class Builder extends AbstractAIAssistant.Builder {
+ @Override
+ protected String resolveModelsBaseUrl() {
+ return resolveDefaultBaseUrl();
+ }
+
+ private String resolveDefaultBaseUrl() {
+ String envBaseUrl = System.getenv(BASE_URL_ENV_KEY);
+ if (envBaseUrl != null && !envBaseUrl.isBlank()) {
+ return envBaseUrl;
+ }
+ return BASE_URL;
+ }
+
@Override
public ChatModel getChatModel() {
String model = config.getModel();
@@ -64,9 +84,17 @@ public ChatModel getChatModel() {
String apiKey = config.getCredentials().get("apiKey");
Assert.notNull(apiKey, "apiKey must not be null");
- DeepSeekApi deepSeekApi = DeepSeekApi.builder().apiKey(apiKey).build();
- DeepSeekChatOptions options =
- DeepSeekChatOptions.builder().model(model).build();
+ DeepSeekApi deepSeekApi = DeepSeekApi.builder()
+ .baseUrl(resolveDefaultBaseUrl())
+ .apiKey(apiKey)
+ .build();
+ DeepSeekChatOptions.Builder optionsBuilder =
+ DeepSeekChatOptions.builder().model(model);
+ List mcpClients = getMcpAsyncClients();
+ if (!mcpClients.isEmpty()) {
+ optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients));
+ }
+ DeepSeekChatOptions options = optionsBuilder.build();
return DeepSeekChatModel.builder()
.deepSeekApi(deepSeekApi)
.defaultOptions(options)
@@ -79,6 +107,45 @@ public StreamingChatModel getStreamingChatModel() {
return getChatModel();
}
+ private ToolCallback[] buildObservedToolCallbacks(
+ List mcpClients) {
+ ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks();
+ ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length];
+ for (int i = 0; i < callbacks.length; i++) {
+ observedCallbacks[i] = wrapToolCallback(callbacks[i]);
+ }
+ return observedCallbacks;
+ }
+
+ private ToolCallback wrapToolCallback(ToolCallback delegate) {
+ return new ToolCallback() {
+ @Override
+ public ToolDefinition getToolDefinition() {
+ return delegate.getToolDefinition();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return call(toolInput, null);
+ }
+
+ @Override
+ public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
+ String toolName = getToolDefinition().name();
+ String executionId = UUID.randomUUID().toString();
+ emitToolExecutionEvent(executionId, toolName, "started", toolInput);
+ try {
+ String result = delegate.call(toolInput, toolContext);
+ emitToolExecutionEvent(executionId, toolName, "completed", result);
+ return result;
+ } catch (Exception e) {
+ emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage());
+ throw e;
+ }
+ }
+ };
+ }
+
public AIAssistant build() {
ChatModel chatModel = getChatModel();
StreamingChatModel streamingChatModel = getStreamingChatModel();
@@ -131,13 +198,18 @@ public Flux streamChat(String userMessage) {
StringBuilder responseBuilder = new StringBuilder();
return streamingChatModel.stream(prompt)
- .map(chatResponse -> {
- String content =
- chatResponse.getResult().getOutput().getText();
- if (content != null) {
+ .concatMap(chatResponse -> {
+ String content = null;
+ if (chatResponse.getResult() != null
+ && chatResponse.getResult().getOutput() != null) {
+ content =
+ chatResponse.getResult().getOutput().getText();
+ }
+ if (content != null && !content.isEmpty()) {
responseBuilder.append(content);
+ return Flux.just(content);
}
- return content;
+ return Flux.empty();
})
.doOnComplete(() -> {
// Save to memory when streaming completes
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java
index b51fa75ef..c6fa5e2c2 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/OpenAIAssistant.java
@@ -29,18 +29,24 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+import java.util.UUID;
public class OpenAIAssistant extends AbstractAIAssistant {
+ private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_OPENAI_BASE_URL";
private static final String BASE_URL = "https://api.openai.com";
public OpenAIAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) {
@@ -58,6 +64,26 @@ public static Builder builder() {
public static class Builder extends AbstractAIAssistant.Builder {
+ @Override
+ protected String resolveModelsBaseUrl() {
+ Map credentials = config == null ? null : config.getCredentials();
+ if (credentials != null) {
+ String baseUrl = credentials.get("baseUrl");
+ if (baseUrl != null && !baseUrl.isBlank()) {
+ return baseUrl;
+ }
+ }
+ return resolveDefaultBaseUrl();
+ }
+
+ private String resolveDefaultBaseUrl() {
+ String envBaseUrl = System.getenv(BASE_URL_ENV_KEY);
+ if (envBaseUrl != null && !envBaseUrl.isBlank()) {
+ return envBaseUrl;
+ }
+ return BASE_URL;
+ }
+
@Override
public ChatModel getChatModel() {
String model = config.getModel();
@@ -65,9 +91,17 @@ public ChatModel getChatModel() {
String apiKey = config.getCredentials().get("apiKey");
Assert.notNull(apiKey, "apiKey must not be null");
- OpenAiApi openAiApi =
- OpenAiApi.builder().baseUrl(BASE_URL).apiKey(apiKey).build();
- OpenAiChatOptions options = OpenAiChatOptions.builder().model(model).build();
+ OpenAiApi openAiApi = OpenAiApi.builder()
+ .baseUrl(resolveDefaultBaseUrl())
+ .apiKey(apiKey)
+ .build();
+ OpenAiChatOptions.Builder optionsBuilder =
+ OpenAiChatOptions.builder().model(model);
+ List mcpClients = getMcpAsyncClients();
+ if (!mcpClients.isEmpty()) {
+ optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients));
+ }
+ OpenAiChatOptions options = optionsBuilder.build();
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(options)
@@ -80,6 +114,45 @@ public StreamingChatModel getStreamingChatModel() {
return getChatModel();
}
+ private ToolCallback[] buildObservedToolCallbacks(
+ List mcpClients) {
+ ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks();
+ ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length];
+ for (int i = 0; i < callbacks.length; i++) {
+ observedCallbacks[i] = wrapToolCallback(callbacks[i]);
+ }
+ return observedCallbacks;
+ }
+
+ private ToolCallback wrapToolCallback(ToolCallback delegate) {
+ return new ToolCallback() {
+ @Override
+ public ToolDefinition getToolDefinition() {
+ return delegate.getToolDefinition();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return call(toolInput, null);
+ }
+
+ @Override
+ public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
+ String toolName = getToolDefinition().name();
+ String executionId = UUID.randomUUID().toString();
+ emitToolExecutionEvent(executionId, toolName, "started", toolInput);
+ try {
+ String result = delegate.call(toolInput, toolContext);
+ emitToolExecutionEvent(executionId, toolName, "completed", result);
+ return result;
+ } catch (Exception e) {
+ emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage());
+ throw e;
+ }
+ }
+ };
+ }
+
public AIAssistant build() {
ChatModel chatModel = getChatModel();
StreamingChatModel streamingChatModel = getStreamingChatModel();
@@ -132,13 +205,18 @@ public Flux streamChat(String userMessage) {
StringBuilder responseBuilder = new StringBuilder();
return streamingChatModel.stream(prompt)
- .map(chatResponse -> {
- String content =
- chatResponse.getResult().getOutput().getText();
- if (content != null) {
+ .concatMap(chatResponse -> {
+ String content = null;
+ if (chatResponse.getResult() != null
+ && chatResponse.getResult().getOutput() != null) {
+ content =
+ chatResponse.getResult().getOutput().getText();
+ }
+ if (content != null && !content.isEmpty()) {
responseBuilder.append(content);
+ return Flux.just(content);
}
- return content;
+ return Flux.empty();
})
.doOnComplete(() -> {
// Save to memory when streaming completes
diff --git a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java
index 468020eb1..65b1782a1 100644
--- a/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java
+++ b/bigtop-manager-ai/src/main/java/org/apache/bigtop/manager/ai/platform/QianFanAssistant.java
@@ -29,18 +29,26 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.tool.ToolCallback;
+import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
+import org.springframework.web.reactive.function.client.WebClient;
+import com.fasterxml.jackson.databind.JsonNode;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
+import java.util.UUID;
public class QianFanAssistant extends AbstractAIAssistant {
+ private static final String BASE_URL_ENV_KEY = "BIGTOP_MANAGER_AI_QIANFAN_BASE_URL";
private static final String BASE_URL = "https://qianfan.baidubce.com";
public QianFanAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) {
@@ -58,6 +66,60 @@ public static Builder builder() {
public static class Builder extends AbstractAIAssistant.Builder {
+ @Override
+ protected String resolveModelsBaseUrl() {
+ return resolveDefaultBaseUrl();
+ }
+
+ private String resolveDefaultBaseUrl() {
+ String envBaseUrl = System.getenv(BASE_URL_ENV_KEY);
+ if (envBaseUrl != null && !envBaseUrl.isBlank()) {
+ return envBaseUrl;
+ }
+ return BASE_URL;
+ }
+
+ @Override
+ protected String resolveModelsPath() {
+ return "/v2/chat/models";
+ }
+
+ @Override
+ public List getModels() {
+ String apiKey = resolveApiKey(config == null ? null : config.getCredentials());
+ if (apiKey == null || apiKey.isBlank()) {
+ return Collections.emptyList();
+ }
+
+ try {
+ WebClient webClient =
+ WebClient.builder().baseUrl(resolveModelsBaseUrl()).build();
+ JsonNode response = webClient
+ .get()
+ .uri(resolveModelsPath())
+ .header("Authorization", "Bearer " + apiKey)
+ .retrieve()
+ .bodyToMono(JsonNode.class)
+ .timeout(java.time.Duration.ofSeconds(10))
+ .block();
+
+ if (response == null || !response.has("result")) {
+ return Collections.emptyList();
+ }
+
+ List models = new ArrayList<>();
+ for (JsonNode modelNode : response.get("result")) {
+ JsonNode modelId = modelNode.get("model");
+ if (modelId != null && !modelId.isNull()) {
+ models.add(modelId.asText());
+ }
+ }
+ return models;
+ } catch (Exception ignored) {
+ return Collections.emptyList();
+ }
+ }
+
@Override
public ChatModel getChatModel() {
String model = config.getModel();
@@ -66,11 +128,17 @@ public ChatModel getChatModel() {
Assert.notNull(apiKey, "apiKey must not be null");
OpenAiApi openAiApi = OpenAiApi.builder()
- .baseUrl(BASE_URL)
+ .baseUrl(resolveDefaultBaseUrl())
.completionsPath("/v2/chat/completions")
.apiKey(apiKey)
.build();
- OpenAiChatOptions options = OpenAiChatOptions.builder().model(model).build();
+ OpenAiChatOptions.Builder optionsBuilder =
+ OpenAiChatOptions.builder().model(model);
+ List mcpClients = getMcpAsyncClients();
+ if (!mcpClients.isEmpty()) {
+ optionsBuilder.toolCallbacks(buildObservedToolCallbacks(mcpClients));
+ }
+ OpenAiChatOptions options = optionsBuilder.build();
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(options)
@@ -82,6 +150,45 @@ public StreamingChatModel getStreamingChatModel() {
return getChatModel();
}
+ private ToolCallback[] buildObservedToolCallbacks(
+ List mcpClients) {
+ ToolCallback[] callbacks = new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks();
+ ToolCallback[] observedCallbacks = new ToolCallback[callbacks.length];
+ for (int i = 0; i < callbacks.length; i++) {
+ observedCallbacks[i] = wrapToolCallback(callbacks[i]);
+ }
+ return observedCallbacks;
+ }
+
+ private ToolCallback wrapToolCallback(ToolCallback delegate) {
+ return new ToolCallback() {
+ @Override
+ public ToolDefinition getToolDefinition() {
+ return delegate.getToolDefinition();
+ }
+
+ @Override
+ public String call(String toolInput) {
+ return call(toolInput, null);
+ }
+
+ @Override
+ public String call(String toolInput, org.springframework.ai.chat.model.ToolContext toolContext) {
+ String toolName = getToolDefinition().name();
+ String executionId = UUID.randomUUID().toString();
+ emitToolExecutionEvent(executionId, toolName, "started", toolInput);
+ try {
+ String result = delegate.call(toolInput, toolContext);
+ emitToolExecutionEvent(executionId, toolName, "completed", result);
+ return result;
+ } catch (Exception e) {
+ emitToolExecutionEvent(executionId, toolName, "failed", e.getMessage());
+ throw e;
+ }
+ }
+ };
+ }
+
public AIAssistant build() {
ChatModel chatModel = getChatModel();
StreamingChatModel streamingChatModel = getStreamingChatModel();
@@ -134,13 +241,18 @@ public Flux streamChat(String userMessage) {
StringBuilder responseBuilder = new StringBuilder();
return streamingChatModel.stream(prompt)
- .map(chatResponse -> {
- String content =
- chatResponse.getResult().getOutput().getText();
- if (content != null) {
+ .concatMap(chatResponse -> {
+ String content = null;
+ if (chatResponse.getResult() != null
+ && chatResponse.getResult().getOutput() != null) {
+ content =
+ chatResponse.getResult().getOutput().getText();
+ }
+ if (content != null && !content.isEmpty()) {
responseBuilder.append(content);
+ return Flux.just(content);
}
- return content;
+ return Flux.empty();
})
.doOnComplete(() -> {
// Save to memory when streaming completes
diff --git a/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java b/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java
index e87d2c418..a780a65d5 100644
--- a/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java
+++ b/bigtop-manager-ai/src/test/java/assistant/GeneralAssistantFactoryTest.java
@@ -64,8 +64,8 @@ void testCreateAIAssistant() {
try (MockedStatic openAIAssistantMockedStatic = mockStatic(OpenAIAssistant.class)) {
openAIAssistantMockedStatic.when(OpenAIAssistant::builder).thenReturn(mockBuilder);
- generalAssistantFactory.createAIService(assistantConfigProvider, null);
- generalAssistantFactory.createForTest(assistantConfigProvider, null);
+ generalAssistantFactory.createAIService(assistantConfigProvider);
+ generalAssistantFactory.createForTest(assistantConfigProvider);
}
}
}
diff --git a/bigtop-manager-ai/src/test/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManagerTest.java b/bigtop-manager-ai/src/test/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManagerTest.java
new file mode 100644
index 000000000..544f89d9c
--- /dev/null
+++ b/bigtop-manager-ai/src/test/java/org/apache/bigtop/manager/ai/config/McpAsyncClientManagerTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.bigtop.manager.ai.config;
+
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties;
+import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties;
+import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties;
+import org.springframework.beans.factory.ObjectProvider;
+
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class McpAsyncClientManagerTest {
+
+ private static final ObjectProvider SSE_PROVIDER = new EmptyObjectProvider<>();
+ private static final ObjectProvider STREAMABLE_HTTP_PROVIDER =
+ new EmptyObjectProvider<>();
+ private static final ObjectProvider STDIO_PROVIDER = new EmptyObjectProvider<>();
+
+ @Test
+ void resolveBaseUrlShouldPreferConfiguredValue() throws Exception {
+ McpAsyncClientManager manager =
+ new McpAsyncClientManager(SSE_PROVIDER, STREAMABLE_HTTP_PROVIDER, STDIO_PROVIDER);
+ setField(manager, "connections", "name=sseMain,type=sse,baseUrl=http://127.0.0.1:9000,endpoint=/mcp/sse");
+
+ Method parseMethod = McpAsyncClientManager.class.getDeclaredMethod("parseConnections");
+ parseMethod.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ List