diff --git a/debug_test.py b/debug_test.py new file mode 100644 index 0000000..2ffd5b2 --- /dev/null +++ b/debug_test.py @@ -0,0 +1,36 @@ +import asyncio +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch + +from src.api.app import create_app + +app = create_app() +client = TestClient(app) + +with patch("src.api.routes.memory.require_api_key", return_value={"username": "test_user"}): + from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready + app.dependency_overrides[require_api_key] = lambda: {"username": "test_user"} + app.dependency_overrides[enforce_rate_limit] = lambda: True + app.dependency_overrides[require_ready] = lambda: True + + payload = { + "items": [ + { + "user_query": "Hello world", + "agent_response": "Hi there", + "user_id": "test_user_1", + } + ] + } + + try: + response = client.post( + "/v1/memory/batch-ingest", + json=payload, + headers={"Authorization": "Bearer test-key"} + ) + print("Status code:", response.status_code) + import json + print(json.dumps(response.json(), indent=2)) + except Exception as e: + print("Exception:", e) diff --git a/server.py b/server.py index 593535b..d487241 100644 --- a/server.py +++ b/server.py @@ -50,6 +50,7 @@ from src.pipelines.ingest import IngestPipeline from src.pipelines.retrieval import RetrievalPipeline +from src.api.ingestion_coordinator import UserIngestionCoordinator # ═══════════════════════════════════════════════════════════════════ @@ -82,6 +83,7 @@ def emit(self, record: logging.LogRecord) -> None: _pipelines_ready = asyncio.Event() _init_error: str | None = None SKIP_PIPELINES = os.getenv("XMEM_SKIP_PIPELINES", "").lower() in {"1", "true", "yes"} +_user_coordinator = UserIngestionCoordinator() def _init_pipelines_sync() -> None: @@ -315,14 +317,15 @@ async def v1_ingest_memory(req: IngestRequest): ) try: - result = await ingest_pipeline.run( - user_query=req.user_query, - agent_response=req.agent_response or "Acknowledged.", - user_id=req.user_id, - session_datetime=req.session_datetime, - image_url=req.image_url, - effort_level=req.effort_level, - ) + async with _user_coordinator.acquire(req.user_id): + result = await ingest_pipeline.run( + user_query=req.user_query, + agent_response=req.agent_response or "Acknowledged.", + user_id=req.user_id, + session_datetime=req.session_datetime, + image_url=req.image_url, + effort_level=req.effort_level, + ) data = { "model": _get_model_name(ingest_pipeline.model), @@ -368,11 +371,12 @@ async def api_ingest(req: IngestRequest): lg.addHandler(capture) try: - result = await ingest_pipeline.run( - user_query=req.user_query, - agent_response=req.agent_response or "Acknowledged.", - user_id=req.user_id, - ) + async with _user_coordinator.acquire(req.user_id): + result = await ingest_pipeline.run( + user_query=req.user_query, + agent_response=req.agent_response or "Acknowledged.", + user_id=req.user_id, + ) # Build structured response response: Dict[str, Any] = { diff --git a/src/agents/judge.py b/src/agents/judge.py index 8ca8c57..7a74a25 100644 --- a/src/agents/judge.py +++ b/src/agents/judge.py @@ -19,6 +19,7 @@ import asyncio import json +from difflib import SequenceMatcher from typing import Any, Callable, Dict, List, Optional from langchain_core.language_models import BaseChatModel @@ -162,7 +163,9 @@ def __init__( # Public entry point # ------------------------------------------------------------------ - async def arun(self, state: Dict[str, Any]) -> JudgeResult: + async def arun( + self, state: Dict[str, Any], pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: domain_str = state.get("domain", "") try: domain = JudgeDomain(domain_str) @@ -186,6 +189,7 @@ async def arun(self, state: Dict[str, Any]) -> JudgeResult: new_items=new_items, user_id=user_id, domain=domain, + pending_ops=pending_ops, ) if domain == JudgeDomain.SUMMARY and not _has_summary_judge_candidates(matches_per_item): @@ -218,7 +222,9 @@ async def arun(self, state: Dict[str, Any]) -> JudgeResult: return result - async def arun_deterministic(self, state: Dict[str, Any]) -> JudgeResult: + async def arun_deterministic( + self, state: Dict[str, Any], pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: """Build operations without an LLM for structured domains. Profile and temporal extraction already returns normalized structured @@ -238,15 +244,15 @@ async def arun_deterministic(self, state: Dict[str, Any]) -> JudgeResult: return JudgeResult() if domain == JudgeDomain.PROFILE: - result = await self._deterministic_profile(new_items, user_id) + result = await self._deterministic_profile(new_items, user_id, pending_ops=pending_ops) elif domain == JudgeDomain.TEMPORAL: - result = await self._deterministic_temporal(new_items, user_id) + result = await self._deterministic_temporal(new_items, user_id, pending_ops=pending_ops) else: self.logger.warning( "Deterministic judge unsupported for %s; falling back to LLM judge.", domain.value, ) - return await self.arun(state) + return await self.arun(state, pending_ops=pending_ops) self._log_result(domain, result) return result @@ -275,11 +281,16 @@ async def _fetch_similar( new_items: list, user_id: str, domain: JudgeDomain, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: if domain == JudgeDomain.TEMPORAL: - return await self._fetch_similar_temporal(items_strings, new_items, user_id) + return await self._fetch_similar_temporal( + items_strings, new_items, user_id, pending_ops=pending_ops + ) else: - return await self._fetch_similar_vector(items_strings, new_items, user_id, domain) + return await self._fetch_similar_vector( + items_strings, new_items, user_id, domain, pending_ops=pending_ops + ) # -- Profile / Summary: Pinecone (vector store) ----------------------- @@ -289,6 +300,7 @@ async def _fetch_similar_vector( new_items: list, user_id: str, domain: JudgeDomain, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: if not self.vector_store: self.logger.debug("No vector store attached — skipping similarity search.") @@ -297,7 +309,7 @@ async def _fetch_similar_vector( # Profile domain: use deterministic metadata lookup if domain == JudgeDomain.PROFILE: return await self._fetch_similar_profile_metadata( - items_strings, new_items, user_id, + items_strings, new_items, user_id, pending_ops=pending_ops ) # Summary / other: parallel semantic search across all items @@ -320,7 +332,43 @@ async def _search_one(item_str: str) -> tuple[str, List[SearchResult]]: return item_str, [] pairs = await asyncio.gather(*(_search_one(s) for s in items_strings)) - return dict(pairs) + matches_per_item = dict(pairs) + + if pending_ops: + for item_str in items_strings: + matches = matches_per_item.get(item_str, []) + + # Apply deletes first + deletes = {op.embedding_id for op in pending_ops if op.type == OperationType.DELETE and op.embedding_id} + if deletes: + matches = [m for m in matches if m.id not in deletes] + + # Apply adds/updates + for op in pending_ops: + if op.type in (OperationType.ADD, OperationType.UPDATE): + ratio = SequenceMatcher(None, _norm_text(item_str), _norm_text(op.content)).ratio() + if ratio > 0.5: + simulated = SearchResult( + id=op.embedding_id or f"pending_{domain.value}_{hash(op.content)}", + content=op.content, + score=ratio, + metadata={ + "domain": domain.value, + "user_id": user_id, + } + ) + # Replace if same ID already exists in matches, otherwise prepend + existing_idx = next((i for i, m in enumerate(matches) if m.id == simulated.id), None) + if existing_idx is not None: + matches[existing_idx] = simulated + else: + matches.insert(0, simulated) + + # Sort matches by score descending + matches = sorted(matches, key=lambda x: x.score or 0.0, reverse=True)[:self.top_k] + matches_per_item[item_str] = matches + + return matches_per_item # -- Profile: deterministic metadata lookup ---------------------------- @@ -329,6 +377,7 @@ async def _fetch_similar_profile_metadata( items_strings: List[str], new_items: list, user_id: str, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: """Fetch existing profile records by exact topic_subtopic match (parallel). @@ -377,7 +426,36 @@ async def _lookup_one(idx: int, item_str: str) -> tuple[str, List[SearchResult]] pairs = await asyncio.gather( *(_lookup_one(i, s) for i, s in enumerate(items_strings)) ) - return dict(pairs) + matches_per_item = dict(pairs) + + if pending_ops: + for idx, item_str in enumerate(items_strings): + item = new_items[idx] if idx < len(new_items) else {} + meta_key = _build_profile_metadata_key(item) + if not meta_key: + continue + + for op in pending_ops: + op_meta_key = _profile_meta_key_from_content(op.content) + if op_meta_key == meta_key: + if op.type in (OperationType.ADD, OperationType.UPDATE): + _, _, memo = _parse_profile_content(op.content) + simulated = SearchResult( + id=op.embedding_id or f"pending_profile_{meta_key}", + content=op.content, + score=1.0, + metadata={ + "main_content": meta_key, + "subcontent": memo, + "domain": "profile", + "user_id": user_id, + } + ) + matches_per_item[item_str] = [simulated] + elif op.type == OperationType.DELETE: + matches_per_item[item_str] = [] + + return matches_per_item # -- Semantic search fallback (summary domain) ------------------------- @@ -406,6 +484,7 @@ async def _fetch_similar_temporal( items_strings: List[str], new_items: list, user_id: str, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: if not self.graph_event_search: self.logger.debug("No graph_event_search provided — skipping Neo4j lookup.") @@ -434,7 +513,36 @@ async def _lookup_one(idx: int, item_str: str) -> tuple[str, List[SearchResult]] pairs = await asyncio.gather( *(_lookup_one(i, s) for i, s in enumerate(items_strings)) ) - return dict(pairs) + matches_per_item = dict(pairs) + + if pending_ops: + for idx, item_str in enumerate(items_strings): + event = new_items[idx] if idx < len(new_items) else {} + event_name = event.get("event_name", "") if isinstance(event, dict) else "" + if not event_name: + continue + norm_event_name = _norm_text(event_name) + + for op in pending_ops: + fields = _temporal_fields_from_content(op.content) + op_event_name = fields.get("event_name", "") + if _norm_text(op_event_name) == norm_event_name: + if op.type in (OperationType.ADD, OperationType.UPDATE): + simulated = SearchResult( + id=op.embedding_id or f"pending_temporal_{norm_event_name}", + content=op.content, + score=1.0, + metadata={ + **fields, + "domain": "temporal", + "user_id": user_id, + } + ) + matches_per_item[item_str] = [simulated] + elif op.type == OperationType.DELETE: + matches_per_item[item_str] = [] + + return matches_per_item # -- Deterministic operation builders --------------------------------- @@ -687,3 +795,20 @@ def _temporal_fields_from_match(match: SearchResult) -> Dict[str, str]: def _same_temporal_event(incoming: Dict[str, str], existing: Dict[str, str]) -> bool: keys = ["date", "event_name", "desc", "year", "time", "date_expression"] return all(_norm_text(incoming.get(key)) == _norm_text(existing.get(key)) for key in keys) + + +def _parse_profile_content(content: str) -> tuple[str, str, str]: + if " = " not in content: + return "", "", content + left, memo = content.split(" = ", 1) + if " / " not in left: + return left.strip(), "", memo.strip() + topic, sub_topic = left.split(" / ", 1) + return topic.strip(), sub_topic.strip(), memo.strip() + + +def _profile_meta_key_from_content(content: str) -> str: + topic, sub_topic, _ = _parse_profile_content(content) + if not topic or not sub_topic: + return "" + return f"{topic}_{sub_topic}".replace(" ", "_").lower() diff --git a/src/api/ingestion_coordinator.py b/src/api/ingestion_coordinator.py new file mode 100644 index 0000000..1503a3b --- /dev/null +++ b/src/api/ingestion_coordinator.py @@ -0,0 +1,94 @@ +""" +Per-user ingestion coordinator — serialises ingestion for each user. + +Guarantees that only one ingestion pipeline runs at a time for any given +``user_id``, while allowing different users to proceed in parallel. +Requests for the same user are processed in strict FIFO order. + +This is the **in-memory** implementation (Option 1). A future distributed +lock (Redis, etc.) can be swapped in by implementing the same ``acquire()`` +context-manager interface. + +Usage:: + + from src.api.ingestion_coordinator import UserIngestionCoordinator + + coordinator = UserIngestionCoordinator() + + async with coordinator.acquire(user_id): + result = await pipeline.run(...) +""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import AsyncIterator, Dict + +logger = logging.getLogger("xmem.api.ingestion_coordinator") + + +class UserIngestionCoordinator: + """Per-user FIFO ingestion lock. + + Internally maintains a ``dict[str, asyncio.Lock]`` keyed by ``user_id``. + Locks are created lazily on first access and removed once no tasks are + waiting or holding them, preventing unbounded memory growth. + + Thread-safety note + ------------------ + All mutations to the internal registry are protected by a single + ``asyncio.Lock`` (the *registry lock*). Since this code runs on the + asyncio event loop, ``asyncio.Lock`` is sufficient — no OS-level + threading primitives are needed. + """ + + def __init__(self) -> None: + # Maps user_id -> (asyncio.Lock, active_count) + # active_count tracks how many tasks are either holding or waiting + # for the lock so we know when it's safe to clean up. + self._locks: Dict[str, asyncio.Lock] = {} + self._waiters: Dict[str, int] = {} + self._registry_lock = asyncio.Lock() + + @asynccontextmanager + async def acquire(self, user_id: str) -> AsyncIterator[None]: + """Acquire the per-user ingestion lock. + + Usage:: + + async with coordinator.acquire("user_123"): + # Only one coroutine per user_id reaches here at a time. + await do_work() + + The lock is automatically released (and cleaned up if idle) when + the ``async with`` block exits, even if an exception is raised. + """ + # ── Get-or-create the user lock ────────────────────────────── + async with self._registry_lock: + if user_id not in self._locks: + self._locks[user_id] = asyncio.Lock() + self._waiters[user_id] = 0 + self._waiters[user_id] += 1 + user_lock = self._locks[user_id] + + logger.debug("User %s: waiting for ingestion lock (waiters=%d)", user_id, self._waiters.get(user_id, 0)) + + try: + async with user_lock: + logger.debug("User %s: ingestion lock acquired", user_id) + yield + finally: + # ── Cleanup: remove the lock if nobody else is waiting ──── + async with self._registry_lock: + self._waiters[user_id] -= 1 + if self._waiters[user_id] <= 0: + self._locks.pop(user_id, None) + self._waiters.pop(user_id, None) + logger.debug("User %s: ingestion lock cleaned up", user_id) + + @property + def active_users(self) -> int: + """Return the number of users with active or pending ingestion locks.""" + return len(self._locks) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index ab40c07..0c7d91b 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -42,6 +42,7 @@ WeaverSummary, ) from src.pipelines.retrieval import RetrievalPipeline +from src.api.ingestion_coordinator import UserIngestionCoordinator from bs4 import BeautifulSoup import json @@ -58,6 +59,7 @@ logger = logging.getLogger("xmem.api.routes.memory") _ingest_semaphore = asyncio.Semaphore(5) +_user_coordinator = UserIngestionCoordinator() router = APIRouter( prefix="/v1/memory", @@ -210,14 +212,55 @@ async def _run_ingest_payload( return data.model_dump() -async def _run_batch_ingest_payload( +async def _run_staged_batch_payload( payload: Dict[str, Any], user_id: str, ) -> Dict[str, Any]: - results = [] + pipeline = get_ingest_pipeline() + items = [] for item in payload["items"]: - results.append(await _run_ingest_payload(item, user_id)) - return {"results": results} + if hasattr(item, "model_dump"): + items.append(item.model_dump()) + elif isinstance(item, dict): + items.append(item) + else: + items.append(dict(item)) + + batch_results = await pipeline.run_staged_batch(items, user_id=user_id) + + results = [] + for result in batch_results: + data = IngestResponse( + model=_model_name(pipeline.model), + classification=_safe_classifications(result), + profile=_build_domain_result( + result.get("profile_judge"), + result.get("profile_weaver"), + ), + temporal=_build_domain_result( + result.get("temporal_judge"), + result.get("temporal_weaver"), + ), + summary=_build_domain_result( + result.get("summary_judge"), + result.get("summary_weaver"), + ), + image=_build_domain_result( + result.get("image_judge"), + result.get("image_weaver"), + ), + ) + results.append(data) + + return {"results": [r.model_dump() for r in results]} + + +async def _run_batch_ingest_payload( + payload: Dict[str, Any], + user_id: str, +) -> Dict[str, Any]: + async with _ingest_semaphore: + return await _run_staged_batch_payload(payload, user_id) async def _run_scrape_payload(payload: Dict[str, Any]) -> Dict[str, Any]: @@ -681,10 +724,11 @@ async def ingest_memory(req: IngestRequest, request: Request, user: dict = Depen payload = req.model_dump() try: - data = await asyncio.wait_for( - _run_ingest_payload(payload, user_id), - timeout=120.0, - ) + async with _user_coordinator.acquire(user_id): + data = await asyncio.wait_for( + _run_ingest_payload(payload, user_id), + timeout=120.0, + ) elapsed = round((time.perf_counter() - start) * 1000, 2) return _wrap(request, data, elapsed) @@ -800,15 +844,15 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d user_id = _current_user_id(user) try: - results = [] - for item in req.items: - data = await asyncio.wait_for( - _run_ingest_payload(item.model_dump(), user_id), - timeout=120.0, - ) - results.append(IngestResponse(**data)) + payload = req.model_dump() + async with _user_coordinator.acquire(user_id): + async with _ingest_semaphore: + data = await asyncio.wait_for( + _run_staged_batch_payload(payload, user_id), + timeout=max(120.0, len(req.items) * 120.0), + ) - response_data = BatchIngestResponse(results=results) + response_data = BatchIngestResponse(**data) elapsed = round((time.perf_counter() - start) * 1000, 2) return _wrap(request, response_data, elapsed) diff --git a/src/pipelines/ingest.py b/src/pipelines/ingest.py index 446440a..f2b8c87 100644 --- a/src/pipelines/ingest.py +++ b/src/pipelines/ingest.py @@ -82,7 +82,7 @@ ) from src.schemas.events import EventResult from src.schemas.image import ImageResult -from src.schemas.judge import JudgeDomain, JudgeResult, OperationType +from src.schemas.judge import JudgeDomain, JudgeResult, OperationType, Operation from src.schemas.profile import ProfileResult from src.schemas.summary import SummaryResult from src.schemas.weaver import WeaverResult @@ -762,33 +762,130 @@ def _route_after_classify(self, state: IngestState) -> List[Send]: # ── Extraction nodes ────────────────────────────────────────────── + # ── Decoupled helpers ───────────────────────────────────────────── + + async def _extract_profile(self, combined_query: str) -> ProfileResult: + return await self.profiler.arun({"classifier_output": combined_query}) + + async def _judge_profile( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun_deterministic({ + "domain": "profile", + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_profile(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.PROFILE, + user_id=user_id, + ) + + async def _extract_temporal(self, combined_query: str, session_dt: str) -> EventResult: + return await self.temporal.arun({ + "classifier_output": combined_query, + "session_datetime": session_dt, + }) + + async def _judge_temporal( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun_deterministic({ + "domain": "temporal", + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_temporal(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.TEMPORAL, + user_id=user_id, + ) + + async def _extract_image(self, state: IngestState) -> ImageResult: + return await self.image_agent.arun(state) + + async def _extract_code(self, combined_query: str) -> CodeAnnotationResult: + return await self.code_agent.arun({"classifier_output": combined_query}) + + async def _judge_code( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun({ + "domain": JudgeDomain.CODE, + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_code(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.CODE, + user_id=user_id, + ) + + async def _extract_snippet(self, combined_query: str) -> SnippetExtractionResult: + return await self.snippet_agent.arun({"classifier_output": combined_query}) + + async def _judge_snippet( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun({ + "domain": JudgeDomain.SNIPPET, + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_snippet(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + self.weaver.snippet_vector_store = self._get_snippet_store(user_id) + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.SNIPPET, + user_id=user_id, + ) + + async def _extract_summary(self, user_query: str, agent_response: str) -> SummaryResult: + return await self.summarizer.arun({ + "user_query": user_query, + "agent_response": agent_response, + }) + + async def _judge_summary( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun({ + "domain": JudgeDomain.SUMMARY, + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_summary(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.SUMMARY, + user_id=user_id, + ) + + # ── Extraction nodes ────────────────────────────────────────────── + async def _node_extract_profile(self, state: IngestState) -> Dict[str, Any]: """Extract profile facts from the classifier query.""" queries = state.get("profile_queries", []) user_id = state.get("user_id", "default") - # Merge into a single query (safety net if classifier outputs duplicate lines) combined_query = " ".join(queries) - result = await self.profiler.arun({"classifier_output": combined_query}) + result = await self._extract_profile(combined_query) if result.is_empty: return {"status": "no_profile_facts"} - # Profile facts are already structured; exact metadata lookup avoids - # an extra judge LLM call on the hot path. items = [f.model_dump() for f in result.facts] - judge_result = await self.judge.arun_deterministic({ - "domain": "profile", - "new_items": items, - "user_id": user_id, - }) + judge_result = await self._judge_profile(items, user_id) - # Weave - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.PROFILE, - user_id=user_id, - ) + weaver_result = await self._weave_profile(judge_result, user_id) return { "profile_result": result, "profile_judge": judge_result, @@ -801,12 +898,8 @@ async def _node_extract_temporal(self, state: IngestState) -> Dict[str, Any]: user_id = state.get("user_id", "default") session_dt = state.get("session_datetime", "") - # Merge into a single query combined_query = " ".join(queries) - result = await self.temporal.arun({ - "classifier_output": combined_query, - "session_datetime": session_dt, - }) + result = await self._extract_temporal(combined_query, session_dt) if result.is_empty: return {"status": "no_temporal_event"} @@ -822,17 +915,9 @@ async def _node_extract_temporal(self, state: IngestState) -> Dict[str, Any]: "date_expression": event.date_expression or "", }) - judge_result = await self.judge.arun_deterministic({ - "domain": "temporal", - "new_items": all_items, - "user_id": user_id, - }) + judge_result = await self._judge_temporal(all_items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.TEMPORAL, - user_id=user_id, - ) + weaver_result = await self._weave_temporal(judge_result, user_id) return { "temporal_result": result, "temporal_judge": judge_result, @@ -843,16 +928,11 @@ async def _node_extract_image(self, state: IngestState) -> Dict[str, Any]: """Extract visual observations from the image and store them as summary.""" user_id = state.get("user_id", "default") - # ImageAgent reads classifier_output and image_url from state - result = await self.image_agent.arun(state) + result = await self._extract_image(state) if result.is_empty: return {"status": "no_image_observations"} - # Convert observations to list of dicts for Judge - # items = [obs.model_dump() for obs in result.observations] - - #converted observation of images to summary and stored as summary items = [] if result.description: items.append(f"[Image] {result.description}") @@ -863,17 +943,9 @@ async def _node_extract_image(self, state: IngestState) -> Dict[str, Any]: if not items: return {"status": "no_image_observations"} - judge_result = await self.judge.arun({ - "domain": JudgeDomain.SUMMARY, - "new_items": items, - "user_id": user_id, - }) + judge_result = await self._judge_summary(items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.SUMMARY, - user_id=user_id, - ) + weaver_result = await self._weave_summary(judge_result, user_id) return { "image_result": result, @@ -886,9 +958,8 @@ async def _node_extract_code(self, state: IngestState) -> Dict[str, Any]: queries = state.get("code_queries", []) user_id = state.get("user_id", "default") - # Merge into a single query combined_query = " ".join(queries) - result = await self.code_agent.arun({"classifier_output": combined_query}) + result = await self._extract_code(combined_query) if result.is_empty: return {"status": "no_code_annotations"} @@ -905,17 +976,9 @@ async def _node_extract_code(self, state: IngestState) -> Dict[str, Any]: ] all_items.append(" | ".join(parts)) - judge_result = await self.judge.arun({ - "domain": JudgeDomain.CODE, - "new_items": all_items, - "user_id": user_id, - }) + judge_result = await self._judge_code(all_items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.CODE, - user_id=user_id, - ) + weaver_result = await self._weave_code(judge_result, user_id) return { "code_result": result, "code_judge": judge_result, @@ -927,9 +990,8 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: queries = state.get("code_queries", []) user_id = state.get("user_id", "default") - # Merge into a single query combined_query = " ".join(queries) - result = await self.snippet_agent.arun({"classifier_output": combined_query}) + result = await self._extract_snippet(combined_query) if result.is_empty: return {"status": "no_snippets"} @@ -945,20 +1007,9 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: ] all_items.append(" | ".join(parts)) - judge_result = await self.judge.arun({ - "domain": JudgeDomain.SNIPPET, - "new_items": all_items, - "user_id": user_id, - }) - - # Bind the user-scoped snippet store before executing - self.weaver.snippet_vector_store = self._get_snippet_store(user_id) + judge_result = await self._judge_snippet(all_items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.SNIPPET, - user_id=user_id, - ) + weaver_result = await self._weave_snippet(judge_result, user_id) return { "snippet_result": result, "snippet_judge": judge_result, @@ -966,14 +1017,14 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: } async def _node_extract_summary(self, state: IngestState) -> Dict[str, Any]: - result = await self.summarizer.arun({ - "user_query": state.get("user_query", ""), - "agent_response": state.get("agent_response", ""), - }) + user_id = state.get("user_id", "default") + result = await self._extract_summary( + user_query=state.get("user_query", ""), + agent_response=state.get("agent_response", ""), + ) if result.is_empty: return {"status": "no_summary"} - # Split bullet summary into individual items items = [ line.lstrip("- •").strip() for line in result.summary.strip().splitlines() @@ -982,17 +1033,9 @@ async def _node_extract_summary(self, state: IngestState) -> Dict[str, Any]: if not items: return {"status": "no_summary_items"} - judge_result = await self.judge.arun({ - "domain": "summary", - "new_items": items, - "user_id": state.get("user_id", "default"), - }) + judge_result = await self._judge_summary(items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.SUMMARY, - user_id=state.get("user_id", "default"), - ) + weaver_result = await self._weave_summary(judge_result, user_id) return { "summary_result": result, "summary_judge": judge_result, @@ -1143,6 +1186,285 @@ async def _invoke_graph( } return await self.graph.ainvoke(initial_state) + async def _process_item_phase_a(self, idx: int, item: Dict[str, Any], user_id: str) -> Dict[str, Any]: + """Phase A - Classification and domain extraction concurrently for a single item.""" + user_query = item.get("user_query", "") + agent_response = item.get("agent_response", "") or "Acknowledged." + session_dt = item.get("session_datetime", "") + image_url = item.get("image_url", "") + disabled_domains = set(item.get("disabled_domains") or []) + + # Run Classifier + classifier_query = user_query + if image_url: + classifier_query += " [User has attached an image]" + + classification_result = await self.classifier.arun({ + "user_query": classifier_query, + }) + + # Collect sub-queries per domain + profile_queries = [] + temporal_queries = [] + image_queries = [] + code_queries = [] + + if classification_result and classification_result.classifications: + for c in classification_result.classifications: + if c["source"] == "profile": + profile_queries.append(c["query"]) + elif c["source"] == "event": + temporal_queries.append(c["query"]) + elif c["source"] == "image": + image_queries.append(c["query"]) + elif c["source"] == "code": + code_queries.append(c["query"]) + + words = user_query.split() + is_trivial = len(words) < 4 and not any([profile_queries, temporal_queries, code_queries, image_queries]) + + tasks = [] + task_names = [] + + if not is_trivial: + tasks.append(self._extract_summary(user_query, agent_response)) + task_names.append("summary") + + if profile_queries: + combined_profile = " ".join(profile_queries) + tasks.append(self._extract_profile(combined_profile)) + task_names.append("profile") + + if temporal_queries: + combined_temporal = " ".join(temporal_queries) + tasks.append(self._extract_temporal(combined_temporal, session_dt)) + task_names.append("temporal") + + if code_queries and not {"code", "snippet"}.issubset(disabled_domains): + is_enterprise = self.org_id != "default" + if is_enterprise and "code" not in disabled_domains: + combined_code = " ".join(code_queries) + tasks.append(self._extract_code(combined_code)) + task_names.append("code") + elif not is_enterprise and "snippet" not in disabled_domains: + combined_snippet = " ".join(code_queries) + tasks.append(self._extract_snippet(combined_snippet)) + task_names.append("snippet") + + if image_url: + if not image_queries: + image_queries.append("Analyze this image for memory-relevant details.") + combined_image = " ".join(image_queries) + image_state = { + "classifier_output": combined_image, + "image_url": image_url, + "user_id": user_id, + } + tasks.append(self._extract_image(image_state)) + task_names.append("image") + + extraction_results = await asyncio.gather(*tasks, return_exceptions=True) + + item_state = { + "user_query": user_query, + "agent_response": agent_response, + "user_id": user_id, + "session_datetime": session_dt, + "image_url": image_url, + "disabled_domains": list(disabled_domains), + "classification_result": classification_result, + "errors": [], + "status": "extracted", + } + + for name, result in zip(task_names, extraction_results): + if isinstance(result, Exception): + logger.error(f"Error during {name} extraction for batch item {idx}: {result}") + item_state["errors"].append(f"{name}_extraction_error: {str(result)}") + else: + item_state[f"{name}_result"] = result + + return {"idx": idx, "item_state": item_state} + + async def run_staged_batch( + self, + items: List[Dict[str, Any]], + user_id: str, + ) -> List[Dict[str, Any]]: + """Run batch memory ingestion using a staged parallel/sequential hybrid pipeline.""" + logger.info("=" * 60) + logger.info("RUN STAGED BATCH: %d items", len(items)) + logger.info("=" * 60) + + # Phase A: Concurrently run classification + domain extraction across all items + phase_a_tasks = [self._process_item_phase_a(idx, item, user_id) for idx, item in enumerate(items)] + phase_a_outputs = await asyncio.gather(*phase_a_tasks) + + # Phase B: Sequentially run Judge across all items with pending_ops tracking + pending_ops: List[Operation] = [] + + for phase_a_out in phase_a_outputs: + item_state = phase_a_out["item_state"] + idx = phase_a_out["idx"] + + judge_tasks = [] + judge_domains = [] + + # 1. Profile facts + profile_res = item_state.get("profile_result") + if profile_res and not profile_res.is_empty: + items_data = [f.model_dump() for f in profile_res.facts] + judge_tasks.append(self._judge_profile(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("profile") + + # 2. Temporal events + temporal_res = item_state.get("temporal_result") + if temporal_res and not temporal_res.is_empty: + items_data = [] + for event in temporal_res.events: + items_data.append({ + "date": event.date, + "event_name": event.event_name or "", + "desc": event.desc or "", + "year": event.year or "", + "time": event.time or "", + "date_expression": event.date_expression or "", + }) + judge_tasks.append(self._judge_temporal(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("temporal") + + # 3. Summary (and Image) + summary_res = item_state.get("summary_result") + image_res = item_state.get("image_result") + + summary_items = [] + if summary_res and not summary_res.is_empty: + summary_items.extend([ + line.lstrip("- •").strip() + for line in summary_res.summary.strip().splitlines() + if line.strip() and line.strip() not in ("-", "•") + ]) + + if image_res and not image_res.is_empty: + if image_res.description: + summary_items.append(f"[Image] {image_res.description}") + for obs in image_res.observations: + conf = f" ({obs.confidence})" if obs.confidence else "" + summary_items.append(f"[Image/{obs.category}] {obs.description}{conf}") + + if summary_items: + judge_tasks.append(self._judge_summary(summary_items, user_id, pending_ops=pending_ops)) + judge_domains.append("summary") + + # 4. Code annotations + code_res = item_state.get("code_judge") or item_state.get("code_result") + # Wait, let's look at the result schema. It's code_result + code_res = item_state.get("code_result") + if code_res and not code_res.is_empty: + items_data = [] + for ann in code_res.annotations: + parts = [ + ann.annotation_type.value, + ann.target_symbol or "", + ann.target_file or "", + ann.repo or "", + ann.severity.value if ann.severity else "", + ann.content, + ] + items_data.append(" | ".join(parts)) + judge_tasks.append(self._judge_code(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("code") + + # 5. Personal code snippets + snippet_res = item_state.get("snippet_result") + if snippet_res and not snippet_res.is_empty: + items_data = [] + for snip in snippet_res.snippets: + parts = [ + snip.content, + snip.code_snippet.replace("\n", "\\n") if snip.code_snippet else "", + snip.language, + snip.snippet_type.value, + ",".join(snip.tags), + ] + items_data.append(" | ".join(parts)) + judge_tasks.append(self._judge_snippet(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("snippet") + + if judge_tasks: + judge_results = await asyncio.gather(*judge_tasks, return_exceptions=True) + for domain_name, jr in zip(judge_domains, judge_results): + if isinstance(jr, Exception): + logger.error(f"Error during {domain_name} judge for batch item {idx}: {jr}") + item_state["errors"].append(f"{domain_name}_judge_error: {str(jr)}") + else: + item_state[f"{domain_name}_judge"] = jr + if domain_name == "summary" and image_res and not image_res.is_empty: + item_state["image_judge"] = jr + + if jr and jr.operations: + pending_ops.extend(jr.operations) + + # Phase C: Concurrently run Weaver to write changes in parallel across all items + weave_tasks = [] + weave_mappings = [] + + for phase_a_out in phase_a_outputs: + item_state = phase_a_out["item_state"] + idx = phase_a_out["idx"] + + # Profile + profile_judge = item_state.get("profile_judge") + if profile_judge: + weave_tasks.append(self._weave_profile(profile_judge, user_id)) + weave_mappings.append((item_state, "profile_weaver")) + + # Temporal + temporal_judge = item_state.get("temporal_judge") + if temporal_judge: + weave_tasks.append(self._weave_temporal(temporal_judge, user_id)) + weave_mappings.append((item_state, "temporal_weaver")) + + # Summary + summary_judge = item_state.get("summary_judge") + if summary_judge: + weave_tasks.append(self._weave_summary(summary_judge, user_id)) + weave_mappings.append((item_state, "summary_weaver")) + + # Image + image_judge = item_state.get("image_judge") + if image_judge and "image_result" in item_state: + weave_tasks.append(self._weave_summary(image_judge, user_id)) + weave_mappings.append((item_state, "image_weaver")) + + # Code + code_judge = item_state.get("code_judge") + if code_judge: + weave_tasks.append(self._weave_code(code_judge, user_id)) + weave_mappings.append((item_state, "code_weaver")) + + # Snippet + snippet_judge = item_state.get("snippet_judge") + if snippet_judge: + weave_tasks.append(self._weave_snippet(snippet_judge, user_id)) + weave_mappings.append((item_state, "snippet_weaver")) + + if weave_tasks: + weave_results = await asyncio.gather(*weave_tasks, return_exceptions=True) + for (item_state, key), wr in zip(weave_mappings, weave_results): + if isinstance(wr, Exception): + logger.error(f"Error during weaving for key {key}: {wr}") + item_state["errors"].append(f"{key}_error: {str(wr)}") + else: + item_state[key] = wr + + # Complete all items + for phase_a_out in phase_a_outputs: + item_state = phase_a_out["item_state"] + item_state["status"] = "completed" + + return [out["item_state"] for out in phase_a_outputs] + async def _run_high_effort( self, user_query: str, @@ -1153,13 +1475,7 @@ async def _run_high_effort( cfg: EffortConfig, disabled_domains: Optional[List[str]] = None, ) -> Dict[str, Any]: - """HIGH-effort path: chunk user_query → sequential pipeline calls → merge. - - Each chunk gets the full pipeline run independently and sequentially. - The ``agent_response`` is passed to every chunk so summary extraction - always has the full assistant context. Image is only forwarded to the - first chunk to avoid duplicate image processing. - """ + """HIGH-effort path: chunk user_query -> parallel staged staged batch run -> merge.""" chunks = chunk_text( user_query, chunk_size_tokens=cfg.chunk_size_tokens, @@ -1175,25 +1491,21 @@ async def _run_high_effort( cfg.chunk_threshold_tokens, ) - # Process every chunk through the pipeline sequentially to avoid duplicates. - # Image is only sent with chunk[0] to avoid duplicate processing. - chunk_results: List[Dict[str, Any]] = [] - for idx, chunk in enumerate(chunks): - logger.info("Processing chunk %d/%d...", idx + 1, len(chunks)) - res = await self._invoke_graph( - user_query=chunk, - agent_response=agent_response, - user_id=user_id, - session_datetime=session_datetime, - image_url=image_url if idx == 0 else "", - disabled_domains=disabled_domains, - ) - chunk_results.append(res) + batch_items = [ + { + "user_query": chunk, + "agent_response": agent_response, + "user_id": user_id, + "session_datetime": session_datetime, + "image_url": image_url if idx == 0 else "", + "disabled_domains": disabled_domains or [], + } + for idx, chunk in enumerate(chunks) + ] + + chunk_results = await self.run_staged_batch(batch_items, user_id=user_id) # ── Merge states ───────────────────────────────────────────── - # All writes (Pinecone / Neo4j) already happened inside each chunk's - # pipeline run. We merge the returned state dicts so callers get a - # sensible aggregate view. merged: Dict[str, Any] = {} all_errors: List[str] = [] @@ -1201,9 +1513,7 @@ async def _run_high_effort( # Accumulate errors from every chunk. all_errors.extend(state.get("errors") or []) - # For every key, prefer the last non-None value; this gives - # callers the final-chunk's extraction results while retaining - # earlier chunks' results when a later chunk produced nothing. + # For every key, prefer the last non-None value for key, value in state.items(): if key == "errors": continue diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 0000000..77f325d Binary files /dev/null and b/test_output.txt differ diff --git a/tests/test_batch_ingest.py b/tests/test_batch_ingest.py new file mode 100644 index 0000000..7d474d1 --- /dev/null +++ b/tests/test_batch_ingest.py @@ -0,0 +1,263 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch +from typing import Dict, Any + +from src.api.app import create_app +from src.api.schemas import BatchIngestRequest, IngestRequest +from src.pipelines.ingest import IngestPipeline + +@pytest.fixture +def client(): + app = create_app() + return TestClient(app) + +@pytest.fixture +def mock_ingest_pipeline(): + with patch("src.api.routes.memory.get_ingest_pipeline") as mock_get_pipeline: + from types import SimpleNamespace + mock_pipeline = AsyncMock(spec=IngestPipeline) + mock_pipeline.model = SimpleNamespace(model_name="test-model") + + # Default mock behavior + async def mock_run(*args, **kwargs): + return { + "classification_result": SimpleNamespace(classifications=["test"]), + "profile_judge": None, + "profile_weaver": None, + "temporal_judge": None, + "temporal_weaver": None, + "summary_judge": None, + "summary_weaver": None, + "image_judge": None, + "image_weaver": None, + } + + async def mock_run_staged_batch(items, user_id): + return [ + { + "classification_result": SimpleNamespace(classifications=["test"]), + "profile_judge": None, + "profile_weaver": None, + "temporal_judge": None, + "temporal_weaver": None, + "summary_judge": None, + "summary_weaver": None, + "image_judge": None, + "image_weaver": None, + } + for _ in items + ] + + mock_pipeline.run.side_effect = mock_run + mock_pipeline.run_staged_batch.side_effect = mock_run_staged_batch + mock_get_pipeline.return_value = mock_pipeline + yield mock_pipeline + +def test_batch_ingest_success(client, mock_ingest_pipeline): + """Test that multiple items can be successfully ingested in a batch.""" + payload = { + "items": [ + { + "user_query": "Hello world", + "agent_response": "Hi there", + "user_id": "test_user_1", + }, + { + "user_query": "Second message", + "agent_response": "Understood", + "user_id": "test_user_1", + } + ] + } + + # You must provide API key or mock dependency for require_api_key + # For test purposes, we assume we override the dependency or add a test key + # Let's mock require_api_key in dependencies + with patch("src.api.routes.memory.require_api_key", return_value={"username": "test_user"}): + app = client.app + from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready + app.dependency_overrides[require_api_key] = lambda: {"username": "test_user"} + app.dependency_overrides[enforce_rate_limit] = lambda: True + app.dependency_overrides[require_ready] = lambda: True + + response = client.post( + "/v1/memory/batch-ingest", + json=payload, + headers={"Authorization": "Bearer test-key"} + ) + + assert response.status_code == 200, response.json() + data = response.json() + assert data["status"] == "ok", data + assert len(data["data"]["results"]) == 2, data + for item in data["data"]["results"]: + assert item["model"] == "test-model" + + +def test_coordinator_serializes_concurrent_batches(client, mock_ingest_pipeline): + """Two concurrent batch-ingest requests for the same user must not overlap. + + We verify this by checking that all 4 pipeline.run calls were made + (2 items × 2 batches) and both requests succeed. + """ + import threading + + payload = { + "items": [ + { + "user_query": "Batch message 1", + "agent_response": "Ack 1", + "user_id": "same_user", + }, + { + "user_query": "Batch message 2", + "agent_response": "Ack 2", + "user_id": "same_user", + }, + ] + } + + with patch("src.api.routes.memory.require_api_key", return_value={"username": "same_user"}): + app = client.app + from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready + app.dependency_overrides[require_api_key] = lambda: {"username": "same_user"} + app.dependency_overrides[enforce_rate_limit] = lambda: True + app.dependency_overrides[require_ready] = lambda: True + + # Send two batch requests concurrently via threads + results = [None, None] + + def _send_batch(idx): + results[idx] = client.post( + "/v1/memory/batch-ingest", + json=payload, + headers={"Authorization": "Bearer test-key"}, + ) + + t1 = threading.Thread(target=_send_batch, args=(0,)) + t2 = threading.Thread(target=_send_batch, args=(1,)) + t1.start() + t2.start() + t1.join() + t2.join() + + # Both requests should succeed + for r in results: + assert r is not None + assert r.status_code == 200, r.json() + + # All 2 run_staged_batch calls (2 batches) should have been made + assert mock_ingest_pipeline.run_staged_batch.call_count == 2 + + +@pytest.mark.asyncio +async def test_run_staged_batch_overlays(): + """Verify that JudgeAgent applies simulated overlays for profile, temporal, and semantic domains when pending_ops are passed.""" + from unittest.mock import MagicMock + from src.agents.judge import JudgeAgent, JudgeDomain + from src.schemas.judge import Operation, OperationType + from src.storage.base import SearchResult + + # Mock the vector store and graph event search + mock_vector_store = MagicMock() + mock_graph_search = MagicMock() + + agent = JudgeAgent(model=MagicMock(), name="judge", system_prompt="system") + agent.vector_store = mock_vector_store + agent.graph_event_search = mock_graph_search + agent.top_k = 5 + + # 1. Test profile topic/sub-topic exact match + pending_profile_ops = [ + Operation( + type=OperationType.ADD, + content="work / company = XMem", + reason="User works at XMem", + embedding_id="pending_prof_1" + ) + ] + + mock_vector_store.search_by_metadata.return_value = [] + + res = await agent._fetch_similar( + items_strings=["work / company = XMem"], + new_items=[{"topic": "work", "sub_topic": "company", "memo": "XMem"}], + user_id="user_1", + domain=JudgeDomain.PROFILE, + pending_ops=pending_profile_ops + ) + + assert "work / company = XMem" in res + assert len(res["work / company = XMem"]) == 1 + match = res["work / company = XMem"][0] + assert match.id == "pending_prof_1" + assert match.score == 1.0 + assert match.metadata["domain"] == "profile" + + # 2. Test profile delete overlay (should clear matches) + pending_delete_profile_ops = [ + Operation( + type=OperationType.DELETE, + content="work / company = XMem", + reason="User left XMem", + embedding_id="pending_prof_1" + ) + ] + + res_del = await agent._fetch_similar( + items_strings=["work / company = XMem"], + new_items=[{"topic": "work", "sub_topic": "company", "memo": "XMem"}], + user_id="user_1", + domain=JudgeDomain.PROFILE, + pending_ops=pending_delete_profile_ops + ) + assert len(res_del["work / company = XMem"]) == 0 + + # 3. Test temporal event name match + pending_temporal_ops = [ + Operation( + type=OperationType.ADD, + content="Date: 05-22, 2026 | Event: Launch | Description: Final Release | Time: | Date expression: today", + reason="Launch event", + embedding_id="pending_temp_1" + ) + ] + + mock_graph_search.search_events_by_embedding.return_value = [] + res_temp = await agent._fetch_similar( + items_strings=["Launch event"], + new_items=[{"date": "05-22", "event_name": "Launch", "desc": "Final Release", "year": "2026", "time": "", "date_expression": "today"}], + user_id="user_1", + domain=JudgeDomain.TEMPORAL, + pending_ops=pending_temporal_ops + ) + assert len(res_temp["Launch event"]) == 1 + match_temp = res_temp["Launch event"][0] + assert match_temp.id == "pending_temp_1" + assert match_temp.score == 1.0 + + # 4. Test semantic similarity match (SequenceMatcher) + pending_summary_ops = [ + Operation( + type=OperationType.ADD, + content="Likes clean coding and unit testing", + reason="Clean code preference", + embedding_id="pending_sum_1" + ) + ] + + mock_vector_store.search_by_text = MagicMock(return_value=[]) + res_sum = await agent._fetch_similar( + items_strings=["Likes clean coding"], + new_items=[], + user_id="user_1", + domain=JudgeDomain.SUMMARY, + pending_ops=pending_summary_ops + ) + assert len(res_sum["Likes clean coding"]) == 1 + match_sum = res_sum["Likes clean coding"][0] + assert match_sum.id == "pending_sum_1" + assert match_sum.score > 0.5 + + diff --git a/tests/test_ingestion_coordinator.py b/tests/test_ingestion_coordinator.py new file mode 100644 index 0000000..2b8efdf --- /dev/null +++ b/tests/test_ingestion_coordinator.py @@ -0,0 +1,200 @@ +""" +Tests for the UserIngestionCoordinator. + +Validates per-user serialisation, cross-user parallelism, FIFO ordering, +lock cleanup, and exception safety. +""" + +import asyncio +import time + +import importlib.util +import os +import sys + +import pytest + +# Import the coordinator module directly from its file to avoid pulling in +# src.api.__init__ → src.api.app → src.config.Settings (requires env vars). +_coordinator_path = os.path.join( + os.path.dirname(__file__), os.pardir, "src", "api", "ingestion_coordinator.py" +) +_spec = importlib.util.spec_from_file_location( + "ingestion_coordinator", os.path.abspath(_coordinator_path) +) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +UserIngestionCoordinator = _mod.UserIngestionCoordinator + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +async def _timed_task(coordinator: UserIngestionCoordinator, user_id: str, duration: float, log: list): + """Acquire the user lock, record (start, end) timestamps, and sleep.""" + async with coordinator.acquire(user_id): + start = time.monotonic() + await asyncio.sleep(duration) + end = time.monotonic() + log.append((user_id, start, end)) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_sequential_for_same_user(): + """Concurrent tasks for the same user must execute one at a time (non-overlapping).""" + coordinator = UserIngestionCoordinator() + log: list = [] + + tasks = [ + asyncio.create_task(_timed_task(coordinator, "alice", 0.05, log)) + for _ in range(5) + ] + await asyncio.gather(*tasks) + + assert len(log) == 5 + + # Sort by start time and verify no overlaps + log.sort(key=lambda x: x[1]) + for i in range(1, len(log)): + prev_end = log[i - 1][2] + curr_start = log[i][1] + assert curr_start >= prev_end - 0.001, ( + f"Task {i} started at {curr_start:.4f} before task {i-1} ended at {prev_end:.4f}" + ) + + +@pytest.mark.asyncio +async def test_parallel_for_different_users(): + """Tasks for different users should run concurrently (overlapping in time).""" + coordinator = UserIngestionCoordinator() + log: list = [] + + users = ["alice", "bob", "charlie"] + tasks = [ + asyncio.create_task(_timed_task(coordinator, user, 0.1, log)) + for user in users + ] + await asyncio.gather(*tasks) + + assert len(log) == 3 + + # All three should start roughly at the same time (within 20ms of each other) + starts = sorted(entry[1] for entry in log) + spread = starts[-1] - starts[0] + assert spread < 0.05, ( + f"Different users should run in parallel, but start-time spread was {spread:.4f}s" + ) + + +@pytest.mark.asyncio +async def test_fifo_ordering(): + """Tasks for the same user should complete in submission order (FIFO).""" + coordinator = UserIngestionCoordinator() + completion_order: list = [] + + async def _ordered_task(task_id: int): + async with coordinator.acquire("user_fifo"): + await asyncio.sleep(0.02) + completion_order.append(task_id) + + # Create tasks in order 0, 1, 2, 3, 4 + tasks = [] + for i in range(5): + tasks.append(asyncio.create_task(_ordered_task(i))) + # Small delay to ensure submission order is deterministic + await asyncio.sleep(0.005) + + await asyncio.gather(*tasks) + + assert completion_order == [0, 1, 2, 3, 4], ( + f"Expected FIFO order [0,1,2,3,4], got {completion_order}" + ) + + +@pytest.mark.asyncio +async def test_lock_cleanup(): + """After all tasks complete, internal lock dict should be empty.""" + coordinator = UserIngestionCoordinator() + + async with coordinator.acquire("cleanup_user"): + assert coordinator.active_users == 1 + + # After context exit, lock should be cleaned up + assert coordinator.active_users == 0 + assert "cleanup_user" not in coordinator._locks + assert "cleanup_user" not in coordinator._waiters + + +@pytest.mark.asyncio +async def test_exception_safety(): + """If a task raises inside the lock, the lock must still be released for subsequent tasks.""" + coordinator = UserIngestionCoordinator() + results: list = [] + + async def _failing_task(): + async with coordinator.acquire("error_user"): + raise ValueError("deliberate test error") + + async def _succeeding_task(): + async with coordinator.acquire("error_user"): + results.append("success") + + # First task fails + with pytest.raises(ValueError, match="deliberate test error"): + await _failing_task() + + # Second task should still be able to acquire the lock and succeed + await _succeeding_task() + + assert results == ["success"] + assert coordinator.active_users == 0 + + +@pytest.mark.asyncio +async def test_concurrent_same_user_does_not_deadlock(): + """Many concurrent acquires for the same user must all complete without deadlock.""" + coordinator = UserIngestionCoordinator() + counter = {"value": 0} + + async def _increment(): + async with coordinator.acquire("stress_user"): + counter["value"] += 1 + await asyncio.sleep(0.001) + + tasks = [asyncio.create_task(_increment()) for _ in range(20)] + await asyncio.gather(*tasks) + + assert counter["value"] == 20 + assert coordinator.active_users == 0 + + +@pytest.mark.asyncio +async def test_mixed_users_serialization(): + """Two users interleaving: each user's tasks are serial, but different users overlap.""" + coordinator = UserIngestionCoordinator() + log: list = [] + + # 3 tasks for alice, 3 tasks for bob — all launched concurrently + tasks = [] + for i in range(3): + tasks.append(asyncio.create_task(_timed_task(coordinator, "alice", 0.03, log))) + tasks.append(asyncio.create_task(_timed_task(coordinator, "bob", 0.03, log))) + + await asyncio.gather(*tasks) + + assert len(log) == 6 + + # Verify per-user serialisation + for user in ("alice", "bob"): + user_entries = sorted([e for e in log if e[0] == user], key=lambda x: x[1]) + for i in range(1, len(user_entries)): + prev_end = user_entries[i - 1][2] + curr_start = user_entries[i][1] + assert curr_start >= prev_end - 0.001, ( + f"{user} task {i} overlapped with task {i-1}" + ) diff --git a/xlsx.py b/xlsx.py new file mode 100644 index 0000000..bdaa05d --- /dev/null +++ b/xlsx.py @@ -0,0 +1,98 @@ +from openpyxl import Workbook +from openpyxl.styles import Font, PatternFill, Alignment, Border, Side +from openpyxl.utils import get_column_letter + +wb = Workbook() +wb.remove(wb.active) + +NAVY = "1B3A6B" +GOLD = "C9A84C" +WHITE = "FFFFFF" +LGRAY = "F4F6FA" +MGRAY = "8A99B0" +DGRAY = "2D3748" +TEAL = "0D9488" +RED = "C0392B" +GREEN_TX = "007A3D" +BLUE_TX = "0000FF" + +INR = '₹#,##0;(₹#,##0);"-"' +PCT = '0.0%;(0.0%);"-"' + +def side(style="thin", color="D1D5DB"): + return Side(style=style, color=color) + +def border(): + s = side() + return Border(top=s, bottom=s, left=s, right=s) + +def cell(ws, row, col, val=None, bold=False, bg=WHITE, fg=BLACK_TX, + size=10, align="left", fmt=None, italic=False): + c = ws.cell(row=row, column=col, value=val) + c.font = Font(name="Arial", bold=bold, color=fg, size=size, italic=italic) + c.fill = PatternFill("solid", start_color=bg) + c.alignment = Alignment(horizontal=align, vertical="center") + c.border = border() + if fmt: + c.number_format = fmt + return c + +def merge(ws, r1, c1, r2, c2, val=None, bold=False, bg=WHITE, fg=BLACK_TX): + ws.merge_cells(start_row=r1, start_column=c1, end_row=r2, end_column=c2) + c = ws.cell(row=r1, column=c1, value=val) + c.font = Font(name="Arial", bold=bold, color=fg, size=11) + c.fill = PatternFill("solid", start_color=bg) + c.alignment = Alignment(horizontal="center", vertical="center") + return c + +# ================= MASTER SHEET ================= +ws = wb.create_sheet("Master Budget") + +headers = ["Item", "Amount", "%"] +for i, h in enumerate(headers, 1): + cell(ws, 1, i, h, bold=True, bg=NAVY, fg=WHITE) + +data = [ + ("1st Prize", 25000), + ("2nd Prize", 12000), + ("3rd Prize", 6000), + ("Marketing", 78000), + ("Team", 25000), + ("Misc", 25000), +] + +row = 2 +for label, amount in data: + cell(ws, row, 1, label) + cell(ws, row, 2, amount, fg=BLUE_TX, align="center", fmt=INR) + row += 1 + +# Total +cell(ws, row, 1, "TOTAL", bold=True) +cell(ws, row, 2, f"=SUM(B2:B{row-1})", bold=True, fmt=INR) + +# Percent column +for r in range(2, row): + cell(ws, r, 3, f"=B{r}/B{row}", fmt=PCT, align="center") + +# ================= PRIZE SHEET ================= +ws2 = wb.create_sheet("Prize Structure") + +headers = ["Rank", "Cash Prize"] +for i, h in enumerate(headers, 1): + cell(ws2, 1, i, h, bold=True, bg=NAVY, fg=WHITE) + +prizes = [ + ("1st", 25000), + ("2nd", 12000), + ("3rd", 6000), +] + +for i, (rank, val) in enumerate(prizes, start=2): + cell(ws2, i, 1, rank) + cell(ws2, i, 2, val, fmt=INR, align="center") + +# ================= SAVE FILE ================= +wb.save("DSA_Budget.xlsx") + +print("✅ Excel file saved as DSA_Budget.xlsx in your folder") \ No newline at end of file