diff --git a/.dockerignore b/.dockerignore
index 45c1def32..385a6449f 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -37,8 +37,6 @@ build/
*.tgz
# Backend
-backend/assets/*
-!backend/assets/test.wav
backend/flower_db.sqlite
uploads/
test/
@@ -60,4 +58,4 @@ assets/
.Spotlight-V100
.Trashes
ehthumbs.db
-Thumbs.db
\ No newline at end of file
+Thumbs.db
diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py
index 8a3fbc807..b38b17e56 100644
--- a/backend/agents/create_agent_info.py
+++ b/backend/agents/create_agent_info.py
@@ -21,7 +21,7 @@
from database.a2a_agent_db import PROTOCOL_JSONRPC
from services.memory_config_service import build_memory_context
-from services.image_service import get_vlm_model
+from services.image_service import get_video_understanding_model, get_vlm_model
from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list
from database.agent_version_db import query_current_version_no
from database.tool_db import search_tools_for_sub_agent
@@ -31,13 +31,36 @@
from utils.model_name_utils import add_repo_to_name
from utils.prompt_template_utils import get_agent_prompt_template
from utils.config_utils import tenant_config_manager, get_model_name_from_config
-from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE
+from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET
from consts.exceptions import ValidationError
logger = logging.getLogger("create_agent_info")
logger.setLevel(logging.DEBUG)
+def _build_internal_s3_url(file: dict) -> str:
+ """Build a valid S3 URL for internal tools from uploaded file metadata."""
+ if not isinstance(file, dict):
+ return ""
+
+ object_name = str(file.get("object_name") or "").strip().lstrip("/")
+ if object_name:
+ bucket = MINIO_DEFAULT_BUCKET or "nexent"
+ return f"s3://{bucket}/{object_name}"
+
+ url = str(file.get("url") or "").strip()
+ if not url or url.startswith("blob:") or url.startswith("s3:/blob:"):
+ return ""
+
+ if url.startswith("s3://"):
+ return url
+
+ if url.startswith("s3:/"):
+ return "s3://" + url.replace("s3:/", "", 1).lstrip("/")
+
+ return "s3:/" + url
+
+
def _get_skills_for_template(
agent_id: int,
tenant_id: str,
@@ -532,10 +555,17 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
}
elif tool_config.class_name == "AnalyzeImageTool":
tool_config.metadata = {
+ # get_vlm_model reads the first multimodal slot, now shown as image understanding.
"vlm_model": get_vlm_model(tenant_id=tenant_id),
"storage_client": minio_client,
"validate_url_access": lambda urls: validate_urls_access(urls, user_id)
}
+ elif tool_config.class_name in ["AnalyzeAudioTool", "AnalyzeVideoTool"]:
+ tool_config.metadata = {
+ "vlm_model": get_video_understanding_model(tenant_id=tenant_id),
+ "storage_client": minio_client,
+ "validate_url_access": lambda urls: validate_urls_access(urls, user_id)
+ }
tool_config_list.append(tool_config)
@@ -636,10 +666,12 @@ async def join_minio_file_description_to_query(
# Collect files from current message first (higher priority)
if minio_files and isinstance(minio_files, list):
for file in minio_files:
- if isinstance(file, dict) and file.get("url") and file.get("name"):
- url = file["url"]
- if url not in seen_urls:
- seen_urls.add(url)
+ if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")):
+ s3_url = _build_internal_s3_url(file)
+ if not s3_url:
+ continue
+ if s3_url not in seen_urls:
+ seen_urls.add(s3_url)
all_files.append(file)
# Collect files from historical messages (lower priority, already-deduped)
@@ -647,10 +679,12 @@ async def join_minio_file_description_to_query(
for msg in history:
if isinstance(msg, dict) and msg.get("minio_files"):
for file in msg["minio_files"]:
- if isinstance(file, dict) and file.get("url") and file.get("name"):
- url = file["url"]
- if url not in seen_urls:
- seen_urls.add(url)
+ if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")):
+ s3_url = _build_internal_s3_url(file)
+ if not s3_url:
+ continue
+ if s3_url not in seen_urls:
+ seen_urls.add(s3_url)
all_files.append(file)
# Enforce file count limit (keep most recent files by truncating from the end)
@@ -666,7 +700,7 @@ async def join_minio_file_description_to_query(
fixed_overhead = len(prefix) + len(suffix)
for i, file in enumerate(all_files):
- s3_url = f"s3:/{file['url']}"
+ s3_url = _build_internal_s3_url(file)
presigned_url = file.get("presigned_url", "")
# Build description with both URLs
@@ -718,8 +752,10 @@ def _format_minio_files_for_content(minio_files: Optional[List[dict]], max_files
if i >= max_files:
file_lines.append(f" - ... (and {len(minio_files) - max_files} more files)")
break
- if isinstance(file, dict) and file.get("url") and file.get("name"):
- s3_url = f"s3:/{file['url']}"
+ if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")):
+ s3_url = _build_internal_s3_url(file)
+ if not s3_url:
+ continue
presigned_url = file.get("presigned_url", "")
if presigned_url:
file_lines.append(
diff --git a/backend/consts/const.py b/backend/consts/const.py
index fdc09c9e7..bbfb1ce41 100644
--- a/backend/consts/const.py
+++ b/backend/consts/const.py
@@ -308,6 +308,8 @@ class VectorDatabaseType(str, Enum):
"multiEmbedding": "MULTI_EMBEDDING_ID",
"rerank": "RERANK_ID",
"vlm": "VLM_ID",
+ "vlm2": "VLM2_ID",
+ "vlm3": "VLM3_ID",
"stt": "STT_ID",
"tts": "TTS_ID"
}
diff --git a/backend/consts/model.py b/backend/consts/model.py
index 10dc05231..f17cef5e0 100644
--- a/backend/consts/model.py
+++ b/backend/consts/model.py
@@ -177,6 +177,14 @@ class STTModelConfig(BaseModel):
accessToken: Optional[str] = None
+def _empty_model_config() -> SingleModelConfig:
+ return SingleModelConfig(
+ modelName="",
+ displayName="",
+ apiConfig=ModelApiConfig(apiKey="", modelUrl="")
+ )
+
+
class TTSModelConfig(BaseModel):
"""TTS model specific configuration with factory, appid, and access token fields"""
modelName: str
@@ -193,6 +201,8 @@ class ModelConfig(BaseModel):
multiEmbedding: SingleModelConfig
rerank: SingleModelConfig
vlm: SingleModelConfig
+ vlm2: SingleModelConfig = Field(default_factory=_empty_model_config)
+ vlm3: SingleModelConfig = Field(default_factory=_empty_model_config)
stt: STTModelConfig
tts: TTSModelConfig
diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml
index 1cbe81096..5c2893c39 100644
--- a/backend/prompts/managed_system_prompt_template_en.yaml
+++ b/backend/prompts/managed_system_prompt_template_en.yaml
@@ -117,7 +117,7 @@ system_prompt: |-
→ Use **presigned_url** (already includes proxy prefix, format: `http://.../api/nb/v1/file/fetch?presigned_url=...`)
Directly use the **presigned_url** field provided in the user's uploaded file info. No need to construct or append anything.
2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image):
- → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`)
+ → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`)
Reason: Internal tools run inside Nexent and can directly access MinIO storage
{%- else %}
diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml
index 464ec5264..8ce58db29 100644
--- a/backend/prompts/manager_system_prompt_template_en.yaml
+++ b/backend/prompts/manager_system_prompt_template_en.yaml
@@ -119,7 +119,7 @@ system_prompt: |-
→ Use **Download URL** (format: `https://minio.example.com/...?token=xxx`)
Reason: MCP tools run on external services and cannot access internal S3 storage
2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image):
- → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`)
+ → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`)
Reason: Internal tools run inside Nexent and can directly access MinIO storage
{%- else %}
- No tools are currently available
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 3889c9d58..12b2ebebd 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -6,6 +6,7 @@ dependencies = [
"aiofiles>=0.8.0",
"uvicorn>=0.34.0",
"fastapi>=0.115.12",
+ "email-validator>=2.0.0",
"aiohttp>=3.8.0",
"authlib>=1.3.0",
"cryptography>=42.0.0",
@@ -16,6 +17,7 @@ dependencies = [
"supabase>=2.18.1",
"websocket-client>=1.8.0",
"pyyaml>=6.0.2",
+ "jsonref>=1.1.0",
"ruamel-yaml==0.19.1",
"redis>=5.0.0",
"fastmcp==2.12.0",
diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py
index 82c4f6f0d..c90d707e1 100644
--- a/backend/services/agent_service.py
+++ b/backend/services/agent_service.py
@@ -1265,7 +1265,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str)
# Check if any tool is KnowledgeBaseSearchTool and set its metadata to empty dict
for tool in tool_list:
- if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "DataMateSearchTool"]:
+ if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool", "DataMateSearchTool"]:
tool.metadata = {}
# Get model_id and model display name from agent_info
diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py
index d2ea8fea6..7feea9452 100644
--- a/backend/services/config_sync_service.py
+++ b/backend/services/config_sync_service.py
@@ -20,7 +20,7 @@
MODEL_ENGINE_ENABLED,
TENANT_NAME
)
-from database.model_management_db import get_model_id_by_display_name
+from database.model_management_db import get_model_id_by_display_name, get_model_records
from utils.config_utils import (
get_env_key,
get_model_name_from_config,
@@ -31,6 +31,20 @@
logger = logging.getLogger("config_sync_service")
+def get_model_id_for_config(model_type: str, display_name: str, tenant_id: str) -> Optional[int]:
+ if not display_name:
+ return None
+
+ records = get_model_records(
+ {"display_name": display_name, "model_type": model_type},
+ tenant_id
+ )
+ if records:
+ return records[0].get("model_id")
+
+ return get_model_id_by_display_name(display_name, tenant_id)
+
+
def handle_model_config(tenant_id: str, user_id: str, config_key: str, model_id: Optional[int], tenant_config_dict: dict) -> None:
"""
Handle model configuration updates, deletions, and settings operations
@@ -98,8 +112,8 @@ async def save_config_impl(config, tenant_id, user_id):
model_display_name = model_config.get("displayName")
config_key = get_env_key(model_type) + "_ID"
- model_id = get_model_id_by_display_name(
- model_display_name, tenant_id)
+ model_id = get_model_id_for_config(
+ model_type, model_display_name, tenant_id)
handle_model_config(tenant_id, user_id, config_key,
model_id, tenant_config_dict)
diff --git a/backend/services/image_service.py b/backend/services/image_service.py
index 8decbd541..8a924e9cc 100644
--- a/backend/services/image_service.py
+++ b/backend/services/image_service.py
@@ -31,7 +31,11 @@ async def proxy_image_impl(decoded_url: str):
def get_vlm_model(tenant_id: str):
- # Get the tenant config
+ """Return the configured image understanding model for AnalyzeImageTool.
+
+ The first multimodal model slot is still stored under MODEL_CONFIG_MAPPING["vlm"]
+ for compatibility, but it is the user-facing image understanding configuration.
+ """
vlm_model_config = tenant_config_manager.get_model_config(
key=MODEL_CONFIG_MAPPING["vlm"], tenant_id=tenant_id)
if not vlm_model_config:
@@ -48,3 +52,27 @@ def get_vlm_model(tenant_id: str):
max_tokens=512,
ssl_verify=vlm_model_config.get("ssl_verify", True),
)
+
+
+def get_image_understanding_model(tenant_id: str):
+ return get_vlm_model(tenant_id=tenant_id)
+
+
+def get_video_understanding_model(tenant_id: str):
+ """Return the configured video understanding model for multimodal tools."""
+ vlm_model_config = tenant_config_manager.get_model_config(
+ key=MODEL_CONFIG_MAPPING["vlm3"], tenant_id=tenant_id)
+ if not vlm_model_config:
+ return None
+ return OpenAIVLModel(
+ observer=MessageObserver(),
+ model_id=get_model_name_from_config(
+ vlm_model_config) if vlm_model_config else "",
+ api_base=vlm_model_config.get("base_url", ""),
+ api_key=vlm_model_config.get("api_key", ""),
+ temperature=0.7,
+ top_p=0.7,
+ frequency_penalty=0.5,
+ max_tokens=512,
+ ssl_verify=vlm_model_config.get("ssl_verify", True),
+ )
diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py
index 4324bf12f..58a0af91f 100644
--- a/backend/services/model_health_service.py
+++ b/backend/services/model_health_service.py
@@ -134,7 +134,7 @@ async def _perform_connectivity_check(
ssl_verify=ssl_verify,
)
connectivity = await rerank_model.connectivity_check()
- elif model_type == "vlm":
+ elif model_type in ("vlm", "vlm2", "vlm3"):
observer = MessageObserver()
set_monitoring_operation("connectivity_check",
display_name=display_name)
diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py
index 5cf4e381b..8f6d191fd 100644
--- a/backend/services/model_management_service.py
+++ b/backend/services/model_management_service.py
@@ -164,6 +164,13 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay
tenant_id, provider, model_type)
model_list_ids = {model.get("id")
for model in model_list} if model_list else set()
+ existing_model_map = {
+ add_repo_to_name(
+ model_repo=model["model_repo"],
+ model_name=model["model_name"],
+ ): model
+ for model in existing_model_list
+ }
# Delete existing models not present
for model in existing_model_list:
@@ -173,21 +180,20 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay
# Create or update new models
for model in model_list:
+ model["model_type"] = model_type
_, model_name = split_repo_name(
model["id"]) if model.get("id") else ("", "")
model_repo, model_name_only = split_repo_name(
model.get("id", "")) if model.get("id") else ("", "")
model_display_name = add_repo_to_name(model_repo, model_name_only)
if model_name:
- existing_model_by_display = get_model_by_display_name(
- model_display_name, tenant_id)
- if existing_model_by_display:
+ existing_model = existing_model_map.get(model_display_name)
+ if existing_model:
# Check if max_tokens has changed
- existing_max_tokens = existing_model_by_display.get(
- "max_tokens")
+ existing_max_tokens = existing_model.get("max_tokens")
new_max_tokens = model.get("max_tokens")
if new_max_tokens is not None and existing_max_tokens != new_max_tokens:
- update_model_record(existing_model_by_display["model_id"], {
+ update_model_record(existing_model["model_id"], {
"max_tokens": new_max_tokens}, user_id)
continue
diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py
index ea41cc95d..130f2346e 100644
--- a/backend/services/providers/silicon_provider.py
+++ b/backend/services/providers/silicon_provider.py
@@ -1,4 +1,5 @@
import httpx
+import re
from typing import Dict, List
from consts.const import DEFAULT_LLM_MAX_TOKENS
@@ -6,6 +7,62 @@
from services.providers.base import AbstractModelProvider, _classify_provider_error
+SILICON_VLM_MODEL_KEYWORDS = (
+ "-vl",
+ "_vl",
+ "/vl",
+ ".vl",
+ "vl-",
+ "vision",
+ "visual",
+ "internvl",
+ "deepseek-vl",
+ "deepseekvl",
+ "glm-4v",
+ "minicpm-v",
+ "llava",
+ "kimi-vl",
+ "kimi-k2.5",
+ "kimi-k2.6",
+ "qvq",
+ "omni",
+ "qwen3.5",
+ "qwen3.6",
+)
+
+SILICON_VLM_METADATA_KEYWORDS = ("image", "video", "vision", "visual")
+
+
+def _contains_silicon_vlm_metadata(value) -> bool:
+ if isinstance(value, str):
+ lower_value = value.lower()
+ return any(keyword in lower_value for keyword in SILICON_VLM_METADATA_KEYWORDS)
+ if isinstance(value, list):
+ return any(_contains_silicon_vlm_metadata(item) for item in value)
+ if isinstance(value, dict):
+ return any(_contains_silicon_vlm_metadata(item) for item in value.values())
+ return False
+
+
+def _is_silicon_vlm_model(model: Dict) -> bool:
+ if _contains_silicon_vlm_metadata(model):
+ return True
+
+ model_id = str(model.get("id", "")).lower()
+ model_name = str(model.get("name", "")).lower()
+ searchable_text = f"{model_id} {model_name}"
+ if any(keyword in searchable_text for keyword in SILICON_VLM_MODEL_KEYWORDS):
+ return True
+
+ return bool(re.search(r"glm-\d+(?:\.\d+)?v", searchable_text))
+
+
+def _is_silicon_omni_model(model: Dict) -> bool:
+ model_id = str(model.get("id", "")).lower()
+ model_name = str(model.get("name", "")).lower()
+ return "omni" in f"{model_id} {model_name}"
+
+
class SiliconModelProvider(AbstractModelProvider):
"""Concrete implementation for SiliconFlow provider."""
@@ -25,12 +82,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
headers = {"Authorization": f"Bearer {model_api_key}"}
+ provider_model_type = "vlm" if model_type in ("vlm2", "vlm3") else model_type
+
# Choose endpoint by model type
- if model_type in ("llm", "vlm"):
+ if provider_model_type in ("llm", "vlm"):
silicon_url = f"{SILICON_GET_URL}?sub_type=chat"
- elif model_type in ("embedding", "multi_embedding"):
+ elif provider_model_type in ("embedding", "multi_embedding"):
silicon_url = f"{SILICON_GET_URL}?sub_type=embedding"
- elif model_type == "rerank":
+ elif provider_model_type == "rerank":
silicon_url = f"{SILICON_GET_URL}?sub_type=reranker"
else:
silicon_url = SILICON_GET_URL
@@ -40,17 +99,22 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
response.raise_for_status()
model_list: List[Dict] = response.json()["data"]
+ if model_type == "vlm3":
+ model_list = [item for item in model_list if _is_silicon_omni_model(item)]
+ elif provider_model_type == "vlm":
+ model_list = [item for item in model_list if _is_silicon_vlm_model(item)]
+
# Annotate models with canonical fields expected downstream
- if model_type in ("llm", "vlm"):
+ if provider_model_type in ("llm", "vlm"):
for item in model_list:
item["model_tag"] = "chat"
item["model_type"] = model_type
item["max_tokens"] = DEFAULT_LLM_MAX_TOKENS
- elif model_type in ("embedding", "multi_embedding"):
+ elif provider_model_type in ("embedding", "multi_embedding"):
for item in model_list:
item["model_tag"] = "embedding"
item["model_type"] = model_type
- elif model_type == "rerank":
+ elif provider_model_type == "rerank":
for item in model_list:
item["model_tag"] = "rerank"
item["model_type"] = model_type
diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py
index c30a6006f..2e868044e 100644
--- a/backend/services/tool_configuration_service.py
+++ b/backend/services/tool_configuration_service.py
@@ -38,7 +38,7 @@
from services.file_management_service import get_llm_model, validate_urls_access
from services.vectordatabase_service import get_embedding_model_by_index_name, get_rerank_model
from database.client import minio_client
-from services.image_service import get_vlm_model
+from services.image_service import get_video_understanding_model, get_vlm_model
from nexent.monitor import set_monitoring_context, set_monitoring_operation
from services.vectordatabase_service import get_vector_db_core
from utils.langchain_utils import discover_langchain_modules
@@ -779,6 +779,7 @@ def _validate_local_tool(
if not tenant_id or not user_id:
raise ToolExecutionException(
f"Tenant ID and User ID are required for {tool_name} validation")
+ # get_vlm_model reads the first multimodal slot, now shown as image understanding.
image_to_text_model = get_vlm_model(tenant_id=tenant_id)
vlm_display_name = getattr(
image_to_text_model, 'display_name', None)
@@ -792,6 +793,23 @@ def _validate_local_tool(
'validate_url_access': lambda urls: validate_urls_access(urls, user_id)
}
tool_instance = tool_class(**params)
+ elif tool_name in ["analyze_audio", "analyze_video"]:
+ if not tenant_id or not user_id:
+ raise ToolExecutionException(
+ f"Tenant ID and User ID are required for {tool_name} validation")
+ video_understanding_model = get_video_understanding_model(tenant_id=tenant_id)
+ model_display_name = getattr(
+ video_understanding_model, 'display_name', None)
+ set_monitoring_context(tenant_id=tenant_id)
+ set_monitoring_operation(
+ "tool_validation", display_name=model_display_name)
+ params = {
+ **instantiation_params,
+ 'vlm_model': video_understanding_model,
+ 'storage_client': minio_client,
+ 'validate_url_access': lambda urls: validate_urls_access(urls, user_id)
+ }
+ tool_instance = tool_class(**params)
elif tool_name == "analyze_text_file":
if not tenant_id or not user_id:
raise ToolExecutionException(
diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
index cb1c8cd3f..993795c98 100644
--- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
+++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
@@ -35,11 +35,17 @@ const TOOLS_REQUIRING_EMBEDDING = [
"knowledge_base_search",
];
-// Tool types that require VLM model
-const TOOLS_REQUIRING_VLM = [
+// Tool types that require the image understanding model
+const TOOLS_REQUIRING_IMAGE_UNDERSTANDING = [
"analyze_image",
];
+// Tool types that require the video understanding model
+const TOOLS_REQUIRING_VIDEO_UNDERSTANDING = [
+ "analyze_audio",
+ "analyze_video",
+];
+
function getToolKbType(
toolName: string
): "knowledge_base_search" | "dify_search" | "datamate_search" | "idata_search" | "haotian_search" | null {
@@ -54,9 +60,18 @@ function getToolKbType(
/**
* Check if a tool requires VLM model but VLM is not available
*/
-function isToolDisabledDueToVlm(toolName: string, vlmAvailable: boolean): boolean {
- if (!TOOLS_REQUIRING_VLM.includes(toolName)) return false;
- return !vlmAvailable;
+function isToolDisabledDueToVlm(
+ toolName: string,
+ imageUnderstandingAvailable: boolean,
+ videoUnderstandingAvailable: boolean
+): boolean {
+ if (TOOLS_REQUIRING_IMAGE_UNDERSTANDING.includes(toolName)) {
+ return !imageUnderstandingAvailable;
+ }
+ if (TOOLS_REQUIRING_VIDEO_UNDERSTANDING.includes(toolName)) {
+ return !videoUnderstandingAvailable;
+ }
+ return false;
}
/**
@@ -98,7 +113,11 @@ export default function ToolManagement({
// Use tool list hook for data management
const { availableTools } = useToolList();
- const { isVlmAvailable, isEmbeddingAvailable } = useConfig();
+ const {
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable,
+ isEmbeddingAvailable,
+ } = useConfig();
// Prefetch knowledge bases for KB tools
const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases();
@@ -358,7 +377,11 @@ export default function ToolManagement({
const isSelected = originalSelectedToolIdsSet.has(
tool.id
);
- const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable);
+ const isDisabledDueToVlm = isToolDisabledDueToVlm(
+ tool.name,
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable
+ );
const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable);
const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly;
// Tooltip priority: permission > VLM > Embedding
@@ -463,7 +486,11 @@ export default function ToolManagement({
>
{group.tools.map((tool) => {
const isSelected = originalSelectedToolIdsSet.has(tool.id);
- const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable);
+ const isDisabledDueToVlm = isToolDisabledDueToVlm(
+ tool.name,
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable
+ );
const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable);
const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly;
// Tooltip priority: permission > VLM > Embedding
diff --git a/frontend/app/[locale]/chat/components/chatAttachment.tsx b/frontend/app/[locale]/chat/components/chatAttachment.tsx
index 5c9da8ec9..d12e939cd 100644
--- a/frontend/app/[locale]/chat/components/chatAttachment.tsx
+++ b/frontend/app/[locale]/chat/components/chatAttachment.tsx
@@ -87,6 +87,14 @@ const getFileIcon = (name: string, contentType?: string) => {
return ;
}
+ // Audio and video files are uploaded as regular attachments for multimodal tools.
+ if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) {
+ return ;
+ }
+ if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) {
+ return ;
+ }
+
// Compressed file
if (chatConfig.fileIcons.compressed.includes(extension)) {
return ;
@@ -230,4 +238,4 @@ export function ChatAttachment({
)}
);
-}
\ No newline at end of file
+}
diff --git a/frontend/app/[locale]/chat/components/chatInput.tsx b/frontend/app/[locale]/chat/components/chatInput.tsx
index 8de0d17eb..bcfc86f6b 100644
--- a/frontend/app/[locale]/chat/components/chatInput.tsx
+++ b/frontend/app/[locale]/chat/components/chatInput.tsx
@@ -96,10 +96,24 @@ const getFileIcon = (file: File) => {
return ;
}
+ if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) {
+ return ;
+ }
+
+ if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) {
+ return ;
+ }
+
// Default file icon
return ;
};
+const isSupportedMediaFile = (extension: string, fileType: string) =>
+ fileType.startsWith("audio/") ||
+ fileType.startsWith("video/") ||
+ chatConfig.audioExtensions.includes(extension) ||
+ chatConfig.videoExtensions.includes(extension);
+
// File limit constants from config
const MAX_FILE_COUNT = chatConfig.maxFileCount;
const MAX_FILE_SIZE = chatConfig.maxFileSize;
@@ -617,8 +631,9 @@ export function ChatInput({
chatConfig.supportedTextExtensions.includes(extension) ||
file.type === "text/csv" ||
file.type === "text/plain";
+ const isMedia = isSupportedMediaFile(extension, file.type);
- if (isImage || isDocument || isSupportedTextFile) {
+ if (isImage || isDocument || isSupportedTextFile || isMedia) {
// Create a preview URL for images
const previewUrl = isImage ? URL.createObjectURL(file) : undefined;
@@ -899,7 +914,7 @@ export function ChatInput({
id="file-upload-regular"
className="hidden"
onChange={handleFileUpload}
- accept={`image/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`}
+ accept={`image/*,audio/*,video/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`}
multiple
/>
@@ -1026,8 +1041,9 @@ export function ChatInput({
chatConfig.supportedTextExtensions.includes(extension) ||
fileType === "text/csv" ||
fileType === "text/plain";
+ const isMedia = isSupportedMediaFile(extension, fileType);
- return !(isImage || isDocument || isSupportedTextFile);
+ return !(isImage || isDocument || isSupportedTextFile || isMedia);
});
// Regular mode, keep the original rendering logic
diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx
index 6e0de48b5..0f3c99715 100644
--- a/frontend/app/[locale]/chat/internal/chatInterface.tsx
+++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx
@@ -38,7 +38,7 @@ import {
extractAssistantMsgFromResponse,
} from "@/lib/chatMessageExtractor";
-import { Layout } from "antd";
+import { Layout, message } from "antd";
import log from "@/lib/logger";
const stepIdCounter = { current: 0 };
@@ -268,9 +268,23 @@ export function ChatInterface() {
// Use preprocessing function to upload attachments
const uploadResult = await uploadAttachments(attachments, t);
+ if (uploadResult.error) {
+ message.error(`${t("chatPreprocess.fileUploadFailed")} ${uploadResult.error}`);
+ setIsLoading(false);
+ return;
+ }
uploadedFileUrls = uploadResult.uploadedFileUrls;
objectNames = uploadResult.objectNames; // Get object name mapping
presignedUrls = uploadResult.presignedUrls; // Get presigned URLs for external access
+
+ const missingUploads = attachments.filter(
+ (attachment) => !uploadedFileUrls[attachment.file.name] || !objectNames[attachment.file.name]
+ );
+ if (missingUploads.length > 0) {
+ message.error(`${t("chatPreprocess.fileUploadFailed")} ${missingUploads.map((item) => item.file.name).join(", ")}`);
+ setIsLoading(false);
+ return;
+ }
}
// Use preprocessing function to create message attachments
diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
index 3bff1101a..acfedb120 100644
--- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
@@ -68,6 +68,23 @@ const DEFAULT_FORM_STATE = {
ttsProvider: "dashscope", // ali or volcengine
};
+const resolveConnectivityModelType = (type: ModelType): ModelType =>
+ type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3
+ ? (MODEL_TYPES.VLM as ModelType)
+ : type;
+
+const resolveConfigKey = (type: ModelType): string =>
+ type;
+
+const isVlmConfigType = (type: ModelType): boolean =>
+ type === MODEL_TYPES.VLM || type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3;
+
+const emptyModelConfig = {
+ modelName: "",
+ displayName: "",
+ apiConfig: { apiKey: "", modelUrl: "" },
+};
+
// Connectivity status type comes from utils
// Helper function to translate error messages from backend
@@ -198,7 +215,7 @@ export const ModelAddDialog = ({
}: ModelAddDialogProps) => {
const { t } = useTranslation();
const { message } = App.useApp();
- const { updateModelConfig, saveConfig } = useConfig();
+ const { modelConfig: currentModelConfig, updateModelConfig, saveConfig } = useConfig();
// Parse backend error message and return i18n key with params
const parseModelError = (
@@ -490,7 +507,7 @@ export const ModelAddDialog = ({
const modelType =
form.type === MODEL_TYPES.EMBEDDING && form.isMultimodal
? (MODEL_TYPES.MULTI_EMBEDDING as ModelType)
- : form.type;
+ : resolveConnectivityModelType(form.type);
let connectivity = false;
@@ -653,6 +670,32 @@ export const ModelAddDialog = ({
});
}
+ if (isVlmConfigType(form.type) && enabledModels.length > 0) {
+ const selectedModel = enabledModels[0];
+ const selectedDisplayName = selectedModel.displayName || selectedModel.id || "";
+ const configKey = resolveConfigKey(form.type);
+ const vlmConfigUpdate: any = {
+ [configKey]: {
+ modelName: selectedModel.id || selectedModel.model_name || "",
+ displayName: selectedDisplayName,
+ apiConfig: {
+ apiKey: form.apiKey,
+ modelUrl: "",
+ },
+ },
+ };
+ for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) {
+ if (
+ key !== configKey &&
+ currentModelConfig?.[key]?.displayName === selectedDisplayName
+ ) {
+ vlmConfigUpdate[key] = emptyModelConfig;
+ }
+ }
+ updateModelConfig(vlmConfigUpdate);
+ await persistModelConfig();
+ }
+
// Reset form state and close dialog on success
resetForm();
handleClose();
@@ -841,6 +884,7 @@ export const ModelAddDialog = ({
// Update the local storage according to the model type
let configUpdate: any = {};
+ const configKey = resolveConfigKey(form.type);
switch (modelType) {
case MODEL_TYPES.LLM:
@@ -853,7 +897,17 @@ export const ModelAddDialog = ({
configUpdate = { multiEmbedding: modelConfig };
break;
case MODEL_TYPES.VLM:
- configUpdate = { vlm: modelConfig };
+ case MODEL_TYPES.VLM2:
+ case MODEL_TYPES.VLM3:
+ configUpdate = { [configKey]: modelConfig };
+ for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) {
+ if (
+ key !== configKey &&
+ currentModelConfig?.[key]?.displayName === modelConfig.displayName
+ ) {
+ configUpdate[key] = emptyModelConfig;
+ }
+ }
break;
case MODEL_TYPES.RERANK:
configUpdate = { rerank: modelConfig };
@@ -996,7 +1050,15 @@ export const ModelAddDialog = ({
-
+
+
+
diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx
index 6ba9cf0c3..55c31c2c0 100644
--- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx
@@ -99,6 +99,8 @@ export const ModelDeleteDialog = ({
border: "border-purple-100",
};
case MODEL_TYPES.VLM:
+ case MODEL_TYPES.VLM2:
+ case MODEL_TYPES.VLM3:
return {
bg: "bg-yellow-50",
text: "text-yellow-600",
@@ -141,6 +143,8 @@ export const ModelDeleteDialog = ({
case MODEL_TYPES.TTS:
return "🔊";
case MODEL_TYPES.VLM:
+ case MODEL_TYPES.VLM2:
+ case MODEL_TYPES.VLM3:
return "👁️";
default:
return "⚙️";
@@ -165,6 +169,10 @@ export const ModelDeleteDialog = ({
return t("model.type.tts");
case MODEL_TYPES.VLM:
return t("model.type.vlm");
+ case MODEL_TYPES.VLM2:
+ return `${t("model.type.vlm")}2`;
+ case MODEL_TYPES.VLM3:
+ return `${t("model.type.vlm")}3`;
default:
return t("model.type.unknown");
}
@@ -344,7 +352,10 @@ export const ModelDeleteDialog = ({
if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl;
}
if (type === MODEL_TYPES.VLM) {
- const cfgUrl = modelConfig?.vlm?.apiConfig?.modelUrl;
+ const cfgUrl =
+ modelConfig?.vlm?.apiConfig?.modelUrl ||
+ modelConfig?.vlm2?.apiConfig?.modelUrl ||
+ modelConfig?.vlm3?.apiConfig?.modelUrl;
if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl;
}
if (type === MODEL_TYPES.LLM) {
@@ -501,6 +512,22 @@ export const ModelDeleteDialog = ({
};
}
+ if (modelConfig.vlm2?.displayName === displayName) {
+ configUpdates.vlm2 = {
+ modelName: "",
+ displayName: "",
+ apiConfig: { apiKey: "", modelUrl: "" },
+ };
+ }
+
+ if (modelConfig.vlm3?.displayName === displayName) {
+ configUpdates.vlm3 = {
+ modelName: "",
+ displayName: "",
+ apiConfig: { apiKey: "", modelUrl: "" },
+ };
+ }
+
if (modelConfig.stt.displayName === displayName) {
configUpdates.stt = { modelName: "", displayName: "" };
}
@@ -1017,6 +1044,8 @@ export const ModelDeleteDialog = ({
MODEL_TYPES.MULTI_EMBEDDING,
MODEL_TYPES.RERANK,
MODEL_TYPES.VLM,
+ MODEL_TYPES.VLM2,
+ MODEL_TYPES.VLM3,
MODEL_TYPES.STT,
MODEL_TYPES.TTS,
] as ModelType[]
diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
index 4b44a6361..c44251f15 100644
--- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
@@ -110,6 +110,10 @@ export const ModelEditDialog = ({
form.type === MODEL_TYPES.EMBEDDING ||
form.type === MODEL_TYPES.MULTI_EMBEDDING;
const isRerankModel = form.type === MODEL_TYPES.RERANK;
+ const connectivityModelType =
+ form.type === MODEL_TYPES.VLM2 || form.type === MODEL_TYPES.VLM3
+ ? (MODEL_TYPES.VLM as ModelType)
+ : form.type;
const isVoiceModel =
form.type === MODEL_TYPES.STT || form.type === MODEL_TYPES.TTS;
@@ -141,13 +145,9 @@ export const ModelEditDialog = ({
});
try {
- const modelType = form.type as ModelType;
- const isVoiceModel =
- modelType === MODEL_TYPES.STT || modelType === MODEL_TYPES.TTS;
-
const config: any = {
modelName: form.name,
- modelType: modelType,
+ modelType: connectivityModelType,
baseUrl: form.url,
apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey,
maxTokens:
@@ -271,6 +271,8 @@ export const ModelEditDialog = ({
embedding: MODEL_TYPES.EMBEDDING,
multi_embedding: MODEL_TYPES.MULTI_EMBEDDING,
vlm: MODEL_TYPES.VLM,
+ vlm2: MODEL_TYPES.VLM2,
+ vlm3: MODEL_TYPES.VLM3,
rerank: MODEL_TYPES.RERANK,
tts: MODEL_TYPES.TTS,
stt: MODEL_TYPES.STT,
diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx
index 07eee5c06..36fcdbb31 100644
--- a/frontend/app/[locale]/models/components/modelConfig.tsx
+++ b/frontend/app/[locale]/models/components/modelConfig.tsx
@@ -56,7 +56,11 @@ const getModelData = (t: any) => ({
},
multimodal: {
title: t("modelConfig.category.multimodal"),
- options: [{ id: MODEL_TYPES.VLM, name: t("modelConfig.option.vlmModel") }],
+ options: [
+ { id: MODEL_TYPES.VLM, name: t("modelConfig.option.imageUnderstandingModel") },
+ { id: MODEL_TYPES.VLM2, name: t("modelConfig.option.imageGenerationModel") },
+ { id: MODEL_TYPES.VLM3, name: t("modelConfig.option.videoUnderstandingModel") },
+ ],
},
voice: {
title: t("modelConfig.category.voice"),
@@ -142,7 +146,7 @@ export const ModelConfigSection = forwardRef<
llm: { main: "" },
embedding: { embedding: "", multi_embedding: "" },
reranker: { reranker: "" },
- multimodal: { vlm: "" },
+ multimodal: { vlm: "", vlm2: "", vlm3: "" },
voice: { tts: "", stt: "" },
});
@@ -284,11 +288,23 @@ export const ModelConfigSection = forwardRef<
: true;
const vlm = modelConfig.vlm.displayName;
+ const vlm2 = modelConfig.vlm2?.displayName || "";
+ const vlm3 = modelConfig.vlm3?.displayName || "";
const vlmExists = vlm
? allModels.some(
(m) => m.displayName === vlm && m.type === MODEL_TYPES.VLM
)
: true;
+ const vlm2Exists = vlm2
+ ? allModels.some(
+ (m) => m.displayName === vlm2 && m.type === MODEL_TYPES.VLM2
+ )
+ : true;
+ const vlm3Exists = vlm3
+ ? allModels.some(
+ (m) => m.displayName === vlm3 && m.type === MODEL_TYPES.VLM3
+ )
+ : true;
const stt = modelConfig.stt.displayName;
const sttExists = stt
@@ -318,6 +334,8 @@ export const ModelConfigSection = forwardRef<
},
multimodal: {
vlm: vlmExists ? vlm : "",
+ vlm2: vlm2Exists ? vlm2 : "",
+ vlm3: vlm3Exists ? vlm3 : "",
},
voice: {
tts: ttsExists ? tts : "",
@@ -363,6 +381,14 @@ export const ModelConfigSection = forwardRef<
configUpdates.vlm = { modelName: "", displayName: "" };
}
+ if (!vlm2Exists && vlm2) {
+ configUpdates.vlm2 = { modelName: "", displayName: "" };
+ }
+
+ if (!vlm3Exists && vlm3) {
+ configUpdates.vlm3 = { modelName: "", displayName: "" };
+ }
+
if (!sttExists && stt) {
configUpdates.stt = { modelName: "", displayName: "" };
}
@@ -385,6 +411,8 @@ export const ModelConfigSection = forwardRef<
!!modelConfig.multiEmbedding.modelName ||
!!modelConfig.rerank.modelName ||
!!modelConfig.vlm.modelName ||
+ !!modelConfig.vlm2?.modelName ||
+ !!modelConfig.vlm3?.modelName ||
!!modelConfig.tts.modelName ||
!!modelConfig.stt.modelName;
@@ -441,11 +469,13 @@ export const ModelConfigSection = forwardRef<
const hasEmbedding = !!modelConfig.embedding.modelName;
const hasReranker = !!modelConfig.rerank.modelName;
const hasVlm = !!modelConfig.vlm.modelName;
+ const hasVlm2 = !!modelConfig.vlm2?.modelName;
+ const hasVlm3 = !!modelConfig.vlm3?.modelName;
const hasTts = !!modelConfig.tts.modelName;
const hasStt = !!modelConfig.stt.modelName;
hasSelectedModels =
- hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasTts || hasStt;
+ hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasVlm2 || hasVlm3 || hasTts || hasStt;
if (hasSelectedModels) {
currentSelectedModels.llm.main = modelConfig.llm.modelName;
@@ -455,6 +485,8 @@ export const ModelConfigSection = forwardRef<
modelConfig.multiEmbedding.modelName || "";
currentSelectedModels.reranker.reranker = modelConfig.rerank.modelName;
currentSelectedModels.multimodal.vlm = modelConfig.vlm.modelName;
+ currentSelectedModels.multimodal.vlm2 = modelConfig.vlm2?.modelName || "";
+ currentSelectedModels.multimodal.vlm3 = modelConfig.vlm3?.modelName || "";
currentSelectedModels.voice.tts = modelConfig.tts.modelName;
currentSelectedModels.voice.stt = modelConfig.stt.modelName;
} else {
@@ -492,7 +524,7 @@ export const ModelConfigSection = forwardRef<
} else if (category === "reranker") {
modelType = MODEL_TYPES.RERANK;
} else if (category === "multimodal") {
- modelType = MODEL_TYPES.VLM;
+ modelType = optionId as ModelType;
} else if (category === MODEL_TYPES.EMBEDDING) {
modelType =
optionId === MODEL_TYPES.MULTI_EMBEDDING
@@ -654,7 +686,7 @@ export const ModelConfigSection = forwardRef<
} else if (category === "reranker") {
modelType = MODEL_TYPES.RERANK;
} else if (category === "multimodal") {
- modelType = MODEL_TYPES.VLM;
+ modelType = option as ModelType;
} else if (category === MODEL_TYPES.EMBEDDING) {
modelType =
option === MODEL_TYPES.MULTI_EMBEDDING
@@ -679,7 +711,7 @@ export const ModelConfigSection = forwardRef<
) {
configKey = "multiEmbedding";
} else if (category === "multimodal") {
- configKey = MODEL_TYPES.VLM;
+ configKey = option;
} else if (category === "reranker") {
configKey = MODEL_TYPES.RERANK;
} else if (category === "voice" && option === "tts") {
@@ -1005,7 +1037,7 @@ export const ModelConfigSection = forwardRef<
? MODEL_TYPES.TTS
: MODEL_TYPES.STT
: key === "multimodal"
- ? MODEL_TYPES.VLM
+ ? (option.id as ModelType)
: key === MODEL_TYPES.EMBEDDING &&
option.id === MODEL_TYPES.MULTI_EMBEDDING
? MODEL_TYPES.MULTI_EMBEDDING
diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts
index fc0dbe6d5..27b3b887d 100644
--- a/frontend/const/chatConfig.ts
+++ b/frontend/const/chatConfig.ts
@@ -38,6 +38,12 @@ export const chatConfig = {
// Supported document file extensions
documentExtensions: ["pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "epub", "html", "xml"],
+
+ // Supported audio file extensions
+ audioExtensions: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"],
+
+ // Supported video file extensions
+ videoExtensions: ["mp4", "mov", "m4v", "avi", "mkv", "webm", "wmv", "flv"],
// Supported text document extensions
supportedTextExtensions: ["md", "markdown", "txt", "csv", "json"],
@@ -73,6 +79,12 @@ export const chatConfig = {
// Compressed file
compressed: ["zip", "rar", "7z", "tar", "gz"],
+
+ // Audio files
+ audio: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"],
+
+ // Video files
+ video: ["mp4", "mov", "m4v", "avi", "mkv", "wmv", "flv"],
},
// File preview type constants
@@ -148,4 +160,4 @@ export const MESSAGE_ROLES = {
USER: "user" as const,
ASSISTANT: "assistant" as const,
SYSTEM: "system" as const,
-} as const;
\ No newline at end of file
+} as const;
diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts
index 9bdc5a4a8..b7762ace0 100644
--- a/frontend/const/modelConfig.ts
+++ b/frontend/const/modelConfig.ts
@@ -7,6 +7,8 @@ export const MODEL_TYPES = {
STT: "stt",
TTS: "tts",
VLM: "vlm",
+ VLM2: "vlm2",
+ VLM3: "vlm3",
} as const;
// Model source constants
diff --git a/frontend/hooks/model/useDashscopeModelList.ts b/frontend/hooks/model/useDashscopeModelList.ts
index b44348fe5..5d1035e8a 100644
--- a/frontend/hooks/model/useDashscopeModelList.ts
+++ b/frontend/hooks/model/useDashscopeModelList.ts
@@ -39,7 +39,9 @@ export const useDashscopeModelList = ({
const modelType =
form.type === "embedding" && form.isMultimodal
? ("multi_embedding" as ModelType)
- : form.type;
+ : form.type === "vlm2" || form.type === "vlm3"
+ ? ("vlm" as ModelType)
+ : form.type;
try {
// Use manage interface if tenantId is provided (for super admin)
diff --git a/frontend/hooks/model/useTokenponyModelList.ts b/frontend/hooks/model/useTokenponyModelList.ts
index 0a7e23581..0c502a404 100644
--- a/frontend/hooks/model/useTokenponyModelList.ts
+++ b/frontend/hooks/model/useTokenponyModelList.ts
@@ -39,7 +39,9 @@ export const useTokenPonyModelList = ({
const modelType =
form.type === "embedding" && form.isMultimodal
? ("multi_embedding" as ModelType)
- : form.type;
+ : form.type === "vlm2" || form.type === "vlm3"
+ ? ("vlm" as ModelType)
+ : form.type;
try {
// Use manage interface if tenantId is provided (for super admin)
diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts
index be4e463f0..94d4d57db 100644
--- a/frontend/hooks/useConfig.ts
+++ b/frontend/hooks/useConfig.ts
@@ -81,6 +81,22 @@ const defaultConfig: GlobalConfig = {
modelUrl: "",
},
},
+ vlm2: {
+ modelName: "",
+ displayName: "",
+ apiConfig: {
+ apiKey: "",
+ modelUrl: "",
+ },
+ },
+ vlm3: {
+ modelName: "",
+ displayName: "",
+ apiConfig: {
+ apiKey: "",
+ modelUrl: "",
+ },
+ },
stt: {
id: 0,
modelName: "",
@@ -173,6 +189,8 @@ function transformBackendToFrontend(backendConfig: any): GlobalConfig {
),
rerank: transformModelEntry(backendConfig.models.rerank),
vlm: transformModelEntry(backendConfig.models.vlm),
+ vlm2: transformModelEntry(backendConfig.models.vlm2),
+ vlm3: transformModelEntry(backendConfig.models.vlm3),
stt: transformVoiceModelEntry(backendConfig.models.stt),
tts: transformVoiceModelEntry(backendConfig.models.tts),
}
@@ -207,7 +225,10 @@ function loadConfigFromStorage(): GlobalConfig | null {
if (storedModelConfig) {
try {
- mergedConfig.models = JSON.parse(storedModelConfig);
+ mergedConfig.models = deepMerge(
+ mergedConfig.models,
+ JSON.parse(storedModelConfig)
+ );
} catch (error) {
log.error("Failed to parse model config:", error);
}
@@ -297,7 +318,24 @@ export function useConfig() {
const config: GlobalConfig = (query.data as GlobalConfig | undefined) ?? defaultConfig;
// Whether config has selected a VLM model
- const isVlmAvailable = !!(config?.models?.vlm?.modelName || config?.models?.vlm?.displayName);
+ const isVlmAvailable = !!(
+ config?.models?.vlm?.modelName ||
+ config?.models?.vlm?.displayName ||
+ config?.models?.vlm2?.modelName ||
+ config?.models?.vlm2?.displayName ||
+ config?.models?.vlm3?.modelName ||
+ config?.models?.vlm3?.displayName
+ );
+
+ const isImageUnderstandingAvailable = !!(
+ config?.models?.vlm?.modelName ||
+ config?.models?.vlm?.displayName
+ );
+
+ const isVideoUnderstandingAvailable = !!(
+ config?.models?.vlm3?.modelName ||
+ config?.models?.vlm3?.displayName
+ );
// Whether config has selected an Embedding model
const isEmbeddingAvailable = !!(config?.models?.embedding?.modelName || config?.models?.embedding?.displayName);
@@ -383,6 +421,8 @@ export function useConfig() {
appConfig: config?.app,
modelConfig: config?.models,
isVlmAvailable,
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable,
isEmbeddingAvailable,
defaultLlmModelName,
defaultLlmModelConfig,
diff --git a/frontend/lib/chat/chatAttachmentUtils.ts b/frontend/lib/chat/chatAttachmentUtils.ts
index c85615b4e..bff686ca1 100644
--- a/frontend/lib/chat/chatAttachmentUtils.ts
+++ b/frontend/lib/chat/chatAttachmentUtils.ts
@@ -69,6 +69,19 @@ export const uploadAttachments = async (
});
}
+ const failedResults = uploadResult.results.filter((result) => !result.success);
+ if (failedResults.length > 0 || uploadResult.success_count < attachments.length) {
+ const failedMessage = failedResults
+ .map((result) => `${result.file_name || "file"}: ${result.error || "Upload failed"}`)
+ .join("; ");
+ return {
+ uploadedFileUrls,
+ objectNames,
+ presignedUrls,
+ error: failedMessage || "Upload failed",
+ };
+ }
+
return { uploadedFileUrls, objectNames, presignedUrls };
} catch (error) {
log.error(t("chatPreprocess.fileUploadFailed"), error);
diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json
index 1adf9d6e3..b008cf70b 100644
--- a/frontend/public/locales/en/common.json
+++ b/frontend/public/locales/en/common.json
@@ -1,4 +1,4 @@
-{
+{
"assistant.name": "Nexent",
"mainPage.layout.title": "Nexent | AI Agents",
@@ -830,6 +830,9 @@
"model.type.llm": "Large Language Model",
"model.type.embedding": "Embedding Model",
"model.type.vlm": "Vision Language Model",
+ "model.type.imageUnderstanding": "Image Understanding Model",
+ "model.type.imageGeneration": "Image Generation Model",
+ "model.type.videoUnderstanding": "Video Understanding Model",
"model.type.rerank": "Rerank Model",
"model.type.stt": "Speech-to-Text Model",
"model.type.tts": "Text-to-Speech Model",
@@ -898,6 +901,9 @@
"modelConfig.option.multiEmbeddingModel": "Multimodal Embedding Model",
"modelConfig.option.rerankerModel": "Reranker Model",
"modelConfig.option.vlmModel": "Vision Language Model",
+ "modelConfig.option.imageUnderstandingModel": "Image Understanding Model",
+ "modelConfig.option.imageGenerationModel": "Image Generation Model",
+ "modelConfig.option.videoUnderstandingModel": "Video Understanding Model",
"modelConfig.option.ttsModel": "Text-to-Speech Model",
"modelConfig.option.sttModel": "Speech-to-Text Model",
"modelConfig.error.loadList": "Failed to load model list:",
diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json
index e37bd8936..b258edf20 100644
--- a/frontend/public/locales/zh/common.json
+++ b/frontend/public/locales/zh/common.json
@@ -1,4 +1,4 @@
-{
+{
"assistant.name": "Nexent",
"mainPage.layout.title": "Nexent | 智能问答",
@@ -820,6 +820,9 @@
"model.type.llm": "大语言模型",
"model.type.embedding": "向量模型",
"model.type.vlm": "视觉语言模型",
+ "model.type.imageUnderstanding": "图片理解模型",
+ "model.type.imageGeneration": "图片生成模型",
+ "model.type.videoUnderstanding": "视频理解模型",
"model.type.rerank": "重排模型",
"model.type.stt": "语音识别模型",
"model.type.tts": "语音合成模型",
@@ -889,6 +892,9 @@
"modelConfig.option.multiEmbeddingModel": "多模态向量模型",
"modelConfig.option.rerankerModel": "重排模型",
"modelConfig.option.vlmModel": "视觉语言模型",
+ "modelConfig.option.imageUnderstandingModel": "图片理解模型",
+ "modelConfig.option.imageGenerationModel": "图片生成模型",
+ "modelConfig.option.videoUnderstandingModel": "视频理解模型",
"modelConfig.option.ttsModel": "语音合成模型",
"modelConfig.option.sttModel": "语音识别模型",
"modelConfig.error.loadList": "加载模型列表失败:",
diff --git a/frontend/tailwind.config.ts b/frontend/tailwind.config.js
similarity index 100%
rename from frontend/tailwind.config.ts
rename to frontend/tailwind.config.js
diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts
index 21b7ef1c5..8f4789f6b 100644
--- a/frontend/types/modelConfig.ts
+++ b/frontend/types/modelConfig.ts
@@ -31,6 +31,8 @@ export type ModelType =
| "stt"
| "tts"
| "vlm"
+ | "vlm2"
+ | "vlm3"
| "multi_embedding";
// Model option interface
@@ -89,7 +91,7 @@ export interface TTSModelConfig extends SingleModelConfig {
// Single model configuration interface
export interface SingleModelConfig {
- id: number;
+ id?: number;
modelName: string;
displayName: string;
apiConfig: ModelApiConfig;
@@ -103,6 +105,8 @@ export interface ModelConfig {
multiEmbedding: SingleModelConfig;
rerank: SingleModelConfig;
vlm: SingleModelConfig;
+ vlm2: SingleModelConfig;
+ vlm3: SingleModelConfig;
stt: STTModelConfig;
tts: TTSModelConfig;
}
diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py
index bce103bb9..2ccf1d72a 100644
--- a/sdk/nexent/core/agents/nexent_agent.py
+++ b/sdk/nexent/core/agents/nexent_agent.py
@@ -241,7 +241,7 @@ def create_local_tool(self, tool_config: ToolConfig):
data_process_service_url=tool_config.metadata.get("data_process_service_url", []),
validate_url_access=validate_url_access,
**params)
- elif class_name == "AnalyzeImageTool":
+ elif class_name in ["AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool"]:
# Extract validate_url_access from metadata if it's callable
validate_url_access = tool_config.metadata.get("validate_url_access") if tool_config.metadata else None
if validate_url_access is not None and not callable(validate_url_access):
diff --git a/sdk/nexent/core/models/openai_vlm.py b/sdk/nexent/core/models/openai_vlm.py
index 1babb0057..cbc7388d6 100644
--- a/sdk/nexent/core/models/openai_vlm.py
+++ b/sdk/nexent/core/models/openai_vlm.py
@@ -126,6 +126,47 @@ def prepare_image_message(self, image_input: Union[str, BinaryIO], system_prompt
return messages
+ def prepare_media_message(
+ self,
+ media_input: Union[str, BinaryIO],
+ media_type: str,
+ content_type: str,
+ system_prompt: str) -> List[Dict[str, Any]]:
+ """
+ Prepare an OpenAI-compatible multimodal message for audio or video inputs.
+
+ Args:
+ media_input: Media file path or file stream object.
+ media_type: Either "audio" or "video".
+ content_type: MIME type for the data URL.
+ system_prompt: System prompt.
+
+ Returns:
+ List[Dict[str, Any]]: Prepared message list.
+ """
+ if media_type not in ("audio", "video"):
+ raise ValueError(f"Unsupported media type: {media_type}")
+
+ base64_media = self.encode_image(media_input)
+ media_url_key = f"{media_type}_url"
+ media_config: Dict[str, Any] = {"url": f"data:{content_type};base64,{base64_media}"}
+ if media_type == "video":
+ media_config.update({"detail": "high", "max_frames": 16, "fps": 1})
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": media_url_key,
+ media_url_key: media_config
+ },
+ {"type": "text", "text": system_prompt}
+ ]
+ }
+ ]
+ return messages
+
def analyze_image(self, image_input: Union[str, BinaryIO],
system_prompt: str = "Please describe this picture concisely and carefully, within 200 words.", stream: bool = True,
**kwargs) -> ChatMessage:
@@ -144,3 +185,23 @@ def analyze_image(self, image_input: Union[str, BinaryIO],
messages = self.prepare_image_message(image_input, system_prompt)
# Call __call__ explicitly so instance-level mocks work in tests.
return self.__call__(messages=messages, **kwargs)
+
+ def analyze_audio(
+ self,
+ audio_input: Union[str, BinaryIO],
+ system_prompt: str = "Please analyze this audio carefully.",
+ content_type: str = "audio/mpeg",
+ **kwargs) -> ChatMessage:
+ """Analyze audio content using the configured multimodal model."""
+ messages = self.prepare_media_message(audio_input, "audio", content_type, system_prompt)
+ return self.__call__(messages=messages, **kwargs)
+
+ def analyze_video(
+ self,
+ video_input: Union[str, BinaryIO],
+ system_prompt: str = "Please analyze this video carefully.",
+ content_type: str = "video/mp4",
+ **kwargs) -> ChatMessage:
+ """Analyze video content using the configured multimodal model."""
+ messages = self.prepare_media_message(video_input, "video", content_type, system_prompt)
+ return self.__call__(messages=messages, **kwargs)
diff --git a/sdk/nexent/core/prompts/analyze_audio_en.yaml b/sdk/nexent/core/prompts/analyze_audio_en.yaml
new file mode 100644
index 000000000..eee0bb060
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_audio_en.yaml
@@ -0,0 +1,13 @@
+# Audio Understanding Prompt Templates
+
+system_prompt: |-
+ The user has asked a question: {{ query }}. Please analyze this audio from the perspective of answering this question, within 300 words.
+
+ **Audio Analysis Requirements:**
+ 1. Focus on speech, sound events, tone, timing, and other audio content relevant to the user's question
+ 2. If speech is present, summarize or transcribe the key spoken content when possible
+ 3. Keep the answer concise and grounded in observable audio evidence
+ 4. Avoid guessing identities or facts that cannot be inferred from the audio
+
+user_prompt: |
+ Please listen to this audio and describe it from the perspective of answering the user's question.
diff --git a/sdk/nexent/core/prompts/analyze_audio_zh.yaml b/sdk/nexent/core/prompts/analyze_audio_zh.yaml
new file mode 100644
index 000000000..ae6f1fa0d
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_audio_zh.yaml
@@ -0,0 +1,13 @@
+# 音频理解 Prompt 模板
+
+system_prompt: |-
+ 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段音频,控制在 300 字以内。
+
+ **音频分析要求:**
+ 1. 关注与用户问题相关的语音、声音事件、语气、节奏和其他音频内容
+ 2. 如果包含人声,请尽可能总结或转写关键口语内容
+ 3. 回答要简洁,并基于音频中可观察到的信息
+ 4. 不要猜测无法从音频中判断的身份或事实
+
+user_prompt: |
+ 请仔细聆听这段音频,并从回答用户问题的角度进行描述。
diff --git a/sdk/nexent/core/prompts/analyze_video_en.yaml b/sdk/nexent/core/prompts/analyze_video_en.yaml
new file mode 100644
index 000000000..7834ca7f3
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_video_en.yaml
@@ -0,0 +1,13 @@
+# Video Understanding Prompt Templates
+
+system_prompt: |-
+ The user has asked a question: {{ query }}. Please analyze this video from the perspective of answering this question, within 300 words.
+
+ **Video Analysis Requirements:**
+ 1. Focus on scenes, actions, objects, people, visible text, and temporal changes relevant to the user's question
+ 2. Mention important audio cues only when they help answer the question
+ 3. Keep the answer concise, structured, and grounded in visible or audible evidence
+ 4. Avoid over-interpreting intent or facts that cannot be inferred from the video
+
+user_prompt: |
+ Please watch this video and describe it from the perspective of answering the user's question.
diff --git a/sdk/nexent/core/prompts/analyze_video_zh.yaml b/sdk/nexent/core/prompts/analyze_video_zh.yaml
new file mode 100644
index 000000000..e83a1676d
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_video_zh.yaml
@@ -0,0 +1,13 @@
+# 视频理解 Prompt 模板
+
+system_prompt: |-
+ 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段视频,控制在 300 字以内。
+
+ **视频分析要求:**
+ 1. 关注与用户问题相关的场景、动作、物体、人物、可见文字和时间变化
+ 2. 只有在有助于回答问题时,才补充重要的音频线索
+ 3. 回答要简洁、有条理,并基于视频中可见或可听的信息
+ 4. 不要过度推断无法从视频中判断的意图或事实
+
+user_prompt: |
+ 请仔细观看这段视频,并从回答用户问题的角度进行描述。
diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py
index 851690f16..a640cb5ff 100644
--- a/sdk/nexent/core/tools/__init__.py
+++ b/sdk/nexent/core/tools/__init__.py
@@ -19,6 +19,8 @@
from .terminal_tool import TerminalTool
from .analyze_text_file_tool import AnalyzeTextFileTool
from .analyze_image_tool import AnalyzeImageTool
+from .analyze_audio_tool import AnalyzeAudioTool
+from .analyze_video_tool import AnalyzeVideoTool
from .run_skill_script_tool import run_skill_script
from .read_skill_md_tool import read_skill_md
from .read_skill_config_tool import read_skill_config
@@ -47,6 +49,8 @@
"TerminalTool",
"AnalyzeTextFileTool",
"AnalyzeImageTool",
+ "AnalyzeAudioTool",
+ "AnalyzeVideoTool",
"run_skill_script",
"read_skill_md",
"read_skill_config"
diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py
new file mode 100644
index 000000000..c7509a6c2
--- /dev/null
+++ b/sdk/nexent/core/tools/analyze_audio_tool.py
@@ -0,0 +1,169 @@
+"""
+Analyze Audio Tool
+
+Analyze audio using the configured video understanding model.
+Supports audio from S3, HTTP, and HTTPS URLs.
+"""
+
+import logging
+from io import BytesIO
+from typing import List
+
+from jinja2 import StrictUndefined, Template
+from pydantic import Field
+from smolagents.tools import Tool
+
+from ...core.models import OpenAIVLModel
+from ...core.utils.observer import MessageObserver, ProcessType
+from ...core.utils.prompt_template_utils import get_prompt_template
+from ...core.utils.tools_common_message import ToolCategory, ToolSign
+from ...multi_modal.load_save_object import LoadSaveObjectManager
+from ...multi_modal.utils import detect_content_type_from_bytes
+from ...storage import MinIOStorageClient
+
+logger = logging.getLogger("analyze_audio_tool")
+
+
+class AnalyzeAudioTool(Tool):
+ """Tool for understanding and analyzing audio using the video understanding model."""
+
+ name = "analyze_audio"
+ description = (
+ "This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n"
+ "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
+ "HTTP, and HTTPS URLs.\n"
+ "Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query."
+ )
+ description_zh = (
+ "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。"
+ "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ )
+
+ inputs = {
+ "audio_urls_list": {
+ "type": "array",
+ "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
+ "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。"
+ },
+ "query": {
+ "type": "string",
+ "description": "User's question to guide the audio analysis",
+ "description_zh": "用户用于指导音频分析的问题"
+ }
+ }
+
+ init_param_descriptions = {
+ "observer": {"description": "Message observer"},
+ "vlm_model": {"description": "The video understanding model to use"},
+ "storage_client": {"description": "Storage client for downloading files"},
+ "validate_url_access": {
+ "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)"
+ }
+ }
+ output_type = "array"
+ category = ToolCategory.MULTIMODAL.value
+ tool_sign = ToolSign.MULTIMODAL_OPERATION.value
+
+ def __init__(
+ self,
+ observer: MessageObserver = Field(
+ description="Message observer",
+ default=None,
+ exclude=True),
+ vlm_model: OpenAIVLModel = Field(
+ description="The video understanding model to use",
+ default=None,
+ exclude=True),
+ storage_client: MinIOStorageClient = Field(
+ description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.",
+ default=None,
+ exclude=True),
+ validate_url_access: callable = Field(
+ description="Callback function to validate URL access permissions",
+ default=None,
+ exclude=True)
+ ):
+ super().__init__()
+ self.observer = observer
+ self.vlm_model = vlm_model
+ self.storage_client = storage_client
+ self._is_chinese = bool(observer and observer.lang == "zh")
+
+ validate_callback = None
+ if validate_url_access is not None and callable(validate_url_access):
+ validate_callback = validate_url_access
+ self.mm = LoadSaveObjectManager(
+ storage_client=self.storage_client,
+ validate_url_access=validate_callback
+ )
+ self.forward = self.mm.load_object(
+ input_names=["audio_urls_list"])(self._forward_impl)
+
+ self.running_prompt_zh = "正在分析音频..."
+ self.running_prompt_en = "Analyzing audio..."
+
+ def _validate_audio_capable_model(self) -> None:
+ """Fail early for SiliconFlow models that are known not to accept audio input."""
+ client_kwargs = getattr(self.vlm_model, "client_kwargs", {}) or {}
+ base_url = client_kwargs.get("base_url", "") if isinstance(client_kwargs, dict) else ""
+ model_id = str(getattr(self.vlm_model, "model_id", "") or "")
+
+ if "siliconflow" in str(base_url).lower() and model_id and "omni" not in model_id.lower():
+ raise ValueError(
+ "The selected video understanding model does not support audio input on SiliconFlow. "
+ "Please choose a Qwen3-Omni model for analyze_audio."
+ )
+
+ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
+ """Analyze audio files and return one result per audio input."""
+ if self.vlm_model is None:
+ error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试"
+ error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ logger.error(error_msg)
+ raise Exception(error_msg)
+ self._validate_audio_capable_model()
+
+ if self.observer:
+ running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en
+ self.observer.add_message("", ProcessType.TOOL, running_prompt)
+
+ if audio_urls_list is None:
+ raise ValueError("audio_urls cannot be None")
+ if not isinstance(audio_urls_list, list):
+ raise ValueError("audio_urls must be a list of bytes")
+ if not audio_urls_list:
+ raise ValueError("audio_urls must contain at least one audio file")
+
+ language = self.observer.lang if self.observer else "en"
+ prompts = get_prompt_template(
+ template_type='analyze_audio', language=language)
+ system_prompt = Template(
+ prompts['system_prompt'], undefined=StrictUndefined).render({'query': query})
+
+ try:
+ analysis_results: List[str] = []
+ for index, audio_bytes in enumerate(audio_urls_list, start=1):
+ logger.info(f"Analyzing audio #{index}, query: {query}")
+ content_type = detect_content_type_from_bytes(audio_bytes)
+ if not content_type.startswith("audio/"):
+ content_type = "audio/mpeg"
+ audio_stream = BytesIO(audio_bytes)
+ try:
+ response = self.vlm_model.analyze_audio(
+ audio_input=audio_stream,
+ system_prompt=system_prompt,
+ content_type=content_type
+ )
+ except Exception as e:
+ error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。"
+ error_msg_en = f"Failed to analyze audio {index}: {str(e)}. Please check if the video understanding model is configured correctly."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ raise Exception(error_msg)
+
+ analysis_results.append(response.content)
+
+ return analysis_results
+ except Exception as e:
+ logger.error(f"Error analyzing audio: {str(e)}", exc_info=True)
+ raise Exception(f"Error analyzing audio: {str(e)}")
diff --git a/sdk/nexent/core/tools/analyze_image_tool.py b/sdk/nexent/core/tools/analyze_image_tool.py
index 3851a896b..f7640a9dc 100644
--- a/sdk/nexent/core/tools/analyze_image_tool.py
+++ b/sdk/nexent/core/tools/analyze_image_tool.py
@@ -24,17 +24,17 @@
class AnalyzeImageTool(Tool):
- """Tool for understanding and analyzing image using a visual language model"""
+ """Tool for understanding and analyzing images using the image understanding model."""
name = "analyze_image"
description = (
- "This tool uses a visual language model to understand images based on your query and then returns a description of the image.\n"
+ "This tool uses the configured image understanding model to understand images based on your query and then returns a description of the image.\n"
"It is used to understand and analyze multiple images, with image sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
"HTTP, and HTTPS URLs.\n"
"Use this tool when you want to retrieve information contained in an image and provide the image's URL and your query."
)
- description_zh = "使用视觉语言模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ description_zh = "使用图片理解模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
inputs = {
"image_urls_list": {
@@ -54,7 +54,7 @@ class AnalyzeImageTool(Tool):
"description": "Message observer"
},
"vlm_model": {
- "description": "The VLM model to use"
+ "description": "The image understanding model to use"
},
"storage_client": {
"description": "Storage client for downloading files"
@@ -74,7 +74,7 @@ def __init__(
default=None,
exclude=True),
vlm_model: OpenAIVLModel = Field(
- description="The VLM model to use",
+ description="The image understanding model to use",
default=None,
exclude=True),
storage_client: MinIOStorageClient = Field(
@@ -130,10 +130,10 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]:
Raises:
Exception: If the image cannot be downloaded or analyzed.
"""
- # Check if VLM model is available
+ # Check if the image understanding model is available.
if self.vlm_model is None:
- error_msg_zh = "视觉语言模型(VLM)未配置,请联系管理员配置VLM模型后重试"
- error_msg_en = "Vision Language Model (VLM) is not configured. Please contact your administrator to configure the VLM model and try again."
+ error_msg_zh = "图片理解模型未配置,请联系管理员配置图片理解模型后重试"
+ error_msg_en = "Image understanding model is not configured. Please contact your administrator to configure the image understanding model and try again."
error_msg = error_msg_zh if self._is_chinese else error_msg_en
logger.error(error_msg)
raise Exception(error_msg)
@@ -170,8 +170,8 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]:
system_prompt=system_prompt
)
except Exception as e:
- error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查VLM模型配置是否正确。"
- error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the VLM model is configured correctly."
+ error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查图片理解模型配置是否正确。"
+ error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the image understanding model is configured correctly."
error_msg = error_msg_zh if self._is_chinese else error_msg_en
raise Exception(error_msg)
diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py
new file mode 100644
index 000000000..3dc033551
--- /dev/null
+++ b/sdk/nexent/core/tools/analyze_video_tool.py
@@ -0,0 +1,156 @@
+"""
+Analyze Video Tool
+
+Analyze videos using the configured video understanding model.
+Supports videos from S3, HTTP, and HTTPS URLs.
+"""
+
+import logging
+from io import BytesIO
+from typing import List
+
+from jinja2 import StrictUndefined, Template
+from pydantic import Field
+from smolagents.tools import Tool
+
+from ...core.models import OpenAIVLModel
+from ...core.utils.observer import MessageObserver, ProcessType
+from ...core.utils.prompt_template_utils import get_prompt_template
+from ...core.utils.tools_common_message import ToolCategory, ToolSign
+from ...multi_modal.load_save_object import LoadSaveObjectManager
+from ...multi_modal.utils import detect_content_type_from_bytes
+from ...storage import MinIOStorageClient
+
+logger = logging.getLogger("analyze_video_tool")
+
+
+class AnalyzeVideoTool(Tool):
+ """Tool for understanding and analyzing videos using the video understanding model."""
+
+ name = "analyze_video"
+ description = (
+ "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n"
+ "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
+ "HTTP, and HTTPS URLs.\n"
+ "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query."
+ )
+ description_zh = (
+ "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。"
+ "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ )
+
+ inputs = {
+ "video_urls_list": {
+ "type": "array",
+ "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
+ "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。"
+ },
+ "query": {
+ "type": "string",
+ "description": "User's question to guide the video analysis",
+ "description_zh": "用户用于指导视频分析的问题"
+ }
+ }
+
+ init_param_descriptions = {
+ "observer": {"description": "Message observer"},
+ "vlm_model": {"description": "The video understanding model to use"},
+ "storage_client": {"description": "Storage client for downloading files"},
+ "validate_url_access": {
+ "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)"
+ }
+ }
+ output_type = "array"
+ category = ToolCategory.MULTIMODAL.value
+ tool_sign = ToolSign.MULTIMODAL_OPERATION.value
+
+ def __init__(
+ self,
+ observer: MessageObserver = Field(
+ description="Message observer",
+ default=None,
+ exclude=True),
+ vlm_model: OpenAIVLModel = Field(
+ description="The video understanding model to use",
+ default=None,
+ exclude=True),
+ storage_client: MinIOStorageClient = Field(
+ description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.",
+ default=None,
+ exclude=True),
+ validate_url_access: callable = Field(
+ description="Callback function to validate URL access permissions",
+ default=None,
+ exclude=True)
+ ):
+ super().__init__()
+ self.observer = observer
+ self.vlm_model = vlm_model
+ self.storage_client = storage_client
+ self._is_chinese = bool(observer and observer.lang == "zh")
+
+ validate_callback = None
+ if validate_url_access is not None and callable(validate_url_access):
+ validate_callback = validate_url_access
+ self.mm = LoadSaveObjectManager(
+ storage_client=self.storage_client,
+ validate_url_access=validate_callback
+ )
+ self.forward = self.mm.load_object(
+ input_names=["video_urls_list"])(self._forward_impl)
+
+ self.running_prompt_zh = "正在分析视频..."
+ self.running_prompt_en = "Analyzing video..."
+
+ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
+ """Analyze videos and return one result per video input."""
+ if self.vlm_model is None:
+ error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试"
+ error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ logger.error(error_msg)
+ raise Exception(error_msg)
+
+ if self.observer:
+ running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en
+ self.observer.add_message("", ProcessType.TOOL, running_prompt)
+
+ if video_urls_list is None:
+ raise ValueError("video_urls cannot be None")
+ if not isinstance(video_urls_list, list):
+ raise ValueError("video_urls must be a list of bytes")
+ if not video_urls_list:
+ raise ValueError("video_urls must contain at least one video")
+
+ language = self.observer.lang if self.observer else "en"
+ prompts = get_prompt_template(
+ template_type='analyze_video', language=language)
+ system_prompt = Template(
+ prompts['system_prompt'], undefined=StrictUndefined).render({'query': query})
+
+ try:
+ analysis_results: List[str] = []
+ for index, video_bytes in enumerate(video_urls_list, start=1):
+ logger.info(f"Analyzing video #{index}, query: {query}")
+ content_type = detect_content_type_from_bytes(video_bytes)
+ if not content_type.startswith("video/"):
+ content_type = "video/mp4"
+ video_stream = BytesIO(video_bytes)
+ try:
+ response = self.vlm_model.analyze_video(
+ video_input=video_stream,
+ system_prompt=system_prompt,
+ content_type=content_type
+ )
+ except Exception as e:
+ error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。"
+ error_msg_en = f"Failed to analyze video {index}: {str(e)}. Please check if the video understanding model is configured correctly."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ raise Exception(error_msg)
+
+ analysis_results.append(response.content)
+
+ return analysis_results
+ except Exception as e:
+ logger.error(f"Error analyzing video: {str(e)}", exc_info=True)
+ raise Exception(f"Error analyzing video: {str(e)}")
diff --git a/sdk/nexent/core/utils/prompt_template_utils.py b/sdk/nexent/core/utils/prompt_template_utils.py
index ad06e9119..24b273876 100644
--- a/sdk/nexent/core/utils/prompt_template_utils.py
+++ b/sdk/nexent/core/utils/prompt_template_utils.py
@@ -17,6 +17,14 @@
LANGUAGE["ZH"]: 'core/prompts/analyze_image_zh.yaml',
LANGUAGE["EN"]: 'core/prompts/analyze_image_en.yaml'
},
+ 'analyze_audio': {
+ LANGUAGE["ZH"]: 'core/prompts/analyze_audio_zh.yaml',
+ LANGUAGE["EN"]: 'core/prompts/analyze_audio_en.yaml'
+ },
+ 'analyze_video': {
+ LANGUAGE["ZH"]: 'core/prompts/analyze_video_zh.yaml',
+ LANGUAGE["EN"]: 'core/prompts/analyze_video_en.yaml'
+ },
'analyze_file': {
LANGUAGE["ZH"]: 'core/prompts/analyze_file_zh.yaml',
LANGUAGE["EN"]: 'core/prompts/analyze_file_en.yaml'
@@ -30,6 +38,8 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw
Args:
template_type: Template type, supports the following values:
- 'analyze_image': Analyze image template
+ - 'analyze_audio': Analyze audio template
+ - 'analyze_video': Analyze video template
- 'analyze_file': Analyze file template (for text files)
language: Language code ('zh' or 'en')
**kwargs: Additional parameters, for agent type need to pass is_manager parameter
@@ -52,4 +62,4 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw
# Read and return template content
with open(absolute_template_path, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
\ No newline at end of file
+ return yaml.safe_load(f)
diff --git a/sdk/nexent/multi_modal/utils.py b/sdk/nexent/multi_modal/utils.py
index e118f6940..bcd6cdd35 100644
--- a/sdk/nexent/multi_modal/utils.py
+++ b/sdk/nexent/multi_modal/utils.py
@@ -34,10 +34,10 @@ def is_url(url: str) -> Optional[UrlType]:
if url.startswith("https://"):
return "https"
- if url.startswith("s3://"):
- bucket_path = url.replace("s3://", "", 1)
+ if url.startswith("s3://") or url.startswith("s3:/"):
+ bucket_path = url.replace("s3://", "", 1) if url.startswith("s3://") else url.replace("s3:/", "", 1).lstrip("/")
bucket_object = bucket_path.split("/", 1)
- if len(bucket_object) == 2 and all(bucket_object):
+ if len(bucket_object) == 2 and all(bucket_object) and ":" not in bucket_object[0]:
return "s3"
return None
@@ -321,6 +321,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
Supports formats:
- s3://bucket/key
+ - s3:/bucket/key
- /bucket/key (MinIO path format)
Args:
@@ -335,11 +336,16 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
if not s3_url:
raise ValueError("S3 URL cannot be empty")
- if s3_url.startswith('s3://'):
- parts = s3_url.replace('s3://', '').split('/', 1)
+ if s3_url.startswith('s3://') or s3_url.startswith('s3:/'):
+ normalized_url = (
+ s3_url.replace('s3://', '', 1)
+ if s3_url.startswith('s3://')
+ else s3_url.replace('s3:/', '', 1).lstrip('/')
+ )
+ parts = normalized_url.split('/', 1)
if len(parts) == 2:
bucket, object_name = parts
- if not bucket or not object_name:
+ if not bucket or not object_name or ":" in bucket:
raise ValueError(f"Invalid s3:// URL format: {s3_url}")
return bucket, object_name
raise ValueError(f"Invalid s3:// URL format: {s3_url}")
@@ -351,4 +357,4 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
return bucket, object_name
raise ValueError(f"Invalid path format: {s3_url}")
- raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...")
\ No newline at end of file
+ raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...")
diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py
index 5817fbe27..20340f2ea 100644
--- a/test/backend/agents/test_create_agent_info.py
+++ b/test/backend/agents/test_create_agent_info.py
@@ -30,6 +30,21 @@ class ValidationError(Exception):
pass
+class MCPConnectionError(Exception):
+ """Mock MCPConnectionError for testing."""
+ pass
+
+
+class NotFoundException(Exception):
+ """Mock NotFoundException for testing."""
+ pass
+
+
+class ToolExecutionException(Exception):
+ """Mock ToolExecutionException for testing."""
+ pass
+
+
consts_model_module = types.ModuleType("consts.model")
consts_model_module.HistoryItem = HistoryItem
sys.modules["consts.model"] = consts_model_module
@@ -37,6 +52,9 @@ class ValidationError(Exception):
# Mock consts.exceptions module with ValidationError
consts_exceptions_module = types.ModuleType("consts.exceptions")
consts_exceptions_module.ValidationError = ValidationError
+consts_exceptions_module.MCPConnectionError = MCPConnectionError
+consts_exceptions_module.NotFoundException = NotFoundException
+consts_exceptions_module.ToolExecutionException = ToolExecutionException
sys.modules["consts.exceptions"] = consts_exceptions_module
# Also add model and exceptions to consts module attributes
@@ -165,7 +183,9 @@ def _create_stub_module(name: str, **attrs):
services_module = _create_stub_module("services")
sys.modules['services'] = services_module
sys.modules['services.image_service'] = _create_stub_module(
- "services.image_service", get_vlm_model=MagicMock(return_value="stub_vlm")
+ "services.image_service",
+ get_vlm_model=MagicMock(return_value="stub_vlm"),
+ get_video_understanding_model=MagicMock(return_value="stub_video_vlm"),
)
sys.modules['services.memory_config_service'] = MagicMock()
# Extend services hierarchy with additional stubs
@@ -250,6 +270,7 @@ def _create_stub_module(name: str, **attrs):
_extract_url_from_card,
_build_external_agent_config,
_get_external_a2a_agents,
+ _build_internal_s3_url,
_format_minio_files_for_content,
_convert_history_with_minio_files,
)
@@ -727,6 +748,48 @@ async def test_create_tool_config_list_with_analyze_image_tool(self):
assert "validate_url_access" in mock_tool_instance.metadata
assert callable(mock_tool_instance.metadata["validate_url_access"])
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "class_name,tool_name",
+ [
+ ("AnalyzeAudioTool", "analyze_audio"),
+ ("AnalyzeVideoTool", "analyze_video"),
+ ],
+ )
+ async def test_create_tool_config_list_with_audio_video_tools(self, class_name, tool_name):
+ """Ensure audio/video tools receive video understanding model metadata."""
+ mock_tool_instance = MagicMock()
+ mock_tool_instance.class_name = class_name
+ mock_tool_config.return_value = mock_tool_instance
+
+ with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \
+ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \
+ patch('backend.agents.create_agent_info.get_video_understanding_model') as mock_get_video_model, \
+ patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock):
+
+ mock_search_tools.return_value = [
+ {
+ "class_name": class_name,
+ "name": tool_name,
+ "description": "Analyze media tool",
+ "inputs": "string",
+ "output_type": "string",
+ "params": [{"name": "prompt", "default": "describe"}],
+ "source": "local",
+ "usage": None
+ }
+ ]
+ mock_get_video_model.return_value = "mock_video_model"
+
+ result = await create_tool_config_list("agent_1", "tenant_1", "user_1")
+
+ assert len(result) == 1
+ assert result[0] is mock_tool_instance
+ mock_get_video_model.assert_called_once_with(tenant_id="tenant_1")
+ assert mock_tool_instance.metadata["vlm_model"] == "mock_video_model"
+ assert "storage_client" in mock_tool_instance.metadata
+ assert callable(mock_tool_instance.metadata["validate_url_access"])
+
@pytest.mark.asyncio
async def test_create_tool_config_list_with_analyze_text_file_tool(self):
"""Ensure AnalyzeTextFileTool receives text-specific metadata."""
@@ -3297,6 +3360,26 @@ async def test_create_agent_run_info_is_need_auth_true_includes_token(self):
class TestJoinMinioFileDescriptionToQuery:
"""Tests for the join_minio_file_description_to_query function"""
+ def test_build_internal_s3_url_prefers_object_name(self):
+ file = {
+ "object_name": "attachments/user/image.png",
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+
+ result = _build_internal_s3_url(file)
+
+ assert result.endswith("/attachments/user/image.png")
+ assert result.startswith("s3://")
+
+ def test_build_internal_s3_url_rejects_blob_preview_url(self):
+ file = {
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+
+ assert _build_internal_s3_url(file) == ""
+
@pytest.mark.asyncio
async def test_join_minio_file_description_to_query_with_files(self):
"""Test case with file descriptions"""
@@ -3345,6 +3428,40 @@ async def test_join_minio_file_description_to_query_no_descriptions(self):
assert result == "test query"
+ @pytest.mark.asyncio
+ async def test_join_minio_file_description_to_query_prefers_object_name_over_blob_url(self):
+ """Uploaded images should be exposed to internal tools through MinIO, not browser blob URLs."""
+ minio_files = [
+ {
+ "object_name": "attachments/user/image.png",
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+ ]
+ query = "describe the image"
+
+ result = await join_minio_file_description_to_query(minio_files, query)
+
+ assert "blob:http" not in result
+ assert "File name: image.png" in result
+ assert "attachments/user/image.png" in result
+ assert "S3 URL: s3://" in result
+
+ @pytest.mark.asyncio
+ async def test_join_minio_file_description_to_query_skips_blob_only_file(self):
+ """Browser-only preview URLs cannot be used by internal tools."""
+ minio_files = [
+ {
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+ ]
+ query = "describe the image"
+
+ result = await join_minio_file_description_to_query(minio_files, query)
+
+ assert result == query
+
@pytest.mark.asyncio
async def test_join_minio_file_description_to_query_deduplication_current(self):
"""Test that duplicate files in current message are de-duplicated by URL"""
@@ -4455,6 +4572,21 @@ def test_format_minio_files_for_content_single_file_without_presigned_url(self):
result = _format_minio_files_for_content(minio_files)
assert result == "\n[Attached files]:\n - file.txt: s3:/bucket/file.txt"
+ def test_format_minio_files_for_content_uses_object_name_for_blob_url(self):
+ """Use uploaded object_name instead of browser-only blob preview URL."""
+ minio_files = [
+ {
+ "object_name": "attachments/user/image.png",
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+ ]
+
+ result = _format_minio_files_for_content(minio_files)
+
+ assert "blob:http" not in result
+ assert "attachments/user/image.png" in result
+
def test_format_minio_files_for_content_multiple_files(self):
"""Test case for multiple files"""
minio_files = [
diff --git a/test/backend/services/providers/test_silicon_provider.py b/test/backend/services/providers/test_silicon_provider.py
index b947040c3..c596643b2 100644
--- a/test/backend/services/providers/test_silicon_provider.py
+++ b/test/backend/services/providers/test_silicon_provider.py
@@ -66,7 +66,13 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
- {"id": "gpt-4v", "name": "GPT-4 Vision"},
+ {"id": "deepseek-ai/DeepSeek-R1", "name": "DeepSeek R1"},
+ {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "name": "Qwen2.5 VL"},
+ {"id": "OpenGVLab/InternVL2-26B", "name": "InternVL2 26B"},
+ {"id": "Pro/moonshotai/Kimi-K2.6", "name": "Kimi K2.6"},
+ {"id": "Pro/moonshotai/Kimi-K2.5", "name": "Kimi K2.5"},
+ {"id": "Qwen/Qwen3.6-27B", "name": "Qwen3.6 27B"},
+ {"id": "Qwen/Qwen3.6-35B-A3B", "name": "Qwen3.6 35B A3B"},
]
}
mock_response.raise_for_status = MagicMock()
@@ -95,10 +101,66 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
result = await provider.get_models(provider_config)
- assert len(result) == 1
- assert result[0]["id"] == "gpt-4v"
- assert result[0]["model_type"] == "vlm"
- assert result[0]["model_tag"] == "chat"
+ assert [model["id"] for model in result] == [
+ "Qwen/Qwen2.5-VL-72B-Instruct",
+ "OpenGVLab/InternVL2-26B",
+ "Pro/moonshotai/Kimi-K2.6",
+ "Pro/moonshotai/Kimi-K2.5",
+ "Qwen/Qwen3.6-27B",
+ "Qwen/Qwen3.6-35B-A3B",
+ ]
+ assert all(model["model_type"] == "vlm" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+
+ @pytest.mark.asyncio
+ async def test_get_models_vlm3_only_returns_omni_models(self, mocker: MockFixture):
+ """Test that SiliconFlow video understanding models are restricted to Omni models."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "data": [
+ {"id": "Qwen/Qwen3-VL-32B-Instruct", "name": "Qwen3 VL"},
+ {"id": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "name": "Qwen3 Omni Instruct"},
+ {"id": "Qwen/Qwen3-Omni-30B-A3B-Thinking", "name": "Qwen3 Omni Thinking"},
+ {"id": "Qwen/Qwen3-Omni-30B-A3B-Captioner", "name": "Qwen3 Omni Captioner"},
+ {"id": "zai-org/GLM-4.5V", "name": "GLM 4.5V"},
+ ]
+ }
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_response
+
+ mock_cm = MagicMock()
+ mock_cm.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_cm.__aexit__ = AsyncMock(return_value=None)
+
+ mocker.patch(
+ "backend.services.providers.silicon_provider.httpx.AsyncClient",
+ return_value=mock_cm
+ )
+ mocker.patch(
+ "backend.services.providers.silicon_provider.SILICON_GET_URL",
+ "https://api.siliconflow.com/v1/models"
+ )
+
+ provider = SiliconModelProvider()
+ provider_config = {
+ "model_type": "vlm3",
+ "api_key": "test-api-key"
+ }
+
+ result = await provider.get_models(provider_config)
+
+ assert [model["id"] for model in result] == [
+ "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "Qwen/Qwen3-Omni-30B-A3B-Thinking",
+ "Qwen/Qwen3-Omni-30B-A3B-Captioner",
+ ]
+ assert all(model["model_type"] == "vlm3" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+ call_args = mock_client.get.call_args
+ assert "sub_type=chat" in call_args[0][0]
@pytest.mark.asyncio
async def test_get_models_embedding_success(self, mocker: MockFixture):
diff --git a/test/backend/services/test_config_sync_service.py b/test/backend/services/test_config_sync_service.py
index 7a583792b..d3db1db3b 100644
--- a/test/backend/services/test_config_sync_service.py
+++ b/test/backend/services/test_config_sync_service.py
@@ -1,4 +1,6 @@
import sys
+import types
+import importlib
from unittest.mock import patch, MagicMock, call
import pytest
@@ -22,6 +24,31 @@
minio_config_mock = MagicMock()
minio_config_mock.validate = MagicMock()
+if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'):
+ sys.modules.pop('consts.const', None)
+if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'):
+ sys.modules.pop('consts', None)
+
+database_client_module = types.ModuleType('database.client')
+database_client_module.MinioClient = MagicMock()
+database_client_module.minio_client = minio_client_mock
+database_client_module.as_dict = MagicMock(side_effect=lambda value: value)
+database_client_module.db_client = MagicMock()
+database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value)
+database_client_module.get_db_session = MagicMock()
+sys.modules['database.client'] = database_client_module
+database_package = sys.modules.get('database') or importlib.import_module('database')
+setattr(database_package, 'client', database_client_module)
+database_model_management_module = types.ModuleType('database.model_management_db')
+database_model_management_module.get_model_by_model_id = MagicMock()
+database_model_management_module.get_model_id_by_display_name = MagicMock()
+database_model_management_module.get_model_records = MagicMock(return_value=[])
+sys.modules['database.model_management_db'] = database_model_management_module
+setattr(database_package, 'model_management_db', database_model_management_module)
+backend_database_client_module = sys.modules.get('backend.database.client')
+if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'):
+ backend_database_client_module.minio_client = minio_client_mock
+
patch('nexent.storage.storage_client_factory.create_storage_client_from_config',
return_value=storage_client_mock).start()
patch('nexent.storage.minio_config.MinIOStorageConfig',
@@ -29,7 +56,7 @@
patch('backend.database.client.MinioClient',
return_value=minio_client_mock).start()
patch('database.client.MinioClient', return_value=minio_client_mock).start()
-patch('backend.database.client.minio_client', minio_client_mock).start()
+patch('backend.database.client.minio_client', minio_client_mock, create=True).start()
patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start()
# Import backend modules after all patches are applied
@@ -52,14 +79,17 @@ def service_mocks():
with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \
patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \
patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \
+ patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \
patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \
patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \
patch('backend.services.config_sync_service.logger') as mock_logger:
+ mock_get_model_records.return_value = []
yield {
'tenant_config_manager': mock_tenant_config_manager,
'get_env_key': mock_get_env_key,
'safe_value': mock_safe_value,
+ 'get_model_records': mock_get_model_records,
'get_model_id': mock_get_model_id,
'get_model_name': mock_get_model_name,
'logger': mock_logger
@@ -1336,6 +1366,8 @@ def side_effect(config_key, tenant_id=None):
"MULTI_EMBEDDING_ID": {},
"RERANK_ID": {},
"VLM_ID": {},
+ "VLM2_ID": {},
+ "VLM3_ID": {},
"STT_ID": {},
"TTS_ID": {}
}
@@ -1348,7 +1380,7 @@ def side_effect(config_key, tenant_id=None):
# Assert
assert isinstance(result, dict)
- assert len(result) == 7 # All model types should be present
+ assert len(result) == 9 # All model types should be present
# Verify successful configs
assert result["llm"]["displayName"] == "GPT-4"
@@ -1372,20 +1404,20 @@ def test_build_models_config_all_failures(self, service_mocks):
# Assert
assert isinstance(result, dict)
# All model types should still be present with empty configs
- assert len(result) == 7
+ assert len(result) == 9
# All configs should be empty due to exceptions
- for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "stt", "tts"]:
+ for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "vlm2", "vlm3", "stt", "tts"]:
assert result[model_key]["name"] == ""
assert result[model_key]["displayName"] == ""
assert result[model_key]["apiConfig"]["apiKey"] == ""
assert result[model_key]["apiConfig"]["modelUrl"] == ""
# Verify that logger.warning was called for each model type
- assert service_mocks['logger'].warning.call_count == 7
+ assert service_mocks['logger'].warning.call_count == 9
warning_calls = service_mocks['logger'].warning.call_args_list
expected_configs = ["LLM_ID", "EMBEDDING_ID", "MULTI_EMBEDDING_ID",
- "RERANK_ID", "VLM_ID", "STT_ID", "TTS_ID"]
+ "RERANK_ID", "VLM_ID", "VLM2_ID", "VLM3_ID", "STT_ID", "TTS_ID"]
for i, config_key in enumerate(expected_configs):
assert f"Failed to get config for {config_key}: Database completely down" in warning_calls[
i][0][0]
diff --git a/test/backend/services/test_config_sync_service_voice.py b/test/backend/services/test_config_sync_service_voice.py
index fcfd531f1..1a3144036 100644
--- a/test/backend/services/test_config_sync_service_voice.py
+++ b/test/backend/services/test_config_sync_service_voice.py
@@ -3,6 +3,8 @@
These tests cover the STT specific fields in save_config_impl.
"""
import sys
+import types
+import importlib
from unittest.mock import patch, MagicMock
import pytest
@@ -22,6 +24,31 @@
minio_config_mock = MagicMock()
minio_config_mock.validate = MagicMock()
+if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'):
+ sys.modules.pop('consts.const', None)
+if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'):
+ sys.modules.pop('consts', None)
+
+database_client_module = types.ModuleType('database.client')
+database_client_module.MinioClient = MagicMock()
+database_client_module.minio_client = minio_client_mock
+database_client_module.as_dict = MagicMock(side_effect=lambda value: value)
+database_client_module.db_client = MagicMock()
+database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value)
+database_client_module.get_db_session = MagicMock()
+sys.modules['database.client'] = database_client_module
+database_package = sys.modules.get('database') or importlib.import_module('database')
+setattr(database_package, 'client', database_client_module)
+database_model_management_module = types.ModuleType('database.model_management_db')
+database_model_management_module.get_model_by_model_id = MagicMock()
+database_model_management_module.get_model_id_by_display_name = MagicMock()
+database_model_management_module.get_model_records = MagicMock(return_value=[])
+sys.modules['database.model_management_db'] = database_model_management_module
+setattr(database_package, 'model_management_db', database_model_management_module)
+backend_database_client_module = sys.modules.get('backend.database.client')
+if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'):
+ backend_database_client_module.minio_client = minio_client_mock
+
patch('nexent.storage.storage_client_factory.create_storage_client_from_config',
return_value=storage_client_mock).start()
patch('nexent.storage.minio_config.MinIOStorageConfig',
@@ -29,7 +56,7 @@
patch('backend.database.client.MinioClient',
return_value=minio_client_mock).start()
patch('database.client.MinioClient', return_value=minio_client_mock).start()
-patch('backend.database.client.minio_client', minio_client_mock).start()
+patch('backend.database.client.minio_client', minio_client_mock, create=True).start()
patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start()
# Import backend modules after all patches are applied
@@ -47,14 +74,17 @@ def service_mocks():
with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \
patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \
patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \
+ patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \
patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \
patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \
patch('backend.services.config_sync_service.logger') as mock_logger:
+ mock_get_model_records.return_value = []
yield {
'tenant_config_manager': mock_tenant_config_manager,
'get_env_key': mock_get_env_key,
'safe_value': mock_safe_value,
+ 'get_model_records': mock_get_model_records,
'get_model_id': mock_get_model_id,
'get_model_name': mock_get_model_name,
'logger': mock_logger
diff --git a/test/backend/services/test_image_service.py b/test/backend/services/test_image_service.py
index 1de8d49fd..34f24568c 100644
--- a/test/backend/services/test_image_service.py
+++ b/test/backend/services/test_image_service.py
@@ -13,10 +13,17 @@
helpers_env = bootstrap_test_env()
helpers_env["mock_const"].DATA_PROCESS_SERVICE = "http://mock-data-process-service"
-helpers_env["mock_const"].MODEL_CONFIG_MAPPING = {"vlm": "vlm_model_config"}
+helpers_env["mock_const"].MODEL_CONFIG_MAPPING = {
+ "vlm": "vlm_model_config",
+ "vlm3": "video_model_config",
+}
mock_const = helpers_env["mock_const"]
-from services.image_service import get_vlm_model, proxy_image_impl
+from services.image_service import get_image_understanding_model, get_video_understanding_model, get_vlm_model, proxy_image_impl
+
+image_service_module = sys.modules[get_vlm_model.__module__]
+if "services" in sys.modules:
+ setattr(sys.modules["services"], "image_service", image_service_module)
# Sample test data
test_url = "https://example.com/image.jpg"
@@ -50,7 +57,7 @@ async def test_proxy_image_impl_success():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -85,7 +92,7 @@ async def test_proxy_image_impl_remote_error():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -118,7 +125,7 @@ async def test_proxy_image_impl_500_error():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -146,7 +153,7 @@ async def test_proxy_image_impl_connection_exception():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function - should raise the exception
@@ -178,7 +185,7 @@ async def test_proxy_image_impl_with_special_chars():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -213,7 +220,7 @@ async def test_proxy_image_impl_json_parse_error():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function - should raise the exception
@@ -253,7 +260,7 @@ async def test_proxy_image_impl_different_status_codes():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -289,7 +296,7 @@ async def test_proxy_image_impl_url_encoding():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function with encoded URL
@@ -305,10 +312,10 @@ async def test_proxy_image_impl_url_encoding():
assert f"url={encoded_url}" in called_url
-@patch('services.image_service.OpenAIVLModel')
-@patch('services.image_service.MessageObserver')
-@patch('services.image_service.get_model_name_from_config')
-@patch('services.image_service.tenant_config_manager')
+@patch.object(image_service_module, 'OpenAIVLModel')
+@patch.object(image_service_module, 'MessageObserver')
+@patch.object(image_service_module, 'get_model_name_from_config')
+@patch.object(image_service_module, 'tenant_config_manager')
def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model):
"""Ensure get_vlm_model builds OpenAIVLModel with tenant config."""
mock_config = {
@@ -324,7 +331,7 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name,
result = get_vlm_model("tenant-1")
mock_tenant_config_manager.get_model_config.assert_called_once_with(
- key=mock_const.MODEL_CONFIG_MAPPING["vlm"],
+ key="vlm_model_config",
tenant_id="tenant-1"
)
mock_message_observer.assert_called_once_with()
@@ -342,10 +349,10 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name,
assert result == mock_model_instance
-@patch('services.image_service.OpenAIVLModel')
-@patch('services.image_service.MessageObserver')
-@patch('services.image_service.get_model_name_from_config')
-@patch('services.image_service.tenant_config_manager')
+@patch.object(image_service_module, 'OpenAIVLModel')
+@patch.object(image_service_module, 'MessageObserver')
+@patch.object(image_service_module, 'get_model_name_from_config')
+@patch.object(image_service_module, 'tenant_config_manager')
def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model):
"""Return None when tenant config is None."""
mock_tenant_config_manager.get_model_config.return_value = None
@@ -359,3 +366,40 @@ def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_mod
# OpenAIVLModel should not be called when config is None
mock_openai_vl_model.assert_not_called()
assert result is None
+
+
+@patch.object(image_service_module, 'get_vlm_model')
+def test_get_image_understanding_model_uses_first_multimodal_slot(mock_get_vlm_model):
+ """Ensure the image understanding alias keeps using the first multimodal slot."""
+ mock_get_vlm_model.return_value = "image-understanding-model"
+
+ result = get_image_understanding_model("tenant-1")
+
+ mock_get_vlm_model.assert_called_once_with(tenant_id="tenant-1")
+ assert result == "image-understanding-model"
+
+
+@patch.object(image_service_module, 'OpenAIVLModel')
+@patch.object(image_service_module, 'MessageObserver')
+@patch.object(image_service_module, 'get_model_name_from_config')
+@patch.object(image_service_module, 'tenant_config_manager')
+def test_get_video_understanding_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model):
+ """Ensure video understanding tools use the third multimodal model slot."""
+ mock_config = {
+ "base_url": "https://mock-video-api",
+ "api_key": "secret",
+ "model_name": "video-model"
+ }
+ mock_tenant_config_manager.get_model_config.return_value = mock_config
+ mock_get_model_name.return_value = "video-model"
+ mock_model_instance = MagicMock()
+ mock_openai_vl_model.return_value = mock_model_instance
+
+ result = get_video_understanding_model("tenant-1")
+
+ mock_tenant_config_manager.get_model_config.assert_called_once_with(
+ key="video_model_config",
+ tenant_id="tenant-1"
+ )
+ mock_openai_vl_model.assert_called_once()
+ assert result == mock_model_instance
diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py
index 479e649bf..83a070fe0 100644
--- a/test/backend/services/test_model_management_service.py
+++ b/test/backend/services/test_model_management_service.py
@@ -326,6 +326,11 @@ def _add_repo_to_name(model_repo, model_name):
def import_svc():
"""Import service under MinioClient patch to avoid real initialization."""
minio_client_mock = mock.MagicMock()
+ sys.modules["database"] = database_mod
+ sys.modules["database.model_management_db"] = db_mm_mod
+ setattr(database_mod, "model_management_db", db_mm_mod)
+ sys.modules.pop("backend.services.model_management_service", None)
+ sys.modules.pop("services.model_management_service", None)
with mock.patch("backend.database.client.MinioClient", return_value=minio_client_mock):
from backend.services import model_management_service as svc # type: ignore
return svc
@@ -673,17 +678,11 @@ async def test_batch_create_models_for_tenant_flow():
existing = [
{"model_id": "del-id", "model_repo": "silicon", "model_name": "delete"},
- {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep"},
+ {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep", "max_tokens": 1024},
]
- def get_by_display(display_name, tenant_id):
- if display_name == "silicon/keep":
- return {"model_id": "keep-id", "max_tokens": 1024}
- return None
-
with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing) as mock_get_existing, \
mock.patch.object(svc, "delete_model_record") as mock_delete, \
- mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \
mock.patch.object(svc, "update_model_record") as mock_update, \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \
mock.patch.object(svc, "create_model_record") as mock_create:
@@ -692,13 +691,35 @@ def get_by_display(display_name, tenant_id):
mock_get_existing.assert_called_once_with("t1", "silicon", "llm")
mock_delete.assert_called_once_with("del-id", "u1", "t1")
- mock_get_by_display.assert_any_call("silicon/keep", "t1")
mock_update.assert_called_once_with(
"keep-id", {"max_tokens": 4096}, "u1")
mock_prep.assert_awaited()
mock_create.assert_called_once()
+@pytest.mark.asyncio
+async def test_batch_create_models_uses_requested_type_for_each_model():
+ svc = import_svc()
+
+ batch_payload = {
+ "provider": "silicon",
+ "type": "vlm",
+ "models": [
+ {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "model_type": "llm", "max_tokens": 4096},
+ ],
+ "api_key": "k",
+ }
+
+ with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \
+ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \
+ mock.patch.object(svc, "create_model_record"):
+
+ await svc.batch_create_models_for_tenant("u1", "t1", batch_payload)
+
+ prepared_model = mock_prep.call_args.kwargs["model"]
+ assert prepared_model["model_type"] == "vlm"
+
+
@pytest.mark.asyncio
async def test_batch_create_models_max_tokens_update():
"""Test batch_create_models updates max_tokens when display_name exists and max_tokens changed (covers lines 160->173, 168->171)"""
@@ -715,22 +736,16 @@ async def test_batch_create_models_max_tokens_update():
"api_key": "k",
}
- def get_by_display(display_name, tenant_id):
- if display_name == "silicon/model1":
- # Different from new value
- return {"model_id": "id1", "max_tokens": 4096}
- elif display_name == "silicon/model2":
- return {"model_id": "id2", "max_tokens": 4096} # Same as new value
- elif display_name == "silicon/model3":
- # Existing has value, new is None
- return {"model_id": "id3", "max_tokens": 2048}
- return None
+ existing = [
+ {"model_id": "id1", "model_repo": "silicon", "model_name": "model1", "max_tokens": 4096},
+ {"model_id": "id2", "model_repo": "silicon", "model_name": "model2", "max_tokens": 4096},
+ {"model_id": "id3", "model_repo": "silicon", "model_name": "model3", "max_tokens": 2048},
+ ]
- with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \
+ with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing), \
mock.patch.object(svc, "delete_model_record"), \
mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \
- mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \
- mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \
+ mock.patch.object(svc, "add_repo_to_name", side_effect=lambda *args, **kwargs: f"{kwargs.get('model_repo', args[0] if args else '')}/{kwargs.get('model_name', args[1] if len(args) > 1 else '')}"), \
mock.patch.object(svc, "update_model_record") as mock_update, \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \
mock.patch.object(svc, "create_model_record", return_value=True):
diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py
index 7855fbbbb..7b08788fa 100644
--- a/test/backend/services/test_tool_configuration_service.py
+++ b/test/backend/services/test_tool_configuration_service.py
@@ -14,8 +14,59 @@
boto3_mock = MagicMock()
minio_client_mock = MagicMock()
sys.modules['boto3'] = boto3_mock
+jsonref_mock = types.ModuleType('jsonref')
+jsonref_mock.replace_refs = lambda value: value
+sys.modules['jsonref'] = jsonref_mock
-# Patch smolagents and its sub-modules before importi ng consts.model to avoid ImportError
+fastmcp_mock = types.ModuleType('fastmcp')
+fastmcp_mock.__path__ = []
+
+
+class MockFastMcpClient:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ def is_connected(self):
+ return True
+
+ async def call_tool(self, *args, **kwargs):
+ return MagicMock()
+
+
+class MockSSETransport:
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+class MockStreamableHttpTransport:
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+fastmcp_mock.Client = MockFastMcpClient
+fastmcp_client_mock = types.ModuleType('fastmcp.client')
+fastmcp_client_mock.__path__ = []
+fastmcp_transports_mock = types.ModuleType('fastmcp.client.transports')
+fastmcp_transports_mock.SSETransport = MockSSETransport
+fastmcp_transports_mock.StreamableHttpTransport = MockStreamableHttpTransport
+sys.modules['fastmcp'] = fastmcp_mock
+sys.modules['fastmcp.client'] = fastmcp_client_mock
+sys.modules['fastmcp.client.transports'] = fastmcp_transports_mock
+
+mcpadapt_mock = types.ModuleType('mcpadapt')
+mcpadapt_mock.__path__ = []
+mcpadapt_smolagents_adapter_mock = types.ModuleType('mcpadapt.smolagents_adapter')
+mcpadapt_smolagents_adapter_mock._sanitize_function_name = lambda name: name
+sys.modules['mcpadapt'] = mcpadapt_mock
+sys.modules['mcpadapt.smolagents_adapter'] = mcpadapt_smolagents_adapter_mock
+
+# Patch smolagents and its sub-modules before importing consts.model to avoid ImportError
mock_smolagents = MagicMock()
sys.modules['smolagents'] = mock_smolagents
@@ -328,6 +379,38 @@ def validate(self):
sys.modules['nexent.monitor'] = monitor_module
setattr(sys.modules['nexent'], 'monitor', monitor_module)
+# Mock services modules before importing tool_configuration_service so absolute
+# imports inside that module do not walk into real service dependency chains.
+sys.modules['services'] = _create_package_mock('services')
+services_modules = {
+ 'file_management_service': {
+ 'get_llm_model': MagicMock(),
+ 'validate_urls_access': MagicMock(return_value=True),
+ },
+ 'vectordatabase_service': {
+ 'get_embedding_model': MagicMock(),
+ 'get_embedding_model_by_index_name': MagicMock(),
+ 'get_rerank_model': MagicMock(),
+ 'get_vector_db_core': MagicMock(),
+ 'ElasticSearchService': MagicMock(),
+ },
+ 'tenant_config_service': {
+ 'get_selected_knowledge_list': MagicMock(),
+ 'build_knowledge_name_mapping': MagicMock(),
+ },
+ 'image_service': {
+ 'get_vlm_model': MagicMock(),
+ 'get_video_understanding_model': MagicMock(),
+ },
+}
+for service_name, attrs in services_modules.items():
+ service_module = types.ModuleType(f'services.{service_name}')
+ for attr_name, attr_value in attrs.items():
+ setattr(service_module, attr_name, attr_value)
+ sys.modules[f'services.{service_name}'] = service_module
+ # Expose on parent package for patch resolution
+ setattr(sys.modules['services'], service_name, service_module)
+
# Load actual backend modules so that patch targets resolve correctly
import importlib # noqa: E402
backend_module = importlib.import_module('backend')
@@ -342,23 +425,6 @@ def validate(self):
# Ensure services package can resolve tool_configuration_service for patching
sys.modules['services.tool_configuration_service'] = backend_services_module
-# Mock services modules
-sys.modules['services'] = _create_package_mock('services')
-services_modules = {
- 'file_management_service': {'get_llm_model': MagicMock()},
- 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(),
- 'ElasticSearchService': MagicMock()},
- 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()},
- 'image_service': {'get_vlm_model': MagicMock()}
-}
-for service_name, attrs in services_modules.items():
- service_module = types.ModuleType(f'services.{service_name}')
- for attr_name, attr_value in attrs.items():
- setattr(service_module, attr_name, attr_value)
- sys.modules[f'services.{service_name}'] = service_module
- # Expose on parent package for patch resolution
- setattr(sys.modules['services'], service_name, service_module)
-
# Patch storage factory and MinIO config validation to avoid errors during initialization
# These patches must be started before any imports that use MinioClient
storage_client_mock = MagicMock()
@@ -380,6 +446,7 @@ def validate(self):
patch('services.tenant_config_service.build_knowledge_name_mapping',
MagicMock()).start()
patch('services.image_service.get_vlm_model', MagicMock()).start()
+patch('services.image_service.get_video_understanding_model', MagicMock()).start()
patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start()
patch('backend.services.tool_configuration_service.get_embedding_model_by_index_name', MagicMock()).start()
@@ -2703,6 +2770,63 @@ def test_validate_local_tool_analyze_image_missing_user(self, mock_get_class):
)
+class TestValidateLocalToolAnalyzeAudioVideo:
+ """Test cases for _validate_local_tool with analyze_audio/analyze_video tools."""
+
+ @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"])
+ @patch('backend.services.tool_configuration_service.minio_client')
+ @patch('backend.services.tool_configuration_service.get_video_understanding_model')
+ @patch('backend.services.tool_configuration_service._get_tool_class_by_name')
+ @patch('backend.services.tool_configuration_service.inspect.signature')
+ def test_validate_local_tool_analyze_audio_video_success(
+ self, mock_signature, mock_get_class, mock_get_video_model, mock_minio_client, tool_name):
+ mock_tool_class = Mock()
+ mock_tool_instance = Mock()
+ mock_tool_instance.forward.return_value = f"{tool_name} result"
+ mock_tool_class.return_value = mock_tool_instance
+ mock_get_class.return_value = mock_tool_class
+ mock_get_video_model.return_value = "mock_video_model"
+
+ mock_sig = Mock()
+ mock_sig.parameters = {}
+ mock_signature.return_value = mock_sig
+
+ from backend.services.tool_configuration_service import _validate_local_tool
+
+ result = _validate_local_tool(
+ tool_name,
+ {"media": "bytes"},
+ {"prompt": "describe"},
+ "tenant1",
+ "user1"
+ )
+
+ assert result == f"{tool_name} result"
+ mock_get_video_model.assert_called_once_with(tenant_id="tenant1")
+ call_kwargs = mock_tool_class.call_args.kwargs
+ assert call_kwargs["vlm_model"] == "mock_video_model"
+ assert "storage_client" in call_kwargs
+ assert callable(call_kwargs["validate_url_access"])
+ mock_tool_instance.forward.assert_called_once_with(media="bytes")
+
+ @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"])
+ @patch('backend.services.tool_configuration_service._get_tool_class_by_name')
+ def test_validate_local_tool_analyze_audio_video_missing_tenant(self, mock_get_class, tool_name):
+ mock_get_class.return_value = Mock()
+
+ from backend.services.tool_configuration_service import _validate_local_tool
+
+ with pytest.raises(ToolExecutionException,
+ match=f"Tenant ID and User ID are required for {tool_name} validation"):
+ _validate_local_tool(
+ tool_name,
+ {"media": "bytes"},
+ {"prompt": "describe"},
+ None,
+ "user1"
+ )
+
+
class TestValidateLocalToolDatamateSearchTool:
"""Test cases for _validate_local_tool function with datamate_search_tool"""
diff --git a/test/common/test_mocks.py b/test/common/test_mocks.py
index c87b52859..c57941780 100644
--- a/test/common/test_mocks.py
+++ b/test/common/test_mocks.py
@@ -112,6 +112,8 @@ def setup_common_mocks():
"multiEmbedding": "MULTI_EMBEDDING_ID",
"rerank": "RERANK_ID",
"vlm": "VLM_ID",
+ "vlm2": "VLM2_ID",
+ "vlm3": "VLM3_ID",
"stt": "STT_ID",
"tts": "TTS_ID"
}
diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py
index 7d6af852c..ab09b95b6 100644
--- a/test/sdk/core/agents/test_nexent_agent.py
+++ b/test/sdk/core/agents/test_nexent_agent.py
@@ -2475,6 +2475,52 @@ def test_create_local_tool_analyze_image(self, nexent_agent_instance):
assert call_kwargs["param1"] == "value1"
assert result == mock_tool_instance
+ @pytest.mark.parametrize(
+ "class_name,tool_name",
+ [
+ ("AnalyzeAudioTool", "analyze_audio"),
+ ("AnalyzeVideoTool", "analyze_video"),
+ ],
+ )
+ def test_create_local_tool_analyze_audio_video(self, nexent_agent_instance, class_name, tool_name):
+ """Test successful audio/video analysis tool creation."""
+ mock_tool_class = MagicMock()
+ mock_tool_instance = MagicMock()
+ mock_tool_class.return_value = mock_tool_instance
+
+ tool_config = ToolConfig(
+ class_name=class_name,
+ name=tool_name,
+ description="desc",
+ inputs="{}",
+ output_type="string",
+ params={"param1": "value1"},
+ source="local",
+ metadata={
+ "vlm_model": ["video-understanding-model"],
+ "storage_client": "storage"
+ }
+ )
+
+ original_value = nexent_agent.__dict__.get(class_name)
+ nexent_agent.__dict__[class_name] = mock_tool_class
+
+ try:
+ result = nexent_agent_instance.create_local_tool(tool_config)
+ finally:
+ if original_value is not None:
+ nexent_agent.__dict__[class_name] = original_value
+ elif class_name in nexent_agent.__dict__:
+ del nexent_agent.__dict__[class_name]
+
+ mock_tool_class.assert_called_once()
+ call_kwargs = mock_tool_class.call_args[1]
+ assert call_kwargs["observer"] == nexent_agent_instance.observer
+ assert call_kwargs["vlm_model"] == ["video-understanding-model"]
+ assert call_kwargs["storage_client"] == "storage"
+ assert call_kwargs["param1"] == "value1"
+ assert result == mock_tool_instance
+
def test_create_local_tool_analyze_text_file_with_validate_url_access_none(self, nexent_agent_instance):
"""Test AnalyzeTextFileTool creation with validate_url_access not in metadata (None)."""
mock_tool_class = MagicMock()
diff --git a/test/sdk/core/models/test_openai_vlm.py b/test/sdk/core/models/test_openai_vlm.py
index f1db49380..4f7104290 100644
--- a/test/sdk/core/models/test_openai_vlm.py
+++ b/test/sdk/core/models/test_openai_vlm.py
@@ -62,7 +62,13 @@ def vl_model_instance():
"""Return an OpenAIVLModel instance with minimal viable attributes for tests."""
observer = MagicMock()
- model = ImportedOpenAIVLModel(observer=observer, ssl_verify=True)
+ model = ImportedOpenAIVLModel(
+ observer=observer,
+ model_id="dummy-model",
+ api_key="dummy-key",
+ api_base="https://example.test",
+ ssl_verify=True,
+ )
# Inject dummy attributes required by the method under test
model.model_id = "dummy-model"
@@ -321,3 +327,55 @@ def test_analyze_image_calls_prepare_image_message(vl_model_instance, tmp_path):
# Verify prepare_image_message was called with correct arguments
mock_prepare.assert_called_once_with(str(test_image), custom_prompt)
+
+
+def test_prepare_media_message_audio(vl_model_instance):
+ audio_stream = MagicMock()
+ audio_stream.read.return_value = b"audio bytes"
+
+ messages = vl_model_instance.prepare_media_message(
+ audio_stream,
+ media_type="audio",
+ content_type="audio/mpeg",
+ system_prompt="Listen carefully",
+ )
+
+ assert messages[0]["content"][0]["type"] == "audio_url"
+ assert messages[0]["content"][0]["audio_url"]["url"].startswith("data:audio/mpeg;base64,")
+ assert messages[0]["content"][1] == {"type": "text", "text": "Listen carefully"}
+
+
+def test_prepare_media_message_video(vl_model_instance):
+ video_stream = MagicMock()
+ video_stream.read.return_value = b"video bytes"
+
+ messages = vl_model_instance.prepare_media_message(
+ video_stream,
+ media_type="video",
+ content_type="video/mp4",
+ system_prompt="Watch carefully",
+ )
+
+ assert messages[0]["content"][0]["type"] == "video_url"
+ assert messages[0]["content"][0]["video_url"]["url"].startswith("data:video/mp4;base64,")
+ assert messages[0]["content"][0]["video_url"]["max_frames"] == 16
+ assert messages[0]["content"][0]["video_url"]["fps"] == 1
+ assert messages[0]["content"][1] == {"type": "text", "text": "Watch carefully"}
+
+
+def test_analyze_audio_calls_prepare_media_message(vl_model_instance):
+ with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare:
+ vl_model_instance.__call__ = MagicMock(return_value=MagicMock())
+
+ vl_model_instance.analyze_audio("audio.mp3", system_prompt="Analyze", content_type="audio/mpeg")
+
+ mock_prepare.assert_called_once_with("audio.mp3", "audio", "audio/mpeg", "Analyze")
+
+
+def test_analyze_video_calls_prepare_media_message(vl_model_instance):
+ with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare:
+ vl_model_instance.__call__ = MagicMock(return_value=MagicMock())
+
+ vl_model_instance.analyze_video("video.mp4", system_prompt="Analyze", content_type="video/mp4")
+
+ mock_prepare.assert_called_once_with("video.mp4", "video", "video/mp4", "Analyze")
diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py
new file mode 100644
index 000000000..94401b61d
--- /dev/null
+++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py
@@ -0,0 +1,119 @@
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+import pytest
+
+from sdk.nexent.core.tools import analyze_audio_tool, analyze_video_tool
+from sdk.nexent.core.tools.analyze_audio_tool import AnalyzeAudioTool
+from sdk.nexent.core.tools.analyze_video_tool import AnalyzeVideoTool
+from sdk.nexent.core.utils.observer import MessageObserver, ProcessType
+
+
+@pytest.fixture
+def mock_storage_client():
+ class DummyStorage:
+ pass
+
+ return DummyStorage()
+
+
+@pytest.fixture
+def mock_vlm_model():
+ return MagicMock()
+
+
+@pytest.fixture
+def observer_en():
+ observer = MagicMock(spec=MessageObserver)
+ observer.lang = "en"
+ return observer
+
+
+def test_analyze_audio_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch):
+ calls = []
+
+ def _fake_get_prompt(template_type, language=None, **_):
+ calls.append((template_type, language))
+ return {"system_prompt": "Analyze audio for {{ query }}"}
+
+ monkeypatch.setattr(analyze_audio_tool, "get_prompt_template", _fake_get_prompt)
+ mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result")
+ tool = AnalyzeAudioTool(
+ observer=observer_en,
+ vlm_model=mock_vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ result = tool._forward_impl([b"ID3audio-bytes"], "what happened?")
+
+ assert result == ["audio result"]
+ assert calls == [("analyze_audio", "en")]
+ mock_vlm_model.analyze_audio.assert_called_once()
+ call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs
+ assert hasattr(call_kwargs["audio_input"], "read")
+ assert call_kwargs["content_type"].startswith("audio/")
+ observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...")
+
+
+def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client):
+ vlm_model = SimpleNamespace(
+ model_id="Qwen/Qwen3-VL-32B-Instruct",
+ client_kwargs={"base_url": "https://api.siliconflow.cn/v1"},
+ )
+ tool = AnalyzeAudioTool(
+ observer=observer_en,
+ vlm_model=vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ tool._forward_impl([b"ID3audio-bytes"], "what happened?")
+
+ assert "Please choose a Qwen3-Omni model" in str(exc_info.value)
+
+
+def test_analyze_video_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch):
+ calls = []
+
+ def _fake_get_prompt(template_type, language=None, **_):
+ calls.append((template_type, language))
+ return {"system_prompt": "Analyze video for {{ query }}"}
+
+ monkeypatch.setattr(analyze_video_tool, "get_prompt_template", _fake_get_prompt)
+ mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result")
+ tool = AnalyzeVideoTool(
+ observer=observer_en,
+ vlm_model=mock_vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?")
+
+ assert result == ["video result"]
+ assert calls == [("analyze_video", "en")]
+ mock_vlm_model.analyze_video.assert_called_once()
+ call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs
+ assert hasattr(call_kwargs["video_input"], "read")
+ assert call_kwargs["content_type"].startswith("video/")
+ observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...")
+
+
+@pytest.mark.parametrize(
+ "tool_class,input_name,error_text",
+ [
+ (AnalyzeAudioTool, "audio_urls_list", "Video understanding model is not configured"),
+ (AnalyzeVideoTool, "video_urls_list", "Video understanding model is not configured"),
+ ],
+)
+def test_analyze_audio_video_require_video_understanding_model(
+ tool_class, input_name, error_text, observer_en, mock_storage_client):
+ tool = tool_class(
+ observer=observer_en,
+ vlm_model=None,
+ storage_client=mock_storage_client,
+ )
+
+ with pytest.raises(Exception) as exc_info:
+ tool._forward_impl(**{input_name: [b"media"], "query": "question"})
+
+ assert error_text in str(exc_info.value)
diff --git a/test/sdk/core/tools/test_analyze_image_tool.py b/test/sdk/core/tools/test_analyze_image_tool.py
index a8598a8ad..63be0ac54 100644
--- a/test/sdk/core/tools/test_analyze_image_tool.py
+++ b/test/sdk/core/tools/test_analyze_image_tool.py
@@ -136,7 +136,7 @@ def test_forward_impl_vlm_model_none(self, observer_en, mock_storage_client):
with pytest.raises(Exception) as exc_info:
tool._forward_impl([b"img"], "question")
- assert "Vision Language Model (VLM) is not configured" in str(
+ assert "Image understanding model is not configured" in str(
exc_info.value)
def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_client):
@@ -150,7 +150,7 @@ def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_cli
with pytest.raises(Exception) as exc_info:
tool._forward_impl([b"img"], "问题")
- assert "视觉语言模型(VLM)未配置" in str(exc_info.value)
+ assert "图片理解模型未配置" in str(exc_info.value)
def test_forward_impl_observer_none_uses_english(self, mock_vlm_model, mock_storage_client):
"""Test that English is used when observer is None."""
@@ -353,7 +353,7 @@ def test_observer_add_message_not_called_when_none(self, mock_vlm_model, mock_st
def test_tool_name_and_description(self, tool):
"""Test that tool name and description are set correctly."""
assert tool.name == "analyze_image"
- assert "visual language model" in tool.description.lower()
+ assert "image understanding model" in tool.description.lower()
assert "image" in tool.description.lower()
def test_tool_inputs_schema(self, tool):
diff --git a/test/sdk/core/utils/test_prompt_template_utils.py b/test/sdk/core/utils/test_prompt_template_utils.py
index c0a3ad634..a50929b8d 100644
--- a/test/sdk/core/utils/test_prompt_template_utils.py
+++ b/test/sdk/core/utils/test_prompt_template_utils.py
@@ -61,6 +61,28 @@ def test_get_prompt_template_analyze_image_en(self, mock_yaml_load, mock_file):
# Verify result
assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"}
+ @pytest.mark.parametrize(
+ "template_type,language,expected_file",
+ [
+ ("analyze_audio", "en", "prompts/analyze_audio_en.yaml"),
+ ("analyze_audio", "zh", "prompts/analyze_audio_zh.yaml"),
+ ("analyze_video", "en", "prompts/analyze_video_en.yaml"),
+ ("analyze_video", "zh", "prompts/analyze_video_zh.yaml"),
+ ],
+ )
+ @patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"\nuser_prompt: "User prompt"')
+ @patch('yaml.safe_load')
+ def test_get_prompt_template_analyze_audio_video(
+ self, mock_yaml_load, mock_file, template_type, language, expected_file):
+ """Test get_prompt_template for audio/video templates."""
+ mock_yaml_load.return_value = {"system_prompt": "Test prompt", "user_prompt": "User prompt"}
+
+ result = get_prompt_template(template_type=template_type, language=language)
+
+ call_args = mock_file.call_args[0]
+ assert expected_file in call_args[0].replace('\\', '/')
+ assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"}
+
@patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"')
@patch('yaml.safe_load')
@patch('sdk.nexent.core.utils.prompt_template_utils.LANGUAGE', {'ZH': 'zh', 'EN': 'en'})
@@ -174,4 +196,4 @@ def test_get_prompt_template_path_resolution(self, mock_yaml_load, mock_file):
assert mock_file.called
call_args = mock_file.call_args[0]
# Path should be absolute or contain the expected template file
- assert 'analyze_image_en.yaml' in call_args[0]
\ No newline at end of file
+ assert 'analyze_image_en.yaml' in call_args[0]
diff --git a/test/sdk/multi_modal/test_load_save_object.py b/test/sdk/multi_modal/test_load_save_object.py
index 92425791c..1670e6a9d 100644
--- a/test/sdk/multi_modal/test_load_save_object.py
+++ b/test/sdk/multi_modal/test_load_save_object.py
@@ -26,6 +26,23 @@ def test_get_client_requires_initialized_storage():
manager._get_client()
+def test_s3_single_slash_url_supported():
+ assert lso.is_url("s3:/bucket/path/to/image.png") == "s3"
+ assert lso.parse_s3_url("s3:/bucket/path/to/image.png") == (
+ "bucket",
+ "path/to/image.png",
+ )
+
+
+def test_s3_blob_preview_url_rejected():
+ assert lso.is_url("s3:/blob:http://localhost:3000/preview") is None
+
+
+def test_parse_s3_blob_preview_url_rejected():
+ with pytest.raises(ValueError, match="Invalid s3:// URL format"):
+ lso.parse_s3_url("s3:/blob:http://localhost:3000/preview")
+
+
def test_download_file_from_http(monkeypatch):
manager = make_manager()