diff --git a/.dockerignore b/.dockerignore index 45c1def32..385a6449f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -37,8 +37,6 @@ build/ *.tgz # Backend -backend/assets/* -!backend/assets/test.wav backend/flower_db.sqlite uploads/ test/ @@ -60,4 +58,4 @@ assets/ .Spotlight-V100 .Trashes ehthumbs.db -Thumbs.db \ No newline at end of file +Thumbs.db diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 8a3fbc807..b38b17e56 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -21,7 +21,7 @@ from database.a2a_agent_db import PROTOCOL_JSONRPC from services.memory_config_service import build_memory_context -from services.image_service import get_vlm_model +from services.image_service import get_video_understanding_model, get_vlm_model from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list from database.agent_version_db import query_current_version_no from database.tool_db import search_tools_for_sub_agent @@ -31,13 +31,36 @@ from utils.model_name_utils import add_repo_to_name from utils.prompt_template_utils import get_agent_prompt_template from utils.config_utils import tenant_config_manager, get_model_name_from_config -from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE +from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET from consts.exceptions import ValidationError logger = logging.getLogger("create_agent_info") logger.setLevel(logging.DEBUG) +def _build_internal_s3_url(file: dict) -> str: + """Build a valid S3 URL for internal tools from uploaded file metadata.""" + if not isinstance(file, dict): + return "" + + object_name = str(file.get("object_name") or "").strip().lstrip("/") + if object_name: + bucket = MINIO_DEFAULT_BUCKET or "nexent" + return f"s3://{bucket}/{object_name}" + + url = str(file.get("url") or "").strip() + if not url or url.startswith("blob:") or url.startswith("s3:/blob:"): + return "" + + if url.startswith("s3://"): + return url + + if url.startswith("s3:/"): + return "s3://" + url.replace("s3:/", "", 1).lstrip("/") + + return "s3:/" + url + + def _get_skills_for_template( agent_id: int, tenant_id: str, @@ -532,10 +555,17 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int } elif tool_config.class_name == "AnalyzeImageTool": tool_config.metadata = { + # get_vlm_model reads the first multimodal slot, now shown as image understanding. "vlm_model": get_vlm_model(tenant_id=tenant_id), "storage_client": minio_client, "validate_url_access": lambda urls: validate_urls_access(urls, user_id) } + elif tool_config.class_name in ["AnalyzeAudioTool", "AnalyzeVideoTool"]: + tool_config.metadata = { + "vlm_model": get_video_understanding_model(tenant_id=tenant_id), + "storage_client": minio_client, + "validate_url_access": lambda urls: validate_urls_access(urls, user_id) + } tool_config_list.append(tool_config) @@ -636,10 +666,12 @@ async def join_minio_file_description_to_query( # Collect files from current message first (higher priority) if minio_files and isinstance(minio_files, list): for file in minio_files: - if isinstance(file, dict) and file.get("url") and file.get("name"): - url = file["url"] - if url not in seen_urls: - seen_urls.add(url) + if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")): + s3_url = _build_internal_s3_url(file) + if not s3_url: + continue + if s3_url not in seen_urls: + seen_urls.add(s3_url) all_files.append(file) # Collect files from historical messages (lower priority, already-deduped) @@ -647,10 +679,12 @@ async def join_minio_file_description_to_query( for msg in history: if isinstance(msg, dict) and msg.get("minio_files"): for file in msg["minio_files"]: - if isinstance(file, dict) and file.get("url") and file.get("name"): - url = file["url"] - if url not in seen_urls: - seen_urls.add(url) + if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")): + s3_url = _build_internal_s3_url(file) + if not s3_url: + continue + if s3_url not in seen_urls: + seen_urls.add(s3_url) all_files.append(file) # Enforce file count limit (keep most recent files by truncating from the end) @@ -666,7 +700,7 @@ async def join_minio_file_description_to_query( fixed_overhead = len(prefix) + len(suffix) for i, file in enumerate(all_files): - s3_url = f"s3:/{file['url']}" + s3_url = _build_internal_s3_url(file) presigned_url = file.get("presigned_url", "") # Build description with both URLs @@ -718,8 +752,10 @@ def _format_minio_files_for_content(minio_files: Optional[List[dict]], max_files if i >= max_files: file_lines.append(f" - ... (and {len(minio_files) - max_files} more files)") break - if isinstance(file, dict) and file.get("url") and file.get("name"): - s3_url = f"s3:/{file['url']}" + if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")): + s3_url = _build_internal_s3_url(file) + if not s3_url: + continue presigned_url = file.get("presigned_url", "") if presigned_url: file_lines.append( diff --git a/backend/consts/const.py b/backend/consts/const.py index fdc09c9e7..bbfb1ce41 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -308,6 +308,8 @@ class VectorDatabaseType(str, Enum): "multiEmbedding": "MULTI_EMBEDDING_ID", "rerank": "RERANK_ID", "vlm": "VLM_ID", + "vlm2": "VLM2_ID", + "vlm3": "VLM3_ID", "stt": "STT_ID", "tts": "TTS_ID" } diff --git a/backend/consts/model.py b/backend/consts/model.py index 10dc05231..f17cef5e0 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -177,6 +177,14 @@ class STTModelConfig(BaseModel): accessToken: Optional[str] = None +def _empty_model_config() -> SingleModelConfig: + return SingleModelConfig( + modelName="", + displayName="", + apiConfig=ModelApiConfig(apiKey="", modelUrl="") + ) + + class TTSModelConfig(BaseModel): """TTS model specific configuration with factory, appid, and access token fields""" modelName: str @@ -193,6 +201,8 @@ class ModelConfig(BaseModel): multiEmbedding: SingleModelConfig rerank: SingleModelConfig vlm: SingleModelConfig + vlm2: SingleModelConfig = Field(default_factory=_empty_model_config) + vlm3: SingleModelConfig = Field(default_factory=_empty_model_config) stt: STTModelConfig tts: TTSModelConfig diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml index 1cbe81096..5c2893c39 100644 --- a/backend/prompts/managed_system_prompt_template_en.yaml +++ b/backend/prompts/managed_system_prompt_template_en.yaml @@ -117,7 +117,7 @@ system_prompt: |- → Use **presigned_url** (already includes proxy prefix, format: `http://.../api/nb/v1/file/fetch?presigned_url=...`) Directly use the **presigned_url** field provided in the user's uploaded file info. No need to construct or append anything. 2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image): - → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`) + → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`) Reason: Internal tools run inside Nexent and can directly access MinIO storage {%- else %} diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml index 464ec5264..8ce58db29 100644 --- a/backend/prompts/manager_system_prompt_template_en.yaml +++ b/backend/prompts/manager_system_prompt_template_en.yaml @@ -119,7 +119,7 @@ system_prompt: |- → Use **Download URL** (format: `https://minio.example.com/...?token=xxx`) Reason: MCP tools run on external services and cannot access internal S3 storage 2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image): - → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`) + → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`) Reason: Internal tools run inside Nexent and can directly access MinIO storage {%- else %} - No tools are currently available diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 3889c9d58..12b2ebebd 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -6,6 +6,7 @@ dependencies = [ "aiofiles>=0.8.0", "uvicorn>=0.34.0", "fastapi>=0.115.12", + "email-validator>=2.0.0", "aiohttp>=3.8.0", "authlib>=1.3.0", "cryptography>=42.0.0", @@ -16,6 +17,7 @@ dependencies = [ "supabase>=2.18.1", "websocket-client>=1.8.0", "pyyaml>=6.0.2", + "jsonref>=1.1.0", "ruamel-yaml==0.19.1", "redis>=5.0.0", "fastmcp==2.12.0", diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 82c4f6f0d..c90d707e1 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -1265,7 +1265,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) # Check if any tool is KnowledgeBaseSearchTool and set its metadata to empty dict for tool in tool_list: - if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "DataMateSearchTool"]: + if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool", "DataMateSearchTool"]: tool.metadata = {} # Get model_id and model display name from agent_info diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index d2ea8fea6..7feea9452 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -20,7 +20,7 @@ MODEL_ENGINE_ENABLED, TENANT_NAME ) -from database.model_management_db import get_model_id_by_display_name +from database.model_management_db import get_model_id_by_display_name, get_model_records from utils.config_utils import ( get_env_key, get_model_name_from_config, @@ -31,6 +31,20 @@ logger = logging.getLogger("config_sync_service") +def get_model_id_for_config(model_type: str, display_name: str, tenant_id: str) -> Optional[int]: + if not display_name: + return None + + records = get_model_records( + {"display_name": display_name, "model_type": model_type}, + tenant_id + ) + if records: + return records[0].get("model_id") + + return get_model_id_by_display_name(display_name, tenant_id) + + def handle_model_config(tenant_id: str, user_id: str, config_key: str, model_id: Optional[int], tenant_config_dict: dict) -> None: """ Handle model configuration updates, deletions, and settings operations @@ -98,8 +112,8 @@ async def save_config_impl(config, tenant_id, user_id): model_display_name = model_config.get("displayName") config_key = get_env_key(model_type) + "_ID" - model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_id = get_model_id_for_config( + model_type, model_display_name, tenant_id) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/image_service.py b/backend/services/image_service.py index 8decbd541..8a924e9cc 100644 --- a/backend/services/image_service.py +++ b/backend/services/image_service.py @@ -31,7 +31,11 @@ async def proxy_image_impl(decoded_url: str): def get_vlm_model(tenant_id: str): - # Get the tenant config + """Return the configured image understanding model for AnalyzeImageTool. + + The first multimodal model slot is still stored under MODEL_CONFIG_MAPPING["vlm"] + for compatibility, but it is the user-facing image understanding configuration. + """ vlm_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["vlm"], tenant_id=tenant_id) if not vlm_model_config: @@ -48,3 +52,27 @@ def get_vlm_model(tenant_id: str): max_tokens=512, ssl_verify=vlm_model_config.get("ssl_verify", True), ) + + +def get_image_understanding_model(tenant_id: str): + return get_vlm_model(tenant_id=tenant_id) + + +def get_video_understanding_model(tenant_id: str): + """Return the configured video understanding model for multimodal tools.""" + vlm_model_config = tenant_config_manager.get_model_config( + key=MODEL_CONFIG_MAPPING["vlm3"], tenant_id=tenant_id) + if not vlm_model_config: + return None + return OpenAIVLModel( + observer=MessageObserver(), + model_id=get_model_name_from_config( + vlm_model_config) if vlm_model_config else "", + api_base=vlm_model_config.get("base_url", ""), + api_key=vlm_model_config.get("api_key", ""), + temperature=0.7, + top_p=0.7, + frequency_penalty=0.5, + max_tokens=512, + ssl_verify=vlm_model_config.get("ssl_verify", True), + ) diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 4324bf12f..58a0af91f 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -134,7 +134,7 @@ async def _perform_connectivity_check( ssl_verify=ssl_verify, ) connectivity = await rerank_model.connectivity_check() - elif model_type == "vlm": + elif model_type in ("vlm", "vlm2", "vlm3"): observer = MessageObserver() set_monitoring_operation("connectivity_check", display_name=display_name) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 5cf4e381b..8f6d191fd 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -164,6 +164,13 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay tenant_id, provider, model_type) model_list_ids = {model.get("id") for model in model_list} if model_list else set() + existing_model_map = { + add_repo_to_name( + model_repo=model["model_repo"], + model_name=model["model_name"], + ): model + for model in existing_model_list + } # Delete existing models not present for model in existing_model_list: @@ -173,21 +180,20 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay # Create or update new models for model in model_list: + model["model_type"] = model_type _, model_name = split_repo_name( model["id"]) if model.get("id") else ("", "") model_repo, model_name_only = split_repo_name( model.get("id", "")) if model.get("id") else ("", "") model_display_name = add_repo_to_name(model_repo, model_name_only) if model_name: - existing_model_by_display = get_model_by_display_name( - model_display_name, tenant_id) - if existing_model_by_display: + existing_model = existing_model_map.get(model_display_name) + if existing_model: # Check if max_tokens has changed - existing_max_tokens = existing_model_by_display.get( - "max_tokens") + existing_max_tokens = existing_model.get("max_tokens") new_max_tokens = model.get("max_tokens") if new_max_tokens is not None and existing_max_tokens != new_max_tokens: - update_model_record(existing_model_by_display["model_id"], { + update_model_record(existing_model["model_id"], { "max_tokens": new_max_tokens}, user_id) continue diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py index ea41cc95d..130f2346e 100644 --- a/backend/services/providers/silicon_provider.py +++ b/backend/services/providers/silicon_provider.py @@ -1,4 +1,5 @@ import httpx +import re from typing import Dict, List from consts.const import DEFAULT_LLM_MAX_TOKENS @@ -6,6 +7,62 @@ from services.providers.base import AbstractModelProvider, _classify_provider_error +SILICON_VLM_MODEL_KEYWORDS = ( + "-vl", + "_vl", + "/vl", + ".vl", + "vl-", + "vision", + "visual", + "internvl", + "deepseek-vl", + "deepseekvl", + "glm-4v", + "minicpm-v", + "llava", + "kimi-vl", + "kimi-k2.5", + "kimi-k2.6", + "qvq", + "omni", + "qwen3.5", + "qwen3.6", +) + +SILICON_VLM_METADATA_KEYWORDS = ("image", "video", "vision", "visual") + + +def _contains_silicon_vlm_metadata(value) -> bool: + if isinstance(value, str): + lower_value = value.lower() + return any(keyword in lower_value for keyword in SILICON_VLM_METADATA_KEYWORDS) + if isinstance(value, list): + return any(_contains_silicon_vlm_metadata(item) for item in value) + if isinstance(value, dict): + return any(_contains_silicon_vlm_metadata(item) for item in value.values()) + return False + + +def _is_silicon_vlm_model(model: Dict) -> bool: + if _contains_silicon_vlm_metadata(model): + return True + + model_id = str(model.get("id", "")).lower() + model_name = str(model.get("name", "")).lower() + searchable_text = f"{model_id} {model_name}" + if any(keyword in searchable_text for keyword in SILICON_VLM_MODEL_KEYWORDS): + return True + + return bool(re.search(r"glm-\d+(?:\.\d+)?v", searchable_text)) + + +def _is_silicon_omni_model(model: Dict) -> bool: + model_id = str(model.get("id", "")).lower() + model_name = str(model.get("name", "")).lower() + return "omni" in f"{model_id} {model_name}" + + class SiliconModelProvider(AbstractModelProvider): """Concrete implementation for SiliconFlow provider.""" @@ -25,12 +82,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: headers = {"Authorization": f"Bearer {model_api_key}"} + provider_model_type = "vlm" if model_type in ("vlm2", "vlm3") else model_type + # Choose endpoint by model type - if model_type in ("llm", "vlm"): + if provider_model_type in ("llm", "vlm"): silicon_url = f"{SILICON_GET_URL}?sub_type=chat" - elif model_type in ("embedding", "multi_embedding"): + elif provider_model_type in ("embedding", "multi_embedding"): silicon_url = f"{SILICON_GET_URL}?sub_type=embedding" - elif model_type == "rerank": + elif provider_model_type == "rerank": silicon_url = f"{SILICON_GET_URL}?sub_type=reranker" else: silicon_url = SILICON_GET_URL @@ -40,17 +99,22 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: response.raise_for_status() model_list: List[Dict] = response.json()["data"] + if model_type == "vlm3": + model_list = [item for item in model_list if _is_silicon_omni_model(item)] + elif provider_model_type == "vlm": + model_list = [item for item in model_list if _is_silicon_vlm_model(item)] + # Annotate models with canonical fields expected downstream - if model_type in ("llm", "vlm"): + if provider_model_type in ("llm", "vlm"): for item in model_list: item["model_tag"] = "chat" item["model_type"] = model_type item["max_tokens"] = DEFAULT_LLM_MAX_TOKENS - elif model_type in ("embedding", "multi_embedding"): + elif provider_model_type in ("embedding", "multi_embedding"): for item in model_list: item["model_tag"] = "embedding" item["model_type"] = model_type - elif model_type == "rerank": + elif provider_model_type == "rerank": for item in model_list: item["model_tag"] = "rerank" item["model_type"] = model_type diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index c30a6006f..2e868044e 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -38,7 +38,7 @@ from services.file_management_service import get_llm_model, validate_urls_access from services.vectordatabase_service import get_embedding_model_by_index_name, get_rerank_model from database.client import minio_client -from services.image_service import get_vlm_model +from services.image_service import get_video_understanding_model, get_vlm_model from nexent.monitor import set_monitoring_context, set_monitoring_operation from services.vectordatabase_service import get_vector_db_core from utils.langchain_utils import discover_langchain_modules @@ -779,6 +779,7 @@ def _validate_local_tool( if not tenant_id or not user_id: raise ToolExecutionException( f"Tenant ID and User ID are required for {tool_name} validation") + # get_vlm_model reads the first multimodal slot, now shown as image understanding. image_to_text_model = get_vlm_model(tenant_id=tenant_id) vlm_display_name = getattr( image_to_text_model, 'display_name', None) @@ -792,6 +793,23 @@ def _validate_local_tool( 'validate_url_access': lambda urls: validate_urls_access(urls, user_id) } tool_instance = tool_class(**params) + elif tool_name in ["analyze_audio", "analyze_video"]: + if not tenant_id or not user_id: + raise ToolExecutionException( + f"Tenant ID and User ID are required for {tool_name} validation") + video_understanding_model = get_video_understanding_model(tenant_id=tenant_id) + model_display_name = getattr( + video_understanding_model, 'display_name', None) + set_monitoring_context(tenant_id=tenant_id) + set_monitoring_operation( + "tool_validation", display_name=model_display_name) + params = { + **instantiation_params, + 'vlm_model': video_understanding_model, + 'storage_client': minio_client, + 'validate_url_access': lambda urls: validate_urls_access(urls, user_id) + } + tool_instance = tool_class(**params) elif tool_name == "analyze_text_file": if not tenant_id or not user_id: raise ToolExecutionException( diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index cb1c8cd3f..993795c98 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -35,11 +35,17 @@ const TOOLS_REQUIRING_EMBEDDING = [ "knowledge_base_search", ]; -// Tool types that require VLM model -const TOOLS_REQUIRING_VLM = [ +// Tool types that require the image understanding model +const TOOLS_REQUIRING_IMAGE_UNDERSTANDING = [ "analyze_image", ]; +// Tool types that require the video understanding model +const TOOLS_REQUIRING_VIDEO_UNDERSTANDING = [ + "analyze_audio", + "analyze_video", +]; + function getToolKbType( toolName: string ): "knowledge_base_search" | "dify_search" | "datamate_search" | "idata_search" | "haotian_search" | null { @@ -54,9 +60,18 @@ function getToolKbType( /** * Check if a tool requires VLM model but VLM is not available */ -function isToolDisabledDueToVlm(toolName: string, vlmAvailable: boolean): boolean { - if (!TOOLS_REQUIRING_VLM.includes(toolName)) return false; - return !vlmAvailable; +function isToolDisabledDueToVlm( + toolName: string, + imageUnderstandingAvailable: boolean, + videoUnderstandingAvailable: boolean +): boolean { + if (TOOLS_REQUIRING_IMAGE_UNDERSTANDING.includes(toolName)) { + return !imageUnderstandingAvailable; + } + if (TOOLS_REQUIRING_VIDEO_UNDERSTANDING.includes(toolName)) { + return !videoUnderstandingAvailable; + } + return false; } /** @@ -98,7 +113,11 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable, + isEmbeddingAvailable, + } = useConfig(); // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -358,7 +377,11 @@ export default function ToolManagement({ const isSelected = originalSelectedToolIdsSet.has( tool.id ); - const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); + const isDisabledDueToVlm = isToolDisabledDueToVlm( + tool.name, + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable + ); const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding @@ -463,7 +486,11 @@ export default function ToolManagement({ > {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); - const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); + const isDisabledDueToVlm = isToolDisabledDueToVlm( + tool.name, + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable + ); const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding diff --git a/frontend/app/[locale]/chat/components/chatAttachment.tsx b/frontend/app/[locale]/chat/components/chatAttachment.tsx index 5c9da8ec9..d12e939cd 100644 --- a/frontend/app/[locale]/chat/components/chatAttachment.tsx +++ b/frontend/app/[locale]/chat/components/chatAttachment.tsx @@ -87,6 +87,14 @@ const getFileIcon = (name: string, contentType?: string) => { return ; } + // Audio and video files are uploaded as regular attachments for multimodal tools. + if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) { + return ; + } + if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) { + return ; + } + // Compressed file if (chatConfig.fileIcons.compressed.includes(extension)) { return ; @@ -230,4 +238,4 @@ export function ChatAttachment({ )} ); -} \ No newline at end of file +} diff --git a/frontend/app/[locale]/chat/components/chatInput.tsx b/frontend/app/[locale]/chat/components/chatInput.tsx index 8de0d17eb..bcfc86f6b 100644 --- a/frontend/app/[locale]/chat/components/chatInput.tsx +++ b/frontend/app/[locale]/chat/components/chatInput.tsx @@ -96,10 +96,24 @@ const getFileIcon = (file: File) => { return ; } + if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) { + return ; + } + + if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) { + return ; + } + // Default file icon return ; }; +const isSupportedMediaFile = (extension: string, fileType: string) => + fileType.startsWith("audio/") || + fileType.startsWith("video/") || + chatConfig.audioExtensions.includes(extension) || + chatConfig.videoExtensions.includes(extension); + // File limit constants from config const MAX_FILE_COUNT = chatConfig.maxFileCount; const MAX_FILE_SIZE = chatConfig.maxFileSize; @@ -617,8 +631,9 @@ export function ChatInput({ chatConfig.supportedTextExtensions.includes(extension) || file.type === "text/csv" || file.type === "text/plain"; + const isMedia = isSupportedMediaFile(extension, file.type); - if (isImage || isDocument || isSupportedTextFile) { + if (isImage || isDocument || isSupportedTextFile || isMedia) { // Create a preview URL for images const previewUrl = isImage ? URL.createObjectURL(file) : undefined; @@ -899,7 +914,7 @@ export function ChatInput({ id="file-upload-regular" className="hidden" onChange={handleFileUpload} - accept={`image/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`} + accept={`image/*,audio/*,video/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`} multiple /> @@ -1026,8 +1041,9 @@ export function ChatInput({ chatConfig.supportedTextExtensions.includes(extension) || fileType === "text/csv" || fileType === "text/plain"; + const isMedia = isSupportedMediaFile(extension, fileType); - return !(isImage || isDocument || isSupportedTextFile); + return !(isImage || isDocument || isSupportedTextFile || isMedia); }); // Regular mode, keep the original rendering logic diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index 6e0de48b5..0f3c99715 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -38,7 +38,7 @@ import { extractAssistantMsgFromResponse, } from "@/lib/chatMessageExtractor"; -import { Layout } from "antd"; +import { Layout, message } from "antd"; import log from "@/lib/logger"; const stepIdCounter = { current: 0 }; @@ -268,9 +268,23 @@ export function ChatInterface() { // Use preprocessing function to upload attachments const uploadResult = await uploadAttachments(attachments, t); + if (uploadResult.error) { + message.error(`${t("chatPreprocess.fileUploadFailed")} ${uploadResult.error}`); + setIsLoading(false); + return; + } uploadedFileUrls = uploadResult.uploadedFileUrls; objectNames = uploadResult.objectNames; // Get object name mapping presignedUrls = uploadResult.presignedUrls; // Get presigned URLs for external access + + const missingUploads = attachments.filter( + (attachment) => !uploadedFileUrls[attachment.file.name] || !objectNames[attachment.file.name] + ); + if (missingUploads.length > 0) { + message.error(`${t("chatPreprocess.fileUploadFailed")} ${missingUploads.map((item) => item.file.name).join(", ")}`); + setIsLoading(false); + return; + } } // Use preprocessing function to create message attachments diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 3bff1101a..acfedb120 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -68,6 +68,23 @@ const DEFAULT_FORM_STATE = { ttsProvider: "dashscope", // ali or volcengine }; +const resolveConnectivityModelType = (type: ModelType): ModelType => + type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3 + ? (MODEL_TYPES.VLM as ModelType) + : type; + +const resolveConfigKey = (type: ModelType): string => + type; + +const isVlmConfigType = (type: ModelType): boolean => + type === MODEL_TYPES.VLM || type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3; + +const emptyModelConfig = { + modelName: "", + displayName: "", + apiConfig: { apiKey: "", modelUrl: "" }, +}; + // Connectivity status type comes from utils // Helper function to translate error messages from backend @@ -198,7 +215,7 @@ export const ModelAddDialog = ({ }: ModelAddDialogProps) => { const { t } = useTranslation(); const { message } = App.useApp(); - const { updateModelConfig, saveConfig } = useConfig(); + const { modelConfig: currentModelConfig, updateModelConfig, saveConfig } = useConfig(); // Parse backend error message and return i18n key with params const parseModelError = ( @@ -490,7 +507,7 @@ export const ModelAddDialog = ({ const modelType = form.type === MODEL_TYPES.EMBEDDING && form.isMultimodal ? (MODEL_TYPES.MULTI_EMBEDDING as ModelType) - : form.type; + : resolveConnectivityModelType(form.type); let connectivity = false; @@ -653,6 +670,32 @@ export const ModelAddDialog = ({ }); } + if (isVlmConfigType(form.type) && enabledModels.length > 0) { + const selectedModel = enabledModels[0]; + const selectedDisplayName = selectedModel.displayName || selectedModel.id || ""; + const configKey = resolveConfigKey(form.type); + const vlmConfigUpdate: any = { + [configKey]: { + modelName: selectedModel.id || selectedModel.model_name || "", + displayName: selectedDisplayName, + apiConfig: { + apiKey: form.apiKey, + modelUrl: "", + }, + }, + }; + for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) { + if ( + key !== configKey && + currentModelConfig?.[key]?.displayName === selectedDisplayName + ) { + vlmConfigUpdate[key] = emptyModelConfig; + } + } + updateModelConfig(vlmConfigUpdate); + await persistModelConfig(); + } + // Reset form state and close dialog on success resetForm(); handleClose(); @@ -841,6 +884,7 @@ export const ModelAddDialog = ({ // Update the local storage according to the model type let configUpdate: any = {}; + const configKey = resolveConfigKey(form.type); switch (modelType) { case MODEL_TYPES.LLM: @@ -853,7 +897,17 @@ export const ModelAddDialog = ({ configUpdate = { multiEmbedding: modelConfig }; break; case MODEL_TYPES.VLM: - configUpdate = { vlm: modelConfig }; + case MODEL_TYPES.VLM2: + case MODEL_TYPES.VLM3: + configUpdate = { [configKey]: modelConfig }; + for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) { + if ( + key !== configKey && + currentModelConfig?.[key]?.displayName === modelConfig.displayName + ) { + configUpdate[key] = emptyModelConfig; + } + } break; case MODEL_TYPES.RERANK: configUpdate = { rerank: modelConfig }; @@ -996,7 +1050,15 @@ export const ModelAddDialog = ({ - + + + diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index 6ba9cf0c3..55c31c2c0 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -99,6 +99,8 @@ export const ModelDeleteDialog = ({ border: "border-purple-100", }; case MODEL_TYPES.VLM: + case MODEL_TYPES.VLM2: + case MODEL_TYPES.VLM3: return { bg: "bg-yellow-50", text: "text-yellow-600", @@ -141,6 +143,8 @@ export const ModelDeleteDialog = ({ case MODEL_TYPES.TTS: return "🔊"; case MODEL_TYPES.VLM: + case MODEL_TYPES.VLM2: + case MODEL_TYPES.VLM3: return "👁️"; default: return "⚙️"; @@ -165,6 +169,10 @@ export const ModelDeleteDialog = ({ return t("model.type.tts"); case MODEL_TYPES.VLM: return t("model.type.vlm"); + case MODEL_TYPES.VLM2: + return `${t("model.type.vlm")}2`; + case MODEL_TYPES.VLM3: + return `${t("model.type.vlm")}3`; default: return t("model.type.unknown"); } @@ -344,7 +352,10 @@ export const ModelDeleteDialog = ({ if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl; } if (type === MODEL_TYPES.VLM) { - const cfgUrl = modelConfig?.vlm?.apiConfig?.modelUrl; + const cfgUrl = + modelConfig?.vlm?.apiConfig?.modelUrl || + modelConfig?.vlm2?.apiConfig?.modelUrl || + modelConfig?.vlm3?.apiConfig?.modelUrl; if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl; } if (type === MODEL_TYPES.LLM) { @@ -501,6 +512,22 @@ export const ModelDeleteDialog = ({ }; } + if (modelConfig.vlm2?.displayName === displayName) { + configUpdates.vlm2 = { + modelName: "", + displayName: "", + apiConfig: { apiKey: "", modelUrl: "" }, + }; + } + + if (modelConfig.vlm3?.displayName === displayName) { + configUpdates.vlm3 = { + modelName: "", + displayName: "", + apiConfig: { apiKey: "", modelUrl: "" }, + }; + } + if (modelConfig.stt.displayName === displayName) { configUpdates.stt = { modelName: "", displayName: "" }; } @@ -1017,6 +1044,8 @@ export const ModelDeleteDialog = ({ MODEL_TYPES.MULTI_EMBEDDING, MODEL_TYPES.RERANK, MODEL_TYPES.VLM, + MODEL_TYPES.VLM2, + MODEL_TYPES.VLM3, MODEL_TYPES.STT, MODEL_TYPES.TTS, ] as ModelType[] diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 4b44a6361..c44251f15 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -110,6 +110,10 @@ export const ModelEditDialog = ({ form.type === MODEL_TYPES.EMBEDDING || form.type === MODEL_TYPES.MULTI_EMBEDDING; const isRerankModel = form.type === MODEL_TYPES.RERANK; + const connectivityModelType = + form.type === MODEL_TYPES.VLM2 || form.type === MODEL_TYPES.VLM3 + ? (MODEL_TYPES.VLM as ModelType) + : form.type; const isVoiceModel = form.type === MODEL_TYPES.STT || form.type === MODEL_TYPES.TTS; @@ -141,13 +145,9 @@ export const ModelEditDialog = ({ }); try { - const modelType = form.type as ModelType; - const isVoiceModel = - modelType === MODEL_TYPES.STT || modelType === MODEL_TYPES.TTS; - const config: any = { modelName: form.name, - modelType: modelType, + modelType: connectivityModelType, baseUrl: form.url, apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, maxTokens: @@ -271,6 +271,8 @@ export const ModelEditDialog = ({ embedding: MODEL_TYPES.EMBEDDING, multi_embedding: MODEL_TYPES.MULTI_EMBEDDING, vlm: MODEL_TYPES.VLM, + vlm2: MODEL_TYPES.VLM2, + vlm3: MODEL_TYPES.VLM3, rerank: MODEL_TYPES.RERANK, tts: MODEL_TYPES.TTS, stt: MODEL_TYPES.STT, diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index 07eee5c06..36fcdbb31 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -56,7 +56,11 @@ const getModelData = (t: any) => ({ }, multimodal: { title: t("modelConfig.category.multimodal"), - options: [{ id: MODEL_TYPES.VLM, name: t("modelConfig.option.vlmModel") }], + options: [ + { id: MODEL_TYPES.VLM, name: t("modelConfig.option.imageUnderstandingModel") }, + { id: MODEL_TYPES.VLM2, name: t("modelConfig.option.imageGenerationModel") }, + { id: MODEL_TYPES.VLM3, name: t("modelConfig.option.videoUnderstandingModel") }, + ], }, voice: { title: t("modelConfig.category.voice"), @@ -142,7 +146,7 @@ export const ModelConfigSection = forwardRef< llm: { main: "" }, embedding: { embedding: "", multi_embedding: "" }, reranker: { reranker: "" }, - multimodal: { vlm: "" }, + multimodal: { vlm: "", vlm2: "", vlm3: "" }, voice: { tts: "", stt: "" }, }); @@ -284,11 +288,23 @@ export const ModelConfigSection = forwardRef< : true; const vlm = modelConfig.vlm.displayName; + const vlm2 = modelConfig.vlm2?.displayName || ""; + const vlm3 = modelConfig.vlm3?.displayName || ""; const vlmExists = vlm ? allModels.some( (m) => m.displayName === vlm && m.type === MODEL_TYPES.VLM ) : true; + const vlm2Exists = vlm2 + ? allModels.some( + (m) => m.displayName === vlm2 && m.type === MODEL_TYPES.VLM2 + ) + : true; + const vlm3Exists = vlm3 + ? allModels.some( + (m) => m.displayName === vlm3 && m.type === MODEL_TYPES.VLM3 + ) + : true; const stt = modelConfig.stt.displayName; const sttExists = stt @@ -318,6 +334,8 @@ export const ModelConfigSection = forwardRef< }, multimodal: { vlm: vlmExists ? vlm : "", + vlm2: vlm2Exists ? vlm2 : "", + vlm3: vlm3Exists ? vlm3 : "", }, voice: { tts: ttsExists ? tts : "", @@ -363,6 +381,14 @@ export const ModelConfigSection = forwardRef< configUpdates.vlm = { modelName: "", displayName: "" }; } + if (!vlm2Exists && vlm2) { + configUpdates.vlm2 = { modelName: "", displayName: "" }; + } + + if (!vlm3Exists && vlm3) { + configUpdates.vlm3 = { modelName: "", displayName: "" }; + } + if (!sttExists && stt) { configUpdates.stt = { modelName: "", displayName: "" }; } @@ -385,6 +411,8 @@ export const ModelConfigSection = forwardRef< !!modelConfig.multiEmbedding.modelName || !!modelConfig.rerank.modelName || !!modelConfig.vlm.modelName || + !!modelConfig.vlm2?.modelName || + !!modelConfig.vlm3?.modelName || !!modelConfig.tts.modelName || !!modelConfig.stt.modelName; @@ -441,11 +469,13 @@ export const ModelConfigSection = forwardRef< const hasEmbedding = !!modelConfig.embedding.modelName; const hasReranker = !!modelConfig.rerank.modelName; const hasVlm = !!modelConfig.vlm.modelName; + const hasVlm2 = !!modelConfig.vlm2?.modelName; + const hasVlm3 = !!modelConfig.vlm3?.modelName; const hasTts = !!modelConfig.tts.modelName; const hasStt = !!modelConfig.stt.modelName; hasSelectedModels = - hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasTts || hasStt; + hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasVlm2 || hasVlm3 || hasTts || hasStt; if (hasSelectedModels) { currentSelectedModels.llm.main = modelConfig.llm.modelName; @@ -455,6 +485,8 @@ export const ModelConfigSection = forwardRef< modelConfig.multiEmbedding.modelName || ""; currentSelectedModels.reranker.reranker = modelConfig.rerank.modelName; currentSelectedModels.multimodal.vlm = modelConfig.vlm.modelName; + currentSelectedModels.multimodal.vlm2 = modelConfig.vlm2?.modelName || ""; + currentSelectedModels.multimodal.vlm3 = modelConfig.vlm3?.modelName || ""; currentSelectedModels.voice.tts = modelConfig.tts.modelName; currentSelectedModels.voice.stt = modelConfig.stt.modelName; } else { @@ -492,7 +524,7 @@ export const ModelConfigSection = forwardRef< } else if (category === "reranker") { modelType = MODEL_TYPES.RERANK; } else if (category === "multimodal") { - modelType = MODEL_TYPES.VLM; + modelType = optionId as ModelType; } else if (category === MODEL_TYPES.EMBEDDING) { modelType = optionId === MODEL_TYPES.MULTI_EMBEDDING @@ -654,7 +686,7 @@ export const ModelConfigSection = forwardRef< } else if (category === "reranker") { modelType = MODEL_TYPES.RERANK; } else if (category === "multimodal") { - modelType = MODEL_TYPES.VLM; + modelType = option as ModelType; } else if (category === MODEL_TYPES.EMBEDDING) { modelType = option === MODEL_TYPES.MULTI_EMBEDDING @@ -679,7 +711,7 @@ export const ModelConfigSection = forwardRef< ) { configKey = "multiEmbedding"; } else if (category === "multimodal") { - configKey = MODEL_TYPES.VLM; + configKey = option; } else if (category === "reranker") { configKey = MODEL_TYPES.RERANK; } else if (category === "voice" && option === "tts") { @@ -1005,7 +1037,7 @@ export const ModelConfigSection = forwardRef< ? MODEL_TYPES.TTS : MODEL_TYPES.STT : key === "multimodal" - ? MODEL_TYPES.VLM + ? (option.id as ModelType) : key === MODEL_TYPES.EMBEDDING && option.id === MODEL_TYPES.MULTI_EMBEDDING ? MODEL_TYPES.MULTI_EMBEDDING diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts index fc0dbe6d5..27b3b887d 100644 --- a/frontend/const/chatConfig.ts +++ b/frontend/const/chatConfig.ts @@ -38,6 +38,12 @@ export const chatConfig = { // Supported document file extensions documentExtensions: ["pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "epub", "html", "xml"], + + // Supported audio file extensions + audioExtensions: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"], + + // Supported video file extensions + videoExtensions: ["mp4", "mov", "m4v", "avi", "mkv", "webm", "wmv", "flv"], // Supported text document extensions supportedTextExtensions: ["md", "markdown", "txt", "csv", "json"], @@ -73,6 +79,12 @@ export const chatConfig = { // Compressed file compressed: ["zip", "rar", "7z", "tar", "gz"], + + // Audio files + audio: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"], + + // Video files + video: ["mp4", "mov", "m4v", "avi", "mkv", "wmv", "flv"], }, // File preview type constants @@ -148,4 +160,4 @@ export const MESSAGE_ROLES = { USER: "user" as const, ASSISTANT: "assistant" as const, SYSTEM: "system" as const, -} as const; \ No newline at end of file +} as const; diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index 9bdc5a4a8..b7762ace0 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -7,6 +7,8 @@ export const MODEL_TYPES = { STT: "stt", TTS: "tts", VLM: "vlm", + VLM2: "vlm2", + VLM3: "vlm3", } as const; // Model source constants diff --git a/frontend/hooks/model/useDashscopeModelList.ts b/frontend/hooks/model/useDashscopeModelList.ts index b44348fe5..5d1035e8a 100644 --- a/frontend/hooks/model/useDashscopeModelList.ts +++ b/frontend/hooks/model/useDashscopeModelList.ts @@ -39,7 +39,9 @@ export const useDashscopeModelList = ({ const modelType = form.type === "embedding" && form.isMultimodal ? ("multi_embedding" as ModelType) - : form.type; + : form.type === "vlm2" || form.type === "vlm3" + ? ("vlm" as ModelType) + : form.type; try { // Use manage interface if tenantId is provided (for super admin) diff --git a/frontend/hooks/model/useTokenponyModelList.ts b/frontend/hooks/model/useTokenponyModelList.ts index 0a7e23581..0c502a404 100644 --- a/frontend/hooks/model/useTokenponyModelList.ts +++ b/frontend/hooks/model/useTokenponyModelList.ts @@ -39,7 +39,9 @@ export const useTokenPonyModelList = ({ const modelType = form.type === "embedding" && form.isMultimodal ? ("multi_embedding" as ModelType) - : form.type; + : form.type === "vlm2" || form.type === "vlm3" + ? ("vlm" as ModelType) + : form.type; try { // Use manage interface if tenantId is provided (for super admin) diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts index be4e463f0..94d4d57db 100644 --- a/frontend/hooks/useConfig.ts +++ b/frontend/hooks/useConfig.ts @@ -81,6 +81,22 @@ const defaultConfig: GlobalConfig = { modelUrl: "", }, }, + vlm2: { + modelName: "", + displayName: "", + apiConfig: { + apiKey: "", + modelUrl: "", + }, + }, + vlm3: { + modelName: "", + displayName: "", + apiConfig: { + apiKey: "", + modelUrl: "", + }, + }, stt: { id: 0, modelName: "", @@ -173,6 +189,8 @@ function transformBackendToFrontend(backendConfig: any): GlobalConfig { ), rerank: transformModelEntry(backendConfig.models.rerank), vlm: transformModelEntry(backendConfig.models.vlm), + vlm2: transformModelEntry(backendConfig.models.vlm2), + vlm3: transformModelEntry(backendConfig.models.vlm3), stt: transformVoiceModelEntry(backendConfig.models.stt), tts: transformVoiceModelEntry(backendConfig.models.tts), } @@ -207,7 +225,10 @@ function loadConfigFromStorage(): GlobalConfig | null { if (storedModelConfig) { try { - mergedConfig.models = JSON.parse(storedModelConfig); + mergedConfig.models = deepMerge( + mergedConfig.models, + JSON.parse(storedModelConfig) + ); } catch (error) { log.error("Failed to parse model config:", error); } @@ -297,7 +318,24 @@ export function useConfig() { const config: GlobalConfig = (query.data as GlobalConfig | undefined) ?? defaultConfig; // Whether config has selected a VLM model - const isVlmAvailable = !!(config?.models?.vlm?.modelName || config?.models?.vlm?.displayName); + const isVlmAvailable = !!( + config?.models?.vlm?.modelName || + config?.models?.vlm?.displayName || + config?.models?.vlm2?.modelName || + config?.models?.vlm2?.displayName || + config?.models?.vlm3?.modelName || + config?.models?.vlm3?.displayName + ); + + const isImageUnderstandingAvailable = !!( + config?.models?.vlm?.modelName || + config?.models?.vlm?.displayName + ); + + const isVideoUnderstandingAvailable = !!( + config?.models?.vlm3?.modelName || + config?.models?.vlm3?.displayName + ); // Whether config has selected an Embedding model const isEmbeddingAvailable = !!(config?.models?.embedding?.modelName || config?.models?.embedding?.displayName); @@ -383,6 +421,8 @@ export function useConfig() { appConfig: config?.app, modelConfig: config?.models, isVlmAvailable, + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable, isEmbeddingAvailable, defaultLlmModelName, defaultLlmModelConfig, diff --git a/frontend/lib/chat/chatAttachmentUtils.ts b/frontend/lib/chat/chatAttachmentUtils.ts index c85615b4e..bff686ca1 100644 --- a/frontend/lib/chat/chatAttachmentUtils.ts +++ b/frontend/lib/chat/chatAttachmentUtils.ts @@ -69,6 +69,19 @@ export const uploadAttachments = async ( }); } + const failedResults = uploadResult.results.filter((result) => !result.success); + if (failedResults.length > 0 || uploadResult.success_count < attachments.length) { + const failedMessage = failedResults + .map((result) => `${result.file_name || "file"}: ${result.error || "Upload failed"}`) + .join("; "); + return { + uploadedFileUrls, + objectNames, + presignedUrls, + error: failedMessage || "Upload failed", + }; + } + return { uploadedFileUrls, objectNames, presignedUrls }; } catch (error) { log.error(t("chatPreprocess.fileUploadFailed"), error); diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 1adf9d6e3..b008cf70b 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -1,4 +1,4 @@ -{ +{ "assistant.name": "Nexent", "mainPage.layout.title": "Nexent | AI Agents", @@ -830,6 +830,9 @@ "model.type.llm": "Large Language Model", "model.type.embedding": "Embedding Model", "model.type.vlm": "Vision Language Model", + "model.type.imageUnderstanding": "Image Understanding Model", + "model.type.imageGeneration": "Image Generation Model", + "model.type.videoUnderstanding": "Video Understanding Model", "model.type.rerank": "Rerank Model", "model.type.stt": "Speech-to-Text Model", "model.type.tts": "Text-to-Speech Model", @@ -898,6 +901,9 @@ "modelConfig.option.multiEmbeddingModel": "Multimodal Embedding Model", "modelConfig.option.rerankerModel": "Reranker Model", "modelConfig.option.vlmModel": "Vision Language Model", + "modelConfig.option.imageUnderstandingModel": "Image Understanding Model", + "modelConfig.option.imageGenerationModel": "Image Generation Model", + "modelConfig.option.videoUnderstandingModel": "Video Understanding Model", "modelConfig.option.ttsModel": "Text-to-Speech Model", "modelConfig.option.sttModel": "Speech-to-Text Model", "modelConfig.error.loadList": "Failed to load model list:", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index e37bd8936..b258edf20 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -1,4 +1,4 @@ -{ +{ "assistant.name": "Nexent", "mainPage.layout.title": "Nexent | 智能问答", @@ -820,6 +820,9 @@ "model.type.llm": "大语言模型", "model.type.embedding": "向量模型", "model.type.vlm": "视觉语言模型", + "model.type.imageUnderstanding": "图片理解模型", + "model.type.imageGeneration": "图片生成模型", + "model.type.videoUnderstanding": "视频理解模型", "model.type.rerank": "重排模型", "model.type.stt": "语音识别模型", "model.type.tts": "语音合成模型", @@ -889,6 +892,9 @@ "modelConfig.option.multiEmbeddingModel": "多模态向量模型", "modelConfig.option.rerankerModel": "重排模型", "modelConfig.option.vlmModel": "视觉语言模型", + "modelConfig.option.imageUnderstandingModel": "图片理解模型", + "modelConfig.option.imageGenerationModel": "图片生成模型", + "modelConfig.option.videoUnderstandingModel": "视频理解模型", "modelConfig.option.ttsModel": "语音合成模型", "modelConfig.option.sttModel": "语音识别模型", "modelConfig.error.loadList": "加载模型列表失败:", diff --git a/frontend/tailwind.config.ts b/frontend/tailwind.config.js similarity index 100% rename from frontend/tailwind.config.ts rename to frontend/tailwind.config.js diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index 21b7ef1c5..8f4789f6b 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -31,6 +31,8 @@ export type ModelType = | "stt" | "tts" | "vlm" + | "vlm2" + | "vlm3" | "multi_embedding"; // Model option interface @@ -89,7 +91,7 @@ export interface TTSModelConfig extends SingleModelConfig { // Single model configuration interface export interface SingleModelConfig { - id: number; + id?: number; modelName: string; displayName: string; apiConfig: ModelApiConfig; @@ -103,6 +105,8 @@ export interface ModelConfig { multiEmbedding: SingleModelConfig; rerank: SingleModelConfig; vlm: SingleModelConfig; + vlm2: SingleModelConfig; + vlm3: SingleModelConfig; stt: STTModelConfig; tts: TTSModelConfig; } diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index bce103bb9..2ccf1d72a 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -241,7 +241,7 @@ def create_local_tool(self, tool_config: ToolConfig): data_process_service_url=tool_config.metadata.get("data_process_service_url", []), validate_url_access=validate_url_access, **params) - elif class_name == "AnalyzeImageTool": + elif class_name in ["AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool"]: # Extract validate_url_access from metadata if it's callable validate_url_access = tool_config.metadata.get("validate_url_access") if tool_config.metadata else None if validate_url_access is not None and not callable(validate_url_access): diff --git a/sdk/nexent/core/models/openai_vlm.py b/sdk/nexent/core/models/openai_vlm.py index 1babb0057..cbc7388d6 100644 --- a/sdk/nexent/core/models/openai_vlm.py +++ b/sdk/nexent/core/models/openai_vlm.py @@ -126,6 +126,47 @@ def prepare_image_message(self, image_input: Union[str, BinaryIO], system_prompt return messages + def prepare_media_message( + self, + media_input: Union[str, BinaryIO], + media_type: str, + content_type: str, + system_prompt: str) -> List[Dict[str, Any]]: + """ + Prepare an OpenAI-compatible multimodal message for audio or video inputs. + + Args: + media_input: Media file path or file stream object. + media_type: Either "audio" or "video". + content_type: MIME type for the data URL. + system_prompt: System prompt. + + Returns: + List[Dict[str, Any]]: Prepared message list. + """ + if media_type not in ("audio", "video"): + raise ValueError(f"Unsupported media type: {media_type}") + + base64_media = self.encode_image(media_input) + media_url_key = f"{media_type}_url" + media_config: Dict[str, Any] = {"url": f"data:{content_type};base64,{base64_media}"} + if media_type == "video": + media_config.update({"detail": "high", "max_frames": 16, "fps": 1}) + + messages = [ + { + "role": "user", + "content": [ + { + "type": media_url_key, + media_url_key: media_config + }, + {"type": "text", "text": system_prompt} + ] + } + ] + return messages + def analyze_image(self, image_input: Union[str, BinaryIO], system_prompt: str = "Please describe this picture concisely and carefully, within 200 words.", stream: bool = True, **kwargs) -> ChatMessage: @@ -144,3 +185,23 @@ def analyze_image(self, image_input: Union[str, BinaryIO], messages = self.prepare_image_message(image_input, system_prompt) # Call __call__ explicitly so instance-level mocks work in tests. return self.__call__(messages=messages, **kwargs) + + def analyze_audio( + self, + audio_input: Union[str, BinaryIO], + system_prompt: str = "Please analyze this audio carefully.", + content_type: str = "audio/mpeg", + **kwargs) -> ChatMessage: + """Analyze audio content using the configured multimodal model.""" + messages = self.prepare_media_message(audio_input, "audio", content_type, system_prompt) + return self.__call__(messages=messages, **kwargs) + + def analyze_video( + self, + video_input: Union[str, BinaryIO], + system_prompt: str = "Please analyze this video carefully.", + content_type: str = "video/mp4", + **kwargs) -> ChatMessage: + """Analyze video content using the configured multimodal model.""" + messages = self.prepare_media_message(video_input, "video", content_type, system_prompt) + return self.__call__(messages=messages, **kwargs) diff --git a/sdk/nexent/core/prompts/analyze_audio_en.yaml b/sdk/nexent/core/prompts/analyze_audio_en.yaml new file mode 100644 index 000000000..eee0bb060 --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_audio_en.yaml @@ -0,0 +1,13 @@ +# Audio Understanding Prompt Templates + +system_prompt: |- + The user has asked a question: {{ query }}. Please analyze this audio from the perspective of answering this question, within 300 words. + + **Audio Analysis Requirements:** + 1. Focus on speech, sound events, tone, timing, and other audio content relevant to the user's question + 2. If speech is present, summarize or transcribe the key spoken content when possible + 3. Keep the answer concise and grounded in observable audio evidence + 4. Avoid guessing identities or facts that cannot be inferred from the audio + +user_prompt: | + Please listen to this audio and describe it from the perspective of answering the user's question. diff --git a/sdk/nexent/core/prompts/analyze_audio_zh.yaml b/sdk/nexent/core/prompts/analyze_audio_zh.yaml new file mode 100644 index 000000000..ae6f1fa0d --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_audio_zh.yaml @@ -0,0 +1,13 @@ +# 音频理解 Prompt 模板 + +system_prompt: |- + 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段音频,控制在 300 字以内。 + + **音频分析要求:** + 1. 关注与用户问题相关的语音、声音事件、语气、节奏和其他音频内容 + 2. 如果包含人声,请尽可能总结或转写关键口语内容 + 3. 回答要简洁,并基于音频中可观察到的信息 + 4. 不要猜测无法从音频中判断的身份或事实 + +user_prompt: | + 请仔细聆听这段音频,并从回答用户问题的角度进行描述。 diff --git a/sdk/nexent/core/prompts/analyze_video_en.yaml b/sdk/nexent/core/prompts/analyze_video_en.yaml new file mode 100644 index 000000000..7834ca7f3 --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_video_en.yaml @@ -0,0 +1,13 @@ +# Video Understanding Prompt Templates + +system_prompt: |- + The user has asked a question: {{ query }}. Please analyze this video from the perspective of answering this question, within 300 words. + + **Video Analysis Requirements:** + 1. Focus on scenes, actions, objects, people, visible text, and temporal changes relevant to the user's question + 2. Mention important audio cues only when they help answer the question + 3. Keep the answer concise, structured, and grounded in visible or audible evidence + 4. Avoid over-interpreting intent or facts that cannot be inferred from the video + +user_prompt: | + Please watch this video and describe it from the perspective of answering the user's question. diff --git a/sdk/nexent/core/prompts/analyze_video_zh.yaml b/sdk/nexent/core/prompts/analyze_video_zh.yaml new file mode 100644 index 000000000..e83a1676d --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_video_zh.yaml @@ -0,0 +1,13 @@ +# 视频理解 Prompt 模板 + +system_prompt: |- + 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段视频,控制在 300 字以内。 + + **视频分析要求:** + 1. 关注与用户问题相关的场景、动作、物体、人物、可见文字和时间变化 + 2. 只有在有助于回答问题时,才补充重要的音频线索 + 3. 回答要简洁、有条理,并基于视频中可见或可听的信息 + 4. 不要过度推断无法从视频中判断的意图或事实 + +user_prompt: | + 请仔细观看这段视频,并从回答用户问题的角度进行描述。 diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py index 851690f16..a640cb5ff 100644 --- a/sdk/nexent/core/tools/__init__.py +++ b/sdk/nexent/core/tools/__init__.py @@ -19,6 +19,8 @@ from .terminal_tool import TerminalTool from .analyze_text_file_tool import AnalyzeTextFileTool from .analyze_image_tool import AnalyzeImageTool +from .analyze_audio_tool import AnalyzeAudioTool +from .analyze_video_tool import AnalyzeVideoTool from .run_skill_script_tool import run_skill_script from .read_skill_md_tool import read_skill_md from .read_skill_config_tool import read_skill_config @@ -47,6 +49,8 @@ "TerminalTool", "AnalyzeTextFileTool", "AnalyzeImageTool", + "AnalyzeAudioTool", + "AnalyzeVideoTool", "run_skill_script", "read_skill_md", "read_skill_config" diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py new file mode 100644 index 000000000..c7509a6c2 --- /dev/null +++ b/sdk/nexent/core/tools/analyze_audio_tool.py @@ -0,0 +1,169 @@ +""" +Analyze Audio Tool + +Analyze audio using the configured video understanding model. +Supports audio from S3, HTTP, and HTTPS URLs. +""" + +import logging +from io import BytesIO +from typing import List + +from jinja2 import StrictUndefined, Template +from pydantic import Field +from smolagents.tools import Tool + +from ...core.models import OpenAIVLModel +from ...core.utils.observer import MessageObserver, ProcessType +from ...core.utils.prompt_template_utils import get_prompt_template +from ...core.utils.tools_common_message import ToolCategory, ToolSign +from ...multi_modal.load_save_object import LoadSaveObjectManager +from ...multi_modal.utils import detect_content_type_from_bytes +from ...storage import MinIOStorageClient + +logger = logging.getLogger("analyze_audio_tool") + + +class AnalyzeAudioTool(Tool): + """Tool for understanding and analyzing audio using the video understanding model.""" + + name = "analyze_audio" + description = ( + "This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n" + "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "HTTP, and HTTPS URLs.\n" + "Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query." + ) + description_zh = ( + "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。" + "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + ) + + inputs = { + "audio_urls_list": { + "type": "array", + "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + }, + "query": { + "type": "string", + "description": "User's question to guide the audio analysis", + "description_zh": "用户用于指导音频分析的问题" + } + } + + init_param_descriptions = { + "observer": {"description": "Message observer"}, + "vlm_model": {"description": "The video understanding model to use"}, + "storage_client": {"description": "Storage client for downloading files"}, + "validate_url_access": { + "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" + } + } + output_type = "array" + category = ToolCategory.MULTIMODAL.value + tool_sign = ToolSign.MULTIMODAL_OPERATION.value + + def __init__( + self, + observer: MessageObserver = Field( + description="Message observer", + default=None, + exclude=True), + vlm_model: OpenAIVLModel = Field( + description="The video understanding model to use", + default=None, + exclude=True), + storage_client: MinIOStorageClient = Field( + description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.", + default=None, + exclude=True), + validate_url_access: callable = Field( + description="Callback function to validate URL access permissions", + default=None, + exclude=True) + ): + super().__init__() + self.observer = observer + self.vlm_model = vlm_model + self.storage_client = storage_client + self._is_chinese = bool(observer and observer.lang == "zh") + + validate_callback = None + if validate_url_access is not None and callable(validate_url_access): + validate_callback = validate_url_access + self.mm = LoadSaveObjectManager( + storage_client=self.storage_client, + validate_url_access=validate_callback + ) + self.forward = self.mm.load_object( + input_names=["audio_urls_list"])(self._forward_impl) + + self.running_prompt_zh = "正在分析音频..." + self.running_prompt_en = "Analyzing audio..." + + def _validate_audio_capable_model(self) -> None: + """Fail early for SiliconFlow models that are known not to accept audio input.""" + client_kwargs = getattr(self.vlm_model, "client_kwargs", {}) or {} + base_url = client_kwargs.get("base_url", "") if isinstance(client_kwargs, dict) else "" + model_id = str(getattr(self.vlm_model, "model_id", "") or "") + + if "siliconflow" in str(base_url).lower() and model_id and "omni" not in model_id.lower(): + raise ValueError( + "The selected video understanding model does not support audio input on SiliconFlow. " + "Please choose a Qwen3-Omni model for analyze_audio." + ) + + def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: + """Analyze audio files and return one result per audio input.""" + if self.vlm_model is None: + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + logger.error(error_msg) + raise Exception(error_msg) + self._validate_audio_capable_model() + + if self.observer: + running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en + self.observer.add_message("", ProcessType.TOOL, running_prompt) + + if audio_urls_list is None: + raise ValueError("audio_urls cannot be None") + if not isinstance(audio_urls_list, list): + raise ValueError("audio_urls must be a list of bytes") + if not audio_urls_list: + raise ValueError("audio_urls must contain at least one audio file") + + language = self.observer.lang if self.observer else "en" + prompts = get_prompt_template( + template_type='analyze_audio', language=language) + system_prompt = Template( + prompts['system_prompt'], undefined=StrictUndefined).render({'query': query}) + + try: + analysis_results: List[str] = [] + for index, audio_bytes in enumerate(audio_urls_list, start=1): + logger.info(f"Analyzing audio #{index}, query: {query}") + content_type = detect_content_type_from_bytes(audio_bytes) + if not content_type.startswith("audio/"): + content_type = "audio/mpeg" + audio_stream = BytesIO(audio_bytes) + try: + response = self.vlm_model.analyze_audio( + audio_input=audio_stream, + system_prompt=system_prompt, + content_type=content_type + ) + except Exception as e: + error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" + error_msg_en = f"Failed to analyze audio {index}: {str(e)}. Please check if the video understanding model is configured correctly." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + raise Exception(error_msg) + + analysis_results.append(response.content) + + return analysis_results + except Exception as e: + logger.error(f"Error analyzing audio: {str(e)}", exc_info=True) + raise Exception(f"Error analyzing audio: {str(e)}") diff --git a/sdk/nexent/core/tools/analyze_image_tool.py b/sdk/nexent/core/tools/analyze_image_tool.py index 3851a896b..f7640a9dc 100644 --- a/sdk/nexent/core/tools/analyze_image_tool.py +++ b/sdk/nexent/core/tools/analyze_image_tool.py @@ -24,17 +24,17 @@ class AnalyzeImageTool(Tool): - """Tool for understanding and analyzing image using a visual language model""" + """Tool for understanding and analyzing images using the image understanding model.""" name = "analyze_image" description = ( - "This tool uses a visual language model to understand images based on your query and then returns a description of the image.\n" + "This tool uses the configured image understanding model to understand images based on your query and then returns a description of the image.\n" "It is used to understand and analyze multiple images, with image sources supporting S3 URLs (s3://bucket/key or /bucket/key), " "HTTP, and HTTPS URLs.\n" "Use this tool when you want to retrieve information contained in an image and provide the image's URL and your query." ) - description_zh = "使用视觉语言模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + description_zh = "使用图片理解模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" inputs = { "image_urls_list": { @@ -54,7 +54,7 @@ class AnalyzeImageTool(Tool): "description": "Message observer" }, "vlm_model": { - "description": "The VLM model to use" + "description": "The image understanding model to use" }, "storage_client": { "description": "Storage client for downloading files" @@ -74,7 +74,7 @@ def __init__( default=None, exclude=True), vlm_model: OpenAIVLModel = Field( - description="The VLM model to use", + description="The image understanding model to use", default=None, exclude=True), storage_client: MinIOStorageClient = Field( @@ -130,10 +130,10 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]: Raises: Exception: If the image cannot be downloaded or analyzed. """ - # Check if VLM model is available + # Check if the image understanding model is available. if self.vlm_model is None: - error_msg_zh = "视觉语言模型(VLM)未配置,请联系管理员配置VLM模型后重试" - error_msg_en = "Vision Language Model (VLM) is not configured. Please contact your administrator to configure the VLM model and try again." + error_msg_zh = "图片理解模型未配置,请联系管理员配置图片理解模型后重试" + error_msg_en = "Image understanding model is not configured. Please contact your administrator to configure the image understanding model and try again." error_msg = error_msg_zh if self._is_chinese else error_msg_en logger.error(error_msg) raise Exception(error_msg) @@ -170,8 +170,8 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]: system_prompt=system_prompt ) except Exception as e: - error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查VLM模型配置是否正确。" - error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the VLM model is configured correctly." + error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查图片理解模型配置是否正确。" + error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the image understanding model is configured correctly." error_msg = error_msg_zh if self._is_chinese else error_msg_en raise Exception(error_msg) diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py new file mode 100644 index 000000000..3dc033551 --- /dev/null +++ b/sdk/nexent/core/tools/analyze_video_tool.py @@ -0,0 +1,156 @@ +""" +Analyze Video Tool + +Analyze videos using the configured video understanding model. +Supports videos from S3, HTTP, and HTTPS URLs. +""" + +import logging +from io import BytesIO +from typing import List + +from jinja2 import StrictUndefined, Template +from pydantic import Field +from smolagents.tools import Tool + +from ...core.models import OpenAIVLModel +from ...core.utils.observer import MessageObserver, ProcessType +from ...core.utils.prompt_template_utils import get_prompt_template +from ...core.utils.tools_common_message import ToolCategory, ToolSign +from ...multi_modal.load_save_object import LoadSaveObjectManager +from ...multi_modal.utils import detect_content_type_from_bytes +from ...storage import MinIOStorageClient + +logger = logging.getLogger("analyze_video_tool") + + +class AnalyzeVideoTool(Tool): + """Tool for understanding and analyzing videos using the video understanding model.""" + + name = "analyze_video" + description = ( + "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n" + "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "HTTP, and HTTPS URLs.\n" + "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query." + ) + description_zh = ( + "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。" + "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + ) + + inputs = { + "video_urls_list": { + "type": "array", + "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + }, + "query": { + "type": "string", + "description": "User's question to guide the video analysis", + "description_zh": "用户用于指导视频分析的问题" + } + } + + init_param_descriptions = { + "observer": {"description": "Message observer"}, + "vlm_model": {"description": "The video understanding model to use"}, + "storage_client": {"description": "Storage client for downloading files"}, + "validate_url_access": { + "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" + } + } + output_type = "array" + category = ToolCategory.MULTIMODAL.value + tool_sign = ToolSign.MULTIMODAL_OPERATION.value + + def __init__( + self, + observer: MessageObserver = Field( + description="Message observer", + default=None, + exclude=True), + vlm_model: OpenAIVLModel = Field( + description="The video understanding model to use", + default=None, + exclude=True), + storage_client: MinIOStorageClient = Field( + description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.", + default=None, + exclude=True), + validate_url_access: callable = Field( + description="Callback function to validate URL access permissions", + default=None, + exclude=True) + ): + super().__init__() + self.observer = observer + self.vlm_model = vlm_model + self.storage_client = storage_client + self._is_chinese = bool(observer and observer.lang == "zh") + + validate_callback = None + if validate_url_access is not None and callable(validate_url_access): + validate_callback = validate_url_access + self.mm = LoadSaveObjectManager( + storage_client=self.storage_client, + validate_url_access=validate_callback + ) + self.forward = self.mm.load_object( + input_names=["video_urls_list"])(self._forward_impl) + + self.running_prompt_zh = "正在分析视频..." + self.running_prompt_en = "Analyzing video..." + + def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: + """Analyze videos and return one result per video input.""" + if self.vlm_model is None: + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + logger.error(error_msg) + raise Exception(error_msg) + + if self.observer: + running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en + self.observer.add_message("", ProcessType.TOOL, running_prompt) + + if video_urls_list is None: + raise ValueError("video_urls cannot be None") + if not isinstance(video_urls_list, list): + raise ValueError("video_urls must be a list of bytes") + if not video_urls_list: + raise ValueError("video_urls must contain at least one video") + + language = self.observer.lang if self.observer else "en" + prompts = get_prompt_template( + template_type='analyze_video', language=language) + system_prompt = Template( + prompts['system_prompt'], undefined=StrictUndefined).render({'query': query}) + + try: + analysis_results: List[str] = [] + for index, video_bytes in enumerate(video_urls_list, start=1): + logger.info(f"Analyzing video #{index}, query: {query}") + content_type = detect_content_type_from_bytes(video_bytes) + if not content_type.startswith("video/"): + content_type = "video/mp4" + video_stream = BytesIO(video_bytes) + try: + response = self.vlm_model.analyze_video( + video_input=video_stream, + system_prompt=system_prompt, + content_type=content_type + ) + except Exception as e: + error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" + error_msg_en = f"Failed to analyze video {index}: {str(e)}. Please check if the video understanding model is configured correctly." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + raise Exception(error_msg) + + analysis_results.append(response.content) + + return analysis_results + except Exception as e: + logger.error(f"Error analyzing video: {str(e)}", exc_info=True) + raise Exception(f"Error analyzing video: {str(e)}") diff --git a/sdk/nexent/core/utils/prompt_template_utils.py b/sdk/nexent/core/utils/prompt_template_utils.py index ad06e9119..24b273876 100644 --- a/sdk/nexent/core/utils/prompt_template_utils.py +++ b/sdk/nexent/core/utils/prompt_template_utils.py @@ -17,6 +17,14 @@ LANGUAGE["ZH"]: 'core/prompts/analyze_image_zh.yaml', LANGUAGE["EN"]: 'core/prompts/analyze_image_en.yaml' }, + 'analyze_audio': { + LANGUAGE["ZH"]: 'core/prompts/analyze_audio_zh.yaml', + LANGUAGE["EN"]: 'core/prompts/analyze_audio_en.yaml' + }, + 'analyze_video': { + LANGUAGE["ZH"]: 'core/prompts/analyze_video_zh.yaml', + LANGUAGE["EN"]: 'core/prompts/analyze_video_en.yaml' + }, 'analyze_file': { LANGUAGE["ZH"]: 'core/prompts/analyze_file_zh.yaml', LANGUAGE["EN"]: 'core/prompts/analyze_file_en.yaml' @@ -30,6 +38,8 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw Args: template_type: Template type, supports the following values: - 'analyze_image': Analyze image template + - 'analyze_audio': Analyze audio template + - 'analyze_video': Analyze video template - 'analyze_file': Analyze file template (for text files) language: Language code ('zh' or 'en') **kwargs: Additional parameters, for agent type need to pass is_manager parameter @@ -52,4 +62,4 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw # Read and return template content with open(absolute_template_path, 'r', encoding='utf-8') as f: - return yaml.safe_load(f) \ No newline at end of file + return yaml.safe_load(f) diff --git a/sdk/nexent/multi_modal/utils.py b/sdk/nexent/multi_modal/utils.py index e118f6940..bcd6cdd35 100644 --- a/sdk/nexent/multi_modal/utils.py +++ b/sdk/nexent/multi_modal/utils.py @@ -34,10 +34,10 @@ def is_url(url: str) -> Optional[UrlType]: if url.startswith("https://"): return "https" - if url.startswith("s3://"): - bucket_path = url.replace("s3://", "", 1) + if url.startswith("s3://") or url.startswith("s3:/"): + bucket_path = url.replace("s3://", "", 1) if url.startswith("s3://") else url.replace("s3:/", "", 1).lstrip("/") bucket_object = bucket_path.split("/", 1) - if len(bucket_object) == 2 and all(bucket_object): + if len(bucket_object) == 2 and all(bucket_object) and ":" not in bucket_object[0]: return "s3" return None @@ -321,6 +321,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: Supports formats: - s3://bucket/key + - s3:/bucket/key - /bucket/key (MinIO path format) Args: @@ -335,11 +336,16 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: if not s3_url: raise ValueError("S3 URL cannot be empty") - if s3_url.startswith('s3://'): - parts = s3_url.replace('s3://', '').split('/', 1) + if s3_url.startswith('s3://') or s3_url.startswith('s3:/'): + normalized_url = ( + s3_url.replace('s3://', '', 1) + if s3_url.startswith('s3://') + else s3_url.replace('s3:/', '', 1).lstrip('/') + ) + parts = normalized_url.split('/', 1) if len(parts) == 2: bucket, object_name = parts - if not bucket or not object_name: + if not bucket or not object_name or ":" in bucket: raise ValueError(f"Invalid s3:// URL format: {s3_url}") return bucket, object_name raise ValueError(f"Invalid s3:// URL format: {s3_url}") @@ -351,4 +357,4 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: return bucket, object_name raise ValueError(f"Invalid path format: {s3_url}") - raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...") \ No newline at end of file + raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...") diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index 5817fbe27..20340f2ea 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -30,6 +30,21 @@ class ValidationError(Exception): pass +class MCPConnectionError(Exception): + """Mock MCPConnectionError for testing.""" + pass + + +class NotFoundException(Exception): + """Mock NotFoundException for testing.""" + pass + + +class ToolExecutionException(Exception): + """Mock ToolExecutionException for testing.""" + pass + + consts_model_module = types.ModuleType("consts.model") consts_model_module.HistoryItem = HistoryItem sys.modules["consts.model"] = consts_model_module @@ -37,6 +52,9 @@ class ValidationError(Exception): # Mock consts.exceptions module with ValidationError consts_exceptions_module = types.ModuleType("consts.exceptions") consts_exceptions_module.ValidationError = ValidationError +consts_exceptions_module.MCPConnectionError = MCPConnectionError +consts_exceptions_module.NotFoundException = NotFoundException +consts_exceptions_module.ToolExecutionException = ToolExecutionException sys.modules["consts.exceptions"] = consts_exceptions_module # Also add model and exceptions to consts module attributes @@ -165,7 +183,9 @@ def _create_stub_module(name: str, **attrs): services_module = _create_stub_module("services") sys.modules['services'] = services_module sys.modules['services.image_service'] = _create_stub_module( - "services.image_service", get_vlm_model=MagicMock(return_value="stub_vlm") + "services.image_service", + get_vlm_model=MagicMock(return_value="stub_vlm"), + get_video_understanding_model=MagicMock(return_value="stub_video_vlm"), ) sys.modules['services.memory_config_service'] = MagicMock() # Extend services hierarchy with additional stubs @@ -250,6 +270,7 @@ def _create_stub_module(name: str, **attrs): _extract_url_from_card, _build_external_agent_config, _get_external_a2a_agents, + _build_internal_s3_url, _format_minio_files_for_content, _convert_history_with_minio_files, ) @@ -727,6 +748,48 @@ async def test_create_tool_config_list_with_analyze_image_tool(self): assert "validate_url_access" in mock_tool_instance.metadata assert callable(mock_tool_instance.metadata["validate_url_access"]) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "class_name,tool_name", + [ + ("AnalyzeAudioTool", "analyze_audio"), + ("AnalyzeVideoTool", "analyze_video"), + ], + ) + async def test_create_tool_config_list_with_audio_video_tools(self, class_name, tool_name): + """Ensure audio/video tools receive video understanding model metadata.""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = class_name + mock_tool_config.return_value = mock_tool_instance + + with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_video_understanding_model') as mock_get_video_model, \ + patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock): + + mock_search_tools.return_value = [ + { + "class_name": class_name, + "name": tool_name, + "description": "Analyze media tool", + "inputs": "string", + "output_type": "string", + "params": [{"name": "prompt", "default": "describe"}], + "source": "local", + "usage": None + } + ] + mock_get_video_model.return_value = "mock_video_model" + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + assert len(result) == 1 + assert result[0] is mock_tool_instance + mock_get_video_model.assert_called_once_with(tenant_id="tenant_1") + assert mock_tool_instance.metadata["vlm_model"] == "mock_video_model" + assert "storage_client" in mock_tool_instance.metadata + assert callable(mock_tool_instance.metadata["validate_url_access"]) + @pytest.mark.asyncio async def test_create_tool_config_list_with_analyze_text_file_tool(self): """Ensure AnalyzeTextFileTool receives text-specific metadata.""" @@ -3297,6 +3360,26 @@ async def test_create_agent_run_info_is_need_auth_true_includes_token(self): class TestJoinMinioFileDescriptionToQuery: """Tests for the join_minio_file_description_to_query function""" + def test_build_internal_s3_url_prefers_object_name(self): + file = { + "object_name": "attachments/user/image.png", + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + + result = _build_internal_s3_url(file) + + assert result.endswith("/attachments/user/image.png") + assert result.startswith("s3://") + + def test_build_internal_s3_url_rejects_blob_preview_url(self): + file = { + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + + assert _build_internal_s3_url(file) == "" + @pytest.mark.asyncio async def test_join_minio_file_description_to_query_with_files(self): """Test case with file descriptions""" @@ -3345,6 +3428,40 @@ async def test_join_minio_file_description_to_query_no_descriptions(self): assert result == "test query" + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_prefers_object_name_over_blob_url(self): + """Uploaded images should be exposed to internal tools through MinIO, not browser blob URLs.""" + minio_files = [ + { + "object_name": "attachments/user/image.png", + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + ] + query = "describe the image" + + result = await join_minio_file_description_to_query(minio_files, query) + + assert "blob:http" not in result + assert "File name: image.png" in result + assert "attachments/user/image.png" in result + assert "S3 URL: s3://" in result + + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_skips_blob_only_file(self): + """Browser-only preview URLs cannot be used by internal tools.""" + minio_files = [ + { + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + ] + query = "describe the image" + + result = await join_minio_file_description_to_query(minio_files, query) + + assert result == query + @pytest.mark.asyncio async def test_join_minio_file_description_to_query_deduplication_current(self): """Test that duplicate files in current message are de-duplicated by URL""" @@ -4455,6 +4572,21 @@ def test_format_minio_files_for_content_single_file_without_presigned_url(self): result = _format_minio_files_for_content(minio_files) assert result == "\n[Attached files]:\n - file.txt: s3:/bucket/file.txt" + def test_format_minio_files_for_content_uses_object_name_for_blob_url(self): + """Use uploaded object_name instead of browser-only blob preview URL.""" + minio_files = [ + { + "object_name": "attachments/user/image.png", + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + ] + + result = _format_minio_files_for_content(minio_files) + + assert "blob:http" not in result + assert "attachments/user/image.png" in result + def test_format_minio_files_for_content_multiple_files(self): """Test case for multiple files""" minio_files = [ diff --git a/test/backend/services/providers/test_silicon_provider.py b/test/backend/services/providers/test_silicon_provider.py index b947040c3..c596643b2 100644 --- a/test/backend/services/providers/test_silicon_provider.py +++ b/test/backend/services/providers/test_silicon_provider.py @@ -66,7 +66,13 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): mock_response.status_code = 200 mock_response.json.return_value = { "data": [ - {"id": "gpt-4v", "name": "GPT-4 Vision"}, + {"id": "deepseek-ai/DeepSeek-R1", "name": "DeepSeek R1"}, + {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "name": "Qwen2.5 VL"}, + {"id": "OpenGVLab/InternVL2-26B", "name": "InternVL2 26B"}, + {"id": "Pro/moonshotai/Kimi-K2.6", "name": "Kimi K2.6"}, + {"id": "Pro/moonshotai/Kimi-K2.5", "name": "Kimi K2.5"}, + {"id": "Qwen/Qwen3.6-27B", "name": "Qwen3.6 27B"}, + {"id": "Qwen/Qwen3.6-35B-A3B", "name": "Qwen3.6 35B A3B"}, ] } mock_response.raise_for_status = MagicMock() @@ -95,10 +101,66 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) - assert len(result) == 1 - assert result[0]["id"] == "gpt-4v" - assert result[0]["model_type"] == "vlm" - assert result[0]["model_tag"] == "chat" + assert [model["id"] for model in result] == [ + "Qwen/Qwen2.5-VL-72B-Instruct", + "OpenGVLab/InternVL2-26B", + "Pro/moonshotai/Kimi-K2.6", + "Pro/moonshotai/Kimi-K2.5", + "Qwen/Qwen3.6-27B", + "Qwen/Qwen3.6-35B-A3B", + ] + assert all(model["model_type"] == "vlm" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + + @pytest.mark.asyncio + async def test_get_models_vlm3_only_returns_omni_models(self, mocker: MockFixture): + """Test that SiliconFlow video understanding models are restricted to Omni models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + {"id": "Qwen/Qwen3-VL-32B-Instruct", "name": "Qwen3 VL"}, + {"id": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "name": "Qwen3 Omni Instruct"}, + {"id": "Qwen/Qwen3-Omni-30B-A3B-Thinking", "name": "Qwen3 Omni Thinking"}, + {"id": "Qwen/Qwen3-Omni-30B-A3B-Captioner", "name": "Qwen3 Omni Captioner"}, + {"id": "zai-org/GLM-4.5V", "name": "GLM 4.5V"}, + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.silicon_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.silicon_provider.SILICON_GET_URL", + "https://api.siliconflow.com/v1/models" + ) + + provider = SiliconModelProvider() + provider_config = { + "model_type": "vlm3", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert [model["id"] for model in result] == [ + "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "Qwen/Qwen3-Omni-30B-A3B-Thinking", + "Qwen/Qwen3-Omni-30B-A3B-Captioner", + ] + assert all(model["model_type"] == "vlm3" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + call_args = mock_client.get.call_args + assert "sub_type=chat" in call_args[0][0] @pytest.mark.asyncio async def test_get_models_embedding_success(self, mocker: MockFixture): diff --git a/test/backend/services/test_config_sync_service.py b/test/backend/services/test_config_sync_service.py index 7a583792b..d3db1db3b 100644 --- a/test/backend/services/test_config_sync_service.py +++ b/test/backend/services/test_config_sync_service.py @@ -1,4 +1,6 @@ import sys +import types +import importlib from unittest.mock import patch, MagicMock, call import pytest @@ -22,6 +24,31 @@ minio_config_mock = MagicMock() minio_config_mock.validate = MagicMock() +if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'): + sys.modules.pop('consts.const', None) +if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'): + sys.modules.pop('consts', None) + +database_client_module = types.ModuleType('database.client') +database_client_module.MinioClient = MagicMock() +database_client_module.minio_client = minio_client_mock +database_client_module.as_dict = MagicMock(side_effect=lambda value: value) +database_client_module.db_client = MagicMock() +database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value) +database_client_module.get_db_session = MagicMock() +sys.modules['database.client'] = database_client_module +database_package = sys.modules.get('database') or importlib.import_module('database') +setattr(database_package, 'client', database_client_module) +database_model_management_module = types.ModuleType('database.model_management_db') +database_model_management_module.get_model_by_model_id = MagicMock() +database_model_management_module.get_model_id_by_display_name = MagicMock() +database_model_management_module.get_model_records = MagicMock(return_value=[]) +sys.modules['database.model_management_db'] = database_model_management_module +setattr(database_package, 'model_management_db', database_model_management_module) +backend_database_client_module = sys.modules.get('backend.database.client') +if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'): + backend_database_client_module.minio_client = minio_client_mock + patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() patch('nexent.storage.minio_config.MinIOStorageConfig', @@ -29,7 +56,7 @@ patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() patch('database.client.MinioClient', return_value=minio_client_mock).start() -patch('backend.database.client.minio_client', minio_client_mock).start() +patch('backend.database.client.minio_client', minio_client_mock, create=True).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() # Import backend modules after all patches are applied @@ -52,14 +79,17 @@ def service_mocks(): with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \ patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \ patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \ + patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \ patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \ patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \ patch('backend.services.config_sync_service.logger') as mock_logger: + mock_get_model_records.return_value = [] yield { 'tenant_config_manager': mock_tenant_config_manager, 'get_env_key': mock_get_env_key, 'safe_value': mock_safe_value, + 'get_model_records': mock_get_model_records, 'get_model_id': mock_get_model_id, 'get_model_name': mock_get_model_name, 'logger': mock_logger @@ -1336,6 +1366,8 @@ def side_effect(config_key, tenant_id=None): "MULTI_EMBEDDING_ID": {}, "RERANK_ID": {}, "VLM_ID": {}, + "VLM2_ID": {}, + "VLM3_ID": {}, "STT_ID": {}, "TTS_ID": {} } @@ -1348,7 +1380,7 @@ def side_effect(config_key, tenant_id=None): # Assert assert isinstance(result, dict) - assert len(result) == 7 # All model types should be present + assert len(result) == 9 # All model types should be present # Verify successful configs assert result["llm"]["displayName"] == "GPT-4" @@ -1372,20 +1404,20 @@ def test_build_models_config_all_failures(self, service_mocks): # Assert assert isinstance(result, dict) # All model types should still be present with empty configs - assert len(result) == 7 + assert len(result) == 9 # All configs should be empty due to exceptions - for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "stt", "tts"]: + for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "vlm2", "vlm3", "stt", "tts"]: assert result[model_key]["name"] == "" assert result[model_key]["displayName"] == "" assert result[model_key]["apiConfig"]["apiKey"] == "" assert result[model_key]["apiConfig"]["modelUrl"] == "" # Verify that logger.warning was called for each model type - assert service_mocks['logger'].warning.call_count == 7 + assert service_mocks['logger'].warning.call_count == 9 warning_calls = service_mocks['logger'].warning.call_args_list expected_configs = ["LLM_ID", "EMBEDDING_ID", "MULTI_EMBEDDING_ID", - "RERANK_ID", "VLM_ID", "STT_ID", "TTS_ID"] + "RERANK_ID", "VLM_ID", "VLM2_ID", "VLM3_ID", "STT_ID", "TTS_ID"] for i, config_key in enumerate(expected_configs): assert f"Failed to get config for {config_key}: Database completely down" in warning_calls[ i][0][0] diff --git a/test/backend/services/test_config_sync_service_voice.py b/test/backend/services/test_config_sync_service_voice.py index fcfd531f1..1a3144036 100644 --- a/test/backend/services/test_config_sync_service_voice.py +++ b/test/backend/services/test_config_sync_service_voice.py @@ -3,6 +3,8 @@ These tests cover the STT specific fields in save_config_impl. """ import sys +import types +import importlib from unittest.mock import patch, MagicMock import pytest @@ -22,6 +24,31 @@ minio_config_mock = MagicMock() minio_config_mock.validate = MagicMock() +if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'): + sys.modules.pop('consts.const', None) +if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'): + sys.modules.pop('consts', None) + +database_client_module = types.ModuleType('database.client') +database_client_module.MinioClient = MagicMock() +database_client_module.minio_client = minio_client_mock +database_client_module.as_dict = MagicMock(side_effect=lambda value: value) +database_client_module.db_client = MagicMock() +database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value) +database_client_module.get_db_session = MagicMock() +sys.modules['database.client'] = database_client_module +database_package = sys.modules.get('database') or importlib.import_module('database') +setattr(database_package, 'client', database_client_module) +database_model_management_module = types.ModuleType('database.model_management_db') +database_model_management_module.get_model_by_model_id = MagicMock() +database_model_management_module.get_model_id_by_display_name = MagicMock() +database_model_management_module.get_model_records = MagicMock(return_value=[]) +sys.modules['database.model_management_db'] = database_model_management_module +setattr(database_package, 'model_management_db', database_model_management_module) +backend_database_client_module = sys.modules.get('backend.database.client') +if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'): + backend_database_client_module.minio_client = minio_client_mock + patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() patch('nexent.storage.minio_config.MinIOStorageConfig', @@ -29,7 +56,7 @@ patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() patch('database.client.MinioClient', return_value=minio_client_mock).start() -patch('backend.database.client.minio_client', minio_client_mock).start() +patch('backend.database.client.minio_client', minio_client_mock, create=True).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() # Import backend modules after all patches are applied @@ -47,14 +74,17 @@ def service_mocks(): with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \ patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \ patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \ + patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \ patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \ patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \ patch('backend.services.config_sync_service.logger') as mock_logger: + mock_get_model_records.return_value = [] yield { 'tenant_config_manager': mock_tenant_config_manager, 'get_env_key': mock_get_env_key, 'safe_value': mock_safe_value, + 'get_model_records': mock_get_model_records, 'get_model_id': mock_get_model_id, 'get_model_name': mock_get_model_name, 'logger': mock_logger diff --git a/test/backend/services/test_image_service.py b/test/backend/services/test_image_service.py index 1de8d49fd..34f24568c 100644 --- a/test/backend/services/test_image_service.py +++ b/test/backend/services/test_image_service.py @@ -13,10 +13,17 @@ helpers_env = bootstrap_test_env() helpers_env["mock_const"].DATA_PROCESS_SERVICE = "http://mock-data-process-service" -helpers_env["mock_const"].MODEL_CONFIG_MAPPING = {"vlm": "vlm_model_config"} +helpers_env["mock_const"].MODEL_CONFIG_MAPPING = { + "vlm": "vlm_model_config", + "vlm3": "video_model_config", +} mock_const = helpers_env["mock_const"] -from services.image_service import get_vlm_model, proxy_image_impl +from services.image_service import get_image_understanding_model, get_video_understanding_model, get_vlm_model, proxy_image_impl + +image_service_module = sys.modules[get_vlm_model.__module__] +if "services" in sys.modules: + setattr(sys.modules["services"], "image_service", image_service_module) # Sample test data test_url = "https://example.com/image.jpg" @@ -50,7 +57,7 @@ async def test_proxy_image_impl_success(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -85,7 +92,7 @@ async def test_proxy_image_impl_remote_error(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -118,7 +125,7 @@ async def test_proxy_image_impl_500_error(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -146,7 +153,7 @@ async def test_proxy_image_impl_connection_exception(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function - should raise the exception @@ -178,7 +185,7 @@ async def test_proxy_image_impl_with_special_chars(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -213,7 +220,7 @@ async def test_proxy_image_impl_json_parse_error(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function - should raise the exception @@ -253,7 +260,7 @@ async def test_proxy_image_impl_different_status_codes(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -289,7 +296,7 @@ async def test_proxy_image_impl_url_encoding(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function with encoded URL @@ -305,10 +312,10 @@ async def test_proxy_image_impl_url_encoding(): assert f"url={encoded_url}" in called_url -@patch('services.image_service.OpenAIVLModel') -@patch('services.image_service.MessageObserver') -@patch('services.image_service.get_model_name_from_config') -@patch('services.image_service.tenant_config_manager') +@patch.object(image_service_module, 'OpenAIVLModel') +@patch.object(image_service_module, 'MessageObserver') +@patch.object(image_service_module, 'get_model_name_from_config') +@patch.object(image_service_module, 'tenant_config_manager') def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model): """Ensure get_vlm_model builds OpenAIVLModel with tenant config.""" mock_config = { @@ -324,7 +331,7 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, result = get_vlm_model("tenant-1") mock_tenant_config_manager.get_model_config.assert_called_once_with( - key=mock_const.MODEL_CONFIG_MAPPING["vlm"], + key="vlm_model_config", tenant_id="tenant-1" ) mock_message_observer.assert_called_once_with() @@ -342,10 +349,10 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, assert result == mock_model_instance -@patch('services.image_service.OpenAIVLModel') -@patch('services.image_service.MessageObserver') -@patch('services.image_service.get_model_name_from_config') -@patch('services.image_service.tenant_config_manager') +@patch.object(image_service_module, 'OpenAIVLModel') +@patch.object(image_service_module, 'MessageObserver') +@patch.object(image_service_module, 'get_model_name_from_config') +@patch.object(image_service_module, 'tenant_config_manager') def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model): """Return None when tenant config is None.""" mock_tenant_config_manager.get_model_config.return_value = None @@ -359,3 +366,40 @@ def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_mod # OpenAIVLModel should not be called when config is None mock_openai_vl_model.assert_not_called() assert result is None + + +@patch.object(image_service_module, 'get_vlm_model') +def test_get_image_understanding_model_uses_first_multimodal_slot(mock_get_vlm_model): + """Ensure the image understanding alias keeps using the first multimodal slot.""" + mock_get_vlm_model.return_value = "image-understanding-model" + + result = get_image_understanding_model("tenant-1") + + mock_get_vlm_model.assert_called_once_with(tenant_id="tenant-1") + assert result == "image-understanding-model" + + +@patch.object(image_service_module, 'OpenAIVLModel') +@patch.object(image_service_module, 'MessageObserver') +@patch.object(image_service_module, 'get_model_name_from_config') +@patch.object(image_service_module, 'tenant_config_manager') +def test_get_video_understanding_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model): + """Ensure video understanding tools use the third multimodal model slot.""" + mock_config = { + "base_url": "https://mock-video-api", + "api_key": "secret", + "model_name": "video-model" + } + mock_tenant_config_manager.get_model_config.return_value = mock_config + mock_get_model_name.return_value = "video-model" + mock_model_instance = MagicMock() + mock_openai_vl_model.return_value = mock_model_instance + + result = get_video_understanding_model("tenant-1") + + mock_tenant_config_manager.get_model_config.assert_called_once_with( + key="video_model_config", + tenant_id="tenant-1" + ) + mock_openai_vl_model.assert_called_once() + assert result == mock_model_instance diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 479e649bf..83a070fe0 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -326,6 +326,11 @@ def _add_repo_to_name(model_repo, model_name): def import_svc(): """Import service under MinioClient patch to avoid real initialization.""" minio_client_mock = mock.MagicMock() + sys.modules["database"] = database_mod + sys.modules["database.model_management_db"] = db_mm_mod + setattr(database_mod, "model_management_db", db_mm_mod) + sys.modules.pop("backend.services.model_management_service", None) + sys.modules.pop("services.model_management_service", None) with mock.patch("backend.database.client.MinioClient", return_value=minio_client_mock): from backend.services import model_management_service as svc # type: ignore return svc @@ -673,17 +678,11 @@ async def test_batch_create_models_for_tenant_flow(): existing = [ {"model_id": "del-id", "model_repo": "silicon", "model_name": "delete"}, - {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep"}, + {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep", "max_tokens": 1024}, ] - def get_by_display(display_name, tenant_id): - if display_name == "silicon/keep": - return {"model_id": "keep-id", "max_tokens": 1024} - return None - with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing) as mock_get_existing, \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ - mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \ mock.patch.object(svc, "update_model_record") as mock_update, \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \ mock.patch.object(svc, "create_model_record") as mock_create: @@ -692,13 +691,35 @@ def get_by_display(display_name, tenant_id): mock_get_existing.assert_called_once_with("t1", "silicon", "llm") mock_delete.assert_called_once_with("del-id", "u1", "t1") - mock_get_by_display.assert_any_call("silicon/keep", "t1") mock_update.assert_called_once_with( "keep-id", {"max_tokens": 4096}, "u1") mock_prep.assert_awaited() mock_create.assert_called_once() +@pytest.mark.asyncio +async def test_batch_create_models_uses_requested_type_for_each_model(): + svc = import_svc() + + batch_payload = { + "provider": "silicon", + "type": "vlm", + "models": [ + {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "model_type": "llm", "max_tokens": 4096}, + ], + "api_key": "k", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \ + mock.patch.object(svc, "create_model_record"): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + prepared_model = mock_prep.call_args.kwargs["model"] + assert prepared_model["model_type"] == "vlm" + + @pytest.mark.asyncio async def test_batch_create_models_max_tokens_update(): """Test batch_create_models updates max_tokens when display_name exists and max_tokens changed (covers lines 160->173, 168->171)""" @@ -715,22 +736,16 @@ async def test_batch_create_models_max_tokens_update(): "api_key": "k", } - def get_by_display(display_name, tenant_id): - if display_name == "silicon/model1": - # Different from new value - return {"model_id": "id1", "max_tokens": 4096} - elif display_name == "silicon/model2": - return {"model_id": "id2", "max_tokens": 4096} # Same as new value - elif display_name == "silicon/model3": - # Existing has value, new is None - return {"model_id": "id3", "max_tokens": 2048} - return None + existing = [ + {"model_id": "id1", "model_repo": "silicon", "model_name": "model1", "max_tokens": 4096}, + {"model_id": "id2", "model_repo": "silicon", "model_name": "model2", "max_tokens": 4096}, + {"model_id": "id3", "model_repo": "silicon", "model_name": "model3", "max_tokens": 2048}, + ] - with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing), \ mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \ - mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \ - mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda *args, **kwargs: f"{kwargs.get('model_repo', args[0] if args else '')}/{kwargs.get('model_name', args[1] if len(args) > 1 else '')}"), \ mock.patch.object(svc, "update_model_record") as mock_update, \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 7855fbbbb..7b08788fa 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -14,8 +14,59 @@ boto3_mock = MagicMock() minio_client_mock = MagicMock() sys.modules['boto3'] = boto3_mock +jsonref_mock = types.ModuleType('jsonref') +jsonref_mock.replace_refs = lambda value: value +sys.modules['jsonref'] = jsonref_mock -# Patch smolagents and its sub-modules before importi ng consts.model to avoid ImportError +fastmcp_mock = types.ModuleType('fastmcp') +fastmcp_mock.__path__ = [] + + +class MockFastMcpClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def is_connected(self): + return True + + async def call_tool(self, *args, **kwargs): + return MagicMock() + + +class MockSSETransport: + def __init__(self, *args, **kwargs): + pass + + +class MockStreamableHttpTransport: + def __init__(self, *args, **kwargs): + pass + + +fastmcp_mock.Client = MockFastMcpClient +fastmcp_client_mock = types.ModuleType('fastmcp.client') +fastmcp_client_mock.__path__ = [] +fastmcp_transports_mock = types.ModuleType('fastmcp.client.transports') +fastmcp_transports_mock.SSETransport = MockSSETransport +fastmcp_transports_mock.StreamableHttpTransport = MockStreamableHttpTransport +sys.modules['fastmcp'] = fastmcp_mock +sys.modules['fastmcp.client'] = fastmcp_client_mock +sys.modules['fastmcp.client.transports'] = fastmcp_transports_mock + +mcpadapt_mock = types.ModuleType('mcpadapt') +mcpadapt_mock.__path__ = [] +mcpadapt_smolagents_adapter_mock = types.ModuleType('mcpadapt.smolagents_adapter') +mcpadapt_smolagents_adapter_mock._sanitize_function_name = lambda name: name +sys.modules['mcpadapt'] = mcpadapt_mock +sys.modules['mcpadapt.smolagents_adapter'] = mcpadapt_smolagents_adapter_mock + +# Patch smolagents and its sub-modules before importing consts.model to avoid ImportError mock_smolagents = MagicMock() sys.modules['smolagents'] = mock_smolagents @@ -328,6 +379,38 @@ def validate(self): sys.modules['nexent.monitor'] = monitor_module setattr(sys.modules['nexent'], 'monitor', monitor_module) +# Mock services modules before importing tool_configuration_service so absolute +# imports inside that module do not walk into real service dependency chains. +sys.modules['services'] = _create_package_mock('services') +services_modules = { + 'file_management_service': { + 'get_llm_model': MagicMock(), + 'validate_urls_access': MagicMock(return_value=True), + }, + 'vectordatabase_service': { + 'get_embedding_model': MagicMock(), + 'get_embedding_model_by_index_name': MagicMock(), + 'get_rerank_model': MagicMock(), + 'get_vector_db_core': MagicMock(), + 'ElasticSearchService': MagicMock(), + }, + 'tenant_config_service': { + 'get_selected_knowledge_list': MagicMock(), + 'build_knowledge_name_mapping': MagicMock(), + }, + 'image_service': { + 'get_vlm_model': MagicMock(), + 'get_video_understanding_model': MagicMock(), + }, +} +for service_name, attrs in services_modules.items(): + service_module = types.ModuleType(f'services.{service_name}') + for attr_name, attr_value in attrs.items(): + setattr(service_module, attr_name, attr_value) + sys.modules[f'services.{service_name}'] = service_module + # Expose on parent package for patch resolution + setattr(sys.modules['services'], service_name, service_module) + # Load actual backend modules so that patch targets resolve correctly import importlib # noqa: E402 backend_module = importlib.import_module('backend') @@ -342,23 +425,6 @@ def validate(self): # Ensure services package can resolve tool_configuration_service for patching sys.modules['services.tool_configuration_service'] = backend_services_module -# Mock services modules -sys.modules['services'] = _create_package_mock('services') -services_modules = { - 'file_management_service': {'get_llm_model': MagicMock()}, - 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(), - 'ElasticSearchService': MagicMock()}, - 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()}, - 'image_service': {'get_vlm_model': MagicMock()} -} -for service_name, attrs in services_modules.items(): - service_module = types.ModuleType(f'services.{service_name}') - for attr_name, attr_value in attrs.items(): - setattr(service_module, attr_name, attr_value) - sys.modules[f'services.{service_name}'] = service_module - # Expose on parent package for patch resolution - setattr(sys.modules['services'], service_name, service_module) - # Patch storage factory and MinIO config validation to avoid errors during initialization # These patches must be started before any imports that use MinioClient storage_client_mock = MagicMock() @@ -380,6 +446,7 @@ def validate(self): patch('services.tenant_config_service.build_knowledge_name_mapping', MagicMock()).start() patch('services.image_service.get_vlm_model', MagicMock()).start() +patch('services.image_service.get_video_understanding_model', MagicMock()).start() patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start() patch('backend.services.tool_configuration_service.get_embedding_model_by_index_name', MagicMock()).start() @@ -2703,6 +2770,63 @@ def test_validate_local_tool_analyze_image_missing_user(self, mock_get_class): ) +class TestValidateLocalToolAnalyzeAudioVideo: + """Test cases for _validate_local_tool with analyze_audio/analyze_video tools.""" + + @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"]) + @patch('backend.services.tool_configuration_service.minio_client') + @patch('backend.services.tool_configuration_service.get_video_understanding_model') + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.inspect.signature') + def test_validate_local_tool_analyze_audio_video_success( + self, mock_signature, mock_get_class, mock_get_video_model, mock_minio_client, tool_name): + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = f"{tool_name} result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + mock_get_video_model.return_value = "mock_video_model" + + mock_sig = Mock() + mock_sig.parameters = {} + mock_signature.return_value = mock_sig + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + tool_name, + {"media": "bytes"}, + {"prompt": "describe"}, + "tenant1", + "user1" + ) + + assert result == f"{tool_name} result" + mock_get_video_model.assert_called_once_with(tenant_id="tenant1") + call_kwargs = mock_tool_class.call_args.kwargs + assert call_kwargs["vlm_model"] == "mock_video_model" + assert "storage_client" in call_kwargs + assert callable(call_kwargs["validate_url_access"]) + mock_tool_instance.forward.assert_called_once_with(media="bytes") + + @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"]) + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + def test_validate_local_tool_analyze_audio_video_missing_tenant(self, mock_get_class, tool_name): + mock_get_class.return_value = Mock() + + from backend.services.tool_configuration_service import _validate_local_tool + + with pytest.raises(ToolExecutionException, + match=f"Tenant ID and User ID are required for {tool_name} validation"): + _validate_local_tool( + tool_name, + {"media": "bytes"}, + {"prompt": "describe"}, + None, + "user1" + ) + + class TestValidateLocalToolDatamateSearchTool: """Test cases for _validate_local_tool function with datamate_search_tool""" diff --git a/test/common/test_mocks.py b/test/common/test_mocks.py index c87b52859..c57941780 100644 --- a/test/common/test_mocks.py +++ b/test/common/test_mocks.py @@ -112,6 +112,8 @@ def setup_common_mocks(): "multiEmbedding": "MULTI_EMBEDDING_ID", "rerank": "RERANK_ID", "vlm": "VLM_ID", + "vlm2": "VLM2_ID", + "vlm3": "VLM3_ID", "stt": "STT_ID", "tts": "TTS_ID" } diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 7d6af852c..ab09b95b6 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -2475,6 +2475,52 @@ def test_create_local_tool_analyze_image(self, nexent_agent_instance): assert call_kwargs["param1"] == "value1" assert result == mock_tool_instance + @pytest.mark.parametrize( + "class_name,tool_name", + [ + ("AnalyzeAudioTool", "analyze_audio"), + ("AnalyzeVideoTool", "analyze_video"), + ], + ) + def test_create_local_tool_analyze_audio_video(self, nexent_agent_instance, class_name, tool_name): + """Test successful audio/video analysis tool creation.""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name=class_name, + name=tool_name, + description="desc", + inputs="{}", + output_type="string", + params={"param1": "value1"}, + source="local", + metadata={ + "vlm_model": ["video-understanding-model"], + "storage_client": "storage" + } + ) + + original_value = nexent_agent.__dict__.get(class_name) + nexent_agent.__dict__[class_name] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__[class_name] = original_value + elif class_name in nexent_agent.__dict__: + del nexent_agent.__dict__[class_name] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["observer"] == nexent_agent_instance.observer + assert call_kwargs["vlm_model"] == ["video-understanding-model"] + assert call_kwargs["storage_client"] == "storage" + assert call_kwargs["param1"] == "value1" + assert result == mock_tool_instance + def test_create_local_tool_analyze_text_file_with_validate_url_access_none(self, nexent_agent_instance): """Test AnalyzeTextFileTool creation with validate_url_access not in metadata (None).""" mock_tool_class = MagicMock() diff --git a/test/sdk/core/models/test_openai_vlm.py b/test/sdk/core/models/test_openai_vlm.py index f1db49380..4f7104290 100644 --- a/test/sdk/core/models/test_openai_vlm.py +++ b/test/sdk/core/models/test_openai_vlm.py @@ -62,7 +62,13 @@ def vl_model_instance(): """Return an OpenAIVLModel instance with minimal viable attributes for tests.""" observer = MagicMock() - model = ImportedOpenAIVLModel(observer=observer, ssl_verify=True) + model = ImportedOpenAIVLModel( + observer=observer, + model_id="dummy-model", + api_key="dummy-key", + api_base="https://example.test", + ssl_verify=True, + ) # Inject dummy attributes required by the method under test model.model_id = "dummy-model" @@ -321,3 +327,55 @@ def test_analyze_image_calls_prepare_image_message(vl_model_instance, tmp_path): # Verify prepare_image_message was called with correct arguments mock_prepare.assert_called_once_with(str(test_image), custom_prompt) + + +def test_prepare_media_message_audio(vl_model_instance): + audio_stream = MagicMock() + audio_stream.read.return_value = b"audio bytes" + + messages = vl_model_instance.prepare_media_message( + audio_stream, + media_type="audio", + content_type="audio/mpeg", + system_prompt="Listen carefully", + ) + + assert messages[0]["content"][0]["type"] == "audio_url" + assert messages[0]["content"][0]["audio_url"]["url"].startswith("data:audio/mpeg;base64,") + assert messages[0]["content"][1] == {"type": "text", "text": "Listen carefully"} + + +def test_prepare_media_message_video(vl_model_instance): + video_stream = MagicMock() + video_stream.read.return_value = b"video bytes" + + messages = vl_model_instance.prepare_media_message( + video_stream, + media_type="video", + content_type="video/mp4", + system_prompt="Watch carefully", + ) + + assert messages[0]["content"][0]["type"] == "video_url" + assert messages[0]["content"][0]["video_url"]["url"].startswith("data:video/mp4;base64,") + assert messages[0]["content"][0]["video_url"]["max_frames"] == 16 + assert messages[0]["content"][0]["video_url"]["fps"] == 1 + assert messages[0]["content"][1] == {"type": "text", "text": "Watch carefully"} + + +def test_analyze_audio_calls_prepare_media_message(vl_model_instance): + with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare: + vl_model_instance.__call__ = MagicMock(return_value=MagicMock()) + + vl_model_instance.analyze_audio("audio.mp3", system_prompt="Analyze", content_type="audio/mpeg") + + mock_prepare.assert_called_once_with("audio.mp3", "audio", "audio/mpeg", "Analyze") + + +def test_analyze_video_calls_prepare_media_message(vl_model_instance): + with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare: + vl_model_instance.__call__ = MagicMock(return_value=MagicMock()) + + vl_model_instance.analyze_video("video.mp4", system_prompt="Analyze", content_type="video/mp4") + + mock_prepare.assert_called_once_with("video.mp4", "video", "video/mp4", "Analyze") diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py new file mode 100644 index 000000000..94401b61d --- /dev/null +++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py @@ -0,0 +1,119 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from sdk.nexent.core.tools import analyze_audio_tool, analyze_video_tool +from sdk.nexent.core.tools.analyze_audio_tool import AnalyzeAudioTool +from sdk.nexent.core.tools.analyze_video_tool import AnalyzeVideoTool +from sdk.nexent.core.utils.observer import MessageObserver, ProcessType + + +@pytest.fixture +def mock_storage_client(): + class DummyStorage: + pass + + return DummyStorage() + + +@pytest.fixture +def mock_vlm_model(): + return MagicMock() + + +@pytest.fixture +def observer_en(): + observer = MagicMock(spec=MessageObserver) + observer.lang = "en" + return observer + + +def test_analyze_audio_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + calls = [] + + def _fake_get_prompt(template_type, language=None, **_): + calls.append((template_type, language)) + return {"system_prompt": "Analyze audio for {{ query }}"} + + monkeypatch.setattr(analyze_audio_tool, "get_prompt_template", _fake_get_prompt) + mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result") + tool = AnalyzeAudioTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl([b"ID3audio-bytes"], "what happened?") + + assert result == ["audio result"] + assert calls == [("analyze_audio", "en")] + mock_vlm_model.analyze_audio.assert_called_once() + call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs + assert hasattr(call_kwargs["audio_input"], "read") + assert call_kwargs["content_type"].startswith("audio/") + observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...") + + +def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client): + vlm_model = SimpleNamespace( + model_id="Qwen/Qwen3-VL-32B-Instruct", + client_kwargs={"base_url": "https://api.siliconflow.cn/v1"}, + ) + tool = AnalyzeAudioTool( + observer=observer_en, + vlm_model=vlm_model, + storage_client=mock_storage_client, + ) + + with pytest.raises(ValueError) as exc_info: + tool._forward_impl([b"ID3audio-bytes"], "what happened?") + + assert "Please choose a Qwen3-Omni model" in str(exc_info.value) + + +def test_analyze_video_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + calls = [] + + def _fake_get_prompt(template_type, language=None, **_): + calls.append((template_type, language)) + return {"system_prompt": "Analyze video for {{ query }}"} + + monkeypatch.setattr(analyze_video_tool, "get_prompt_template", _fake_get_prompt) + mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result") + tool = AnalyzeVideoTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?") + + assert result == ["video result"] + assert calls == [("analyze_video", "en")] + mock_vlm_model.analyze_video.assert_called_once() + call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs + assert hasattr(call_kwargs["video_input"], "read") + assert call_kwargs["content_type"].startswith("video/") + observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...") + + +@pytest.mark.parametrize( + "tool_class,input_name,error_text", + [ + (AnalyzeAudioTool, "audio_urls_list", "Video understanding model is not configured"), + (AnalyzeVideoTool, "video_urls_list", "Video understanding model is not configured"), + ], +) +def test_analyze_audio_video_require_video_understanding_model( + tool_class, input_name, error_text, observer_en, mock_storage_client): + tool = tool_class( + observer=observer_en, + vlm_model=None, + storage_client=mock_storage_client, + ) + + with pytest.raises(Exception) as exc_info: + tool._forward_impl(**{input_name: [b"media"], "query": "question"}) + + assert error_text in str(exc_info.value) diff --git a/test/sdk/core/tools/test_analyze_image_tool.py b/test/sdk/core/tools/test_analyze_image_tool.py index a8598a8ad..63be0ac54 100644 --- a/test/sdk/core/tools/test_analyze_image_tool.py +++ b/test/sdk/core/tools/test_analyze_image_tool.py @@ -136,7 +136,7 @@ def test_forward_impl_vlm_model_none(self, observer_en, mock_storage_client): with pytest.raises(Exception) as exc_info: tool._forward_impl([b"img"], "question") - assert "Vision Language Model (VLM) is not configured" in str( + assert "Image understanding model is not configured" in str( exc_info.value) def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_client): @@ -150,7 +150,7 @@ def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_cli with pytest.raises(Exception) as exc_info: tool._forward_impl([b"img"], "问题") - assert "视觉语言模型(VLM)未配置" in str(exc_info.value) + assert "图片理解模型未配置" in str(exc_info.value) def test_forward_impl_observer_none_uses_english(self, mock_vlm_model, mock_storage_client): """Test that English is used when observer is None.""" @@ -353,7 +353,7 @@ def test_observer_add_message_not_called_when_none(self, mock_vlm_model, mock_st def test_tool_name_and_description(self, tool): """Test that tool name and description are set correctly.""" assert tool.name == "analyze_image" - assert "visual language model" in tool.description.lower() + assert "image understanding model" in tool.description.lower() assert "image" in tool.description.lower() def test_tool_inputs_schema(self, tool): diff --git a/test/sdk/core/utils/test_prompt_template_utils.py b/test/sdk/core/utils/test_prompt_template_utils.py index c0a3ad634..a50929b8d 100644 --- a/test/sdk/core/utils/test_prompt_template_utils.py +++ b/test/sdk/core/utils/test_prompt_template_utils.py @@ -61,6 +61,28 @@ def test_get_prompt_template_analyze_image_en(self, mock_yaml_load, mock_file): # Verify result assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"} + @pytest.mark.parametrize( + "template_type,language,expected_file", + [ + ("analyze_audio", "en", "prompts/analyze_audio_en.yaml"), + ("analyze_audio", "zh", "prompts/analyze_audio_zh.yaml"), + ("analyze_video", "en", "prompts/analyze_video_en.yaml"), + ("analyze_video", "zh", "prompts/analyze_video_zh.yaml"), + ], + ) + @patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"\nuser_prompt: "User prompt"') + @patch('yaml.safe_load') + def test_get_prompt_template_analyze_audio_video( + self, mock_yaml_load, mock_file, template_type, language, expected_file): + """Test get_prompt_template for audio/video templates.""" + mock_yaml_load.return_value = {"system_prompt": "Test prompt", "user_prompt": "User prompt"} + + result = get_prompt_template(template_type=template_type, language=language) + + call_args = mock_file.call_args[0] + assert expected_file in call_args[0].replace('\\', '/') + assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"} + @patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"') @patch('yaml.safe_load') @patch('sdk.nexent.core.utils.prompt_template_utils.LANGUAGE', {'ZH': 'zh', 'EN': 'en'}) @@ -174,4 +196,4 @@ def test_get_prompt_template_path_resolution(self, mock_yaml_load, mock_file): assert mock_file.called call_args = mock_file.call_args[0] # Path should be absolute or contain the expected template file - assert 'analyze_image_en.yaml' in call_args[0] \ No newline at end of file + assert 'analyze_image_en.yaml' in call_args[0] diff --git a/test/sdk/multi_modal/test_load_save_object.py b/test/sdk/multi_modal/test_load_save_object.py index 92425791c..1670e6a9d 100644 --- a/test/sdk/multi_modal/test_load_save_object.py +++ b/test/sdk/multi_modal/test_load_save_object.py @@ -26,6 +26,23 @@ def test_get_client_requires_initialized_storage(): manager._get_client() +def test_s3_single_slash_url_supported(): + assert lso.is_url("s3:/bucket/path/to/image.png") == "s3" + assert lso.parse_s3_url("s3:/bucket/path/to/image.png") == ( + "bucket", + "path/to/image.png", + ) + + +def test_s3_blob_preview_url_rejected(): + assert lso.is_url("s3:/blob:http://localhost:3000/preview") is None + + +def test_parse_s3_blob_preview_url_rejected(): + with pytest.raises(ValueError, match="Invalid s3:// URL format"): + lso.parse_s3_url("s3:/blob:http://localhost:3000/preview") + + def test_download_file_from_http(monkeypatch): manager = make_manager()