diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index 0cfc962ea..3119b38b6 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -16,6 +16,7 @@ from apps.model_managment_app import router as model_manager_router from apps.oauth_app import router as oauth_router from apps.prompt_app import router as prompt_router +from apps.prompt_template_app import router as prompt_template_router from apps.remote_mcp_app import router as remote_mcp_router from apps.skill_app import router as skill_router from apps.tenant_config_app import router as tenant_config_router @@ -30,6 +31,7 @@ from apps.monitoring_app import router as monitoring_router from apps.a2a_server_app import router as a2a_server_router from consts.const import IS_SPEED_MODE +from services.prompt_template_service import sync_system_default_prompt_template # Create logger instance logger = logging.getLogger("base_app") @@ -37,6 +39,16 @@ # Create FastAPI app with common configurations app = create_app(title="Nexent Config API", description="Configuration APIs") + +@app.on_event("startup") +async def sync_default_prompt_template_on_startup(): + """Sync the YAML-backed system default prompt template into the database on startup.""" + try: + sync_system_default_prompt_template() + logger.info("System default prompt template synced successfully.") + except Exception as exc: + logger.error(f"Failed to sync system default prompt template: {str(exc)}") + app.include_router(model_manager_router) app.include_router(config_sync_router) app.include_router(agent_router) @@ -62,6 +74,7 @@ app.include_router(summary_router) app.include_router(prompt_router) +app.include_router(prompt_template_router) app.include_router(skill_router) app.include_router(tenant_config_router) app.include_router(remote_mcp_router) diff --git a/backend/apps/prompt_app.py b/backend/apps/prompt_app.py index 23868ad79..47bc38a72 100644 --- a/backend/apps/prompt_app.py +++ b/backend/apps/prompt_app.py @@ -28,6 +28,7 @@ async def generate_and_save_system_prompt_api( agent_id=prompt_request.agent_id, model_id=prompt_request.model_id, task_description=prompt_request.task_description, + prompt_template_id=prompt_request.prompt_template_id, user_id=user_id, tenant_id=tenant_id, language=language, diff --git a/backend/apps/prompt_template_app.py b/backend/apps/prompt_template_app.py new file mode 100644 index 000000000..0f12bd614 --- /dev/null +++ b/backend/apps/prompt_template_app.py @@ -0,0 +1,143 @@ +import logging +from http import HTTPStatus +from typing import Optional + +from fastapi import APIRouter, Header, HTTPException +from starlette.responses import JSONResponse + +from consts.exceptions import DuplicateError, NotFoundException, ValidationError +from consts.model import PromptTemplateRequest +from services.prompt_template_service import ( + create_prompt_template_impl, + delete_prompt_template_impl, + get_prompt_template_detail_impl, + list_prompt_templates_impl, + update_prompt_template_impl, +) +from utils.auth_utils import get_current_user_id + +router = APIRouter(prefix="/prompt_templates") +logger = logging.getLogger("prompt_template_app") + + +@router.get("") +async def list_prompt_templates_api( + authorization: Optional[str] = Header(None), +): + """List prompt templates for the current user.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = list_prompt_templates_impl(tenant_id=tenant_id, user_id=user_id) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except Exception as exc: + logger.error(f"Prompt template list error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template list error.", + ) + + +@router.get("/{template_id}") +async def get_prompt_template_api( + template_id: int, + authorization: Optional[str] = Header(None), +): + """Get prompt template detail.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = get_prompt_template_detail_impl( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except NotFoundException as exc: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template detail error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template detail error.", + ) + + +@router.post("") +async def create_prompt_template_api( + request: PromptTemplateRequest, + authorization: Optional[str] = Header(None), +): + """Create a prompt template.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = create_prompt_template_impl( + request=request, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except DuplicateError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template create error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template create error.", + ) + + +@router.put("/{template_id}") +async def update_prompt_template_api( + template_id: int, + request: PromptTemplateRequest, + authorization: Optional[str] = Header(None), +): + """Update a prompt template.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = update_prompt_template_impl( + template_id=template_id, + request=request, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except NotFoundException as exc: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except DuplicateError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template update error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template update error.", + ) + + +@router.delete("/{template_id}") +async def delete_prompt_template_api( + template_id: int, + authorization: Optional[str] = Header(None), +): + """Delete a prompt template.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = delete_prompt_template_impl( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except NotFoundException as exc: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template delete error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template delete error.", + ) diff --git a/backend/consts/model.py b/backend/consts/model.py index 6c792501f..5e00a143f 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -1,9 +1,11 @@ from enum import Enum from typing import Optional, Any, List, Dict -from pydantic import BaseModel, Field, EmailStr +from pydantic import BaseModel, Field, EmailStr, ConfigDict from nexent.core.agents.agent_model import ToolConfig +from consts.prompt_template import PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP + class ModelConnectStatusEnum(Enum): """Enum class for model connection status""" @@ -312,6 +314,7 @@ class GeneratePromptRequest(BaseModel): task_description: str agent_id: int model_id: int + prompt_template_id: Optional[int] = None tool_ids: Optional[List[int]] = Field( None, description="Optional: tool IDs from frontend (takes precedence over database query)") sub_agent_ids: Optional[List[int]] = Field( @@ -320,6 +323,50 @@ class GeneratePromptRequest(BaseModel): None, description="Optional: knowledge base display names from frontend (takes precedence over database query)") +class PromptTemplateContentRequest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + duty_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["duty_system_prompt"] + ) + constraint_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["constraint_system_prompt"] + ) + few_shots_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["few_shots_system_prompt"] + ) + agent_variable_name_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_variable_name_system_prompt"] + ) + agent_display_name_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_display_name_system_prompt"] + ) + agent_description_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_description_system_prompt"] + ) + user_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["user_prompt"] + ) + agent_name_regenerate_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_name_regenerate_system_prompt"] + ) + agent_name_regenerate_user_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_name_regenerate_user_prompt"] + ) + agent_display_name_regenerate_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_display_name_regenerate_system_prompt"] + ) + agent_display_name_regenerate_user_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_display_name_regenerate_user_prompt"] + ) + + +class PromptTemplateRequest(BaseModel): + template_name: str + description: Optional[str] = None + template_type: str = "agent_generate" + template_content_zh: PromptTemplateContentRequest + template_content_en: Optional[PromptTemplateContentRequest] = None class OptimizePromptSectionRequest(BaseModel): task_description: str agent_id: int @@ -359,6 +406,8 @@ class AgentInfoRequest(BaseModel): enabled: Optional[bool] = None business_logic_model_name: Optional[str] = None business_logic_model_id: Optional[int] = None + prompt_template_id: Optional[int] = None + prompt_template_name: Optional[str] = None enabled_tool_ids: Optional[List[int]] = None enabled_skill_ids: Optional[List[int]] = None related_agent_ids: Optional[List[int]] = None @@ -448,6 +497,8 @@ class ExportAndImportAgentInfo(BaseModel): model_name: Optional[str] = None business_logic_model_id: Optional[int] = None business_logic_model_name: Optional[str] = None + prompt_template_id: Optional[int] = None + prompt_template_name: Optional[str] = None class Config: arbitrary_types_allowed = True diff --git a/backend/consts/prompt_template.py b/backend/consts/prompt_template.py new file mode 100644 index 000000000..febcaeca5 --- /dev/null +++ b/backend/consts/prompt_template.py @@ -0,0 +1,15 @@ +PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} + +PROMPT_GENERATE_TEMPLATE_FIELDS = tuple(PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys()) diff --git a/backend/database/agent_db.py b/backend/database/agent_db.py index 7d14d7b8e..90de64ca9 100644 --- a/backend/database/agent_db.py +++ b/backend/database/agent_db.py @@ -192,6 +192,8 @@ def create_agent(agent_info, tenant_id: str, user_id: str): "business_description": new_agent.business_description, "business_logic_model_id": new_agent.business_logic_model_id, "business_logic_model_name": new_agent.business_logic_model_name, + "prompt_template_id": new_agent.prompt_template_id, + "prompt_template_name": new_agent.prompt_template_name, "group_ids": new_agent.group_ids, "is_new": new_agent.is_new, "enable_context_manager": new_agent.enable_context_manager, diff --git a/backend/database/db_models.py b/backend/database/db_models.py index baa8e903e..eeadf0192 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -1,4 +1,4 @@ -from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, Numeric, Sequence, String, Text, TIMESTAMP, UniqueConstraint, Index, Float +from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, Numeric, Sequence, String, Text, TIMESTAMP, UniqueConstraint, Index, Float, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql import func @@ -313,6 +313,8 @@ class AgentInfo(TableBase): Text, doc="Manually entered by the user to describe the entire business process") business_logic_model_name = Column(String(100), doc="Model name used for business logic prompt generation") business_logic_model_id = Column(Integer, doc="Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id") + prompt_template_id = Column(Integer, doc="Prompt template ID used for business logic prompt generation") + prompt_template_name = Column(String(100), doc="Prompt template name used for business logic prompt generation") group_ids = Column(String, doc="Agent group IDs list") is_new = Column(Boolean, default=False, doc="Whether this agent is marked as new for the user") current_version_no = Column(Integer, nullable=True, doc="Current published version number. NULL means no version published yet") @@ -320,6 +322,41 @@ class AgentInfo(TableBase): enable_context_manager = Column(Boolean, default=False, doc="Whether to enable context management (compression) for this agent") +class PromptTemplate(TableBase): + """ + Prompt template table for user-defined prompt generation templates. + """ + __tablename__ = "ag_prompt_template_t" + __table_args__ = ( + Index( + "uq_prompt_template_user_name_active", + "tenant_id", + "user_id", + "template_name", + unique=True, + postgresql_where=text("delete_flag = 'N'"), + ), + Index( + "idx_ag_prompt_template_t_user", + "tenant_id", + "user_id", + "template_type", + postgresql_where=text("delete_flag = 'N'"), + ), + {"schema": SCHEMA}, + ) + + template_id = Column(Integer, Sequence( + "ag_prompt_template_t_template_id_seq", schema=SCHEMA), primary_key=True, nullable=False, autoincrement=True, doc="Prompt template ID") + template_name = Column(String(100), nullable=False, doc="Prompt template name") + description = Column(String(500), doc="Prompt template description") + template_type = Column(String(50), nullable=False, default="agent_generate", doc="Prompt template type") + tenant_id = Column(String(100), nullable=False, doc="Tenant ID") + user_id = Column(String(100), nullable=False, doc="User ID") + template_content_zh = Column(JSONB, nullable=False, doc="Chinese prompt template content") + template_content_en = Column(JSONB, doc="English prompt template content") + + class ToolInstance(TableBase): """ Information table for tenant tool configuration. diff --git a/backend/database/prompt_template_db.py b/backend/database/prompt_template_db.py new file mode 100644 index 000000000..fbc286cf9 --- /dev/null +++ b/backend/database/prompt_template_db.py @@ -0,0 +1,165 @@ +import logging +from typing import Optional + +from sqlalchemy import select, update + +from database.client import as_dict, filter_property, get_db_session +from database.db_models import PromptTemplate + +logger = logging.getLogger("prompt_template_db") + + +def create_prompt_template(template_data: dict) -> dict: + """Create a prompt template.""" + with get_db_session() as session: + prompt_template = PromptTemplate( + **filter_property(template_data, PromptTemplate) + ) + prompt_template.delete_flag = "N" + session.add(prompt_template) + session.flush() + return as_dict(prompt_template) + + +def upsert_prompt_template_by_id(template_id: int, template_data: dict, user_id: str) -> dict: + """Create or update a prompt template with a fixed template ID.""" + with get_db_session() as session: + prompt_template = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + ).first() + + filtered_data = filter_property(template_data, PromptTemplate) + if prompt_template: + for key, value in filtered_data.items(): + setattr(prompt_template, key, value) + prompt_template.updated_by = user_id + else: + prompt_template = PromptTemplate(**filtered_data) + prompt_template.template_id = template_id + prompt_template.delete_flag = filtered_data.get("delete_flag", "N") + session.add(prompt_template) + + session.flush() + return as_dict(prompt_template) + + +def update_prompt_template(template_id: int, template_data: dict, user_id: str) -> dict: + """Update a prompt template.""" + with get_db_session() as session: + prompt_template = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + PromptTemplate.delete_flag == "N", + ).first() + + if not prompt_template: + raise ValueError("prompt template not found") + + for key, value in filter_property(template_data, PromptTemplate).items(): + if value is None: + continue + setattr(prompt_template, key, value) + + prompt_template.updated_by = user_id + session.flush() + return as_dict(prompt_template) + + +def delete_prompt_template(template_id: int, user_id: str) -> int: + """Soft-delete a prompt template.""" + with get_db_session() as session: + result = session.execute( + update(PromptTemplate) + .where( + PromptTemplate.template_id == template_id, + PromptTemplate.delete_flag == "N", + ) + .values(delete_flag="Y", updated_by=user_id) + ) + return result.rowcount + + +def query_prompt_templates_by_user( + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> list[dict]: + """Query prompt templates by tenant and user.""" + with get_db_session() as session: + templates = session.query(PromptTemplate).filter( + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ).order_by(PromptTemplate.update_time.desc(), PromptTemplate.template_id.desc()).all() + return [as_dict(template) for template in templates] + + +def get_prompt_template_by_id( + template_id: int, + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> Optional[dict]: + """Get a prompt template by ID.""" + with get_db_session() as session: + template = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ).first() + return as_dict(template) if template else None + + +def get_prompt_template_by_name( + template_name: str, + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> Optional[dict]: + """Get a prompt template by name.""" + with get_db_session() as session: + template = session.query(PromptTemplate).filter( + PromptTemplate.template_name == template_name, + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ).first() + return as_dict(template) if template else None + + +def get_prompt_template_by_template_id( + template_id: int, + template_type: str = "agent_generate", + include_deleted: bool = False, +) -> Optional[dict]: + """Get a prompt template by template ID regardless of owner.""" + with get_db_session() as session: + query = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + PromptTemplate.template_type == template_type, + ) + if not include_deleted: + query = query.filter(PromptTemplate.delete_flag == "N") + template = query.first() + return as_dict(template) if template else None + + +def query_prompt_template_names( + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> set[str]: + """Query all active prompt template names for the current user.""" + with get_db_session() as session: + rows = session.execute( + select(PromptTemplate.template_name).where( + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ) + ).all() + return {row[0] for row in rows if row and row[0]} diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 02fa7d8c6..453dabcdb 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -16,6 +16,7 @@ from agents.create_agent_info import create_agent_run_info, create_tool_config_list from agents.preprocess_manager import preprocess_manager from services.agent_version_service import publish_version_impl +from utils.prompt_template_utils import normalize_prompt_generate_template_content from consts.const import MEMORY_SEARCH_START_MSG, MEMORY_SEARCH_DONE_MSG, MEMORY_SEARCH_FAIL_MSG, TOOL_TYPE_MAPPING, \ LANGUAGE, MESSAGE_ROLE, MODEL_CONFIG_MAPPING, CAN_EDIT_ALL_USER_ROLES, PERMISSION_EDIT, PERMISSION_READ, PERMISSION_PRIVATE from consts.exceptions import MemoryPreparationException @@ -64,6 +65,11 @@ from database.group_db import query_group_ids_by_user from database.user_tenant_db import get_user_tenant_by_user_id from database.a2a_agent_db import get_server_agent_ids +from services.prompt_template_service import ( + SYSTEM_PROMPT_TEMPLATE_ID, + SYSTEM_PROMPT_TEMPLATE_NAME, + get_prompt_template_summary, +) from utils.str_utils import convert_list_to_string, convert_string_to_list from services.conversation_management_service import save_conversation_assistant, save_conversation_user from services.memory_config_service import build_memory_context @@ -312,12 +318,25 @@ def _regenerate_agent_value_with_llm( user_prompt_key: str, default_system_prompt: str, default_user_prompt_builder: Callable[[dict], str], - fallback_fn: Callable[[str], str] + fallback_fn: Callable[[str], str], + prompt_template_id: Optional[int] = None, + user_id: Optional[str] = None, ) -> str: """ Shared helper to regenerate agent-related values with an LLM. """ - prompt_template = get_prompt_generate_prompt_template(language) + if user_id is not None: + from services.prompt_template_service import resolve_prompt_generate_template + prompt_template = resolve_prompt_generate_template( + tenant_id=tenant_id, + user_id=user_id, + language=language, + prompt_template_id=prompt_template_id, + ) + else: + prompt_template = normalize_prompt_generate_template_content( + get_prompt_generate_prompt_template(language) + ) system_prompt = _render_prompt_template( prompt_template.get(system_prompt_key, ""), original_value=original_value @@ -374,7 +393,9 @@ def _regenerate_agent_name_with_llm( tenant_id: str, language: str = LANGUAGE["ZH"], agents_cache: list[dict] | None = None, - exclude_agent_id: int | None = None + exclude_agent_id: int | None = None, + prompt_template_id: Optional[int] = None, + user_id: Optional[str] = None, ) -> str: return _regenerate_agent_value_with_llm( original_value=original_name, @@ -383,8 +404,8 @@ def _regenerate_agent_name_with_llm( model_id=model_id, tenant_id=tenant_id, language=language, - system_prompt_key="AGENT_NAME_REGENERATE_SYSTEM_PROMPT", - user_prompt_key="AGENT_NAME_REGENERATE_USER_PROMPT", + system_prompt_key="agent_name_regenerate_system_prompt", + user_prompt_key="agent_name_regenerate_user_prompt", default_system_prompt=( "You refine agent variable names so that they stay close to the " "original meaning and remain unique within the tenant." @@ -402,7 +423,9 @@ def _regenerate_agent_name_with_llm( tenant_id=tenant_id, agents_cache=agents_cache, exclude_agent_id=exclude_agent_id - ) + ), + prompt_template_id=prompt_template_id, + user_id=user_id, ) @@ -415,7 +438,9 @@ def _regenerate_agent_display_name_with_llm( tenant_id: str, language: str = LANGUAGE["ZH"], agents_cache: list[dict] | None = None, - exclude_agent_id: int | None = None + exclude_agent_id: int | None = None, + prompt_template_id: Optional[int] = None, + user_id: Optional[str] = None, ) -> str: return _regenerate_agent_value_with_llm( original_value=original_display_name, @@ -424,8 +449,8 @@ def _regenerate_agent_display_name_with_llm( model_id=model_id, tenant_id=tenant_id, language=language, - system_prompt_key="AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", - user_prompt_key="AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", + system_prompt_key="agent_display_name_regenerate_system_prompt", + user_prompt_key="agent_display_name_regenerate_user_prompt", default_system_prompt=( "You refine agent display names so they remain unique, concise, " "and aligned with the agent's capability." @@ -442,7 +467,9 @@ def _regenerate_agent_display_name_with_llm( tenant_id=tenant_id, agents_cache=agents_cache, exclude_agent_id=exclude_agent_id - ) + ), + prompt_template_id=prompt_template_id, + user_id=user_id, ) @@ -749,6 +776,11 @@ async def get_agent_info_impl(agent_id: int, tenant_id: str, version_no: int = 0 elif "business_logic_model_name" not in agent_info: agent_info["business_logic_model_name"] = None + if not agent_info.get("prompt_template_id"): + agent_info["prompt_template_id"] = SYSTEM_PROMPT_TEMPLATE_ID + if not agent_info.get("prompt_template_name"): + agent_info["prompt_template_name"] = SYSTEM_PROMPT_TEMPLATE_NAME + if agent_info.get("group_ids") is not None: agent_info["group_ids"] = convert_string_to_list(agent_info.get("group_ids")) @@ -805,6 +837,11 @@ async def get_creating_sub_agent_info_impl(authorization: str = Header(None)): async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = Header(None)): user_id, tenant_id, _ = get_current_user_info(authorization) + prompt_template_id, prompt_template_name = get_prompt_template_summary( + template_id=request.prompt_template_id, + tenant_id=tenant_id, + user_id=user_id, + ) # If agent_id is None, create a new agent; otherwise, update existing agent_id: Optional[int] = request.agent_id @@ -822,6 +859,8 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = "model_name": request.model_name, "business_logic_model_id": request.business_logic_model_id, "business_logic_model_name": request.business_logic_model_name, + "prompt_template_id": prompt_template_id, + "prompt_template_name": prompt_template_name, "max_steps": request.max_steps, "provide_run_summary": request.provide_run_summary, "duty_prompt": request.duty_prompt, @@ -834,6 +873,8 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = agent_id = created["agent_id"] else: # Update agent + request.prompt_template_id = prompt_template_id + request.prompt_template_name = prompt_template_name update_agent(agent_id, request, user_id) except Exception as e: logger.error(f"Failed to update agent info: {str(e)}") @@ -1189,7 +1230,9 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) model_id=model_id, model_name=model_display_name, business_logic_model_id=business_logic_model_id, - business_logic_model_name=business_logic_model_display_name) + business_logic_model_name=business_logic_model_display_name, + prompt_template_id=agent_info.get("prompt_template_id"), + prompt_template_name=agent_info.get("prompt_template_name")) return agent_info @@ -1322,6 +1365,8 @@ async def import_agent_by_agent_id( "model_name": import_agent_info.model_name, "business_logic_model_id": business_logic_model_id, "business_logic_model_name": import_agent_info.business_logic_model_name, + "prompt_template_id": import_agent_info.prompt_template_id or SYSTEM_PROMPT_TEMPLATE_ID, + "prompt_template_name": import_agent_info.prompt_template_name or SYSTEM_PROMPT_TEMPLATE_NAME, "max_steps": import_agent_info.max_steps, "provide_run_summary": import_agent_info.provide_run_summary, "duty_prompt": import_agent_info.duty_prompt, diff --git a/backend/services/agent_version_service.py b/backend/services/agent_version_service.py index 69163dbc6..58c2b8654 100644 --- a/backend/services/agent_version_service.py +++ b/backend/services/agent_version_service.py @@ -387,6 +387,11 @@ def rollback_version_impl( if not target_agent: raise ValueError(f"Agent snapshot for version {target_version_no} not found") + # Ensure the draft still exists before attempting an in-place restore. + draft_agent, _, _ = query_agent_draft(agent_id, tenant_id) + if not draft_agent: + raise ValueError("Agent draft not found") + # Get skill snapshots for target version from database import skill_db as skill_db_module target_skills = skill_db_module.query_skill_instances_by_agent_id( diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index 199521146..d2ecb287a 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -23,6 +23,7 @@ _generate_unique_agent_name_with_suffix, _generate_unique_display_name_with_suffix ) +from services.prompt_template_service import resolve_prompt_generate_template from utils.llm_utils import call_llm_for_system_prompt from utils.prompt_template_utils import ( get_prompt_generate_prompt_template, @@ -46,7 +47,7 @@ } -def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): +def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, prompt_template_id: Optional[int] = None, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): try: for system_prompt in generate_and_save_system_prompt_impl( agent_id=agent_id, @@ -55,6 +56,7 @@ def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: user_id=user_id, tenant_id=tenant_id, language=language, + prompt_template_id=prompt_template_id, tool_ids=tool_ids, sub_agent_ids=sub_agent_ids, knowledge_base_display_names=knowledge_base_display_names @@ -80,6 +82,7 @@ def generate_and_save_system_prompt_impl(agent_id: int, user_id: str, tenant_id: str, language: str, + prompt_template_id: Optional[int] = None, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): @@ -144,8 +147,17 @@ def generate_and_save_system_prompt_impl(agent_id: int, ] # Collect results and yield non-name fields immediately, but hold name fields for duplicate checking - for result_data in generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id, - model_id, language, knowledge_base_display_names): + for result_data in generate_system_prompt( + sub_agent_info_list, + task_description, + tool_info_list, + tenant_id, + user_id, + model_id, + language, + prompt_template_id, + knowledge_base_display_names, + ): result_type = result_data["type"] final_results[result_type] = result_data["content"] @@ -174,7 +186,9 @@ def generate_and_save_system_prompt_impl(agent_id: int, tenant_id=tenant_id, language=language, agents_cache=all_agents, - exclude_agent_id=agent_id + exclude_agent_id=agent_id, + prompt_template_id=prompt_template_id, + user_id=user_id, ) logger.info(f"Regenerated agent name: '{agent_name}'") final_results["agent_var_name"] = agent_name @@ -215,7 +229,9 @@ def generate_and_save_system_prompt_impl(agent_id: int, tenant_id=tenant_id, language=language, agents_cache=all_agents, - exclude_agent_id=agent_id + exclude_agent_id=agent_id, + prompt_template_id=prompt_template_id, + user_id=user_id, ) logger.info(f"Regenerated agent display_name: '{agent_display_name}'") final_results["agent_display_name"] = agent_display_name @@ -253,7 +269,6 @@ def generate_and_save_system_prompt_impl(agent_id: int, if not has_content: raise Exception("Failed to generate prompt content.") - def optimize_prompt_section_impl( agent_id: int, model_id: int, @@ -336,9 +351,15 @@ def optimize_prompt_section_impl( } -def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, model_id: int, language: str = LANGUAGE["ZH"], knowledge_base_display_names: Optional[List[str]] = None): + +def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, user_id: str, model_id: int, language: str = LANGUAGE["ZH"], prompt_template_id: Optional[int] = None, knowledge_base_display_names: Optional[List[str]] = None): """Main function for generating system prompts""" - prompt_for_generate = get_prompt_generate_prompt_template(language) + prompt_for_generate = resolve_prompt_generate_template( + tenant_id=tenant_id, + user_id=user_id, + language=language, + prompt_template_id=prompt_template_id, + ) # Prepare content for generating system prompts content = join_info_for_generate_system_prompt( @@ -451,15 +472,15 @@ def run_and_flag(tag, sys_prompt): logger.info("Generating system prompt") prompt_configs = [ - ("duty", prompt_for_generate["DUTY_SYSTEM_PROMPT"]), - ("constraint", prompt_for_generate["CONSTRAINT_SYSTEM_PROMPT"]), - ("few_shots", prompt_for_generate["FEW_SHOTS_SYSTEM_PROMPT"]), + ("duty", prompt_for_generate["duty_system_prompt"]), + ("constraint", prompt_for_generate["constraint_system_prompt"]), + ("few_shots", prompt_for_generate["few_shots_system_prompt"]), ("agent_var_name", - prompt_for_generate["AGENT_VARIABLE_NAME_SYSTEM_PROMPT"]), + prompt_for_generate["agent_variable_name_system_prompt"]), ("agent_display_name", - prompt_for_generate["AGENT_DISPLAY_NAME_SYSTEM_PROMPT"]), + prompt_for_generate["agent_display_name_system_prompt"]), ("agent_description", - prompt_for_generate["AGENT_DESCRIPTION_SYSTEM_PROMPT"]) + prompt_for_generate["agent_description_system_prompt"]) ] for tag, sys_prompt in prompt_configs: @@ -557,7 +578,7 @@ def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_lis template_context["knowledge_base_names"] = kb_names_str # Generate content using template - content = Template(prompt_for_generate["USER_PROMPT"], undefined=StrictUndefined).render(template_context) + content = Template(prompt_for_generate["user_prompt"], undefined=StrictUndefined).render(template_context) return content diff --git a/backend/services/prompt_template_service.py b/backend/services/prompt_template_service.py new file mode 100644 index 000000000..14224a099 --- /dev/null +++ b/backend/services/prompt_template_service.py @@ -0,0 +1,322 @@ +import logging +from typing import Optional + +from consts.const import DEFAULT_TENANT_ID, DEFAULT_USER_ID +from consts.const import LANGUAGE +from consts.exceptions import DuplicateError, NotFoundException, ValidationError +from consts.model import PromptTemplateRequest +from database.prompt_template_db import ( + create_prompt_template, + delete_prompt_template, + get_prompt_template_by_id, + get_prompt_template_by_name, + get_prompt_template_by_template_id, + query_prompt_templates_by_user, + upsert_prompt_template_by_id, + update_prompt_template, +) +from utils.prompt_template_utils import ( + get_prompt_generate_prompt_template, + merge_prompt_generate_templates, + normalize_prompt_generate_template_content, +) + +logger = logging.getLogger("prompt_template_service") + +SYSTEM_PROMPT_TEMPLATE_ID = 0 +SYSTEM_PROMPT_TEMPLATE_NAME = "system_default" +PROMPT_TEMPLATE_TYPE_AGENT_GENERATE = "agent_generate" +SYSTEM_PROMPT_TEMPLATE_DESCRIPTION = "System default prompt template" +SYSTEM_PROMPT_TEMPLATE_TENANT_ID = DEFAULT_TENANT_ID +SYSTEM_PROMPT_TEMPLATE_USER_ID = DEFAULT_USER_ID + + +def _normalize_prompt_template_entity(template: Optional[dict]) -> Optional[dict]: + """Normalize prompt template entity content keys to lowercase.""" + if not template: + return template + + normalized_template = dict(template) + normalized_template["template_content_zh"] = normalize_prompt_generate_template_content( + normalized_template.get("template_content_zh") + ) + template_content_en = normalize_prompt_generate_template_content( + normalized_template.get("template_content_en") + ) + normalized_template["template_content_en"] = template_content_en or None + return normalized_template + + +def build_system_default_prompt_template_payload() -> dict: + """Build the canonical system default prompt template payload from YAML files.""" + system_template_zh = normalize_prompt_generate_template_content( + get_prompt_generate_prompt_template(LANGUAGE["ZH"]) + ) + system_template_en = normalize_prompt_generate_template_content( + get_prompt_generate_prompt_template(LANGUAGE["EN"]) + ) + return { + "template_id": SYSTEM_PROMPT_TEMPLATE_ID, + "template_name": SYSTEM_PROMPT_TEMPLATE_NAME, + "description": SYSTEM_PROMPT_TEMPLATE_DESCRIPTION, + "template_type": PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + "tenant_id": SYSTEM_PROMPT_TEMPLATE_TENANT_ID, + "user_id": SYSTEM_PROMPT_TEMPLATE_USER_ID, + "template_content_zh": system_template_zh, + "template_content_en": system_template_en, + "created_by": SYSTEM_PROMPT_TEMPLATE_USER_ID, + "updated_by": SYSTEM_PROMPT_TEMPLATE_USER_ID, + "delete_flag": "N", + } + + +def sync_system_default_prompt_template() -> dict: + """Sync the YAML-backed system default prompt template into the database.""" + payload = build_system_default_prompt_template_payload() + prompt_template = upsert_prompt_template_by_id( + template_id=SYSTEM_PROMPT_TEMPLATE_ID, + template_data=payload, + user_id=SYSTEM_PROMPT_TEMPLATE_USER_ID, + ) + prompt_template["is_system_default"] = True + return _normalize_prompt_template_entity(prompt_template) + + +def get_system_default_prompt_template() -> dict: + """Return the system default prompt generation template from the database.""" + prompt_template = get_prompt_template_by_template_id( + template_id=SYSTEM_PROMPT_TEMPLATE_ID, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not prompt_template: + prompt_template = sync_system_default_prompt_template() + else: + prompt_template["is_system_default"] = True + return _normalize_prompt_template_entity({ + **prompt_template, + "is_system_default": True, + }) + + +def _normalize_template_request(request: PromptTemplateRequest) -> dict: + """Normalize prompt template request payload.""" + template_name = (request.template_name or "").strip() + if not template_name: + raise ValidationError("template_name is required") + + if request.template_type != PROMPT_TEMPLATE_TYPE_AGENT_GENERATE: + raise ValidationError("Unsupported template type") + + zh_content = normalize_prompt_generate_template_content( + request.template_content_zh.model_dump() + ) + if len(zh_content) == 0: + raise ValidationError("template_content_zh is required") + + en_content = None + if request.template_content_en is not None: + en_content = normalize_prompt_generate_template_content( + request.template_content_en.model_dump() + ) + if len(en_content) == 0: + en_content = None + + return { + "template_name": template_name, + "description": (request.description or "").strip() or None, + "template_type": request.template_type, + "template_content_zh": zh_content, + "template_content_en": en_content, + } + + +def list_prompt_templates_impl(tenant_id: str, user_id: str) -> list[dict]: + """List all prompt templates for the current user.""" + system_default_template = sync_system_default_prompt_template() + templates = query_prompt_templates_by_user( + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + return [system_default_template, *[ + _normalize_prompt_template_entity({ + **template, + "is_system_default": False, + }) + for template in templates + if template.get("template_id") != SYSTEM_PROMPT_TEMPLATE_ID + ]] + + +def get_prompt_template_detail_impl(template_id: int, tenant_id: str, user_id: str) -> dict: + """Get prompt template detail.""" + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + return get_system_default_prompt_template() + + template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not template: + raise NotFoundException("Prompt template not found") + + template["is_system_default"] = False + return _normalize_prompt_template_entity(template) + + +def create_prompt_template_impl( + request: PromptTemplateRequest, + tenant_id: str, + user_id: str, +) -> dict: + """Create a prompt template.""" + normalized_request = _normalize_template_request(request) + existing_template = get_prompt_template_by_name( + template_name=normalized_request["template_name"], + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if existing_template: + raise DuplicateError("Prompt template name already exists") + + created_template = create_prompt_template({ + **normalized_request, + "tenant_id": tenant_id, + "user_id": user_id, + "created_by": user_id, + "updated_by": user_id, + }) + created_template["is_system_default"] = False + return _normalize_prompt_template_entity(created_template) + + +def update_prompt_template_impl( + template_id: int, + request: PromptTemplateRequest, + tenant_id: str, + user_id: str, +) -> dict: + """Update a prompt template.""" + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + raise ValidationError("System default prompt template cannot be updated") + + existing_template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not existing_template: + raise NotFoundException("Prompt template not found") + + normalized_request = _normalize_template_request(request) + duplicate_template = get_prompt_template_by_name( + template_name=normalized_request["template_name"], + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if duplicate_template and duplicate_template["template_id"] != template_id: + raise DuplicateError("Prompt template name already exists") + + updated_template = update_prompt_template( + template_id=template_id, + template_data=normalized_request, + user_id=user_id, + ) + updated_template["is_system_default"] = False + return _normalize_prompt_template_entity(updated_template) + + +def delete_prompt_template_impl(template_id: int, tenant_id: str, user_id: str) -> dict: + """Delete a prompt template.""" + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + raise ValidationError("System default prompt template cannot be deleted") + + existing_template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not existing_template: + raise NotFoundException("Prompt template not found") + + deleted_count = delete_prompt_template(template_id=template_id, user_id=user_id) + return { + "template_id": template_id, + "deleted": deleted_count > 0, + } + + +def resolve_prompt_generate_template( + tenant_id: str, + user_id: str, + language: str, + prompt_template_id: Optional[int] = None, +) -> dict: + """Resolve prompt generation template for the current user and language.""" + system_default_template = sync_system_default_prompt_template() + system_template = ( + system_default_template.get("template_content_en") + if language == LANGUAGE["EN"] + else system_default_template.get("template_content_zh") + ) + fallback_system_template = system_default_template.get("template_content_zh") + + if not prompt_template_id or prompt_template_id == SYSTEM_PROMPT_TEMPLATE_ID: + return merge_prompt_generate_templates(system_template, fallback_system_template) + + prompt_template = get_prompt_template_by_id( + template_id=prompt_template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not prompt_template: + logger.warning( + "Prompt template %s not found for tenant %s user %s, falling back to system default", + prompt_template_id, + tenant_id, + user_id, + ) + return merge_prompt_generate_templates(system_template, fallback_system_template) + + custom_language_template = ( + prompt_template.get("template_content_en") + if language == LANGUAGE["EN"] + else prompt_template.get("template_content_zh") + ) + return merge_prompt_generate_templates( + custom_language_template, + prompt_template.get("template_content_zh"), + system_template, + fallback_system_template, + ) + + +def get_prompt_template_summary( + template_id: Optional[int], + tenant_id: str, + user_id: str, +) -> tuple[Optional[int], Optional[str]]: + """Resolve prompt template identity for saving on agent.""" + if template_id is None: + return None, None + + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + return SYSTEM_PROMPT_TEMPLATE_ID, SYSTEM_PROMPT_TEMPLATE_NAME + + prompt_template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not prompt_template: + raise NotFoundException("Prompt template not found") + + return prompt_template["template_id"], prompt_template["template_name"] diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index fb2e06fdb..ec790f08e 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -100,17 +100,21 @@ def call_llm_for_system_prompt( reasoning_content_seen = False content_tokens_seen = 0 for chunk in current_request: - if not hasattr(chunk, "choices"): - logger.warning("Received non-standard chunk without choices during prompt generation.") + choices = getattr(chunk, "choices", None) + if choices is None: + logger.warning("Received non-standard chunk without choices during prompt generation.") + continue + if not choices: + logger.debug("Received empty choices chunk during prompt generation; skipping.") + continue + + delta = getattr(choices[0], "delta", None) + if delta is None: + logger.debug("Skipping LLM stream chunk without delta") continue - - if not chunk.choices: - logger.debug("Received empty choices chunk during prompt generation; skipping.") - continue - - delta = chunk.choices[0].delta + reasoning_content = getattr(delta, "reasoning_content", None) - new_token = delta.content + new_token = getattr(delta, "content", None) # Note: reasoning_content is separate metadata and doesn't affect content filtering # We only filter content based on tags in delta.content diff --git a/backend/utils/prompt_template_utils.py b/backend/utils/prompt_template_utils.py index 3cd267a10..8822e5fd4 100644 --- a/backend/utils/prompt_template_utils.py +++ b/backend/utils/prompt_template_utils.py @@ -5,9 +5,56 @@ import yaml from consts.const import LANGUAGE +from consts.prompt_template import ( + PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP, + PROMPT_GENERATE_TEMPLATE_FIELDS, +) logger = logging.getLogger("prompt_template_utils") +PROMPT_GENERATE_TEMPLATE_KEY_MAP = PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP +PROMPT_GENERATE_TEMPLATE_KEYS = PROMPT_GENERATE_TEMPLATE_FIELDS + + +def get_prompt_generate_template_keys() -> list[str]: + """Return the supported prompt generation template keys.""" + return list(PROMPT_GENERATE_TEMPLATE_FIELDS) + + +def normalize_prompt_generate_template_content( + template_content: Optional[Dict[str, Any]] +) -> Dict[str, str]: + """Normalize prompt generation template content and keep non-empty fields only.""" + normalized: Dict[str, str] = {} + if not isinstance(template_content, dict): + return normalized + + for key in PROMPT_GENERATE_TEMPLATE_FIELDS: + legacy_key = PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP[key] + value = template_content.get(key) + if value is None: + value = template_content.get(legacy_key) + if isinstance(value, str) and value.strip(): + normalized[key] = value + + return normalized + + +def merge_prompt_generate_templates( + *template_contents: Optional[Dict[str, Any]] +) -> Dict[str, str]: + """Merge multiple prompt generation templates with first-non-empty priority.""" + merged: Dict[str, str] = {} + + for template_content in template_contents: + normalized = normalize_prompt_generate_template_content(template_content) + for key in PROMPT_GENERATE_TEMPLATE_FIELDS: + value = normalized.get(key) + if value and key not in merged: + merged[key] = value + + return merged + def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kwargs) -> Dict[str, Any]: """ diff --git a/docker/init.sql b/docker/init.sql index 2e494fc72..ed45026ab 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -320,6 +320,8 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tenant_agent_t ( model_id INTEGER, business_logic_model_name VARCHAR(100), business_logic_model_id INTEGER, + prompt_template_id INTEGER, + prompt_template_name VARCHAR(100), max_steps INTEGER, duty_prompt TEXT, constraint_prompt TEXT, @@ -370,6 +372,8 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_name IS '[DEPRECATED] Name of t COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_id IS 'Model ID, foreign key reference to model_record_t.model_id'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_name IS 'Model name used for business logic prompt generation'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_id IS 'Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_id IS 'Prompt template ID used for business logic prompt generation'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_name IS 'Prompt template name used for business logic prompt generation'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.max_steps IS 'Maximum number of steps'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.duty_prompt IS 'Duty prompt'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.constraint_prompt IS 'Constraint prompt'; @@ -393,8 +397,98 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.enable_context_manager IS 'Whether to -- Create index for is_new queries CREATE INDEX IF NOT EXISTS idx_ag_tenant_agent_t_is_new ON nexent.ag_tenant_agent_t (tenant_id, is_new) + +CREATE TABLE IF NOT EXISTS nexent.ag_prompt_template_t ( + template_id SERIAL PRIMARY KEY, + template_name VARCHAR(100) NOT NULL, + description VARCHAR(500), + template_type VARCHAR(50) NOT NULL DEFAULT 'agent_generate', + tenant_id VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + template_content_zh JSONB NOT NULL, + template_content_en JSONB, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE nexent.ag_prompt_template_t OWNER TO "root"; + +CREATE OR REPLACE FUNCTION update_ag_prompt_template_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER update_ag_prompt_template_update_time_trigger +BEFORE UPDATE ON nexent.ag_prompt_template_t +FOR EACH ROW +EXECUTE FUNCTION update_ag_prompt_template_update_time(); + +COMMENT ON TABLE nexent.ag_prompt_template_t IS 'Prompt template table for user-defined business logic generation prompts'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_id IS 'Prompt template ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_name IS 'Prompt template name'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.description IS 'Prompt template description'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_type IS 'Prompt template type'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.tenant_id IS 'Tenant ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.user_id IS 'User ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_zh IS 'Chinese prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_en IS 'English prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +CREATE UNIQUE INDEX IF NOT EXISTS uq_prompt_template_user_name_active +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_name) WHERE delete_flag = 'N'; +CREATE INDEX IF NOT EXISTS idx_ag_prompt_template_t_user +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_type) +WHERE delete_flag = 'N'; + +INSERT INTO nexent.ag_prompt_template_t ( + template_id, + template_name, + description, + template_type, + tenant_id, + user_id, + template_content_zh, + template_content_en, + created_by, + updated_by, + delete_flag +) +VALUES ( + 0, + 'system_default', + 'System default prompt template', + 'agent_generate', + 'tenant_id', + 'user_id', + '{}'::jsonb, + '{}'::jsonb, + 'user_id', + 'user_id', + 'N' +) +ON CONFLICT (template_id) DO UPDATE SET + template_name = EXCLUDED.template_name, + description = EXCLUDED.description, + template_type = EXCLUDED.template_type, + tenant_id = EXCLUDED.tenant_id, + user_id = EXCLUDED.user_id, + template_content_zh = EXCLUDED.template_content_zh, + template_content_en = EXCLUDED.template_content_en, + updated_by = EXCLUDED.updated_by, + delete_flag = 'N'; + -- Create the ag_tool_instance_t table in the nexent schema CREATE TABLE IF NOT EXISTS nexent.ag_tool_instance_t ( diff --git a/docker/sql/v2.1.0_0503_add_prompt_template_t.sql b/docker/sql/v2.1.0_0503_add_prompt_template_t.sql new file mode 100644 index 000000000..3db9a9701 --- /dev/null +++ b/docker/sql/v2.1.0_0503_add_prompt_template_t.sql @@ -0,0 +1,115 @@ +-- Migration: Add prompt template table and agent prompt template fields +-- Date: 2026-05-03 +-- Description: Add user-scoped prompt template storage and bind selected prompt template to agents + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS prompt_template_id INTEGER; + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS prompt_template_name VARCHAR(100); + +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_id IS 'Prompt template ID used for business logic prompt generation'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_name IS 'Prompt template name used for business logic prompt generation'; + +UPDATE nexent.ag_tenant_agent_t +SET prompt_template_id = 0, + prompt_template_name = 'system_default' +WHERE delete_flag = 'N' + AND (prompt_template_id IS NULL OR prompt_template_name IS NULL); + +CREATE TABLE IF NOT EXISTS nexent.ag_prompt_template_t ( + template_id SERIAL PRIMARY KEY, + template_name VARCHAR(100) NOT NULL, + description VARCHAR(500), + template_type VARCHAR(50) NOT NULL DEFAULT 'agent_generate', + tenant_id VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + template_content_zh JSONB NOT NULL, + template_content_en JSONB, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE nexent.ag_prompt_template_t OWNER TO "root"; + +CREATE OR REPLACE FUNCTION update_ag_prompt_template_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS update_ag_prompt_template_update_time_trigger ON nexent.ag_prompt_template_t; + +CREATE TRIGGER update_ag_prompt_template_update_time_trigger +BEFORE UPDATE ON nexent.ag_prompt_template_t +FOR EACH ROW +EXECUTE FUNCTION update_ag_prompt_template_update_time(); + +ALTER TABLE nexent.ag_prompt_template_t +DROP CONSTRAINT IF EXISTS uq_prompt_template_user_name; + +COMMENT ON TABLE nexent.ag_prompt_template_t IS 'Prompt template table for user-defined business logic generation prompts'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_id IS 'Prompt template ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_name IS 'Prompt template name'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.description IS 'Prompt template description'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_type IS 'Prompt template type'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.tenant_id IS 'Tenant ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.user_id IS 'User ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_zh IS 'Chinese prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_en IS 'English prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +DROP INDEX IF EXISTS nexent.uq_prompt_template_user_name_active; +CREATE UNIQUE INDEX IF NOT EXISTS uq_prompt_template_user_name_active +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_name) +WHERE delete_flag = 'N'; + +CREATE INDEX IF NOT EXISTS idx_ag_prompt_template_t_user +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_type) +WHERE delete_flag = 'N'; + +INSERT INTO nexent.ag_prompt_template_t ( + template_id, + template_name, + description, + template_type, + tenant_id, + user_id, + template_content_zh, + template_content_en, + created_by, + updated_by, + delete_flag +) +VALUES ( + 0, + 'system_default', + 'System default prompt template', + 'agent_generate', + 'tenant_id', + 'user_id', + '{}'::jsonb, + '{}'::jsonb, + 'user_id', + 'user_id', + 'N' +) +ON CONFLICT (template_id) DO UPDATE SET + template_name = EXCLUDED.template_name, + description = EXCLUDED.description, + template_type = EXCLUDED.template_type, + tenant_id = EXCLUDED.tenant_id, + user_id = EXCLUDED.user_id, + template_content_zh = EXCLUDED.template_content_zh, + template_content_en = EXCLUDED.template_content_en, + updated_by = EXCLUDED.updated_by, + delete_flag = 'N'; diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index db2667535..0f03c8580 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -17,10 +17,14 @@ import { App, } from "antd"; import type { TabsProps } from "antd"; -import { Sparkles, Zap, Maximize2 } from "lucide-react"; +import { Zap, Maximize2, Settings2, Sparkles } from "lucide-react"; import log from "@/lib/logger"; -import { AgentProfileInfo, AgentBusinessInfo } from "@/types/agentConfig"; +import { + AgentProfileInfo, + AgentBusinessInfo, + PromptTemplate, +} from "@/types/agentConfig"; import { getAgentGenerationCache, setAgentGenerationStatus, @@ -39,10 +43,12 @@ import { useModelList } from "@/hooks/model/useModelList"; import { useConfig } from "@/hooks/useConfig"; import { useTenantList } from "@/hooks/tenant/useTenantList"; import { useGroupList } from "@/hooks/group/useGroupList"; +import { usePromptTemplateList } from "@/hooks/agent/usePromptTemplateList"; import { USER_ROLES } from "@/const/auth"; import { Can } from "@/components/permission/Can"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import ExpandEditModal from "./ExpandEditModal"; +import PromptTemplateManagerModal from "./PromptTemplateManagerModal"; import PromptOptimizeModal from "./PromptOptimizeModal"; const { TextArea } = Input; @@ -75,6 +81,11 @@ export default function AgentGenerateDetail({ // Model data: default LLM name from config, resolve to full model from model list const { defaultLlmModelName } = useConfig(); const { availableLlmModels, models, isLoading: loadingModels } = useModelList(); + const { + templates: promptTemplates, + isLoading: loadingPromptTemplates, + invalidate: invalidatePromptTemplates, + } = usePromptTemplateList(); const defaultLlmModel = useMemo(() => { if (defaultLlmModelName) { const found = availableLlmModels.find( @@ -116,6 +127,7 @@ export default function AgentGenerateDetail({ // Modal states const [expandModalOpen, setExpandModalOpen] = useState(false); const [expandModalType, setExpandModalType] = useState<'duty' | 'constraint' | 'few-shots' | null>(null); + const [promptTemplateManagerOpen, setPromptTemplateManagerOpen] = useState(false); const [optimizeModalOpen, setOptimizeModalOpen] = useState(false); const [optimizeModalType, setOptimizeModalType] = useState<'duty' | 'constraint' | 'few-shots' | null>(null); @@ -136,14 +148,24 @@ export default function AgentGenerateDetail({ useEffect(() => { if (editedAgent.business_description !== businessInfo.businessDescription || editedAgent.business_logic_model_name !== businessInfo.businessLogicModelName || - editedAgent.business_logic_model_id !== businessInfo.businessLogicModelId) { + editedAgent.business_logic_model_id !== businessInfo.businessLogicModelId || + (editedAgent.prompt_template_id ?? 0) !== businessInfo.promptTemplateId || + (editedAgent.prompt_template_name || "system_default") !== businessInfo.promptTemplateName) { setBusinessInfo({ businessDescription: editedAgent.business_description || "", businessLogicModelName: editedAgent.business_logic_model_name || "", businessLogicModelId: editedAgent.business_logic_model_id || 0, + promptTemplateId: editedAgent.prompt_template_id ?? 0, + promptTemplateName: editedAgent.prompt_template_name || "system_default", }); } - }, [editedAgent.business_description, editedAgent.business_logic_model_name, editedAgent.business_logic_model_id]); + }, [ + editedAgent.business_description, + editedAgent.business_logic_model_name, + editedAgent.business_logic_model_id, + editedAgent.prompt_template_id, + editedAgent.prompt_template_name, + ]); // Only show "no edit permission" tooltip when the panel is active and agent is read-only. // Note: when no agent is selected, AgentInfoComp shows an overlay and we should not show @@ -197,6 +219,8 @@ export default function AgentGenerateDetail({ businessDescription: "", businessLogicModelName: "", businessLogicModelId: 0, + promptTemplateId: 0, + promptTemplateName: "system_default", }); const normalizeNumberArray = (value: unknown): number[] => { @@ -289,6 +313,8 @@ export default function AgentGenerateDetail({ "", businessLogicModelId: editedAgent.business_logic_model_id || defaultLlmModel?.id || 0, + promptTemplateId: editedAgent.prompt_template_id ?? 0, + promptTemplateName: editedAgent.prompt_template_name || "system_default", }; // Initialize local business description state setBusinessInfo(initialBusinessInfo); @@ -403,6 +429,8 @@ export default function AgentGenerateDetail({ business_description: value, business_logic_model_id: businessInfo.businessLogicModelId, business_logic_model_name: businessInfo.businessLogicModelName, + prompt_template_id: businessInfo.promptTemplateId, + prompt_template_name: businessInfo.promptTemplateName, }); }; @@ -421,6 +449,34 @@ export default function AgentGenerateDetail({ business_description: businessInfo.businessDescription || "", business_logic_model_id: selectedModel?.id || 0, business_logic_model_name: modelName, + prompt_template_id: businessInfo.promptTemplateId, + prompt_template_name: businessInfo.promptTemplateName, + }); + }; + + const handlePromptTemplateChange = (templateId: number) => { + const selectedTemplate = promptTemplates.find( + (template) => template.template_id === templateId + ); + if (!selectedTemplate) { + return; + } + handleSelectPromptTemplate(selectedTemplate); + }; + + const handleSelectPromptTemplate = (template: PromptTemplate) => { + setBusinessInfo((prev) => ({ + ...prev, + promptTemplateId: template.template_id, + promptTemplateName: template.template_name, + })); + + updateBusinessInfo({ + business_description: businessInfo.businessDescription || "", + business_logic_model_id: businessInfo.businessLogicModelId, + business_logic_model_name: businessInfo.businessLogicModelName, + prompt_template_id: template.template_id, + prompt_template_name: template.template_name, }); }; @@ -756,6 +812,7 @@ export default function AgentGenerateDetail({ agent_id: effectiveAgentId, task_description: businessInfo.businessDescription, model_id: businessInfo.businessLogicModelId.toString(), + prompt_template_id: businessInfo.promptTemplateId, sub_agent_ids: editedAgent.sub_agent_id_list, tool_ids: Array.isArray(editedAgent.tools) ? editedAgent.tools.map((tool: any) => @@ -968,6 +1025,47 @@ export default function AgentGenerateDetail({ disabled: model.connect_status !== "available", })); + const promptTemplateSelectOptions = useMemo(() => { + const options = promptTemplates.map((template) => ({ + value: template.template_id, + label: template.is_system_default + ? t("businessLogic.config.template.systemDefault") + : template.template_name, + })); + + if ( + businessInfo.promptTemplateId && + !options.some((option) => option.value === businessInfo.promptTemplateId) + ) { + options.unshift({ + value: businessInfo.promptTemplateId, + label: businessInfo.promptTemplateName || t("businessLogic.config.template.label"), + }); + } + + return options; + }, [ + businessInfo.promptTemplateId, + businessInfo.promptTemplateName, + promptTemplates, + t, + ]); + + const generationControlSelectStyle = { + width: "min(300px, 100%)", + minWidth: "220px", + maxWidth: "300px", + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + }; + + const generationControlLabelStyle = { + width: 84, + minWidth: 84, + flexShrink: 0, + }; + // Tab items configuration const tabItems = [ { @@ -1286,46 +1384,77 @@ export default function AgentGenerateDetail({ )} {/* Control area */} - -
- - {t("model.type.llm")}: - - } + disabled={!editable || isGenerating} + style={generationControlSelectStyle} + /> +
+
+ {wrapNoEditTooltipInline( + + )} +
+
+ + +
+ - - {isGenerating - ? t("businessLogic.config.button.generating") - : t("businessLogic.config.button.generatePrompt")} - - - )} -
+ {t("model.type.llm")}: + + option.value === selectedTemplateId)?.label + } + disabled + style={{ flex: 1, minWidth: 220 }} + /> +
+ + { + const isSelected = selectedTemplateId === template.template_id; + const isSystemDefault = template.is_system_default; + return ( + + + + + + + {isSystemDefault + ? t("businessLogic.config.template.systemDefault") + : template.template_name} + + {isSystemDefault ? ( + + {t("businessLogic.config.template.system")} + + ) : null} + {isSelected ? ( + + {t("businessLogic.config.template.current")} + + ) : null} + + + {template.description || t("businessLogic.config.template.noDescription")} + + + + + + + + + + + + + ); + }} + /> + + + + + + + + {t("businessLogic.config.template.manageDescription")} + + + +
+ + + + + + + + + + +
+
+ + ); +} diff --git a/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx b/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx index edfeff559..a1d809b50 100644 --- a/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx +++ b/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx @@ -259,6 +259,8 @@ export default function AgentList({ few_shots_prompt: detail.few_shots_prompt, business_logic_model_name: detail.business_logic_model_name ?? undefined, business_logic_model_id: detail.business_logic_model_id ?? undefined, + prompt_template_id: detail.prompt_template_id ?? 0, + prompt_template_name: detail.prompt_template_name ?? "system_default", enabled_tool_ids: enabledToolIds, related_agent_ids: subAgentIds, }); diff --git a/frontend/const/promptTemplate.ts b/frontend/const/promptTemplate.ts new file mode 100644 index 000000000..aada2371e --- /dev/null +++ b/frontend/const/promptTemplate.ts @@ -0,0 +1,82 @@ +export const PROMPT_TEMPLATE_FIELD_CONFIG = [ + { + key: "duty_system_prompt", + labelKey: "systemPrompt.card.duty.title", + section: "basic", + }, + { + key: "constraint_system_prompt", + labelKey: "systemPrompt.card.constraint.title", + section: "basic", + }, + { + key: "few_shots_system_prompt", + labelKey: "systemPrompt.card.fewShots.title", + section: "basic", + }, + { + key: "user_prompt", + labelKey: "businessLogic.config.template.field.userPrompt", + section: "basic", + }, + { + key: "agent_variable_name_system_prompt", + labelKey: "businessLogic.config.template.field.agentVariableName", + section: "advanced", + }, + { + key: "agent_display_name_system_prompt", + labelKey: "businessLogic.config.template.field.agentDisplayName", + section: "advanced", + }, + { + key: "agent_description_system_prompt", + labelKey: "businessLogic.config.template.field.agentDescription", + section: "advanced", + }, + { + key: "agent_name_regenerate_system_prompt", + labelKey: "businessLogic.config.template.field.agentNameRegenerateSystem", + section: "advanced", + }, + { + key: "agent_name_regenerate_user_prompt", + labelKey: "businessLogic.config.template.field.agentNameRegenerateUser", + section: "advanced", + }, + { + key: "agent_display_name_regenerate_system_prompt", + labelKey: "businessLogic.config.template.field.agentDisplayNameRegenerateSystem", + section: "advanced", + }, + { + key: "agent_display_name_regenerate_user_prompt", + labelKey: "businessLogic.config.template.field.agentDisplayNameRegenerateUser", + section: "advanced", + }, +] as const; + +export type PromptTemplateFieldConfig = (typeof PROMPT_TEMPLATE_FIELD_CONFIG)[number]; +export type PromptTemplateFieldKey = PromptTemplateFieldConfig["key"]; + +export const PROMPT_TEMPLATE_FIELD_KEYS = PROMPT_TEMPLATE_FIELD_CONFIG.map( + (field) => field.key +) as PromptTemplateFieldKey[]; + +export const BASIC_PROMPT_TEMPLATE_FIELDS = PROMPT_TEMPLATE_FIELD_CONFIG.filter( + (field) => field.section === "basic" +); + +export const ADVANCED_PROMPT_TEMPLATE_FIELDS = PROMPT_TEMPLATE_FIELD_CONFIG.filter( + (field) => field.section === "advanced" +); + +export function createEmptyPromptTemplateContent(): Record { + return PROMPT_TEMPLATE_FIELD_KEYS.reduce( + (content, key) => { + content[key] = ""; + return content; + }, + {} as Record + ); +} diff --git a/frontend/hooks/agent/usePromptTemplateList.ts b/frontend/hooks/agent/usePromptTemplateList.ts new file mode 100644 index 000000000..592776b7c --- /dev/null +++ b/frontend/hooks/agent/usePromptTemplateList.ts @@ -0,0 +1,22 @@ +import { useQuery, useQueryClient } from "@tanstack/react-query"; + +import { promptTemplateService } from "@/services/promptTemplateService"; +import { PromptTemplate } from "@/types/agentConfig"; + +export function usePromptTemplateList() { + const queryClient = useQueryClient(); + + const query = useQuery({ + queryKey: ["promptTemplates"], + queryFn: async (): Promise => { + return promptTemplateService.list(); + }, + staleTime: 60_000, + }); + + return { + ...query, + templates: query.data ?? [], + invalidate: () => queryClient.invalidateQueries({ queryKey: ["promptTemplates"] }), + }; +} diff --git a/frontend/hooks/agent/useSaveGuard.ts b/frontend/hooks/agent/useSaveGuard.ts index 1f3e82783..6d948deff 100644 --- a/frontend/hooks/agent/useSaveGuard.ts +++ b/frontend/hooks/agent/useSaveGuard.ts @@ -142,6 +142,8 @@ export const useSaveGuard = () => { few_shots_prompt: currentEditedAgent.few_shots_prompt, business_logic_model_name: currentEditedAgent.business_logic_model_name ?? undefined, business_logic_model_id: currentEditedAgent.business_logic_model_id ?? undefined, + prompt_template_id: currentEditedAgent.prompt_template_id ?? 0, + prompt_template_name: currentEditedAgent.prompt_template_name ?? "system_default", enabled_tool_ids: enabledToolIds, enabled_skill_ids: enabledSkillIds, related_agent_ids: relatedAgentIds, diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 48f85641c..d51367587 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -947,6 +947,43 @@ "businessLogic.config.message.agentDeleteSuccess": "Agent delete success", "businessLogic.config.message.agentDeleteFailed": "Agent delete failed", "businessLogic.config.message.agentSaveSuccess": "Agent save success", + "businessLogic.config.template.label": "Prompt Template", + "businessLogic.config.template.manage": "Manage Templates", + "businessLogic.config.template.manageDescription": "Choose a prompt template for generation, or create your own private templates.", + "businessLogic.config.template.create": "New Template", + "businessLogic.config.template.use": "Use", + "businessLogic.config.template.current": "Current", + "businessLogic.config.template.system": "System", + "businessLogic.config.template.systemDefault": "System Default", + "businessLogic.config.template.empty": "No prompt templates", + "businessLogic.config.template.noDescription": "No description", + "businessLogic.config.template.name": "Template Name", + "businessLogic.config.template.nameRequired": "Please enter a template name", + "businessLogic.config.template.description": "Description", + "businessLogic.config.template.language.zh": "Chinese Template", + "businessLogic.config.template.language.en": "English Template", + "businessLogic.config.template.contentRequired": "This field is required", + "businessLogic.config.template.basicSection": "Basic Configuration", + "businessLogic.config.template.basicDescription": "Configure the core prompts users most often care about. The remaining prompt segments can be adjusted in Advanced Configuration.", + "businessLogic.config.template.englishOptionalDescription": "English content is optional. Leave it blank to fall back to the Chinese template during generation.", + "businessLogic.config.template.advancedSection": "Advanced Configuration", + "businessLogic.config.template.advancedDescription": "These fields are still stored with the template and are suitable for fine-grained control of naming and regeneration behavior.", + "businessLogic.config.template.createTitle": "Create Prompt Template", + "businessLogic.config.template.editTitle": "Edit Prompt Template", + "businessLogic.config.template.saveSuccess": "Prompt template saved successfully", + "businessLogic.config.template.saveError": "Failed to save prompt template", + "businessLogic.config.template.deleteSuccess": "Prompt template deleted successfully", + "businessLogic.config.template.deleteError": "Failed to delete prompt template", + "businessLogic.config.template.deleteConfirm": "Are you sure you want to delete prompt template {{name}}?", + "businessLogic.config.template.loadError": "Failed to load prompt template", + "businessLogic.config.template.field.agentVariableName": "Agent Variable Name Prompt", + "businessLogic.config.template.field.agentDisplayName": "Agent Display Name Prompt", + "businessLogic.config.template.field.agentDescription": "Agent Description Prompt", + "businessLogic.config.template.field.userPrompt": "User Prompt", + "businessLogic.config.template.field.agentNameRegenerateSystem": "Agent Name Regenerate System Prompt", + "businessLogic.config.template.field.agentNameRegenerateUser": "Agent Name Regenerate User Prompt", + "businessLogic.config.template.field.agentDisplayNameRegenerateSystem": "Agent Display Name Regenerate System Prompt", + "businessLogic.config.template.field.agentDisplayNameRegenerateUser": "Agent Display Name Regenerate User Prompt", "businessLogic.config.import.duplicateTitle": "Duplicate Agent detected", "businessLogic.config.import.duplicateDescription": "The imported Agent name or display name conflicts with an existing Agent. You can choose to import directly or call the LLM to regenerate a unique name before importing.", "businessLogic.config.import.duplicateConfirm": "Regenerate and import", @@ -1868,6 +1905,7 @@ "common.loading": "Loading", "common.save": "Save", "common.cancel": "Cancel", + "common.close": "Close", "common.confirm": "Confirm", "common.skip": "Skip", "common.saving": "Saving...", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 48c5ed9c5..7cbff0fa2 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -938,6 +938,43 @@ "businessLogic.config.message.agentDeleteSuccess": "智能体删除成功", "businessLogic.config.message.agentDeleteFailed": "智能体删除失败", "businessLogic.config.message.agentSaveSuccess": "智能体保存成功", + "businessLogic.config.template.label": "提示词模板", + "businessLogic.config.template.manage": "管理模板", + "businessLogic.config.template.manageDescription": "选择用于生成的提示词模板,或创建仅自己可见的私有模板。", + "businessLogic.config.template.create": "新建模板", + "businessLogic.config.template.use": "使用", + "businessLogic.config.template.current": "当前使用", + "businessLogic.config.template.system": "系统", + "businessLogic.config.template.systemDefault": "系统默认", + "businessLogic.config.template.empty": "暂无提示词模板", + "businessLogic.config.template.noDescription": "暂无描述", + "businessLogic.config.template.name": "模板名称", + "businessLogic.config.template.nameRequired": "请输入模板名称", + "businessLogic.config.template.description": "模板描述", + "businessLogic.config.template.language.zh": "中文模板", + "businessLogic.config.template.language.en": "英文模板", + "businessLogic.config.template.contentRequired": "该字段不能为空", + "businessLogic.config.template.basicSection": "基础配置", + "businessLogic.config.template.basicDescription": "默认展示用户最常调整的核心提示词,其余提示词片段可在高级配置中继续编辑。", + "businessLogic.config.template.englishOptionalDescription": "英文内容为选填,留空时生成阶段会回退使用中文模板。", + "businessLogic.config.template.advancedSection": "高级配置", + "businessLogic.config.template.advancedDescription": "这些字段也会随模板一并入库,适合精细控制名称生成和重生成行为。", + "businessLogic.config.template.createTitle": "新建提示词模板", + "businessLogic.config.template.editTitle": "编辑提示词模板", + "businessLogic.config.template.saveSuccess": "提示词模板保存成功", + "businessLogic.config.template.saveError": "提示词模板保存失败", + "businessLogic.config.template.deleteSuccess": "提示词模板删除成功", + "businessLogic.config.template.deleteError": "提示词模板删除失败", + "businessLogic.config.template.deleteConfirm": "确定要删除提示词模板 {{name}} 吗?", + "businessLogic.config.template.loadError": "加载提示词模板失败", + "businessLogic.config.template.field.agentVariableName": "智能体变量名提示词", + "businessLogic.config.template.field.agentDisplayName": "智能体展示名提示词", + "businessLogic.config.template.field.agentDescription": "智能体描述提示词", + "businessLogic.config.template.field.userPrompt": "用户提示词", + "businessLogic.config.template.field.agentNameRegenerateSystem": "变量名重生成系统提示词", + "businessLogic.config.template.field.agentNameRegenerateUser": "变量名重生成用户提示词", + "businessLogic.config.template.field.agentDisplayNameRegenerateSystem": "展示名重生成系统提示词", + "businessLogic.config.template.field.agentDisplayNameRegenerateUser": "展示名重生成用户提示词", "businessLogic.config.import.duplicateTitle": "检测到重名智能体", "businessLogic.config.import.duplicateDescription": "导入的智能体名称或展示名称与已有智能体重复。您可以选择直接导入或调用 LLM 重新生成唯一名称后导入。", "businessLogic.config.import.duplicateConfirm": "重新生成并导入", @@ -1915,6 +1952,7 @@ "common.loading": "加载中", "common.save": "保存", "common.cancel": "取消", + "common.close": "关闭", "common.confirm": "确定", "common.skip": "跳过", "common.saving": "保存中...", diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index 926096903..1bbffbd38 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -401,6 +401,8 @@ export interface UpdateAgentInfoPayload { business_description?: string; business_logic_model_name?: string; business_logic_model_id?: number; + prompt_template_id?: number; + prompt_template_name?: string; enabled_tool_ids?: number[]; enabled_skill_ids?: number[]; related_agent_ids?: number[]; @@ -698,6 +700,8 @@ export const searchAgentInfo = async (agentId: number, tenantId?: string, versio business_description: data.business_description, business_logic_model_name: data.business_logic_model_name, business_logic_model_id: data.business_logic_model_id, + prompt_template_id: data.prompt_template_id ?? 0, + prompt_template_name: data.prompt_template_name ?? "system_default", provide_run_summary: data.provide_run_summary, enabled: data.enabled, is_available: data.is_available, diff --git a/frontend/services/api.ts b/frontend/services/api.ts index b441ff2e0..b6ea75a6d 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -87,6 +87,13 @@ export const API_ENDPOINTS = { generate: `${API_BASE_URL}/prompt/generate`, optimize: `${API_BASE_URL}/prompt/optimize`, }, + promptTemplates: { + list: `${API_BASE_URL}/prompt_templates`, + detail: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + create: `${API_BASE_URL}/prompt_templates`, + update: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + delete: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + }, stt: { ws: `/api/voice/stt/ws`, }, diff --git a/frontend/services/promptTemplateService.ts b/frontend/services/promptTemplateService.ts new file mode 100644 index 000000000..c88275ae1 --- /dev/null +++ b/frontend/services/promptTemplateService.ts @@ -0,0 +1,90 @@ +import { API_ENDPOINTS, fetchWithErrorHandling } from "./api"; + +import { getAuthHeaders } from "@/lib/auth"; +import log from "@/lib/logger"; +import { + PromptTemplate, + PromptTemplatePayload, +} from "@/types/agentConfig"; + +async function requestJson(url: string, options: RequestInit = {}): Promise { + const response = await fetchWithErrorHandling(url, { + ...options, + headers: { + ...getAuthHeaders(), + ...(options.headers || {}), + }, + }); + return response.json(); +} + +export const promptTemplateService = { + async list(): Promise { + try { + const data = await requestJson(API_ENDPOINTS.promptTemplates.list, { + method: "GET", + }); + return data || []; + } catch (error) { + log.error("Failed to list prompt templates:", error); + return []; + } + }, + + async detail(templateId: number): Promise { + try { + const data = await requestJson( + API_ENDPOINTS.promptTemplates.detail(templateId), + { method: "GET" } + ); + return data; + } catch (error) { + log.error("Failed to get prompt template detail:", error); + return null; + } + }, + + async create(payload: PromptTemplatePayload): Promise { + try { + const data = await requestJson( + API_ENDPOINTS.promptTemplates.create, + { + method: "POST", + body: JSON.stringify(payload), + } + ); + return data; + } catch (error) { + log.error("Failed to create prompt template:", error); + throw error; + } + }, + + async update(templateId: number, payload: PromptTemplatePayload): Promise { + try { + const data = await requestJson( + API_ENDPOINTS.promptTemplates.update(templateId), + { + method: "PUT", + body: JSON.stringify(payload), + } + ); + return data; + } catch (error) { + log.error("Failed to update prompt template:", error); + throw error; + } + }, + + async remove(templateId: number): Promise { + try { + await requestJson(API_ENDPOINTS.promptTemplates.delete(templateId), { + method: "DELETE", + }); + return true; + } catch (error) { + log.error("Failed to delete prompt template:", error); + throw error; + } + }, +}; diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index e0840acf3..83fbef586 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -35,6 +35,8 @@ export type EditableAgent = Pick< | "business_description" | "business_logic_model_name" | "business_logic_model_id" + | "prompt_template_id" + | "prompt_template_name" | "sub_agent_id_list" | "group_ids" | "ingroup_permission" @@ -159,6 +161,8 @@ const emptyEditableAgent: EditableAgent = { business_description: "", business_logic_model_name: "", business_logic_model_id: 0, + prompt_template_id: 0, + prompt_template_name: "system_default", sub_agent_id_list: [], group_ids: [], ingroup_permission: "READ_ONLY", @@ -183,6 +187,8 @@ const toEditable = (agent: Agent | null): EditableAgent => business_description: agent.business_description || "", business_logic_model_name: agent.business_logic_model_name || "", business_logic_model_id: agent.business_logic_model_id || 0, + prompt_template_id: agent.prompt_template_id ?? 0, + prompt_template_name: agent.prompt_template_name || "system_default", sub_agent_id_list: agent.sub_agent_id_list || [], group_ids: agent.group_ids || [], ingroup_permission: agent.ingroup_permission || "READ_ONLY", @@ -200,13 +206,17 @@ const isBusinessInfoDirty = (baselineAgent: EditableAgent | null, editedAgent: E return ( editedAgent.business_description !== "" || editedAgent.business_logic_model_name !== "" || - editedAgent.business_logic_model_id !== 0 + editedAgent.business_logic_model_id !== 0 || + (editedAgent.prompt_template_id ?? 0) !== 0 || + (editedAgent.prompt_template_name || "system_default") !== "system_default" ); } return ( baselineAgent.business_description !== editedAgent.business_description || baselineAgent.business_logic_model_name !== editedAgent.business_logic_model_name || - baselineAgent.business_logic_model_id !== editedAgent.business_logic_model_id + baselineAgent.business_logic_model_id !== editedAgent.business_logic_model_id || + (baselineAgent.prompt_template_id ?? 0) !== (editedAgent.prompt_template_id ?? 0) || + (baselineAgent.prompt_template_name || "system_default") !== (editedAgent.prompt_template_name || "system_default") ); }; diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index b506730f8..d0c6ee43c 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -4,10 +4,15 @@ import type { Dispatch, SetStateAction } from "react"; import { ChatMessageType } from "./chat"; import { ModelOption } from "@/types/modelConfig"; import { GENERATE_PROMPT_STREAM_TYPES } from "../const/agentConfig"; +import type { PromptTemplateFieldKey } from "../const/promptTemplate"; export type AgentBusinessInfo = Partial>; export type AgentProfileInfo = Partial< @@ -26,6 +31,8 @@ export type AgentProfileInfo = Partial< | "few_shots_prompt" | "group_ids" | "ingroup_permission" + | "prompt_template_id" + | "prompt_template_name" > >; @@ -50,6 +57,8 @@ export interface Agent { business_description?: string; business_logic_model_name?: string; business_logic_model_id?: number; + prompt_template_id?: number; + prompt_template_name?: string; is_available?: boolean; is_new?: boolean; sub_agent_id_list?: number[]; @@ -408,6 +417,7 @@ export interface GeneratePromptParams { agent_id: number; task_description: string; model_id: string; + prompt_template_id?: number; tool_ids?: number[]; // Optional: tool IDs selected in frontend (takes precedence over database query) sub_agent_ids?: number[]; // Optional: sub-agent IDs selected in frontend (takes precedence over database query) /** @@ -447,3 +457,25 @@ export interface StreamResponseData { content: string; is_complete: boolean; } + +export type PromptTemplateContent = Record; + +export interface PromptTemplate { + template_id: number; + template_name: string; + description?: string | null; + template_type: string; + template_content_zh: PromptTemplateContent; + template_content_en?: PromptTemplateContent | null; + is_system_default?: boolean; + create_time?: string; + update_time?: string; +} + +export interface PromptTemplatePayload { + template_name: string; + description?: string; + template_type?: string; + template_content_zh: PromptTemplateContent; + template_content_en?: PromptTemplateContent | null; +} diff --git a/test/backend/app/test_prompt_template_app.py b/test/backend/app/test_prompt_template_app.py new file mode 100644 index 000000000..8cd78cf1d --- /dev/null +++ b/test/backend/app/test_prompt_template_app.py @@ -0,0 +1,397 @@ +import importlib +import os +import sys +import types +from http import HTTPStatus + +import pytest + + +BACKEND_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../backend") +) + + +@pytest.fixture(autouse=True) +def _reset_prompt_template_app_modules(): + yield + sys.modules.pop("apps.prompt_template_app", None) + sys.modules.pop("services.prompt_template_service", None) + sys.modules.pop("utils.auth_utils", None) + + +@pytest.fixture +def prompt_template_app_module(monkeypatch): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + + service_module = types.ModuleType("services.prompt_template_service") + for name in [ + "create_prompt_template_impl", + "delete_prompt_template_impl", + "get_prompt_template_detail_impl", + "list_prompt_templates_impl", + "update_prompt_template_impl", + ]: + setattr(service_module, name, lambda *args, **kwargs: None) + monkeypatch.setitem(sys.modules, "services.prompt_template_service", service_module) + + auth_module = types.ModuleType("utils.auth_utils") + auth_module.get_current_user_id = lambda authorization: ("user-1", "tenant-1") + monkeypatch.setitem(sys.modules, "utils.auth_utils", auth_module) + + sys.modules.pop("apps.prompt_template_app", None) + module = importlib.import_module("apps.prompt_template_app") + return importlib.reload(module) + + +@pytest.fixture +def prompt_template_exceptions(): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + return importlib.import_module("consts.exceptions") + + +@pytest.fixture +def prompt_template_client(prompt_template_app_module): + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + app.include_router(prompt_template_app_module.router) + return TestClient(app) + + +@pytest.fixture +def prompt_template_payload(): + return { + "template_name": "template-a", + "description": "template description", + "template_type": "agent_generate", + "template_content_zh": { + "duty_system_prompt": "zh-duty", + "constraint_system_prompt": "zh-constraint", + "few_shots_system_prompt": "zh-few-shots", + "agent_variable_name_system_prompt": "zh-agent-name", + "agent_display_name_system_prompt": "zh-display-name", + "agent_description_system_prompt": "zh-description", + "user_prompt": "zh-user", + "agent_name_regenerate_system_prompt": "zh-regen-name-system", + "agent_name_regenerate_user_prompt": "zh-regen-name-user", + "agent_display_name_regenerate_system_prompt": "zh-regen-display-system", + "agent_display_name_regenerate_user_prompt": "zh-regen-display-user", + }, + "template_content_en": { + "duty_system_prompt": "en-duty", + "constraint_system_prompt": "en-constraint", + "few_shots_system_prompt": "en-few-shots", + "agent_variable_name_system_prompt": "en-agent-name", + "agent_display_name_system_prompt": "en-display-name", + "agent_description_system_prompt": "en-description", + "user_prompt": "en-user", + "agent_name_regenerate_system_prompt": "en-regen-name-system", + "agent_name_regenerate_user_prompt": "en-regen-name-user", + "agent_display_name_regenerate_system_prompt": "en-regen-display-system", + "agent_display_name_regenerate_user_prompt": "en-regen-display-user", + }, + } + + +def test_list_prompt_templates_api_success( + mocker, prompt_template_app_module, prompt_template_client +): + auth_mock = mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + list_mock = mocker.patch.object( + prompt_template_app_module, + "list_prompt_templates_impl", + return_value=[{"template_id": 0, "template_name": "system_default"}], + ) + + response = prompt_template_client.get( + "/prompt_templates", + headers={"Authorization": "Bearer token"}, + ) + + assert response.status_code == HTTPStatus.OK + assert response.json() == [{"template_id": 0, "template_name": "system_default"}] + auth_mock.assert_called_once_with("Bearer token") + list_mock.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + + +def test_list_prompt_templates_api_returns_internal_error_on_unexpected_exception( + mocker, prompt_template_app_module, prompt_template_client +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "list_prompt_templates_impl", + side_effect=Exception("db error"), + ) + + response = prompt_template_client.get("/prompt_templates") + + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + assert response.json()["detail"] == "Prompt template list error." + + +def test_get_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + detail_mock = mocker.patch.object( + prompt_template_app_module, + "get_prompt_template_detail_impl", + return_value={"template_id": 1, "template_name": "template-a"}, + ) + + response = prompt_template_client.get("/prompt_templates/1") + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 1, "template_name": "template-a"} + detail_mock.assert_called_once_with(template_id=1, tenant_id="tenant-1", user_id="user-1") + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("not_found", HTTPStatus.NOT_FOUND, "Prompt template not found", id="not-found"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template detail error."), + ], +) +def test_get_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "not_found": + side_effect = prompt_template_exceptions.NotFoundException( + "Prompt template not found" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "get_prompt_template_detail_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.get("/prompt_templates/3") + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail + + +def test_create_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client, prompt_template_payload +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + create_mock = mocker.patch.object( + prompt_template_app_module, + "create_prompt_template_impl", + return_value={"template_id": 9, "template_name": "template-a"}, + ) + + response = prompt_template_client.post("/prompt_templates", json=prompt_template_payload) + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 9, "template_name": "template-a"} + assert create_mock.call_args.kwargs["tenant_id"] == "tenant-1" + assert create_mock.call_args.kwargs["user_id"] == "user-1" + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("duplicate", HTTPStatus.BAD_REQUEST, "Prompt template name already exists", id="duplicate"), + pytest.param("validation", HTTPStatus.BAD_REQUEST, "template_content_zh is required", id="validation"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template create error."), + ], +) +def test_create_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + prompt_template_payload, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "duplicate": + side_effect = prompt_template_exceptions.DuplicateError( + "Prompt template name already exists" + ) + elif side_effect == "validation": + side_effect = prompt_template_exceptions.ValidationError( + "template_content_zh is required" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "create_prompt_template_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.post("/prompt_templates", json=prompt_template_payload) + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail + + +def test_update_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client, prompt_template_payload +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + update_mock = mocker.patch.object( + prompt_template_app_module, + "update_prompt_template_impl", + return_value={"template_id": 4, "template_name": "template-a"}, + ) + + response = prompt_template_client.put("/prompt_templates/4", json=prompt_template_payload) + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 4, "template_name": "template-a"} + assert update_mock.call_args.kwargs["template_id"] == 4 + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("not_found", HTTPStatus.NOT_FOUND, "Prompt template not found", id="not-found"), + pytest.param("duplicate", HTTPStatus.BAD_REQUEST, "Prompt template name already exists", id="duplicate"), + pytest.param("validation", HTTPStatus.BAD_REQUEST, "System default prompt template cannot be updated", id="validation"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template update error."), + ], +) +def test_update_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + prompt_template_payload, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "not_found": + side_effect = prompt_template_exceptions.NotFoundException( + "Prompt template not found" + ) + elif side_effect == "duplicate": + side_effect = prompt_template_exceptions.DuplicateError( + "Prompt template name already exists" + ) + elif side_effect == "validation": + side_effect = prompt_template_exceptions.ValidationError( + "System default prompt template cannot be updated" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "update_prompt_template_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.put("/prompt_templates/7", json=prompt_template_payload) + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail + + +def test_delete_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + delete_mock = mocker.patch.object( + prompt_template_app_module, + "delete_prompt_template_impl", + return_value={"template_id": 8, "deleted": True}, + ) + + response = prompt_template_client.delete("/prompt_templates/8") + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 8, "deleted": True} + delete_mock.assert_called_once_with(template_id=8, tenant_id="tenant-1", user_id="user-1") + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("not_found", HTTPStatus.NOT_FOUND, "Prompt template not found", id="not-found"), + pytest.param("validation", HTTPStatus.BAD_REQUEST, "System default prompt template cannot be deleted", id="validation"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template delete error."), + ], +) +def test_delete_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "not_found": + side_effect = prompt_template_exceptions.NotFoundException( + "Prompt template not found" + ) + elif side_effect == "validation": + side_effect = prompt_template_exceptions.ValidationError( + "System default prompt template cannot be deleted" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "delete_prompt_template_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.delete("/prompt_templates/11") + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail diff --git a/test/backend/database/test_agent_db.py b/test/backend/database/test_agent_db.py index 6f2c780e5..de2ed8864 100644 --- a/test/backend/database/test_agent_db.py +++ b/test/backend/database/test_agent_db.py @@ -119,6 +119,8 @@ def __init__(self): self.parent_agent_id = None self.provide_run_summary = None self.business_description = None + self.prompt_template_id = None + self.prompt_template_name = None self.group_ids = None self.is_new = True self.enable_context_manager = False diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 27298f25f..393695c09 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -1,6 +1,7 @@ import sys import asyncio import json +import types from contextlib import contextmanager from unittest.mock import patch, MagicMock, mock_open, call, Mock, AsyncMock import os @@ -62,10 +63,23 @@ def model_dump(self, **kwargs): sys.modules['database.a2a_agent_db'] = MagicMock() # Mock services submodules -sys.modules['services'] = MagicMock() -sys.modules['services.conversation_management_service'] = MagicMock() -sys.modules['services.memory_config_service'] = MagicMock() -sys.modules['services.agent_version_service'] = MagicMock() +services_module = types.ModuleType("services") +services_module.__path__ = [] +sys.modules['services'] = services_module + +conversation_management_service_mock = MagicMock() +memory_config_service_mock = MagicMock() +agent_version_service_mock = MagicMock() +prompt_template_service_mock = MagicMock() +prompt_template_service_mock.SYSTEM_PROMPT_TEMPLATE_ID = 0 +prompt_template_service_mock.SYSTEM_PROMPT_TEMPLATE_NAME = "system_default" +prompt_template_service_mock.get_prompt_template_summary = MagicMock(return_value=(None, None)) +prompt_template_service_mock.resolve_prompt_generate_template = MagicMock(return_value={}) + +sys.modules['services.conversation_management_service'] = conversation_management_service_mock +sys.modules['services.memory_config_service'] = memory_config_service_mock +sys.modules['services.agent_version_service'] = agent_version_service_mock +sys.modules['services.prompt_template_service'] = prompt_template_service_mock # Mock agents submodules sys.modules['agents'] = MagicMock() @@ -282,6 +296,18 @@ def reset_mocks(): yield +def apply_default_prompt_template_request_fields(request, prompt_template_id=None): + """Populate default request fields needed by prompt template aware service logic.""" + request.prompt_template_id = prompt_template_id + request.prompt_template_name = None + request.enabled_skill_ids = None + if not hasattr(request, "related_agent_ids"): + request.related_agent_ids = None + if not hasattr(request, "enabled_tool_ids"): + request.enabled_tool_ids = None + return request + + @pytest.mark.asyncio async def test_get_enable_tool_id_by_agent_id(): """ @@ -421,6 +447,8 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -479,6 +507,8 @@ async def test_get_agent_info_impl_with_version_no(mock_search_agent_info, mock_ "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -584,6 +614,7 @@ async def test_update_agent_info_impl_success(mock_get_current_user_info, mock_u request.business_description = "Updated agent" request.display_name = "Updated Display Name" request.enabled_tool_ids = None # Explicitly set to None to avoid tool handling path + apply_default_prompt_template_request_fields(request) # Execute await update_agent_info_impl(request, authorization="Bearer token") @@ -662,6 +693,7 @@ async def test_update_agent_info_impl_exception_handling(mock_get_current_user_i request.display_name = "Test Display Name" request.enabled_tool_ids = None request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute & Assert with pytest.raises(ValueError) as context: @@ -701,6 +733,7 @@ async def test_update_agent_info_impl_with_enabled_tool_ids( request.agent_id = 123 request.enabled_tool_ids = [1, 2] # Enable tools 1 and 2 request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -758,6 +791,7 @@ async def test_update_agent_info_impl_with_enabled_tool_ids_instance_having_null request.agent_id = 123 request.enabled_tool_ids = [1] # Enable only tool 1 request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -805,6 +839,7 @@ async def test_update_agent_info_impl_with_enabled_tool_ids_disabled_existing_to request.enabled_tool_ids = [2] # Only enable tool 2 (new tool) # Tool 1 exists but is NOT in enabled_tool_ids, so it should be disabled request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -858,6 +893,7 @@ async def test_update_agent_info_impl_with_related_agent_ids( request.agent_id = 123 request.enabled_tool_ids = None request.related_agent_ids = [456, 789] + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -896,6 +932,7 @@ async def test_update_agent_info_impl_circular_dependency_detection( request.agent_id = 123 request.enabled_tool_ids = None request.related_agent_ids = [123] # Agent tries to relate to itself + apply_default_prompt_template_request_fields(request) # Execute & Assert - self-reference should raise ValueError with pytest.raises(ValueError, match="Circular dependency detected"): @@ -941,6 +978,7 @@ async def test_update_agent_info_impl_with_both_tool_and_related_agents( request.agent_id = 123 request.enabled_tool_ids = [1] request.related_agent_ids = [456] + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -983,6 +1021,7 @@ async def test_update_agent_info_impl_tool_update_exception( request.agent_id = 123 request.enabled_tool_ids = [1] request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute & Assert with pytest.raises(ValueError, match="Failed to update agent tools"): @@ -1015,6 +1054,7 @@ async def test_update_agent_info_impl_related_agent_update_exception( request.agent_id = 123 request.enabled_tool_ids = None request.related_agent_ids = [456] + apply_default_prompt_template_request_fields(request) # Execute & Assert with pytest.raises(ValueError, match="Failed to update related agents"): @@ -1216,6 +1256,7 @@ async def test_update_agent_info_impl_create_agent_auto_group_ids(mock_get_curre request.enabled_tool_ids = None request.related_agent_ids = None request.group_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -1563,6 +1604,8 @@ async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1651,6 +1694,8 @@ async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_age "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1702,6 +1747,8 @@ async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_age "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1777,6 +1824,8 @@ def mock_get_model(model_id): "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": "Claude-3.5", + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1848,6 +1897,8 @@ def mock_get_model(model_id): "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": None, # Should be None when model info is not found + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1926,6 +1977,8 @@ def mock_get_model(model_id): "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": None, # Should be None when display_name is not in model_info + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -8015,6 +8068,7 @@ async def test_update_agent_info_impl_create_agent_with_ingroup_permission( request.related_agent_ids = None request.group_ids = [1, 2] request.ingroup_permission = PERMISSION_READ + apply_default_prompt_template_request_fields(request) result = await update_agent_info_impl(request, authorization="Bearer token") @@ -8065,6 +8119,7 @@ async def test_update_agent_info_impl_create_agent_with_ingroup_permission_none( request.related_agent_ids = None request.group_ids = None request.ingroup_permission = None + apply_default_prompt_template_request_fields(request) result = await update_agent_info_impl(request, authorization="Bearer token") @@ -8766,6 +8821,8 @@ async def test_update_agent_info_impl_skill_update_exception( mock_request.related_agent_ids = None mock_request.group_ids = None mock_request.ingroup_permission = None + mock_request.prompt_template_id = None + mock_request.prompt_template_name = None mock_query_skills.return_value = [] mock_create_skill.side_effect = Exception("Skill update failed") diff --git a/test/backend/services/test_agent_version_service.py b/test/backend/services/test_agent_version_service.py index 30ce8792b..d0146382f 100644 --- a/test/backend/services/test_agent_version_service.py +++ b/test/backend/services/test_agent_version_service.py @@ -607,8 +607,16 @@ def test_rollback_version_impl_success(monkeypatch): } mock_search = MagicMock(return_value=mock_version) monkeypatch.setattr(agent_version_service_module, "search_version_by_version_no", mock_search) - mock_query_snapshot = MagicMock(return_value=(mock_agent_snapshot, [], [])) + mock_query_snapshot = MagicMock( + return_value=( + {"agent_id": 1, "version_no": 1, "name": "Test Agent"}, + [{"tool_id": 1, "version_no": 1}], + [{"relation_id": 1, "version_no": 1}], + ) + ) monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot) + mock_query_draft = MagicMock(return_value=({"agent_id": 1, "version_no": 0}, [], [])) + monkeypatch.setattr(agent_version_service_module, "query_agent_draft", mock_query_draft) mock_restore_draft = MagicMock() monkeypatch.setattr(agent_version_service_module, "restore_agent_draft", mock_restore_draft) monkeypatch.setattr(skill_db_mock, "query_skill_instances_by_agent_id", MagicMock(return_value=[])) @@ -640,14 +648,22 @@ def test_rollback_version_impl_version_not_found(monkeypatch): def test_rollback_version_impl_draft_not_found(monkeypatch): - """Test rolling back when snapshot is not found""" + """Test rolling back when draft doesn't exist""" mock_version = {"version_no": 1} mock_search = MagicMock(return_value=mock_version) monkeypatch.setattr(agent_version_service_module, "search_version_by_version_no", mock_search) - mock_query_snapshot = MagicMock(return_value=(None, [], [])) + mock_query_snapshot = MagicMock( + return_value=( + {"agent_id": 1, "version_no": 1, "name": "Test Agent"}, + [], + [], + ) + ) monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot) + mock_query_draft = MagicMock(return_value=(None, [], [])) + monkeypatch.setattr(agent_version_service_module, "query_agent_draft", mock_query_draft) - with pytest.raises(ValueError, match="Agent snapshot for version 1 not found"): + with pytest.raises(ValueError, match="Agent draft not found"): rollback_version_impl( agent_id=1, tenant_id="tenant1", diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 1b71baa5c..522a850f0 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -268,9 +268,11 @@ def mock_generator(*args, **kwargs): "Test task", enabled_tools, # tool_info_list from helper "tenant456", + "user123", self.test_model_id, "zh", - None # knowledge_base_display_names + None, + None, ) @patch('backend.services.prompt_service._regenerate_agent_display_name_with_llm') @@ -663,6 +665,7 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): user_id="user123", tenant_id="tenant456", language="zh", + prompt_template_id=None, tool_ids=None, sub_agent_ids=None, knowledge_base_display_names=None, @@ -676,19 +679,19 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt(self, mock_get_prompt_template, mock_join_info, mock_call_llm): + @patch('backend.services.prompt_service.resolve_prompt_generate_template') + def test_generate_system_prompt(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" @@ -740,6 +743,7 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): mock_task_description, mock_tools, mock_tenant_id, + "test_user", self.test_model_id, mock_language ): @@ -747,7 +751,12 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): # Assert # Verify template loading - mock_get_prompt_template.assert_called_once_with(mock_language) + mock_resolve_prompt_template.assert_called_once_with( + tenant_id=mock_tenant_id, + user_id="test_user", + language=mock_language, + prompt_template_id=None, + ) # Verify template joining - now includes knowledge_base_display_names parameter mock_join_info.assert_called_once_with( @@ -793,19 +802,19 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt_with_exception(self, mock_get_prompt_template, mock_join_info, mock_call_llm): + @patch('backend.services.prompt_service.resolve_prompt_generate_template') + def test_generate_system_prompt_with_exception(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" # Mock call_llm_for_system_prompt to raise exception for one prompt type @@ -837,6 +846,7 @@ def mock_llm_call_with_exception(model_id, content, sys_prompt, callback, tenant mock_task_description, mock_tools, mock_tenant_id, + "test_user", self.test_model_id, mock_language ): @@ -848,7 +858,7 @@ def mock_llm_call_with_exception(model_id, content, sys_prompt, callback, tenant @patch('backend.services.prompt_service.Template') def test_join_info_for_generate_system_prompt(self, mock_template): # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [ {"name": "agent1", "description": "Agent 1 desc"}, {"name": "agent2", "description": "Agent 2 desc"} @@ -873,7 +883,7 @@ def test_join_info_for_generate_system_prompt(self, mock_template): # Assert self.assertEqual(result, "Rendered content") mock_template.assert_called_once_with( - mock_prompt_for_generate["USER_PROMPT"], undefined=StrictUndefined) + mock_prompt_for_generate["user_prompt"], undefined=StrictUndefined) mock_template_instance.render.assert_called_once() # Check template variables template_vars = mock_template_instance.render.call_args[0][0] @@ -1090,25 +1100,25 @@ def mock_gen(*args, **kwargs): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') + @patch('backend.services.prompt_service.resolve_prompt_generate_template') def test_generate_system_prompt_error_before_streaming( self, - mock_get_prompt_template, + mock_resolve_prompt_template, mock_join_info, mock_call_llm, ): """Test generate_system_prompt handles error that occurs before streaming (line 307-311)""" # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" # Mock call_llm_for_system_prompt to raise exception immediately @@ -1130,6 +1140,7 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): "Test task", [{"name": "tool1"}], "tenant123", + "test_user", self.test_model_id, "zh" ): @@ -1139,25 +1150,25 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') + @patch('backend.services.prompt_service.resolve_prompt_generate_template') def test_generate_system_prompt_error_during_streaming( self, - mock_get_prompt_template, + mock_resolve_prompt_template, mock_join_info, mock_call_llm, ): """Test generate_system_prompt handles error that occurs during streaming (line 330-331)""" # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" # Track which call we're on @@ -1188,6 +1199,7 @@ def mock_llm_call_error_after_first( "Test task", [{"name": "tool1"}], "tenant123", + "test_user", self.test_model_id, "zh" ): @@ -1242,7 +1254,7 @@ def test_get_enabled_sub_agent_description_for_generate_prompt_empty( def test_join_info_for_generate_system_prompt_english(self, mock_template): """Test join_info_for_generate_system_prompt with English language""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [ {"name": "agent1", "description": "Agent 1 desc"} ] @@ -1272,7 +1284,7 @@ def test_join_info_for_generate_system_prompt_english(self, mock_template): def test_join_info_for_generate_system_prompt_empty_tools_and_agents(self, mock_template): """Test join_info_for_generate_system_prompt with empty tools and sub-agents""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [] mock_task_description = "Test task" mock_tools = [] @@ -1293,7 +1305,7 @@ def test_join_info_for_generate_system_prompt_empty_tools_and_agents(self, mock_ def test_join_info_for_generate_system_prompt_with_knowledge_base_names(self, mock_template): """Test join_info_for_generate_system_prompt with knowledge_base_display_names""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [] mock_task_description = "Test task" mock_tools = [ @@ -1322,7 +1334,7 @@ def test_join_info_for_generate_system_prompt_with_knowledge_base_names(self, mo def test_join_info_for_generate_system_prompt_without_knowledge_base_names(self, mock_template): """Test join_info_for_generate_system_prompt without knowledge_base_display_names""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [] mock_task_description = "Test task" mock_tools = [ diff --git a/test/backend/services/test_prompt_template_service.py b/test/backend/services/test_prompt_template_service.py new file mode 100644 index 000000000..34415b203 --- /dev/null +++ b/test/backend/services/test_prompt_template_service.py @@ -0,0 +1,501 @@ +import importlib +import os +import sys +import types + +import pytest + + +BACKEND_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../backend") +) + + +@pytest.fixture(autouse=True) +def _reset_prompt_template_service_modules(): + yield + sys.modules.pop("services.prompt_template_service", None) + sys.modules.pop("database.prompt_template_db", None) + + +@pytest.fixture +def prompt_template_models(monkeypatch): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + + nexent_module = types.ModuleType("nexent") + nexent_core_module = types.ModuleType("nexent.core") + nexent_agents_module = types.ModuleType("nexent.core.agents") + agent_model_module = types.ModuleType("nexent.core.agents.agent_model") + agent_model_module.ToolConfig = type("ToolConfig", (), {}) + + monkeypatch.setitem(sys.modules, "nexent", nexent_module) + monkeypatch.setitem(sys.modules, "nexent.core", nexent_core_module) + monkeypatch.setitem(sys.modules, "nexent.core.agents", nexent_agents_module) + monkeypatch.setitem(sys.modules, "nexent.core.agents.agent_model", agent_model_module) + + consts_model = importlib.import_module("consts.model") + consts_exceptions = importlib.import_module("consts.exceptions") + return consts_model, consts_exceptions + + +@pytest.fixture +def prompt_template_service_module(monkeypatch): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + + db_module = types.ModuleType("database.prompt_template_db") + for name in [ + "create_prompt_template", + "delete_prompt_template", + "get_prompt_template_by_id", + "get_prompt_template_by_name", + "get_prompt_template_by_template_id", + "query_prompt_templates_by_user", + "upsert_prompt_template_by_id", + "update_prompt_template", + ]: + setattr(db_module, name, lambda *args, **kwargs: None) + monkeypatch.setitem(sys.modules, "database.prompt_template_db", db_module) + + sys.modules.pop("services.prompt_template_service", None) + module = importlib.import_module("services.prompt_template_service") + return importlib.reload(module) + + +@pytest.fixture +def template_content_factory(): + def _build(seed: str = "value", **overrides): + content = { + "duty_system_prompt": f"{seed}-duty", + "constraint_system_prompt": f"{seed}-constraint", + "few_shots_system_prompt": f"{seed}-few-shots", + "agent_variable_name_system_prompt": f"{seed}-agent-name", + "agent_display_name_system_prompt": f"{seed}-display-name", + "agent_description_system_prompt": f"{seed}-description", + "user_prompt": f"{seed}-user", + "agent_name_regenerate_system_prompt": f"{seed}-regen-name-system", + "agent_name_regenerate_user_prompt": f"{seed}-regen-name-user", + "agent_display_name_regenerate_system_prompt": f"{seed}-regen-display-system", + "agent_display_name_regenerate_user_prompt": f"{seed}-regen-display-user", + } + content.update(overrides) + return content + + return _build + + +@pytest.fixture +def prompt_template_request_factory(template_content_factory, prompt_template_models): + consts_model, _ = prompt_template_models + + def _build( + template_name: str = "template-a", + description: str | None = "template description", + template_type: str = "agent_generate", + template_content_zh: dict | None = None, + template_content_en: dict | None = None, + ): + return consts_model.PromptTemplateRequest( + template_name=template_name, + description=description, + template_type=template_type, + template_content_zh=consts_model.PromptTemplateContentRequest( + **(template_content_zh or template_content_factory("zh")) + ), + template_content_en=( + consts_model.PromptTemplateContentRequest( + **(template_content_en or template_content_factory("en")) + ) + if template_content_en is not None + else None + ), + ) + + return _build + + +def test_build_system_default_prompt_template_payload( + mocker, prompt_template_service_module, template_content_factory +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_generate_prompt_template", + side_effect=[ + template_content_factory("zh"), + template_content_factory("en"), + ], + ) + + payload = prompt_template_service_module.build_system_default_prompt_template_payload() + + assert payload["template_id"] == 0 + assert payload["template_name"] == "system_default" + assert payload["tenant_id"] == prompt_template_service_module.SYSTEM_PROMPT_TEMPLATE_TENANT_ID + assert payload["user_id"] == prompt_template_service_module.SYSTEM_PROMPT_TEMPLATE_USER_ID + assert payload["template_content_zh"]["duty_system_prompt"] == "zh-duty" + assert payload["template_content_en"]["duty_system_prompt"] == "en-duty" + + +def test_sync_system_default_prompt_template_marks_system_default( + mocker, prompt_template_service_module +): + payload = {"template_id": 0, "template_name": "system_default"} + mocker.patch.object( + prompt_template_service_module, + "build_system_default_prompt_template_payload", + return_value=payload, + ) + upsert_mock = mocker.patch.object( + prompt_template_service_module, + "upsert_prompt_template_by_id", + return_value={"template_id": 0, "template_name": "system_default"}, + ) + + result = prompt_template_service_module.sync_system_default_prompt_template() + + upsert_mock.assert_called_once_with( + template_id=0, + template_data=payload, + user_id=prompt_template_service_module.SYSTEM_PROMPT_TEMPLATE_USER_ID, + ) + assert result["is_system_default"] is True + + +def test_get_system_default_prompt_template_syncs_when_missing( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_template_id", + return_value=None, + ) + sync_mock = mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={"template_id": 0, "template_name": "system_default"}, + ) + + result = prompt_template_service_module.get_system_default_prompt_template() + + sync_mock.assert_called_once_with() + assert result["template_id"] == 0 + assert result["is_system_default"] is True + + +def test_normalize_template_request_trims_and_drops_empty_optional_fields( + prompt_template_service_module, prompt_template_request_factory, template_content_factory +): + request = prompt_template_request_factory( + template_name=" template-a ", + description=" ", + template_content_zh=template_content_factory( + "zh", + constraint_system_prompt="", + few_shots_system_prompt=" ", + ), + template_content_en=template_content_factory( + "en", + duty_system_prompt="", + constraint_system_prompt="", + few_shots_system_prompt="", + agent_variable_name_system_prompt="", + agent_display_name_system_prompt="", + agent_description_system_prompt="", + user_prompt="", + agent_name_regenerate_system_prompt="", + agent_name_regenerate_user_prompt="", + agent_display_name_regenerate_system_prompt="", + agent_display_name_regenerate_user_prompt="", + ), + ) + + result = prompt_template_service_module._normalize_template_request(request) + + assert result["template_name"] == "template-a" + assert result["description"] is None + assert "constraint_system_prompt" not in result["template_content_zh"] + assert result["template_content_en"] is None + + +def test_normalize_template_request_requires_non_empty_zh_content( + prompt_template_service_module, + prompt_template_request_factory, + template_content_factory, + prompt_template_models, +): + _, consts_exceptions = prompt_template_models + request = prompt_template_request_factory( + template_content_zh=template_content_factory( + "zh", + duty_system_prompt="", + constraint_system_prompt="", + few_shots_system_prompt="", + agent_variable_name_system_prompt="", + agent_display_name_system_prompt="", + agent_description_system_prompt="", + user_prompt="", + agent_name_regenerate_system_prompt="", + agent_name_regenerate_user_prompt="", + agent_display_name_regenerate_system_prompt="", + agent_display_name_regenerate_user_prompt="", + ) + ) + + with pytest.raises( + consts_exceptions.ValidationError, match="template_content_zh is required" + ): + prompt_template_service_module._normalize_template_request(request) + + +def test_list_prompt_templates_impl_prepends_system_default_and_filters_duplicate_id( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={"template_id": 0, "template_name": "system_default", "is_system_default": True}, + ) + mocker.patch.object( + prompt_template_service_module, + "query_prompt_templates_by_user", + return_value=[ + {"template_id": 0, "template_name": "system_default"}, + {"template_id": 2, "template_name": "custom-template"}, + ], + ) + + result = prompt_template_service_module.list_prompt_templates_impl("tenant-1", "user-1") + + assert [item["template_id"] for item in result] == [0, 2] + assert result[0]["is_system_default"] is True + assert result[1]["is_system_default"] is False + + +def test_create_prompt_template_impl_rejects_duplicate_name( + mocker, + prompt_template_service_module, + prompt_template_request_factory, + prompt_template_models, +): + _, consts_exceptions = prompt_template_models + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_name", + return_value={"template_id": 1, "template_name": "template-a"}, + ) + + with pytest.raises( + consts_exceptions.DuplicateError, match="Prompt template name already exists" + ): + prompt_template_service_module.create_prompt_template_impl( + prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + +def test_create_prompt_template_impl_persists_user_template( + mocker, prompt_template_service_module, prompt_template_request_factory +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_name", + return_value=None, + ) + create_mock = mocker.patch.object( + prompt_template_service_module, + "create_prompt_template", + return_value={"template_id": 9, "template_name": "template-a"}, + ) + + result = prompt_template_service_module.create_prompt_template_impl( + prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + create_payload = create_mock.call_args.args[0] + assert create_payload["tenant_id"] == "tenant-1" + assert create_payload["user_id"] == "user-1" + assert create_payload["created_by"] == "user-1" + assert result["is_system_default"] is False + + +def test_update_prompt_template_impl_rejects_system_default( + prompt_template_service_module, + prompt_template_request_factory, + prompt_template_models, +): + _, consts_exceptions = prompt_template_models + with pytest.raises( + consts_exceptions.ValidationError, + match="System default prompt template cannot be updated", + ): + prompt_template_service_module.update_prompt_template_impl( + template_id=0, + request=prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + +def test_update_prompt_template_impl_updates_existing_template( + mocker, prompt_template_service_module, prompt_template_request_factory +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value={"template_id": 3, "template_name": "template-a"}, + ) + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_name", + return_value={"template_id": 3, "template_name": "template-a"}, + ) + update_mock = mocker.patch.object( + prompt_template_service_module, + "update_prompt_template", + return_value={"template_id": 3, "template_name": "template-a"}, + ) + + result = prompt_template_service_module.update_prompt_template_impl( + template_id=3, + request=prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + assert update_mock.call_args.kwargs["template_id"] == 3 + assert update_mock.call_args.kwargs["user_id"] == "user-1" + assert result["is_system_default"] is False + + +@pytest.mark.parametrize("deleted_count, expected_deleted", [(1, True), (0, False)]) +def test_delete_prompt_template_impl_returns_deleted_status( + mocker, prompt_template_service_module, deleted_count, expected_deleted +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value={"template_id": 5, "template_name": "template-a"}, + ) + mocker.patch.object( + prompt_template_service_module, + "delete_prompt_template", + return_value=deleted_count, + ) + + result = prompt_template_service_module.delete_prompt_template_impl( + template_id=5, + tenant_id="tenant-1", + user_id="user-1", + ) + + assert result == {"template_id": 5, "deleted": expected_deleted} + + +def test_resolve_prompt_generate_template_falls_back_to_system_default_when_custom_missing( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={ + "template_content_en": {"duty_system_prompt": "system-en-duty"}, + "template_content_zh": {"constraint_system_prompt": "system-zh-constraint"}, + }, + ) + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value=None, + ) + + result = prompt_template_service_module.resolve_prompt_generate_template( + tenant_id="tenant-1", + user_id="user-1", + language=prompt_template_service_module.LANGUAGE["EN"], + prompt_template_id=8, + ) + + assert result == { + "duty_system_prompt": "system-en-duty", + "constraint_system_prompt": "system-zh-constraint", + } + + +def test_resolve_prompt_generate_template_merges_custom_and_system_fallbacks( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={ + "template_content_en": {"few_shots_system_prompt": "system-en-few"}, + "template_content_zh": {"user_prompt": "system-zh-user"}, + }, + ) + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value={ + "template_id": 6, + "template_content_en": {"duty_system_prompt": "custom-en-duty"}, + "template_content_zh": {"constraint_system_prompt": "custom-zh-constraint"}, + }, + ) + + result = prompt_template_service_module.resolve_prompt_generate_template( + tenant_id="tenant-1", + user_id="user-1", + language=prompt_template_service_module.LANGUAGE["EN"], + prompt_template_id=6, + ) + + assert result == { + "duty_system_prompt": "custom-en-duty", + "constraint_system_prompt": "custom-zh-constraint", + "few_shots_system_prompt": "system-en-few", + "user_prompt": "system-zh-user", + } + + +@pytest.mark.parametrize( + ("template_id", "expected"), + [ + (None, (None, None)), + (0, (0, "system_default")), + ], +) +def test_get_prompt_template_summary_handles_none_and_system_default( + prompt_template_service_module, template_id, expected +): + assert ( + prompt_template_service_module.get_prompt_template_summary( + template_id=template_id, + tenant_id="tenant-1", + user_id="user-1", + ) + == expected + ) + + +def test_get_prompt_template_summary_raises_when_template_missing( + mocker, prompt_template_service_module, prompt_template_models +): + _, consts_exceptions = prompt_template_models + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value=None, + ) + + with pytest.raises( + consts_exceptions.NotFoundException, match="Prompt template not found" + ): + prompt_template_service_module.get_prompt_template_summary( + template_id=10, + tenant_id="tenant-1", + user_id="user-1", + ) diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py index 82af6d5ba..dd24a9f20 100644 --- a/test/backend/test_cluster_summarization.py +++ b/test/backend/test_cluster_summarization.py @@ -35,10 +35,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py index 4fb094618..33d97c776 100644 --- a/test/backend/test_document_vector_integration.py +++ b/test/backend/test_document_vector_integration.py @@ -36,10 +36,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index 9bce2af29..53c87a022 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -35,10 +35,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index 23a6923c8..2b4278603 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -34,10 +34,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py index be9d6a20d..247e20399 100644 --- a/test/backend/test_summary_formatting.py +++ b/test/backend/test_summary_formatting.py @@ -32,10 +32,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index e0d5577ae..f7cac586b 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -496,6 +496,30 @@ def gen(): res = call_llm_for_system_prompt(2, "u2", "s2") assert res == "ABC" + def test_call_llm_for_system_prompt_skips_chunk_without_choices(self, mocker: MockFixture): + mock_get_model_by_id = mocker.patch('backend.utils.llm_utils.get_model_by_model_id') + mock_get_model_name = mocker.patch('backend.utils.llm_utils.get_model_name_from_config') + mock_openai = mocker.patch('backend.utils.llm_utils.OpenAIModel') + + mock_get_model_by_id.return_value = {"base_url": "http://y", "api_key": "k2"} + mock_get_model_name.return_value = "gpt-6" + + mock_instance = mock_openai.return_value + + empty_chunk = MagicMock() + empty_chunk.choices = [] + + valid_chunk = MagicMock() + valid_chunk.choices = [MagicMock()] + valid_chunk.choices[0].delta.content = "OK" + + mock_instance.client = MagicMock() + mock_instance.client.chat.completions.create.return_value = [empty_chunk, valid_chunk] + mock_instance._prepare_completion_kwargs.return_value = {} + + res = call_llm_for_system_prompt(2, "u2", "s2") + assert res == "OK" + def test_call_llm_for_system_prompt_with_callback(self, mocker: MockFixture): """Test call_llm_for_system_prompt with callback""" mock_get_model_by_id = mocker.patch('backend.utils.llm_utils.get_model_by_model_id')