diff --git a/backend/app/api/admin_routes/knowledge_base/chunk/models.py b/backend/app/api/admin_routes/knowledge_base/chunk/models.py index f439ace47..79b0ec726 100644 --- a/backend/app/api/admin_routes/knowledge_base/chunk/models.py +++ b/backend/app/api/admin_routes/knowledge_base/chunk/models.py @@ -1,10 +1,11 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig class KBChunkRetrievalConfig(BaseModel): vector_search: VectorSearchRetrieverConfig + score_threshold: float = Field(gt=0, lt=1, default=0.3) # TODO: add fulltext and knowledge graph search config diff --git a/backend/app/api/admin_routes/knowledge_base/chunk/routes.py b/backend/app/api/admin_routes/knowledge_base/chunk/routes.py index afbb6fd9e..7199718f9 100644 --- a/backend/app/api/admin_routes/knowledge_base/chunk/routes.py +++ b/backend/app/api/admin_routes/knowledge_base/chunk/routes.py @@ -1,13 +1,13 @@ import logging -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.api.deps import SessionDep, CurrentSuperuserDep from app.rag.retrievers.chunk.simple_retriever import ( ChunkSimpleRetriever, ) from app.rag.retrievers.chunk.schema import ChunksRetrievalResult -from app.exceptions import InternalServerError, KBNotFound +from app.exceptions import InternalServerError from .models import KBRetrieveChunksRequest router = APIRouter() @@ -31,7 +31,7 @@ def retrieve_chunks( return retriever.retrieve_chunks( request.query, ) - except KBNotFound as e: + except HTTPException as e: raise e except Exception as e: logger.exception(e) diff --git a/backend/app/api/admin_routes/knowledge_base/graph/entity/routes.py b/backend/app/api/admin_routes/knowledge_base/graph/entity/routes.py new file mode 100644 index 000000000..e078264ec --- /dev/null +++ b/backend/app/api/admin_routes/knowledge_base/graph/entity/routes.py @@ -0,0 +1,164 @@ +import logging + +from typing import List, Annotated +from fastapi import APIRouter, HTTPException, Depends, Query +from fastapi_pagination import Params, Page + +from app.api.deps import SessionDep +from app.exceptions import InternalServerError +from app.models import EntityPublic, EntityType +from app.rag.indices.knowledge_graph.schema import ( + EntityCreate, + EntityFilters, + SynopsisEntityCreate, + EntityUpdate, +) +from app.rag.knowledge_base.index_store import ( + get_kb_graph_editor, + get_kb_tidb_graph_store, +) +from app.rag.retrievers.knowledge_graph.schema import ( + RetrievedEntity, + RetrievedKnowledgeGraph, +) +from app.repositories import knowledge_base_repo + +router = APIRouter( + prefix="/admin/knowledge_bases/{kb_id}/graph/entities", + tags=["knowledge_base/graph/entity"], +) +logger = logging.getLogger(__name__) + + +@router.get("/", response_model=Page[EntityPublic]) +def list_entities( + db_session: SessionDep, + kb_id: int, + filters: Annotated[EntityFilters, Query()] = EntityFilters(), + params: Params = Depends(), +): + try: + kb = knowledge_base_repo.must_get(db_session, kb_id) + graph_editor = get_kb_graph_editor(db_session, kb) + return graph_editor.query_entities(filters, params) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.post("/", response_model=EntityPublic) +def create_entity(session: SessionDep, kb_id: int, create: EntityCreate): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.create_entity(create) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.post("/synopsis", response_model=EntityPublic) +def create_synopsis_entity( + session: SessionDep, kb_id: int, create: SynopsisEntityCreate +): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.create_synopsis_entity(create) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.get( + "/search", +) +def search_similar_entities( + session: SessionDep, + kb_id: int, + query: str, + top_k: int = 10, + nprobe: int = 10, + entity_type: EntityType = EntityType.original, + similarity_threshold: float = 0.4, +) -> List[RetrievedEntity]: + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_store = get_kb_tidb_graph_store(session, kb) + return graph_store.retrieve_entities( + query=query, + top_k=top_k, + nprobe=nprobe, + entity_type=entity_type, + similarity_threshold=similarity_threshold, + ) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.get("/{entity_id}", response_model=EntityPublic) +def get_entity(session: SessionDep, kb_id: int, entity_id: int): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.must_get_entity(entity_id) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.put("/{entity_id}", response_model=EntityPublic) +def update_entity( + session: SessionDep, kb_id: int, entity_id: int, update: EntityUpdate +): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.update_entity(entity_id, update) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.delete("/{entity_id}") +def delete_entity(session: SessionDep, kb_id: int, entity_id: int): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + graph_editor.delete_entity(entity_id) + return { + "detail": "success", + } + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.get("/{entity_id}/subgraph") +def get_entity_subgraph( + session: SessionDep, kb_id: int, entity_id: int +) -> RetrievedKnowledgeGraph: + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.get_entity_subgraph(entity_id) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() diff --git a/backend/app/api/admin_routes/knowledge_base/graph/knowledge/routes.py b/backend/app/api/admin_routes/knowledge_base/graph/knowledge/routes.py index f01027a28..da7e99bea 100644 --- a/backend/app/api/admin_routes/knowledge_base/graph/knowledge/routes.py +++ b/backend/app/api/admin_routes/knowledge_base/graph/knowledge/routes.py @@ -1,27 +1,32 @@ from fastapi import HTTPException +from pydantic import BaseModel from starlette import status from app.api.admin_routes.knowledge_base.graph.models import ( KnowledgeRequest, KnowledgeNeighborRequest, - KnowledgeChunkRequest, ) from app.api.admin_routes.knowledge_base.graph.routes import router, logger from app.api.deps import SessionDep from app.exceptions import KBNotFound, InternalServerError -from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store +from app.rag.knowledge_base.index_store import ( + get_kb_tidb_graph_store, + get_kb_graph_editor, +) from app.repositories import knowledge_base_repo # Experimental interface -@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge") -def retrieve_knowledge(session: SessionDep, kb_id: int, request: KnowledgeRequest): +@router.post("/knowledge", deprecated=True) +def legacy_retrieve_knowledge( + session: SessionDep, kb_id: int, request: KnowledgeRequest +): try: kb = knowledge_base_repo.must_get(session, kb_id) graph_store = get_kb_tidb_graph_store(session, kb) - data = graph_store.retrieve_graph_data( + data = graph_store.retrieve_subgraph_by_similar( request.query, request.top_k, request.similarity_threshold, @@ -37,8 +42,8 @@ def retrieve_knowledge(session: SessionDep, kb_id: int, request: KnowledgeReques raise InternalServerError() -@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/neighbors") -def retrieve_knowledge_neighbors( +@router.post("/knowledge/neighbors", deprecated=True) +def legacy_retrieve_knowledge_neighbors( session: SessionDep, kb_id: int, request: KnowledgeNeighborRequest ): try: @@ -59,22 +64,24 @@ def retrieve_knowledge_neighbors( raise InternalServerError() -@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/chunks") -def retrieve_knowledge_chunks( +class KnowledgeChunkRequest(BaseModel): + pass + + +@router.post("/knowledge/chunks", deprecated=True) +def legacy_retrieve_knowledge_chunks( session: SessionDep, kb_id: int, request: KnowledgeChunkRequest ): try: kb = knowledge_base_repo.must_get(session, kb_id) - graph_store = get_kb_tidb_graph_store(session, kb) - data = graph_store.get_chunks_by_relationships(request.relationships_ids) + graph_editor = get_kb_graph_editor(session, kb) + data = graph_editor.batch_get_chunks_by_relationships(request.relationships_ids) if not data: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No chunks found for the given relationships", ) return data - except KBNotFound as e: - raise e except HTTPException as e: raise e except Exception as e: diff --git a/backend/app/api/admin_routes/knowledge_base/graph/models.py b/backend/app/api/admin_routes/knowledge_base/graph/models.py index cbe7e13da..046189ecf 100644 --- a/backend/app/api/admin_routes/knowledge_base/graph/models.py +++ b/backend/app/api/admin_routes/knowledge_base/graph/models.py @@ -1,31 +1,11 @@ from typing import List, Optional -from pydantic import BaseModel, model_validator +from pydantic import BaseModel from app.rag.retrievers.knowledge_graph.schema import ( KnowledgeGraphRetrieverConfig, ) -class SynopsisEntityCreate(BaseModel): - name: str - description: str - topic: str - meta: dict - entities: List[int] - - @model_validator(mode="after") - def validate_entities(self): - if len(self.entities) == 0: - raise ValueError("Entities list should not be empty") - return self - - -class EntityUpdate(BaseModel): - name: Optional[str] = None - description: Optional[str] = None - meta: Optional[dict] = None - - class RelationshipUpdate(BaseModel): description: Optional[str] = None meta: Optional[dict] = None @@ -70,5 +50,5 @@ class KnowledgeNeighborRequest(BaseModel): similarity_threshold: float = 0.55 -class KnowledgeChunkRequest(BaseModel): - relationships_ids: List[int] +class RelationshipBatchRequest(BaseModel): + relationship_ids: List[int] diff --git a/backend/app/api/admin_routes/knowledge_base/graph/relationship/routes.py b/backend/app/api/admin_routes/knowledge_base/graph/relationship/routes.py new file mode 100644 index 000000000..ccdb7f0da --- /dev/null +++ b/backend/app/api/admin_routes/knowledge_base/graph/relationship/routes.py @@ -0,0 +1,121 @@ +import logging + +from typing import Annotated, List +from fastapi import APIRouter, HTTPException, Query, Depends +from fastapi_pagination import Params, Page + +from app.api.admin_routes.knowledge_base.graph.models import ( + RelationshipUpdate, + RelationshipBatchRequest, +) +from app.api.deps import SessionDep +from app.exceptions import InternalServerError +from app.models import RelationshipPublic, Chunk as DBChunk +from app.rag.indices.knowledge_graph.schema import ( + RelationshipCreate, + RelationshipFilters, +) +from app.rag.knowledge_base.index_store import get_kb_graph_editor +from app.repositories import knowledge_base_repo + +router = APIRouter( + prefix="/admin/knowledge_bases/{kb_id}/graph/relationships", + tags=["knowledge_base/graph/relationship"], +) +logger = logging.getLogger(__name__) + + +@router.get("/", response_model=Page[RelationshipPublic]) +def query_relationships( + db_session: SessionDep, + kb_id: int, + filters: Annotated[RelationshipFilters, Query()] = RelationshipFilters(), + params: Params = Depends(), +): + try: + kb = knowledge_base_repo.must_get(db_session, kb_id) + graph_editor = get_kb_graph_editor(db_session, kb) + return graph_editor.query_relationships(filters, params) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.post("/", response_model=RelationshipPublic) +def create_relationship( + session: SessionDep, + kb_id: int, + create: RelationshipCreate, +): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.create_relationship(create) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.post("/chunks", response_model=List[DBChunk]) +def batch_get_chunks_by_relationships( + session: SessionDep, + kb_id: int, + request: RelationshipBatchRequest, +): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.batch_get_chunks_by_relationships(request.relationship_ids) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.get("/{relationship_id}", response_model=RelationshipPublic) +def get_relationship(session: SessionDep, kb_id: int, relationship_id: int): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.must_get_relationship(relationship_id) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.put("/{relationship_id}", response_model=RelationshipPublic) +def update_relationship( + session: SessionDep, + kb_id: int, + relationship_id: int, + update: RelationshipUpdate, +): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.update_relationship(relationship_id, update) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() + + +@router.delete("/{relationship_id}") +def delete_relationship(session: SessionDep, kb_id: int, relationship_id: int): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_editor = get_kb_graph_editor(session, kb) + return graph_editor.delete_relationship(relationship_id) + except HTTPException as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() diff --git a/backend/app/api/admin_routes/knowledge_base/graph/routes.py b/backend/app/api/admin_routes/knowledge_base/graph/routes.py index 6f6c51d1e..2b9e44649 100644 --- a/backend/app/api/admin_routes/knowledge_base/graph/routes.py +++ b/backend/app/api/admin_routes/knowledge_base/graph/routes.py @@ -1,26 +1,15 @@ import logging -from typing import List - -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException from app.api.admin_routes.knowledge_base.graph.models import ( - SynopsisEntityCreate, - EntityUpdate, - RelationshipUpdate, KBRetrieveKnowledgeGraphRequest, GraphSearchRequest, + KnowledgeRequest, ) from app.api.deps import SessionDep -from app.exceptions import KBNotFound, InternalServerError -from app.models import ( - EntityPublic, - RelationshipPublic, -) -from app.rag.retrievers.knowledge_graph.schema import ( - KnowledgeGraphRetrievalResult, -) +from app.exceptions import InternalServerError +from app.rag.retrievers.knowledge_graph.schema import KnowledgeGraphRetrievalResult from app.rag.knowledge_base.index_store import ( - get_kb_tidb_graph_editor, get_kb_tidb_graph_store, ) from app.rag.retrievers.knowledge_graph.simple_retriever import ( @@ -28,179 +17,15 @@ ) from app.repositories import knowledge_base_repo -router = APIRouter() -logger = logging.getLogger(__name__) - - -@router.get( - "/admin/knowledge_bases/{kb_id}/graph/entities/search", - response_model=List[EntityPublic], -) -def search_similar_entities( - session: SessionDep, kb_id: int, query: str, top_k: int = 10 -): - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - return tidb_graph_editor.search_similar_entities(session, query, top_k) - except KBNotFound as e: - raise e - except Exception as e: - # TODO: throw InternalServerError - raise e - - -@router.post( - "/admin/knowledge_bases/{kb_id}/graph/entities/synopsis", - response_model=EntityPublic, -) -def create_synopsis_entity( - session: SessionDep, kb_id: int, request: SynopsisEntityCreate -): - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - return tidb_graph_editor.create_synopsis_entity( - session, - request.name, - request.description, - request.topic, - request.meta, - request.entities, - ) - except KBNotFound as e: - raise e - except Exception as e: - # TODO: throw InternalServerError - raise e - - -@router.get( - "/admin/knowledge_bases/{kb_id}/graph/entities/{entity_id}", - response_model=EntityPublic, -) -def get_entity(session: SessionDep, kb_id: int, entity_id: int): - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - entity = tidb_graph_editor.get_entity(session, entity_id) - if not entity: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Entity not found", - ) - return entity - except KBNotFound as e: - raise e - except Exception as e: - # TODO: throw InternalServerError - raise e - - -@router.put( - "/admin/knowledge_bases/{kb_id}/graph/entities/{entity_id}", - response_model=EntityPublic, -) -def update_entity( - session: SessionDep, kb_id: int, entity_id: int, entity_update: EntityUpdate -): - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - old_entity = tidb_graph_editor.get_entity(session, entity_id) - if old_entity is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Entity not found", - ) - entity = tidb_graph_editor.update_entity( - session, old_entity, entity_update.model_dump() - ) - return entity - except KBNotFound as e: - raise e - except Exception as e: - # TODO: throw InternalServerError - raise e - - -@router.get("/admin/knowledge_bases/{kb_id}/graph/entities/{entity_id}/subgraph") -def get_entity_subgraph(session: SessionDep, kb_id: int, entity_id: int) -> dict: - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - entity = tidb_graph_editor.get_entity(session, entity_id) - if entity is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Entity not found", - ) - relationships, entities = tidb_graph_editor.get_entity_subgraph(session, entity) - return { - "relationships": relationships, - "entities": entities, - } - except KBNotFound as e: - raise e - except Exception as e: - logger.exception(e) - raise InternalServerError() - - -@router.get( - "/admin/knowledge_bases/{kb_id}/graph/relationships/{relationship_id}", - response_model=RelationshipPublic, +router = APIRouter( + prefix="/admin/knowledge_bases/{kb_id}/graph", + tags=["knowledge_base/graph"], ) -def get_relationship(session: SessionDep, kb_id: int, relationship_id: int): - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - relationship = tidb_graph_editor.get_relationship(session, relationship_id) - if relationship is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Relationship not found", - ) - return relationship - except KBNotFound as e: - raise e - except Exception as e: - # TODO: throw InternalServerError - raise e - - -@router.put( - "/admin/knowledge_bases/{kb_id}/graph/relationships/{relationship_id}", - response_model=RelationshipPublic, -) -def update_relationship( - session: SessionDep, - kb_id: int, - relationship_id: int, - relationship_update: RelationshipUpdate, -): - try: - kb = knowledge_base_repo.must_get(session, kb_id) - tidb_graph_editor = get_kb_tidb_graph_editor(session, kb) - old_relationship = tidb_graph_editor.get_relationship(session, relationship_id) - if old_relationship is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Relationship not found", - ) - relationship = tidb_graph_editor.update_relationship( - session, old_relationship, relationship_update.model_dump() - ) - return relationship - except KBNotFound as e: - raise e - except Exception as e: - # TODO: throw InternalServerError - raise e +logger = logging.getLogger(__name__) -@router.post("/admin/knowledge_bases/{kb_id}/graph/retrieve") -def retrieve_kb_knowledge_graph( +@router.post("/retrieve") +def retrieve_knowledge_graph( db_session: SessionDep, kb_id: int, request: KBRetrieveKnowledgeGraphRequest ) -> KnowledgeGraphRetrievalResult: try: @@ -214,14 +39,34 @@ def retrieve_kb_knowledge_graph( entities=knowledge_graph.entities, relationships=knowledge_graph.relationships, ) - except KBNotFound as e: + except HTTPException as e: raise e except Exception as e: - # TODO: throw InternalServerError + logger.exception(e) + raise InternalServerError() + + +@router.post("/knowledge", deprecated=True) +def retrieve_knowledge(session: SessionDep, kb_id: int, request: KnowledgeRequest): + try: + kb = knowledge_base_repo.must_get(session, kb_id) + graph_store = get_kb_tidb_graph_store(session, kb) + return graph_store.retrieve_subgraph_by_similar( + request.query, + request.top_k, + request.similarity_threshold, + ) + except HTTPException as e: raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() -@router.post("/admin/knowledge_bases/{kb_id}/graph/search", deprecated=True) +# Legacy + + +@router.post("/search", deprecated=True) def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchRequest): try: kb = knowledge_base_repo.must_get(session, kb_id) @@ -238,8 +83,8 @@ def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchReq "entities": entities, "relationships": relationships, } - except KBNotFound as e: + except HTTPException as e: raise e except Exception as e: - # TODO: throw InternalServerError - raise e + logger.exception(e) + raise InternalServerError() diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ff432c908..fc1078d92 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -13,6 +13,12 @@ from app.api.admin_routes.knowledge_base.graph.routes import ( router as admin_kb_graph_router, ) +from app.api.admin_routes.knowledge_base.graph.entity.routes import ( + router as admin_kb_graph_entity_router, +) +from app.api.admin_routes.knowledge_base.graph.relationship.routes import ( + router as admin_kb_graph_relationship_router, +) from app.api.admin_routes.knowledge_base.graph.knowledge.routes import ( router as admin_kb_graph_knowledge_router, ) @@ -71,6 +77,12 @@ api_router.include_router(admin_upload.router, tags=["admin/upload"]) api_router.include_router(admin_knowledge_base_router, tags=["admin/knowledge_base"]) api_router.include_router(admin_kb_graph_router, tags=["admin/knowledge_base/graph"]) +api_router.include_router( + admin_kb_graph_entity_router, tags=["admin/knowledge_base/graph/entity"] +) +api_router.include_router( + admin_kb_graph_relationship_router, tags=["admin/knowledge_base/graph/relationship"] +) api_router.include_router( admin_kb_graph_knowledge_router, tags=["admin/knowledge_base/graph/knowledge"] ) diff --git a/backend/app/core/db.py b/backend/app/core/db.py index b8718df01..66b4c8b7f 100644 --- a/backend/app/core/db.py +++ b/backend/app/core/db.py @@ -9,12 +9,12 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession -from app.core.config import settings, Environment +from app.core.config import settings -if settings.ENVIRONMENT == Environment.LOCAL: - logging.basicConfig() - logger = logging.getLogger("sqlalchemy.engine") - logger.setLevel(logging.DEBUG) +# if settings.ENVIRONMENT == Environment.LOCAL: +logging.basicConfig() +logger = logging.getLogger("sqlalchemy.engine") +logger.setLevel(logging.DEBUG) # TiDB Serverless clusters have a limitation: if there are no active connections for 5 minutes, # they will shut down, which closes all connections, so we need to recycle the connections diff --git a/backend/app/models/staff_action_log.py b/backend/app/models/staff_action_log.py index 3120287df..5cf27cb8e 100644 --- a/backend/app/models/staff_action_log.py +++ b/backend/app/models/staff_action_log.py @@ -8,6 +8,7 @@ class StaffActionLog(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) action: str action_time: datetime = Field(sa_column=Column(DateTime, server_default=func.now())) + # TODO: Add knowledge base ID. target_type: str target_id: int before: Dict = Field(default_factory=dict, sa_column=Column(JSON)) diff --git a/backend/app/rag/chat/chat_service.py b/backend/app/rag/chat/chat_service.py index e2a1b045d..9c0ab5992 100644 --- a/backend/app/rag/chat/chat_service.py +++ b/backend/app/rag/chat/chat_service.py @@ -121,14 +121,14 @@ def get_graph_data_from_chat_message( if "version" not in graph_data: kb = engine_config.get_knowledge_bases(db_session)[0] graph_store = get_kb_tidb_graph_store(db_session, kb) - return graph_store.get_subgraph_by_relationship_ids(graph_data["relationships"]) + return graph_store.list_relationships_by_ids(graph_data["relationships"]) # Stored Knowledge Graph -> Retrieved Knowledge Graph stored_kg = StoredKnowledgeGraph.model_validate(graph_data) if stored_kg.knowledge_base_id is not None: kb = knowledge_base_repo.must_get(db_session, stored_kg.knowledge_base_id) graph_store = get_kb_tidb_graph_store(db_session, kb) - retrieved_kg = graph_store.get_subgraph_by_relationship_ids( + retrieved_kg = graph_store.list_relationships_by_ids( ids=stored_kg.relationships, query=stored_kg.query ) return retrieved_kg @@ -150,7 +150,7 @@ def get_graph_data_from_chat_message( if kg_store is None: continue relationship_ids = stored_subgraph.relationships - subgraph = kg_store.get_subgraph_by_relationship_ids( + subgraph = kg_store.list_relationships_by_ids( ids=relationship_ids, query=stored_kg.query, ) diff --git a/backend/app/rag/indices/knowledge_graph/base.py b/backend/app/rag/indices/knowledge_graph/base.py index 419de07ca..7575ebcb9 100644 --- a/backend/app/rag/indices/knowledge_graph/base.py +++ b/backend/app/rag/indices/knowledge_graph/base.py @@ -3,15 +3,15 @@ from typing import Any, Dict, List, Optional, Sequence from llama_index.core.data_structs import IndexLPG -from llama_index.core.callbacks import CallbackManager from llama_index.core.indices.base import BaseIndex from llama_index.core.storage.docstore.types import RefDocInfo -from llama_index.core.storage.storage_context import StorageContext -from llama_index.core.schema import BaseNode, TransformComponent +from llama_index.core.schema import BaseNode import llama_index.core.instrumentation as instrument -from app.rag.indices.knowledge_graph.extractor import SimpleGraphExtractor -from app.rag.indices.knowledge_graph.graph_store import KnowledgeGraphStore +from sqlmodel import Session +from app.rag.indices.knowledge_graph.extractor import Extractor +from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore +from app.rag.indices.knowledge_graph.schema import AIEntity, EntityCreate logger = logging.getLogger(__name__) @@ -19,22 +19,11 @@ class KnowledgeGraphIndex(BaseIndex[IndexLPG]): - """An index for a property graph. + """An index for a knowledge graph. Args: - nodes (Optional[Sequence[BaseNode]]): - A list of nodes to insert into the index. dspy_lm (dspy.BaseLLM): The language model of dspy to use for extracting triplets. - callback_manager (Optional[CallbackManager]): - The callback manager to use. - transformations (Optional[List[TransformComponent]]): - A list of transformations to apply to the nodes before inserting them into the index. - These are applied prior to the `kg_extractors`. - storage_context (Optional[StorageContext]): - The storage context to use. - show_progress (bool): - Whether to show progress bars for transformations. Defaults to `False`. """ index_struct_cls = IndexLPG @@ -42,72 +31,110 @@ class KnowledgeGraphIndex(BaseIndex[IndexLPG]): def __init__( self, dspy_lm: dspy.LM, - kg_store: KnowledgeGraphStore, - nodes: Optional[Sequence[BaseNode]] = None, - # parent class params - callback_manager: Optional[CallbackManager] = None, + kg_store: TiDBGraphStore, **kwargs: Any, ) -> None: self._dspy_lm = dspy_lm self._kg_store = kg_store + self._kg_extractor = Extractor(dspy_lm=self._dspy_lm) super().__init__( - nodes=nodes, - callback_manager=callback_manager, **kwargs, ) @classmethod def from_existing( - cls: "KnowledgeGraphIndex", + cls, dspy_lm: dspy.LM, - kg_store: KnowledgeGraphStore, - # parent class params - callback_manager: Optional[CallbackManager] = None, - transformations: Optional[List[TransformComponent]] = None, - storage_context: Optional[StorageContext] = None, - show_progress: bool = False, + kg_store: TiDBGraphStore, **kwargs: Any, ) -> "KnowledgeGraphIndex": return cls( dspy_lm=dspy_lm, kg_store=kg_store, - nodes=[], # no nodes to insert - callback_manager=callback_manager, - transformations=transformations, - storage_context=storage_context, - show_progress=show_progress, **kwargs, ) - def _insert_nodes(self, nodes: Sequence[BaseNode]): + def insert_nodes(self, db_session: Session, nodes: Sequence[BaseNode]): """Insert nodes to the index struct.""" if len(nodes) == 0: return nodes - extractor = SimpleGraphExtractor(dspy_lm=self._dspy_lm) for node in nodes: - entities_df, rel_df = extractor.extract( - text=node.get_content(), - node=node, + self._inert_node(db_session, node) + + def _inert_node(self, db_session: Session, node: BaseNode): + node_id = node.node_id + logger.info("Extracting entities and relationships for node %s", node_id) + + knowledge_graph = self._kg_extractor.forward(text=node.get_content()) + if knowledge_graph.entities is None or knowledge_graph.relationships is None: + logger.warning( + f"Entities or relationships of node {node_id} are empty, not need to insert to index." + ) + return + + if self._kg_store.exists_chunk_relationships(node_id): + logger.info( + f"Node #{node_id} already exists in the relationship table, skip." + ) + return + + for extracted_entity in knowledge_graph.entities: + self._kg_store.find_or_create_entity( + EntityCreate( + name=extracted_entity.name, + description=extracted_entity.description, + meta=extracted_entity.meta, + ), + commit=False, ) - self._kg_store.save(node.node_id, entities_df, rel_df) + + for r in knowledge_graph.relationships: + source_entity = self._kg_store.find_or_create_entity( + EntityCreate( + name=r.source_entity, + description=r.source_entity_description, + ), + commit=False, + ) + target_entity = self._kg_store.find_or_create_entity( + EntityCreate( + name=r.target_entity, description=r.target_entity_description + ), + commit=False, + ) + self._kg_store.create_relationship( + source_entity, + target_entity, + r.relationship_desc, + metadata=node.metadata, + commit=False, + ) + + def _try_merge_entities(self, entities: List[AIEntity]) -> AIEntity: + logger.info(f"Trying to merge entities: {entities[0].name}") + try: + with dspy.settings.context(lm=self._dspy_lm): + pred = self.merge_entities_prog(entities=entities) + return pred.merged_entity + except Exception as e: + logger.error(f"Failed to merge entities: {e}", exc_info=True) + return None def _build_index_from_nodes(self, nodes: Optional[Sequence[BaseNode]]) -> IndexLPG: """Build index from nodes.""" - nodes = self._insert_nodes(nodes or []) + nodes = self.insert_nodes(nodes or []) return IndexLPG() def as_retriever(self, **kwargs: Any): """Return a retriever for the index.""" - # Our retriever params is more complex than the base retriever, - # so we can't use the base retriever. raise NotImplementedError( "Retriever not implemented for KnowledgeGraphIndex, use `retrieve_with_weight` instead." ) def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None: """Index-specific logic for inserting nodes to the index struct.""" - self._insert_nodes(nodes) + self.insert_nodes(nodes) def ref_doc_info(self) -> Dict[str, RefDocInfo]: """Retrieve a dict mapping of ingested documents and their nodes+metadata.""" diff --git a/backend/app/rag/indices/knowledge_graph/extractor.py b/backend/app/rag/indices/knowledge_graph/extractor.py index 3105b9909..43de594d7 100644 --- a/backend/app/rag/indices/knowledge_graph/extractor.py +++ b/backend/app/rag/indices/knowledge_graph/extractor.py @@ -7,9 +7,9 @@ from llama_index.core.schema import BaseNode from app.rag.indices.knowledge_graph.schema import ( - Entity, - Relationship, - KnowledgeGraph, + AIEntity, + AIRelationship, + AIKnowledgeGraph, EntityCovariateInput, EntityCovariateOutput, ) @@ -52,7 +52,7 @@ class ExtractGraphTriplet(dspy.Signature): text = dspy.InputField( desc="a paragraph of text to extract entities and relationships to form a knowledge graph" ) - knowledge: KnowledgeGraph = dspy.OutputField( + knowledge: AIKnowledgeGraph = dspy.OutputField( desc="Graph representation of the knowledge extracted from the text." ) @@ -118,7 +118,7 @@ def get_llm_output_config(self): "response_mime_type": "application/json", } - def forward(self, text): + def forward(self, text: str) -> AIKnowledgeGraph: with dspy.settings.context(lm=self.dspy_lm): pred_graph = self.prog_graph( text=text, @@ -166,8 +166,8 @@ def extract(self, text: str, node: BaseNode): def _to_df( self, - entities: list[Entity], - relationships: list[Relationship], + entities: list[AIEntity], + relationships: list[AIRelationship], extra_meta: Mapping[str, str], ): # Create lists to store dictionaries for entities and relationships diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py b/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py index 9fcdea577..d3d70dde3 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/__init__.py @@ -1,9 +1,7 @@ from .tidb_graph_store import TiDBGraphStore from .tidb_graph_editor import TiDBGraphEditor -from .tidb_graph_store import KnowledgeGraphStore __all__ = [ "TiDBGraphStore", "TiDBGraphEditor", - "KnowledgeGraphStore", ] diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/helpers.py b/backend/app/rag/indices/knowledge_graph/graph_store/helpers.py index 52fd2ffde..43b9d3050 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/helpers.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/helpers.py @@ -6,7 +6,7 @@ # The configuration for the weight coefficient # format: ((min_weight, max_weight), coefficient) -DEFAULT_WEIGHT_COEFFICIENT_CONFIG = [ +DEFAULT_WEIGHT_COEFFICIENTS = [ ((0, 100), 0.01), ((100, 1000), 0.001), ((1000, 10000), 0.0001), @@ -54,13 +54,13 @@ def calculate_relationship_score( in_degree: int, out_degree: int, alpha: float, - weight_coefficient_config: List[ + weight_coefficients: List[ Tuple[Tuple[int, int], float] - ] = DEFAULT_WEIGHT_COEFFICIENT_CONFIG, + ] = DEFAULT_WEIGHT_COEFFICIENTS, degree_coefficient: float = DEFAULT_DEGREE_COEFFICIENT, with_degree: bool = False, ) -> float: - weighted_score = get_weight_score(weight, weight_coefficient_config) + weighted_score = get_weight_score(weight, weight_coefficients) degree_score = 0 if with_degree: degree_score = get_degree_score(in_degree, out_degree, degree_coefficient) @@ -99,7 +99,7 @@ def get_entity_metadata_embedding( def get_relationship_description_embedding( source_entity_name: str, - source_entity_description, + source_entity_description: str, target_entity_name: str, target_entity_description: str, relationship_desc: str, diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/schema.py b/backend/app/rag/indices/knowledge_graph/graph_store/schema.py deleted file mode 100644 index be4637545..000000000 --- a/backend/app/rag/indices/knowledge_graph/graph_store/schema.py +++ /dev/null @@ -1,25 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple - -from sqlmodel import Session - - -class KnowledgeGraphStore(ABC): - @abstractmethod - def save(self, entities_df, relationships_df) -> None: - """Upsert entities and relationships to the graph store.""" - pass - - @abstractmethod - def retrieve_with_weight( - self, - query: str, - embedding: list, - depth: int = 2, - include_meta: bool = False, - with_degree: bool = False, - relationship_meta_filters: Dict = {}, - session: Optional[Session] = None, - ) -> Tuple[list, list, list]: - """Retrieve nodes and relationships with weights.""" - pass diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py index 22de2f37d..41222aac9 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_editor.py @@ -1,222 +1,187 @@ -from typing import Optional, Tuple, List, Type - -from llama_index.core.embeddings import resolve_embed_model -from llama_index.core.embeddings.utils import EmbedType -from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType -from sqlmodel import Session, select, SQLModel -from sqlalchemy.orm import joinedload -from sqlalchemy.orm.attributes import flag_modified +from typing import Optional, Type, List +from fastapi import HTTPException +from fastapi_pagination import Params, Page +from sqlmodel import Session, SQLModel + from app.models import EntityType -from app.rag.indices.knowledge_graph.schema import Relationship as RelationshipAIModel -from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore -from app.rag.indices.knowledge_graph.graph_store.helpers import ( - get_entity_description_embedding, - get_relationship_description_embedding, - get_entity_metadata_embedding, - get_query_embedding, +from app.rag.indices.knowledge_graph.schema import ( + EntityCreate, + SynopsisEntityCreate, + EntityUpdate, + EntityFilters, + RelationshipCreate, + RelationshipUpdate, + RelationshipFilters, ) +from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore +from app.rag.retrievers.knowledge_graph.schema import RetrievedKnowledgeGraph from app.staff_action import create_staff_action_log -# TODO: CRUD operations should move to TiDBGraphStore class TiDBGraphEditor: - _entity_db_model: Type[SQLModel] - _relationship_db_model: Type[SQLModel] - def __init__( self, - knowledge_base_id: int, - entity_db_model: Type[SQLModel], - relationship_db_model: Type[SQLModel], - embed_model: Optional[EmbedType] = None, + db_session: Session, + graph_store: TiDBGraphStore, ): - self.knowledge_base_id = knowledge_base_id - self._entity_db_model = entity_db_model - self._relationship_db_model = relationship_db_model - - if embed_model: - self._embed_model = resolve_embed_model(embed_model) - else: - self._embed_model = OpenAIEmbedding( - model=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL - ) + self._db_session = db_session + self._graph_store = graph_store - def get_entity(self, session: Session, entity_id: int) -> Optional[SQLModel]: - return session.get(self._entity_db_model, entity_id) - - def update_entity( - self, session: Session, entity: SQLModel, new_entity: dict - ) -> SQLModel: - old_entity_dict = entity.screenshot() - for key, value in new_entity.items(): - if value is not None: - setattr(entity, key, value) - flag_modified(entity, key) - entity.description_vec = get_entity_description_embedding( - entity.name, entity.description, self._embed_model - ) - entity.meta_vec = get_entity_metadata_embedding(entity.meta, self._embed_model) - for relationship in session.exec( - select(self._relationship_db_model) - .options( - joinedload(self._relationship_db_model.source_entity), - joinedload(self._relationship_db_model.target_entity), - ) - .where( - (self._relationship_db_model.source_entity_id == entity.id) - | (self._relationship_db_model.target_entity_id == entity.id) - ) - ): - relationship.description_vec = get_relationship_description_embedding( - relationship.source_entity.name, - relationship.source_entity.description, - relationship.target_entity.name, - relationship.target_entity.description, - relationship.description, - self._embed_model, - ) - session.add(relationship) - session.commit() - session.refresh(entity) - new_entity_dict = entity.screenshot() + # Entities. + + def query_entities( + self, + filters: Optional[EntityFilters] = EntityFilters(), + params: Params = Params(), + ) -> Page[SQLModel]: + return self._graph_store.fetch_entities_page(filters, params) + + def create_entity(self, create: EntityCreate) -> SQLModel: + entity = self._graph_store.create_entity(create, commit=True) create_staff_action_log( - session, "update", "entity", entity.id, old_entity_dict, new_entity_dict + self._db_session, + "create_original_entity", + "entity", + entity.id, + {}, + entity.screenshot(), + commit=True, ) return entity - def get_entity_subgraph( - self, session: Session, entity: SQLModel - ) -> Tuple[list, list]: - """ - Get the subgraph of an entity, including all related relationships and entities. - """ - relationships_queryset = session.exec( - select(self._relationship_db_model) - .options( - joinedload(self._relationship_db_model.source_entity), - joinedload(self._relationship_db_model.target_entity), + def create_synopsis_entity(self, create: SynopsisEntityCreate) -> SQLModel: + synopsis_entity = self._graph_store.create_entity(create, commit=False) + + # Create relationships between synopsis entity and related entities. + related_entities = self._graph_store.list_entities( + EntityFilters(entity_ids=create.entities) + ) + for related_entity in related_entities: + self._graph_store.create_relationship( + source_entity=synopsis_entity, + target_entity=related_entity, + description=f"{related_entity.name} is a part of synopsis entity (name={synopsis_entity.name}, topic={create.topic})", + metadata={"relationship_type": EntityType.synopsis.value}, + commit=False, ) - .where( - (self._relationship_db_model.source_entity_id == entity.id) - | (self._relationship_db_model.target_entity_id == entity.id) + self._db_session.commit() + + create_staff_action_log( + self._db_session, + "create_synopsis_entity", + "entity", + synopsis_entity.id, + {}, + synopsis_entity.screenshot(), + commit=True, + ) + return synopsis_entity + + def must_get_entity(self, entity_id: int) -> Optional[Type[SQLModel]]: + entity = self._graph_store.get_entity_by_id(entity_id) + if entity is None: + raise HTTPException( + status_code=404, detail=f"Entity #{entity_id} is not found" ) + return entity + + def update_entity(self, entity_id: int, update: EntityUpdate) -> Type[SQLModel]: + old_entity = self.must_get_entity(entity_id) + old_entity_dict = old_entity.screenshot() + new_entity = self._graph_store.update_entity(old_entity, update, commit=True) + new_entity_dict = new_entity.screenshot() + create_staff_action_log( + self._db_session, + "update", + "entity", + entity_id, + old_entity_dict, + new_entity_dict, + ) + return new_entity + + def delete_entity(self, entity_id: int) -> Optional[Type[SQLModel]]: + old_entity = self.must_get_entity(entity_id) + old_entity_dict = old_entity.screenshot() + self._graph_store.delete_entity(old_entity, commit=True) + create_staff_action_log( + self._db_session, + "delete", + "entity", + entity_id, + old_entity_dict, + {}, + commit=True, ) - relationships = [] - entities = [] - entities_set = set() - for relationship in relationships_queryset: - entities_set.add(relationship.source_entity) - entities_set.add(relationship.target_entity) - relationships.append(relationship.screenshot()) + return old_entity - for entity in entities_set: - entities.append(entity.screenshot()) + def get_entity_subgraph(self, entity_id: int) -> RetrievedKnowledgeGraph: + entity = self.must_get_entity(entity_id) + return self._graph_store.list_relationships_by_connected_entity(entity.id) - return relationships, entities + # Relationships. + + def must_get_relationship(self, relationship_id: int) -> Optional[Type[SQLModel]]: + entity = self._graph_store.get_relationship_by_id(relationship_id) + if entity is None: + raise HTTPException( + status_code=404, detail=f"Relationship #{relationship_id} not found" + ) + return entity - def get_relationship( - self, session: Session, relationship_id: int - ) -> Optional[SQLModel]: - return session.get(self._relationship_db_model, relationship_id) + def create_relationship(self, create: RelationshipCreate) -> SQLModel: + source_entity = self.must_get_entity(create.source_entity_id) + target_entity = self.must_get_entity(create.target_entity_id) + new_relationship = self._graph_store.create_relationship( + source_entity=source_entity, + target_entity=target_entity, + description=create.description, + metadata=create.metadata, + commit=True, + ) + new_relationship_dict = new_relationship.screenshot() + create_staff_action_log( + self._db_session, + "create", + "relationship", + new_relationship.id, + {}, + new_relationship_dict, + commit=True, + ) + return new_relationship def update_relationship( - self, session: Session, relationship: SQLModel, new_relationship: dict - ) -> SQLModel: - old_relationship_dict = relationship.screenshot() - for key, value in new_relationship.items(): - if value is not None: - setattr(relationship, key, value) - flag_modified(relationship, key) - relationship.description_vec = get_relationship_description_embedding( - relationship.source_entity.name, - relationship.source_entity.description, - relationship.target_entity.name, - relationship.target_entity.description, - relationship.description, - self._embed_model, + self, relationship_id: int, update: RelationshipUpdate + ) -> Type[SQLModel]: + old_relationship = self.must_get_relationship(relationship_id) + old_relationship_dict = old_relationship.screenshot() + new_relationship = self._graph_store.update_relationship( + old_relationship, update, commit=True ) - session.commit() - session.refresh(relationship) - new_relationship_dict = relationship.screenshot() - # FIXME: some error when create staff action log + new_relationship_dict = new_relationship.screenshot() create_staff_action_log( - session, + self._db_session, "update", "relationship", - relationship.id, + old_relationship.id, old_relationship_dict, new_relationship_dict, ) - return relationship - - def search_similar_entities( - self, session: Session, query: str, top_k: int = 10 - ) -> list: - embedding = get_query_embedding(query, self._embed_model) - return session.exec( - select(self._entity_db_model) - .where(self._entity_db_model.entity_type == EntityType.original) - .order_by(self._entity_db_model.description_vec.cosine_distance(embedding)) - .limit(top_k) - ).all() - - def create_synopsis_entity( + return new_relationship + + def query_relationships( self, - session: Session, - name: str, - description: str, - topic: str, - meta: dict, - related_entities_ids: List[int], - ) -> SQLModel: - # with session.begin(): - synopsis_entity = self._entity_db_model( - name=name, - description=description, - description_vec=get_entity_description_embedding( - name, description, self._embed_model - ), - meta=meta, - meta_vec=get_entity_metadata_embedding(meta, self._embed_model), - entity_type=EntityType.synopsis, - synopsis_info={ - "entities": related_entities_ids, - "topic": topic, - }, - ) - session.add(synopsis_entity) - graph_store = TiDBGraphStore( - knowledge_base=self.knowledge_base_id, - dspy_lm=None, - session=session, - embed_model=self._embed_model, - entity_db_model=self._entity_db_model, - relationship_db_model=self._relationship_db_model, - ) - for related_entity in session.exec( - select(self._entity_db_model).where( - self._entity_db_model.id.in_(related_entities_ids) - ) - ).all(): - graph_store.create_relationship( - synopsis_entity, - related_entity, - RelationshipAIModel( - source_entity=synopsis_entity.name, - target_entity=related_entity.name, - relationship_desc=f"{related_entity.name} is a part of synopsis entity (name={synopsis_entity.name}, topic={topic})", - ), - {"relationship_type": EntityType.synopsis.value}, - commit=False, - ) - session.commit() - create_staff_action_log( - session, - "create_synopsis_entity", - "entity", - synopsis_entity.id, - {}, - synopsis_entity.screenshot(), - commit=False, - ) - return synopsis_entity + filters: Optional[RelationshipFilters] = RelationshipFilters(), + params: Params = Params(), + ) -> Page[Type[SQLModel]]: + return self._graph_store.fetch_relationships_page(filters, params) + + def delete_relationship(self, relationship_id: int) -> None: + relationship = self.must_get_relationship(relationship_id) + self._graph_store.delete_relationship(relationship, commit=True) + + def batch_get_chunks_by_relationships( + self, relationship_ids: List[int] + ) -> List[Type[SQLModel]]: + return self._graph_store.batch_get_chunks_by_relationships(relationship_ids) diff --git a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py index 58989e9d9..a94e37312 100644 --- a/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py +++ b/backend/app/rag/indices/knowledge_graph/graph_store/tidb_graph_store.py @@ -1,21 +1,23 @@ -import dspy import logging import numpy as np import tidb_vector -from dspy.functional import TypedPredictor +import sqlalchemy from deepdiff import DeepDiff -from typing import List, Optional, Tuple, Dict, Set, Type, Any -from collections import defaultdict +from typing import List, Optional, Tuple, Dict, Set, Type, Sequence +from fastapi_pagination import Params, Page +from fastapi_pagination.ext.sqlmodel import paginate from llama_index.core.embeddings.utils import EmbedType, resolve_embed_model -from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingModelType -import sqlalchemy -from sqlmodel import Session, asc, func, select, text, SQLModel from sqlalchemy.orm import aliased, defer, joinedload +from sqlalchemy.orm.attributes import flag_modified +from sqlmodel import Session, asc, func, select, text, SQLModel, or_ from tidb_vector.sqlalchemy import VectorAdaptor -from sqlalchemy import or_, desc +from sqlalchemy import desc from app.core.db import engine +from app.models.chunk import get_kb_chunk_model +from app.models.entity import get_kb_entity_model +from app.models.relationship import get_kb_relationship_model from app.rag.indices.knowledge_graph.graph_store.helpers import ( get_entity_description_embedding, get_relationship_description_embedding, @@ -23,15 +25,18 @@ get_entity_metadata_embedding, get_query_embedding, DEFAULT_RANGE_SEARCH_CONFIG, - DEFAULT_WEIGHT_COEFFICIENT_CONFIG, + DEFAULT_WEIGHT_COEFFICIENTS, DEFAULT_DEGREE_COEFFICIENT, ) -from app.rag.indices.knowledge_graph.graph_store.schema import KnowledgeGraphStore from app.rag.indices.knowledge_graph.schema import ( - Entity, - Relationship, - SynopsisEntity, + EntityCreate, + EntityDegree, + EntityFilters, + EntityUpdate, + RelationshipUpdate, + RelationshipFilters, ) +from app.rag.knowledge_base.config import get_kb_embed_model from app.rag.retrievers.knowledge_graph.schema import ( RetrievedEntity, RetrievedRelationship, @@ -40,7 +45,6 @@ from app.models import ( Entity as DBEntity, Relationship as DBRelationship, - Chunk as DBChunk, KnowledgeBase, ) from app.models import EntityType @@ -52,87 +56,61 @@ def cosine_distance(v1, v2): return 1 - np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) -class MergeEntities(dspy.Signature): - """As a knowledge expert assistant specialized in database technologies, evaluate the two provided entities. These entities have been pre-analyzed and have same name but different descriptions and metadata. - Please carefully review the detailed descriptions and metadata for both entities to determine if they genuinely represent the same concept or object(entity). - If you conclude that the entities are identical, merge the descriptions and metadata fields of the two entities into a single consolidated entity. - If the entities are distinct despite their same name that may be due to different contexts or perspectives, do not merge the entities and return none as the merged entity. - - Considerations: Ensure your decision is based on a comprehensive analysis of the content and context provided within the entity descriptions and metadata. - Please only response in JSON Format. - """ - - entities: List[Entity] = dspy.InputField( - desc="List of entities identified from previous analysis." - ) - merged_entity: Optional[Entity] = dspy.OutputField( - desc="Merged entity with consolidated descriptions and metadata." - ) - - -class MergeEntitiesProgram(dspy.Module): - def __init__(self): - self.prog = TypedPredictor(MergeEntities) - - def forward(self, entities: List[Entity]): - if len(entities) != 2: - raise ValueError("The input should contain exactly two entities") - - pred = self.prog(entities=entities) - return pred - - -class TiDBGraphStore(KnowledgeGraphStore): +class TiDBGraphStore: def __init__( self, + db_session: Session, knowledge_base: KnowledgeBase, - dspy_lm: dspy.LM, - session: Optional[Session] = None, - embed_model: Optional[EmbedType] = None, - description_similarity_threshold=0.9, - entity_db_model: Type[SQLModel] = DBEntity, - relationship_db_model: Type[SQLModel] = DBRelationship, - chunk_db_model: Type[SQLModel] = DBChunk, + embed_model: EmbedType, + entity_model: Type[SQLModel], + relationship_model: Type[SQLModel], + chunk_model: Type[SQLModel], + entity_distance_threshold: Optional[float] = 0.1, ): - self.knowledge_base = knowledge_base - self._session = session - self._owns_session = session is None - if self._session is None: - self._session = Session(engine) - self._dspy_lm = dspy_lm - - if embed_model: - self._embed_model = resolve_embed_model(embed_model) - else: - self._embed_model = OpenAIEmbedding( - model=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL - ) - - self.merge_entities_prog = MergeEntitiesProgram() - self.description_cosine_distance_threshold = ( - 1 - description_similarity_threshold + self._db_session = db_session + self._embed_model = resolve_embed_model(embed_model) + self._entity_db_model = entity_model + self._relationship_db_model = relationship_model + self._chunk_db_model = chunk_model + self._knowledge_base = knowledge_base + self._entity_distance_threshold = entity_distance_threshold + + @classmethod + def from_knowledge_base( + cls, knowledge_base: KnowledgeBase, db_session: Session + ) -> "TiDBGraphStore": + embed_model = get_kb_embed_model(db_session, knowledge_base) + entity_db_model = get_kb_entity_model(knowledge_base) + relationship_db_model = get_kb_relationship_model(knowledge_base) + chunk_db_model = get_kb_chunk_model(knowledge_base) + return cls( + db_session=db_session, + knowledge_base=knowledge_base, + embed_model=embed_model, + entity_model=entity_db_model, + relationship_model=relationship_db_model, + chunk_model=chunk_db_model, ) - self._entity_model = entity_db_model - self._relationship_model = relationship_db_model - self._chunk_model = chunk_db_model + + # Schema Operations def ensure_table_schema(self) -> None: inspector = sqlalchemy.inspect(engine) existed_table_names = inspector.get_table_names() - entities_table_name = self._entity_model.__tablename__ - relationships_table_name = self._relationship_model.__tablename__ + entities_table_name = self._entity_db_model.__tablename__ + relationships_table_name = self._relationship_db_model.__tablename__ if entities_table_name not in existed_table_names: - self._entity_model.metadata.create_all( - engine, tables=[self._entity_model.__table__] + self._entity_db_model.metadata.create_all( + engine, tables=[self._entity_db_model.__table__] ) # Add HNSW index to accelerate ann queries. VectorAdaptor(engine).create_vector_index( - self._entity_model.description_vec, tidb_vector.DistanceMetric.COSINE + self._entity_db_model.description_vec, tidb_vector.DistanceMetric.COSINE ) VectorAdaptor(engine).create_vector_index( - self._entity_model.meta_vec, tidb_vector.DistanceMetric.COSINE + self._entity_db_model.meta_vec, tidb_vector.DistanceMetric.COSINE ) logger.info( @@ -144,13 +122,13 @@ def ensure_table_schema(self) -> None: ) if relationships_table_name not in existed_table_names: - self._relationship_model.metadata.create_all( - engine, tables=[self._relationship_model.__table__] + self._relationship_db_model.metadata.create_all( + engine, tables=[self._relationship_db_model.__table__] ) # Add HNSW index to accelerate ann queries. VectorAdaptor(engine).create_vector_index( - self._relationship_model.description_vec, + self._relationship_db_model.description_vec, tidb_vector.DistanceMetric.COSINE, ) @@ -165,12 +143,12 @@ def ensure_table_schema(self) -> None: def drop_table_schema(self) -> None: inspector = sqlalchemy.inspect(engine) existed_table_names = inspector.get_table_names() - relationships_table_name = self._relationship_model.__tablename__ - entities_table_name = self._entity_model.__tablename__ + relationships_table_name = self._relationship_db_model.__tablename__ + entities_table_name = self._entity_db_model.__tablename__ if relationships_table_name in existed_table_names: - self._relationship_model.metadata.drop_all( - engine, tables=[self._relationship_model.__table__] + self._relationship_db_model.metadata.drop_all( + engine, tables=[self._relationship_db_model.__table__] ) logger.info( f"Relationships table <{relationships_table_name}> has been dropped successfully." @@ -181,8 +159,8 @@ def drop_table_schema(self) -> None: ) if entities_table_name in existed_table_names: - self._entity_model.metadata.drop_all( - engine, tables=[self._entity_model.__table__] + self._entity_db_model.metadata.drop_all( + engine, tables=[self._entity_db_model.__table__] ) logger.info( f"Entities table <{entities_table_name}> has been dropped successfully." @@ -192,151 +170,587 @@ def drop_table_schema(self) -> None: f"Entities table <{entities_table_name}> is not existed, not action to do." ) - def close_session(self) -> None: - # Always call this method is necessary to make sure the session is closed - if self._owns_session: - self._session.close() + # Entity Basic Operations - def save(self, chunk_id, entities_df, relationships_df): - if entities_df.empty or relationships_df.empty: - logger.info( - "Entities or relationships are empty, skip saving to the database" + def fetch_entities_page( + self, + filters: Optional[EntityFilters] = EntityFilters(), + params: Params = Params(), + ) -> Page[SQLModel]: + stmt = self._build_entities_query(filters) + return paginate(self._db_session, stmt, params) + + def list_entities( + self, filters: Optional[EntityFilters] = EntityFilters() + ) -> Sequence[SQLModel]: + stmt = self._build_entities_query(filters) + return self._db_session.exec(stmt).all() + + def _build_entities_query(self, filters: EntityFilters): + stmt = select(self._entity_db_model) + if filters.entity_type: + stmt = stmt.where(self._entity_db_model.entity_type == filters.entity_type) + if filters.search: + stmt = stmt.where( + or_( + self._entity_db_model.name.like(f"%{filters.search}%"), + self._entity_db_model.description.like(f"%{filters.search}%"), + ) ) - return + return stmt - if ( - self._session.exec( - select(self._relationship_model).where( - self._relationship_model.meta["chunk_id"] == chunk_id + def get_entity_by_id(self, entity_id: int) -> Type[SQLModel]: + return self._db_session.get(self._entity_db_model, entity_id) + + def must_get_entity_by_id(self, entity_id: int) -> Type[SQLModel]: + entity = self.get_entity_by_id(entity_id) + if entity is None: + raise ValueError(f"Entity <{entity_id}> does not exist") + return entity + + def create_entity(self, create: EntityCreate, commit: bool = True) -> SQLModel: + desc_vec = get_entity_description_embedding( + create.name, create.description, self._embed_model + ) + meta_vec = get_entity_metadata_embedding(create.meta, self._embed_model) + entity = self._entity_db_model( + name=create.name, + entity_type=EntityType.original, + description=create.description, + description_vec=desc_vec, + meta=create.meta, + meta_vec=meta_vec, + ) + + self._db_session.add(entity) + if commit: + self._db_session.commit() + self._db_session.refresh(entity) + else: + self._db_session.flush() + return entity + + def find_or_create_entity( + self, + create: EntityCreate, + commit: bool = True, + ) -> SQLModel: + most_similar_entity = self._get_the_most_similar_entity(create) + + if most_similar_entity is not None: + return most_similar_entity + + return self.create_entity(create, commit=commit) + + def update_entity( + self, entity: Type[SQLModel], update: EntityUpdate, commit: bool = True + ) -> Type[SQLModel]: + for key, value in update.model_dump().items(): + if value is None: + continue + setattr(entity, key, value) + flag_modified(entity, key) + + entity.description_vec = get_entity_description_embedding( + entity.name, entity.description, self._embed_model + ) + if update.meta is not None: + entity.meta_vec = get_entity_metadata_embedding( + entity.meta, self._embed_model + ) + self._db_session.add(entity) + + # Update linked relationships. + linked_relationships = self.list_relationships_by_connected_entity(entity.id) + for relationship in linked_relationships: + self.update_relationship(relationship, RelationshipUpdate(), commit) + + if commit: + self._db_session.commit() + self._db_session.refresh(entity) + else: + self._db_session.flush() + return entity + + def delete_entity(self, entity: Type[SQLModel], commit: bool = True): + # Delete linked relationships. + linked_relationships = self.list_entity_connected_relationships(entity.id) + for relationship in linked_relationships: + self._db_session.delete(relationship) + + self._db_session.delete(entity) + if commit: + self._db_session.commit() + else: + self._db_session.flush() + + def calc_entity_out_degree(self, entity_id: int) -> Optional[int]: + stmt = select(func.count(self._relationship_db_model.id)).where( + self._relationship_db_model.source_entity_id == entity_id + ) + return self._db_session.exec(stmt).one() + + def calc_entity_in_degree(self, entity_id: int) -> Optional[int]: + stmt = select(func.count(self._relationship_db_model.id)).where( + self._relationship_db_model.target_entity_id == entity_id + ) + return self._db_session.exec(stmt).one() + + def calc_entities_degrees(self, entity_ids: List[int]) -> List[EntityDegree]: + stmt = ( + select( + self._entity_db_model.id, + func.count(self._relationship_db_model.id) + .filter( + self._relationship_db_model.source_entity_id + == self._entity_db_model.id ) - ).first() - is not None - ): - logger.info(f"{chunk_id} already exists in the relationship table, skip.") - return - - entities_name_map = defaultdict(list) - for _, row in entities_df.iterrows(): - entities_name_map[row["name"]].append( - self.get_or_create_entity( - Entity( - name=row["name"], - description=row["description"], - metadata=row["meta"], - ), - commit=False, + .label("out_degree"), + func.count(self._relationship_db_model.id) + .filter( + self._relationship_db_model.target_entity_id + == self._entity_db_model.id ) + .label("in_degree"), ) + .where(self._entity_db_model.id.in_(entity_ids)) + .outerjoin(self._relationship_db_model) + .group_by(self._entity_db_model.id) + ) - def _find_or_create_entity_for_relation( - name: str, description: str - ) -> SQLModel: - _embedding = get_entity_description_embedding( - name, description, self._embed_model + results = self._db_session.exec(stmt).all() + return [ + EntityDegree( + entity_id=r.id, + in_degree=r.in_degree, + out_degree=r.out_degree, + degrees=r.in_degree + r.out_degree, ) - # Check entities_name_map first, if not found, then check the database - for e in entities_name_map.get(name, []): - if ( - cosine_distance(e.description_vec, _embedding) - < self.description_cosine_distance_threshold - ): - return e - return self.get_or_create_entity( - Entity( - name=name, - description=description, - metadata={"status": "need-revised"}, + for r in results + ] + + # Entities Retrieve Operations + + def retrieve_entities( + self, + query: str, + entity_type: EntityType = EntityType.original, + top_k: int = 10, + nprobe: Optional[int] = None, + similarity_threshold: Optional[float] = None, + ) -> List[RetrievedEntity]: + entities = self.search_similar_entities( + query=query, + top_k=top_k, + nprobe=nprobe, + entity_type=entity_type, + similarity_threshold=similarity_threshold, + ) + return [ + RetrievedEntity( + id=entity.id, + knowledge_base_id=self._knowledge_base.id, + entity_type=entity.entity_type, + name=entity.name, + description=entity.description, + meta=entity.meta, + similarity_score=similarity_score, + ) + for entity, similarity_score in entities + ] + + def search_similar_entities( + self, + query: Optional[str] = None, + query_embedding: List[float] = None, + top_k: int = 10, + nprobe: Optional[int] = None, + entity_type: EntityType = EntityType.original, + similarity_threshold: Optional[float] = None, + # TODO: Metadata filter + # TODO: include_metadata, include_metadata_keys, include_embeddings parameters + ) -> List[Tuple[SQLModel, float]]: + if query_embedding is None: + assert ( + query + ), "One of the parameters of `query` and `query_embedding` must be provided" + embedding = get_query_embedding(query, self._embed_model) + else: + embedding = query_embedding + + distance_threshold = 1 - similarity_threshold + entity_model = self._entity_db_model + nprobe = nprobe or top_k * 10 + + if entity_type == EntityType.synopsis: + return self._search_similar_synopsis_entities( + entity_model, embedding, top_k, distance_threshold + ) + else: + return self._search_similar_original_entities( + entity_model, embedding, top_k, distance_threshold, nprobe + ) + + def _search_similar_original_entities( + self, + entity_model: Type[SQLModel], + query_embedding: List[float], + top_k: int, + distance_threshold: float, + nprobe: int, + ) -> List[Tuple[SQLModel, float]]: + """ + For original entities, it leverages TiFlash's ANN search to efficiently retrieve the most similar entities + from a large-scale vector space. + + To optimize retrieval performance on ANN Index, there employ a two-phase retrieval strategy: + 1. Fetch more (nprobe) results from the ANN Index as candidates. + 2. Sort the candidates by distance and get the top-k results. + """ + subquery = ( + select( + entity_model.id, + entity_model.description_vec.cosine_distance(query_embedding).label( + "distance" ), - commit=False, ) + .order_by(asc("distance")) + .limit(nprobe) + .subquery("candidates") + ) + query = ( + select(entity_model, (1 - subquery.c.distance).label("similarity_score")) + .where(subquery.c.distance <= distance_threshold) + .where(entity_model.id == subquery.c.id) + .where(entity_model.entity_type == EntityType.original) + .order_by(desc("similarity_score")) + .limit(top_k) + ) + return self._db_session.exec(query).all() - try: - for _, row in relationships_df.iterrows(): - logger.info( - "save entities for relationship %s -> %s -> %s", - row["source_entity"], - row["relationship_desc"], - row["target_entity"], - ) - source_entity = _find_or_create_entity_for_relation( - row["source_entity"], row["source_entity_description"] - ) - target_entity = _find_or_create_entity_for_relation( - row["target_entity"], row["target_entity_description"] - ) + def _search_similar_synopsis_entities( + self, + entity_model: Type[SQLModel], + query_embedding: List[float], + top_k: int, + distance_threshold: float, + ) -> List[Tuple[SQLModel, float]]: + """ + For synopsis entities, it leverages TiKV to fetch the synopsis entity quickly by filtering by entity_type, + because the number of synopsis entities is very small, it is commonly faster than using TiFlash to perform + ANN search. + """ + hint = text(f"/*+ read_from_storage(tikv[{entity_model.__tablename__}]) */") + subquery = ( + select( + entity_model, + entity_model.description_vec.cosine_distance(query_embedding).label( + "distance" + ), + ) + .prefix_with(hint) + .where(entity_model.entity_type == EntityType.synopsis) + .order_by(asc("distance")) + .limit(top_k) + .subquery("candidates") + ) + query = ( + select(entity_model, (1 - subquery.c.distance).label("similarity_score")) + .where(subquery.c.distance <= distance_threshold) + .order_by(desc("similarity_score")) + .limit(top_k) + ) + return self._db_session.exec(query).all() - self.create_relationship( - source_entity, - target_entity, - Relationship( - source_entity=source_entity.name, - target_entity=target_entity.name, - relationship_desc=row["relationship_desc"], - ), - relationship_metadata=row["meta"], - commit=False, - ) + def _get_the_most_similar_entity( + self, + create: EntityCreate, + ) -> Optional[DBEntity]: + query = f"{create.name}: {create.description}" + similar_entities = self.search_similar_entities(query, top_k=1, nprobe=10) + + if len(similar_entities) == 0: + return None + + most_similar_entity = similar_entities[0] - self._session.commit() - except Exception as e: - logger.error(e, exc_info=True) - self._session.rollback() - raise e + # For same entity. + if ( + most_similar_entity.name == create.name + and most_similar_entity.description == create.description + and len(DeepDiff(most_similar_entity.meta, create.meta)) == 0 + ): + return most_similar_entity + + # For the most similar entity. + if most_similar_entity.distance < self.entity_distance_threshold: + return most_similar_entity + + return None + + # Relationship Basic Operations + + def get_relationship_by_id(self, relationship_id: int) -> Type[SQLModel]: + stmt = select(self._relationship_db_model).where( + self._relationship_db_model.id == relationship_id + ) + return self._db_session.exec(stmt).first() + + def fetch_relationships_page( + self, filters: RelationshipFilters, params: Params + ) -> Page[Type[SQLModel]]: + stmt = self._build_relationships_query(filters) + return paginate(self._db_session, stmt, params) + + def list_relationships(self, filters: RelationshipFilters) -> Sequence[SQLModel]: + stmt = self._build_relationships_query(filters) + return self._db_session.exec(stmt).all() + + def _build_relationships_query(self, filters: RelationshipFilters): + stmt = select(self._relationship_db_model) + if filters.target_entity_id: + stmt = stmt.where( + self._relationship_db_model.target_entity_id == filters.target_entity_id + ) + if filters.target_entity_id: + stmt = stmt.where( + self._relationship_db_model.source_target_id == filters.source_target_id + ) + if filters.relationship_ids: + stmt = stmt.where( + self._relationship_db_model.id.in_(filters.relationship_ids) + ) + if filters.search: + stmt = stmt.where( + or_( + self._relationship_db_model.name.like(f"%{filters.search}%"), + self._relationship_db_model.description.like(f"%{filters.search}%"), + ) + ) + return stmt def create_relationship( self, - source_entity: SQLModel, - target_entity: SQLModel, - relationship: Relationship, - relationship_metadata: dict = {}, - commit=True, - ): - relationship_object = self._relationship_model( + source_entity: Type[SQLModel] | SQLModel, + target_entity: Type[SQLModel] | SQLModel, + description: Optional[str] = None, + metadata: Optional[dict] = {}, + commit: bool = True, + ) -> SQLModel: + """ + Create a relationship between two entities. + """ + description_vec = get_relationship_description_embedding( + source_entity.name, + source_entity.description, + target_entity.name, + target_entity.description, + description, + self._embed_model, + ) + relationship = self._relationship_db_model( source_entity=source_entity, target_entity=target_entity, - description=relationship.relationship_desc, - description_vec=get_relationship_description_embedding( - source_entity.name, - source_entity.description, - target_entity.name, - target_entity.description, - relationship.relationship_desc, - self._embed_model, - ), - meta=relationship_metadata, - document_id=relationship_metadata.get("document_id"), - chunk_id=relationship_metadata.get("chunk_id"), + description=description, + description_vec=description_vec, + meta=metadata, + chunk_id=metadata["chunk_id"] if "chunk_id" in metadata else None, + document_id=metadata["document_id"] if "document_id" in metadata else None, ) - self._session.add(relationship_object) + + self._db_session.add(relationship) if commit: - self._session.commit() - self._session.refresh(relationship_object) + self._db_session.commit() + self._db_session.refresh(relationship) else: - self._session.flush() + self._db_session.flush() - def get_subgraph_by_relationship_ids( - self, ids: list[int], **kwargs - ) -> RetrievedKnowledgeGraph: + return relationship + + def update_relationship( + self, + relationship: Type[SQLModel], + update: RelationshipUpdate, + commit: bool = True, + ) -> Type[SQLModel]: + for key, value in update.items(): + if value is None: + continue + setattr(relationship, key, value) + flag_modified(relationship, key) + + # Update embeddings. + relationship.description_vec = get_relationship_description_embedding( + relationship.source_entity.name, + relationship.source_entity.description, + relationship.target_entity.name, + relationship.target_entity.description, + relationship.description, + self._embed_model, + ) + + self._db_session.add(relationship) + if commit: + self._db_session.commit() + self._db_session.refresh(relationship) + else: + self._db_session.flush() + return relationship + + def delete_relationship(self, relationship: Type[SQLModel], commit: bool = True): + self._db_session.delete(relationship) + + if commit: + self._db_session.commit() + else: + self._db_session.flush() + + def clear_orphan_entities(self): + pass + + # Relationship Chunks Operations + + def exists_chunk_relationships(self, chunk_id: str) -> bool: + stmt = select(self._relationship_db_model).where( + self._relationship_db_model.chunk_id == chunk_id + ) + return self._db_session.exec(stmt).first() is not None + + def batch_get_chunks_by_relationships( + self, relationships_ids: List[int] + ) -> List[Type[SQLModel]]: + """ + Batch get chunks for the provided relationships. + """ + subquery = ( + select(self._relationship_db_model.chunk_id) + .where(self._relationship_db_model.id.in_(relationships_ids)) + .subquery() + ) + stmt = select(self._chunk_db_model).where(self._chunk_db_model.id.in_(subquery)) + return self._db_session.exec(stmt).all() + + # Relationship Retrieve Operations + + def retrieve_relationships( + self, + query: str, + query_embedding: Optional[List[float]] = None, + top_k: int = 10, + nprobe: Optional[int] = None, + similarity_threshold: Optional[float] = 0, + ) -> List[RetrievedRelationship]: + relationships = self.search_similar_relationships( + query=query, + query_embedding=query_embedding, + top_k=top_k, + nprobe=nprobe, + similarity_threshold=similarity_threshold, + ) + return [ + RetrievedRelationship( + id=relationship.id, + knowledge_base_id=self._knowledge_base.id, + source_entity_id=relationship.source_entity_id, + target_entity_id=relationship.target_entity_id, + description=relationship.description, + rag_description=f"{relationship.source_entity.name} -> {relationship.description} -> {relationship.target_entity.name}", + meta=relationship.meta, + weight=relationship.weight, + last_modified_at=relationship.last_modified_at, + similarity_score=relationship.similarity_score, + ) + for relationship, similarity_score in relationships + ] + + def search_similar_relationships( + self, + query: str, + top_k: int = 10, + nprobe: Optional[int] = None, + query_embedding: List[float] = None, + similarity_threshold: Optional[float] = 0, + ) -> List[Tuple[DBRelationship, float]]: + embedding = query_embedding or get_query_embedding(query, self._embed_model) + distance_threshold = 1 - similarity_threshold + nprobe = nprobe or top_k * 10 + + subquery = ( + select( + self._relationship_db_model, + self._relationship_db_model.description_vec.cosine_distance( + embedding + ).label("distance"), + ) + .order_by(asc("distance")) + .limit(nprobe) + .subquery() + ) + query = ( + select(subquery, (1 - subquery.c.distance).label("similarity_score")) + .where(subquery.c.distance <= distance_threshold) + .order_by(desc("similarity_score")) + .limit(top_k) + ) + + results = self._db_session.exec(query).all() + return [(row[0], row.similarity_score) for row in results] + + # Graph Basic Operations + + def list_relationships_by_connected_entity( + self, entity_id: int + ) -> List[Type[SQLModel]]: + stmt = ( + select(self._relationship_db_model) + .where( + (self._relationship_db_model.source_entity_id == entity_id) + | (self._relationship_db_model.target_entity_id == entity_id) + ) + .options( + defer(self._relationship_db_model.description_vec), + joinedload(self._relationship_db_model.source_entity) + .defer(self._entity_db_model.description_vec) + .defer(self._entity_db_model.meta_vec), + joinedload(self._relationship_db_model.target_entity) + .defer(self._entity_db_model.description_vec) + .defer(self._entity_db_model.meta_vec), + ) + ) + return self._db_session.exec(stmt) + + def list_relationships_by_ids( + self, relationship_ids: list[int], **kwargs + ) -> List[Type[SQLModel]]: stmt = ( - select(self._relationship_model) - .where(self._relationship_model.id.in_(ids)) + select(self._relationship_db_model) .options( - joinedload(self._relationship_model.source_entity), - joinedload(self._relationship_model.target_entity), + defer(self._relationship_db_model.description_vec), + joinedload(self._relationship_db_model.source_entity) + .defer(self._entity_db_model.description_vec) + .defer(self._entity_db_model.meta_vec), + joinedload(self._relationship_db_model.target_entity) + .defer(self._entity_db_model.description_vec) + .defer(self._entity_db_model.meta_vec), ) + .where(self._relationship_db_model.id.in_(relationship_ids)) ) - relationships_set = self._session.exec(stmt) + return self._db_session.exec(stmt).all() + + def _relationships_to_knowledge_graph( + self, relationships: list[DBRelationship], **kwargs + ) -> RetrievedKnowledgeGraph: entities_set = set() - relationships = [] + relationship_set = set() entities = [] - for rel in relationships_set: + for rel in relationships: entities_set.add(rel.source_entity) entities_set.add(rel.target_entity) - relationships.append( + relationship_set.add( RetrievedRelationship( id=rel.id, - knowledge_base_id=self.knowledge_base.id, + knowledge_base_id=self._knowledge_base.id, source_entity_id=rel.source_entity_id, target_entity_id=rel.target_entity_id, description=rel.description, @@ -351,7 +765,7 @@ def get_subgraph_by_relationship_ids( entities.append( RetrievedEntity( id=entity.id, - knowledge_base_id=self.knowledge_base.id, + knowledge_base_id=self._knowledge_base.id, name=entity.name, description=entity.description, meta=entity.meta, @@ -360,127 +774,34 @@ def get_subgraph_by_relationship_ids( ) return RetrievedKnowledgeGraph( - knowledge_base=self.knowledge_base.to_descriptor(), + knowledge_base=self._knowledge_base.to_descriptor(), entities=entities, - relationships=relationships, + relationships=list(relationship_set), **kwargs, ) - def get_or_create_entity(self, entity: Entity, commit: bool = True) -> SQLModel: - # using the cosine distance between the description vectors to determine if the entity already exists - entity_type = ( - EntityType.synopsis - if isinstance(entity, SynopsisEntity) - else EntityType.original - ) - entity_description_vec = get_entity_description_embedding( - entity.name, - entity.description, - self._embed_model, - ) - hint = text( - f"/*+ read_from_storage(tikv[{self._entity_model.__tablename__}]) */" - ) - result = ( - self._session.query( - self._entity_model, - self._entity_model.description_vec.cosine_distance( - entity_description_vec - ).label("distance"), - ) - .filter( - self._entity_model.name == entity.name - and self._entity_model.entity_type == entity_type - ) - .prefix_with(hint) - .order_by(asc("distance")) - .first() - ) - if ( - result is not None - and result[1] < self.description_cosine_distance_threshold - ): - db_obj = result[0] - ob_obj_metadata = db_obj.meta - if ( - db_obj.description == entity.description - and db_obj.name == entity.name - and len(DeepDiff(ob_obj_metadata, entity.metadata)) == 0 - ): - return db_obj - elif entity_type == EntityType.original: - # TODO: move to TiDBKnowledgeGraphIndex - # use LLM to merge the most similar entities - merged_entity = self._try_merge_entities( - [ - Entity( - name=db_obj.name, - description=db_obj.description, - metadata=ob_obj_metadata, - ), - Entity( - name=entity.name, - description=entity.description, - metadata=entity.metadata, - ), - ] - ) - if merged_entity is not None: - db_obj.description = merged_entity.description - db_obj.meta = merged_entity.metadata - db_obj.description_vec = get_entity_description_embedding( - db_obj.name, db_obj.description, self._embed_model - ) - db_obj.meta_vec = get_entity_metadata_embedding( - db_obj.meta, self._embed_model - ) + # Knowledge Graph Retrieve Operations - self._session.add(db_obj) - if commit: - self._session.commit() - self._session.refresh(db_obj) - else: - self._session.flush() - return db_obj - - synopsis_info_str = ( - entity.group_info.model_dump() - if entity_type == EntityType.synopsis - else None - ) + def traval_knowledge_graph( + self, + ) -> RetrievedKnowledgeGraph: + pass - db_obj = self._entity_model( - name=entity.name, - description=entity.description, - description_vec=entity_description_vec, - meta=entity.metadata, - meta_vec=get_entity_metadata_embedding(entity.metadata, self._embed_model), - synopsis_info=synopsis_info_str, - entity_type=entity_type, - ) - self._session.add(db_obj) - if commit: - self._session.commit() - self._session.refresh(db_obj) - else: - self._session.flush() - - return db_obj - - def _try_merge_entities(self, entities: List[Entity]) -> Entity: - logger.info(f"Trying to merge entities: {entities[0].name}") - try: - with dspy.settings.context(lm=self._dspy_lm): - pred = self.merge_entities_prog(entities=entities) - return pred.merged_entity - except Exception as e: - logger.error(f"Failed to merge entities: {e}", exc_info=True) - return None + def retrieve_knowledge_graph( + self, + query: Optional[str], + query_embedding: Optional[list[float]], + depth: int = 2, + include_meta: bool = False, + with_degree: bool = False, + metadata_filters: Optional[dict] = None, + ) -> RetrievedKnowledgeGraph: + pass def retrieve_with_weight( self, query: str, - embedding: list, + query_embedding: list, depth: int = 2, include_meta: bool = False, with_degree: bool = False, @@ -488,14 +809,14 @@ def retrieve_with_weight( relationship_meta_filters: dict = {}, session: Optional[Session] = None, ) -> Tuple[List[RetrievedEntity], List[RetrievedRelationship]]: - if not embedding: - assert query, "Either query or embedding must be provided" - embedding = get_query_embedding(query, self._embed_model) + if not query_embedding: + assert query, "Either `query` or `query_embedding` must be provided" + query_embedding = get_query_embedding(query, self._embed_model) relationships, entities = self.search_relationships_weight( - embedding, - [], - [], + query_embedding, + set(), + set(), with_degree=with_degree, relationship_meta_filters=relationship_meta_filters, session=session, @@ -506,23 +827,25 @@ def retrieve_with_weight( visited_entities = set(e.id for e in entities) visited_relationships = set(r.id for r in relationships) + fetch_synopsis_entities_num = 2 + for _ in range(depth - 1): actual_number = 0 progress = 0 - search_number_each_depth = 10 + search_number_each_level = 10 for search_config in DEFAULT_RANGE_SEARCH_CONFIG: search_ratio = search_config[1] search_distance_range = search_config[0] - remaining_number = search_number_each_depth - actual_number + remaining_number = search_number_each_level - actual_number # calculate the expected number based search progress # It's a accumulative search, so the expected number should be the difference between the expected number and the actual number expected_number = ( int( - (search_ratio + progress) * search_number_each_depth + (search_ratio + progress) * search_number_each_level - actual_number ) - if progress * search_number_each_depth > actual_number - else int(search_ratio * search_number_each_depth) + if progress * search_number_each_level > actual_number + else int(search_ratio * search_number_each_level) ) if expected_number > remaining_number: expected_number = remaining_number @@ -530,7 +853,7 @@ def retrieve_with_weight( break new_relationships, new_entities = self.search_relationships_weight( - embedding, + query_embedding, visited_relationships, visited_entities, search_distance_range, @@ -550,21 +873,18 @@ def retrieve_with_weight( if search_ratio != 1: progress += search_ratio - synopsis_entities = self.fetch_similar_entities( - embedding, top_k=2, entity_type=EntityType.synopsis, session=session + # Fetch related synopsis entities. + synopsis_entities = self.search_similar_entities( + entity_type=EntityType.synopsis, + query_embedding=query_embedding, + top_k=fetch_synopsis_entities_num, ) all_entities.update(synopsis_entities) - related_doc_ids = set() - for r in all_relationships: - if "doc_id" not in r.meta: - continue - related_doc_ids.add(r.meta["doc_id"]) - entities = [ RetrievedEntity( id=e.id, - knowledge_base_id=self.knowledge_base.id, + knowledge_base_id=self._knowledge_base.id, name=e.name, description=e.description, meta=e.meta if include_meta else None, @@ -575,7 +895,7 @@ def retrieve_with_weight( relationships = [ RetrievedRelationship( id=r.id, - knowledge_base_id=self.knowledge_base.id, + knowledge_base_id=self._knowledge_base.id, source_entity_id=r.source_entity_id, target_entity_id=r.target_entity_id, rag_description=f"{r.source_entity.name} -> {r.description} -> {r.target_entity.name}", @@ -589,48 +909,6 @@ def retrieve_with_weight( return entities, relationships - # Function to fetch degrees for entities - def fetch_entity_degrees( - self, - entity_ids: List[int], - session: Optional[Session] = None, - ) -> Dict[int, Dict[str, int]]: - degrees = { - entity_id: {"in_degree": 0, "out_degree": 0} for entity_id in entity_ids - } - session = session or self._session - - try: - # Fetch out-degrees - out_degree_query = ( - session.query( - self._relationship_model.source_entity_id, - func.count(self._relationship_model.id).label("out_degree"), - ) - .filter(self._relationship_model.source_entity_id.in_(entity_ids)) - .group_by(self._relationship_model.source_entity_id) - ).all() - - for row in out_degree_query: - degrees[row.source_entity_id]["out_degree"] = row.out_degree - - # Fetch in-degrees - in_degree_query = ( - session.query( - self._relationship_model.target_entity_id, - func.count(self._relationship_model.id).label("in_degree"), - ) - .filter(self._relationship_model.target_entity_id.in_(entity_ids)) - .group_by(self._relationship_model.target_entity_id) - ).all() - - for row in in_degree_query: - degrees[row.target_entity_id]["in_degree"] = row.in_degree - except Exception as e: - logger.error(e) - - return degrees - def search_relationships_weight( self, embedding: List[float], @@ -638,9 +916,9 @@ def search_relationships_weight( visited_entities: Set[int], distance_range: Tuple[float, float] = (0.0, 1.0), limit: int = 100, - weight_coefficient_config: List[ + weight_coefficients: List[ Tuple[Tuple[int, int], float] - ] = DEFAULT_WEIGHT_COEFFICIENT_CONFIG, + ] = DEFAULT_WEIGHT_COEFFICIENTS, alpha: float = 1, rank_n: int = 10, degree_coefficient: float = DEFAULT_DEGREE_COEFFICIENT, @@ -651,28 +929,28 @@ def search_relationships_weight( # select the relationships to rank subquery = ( select( - self._relationship_model, - self._relationship_model.description_vec.cosine_distance( + self._relationship_db_model, + self._relationship_db_model.description_vec.cosine_distance( embedding ).label("embedding_distance"), ) - .options(defer(self._relationship_model.description_vec)) + .options(defer(self._relationship_db_model.description_vec)) .order_by(asc("embedding_distance")) .limit(limit * 10) ).subquery() - relationships_alias = aliased(self._relationship_model, subquery) + relationships_alias = aliased(self._relationship_db_model, subquery) query = ( select(relationships_alias, text("embedding_distance")) .options( defer(relationships_alias.description_vec), joinedload(relationships_alias.source_entity) - .defer(self._entity_model.meta_vec) - .defer(self._entity_model.description_vec), + .defer(self._entity_db_model.meta_vec) + .defer(self._entity_db_model.description_vec), joinedload(relationships_alias.target_entity) - .defer(self._entity_model.meta_vec) - .defer(self._entity_model.description_vec), + .defer(self._entity_db_model.meta_vec) + .defer(self._entity_db_model.description_vec), ) .where(relationships_alias.weight >= 0) ) @@ -683,7 +961,7 @@ def search_relationships_weight( if visited_relationships: query = query.where( - self._relationship_model.id.notin_(visited_relationships) + self._relationship_db_model.id.notin_(visited_relationships) ) if distance_range != (0.0, 1.0): @@ -696,7 +974,7 @@ def search_relationships_weight( if visited_entities: query = query.where( - self._relationship_model.source_entity_id.in_(visited_entities) + self._relationship_db_model.source_entity_id.in_(visited_entities) ) query = query.order_by(asc("embedding_distance")).limit(limit) @@ -742,7 +1020,7 @@ def search_relationships_weight( source_in_degree, target_out_degree, alpha, - weight_coefficient_config, + weight_coefficients, degree_coefficient, with_degree, ) @@ -758,90 +1036,12 @@ def search_relationships_weight( return list(relationship_set), list(entity_set) - def fetch_similar_entities_by_post_filter( - self, - embedding: list, - top_k: int = 5, - entity_type: EntityType = EntityType.original, - session: Optional[Session] = None, - post_filter_multiplier: int = 10, - ): - new_entity_set = set() - session = session or self._session - - # Create a subquery with a larger limit and include the distance - subquery = ( - select( - self._entity_model, - self._entity_model.description_vec.cosine_distance(embedding).label( - "distance" - ), - ) - .order_by(asc("distance")) - .limit( - post_filter_multiplier * top_k - if entity_type != EntityType.original - else top_k - ) - .subquery() - ) - - # Apply filter only for non-original entity types - query = ( - select(self._entity_model) - .where(subquery.c.entity_type == entity_type) - .order_by(asc(subquery.c.distance)) - .limit(top_k) - ) - - for row in session.exec(query).all(): - new_entity_set.add(row) - - return new_entity_set - - def fetch_similar_entities( - self, - embedding: list, - top_k: int = 10, - entity_type: EntityType = EntityType.original, - session: Optional[Session] = None, - ): - new_entity_set = set() - - # Retrieve entities based on their ID and similarity to the embedding - session = session or self._session - - query = select(self._entity_model) - - if entity_type == EntityType.synopsis: - query = query.where(self._entity_model.entity_type == entity_type) - hint = text("/*+ read_from_storage(tikv[entities]) */") - query = query.prefix_with(hint) - - query = query.order_by( - self._entity_model.description_vec.cosine_distance(embedding) - ).limit(top_k) - - # Debug: Print the SQL query - """ - from sqlalchemy.dialects import mysql - compiled_query = query.compile( - dialect=mysql.dialect(), compile_kwargs={"literal_binds": True} - ) - print(f"Debug - SQL Query: {compiled_query}") - """ - - for entity in session.exec(query).all(): - new_entity_set.add(entity) - - return new_entity_set - - def retrieve_graph_data( + def retrieve_subgraph_by_similar( self, query_text: str, top_k: int = 5, similarity_threshold: float = 0.7, - ) -> Dict[str, List[Dict[str, Any]]]: + ) -> RetrievedKnowledgeGraph: """Retrieve related entities and relationships using semantic search. Args: @@ -850,92 +1050,30 @@ def retrieve_graph_data( similarity_threshold: Minimum similarity score threshold Returns: - Dictionary containing: - - entities: List of similar entities with similarity scores - - relationships: List of similar relationships with similarity scores + RetrievedKnowledgeGraph containing similar entities and relationships """ query_embedding = get_query_embedding(query_text, self._embed_model) - # Query similar entities - entity_query = ( - select( - self._entity_model, - ( - 1 - - self._entity_model.description_vec.cosine_distance( - query_embedding - ) - ).label("similarity"), - ) - .options( - defer(self._entity_model.description_vec), - defer(self._entity_model.meta_vec), - ) - .order_by(desc("similarity")) - .limit(top_k * 2) # Fetch more results to account for filtering + # Get similar entities + entities = self.search_similar_entities( + query=query_text, + similarity_threshold=similarity_threshold, + top_k=top_k, ) - # Query similar relationships - relationship_query = ( - select( - self._relationship_model, - ( - 1 - - self._relationship_model.description_vec.cosine_distance( - query_embedding - ) - ).label("similarity"), - ) - .options( - defer(self._relationship_model.description_vec), - joinedload(self._relationship_model.source_entity) - .defer(self._entity_model.meta_vec) - .defer(self._entity_model.description_vec), - joinedload(self._relationship_model.target_entity) - .defer(self._entity_model.meta_vec) - .defer(self._entity_model.description_vec), - ) - .order_by(desc("similarity")) - .limit(top_k * 2) # Fetch more results to account for filtering + # Get similar relationships + relationships = self.search_similar_relationships( + query=query_text, + query_embedding=query_embedding, + similarity_threshold=similarity_threshold, + top_k=top_k, ) - # Execute both queries - entities = [] - relationships = [] - - for entity, similarity in self._session.exec(entity_query).all(): - if similarity >= similarity_threshold and len(entities) < top_k: - entities.append( - { - "id": entity.id, - "name": entity.name, - "description": entity.description, - "metadata": entity.meta, - "similarity_score": similarity, - } - ) - - for relationship, similarity in self._session.exec(relationship_query).all(): - if similarity >= similarity_threshold and len(relationships) < top_k: - relationships.append( - { - "id": relationship.id, - "relationship": relationship.description, - "source_entity": { - "id": relationship.source_entity.id, - "name": relationship.source_entity.name, - "description": relationship.source_entity.description, - }, - "target_entity": { - "id": relationship.target_entity.id, - "name": relationship.target_entity.name, - "description": relationship.target_entity.description, - }, - "similarity_score": similarity, - } - ) - - return {"entities": entities, "relationships": relationships} + return RetrievedKnowledgeGraph( + knowledge_base=self._knowledge_base.to_descriptor(), + entities=entities, + relationships=relationships, + ) def retrieve_neighbors( self, @@ -948,7 +1086,7 @@ def retrieve_neighbors( """Retrieve most relevant neighbor paths for a group of similar nodes. Args: - node_ids: List of source node IDs (representing similar entities) + entities_ids: List of source node IDs (representing similar entities) query: Search query for relevant relationships max_depth: Maximum depth for relationship traversal max_neighbors: Maximum number of total neighbor paths to return @@ -958,15 +1096,6 @@ def retrieve_neighbors( Dictionary containing most relevant paths from source nodes to neighbors """ query_embedding = get_query_embedding(query, self._embed_model) - # Get all source entities - source_entities = self._session.exec( - select(self._entity_model) - .options( - defer(self._entity_model.description_vec), - defer(self._entity_model.meta_vec), - ) - .where(self._entity_model.id.in_(entities_ids)) - ).all() # Track visited nodes and discovered paths all_visited = set(entities_ids) @@ -978,38 +1107,12 @@ def retrieve_neighbors( break # Query relationships for current level - relationships = self._session.exec( - select( - self._relationship_model, - ( - 1 - - self._relationship_model.description_vec.cosine_distance( - query_embedding - ) - ).label("similarity"), - ) - .options( - defer(self._relationship_model.description_vec), - joinedload(self._relationship_model.source_entity) - .defer(self._entity_model.meta_vec) - .defer(self._entity_model.description_vec), - joinedload(self._relationship_model.target_entity) - .defer(self._entity_model.meta_vec) - .defer(self._entity_model.description_vec), - ) - .where( - or_( - self._relationship_model.source_entity_id.in_( - current_level_nodes - ), - self._relationship_model.target_entity_id.in_( - current_level_nodes - ), - ) - ) - .order_by(desc("similarity")) - .limit(max_neighbors * 2) # Fetch more results to account for filtering - ).all() + relationships = self.search_similar_relationships( + query=query, + query_embedding=query_embedding, + nprobe=100, + similarity_threshold=similarity_threshold, + ) next_level_nodes = set() @@ -1054,59 +1157,3 @@ def retrieve_neighbors( neighbors.sort(key=lambda x: x["similarity_score"], reverse=True) return {"relationships": neighbors[:max_neighbors]} - - def get_chunks_by_relationships( - self, - relationships_ids: List[int], - session: Optional[Session] = None, - ) -> List[Dict[str, Any]]: - """Get chunks for a list of relationships. - - Args: - relationships: List of relationship objects - session: Optional database session - - Returns: - List of dictionaries containing chunk information: - - text: chunk text content - - document_id: associated document id - - meta: chunk metadata - """ - session = session or self._session - - relationships = session.exec( - select(self._relationship_model).where( - self._relationship_model.id.in_(relationships_ids) - ) - ).all() - - # Extract chunk IDs from relationships - chunk_ids = { - rel.meta.get("chunk_id") - for rel in relationships - if rel.meta.get("chunk_id") is not None - } - - if not chunk_ids: - return [] - - # Query chunks - chunks = session.exec( - select(self._chunk_model).where(self._chunk_model.id.in_(chunk_ids)) - ).all() - - return [ - { - "id": chunk.id, - "text": chunk.text, - "document_id": chunk.document_id, - "meta": { - "language": chunk.meta.get("language"), - "product": chunk.meta.get("product"), - "resource": chunk.meta.get("resource"), - "source_uri": chunk.meta.get("source_uri"), - "tidb_version": chunk.meta.get("tidb_version"), - }, - } - for chunk in chunks - ] diff --git a/backend/app/rag/indices/knowledge_graph/schema.py b/backend/app/rag/indices/knowledge_graph/schema.py index 01efcd584..d695aa6bb 100644 --- a/backend/app/rag/indices/knowledge_graph/schema.py +++ b/backend/app/rag/indices/knowledge_graph/schema.py @@ -1,8 +1,11 @@ -from pydantic import BaseModel, Field -from typing import Mapping, Any, List +from typing import Mapping, Any, List, Optional +import dspy +from dspy import TypedPredictor +from pydantic import BaseModel, model_validator, Field +from app.models.entity import EntityType -class Entity(BaseModel): +class AIEntity(BaseModel): """List of entities extracted from the text to form the knowledge graph""" name: str = Field( @@ -22,7 +25,7 @@ class Entity(BaseModel): ) -class EntityWithID(Entity): +class AIEntityWithID(AIEntity): """Entity extracted from the text to form the knowledge graph with an ID.""" id: int = Field(description="Unique identifier for the entity.") @@ -39,7 +42,7 @@ class SynopsisInfo(BaseModel): ) -class SynopsisEntity(Entity): +class SynopsisEntity(AIEntity): """Unified synopsis entity with comprehensive description and metadata based on the entities group.""" group_info: SynopsisInfo = Field( @@ -53,7 +56,7 @@ class ExistingSynopsisEntity(SynopsisEntity): id: int = Field(description="Unique identifier for the entity.") -class Relationship(BaseModel): +class AIRelationship(BaseModel): """List of relationships extracted from the text to form the knowledge graph""" source_entity: str = Field( @@ -70,7 +73,7 @@ class Relationship(BaseModel): ) -class RelationshipReasoning(Relationship): +class AIRelationshipReasoning(AIRelationship): """Relationship between two entities extracted from the query""" reasoning: str = Field( @@ -80,13 +83,13 @@ class RelationshipReasoning(Relationship): ) -class KnowledgeGraph(BaseModel): +class AIKnowledgeGraph(BaseModel): """Graph representation of the knowledge for text.""" - entities: List[Entity] = Field( + entities: List[AIEntity] = Field( description="List of entities in the knowledge graph" ) - relationships: List[Relationship] = Field( + relationships: List[AIRelationship] = Field( description="List of relationships in the knowledge graph" ) @@ -113,6 +116,95 @@ class EntityCovariateOutput(BaseModel): class DecomposedFactors(BaseModel): """Decomposed factors extracted from the query to form the knowledge graph""" - relationships: List[RelationshipReasoning] = Field( + relationships: List[AIRelationshipReasoning] = Field( description="List of relationships to represent critical concepts and their relationships extracted from the query." ) + + +class MergeEntities(dspy.Signature): + """As a knowledge expert assistant specialized in database technologies, evaluate the two provided entities. These entities have been pre-analyzed and have same name but different descriptions and metadata. + Please carefully review the detailed descriptions and metadata for both entities to determine if they genuinely represent the same concept or object(entity). + If you conclude that the entities are identical, merge the descriptions and metadata fields of the two entities into a single consolidated entity. + If the entities are distinct despite their same name that may be due to different contexts or perspectives, do not merge the entities and return none as the merged entity. + + Considerations: Ensure your decision is based on a comprehensive analysis of the content and context provided within the entity descriptions and metadata. + Please only response in JSON Format. + """ + + entities: List[AIEntity] = dspy.InputField( + desc="List of entities identified from previous analysis." + ) + merged_entity: Optional[AIEntity] = dspy.OutputField( + desc="Merged entity with consolidated descriptions and metadata." + ) + + +class MergeEntitiesProgram(dspy.Module): + def __init__(self): + self.prog = TypedPredictor(MergeEntities) + + def forward(self, entities: List[AIEntity]): + if len(entities) != 2: + raise ValueError("The input should contain exactly two entities") + + pred = self.prog(entities=entities) + return pred + + +# Entity + + +class EntityCreate(BaseModel): + entity_type: Optional[EntityType] = EntityType.original + name: Optional[str] = None + description: Optional[str] = None + meta: Optional[dict] = None + + +class SynopsisEntityCreate(EntityCreate): + topic: str + entities: List[int] = Field(description="The id list of the related entities") + + @model_validator(mode="after") + def validate_entities(self): + if len(self.entities) == 0: + raise ValueError("Entities list should not be empty") + return self + + +class EntityFilters(BaseModel): + entity_ids: Optional[List[int]] = None + entity_type: Optional[EntityType] = None + search: Optional[str] = None + + +class EntityUpdate(BaseModel): + description: Optional[str] = None + meta: Optional[dict] = None + + +class EntityDegree(BaseModel): + entity_id: int + out_degree: int + in_degree: int + degrees: int + + +# Relationship + + +class RelationshipCreate(BaseModel): + source_entity_id: int + target_entity_id: int + description: str + + +class RelationshipUpdate(BaseModel): + description: Optional[str] = None + + +class RelationshipFilters(BaseModel): + target_entity_id: Optional[int] = None + source_entity_id: Optional[int] = None + relationship_ids: Optional[List[int]] = None + search: Optional[str] = None diff --git a/backend/app/rag/knowledge_base/index_store.py b/backend/app/rag/knowledge_base/index_store.py index a1eb98c34..cb0b6816d 100644 --- a/backend/app/rag/knowledge_base/index_store.py +++ b/backend/app/rag/knowledge_base/index_store.py @@ -1,11 +1,7 @@ -from sqlalchemy import inspection from sqlmodel import Session from app.models import KnowledgeBase from app.models.chunk import get_kb_chunk_model -from app.models.entity import get_kb_entity_model -from app.rag.knowledge_base.config import get_kb_dspy_llm, get_kb_embed_model -from app.models.relationship import get_kb_relationship_model from app.rag.indices.knowledge_graph.graph_store import TiDBGraphStore, TiDBGraphEditor from app.rag.indices.vector_search.vector_store.tidb_vector_store import TiDBVectorStore @@ -23,23 +19,7 @@ def init_kb_tidb_vector_store(session: Session, kb: KnowledgeBase) -> TiDBVector def get_kb_tidb_graph_store(session: Session, kb: KnowledgeBase) -> TiDBGraphStore: - dspy_lm = get_kb_dspy_llm(session, kb) - embed_model = get_kb_embed_model(session, kb) - entity_model = get_kb_entity_model(kb) - relationship_model = get_kb_relationship_model(kb) - inspection.inspect(relationship_model) - chunk_model = get_kb_chunk_model(kb) - - graph_store = TiDBGraphStore( - knowledge_base=kb, - dspy_lm=dspy_lm, - session=session, - embed_model=embed_model, - entity_db_model=entity_model, - relationship_db_model=relationship_model, - chunk_db_model=chunk_model, - ) - return graph_store + return TiDBGraphStore.from_knowledge_base(kb, session) def init_kb_tidb_graph_store(session: Session, kb: KnowledgeBase) -> TiDBGraphStore: @@ -48,13 +28,6 @@ def init_kb_tidb_graph_store(session: Session, kb: KnowledgeBase) -> TiDBGraphSt return graph_store -def get_kb_tidb_graph_editor(session: Session, kb: KnowledgeBase) -> TiDBGraphEditor: - entity_db_model = get_kb_entity_model(kb) - relationship_db_model = get_kb_relationship_model(kb) - embed_model = get_kb_embed_model(session, kb) - return TiDBGraphEditor( - knowledge_base_id=kb.id, - entity_db_model=entity_db_model, - relationship_db_model=relationship_db_model, - embed_model=embed_model, - ) +def get_kb_graph_editor(session: Session, kb: KnowledgeBase) -> TiDBGraphEditor: + graph_store = get_kb_tidb_graph_store(session, kb) + return TiDBGraphEditor(session, graph_store) diff --git a/backend/app/rag/retrievers/knowledge_graph/schema.py b/backend/app/rag/retrievers/knowledge_graph/schema.py index 36868e842..b3186162b 100644 --- a/backend/app/rag/retrievers/knowledge_graph/schema.py +++ b/backend/app/rag/retrievers/knowledge_graph/schema.py @@ -66,6 +66,9 @@ class RetrievedEntity(BaseModel): name: str = Field(description="Name of the entity") description: str = Field(description="Description of the entity") meta: Optional[Mapping[str, Any]] = Field(description="Metadata of the entity") + similarity_score: Optional[float] = Field( + description="Similarity score of the entity", default=None + ) @property def global_id(self) -> str: @@ -91,6 +94,9 @@ class RetrievedRelationship(BaseModel): last_modified_at: Optional[datetime.datetime] = Field( description="Last modified at of the relationship", default=None ) + similarity_score: Optional[float] = Field( + description="Similarity score of the relationship", default=None + ) @property def global_id(self) -> str: diff --git a/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py b/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py index 5d5c42d13..c8cd12b91 100644 --- a/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py +++ b/backend/app/rag/retrievers/knowledge_graph/simple_retriever.py @@ -44,8 +44,8 @@ def __init__( dspy_lm=dspy_lm, session=db_session, embed_model=self.embed_model, - entity_db_model=self.entity_db_model, - relationship_db_model=self.relationship_db_model, + entity_model=self.entity_db_model, + relationship_model=self.relationship_db_model, ) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: @@ -55,7 +55,7 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: entities, relationships = self._kg_store.retrieve_with_weight( query_bundle.query_str, - embedding=[], + query_embedding=[], depth=self.config.depth, include_meta=self.config.include_meta, with_degree=self.config.with_degree, diff --git a/backend/tests/rag/storage/graph_store/__init__.py b/backend/tests/rag/storage/graph_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/rag/storage/graph_store/test_graph_store.py b/backend/tests/rag/storage/graph_store/test_graph_store.py new file mode 100644 index 000000000..97e86c3b4 --- /dev/null +++ b/backend/tests/rag/storage/graph_store/test_graph_store.py @@ -0,0 +1,98 @@ +from dotenv import load_dotenv +from sqlmodel import Session + +from app.core.db import engine +from app.models import ( + Chunk as DBChunk, + Entity as DBEntity, + Relationship as DBRelationship, +) +from app.models.knowledge_base import KnowledgeBase +from app.rag.indices.knowledge_graph.graph_store.helpers import get_default_embed_model +from app.rag.indices.knowledge_graph.graph_store.tidb_graph_store import TiDBGraphStore +from app.rag.indices.knowledge_graph.schema import ( + EntityCreate, + EntityDegree, +) + +load_dotenv() + + +class TestGraphStore: + @classmethod + def setup_class(cls): + """Set up test fixtures before running any tests in the class""" + cls.db_session = Session(engine) + # Create a test knowledge base + cls.kb = KnowledgeBase(name="test_kb") + cls.db_session.add(cls.kb) + cls.db_session.commit() + + cls.graph_store = TiDBGraphStore( + db_session=cls.db_session, + knowledge_base=cls.kb, + embed_model=get_default_embed_model(), + entity_model=DBEntity, + relationship_model=DBRelationship, + chunk_model=DBChunk, + ) + + @classmethod + def teardown_class(cls): + """Clean up after all tests in the class have run""" + # Clean up the test data + cls.db_session.delete(cls.kb) + cls.db_session.commit() + cls.db_session.close() + + def test_calc_entity_degrees(self): + tidb_entity = self.graph_store.create_entity( + EntityCreate( + name="TiDB", + ) + ) + tikv_entity = self.graph_store.create_entity( + EntityCreate( + name="TiKV", + ) + ) + ticdc_entity = self.graph_store.create_entity( + EntityCreate( + name="TiCDC", + ) + ) + self.graph_store.create_relationship( + source_entity=tidb_entity, + target_entity=tikv_entity, + description="TiDB has a component named TiKV", + ) + self.graph_store.create_relationship( + source_entity=tidb_entity, + target_entity=ticdc_entity, + description="TiDB has a tool named TiCDC", + ) + + out_degree = self.graph_store.calc_entity_out_degree(tidb_entity.id) + assert out_degree == 2 + + in_degree = self.graph_store.calc_entity_in_degree(tikv_entity.id) + assert in_degree == 1 + + degrees = self.graph_store.calc_entities_degrees( + [tidb_entity.id, tikv_entity.id, ticdc_entity.id] + ) + print(degrees) + assert degrees == [ + EntityDegree( + entity_id=tidb_entity.id, in_degree=0, out_degree=2, degrees=2 + ), + EntityDegree( + entity_id=tikv_entity.id, in_degree=1, out_degree=0, degrees=1 + ), + EntityDegree( + entity_id=ticdc_entity.id, in_degree=1, out_degree=0, degrees=1 + ), + ] + + def test_another_method(self): + pass