Skip to content
4 changes: 3 additions & 1 deletion src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from src.api.routes.enterprise import router as enterprise_router
from src.api.routes.health import router as health_router
from src.api.routes.memory import router as memory_router
from src.api.routes.memory import search_router as memory_search_router
from src.api.routes.memory import scrape_router as memory_scrape_router
from src.api.routes.memory import v2_router as memory_v2_router
from src.api.routes.memory import v2_scrape_router as memory_v2_scrape_router
Expand Down Expand Up @@ -157,6 +158,7 @@ async def lifespan(app: FastAPI):
# ── Routes ────────────────────────────────────────────────────────
app.include_router(health_router)
app.include_router(memory_scrape_router)
app.include_router(memory_search_router)
app.include_router(memory_router)
app.include_router(memory_v2_scrape_router)
app.include_router(memory_v2_router)
Expand Down Expand Up @@ -192,7 +194,7 @@ async def prometheus_metrics():
async def sentry_debug():
"""Intentionally raises an error to verify Sentry is capturing exceptions."""
try:
division_by_zero = 1 / 0
1 / 0
except ZeroDivisionError as exc:
from src.config.monitoring import capture_exception
capture_exception(exc)
Expand Down
3 changes: 2 additions & 1 deletion src/api/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .health import router as health_router
from .memory import router as memory_router
from .memory import search_router as memory_search_router

__all__ = ["health_router", "memory_router"]
__all__ = ["health_router", "memory_router", "memory_search_router"]
73 changes: 59 additions & 14 deletions src/api/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
import logging
import math
import threading
import time
from typing import Any, Dict, List
Expand Down Expand Up @@ -83,6 +84,11 @@
dependencies=[Depends(enforce_rate_limit)],
)

search_router = APIRouter(
tags=["memory"],
dependencies=[Depends(require_ready), Depends(enforce_rate_limit)],
)


# Helpers
def _model_name(model: Any) -> str:
Expand Down Expand Up @@ -187,6 +193,7 @@ async def _run_ingest_payload(
image_url=payload.get("image_url", ""),
effort_level=payload.get("effort_level", "low"),
)
_invalidate_profile_cache(user_id)
data = IngestResponse(
model=_model_name(pipeline.model),
classification=_safe_classifications(result),
Expand Down Expand Up @@ -233,6 +240,14 @@ def _schedule_job(job: Dict[str, Any], handler) -> None:
asyncio.create_task(run_job(get_default_job_store(), job["job_id"], handler))


def _safe_score(score: Any) -> float:
try:
value = float(score)
except (TypeError, ValueError):
return 0.0
return value if math.isfinite(value) else 0.0


def _detect_chat_provider(*urls: str) -> str:
for url in urls:
lowered = (url or "").lower()
Expand Down Expand Up @@ -750,6 +765,13 @@ def _safe_classifications(result: Dict[str, Any]) -> list:
return []


def _invalidate_profile_cache(user_id: str) -> None:
try:
get_retrieval_pipeline().invalidate_profile_cache(user_id)
except Exception as exc:
logger.warning("Failed to invalidate profile cache for user=%s: %s", user_id, exc)


async def _read_user_job(job_id: str, user_id: str) -> Dict[str, Any] | None:
job = await asyncio.to_thread(get_default_job_store().get, job_id)
if not job:
Expand Down Expand Up @@ -885,13 +907,14 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D
sources=[
SourceRecord(
domain=s.domain, content=s.content,
score=round(s.score, 3), metadata=s.metadata,
score=round(_safe_score(s.score), 3), metadata=s.metadata,
)
for s in result.sources
],
confidence=result.confidence,
)
elapsed = round((time.perf_counter() - start) * 1000, 2)
pipeline.record_latency("agentic", elapsed)
return _wrap(request, data, elapsed)

except Exception as exc:
Expand All @@ -901,10 +924,15 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D


