From d52afea55e716937580fe3d9e65aab98e845bd0f Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 14 May 2026 09:58:24 +0800 Subject: [PATCH 01/48] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 9 +- backend/consts/model.py | 3 + backend/database/db_models.py | 2 + backend/database/model_management_db.py | 55 +++++++++ .../conversation_management_service.py | 5 +- backend/services/file_management_service.py | 2 + backend/services/model_health_service.py | 13 ++- backend/services/model_management_service.py | 28 ++++- backend/services/model_provider_service.py | 7 +- backend/utils/llm_utils.py | 3 + ..._add_timeout_seconds_to_model_record_t.sql | 10 ++ .../components/model/ModelAddDialog.tsx | 106 +++++++++++++++++- .../components/model/ModelDeleteDialog.tsx | 57 +++++++++- .../components/model/ModelEditDialog.tsx | 47 +++++++- frontend/public/locales/en/common.json | 1 + frontend/public/locales/zh/common.json | 1 + frontend/services/modelService.ts | 14 +++ frontend/types/modelConfig.ts | 1 + sdk/nexent/core/agents/agent_model.py | 4 + sdk/nexent/core/agents/nexent_agent.py | 1 + sdk/nexent/core/models/openai_llm.py | 37 ++++-- 21 files changed, 371 insertions(+), 35 deletions(-) create mode 100644 docker/sql/v2.0.5_0507_add_timeout_seconds_to_model_record_t.sql diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 5a11b550b..90509c8f5 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -247,7 +247,8 @@ async def create_model_config_list(tenant_id): ), url=record["base_url"], ssl_verify=record.get("ssl_verify", True), - model_factory=record.get("model_factory"))) + model_factory=record.get("model_factory"), + timeout_seconds=record.get("timeout_seconds"))) # fit for old version, main_model and sub_model use default model main_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) @@ -258,7 +259,8 @@ async def create_model_config_list(tenant_id): "model_name") else "", url=main_model_config.get("base_url", ""), ssl_verify=main_model_config.get("ssl_verify", True), - model_factory=main_model_config.get("model_factory"))) + model_factory=main_model_config.get("model_factory"), + timeout_seconds=main_model_config.get("timeout_seconds"))) model_list.append( ModelConfig(cite_name="sub_model", api_key=main_model_config.get("api_key", ""), @@ -266,7 +268,8 @@ async def create_model_config_list(tenant_id): "model_name") else "", url=main_model_config.get("base_url", ""), ssl_verify=main_model_config.get("ssl_verify", True), - model_factory=main_model_config.get("model_factory"))) + model_factory=main_model_config.get("model_factory"), + timeout_seconds=main_model_config.get("timeout_seconds"))) return model_list diff --git a/backend/consts/model.py b/backend/consts/model.py index 6c792501f..9e264c7f8 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -121,6 +121,7 @@ class ModelRequest(BaseModel): # STT specific fields model_appid: Optional[str] = None access_token: Optional[str] = None + timeout_seconds: Optional[int] = None class ProviderModelRequest(BaseModel): @@ -772,6 +773,7 @@ class ManageTenantModelCreateRequest(BaseModel): # STT specific fields model_appid: Optional[str] = Field(None, description="Application ID for STT models (e.g., Volcano Engine)") access_token: Optional[str] = Field(None, description="Access token for STT models (e.g., Volcano Engine)") + timeout_seconds: Optional[int] = Field(None, description="Request timeout in seconds") class ManageTenantModelUpdateRequest(BaseModel): @@ -792,6 +794,7 @@ class ManageTenantModelUpdateRequest(BaseModel): # STT specific fields model_appid: Optional[str] = Field(None, description="Application ID for STT models") access_token: Optional[str] = Field(None, description="Access token for STT models") + timeout_seconds: Optional[int] = Field(None, description="Request timeout in seconds") class ManageTenantModelDeleteRequest(BaseModel): diff --git a/backend/database/db_models.py b/backend/database/db_models.py index baa8e903e..94f5be80b 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -182,6 +182,8 @@ class ModelRecord(TableBase): String(100), doc="Application ID for model authentication (used by some STT/TTS providers like Volcano Engine)") access_token = Column( String(100), doc="Access token for model authentication (used by some STT/TTS providers like Volcano Engine)") + timeout_seconds = Column( + Integer, doc="Request timeout in seconds for this model. Default is 120 seconds.") class ModelMonitoringRecord(SimpleTableBase): diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..7838315b8 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, List, Optional from sqlalchemy import and_, desc, func, insert, select, update @@ -7,6 +8,8 @@ from .db_models import ModelRecord from .utils import add_creation_tracking, add_update_tracking +logger = logging.getLogger("database.model_management_db") + def create_model_record(model_data: Dict[str, Any], user_id: str, tenant_id: str) -> bool: """ @@ -84,6 +87,58 @@ def update_model_record( return result.rowcount > 0 +def update_model_record_by_model_name( + model_name: str, + update_data: Dict[str, Any], + user_id: Optional[str] = None, + tenant_id: Optional[str] = None, + model_repo: Optional[str] = None +) -> bool: + """ + Update a model record by model_name and tenant_id. + + Args: + model_name: Model name (display name, not the primary key) + update_data: Dictionary containing update data + user_id: Reserved parameter for filling updated_by field + tenant_id: Tenant ID for filtering + model_repo: Optional model repo for more precise matching + + Returns: + bool: Whether the operation was successful + """ + import logging + db_logger = logging.getLogger("database.client") + + with get_db_session() as session: + # Data cleaning + cleaned_data = db_client.clean_string_values(update_data) + + # Add update timestamp + cleaned_data["update_time"] = func.current_timestamp() + if user_id: + cleaned_data = add_update_tracking(cleaned_data, user_id) + + db_logger.info(f"update_model_record_by_model_name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}, cleaned_data={cleaned_data}") + + # Build conditions list + conditions = [ + ModelRecord.model_name == model_name, + ModelRecord.tenant_id == tenant_id + ] + if model_repo: + conditions.append(ModelRecord.model_repo == model_repo) + + # Build the update statement + stmt = update(ModelRecord).where(*conditions).values(cleaned_data) + + # Execute the update statement + result = session.execute(stmt) + db_logger.info(f"update_model_record_by_model_name: rowcount={result.rowcount}") + + return result.rowcount > 0 + + def delete_model_record(model_id: int, user_id: str, tenant_id: str) -> bool: """ Delete a model record (soft delete) and update the update timestamp diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index d5d4a85a4..c3571fcf3 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -248,6 +248,8 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" display_name = model_config.get("display_name", "") if model_config else "" set_monitoring_operation("title_generation", display_name=display_name or None) + timeout_seconds = model_config.get("timeout_seconds") if model_config else None + # Create OpenAIModel instance llm = OpenAIModel( model_id=get_model_name_from_config(model_config) if model_config.get("model_name") else "", @@ -256,7 +258,8 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" temperature=0.7, top_p=0.95, model_factory=model_config.get("model_factory", None), - ssl_verify=model_config.get("ssl_verify", True) + ssl_verify=model_config.get("ssl_verify", True), + timeout_seconds=timeout_seconds, ) # Build messages - use new template variable 'question' instead of 'content' diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index b5cd048bf..7dad75a0a 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -352,6 +352,7 @@ def get_llm_model(tenant_id: str): # Get the tenant config main_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) + timeout_seconds = main_model_config.get("timeout_seconds") if main_model_config else None long_text_to_text_model = OpenAILongContextModel( observer=MessageObserver(), model_id=get_model_name_from_config(main_model_config), @@ -359,6 +360,7 @@ def get_llm_model(tenant_id: str): api_key=main_model_config.get("api_key"), max_context_tokens=main_model_config.get("max_tokens"), ssl_verify=main_model_config.get("ssl_verify", True), + timeout_seconds=timeout_seconds, ) return long_text_to_text_model diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index a20b2a6ca..b6dac2d04 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -71,6 +71,7 @@ async def _perform_connectivity_check( model_appid: Optional[str] = None, access_token: Optional[str] = None, display_name: Optional[str] = None, + timeout_seconds: Optional[float] = None, ) -> bool: """ Perform specific model connectivity check @@ -80,6 +81,8 @@ async def _perform_connectivity_check( model_base_url: Model base URL model_api_key: API key ssl_verify: Whether to verify SSL certificates (default: True) + display_name: Optional display name for monitoring + timeout_seconds: Optional request timeout in seconds Returns: bool: Connectivity check result """ @@ -115,7 +118,8 @@ async def _perform_connectivity_check( model_id=model_name, api_base=model_base_url, api_key=model_api_key, - ssl_verify=ssl_verify + ssl_verify=ssl_verify, + timeout_seconds=timeout_seconds, ).check_connectivity() elif model_type == "rerank": rerank_model = OpenAICompatibleRerank( @@ -192,6 +196,7 @@ async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: model_factory = model.get("model_factory") model_appid = model.get("model_appid") access_token = model.get("access_token") + timeout_seconds = model.get("timeout_seconds") try: set_monitoring_context(tenant_id=tenant_id) @@ -199,6 +204,8 @@ async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: connectivity = await _perform_connectivity_check( model_name, model_type, model_base_url, model_api_key, ssl_verify, model_factory, model_appid, access_token,display_name=display_name, + display_name=display_name, + timeout_seconds=timeout_seconds, ) except Exception as e: update_data = { @@ -245,16 +252,20 @@ async def verify_model_config_connectivity(model_config: dict): model_factory = model_config.get("model_factory") model_appid = model_config.get("model_appid") access_token = model_config.get("access_token") + # Get timeout from model config if present + timeout_seconds = model_config.get("timeout_seconds") try: connectivity = await _perform_connectivity_check( model_name, model_type, model_base_url, model_api_key, ssl_verify, model_factory, model_appid, access_token + timeout_seconds=timeout_seconds, ) if not connectivity and ssl_verify: connectivity = await _perform_connectivity_check( model_name, model_type, model_base_url, model_api_key, False, model_factory, model_appid, access_token + timeout_seconds=timeout_seconds, ) if not connectivity: error_msg = f"Failed to connect to model '{model_name}' at {model_base_url}. Please verify the URL, API key, and network connection." diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index d012803be..64675d047 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -13,6 +13,7 @@ get_model_records, get_models_by_tenant_factory_type, update_model_record, + update_model_record_by_model_name, ) from services.model_provider_service import ( prepare_model_dict, @@ -276,12 +277,31 @@ async def update_single_model_for_tenant( async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_list: List[Dict[str, Any]]): - """Batch update models for a tenant.""" + """Batch update models for a tenant by model_id or model_name.""" try: for model in model_list: - update_model_record(model["model_id"], model, user_id, tenant_id) - - logging.debug("Batch update models successfully") + # Build update data excluding id fields + update_data = {k: v for k, v in model.items() if k not in ["model_id", "model_name"]} + + model_id_or_name = model.get("model_id") or model.get("model_name") + + # Check if model_id is a numeric string (primary key) + if model_id_or_name and model_id_or_name.isdigit(): + # Use model_id (primary key) for update + logging.info(f"[DEBUG] Updating model by id: model_id={model_id_or_name}, tenant_id={tenant_id}, update_data={update_data}") + update_model_record(int(model_id_or_name), update_data, user_id, tenant_id) + else: + # Parse "model_repo/model_name" format from frontend's model_id field + if "/" in model_id_or_name: + model_repo, model_name = model_id_or_name.split("/", 1) + else: + model_repo = None + model_name = model_id_or_name + + logging.info(f"[DEBUG] Updating model by name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}, update_data={update_data}") + update_model_record_by_model_name(model_name, update_data, user_id, tenant_id, model_repo) + + logging.info("[DEBUG] Batch update models successfully") except Exception as e: logging.error(f"Failed to batch update models: {str(e)}") raise Exception(f"Failed to batch update models: {str(e)}") diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index dbff17082..6fc729a39 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -100,11 +100,13 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # Build the canonical representation using the existing Pydantic schema for # consistency of validation and default handling. # For embedding/multi_embedding models, max_tokens will be set via connectivity check later, - # so use 0 as placeholder if not provided + # so use 0 as placeholder if not provided. + # Set default timeout_seconds to 120 for LLM models (embedding models don't need it). model_type = model["model_type"] is_embedding_type = model_type in ["embedding", "multi_embedding"] max_tokens_value = model.get( "max_tokens", 0) if not is_embedding_type else 0 + timeout_seconds_value = 120 if not is_embedding_type else None model_obj = ModelRequest( model_factory=provider, @@ -115,7 +117,8 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a display_name=model_display_name, expected_chunk_size=expected_chunk_size, maximum_chunk_size=maximum_chunk_size, - chunk_batch=chunk_batch + chunk_batch=chunk_batch, + timeout_seconds=timeout_seconds_value ) model_dict = model_obj.model_dump() diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index fb2e06fdb..a5e90c727 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -73,6 +73,8 @@ def call_llm_for_system_prompt( set_monitoring_operation("system_prompt_generation", display_name=display_name or None) + timeout_seconds = llm_model_config.get("timeout_seconds") if llm_model_config else None + llm = OpenAIModel( model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", api_base=llm_model_config.get("base_url", "") if llm_model_config else "", @@ -82,6 +84,7 @@ def call_llm_for_system_prompt( model_factory=llm_model_config.get("model_factory") if llm_model_config else None, ssl_verify=llm_model_config.get("ssl_verify", True) if llm_model_config else True, display_name=display_name or None, + timeout_seconds=timeout_seconds, ) messages = [ {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, diff --git a/docker/sql/v2.0.5_0507_add_timeout_seconds_to_model_record_t.sql b/docker/sql/v2.0.5_0507_add_timeout_seconds_to_model_record_t.sql new file mode 100644 index 000000000..6c0ef24db --- /dev/null +++ b/docker/sql/v2.0.5_0507_add_timeout_seconds_to_model_record_t.sql @@ -0,0 +1,10 @@ +-- Migration: Add timeout_seconds column to model_record_t table +-- Date: 2026-05-07 +-- Description: Add timeout_seconds field to control request timeout per model + +-- Add timeout_seconds column to model_record_t table +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER DEFAULT 120; + +-- Add comment to the column +COMMENT ON COLUMN nexent.model_record_t.timeout_seconds IS 'Request timeout in seconds for this model. Default is 120 seconds.'; diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 11391c133..eee1ab277 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -50,6 +50,7 @@ const DEFAULT_FORM_STATE = { url: "", apiKey: "", maxTokens: "4096", + timeoutSeconds: "120", isMultimodal: false, isBatchImport: false, provider: "modelengine", @@ -252,6 +253,7 @@ export const ModelAddDialog = ({ const [selectedModelForSettings, setSelectedModelForSettings] = useState(null); const [modelMaxTokens, setModelMaxTokens] = useState("4096"); + const [modelTimeoutSeconds, setModelTimeoutSeconds] = useState("120"); // Use the silicon model list hook const siliconHook = useSiliconModelList({ @@ -639,23 +641,49 @@ export const ModelAddDialog = ({ const handleSettingsClick = (model: any) => { setSelectedModelForSettings(model); setModelMaxTokens(model.max_tokens?.toString() || "4096"); + setModelTimeoutSeconds(model.timeout_seconds?.toString() || "120"); setSettingsModalVisible(true); }; // Handle settings save - const handleSettingsSave = () => { - if (selectedModelForSettings) { - // Update the model in the list with new max_tokens + const handleSettingsSave = async () => { + if (!selectedModelForSettings) return; + + try { + // Use model_name as the identifier (API returns model_name field, id is combined format) + const modelName = selectedModelForSettings.model_name || selectedModelForSettings.id; + + // Call API to update model settings + await modelService.updateBatchModel( + [ + { + model_id: modelName, + apiKey: selectedModelForSettings.api_key || "", + maxTokens: parseInt(modelMaxTokens) || 4096, + timeoutSeconds: parseInt(modelTimeoutSeconds) || 120, + }, + ], + selectedModelForSettings.model_factory + ); + + // Update the model in the list with new max_tokens and timeout_seconds setModelList((prev) => prev.map((model) => model.id === selectedModelForSettings.id - ? { ...model, max_tokens: parseInt(modelMaxTokens) || 4096 } + ? { + ...model, + max_tokens: parseInt(modelMaxTokens) || 4096, + timeout_seconds: parseInt(modelTimeoutSeconds) || 120, + } : model ) ); + } catch (error) { + console.error("Failed to update model settings:", error); + } finally { + setSettingsModalVisible(false); + setSelectedModelForSettings(null); } - setSettingsModalVisible(false); - setSelectedModelForSettings(null); }; // Handle adding a model @@ -698,6 +726,7 @@ export const ModelAddDialog = ({ apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, maxTokens: maxTokensValue, displayName: form.displayName || form.name, +<<<<<<< HEAD }; // Add STT specific fields @@ -717,6 +746,21 @@ export const ModelAddDialog = ({ } await modelService.createManageTenantModel(modelParams); +======= + expectedChunkSize: isEmbeddingModel + ? form.chunkSizeRange[0] + : undefined, + maximumChunkSize: isEmbeddingModel + ? form.chunkSizeRange[1] + : undefined, + chunkingBatchSize: isEmbeddingModel + ? parseInt(form.chunkingBatchSize) || 10 + : undefined, + timeoutSeconds: !isEmbeddingModel && !isRerankModel + ? parseInt(form.timeoutSeconds) || 120 + : undefined, + }); +>>>>>>> a64daaea1 (Feat: support user to configurate model timeout) } else { const modelParams: any = { name: form.name, @@ -725,6 +769,7 @@ export const ModelAddDialog = ({ apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, maxTokens: maxTokensValue, displayName: form.displayName || form.name, +<<<<<<< HEAD }; // Add STT specific fields @@ -744,6 +789,23 @@ export const ModelAddDialog = ({ } await modelService.addCustomModel(modelParams); +======= + // Send chunk size range for embedding models + ...(isEmbeddingModel + ? { + expectedChunkSize: form.chunkSizeRange[0], + maximumChunkSize: form.chunkSizeRange[1], + chunkingBatchSize: parseInt(form.chunkingBatchSize) || 10, + } + : {}), + // Send timeout for non-embedding models + ...(!isEmbeddingModel && !isRerankModel + ? { + timeoutSeconds: parseInt(form.timeoutSeconds) || 120, + } + : {}), + }); +>>>>>>> a64daaea1 (Feat: support user to configurate model timeout) } // Create the model configuration object @@ -1190,6 +1252,26 @@ export const ModelAddDialog = ({ )} + {/* Timeout Seconds */} + {!isEmbeddingModel && !isRerankModel && !form.isBatchImport && ( +
+ + handleFormChange("timeoutSeconds", e.target.value)} + /> +
+ )} + {/* Connectivity verification area */} {!form.isBatchImport && (
@@ -1713,6 +1795,18 @@ export const ModelAddDialog = ({ placeholder={t("model.dialog.placeholder.maxTokens")} />
+
+ + setModelTimeoutSeconds(e.target.value)} + placeholder="120" + /> +
diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index ad3cf0391..f58ca242e 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -57,6 +57,7 @@ export const ModelDeleteDialog = ({ const [selectedModelForSettings, setSelectedModelForSettings] = useState(null); const [modelMaxTokens, setModelMaxTokens] = useState("4096"); + const [modelTimeoutSeconds, setModelTimeoutSeconds] = useState("120"); const [providerModelSearchTerm, setProviderModelSearchTerm] = useState(""); // Embedding model chunk config modal state @@ -589,9 +590,11 @@ export const ModelDeleteDialog = ({ const handleProviderConfigSave = async ({ apiKey, maxTokens, + timeoutSeconds, }: { apiKey: string; maxTokens: number; + timeoutSeconds?: number; }) => { setMaxTokens(maxTokens); if ( @@ -624,6 +627,7 @@ export const ModelDeleteDialog = ({ model_id: String(m.id), apiKey: apiKey || m.apiKey, maxTokens: maxTokens || m.maxTokens, + ...(timeoutSeconds !== undefined ? { timeoutSeconds } : {}), })); await modelService.updateBatchModel( @@ -653,23 +657,52 @@ export const ModelDeleteDialog = ({ const handleSettingsClick = (model: any) => { setSelectedModelForSettings(model); setModelMaxTokens(model.max_tokens?.toString() || "4096"); + setModelTimeoutSeconds(model.timeout_seconds?.toString() || "120"); setSettingsModalVisible(true); }; // Handle settings save - const handleSettingsSave = () => { - if (selectedModelForSettings) { - // Update the model in the list with new max_tokens + const handleSettingsSave = async () => { + if (!selectedModelForSettings) return; + + try { + // Use model_name as the identifier (API returns model_name field, id is combined format) + const modelName = selectedModelForSettings.model_name || selectedModelForSettings.id; + + // Call API to update model settings + await modelService.updateBatchModel( + [ + { + model_id: modelName, + apiKey: selectedModelForSettings.api_key || "", + maxTokens: parseInt(modelMaxTokens) || 4096, + timeoutSeconds: parseInt(modelTimeoutSeconds) || 120, + }, + ], + selectedModelForSettings.model_factory + ); + + // Update the model in the list with new max_tokens and timeout_seconds setProviderModels((prev) => prev.map((model) => model.id === selectedModelForSettings.id - ? { ...model, max_tokens: parseInt(modelMaxTokens) || 4096 } + ? { + ...model, + max_tokens: parseInt(modelMaxTokens) || 4096, + timeout_seconds: parseInt(modelTimeoutSeconds) || 120, + } : model ) ); + + message.success(t("model.message.updateSuccess") || "Update successful"); + } catch (error) { + console.error("Failed to update model settings:", error); + message.error(t("model.message.updateFailed") || "Failed to update settings"); + } finally { + setSettingsModalVisible(false); + setSelectedModelForSettings(null); } - setSettingsModalVisible(false); - setSelectedModelForSettings(null); }; // Handle embedding model click to open config modal @@ -1542,6 +1575,18 @@ export const ModelDeleteDialog = ({ placeholder={t("model.dialog.placeholder.maxTokens")} /> +
+ + setModelTimeoutSeconds(e.target.value)} + placeholder="120" + /> +
diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 3114c5535..a784258df 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -39,6 +39,7 @@ export const ModelEditDialog = ({ url: "", apiKey: "", maxTokens: "4096", + timeoutSeconds: "120", vectorDimension: "1024", chunkSizeRange: [ DEFAULT_EXPECTED_CHUNK_SIZE, @@ -65,6 +66,7 @@ export const ModelEditDialog = ({ url: model.apiUrl || "", apiKey: model.apiKey || "", maxTokens: model.maxTokens?.toString() || "4096", + timeoutSeconds: model.timeoutSeconds?.toString() || "120", vectorDimension: model.maxTokens?.toString() || "1024", chunkSizeRange: [ model.expectedChunkSize || DEFAULT_EXPECTED_CHUNK_SIZE, @@ -78,7 +80,7 @@ export const ModelEditDialog = ({ const handleFormChange = (field: string, value: string) => { setForm((prev) => ({ ...prev, [field]: value })); // If the key configuration item changes, clear the verification status - if (["url", "apiKey", "maxTokens", "vectorDimension"].includes(field)) { + if (["url", "apiKey", "maxTokens", "timeoutSeconds", "vectorDimension"].includes(field)) { setConnectivityStatus({ status: null, message: "" }); } }; @@ -176,6 +178,7 @@ export const ModelEditDialog = ({ expectedChunkSize: isEmbeddingModel ? form.chunkSizeRange[0] : undefined, maximumChunkSize: isEmbeddingModel ? form.chunkSizeRange[1] : undefined, chunkingBatchSize: isEmbeddingModel ? parseInt(form.chunkingBatchSize) || 10 : undefined, + timeoutSeconds: !isEmbeddingModel && !isRerankModel ? parseInt(form.timeoutSeconds) || 120 : undefined, }); } else { await modelService.updateSingleModel({ @@ -196,6 +199,12 @@ export const ModelEditDialog = ({ chunkingBatchSize: parseInt(form.chunkingBatchSize) || 10, } : {}), + // Send timeout for non-embedding models + ...(!isEmbeddingModel && !isRerankModel + ? { + timeoutSeconds: parseInt(form.timeoutSeconds) || 120, + } + : {}), }); } @@ -303,6 +312,12 @@ export const ModelEditDialog = ({ value={form.maxTokens} onChange={(e) => handleFormChange("maxTokens", e.target.value)} /> + handleFormChange("timeoutSeconds", e.target.value)} + /> )} @@ -408,15 +423,17 @@ interface ProviderConfigEditDialogProps { isOpen: boolean initialApiKey?: string initialMaxTokens?: string + initialTimeoutSeconds?: string modelType?: ModelType onClose: () => void - onSave: (config: { apiKey: string; maxTokens: number }) => Promise | void + onSave: (config: { apiKey: string; maxTokens: number; timeoutSeconds?: number }) => Promise | void } export const ProviderConfigEditDialog = ({ isOpen, initialApiKey = '', initialMaxTokens = '4096', + initialTimeoutSeconds = '120', modelType, onClose, onSave, @@ -424,12 +441,14 @@ export const ProviderConfigEditDialog = ({ const { t } = useTranslation() const [apiKey, setApiKey] = useState(initialApiKey) const [maxTokens, setMaxTokens] = useState(initialMaxTokens) + const [timeoutSeconds, setTimeoutSeconds] = useState(initialTimeoutSeconds) const [saving, setSaving] = useState(false) useEffect(() => { setApiKey(initialApiKey) setMaxTokens(initialMaxTokens) - }, [initialApiKey, initialMaxTokens]) + setTimeoutSeconds(initialTimeoutSeconds) + }, [initialApiKey, initialMaxTokens, initialTimeoutSeconds]) const valid = () => { const parsed = parseInt(maxTokens) @@ -440,7 +459,13 @@ export const ProviderConfigEditDialog = ({ if (!valid()) return try { setSaving(true) - await onSave({ apiKey: apiKey.trim() === '' ? 'sk-no-api-key' : apiKey, maxTokens: parseInt(maxTokens) }) + const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING + const isRerankModel = modelType === MODEL_TYPES.RERANK + await onSave({ + apiKey: apiKey.trim() === '' ? 'sk-no-api-key' : apiKey, + maxTokens: parseInt(maxTokens), + ...(!isEmbeddingModel && !isRerankModel ? { timeoutSeconds: parseInt(timeoutSeconds) || 120 } : {}), + }) onClose() } finally { setSaving(false) @@ -448,6 +473,7 @@ export const ProviderConfigEditDialog = ({ } const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING + const isRerankModel = modelType === MODEL_TYPES.RERANK return ( setMaxTokens(e.target.value)} /> )} + {!isEmbeddingModel && !isRerankModel && ( +
+ + setTimeoutSeconds(e.target.value)} + /> +
+ )}
)} + {/* Concurrency Limit */} + {!isEmbeddingModel && !isRerankModel && !form.isBatchImport && ( +
+ + handleFormChange("concurrencyLimit", e.target.value)} + /> +
+ {t("model.dialog.hint.concurrencyLimit")} +
+
+ )} + {/* Connectivity verification area */} {!form.isBatchImport && (
@@ -1428,7 +1411,7 @@ export const ModelAddDialog = ({ size="small" onClick={(e) => { e.stopPropagation(); // Prevent switch toggle - handleSettingsClick(model); + handleSingleModelSettingsClick(model); }} /> @@ -1773,42 +1756,52 @@ export const ModelAddDialog = ({
- {/* Settings Modal */} - setSettingsModalVisible(false)} - onOk={handleSettingsSave} - cancelText={t("common.cancel")} - okText={t("common.confirm")} - destroyOnHidden - > -
-
- - setModelMaxTokens(e.target.value)} - placeholder={t("model.dialog.placeholder.maxTokens")} - /> -
-
- - setModelTimeoutSeconds(e.target.value)} - placeholder="120" - /> -
-
-
+ {/* Single Model Settings Modal */} + { + setIsSingleModelSettingsOpen(false); + setSelectedSingleModel(null); + }} + initialMaxTokens={selectedSingleModel?.max_tokens?.toString() || "4096"} + initialTimeoutSeconds={selectedSingleModel?.timeout_seconds?.toString() || "120"} + modelType={form.type} + showApiKeyField={false} + onSave={async (config) => { + if (!selectedSingleModel) return; + try { + const modelName = selectedSingleModel.model_name || selectedSingleModel.id; + await modelService.updateBatchModel( + [ + { + model_id: modelName, + apiKey: config.apiKey, + maxTokens: config.maxTokens, + timeoutSeconds: config.timeoutSeconds, + concurrencyLimit: config.concurrencyLimit, + }, + ], + selectedSingleModel.model_factory + ); + + // Update the model in the list + setModelList((prev) => + prev.map((model) => + model.id === selectedSingleModel.id + ? { + ...model, + api_key: config.apiKey, + max_tokens: config.maxTokens, + timeout_seconds: config.timeoutSeconds, + } + : model + ) + ); + } catch (error) { + console.error("Failed to update model settings:", error); + } + }} + />
); }; diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index f58ca242e..0074a9bb5 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -52,12 +52,9 @@ export const ModelDeleteDialog = ({ const [isConfirmLoading, setIsConfirmLoading] = useState(false); const [maxTokens, setMaxTokens] = useState(0); - // Settings modal state - const [settingsModalVisible, setSettingsModalVisible] = useState(false); - const [selectedModelForSettings, setSelectedModelForSettings] = - useState(null); - const [modelMaxTokens, setModelMaxTokens] = useState("4096"); - const [modelTimeoutSeconds, setModelTimeoutSeconds] = useState("120"); + // Single model settings modal state + const [isSingleModelSettingsOpen, setIsSingleModelSettingsOpen] = useState(false); + const [selectedSingleModel, setSelectedSingleModel] = useState(null); const [providerModelSearchTerm, setProviderModelSearchTerm] = useState(""); // Embedding model chunk config modal state @@ -591,10 +588,12 @@ export const ModelDeleteDialog = ({ apiKey, maxTokens, timeoutSeconds, + concurrencyLimit, }: { apiKey: string; maxTokens: number; timeoutSeconds?: number; + concurrencyLimit?: number; }) => { setMaxTokens(maxTokens); if ( @@ -628,6 +627,7 @@ export const ModelDeleteDialog = ({ apiKey: apiKey || m.apiKey, maxTokens: maxTokens || m.maxTokens, ...(timeoutSeconds !== undefined ? { timeoutSeconds } : {}), + ...(concurrencyLimit !== undefined ? { concurrencyLimit } : {}), })); await modelService.updateBatchModel( @@ -643,6 +643,8 @@ export const ModelDeleteDialog = ({ prev.map((model) => ({ ...model, max_tokens: maxTokens || model.max_tokens || 4096, + timeout_seconds: timeoutSeconds || model.timeout_seconds, + concurrency_limit: concurrencyLimit !== undefined ? concurrencyLimit : model.concurrency_limit, })) ); } catch (e) { @@ -653,58 +655,6 @@ export const ModelDeleteDialog = ({ setIsProviderConfigOpen(false); }; - // Handle settings button click - const handleSettingsClick = (model: any) => { - setSelectedModelForSettings(model); - setModelMaxTokens(model.max_tokens?.toString() || "4096"); - setModelTimeoutSeconds(model.timeout_seconds?.toString() || "120"); - setSettingsModalVisible(true); - }; - - // Handle settings save - const handleSettingsSave = async () => { - if (!selectedModelForSettings) return; - - try { - // Use model_name as the identifier (API returns model_name field, id is combined format) - const modelName = selectedModelForSettings.model_name || selectedModelForSettings.id; - - // Call API to update model settings - await modelService.updateBatchModel( - [ - { - model_id: modelName, - apiKey: selectedModelForSettings.api_key || "", - maxTokens: parseInt(modelMaxTokens) || 4096, - timeoutSeconds: parseInt(modelTimeoutSeconds) || 120, - }, - ], - selectedModelForSettings.model_factory - ); - - // Update the model in the list with new max_tokens and timeout_seconds - setProviderModels((prev) => - prev.map((model) => - model.id === selectedModelForSettings.id - ? { - ...model, - max_tokens: parseInt(modelMaxTokens) || 4096, - timeout_seconds: parseInt(modelTimeoutSeconds) || 120, - } - : model - ) - ); - - message.success(t("model.message.updateSuccess") || "Update successful"); - } catch (error) { - console.error("Failed to update model settings:", error); - message.error(t("model.message.updateFailed") || "Failed to update settings"); - } finally { - setSettingsModalVisible(false); - setSelectedModelForSettings(null); - } - }; - // Handle embedding model click to open config modal const handleEmbeddingModelClick = (model: ModelOption | any) => { const isEmbeddingModel = @@ -762,6 +712,12 @@ export const ModelDeleteDialog = ({ } }; + // Handle single model settings button click + const handleSingleModelSettingsClick = (model: any) => { + setSelectedSingleModel(model); + setIsSingleModelSettingsOpen(true); + }; + // Handle embedding config save const handleEmbeddingConfigSave = async () => { if (!selectedEmbeddingModel) return; @@ -1363,7 +1319,7 @@ export const ModelDeleteDialog = ({ size="small" onClick={(e) => { e.stopPropagation(); // Prevent switch toggle - handleSettingsClick(providerModel); + handleSingleModelSettingsClick(providerModel); }} /> @@ -1549,46 +1505,75 @@ export const ModelDeleteDialog = ({ m.source === (selectedSource || MODEL_SOURCES.SILICON) )?.maxTokens || 4096 ).toString()} + initialTimeoutSeconds={( + models.find( + (m) => + m.type === deletingModelType && + m.source === (selectedSource || MODEL_SOURCES.SILICON) + )?.timeoutSeconds?.toString() || "120" + )} + initialConcurrencyLimit={( + models.find( + (m) => + m.type === deletingModelType && + m.source === (selectedSource || MODEL_SOURCES.SILICON) + )?.concurrencyLimit?.toString() || "" + )} modelType={deletingModelType || undefined} onSave={handleProviderConfigSave} /> - {/* Settings Modal */} - setSettingsModalVisible(false)} - onOk={handleSettingsSave} - cancelText={t("common.button.cancel")} - okText={t("common.button.save")} - destroyOnHidden - > -
-
- - setModelMaxTokens(e.target.value)} - placeholder={t("model.dialog.placeholder.maxTokens")} - /> -
-
- - setModelTimeoutSeconds(e.target.value)} - placeholder="120" - /> -
-
-
+ {/* Single Model Settings Modal */} + { + setIsSingleModelSettingsOpen(false); + setSelectedSingleModel(null); + }} + initialMaxTokens={selectedSingleModel?.max_tokens?.toString() || "4096"} + initialTimeoutSeconds={selectedSingleModel?.timeout_seconds?.toString() || "120"} + initialConcurrencyLimit={selectedSingleModel?.concurrency_limit?.toString() || ""} + modelType={deletingModelType || undefined} + showApiKeyField={false} + onSave={async (config) => { + if (!selectedSingleModel) return; + try { + const modelName = selectedSingleModel.model_name || selectedSingleModel.id; + await modelService.updateBatchModel( + [ + { + model_id: modelName, + apiKey: config.apiKey, + maxTokens: config.maxTokens, + timeoutSeconds: config.timeoutSeconds, + concurrencyLimit: config.concurrencyLimit, + }, + ], + selectedSingleModel.model_factory + ); + + // Update the model in the list + setProviderModels((prev) => + prev.map((model) => + model.id === selectedSingleModel.id + ? { + ...model, + api_key: config.apiKey, + max_tokens: config.maxTokens, + timeout_seconds: config.timeoutSeconds, + concurrency_limit: config.concurrencyLimit, + } + : model + ) + ); + + message.success(t("model.message.updateSuccess") || "Update successful"); + } catch (error) { + console.error("Failed to update model settings:", error); + message.error(t("model.message.updateFailed") || "Failed to update settings"); + } + }} + /> {/* Embedding Model Config Modal */} handleFormChange("maxTokens", e.target.value)} /> + + )} + + {/* Timeout Seconds */} + {!isEmbeddingModel && !isRerankModel && ( +
+ )} + {/* Concurrency Limit */} + {!isEmbeddingModel && !isRerankModel && ( +
+ + handleFormChange("concurrencyLimit", e.target.value)} + placeholder={t("model.dialog.placeholder.concurrencyLimit")} + /> +
+ {t("model.dialog.hint.concurrencyLimit")} +
+
+ )} + {/* Chunk Size Range for embedding models */} {isEmbeddingModel && (
@@ -424,9 +456,11 @@ interface ProviderConfigEditDialogProps { initialApiKey?: string initialMaxTokens?: string initialTimeoutSeconds?: string + initialConcurrencyLimit?: string modelType?: ModelType + showApiKeyField?: boolean // Whether to show API Key field (default: true) onClose: () => void - onSave: (config: { apiKey: string; maxTokens: number; timeoutSeconds?: number }) => Promise | void + onSave: (config: { apiKey: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number }) => Promise | void } export const ProviderConfigEditDialog = ({ @@ -434,7 +468,9 @@ export const ProviderConfigEditDialog = ({ initialApiKey = '', initialMaxTokens = '4096', initialTimeoutSeconds = '120', + initialConcurrencyLimit = '', modelType, + showApiKeyField = true, onClose, onSave, }: ProviderConfigEditDialogProps) => { @@ -442,13 +478,15 @@ export const ProviderConfigEditDialog = ({ const [apiKey, setApiKey] = useState(initialApiKey) const [maxTokens, setMaxTokens] = useState(initialMaxTokens) const [timeoutSeconds, setTimeoutSeconds] = useState(initialTimeoutSeconds) + const [concurrencyLimit, setConcurrencyLimit] = useState(initialConcurrencyLimit) const [saving, setSaving] = useState(false) useEffect(() => { setApiKey(initialApiKey) setMaxTokens(initialMaxTokens) setTimeoutSeconds(initialTimeoutSeconds) - }, [initialApiKey, initialMaxTokens, initialTimeoutSeconds]) + setConcurrencyLimit(initialConcurrencyLimit) + }, [initialApiKey, initialMaxTokens, initialTimeoutSeconds, initialConcurrencyLimit]) const valid = () => { const parsed = parseInt(maxTokens) @@ -462,9 +500,10 @@ export const ProviderConfigEditDialog = ({ const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING const isRerankModel = modelType === MODEL_TYPES.RERANK await onSave({ - apiKey: apiKey.trim() === '' ? 'sk-no-api-key' : apiKey, + apiKey: showApiKeyField ? (apiKey.trim() === '' ? 'sk-no-api-key' : apiKey) : '', maxTokens: parseInt(maxTokens), ...(!isEmbeddingModel && !isRerankModel ? { timeoutSeconds: parseInt(timeoutSeconds) || 120 } : {}), + ...(!isEmbeddingModel && !isRerankModel ? { concurrencyLimit: concurrencyLimit ? parseInt(concurrencyLimit) : undefined } : {}), }) onClose() } finally { @@ -484,12 +523,14 @@ export const ProviderConfigEditDialog = ({ destroyOnHidden >
-
- - setApiKey(e.target.value)} visibilityToggle={false} /> -
+ {showApiKeyField && ( +
+ + setApiKey(e.target.value)} visibilityToggle={false} /> +
+ )} {!isEmbeddingModel && (
)} + {!isEmbeddingModel && !isRerankModel && ( +
+ + setConcurrencyLimit(e.target.value)} + placeholder={t("model.dialog.placeholder.concurrencyLimit")} + /> +
+ {t("model.dialog.hint.concurrencyLimit")} +
+
+ )}
- {/* Single Model Settings Modal */} - { - setIsSingleModelSettingsOpen(false); - setSelectedSingleModel(null); - }} - initialMaxTokens={selectedSingleModel?.max_tokens?.toString() || "4096"} - initialTimeoutSeconds={selectedSingleModel?.timeout_seconds?.toString() || "120"} - modelType={form.type} - showApiKeyField={false} - onSave={async (config) => { - if (!selectedSingleModel) return; - try { - const modelName = selectedSingleModel.model_name || selectedSingleModel.id; - await modelService.updateBatchModel( - [ - { - model_id: modelName, - apiKey: config.apiKey, - maxTokens: config.maxTokens, - timeoutSeconds: config.timeoutSeconds, - concurrencyLimit: config.concurrencyLimit, - }, - ], - selectedSingleModel.model_factory - ); - - // Update the model in the list - setModelList((prev) => - prev.map((model) => - model.id === selectedSingleModel.id - ? { - ...model, - api_key: config.apiKey, - max_tokens: config.maxTokens, - timeout_seconds: config.timeoutSeconds, - } - : model - ) - ); - } catch (error) { - console.error("Failed to update model settings:", error); - } - }} - /> + {/* Settings Modal */} + setSettingsModalVisible(false)} + onOk={handleSettingsSave} + cancelText={t("common.cancel")} + okText={t("common.confirm")} + destroyOnHidden + > +
+
+ + setModelMaxTokens(e.target.value)} + placeholder={t("model.dialog.placeholder.maxTokens")} + /> +
+
+
); -}; +}; \ No newline at end of file From 65a7c51a1559c6bfe62443ad529b863299814181 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Sat, 9 May 2026 17:18:23 +0800 Subject: [PATCH 12/48] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=89=8D=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/[locale]/models/components/model/ModelAddDialog.tsx | 1 + frontend/services/modelService.ts | 2 ++ 2 files changed, 3 insertions(+) diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 471963439..94a869301 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -829,6 +829,7 @@ export const ModelAddDialog = ({ }; const isEmbeddingModel = form.type === MODEL_TYPES.EMBEDDING; + const isRerankModel = form.type === MODEL_TYPES.RERANK; const isSTTModel = form.type === MODEL_TYPES.STT; return ( diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index 3538b34f4..07796d2c4 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -124,6 +124,7 @@ export const modelService = { maximum_chunk_size: model.maximumChunkSize, chunk_batch: model.chunkingBatchSize, timeout_seconds: model.timeoutSeconds, + concurrency_limit: model.concurrencyLimit, }; // Add STT specific fields @@ -721,6 +722,7 @@ export const modelService = { api_key: params.apiKey, max_tokens: params.maxTokens || 4096, display_name: params.displayName || params.name, + model_factory: params.modelFactory || "OpenAI-API-Compatible", expected_chunk_size: params.expectedChunkSize, maximum_chunk_size: params.maximumChunkSize, chunk_batch: params.chunkingBatchSize, From 764df4d4acd4128b473105a92abb36654c8a37d5 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 11 May 2026 14:25:51 +0800 Subject: [PATCH 13/48] =?UTF-8?q?=E4=B8=BAnexent-config=E6=8C=82=E8=BD=BD?= =?UTF-8?q?=E8=AF=81=E4=B9=A6=EF=BC=8C=E4=BB=A4=E5=AE=B9=E5=99=A8=E5=86=85?= =?UTF-8?q?=E7=9A=84=20Python=20=E5=BA=94=E7=94=A8=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E5=AE=BF=E4=B8=BB=E6=9C=BA=E7=9A=84=20CA=20=E8=AF=81=E4=B9=A6?= =?UTF-8?q?=E6=9D=A5=E9=AA=8C=E8=AF=81=E5=A4=96=E9=83=A8=20SMTP=20?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=99=A8=E7=9A=84=20SSL=20=E8=AF=81=E4=B9=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker/docker-compose.prod.yml | 2 ++ docker/docker-compose.yml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 934fe8b2f..3cc7ac59a 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -78,6 +78,8 @@ services: - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro - ${ROOT_DIR}/scripts/sync_user_supabase2pg.py:/opt/sync_user_supabase2pg.py:ro - /var/run/docker.sock:/var/run/docker.sock:ro # Docker socket for MCP container management + # CA certificates for external service SSL verification (e.g., SMTP) + - /etc/ssl/certs:/etc/ssl/certs:ro environment: <<: [*minio-vars, *es-vars] skip_proxy: "true" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 89088f2c3..4056683dc 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -89,6 +89,8 @@ services: - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro - ${ROOT_DIR}/scripts/sync_user_supabase2pg.py:/opt/sync_user_supabase2pg.py:ro - /var/run/docker.sock:/var/run/docker.sock:ro # Docker socket for MCP container management + # CA certificates for external service SSL verification (e.g., SMTP) + - /etc/ssl/certs:/etc/ssl/certs:ro environment: <<: [*minio-vars, *es-vars] skip_proxy: "true" From 3682d6bef8afafc18647bf8d4c9dc3e1f0c4268f Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 11 May 2026 14:36:38 +0800 Subject: [PATCH 14/48] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=81=A5=E5=BA=B7=E6=A3=80=E6=9F=A5=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/model_health_service.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index e22f6c642..73adacc00 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -97,23 +97,23 @@ async def _perform_connectivity_check( # Test connectivity based on different model types if model_type == "embedding": - connectivity = len(await OpenAICompatibleEmbedding( + embedding = OpenAICompatibleEmbedding( model_name=model_name, base_url=model_base_url, api_key=model_api_key, embedding_dim=0, ssl_verify=ssl_verify, - timeout_seconds=timeout_seconds, - ).dimension_check()) > 0 + ) + connectivity = len(await embedding.dimension_check(timeout=timeout_seconds if timeout_seconds else 5.0)) > 0 elif model_type == "multi_embedding": - connectivity = len(await JinaEmbedding( + embedding = JinaEmbedding( model_name=model_name, base_url=model_base_url, api_key=model_api_key, embedding_dim=0, ssl_verify=ssl_verify, - timeout_seconds=timeout_seconds, - ).dimension_check()) > 0 + ) + connectivity = len(await embedding.dimension_check(timeout=timeout_seconds if timeout_seconds else 5.0)) > 0 elif model_type == "llm": observer = MessageObserver() set_monitoring_operation("connectivity_check", From e38dce10bac5903dc9c646146812434f7d56b4c0 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 11 May 2026 15:13:59 +0800 Subject: [PATCH 15/48] =?UTF-8?q?=E5=8C=BA=E5=88=86send=20email=E9=92=88?= =?UTF-8?q?=E5=AF=B9=E6=98=AF=E5=90=A6=E8=B7=B3=E8=BF=87=E8=AF=81=E4=B9=A6?= =?UTF-8?q?=E6=A0=A1=E9=AA=8C=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/nexent/core/tools/send_email_tool.py | 49 ++++++-- test/sdk/core/tools/test_send_email_tool.py | 122 +++++++++++++------- 2 files changed, 120 insertions(+), 51 deletions(-) diff --git a/sdk/nexent/core/tools/send_email_tool.py b/sdk/nexent/core/tools/send_email_tool.py index 2451020ea..097ad838c 100644 --- a/sdk/nexent/core/tools/send_email_tool.py +++ b/sdk/nexent/core/tools/send_email_tool.py @@ -65,8 +65,8 @@ class SendEmailTool(Tool): "description_zh": "SMTP 服务器密码" }, "use_ssl": { - "description": "Use SSL", - "description_zh": "使用 SSL" + "description": "Use SSL/TLS encryption (set to False for plain text)", + "description_zh": "使用 SSL/TLS 加密(设为 False 使用明文)" }, "sender_name": { "description": "Sender name", @@ -80,13 +80,13 @@ class SendEmailTool(Tool): output_type = "string" category = ToolCategory.EMAIL.value - def __init__(self, smtp_server: str=Field(description="SMTP Server Address"), - smtp_port: int=Field(description="SMTP server port"), - username: str=Field(description="SMTP server username"), - password: str=Field(description="SMTP server password"), - use_ssl: bool=Field(description="Use SSL", default=True), - sender_name: Optional[str] = Field(description="Sender name", default=None), - timeout: int = Field(description="Timeout", default=30)): + def __init__(self, smtp_server: str = "", + smtp_port: int = 587, + username: str = "", + password: str = "", + use_ssl: bool = True, + sender_name: Optional[str] = None, + timeout: int = 30): super().__init__() self.smtp_server = smtp_server self.smtp_port = smtp_port @@ -96,6 +96,18 @@ def __init__(self, smtp_server: str=Field(description="SMTP Server Address"), self.sender_name = sender_name self.timeout = timeout + def _create_ssl_context(self, skip_verify: bool = False) -> ssl.SSLContext: + """Create SSL context with optional verification disabled for self-signed certs.""" + context = ssl.create_default_context() + if skip_verify: + logger.warning("SSL verification disabled - use only for internal/local SMTP servers") + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.check_hostname = True + context.verify_mode = ssl.CERT_REQUIRED + return context + def forward(self, to: str, subject: str, content: str, cc: str = "", bcc: str = "") -> str: try: logger.info("Creating email message...") @@ -119,13 +131,26 @@ def forward(self, to: str, subject: str, content: str, cc: str = "", bcc: str = if self.smtp_port == 465: # Port 465 uses implicit SSL logger.info("Using implicit SSL connection (port 465)...") - context = ssl.create_default_context() + context = self._create_ssl_context(skip_verify=False) server = smtplib.SMTP_SSL(self.smtp_server, self.smtp_port, context=context, timeout=self.timeout) - else: + elif self.use_ssl: # Port 587 (and others) use STARTTLS logger.info("Using STARTTLS connection...") server = smtplib.SMTP(self.smtp_server, self.smtp_port, timeout=self.timeout) - server.starttls(context=ssl.create_default_context()) + server.starttls(context=self._create_ssl_context(skip_verify=False)) + else: + # Port 25 - plain connection (may have self-signed certs) + logger.info("Using plain text connection (port 25)...") + server = smtplib.SMTP(self.smtp_server, self.smtp_port, timeout=self.timeout) + # Some servers force TLS handshake even on plain connections + # Skip cert verification for port 25 to handle self-signed certs + try: + server.starttls(context=self._create_ssl_context(skip_verify=True)) + logger.info("Server upgraded to TLS connection") + except smtplib.SMTPNotSupportedError: + logger.info("Server does not support STARTTLS, using plain connection") + except Exception as tls_err: + logger.warning(f"TLS upgrade failed: {tls_err}, continuing with plain connection") logger.info("Logging in...") # Login diff --git a/test/sdk/core/tools/test_send_email_tool.py b/test/sdk/core/tools/test_send_email_tool.py index 1287a4f53..88b279eb2 100644 --- a/test/sdk/core/tools/test_send_email_tool.py +++ b/test/sdk/core/tools/test_send_email_tool.py @@ -60,6 +60,17 @@ def test_init_with_custom_values(self): assert tool.sender_name == "Custom Sender" assert tool.timeout == 60 + def test_init_use_ssl_default(self): + """Test that use_ssl defaults to True""" + tool = SendEmailTool( + smtp_server="smtp.example.com", + smtp_port=587, + username="user@example.com", + password="password123" + ) + assert tool.use_ssl is True + assert tool.timeout == 30 + def test_tool_attributes(self, send_email_tool): """Test tool class attributes""" assert send_email_tool.name == "send_email" @@ -91,9 +102,9 @@ def test_tool_inputs_schema(self, send_email_tool): assert inputs["bcc"]["type"] == "string" assert inputs["bcc"]["nullable"] is True - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_success_basic_email(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_success_basic_email(self, mock_ssl_context, mock_smtp, send_email_tool): """Test successful basic email sending""" # Mock SSL context mock_context = Mock() @@ -101,7 +112,7 @@ def test_forward_success_basic_email(self, mock_ssl_context, mock_smtp_ssl, send # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="recipient@example.com", @@ -119,17 +130,16 @@ def test_forward_success_basic_email(self, mock_ssl_context, mock_smtp_ssl, send assert result_data["subject"] == "Test Subject" # Verify SMTP operations - mock_smtp_ssl.assert_called_once_with( - "smtp.test.com", 587, context=mock_context, timeout=30 - ) + mock_smtp.assert_called_once_with("smtp.test.com", 587, timeout=30) + mock_server.starttls.assert_called_once_with(context=mock_context) mock_server.login.assert_called_once_with( "test@test.com", "test_password") mock_server.send_message.assert_called_once() mock_server.quit.assert_called_once() - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_success_with_cc_and_bcc(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_success_with_cc_and_bcc(self, mock_ssl_context, mock_smtp, send_email_tool): """Test successful email sending with CC and BCC""" # Mock SSL context mock_context = Mock() @@ -137,7 +147,7 @@ def test_forward_success_with_cc_and_bcc(self, mock_ssl_context, mock_smtp_ssl, # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="recipient@example.com", @@ -164,9 +174,9 @@ def test_forward_success_with_cc_and_bcc(self, mock_ssl_context, mock_smtp_ssl, assert call_args['Cc'] == "cc1@example.com,cc2@example.com" assert call_args['Bcc'] == "bcc@example.com" - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_success_multiple_recipients(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_success_multiple_recipients(self, mock_ssl_context, mock_smtp, send_email_tool): """Test successful email sending with multiple recipients""" # Mock SSL context mock_context = Mock() @@ -174,7 +184,7 @@ def test_forward_success_multiple_recipients(self, mock_ssl_context, mock_smtp_s # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="recipient1@example.com,recipient2@example.com", @@ -191,9 +201,9 @@ def test_forward_success_multiple_recipients(self, mock_ssl_context, mock_smtp_s assert result_data["status"] == "success" assert result_data["to"] == "recipient1@example.com,recipient2@example.com" - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_smtp_send_error(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_smtp_send_error(self, mock_ssl_context, mock_smtp, send_email_tool): """Test email sending with SMTP send error""" # Mock SSL context mock_context = Mock() @@ -204,7 +214,7 @@ def test_forward_smtp_send_error(self, mock_ssl_context, mock_smtp_ssl, send_ema mock_server.send_message.side_effect = smtplib.SMTPRecipientsRefused( "Recipients refused" ) - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="recipient@example.com", @@ -219,9 +229,9 @@ def test_forward_smtp_send_error(self, mock_ssl_context, mock_smtp_ssl, send_ema assert result_data["status"] == "error" assert "Failed to send email" in result_data["message"] - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_unexpected_exception(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_unexpected_exception(self, mock_ssl_context, mock_smtp, send_email_tool): """Test email sending with unexpected exception""" # Mock SSL context mock_context = Mock() @@ -230,7 +240,7 @@ def test_forward_unexpected_exception(self, mock_ssl_context, mock_smtp_ssl, sen # Mock SMTP server with unexpected error mock_server = Mock() mock_server.login.side_effect = RuntimeError("Unexpected error") - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="recipient@example.com", @@ -246,9 +256,9 @@ def test_forward_unexpected_exception(self, mock_ssl_context, mock_smtp_ssl, sen assert "An unexpected error occurred" in result_data["message"] assert "Unexpected error" in result_data["message"] - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_empty_cc_and_bcc(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_empty_cc_and_bcc(self, mock_ssl_context, mock_smtp, send_email_tool): """Test email sending with empty CC and BCC""" # Mock SSL context mock_context = Mock() @@ -256,7 +266,7 @@ def test_forward_empty_cc_and_bcc(self, mock_ssl_context, mock_smtp_ssl, send_em # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="recipient@example.com", @@ -277,9 +287,9 @@ def test_forward_empty_cc_and_bcc(self, mock_ssl_context, mock_smtp_ssl, send_em assert 'Cc' not in call_args assert 'Bcc' not in call_args - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_html_content_attachment(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_html_content_attachment(self, mock_ssl_context, mock_smtp, send_email_tool): """Test that HTML content is properly attached to email""" # Mock SSL context mock_context = Mock() @@ -287,7 +297,7 @@ def test_forward_html_content_attachment(self, mock_ssl_context, mock_smtp_ssl, # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server html_content = "

Test Header

This is bold text.

" @@ -314,17 +324,19 @@ def test_forward_html_content_attachment(self, mock_ssl_context, mock_smtp_ssl, assert attachments[0].get_content_type() == "text/html" assert attachments[0].get_payload() == html_content - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_ssl_context_configuration(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): - """Test SSL context is properly configured""" + def test_forward_ssl_context_configuration(self, mock_ssl_context, mock_smtp, send_email_tool): + """Test SSL context is properly configured for STARTTLS""" # Mock SSL context mock_context = Mock() + mock_context.check_hostname = True + mock_context.verify_mode = ssl.CERT_REQUIRED mock_ssl_context.return_value = mock_context # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server send_email_tool.forward( to="recipient@example.com", @@ -332,16 +344,48 @@ def test_forward_ssl_context_configuration(self, mock_ssl_context, mock_smtp_ssl content="

Test content

" ) - # Verify SSL context configuration + # Verify SSL context is created (default settings preserved) mock_ssl_context.assert_called_once() - assert mock_context.check_hostname is True - assert mock_context.verify_mode == ssl.CERT_REQUIRED - # Verify SMTP_SSL is called with context - mock_smtp_ssl.assert_called_once_with( - "smtp.test.com", 587, context=mock_context, timeout=30 + # Verify STARTTLS is called with context + mock_server.starttls.assert_called_once_with(context=mock_context) + + @patch('smtplib.SMTP') + @patch('ssl.create_default_context') + def test_forward_port_25_skips_ssl_verification(self, mock_ssl_context, mock_smtp): + """Test that port 25 skips SSL certificate verification for self-signed certs""" + # Create tool with port 25 + tool = SendEmailTool( + smtp_server="smtp.local.com", + smtp_port=25, + username="user@example.com", + password="password123", + use_ssl=False + ) + + # Mock SSL context + mock_context = Mock() + mock_context.check_hostname = False + mock_context.verify_mode = ssl.CERT_NONE + mock_ssl_context.return_value = mock_context + + # Mock SMTP server + mock_server = Mock() + mock_smtp.return_value = mock_server + + result = tool.forward( + to="recipient@example.com", + subject="Test Subject", + content="

Test content

" ) + # Parse result + result_data = json.loads(result) + assert result_data["status"] == "success" + + # Verify STARTTLS is called with context for self-signed certs + mock_server.starttls.assert_called_once_with(context=mock_context) + @patch('smtplib.SMTP_SSL') @patch('ssl.create_default_context') def test_forward_timeout_configuration(self, mock_ssl_context, mock_smtp_ssl): @@ -374,9 +418,9 @@ def test_forward_timeout_configuration(self, mock_ssl_context, mock_smtp_ssl): "smtp.example.com", 465, context=mock_context, timeout=60 ) - @patch('smtplib.SMTP_SSL') + @patch('smtplib.SMTP') @patch('ssl.create_default_context') - def test_forward_server_quit_called_on_success(self, mock_ssl_context, mock_smtp_ssl, send_email_tool): + def test_forward_server_quit_called_on_success(self, mock_ssl_context, mock_smtp, send_email_tool): """Test that server.quit() is called on successful send""" # Mock SSL context mock_context = Mock() @@ -384,7 +428,7 @@ def test_forward_server_quit_called_on_success(self, mock_ssl_context, mock_smtp # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server send_email_tool.forward( to="recipient@example.com", @@ -397,7 +441,7 @@ def test_forward_server_quit_called_on_success(self, mock_ssl_context, mock_smtp def test_forward_empty_parameters(self, send_email_tool): """Test forward method with empty parameters""" - with patch('smtplib.SMTP_SSL') as mock_smtp_ssl, \ + with patch('smtplib.SMTP') as mock_smtp, \ patch('ssl.create_default_context') as mock_ssl_context: # Mock SSL context @@ -406,7 +450,7 @@ def test_forward_empty_parameters(self, send_email_tool): # Mock SMTP server mock_server = Mock() - mock_smtp_ssl.return_value = mock_server + mock_smtp.return_value = mock_server result = send_email_tool.forward( to="", From d53ab2387fac05e9780af3a1c370d88b30888630 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 11 May 2026 17:08:37 +0800 Subject: [PATCH 16/48] =?UTF-8?q?=E5=8C=BA=E5=88=86sender=5Femail=E5=92=8C?= =?UTF-8?q?=E5=92=8Csender=5Fname?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/tool_configuration_service.py | 26 +++++++++--- backend/utils/tool_utils.py | 3 +- sdk/nexent/core/tools/send_email_tool.py | 34 ++++++++++++---- test/sdk/core/tools/test_send_email_tool.py | 40 ++++++++++++++++++- 4 files changed, 87 insertions(+), 16 deletions(-) diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 5e5229ff6..0f779cb98 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -130,11 +130,15 @@ def get_local_tools() -> List[ToolInfo]: if hasattr(param.default, 'exclude') and param.default.exclude: continue + # Check if default is a Pydantic FieldInfo (has .default attribute) + is_pydantic_field = hasattr(param.default, 'default') + # Get description in both languages - param_description = param.default.description if hasattr(param.default, 'description') else "" + param_description = param.default.description if is_pydantic_field else "" # First try to get from param.default.description_zh (FieldInfo) - param_description_zh = param.default.description_zh if hasattr(param.default, 'description_zh') else None + # Note: Pydantic Field doesn't have description_zh attribute, so use getattr with default + param_description_zh = getattr(param.default, 'description_zh', None) if is_pydantic_field else None # Fallback to init_param_descriptions if not found if param_description_zh is None and param_name in init_param_descriptions: @@ -146,11 +150,21 @@ def get_local_tools() -> List[ToolInfo]: "description": param_description, "description_zh": param_description_zh } - if param.default.default is PydanticUndefined: - param_info["optional"] = False + + # Handle both Pydantic FieldInfo and simple defaults + if is_pydantic_field: + if param.default.default is PydanticUndefined: + param_info["optional"] = False + else: + param_info["default"] = param.default.default + param_info["optional"] = True else: - param_info["default"] = param.default.default - param_info["optional"] = True + # Simple default value (not a FieldInfo) + if param.default == inspect.Parameter.empty: + param_info["optional"] = False + else: + param_info["default"] = param.default + param_info["optional"] = True init_params_list.append(param_info) diff --git a/backend/utils/tool_utils.py b/backend/utils/tool_utils.py index f06f36bc3..f1d9147e3 100644 --- a/backend/utils/tool_utils.py +++ b/backend/utils/tool_utils.py @@ -46,7 +46,8 @@ def get_local_tools_description_zh() -> Dict[str, Dict]: if hasattr(param.default, 'exclude') and param.default.exclude: continue - param_description_zh = param.default.description_zh if hasattr(param.default, 'description_zh') else None + # Note: Pydantic Field doesn't have description_zh attribute + param_description_zh = getattr(param.default, 'description_zh', None) if hasattr(param.default, 'description_zh') else None if param_description_zh is None and param_name in init_param_descriptions: param_description_zh = init_param_descriptions[param_name].get('description_zh') diff --git a/sdk/nexent/core/tools/send_email_tool.py b/sdk/nexent/core/tools/send_email_tool.py index 097ad838c..42453e16b 100644 --- a/sdk/nexent/core/tools/send_email_tool.py +++ b/sdk/nexent/core/tools/send_email_tool.py @@ -44,6 +44,12 @@ class SendEmailTool(Tool): "description": "BCC email address, multiple BCCs separated by commas, optional", "description_zh": "密送邮箱地址,多个密送用逗号分隔,可选", "nullable": True + }, + "sender_email": { + "type": "string", + "description": "Actual sender email address (From address), optional - defaults to username", + "description_zh": "实际发件人邮箱地址(From字段),可选,默认为username", + "nullable": True } } @@ -68,6 +74,10 @@ class SendEmailTool(Tool): "description": "Use SSL/TLS encryption (set to False for plain text)", "description_zh": "使用 SSL/TLS 加密(设为 False 使用明文)" }, + "sender_email": { + "description": "Actual sender email address (From address), defaults to username", + "description_zh": "实际发件人邮箱地址,默认为 username" + }, "sender_name": { "description": "Sender name", "description_zh": "发件人名称" @@ -81,10 +91,11 @@ class SendEmailTool(Tool): category = ToolCategory.EMAIL.value def __init__(self, smtp_server: str = "", - smtp_port: int = 587, - username: str = "", - password: str = "", + smtp_port: int = 587, + username: str = "", + password: str = "", use_ssl: bool = True, + sender_email: Optional[str] = None, sender_name: Optional[str] = None, timeout: int = 30): super().__init__() @@ -93,6 +104,7 @@ def __init__(self, smtp_server: str = "", self.username = username self.password = password self.use_ssl = use_ssl + self.sender_email = sender_email or username self.sender_name = sender_name self.timeout = timeout @@ -108,12 +120,18 @@ def _create_ssl_context(self, skip_verify: bool = False) -> ssl.SSLContext: context.verify_mode = ssl.CERT_REQUIRED return context - def forward(self, to: str, subject: str, content: str, cc: str = "", bcc: str = "") -> str: + def forward(self, to: str, subject: str, content: str, cc: str = "", bcc: str = "", + sender_email: Optional[str] = None) -> str: try: logger.info("Creating email message...") - # Create email object msg = MIMEMultipart() - msg['From'] = f"{self.sender_name} <{self.username}>" if self.sender_name else self.username + + sender = sender_email or self.sender_email + if self.sender_name: + msg['From'] = f"{self.sender_name} <{sender}>" + else: + msg['From'] = sender + msg['To'] = to msg['Subject'] = subject @@ -131,13 +149,13 @@ def forward(self, to: str, subject: str, content: str, cc: str = "", bcc: str = if self.smtp_port == 465: # Port 465 uses implicit SSL logger.info("Using implicit SSL connection (port 465)...") - context = self._create_ssl_context(skip_verify=False) + context = self._create_ssl_context(skip_verify=True) server = smtplib.SMTP_SSL(self.smtp_server, self.smtp_port, context=context, timeout=self.timeout) elif self.use_ssl: # Port 587 (and others) use STARTTLS logger.info("Using STARTTLS connection...") server = smtplib.SMTP(self.smtp_server, self.smtp_port, timeout=self.timeout) - server.starttls(context=self._create_ssl_context(skip_verify=False)) + server.starttls(context=self._create_ssl_context(skip_verify=True)) else: # Port 25 - plain connection (may have self-signed certs) logger.info("Using plain text connection (port 25)...") diff --git a/test/sdk/core/tools/test_send_email_tool.py b/test/sdk/core/tools/test_send_email_tool.py index 88b279eb2..d3bc9f946 100644 --- a/test/sdk/core/tools/test_send_email_tool.py +++ b/test/sdk/core/tools/test_send_email_tool.py @@ -19,6 +19,7 @@ def send_email_tool(): username="test@test.com", password="test_password", use_ssl=True, + sender_email="actual@test.com", sender_name="Test Sender", timeout=30 ) @@ -102,6 +103,10 @@ def test_tool_inputs_schema(self, send_email_tool): assert inputs["bcc"]["type"] == "string" assert inputs["bcc"]["nullable"] is True + assert "sender_email" in inputs + assert inputs["sender_email"]["type"] == "string" + assert inputs["sender_email"]["nullable"] is True + @patch('smtplib.SMTP') @patch('ssl.create_default_context') def test_forward_success_basic_email(self, mock_ssl_context, mock_smtp, send_email_tool): @@ -168,7 +173,7 @@ def test_forward_success_with_cc_and_bcc(self, mock_ssl_context, mock_smtp, send call_args = mock_server.send_message.call_args[0][0] # Verify email headers - assert call_args['From'] == "Test Sender " + assert call_args['From'] == "Test Sender " assert call_args['To'] == "recipient@example.com" assert call_args['Subject'] == "Test Subject" assert call_args['Cc'] == "cc1@example.com,cc2@example.com" @@ -466,6 +471,39 @@ def test_forward_empty_parameters(self, send_email_tool): assert result_data["to"] == "" assert result_data["subject"] == "" + @patch('smtplib.SMTP') + @patch('ssl.create_default_context') + def test_forward_sender_email_override(self, mock_ssl_context, mock_smtp): + """Test that sender_email parameter in forward overrides instance sender_email""" + tool = SendEmailTool( + smtp_server="smtp.test.com", + smtp_port=587, + username="auth@test.com", + password="password", + use_ssl=True, + sender_email="instance@test.com", + sender_name="Instance Sender" + ) + + mock_context = Mock() + mock_ssl_context.return_value = mock_context + + mock_server = Mock() + mock_smtp.return_value = mock_server + + result = tool.forward( + to="recipient@example.com", + subject="Test Subject", + content="

Test content

", + sender_email="override@test.com" + ) + + result_data = json.loads(result) + assert result_data["status"] == "success" + + call_args = mock_server.send_message.call_args[0][0] + assert call_args['From'] == "Instance Sender " + if __name__ == '__main__': pytest.main([__file__]) From a4c17f034880276c884455e60d70e9ce8989079b Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 11 May 2026 20:28:51 +0800 Subject: [PATCH 17/48] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E6=98=8A=E5=A4=A9=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=88=97=E8=A1=A8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/haotian_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/services/haotian_service.py b/backend/services/haotian_service.py index a49079ec7..97c5db564 100644 --- a/backend/services/haotian_service.py +++ b/backend/services/haotian_service.py @@ -77,7 +77,7 @@ async def fetch_haotian_knowledge_sets_impl( ) headers = {"Authorization": external_authorization} - async with httpx.AsyncClient(timeout=timeout_s, follow_redirects=True) as client: + async with httpx.AsyncClient(timeout=timeout_s, follow_redirects=True, trust_env=False) as client: resp = await client.get(list_url, headers=headers) if resp.status_code >= 400: raise RuntimeError( From 0635864042ea88f5cc393cffa29da544c8bdb2ca Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 12 May 2026 16:18:30 +0800 Subject: [PATCH 18/48] Create a session with trust_env=False to ignore proxy environment variables --- sdk/nexent/core/models/embedding_model.py | 12 ++++++++++-- sdk/nexent/core/models/openai_llm.py | 16 +++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sdk/nexent/core/models/embedding_model.py b/sdk/nexent/core/models/embedding_model.py index 092877941..a7379efcb 100644 --- a/sdk/nexent/core/models/embedding_model.py +++ b/sdk/nexent/core/models/embedding_model.py @@ -171,6 +171,10 @@ def __init__( self.model = model_name self.embedding_dim = embedding_dim self.ssl_verify = ssl_verify + + # Create a session with trust_env=False to ignore proxy environment variables + self.session = requests.Session() + self.session.trust_env = False self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} @@ -189,7 +193,7 @@ def _make_request(self, data: Dict[str, Any], timeout: Optional[float] = None) - Returns: Dict[str, Any]: API response """ - response = requests.post(self.api_url, headers=self.headers, json=data, timeout=timeout, verify=self.ssl_verify) + response = self.session.post(self.api_url, headers=self.headers, json=data, timeout=timeout, verify=self.ssl_verify) response.raise_for_status() return response.json() @@ -332,6 +336,10 @@ def __init__(self, model_name: str, base_url: str, api_key: str, embedding_dim: self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + # Create a session with trust_env=False to ignore proxy environment variables + self.session = requests.Session() + self.session.trust_env = False + def _prepare_input(self, inputs: Union[str, List[str]]) -> Dict[str, Any]: """Prepare the input data for the API request.""" if isinstance(inputs, str): @@ -349,7 +357,7 @@ def _make_request(self, data: Dict[str, Any], timeout: Optional[float] = None) - Returns: Dict[str, Any]: API response """ - response = requests.post(self.api_url, headers=self.headers, json=data, timeout=timeout, verify=self.ssl_verify) + response = self.session.post(self.api_url, headers=self.headers, json=data, timeout=timeout, verify=self.ssl_verify) response.raise_for_status() return response.json() diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 4c41e0021..02c4f74bd 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -56,15 +56,13 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, if concurrency_limit is not None and concurrency_limit > 0: self._semaphore = asyncio.Semaphore(concurrency_limit) - # Create http_client based on ssl_verify parameter and timeout_seconds - if not ssl_verify or timeout_seconds is not None: - import httpx - # Build timeout configuration - timeout = httpx.Timeout(timeout_seconds) if timeout_seconds is not None else httpx.Timeout(120.0) - http_client = httpx.Client(verify=ssl_verify, timeout=timeout) - client_kwargs = kwargs.get('client_kwargs', {}) - client_kwargs['http_client'] = http_client - kwargs['client_kwargs'] = client_kwargs + # Create http_client with trust_env=False to ignore proxy env vars + import httpx + timeout = httpx.Timeout(timeout_seconds) if timeout_seconds is not None else httpx.Timeout(120.0) + http_client = httpx.Client(verify=ssl_verify, timeout=timeout, trust_env=False) + client_kwargs = kwargs.get('client_kwargs', {}) + client_kwargs['http_client'] = http_client + kwargs['client_kwargs'] = client_kwargs super().__init__(*args, **kwargs) From 842f31270c6c3b600968c6c03fa352a81a9973d2 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 12 May 2026 19:26:57 +0800 Subject: [PATCH 19/48] =?UTF-8?q?=E8=AE=BE=E7=BD=AEgenerate=5Ftitle?= =?UTF-8?q?=E4=B8=BA=E9=9D=9E=E6=B5=81=E5=BC=8F=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/conversation_management_service.py | 4 ++-- sdk/nexent/core/models/openai_llm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index c3571fcf3..f03c32512 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -275,8 +275,8 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" if model_config.get("model_factory", "").lower() == "modelengine": messages = [{"role": msg["role"], "content": str(msg.get("content", ""))} for msg in messages] - # Call the model - response = llm.generate(messages) + # Call the model with stream=False to get a single response + response = llm.generate(messages, stream=False) if not response or not response.content or not response.content.strip(): return DEFAULT_EN_TITLE if language == LANGUAGE["EN"] else DEFAULT_ZH_TITLE return remove_think_blocks(response.content.strip()) diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 02c4f74bd..918e362a3 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -142,7 +142,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List completion_kwargs["stream_options"] = {"include_usage": True} current_request = self.client.chat.completions.create( - stream=True, **completion_kwargs) + stream=kwargs.get("stream", True), **completion_kwargs) # Validate response type: ensure we got a proper iterator, not error strings or dicts # Some APIs return error strings like "error: rate limit" or JSON dicts on failure From 9091a81204999048d7bac76b60c4d67ec0cb6180 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 12 May 2026 20:01:49 +0800 Subject: [PATCH 20/48] =?UTF-8?q?Revert=20"=E8=AE=BE=E7=BD=AEgenerate=5Fti?= =?UTF-8?q?tle=E4=B8=BA=E9=9D=9E=E6=B5=81=E5=BC=8F=E6=8E=A5=E5=8F=A3"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d1cffeb589b3ea2cb735d42d4d1ab7f61e125b39. --- backend/services/conversation_management_service.py | 4 ++-- sdk/nexent/core/models/openai_llm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index f03c32512..c3571fcf3 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -275,8 +275,8 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" if model_config.get("model_factory", "").lower() == "modelengine": messages = [{"role": msg["role"], "content": str(msg.get("content", ""))} for msg in messages] - # Call the model with stream=False to get a single response - response = llm.generate(messages, stream=False) + # Call the model + response = llm.generate(messages) if not response or not response.content or not response.content.strip(): return DEFAULT_EN_TITLE if language == LANGUAGE["EN"] else DEFAULT_ZH_TITLE return remove_think_blocks(response.content.strip()) diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 918e362a3..02c4f74bd 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -142,7 +142,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List completion_kwargs["stream_options"] = {"include_usage": True} current_request = self.client.chat.completions.create( - stream=kwargs.get("stream", True), **completion_kwargs) + stream=True, **completion_kwargs) # Validate response type: ensure we got a proper iterator, not error strings or dicts # Some APIs return error strings like "error: rate limit" or JSON dicts on failure From ba854712e8cfd92b69eb797f26f05d342829a317 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 12 May 2026 20:25:32 +0800 Subject: [PATCH 21/48] =?UTF-8?q?"=E8=AE=BE=E7=BD=AEgenerate=5Ftitle?= =?UTF-8?q?=E4=B8=BA=E9=9D=9E=E6=B5=81=E5=BC=8F=E6=8E=A5=E5=8F=A3"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/conversation_management_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index c3571fcf3..302ec63a8 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -260,6 +260,7 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" model_factory=model_config.get("model_factory", None), ssl_verify=model_config.get("ssl_verify", True), timeout_seconds=timeout_seconds, + stream=False, ) # Build messages - use new template variable 'question' instead of 'content' From 64dc28484432fac835c7deb573ce9122e36bcf63 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Wed, 13 May 2026 14:19:10 +0800 Subject: [PATCH 22/48] =?UTF-8?q?=E8=AE=BE=E7=BD=AEauthorization=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=E4=B9=9F=E4=B8=BA=E5=AF=86=E7=A0=81=E5=B1=95=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agentConfig/tool/ToolConfigModal.tsx | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 53c6d3f03..39c3bbce2 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -1474,10 +1474,21 @@ export default function ToolConfigModal({ case TOOL_PARAM_TYPES.ARRAY: case TOOL_PARAM_TYPES.OBJECT: default: - // Check if parameter name contains "password" for secure input - const isPasswordType = param.name.toLowerCase().includes("password"); + // Check if parameter name indicates a secure/sensitive field + const sensitivePatterns = [ + "password", + "authorization", + "api_key", + "apikey", + "api-key", + "secret", + "token", + ]; + const isSecureField = sensitivePatterns.some((pattern) => + param.name.toLowerCase().includes(pattern) + ); - if (isPasswordType) { + if (isSecureField) { return ( Date: Wed, 13 May 2026 18:18:25 +0800 Subject: [PATCH 23/48] =?UTF-8?q?=E5=A6=82=E6=9E=9C=E6=98=AF=E5=85=AC?= =?UTF-8?q?=E5=85=B1=E7=9F=A5=E8=AF=86=E5=BA=93=EF=BC=8C=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E9=BB=98=E8=AE=A4id?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/haotian_service.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/backend/services/haotian_service.py b/backend/services/haotian_service.py index 97c5db564..e7f762244 100644 --- a/backend/services/haotian_service.py +++ b/backend/services/haotian_service.py @@ -11,6 +11,8 @@ logger = logging.getLogger("haotian_service") +_DEFAULT_KNOWLEDGE_BASE_ID = "abcdefg" + def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: """ @@ -24,7 +26,7 @@ def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: ] } - This function also filters out knowledge sets with name == "Public". + When dify_dataset_id is "null", it is replaced with the default ID. """ knowledge_sets = raw.get("knowledge_sets", []) if not isinstance(knowledge_sets, list): @@ -35,7 +37,7 @@ def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: if not isinstance(ks, dict): continue set_name = str(ks.get("name", "") or "").strip() - if not set_name or set_name == "Public": + if not set_name: continue bases = ks.get("knowledge_bases", []) @@ -48,15 +50,18 @@ def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: continue dataset_id = str(kb.get("dify_dataset_id", "") or "").strip() kb_name = str(kb.get("name", "") or "").strip() - if not dataset_id or not kb_name: + if not kb_name: continue + if dataset_id == "null" or not dataset_id: + dataset_id = _DEFAULT_KNOWLEDGE_BASE_ID normalized_bases.append( {"dify_dataset_id": dataset_id, "name": kb_name} ) - normalized_sets.append( - {"name": set_name, "knowledge_bases": normalized_bases} - ) + if normalized_bases: + normalized_sets.append( + {"name": set_name, "knowledge_bases": normalized_bases} + ) return {"knowledge_sets": normalized_sets} From 9d6fe0c4991d5b53f6dde95baeb6f00fb432b95a Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 14 May 2026 14:58:37 +0800 Subject: [PATCH 24/48] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=B9=B6=E5=8F=91?= =?UTF-8?q?=E6=95=B0=E9=87=8F=E7=9A=84=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/services/modelService.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index 07796d2c4..58e9b9887 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -727,6 +727,7 @@ export const modelService = { maximum_chunk_size: params.maximumChunkSize, chunk_batch: params.chunkingBatchSize, timeout_seconds: params.timeoutSeconds, + concurrency_limit: params.concurrencyLimit, }; // Add STT specific fields From d7c5bdfa1fad6271832ddb7e1dc3ef8cc391a49f Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Fri, 15 May 2026 19:09:33 +0800 Subject: [PATCH 25/48] Bugfix: Resolve frontend cache issue when only one model is available --- backend/utils/llm_utils.py | 7 ++--- .../agentInfo/AgentGenerateDetail.tsx | 27 ++++++++++++++++--- frontend/types/agentConfig.ts | 2 +- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index a5e90c727..fec97c827 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -103,12 +103,13 @@ def call_llm_for_system_prompt( reasoning_content_seen = False content_tokens_seen = 0 for chunk in current_request: - if not hasattr(chunk, "choices"): - logger.warning("Received non-standard chunk without choices during prompt generation.") - continue if not chunk.choices: logger.debug("Received empty choices chunk during prompt generation; skipping.") + # Safety check: skip non-standard chunks that lack expected attributes + if not hasattr(chunk, 'choices'): + if hasattr(chunk, '__str__'): + logger.warning(f"Received non-standard chunk (no 'choices'): {str(chunk)[:200]}") continue delta = chunk.choices[0].delta diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index db2667535..a12571385 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -281,6 +281,21 @@ export default function AgentGenerateDetail({ delete initialAgentInfo.group_ids; } + // Check if the agent's model is still available + const agentModelAvailable = availableLlmModels.some( + (m) => m.name === editedAgent.model || m.displayName === editedAgent.model + ); + let effectiveMainAgentModel = initialAgentInfo.mainAgentModel; + let effectiveMainAgentModelId = editedAgent.model_id || 0; + + if (!agentModelAvailable && defaultLlmModel) { + // Agent's original model is no longer available, switch to default model + effectiveMainAgentModel = defaultLlmModel.displayName || ""; + effectiveMainAgentModelId = defaultLlmModel.id || 0; + // Update the initialAgentInfo with the new model + initialAgentInfo.mainAgentModel = effectiveMainAgentModel; + } + const initialBusinessInfo = { businessDescription: editedAgent.business_description || "", businessLogicModelName: @@ -294,12 +309,18 @@ export default function AgentGenerateDetail({ setBusinessInfo(initialBusinessInfo); form.setFieldsValue(initialAgentInfo); - // Sync model to store if not already set (e.g., in create mode with default model) + // Sync model to store (use default model if original is unavailable) if (isCreatingMode && defaultLlmModel) { updateProfileInfo({ model: defaultLlmModel.displayName || "", model_id: defaultLlmModel.id || 0, }); + } else if (!agentModelAvailable && defaultLlmModel) { + // Update model in store when original model is no longer available + updateProfileInfo({ + model: effectiveMainAgentModel, + model_id: effectiveMainAgentModelId, + }); } // Sync max_step to store in create mode (default to 5) if (isCreatingMode && !editedAgent.max_step) { @@ -313,7 +334,7 @@ export default function AgentGenerateDetail({ }); } - }, [currentAgentId, defaultLlmModel?.id, isCreatingMode, forceRefreshKey]); + }, [currentAgentId, defaultLlmModel, isCreatingMode, forceRefreshKey, availableLlmModels]); // Default to selecting all groups when creating a new agent. // Only applies when groups are loaded and no group is selected yet. @@ -755,7 +776,7 @@ export default function AgentGenerateDetail({ { agent_id: effectiveAgentId, task_description: businessInfo.businessDescription, - model_id: businessInfo.businessLogicModelId.toString(), + model_id: businessInfo.businessLogicModelId, sub_agent_ids: editedAgent.sub_agent_id_list, tool_ids: Array.isArray(editedAgent.tools) ? editedAgent.tools.map((tool: any) => diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index b506730f8..a70d2a96b 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -407,7 +407,7 @@ export interface McpContainer { export interface GeneratePromptParams { agent_id: number; task_description: string; - model_id: string; + model_id: number; tool_ids?: number[]; // Optional: tool IDs selected in frontend (takes precedence over database query) sub_agent_ids?: number[]; // Optional: sub-agent IDs selected in frontend (takes precedence over database query) /** From c33b3688ccdf5468d3069ee34df7d088f5d3df4b Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 18 May 2026 14:51:48 +0800 Subject: [PATCH 26/48] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agents/components/agentInfo/AgentGenerateDetail.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index a12571385..8aac21829 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -334,7 +334,7 @@ export default function AgentGenerateDetail({ }); } - }, [currentAgentId, defaultLlmModel, isCreatingMode, forceRefreshKey, availableLlmModels]); + }, [currentAgentId, defaultLlmModel?.id, isCreatingMode, forceRefreshKey, availableLlmModels.length]); // Default to selecting all groups when creating a new agent. // Only applies when groups are loaded and no group is selected yet. From 626f263c82f214d74f8b53cd24d4d8279147fb6f Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 18 May 2026 17:48:32 +0800 Subject: [PATCH 27/48] Bugfix: Prevent overwriting of agent name and variable name when generating agent info --- .../agentInfo/AgentGenerateDetail.tsx | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 8aac21829..8d5c5ed53 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -837,7 +837,11 @@ export default function AgentGenerateDetail({ agentName: data.content, })); } - saveGeneratedField(generationAgentId, 'agentName', data.content); + // Only save to cache if user hasn't filled in agent name themselves + // This preserves user's input even if backend generates different values + if (!editedAgent.name && !form.getFieldValue("agentName")?.trim()) { + saveGeneratedField(generationAgentId, 'agentName', data.content); + } break; case GENERATE_PROMPT_STREAM_TYPES.AGENT_DESCRIPTION: if (isSameAgent) { @@ -847,7 +851,11 @@ export default function AgentGenerateDetail({ agentDescription: data.content, })); } - saveGeneratedField(generationAgentId, 'agentDescription', data.content); + // Only save to cache if user hasn't filled in agent description themselves + // This preserves user's input even if backend generates different values + if (!editedAgent.description && !form.getFieldValue("agentDescription")?.trim()) { + saveGeneratedField(generationAgentId, 'agentDescription', data.content); + } break; case GENERATE_PROMPT_STREAM_TYPES.AGENT_DISPLAY_NAME: if (isSameAgent) { @@ -860,7 +868,11 @@ export default function AgentGenerateDetail({ agentDisplayName: data.content, })); } - saveGeneratedField(generationAgentId, 'agentDisplayName', data.content); + // Only save to cache if user hasn't filled in agent display name themselves + // This preserves user's input even if backend generates different values + if (!editedAgent.display_name && !form.getFieldValue("agentDisplayName")?.trim()) { + saveGeneratedField(generationAgentId, 'agentDisplayName', data.content); + } break; } }, From 9b59e72125d564a020dabf9c3dc37e166557046a Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 18 May 2026 17:50:28 +0800 Subject: [PATCH 28/48] Bugfix: Immediately show login page on 401 response --- frontend/services/api.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 38958f21a..b441ff2e0 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -437,6 +437,12 @@ export const fetchWithErrorHandling = async ( throw new ApiError(errorCode, errorMessage); } + // Handle HTTP 401 - trigger session expired modal for all unauthorized errors + if (response.status === 401) { + handleSessionExpired(); + throw new ApiError(errorCode, errorMessage); + } + // Handle custom 499 error code (client closed connection) if (response.status === 499) { handleSessionExpired(); From 7c0ff333090f7240eead5f72997fd214d0d9458e Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 18 May 2026 19:20:32 +0800 Subject: [PATCH 29/48] Revert "Bugfix: Immediately show login page on 401 response" This reverts commit 9b59e72125d564a020dabf9c3dc37e166557046a. --- frontend/services/api.ts | 6 ------ 1 file changed, 6 deletions(-) diff --git a/frontend/services/api.ts b/frontend/services/api.ts index b441ff2e0..38958f21a 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -437,12 +437,6 @@ export const fetchWithErrorHandling = async ( throw new ApiError(errorCode, errorMessage); } - // Handle HTTP 401 - trigger session expired modal for all unauthorized errors - if (response.status === 401) { - handleSessionExpired(); - throw new ApiError(errorCode, errorMessage); - } - // Handle custom 499 error code (client closed connection) if (response.status === 499) { handleSessionExpired(); From dfe69ee3db6b6bebc1de27a73795969e5eed7f1f Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 18 May 2026 20:22:04 +0800 Subject: [PATCH 30/48] Bugfix: Remove invalid concurrency_limit related code from OpenAIModel --- sdk/nexent/core/agents/nexent_agent.py | 1 - sdk/nexent/core/models/openai_llm.py | 24 +++++------------------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 691d441d9..b836c9e8f 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -71,7 +71,6 @@ def create_model(self, model_cite_name: str): model_factory=model_config.model_factory, display_name=model_config.cite_name, timeout_seconds=model_config.timeout_seconds, - concurrency_limit=model_config.concurrency_limit, ) model.stop_event = self.stop_event return model diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 02c4f74bd..4ed3374d8 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -25,7 +25,6 @@ class OpenAIModel(OpenAIServerModel): def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p=0.95, ssl_verify=True, model_factory: Optional[str] = None, display_name: Optional[str] = None, timeout_seconds: Optional[float] = None, - concurrency_limit: Optional[int] = None, *args, **kwargs): """ Initialize OpenAI Model with observer and SSL verification option. @@ -39,7 +38,6 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, model_factory: Provider identifier (e.g., openai, modelengine) display_name: Human-readable display name for monitoring timeout_seconds: Request timeout in seconds. If None, uses httpx default. - concurrency_limit: Maximum concurrent requests. If None, no limit. *args: Additional positional arguments for OpenAIServerModel **kwargs: Additional keyword arguments for OpenAIServerModel """ @@ -51,10 +49,6 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, self.model_factory = (model_factory or "").lower() self.display_name = display_name self.timeout_seconds = timeout_seconds - self.concurrency_limit = concurrency_limit - self._semaphore = None - if concurrency_limit is not None and concurrency_limit > 0: - self._semaphore = asyncio.Semaphore(concurrency_limit) # Create http_client with trust_env=False to ignore proxy env vars import httpx @@ -308,19 +302,11 @@ async def check_connectivity(self) -> bool: import httpx request_kwargs["timeout"] = httpx.Timeout(self.timeout_seconds) - # Use semaphore for concurrency control if configured - async def _make_request(): - # Offload the blocking SDK call to a thread pool to avoid blocking the event loop - await asyncio.to_thread( - self.client.chat.completions.create, - **request_kwargs, - ) - - if self._semaphore is not None: - async with self._semaphore: - await _make_request() - else: - await _make_request() + # Offload the blocking SDK call to a thread pool to avoid blocking the event loop + await asyncio.to_thread( + self.client.chat.completions.create, + **request_kwargs, + ) # If no exception is raised, the connection is successful return True From b7c5b82898d35d8d7998cd8ba86cc600db74e549 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 18 May 2026 21:19:42 +0800 Subject: [PATCH 31/48] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/model_management_db.py | 10 +++++++-- backend/services/model_management_service.py | 2 -- .../components/model/ModelDeleteDialog.tsx | 4 ++-- .../components/model/ModelEditDialog.tsx | 4 ++-- sdk/nexent/core/models/openai_llm.py | 13 +----------- .../services/test_agent_version_service.py | 21 +++++++++++++++++++ .../services/test_file_management_service.py | 3 ++- .../test_tool_configuration_service.py | 3 ++- test/backend/utils/test_llm_utils.py | 1 + test/sdk/core/agents/test_nexent_agent.py | 2 ++ 10 files changed, 41 insertions(+), 22 deletions(-) diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index d501fd52f..8ecf6c60d 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -92,7 +92,8 @@ def update_model_record_by_model_name( update_data: Dict[str, Any], user_id: Optional[str] = None, tenant_id: Optional[str] = None, - model_repo: Optional[str] = None + model_repo: Optional[str] = None, + model_factory: Optional[str] = None ) -> bool: """ Update a model record by model_name and tenant_id. @@ -103,6 +104,7 @@ def update_model_record_by_model_name( user_id: Reserved parameter for filling updated_by field tenant_id: Tenant ID for filtering model_repo: Optional model repo for more precise matching + model_factory: Optional model vendor for more precise matching Returns: bool: Whether the operation was successful @@ -119,7 +121,9 @@ def update_model_record_by_model_name( if user_id: cleaned_data = add_update_tracking(cleaned_data, user_id) - db_logger.debug(f"update_model_record_by_model_name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}, cleaned_data={cleaned_data}") + db_logger.debug(f"update_model_record_by_model_name: model_name={model_name}, " + f"model_repo={model_repo}, model_factory={model_factory}, " + f"tenant_id={tenant_id}, cleaned_data={cleaned_data}") # Build conditions list conditions = [ @@ -128,6 +132,8 @@ def update_model_record_by_model_name( ] if model_repo: conditions.append(ModelRecord.model_repo == model_repo) + if model_factory: + conditions.append(ModelRecord.model_factory == model_factory) # Build the update statement stmt = update(ModelRecord).where(*conditions).values(cleaned_data) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index ab0e52259..f65cae6b0 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -302,8 +302,6 @@ async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_lis # Check if model_id is a numeric string (primary key) if model_id_or_name and model_id_or_name.isdigit(): - # Use model_id (primary key) for update - logging.info(f"[DEBUG] Updating model by id: model_id={model_id_or_name}, tenant_id={tenant_id}, update_data={update_data}") update_model_record(int(model_id_or_name), update_data, user_id, tenant_id) else: # Parse "model_repo/model_name" format from frontend's model_id field diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index 894f50907..88fec353d 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -590,7 +590,7 @@ export const ModelDeleteDialog = ({ timeoutSeconds, concurrencyLimit, }: { - apiKey: string; + apiKey?: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number; @@ -624,7 +624,7 @@ export const ModelDeleteDialog = ({ ) .map((m) => ({ model_id: String(m.id), - apiKey: apiKey || m.apiKey, + apiKey: apiKey ?? m.apiKey, maxTokens: maxTokens || m.maxTokens, ...(timeoutSeconds !== undefined ? { timeoutSeconds } : {}), ...(concurrencyLimit !== undefined ? { concurrencyLimit } : {}), diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 1fa2b4a64..5470fd39b 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -460,7 +460,7 @@ interface ProviderConfigEditDialogProps { modelType?: ModelType showApiKeyField?: boolean // Whether to show API Key field (default: true) onClose: () => void - onSave: (config: { apiKey: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number }) => Promise | void + onSave: (config: { apiKey?: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number }) => Promise | void } export const ProviderConfigEditDialog = ({ @@ -500,7 +500,7 @@ export const ProviderConfigEditDialog = ({ const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING const isRerankModel = modelType === MODEL_TYPES.RERANK await onSave({ - apiKey: showApiKeyField ? (apiKey.trim() === '' ? 'sk-no-api-key' : apiKey) : '', + ...(showApiKeyField ? { apiKey: apiKey.trim() === '' ? 'sk-no-api-key' : apiKey } : {}), maxTokens: parseInt(maxTokens), ...(!isEmbeddingModel && !isRerankModel ? { timeoutSeconds: parseInt(timeoutSeconds) || 120 } : {}), ...(!isEmbeddingModel && !isRerankModel ? { concurrencyLimit: concurrencyLimit ? parseInt(concurrencyLimit) : undefined } : {}), diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 4ed3374d8..fba18151a 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -37,7 +37,7 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, Set to False for local services without SSL support. model_factory: Provider identifier (e.g., openai, modelengine) display_name: Human-readable display name for monitoring - timeout_seconds: Request timeout in seconds. If None, uses httpx default. + timeout_seconds: Request timeout in seconds. If None, defaults to 120 seconds. *args: Additional positional arguments for OpenAIServerModel **kwargs: Additional keyword arguments for OpenAIServerModel """ @@ -60,17 +60,6 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, super().__init__(*args, **kwargs) - # Apply custom timeout to client if specified (even when ssl_verify is True) - if timeout_seconds is not None and hasattr(self, 'client'): - import httpx - # Update client's timeout - new_timeout = httpx.Timeout(timeout_seconds) - if hasattr(self.client, '_client'): - # httpx client wrapped by openai - self.client._client.timeout = new_timeout - elif hasattr(self.client, 'timeout'): - self.client.timeout = new_timeout - # Wrap the OpenAI client with monitoring interceptor model_type = _detect_model_type(self) model_id = getattr(self, "model_id", None) diff --git a/test/backend/services/test_agent_version_service.py b/test/backend/services/test_agent_version_service.py index d29795200..2273a45c2 100644 --- a/test/backend/services/test_agent_version_service.py +++ b/test/backend/services/test_agent_version_service.py @@ -22,6 +22,27 @@ sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_mock.const +# Mock consts.agent_unavailable_reasons +agent_unavailable_reasons_mock = MagicMock() +agent_unavailable_reasons_mock.AgentUnavailableReason = type('AgentUnavailableReason', (), { + 'DUPLICATE_NAME': 'duplicate_name', + 'DUPLICATE_DISPLAY_NAME': 'duplicate_display_name', + 'MODEL_NOT_CONFIGURED': 'model_not_configured', + 'MODEL_UNAVAILABLE': 'model_unavailable', + 'TOOL_UNAVAILABLE': 'tool_unavailable', + 'ALL_TOOLS_DISABLED': 'all_tools_disabled', + 'AGENT_NOT_FOUND': 'agent_not_found', + 'all_reasons': classmethod(lambda cls: [ + 'duplicate_name', 'duplicate_display_name', 'model_not_configured', + 'model_unavailable', 'tool_unavailable', 'all_tools_disabled', 'agent_not_found' + ]), + 'is_valid_reason': classmethod(lambda cls, reason: reason in [ + 'duplicate_name', 'duplicate_display_name', 'model_not_configured', + 'model_unavailable', 'tool_unavailable', 'all_tools_disabled', 'agent_not_found' + ]), +})() +sys.modules['consts.agent_unavailable_reasons'] = agent_unavailable_reasons_mock + # Mock utils module utils_mock = MagicMock() utils_mock.auth_utils = MagicMock() diff --git a/test/backend/services/test_file_management_service.py b/test/backend/services/test_file_management_service.py index 1effa6e26..73df6441d 100644 --- a/test/backend/services/test_file_management_service.py +++ b/test/backend/services/test_file_management_service.py @@ -1382,7 +1382,8 @@ def test_get_llm_model_success(self, mock_tenant_config, mock_get_model_name, mo api_base="http://api.example.com", api_key="test_api_key", max_context_tokens=4096, - ssl_verify=True + ssl_verify=True, + timeout_seconds=None ) @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 3cbdcee2b..1c9bf2a8f 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -3065,7 +3065,8 @@ def test_get_llm_model_success(self, mock_tenant_config, mock_get_model_name, mo api_base="http://api.example.com", api_key="test_api_key", max_context_tokens=4096, - ssl_verify=True + ssl_verify=True, + timeout_seconds=None ) @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index e0d5577ae..8f5e8e3c8 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -141,6 +141,7 @@ def test_call_llm_for_system_prompt_success(self, mocker: MockFixture): top_p=0.95, ssl_verify=True, display_name=None, + timeout_seconds=None, ) def test_call_llm_for_system_prompt_exception(self, mocker: MockFixture): diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 9853b9eca..f640b2095 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -470,6 +470,7 @@ def test_create_model_success(nexent_agent_with_models, mock_model_config): top_p=mock_model_config.top_p, ssl_verify=True, display_name=mock_model_config.cite_name, + timeout_seconds=mock_model_config.timeout_seconds, ) # Verify stop_event was set @@ -499,6 +500,7 @@ def test_create_model_deep_thinking_success(nexent_agent_with_models, mock_deep_ top_p=mock_deep_thinking_model_config.top_p, ssl_verify=True, display_name=mock_deep_thinking_model_config.cite_name, + timeout_seconds=mock_deep_thinking_model_config.timeout_seconds, ) # Verify stop_event was set From 6d6c6dfab1a754e19f48c7644e258dda48e0e313 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 10:01:26 +0800 Subject: [PATCH 32/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/services/test_model_health_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py index 4cb527603..4434bda1e 100644 --- a/test/backend/services/test_model_health_service.py +++ b/test/backend/services/test_model_health_service.py @@ -577,7 +577,7 @@ async def test_verify_model_config_connectivity_success(): mock_connectivity_check.assert_called_once_with( "gpt-4", "llm", "https://api.openai.com", "test-key", True, - None, None, None, "GPT-4" + None, None, None, "GPT-4", timeout_seconds=None, ) From ac7539074cccc04aafa49a89006da305bbf607d2 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 10:39:02 +0800 Subject: [PATCH 33/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/model_management_db.py | 60 +--------------- backend/services/model_management_service.py | 34 +++++---- test/backend/services/test_prompt_service.py | 74 +++++++++++++++++++- test/sdk/core/models/test_embedding_model.py | 73 ++++++++++--------- test/sdk/core/models/test_openai_llm.py | 22 +++--- 5 files changed, 144 insertions(+), 119 deletions(-) diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index 8ecf6c60d..d260ec48c 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -3,7 +3,7 @@ from sqlalchemy import and_, desc, func, insert, select, update -from consts.const import DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE +from backend.consts.const import DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE from .client import as_dict, db_client, get_db_session from .db_models import ModelRecord from .utils import add_creation_tracking, add_update_tracking @@ -87,64 +87,6 @@ def update_model_record( return result.rowcount > 0 -def update_model_record_by_model_name( - model_name: str, - update_data: Dict[str, Any], - user_id: Optional[str] = None, - tenant_id: Optional[str] = None, - model_repo: Optional[str] = None, - model_factory: Optional[str] = None -) -> bool: - """ - Update a model record by model_name and tenant_id. - - Args: - model_name: Model name (display name, not the primary key) - update_data: Dictionary containing update data - user_id: Reserved parameter for filling updated_by field - tenant_id: Tenant ID for filtering - model_repo: Optional model repo for more precise matching - model_factory: Optional model vendor for more precise matching - - Returns: - bool: Whether the operation was successful - """ - import logging - db_logger = logging.getLogger("database.client") - - with get_db_session() as session: - # Data cleaning - cleaned_data = db_client.clean_string_values(update_data) - - # Add update timestamp - cleaned_data["update_time"] = func.current_timestamp() - if user_id: - cleaned_data = add_update_tracking(cleaned_data, user_id) - - db_logger.debug(f"update_model_record_by_model_name: model_name={model_name}, " - f"model_repo={model_repo}, model_factory={model_factory}, " - f"tenant_id={tenant_id}, cleaned_data={cleaned_data}") - - # Build conditions list - conditions = [ - ModelRecord.model_name == model_name, - ModelRecord.tenant_id == tenant_id - ] - if model_repo: - conditions.append(ModelRecord.model_repo == model_repo) - if model_factory: - conditions.append(ModelRecord.model_factory == model_factory) - - # Build the update statement - stmt = update(ModelRecord).where(*conditions).values(cleaned_data) - - # Execute the update statement - result = session.execute(stmt) - db_logger.info(f"update_model_record_by_model_name: rowcount={result.rowcount}") - - return result.rowcount > 0 - - def delete_model_record(model_id: int, user_id: str, tenant_id: str) -> bool: """ Delete a model record (soft delete) and update the update timestamp diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index f65cae6b0..1fce67b54 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -1,11 +1,11 @@ import logging from typing import List, Dict, Any, Optional -from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST -from consts.model import ModelConnectStatusEnum -from consts.provider import ProviderEnum, SILICON_BASE_URL, DASHSCOPE_BASE_URL, TOKENPONY_BASE_URL +from backend.consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST +from backend.consts.model import ModelConnectStatusEnum +from backend.consts.provider import ProviderEnum, SILICON_BASE_URL, DASHSCOPE_BASE_URL, TOKENPONY_BASE_URL -from database.model_management_db import ( +from backend.database.model_management_db import ( create_model_record, delete_model_record, get_model_by_display_name, @@ -13,22 +13,21 @@ get_model_records, get_models_by_tenant_factory_type, update_model_record, - update_model_record_by_model_name, ) -from services.model_provider_service import ( +from backend.services.model_provider_service import ( prepare_model_dict, merge_existing_model_attributes, get_provider_models, ) -from services.model_health_service import embedding_dimension_check -from utils.model_name_utils import ( +from backend.services.model_health_service import embedding_dimension_check +from backend.utils.model_name_utils import ( add_repo_to_name, split_repo_name, sort_models_by_id, ) -from utils.memory_utils import build_memory_config as build_memory_config_for_tenant -from services.vectordatabase_service import get_vector_db_core -from nexent.memory.memory_service import clear_model_memories +from backend.utils.memory_utils import build_memory_config as build_memory_config_for_tenant +from backend.services.vectordatabase_service import get_vector_db_core +from backend.nexent.memory.memory_service import clear_model_memories logger = logging.getLogger("model_management_service") @@ -293,6 +292,8 @@ async def update_single_model_for_tenant( async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_list: List[Dict[str, Any]]): """Batch update models for a tenant by model_id or model_name.""" + from backend.database.model_management_db import get_model_by_name_factory + try: for model in model_list: # Build update data excluding id fields @@ -311,8 +312,15 @@ async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_lis model_repo = None model_name = model_id_or_name - logging.info(f"[DEBUG] Updating model by name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}, update_data={update_data}") - update_model_record_by_model_name(model_name, update_data, user_id, tenant_id, model_repo) + logging.info(f"[DEBUG] Updating model by name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}") + + # Query to get model_id first, then update by primary key + model_record = get_model_by_name_factory(model_name, model_repo, tenant_id) + if not model_record: + logging.warning(f"Model not found: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}") + continue + + update_model_record(model_record["model_id"], update_data, user_id, tenant_id) logging.info("[DEBUG] Batch update models successfully") except Exception as e: diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 1b71baa5c..27df68d07 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -1,11 +1,65 @@ import json +import sys import unittest from unittest.mock import patch, MagicMock from consts.error_code import ErrorCode from consts.exceptions import AppException +# Mock database submodules before any imports +# These must be mocked before importing modules that depend on them +database_mock = MagicMock() +sys.modules['database'] = database_mock +sys.modules['database.model_management_db'] = MagicMock() +sys.modules['database.agent_db'] = MagicMock() +sys.modules['database.tool_db'] = MagicMock() +sys.modules['database.knowledge_db'] = MagicMock() +sys.modules['database.client'] = MagicMock() +sys.modules['database.attachment_db'] = MagicMock() +sys.modules['database.group_db'] = MagicMock() +sys.modules['database.user_tenant_db'] = MagicMock() +sys.modules['database.remote_mcp_db'] = MagicMock() +sys.modules['database.agent_version_db'] = MagicMock() +sys.modules['database.a2a_agent_db'] = MagicMock() +sys.modules['backend.database'] = MagicMock() +sys.modules['backend.database.model_management_db'] = MagicMock() +sys.modules['backend.database.agent_db'] = MagicMock() +sys.modules['backend.database.tool_db'] = MagicMock() +sys.modules['backend.database.knowledge_db'] = MagicMock() +sys.modules['backend.database.client'] = MagicMock() +sys.modules['backend.database.attachment_db'] = MagicMock() +sys.modules['backend.database.group_db'] = MagicMock() +sys.modules['backend.database.user_tenant_db'] = MagicMock() +sys.modules['backend.database.remote_mcp_db'] = MagicMock() +sys.modules['backend.database.agent_version_db'] = MagicMock() +sys.modules['backend.database.a2a_agent_db'] = MagicMock() + +# Mock services submodules (NOT backend.services which blocks import of prompt_service) +sys.modules['services'] = MagicMock() +sys.modules['services.agent_service'] = MagicMock() +sys.modules['services.file_management_service'] = MagicMock() +sys.modules['services.conversation_management_service'] = MagicMock() +sys.modules['services.memory_config_service'] = MagicMock() +sys.modules['services.agent_version_service'] = MagicMock() + +# Mock agents submodules +sys.modules['agents'] = MagicMock() +sys.modules['agents.create_agent_info'] = MagicMock() +sys.modules['backend.agents'] = MagicMock() +sys.modules['backend.agents.create_agent_info'] = MagicMock() + +# Mock utils submodules to avoid llm_utils import triggering database connection +sys.modules['utils'] = MagicMock() +sys.modules['utils.llm_utils'] = MagicMock() +sys.modules['utils.prompt_template_utils'] = MagicMock() +sys.modules['utils.config_utils'] = MagicMock() +sys.modules['utils.auth_utils'] = MagicMock() +sys.modules['backend.utils'] = MagicMock() +sys.modules['backend.utils.llm_utils'] = MagicMock() +sys.modules['backend.utils.prompt_template_utils'] = MagicMock() +sys.modules['backend.utils.config_utils'] = MagicMock() +sys.modules['backend.utils.auth_utils'] = MagicMock() + # Mock boto3 and minio client before importing the module under test -import sys boto3_mock = MagicMock() sys.modules['boto3'] = boto3_mock @@ -674,10 +728,13 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): expected_data = f"data: {json.dumps({'success': True, 'data': test_data[i]}, ensure_ascii=False)}\n\n" self.assertEqual(result, expected_data) + @patch('database.model_management_db.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt(self, mock_get_prompt_template, mock_join_info, mock_call_llm): + def test_generate_system_prompt(self, mock_get_prompt_template, mock_join_info, mock_call_llm, mock_get_model): + # Mock model config to avoid concurrency limit issue + mock_get_model.return_value = {"concurrency_limit": None} # Setup mock_prompt_config = { "USER_PROMPT": "Test user prompt template", @@ -791,10 +848,13 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): self.assertIsInstance(result["is_complete"], bool) self.assertIsInstance(result["content"], str) + @patch('database.model_management_db.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt_with_exception(self, mock_get_prompt_template, mock_join_info, mock_call_llm): + def test_generate_system_prompt_with_exception(self, mock_get_prompt_template, mock_join_info, mock_call_llm, mock_get_model): + # Mock model config to avoid concurrency limit issue + mock_get_model.return_value = {"concurrency_limit": None} # Setup mock_prompt_config = { "USER_PROMPT": "Test user prompt template", @@ -1088,6 +1148,7 @@ def mock_gen(*args, **kwargs): self.assertIn("Failed to generate prompt content", str(context.exception)) + @patch('database.model_management_db.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') @@ -1096,7 +1157,10 @@ def test_generate_system_prompt_error_before_streaming( mock_get_prompt_template, mock_join_info, mock_call_llm, + mock_get_model, ): + # Mock model config to avoid concurrency limit issue + mock_get_model.return_value = {"concurrency_limit": None} """Test generate_system_prompt handles error that occurs before streaming (line 307-311)""" # Setup mock_prompt_config = { @@ -1137,6 +1201,7 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): self.assertIn("LLM connection error", str(context.exception)) + @patch('database.model_management_db.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') @@ -1145,7 +1210,10 @@ def test_generate_system_prompt_error_during_streaming( mock_get_prompt_template, mock_join_info, mock_call_llm, + mock_get_model, ): + # Mock model config to avoid concurrency limit issue + mock_get_model.return_value = {"concurrency_limit": None} """Test generate_system_prompt handles error that occurs during streaming (line 330-331)""" # Setup mock_prompt_config = { diff --git a/test/sdk/core/models/test_embedding_model.py b/test/sdk/core/models/test_embedding_model.py index 9c3f8824b..c7856f46f 100644 --- a/test/sdk/core/models/test_embedding_model.py +++ b/test/sdk/core/models/test_embedding_model.py @@ -3,7 +3,7 @@ import sys from unittest.mock import AsyncMock, Mock, patch -from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding +from sdk.nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding class DummyResponse: def __init__(self, status_code=200, json_data=None): @@ -54,7 +54,7 @@ async def test_dimension_check_success(openai_embedding_instance): expected_embeddings = [[0.1, 0.2, 0.3]] with patch( - "nexent.core.models.embedding_model.asyncio.to_thread", + "sdk.nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, return_value=expected_embeddings, ) as mock_to_thread: @@ -69,7 +69,7 @@ async def test_dimension_check_failure(openai_embedding_instance): """dimension_check should return an empty list when an exception is raised inside to_thread.""" with patch( - "nexent.core.models.embedding_model.asyncio.to_thread", + "sdk.nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("connection error"), ) as mock_to_thread: @@ -91,7 +91,7 @@ async def test_jina_dimension_check_success(jina_embedding_instance): expected_embeddings = [[0.5, 0.4, 0.3]] with patch( - "nexent.core.models.embedding_model.asyncio.to_thread", + "sdk.nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, return_value=expected_embeddings, ) as mock_to_thread: @@ -106,7 +106,7 @@ async def test_jina_dimension_check_failure(jina_embedding_instance): """dimension_check should return an empty list when an exception is raised inside to_thread.""" with patch( - "nexent.core.models.embedding_model.asyncio.to_thread", + "sdk.nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("connection error"), ) as mock_to_thread: @@ -127,7 +127,7 @@ def test_openai_get_embeddings_success_returns_list(openai_embedding_instance): fake_response = {"data": [{"embedding": [0.9, 0.8]}]} with patch( - "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", return_value=fake_response, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -145,7 +145,7 @@ def test_openai_get_embeddings_with_metadata(openai_embedding_instance): "data": [{"embedding": [1, 2, 3]}], "meta": {"foo": "bar"}} with patch( - "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", return_value=fake_response, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -172,7 +172,7 @@ def side_effect(data, timeout=None): side_effect.calls = 0 with patch( - "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", side_effect=side_effect, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -192,7 +192,7 @@ def test_openai_get_embeddings_timeout_exhausts_raises(openai_embedding_instance """Should raise Timeout after exhausting retries.""" with patch( - "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", side_effect=requests.exceptions.Timeout(), ) as mock_make_request: with pytest.raises(requests.exceptions.Timeout): @@ -226,7 +226,7 @@ def side_effect(inputs, with_metadata=False, timeout=None): return [[0.3, 0.4]] with patch( - "nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", + "sdk.nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", side_effect=side_effect, ) as mock_delegate: result = jina_embedding_instance.get_embeddings( @@ -251,7 +251,7 @@ def side_effect(inputs, with_metadata=False, timeout=None): side_effect.calls = 0 with patch( - "nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", + "sdk.nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", side_effect=side_effect, ) as mock_delegate: result = jina_embedding_instance.get_embeddings( @@ -273,7 +273,7 @@ def test_jina_get_embeddings_timeout_exhausts_raises(jina_embedding_instance): """Should raise Timeout after exhausting retries.""" with patch( - "nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", + "sdk.nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", side_effect=requests.exceptions.Timeout(), ) as mock_delegate: with pytest.raises(requests.exceptions.Timeout): @@ -306,7 +306,7 @@ def test_jina_get_multimodal_embeddings_parses_embeddings(jina_embedding_instanc mock_resp.json = Mock(return_value=fake_response) with patch( - "nexent.core.models.embedding_model.requests.post", return_value=mock_resp + "sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp ) as mock_post: inputs = [{"text": "t1"}, {"image": "http://x/y.jpg"}] result = jina_embedding_instance.get_multimodal_embeddings( @@ -334,7 +334,7 @@ def test_jina_get_multimodal_embeddings_with_metadata(jina_embedding_instance): mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch("nexent.core.models.embedding_model.requests.post", return_value=mock_resp) as mock_post: + with patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) as mock_post: inputs = [{"text": "t"}] result = jina_embedding_instance.get_multimodal_embeddings( inputs, with_metadata=True, timeout=4 @@ -370,7 +370,7 @@ def side_effect(url, headers=None, json=None, timeout=None, **kwargs): side_effect.calls = 0 with patch( - "nexent.core.models.embedding_model.requests.post", side_effect=side_effect + "sdk.nexent.core.models.embedding_model.requests.Session.post", side_effect=side_effect ) as mock_post: inputs = [{"text": "t"}] result = jina_embedding_instance.get_multimodal_embeddings( @@ -391,7 +391,7 @@ def test_jina_get_multimodal_embeddings_timeout_exhausts_raises( """Should raise Timeout after exhausting retries.""" with patch( - "nexent.core.models.embedding_model.requests.post", + "sdk.nexent.core.models.embedding_model.requests.Session.post", side_effect=requests.exceptions.Timeout(), ) as mock_post: with pytest.raises(requests.exceptions.Timeout): @@ -438,7 +438,7 @@ async def test_jina_dimension_check_connection_error_returns_empty(jina_embeddin """dimension_check should return [] on ConnectionError.""" with patch( - "nexent.core.models.embedding_model.asyncio.to_thread", + "sdk.nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=requests.exceptions.ConnectionError(), ): @@ -457,7 +457,7 @@ def side_effect(data, timeout=None): return {"data": [{"embedding": [0.21, 0.22]}]} with patch( - "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", side_effect=side_effect, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -470,7 +470,7 @@ def side_effect(data, timeout=None): def test_openai_make_request_invokes_requests_post(openai_embedding_instance): - """Cover OpenAI _make_request by patching requests.post path.""" + """Cover OpenAI _make_request by patching requests.Session.post path.""" fake_response = {"data": [{"embedding": [7, 8]}]} @@ -478,7 +478,7 @@ def test_openai_make_request_invokes_requests_post(openai_embedding_instance): mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch("nexent.core.models.embedding_model.requests.post", return_value=mock_resp) as mock_post: + with patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) as mock_post: result = openai_embedding_instance.get_embeddings( ["hi"], with_metadata=False, timeout=2 ) @@ -502,7 +502,7 @@ async def test_openai_dimension_check_connection_error_returns_empty(openai_embe """dimension_check should return [] on ConnectionError.""" with patch( - "nexent.core.models.embedding_model.asyncio.to_thread", + "sdk.nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=requests.exceptions.ConnectionError(), ): @@ -513,13 +513,13 @@ async def test_openai_dimension_check_connection_error_returns_empty(openai_embe def test_api_key_normalization_and_verify_jina(monkeypatch): captured = {} - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): captured['url'] = url captured['headers'] = headers captured['verify'] = verify return DummyResponse() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr(requests.Session, "post", fake_post) # api_key containing Bearer prefix should be normalized emb = JinaEmbedding(api_key="my-secret", base_url="https://example.com/emb", ssl_verify=False) @@ -533,13 +533,13 @@ def fake_post(url, headers=None, json=None, timeout=None, verify=True): def test_api_key_normalization_and_verify_openaicompatible(monkeypatch): captured = {} - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): captured['url'] = url captured['headers'] = headers captured['verify'] = verify return DummyResponse() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr(requests.Session, "post", fake_post) emb = OpenAICompatibleEmbedding(model_name="m", base_url="https://api.example/emb", api_key="KEY", embedding_dim=16, ssl_verify=True) data = emb._prepare_input("hi") @@ -574,9 +574,9 @@ async def dimension_check(self, timeout: float = 5.0): def test_jina_make_request_raises_http_error(monkeypatch): - """Ensure _make_request propagates HTTP errors from requests.post""" + """Ensure _make_request propagates HTTP errors from requests.Session.post""" - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): class BadResp: status_code = 500 @@ -585,7 +585,7 @@ def raise_for_status(self): return BadResp() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr(requests.Session, "post", fake_post) emb = JinaEmbedding(api_key="k", base_url="https://api.jina.ai/v1/embeddings", ssl_verify=True) data = emb._prepare_multimodal_input([{"text": "hi"}]) @@ -596,7 +596,7 @@ def raise_for_status(self): def test_openai_make_request_raises_http_error(monkeypatch): """Ensure OpenAICompatibleEmbedding._make_request propagates HTTP errors""" - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): class BadResp: status_code = 502 @@ -605,7 +605,7 @@ def raise_for_status(self): return BadResp() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr(requests.Session, "post", fake_post) emb = OpenAICompatibleEmbedding(model_name="m", base_url="https://api.example.com/emb", api_key="k", embedding_dim=16, ssl_verify=False) data = emb._prepare_input("hello") @@ -623,7 +623,10 @@ def raise_for_status(self): def json(self): return {"meta": {"ok": True}} - monkeypatch.setattr("requests.post", lambda *a, **k: RespNoData()) + def fake_post(self, *a, **k): + return RespNoData() + + monkeypatch.setattr(requests.Session, "post", fake_post) emb = JinaEmbedding(api_key="k") with pytest.raises(KeyError): @@ -641,13 +644,13 @@ def test_openai_get_embeddings_calls_record_model_call(mocker): mock_ctx.__enter__ = mocker.MagicMock(return_value=None) mock_ctx.__exit__ = mocker.MagicMock(return_value=False) mock_record = mocker.patch( - "nexent.core.models.embedding_model.record_model_call", + "sdk.nexent.core.models.embedding_model.record_model_call", return_value=mock_ctx, ) mock_resp = Mock() mock_resp.raise_for_status = Mock() mock_resp.json.return_value = {"data": [{"embedding": [0.1, 0.2]}]} - mocker.patch("requests.post", return_value=mock_resp) + mocker.patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) emb = OpenAICompatibleEmbedding( model_name="text-emb-3", @@ -669,13 +672,13 @@ def test_jina_get_embeddings_calls_record_model_call(mocker): mock_ctx.__enter__ = mocker.MagicMock(return_value=None) mock_ctx.__exit__ = mocker.MagicMock(return_value=False) mock_record = mocker.patch( - "nexent.core.models.embedding_model.record_model_call", + "sdk.nexent.core.models.embedding_model.record_model_call", return_value=mock_ctx, ) mock_resp = Mock() mock_resp.raise_for_status = Mock() mock_resp.json.return_value = {"data": [{"embedding": [0.1, 0.2]}]} - mocker.patch("requests.post", return_value=mock_resp) + mocker.patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) emb = JinaEmbedding(api_key="k", ssl_verify=True) emb.get_multimodal_embeddings([{"text": "hi"}], with_metadata=False, timeout=5) diff --git a/test/sdk/core/models/test_openai_llm.py b/test/sdk/core/models/test_openai_llm.py index 0477a86a1..2589e1d6c 100644 --- a/test/sdk/core/models/test_openai_llm.py +++ b/test/sdk/core/models/test_openai_llm.py @@ -887,29 +887,33 @@ def test_init_with_ssl_verify_false(): observer = MagicMock() - # Mock DefaultHttpxClient from openai module - with patch("openai.DefaultHttpxClient") as mock_httpx_client: + # Mock httpx.Client directly (it's imported inside __init__) + with patch("httpx.Client") as mock_httpx_client: mock_httpx_client.return_value = MagicMock() # Create model with ssl_verify=False model = ImportedOpenAIModel(observer=observer, ssl_verify=False) - # Verify DefaultHttpxClient was called with verify=False - mock_httpx_client.assert_called_once_with(verify=False) + # Verify httpx.Client was called with verify=False + mock_httpx_client.assert_called_once() + call_kwargs = mock_httpx_client.call_args + assert call_kwargs.kwargs.get("verify") is False def test_init_with_ssl_verify_true(): - """Test __init__ method doesn't create http_client when ssl_verify=True (default)""" + """Test __init__ method creates http_client when ssl_verify=True (default)""" observer = MagicMock() - # Mock DefaultHttpxClient from openai module - with patch("openai.DefaultHttpxClient") as mock_httpx_client: + # Mock httpx.Client directly (it's imported inside __init__) + with patch("httpx.Client") as mock_httpx_client: # Create model with ssl_verify=True (default) model = ImportedOpenAIModel(observer=observer, ssl_verify=True) - # Verify DefaultHttpxClient was NOT called - mock_httpx_client.assert_not_called() + # Verify httpx.Client was called (it's always created, the verify param differs) + assert mock_httpx_client.call_count == 1 + call_kwargs = mock_httpx_client.call_args + assert call_kwargs.kwargs.get("verify") is True # --------------------------------------------------------------------------- From 6d63acb8c058f27cb425779e70302c88214e4e16 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 11:02:48 +0800 Subject: [PATCH 34/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/services/test_model_health_service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py index 4434bda1e..7cc527372 100644 --- a/test/backend/services/test_model_health_service.py +++ b/test/backend/services/test_model_health_service.py @@ -169,7 +169,6 @@ async def test_perform_connectivity_check_embedding(): api_key="test-key", embedding_dim=0, ssl_verify=True, - timeout_seconds=None, ) mock_embedding_instance.dimension_check.assert_called_once() @@ -199,7 +198,6 @@ async def test_perform_connectivity_check_multi_embedding(): api_key="test-key", embedding_dim=0, ssl_verify=True, - timeout_seconds=None, ) mock_embedding_instance.dimension_check.assert_called_once() From ad4a25f3945232f9c5272e4eca78b1dff649debd Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 11:19:00 +0800 Subject: [PATCH 35/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/test_model_health_service.py | 7 +-- test/backend/services/test_prompt_service.py | 61 +++++++++++++------ 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py index 7cc527372..b823c370b 100644 --- a/test/backend/services/test_model_health_service.py +++ b/test/backend/services/test_model_health_service.py @@ -446,9 +446,7 @@ async def test_check_model_connectivity_success(): "model123", {"connect_status": "available"}) mock_connectivity_check.assert_called_once_with( "openai/gpt-4", "llm", "https://api.openai.com", "test-key", True, - None, None, None, - display_name="GPT-4", - timeout_seconds=None, + None, None, None, "GPT-4", None, ) @@ -575,8 +573,7 @@ async def test_verify_model_config_connectivity_success(): mock_connectivity_check.assert_called_once_with( "gpt-4", "llm", "https://api.openai.com", "test-key", True, - None, None, None, "GPT-4", - timeout_seconds=None, + None, None, None, None, None, ) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 27df68d07..771699313 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -1,26 +1,51 @@ import json +import os import sys import unittest +from importlib.machinery import ModuleSpec from unittest.mock import patch, MagicMock from consts.error_code import ErrorCode from consts.exceptions import AppException -# Mock database submodules before any imports -# These must be mocked before importing modules that depend on them -database_mock = MagicMock() + +def _make_package_mock(module_name: str, module_dir: str) -> MagicMock: + """Create a MagicMock that behaves like a Python package with proper __path__.""" + mock = MagicMock() + mock.__path__ = [module_dir] + mock.__package__ = module_name + mock.__spec__ = ModuleSpec(module_name, None) + mock.__loader__ = None + return mock + + +# Setup paths for real modules +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../backend")) +database_dir = os.path.join(backend_dir, "database") +services_dir = os.path.join(backend_dir, "services") +utils_dir = os.path.join(backend_dir, "utils") +agents_dir = os.path.join(backend_dir, "agents") + +# Mock backend package and its sub-packages as proper packages +backend_mock = _make_package_mock("backend", backend_dir) +sys.modules['backend'] = backend_mock + +database_mock = _make_package_mock("database", database_dir) sys.modules['database'] = database_mock -sys.modules['database.model_management_db'] = MagicMock() -sys.modules['database.agent_db'] = MagicMock() -sys.modules['database.tool_db'] = MagicMock() -sys.modules['database.knowledge_db'] = MagicMock() -sys.modules['database.client'] = MagicMock() -sys.modules['database.attachment_db'] = MagicMock() -sys.modules['database.group_db'] = MagicMock() -sys.modules['database.user_tenant_db'] = MagicMock() -sys.modules['database.remote_mcp_db'] = MagicMock() -sys.modules['database.agent_version_db'] = MagicMock() -sys.modules['database.a2a_agent_db'] = MagicMock() -sys.modules['backend.database'] = MagicMock() +backend_mock.database = database_mock + +services_mock = _make_package_mock("services", services_dir) +sys.modules['services'] = services_mock + +utils_mock = _make_package_mock("utils", utils_dir) +sys.modules['utils'] = utils_mock +backend_mock.utils = utils_mock + +agents_mock = _make_package_mock("agents", agents_dir) +sys.modules['agents'] = agents_mock +backend_mock.agents = agents_mock + +# Mock backend.database submodules for patching +sys.modules['backend.database'] = database_mock sys.modules['backend.database.model_management_db'] = MagicMock() sys.modules['backend.database.agent_db'] = MagicMock() sys.modules['backend.database.tool_db'] = MagicMock() @@ -34,7 +59,6 @@ sys.modules['backend.database.a2a_agent_db'] = MagicMock() # Mock services submodules (NOT backend.services which blocks import of prompt_service) -sys.modules['services'] = MagicMock() sys.modules['services.agent_service'] = MagicMock() sys.modules['services.file_management_service'] = MagicMock() sys.modules['services.conversation_management_service'] = MagicMock() @@ -42,18 +66,15 @@ sys.modules['services.agent_version_service'] = MagicMock() # Mock agents submodules -sys.modules['agents'] = MagicMock() sys.modules['agents.create_agent_info'] = MagicMock() -sys.modules['backend.agents'] = MagicMock() sys.modules['backend.agents.create_agent_info'] = MagicMock() # Mock utils submodules to avoid llm_utils import triggering database connection -sys.modules['utils'] = MagicMock() sys.modules['utils.llm_utils'] = MagicMock() sys.modules['utils.prompt_template_utils'] = MagicMock() sys.modules['utils.config_utils'] = MagicMock() sys.modules['utils.auth_utils'] = MagicMock() -sys.modules['backend.utils'] = MagicMock() +sys.modules['utils.str_utils'] = MagicMock() sys.modules['backend.utils.llm_utils'] = MagicMock() sys.modules['backend.utils.prompt_template_utils'] = MagicMock() sys.modules['backend.utils.config_utils'] = MagicMock() From 81b6928f3875c2e28f58cfc099eeaeaf40c54dd1 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 11:48:09 +0800 Subject: [PATCH 36/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/model_management_service.py | 2 +- .../services/test_model_management_service.py | 194 ++++++++++++------ test/backend/services/test_prompt_service.py | 2 +- 3 files changed, 134 insertions(+), 64 deletions(-) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 1fce67b54..33b5e31db 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -27,7 +27,7 @@ ) from backend.utils.memory_utils import build_memory_config as build_memory_config_for_tenant from backend.services.vectordatabase_service import get_vector_db_core -from backend.nexent.memory.memory_service import clear_model_memories +from nexent.memory.memory_service import clear_model_memories logger = logging.getLogger("model_management_service") diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 6e504e90a..438598722 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -14,12 +14,48 @@ # Stub external modules required by consts.model before importing services -if "nexent" not in sys.modules: - sys.modules["nexent"] = mock.MagicMock() -if "nexent.core" not in sys.modules: - sys.modules["nexent.core"] = mock.MagicMock() -if "nexent.core.agents" not in sys.modules: - sys.modules["nexent.core.agents"] = mock.MagicMock() +# Use namespace packages that delegate to the real nexent package for proper submodule traversal +_real_nexent_base = os.path.abspath(os.path.join(current_dir, "../../../sdk/nexent")) + +nexent_pkg = types.ModuleType("nexent") +nexent_pkg.__path__ = [_real_nexent_base] +nexent_pkg.__name__ = "nexent" +sys.modules["nexent"] = nexent_pkg + +nexent_core_pkg = types.ModuleType("nexent.core") +nexent_core_pkg.__path__ = [os.path.join(_real_nexent_base, "core")] +nexent_core_pkg.__name__ = "nexent.core" +nexent_pkg.core = nexent_core_pkg + +# Add MessageObserver stub - it's imported from nexent.core +nexent_core_pkg.MessageObserver = type("MessageObserver", (), {}) + +sys.modules["nexent.core"] = nexent_core_pkg + +nexent_core_models_pkg = types.ModuleType("nexent.core.models") +nexent_core_models_pkg.__path__ = [os.path.join(_real_nexent_base, "core/models")] +nexent_core_models_pkg.__name__ = "nexent.core.models" +nexent_core_pkg.models = nexent_core_models_pkg + +# Import real models from the actual module and expose them +try: + from nexent.core.models import OpenAIModel, OpenAIVLModel + nexent_core_models_pkg.OpenAIModel = OpenAIModel + nexent_core_models_pkg.OpenAIVLModel = OpenAIVLModel +except ImportError: + # Fallback to stub classes if import fails + nexent_core_models_pkg.OpenAIModel = type("OpenAIModel", (), {}) + nexent_core_models_pkg.OpenAIVLModel = type("OpenAIVLModel", (), {}) + +sys.modules["nexent.core.models"] = nexent_core_models_pkg + +nexent_core_agents_pkg = types.ModuleType("nexent.core.agents") +nexent_core_agents_pkg.__path__ = [os.path.join(_real_nexent_base, "core/agents")] +nexent_core_agents_pkg.__name__ = "nexent.core.agents" +nexent_core_pkg.agents = nexent_core_agents_pkg +sys.modules["nexent.core.agents"] = nexent_core_agents_pkg + +# Stub nexent.core.agents.agent_model if "nexent.core.agents.agent_model" not in sys.modules: agent_model_mod = types.ModuleType("nexent.core.agents.agent_model") @@ -28,6 +64,7 @@ class ToolConfig: # minimal stub agent_model_mod.ToolConfig = ToolConfig sys.modules["nexent.core.agents.agent_model"] = agent_model_mod + nexent_core_agents_pkg.agent_model = agent_model_mod # Stub boto3 used by backend.database.client if "boto3" not in sys.modules: @@ -42,7 +79,32 @@ class _MinioClient: # minimal stub pass +def _as_dict(*args, **kwargs): + return {} + + +def _get_db_session(*args, **kwargs): + class _MockSession: + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def execute(self, *args, **kwargs): + class _Result: + rowcount = 0 + + return _Result() + + return _MockSession() + + backend_db_client_mod.MinioClient = _MinioClient +backend_db_client_mod.as_dict = _as_dict +backend_db_client_mod.get_db_session = _get_db_session +backend_db_client_mod.db_client = mock.MagicMock() +backend_db_client_mod.minio_client = mock.MagicMock() sys.modules["backend.database.client"] = backend_db_client_mod # Ensure parent package exposes the submodule attribute for import machinery @@ -58,76 +120,67 @@ class _MinioClient: # minimal stub # Also stub database.client.MinioClient in case modules import without the 'backend.' prefix database_client_mod = types.ModuleType("database.client") database_client_mod.MinioClient = _MinioClient +database_client_mod.as_dict = _as_dict +database_client_mod.get_db_session = _get_db_session +database_client_mod.db_client = mock.MagicMock() +database_client_mod.minio_client = mock.MagicMock() sys.modules["database.client"] = database_client_mod if "database" in sys.modules: setattr(sys.modules["database"], "client", database_client_mod) -# Stub consts.model to avoid deep dependencies -consts_model_mod = types.ModuleType("consts.model") - - -class _EnumItem: - def __init__(self, value: str): - self.value = value - - -class _ModelConnectStatusEnum: - OPERATIONAL = _EnumItem("operational") - NOT_DETECTED = _EnumItem("not_detected") - DETECTING = _EnumItem("detecting") - UNAVAILABLE = _EnumItem("unavailable") - - @staticmethod - def get_value(status): - return status or _ModelConnectStatusEnum.NOT_DETECTED.value - - -consts_model_mod.ModelConnectStatusEnum = _ModelConnectStatusEnum -sys.modules["consts.model"] = consts_model_mod -if "consts" not in sys.modules: - sys.modules["consts"] = types.ModuleType("consts") - -# Stub consts.const required by service -consts_const_mod = types.ModuleType("consts.const") -consts_const_mod.LOCALHOST_IP = "127.0.0.1" -consts_const_mod.LOCALHOST_NAME = "localhost" -consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal" -# Fields required by utils.memory_utils and services.vectordatabase_service -consts_const_mod.MODEL_CONFIG_MAPPING = { - "llm": "LLM_ID", "embedding": "EMBEDDING_ID"} -consts_const_mod.ES_HOST = "http://localhost:9200" -consts_const_mod.ES_API_KEY = "" -consts_const_mod.ES_USERNAME = "" -consts_const_mod.ES_PASSWORD = "" -sys.modules["consts.const"] = consts_const_mod +# Make consts a namespace package that can delegate to real backend/consts +_real_consts_path = os.path.abspath(os.path.join(current_dir, "../../../backend/consts")) +consts_pkg = types.ModuleType("consts") +consts_pkg.__path__ = [_real_consts_path] +consts_pkg.__name__ = "consts" +sys.modules["consts"] = consts_pkg + +# Import real consts.const from backend and register it +import importlib.util +_real_consts_spec = importlib.util.spec_from_file_location( + "consts.const", os.path.join(_real_consts_path, "const.py") +) +_real_consts_mod = importlib.util.module_from_spec(_real_consts_spec) +sys.modules["consts.const"] = _real_consts_mod +_real_consts_spec.loader.exec_module(_real_consts_mod) + +# Import real consts.model from backend and register it +_real_model_spec = importlib.util.spec_from_file_location( + "consts.model", os.path.join(_real_consts_path, "model.py") +) +_real_model_mod = importlib.util.module_from_spec(_real_model_spec) +sys.modules["consts.model"] = _real_model_mod +_real_model_spec.loader.exec_module(_real_model_mod) # Stub sqlalchemy.sql.func used by utils.config_utils -sqlalchemy_sql_mod = types.ModuleType("sqlalchemy.sql") +# Import the real module first to preserve SQLAlchemy's internal imports, +# then add the func attribute if it doesn't exist +import sqlalchemy.sql as _real_sql - -class _Func: - pass - - -sqlalchemy_sql_mod.func = _Func() -sys.modules["sqlalchemy.sql"] = sqlalchemy_sql_mod +if not hasattr(_real_sql, "func") or _real_sql.func is None: + _real_sql.func = types.ModuleType("sqlalchemy.sql.func") +sys.modules["sqlalchemy.sql"] = _real_sql +sys.modules["sqlalchemy.sql.func"] = _real_sql.func # Stub consts.provider used by service consts_provider_mod = types.ModuleType("consts.provider") class _ProviderEnum: - SILICON = _EnumItem("silicon") - MODELENGINE = _EnumItem("modelengine") - DASHSCOPE = _EnumItem("dashscope") - TOKENPONY = _EnumItem("tokenpony") + SILICON = "silicon" + MODELENGINE = "modelengine" + DASHSCOPE = "dashscope" + TOKENPONY = "tokenpony" consts_provider_mod.ProviderEnum = _ProviderEnum consts_provider_mod.SILICON_BASE_URL = "http://silicon.test" +consts_provider_mod.SILICON_GET_URL = "http://silicon.test/models" consts_provider_mod.DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/" +consts_provider_mod.DASHSCOPE_GET_URL = "https://dashscope.aliyuncs.com/api/v1/models" consts_provider_mod.TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1/" +consts_provider_mod.TOKENPONY_GET_URL = "https://api.tokenpony.cn/v1/models" sys.modules["consts.provider"] = consts_provider_mod # Stub services.model_provider_service used by service @@ -189,7 +242,10 @@ def _sort_models_by_id(model_list): sys.modules["utils.model_name_utils"] = utils_name_mod # Stub database.model_management_db to avoid importing heavy DB client +_real_database_path = os.path.abspath(os.path.join(current_dir, "../../../backend/database")) database_mod = types.ModuleType("database") +database_mod.__path__ = [_real_database_path] +database_mod.__name__ = "database" db_mm_mod = types.ModuleType("database.model_management_db") @@ -210,6 +266,12 @@ def _get_models_by_display_name(*args, **kwargs): return [] +def _get_model_by_name_factory(*args, **kwargs): + """Return None for model name factory lookups in tests.""" + return None + + +db_mm_mod.get_model_by_name_factory = _get_model_by_name_factory db_mm_mod.create_model_record = _noop db_mm_mod.delete_model_record = _noop db_mm_mod.get_model_by_display_name = _noop @@ -232,8 +294,10 @@ def _get_model_by_model_id(model_id: int, tenant_id: str): db_mm_mod.get_model_by_model_id = _get_model_by_model_id db_mm_mod.update_model_record = _noop +db_mm_mod.model_management_db = db_mm_mod sys.modules["database"] = database_mod sys.modules["database.model_management_db"] = db_mm_mod +sys.modules["backend.database.model_management_db"] = db_mm_mod # Stub database.tenant_config_db required by utils.config_utils db_tenant_cfg_mod = types.ModuleType("database.tenant_config_db") @@ -420,7 +484,8 @@ async def test_create_model_for_tenant_multi_embedding_creates_two_records(): with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ mock.patch.object(svc, "create_model_record") as mock_create, \ - mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): + mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=1536)): user_id = "u1" tenant_id = "t1" @@ -536,7 +601,7 @@ async def test_create_provider_models_for_tenant_success(): models = [{"id": "silicon/a"}, {"id": "silicon/b"}] with mock.patch.object(svc, "get_provider_models", new=mock.AsyncMock(return_value=models)) as mock_get, \ - mock.patch.object(svc, "merge_existing_model_tokens", return_value=models) as mock_merge, \ + mock.patch.object(svc, "merge_existing_model_attributes", return_value=models) as mock_merge, \ mock.patch.object(svc, "sort_models_by_id", side_effect=lambda m: m) as mock_sort: out = await svc.create_provider_models_for_tenant("t1", req) @@ -867,18 +932,23 @@ async def test_batch_update_models_for_tenant_success(): svc = import_svc() models = [{"model_id": "a"}, {"model_id": "b"}] - with mock.patch.object(svc, "update_model_record") as mock_update: + # Mock get_model_by_name_factory to return valid model records + mock_model_record = {"model_id": 1} + with mock.patch.object(svc, "update_model_record") as mock_update, \ + mock.patch("backend.database.model_management_db.get_model_by_name_factory", return_value=mock_model_record): await svc.batch_update_models_for_tenant("u1", "t1", models) assert mock_update.call_count == 2 - mock_update.assert_any_call("a", models[0], "u1", "t1") - mock_update.assert_any_call("b", models[1], "u1", "t1") + # update_data is models[i] with model_id/model_name excluded -> {} + mock_update.assert_any_call(1, {}, "u1", "t1") async def test_batch_update_models_for_tenant_exception(): svc = import_svc() models = [{"model_id": "a"}] - with mock.patch.object(svc, "update_model_record", side_effect=Exception("oops")): + mock_model_record = {"model_id": 1} + with mock.patch.object(svc, "update_model_record", side_effect=Exception("oops")), \ + mock.patch("backend.database.model_management_db.get_model_by_name_factory", return_value=mock_model_record): with pytest.raises(Exception) as exc: await svc.batch_update_models_for_tenant("u1", "t1", models) assert "Failed to batch update models" in str(exc.value) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 771699313..e0502dc83 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -749,7 +749,7 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): expected_data = f"data: {json.dumps({'success': True, 'data': test_data[i]}, ensure_ascii=False)}\n\n" self.assertEqual(result, expected_data) - @patch('database.model_management_db.get_model_by_model_id') + @patch('backend.services.prompt_service.get_model_by_model_id', create=True) @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') From d13c107266eea8e64ad51f691e97cd06ccd5f81c Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 12:42:58 +0800 Subject: [PATCH 37/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/model_management_service.py | 3 +-- test/backend/services/test_model_management_service.py | 5 ++--- test/backend/services/test_prompt_service.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 33b5e31db..5cfc0a59e 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -12,6 +12,7 @@ get_models_by_display_name, get_model_records, get_models_by_tenant_factory_type, + get_model_by_name_factory, update_model_record, ) from backend.services.model_provider_service import ( @@ -292,8 +293,6 @@ async def update_single_model_for_tenant( async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_list: List[Dict[str, Any]]): """Batch update models for a tenant by model_id or model_name.""" - from backend.database.model_management_db import get_model_by_name_factory - try: for model in model_list: # Build update data excluding id fields diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 438598722..3bc2a7c89 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -932,10 +932,9 @@ async def test_batch_update_models_for_tenant_success(): svc = import_svc() models = [{"model_id": "a"}, {"model_id": "b"}] - # Mock get_model_by_name_factory to return valid model records mock_model_record = {"model_id": 1} with mock.patch.object(svc, "update_model_record") as mock_update, \ - mock.patch("backend.database.model_management_db.get_model_by_name_factory", return_value=mock_model_record): + mock.patch.object(svc, "get_model_by_name_factory", return_value=mock_model_record): await svc.batch_update_models_for_tenant("u1", "t1", models) assert mock_update.call_count == 2 # update_data is models[i] with model_id/model_name excluded -> {} @@ -948,7 +947,7 @@ async def test_batch_update_models_for_tenant_exception(): models = [{"model_id": "a"}] mock_model_record = {"model_id": 1} with mock.patch.object(svc, "update_model_record", side_effect=Exception("oops")), \ - mock.patch("backend.database.model_management_db.get_model_by_name_factory", return_value=mock_model_record): + mock.patch.object(svc, "get_model_by_name_factory", return_value=mock_model_record): with pytest.raises(Exception) as exc: await svc.batch_update_models_for_tenant("u1", "t1", models) assert "Failed to batch update models" in str(exc.value) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index e0502dc83..771699313 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -749,7 +749,7 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): expected_data = f"data: {json.dumps({'success': True, 'data': test_data[i]}, ensure_ascii=False)}\n\n" self.assertEqual(result, expected_data) - @patch('backend.services.prompt_service.get_model_by_model_id', create=True) + @patch('database.model_management_db.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') From 42b28b8f60a8c19440bfca95e4f0d5bd9ee12581 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 19 May 2026 14:09:24 +0800 Subject: [PATCH 38/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/prompt_service.py | 2 +- test/backend/services/test_prompt_service.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index 4f26eb40a..a34922706 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -12,6 +12,7 @@ from consts.exceptions import AppException from database.agent_db import search_agent_info_by_agent_id, query_all_agent_info_by_tenant_id, \ query_sub_agents_id_list +from database.model_management_db import get_model_by_model_id from database.knowledge_db import get_knowledge_name_map_by_index_names from database.tool_db import query_tools_by_ids, query_tool_instances_by_id from services.agent_service import ( @@ -360,7 +361,6 @@ def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list # Get model concurrency limit to control the number of concurrent LLM calls # If None or >= 6, no limit (all 6 calls run concurrently) # If < 6, use semaphore to limit concurrent calls - from database.model_management_db import get_model_by_model_id model_config = get_model_by_model_id(model_id, tenant_id) concurrency_limit = model_config.get("concurrency_limit") if model_config else None diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 771699313..040a4c9db 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -749,7 +749,7 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): expected_data = f"data: {json.dumps({'success': True, 'data': test_data[i]}, ensure_ascii=False)}\n\n" self.assertEqual(result, expected_data) - @patch('database.model_management_db.get_model_by_model_id') + @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') @@ -869,7 +869,7 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): self.assertIsInstance(result["is_complete"], bool) self.assertIsInstance(result["content"], str) - @patch('database.model_management_db.get_model_by_model_id') + @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') @@ -1169,7 +1169,7 @@ def mock_gen(*args, **kwargs): self.assertIn("Failed to generate prompt content", str(context.exception)) - @patch('database.model_management_db.get_model_by_model_id') + @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') @@ -1222,7 +1222,7 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): self.assertIn("LLM connection error", str(context.exception)) - @patch('database.model_management_db.get_model_by_model_id') + @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') From 4f3c6c679da9c10158ec90b9555fcec8ff2412f6 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Wed, 20 May 2026 17:05:54 +0800 Subject: [PATCH 39/48] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E5=AE=BF=E4=B8=BB=E6=9C=BA=E7=9A=84=E8=AF=81=E4=B9=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker/docker-compose.prod.yml | 2 -- docker/docker-compose.yml | 2 -- 2 files changed, 4 deletions(-) diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 3cc7ac59a..934fe8b2f 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -78,8 +78,6 @@ services: - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro - ${ROOT_DIR}/scripts/sync_user_supabase2pg.py:/opt/sync_user_supabase2pg.py:ro - /var/run/docker.sock:/var/run/docker.sock:ro # Docker socket for MCP container management - # CA certificates for external service SSL verification (e.g., SMTP) - - /etc/ssl/certs:/etc/ssl/certs:ro environment: <<: [*minio-vars, *es-vars] skip_proxy: "true" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 4056683dc..89088f2c3 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -89,8 +89,6 @@ services: - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro - ${ROOT_DIR}/scripts/sync_user_supabase2pg.py:/opt/sync_user_supabase2pg.py:ro - /var/run/docker.sock:/var/run/docker.sock:ro # Docker socket for MCP container management - # CA certificates for external service SSL verification (e.g., SMTP) - - /etc/ssl/certs:/etc/ssl/certs:ro environment: <<: [*minio-vars, *es-vars] skip_proxy: "true" From f6eefe5a318f5860402723fd7aad587be177db0d Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Wed, 20 May 2026 17:09:04 +0800 Subject: [PATCH 40/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9sql?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker/init.sql | 4 ++++ ....0_0520_add_concurrency_and_timeout_to_model_record_t.sql} | 0 k8s/helm/nexent/charts/nexent-common/files/init.sql | 4 ++++ 3 files changed, 8 insertions(+) rename docker/sql/{v2.1.1_0507_add_concurrency_and_timeout_to_model_record_t.sql => v2.2.0_0520_add_concurrency_and_timeout_to_model_record_t.sql} (100%) diff --git a/docker/init.sql b/docker/init.sql index ed45026ab..aadaa044b 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -177,6 +177,8 @@ CREATE TABLE IF NOT EXISTS "model_record_t" ( "tenant_id" varchar(100) COLLATE "pg_catalog"."default" DEFAULT 'tenant_id', "model_appid" varchar(100) COLLATE "pg_catalog"."default" DEFAULT '', "access_token" varchar(100) COLLATE "pg_catalog"."default" DEFAULT '', + "concurrency_limit" INTEGER DEFAULT NULL, + "timeout_seconds" INTEGER DEFAULT 120, CONSTRAINT "nexent_models_t_pk" PRIMARY KEY ("model_id") ); ALTER TABLE "model_record_t" OWNER TO "root"; @@ -202,6 +204,8 @@ COMMENT ON COLUMN "model_record_t"."created_by" IS 'Creator ID, audit field'; COMMENT ON COLUMN "model_record_t"."tenant_id" IS 'Tenant ID for filtering'; COMMENT ON COLUMN "model_record_t"."model_appid" IS 'Application ID for model authentication.'; COMMENT ON COLUMN "model_record_t"."access_token" IS 'Access token for model authentication.'; +COMMENT ON COLUMN "model_record_t"."concurrency_limit" IS 'Maximum concurrent requests for this model. Default is NULL (unlimited).'; +COMMENT ON COLUMN "model_record_t"."timeout_seconds" IS 'Request timeout in seconds for this model. Default is 120 seconds.'; COMMENT ON TABLE "model_record_t" IS 'List of models defined by users in the configuration page'; INSERT INTO "nexent"."model_record_t" ("model_repo", "model_name", "model_factory", "model_type", "api_key", "base_url", "max_tokens", "used_token", "display_name", "connect_status") VALUES ('', 'volcano_tts', 'OpenAI-API-Compatible', 'tts', '', '', 0, 0, 'volcano_tts', 'unavailable'); diff --git a/docker/sql/v2.1.1_0507_add_concurrency_and_timeout_to_model_record_t.sql b/docker/sql/v2.2.0_0520_add_concurrency_and_timeout_to_model_record_t.sql similarity index 100% rename from docker/sql/v2.1.1_0507_add_concurrency_and_timeout_to_model_record_t.sql rename to docker/sql/v2.2.0_0520_add_concurrency_and_timeout_to_model_record_t.sql diff --git a/k8s/helm/nexent/charts/nexent-common/files/init.sql b/k8s/helm/nexent/charts/nexent-common/files/init.sql index 453a7dcbb..9558f5afd 100644 --- a/k8s/helm/nexent/charts/nexent-common/files/init.sql +++ b/k8s/helm/nexent/charts/nexent-common/files/init.sql @@ -177,6 +177,8 @@ CREATE TABLE IF NOT EXISTS "model_record_t" ( "tenant_id" varchar(100) COLLATE "pg_catalog"."default" DEFAULT 'tenant_id', "model_appid" varchar(100) COLLATE "pg_catalog"."default" DEFAULT '', "access_token" varchar(100) COLLATE "pg_catalog"."default" DEFAULT '', + "concurrency_limit" INTEGER DEFAULT NULL, + "timeout_seconds" INTEGER DEFAULT 120, CONSTRAINT "nexent_models_t_pk" PRIMARY KEY ("model_id") ); ALTER TABLE "model_record_t" OWNER TO "root"; @@ -202,6 +204,8 @@ COMMENT ON COLUMN "model_record_t"."created_by" IS 'Creator ID, audit field'; COMMENT ON COLUMN "model_record_t"."tenant_id" IS 'Tenant ID for filtering'; COMMENT ON COLUMN "model_record_t"."model_appid" IS 'Application ID for model authentication.'; COMMENT ON COLUMN "model_record_t"."access_token" IS 'Access token for model authentication.'; +COMMENT ON COLUMN "model_record_t"."concurrency_limit" IS 'Maximum concurrent requests for this model. Default is NULL (unlimited).'; +COMMENT ON COLUMN "model_record_t"."timeout_seconds" IS 'Request timeout in seconds for this model. Default is 120 seconds.'; COMMENT ON TABLE "model_record_t" IS 'List of models defined by users in the configuration page'; INSERT INTO "nexent"."model_record_t" ("model_repo", "model_name", "model_factory", "model_type", "api_key", "base_url", "max_tokens", "used_token", "display_name", "connect_status") VALUES ('', 'volcano_tts', 'OpenAI-API-Compatible', 'tts', '', '', 0, 0, 'volcano_tts', 'unavailable'); From 05226b3468e637060cf89ff1ddc501de462c66bf Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Wed, 20 May 2026 17:17:45 +0800 Subject: [PATCH 41/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9UT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/services/test_prompt_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 7aec9e0de..b25baf708 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -755,8 +755,8 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt(self, mock_get_prompt_template, mock_join_info, mock_call_llm, mock_get_model): + @patch('backend.services.prompt_service.resolve_prompt_generate_template') + def test_generate_system_prompt(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm, mock_get_model): # Mock model config to avoid concurrency limit issue mock_get_model.return_value = {"concurrency_limit": None} # Setup @@ -881,8 +881,8 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt_with_exception(self, mock_get_prompt_template, mock_join_info, mock_call_llm, mock_get_model): + @patch('backend.services.prompt_service.resolve_prompt_generate_template') + def test_generate_system_prompt_with_exception(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm, mock_get_model): # Mock model config to avoid concurrency limit issue mock_get_model.return_value = {"concurrency_limit": None} # Setup From 90a41f031facd97e0248526c717a57ac2a61a953 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 09:26:40 +0800 Subject: [PATCH 42/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E6=98=8A=E5=A4=A9=E5=85=AC=E5=85=B1=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?Id?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/haotian_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/services/haotian_service.py b/backend/services/haotian_service.py index e7f762244..4d86823b5 100644 --- a/backend/services/haotian_service.py +++ b/backend/services/haotian_service.py @@ -11,7 +11,7 @@ logger = logging.getLogger("haotian_service") -_DEFAULT_KNOWLEDGE_BASE_ID = "abcdefg" +_DEFAULT_KNOWLEDGE_BASE_ID = "a8d68fbf-bd6e-5461-a9d1-cf1bb3522e38" def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: From b073df43e86da7d216f0e4d72a390b09b196f462 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 11:49:00 +0800 Subject: [PATCH 43/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/model_management_service.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 5cfc0a59e..d46178d50 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -1,33 +1,33 @@ import logging from typing import List, Dict, Any, Optional -from backend.consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST -from backend.consts.model import ModelConnectStatusEnum -from backend.consts.provider import ProviderEnum, SILICON_BASE_URL, DASHSCOPE_BASE_URL, TOKENPONY_BASE_URL +from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST +from consts.model import ModelConnectStatusEnum +from consts.provider import ProviderEnum, SILICON_BASE_URL, DASHSCOPE_BASE_URL, TOKENPONY_BASE_URL -from backend.database.model_management_db import ( +from database.model_management_db import ( create_model_record, delete_model_record, get_model_by_display_name, get_models_by_display_name, get_model_records, get_models_by_tenant_factory_type, - get_model_by_name_factory, update_model_record, + update_model_record_by_model_name, ) -from backend.services.model_provider_service import ( +from services.model_provider_service import ( prepare_model_dict, merge_existing_model_attributes, get_provider_models, ) -from backend.services.model_health_service import embedding_dimension_check -from backend.utils.model_name_utils import ( +from services.model_health_service import embedding_dimension_check +from utils.model_name_utils import ( add_repo_to_name, split_repo_name, sort_models_by_id, ) -from backend.utils.memory_utils import build_memory_config as build_memory_config_for_tenant -from backend.services.vectordatabase_service import get_vector_db_core +from utils.memory_utils import build_memory_config as build_memory_config_for_tenant +from services.vectordatabase_service import get_vector_db_core from nexent.memory.memory_service import clear_model_memories logger = logging.getLogger("model_management_service") From 9e05f77ee9169a02374ade6d77610354d0672094 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 14:51:39 +0800 Subject: [PATCH 44/48] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/model_management_db.py | 2 +- .../services/test_model_management_service.py | 193 ++++++------------ test/backend/services/test_prompt_service.py | 95 +-------- test/sdk/core/models/test_embedding_model.py | 73 ++++--- test/sdk/core/models/test_openai_llm.py | 22 +- 5 files changed, 110 insertions(+), 275 deletions(-) diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index d260ec48c..6ae401931 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -3,7 +3,7 @@ from sqlalchemy import and_, desc, func, insert, select, update -from backend.consts.const import DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE +from consts.const import DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE from .client import as_dict, db_client, get_db_session from .db_models import ModelRecord from .utils import add_creation_tracking, add_update_tracking diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 3bc2a7c89..6e504e90a 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -14,48 +14,12 @@ # Stub external modules required by consts.model before importing services -# Use namespace packages that delegate to the real nexent package for proper submodule traversal -_real_nexent_base = os.path.abspath(os.path.join(current_dir, "../../../sdk/nexent")) - -nexent_pkg = types.ModuleType("nexent") -nexent_pkg.__path__ = [_real_nexent_base] -nexent_pkg.__name__ = "nexent" -sys.modules["nexent"] = nexent_pkg - -nexent_core_pkg = types.ModuleType("nexent.core") -nexent_core_pkg.__path__ = [os.path.join(_real_nexent_base, "core")] -nexent_core_pkg.__name__ = "nexent.core" -nexent_pkg.core = nexent_core_pkg - -# Add MessageObserver stub - it's imported from nexent.core -nexent_core_pkg.MessageObserver = type("MessageObserver", (), {}) - -sys.modules["nexent.core"] = nexent_core_pkg - -nexent_core_models_pkg = types.ModuleType("nexent.core.models") -nexent_core_models_pkg.__path__ = [os.path.join(_real_nexent_base, "core/models")] -nexent_core_models_pkg.__name__ = "nexent.core.models" -nexent_core_pkg.models = nexent_core_models_pkg - -# Import real models from the actual module and expose them -try: - from nexent.core.models import OpenAIModel, OpenAIVLModel - nexent_core_models_pkg.OpenAIModel = OpenAIModel - nexent_core_models_pkg.OpenAIVLModel = OpenAIVLModel -except ImportError: - # Fallback to stub classes if import fails - nexent_core_models_pkg.OpenAIModel = type("OpenAIModel", (), {}) - nexent_core_models_pkg.OpenAIVLModel = type("OpenAIVLModel", (), {}) - -sys.modules["nexent.core.models"] = nexent_core_models_pkg - -nexent_core_agents_pkg = types.ModuleType("nexent.core.agents") -nexent_core_agents_pkg.__path__ = [os.path.join(_real_nexent_base, "core/agents")] -nexent_core_agents_pkg.__name__ = "nexent.core.agents" -nexent_core_pkg.agents = nexent_core_agents_pkg -sys.modules["nexent.core.agents"] = nexent_core_agents_pkg - -# Stub nexent.core.agents.agent_model +if "nexent" not in sys.modules: + sys.modules["nexent"] = mock.MagicMock() +if "nexent.core" not in sys.modules: + sys.modules["nexent.core"] = mock.MagicMock() +if "nexent.core.agents" not in sys.modules: + sys.modules["nexent.core.agents"] = mock.MagicMock() if "nexent.core.agents.agent_model" not in sys.modules: agent_model_mod = types.ModuleType("nexent.core.agents.agent_model") @@ -64,7 +28,6 @@ class ToolConfig: # minimal stub agent_model_mod.ToolConfig = ToolConfig sys.modules["nexent.core.agents.agent_model"] = agent_model_mod - nexent_core_agents_pkg.agent_model = agent_model_mod # Stub boto3 used by backend.database.client if "boto3" not in sys.modules: @@ -79,32 +42,7 @@ class _MinioClient: # minimal stub pass -def _as_dict(*args, **kwargs): - return {} - - -def _get_db_session(*args, **kwargs): - class _MockSession: - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def execute(self, *args, **kwargs): - class _Result: - rowcount = 0 - - return _Result() - - return _MockSession() - - backend_db_client_mod.MinioClient = _MinioClient -backend_db_client_mod.as_dict = _as_dict -backend_db_client_mod.get_db_session = _get_db_session -backend_db_client_mod.db_client = mock.MagicMock() -backend_db_client_mod.minio_client = mock.MagicMock() sys.modules["backend.database.client"] = backend_db_client_mod # Ensure parent package exposes the submodule attribute for import machinery @@ -120,67 +58,76 @@ class _Result: # Also stub database.client.MinioClient in case modules import without the 'backend.' prefix database_client_mod = types.ModuleType("database.client") database_client_mod.MinioClient = _MinioClient -database_client_mod.as_dict = _as_dict -database_client_mod.get_db_session = _get_db_session -database_client_mod.db_client = mock.MagicMock() -database_client_mod.minio_client = mock.MagicMock() sys.modules["database.client"] = database_client_mod if "database" in sys.modules: setattr(sys.modules["database"], "client", database_client_mod) -# Make consts a namespace package that can delegate to real backend/consts -_real_consts_path = os.path.abspath(os.path.join(current_dir, "../../../backend/consts")) -consts_pkg = types.ModuleType("consts") -consts_pkg.__path__ = [_real_consts_path] -consts_pkg.__name__ = "consts" -sys.modules["consts"] = consts_pkg - -# Import real consts.const from backend and register it -import importlib.util -_real_consts_spec = importlib.util.spec_from_file_location( - "consts.const", os.path.join(_real_consts_path, "const.py") -) -_real_consts_mod = importlib.util.module_from_spec(_real_consts_spec) -sys.modules["consts.const"] = _real_consts_mod -_real_consts_spec.loader.exec_module(_real_consts_mod) - -# Import real consts.model from backend and register it -_real_model_spec = importlib.util.spec_from_file_location( - "consts.model", os.path.join(_real_consts_path, "model.py") -) -_real_model_mod = importlib.util.module_from_spec(_real_model_spec) -sys.modules["consts.model"] = _real_model_mod -_real_model_spec.loader.exec_module(_real_model_mod) +# Stub consts.model to avoid deep dependencies +consts_model_mod = types.ModuleType("consts.model") + + +class _EnumItem: + def __init__(self, value: str): + self.value = value + + +class _ModelConnectStatusEnum: + OPERATIONAL = _EnumItem("operational") + NOT_DETECTED = _EnumItem("not_detected") + DETECTING = _EnumItem("detecting") + UNAVAILABLE = _EnumItem("unavailable") + + @staticmethod + def get_value(status): + return status or _ModelConnectStatusEnum.NOT_DETECTED.value + + +consts_model_mod.ModelConnectStatusEnum = _ModelConnectStatusEnum +sys.modules["consts.model"] = consts_model_mod +if "consts" not in sys.modules: + sys.modules["consts"] = types.ModuleType("consts") + +# Stub consts.const required by service +consts_const_mod = types.ModuleType("consts.const") +consts_const_mod.LOCALHOST_IP = "127.0.0.1" +consts_const_mod.LOCALHOST_NAME = "localhost" +consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal" +# Fields required by utils.memory_utils and services.vectordatabase_service +consts_const_mod.MODEL_CONFIG_MAPPING = { + "llm": "LLM_ID", "embedding": "EMBEDDING_ID"} +consts_const_mod.ES_HOST = "http://localhost:9200" +consts_const_mod.ES_API_KEY = "" +consts_const_mod.ES_USERNAME = "" +consts_const_mod.ES_PASSWORD = "" +sys.modules["consts.const"] = consts_const_mod # Stub sqlalchemy.sql.func used by utils.config_utils -# Import the real module first to preserve SQLAlchemy's internal imports, -# then add the func attribute if it doesn't exist -import sqlalchemy.sql as _real_sql +sqlalchemy_sql_mod = types.ModuleType("sqlalchemy.sql") -if not hasattr(_real_sql, "func") or _real_sql.func is None: - _real_sql.func = types.ModuleType("sqlalchemy.sql.func") -sys.modules["sqlalchemy.sql"] = _real_sql -sys.modules["sqlalchemy.sql.func"] = _real_sql.func + +class _Func: + pass + + +sqlalchemy_sql_mod.func = _Func() +sys.modules["sqlalchemy.sql"] = sqlalchemy_sql_mod # Stub consts.provider used by service consts_provider_mod = types.ModuleType("consts.provider") class _ProviderEnum: - SILICON = "silicon" - MODELENGINE = "modelengine" - DASHSCOPE = "dashscope" - TOKENPONY = "tokenpony" + SILICON = _EnumItem("silicon") + MODELENGINE = _EnumItem("modelengine") + DASHSCOPE = _EnumItem("dashscope") + TOKENPONY = _EnumItem("tokenpony") consts_provider_mod.ProviderEnum = _ProviderEnum consts_provider_mod.SILICON_BASE_URL = "http://silicon.test" -consts_provider_mod.SILICON_GET_URL = "http://silicon.test/models" consts_provider_mod.DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/" -consts_provider_mod.DASHSCOPE_GET_URL = "https://dashscope.aliyuncs.com/api/v1/models" consts_provider_mod.TOKENPONY_BASE_URL = "https://api.tokenpony.cn/v1/" -consts_provider_mod.TOKENPONY_GET_URL = "https://api.tokenpony.cn/v1/models" sys.modules["consts.provider"] = consts_provider_mod # Stub services.model_provider_service used by service @@ -242,10 +189,7 @@ def _sort_models_by_id(model_list): sys.modules["utils.model_name_utils"] = utils_name_mod # Stub database.model_management_db to avoid importing heavy DB client -_real_database_path = os.path.abspath(os.path.join(current_dir, "../../../backend/database")) database_mod = types.ModuleType("database") -database_mod.__path__ = [_real_database_path] -database_mod.__name__ = "database" db_mm_mod = types.ModuleType("database.model_management_db") @@ -266,12 +210,6 @@ def _get_models_by_display_name(*args, **kwargs): return [] -def _get_model_by_name_factory(*args, **kwargs): - """Return None for model name factory lookups in tests.""" - return None - - -db_mm_mod.get_model_by_name_factory = _get_model_by_name_factory db_mm_mod.create_model_record = _noop db_mm_mod.delete_model_record = _noop db_mm_mod.get_model_by_display_name = _noop @@ -294,10 +232,8 @@ def _get_model_by_model_id(model_id: int, tenant_id: str): db_mm_mod.get_model_by_model_id = _get_model_by_model_id db_mm_mod.update_model_record = _noop -db_mm_mod.model_management_db = db_mm_mod sys.modules["database"] = database_mod sys.modules["database.model_management_db"] = db_mm_mod -sys.modules["backend.database.model_management_db"] = db_mm_mod # Stub database.tenant_config_db required by utils.config_utils db_tenant_cfg_mod = types.ModuleType("database.tenant_config_db") @@ -484,8 +420,7 @@ async def test_create_model_for_tenant_multi_embedding_creates_two_records(): with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ mock.patch.object(svc, "create_model_record") as mock_create, \ - mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")), \ - mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=1536)): + mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): user_id = "u1" tenant_id = "t1" @@ -601,7 +536,7 @@ async def test_create_provider_models_for_tenant_success(): models = [{"id": "silicon/a"}, {"id": "silicon/b"}] with mock.patch.object(svc, "get_provider_models", new=mock.AsyncMock(return_value=models)) as mock_get, \ - mock.patch.object(svc, "merge_existing_model_attributes", return_value=models) as mock_merge, \ + mock.patch.object(svc, "merge_existing_model_tokens", return_value=models) as mock_merge, \ mock.patch.object(svc, "sort_models_by_id", side_effect=lambda m: m) as mock_sort: out = await svc.create_provider_models_for_tenant("t1", req) @@ -932,22 +867,18 @@ async def test_batch_update_models_for_tenant_success(): svc = import_svc() models = [{"model_id": "a"}, {"model_id": "b"}] - mock_model_record = {"model_id": 1} - with mock.patch.object(svc, "update_model_record") as mock_update, \ - mock.patch.object(svc, "get_model_by_name_factory", return_value=mock_model_record): + with mock.patch.object(svc, "update_model_record") as mock_update: await svc.batch_update_models_for_tenant("u1", "t1", models) assert mock_update.call_count == 2 - # update_data is models[i] with model_id/model_name excluded -> {} - mock_update.assert_any_call(1, {}, "u1", "t1") + mock_update.assert_any_call("a", models[0], "u1", "t1") + mock_update.assert_any_call("b", models[1], "u1", "t1") async def test_batch_update_models_for_tenant_exception(): svc = import_svc() models = [{"model_id": "a"}] - mock_model_record = {"model_id": 1} - with mock.patch.object(svc, "update_model_record", side_effect=Exception("oops")), \ - mock.patch.object(svc, "get_model_by_name_factory", return_value=mock_model_record): + with mock.patch.object(svc, "update_model_record", side_effect=Exception("oops")): with pytest.raises(Exception) as exc: await svc.batch_update_models_for_tenant("u1", "t1", models) assert "Failed to batch update models" in str(exc.value) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index b25baf708..522a850f0 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -1,86 +1,11 @@ import json -import os -import sys import unittest -from importlib.machinery import ModuleSpec from unittest.mock import patch, MagicMock from consts.error_code import ErrorCode from consts.exceptions import AppException - -def _make_package_mock(module_name: str, module_dir: str) -> MagicMock: - """Create a MagicMock that behaves like a Python package with proper __path__.""" - mock = MagicMock() - mock.__path__ = [module_dir] - mock.__package__ = module_name - mock.__spec__ = ModuleSpec(module_name, None) - mock.__loader__ = None - return mock - - -# Setup paths for real modules -backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../backend")) -database_dir = os.path.join(backend_dir, "database") -services_dir = os.path.join(backend_dir, "services") -utils_dir = os.path.join(backend_dir, "utils") -agents_dir = os.path.join(backend_dir, "agents") - -# Mock backend package and its sub-packages as proper packages -backend_mock = _make_package_mock("backend", backend_dir) -sys.modules['backend'] = backend_mock - -database_mock = _make_package_mock("database", database_dir) -sys.modules['database'] = database_mock -backend_mock.database = database_mock - -services_mock = _make_package_mock("services", services_dir) -sys.modules['services'] = services_mock - -utils_mock = _make_package_mock("utils", utils_dir) -sys.modules['utils'] = utils_mock -backend_mock.utils = utils_mock - -agents_mock = _make_package_mock("agents", agents_dir) -sys.modules['agents'] = agents_mock -backend_mock.agents = agents_mock - -# Mock backend.database submodules for patching -sys.modules['backend.database'] = database_mock -sys.modules['backend.database.model_management_db'] = MagicMock() -sys.modules['backend.database.agent_db'] = MagicMock() -sys.modules['backend.database.tool_db'] = MagicMock() -sys.modules['backend.database.knowledge_db'] = MagicMock() -sys.modules['backend.database.client'] = MagicMock() -sys.modules['backend.database.attachment_db'] = MagicMock() -sys.modules['backend.database.group_db'] = MagicMock() -sys.modules['backend.database.user_tenant_db'] = MagicMock() -sys.modules['backend.database.remote_mcp_db'] = MagicMock() -sys.modules['backend.database.agent_version_db'] = MagicMock() -sys.modules['backend.database.a2a_agent_db'] = MagicMock() - -# Mock services submodules (NOT backend.services which blocks import of prompt_service) -sys.modules['services.agent_service'] = MagicMock() -sys.modules['services.file_management_service'] = MagicMock() -sys.modules['services.conversation_management_service'] = MagicMock() -sys.modules['services.memory_config_service'] = MagicMock() -sys.modules['services.agent_version_service'] = MagicMock() - -# Mock agents submodules -sys.modules['agents.create_agent_info'] = MagicMock() -sys.modules['backend.agents.create_agent_info'] = MagicMock() - -# Mock utils submodules to avoid llm_utils import triggering database connection -sys.modules['utils.llm_utils'] = MagicMock() -sys.modules['utils.prompt_template_utils'] = MagicMock() -sys.modules['utils.config_utils'] = MagicMock() -sys.modules['utils.auth_utils'] = MagicMock() -sys.modules['utils.str_utils'] = MagicMock() -sys.modules['backend.utils.llm_utils'] = MagicMock() -sys.modules['backend.utils.prompt_template_utils'] = MagicMock() -sys.modules['backend.utils.config_utils'] = MagicMock() -sys.modules['backend.utils.auth_utils'] = MagicMock() - # Mock boto3 and minio client before importing the module under test +import sys boto3_mock = MagicMock() sys.modules['boto3'] = boto3_mock @@ -752,13 +677,10 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): expected_data = f"data: {json.dumps({'success': True, 'data': test_data[i]}, ensure_ascii=False)}\n\n" self.assertEqual(result, expected_data) - @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') - def test_generate_system_prompt(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm, mock_get_model): - # Mock model config to avoid concurrency limit issue - mock_get_model.return_value = {"concurrency_limit": None} + def test_generate_system_prompt(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup mock_prompt_config = { "user_prompt": "Test user prompt template", @@ -878,13 +800,10 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): self.assertIsInstance(result["is_complete"], bool) self.assertIsInstance(result["content"], str) - @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') - def test_generate_system_prompt_with_exception(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm, mock_get_model): - # Mock model config to avoid concurrency limit issue - mock_get_model.return_value = {"concurrency_limit": None} + def test_generate_system_prompt_with_exception(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup mock_prompt_config = { "user_prompt": "Test user prompt template", @@ -1179,7 +1098,6 @@ def mock_gen(*args, **kwargs): self.assertIn("Failed to generate prompt content", str(context.exception)) - @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') @@ -1188,10 +1106,7 @@ def test_generate_system_prompt_error_before_streaming( mock_resolve_prompt_template, mock_join_info, mock_call_llm, - mock_get_model, ): - # Mock model config to avoid concurrency limit issue - mock_get_model.return_value = {"concurrency_limit": None} """Test generate_system_prompt handles error that occurs before streaming (line 307-311)""" # Setup mock_prompt_config = { @@ -1233,7 +1148,6 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): self.assertIn("LLM connection error", str(context.exception)) - @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') @@ -1242,10 +1156,7 @@ def test_generate_system_prompt_error_during_streaming( mock_resolve_prompt_template, mock_join_info, mock_call_llm, - mock_get_model, ): - # Mock model config to avoid concurrency limit issue - mock_get_model.return_value = {"concurrency_limit": None} """Test generate_system_prompt handles error that occurs during streaming (line 330-331)""" # Setup mock_prompt_config = { diff --git a/test/sdk/core/models/test_embedding_model.py b/test/sdk/core/models/test_embedding_model.py index c7856f46f..9c3f8824b 100644 --- a/test/sdk/core/models/test_embedding_model.py +++ b/test/sdk/core/models/test_embedding_model.py @@ -3,7 +3,7 @@ import sys from unittest.mock import AsyncMock, Mock, patch -from sdk.nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding +from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding class DummyResponse: def __init__(self, status_code=200, json_data=None): @@ -54,7 +54,7 @@ async def test_dimension_check_success(openai_embedding_instance): expected_embeddings = [[0.1, 0.2, 0.3]] with patch( - "sdk.nexent.core.models.embedding_model.asyncio.to_thread", + "nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, return_value=expected_embeddings, ) as mock_to_thread: @@ -69,7 +69,7 @@ async def test_dimension_check_failure(openai_embedding_instance): """dimension_check should return an empty list when an exception is raised inside to_thread.""" with patch( - "sdk.nexent.core.models.embedding_model.asyncio.to_thread", + "nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("connection error"), ) as mock_to_thread: @@ -91,7 +91,7 @@ async def test_jina_dimension_check_success(jina_embedding_instance): expected_embeddings = [[0.5, 0.4, 0.3]] with patch( - "sdk.nexent.core.models.embedding_model.asyncio.to_thread", + "nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, return_value=expected_embeddings, ) as mock_to_thread: @@ -106,7 +106,7 @@ async def test_jina_dimension_check_failure(jina_embedding_instance): """dimension_check should return an empty list when an exception is raised inside to_thread.""" with patch( - "sdk.nexent.core.models.embedding_model.asyncio.to_thread", + "nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("connection error"), ) as mock_to_thread: @@ -127,7 +127,7 @@ def test_openai_get_embeddings_success_returns_list(openai_embedding_instance): fake_response = {"data": [{"embedding": [0.9, 0.8]}]} with patch( - "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", return_value=fake_response, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -145,7 +145,7 @@ def test_openai_get_embeddings_with_metadata(openai_embedding_instance): "data": [{"embedding": [1, 2, 3]}], "meta": {"foo": "bar"}} with patch( - "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", return_value=fake_response, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -172,7 +172,7 @@ def side_effect(data, timeout=None): side_effect.calls = 0 with patch( - "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", side_effect=side_effect, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -192,7 +192,7 @@ def test_openai_get_embeddings_timeout_exhausts_raises(openai_embedding_instance """Should raise Timeout after exhausting retries.""" with patch( - "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", side_effect=requests.exceptions.Timeout(), ) as mock_make_request: with pytest.raises(requests.exceptions.Timeout): @@ -226,7 +226,7 @@ def side_effect(inputs, with_metadata=False, timeout=None): return [[0.3, 0.4]] with patch( - "sdk.nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", + "nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", side_effect=side_effect, ) as mock_delegate: result = jina_embedding_instance.get_embeddings( @@ -251,7 +251,7 @@ def side_effect(inputs, with_metadata=False, timeout=None): side_effect.calls = 0 with patch( - "sdk.nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", + "nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", side_effect=side_effect, ) as mock_delegate: result = jina_embedding_instance.get_embeddings( @@ -273,7 +273,7 @@ def test_jina_get_embeddings_timeout_exhausts_raises(jina_embedding_instance): """Should raise Timeout after exhausting retries.""" with patch( - "sdk.nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", + "nexent.core.models.embedding_model.JinaEmbedding.get_multimodal_embeddings", side_effect=requests.exceptions.Timeout(), ) as mock_delegate: with pytest.raises(requests.exceptions.Timeout): @@ -306,7 +306,7 @@ def test_jina_get_multimodal_embeddings_parses_embeddings(jina_embedding_instanc mock_resp.json = Mock(return_value=fake_response) with patch( - "sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp + "nexent.core.models.embedding_model.requests.post", return_value=mock_resp ) as mock_post: inputs = [{"text": "t1"}, {"image": "http://x/y.jpg"}] result = jina_embedding_instance.get_multimodal_embeddings( @@ -334,7 +334,7 @@ def test_jina_get_multimodal_embeddings_with_metadata(jina_embedding_instance): mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) as mock_post: + with patch("nexent.core.models.embedding_model.requests.post", return_value=mock_resp) as mock_post: inputs = [{"text": "t"}] result = jina_embedding_instance.get_multimodal_embeddings( inputs, with_metadata=True, timeout=4 @@ -370,7 +370,7 @@ def side_effect(url, headers=None, json=None, timeout=None, **kwargs): side_effect.calls = 0 with patch( - "sdk.nexent.core.models.embedding_model.requests.Session.post", side_effect=side_effect + "nexent.core.models.embedding_model.requests.post", side_effect=side_effect ) as mock_post: inputs = [{"text": "t"}] result = jina_embedding_instance.get_multimodal_embeddings( @@ -391,7 +391,7 @@ def test_jina_get_multimodal_embeddings_timeout_exhausts_raises( """Should raise Timeout after exhausting retries.""" with patch( - "sdk.nexent.core.models.embedding_model.requests.Session.post", + "nexent.core.models.embedding_model.requests.post", side_effect=requests.exceptions.Timeout(), ) as mock_post: with pytest.raises(requests.exceptions.Timeout): @@ -438,7 +438,7 @@ async def test_jina_dimension_check_connection_error_returns_empty(jina_embeddin """dimension_check should return [] on ConnectionError.""" with patch( - "sdk.nexent.core.models.embedding_model.asyncio.to_thread", + "nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=requests.exceptions.ConnectionError(), ): @@ -457,7 +457,7 @@ def side_effect(data, timeout=None): return {"data": [{"embedding": [0.21, 0.22]}]} with patch( - "sdk.nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", + "nexent.core.models.embedding_model.OpenAICompatibleEmbedding._make_request", side_effect=side_effect, ) as mock_make_request: result = openai_embedding_instance.get_embeddings( @@ -470,7 +470,7 @@ def side_effect(data, timeout=None): def test_openai_make_request_invokes_requests_post(openai_embedding_instance): - """Cover OpenAI _make_request by patching requests.Session.post path.""" + """Cover OpenAI _make_request by patching requests.post path.""" fake_response = {"data": [{"embedding": [7, 8]}]} @@ -478,7 +478,7 @@ def test_openai_make_request_invokes_requests_post(openai_embedding_instance): mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) as mock_post: + with patch("nexent.core.models.embedding_model.requests.post", return_value=mock_resp) as mock_post: result = openai_embedding_instance.get_embeddings( ["hi"], with_metadata=False, timeout=2 ) @@ -502,7 +502,7 @@ async def test_openai_dimension_check_connection_error_returns_empty(openai_embe """dimension_check should return [] on ConnectionError.""" with patch( - "sdk.nexent.core.models.embedding_model.asyncio.to_thread", + "nexent.core.models.embedding_model.asyncio.to_thread", new_callable=AsyncMock, side_effect=requests.exceptions.ConnectionError(), ): @@ -513,13 +513,13 @@ async def test_openai_dimension_check_connection_error_returns_empty(openai_embe def test_api_key_normalization_and_verify_jina(monkeypatch): captured = {} - def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): + def fake_post(url, headers=None, json=None, timeout=None, verify=True): captured['url'] = url captured['headers'] = headers captured['verify'] = verify return DummyResponse() - monkeypatch.setattr(requests.Session, "post", fake_post) + monkeypatch.setattr("requests.post", fake_post) # api_key containing Bearer prefix should be normalized emb = JinaEmbedding(api_key="my-secret", base_url="https://example.com/emb", ssl_verify=False) @@ -533,13 +533,13 @@ def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): def test_api_key_normalization_and_verify_openaicompatible(monkeypatch): captured = {} - def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): + def fake_post(url, headers=None, json=None, timeout=None, verify=True): captured['url'] = url captured['headers'] = headers captured['verify'] = verify return DummyResponse() - monkeypatch.setattr(requests.Session, "post", fake_post) + monkeypatch.setattr("requests.post", fake_post) emb = OpenAICompatibleEmbedding(model_name="m", base_url="https://api.example/emb", api_key="KEY", embedding_dim=16, ssl_verify=True) data = emb._prepare_input("hi") @@ -574,9 +574,9 @@ async def dimension_check(self, timeout: float = 5.0): def test_jina_make_request_raises_http_error(monkeypatch): - """Ensure _make_request propagates HTTP errors from requests.Session.post""" + """Ensure _make_request propagates HTTP errors from requests.post""" - def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): + def fake_post(url, headers=None, json=None, timeout=None, verify=True): class BadResp: status_code = 500 @@ -585,7 +585,7 @@ def raise_for_status(self): return BadResp() - monkeypatch.setattr(requests.Session, "post", fake_post) + monkeypatch.setattr("requests.post", fake_post) emb = JinaEmbedding(api_key="k", base_url="https://api.jina.ai/v1/embeddings", ssl_verify=True) data = emb._prepare_multimodal_input([{"text": "hi"}]) @@ -596,7 +596,7 @@ def raise_for_status(self): def test_openai_make_request_raises_http_error(monkeypatch): """Ensure OpenAICompatibleEmbedding._make_request propagates HTTP errors""" - def fake_post(self, url, headers=None, json=None, timeout=None, verify=True): + def fake_post(url, headers=None, json=None, timeout=None, verify=True): class BadResp: status_code = 502 @@ -605,7 +605,7 @@ def raise_for_status(self): return BadResp() - monkeypatch.setattr(requests.Session, "post", fake_post) + monkeypatch.setattr("requests.post", fake_post) emb = OpenAICompatibleEmbedding(model_name="m", base_url="https://api.example.com/emb", api_key="k", embedding_dim=16, ssl_verify=False) data = emb._prepare_input("hello") @@ -623,10 +623,7 @@ def raise_for_status(self): def json(self): return {"meta": {"ok": True}} - def fake_post(self, *a, **k): - return RespNoData() - - monkeypatch.setattr(requests.Session, "post", fake_post) + monkeypatch.setattr("requests.post", lambda *a, **k: RespNoData()) emb = JinaEmbedding(api_key="k") with pytest.raises(KeyError): @@ -644,13 +641,13 @@ def test_openai_get_embeddings_calls_record_model_call(mocker): mock_ctx.__enter__ = mocker.MagicMock(return_value=None) mock_ctx.__exit__ = mocker.MagicMock(return_value=False) mock_record = mocker.patch( - "sdk.nexent.core.models.embedding_model.record_model_call", + "nexent.core.models.embedding_model.record_model_call", return_value=mock_ctx, ) mock_resp = Mock() mock_resp.raise_for_status = Mock() mock_resp.json.return_value = {"data": [{"embedding": [0.1, 0.2]}]} - mocker.patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) + mocker.patch("requests.post", return_value=mock_resp) emb = OpenAICompatibleEmbedding( model_name="text-emb-3", @@ -672,13 +669,13 @@ def test_jina_get_embeddings_calls_record_model_call(mocker): mock_ctx.__enter__ = mocker.MagicMock(return_value=None) mock_ctx.__exit__ = mocker.MagicMock(return_value=False) mock_record = mocker.patch( - "sdk.nexent.core.models.embedding_model.record_model_call", + "nexent.core.models.embedding_model.record_model_call", return_value=mock_ctx, ) mock_resp = Mock() mock_resp.raise_for_status = Mock() mock_resp.json.return_value = {"data": [{"embedding": [0.1, 0.2]}]} - mocker.patch("sdk.nexent.core.models.embedding_model.requests.Session.post", return_value=mock_resp) + mocker.patch("requests.post", return_value=mock_resp) emb = JinaEmbedding(api_key="k", ssl_verify=True) emb.get_multimodal_embeddings([{"text": "hi"}], with_metadata=False, timeout=5) diff --git a/test/sdk/core/models/test_openai_llm.py b/test/sdk/core/models/test_openai_llm.py index 2589e1d6c..0477a86a1 100644 --- a/test/sdk/core/models/test_openai_llm.py +++ b/test/sdk/core/models/test_openai_llm.py @@ -887,33 +887,29 @@ def test_init_with_ssl_verify_false(): observer = MagicMock() - # Mock httpx.Client directly (it's imported inside __init__) - with patch("httpx.Client") as mock_httpx_client: + # Mock DefaultHttpxClient from openai module + with patch("openai.DefaultHttpxClient") as mock_httpx_client: mock_httpx_client.return_value = MagicMock() # Create model with ssl_verify=False model = ImportedOpenAIModel(observer=observer, ssl_verify=False) - # Verify httpx.Client was called with verify=False - mock_httpx_client.assert_called_once() - call_kwargs = mock_httpx_client.call_args - assert call_kwargs.kwargs.get("verify") is False + # Verify DefaultHttpxClient was called with verify=False + mock_httpx_client.assert_called_once_with(verify=False) def test_init_with_ssl_verify_true(): - """Test __init__ method creates http_client when ssl_verify=True (default)""" + """Test __init__ method doesn't create http_client when ssl_verify=True (default)""" observer = MagicMock() - # Mock httpx.Client directly (it's imported inside __init__) - with patch("httpx.Client") as mock_httpx_client: + # Mock DefaultHttpxClient from openai module + with patch("openai.DefaultHttpxClient") as mock_httpx_client: # Create model with ssl_verify=True (default) model = ImportedOpenAIModel(observer=observer, ssl_verify=True) - # Verify httpx.Client was called (it's always created, the verify param differs) - assert mock_httpx_client.call_count == 1 - call_kwargs = mock_httpx_client.call_args - assert call_kwargs.kwargs.get("verify") is True + # Verify DefaultHttpxClient was NOT called + mock_httpx_client.assert_not_called() # --------------------------------------------------------------------------- From 15fe84fe0c9b06205c938ecc2e4f92d0a894fb56 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 15:29:11 +0800 Subject: [PATCH 45/48] =?UTF-8?q?=E4=BD=BF=E7=94=A8DefaultHttpxClient?= =?UTF-8?q?=E8=80=8C=E9=9D=9Ehttpx.client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/nexent/core/models/openai_llm.py | 34 ++++++++++++---------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index fba18151a..455eb5f05 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -23,9 +23,8 @@ class OpenAIModel(OpenAIServerModel): def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p=0.95, - ssl_verify=True, model_factory: Optional[str] = None, - display_name: Optional[str] = None, timeout_seconds: Optional[float] = None, - *args, **kwargs): + ssl_verify=True, timeout_seconds: Optional[float] = None, model_factory: Optional[str] = None, + display_name: Optional[str] = None, *args, **kwargs): """ Initialize OpenAI Model with observer and SSL verification option. @@ -35,9 +34,9 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p: Top-p sampling parameter (default: 0.95) ssl_verify: Whether to verify SSL certificates (default: True). Set to False for local services without SSL support. + timeout_seconds: Timeout in seconds for HTTP requests (default: None, uses client default). model_factory: Provider identifier (e.g., openai, modelengine) display_name: Human-readable display name for monitoring - timeout_seconds: Request timeout in seconds. If None, defaults to 120 seconds. *args: Additional positional arguments for OpenAIServerModel **kwargs: Additional keyword arguments for OpenAIServerModel """ @@ -48,15 +47,17 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, self._monitoring = get_monitoring_manager() self.model_factory = (model_factory or "").lower() self.display_name = display_name - self.timeout_seconds = timeout_seconds - # Create http_client with trust_env=False to ignore proxy env vars - import httpx - timeout = httpx.Timeout(timeout_seconds) if timeout_seconds is not None else httpx.Timeout(120.0) - http_client = httpx.Client(verify=ssl_verify, timeout=timeout, trust_env=False) - client_kwargs = kwargs.get('client_kwargs', {}) - client_kwargs['http_client'] = http_client - kwargs['client_kwargs'] = client_kwargs + # Create http_client based on ssl_verify parameter and timeout + if not ssl_verify or timeout_seconds is not None: + from openai import DefaultHttpxClient + client_config = {"verify": ssl_verify} + if timeout_seconds is not None: + client_config["timeout"] = timeout_seconds + http_client = DefaultHttpxClient(**client_config) + client_kwargs = kwargs.get('client_kwargs', {}) + client_kwargs['http_client'] = http_client + kwargs['client_kwargs'] = client_kwargs super().__init__(*args, **kwargs) @@ -285,16 +286,11 @@ async def check_connectivity(self) -> bool: max_tokens=5, ) - # Use custom timeout if specified - request_kwargs = {"stream": False, **completion_kwargs} - if self.timeout_seconds is not None: - import httpx - request_kwargs["timeout"] = httpx.Timeout(self.timeout_seconds) - # Offload the blocking SDK call to a thread pool to avoid blocking the event loop await asyncio.to_thread( self.client.chat.completions.create, - **request_kwargs, + stream=False, + **completion_kwargs, ) # If no exception is raised, the connection is successful From 6237459710fe616abef6d86c4ceefd8d3d58b93e Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 15:38:50 +0800 Subject: [PATCH 46/48] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=B8=8D=E5=AD=98?= =?UTF-8?q?=E5=9C=A8=E7=9A=84=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/model_management_service.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index d46178d50..0de1d41e5 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -12,8 +12,7 @@ get_models_by_display_name, get_model_records, get_models_by_tenant_factory_type, - update_model_record, - update_model_record_by_model_name, + update_model_record ) from services.model_provider_service import ( prepare_model_dict, From 6ed26ec754bd96dac4e23533b894c72563866418 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 16:14:33 +0800 Subject: [PATCH 47/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/model_management_service.py | 1 + .../services/test_model_management_service.py | 36 ++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 0de1d41e5..a90794a85 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -9,6 +9,7 @@ create_model_record, delete_model_record, get_model_by_display_name, + get_model_by_name_factory, get_models_by_display_name, get_model_records, get_models_by_tenant_factory_type, diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 6e504e90a..5013d27ab 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -138,6 +138,10 @@ async def _prepare_model_dict(**kwargs): return {} +def _merge_existing_model_attributes(model_list, tenant_id, provider, model_type, fields=None): + return model_list + + def _merge_existing_model_tokens(model_list, tenant_id, provider, model_type): return model_list @@ -145,6 +149,7 @@ def _merge_existing_model_tokens(model_list, tenant_id, provider, model_type): async def _get_provider_models(model_data): return [] services_provider_mod.prepare_model_dict = _prepare_model_dict +services_provider_mod.merge_existing_model_attributes = _merge_existing_model_attributes services_provider_mod.merge_existing_model_tokens = _merge_existing_model_tokens services_provider_mod.get_provider_models = _get_provider_models sys.modules["services.model_provider_service"] = services_provider_mod @@ -210,9 +215,15 @@ def _get_models_by_display_name(*args, **kwargs): return [] +def _get_model_by_name_factory(*args, **kwargs): + """Return None by default; tests can patch svc.get_model_by_name_factory.""" + return None + + db_mm_mod.create_model_record = _noop db_mm_mod.delete_model_record = _noop db_mm_mod.get_model_by_display_name = _noop +db_mm_mod.get_model_by_name_factory = _get_model_by_name_factory db_mm_mod.get_models_by_display_name = _get_models_by_display_name db_mm_mod.get_model_records = _get_model_records db_mm_mod.get_models_by_tenant_factory_type = _get_models_by_tenant_factory_type @@ -536,7 +547,7 @@ async def test_create_provider_models_for_tenant_success(): models = [{"id": "silicon/a"}, {"id": "silicon/b"}] with mock.patch.object(svc, "get_provider_models", new=mock.AsyncMock(return_value=models)) as mock_get, \ - mock.patch.object(svc, "merge_existing_model_tokens", return_value=models) as mock_merge, \ + mock.patch.object(svc, "merge_existing_model_attributes", return_value=models) as mock_merge, \ mock.patch.object(svc, "sort_models_by_id", side_effect=lambda m: m) as mock_sort: out = await svc.create_provider_models_for_tenant("t1", req) @@ -866,18 +877,33 @@ async def test_update_single_model_for_tenant_multi_embedding_updates_both(): async def test_batch_update_models_for_tenant_success(): svc = import_svc() - models = [{"model_id": "a"}, {"model_id": "b"}] + models = [{"model_id": "1", "max_tokens": 4096}, {"model_id": "2", "max_tokens": 8192}] with mock.patch.object(svc, "update_model_record") as mock_update: await svc.batch_update_models_for_tenant("u1", "t1", models) assert mock_update.call_count == 2 - mock_update.assert_any_call("a", models[0], "u1", "t1") - mock_update.assert_any_call("b", models[1], "u1", "t1") + mock_update.assert_any_call(1, {"max_tokens": 4096}, "u1", "t1") + mock_update.assert_any_call(2, {"max_tokens": 8192}, "u1", "t1") + + +async def test_batch_update_models_for_tenant_by_name_factory(): + """Batch update resolves model_id via get_model_by_name_factory when model_id is not numeric.""" + svc = import_svc() + + models = [{"model_id": "openai/gpt-4", "max_tokens": 4096}] + with mock.patch.object( + svc, + "get_model_by_name_factory", + return_value={"model_id": 42}, + ) as mock_lookup, mock.patch.object(svc, "update_model_record") as mock_update: + await svc.batch_update_models_for_tenant("u1", "t1", models) + mock_lookup.assert_called_once_with("gpt-4", "openai", "t1") + mock_update.assert_called_once_with(42, {"max_tokens": 4096}, "u1", "t1") async def test_batch_update_models_for_tenant_exception(): svc = import_svc() - models = [{"model_id": "a"}] + models = [{"model_id": "1"}] with mock.patch.object(svc, "update_model_record", side_effect=Exception("oops")): with pytest.raises(Exception) as exc: await svc.batch_update_models_for_tenant("u1", "t1", models) From 2bfd1f8be457ca3417d82109efa68d2e52e60cdc Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Thu, 21 May 2026 16:42:10 +0800 Subject: [PATCH 48/48] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/services/test_prompt_service.py | 14 ++++++- test/sdk/core/models/test_embedding_model.py | 41 +++++++++----------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 522a850f0..53f54c34a 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -680,8 +680,10 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') - def test_generate_system_prompt(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): + @patch('backend.services.prompt_service.get_model_by_model_id') + def test_generate_system_prompt(self, mock_get_model, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup + mock_get_model.return_value = None # No DB connection needed; concurrency_limit defaults to unlimited mock_prompt_config = { "user_prompt": "Test user prompt template", "duty_system_prompt": "Generate duty prompt", @@ -803,8 +805,10 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') - def test_generate_system_prompt_with_exception(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): + @patch('backend.services.prompt_service.get_model_by_model_id') + def test_generate_system_prompt_with_exception(self, mock_get_model, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup + mock_get_model.return_value = None # No DB connection needed; concurrency_limit defaults to unlimited mock_prompt_config = { "user_prompt": "Test user prompt template", "duty_system_prompt": "Generate duty prompt", @@ -1101,14 +1105,17 @@ def mock_gen(*args, **kwargs): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') + @patch('backend.services.prompt_service.get_model_by_model_id') def test_generate_system_prompt_error_before_streaming( self, + mock_get_model, mock_resolve_prompt_template, mock_join_info, mock_call_llm, ): """Test generate_system_prompt handles error that occurs before streaming (line 307-311)""" # Setup + mock_get_model.return_value = None # No DB connection needed; concurrency_limit defaults to unlimited mock_prompt_config = { "user_prompt": "Test user prompt template", "duty_system_prompt": "Generate duty prompt", @@ -1151,14 +1158,17 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') @patch('backend.services.prompt_service.resolve_prompt_generate_template') + @patch('backend.services.prompt_service.get_model_by_model_id') def test_generate_system_prompt_error_during_streaming( self, + mock_get_model, mock_resolve_prompt_template, mock_join_info, mock_call_llm, ): """Test generate_system_prompt handles error that occurs during streaming (line 330-331)""" # Setup + mock_get_model.return_value = None # No DB connection needed; concurrency_limit defaults to unlimited mock_prompt_config = { "user_prompt": "Test user prompt template", "duty_system_prompt": "Generate duty prompt", diff --git a/test/sdk/core/models/test_embedding_model.py b/test/sdk/core/models/test_embedding_model.py index 9c3f8824b..02d706c04 100644 --- a/test/sdk/core/models/test_embedding_model.py +++ b/test/sdk/core/models/test_embedding_model.py @@ -305,9 +305,7 @@ def test_jina_get_multimodal_embeddings_parses_embeddings(jina_embedding_instanc mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch( - "nexent.core.models.embedding_model.requests.post", return_value=mock_resp - ) as mock_post: + with patch.object(jina_embedding_instance.session, "post", return_value=mock_resp) as mock_post: inputs = [{"text": "t1"}, {"image": "http://x/y.jpg"}] result = jina_embedding_instance.get_multimodal_embeddings( inputs, with_metadata=False, timeout=3 @@ -334,7 +332,7 @@ def test_jina_get_multimodal_embeddings_with_metadata(jina_embedding_instance): mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch("nexent.core.models.embedding_model.requests.post", return_value=mock_resp) as mock_post: + with patch.object(jina_embedding_instance.session, "post", return_value=mock_resp) as mock_post: inputs = [{"text": "t"}] result = jina_embedding_instance.get_multimodal_embeddings( inputs, with_metadata=True, timeout=4 @@ -369,9 +367,7 @@ def side_effect(url, headers=None, json=None, timeout=None, **kwargs): side_effect.calls = 0 - with patch( - "nexent.core.models.embedding_model.requests.post", side_effect=side_effect - ) as mock_post: + with patch.object(jina_embedding_instance.session, "post", side_effect=side_effect) as mock_post: inputs = [{"text": "t"}] result = jina_embedding_instance.get_multimodal_embeddings( inputs, with_metadata=False, timeout=None, retries=2, retry_timeout_step=2 @@ -390,8 +386,9 @@ def test_jina_get_multimodal_embeddings_timeout_exhausts_raises( ): """Should raise Timeout after exhausting retries.""" - with patch( - "nexent.core.models.embedding_model.requests.post", + with patch.object( + jina_embedding_instance.session, + "post", side_effect=requests.exceptions.Timeout(), ) as mock_post: with pytest.raises(requests.exceptions.Timeout): @@ -470,7 +467,7 @@ def side_effect(data, timeout=None): def test_openai_make_request_invokes_requests_post(openai_embedding_instance): - """Cover OpenAI _make_request by patching requests.post path.""" + """Cover OpenAI _make_request by patching session.post path.""" fake_response = {"data": [{"embedding": [7, 8]}]} @@ -478,7 +475,7 @@ def test_openai_make_request_invokes_requests_post(openai_embedding_instance): mock_resp.raise_for_status = Mock() mock_resp.json = Mock(return_value=fake_response) - with patch("nexent.core.models.embedding_model.requests.post", return_value=mock_resp) as mock_post: + with patch.object(openai_embedding_instance.session, "post", return_value=mock_resp) as mock_post: result = openai_embedding_instance.get_embeddings( ["hi"], with_metadata=False, timeout=2 ) @@ -513,13 +510,13 @@ async def test_openai_dimension_check_connection_error_returns_empty(openai_embe def test_api_key_normalization_and_verify_jina(monkeypatch): captured = {} - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True, **kwargs): captured['url'] = url captured['headers'] = headers captured['verify'] = verify return DummyResponse() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr("requests.Session.post", fake_post) # api_key containing Bearer prefix should be normalized emb = JinaEmbedding(api_key="my-secret", base_url="https://example.com/emb", ssl_verify=False) @@ -533,13 +530,13 @@ def fake_post(url, headers=None, json=None, timeout=None, verify=True): def test_api_key_normalization_and_verify_openaicompatible(monkeypatch): captured = {} - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True, **kwargs): captured['url'] = url captured['headers'] = headers captured['verify'] = verify return DummyResponse() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr("requests.Session.post", fake_post) emb = OpenAICompatibleEmbedding(model_name="m", base_url="https://api.example/emb", api_key="KEY", embedding_dim=16, ssl_verify=True) data = emb._prepare_input("hi") @@ -576,7 +573,7 @@ async def dimension_check(self, timeout: float = 5.0): def test_jina_make_request_raises_http_error(monkeypatch): """Ensure _make_request propagates HTTP errors from requests.post""" - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True, **kwargs): class BadResp: status_code = 500 @@ -585,7 +582,7 @@ def raise_for_status(self): return BadResp() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr("requests.Session.post", fake_post) emb = JinaEmbedding(api_key="k", base_url="https://api.jina.ai/v1/embeddings", ssl_verify=True) data = emb._prepare_multimodal_input([{"text": "hi"}]) @@ -596,7 +593,7 @@ def raise_for_status(self): def test_openai_make_request_raises_http_error(monkeypatch): """Ensure OpenAICompatibleEmbedding._make_request propagates HTTP errors""" - def fake_post(url, headers=None, json=None, timeout=None, verify=True): + def fake_post(self, url, headers=None, json=None, timeout=None, verify=True, **kwargs): class BadResp: status_code = 502 @@ -605,7 +602,7 @@ def raise_for_status(self): return BadResp() - monkeypatch.setattr("requests.post", fake_post) + monkeypatch.setattr("requests.Session.post", fake_post) emb = OpenAICompatibleEmbedding(model_name="m", base_url="https://api.example.com/emb", api_key="k", embedding_dim=16, ssl_verify=False) data = emb._prepare_input("hello") @@ -623,7 +620,7 @@ def raise_for_status(self): def json(self): return {"meta": {"ok": True}} - monkeypatch.setattr("requests.post", lambda *a, **k: RespNoData()) + monkeypatch.setattr("requests.Session.post", lambda *a, **k: RespNoData()) emb = JinaEmbedding(api_key="k") with pytest.raises(KeyError): @@ -647,7 +644,6 @@ def test_openai_get_embeddings_calls_record_model_call(mocker): mock_resp = Mock() mock_resp.raise_for_status = Mock() mock_resp.json.return_value = {"data": [{"embedding": [0.1, 0.2]}]} - mocker.patch("requests.post", return_value=mock_resp) emb = OpenAICompatibleEmbedding( model_name="text-emb-3", @@ -656,6 +652,7 @@ def test_openai_get_embeddings_calls_record_model_call(mocker): embedding_dim=2, ssl_verify=True, ) + mocker.patch.object(emb.session, "post", return_value=mock_resp) emb.get_embeddings(["hello"], with_metadata=False, timeout=5) mock_record.assert_called_once_with( @@ -675,9 +672,9 @@ def test_jina_get_embeddings_calls_record_model_call(mocker): mock_resp = Mock() mock_resp.raise_for_status = Mock() mock_resp.json.return_value = {"data": [{"embedding": [0.1, 0.2]}]} - mocker.patch("requests.post", return_value=mock_resp) emb = JinaEmbedding(api_key="k", ssl_verify=True) + mocker.patch.object(emb.session, "post", return_value=mock_resp) emb.get_multimodal_embeddings([{"text": "hi"}], with_metadata=False, timeout=5) mock_record.assert_called_once_with(