diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index e87574ab8a8..bb099d26549 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -4,6 +4,8 @@ import urllib.request import warnings +from pydantic import PydanticUserError + from crewai.agent.core import Agent from crewai.agent.planning_config import PlanningConfig from crewai.crew import Crew @@ -93,6 +95,38 @@ def __getattr__(name: str) -> Any: raise AttributeError(f"module 'crewai' has no attribute {name!r}") +try: + from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler + from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor + from crewai.hooks.llm_hooks import LLMCallHookContext as _LLMCallHookContext + from crewai.tools.tool_types import ToolResult as _ToolResult + from crewai.utilities.prompts import ( + StandardPromptResult as _StandardPromptResult, + SystemPromptResult as _SystemPromptResult, + ) + + _AgentExecutor.model_rebuild( + force=True, + _types_namespace={ + "Agent": Agent, + "ToolsHandler": _ToolsHandler, + "Crew": Crew, + "BaseLLM": BaseLLM, + "Task": Task, + "StandardPromptResult": _StandardPromptResult, + "SystemPromptResult": _SystemPromptResult, + "LLMCallHookContext": _LLMCallHookContext, + "ToolResult": _ToolResult, + }, + ) +except (ImportError, PydanticUserError): + import logging as _logging + + _logging.getLogger(__name__).warning( + "AgentExecutor.model_rebuild() failed; forward refs may be unresolved.", + exc_info=True, + ) + __all__ = [ "LLM", "Agent", diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 8c31dd13999..e125dd7d432 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -1011,7 +1011,7 @@ def _update_executor_parameters( self.agent_executor.tools = tools self.agent_executor.original_tools = raw_tools self.agent_executor.prompt = prompt - self.agent_executor.stop = stop_words + self.agent_executor.stop_words = stop_words self.agent_executor.tools_names = get_tool_names(tools) self.agent_executor.tools_description = render_text_description_and_args(tools) self.agent_executor.response_model = ( diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index a504e5097d4..bbd14f518d9 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -11,10 +11,15 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from uuid import uuid4 -from pydantic import BaseModel, Field, GetCoreSchemaHandler -from pydantic_core import CoreSchema, core_schema +from pydantic import ( + BaseModel, + Field, + PrivateAttr, + model_validator, +) from rich.console import Console from rich.text import Text +from typing_extensions import Self from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.parser import ( @@ -119,6 +124,7 @@ class AgentExecutorState(BaseModel): (todos, observations, replan tracking) in a single validated model. """ + id: str = Field(default_factory=lambda: str(uuid4())) messages: list[LLMMessage] = Field(default_factory=list) iterations: int = Field(default=0) current_answer: AgentAction | AgentFinish | None = Field(default=None) @@ -152,6 +158,9 @@ class AgentExecutorState(BaseModel): class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): """Agent Executor for both standalone agents and crew-bound agents. + _skip_auto_memory prevents Flow from eagerly allocating a Memory + instance — the executor uses agent/crew memory, not its own. + Inherits from: - Flow[AgentExecutorState]: Provides flow orchestration capabilities - CrewAgentExecutorMixin: Provides memory methods (short/long/external term) @@ -159,136 +168,74 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): This executor can operate in two modes: - Standalone mode: When crew and task are None (used by Agent.kickoff()) - Crew mode: When crew and task are provided (used by Agent.execute_task()) - - Note: Multiple instances may be created during agent initialization - (cache setup, RPM controller setup, etc.) but only the final instance - should execute tasks via invoke(). """ - def __init__( - self, - llm: BaseLLM, - agent: Agent, - prompt: SystemPromptResult | StandardPromptResult, - max_iter: int, - tools: list[CrewStructuredTool], - tools_names: str, - stop_words: list[str], - tools_description: str, - tools_handler: ToolsHandler, - task: Task | None = None, - crew: Crew | None = None, - step_callback: Any = None, - original_tools: list[BaseTool] | None = None, - function_calling_llm: BaseLLM | Any | None = None, - respect_context_window: bool = False, - request_within_rpm_limit: Callable[[], bool] | None = None, - callbacks: list[Any] | None = None, - response_model: type[BaseModel] | None = None, - i18n: I18N | None = None, - ) -> None: - """Initialize the flow-based agent executor. + _skip_auto_memory: bool = True + + suppress_flow_events: bool = True # always suppress for executor + llm: BaseLLM = Field(exclude=True) + agent: Agent = Field(exclude=True) + prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True) + max_iter: int = Field(default=25, exclude=True) + tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True) + tools_names: str = Field(default="", exclude=True) + stop_words: list[str] = Field(default_factory=list, exclude=True) + tools_description: str = Field(default="", exclude=True) + tools_handler: ToolsHandler | None = Field(default=None, exclude=True) + task: Task | None = Field(default=None, exclude=True) + crew: Crew | None = Field(default=None, exclude=True) + step_callback: Any = Field(default=None, exclude=True) + original_tools: list[BaseTool] = Field(default_factory=list, exclude=True) + function_calling_llm: BaseLLM | None = Field(default=None, exclude=True) + respect_context_window: bool = Field(default=False, exclude=True) + request_within_rpm_limit: Callable[[], bool] | None = Field( + default=None, exclude=True + ) + callbacks: list[Any] = Field(default_factory=list, exclude=True) + response_model: type[BaseModel] | None = Field(default=None, exclude=True) + i18n: I18N | None = Field(default=None, exclude=True) + log_error_after: int = Field(default=3, exclude=True) + before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = ( + Field(default_factory=list, exclude=True) + ) + after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = Field( + default_factory=list, exclude=True + ) - Args: - llm: Language model instance. - agent: Agent to execute. - prompt: Prompt templates. - max_iter: Maximum iterations. - tools: Available tools. - tools_names: Tool names string. - stop_words: Stop word list. - tools_description: Tool descriptions. - tools_handler: Tool handler instance. - task: Optional task to execute (None for standalone agent execution). - crew: Optional crew instance (None for standalone agent execution). - step_callback: Optional step callback. - original_tools: Original tool list. - function_calling_llm: Optional function calling LLM. - respect_context_window: Respect context limits. - request_within_rpm_limit: RPM limit check function. - callbacks: Optional callbacks list. - response_model: Optional Pydantic model for structured outputs. - """ - self._i18n: I18N = i18n or get_i18n() - self.llm = llm - self.task: Task | None = task - self.agent = agent - self.crew: Crew | None = crew - self.prompt = prompt - self.tools = tools - self.tools_names = tools_names - self.stop = stop_words - self.max_iter = max_iter - self.callbacks = callbacks or [] - self._printer: Printer = Printer() - self.tools_handler = tools_handler - self.original_tools = original_tools or [] - self.step_callback = step_callback - self.tools_description = tools_description - self.function_calling_llm = function_calling_llm - self.respect_context_window = respect_context_window - self.request_within_rpm_limit = request_within_rpm_limit - self.response_model = response_model - self.log_error_after = 3 - self._console: Console = Console() - - # Error context storage for recovery - self._last_parser_error: OutputParserError | None = None - self._last_context_error: Exception | None = None - - # Execution guard to prevent concurrent/duplicate executions - self._execution_lock = threading.Lock() - self._finalize_lock = threading.Lock() - self._finalize_called: bool = False - self._is_executing: bool = False - self._has_been_invoked: bool = False - self._flow_initialized: bool = False - - self._instance_id = str(uuid4())[:8] - - self.before_llm_call_hooks: list[ - BeforeLLMCallHookType | BeforeLLMCallHookCallable - ] = [] - self.after_llm_call_hooks: list[ - AfterLLMCallHookType | AfterLLMCallHookCallable - ] = [] + _i18n: I18N = PrivateAttr(default_factory=get_i18n) + _printer: Printer = PrivateAttr(default_factory=Printer) + _console: Console = PrivateAttr(default_factory=Console) + _last_parser_error: OutputParserError | None = PrivateAttr(default=None) + _last_context_error: Exception | None = PrivateAttr(default=None) + _execution_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _finalize_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _finalize_called: bool = PrivateAttr(default=False) + _is_executing: bool = PrivateAttr(default=False) + _has_been_invoked: bool = PrivateAttr(default=False) + _instance_id: str = PrivateAttr(default_factory=lambda: str(uuid4())[:8]) + _step_executor: Any = PrivateAttr(default=None) + _planner_observer: Any = PrivateAttr(default=None) + + @model_validator(mode="after") + def _setup_executor(self) -> Self: + """Configure executor after Pydantic field initialization.""" + self._i18n = self.i18n or get_i18n() self.before_llm_call_hooks.extend(get_before_llm_call_hooks()) self.after_llm_call_hooks.extend(get_after_llm_call_hooks()) if self.llm: existing_stop = getattr(self.llm, "stop", []) - self.llm.stop = list( - set( - existing_stop + self.stop - if isinstance(existing_stop, list) - else self.stop - ) - ) - self._state = AgentExecutorState() - - # Plan-and-Execute components (Phase 2) - # Lazy-imported to avoid circular imports during module load - self._step_executor: Any = None - self._planner_observer: Any = None + if not isinstance(existing_stop, list): + existing_stop = [] + self.llm.stop = list(set(existing_stop + self.stop_words)) - def _ensure_flow_initialized(self) -> None: - """Ensure Flow.__init__() has been called. + self._state = AgentExecutorState() + self.max_method_calls = self.max_iter * 10 - This is deferred from __init__ to prevent FlowCreatedEvent emission - during agent setup when multiple executor instances are created. - Only the instance that actually executes via invoke() will emit events. - """ - if not self._flow_initialized: - current_tracing = is_tracing_enabled_in_context() - # Now call Flow's __init__ which will replace self._state - # with Flow's managed state. Suppress flow events since this is - # an agent executor, not a user-facing flow. - super().__init__( - suppress_flow_events=True, - tracing=current_tracing if current_tracing else None, - max_method_calls=self.max_iter * 10, - ) - self._flow_initialized = True + current_tracing = is_tracing_enabled_in_context() + self.tracing = current_tracing if current_tracing else None + self._flow_post_init() + return self def _check_native_tool_support(self) -> bool: """Check if LLM supports native function calling.""" @@ -318,19 +265,13 @@ def use_stop_words(self) -> bool: @property def state(self) -> AgentExecutorState: - """Get state - returns temporary state if Flow not yet initialized. - - Flow initialization is deferred to prevent event emission during agent setup. - Returns the temporary state until invoke() is called. - """ - if self._flow_initialized and hasattr(self, "_state_lock"): - return StateProxy(self._state, self._state_lock) # type: ignore[return-value] - return self._state + """Get thread-safe state proxy.""" + return StateProxy(self._state, self._state_lock) # type: ignore[return-value] @property def iterations(self) -> int: """Compatibility property for mixin - returns state iterations.""" - return self._state.iterations + return self._state.iterations # type: ignore[no-any-return] @iterations.setter def iterations(self, value: int) -> None: @@ -340,7 +281,7 @@ def iterations(self, value: int) -> None: @property def messages(self) -> list[LLMMessage]: """Compatibility property - returns state messages.""" - return self._state.messages + return self._state.messages # type: ignore[no-any-return] @messages.setter def messages(self, value: list[LLMMessage]) -> None: @@ -1969,8 +1910,7 @@ def _execute_single_native_tool_call(self, tool_call: Any) -> dict[str, Any]: @listen("initialized") def continue_iteration(self) -> Literal["check_iteration"]: """Bridge listener that connects iteration loop back to iteration check.""" - if self._flow_initialized: - self._discard_or_listener(FlowMethodName("continue_iteration")) + self._discard_or_listener(FlowMethodName("continue_iteration")) return "check_iteration" @router(or_(initialize_reasoning, continue_iteration)) @@ -2598,8 +2538,6 @@ def invoke( if is_inside_event_loop(): return self.invoke_async(inputs) - self._ensure_flow_initialized() - with self._execution_lock: if self._is_executing: raise RuntimeError( @@ -2690,8 +2628,6 @@ async def invoke_async(self, inputs: dict[str, Any]) -> dict[str, Any]: Returns: Dictionary with agent output. """ - self._ensure_flow_initialized() - with self._execution_lock: if self._is_executing: raise RuntimeError( @@ -3007,17 +2943,6 @@ def _is_training_mode(self) -> bool: """ return bool(self.crew and self.crew._train) - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: GetCoreSchemaHandler - ) -> CoreSchema: - """Generate Pydantic core schema for Protocol compatibility. - - Allows the executor to be used in Pydantic models without - requiring arbitrary_types_allowed=True. - """ - return core_schema.any_schema() - # Backward compatibility alias (deprecated) CrewAgentExecutorFlow = AgentExecutor diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 0624f7bec17..def7d1ba9b9 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -39,7 +39,14 @@ from opentelemetry import baggage from opentelemetry.context import attach, detach -from pydantic import BaseModel, Field, ValidationError +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + ValidationError, +) +from pydantic._internal._model_construction import ModelMetaclass from rich.console import Console from rich.panel import Panel @@ -81,6 +88,7 @@ SimpleFlowCondition, StartMethod, ) +from crewai.flow.human_feedback import HumanFeedbackResult from crewai.flow.input_provider import InputProvider from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import ( @@ -108,7 +116,6 @@ from crewai_files import FileInput from crewai.flow.async_feedback.types import PendingFeedbackContext - from crewai.flow.human_feedback import HumanFeedbackResult from crewai.llms.base_llm import BaseLLM from crewai.flow.visualization import build_flow_structure, render_interactive @@ -728,7 +735,7 @@ def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: return result -class FlowMeta(type): +class FlowMeta(ModelMetaclass): def __new__( mcs, name: str, @@ -736,6 +743,45 @@ def __new__( namespace: dict[str, Any], **kwargs: Any, ) -> type: + parent_fields: set[str] = set() + for base in bases: + if hasattr(base, "model_fields"): + parent_fields.update(base.model_fields) + + annotations = namespace.get("__annotations__", {}) + _skip_types = (classmethod, staticmethod, property) + + for base in bases: + if isinstance(base, ModelMetaclass): + continue + for attr_name in getattr(base, "__annotations__", {}): + if attr_name not in annotations and attr_name not in namespace: + annotations[attr_name] = ClassVar + + for attr_name, attr_value in namespace.items(): + if isinstance(attr_value, property) and attr_name not in annotations: + for base in bases: + base_ann = getattr(base, "__annotations__", {}) + if attr_name in base_ann: + annotations[attr_name] = ClassVar + + for attr_name, attr_value in list(namespace.items()): + if attr_name in annotations or attr_name.startswith("_"): + continue + if attr_name in parent_fields: + annotations[attr_name] = Any + if isinstance(attr_value, BaseModel): + namespace[attr_name] = Field( + default_factory=lambda v=attr_value: v, exclude=True + ) + continue + if callable(attr_value) or isinstance( + attr_value, (*_skip_types, FlowMethod) + ): + continue + annotations[attr_name] = ClassVar[type(attr_value)] + namespace["__annotations__"] = annotations + cls = super().__new__(mcs, name, bases, namespace) start_methods = [] @@ -820,88 +866,90 @@ def __new__( return cls -class Flow(Generic[T], metaclass=FlowMeta): +class Flow(BaseModel, Generic[T], metaclass=FlowMeta): """Base class for all flows. type parameter T must be either dict[str, Any] or a subclass of BaseModel.""" + model_config = ConfigDict( + arbitrary_types_allowed=True, + ignored_types=(StartMethod, ListenMethod, RouterMethod), + revalidate_instances="never", + ) + __hash__ = object.__hash__ + _start_methods: ClassVar[list[FlowMethodName]] = [] _listeners: ClassVar[dict[FlowMethodName, SimpleFlowCondition | FlowCondition]] = {} _routers: ClassVar[set[FlowMethodName]] = set() _router_paths: ClassVar[dict[FlowMethodName, list[FlowMethodName]]] = {} - initial_state: type[T] | T | None = None - name: str | None = None - tracing: bool | None = None - stream: bool = False - memory: Memory | MemoryScope | MemorySlice | None = None - input_provider: InputProvider | None = None - def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: - class _FlowGeneric(cls): # type: ignore - _initial_state_t = item + initial_state: Any = Field(default=None) + name: str | None = Field(default=None) + tracing: bool | None = Field(default=None) + stream: bool = Field(default=False) + memory: Memory | MemoryScope | MemorySlice | None = Field(default=None) + input_provider: InputProvider | None = Field(default=None) + suppress_flow_events: bool = Field(default=False) + human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list) + last_human_feedback: HumanFeedbackResult | None = Field(default=None) + + persistence: Any = Field(default=None, exclude=True) + max_method_calls: int = Field(default=100, exclude=True) + + _methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr( + default_factory=dict + ) + _method_execution_counts: dict[FlowMethodName, int] = PrivateAttr( + default_factory=dict + ) + _pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = PrivateAttr( + default_factory=dict + ) + _fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set) + _method_outputs: list[Any] = PrivateAttr(default_factory=list) + _state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set) + _method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict) + _is_execution_resuming: bool = PrivateAttr(default=False) + _event_futures: list[Future[None]] = PrivateAttr(default_factory=list) + _pending_feedback_context: PendingFeedbackContext | None = PrivateAttr(default=None) + _human_feedback_method_outputs: dict[str, Any] = PrivateAttr(default_factory=dict) + _input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list) + _state: Any = PrivateAttr(default=None) + + def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override] + class _FlowGeneric(cls): # type: ignore[valid-type,misc] + pass _FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]" + _FlowGeneric._initial_state_t = item return _FlowGeneric - def __init__( - self, - persistence: FlowPersistence | None = None, - tracing: bool | None = None, - suppress_flow_events: bool = False, - max_method_calls: int = 100, - **kwargs: Any, - ) -> None: - """Initialize a new Flow instance. + def __setattr__(self, name: str, value: Any) -> None: + """Allow arbitrary attribute assignment for backward compat with plain class.""" + if name in self.model_fields or name in self.__private_attributes__: + super().__setattr__(name, value) + else: + object.__setattr__(self, name, value) + + def model_post_init(self, __context: Any) -> None: + self._flow_post_init() + + def _flow_post_init(self) -> None: + """Heavy initialization: state creation, events, memory, method registration.""" + if getattr(self, "_flow_post_init_done", False): + return + object.__setattr__(self, "_flow_post_init_done", True) + + if self._state is None: + self._state = self._create_initial_state() - Args: - persistence: Optional persistence backend for storing flow states - tracing: Whether to enable tracing. True=always enable, False=always disable, None=check environment/user settings - suppress_flow_events: Whether to suppress flow event emissions (internal use) - max_method_calls: Maximum times a single method can be called per execution before raising RecursionError - **kwargs: Additional state values to initialize or override - """ - # Initialize basic instance attributes - self._methods: dict[FlowMethodName, FlowMethod[Any, Any]] = {} - self._method_execution_counts: dict[FlowMethodName, int] = {} - self._pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = {} - self._fired_or_listeners: set[FlowMethodName] = ( - set() - ) # Track OR listeners that already fired - self._method_outputs: list[Any] = [] # list to store all method outputs - self._state_lock = threading.Lock() - self._or_listeners_lock = threading.Lock() - self._completed_methods: set[FlowMethodName] = ( - set() - ) # Track completed methods for reload - self._method_call_counts: dict[FlowMethodName, int] = {} - self._max_method_calls = max_method_calls - self._persistence: FlowPersistence | None = persistence - self._is_execution_resuming: bool = False - self._event_futures: list[Future[None]] = [] - - # Human feedback storage - self.human_feedback_history: list[HumanFeedbackResult] = [] - self.last_human_feedback: HumanFeedbackResult | None = None - self._pending_feedback_context: PendingFeedbackContext | None = None - # Per-method stash for real @human_feedback output (keyed by method name) - # Used to decouple routing outcome from method return value when emit is set - self._human_feedback_method_outputs: dict[str, Any] = {} - self.suppress_flow_events: bool = suppress_flow_events - - # User input history (for self.ask()) - self._input_history: list[InputHistoryEntry] = [] - - # Initialize state with initial values - self._state = self._create_initial_state() - self.tracing = tracing tracing_enabled = should_enable_tracing(override=self.tracing) set_tracing_enabled(tracing_enabled) trace_listener = TraceCollectionListener() trace_listener.setup_listeners(crewai_event_bus) - # Apply any additional kwargs - if kwargs: - self._initialize_state(kwargs) if not self.suppress_flow_events: crewai_event_bus.emit( @@ -1385,8 +1433,8 @@ async def resume_async(self, feedback: str = "") -> Any: self._pending_feedback_context = None # Clear pending feedback from persistence - if self._persistence: - self._persistence.clear_pending_feedback(context.flow_id) + if self.persistence: + self.persistence.clear_pending_feedback(context.flow_id) # Emit feedback received event crewai_event_bus.emit( @@ -1427,17 +1475,17 @@ async def resume_async(self, feedback: str = "") -> Any: if isinstance(e, HumanFeedbackPending): self._pending_feedback_context = e.context - if self._persistence is None: + if self.persistence is None: from crewai.flow.persistence import SQLiteFlowPersistence - self._persistence = SQLiteFlowPersistence() + self.persistence = SQLiteFlowPersistence() state_data = ( self._state if isinstance(self._state, dict) else self._state.model_dump() ) - self._persistence.save_pending_feedback( + self.persistence.save_pending_feedback( flow_uuid=e.context.flow_id, context=e.context, state_data=state_data, @@ -1487,39 +1535,33 @@ def _create_initial_state(self) -> T: """ init_state = self.initial_state - # Handle case where initial_state is None but we have a type parameter if init_state is None and hasattr(self, "_initial_state_t"): state_type = self._initial_state_t if isinstance(state_type, type): if issubclass(state_type, FlowState): - # Create instance - FlowState auto-generates id via default_factory instance = state_type() - # Ensure id is set - generate UUID if empty if not getattr(instance, "id", None): object.__setattr__(instance, "id", str(uuid4())) return cast(T, instance) if issubclass(state_type, BaseModel): - # Create a new type with FlowState first for proper id default + class StateWithId(FlowState, state_type): # type: ignore pass instance = StateWithId() - # Ensure id is set - generate UUID if empty if not getattr(instance, "id", None): object.__setattr__(instance, "id", str(uuid4())) return cast(T, instance) if state_type is dict: return cast(T, {"id": str(uuid4())}) - # Handle case where no initial state is provided if init_state is None: return cast(T, {"id": str(uuid4())}) - # Handle case where initial_state is a type (class) if isinstance(init_state, type): state_class = init_state if issubclass(state_class, FlowState): - return state_class() + return cast(T, state_class()) if issubclass(state_class, BaseModel): model_fields = getattr(state_class, "model_fields", None) if not model_fields or "id" not in model_fields: @@ -1527,7 +1569,7 @@ class StateWithId(FlowState, state_type): # type: ignore model_instance = state_class() if not getattr(model_instance, "id", None): object.__setattr__(model_instance, "id", str(uuid4())) - return model_instance + return cast(T, model_instance) if init_state is dict: return cast(T, {"id": str(uuid4())}) @@ -1538,32 +1580,21 @@ class StateWithId(FlowState, state_type): # type: ignore new_state["id"] = str(uuid4()) return cast(T, new_state) - # Handle BaseModel instance case if isinstance(init_state, BaseModel): - model = cast(BaseModel, init_state) - if not hasattr(model, "id"): - raise ValueError("Flow state model must have an 'id' field") - - # Create new instance with same values to avoid mutations - if hasattr(model, "model_dump"): - # Pydantic v2 + model = init_state + if hasattr(model, "id"): state_dict = model.model_dump() - elif hasattr(model, "dict"): - # Pydantic v1 - state_dict = model.dict() - else: - # Fallback for other BaseModel implementations - state_dict = { - k: v for k, v in model.__dict__.items() if not k.startswith("_") - } + if not state_dict.get("id"): + state_dict["id"] = str(uuid4()) + model_class = type(model) + return cast(T, model_class(**state_dict)) - # Ensure id is set - generate UUID if empty - if not state_dict.get("id"): - state_dict["id"] = str(uuid4()) + class StateWithId(FlowState, type(model)): # type: ignore + pass - # Create new instance of the same class - model_class = type(model) - return cast(T, model_class(**state_dict)) + state_dict = model.model_dump() + state_dict["id"] = str(uuid4()) + return cast(T, StateWithId(**state_dict)) raise TypeError( f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) @@ -1576,17 +1607,17 @@ def _copy_state(self) -> T: """ if isinstance(self._state, BaseModel): try: - return self._state.model_copy(deep=True) + return cast(T, self._state.model_copy(deep=True)) except (TypeError, AttributeError): try: state_dict = self._state.model_dump() model_class = type(self._state) - return model_class(**state_dict) + return cast(T, model_class(**state_dict)) except Exception: - return self._state.model_copy(deep=False) + return cast(T, self._state.model_copy(deep=False)) else: try: - return copy.deepcopy(self._state) + return cast(T, copy.deepcopy(self._state)) except (TypeError, AttributeError): return cast(T, self._state.copy()) @@ -1662,7 +1693,7 @@ def _initialize_state(self, inputs: dict[str, Any]) -> None: elif isinstance(self._state, BaseModel): # For BaseModel states, preserve existing fields unless overridden try: - model = cast(BaseModel, self._state) + model = self._state # Get current state as dict if hasattr(model, "model_dump"): current_state = model.model_dump() @@ -1713,7 +1744,7 @@ def _restore_state(self, stored_state: dict[str, Any]) -> None: self._state.update(stored_state) elif isinstance(self._state, BaseModel): # For BaseModel states, create new instance with stored values - model = cast(BaseModel, self._state) + model = self._state if hasattr(model, "model_validate"): # Pydantic v2 self._state = cast(T, type(model).model_validate(stored_state)) @@ -1938,7 +1969,7 @@ async def run_flow() -> None: try: # Reset flow state for fresh execution unless restoring from persistence - is_restoring = inputs and "id" in inputs and self._persistence is not None + is_restoring = inputs and "id" in inputs and self.persistence is not None if not is_restoring: # Clear completed methods and outputs for a fresh start self._completed_methods.clear() @@ -1964,9 +1995,9 @@ async def run_flow() -> None: setattr(self._state, "id", inputs["id"]) # noqa: B010 # If persistence is enabled, attempt to restore the stored state using the provided id. - if "id" in inputs and self._persistence is not None: + if "id" in inputs and self.persistence is not None: restore_uuid = inputs["id"] - stored_state = self._persistence.load_state(restore_uuid) + stored_state = self.persistence.load_state(restore_uuid) if stored_state: self._log_flow_event( f"Loading flow state from memory for UUID: {restore_uuid}" @@ -2036,17 +2067,17 @@ async def run_flow() -> None: if isinstance(e, HumanFeedbackPending): # Auto-save pending feedback (create default persistence if needed) - if self._persistence is None: + if self.persistence is None: from crewai.flow.persistence import SQLiteFlowPersistence - self._persistence = SQLiteFlowPersistence() + self.persistence = SQLiteFlowPersistence() state_data = ( self._state if isinstance(self._state, dict) else self._state.model_dump() ) - self._persistence.save_pending_feedback( + self.persistence.save_pending_feedback( flow_uuid=e.context.flow_id, context=e.context, state_data=state_data, @@ -2332,10 +2363,10 @@ async def _execute_method( if isinstance(e, HumanFeedbackPending): e.context.method_name = method_name - if self._persistence is None: + if self.persistence is None: from crewai.flow.persistence import SQLiteFlowPersistence - self._persistence = SQLiteFlowPersistence() + self.persistence = SQLiteFlowPersistence() # Emit paused event (not failed) if not self.suppress_flow_events: @@ -2696,9 +2727,9 @@ async def _execute_single_listener( - Catches and logs any exceptions during execution, preventing individual listener failures from breaking the entire flow """ count = self._method_call_counts.get(listener_name, 0) + 1 - if count > self._max_method_calls: + if count > self.max_method_calls: raise RecursionError( - f"Method '{listener_name}' has been called {self._max_method_calls} times in " + f"Method '{listener_name}' has been called {self.max_method_calls} times in " f"this flow execution, which indicates an infinite loop. " f"This commonly happens when a @listen label matches the " f"method's own name." @@ -2805,7 +2836,7 @@ def _checkpoint_state_for_ask(self) -> None: This is best-effort: if persistence is not configured, this is a no-op. """ - if self._persistence is None: + if self.persistence is None: return try: state_data = ( @@ -2813,7 +2844,7 @@ def _checkpoint_state_for_ask(self) -> None: if isinstance(self._state, dict) else self._state.model_dump() ) - self._persistence.save_state( + self.persistence.save_state( flow_uuid=self.flow_id, method_name="_ask_checkpoint", state_data=state_data, diff --git a/lib/crewai/src/crewai/memory/encoding_flow.py b/lib/crewai/src/crewai/memory/encoding_flow.py index 1580544907a..acd025d5538 100644 --- a/lib/crewai/src/crewai/memory/encoding_flow.py +++ b/lib/crewai/src/crewai/memory/encoding_flow.py @@ -98,7 +98,7 @@ class EncodingFlow(Flow[EncodingState]): _skip_auto_memory: bool = True - initial_state = EncodingState + initial_state: type[EncodingState] = EncodingState def __init__( self, diff --git a/lib/crewai/src/crewai/memory/recall_flow.py b/lib/crewai/src/crewai/memory/recall_flow.py index f056c9a1d2c..3a058f27bdb 100644 --- a/lib/crewai/src/crewai/memory/recall_flow.py +++ b/lib/crewai/src/crewai/memory/recall_flow.py @@ -65,7 +65,7 @@ class RecallFlow(Flow[RecallState]): _skip_auto_memory: bool = True - initial_state = RecallState + initial_state: type[RecallState] = RecallState def __init__( self, diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 1454f0fcfc5..d879bace0cc 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -148,6 +148,36 @@ class Memory(BaseModel): _pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list) _pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory: + """Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock).""" + import copy as _copy + + cls = type(self) + new = cls.__new__(cls) + if memo is None: + memo = {} + memo[id(self)] = new + object.__setattr__(new, "__dict__", _copy.deepcopy(self.__dict__, memo)) + object.__setattr__( + new, "__pydantic_fields_set__", _copy.copy(self.__pydantic_fields_set__) + ) + object.__setattr__( + new, "__pydantic_extra__", _copy.deepcopy(self.__pydantic_extra__, memo) + ) + # Private attrs: create fresh pool/lock instead of deepcopying + private = {} + for k, v in (self.__pydantic_private__ or {}).items(): + if isinstance(v, (ThreadPoolExecutor, threading.Lock)): + attr = self.__private_attributes__[k] + private[k] = attr.get_default() + else: + try: + private[k] = _copy.deepcopy(v, memo) + except Exception: + private[k] = v + object.__setattr__(new, "__pydantic_private__", private) + return new + def model_post_init(self, __context: Any) -> None: """Initialize runtime state from field values.""" self._config = MemoryConfig( diff --git a/lib/crewai/src/crewai/utilities/prompts.py b/lib/crewai/src/crewai/utilities/prompts.py index 57b54be1c17..e88a9708a95 100644 --- a/lib/crewai/src/crewai/utilities/prompts.py +++ b/lib/crewai/src/crewai/utilities/prompts.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import Annotated, Any, Literal, TypedDict +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field +from typing_extensions import TypedDict from crewai.utilities.i18n import I18N, get_i18n diff --git a/lib/crewai/tests/agents/test_agent_executor.py b/lib/crewai/tests/agents/test_agent_executor.py index 9989feb3654..1ec1a1788ca 100644 --- a/lib/crewai/tests/agents/test_agent_executor.py +++ b/lib/crewai/tests/agents/test_agent_executor.py @@ -4,13 +4,55 @@ flow methods, routing logic, and error handling. """ +from __future__ import annotations + import asyncio import time +from typing import Any from unittest.mock import AsyncMock, Mock, patch import pytest +from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler from crewai.agents.step_executor import StepExecutor + + +def _build_executor(**kwargs: Any) -> AgentExecutor: + """Create an AgentExecutor without validation — for unit tests. + + Uses model_construct to skip Pydantic validators so plain Mock() + objects are accepted for typed fields like llm, agent, crew, task. + """ + executor = AgentExecutor.model_construct(**kwargs) + executor._state = AgentExecutorState() + executor._methods = {} + executor._method_outputs = [] + executor._completed_methods = set() + executor._fired_or_listeners = set() + executor._pending_and_listeners = {} + executor._method_execution_counts = {} + executor._method_call_counts = {} + executor._event_futures = [] + executor._human_feedback_method_outputs = {} + executor._input_history = [] + executor._is_execution_resuming = False + import threading + executor._state_lock = threading.Lock() + executor._or_listeners_lock = threading.Lock() + executor._execution_lock = threading.Lock() + executor._finalize_lock = threading.Lock() + executor._finalize_called = False + executor._is_executing = False + executor._has_been_invoked = False + executor._last_parser_error = None + executor._last_context_error = None + executor._step_executor = None + executor._planner_observer = None + from crewai.utilities.printer import Printer + executor._printer = Printer() + from crewai.utilities.i18n import get_i18n + executor._i18n = kwargs.get("i18n") or get_i18n() + return executor from crewai.agents.planner_observer import PlannerObserver from crewai.experimental.agent_executor import ( AgentExecutorState, @@ -75,6 +117,7 @@ def mock_dependencies(self): """Create mock dependencies for executor.""" llm = Mock() llm.supports_stop_words.return_value = True + llm.stop = [] task = Mock() task.description = "Test task" @@ -94,7 +137,7 @@ def mock_dependencies(self): prompt = {"prompt": "Test prompt with {input}, {tool_names}, {tools}"} tools = [] - tools_handler = Mock() + tools_handler = Mock(spec=_ToolsHandler) return { "llm": llm, @@ -112,7 +155,7 @@ def mock_dependencies(self): def test_executor_initialization(self, mock_dependencies): """Test AgentExecutor initialization.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) assert executor.llm == mock_dependencies["llm"] assert executor.task == mock_dependencies["task"] @@ -126,7 +169,7 @@ def test_initialize_reasoning(self, mock_dependencies): with patch.object( AgentExecutor, "_show_start_logs" ) as mock_show_start: - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) result = executor.initialize_reasoning() assert result == "initialized" @@ -134,7 +177,7 @@ def test_initialize_reasoning(self, mock_dependencies): def test_check_max_iterations_not_reached(self, mock_dependencies): """Test routing when iterations < max.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.iterations = 5 result = executor.check_max_iterations() @@ -142,7 +185,7 @@ def test_check_max_iterations_not_reached(self, mock_dependencies): def test_check_max_iterations_reached(self, mock_dependencies): """Test routing when iterations >= max.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.iterations = 10 result = executor.check_max_iterations() @@ -150,7 +193,7 @@ def test_check_max_iterations_reached(self, mock_dependencies): def test_route_by_answer_type_action(self, mock_dependencies): """Test routing for AgentAction.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.current_answer = AgentAction( thought="thinking", tool="search", tool_input="query", text="action text" ) @@ -160,7 +203,7 @@ def test_route_by_answer_type_action(self, mock_dependencies): def test_route_by_answer_type_finish(self, mock_dependencies): """Test routing for AgentFinish.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.current_answer = AgentFinish( thought="final thoughts", output="Final answer", text="complete" ) @@ -170,7 +213,7 @@ def test_route_by_answer_type_finish(self, mock_dependencies): def test_continue_iteration(self, mock_dependencies): """Test iteration continuation.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) result = executor.continue_iteration() @@ -179,7 +222,7 @@ def test_continue_iteration(self, mock_dependencies): def test_finalize_success(self, mock_dependencies): """Test finalize with valid AgentFinish.""" with patch.object(AgentExecutor, "_show_logs") as mock_show_logs: - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.current_answer = AgentFinish( thought="final thinking", output="Done", text="complete" ) @@ -192,7 +235,7 @@ def test_finalize_success(self, mock_dependencies): def test_finalize_failure(self, mock_dependencies): """Test finalize skips when given AgentAction instead of AgentFinish.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.current_answer = AgentAction( thought="thinking", tool="search", tool_input="query", text="action text" ) @@ -208,7 +251,7 @@ def test_finalize_skips_synthesis_for_strong_last_todo_result( ): """Finalize should skip synthesis when last todo is already a complete answer.""" with patch.object(AgentExecutor, "_show_logs") as mock_show_logs: - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.todos.items = [ TodoItem( step_number=1, @@ -252,7 +295,7 @@ def test_finalize_keeps_synthesis_when_response_model_is_set( ): """Finalize should still synthesize when response_model is configured.""" with patch.object(AgentExecutor, "_show_logs"): - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.response_model = Mock() executor.state.todos.items = [ TodoItem( @@ -287,7 +330,7 @@ def _set_current_answer() -> None: def test_format_prompt(self, mock_dependencies): """Test prompt formatting.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) inputs = {"input": "test input", "tool_names": "tool1, tool2", "tools": "desc"} result = executor._format_prompt("Prompt {input} {tool_names} {tools}", inputs) @@ -298,18 +341,18 @@ def test_format_prompt(self, mock_dependencies): def test_is_training_mode_false(self, mock_dependencies): """Test training mode detection when not in training.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) assert executor._is_training_mode() is False def test_is_training_mode_true(self, mock_dependencies): """Test training mode detection when in training.""" mock_dependencies["crew"]._train = True - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) assert executor._is_training_mode() is True def test_append_message_to_state(self, mock_dependencies): """Test message appending to state.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) initial_count = len(executor.state.messages) executor._append_message_to_state("test message") @@ -322,7 +365,7 @@ def test_invoke_step_callback(self, mock_dependencies): callback = Mock() mock_dependencies["step_callback"] = callback - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) answer = AgentFinish(thought="thinking", output="test", text="final") executor._invoke_step_callback(answer) @@ -332,7 +375,7 @@ def test_invoke_step_callback(self, mock_dependencies): def test_invoke_step_callback_none(self, mock_dependencies): """Test step callback when none provided.""" mock_dependencies["step_callback"] = None - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) # Should not raise error executor._invoke_step_callback( @@ -346,7 +389,7 @@ async def test_invoke_step_callback_async_inside_running_loop( """Test async step callback scheduling when already in an event loop.""" callback = AsyncMock() mock_dependencies["step_callback"] = callback - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) answer = AgentFinish(thought="thinking", output="test", text="final") with patch("crewai.experimental.agent_executor.asyncio.run") as mock_run: @@ -364,6 +407,7 @@ class TestStepExecutorCriticalFixes: def mock_dependencies(self): """Create mock dependencies for AgentExecutor tests in this class.""" llm = Mock() + llm.stop = [] llm.supports_stop_words.return_value = True task = Mock() @@ -393,6 +437,7 @@ def mock_dependencies(self): @pytest.fixture def step_executor(self): llm = Mock() + llm.stop = [] llm.supports_stop_words.return_value = True agent = Mock() @@ -485,7 +530,7 @@ def test_recover_from_parser_error( mock_handle_exception.return_value = None - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor._last_parser_error = OutputParserError("test error") initial_iterations = executor.state.iterations @@ -500,7 +545,7 @@ def test_recover_from_context_length( self, mock_handle_context, mock_dependencies ): """Test recovery from context length error.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor._last_context_error = Exception("context too long") initial_iterations = executor.state.iterations @@ -513,16 +558,16 @@ def test_recover_from_context_length( def test_use_stop_words_property(self, mock_dependencies): """Test use_stop_words property.""" mock_dependencies["llm"].supports_stop_words.return_value = True - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) assert executor.use_stop_words is True mock_dependencies["llm"].supports_stop_words.return_value = False - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) assert executor.use_stop_words is False def test_compatibility_properties(self, mock_dependencies): """Test compatibility properties for mixin.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.messages = [{"role": "user", "content": "test"}] executor.state.iterations = 5 @@ -538,6 +583,7 @@ class TestFlowErrorHandling: def mock_dependencies(self): """Create mock dependencies.""" llm = Mock() + llm.stop = [] llm.supports_stop_words.return_value = True task = Mock() @@ -575,7 +621,7 @@ def test_call_llm_parser_error( mock_enforce_rpm.return_value = None mock_get_llm.side_effect = OutputParserError("parse failed") - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) result = executor.call_llm_and_parse() assert result == "parser_error" @@ -596,7 +642,7 @@ def test_call_llm_context_error( mock_get_llm.side_effect = Exception("context length") mock_is_context_exceeded.return_value = True - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) result = executor.call_llm_and_parse() assert result == "context_error" @@ -610,6 +656,7 @@ class TestFlowInvoke: def mock_dependencies(self): """Create mock dependencies.""" llm = Mock() + llm.stop = [] task = Mock() task.description = "Test" task.human_input = False @@ -646,7 +693,7 @@ def test_invoke_success( mock_dependencies, ): """Test successful invoke without human feedback.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) # Mock kickoff to set the final answer in state def mock_kickoff_side_effect(): @@ -666,7 +713,7 @@ def mock_kickoff_side_effect(): @patch.object(AgentExecutor, "kickoff") def test_invoke_failure_no_agent_finish(self, mock_kickoff, mock_dependencies): """Test invoke fails without AgentFinish.""" - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) executor.state.current_answer = AgentAction( thought="thinking", tool="test", tool_input="test", text="action text" ) @@ -689,7 +736,7 @@ def test_invoke_with_system_prompt( "system": "System: {input}", "user": "User: {input} {tool_names} {tools}", } - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) def mock_kickoff_side_effect(): executor.state.current_answer = AgentFinish( @@ -713,6 +760,7 @@ class TestNativeToolExecution: @pytest.fixture def mock_dependencies(self): llm = Mock() + llm.stop = [] llm.supports_stop_words.return_value = True task = Mock() @@ -734,7 +782,7 @@ def mock_dependencies(self): prompt = {"prompt": "Test {input} {tool_names} {tools}"} - tools_handler = Mock() + tools_handler = Mock(spec=_ToolsHandler) tools_handler.cache = None return { @@ -754,7 +802,7 @@ def mock_dependencies(self): def test_execute_native_tool_runs_parallel_for_multiple_calls( self, mock_dependencies ): - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) def slow_one() -> str: time.sleep(0.2) @@ -790,7 +838,7 @@ def slow_two() -> str: def test_execute_native_tool_falls_back_to_sequential_for_result_as_answer( self, mock_dependencies ): - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) def slow_one() -> str: time.sleep(0.2) @@ -832,7 +880,7 @@ def slow_two() -> str: def test_execute_native_tool_result_as_answer_short_circuits_remaining_calls( self, mock_dependencies ): - executor = AgentExecutor(**mock_dependencies) + executor = _build_executor(**mock_dependencies) call_counts = {"slow_one": 0, "slow_two": 0} def slow_one() -> str: diff --git a/lib/crewai/tests/test_async_human_feedback.py b/lib/crewai/tests/test_async_human_feedback.py index a72147213c0..a664c6ffa2d 100644 --- a/lib/crewai/tests/test_async_human_feedback.py +++ b/lib/crewai/tests/test_async_human_feedback.py @@ -873,7 +873,7 @@ def generate(self): # Create flow WITHOUT persistence flow = TestFlow() - assert flow._persistence is None # No persistence initially + assert flow.persistence is None # No persistence initially # kickoff should auto-create persistence when HumanFeedbackPending is raised result = flow.kickoff() @@ -882,11 +882,11 @@ def generate(self): assert isinstance(result, HumanFeedbackPending) # Persistence should have been auto-created - assert flow._persistence is not None + assert flow.persistence is not None # The pending feedback should be saved flow_id = result.context.flow_id - loaded = flow._persistence.load_pending_feedback(flow_id) + loaded = flow.persistence.load_pending_feedback(flow_id) assert loaded is not None