# POST /v1/memory/search
@search_router.post(
"/search",
response_model=APIResponse,
summary="Raw semantic search across memory domains with optional answer synthesis",
)
@router.post(
"/search",
response_model=APIResponse,
summary="Raw semantic search across memory domains (no LLM answer)",
summary="Raw semantic search across memory domains with optional answer synthesis",
)
async def search_memory(req: SearchRequest, request: Request, user: dict = Depends(require_api_key)):
start = time.perf_counter()
Expand All @@ -914,17 +942,34 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
user_id = user.get("username") or user.get("name") or user["id"]

try:
all_results: List[SourceRecord] = []

if "profile" in req.domains:
all_results.extend(_search_profile(pipeline, user_id))
if "temporal" in req.domains:
all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k))
if "summary" in req.domains:
all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k))
all_results = await pipeline.search_raw(
query=req.query,
user_id=user_id,
domains=req.domains,
top_k=req.top_k,
)
answer = ""
if req.answer:
answer = await pipeline.answer_from_sources(req.query, all_results)

data = SearchResponse(results=all_results, total=len(all_results))
elapsed = round((time.perf_counter() - start) * 1000, 2)
pipeline.record_latency("answer" if req.answer else "raw", elapsed)
data = SearchResponse(
results=[
SourceRecord(
domain=s.domain,
content=s.content,
score=round(_safe_score(s.score), 3),
metadata=s.metadata,
)
for s in all_results
],
total=len(all_results),
answer=answer,
model=_model_name(pipeline.model) if req.answer else "",
confidence=min(1.0, len(all_results) * 0.2) if answer else 0.0,
latency=pipeline.get_latency_snapshot(),
)
return _wrap(request, data, elapsed)

except Exception as exc:
Expand All @@ -938,7 +983,7 @@ def _search_profile(pipeline: RetrievalPipeline, user_id: str) -> List[SourceRec
raw = pipeline.vector_store.search_by_metadata(
filters={"user_id": user_id, "domain": "profile"}, top_k=100,
)
return [SourceRecord(domain="profile", content=r.content, score=r.score, metadata=r.metadata) for r in raw]
return [SourceRecord(domain="profile", content=r.content, score=_safe_score(r.score), metadata=r.metadata) for r in raw]
except Exception as exc:
logger.warning("Profile search error: %s", exc)
return []
Expand All @@ -965,7 +1010,7 @@ def _search_temporal(pipeline: RetrievalPipeline, query: str, user_id: str, top_
parts.append(f"Time: {ev['time']}")
results.append(SourceRecord(
domain="temporal", content=" | ".join(parts),
score=ev.get("similarity_score", 0.0), metadata=ev,
score=_safe_score(ev.get("similarity_score", 0.0)), metadata=ev,
))
return results
except Exception as exc:
Expand All @@ -980,7 +1025,7 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str,
filters={"user_id": user_id, "domain": "summary"},
)
return [
SourceRecord(domain="summary", content=r.content, score=r.score, metadata={"id": r.id, **r.metadata})
SourceRecord(domain="summary", content=r.content, score=_safe_score(r.score), metadata={"id": r.id, **r.metadata})
for r in raw
]
except Exception as exc:
Expand Down
13 changes: 10 additions & 3 deletions src/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -159,15 +158,19 @@ class SearchRequest(BaseModel):
..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$",
)
domains: List[str] = Field(
default=["profile", "temporal", "summary"],
default=["profile", "temporal", "summary", "snippet", "code"],
description="Which memory domains to search",
)
top_k: int = Field(default=10, ge=1, le=100)
answer: bool = Field(
default=False,
description="When true, synthesize an answer from the raw hits without agentic tool selection.",
)

@field_validator("domains")
@classmethod
def validate_domains(cls, v: List[str]) -> List[str]:
allowed = {"profile", "temporal", "summary"}
allowed = {"profile", "temporal", "summary", "snippet", "code"}
for d in v:
if d not in allowed:
raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}")
Expand All @@ -177,6 +180,10 @@ def validate_domains(cls, v: List[str]) -> List[str]:
class SearchResponse(BaseModel):
results: List[SourceRecord] = Field(default_factory=list)
total: int = 0
answer: str = ""
model: str = ""
confidence: float = 0.0
latency: Dict[str, Dict[str, float | int]] = Field(default_factory=dict)


# ── Scrape (extract from shared chat links) ────────────────────────────────
Expand Down
Loading
Loading