diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ad0928..cafea4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,54 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.0] - 2026-01-11 + +### 🏗️ Major Refactoring Release + +Complete architectural overhaul for better maintainability and performance. + +### Changed + +#### Modular Architecture +- **Refactored from single file to multi-module structure** + - From: `memory.rs` (2505 lines, monolithic) + - To: 35+ files in 8 organized modules +- **New module structure:** + - `src/types/` - Core data models (Entity, Relation, KnowledgeGraph) + - `src/protocol/` - JSON-RPC and MCP protocol handling + - `src/knowledge_base/` - Core engine with CRUD, queries, temporal + - `src/tools/` - 15 MCP tools organized by category (memory, query, temporal) + - `src/search/` - Semantic search with synonym expansion + - `src/server/` - MCP server implementation + - `src/validation/` - Entity and relation type validation + - `src/utils/` - Timestamp and user utilities +- **Library + Binary separation** + - `src/lib.rs` - Public API for embedding + - `src/main.rs` - Minimal binary entry point + +#### Performance Optimization +- **Mutex → RwLock migration** for `KnowledgeBase.graph` + - Allows multiple concurrent readers (60% of operations are reads) + - Write operations still have exclusive access + - Significant performance boost for multi-agent scenarios +- **Documentation:** See `docs/Proposed-RwLock.md` for risk analysis + +#### Docker +- Updated `Dockerfile` for new `src/` directory structure +- Better layer caching with separate Cargo.toml and src copies + +### Added +- `src/lib.rs` - Library crate for embedding in other projects +- `tests/integration_tests.rs` - 8 integration tests including concurrency tests +- `docs/Proposed-RwLock.md` - RwLock migration documentation + +### Technical Details +- **Test suite expanded:** 16 tests (7 unit + 8 integration + 1 doc) +- **Zero-cost abstractions:** No runtime overhead from modularization +- **Backward compatible:** All 15 MCP tools unchanged + +--- + ## [1.0.0] - 2026-01-11 ### 🎉 Initial Release @@ -74,3 +122,4 @@ First production-ready release of Memory Graph MCP Server. - Multi-tenant support - WAL (Write-Ahead Log) for large graphs - Import/Export with external knowledge bases +- `parking_lot::RwLock` upgrade if benchmarks show bottleneck diff --git a/Cargo.lock b/Cargo.lock index 4cb83a4..583c56e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,64 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "itoa" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "memchr" version = "2.7.6" @@ -15,13 +67,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] -name = "memory-server" -version = "1.0.0" +name = "memory-graph" +version = "1.1.0" dependencies = [ "serde", "serde_json", + "tempfile", ] +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + [[package]] name = "proc-macro2" version = "1.0.105" @@ -40,6 +99,25 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "serde" version = "1.0.228" @@ -94,12 +172,55 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "zmij" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index 6555b02..229c971 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,21 +1,36 @@ [package] -name = "memory-server" -version = "1.0.0" +name = "memory-graph" +version = "1.1.0" edition = "2021" authors = ["Memory Graph MCP Server"] description = "A knowledge graph MCP server implementing the Model Context Protocol" license = "MIT" +readme = "README.md" +repository = "https://github.com/maithanhduyan/memory-graph" +keywords = ["mcp", "knowledge-graph", "ai", "memory", "llm"] +categories = ["development-tools", "data-structures"] + +[lib] +name = "memory_graph" +path = "src/lib.rs" [[bin]] name = "memory-server" -path = "memory.rs" +path = "src/main.rs" [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +[dev-dependencies] +tempfile = "3" + [profile.release] opt-level = 3 lto = true codegen-units = 1 strip = true + +[profile.dev] +opt-level = 0 +debug = true diff --git a/Dockerfile b/Dockerfile index 1cc74bc..b1bf1bd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,8 +14,11 @@ RUN apk add --no-cache musl-dev WORKDIR /app -# Copy all source files -COPY Cargo.toml Cargo.lock* memory.rs ./ +# Copy Cargo files first (for better layer caching) +COPY Cargo.toml Cargo.lock* ./ + +# Copy source directory with full structure +COPY src/ ./src/ # Build for release RUN cargo build --release diff --git a/docs/Proposed-RwLock.md b/docs/Proposed-RwLock.md new file mode 100644 index 0000000..e0aad43 --- /dev/null +++ b/docs/Proposed-RwLock.md @@ -0,0 +1,178 @@ +# Proposed: Mutex → RwLock Migration + +> **Status**: ✅ Approved +> **Date**: 2026-01-11 +> **Risk Level**: 🟢 LOW + +--- + +## 📋 Executive Summary + +Đề xuất chuyển đổi `std::sync::Mutex` sang `std::sync::RwLock` trong `KnowledgeBase` để tối ưu hiệu năng cho hệ thống **read-heavy**. + +--- + +## 🎯 Motivation + +### Phân tích Access Pattern + +| Operation Type | Count | Tools | +|----------------|-------|-------| +| **Read** | 9 | `search_nodes`, `read_graph`, `open_nodes`, `get_related`, `traverse`, `summarize`, `get_relations_at_time`, `get_relation_history` | +| **Write** | 6 | `create_entities`, `create_relations`, `add_observations`, `delete_entities`, `delete_observations`, `delete_relations` | + +**Kết luận**: Hệ thống là **Read-Heavy** (60% read vs 40% write). + +### Vấn đề với Mutex + +``` +Thread A: search_nodes() → lock() → READ → unlock() +Thread B: search_nodes() → BLOCKED (chờ Thread A) +Thread C: search_nodes() → BLOCKED (chờ Thread A, B) +``` + +Với Mutex, tất cả operations (kể cả read-only) phải chờ tuần tự. + +### Giải pháp với RwLock + +``` +Thread A: search_nodes() → read() → READ +Thread B: search_nodes() → read() → READ (PARALLEL!) +Thread C: search_nodes() → read() → READ (PARALLEL!) +``` + +RwLock cho phép **multiple concurrent readers**. + +--- + +## 📊 Risk Analysis + +### Memory Safety + +| Risk | Level | Mitigation | +|------|-------|------------| +| Race condition (memory) | 🟢 LOW | RwLock guarantees: exclusive write OR multiple reads | +| Race condition (file I/O) | 🟢 LOW | `persist_to_file()` called inside write lock scope | +| Stale read | 🟢 LOW | Readers clone data → consistent snapshot | +| Deadlock | 🟢 LOW | No nested locks in codebase | +| Data loss | 🟢 LOW | `fs::write()` is atomic (write-replace pattern) | + +### Code Pattern Verification + +**Current pattern (SAFE):** +```rust +pub fn create_entities(kb: &KnowledgeBase, ...) { + let mut graph = kb.graph.lock().unwrap(); // Lock acquired + // ... modify graph ... + kb.persist_to_file(&graph)?; // Persist INSIDE lock + Ok(created) // Lock released on return +} +``` + +**After RwLock (STILL SAFE):** +```rust +pub fn create_entities(kb: &KnowledgeBase, ...) { + let mut graph = kb.graph.write().unwrap(); // Write lock acquired + // ... modify graph ... + kb.persist_to_file(&graph)?; // Persist INSIDE lock + Ok(created) // Lock released on return +} +``` + +--- + +## 🔧 Implementation Plan + +### Files to Modify + +| File | Changes | +|------|---------| +| `src/knowledge_base/mod.rs` | `Mutex` → `RwLock`, `.lock()` → `.read()` | +| `src/knowledge_base/crud.rs` | `.lock()` → `.write()` (6 places) | + +### Code Changes + +#### 1. mod.rs - Struct Definition + +```diff +- use std::sync::Mutex; ++ use std::sync::RwLock; + +pub struct KnowledgeBase { + pub(crate) memory_file_path: String, +- pub(crate) graph: Mutex, ++ pub(crate) graph: RwLock, + pub(crate) current_user: String, +} +``` + +#### 2. mod.rs - Initialization + +```diff +Self { + memory_file_path, +- graph: Mutex::new(graph), ++ graph: RwLock::new(graph), + current_user, +} +``` + +#### 3. mod.rs - load_graph() + +```diff +pub(crate) fn load_graph(&self) -> McpResult { +- Ok(self.graph.lock().unwrap().clone()) ++ Ok(self.graph.read().unwrap().clone()) +} +``` + +#### 4. crud.rs - All Write Operations + +```diff +pub fn create_entities(kb: &KnowledgeBase, entities: Vec) { +- let mut graph = kb.graph.lock().unwrap(); ++ let mut graph = kb.graph.write().unwrap(); + // ... rest unchanged +} +``` + +--- + +## 📈 Expected Performance Impact + +| Scenario | Mutex | RwLock | Improvement | +|----------|-------|--------|-------------| +| 10 concurrent reads | Sequential | Parallel | ~10x faster | +| 5 reads + 1 write | All blocked | Reads wait for write only | ~5x faster | +| 10 concurrent writes | Sequential | Sequential | Same | + +--- + +## ✅ Testing Plan + +1. **Unit tests**: Run existing test suite +2. **Concurrent test**: `test_concurrent_access` in integration tests +3. **Stress test**: Manual testing with multiple MCP clients + +--- + +## 🚀 Rollback Plan + +If issues arise, revert changes: +```diff +- use std::sync::RwLock; ++ use std::sync::Mutex; +``` + +Single commit, easy to revert. + +--- + +## 📝 Decision + +**APPROVED** - Proceed with implementation. + +- Risk is low +- Performance benefit is significant for read-heavy workloads +- Code changes are minimal and well-understood +- Existing test suite provides safety net diff --git a/memory.jsonl b/memory.jsonl index 023c77a..e7ca13d 100644 --- a/memory.jsonl +++ b/memory.jsonl @@ -30,6 +30,12 @@ {"name":"Bug: Update Paradox","entityType":"Bug","observations":["JSONL lưu static snapshot, không phải event sourcing","Khi update: hoặc có 2 dòng mâu thuẫn, hoặc mất immutable history","Priority: Medium","Solution: Thêm Temporal Edges với valid_from/valid_to","FIXED: Added validFrom/validTo fields to Relation struct","FIXED: New tool get_relations_at_time - query relations valid at specific timestamp","FIXED: New tool get_relation_history - view all relations including expired ones","Total tools now: 16 (9 memory + 3 query + 2 temporal + 1 time + 1 summarize)","Status: Resolved"],"createdAt":1768062649,"updatedAt":1768063434} {"name":"Bug: Context Window Killer","entityType":"Bug","observations":["read_graph dump toàn bộ graph","Với 10000 nodes sẽ tràn context window của LLM","Priority: High","Solution: Implement Graph RAG, chỉ lấy relevant nodes + 1-hop neighbors","FIXED: Added limit/offset to read_graph for pagination","FIXED: Added limit/includeRelations to search_nodes","AI agents can now control context size when querying large graphs","Status: Resolved"],"createdAt":1768062649,"updatedAt":1768063183} {"name":"Bug: Semantic Blindness","entityType":"Bug","observations":["search_nodes chỉ dùng text matching đơn thuần","Search 'Coder' không tìm được 'Software Engineer'","Priority: High","Solution: Tích hợp Vector Search với embeddings","FIXED: Implemented synonym dictionary with 15+ groups","Covers: developer roles, bug/issue, feature/task, status, priority, project mgmt, docs, testing, architecture","Search 'coder' now matches 'programmer', 'developer', 'engineer', 'dev'","Zero dependencies - pure Rust, no AI models needed","Status: Resolved"],"createdAt":1768062649,"updatedAt":1768063937} +{"name":"Refactoring: Memory Graph Modularization","entityType":"Decision","observations":["Quyết định tái cấu trúc memory.rs (2505 dòng) thành multi-file architecture","Mục tiêu: Dễ maintain, test, và mở rộng tính năng mới","Giữ nguyên hiệu suất với zero-cost abstractions","5 Phases: Types → Protocol → KnowledgeBase → Tools → Cleanup","✅ Refactoring hoàn thành thành công - 2026-01-11","Từ 1 file memory.rs (2505 dòng) → 35+ files trong src/","All 16 tests passed (7 unit + 8 integration + 1 doc)","Build release thành công với LTO optimization","Cấu trúc mới: types/, protocol/, knowledge_base/, tools/, search/, validation/, utils/, server/"],"createdBy":"Mai Thành Duy An","updatedBy":"Mai Thành Duy An","createdAt":1768095691,"updatedAt":1768096450} +{"name":"Module: src/types","entityType":"Module","observations":["Chứa tất cả data models: Entity, Relation, KnowledgeGraph","Files: mod.rs, entity.rs, relation.rs, graph.rs, observation.rs, traversal.rs, summary.rs"],"createdBy":"Mai Thành Duy An","updatedBy":"Mai Thành Duy An","createdAt":1768095691,"updatedAt":1768095691} +{"name":"Module: src/protocol","entityType":"Module","observations":["Chứa JSON-RPC và MCP protocol types","Files: mod.rs, jsonrpc.rs, mcp.rs","Includes Tool trait definition"],"createdBy":"Mai Thành Duy An","updatedBy":"Mai Thành Duy An","createdAt":1768095691,"updatedAt":1768095691} +{"name":"Module: src/knowledge_base","entityType":"Module","observations":["Core engine với thread-safe operations","Files: mod.rs, crud.rs, query.rs, traversal.rs, temporal.rs, summarize.rs","🔄 RwLock Migration - 2026-01-11","Chuyển từ std::sync::Mutex sang std::sync::RwLock","Lý do: Hệ thống read-heavy (60% read vs 40% write)","Thay đổi mod.rs: Mutex → RwLock","Thay đổi mod.rs: .lock() → .read() cho load_graph()","Thay đổi crud.rs: 6x .lock() → .write() cho CRUD operations","Performance benefit: Cho phép multiple concurrent readers","Test result: 16/16 tests passed including concurrent access tests","Documentation: docs/Proposed-RwLock.md"],"createdBy":"Mai Thành Duy An","updatedBy":"Mai Thành Duy An","createdAt":1768095691,"updatedAt":1768098901} +{"name":"Module: src/tools","entityType":"Module","observations":["15 MCP tools organized by category","Submodules: memory/ (9 tools), query/ (3 tools), temporal/ (3 tools)"],"createdBy":"Mai Thành Duy An","updatedBy":"Mai Thành Duy An","createdAt":1768095691,"updatedAt":1768095691} +{"name":"Decision: Mutex to RwLock Migration","entityType":"Decision","observations":["Date: 2026-01-11","Status: Completed","Risk Level: LOW","Affected files: src/knowledge_base/mod.rs, src/knowledge_base/crud.rs","Rationale: System is read-heavy (search, traverse, get_related = 60% of operations)","Before: Mutex blocks all threads even for read-only operations","After: RwLock allows multiple concurrent readers, only write blocks","Pattern: Write lock held during entire CRUD + persist_to_file()","Validation: All 16 tests passed including test_concurrent_access and test_concurrent_read_write","Rollback: Simple - change RwLock back to Mutex if issues arise"],"createdBy":"Mai Thành Duy An","updatedBy":"Mai Thành Duy An","createdAt":1768098901,"updatedAt":1768098901} {"from":"tiach","to":"Memory Graph MCP Server","relationType":"develops"} {"from":"Memory Graph MCP Server","to":"Feature: Create Entities","relationType":"implements"} {"from":"Memory Graph MCP Server","to":"Feature: Create Relations","relationType":"implements"} @@ -76,3 +82,5 @@ {"from":"Bug: Update Paradox","to":"Memory Graph MCP Server","relationType":"affects","createdAt":1768062654} {"from":"Bug: Context Window Killer","to":"Memory Graph MCP Server","relationType":"affects","createdAt":1768062654} {"from":"Bug: Semantic Blindness","to":"Memory Graph MCP Server","relationType":"affects","createdAt":1768062654} +{"from":"Decision: Mutex to RwLock Migration","to":"Module: src/knowledge_base","relationType":"affects","createdBy":"Mai Thành Duy An","createdAt":1768098907} +{"from":"Memory Graph MCP Server","to":"Decision: Mutex to RwLock Migration","relationType":"implements","createdBy":"Mai Thành Duy An","createdAt":1768098907} diff --git a/memory.rs b/memory.rs deleted file mode 100644 index 40134c5..0000000 --- a/memory.rs +++ /dev/null @@ -1,2504 +0,0 @@ -//! Memory Graph MCP Server - Single File Implementation -//! A knowledge graph server implementing the Model Context Protocol (MCP) -//! using pure Rust with minimal dependencies. - -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use std::collections::{HashMap, HashSet}; -use std::env; -use std::fs; -use std::io::{self, BufRead, BufReader, BufWriter, Write}; -use std::path::Path; -use std::process::Command; -use std::sync::Mutex; -use std::time::{SystemTime, UNIX_EPOCH}; - -// ============================================================================ -// Types -// ============================================================================ - -pub type McpResult = Result>; - -// ============================================================================ -// Standard Entity & Relation Types (Soft Validation) -// ============================================================================ - -/// Standard entity types for software project management -const STANDARD_ENTITY_TYPES: &[&str] = &[ - "Project", "Module", "Feature", "Bug", "Decision", - "Requirement", "Milestone", "Risk", "Convention", "Schema", "Person", -]; - -/// Standard relation types for software project management -const STANDARD_RELATION_TYPES: &[&str] = &[ - "contains", "implements", "fixes", "caused_by", "depends_on", - "blocked_by", "assigned_to", "part_of", "relates_to", "supersedes", - "affects", "requires", -]; - -/// Check if entity type is standard, return warning if not -fn validate_entity_type(entity_type: &str) -> Option { - if STANDARD_ENTITY_TYPES.iter().any(|&t| t.eq_ignore_ascii_case(entity_type)) { - None - } else { - Some(format!( - "⚠️ Non-standard entityType '{}'. Recommended: {:?}", - entity_type, STANDARD_ENTITY_TYPES - )) - } -} - -/// Check if relation type is standard, return warning if not -fn validate_relation_type(relation_type: &str) -> Option { - if STANDARD_RELATION_TYPES.iter().any(|&t| t.eq_ignore_ascii_case(relation_type)) { - None - } else { - Some(format!( - "⚠️ Non-standard relationType '{}'. Recommended: {:?}", - relation_type, STANDARD_RELATION_TYPES - )) - } -} - -/// Get current Unix timestamp in seconds -fn current_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() -} - -/// Get current user from git config or OS environment -fn get_current_user() -> String { - // 1. Try Git Config (preferred for project context) - if let Ok(output) = Command::new("git").args(["config", "user.name"]).output() { - if output.status.success() { - let name = String::from_utf8_lossy(&output.stdout).trim().to_string(); - if !name.is_empty() { - return name; - } - } - } - - // 2. Try OS Environment Variable - env::var("USER") // Linux/Mac - .or_else(|_| env::var("USERNAME")) // Windows - .unwrap_or_else(|_| "anonymous".to_string()) -} - -/// Default user for serde deserialization -fn default_user() -> String { - "system".to_string() -} - -/// Check if string is empty or "system" (for skip_serializing_if) -fn is_default_user(val: &str) -> bool { - val.is_empty() || val == "system" -} - -// ============================================================================ -// Synonym Dictionary for Semantic Search -// ============================================================================ - -/// Synonym groups - words in same group are considered semantically similar -const SYNONYM_GROUPS: &[&[&str]] = &[ - // Developer roles - &["coder", "programmer", "developer", "engineer", "dev", "software engineer", "software developer"], - &["frontend", "front-end", "ui developer", "client-side"], - &["backend", "back-end", "server-side", "api developer"], - &["fullstack", "full-stack", "full stack"], - &["devops", "sre", "infrastructure", "platform engineer"], - - // Bug/Issue related - &["bug", "issue", "defect", "error", "problem", "fault", "glitch"], - &["fix", "patch", "hotfix", "bugfix", "repair", "resolve"], - - // Feature/Task related - &["feature", "functionality", "capability", "enhancement"], - &["task", "ticket", "work item", "story", "user story"], - &["requirement", "spec", "specification", "req"], - - // Status - &["done", "completed", "finished", "resolved", "closed"], - &["pending", "waiting", "blocked", "on hold"], - &["in progress", "wip", "ongoing", "active", "working"], - &["todo", "to do", "planned", "backlog"], - - // Priority - &["critical", "urgent", "p0", "blocker", "showstopper"], - &["high", "important", "p1"], - &["medium", "normal", "p2"], - &["low", "minor", "p3"], - - // Project management - &["milestone", "release", "version", "sprint"], - &["deadline", "due date", "target date"], - &["project", "repo", "repository", "codebase"], - - // Documentation - &["doc", "docs", "documentation", "readme", "guide"], - &["api", "interface", "endpoint"], - - // Testing - &["test", "testing", "qa", "quality assurance"], - &["unit test", "unittest"], - &["integration test", "e2e", "end-to-end"], - - // Architecture - &["module", "component", "service", "package"], - &["database", "db", "datastore", "storage"], - &["cache", "caching", "redis", "memcached"], -]; - -/// Get all synonyms for a query term -fn get_synonyms(query: &str) -> Vec { - let query_lower = query.to_lowercase(); - let mut synonyms = vec![query_lower.clone()]; - - for group in SYNONYM_GROUPS { - if group.iter().any(|&word| word == query_lower || query_lower.contains(word) || word.contains(&query_lower)) { - for &word in *group { - if !synonyms.contains(&word.to_string()) { - synonyms.push(word.to_string()); - } - } - } - } - - synonyms -} - -/// Check if text matches any of the search terms (including synonyms) -fn matches_with_synonyms(text: &str, search_terms: &[String]) -> bool { - let text_lower = text.to_lowercase(); - search_terms.iter().any(|term| text_lower.contains(term)) -} - -/// Entity in the knowledge graph -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Entity { - pub name: String, - #[serde(rename = "entityType")] - pub entity_type: String, - #[serde(default)] - pub observations: Vec, - #[serde(rename = "createdBy", default = "default_user", skip_serializing_if = "is_default_user")] - pub created_by: String, - #[serde(rename = "updatedBy", default = "default_user", skip_serializing_if = "is_default_user")] - pub updated_by: String, - #[serde(rename = "createdAt", default, skip_serializing_if = "is_zero")] - pub created_at: u64, - #[serde(rename = "updatedAt", default, skip_serializing_if = "is_zero")] - pub updated_at: u64, -} - -fn is_zero(val: &u64) -> bool { - *val == 0 -} - -/// Relation between entities with temporal validity -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Relation { - pub from: String, - pub to: String, - #[serde(rename = "relationType")] - pub relation_type: String, - #[serde(rename = "createdBy", default = "default_user", skip_serializing_if = "is_default_user")] - pub created_by: String, - #[serde(rename = "createdAt", default, skip_serializing_if = "is_zero")] - pub created_at: u64, - #[serde(rename = "validFrom", default, skip_serializing_if = "Option::is_none")] - pub valid_from: Option, - #[serde(rename = "validTo", default, skip_serializing_if = "Option::is_none")] - pub valid_to: Option, -} - -/// Knowledge graph containing entities and relations -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct KnowledgeGraph { - #[serde(default)] - pub entities: Vec, - #[serde(default)] - pub relations: Vec, -} - -/// Observation to add to an entity -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Observation { - #[serde(rename = "entityName")] - pub entity_name: String, - pub contents: Vec, -} - -/// Observation deletion request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ObservationDeletion { - #[serde(rename = "entityName")] - pub entity_name: String, - pub observations: Vec, -} - -/// Related entity with relation info -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RelatedEntity { - #[serde(rename = "relationType")] - pub relation_type: String, - pub direction: String, - pub entity: Entity, -} - -/// Result of get_related query -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RelatedEntities { - pub entity: String, - pub relations: Vec, -} - -/// Path step for traverse query -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PathStep { - #[serde(rename = "relationType")] - pub relation_type: String, - pub direction: String, - #[serde(rename = "targetType")] - pub target_type: Option, -} - -/// Single path in traversal result -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TraversalPath { - pub nodes: Vec, - pub relations: Vec, -} - -/// Result of traverse query -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TraversalResult { - #[serde(rename = "startNode")] - pub start_node: String, - pub paths: Vec, - #[serde(rename = "endNodes")] - pub end_nodes: Vec, -} - -/// Summary statistics -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct Summary { - #[serde(rename = "totalEntities")] - pub total_entities: usize, - #[serde(skip_serializing_if = "Option::is_none")] - pub entities: Option>, - #[serde(rename = "byStatus", skip_serializing_if = "Option::is_none")] - pub by_status: Option>, - #[serde(rename = "byType", skip_serializing_if = "Option::is_none")] - pub by_type: Option>, - #[serde(rename = "byPriority", skip_serializing_if = "Option::is_none")] - pub by_priority: Option>, -} - -/// Brief entity info for summary -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EntityBrief { - pub name: String, - #[serde(rename = "entityType")] - pub entity_type: String, - pub brief: String, -} - -// ============================================================================ -// JSON-RPC Types -// ============================================================================ - -#[derive(Deserialize, Debug, Clone)] -pub struct JsonRpcRequest { - pub jsonrpc: String, - pub id: Option, - pub method: String, - pub params: Option, -} - -#[derive(Serialize, Debug)] -pub struct JsonRpcResponse { - pub jsonrpc: String, - pub id: Value, - pub result: Value, -} - -#[derive(Serialize, Debug)] -pub struct JsonRpcError { - pub jsonrpc: String, - pub id: Value, - pub error: ErrorObject, -} - -#[derive(Serialize, Debug)] -pub struct ErrorObject { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -// ============================================================================ -// MCP Types -// ============================================================================ - -#[derive(Serialize, Debug)] -pub struct McpTool { - pub name: String, - pub description: String, - #[serde(rename = "inputSchema")] - pub input_schema: Value, -} - -#[derive(Clone)] -pub struct ServerInfo { - pub name: String, - pub version: String, -} - -// ============================================================================ -// Tool Trait -// ============================================================================ - -pub trait Tool: Send + Sync { - fn definition(&self) -> McpTool; - fn execute(&self, params: Value) -> McpResult; -} - -// ============================================================================ -// Knowledge Base -// ============================================================================ - -/// Knowledge base with in-memory cache for thread-safe operations -pub struct KnowledgeBase { - memory_file_path: String, - graph: Mutex, - current_user: String, -} - -impl KnowledgeBase { - pub fn new() -> Self { - let current_dir = env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")); - let default_memory_path = current_dir.join("memory.jsonl"); - - let memory_file_path = match env::var("MEMORY_FILE_PATH") { - Ok(path) => { - if Path::new(&path).is_absolute() { - path - } else { - current_dir.join(path).to_string_lossy().to_string() - } - } - Err(_) => default_memory_path.to_string_lossy().to_string(), - }; - - // Detect current user once at startup - let current_user = get_current_user(); - - // Load graph from file at startup (or create empty if not exists) - let graph = Self::load_graph_from_file(&memory_file_path).unwrap_or_default(); - - Self { - memory_file_path, - graph: Mutex::new(graph), - current_user, - } - } - - /// Load graph from file (static helper for initialization) - fn load_graph_from_file(file_path: &str) -> McpResult { - if !Path::new(file_path).exists() { - return Ok(KnowledgeGraph::default()); - } - - let content = fs::read_to_string(file_path)?; - let mut graph = KnowledgeGraph::default(); - - for line in content.lines() { - let line = line.trim(); - if line.is_empty() { - continue; - } - - if let Ok(entity) = serde_json::from_str::(line) { - if !entity.name.is_empty() && !entity.entity_type.is_empty() { - graph.entities.push(entity); - continue; - } - } - - if let Ok(relation) = serde_json::from_str::(line) { - if !relation.from.is_empty() && !relation.to.is_empty() { - graph.relations.push(relation); - } - } - } - - Ok(graph) - } - - /// Get a clone of the current graph (thread-safe read) - fn load_graph(&self) -> McpResult { - Ok(self.graph.lock().unwrap().clone()) - } - - /// Persist graph to file (internal helper, expects caller to hold lock) - fn persist_to_file(&self, graph: &KnowledgeGraph) -> McpResult<()> { - // Ensure parent directory exists - if let Some(parent) = Path::new(&self.memory_file_path).parent() { - fs::create_dir_all(parent)?; - } - - let mut content = String::new(); - - for entity in &graph.entities { - content.push_str(&serde_json::to_string(entity)?); - content.push('\n'); - } - - for relation in &graph.relations { - content.push_str(&serde_json::to_string(relation)?); - content.push('\n'); - } - - fs::write(&self.memory_file_path, content)?; - Ok(()) - } - - /// Create new entities (thread-safe: holds lock during entire operation) - pub fn create_entities(&self, entities: Vec) -> McpResult> { - let mut graph = self.graph.lock().unwrap(); - let existing_names: HashSet = graph.entities.iter().map(|e| e.name.clone()).collect(); - let now = current_timestamp(); - - let mut created = Vec::new(); - for mut entity in entities { - if !existing_names.contains(&entity.name) { - // Auto-fill user info if not provided - if entity.created_by.is_empty() || entity.created_by == "system" { - entity.created_by = self.current_user.clone(); - } - if entity.updated_by.is_empty() || entity.updated_by == "system" { - entity.updated_by = self.current_user.clone(); - } - entity.created_at = now; - entity.updated_at = now; - created.push(entity.clone()); - graph.entities.push(entity); - } - } - - self.persist_to_file(&graph)?; - Ok(created) - } - - /// Create new relations (thread-safe: holds lock during entire operation) - pub fn create_relations(&self, relations: Vec) -> McpResult> { - let mut graph = self.graph.lock().unwrap(); - let entity_names: HashSet = graph.entities.iter().map(|e| e.name.clone()).collect(); - let now = current_timestamp(); - - let existing_relations: HashSet = graph.relations - .iter() - .map(|r| format!("{}|{}|{}", r.from, r.to, r.relation_type)) - .collect(); - - let mut created = Vec::new(); - for mut relation in relations { - if entity_names.contains(&relation.from) && entity_names.contains(&relation.to) { - let key = format!("{}|{}|{}", relation.from, relation.to, relation.relation_type); - if !existing_relations.contains(&key) { - // Auto-fill user info if not provided - if relation.created_by.is_empty() || relation.created_by == "system" { - relation.created_by = self.current_user.clone(); - } - relation.created_at = now; - created.push(relation.clone()); - graph.relations.push(relation); - } - } - } - - self.persist_to_file(&graph)?; - Ok(created) - } - - /// Add observations to entities (thread-safe: holds lock during entire operation) - pub fn add_observations(&self, observations: Vec) -> McpResult> { - let mut graph = self.graph.lock().unwrap(); - let mut added = Vec::new(); - let now = current_timestamp(); - - for obs in observations { - if let Some(entity) = graph.entities.iter_mut().find(|e| e.name == obs.entity_name) { - let existing: HashSet = entity.observations.iter().cloned().collect(); - let mut new_contents = Vec::new(); - - for content in &obs.contents { - if !existing.contains(content) { - entity.observations.push(content.clone()); - new_contents.push(content.clone()); - } - } - - if !new_contents.is_empty() { - entity.updated_at = now; - entity.updated_by = self.current_user.clone(); - added.push(Observation { - entity_name: obs.entity_name.clone(), - contents: new_contents, - }); - } - } - } - - self.persist_to_file(&graph)?; - Ok(added) - } - - /// Delete entities (thread-safe: holds lock during entire operation) - pub fn delete_entities(&self, entity_names: Vec) -> McpResult<()> { - let mut graph = self.graph.lock().unwrap(); - let names_to_delete: HashSet = entity_names.into_iter().collect(); - - graph.entities.retain(|e| !names_to_delete.contains(&e.name)); - graph.relations.retain(|r| { - !names_to_delete.contains(&r.from) && !names_to_delete.contains(&r.to) - }); - - self.persist_to_file(&graph)?; - Ok(()) - } - - /// Delete observations from entities (thread-safe: holds lock during entire operation) - pub fn delete_observations(&self, deletions: Vec) -> McpResult<()> { - let mut graph = self.graph.lock().unwrap(); - - for deletion in deletions { - if let Some(entity) = graph.entities.iter_mut().find(|e| e.name == deletion.entity_name) { - let to_remove: HashSet = deletion.observations.into_iter().collect(); - entity.observations.retain(|o| !to_remove.contains(o)); - } - } - - self.persist_to_file(&graph)?; - Ok(()) - } - - /// Delete relations (thread-safe: holds lock during entire operation) - pub fn delete_relations(&self, relations: Vec) -> McpResult<()> { - let mut graph = self.graph.lock().unwrap(); - - let to_delete: HashSet = relations - .iter() - .map(|r| format!("{}|{}|{}", r.from, r.to, r.relation_type)) - .collect(); - - graph.relations.retain(|r| { - let key = format!("{}|{}|{}", r.from, r.to, r.relation_type); - !to_delete.contains(&key) - }); - - self.persist_to_file(&graph)?; - Ok(()) - } - - /// Read graph with optional pagination - pub fn read_graph(&self, limit: Option, offset: Option) -> McpResult { - let graph = self.load_graph()?; - - let offset = offset.unwrap_or(0); - - let entities: Vec = if let Some(lim) = limit { - graph.entities.into_iter().skip(offset).take(lim).collect() - } else { - graph.entities.into_iter().skip(offset).collect() - }; - - let entity_names: HashSet = entities.iter().map(|e| e.name.clone()).collect(); - - let relations: Vec = graph.relations - .into_iter() - .filter(|r| entity_names.contains(&r.from) || entity_names.contains(&r.to)) - .collect(); - - Ok(KnowledgeGraph { entities, relations }) - } - - /// Search nodes by query with synonym expansion, optional limit and relation inclusion - pub fn search_nodes(&self, query: &str, limit: Option, include_relations: bool) -> McpResult { - let graph = self.load_graph()?; - - // Expand query with synonyms for semantic matching - let search_terms = get_synonyms(query); - - let mut matching_entities: Vec = graph.entities - .into_iter() - .filter(|e| { - matches_with_synonyms(&e.name, &search_terms) || - matches_with_synonyms(&e.entity_type, &search_terms) || - e.observations.iter().any(|o| matches_with_synonyms(o, &search_terms)) - }) - .collect(); - - // Apply limit if specified - if let Some(lim) = limit { - matching_entities.truncate(lim); - } - - let matching_relations = if include_relations { - let entity_names: HashSet = matching_entities.iter().map(|e| e.name.clone()).collect(); - graph.relations - .into_iter() - .filter(|r| entity_names.contains(&r.from) || entity_names.contains(&r.to)) - .collect() - } else { - Vec::new() - }; - - Ok(KnowledgeGraph { - entities: matching_entities, - relations: matching_relations, - }) - } - - /// Open specific nodes by names - pub fn open_nodes(&self, names: Vec) -> McpResult { - let graph = self.load_graph()?; - let name_set: HashSet = names.into_iter().collect(); - - let matching_entities: Vec = graph.entities - .into_iter() - .filter(|e| name_set.contains(&e.name)) - .collect(); - - let entity_names: HashSet = matching_entities.iter().map(|e| e.name.clone()).collect(); - - let matching_relations: Vec = graph.relations - .into_iter() - .filter(|r| entity_names.contains(&r.from) && entity_names.contains(&r.to)) - .collect(); - - Ok(KnowledgeGraph { - entities: matching_entities, - relations: matching_relations, - }) - } - - /// Get related entities - pub fn get_related( - &self, - entity_name: &str, - relation_type: Option<&str>, - direction: &str, - ) -> McpResult { - let graph = self.load_graph()?; - let mut related = Vec::new(); - - for relation in &graph.relations { - let matches = match direction { - "outgoing" => relation.from == entity_name, - "incoming" => relation.to == entity_name, - "both" => relation.from == entity_name || relation.to == entity_name, - _ => false, - }; - - if !matches { - continue; - } - - if let Some(rt) = relation_type { - if relation.relation_type != rt { - continue; - } - } - - let target_name = if relation.from == entity_name { - &relation.to - } else { - &relation.from - }; - - if let Some(entity) = graph.entities.iter().find(|e| e.name == *target_name) { - related.push(RelatedEntity { - relation_type: relation.relation_type.clone(), - direction: if relation.from == entity_name { - "outgoing".to_string() - } else { - "incoming".to_string() - }, - entity: entity.clone(), - }); - } - } - - Ok(RelatedEntities { - entity: entity_name.to_string(), - relations: related, - }) - } - - /// Traverse graph following path pattern - pub fn traverse( - &self, - start: &str, - path: Vec, - max_results: usize, - ) -> McpResult { - let graph = self.load_graph()?; - - // Track paths: (current_node, path_so_far, relations_so_far) - let mut current_paths: Vec<(String, Vec, Vec)> = - vec![(start.to_string(), vec![start.to_string()], vec![])]; - - for step in &path { - let mut next_paths = Vec::new(); - - for (node, nodes_path, rels_path) in ¤t_paths { - // Find related entities for this step - for relation in &graph.relations { - let (matches, target_name) = match step.direction.as_str() { - "out" => { - if relation.from == *node && relation.relation_type == step.relation_type - { - (true, &relation.to) - } else { - (false, &relation.to) - } - } - "in" => { - if relation.to == *node && relation.relation_type == step.relation_type { - (true, &relation.from) - } else { - (false, &relation.from) - } - } - _ => (false, &relation.to), - }; - - if !matches { - continue; - } - - // Check target type if specified - if let Some(ref target_type) = step.target_type { - if let Some(entity) = graph.entities.iter().find(|e| e.name == *target_name) - { - if &entity.entity_type != target_type { - continue; - } - } else { - continue; - } - } - - let mut new_nodes = nodes_path.clone(); - new_nodes.push(target_name.clone()); - let mut new_rels = rels_path.clone(); - new_rels.push(step.relation_type.clone()); - - next_paths.push((target_name.clone(), new_nodes, new_rels)); - } - } - - if next_paths.len() > max_results { - next_paths.truncate(max_results); - } - - current_paths = next_paths; - } - - // Build result - let mut paths = Vec::new(); - let mut end_node_names: HashSet = HashSet::new(); - - for (end_node, nodes, rels) in current_paths { - end_node_names.insert(end_node); - paths.push(TraversalPath { - nodes, - relations: rels, - }); - } - - let end_nodes: Vec = graph - .entities - .iter() - .filter(|e| end_node_names.contains(&e.name)) - .cloned() - .collect(); - - Ok(TraversalResult { - start_node: start.to_string(), - paths, - end_nodes, - }) - } - - /// Summarize entities - pub fn summarize( - &self, - entity_names: Option>, - entity_type: Option, - format: &str, - ) -> McpResult { - let graph = self.load_graph()?; - - let entities: Vec<&Entity> = graph - .entities - .iter() - .filter(|e| { - if let Some(ref names) = entity_names { - names.contains(&e.name) - } else if let Some(ref et) = entity_type { - &e.entity_type == et - } else { - true - } - }) - .collect(); - - match format { - "brief" => self.format_brief(&entities), - "detailed" => self.format_detailed(&entities), - "stats" => self.format_stats(&entities), - _ => self.format_brief(&entities), - } - } - - fn format_brief(&self, entities: &[&Entity]) -> McpResult { - let briefs: Vec = entities - .iter() - .map(|e| { - let brief = e - .observations - .first() - .cloned() - .unwrap_or_default() - .chars() - .take(100) - .collect::(); - EntityBrief { - name: e.name.clone(), - entity_type: e.entity_type.clone(), - brief, - } - }) - .collect(); - - Ok(Summary { - total_entities: entities.len(), - entities: Some(briefs), - ..Default::default() - }) - } - - fn format_detailed(&self, entities: &[&Entity]) -> McpResult { - let briefs: Vec = entities - .iter() - .map(|e| { - let brief = e.observations.join("; "); - EntityBrief { - name: e.name.clone(), - entity_type: e.entity_type.clone(), - brief, - } - }) - .collect(); - - Ok(Summary { - total_entities: entities.len(), - entities: Some(briefs), - ..Default::default() - }) - } - - fn format_stats(&self, entities: &[&Entity]) -> McpResult { - let mut by_status: HashMap = HashMap::new(); - let mut by_type: HashMap = HashMap::new(); - let mut by_priority: HashMap = HashMap::new(); - - for entity in entities { - *by_type.entry(entity.entity_type.clone()).or_insert(0) += 1; - - for obs in &entity.observations { - if obs.starts_with("Status:") { - let status = obs.trim_start_matches("Status:").trim().to_string(); - *by_status.entry(status).or_insert(0) += 1; - } - if obs.starts_with("Priority:") { - let priority = obs.trim_start_matches("Priority:").trim().to_string(); - *by_priority.entry(priority).or_insert(0) += 1; - } - } - } - - Ok(Summary { - total_entities: entities.len(), - entities: None, - by_status: if by_status.is_empty() { - None - } else { - Some(by_status) - }, - by_type: Some(by_type), - by_priority: if by_priority.is_empty() { - None - } else { - Some(by_priority) - }, - }) - } - - /// Get relations valid at a specific point in time - pub fn get_relations_at_time(&self, timestamp: Option, entity_name: Option<&str>) -> McpResult> { - let graph = self.load_graph()?; - let check_time = timestamp.unwrap_or_else(current_timestamp); - - let relations: Vec = graph.relations - .into_iter() - .filter(|r| { - // Filter by entity if specified - if let Some(name) = entity_name { - if r.from != name && r.to != name { - return false; - } - } - - // Check temporal validity - let valid_from_ok = match r.valid_from { - Some(vf) => check_time >= vf, - None => true, // No start time means always valid from past - }; - - let valid_to_ok = match r.valid_to { - Some(vt) => check_time <= vt, - None => true, // No end time means still valid - }; - - valid_from_ok && valid_to_ok - }) - .collect(); - - Ok(relations) - } - - /// Get historical relations (including expired ones) - pub fn get_relation_history(&self, entity_name: &str) -> McpResult> { - let graph = self.load_graph()?; - - let relations: Vec = graph.relations - .into_iter() - .filter(|r| r.from == entity_name || r.to == entity_name) - .collect(); - - Ok(relations) - } -} - -// ============================================================================ -// Memory Tools Implementation -// ============================================================================ - -pub struct CreateEntitiesTool { - kb: std::sync::Arc, -} - -impl CreateEntitiesTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for CreateEntitiesTool { - fn definition(&self) -> McpTool { - McpTool { - name: "create_entities".to_string(), - description: "Create multiple new entities in the knowledge graph".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "entities": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string", "description": "The name of the entity" }, - "entityType": { "type": "string", "description": "The type of the entity" }, - "observations": { - "type": "array", - "items": { "type": "string" }, - "description": "Initial observations about the entity" - }, - "createdBy": { "type": "string", "description": "Who created this entity (auto-filled from git/env if not provided)" }, - "updatedBy": { "type": "string", "description": "Who last updated this entity (auto-filled from git/env if not provided)" } - }, - "required": ["name", "entityType"] - } - } - }, - "required": ["entities"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let entities: Vec = serde_json::from_value( - params.get("entities").cloned().unwrap_or(json!([])) - )?; - - // Collect warnings for non-standard types - let warnings: Vec = entities.iter() - .filter_map(|e| validate_entity_type(&e.entity_type)) - .collect(); - - let created = self.kb.create_entities(entities)?; - - let response = if warnings.is_empty() { - serde_json::to_string_pretty(&created)? - } else { - format!("{}\n\n{}", serde_json::to_string_pretty(&created)?, warnings.join("\n")) - }; - - Ok(json!({ - "content": [{ - "type": "text", - "text": response - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct CreateRelationsTool { - kb: std::sync::Arc, -} - -impl CreateRelationsTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for CreateRelationsTool { - fn definition(&self) -> McpTool { - McpTool { - name: "create_relations".to_string(), - description: "Create multiple new relations between entities in the knowledge graph".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "relations": { - "type": "array", - "items": { - "type": "object", - "properties": { - "from": { "type": "string", "description": "The source entity name" }, - "to": { "type": "string", "description": "The target entity name" }, - "relationType": { "type": "string", "description": "The type of relation" }, - "createdBy": { "type": "string", "description": "Who created this relation (auto-filled from git/env if not provided)" } - }, - "required": ["from", "to", "relationType"] - } - } - }, - "required": ["relations"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let relations: Vec = serde_json::from_value( - params.get("relations").cloned().unwrap_or(json!([])) - )?; - - // Collect warnings for non-standard relation types - let warnings: Vec = relations.iter() - .filter_map(|r| validate_relation_type(&r.relation_type)) - .collect(); - - let created = self.kb.create_relations(relations)?; - - let response = if warnings.is_empty() { - serde_json::to_string_pretty(&created)? - } else { - format!("{}\n\n{}", serde_json::to_string_pretty(&created)?, warnings.join("\n")) - }; - - Ok(json!({ - "content": [{ - "type": "text", - "text": response - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct AddObservationsTool { - kb: std::sync::Arc, -} - -impl AddObservationsTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for AddObservationsTool { - fn definition(&self) -> McpTool { - McpTool { - name: "add_observations".to_string(), - description: "Add new observations to existing entities in the knowledge graph".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "observations": { - "type": "array", - "items": { - "type": "object", - "properties": { - "entityName": { "type": "string", "description": "The name of the entity" }, - "contents": { - "type": "array", - "items": { "type": "string" }, - "description": "Observation contents to add" - } - }, - "required": ["entityName", "contents"] - } - } - }, - "required": ["observations"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let observations: Vec = serde_json::from_value( - params.get("observations").cloned().unwrap_or(json!([])) - )?; - let added = self.kb.add_observations(observations)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&added)? - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct DeleteEntitiesTool { - kb: std::sync::Arc, -} - -impl DeleteEntitiesTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for DeleteEntitiesTool { - fn definition(&self) -> McpTool { - McpTool { - name: "delete_entities".to_string(), - description: "Delete multiple entities from the knowledge graph".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "entityNames": { - "type": "array", - "items": { "type": "string" }, - "description": "An array of entity names to delete" - } - }, - "required": ["entityNames"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let entity_names: Vec = serde_json::from_value( - params.get("entityNames").cloned().unwrap_or(json!([])) - )?; - self.kb.delete_entities(entity_names)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": "Entities deleted successfully" - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct DeleteObservationsTool { - kb: std::sync::Arc, -} - -impl DeleteObservationsTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for DeleteObservationsTool { - fn definition(&self) -> McpTool { - McpTool { - name: "delete_observations".to_string(), - description: "Delete specific observations from entities in the knowledge graph".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "deletions": { - "type": "array", - "items": { - "type": "object", - "properties": { - "entityName": { "type": "string", "description": "The name of the entity" }, - "observations": { - "type": "array", - "items": { "type": "string" }, - "description": "Observations to delete" - } - }, - "required": ["entityName", "observations"] - } - } - }, - "required": ["deletions"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let deletions: Vec = serde_json::from_value( - params.get("deletions").cloned().unwrap_or(json!([])) - )?; - self.kb.delete_observations(deletions)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": "Observations deleted successfully" - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct DeleteRelationsTool { - kb: std::sync::Arc, -} - -impl DeleteRelationsTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for DeleteRelationsTool { - fn definition(&self) -> McpTool { - McpTool { - name: "delete_relations".to_string(), - description: "Delete multiple relations from the knowledge graph".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "relations": { - "type": "array", - "items": { - "type": "object", - "properties": { - "from": { "type": "string", "description": "The source entity name" }, - "to": { "type": "string", "description": "The target entity name" }, - "relationType": { "type": "string", "description": "The type of relation" } - }, - "required": ["from", "to", "relationType"] - } - } - }, - "required": ["relations"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let relations: Vec = serde_json::from_value( - params.get("relations").cloned().unwrap_or(json!([])) - )?; - self.kb.delete_relations(relations)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": "Relations deleted successfully" - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct ReadGraphTool { - kb: std::sync::Arc, -} - -impl ReadGraphTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for ReadGraphTool { - fn definition(&self) -> McpTool { - McpTool { - name: "read_graph".to_string(), - description: "Read the knowledge graph with optional pagination. Use limit/offset to avoid context overflow with large graphs.".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "limit": { - "type": "integer", - "description": "Maximum number of entities to return. Recommended: 50-100 for large graphs" - }, - "offset": { - "type": "integer", - "description": "Number of entities to skip (for pagination)" - } - }, - "required": [] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let limit = params.get("limit").and_then(|v| v.as_u64()).map(|v| v as usize); - let offset = params.get("offset").and_then(|v| v.as_u64()).map(|v| v as usize); - let graph = self.kb.read_graph(limit, offset)?; - - let total_msg = if limit.is_some() || offset.is_some() { - format!(" (showing {} entities)", graph.entities.len()) - } else { - String::new() - }; - - Ok(json!({ - "content": [{ - "type": "text", - "text": format!("{}{}", serde_json::to_string_pretty(&graph)?, total_msg) - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct SearchNodesTool { - kb: std::sync::Arc, -} - -impl SearchNodesTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for SearchNodesTool { - fn definition(&self) -> McpTool { - McpTool { - name: "search_nodes".to_string(), - description: "Search for nodes in the knowledge graph. Returns matching entities with optional relations.".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query to match against entity names, types, and observations" - }, - "limit": { - "type": "integer", - "description": "Maximum number of entities to return (default: no limit)" - }, - "includeRelations": { - "type": "boolean", - "description": "Whether to include relations connected to matching entities (default: true)" - } - }, - "required": ["query"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let query = params.get("query") - .and_then(|v| v.as_str()) - .unwrap_or(""); - let limit = params.get("limit").and_then(|v| v.as_u64()).map(|v| v as usize); - let include_relations = params.get("includeRelations") - .and_then(|v| v.as_bool()) - .unwrap_or(true); - - let graph = self.kb.search_nodes(query, limit, include_relations)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&graph)? - }] - })) - } -} - -// ---------------------------------------------------------------------------- - -pub struct OpenNodesTool { - kb: std::sync::Arc, -} - -impl OpenNodesTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for OpenNodesTool { - fn definition(&self) -> McpTool { - McpTool { - name: "open_nodes".to_string(), - description: "Open specific nodes in the knowledge graph by their names".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "names": { - "type": "array", - "items": { "type": "string" }, - "description": "An array of entity names to retrieve" - } - }, - "required": ["names"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let names: Vec = serde_json::from_value( - params.get("names").cloned().unwrap_or(json!([])) - )?; - let graph = self.kb.open_nodes(names)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&graph)? - }] - })) - } -} - -// ---------------------------------------------------------------------------- -// Get Related Tool -// ---------------------------------------------------------------------------- - -pub struct GetRelatedTool { - kb: std::sync::Arc, -} - -impl GetRelatedTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for GetRelatedTool { - fn definition(&self) -> McpTool { - McpTool { - name: "get_related".to_string(), - description: "Get entities related to a specific entity".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "entityName": { - "type": "string", - "description": "Name of the entity to find relations for" - }, - "relationType": { - "type": "string", - "description": "Filter by relation type (optional)" - }, - "direction": { - "type": "string", - "enum": ["outgoing", "incoming", "both"], - "default": "both", - "description": "Direction of relations" - } - }, - "required": ["entityName"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let entity_name = params - .get("entityName") - .and_then(|v| v.as_str()) - .ok_or("Missing entityName")?; - let relation_type = params.get("relationType").and_then(|v| v.as_str()); - let direction = params - .get("direction") - .and_then(|v| v.as_str()) - .unwrap_or("both"); - - let related = self.kb.get_related(entity_name, relation_type, direction)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&related)? - }] - })) - } -} - -// ---------------------------------------------------------------------------- -// Traverse Tool -// ---------------------------------------------------------------------------- - -pub struct TraverseTool { - kb: std::sync::Arc, -} - -impl TraverseTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for TraverseTool { - fn definition(&self) -> McpTool { - McpTool { - name: "traverse".to_string(), - description: "Traverse the graph following a path pattern for multi-hop queries" - .to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "startNode": { - "type": "string", - "description": "Starting entity name" - }, - "path": { - "type": "array", - "items": { - "type": "object", - "properties": { - "relationType": { - "type": "string", - "description": "Type of relation to follow" - }, - "direction": { - "type": "string", - "enum": ["out", "in"], - "description": "Direction: out (outgoing) or in (incoming)" - }, - "targetType": { - "type": "string", - "description": "Filter by target entity type (optional)" - } - }, - "required": ["relationType", "direction"] - }, - "description": "Path pattern to follow" - }, - "maxResults": { - "type": "integer", - "default": 50, - "description": "Maximum number of results" - } - }, - "required": ["startNode", "path"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let start_node = params - .get("startNode") - .and_then(|v| v.as_str()) - .ok_or("Missing startNode")?; - - let path: Vec = serde_json::from_value( - params.get("path").cloned().unwrap_or(json!([])) - )?; - - let max_results = params - .get("maxResults") - .and_then(|v| v.as_u64()) - .unwrap_or(50) as usize; - - let result = self.kb.traverse(start_node, path, max_results)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&result)? - }] - })) - } -} - -// ---------------------------------------------------------------------------- -// Summarize Tool -// ---------------------------------------------------------------------------- - -pub struct SummarizeTool { - kb: std::sync::Arc, -} - -impl SummarizeTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for SummarizeTool { - fn definition(&self) -> McpTool { - McpTool { - name: "summarize".to_string(), - description: "Get a condensed summary of entities".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "entityNames": { - "type": "array", - "items": { "type": "string" }, - "description": "Specific entities to summarize (optional)" - }, - "entityType": { - "type": "string", - "description": "Summarize all entities of this type (optional)" - }, - "format": { - "type": "string", - "enum": ["brief", "detailed", "stats"], - "default": "brief", - "description": "Output format: brief (first observation), detailed (all observations), stats (statistics)" - } - }, - "required": [] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let entity_names: Option> = params - .get("entityNames") - .and_then(|v| serde_json::from_value(v.clone()).ok()); - - let entity_type = params - .get("entityType") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - let format = params - .get("format") - .and_then(|v| v.as_str()) - .unwrap_or("brief"); - - let summary = self.kb.summarize(entity_names, entity_type, format)?; - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&summary)? - }] - })) - } -} - -// ---------------------------------------------------------------------------- -// Time Tool -// ---------------------------------------------------------------------------- - -fn get_current_time() -> Value { - let now = SystemTime::now(); - let duration = now.duration_since(UNIX_EPOCH).unwrap(); - let timestamp = duration.as_secs(); - let millis = duration.as_millis() as u64; - - // Calculate datetime components - let secs = timestamp as i64; - - // Days since epoch - let days = secs / 86400; - let remaining = secs % 86400; - - let hours = remaining / 3600; - let minutes = (remaining % 3600) / 60; - let seconds = remaining % 60; - - // Calculate year, month, day - let (year, month, day) = days_to_ymd(days); - - // Format ISO 8601 - let iso8601 = format!( - "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", - year, month, day, hours, minutes, seconds - ); - - // Format readable - let weekday = get_weekday(days); - let month_name = get_month_name(month); - let readable = format!( - "{}, {} {} {} {:02}:{:02}:{:02} UTC", - weekday, day, month_name, year, hours, minutes, seconds - ); - - json!({ - "timestamp": timestamp, - "timestamp_ms": millis, - "iso8601": iso8601, - "readable": readable, - "components": { - "year": year, - "month": month, - "day": day, - "hour": hours, - "minute": minutes, - "second": seconds, - "weekday": weekday - } - }) -} - -fn days_to_ymd(days: i64) -> (i64, u32, u32) { - // Algorithm to convert days since epoch to year/month/day - let remaining_days = days + 719468; // Days from year 0 to 1970 - - let era = remaining_days / 146097; - let doe = remaining_days - era * 146097; - let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; - let year = yoe + era * 400; - let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); - let mp = (5 * doy + 2) / 153; - let day = (doy - (153 * mp + 2) / 5 + 1) as u32; - let month = if mp < 10 { mp + 3 } else { mp - 9 } as u32; - let year = if month <= 2 { year + 1 } else { year }; - - (year, month, day) -} - -fn get_weekday(days: i64) -> &'static str { - match (days + 4) % 7 { - 0 => "Sunday", - 1 => "Monday", - 2 => "Tuesday", - 3 => "Wednesday", - 4 => "Thursday", - 5 => "Friday", - 6 => "Saturday", - _ => "Unknown", - } -} - -fn get_month_name(month: u32) -> &'static str { - match month { - 1 => "January", - 2 => "February", - 3 => "March", - 4 => "April", - 5 => "May", - 6 => "June", - 7 => "July", - 8 => "August", - 9 => "September", - 10 => "October", - 11 => "November", - 12 => "December", - _ => "Unknown", - } -} - -// ---------------------------------------------------------------------------- -// Temporal Query Tools -// ---------------------------------------------------------------------------- - -pub struct GetRelationsAtTimeTool { - kb: std::sync::Arc, -} - -impl GetRelationsAtTimeTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for GetRelationsAtTimeTool { - fn definition(&self) -> McpTool { - McpTool { - name: "get_relations_at_time".to_string(), - description: "Get relations that are valid at a specific point in time. Useful for querying historical state of the knowledge graph.".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "timestamp": { - "type": "integer", - "description": "Unix timestamp to query. If not provided, uses current time." - }, - "entityName": { - "type": "string", - "description": "Optional: filter relations involving this entity" - } - }, - "required": [] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let timestamp = params.get("timestamp").and_then(|v| v.as_u64()); - let entity_name = params.get("entityName").and_then(|v| v.as_str()); - - let relations = self.kb.get_relations_at_time(timestamp, entity_name)?; - - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&json!({ - "queryTime": timestamp.unwrap_or_else(current_timestamp), - "relations": relations - }))? - }] - })) - } -} - -pub struct GetRelationHistoryTool { - kb: std::sync::Arc, -} - -impl GetRelationHistoryTool { - pub fn new(kb: std::sync::Arc) -> Self { - Self { kb } - } -} - -impl Tool for GetRelationHistoryTool { - fn definition(&self) -> McpTool { - McpTool { - name: "get_relation_history".to_string(), - description: "Get all relations (current and historical) for an entity. Shows temporal validity (validFrom/validTo) for each relation.".to_string(), - input_schema: json!({ - "type": "object", - "properties": { - "entityName": { - "type": "string", - "description": "The name of the entity to get relation history for" - } - }, - "required": ["entityName"] - }), - } - } - - fn execute(&self, params: Value) -> McpResult { - let entity_name = params.get("entityName") - .and_then(|v| v.as_str()) - .ok_or("entityName is required")?; - - let relations = self.kb.get_relation_history(entity_name)?; - let current_time = current_timestamp(); - - // Mark each relation as current or historical - let annotated: Vec = relations.iter().map(|r| { - let is_current = match (r.valid_from, r.valid_to) { - (Some(vf), Some(vt)) => current_time >= vf && current_time <= vt, - (Some(vf), None) => current_time >= vf, - (None, Some(vt)) => current_time <= vt, - (None, None) => true, - }; - - json!({ - "from": r.from, - "to": r.to, - "relationType": r.relation_type, - "validFrom": r.valid_from, - "validTo": r.valid_to, - "isCurrent": is_current - }) - }).collect(); - - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&json!({ - "entity": entity_name, - "currentTime": current_time, - "relations": annotated - }))? - }] - })) - } -} - -pub struct GetCurrentTimeTool; - -impl GetCurrentTimeTool { - pub fn new() -> Self { - Self - } -} - -impl Tool for GetCurrentTimeTool { - fn definition(&self) -> McpTool { - McpTool { - name: "get_current_time".to_string(), - description: "Get the current datetime and timestamp".to_string(), - input_schema: json!({ - "type": "object", - "properties": {}, - "required": [] - }), - } - } - - fn execute(&self, _params: Value) -> McpResult { - let time_info = get_current_time(); - Ok(json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string_pretty(&time_info)? - }] - })) - } -} - -// ============================================================================ -// MCP Server -// ============================================================================ - -pub struct McpServer { - server_info: ServerInfo, - tools: HashMap>, - reader: BufReader, - writer: BufWriter, -} - -impl McpServer { - pub fn new() -> Self { - Self { - server_info: ServerInfo { - name: "memory".to_string(), - version: "1.0.0".to_string(), - }, - tools: HashMap::new(), - reader: BufReader::new(io::stdin()), - writer: BufWriter::new(io::stdout()), - } - } - - pub fn with_info(info: ServerInfo) -> Self { - Self { - server_info: info, - tools: HashMap::new(), - reader: BufReader::new(io::stdin()), - writer: BufWriter::new(io::stdout()), - } - } - - pub fn register_tool(&mut self, tool: Box) -> &mut Self { - let name = tool.definition().name.clone(); - self.tools.insert(name, tool); - self - } - - pub fn run(&mut self) -> McpResult<()> { - let mut line = String::new(); - while self.reader.read_line(&mut line)? > 0 { - let trimmed = line.trim(); - if !trimmed.is_empty() { - self.handle_request(trimmed)?; - } - line.clear(); - } - Ok(()) - } - - fn handle_request(&mut self, request_str: &str) -> McpResult<()> { - let request: JsonRpcRequest = match serde_json::from_str(request_str) { - Ok(req) => req, - Err(e) => { - self.send_error_response( - Value::Null, - -32700, - "Parse error", - Some(json!({"details": e.to_string()})), - )?; - return Ok(()); - } - }; - - if request.jsonrpc != "2.0" { - self.send_error_response( - request.id.unwrap_or(Value::Null), - -32600, - "Invalid Request", - Some(json!({"details": "jsonrpc must be '2.0'"})), - )?; - return Ok(()); - } - - let id = request.id.clone().unwrap_or(Value::Null); - - match request.method.as_str() { - "initialize" => self.handle_initialize(id, request.params), - "notifications/initialized" => Ok(()), // Notification, no response - "tools/list" => self.handle_tools_list(id), - "tools/call" => self.handle_tool_call(id, request.params), - "ping" => self.send_success_response(id, json!({})), - _ => self.send_error_response( - id, - -32601, - "Method not found", - Some(json!({"method": request.method})), - ), - } - } - - fn handle_initialize(&mut self, id: Value, _params: Option) -> McpResult<()> { - let result = json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "tools": {} - }, - "serverInfo": { - "name": self.server_info.name, - "version": self.server_info.version - } - }); - self.send_success_response(id, result) - } - - fn handle_tools_list(&mut self, id: Value) -> McpResult<()> { - let tools: Vec = self.tools.values().map(|t| t.definition()).collect(); - let result = json!({ "tools": tools }); - self.send_success_response(id, result) - } - - fn handle_tool_call(&mut self, id: Value, params: Option) -> McpResult<()> { - let params = params.ok_or("Missing parameters")?; - let tool_name = params - .get("name") - .and_then(|v| v.as_str()) - .ok_or("Missing tool name")?; - - let tool = match self.tools.get(tool_name) { - Some(tool) => tool, - None => { - self.send_error_response( - id, - -32602, - "Unknown tool", - Some(json!({"tool": tool_name})), - )?; - return Ok(()); - } - }; - - let arguments = params - .get("arguments") - .cloned() - .unwrap_or(json!({})); - - match tool.execute(arguments) { - Ok(result) => self.send_success_response(id, result), - Err(e) => self.send_error_response( - id, - -32603, - "Tool execution error", - Some(json!({"details": e.to_string()})), - ), - } - } - - fn send_success_response(&mut self, id: Value, result: Value) -> McpResult<()> { - let response = JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id, - result, - }; - let json = serde_json::to_string(&response)?; - writeln!(self.writer, "{}", json)?; - self.writer.flush()?; - Ok(()) - } - - fn send_error_response( - &mut self, - id: Value, - code: i32, - message: &str, - data: Option, - ) -> McpResult<()> { - let response = JsonRpcError { - jsonrpc: "2.0".to_string(), - id, - error: ErrorObject { - code, - message: message.to_string(), - data, - }, - }; - let json = serde_json::to_string(&response)?; - writeln!(self.writer, "{}", json)?; - self.writer.flush()?; - Ok(()) - } -} - -// ============================================================================ -// Main -// ============================================================================ - -fn main() -> McpResult<()> { - let kb = std::sync::Arc::new(KnowledgeBase::new()); - - let server_info = ServerInfo { - name: "memory".to_string(), - version: "1.0.0".to_string(), - }; - - let mut server = McpServer::with_info(server_info); - - // Register all 9 memory tools + 3 query tools + 2 temporal tools + 1 time tool - server.register_tool(Box::new(CreateEntitiesTool::new(kb.clone()))); - server.register_tool(Box::new(CreateRelationsTool::new(kb.clone()))); - server.register_tool(Box::new(AddObservationsTool::new(kb.clone()))); - server.register_tool(Box::new(DeleteEntitiesTool::new(kb.clone()))); - server.register_tool(Box::new(DeleteObservationsTool::new(kb.clone()))); - server.register_tool(Box::new(DeleteRelationsTool::new(kb.clone()))); - server.register_tool(Box::new(ReadGraphTool::new(kb.clone()))); - server.register_tool(Box::new(SearchNodesTool::new(kb.clone()))); - server.register_tool(Box::new(OpenNodesTool::new(kb.clone()))); - // Query tools - server.register_tool(Box::new(GetRelatedTool::new(kb.clone()))); - server.register_tool(Box::new(TraverseTool::new(kb.clone()))); - server.register_tool(Box::new(SummarizeTool::new(kb.clone()))); - // Temporal tools - server.register_tool(Box::new(GetRelationsAtTimeTool::new(kb.clone()))); - server.register_tool(Box::new(GetRelationHistoryTool::new(kb.clone()))); - // Time tool - server.register_tool(Box::new(GetCurrentTimeTool::new())); - - server.run() -} - -// ============================================================================ -// Tests -// ============================================================================ - -#[cfg(test)] -mod tests { - use super::*; - use std::fs; - use std::sync::atomic::{AtomicU64, Ordering}; - - static TEST_COUNTER: AtomicU64 = AtomicU64::new(0); - - fn setup_test_kb() -> (KnowledgeBase, String) { - let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst); - let temp_file = format!("test_memory_{}_{}.jsonl", std::process::id(), id); - - // Create a KnowledgeBase with explicit file path and empty graph - let kb = KnowledgeBase { - memory_file_path: temp_file.clone(), - graph: Mutex::new(KnowledgeGraph::default()), - current_user: "test_user".to_string(), - }; - (kb, temp_file) - } - - fn cleanup(file_path: &str) { - let _ = fs::remove_file(file_path); - } - - #[test] - fn test_create_entities() { - let (kb, temp_file) = setup_test_kb(); - - let entities = vec![ - Entity { - name: "Alice".to_string(), - entity_type: "Person".to_string(), - observations: vec!["Lives in NYC".to_string()], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - Entity { - name: "Bob".to_string(), - entity_type: "Person".to_string(), - observations: vec![], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - ]; - - let created = kb.create_entities(entities).unwrap(); - assert_eq!(created.len(), 2); - // Verify user was auto-filled - assert_eq!(created[0].created_by, "test_user"); - assert_eq!(created[0].updated_by, "test_user"); - - let graph = kb.read_graph(None, None).unwrap(); - assert_eq!(graph.entities.len(), 2); - - cleanup(&temp_file); - } - - #[test] - fn test_create_relations() { - let (kb, temp_file) = setup_test_kb(); - - // First create entities - let entities = vec![ - Entity { - name: "Alice".to_string(), - entity_type: "Person".to_string(), - observations: vec![], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - Entity { - name: "Bob".to_string(), - entity_type: "Person".to_string(), - observations: vec![], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - ]; - kb.create_entities(entities).unwrap(); - - // Then create relations - let relations = vec![ - Relation { - from: "Alice".to_string(), - to: "Bob".to_string(), - relation_type: "knows".to_string(), - created_by: String::new(), - created_at: 0, - valid_from: None, - valid_to: None, - }, - ]; - - let created = kb.create_relations(relations).unwrap(); - assert_eq!(created.len(), 1); - // Verify user was auto-filled - assert_eq!(created[0].created_by, "test_user"); - - let graph = kb.read_graph(None, None).unwrap(); - assert_eq!(graph.relations.len(), 1); - - cleanup(&temp_file); - } - - #[test] - fn test_search_nodes() { - let (kb, temp_file) = setup_test_kb(); - - let entities = vec![ - Entity { - name: "Alice".to_string(), - entity_type: "Person".to_string(), - observations: vec!["Software Engineer".to_string()], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - Entity { - name: "Bob".to_string(), - entity_type: "Person".to_string(), - observations: vec!["Doctor".to_string()], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - ]; - kb.create_entities(entities).unwrap(); - - let result = kb.search_nodes("Alice", None, true).unwrap(); - assert_eq!(result.entities.len(), 1); - assert_eq!(result.entities[0].name, "Alice"); - - let result = kb.search_nodes("Engineer", None, true).unwrap(); - assert_eq!(result.entities.len(), 1); - assert_eq!(result.entities[0].name, "Alice"); - - cleanup(&temp_file); - } - - #[test] - fn test_delete_entities() { - let (kb, temp_file) = setup_test_kb(); - - let entities = vec![ - Entity { - name: "Alice".to_string(), - entity_type: "Person".to_string(), - observations: vec![], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - Entity { - name: "Bob".to_string(), - entity_type: "Person".to_string(), - observations: vec![], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }, - ]; - kb.create_entities(entities).unwrap(); - - kb.delete_entities(vec!["Alice".to_string()]).unwrap(); - - let graph = kb.read_graph(None, None).unwrap(); - assert_eq!(graph.entities.len(), 1); - assert_eq!(graph.entities[0].name, "Bob"); - - cleanup(&temp_file); - } - - #[test] - fn test_concurrent_access() { - use std::sync::Arc; - use std::thread; - - let (kb, temp_file) = setup_test_kb(); - let kb = Arc::new(kb); - - // Spawn multiple threads simulating concurrent agents - let mut handles = vec![]; - - for i in 0..10 { - let kb_clone = Arc::clone(&kb); - let handle = thread::spawn(move || { - // Each "agent" creates an entity - let entity = Entity { - name: format!("Agent{}", i), - entity_type: "Person".to_string(), - observations: vec![format!("Created by thread {}", i)], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }; - kb_clone.create_entities(vec![entity]).unwrap(); - - // Each agent also reads the graph - let graph = kb_clone.read_graph(None, None).unwrap(); - assert!(graph.entities.len() >= 1); - - // Each agent adds an observation - let obs = Observation { - entity_name: format!("Agent{}", i), - contents: vec![format!("Observation from thread {}", i)], - }; - let _ = kb_clone.add_observations(vec![obs]); - }); - handles.push(handle); - } - - // Wait for all threads to complete - for handle in handles { - handle.join().expect("Thread panicked"); - } - - // Verify final state - let graph = kb.read_graph(None, None).unwrap(); - assert_eq!(graph.entities.len(), 10, "All 10 entities should exist"); - - // Verify all entities have observations - for entity in &graph.entities { - assert!(entity.observations.len() >= 1, "Entity should have observations"); - } - - cleanup(&temp_file); - } - - #[test] - fn test_concurrent_read_write() { - use std::sync::Arc; - use std::thread; - - let (kb, temp_file) = setup_test_kb(); - - // Pre-populate with some entities - for i in 0..5 { - let entity = Entity { - name: format!("Entity{}", i), - entity_type: "Module".to_string(), - observations: vec![], - created_by: String::new(), - updated_by: String::new(), - created_at: 0, - updated_at: 0, - }; - kb.create_entities(vec![entity]).unwrap(); - } - - let kb = Arc::new(kb); - let mut handles = vec![]; - - // 5 reader threads - for _ in 0..5 { - let kb_clone = Arc::clone(&kb); - let handle = thread::spawn(move || { - for _ in 0..100 { - let graph = kb_clone.read_graph(None, None).unwrap(); - assert!(graph.entities.len() >= 5); - let _ = kb_clone.search_nodes("Entity", None, true); - } - }); - handles.push(handle); - } - - // 3 writer threads - for i in 0..3 { - let kb_clone = Arc::clone(&kb); - let handle = thread::spawn(move || { - for j in 0..10 { - let obs = Observation { - entity_name: format!("Entity{}", i), - contents: vec![format!("Update {} from writer {}", j, i)], - }; - let _ = kb_clone.add_observations(vec![obs]); - } - }); - handles.push(handle); - } - - // Wait for all threads - for handle in handles { - handle.join().expect("Thread panicked"); - } - - // Verify no data corruption - let graph = kb.read_graph(None, None).unwrap(); - assert_eq!(graph.entities.len(), 5, "Original entities should still exist"); - - cleanup(&temp_file); - } -} diff --git a/src/knowledge_base/crud.rs b/src/knowledge_base/crud.rs new file mode 100644 index 0000000..f9a0dfa --- /dev/null +++ b/src/knowledge_base/crud.rs @@ -0,0 +1,162 @@ +//! CRUD operations for the knowledge base + +use std::collections::HashSet; + +use crate::types::{Entity, McpResult, Observation, ObservationDeletion, Relation}; +use crate::utils::time::current_timestamp; + +use super::KnowledgeBase; + +/// Create new entities (thread-safe: holds write lock during entire operation) +pub fn create_entities(kb: &KnowledgeBase, entities: Vec) -> McpResult> { + let mut graph = kb.graph.write().unwrap(); + let existing_names: HashSet = graph.entities.iter().map(|e| e.name.clone()).collect(); + let now = current_timestamp(); + + let mut created = Vec::new(); + for mut entity in entities { + if !existing_names.contains(&entity.name) { + // Auto-fill user info if not provided + if entity.created_by.is_empty() || entity.created_by == "system" { + entity.created_by = kb.current_user.clone(); + } + if entity.updated_by.is_empty() || entity.updated_by == "system" { + entity.updated_by = kb.current_user.clone(); + } + entity.created_at = now; + entity.updated_at = now; + created.push(entity.clone()); + graph.entities.push(entity); + } + } + + kb.persist_to_file(&graph)?; + Ok(created) +} + +/// Create new relations (thread-safe: holds write lock during entire operation) +pub fn create_relations(kb: &KnowledgeBase, relations: Vec) -> McpResult> { + let mut graph = kb.graph.write().unwrap(); + let entity_names: HashSet = graph.entities.iter().map(|e| e.name.clone()).collect(); + let now = current_timestamp(); + + let existing_relations: HashSet = graph + .relations + .iter() + .map(|r| format!("{}|{}|{}", r.from, r.to, r.relation_type)) + .collect(); + + let mut created = Vec::new(); + for mut relation in relations { + if entity_names.contains(&relation.from) && entity_names.contains(&relation.to) { + let key = format!( + "{}|{}|{}", + relation.from, relation.to, relation.relation_type + ); + if !existing_relations.contains(&key) { + // Auto-fill user info if not provided + if relation.created_by.is_empty() || relation.created_by == "system" { + relation.created_by = kb.current_user.clone(); + } + relation.created_at = now; + created.push(relation.clone()); + graph.relations.push(relation); + } + } + } + + kb.persist_to_file(&graph)?; + Ok(created) +} + +/// Add observations to entities (thread-safe: holds write lock during entire operation) +pub fn add_observations( + kb: &KnowledgeBase, + observations: Vec, +) -> McpResult> { + let mut graph = kb.graph.write().unwrap(); + let mut added = Vec::new(); + let now = current_timestamp(); + + for obs in observations { + if let Some(entity) = graph.entities.iter_mut().find(|e| e.name == obs.entity_name) { + let existing: HashSet = entity.observations.iter().cloned().collect(); + let mut new_contents = Vec::new(); + + for content in &obs.contents { + if !existing.contains(content) { + entity.observations.push(content.clone()); + new_contents.push(content.clone()); + } + } + + if !new_contents.is_empty() { + entity.updated_at = now; + entity.updated_by = kb.current_user.clone(); + added.push(Observation { + entity_name: obs.entity_name.clone(), + contents: new_contents, + }); + } + } + } + + kb.persist_to_file(&graph)?; + Ok(added) +} + +/// Delete entities (thread-safe: holds write lock during entire operation) +pub fn delete_entities(kb: &KnowledgeBase, entity_names: Vec) -> McpResult<()> { + let mut graph = kb.graph.write().unwrap(); + let names_to_delete: HashSet = entity_names.into_iter().collect(); + + graph + .entities + .retain(|e| !names_to_delete.contains(&e.name)); + graph + .relations + .retain(|r| !names_to_delete.contains(&r.from) && !names_to_delete.contains(&r.to)); + + kb.persist_to_file(&graph)?; + Ok(()) +} + +/// Delete observations from entities (thread-safe: holds write lock during entire operation) +pub fn delete_observations( + kb: &KnowledgeBase, + deletions: Vec, +) -> McpResult<()> { + let mut graph = kb.graph.write().unwrap(); + + for deletion in deletions { + if let Some(entity) = graph + .entities + .iter_mut() + .find(|e| e.name == deletion.entity_name) + { + let to_remove: HashSet = deletion.observations.into_iter().collect(); + entity.observations.retain(|o| !to_remove.contains(o)); + } + } + + kb.persist_to_file(&graph)?; + Ok(()) +} + +/// Delete relations (thread-safe: holds write lock during entire operation) +pub fn delete_relations(kb: &KnowledgeBase, relations: Vec) -> McpResult<()> { + let mut graph = kb.graph.write().unwrap(); + + let to_delete: HashSet = relations + .iter() + .map(|r| format!("{}|{}|{}", r.from, r.to, r.relation_type)) + .collect(); + + graph.relations.retain(|r| { + let key = format!("{}|{}|{}", r.from, r.to, r.relation_type); + !to_delete.contains(&key) + }); + + kb.persist_to_file(&graph)?; + Ok(()) +} diff --git a/src/knowledge_base/mod.rs b/src/knowledge_base/mod.rs new file mode 100644 index 0000000..722e101 --- /dev/null +++ b/src/knowledge_base/mod.rs @@ -0,0 +1,251 @@ +//! Knowledge Base - Core data engine +//! +//! This module contains the main knowledge base implementation with +//! thread-safe CRUD operations, queries, and temporal features. + +mod crud; +mod query; +mod summarize; +mod temporal; +mod traversal; + +use std::env; +use std::fs; +use std::path::Path; +use std::sync::RwLock; + +use crate::types::{ + Entity, KnowledgeGraph, McpResult, Observation, ObservationDeletion, + PathStep, RelatedEntities, Relation, Summary, TraversalResult, +}; +use crate::utils::time::get_current_user; + +/// Knowledge base with in-memory cache for thread-safe operations +/// Uses RwLock for better concurrent read performance (read-heavy workload) +pub struct KnowledgeBase { + pub(crate) memory_file_path: String, + pub(crate) graph: RwLock, + pub(crate) current_user: String, +} + +impl KnowledgeBase { + /// Create a new knowledge base instance + pub fn new() -> Self { + let current_dir = env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")); + let default_memory_path = current_dir.join("memory.jsonl"); + + let memory_file_path = match env::var("MEMORY_FILE_PATH") { + Ok(path) => { + if Path::new(&path).is_absolute() { + path + } else { + current_dir.join(path).to_string_lossy().to_string() + } + } + Err(_) => default_memory_path.to_string_lossy().to_string(), + }; + + // Detect current user once at startup + let current_user = get_current_user(); + + // Load graph from file at startup (or create empty if not exists) + let graph = Self::load_graph_from_file(&memory_file_path).unwrap_or_default(); + + Self { + memory_file_path, + graph: RwLock::new(graph), + current_user, + } + } + + /// Create a new knowledge base with custom file path + pub fn with_file_path(file_path: String) -> Self { + let current_user = get_current_user(); + let graph = Self::load_graph_from_file(&file_path).unwrap_or_default(); + + Self { + memory_file_path: file_path, + graph: RwLock::new(graph), + current_user, + } + } + + /// Create a new knowledge base for testing with explicit parameters + #[cfg(test)] + pub fn for_testing(file_path: String, user: String) -> Self { + Self { + memory_file_path: file_path, + graph: RwLock::new(KnowledgeGraph::default()), + current_user: user, + } + } + + /// Load graph from file (static helper for initialization) + fn load_graph_from_file(file_path: &str) -> McpResult { + if !Path::new(file_path).exists() { + return Ok(KnowledgeGraph::default()); + } + + let content = fs::read_to_string(file_path)?; + let mut graph = KnowledgeGraph::default(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + + if let Ok(entity) = serde_json::from_str::(line) { + if !entity.name.is_empty() && !entity.entity_type.is_empty() { + graph.entities.push(entity); + continue; + } + } + + if let Ok(relation) = serde_json::from_str::(line) { + if !relation.from.is_empty() && !relation.to.is_empty() { + graph.relations.push(relation); + } + } + } + + Ok(graph) + } + + /// Get a clone of the current graph (thread-safe read) + /// Uses read lock - allows multiple concurrent readers + pub(crate) fn load_graph(&self) -> McpResult { + Ok(self.graph.read().unwrap().clone()) + } + + /// Persist graph to file (internal helper, expects caller to hold write lock) + pub(crate) fn persist_to_file(&self, graph: &KnowledgeGraph) -> McpResult<()> { + // Ensure parent directory exists + if let Some(parent) = Path::new(&self.memory_file_path).parent() { + fs::create_dir_all(parent)?; + } + + let mut content = String::new(); + + for entity in &graph.entities { + content.push_str(&serde_json::to_string(entity)?); + content.push('\n'); + } + + for relation in &graph.relations { + content.push_str(&serde_json::to_string(relation)?); + content.push('\n'); + } + + fs::write(&self.memory_file_path, content)?; + Ok(()) + } + + /// Get the current user + pub fn current_user(&self) -> &str { + &self.current_user + } + + /// Get the memory file path + pub fn file_path(&self) -> &str { + &self.memory_file_path + } +} + +impl Default for KnowledgeBase { + fn default() -> Self { + Self::new() + } +} + +// Re-export methods from submodules by implementing them here +impl KnowledgeBase { + // CRUD operations (from crud.rs) + pub fn create_entities(&self, entities: Vec) -> McpResult> { + crud::create_entities(self, entities) + } + + pub fn create_relations(&self, relations: Vec) -> McpResult> { + crud::create_relations(self, relations) + } + + pub fn add_observations(&self, observations: Vec) -> McpResult> { + crud::add_observations(self, observations) + } + + pub fn delete_entities(&self, entity_names: Vec) -> McpResult<()> { + crud::delete_entities(self, entity_names) + } + + pub fn delete_observations(&self, deletions: Vec) -> McpResult<()> { + crud::delete_observations(self, deletions) + } + + pub fn delete_relations(&self, relations: Vec) -> McpResult<()> { + crud::delete_relations(self, relations) + } + + // Query operations (from query.rs) + pub fn read_graph( + &self, + limit: Option, + offset: Option, + ) -> McpResult { + query::read_graph(self, limit, offset) + } + + pub fn search_nodes( + &self, + query: &str, + limit: Option, + include_relations: bool, + ) -> McpResult { + query::search_nodes(self, query, limit, include_relations) + } + + pub fn open_nodes(&self, names: Vec) -> McpResult { + query::open_nodes(self, names) + } + + // Traversal operations (from traversal.rs) + pub fn get_related( + &self, + entity_name: &str, + relation_type: Option<&str>, + direction: &str, + ) -> McpResult { + traversal::get_related(self, entity_name, relation_type, direction) + } + + pub fn traverse( + &self, + start: &str, + path: Vec, + max_results: usize, + ) -> McpResult { + traversal::traverse(self, start, path, max_results) + } + + // Summarize operations (from summarize.rs) + pub fn summarize( + &self, + entity_names: Option>, + entity_type: Option, + format: &str, + ) -> McpResult { + summarize::summarize(self, entity_names, entity_type, format) + } + + // Temporal operations (from temporal.rs) + pub fn get_relations_at_time( + &self, + timestamp: Option, + entity_name: Option<&str>, + ) -> McpResult> { + temporal::get_relations_at_time(self, timestamp, entity_name) + } + + pub fn get_relation_history(&self, entity_name: &str) -> McpResult> { + temporal::get_relation_history(self, entity_name) + } +} diff --git a/src/knowledge_base/query.rs b/src/knowledge_base/query.rs new file mode 100644 index 0000000..052020d --- /dev/null +++ b/src/knowledge_base/query.rs @@ -0,0 +1,107 @@ +//! Query operations for the knowledge base + +use std::collections::HashSet; + +use crate::search::{get_synonyms, matches_with_synonyms}; +use crate::types::{Entity, KnowledgeGraph, McpResult, Relation}; + +use super::KnowledgeBase; + +/// Read graph with optional pagination +pub fn read_graph( + kb: &KnowledgeBase, + limit: Option, + offset: Option, +) -> McpResult { + let graph = kb.load_graph()?; + + let offset = offset.unwrap_or(0); + + let entities: Vec = if let Some(lim) = limit { + graph.entities.into_iter().skip(offset).take(lim).collect() + } else { + graph.entities.into_iter().skip(offset).collect() + }; + + let entity_names: HashSet = entities.iter().map(|e| e.name.clone()).collect(); + + let relations: Vec = graph + .relations + .into_iter() + .filter(|r| entity_names.contains(&r.from) || entity_names.contains(&r.to)) + .collect(); + + Ok(KnowledgeGraph { entities, relations }) +} + +/// Search nodes by query with synonym expansion, optional limit and relation inclusion +pub fn search_nodes( + kb: &KnowledgeBase, + query: &str, + limit: Option, + include_relations: bool, +) -> McpResult { + let graph = kb.load_graph()?; + + // Expand query with synonyms for semantic matching + let search_terms = get_synonyms(query); + + let mut matching_entities: Vec = graph + .entities + .into_iter() + .filter(|e| { + matches_with_synonyms(&e.name, &search_terms) + || matches_with_synonyms(&e.entity_type, &search_terms) + || e.observations + .iter() + .any(|o| matches_with_synonyms(o, &search_terms)) + }) + .collect(); + + // Apply limit if specified + if let Some(lim) = limit { + matching_entities.truncate(lim); + } + + let matching_relations = if include_relations { + let entity_names: HashSet = + matching_entities.iter().map(|e| e.name.clone()).collect(); + graph + .relations + .into_iter() + .filter(|r| entity_names.contains(&r.from) || entity_names.contains(&r.to)) + .collect() + } else { + Vec::new() + }; + + Ok(KnowledgeGraph { + entities: matching_entities, + relations: matching_relations, + }) +} + +/// Open specific nodes by names +pub fn open_nodes(kb: &KnowledgeBase, names: Vec) -> McpResult { + let graph = kb.load_graph()?; + let name_set: HashSet = names.into_iter().collect(); + + let matching_entities: Vec = graph + .entities + .into_iter() + .filter(|e| name_set.contains(&e.name)) + .collect(); + + let entity_names: HashSet = matching_entities.iter().map(|e| e.name.clone()).collect(); + + let matching_relations: Vec = graph + .relations + .into_iter() + .filter(|r| entity_names.contains(&r.from) && entity_names.contains(&r.to)) + .collect(); + + Ok(KnowledgeGraph { + entities: matching_entities, + relations: matching_relations, + }) +} diff --git a/src/knowledge_base/summarize.rs b/src/knowledge_base/summarize.rs new file mode 100644 index 0000000..c166e7f --- /dev/null +++ b/src/knowledge_base/summarize.rs @@ -0,0 +1,122 @@ +//! Summarize operations + +use std::collections::HashMap; + +use crate::types::{Entity, EntityBrief, McpResult, Summary}; + +use super::KnowledgeBase; + +/// Summarize entities +pub fn summarize( + kb: &KnowledgeBase, + entity_names: Option>, + entity_type: Option, + format: &str, +) -> McpResult { + let graph = kb.load_graph()?; + + let entities: Vec<&Entity> = graph + .entities + .iter() + .filter(|e| { + if let Some(ref names) = entity_names { + names.contains(&e.name) + } else if let Some(ref et) = entity_type { + &e.entity_type == et + } else { + true + } + }) + .collect(); + + match format { + "brief" => format_brief(&entities), + "detailed" => format_detailed(&entities), + "stats" => format_stats(&entities), + _ => format_brief(&entities), + } +} + +fn format_brief(entities: &[&Entity]) -> McpResult { + let briefs: Vec = entities + .iter() + .map(|e| { + let brief = e + .observations + .first() + .cloned() + .unwrap_or_default() + .chars() + .take(100) + .collect::(); + EntityBrief { + name: e.name.clone(), + entity_type: e.entity_type.clone(), + brief, + } + }) + .collect(); + + Ok(Summary { + total_entities: entities.len(), + entities: Some(briefs), + ..Default::default() + }) +} + +fn format_detailed(entities: &[&Entity]) -> McpResult { + let briefs: Vec = entities + .iter() + .map(|e| { + let brief = e.observations.join("; "); + EntityBrief { + name: e.name.clone(), + entity_type: e.entity_type.clone(), + brief, + } + }) + .collect(); + + Ok(Summary { + total_entities: entities.len(), + entities: Some(briefs), + ..Default::default() + }) +} + +fn format_stats(entities: &[&Entity]) -> McpResult { + let mut by_status: HashMap = HashMap::new(); + let mut by_type: HashMap = HashMap::new(); + let mut by_priority: HashMap = HashMap::new(); + + for entity in entities { + *by_type.entry(entity.entity_type.clone()).or_insert(0) += 1; + + for obs in &entity.observations { + if obs.starts_with("Status:") { + let status = obs.trim_start_matches("Status:").trim().to_string(); + *by_status.entry(status).or_insert(0) += 1; + } + if obs.starts_with("Priority:") { + let priority = obs.trim_start_matches("Priority:").trim().to_string(); + *by_priority.entry(priority).or_insert(0) += 1; + } + } + } + + Ok(Summary { + total_entities: entities.len(), + entities: None, + by_status: if by_status.is_empty() { + None + } else { + Some(by_status) + }, + by_type: Some(by_type), + by_priority: if by_priority.is_empty() { + None + } else { + Some(by_priority) + }, + }) +} diff --git a/src/knowledge_base/temporal.rs b/src/knowledge_base/temporal.rs new file mode 100644 index 0000000..bd98a50 --- /dev/null +++ b/src/knowledge_base/temporal.rs @@ -0,0 +1,57 @@ +//! Temporal query operations + +use crate::types::{McpResult, Relation}; +use crate::utils::time::current_timestamp; + +use super::KnowledgeBase; + +/// Get relations valid at a specific point in time +pub fn get_relations_at_time( + kb: &KnowledgeBase, + timestamp: Option, + entity_name: Option<&str>, +) -> McpResult> { + let graph = kb.load_graph()?; + let check_time = timestamp.unwrap_or_else(current_timestamp); + + let relations: Vec = graph + .relations + .into_iter() + .filter(|r| { + // Filter by entity if specified + if let Some(name) = entity_name { + if r.from != name && r.to != name { + return false; + } + } + + // Check temporal validity + let valid_from_ok = match r.valid_from { + Some(vf) => check_time >= vf, + None => true, // No start time means always valid from past + }; + + let valid_to_ok = match r.valid_to { + Some(vt) => check_time <= vt, + None => true, // No end time means still valid + }; + + valid_from_ok && valid_to_ok + }) + .collect(); + + Ok(relations) +} + +/// Get historical relations (including expired ones) +pub fn get_relation_history(kb: &KnowledgeBase, entity_name: &str) -> McpResult> { + let graph = kb.load_graph()?; + + let relations: Vec = graph + .relations + .into_iter() + .filter(|r| r.from == entity_name || r.to == entity_name) + .collect(); + + Ok(relations) +} diff --git a/src/knowledge_base/traversal.rs b/src/knowledge_base/traversal.rs new file mode 100644 index 0000000..ebc23f0 --- /dev/null +++ b/src/knowledge_base/traversal.rs @@ -0,0 +1,156 @@ +//! Graph traversal operations + +use std::collections::HashSet; + +use crate::types::{ + Entity, McpResult, PathStep, RelatedEntities, RelatedEntity, TraversalPath, TraversalResult, +}; + +use super::KnowledgeBase; + +/// Get related entities +pub fn get_related( + kb: &KnowledgeBase, + entity_name: &str, + relation_type: Option<&str>, + direction: &str, +) -> McpResult { + let graph = kb.load_graph()?; + let mut related = Vec::new(); + + for relation in &graph.relations { + let matches = match direction { + "outgoing" => relation.from == entity_name, + "incoming" => relation.to == entity_name, + "both" => relation.from == entity_name || relation.to == entity_name, + _ => false, + }; + + if !matches { + continue; + } + + if let Some(rt) = relation_type { + if relation.relation_type != rt { + continue; + } + } + + let target_name = if relation.from == entity_name { + &relation.to + } else { + &relation.from + }; + + if let Some(entity) = graph.entities.iter().find(|e| e.name == *target_name) { + related.push(RelatedEntity { + relation_type: relation.relation_type.clone(), + direction: if relation.from == entity_name { + "outgoing".to_string() + } else { + "incoming".to_string() + }, + entity: entity.clone(), + }); + } + } + + Ok(RelatedEntities { + entity: entity_name.to_string(), + relations: related, + }) +} + +/// Traverse graph following path pattern +pub fn traverse( + kb: &KnowledgeBase, + start: &str, + path: Vec, + max_results: usize, +) -> McpResult { + let graph = kb.load_graph()?; + + // Track paths: (current_node, path_so_far, relations_so_far) + let mut current_paths: Vec<(String, Vec, Vec)> = + vec![(start.to_string(), vec![start.to_string()], vec![])]; + + for step in &path { + let mut next_paths = Vec::new(); + + for (node, nodes_path, rels_path) in ¤t_paths { + // Find related entities for this step + for relation in &graph.relations { + let (matches, target_name) = match step.direction.as_str() { + "out" => { + if relation.from == *node && relation.relation_type == step.relation_type { + (true, &relation.to) + } else { + (false, &relation.to) + } + } + "in" => { + if relation.to == *node && relation.relation_type == step.relation_type { + (true, &relation.from) + } else { + (false, &relation.from) + } + } + _ => (false, &relation.to), + }; + + if !matches { + continue; + } + + // Check target type if specified + if let Some(ref target_type) = step.target_type { + if let Some(entity) = graph.entities.iter().find(|e| e.name == *target_name) { + if &entity.entity_type != target_type { + continue; + } + } else { + continue; + } + } + + let mut new_nodes = nodes_path.clone(); + new_nodes.push(target_name.clone()); + let mut new_rels = rels_path.clone(); + new_rels.push(step.relation_type.clone()); + + next_paths.push((target_name.clone(), new_nodes, new_rels)); + } + } + + if next_paths.len() > max_results { + next_paths.truncate(max_results); + } + + current_paths = next_paths; + } + + // Build result + let mut paths = Vec::new(); + let mut end_node_names: HashSet = HashSet::new(); + + for (end_node, nodes, rels) in current_paths { + end_node_names.insert(end_node); + paths.push(TraversalPath { + nodes, + relations: rels, + }); + } + + let end_nodes: Vec = graph + .entities + .iter() + .filter(|e| end_node_names.contains(&e.name)) + .cloned() + .collect(); + + Ok(TraversalResult { + start_node: start.to_string(), + paths, + end_nodes, + }) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6d58fb4 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,63 @@ +//! Memory Graph MCP Server +//! +//! A knowledge graph server implementing the Model Context Protocol (MCP) +//! using pure Rust with minimal dependencies. +//! +//! # Features +//! +//! - **15 MCP Tools**: Full CRUD, query, and temporal operations +//! - **Thread-Safe**: Production-ready with Mutex-based concurrency +//! - **Semantic Search**: Built-in synonym matching +//! - **Time Travel**: Query historical state with validFrom/validTo +//! - **Pagination**: Handle massive graphs with limit/offset +//! +//! # Modules +//! +//! - `types`: Core data structures (Entity, Relation, KnowledgeGraph) +//! - `protocol`: MCP and JSON-RPC protocol types +//! - `knowledge_base`: Core data engine with CRUD and queries +//! - `tools`: 15 MCP tool implementations +//! - `search`: Semantic search with synonym expansion +//! - `validation`: Entity and relation type validation +//! - `utils`: Utility functions (timestamps, etc.) +//! - `server`: MCP server implementation +//! +//! # Example +//! +//! ```no_run +//! use std::sync::Arc; +//! use memory_graph::{KnowledgeBase, McpServer, ServerInfo}; +//! use memory_graph::tools::register_all_tools; +//! +//! fn main() { +//! let kb = Arc::new(KnowledgeBase::new()); +//! let server_info = ServerInfo::new("memory".to_string(), "1.0.0".to_string()); +//! let mut server = McpServer::with_info(server_info); +//! register_all_tools(&mut server, kb); +//! server.run().unwrap(); +//! } +//! ``` + +pub mod knowledge_base; +pub mod protocol; +pub mod search; +pub mod server; +pub mod tools; +pub mod types; +pub mod utils; +pub mod validation; + +// Re-export commonly used items at crate root +pub use knowledge_base::KnowledgeBase; +pub use protocol::{McpTool, ServerInfo, Tool}; +pub use server::McpServer; +pub use types::{ + Entity, EntityBrief, KnowledgeGraph, McpResult, Observation, ObservationDeletion, PathStep, + RelatedEntities, RelatedEntity, Relation, Summary, TraversalPath, TraversalResult, +}; + +/// Library version +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); + +/// Library name +pub const NAME: &str = env!("CARGO_PKG_NAME"); diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..13ba3b0 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,27 @@ +//! Memory Graph MCP Server - Binary Entry Point +//! +//! This is the main entry point for the memory-server binary. + +use std::sync::Arc; + +use memory_graph::knowledge_base::KnowledgeBase; +use memory_graph::protocol::ServerInfo; +use memory_graph::server::McpServer; +use memory_graph::tools::register_all_tools; +use memory_graph::types::McpResult; + +fn main() -> McpResult<()> { + let kb = Arc::new(KnowledgeBase::new()); + + let server_info = ServerInfo { + name: "memory".to_string(), + version: "1.0.0".to_string(), + }; + + let mut server = McpServer::with_info(server_info); + + // Register all 15 tools + register_all_tools(&mut server, kb); + + server.run() +} diff --git a/src/protocol/jsonrpc.rs b/src/protocol/jsonrpc.rs new file mode 100644 index 0000000..cb290f9 --- /dev/null +++ b/src/protocol/jsonrpc.rs @@ -0,0 +1,137 @@ +//! JSON-RPC 2.0 protocol types + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// JSON-RPC 2.0 Request +#[derive(Deserialize, Debug, Clone)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + pub id: Option, + pub method: String, + pub params: Option, +} + +impl JsonRpcRequest { + /// Check if this is a valid JSON-RPC 2.0 request + pub fn is_valid(&self) -> bool { + self.jsonrpc == "2.0" + } + + /// Check if this is a notification (no id) + pub fn is_notification(&self) -> bool { + self.id.is_none() + } +} + +/// JSON-RPC 2.0 Success Response +#[derive(Serialize, Debug)] +pub struct JsonRpcResponse { + pub jsonrpc: String, + pub id: Value, + pub result: Value, +} + +impl JsonRpcResponse { + /// Create a new success response + pub fn new(id: Value, result: Value) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + result, + } + } +} + +/// JSON-RPC 2.0 Error Response +#[derive(Serialize, Debug)] +pub struct JsonRpcError { + pub jsonrpc: String, + pub id: Value, + pub error: ErrorObject, +} + +impl JsonRpcError { + /// Create a new error response + pub fn new(id: Value, code: i32, message: String, data: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + error: ErrorObject { + code, + message, + data, + }, + } + } + + /// Create a parse error response + pub fn parse_error(id: Value, details: String) -> Self { + Self::new( + id, + -32700, + "Parse error".to_string(), + Some(serde_json::json!({"details": details})), + ) + } + + /// Create an invalid request error response + pub fn invalid_request(id: Value, details: String) -> Self { + Self::new( + id, + -32600, + "Invalid Request".to_string(), + Some(serde_json::json!({"details": details})), + ) + } + + /// Create a method not found error response + pub fn method_not_found(id: Value, method: String) -> Self { + Self::new( + id, + -32601, + "Method not found".to_string(), + Some(serde_json::json!({"method": method})), + ) + } + + /// Create an invalid params error response + pub fn invalid_params(id: Value, details: String) -> Self { + Self::new( + id, + -32602, + "Invalid params".to_string(), + Some(serde_json::json!({"details": details})), + ) + } + + /// Create an internal error response + pub fn internal_error(id: Value, details: String) -> Self { + Self::new( + id, + -32603, + "Internal error".to_string(), + Some(serde_json::json!({"details": details})), + ) + } +} + +/// JSON-RPC 2.0 Error Object +#[derive(Serialize, Debug)] +pub struct ErrorObject { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl ErrorObject { + /// Create a new error object + pub fn new(code: i32, message: String, data: Option) -> Self { + Self { + code, + message, + data, + } + } +} diff --git a/src/protocol/mcp.rs b/src/protocol/mcp.rs new file mode 100644 index 0000000..e7db12a --- /dev/null +++ b/src/protocol/mcp.rs @@ -0,0 +1,65 @@ +//! MCP (Model Context Protocol) types + +use serde::Serialize; +use serde_json::Value; + +use crate::types::McpResult; + +/// MCP Tool definition +#[derive(Serialize, Debug, Clone)] +pub struct McpTool { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: Value, +} + +impl McpTool { + /// Create a new MCP tool definition + pub fn new(name: String, description: String, input_schema: Value) -> Self { + Self { + name, + description, + input_schema, + } + } +} + +/// Server information for MCP handshake +#[derive(Clone, Debug)] +pub struct ServerInfo { + pub name: String, + pub version: String, +} + +impl ServerInfo { + /// Create new server info + pub fn new(name: String, version: String) -> Self { + Self { name, version } + } +} + +impl Default for ServerInfo { + fn default() -> Self { + Self { + name: "memory".to_string(), + version: "1.0.0".to_string(), + } + } +} + +/// Trait for MCP tools +/// +/// All tools must implement this trait to be registered with the MCP server. +pub trait Tool: Send + Sync { + /// Get the tool definition for tools/list + fn definition(&self) -> McpTool; + + /// Execute the tool with the given parameters + fn execute(&self, params: Value) -> McpResult; + + /// Get the tool name (convenience method) + fn name(&self) -> String { + self.definition().name + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..788c191 --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,9 @@ +//! Protocol types for MCP and JSON-RPC communication +//! +//! This module contains all protocol-related types and traits. + +mod jsonrpc; +mod mcp; + +pub use jsonrpc::{ErrorObject, JsonRpcError, JsonRpcRequest, JsonRpcResponse}; +pub use mcp::{McpTool, ServerInfo, Tool}; diff --git a/src/search/mod.rs b/src/search/mod.rs new file mode 100644 index 0000000..a638a71 --- /dev/null +++ b/src/search/mod.rs @@ -0,0 +1,7 @@ +//! Semantic search with synonym matching +//! +//! This module provides semantic search capabilities through synonym expansion. + +mod synonyms; + +pub use synonyms::{get_synonyms, matches_with_synonyms, SYNONYM_GROUPS}; diff --git a/src/search/synonyms.rs b/src/search/synonyms.rs new file mode 100644 index 0000000..076b689 --- /dev/null +++ b/src/search/synonyms.rs @@ -0,0 +1,108 @@ +//! Synonym dictionary for semantic search + +/// Synonym groups - words in same group are considered semantically similar +pub const SYNONYM_GROUPS: &[&[&str]] = &[ + // Developer roles + &[ + "coder", + "programmer", + "developer", + "engineer", + "dev", + "software engineer", + "software developer", + ], + &["frontend", "front-end", "ui developer", "client-side"], + &["backend", "back-end", "server-side", "api developer"], + &["fullstack", "full-stack", "full stack"], + &["devops", "sre", "infrastructure", "platform engineer"], + // Bug/Issue related + &[ + "bug", "issue", "defect", "error", "problem", "fault", "glitch", + ], + &["fix", "patch", "hotfix", "bugfix", "repair", "resolve"], + // Feature/Task related + &["feature", "functionality", "capability", "enhancement"], + &["task", "ticket", "work item", "story", "user story"], + &["requirement", "spec", "specification", "req"], + // Status + &["done", "completed", "finished", "resolved", "closed"], + &["pending", "waiting", "blocked", "on hold"], + &["in progress", "wip", "ongoing", "active", "working"], + &["todo", "to do", "planned", "backlog"], + // Priority + &["critical", "urgent", "p0", "blocker", "showstopper"], + &["high", "important", "p1"], + &["medium", "normal", "p2"], + &["low", "minor", "p3"], + // Project management + &["milestone", "release", "version", "sprint"], + &["deadline", "due date", "target date"], + &["project", "repo", "repository", "codebase"], + // Documentation + &["doc", "docs", "documentation", "readme", "guide"], + &["api", "interface", "endpoint"], + // Testing + &["test", "testing", "qa", "quality assurance"], + &["unit test", "unittest"], + &["integration test", "e2e", "end-to-end"], + // Architecture + &["module", "component", "service", "package"], + &["database", "db", "datastore", "storage"], + &["cache", "caching", "redis", "memcached"], +]; + +/// Get all synonyms for a query term +pub fn get_synonyms(query: &str) -> Vec { + let query_lower = query.to_lowercase(); + let mut synonyms = vec![query_lower.clone()]; + + for group in SYNONYM_GROUPS { + if group.iter().any(|&word| { + word == query_lower || query_lower.contains(word) || word.contains(&query_lower) + }) { + for &word in *group { + if !synonyms.contains(&word.to_string()) { + synonyms.push(word.to_string()); + } + } + } + } + + synonyms +} + +/// Check if text matches any of the search terms (including synonyms) +pub fn matches_with_synonyms(text: &str, search_terms: &[String]) -> bool { + let text_lower = text.to_lowercase(); + search_terms.iter().any(|term| text_lower.contains(term)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_synonyms_developer() { + let synonyms = get_synonyms("developer"); + assert!(synonyms.contains(&"developer".to_string())); + assert!(synonyms.contains(&"coder".to_string())); + assert!(synonyms.contains(&"programmer".to_string())); + } + + #[test] + fn test_get_synonyms_bug() { + let synonyms = get_synonyms("bug"); + assert!(synonyms.contains(&"bug".to_string())); + assert!(synonyms.contains(&"issue".to_string())); + assert!(synonyms.contains(&"defect".to_string())); + } + + #[test] + fn test_matches_with_synonyms() { + let terms = get_synonyms("developer"); + assert!(matches_with_synonyms("I am a coder", &terms)); + assert!(matches_with_synonyms("Software Engineer position", &terms)); + assert!(!matches_with_synonyms("I am a doctor", &terms)); + } +} diff --git a/src/server/handlers.rs b/src/server/handlers.rs new file mode 100644 index 0000000..3cc0911 --- /dev/null +++ b/src/server/handlers.rs @@ -0,0 +1,38 @@ +//! Request handlers for the MCP server +//! +//! This module contains helper functions for handling various request types. +//! Most handlers are implemented directly in McpServer, but this module +//! can be extended for custom handlers. + +use serde_json::Value; + +/// Extract tool arguments from params +pub fn extract_arguments(params: &Value) -> Value { + params.get("arguments").cloned().unwrap_or(Value::Object(serde_json::Map::new())) +} + +/// Extract tool name from params +pub fn extract_tool_name(params: &Value) -> Option<&str> { + params.get("name").and_then(|v| v.as_str()) +} + +/// Build a text content response +pub fn text_response(text: String) -> Value { + serde_json::json!({ + "content": [{ + "type": "text", + "text": text + }] + }) +} + +/// Build an error content response +pub fn error_response(message: String) -> Value { + serde_json::json!({ + "content": [{ + "type": "text", + "text": format!("Error: {}", message) + }], + "isError": true + }) +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..fc3d958 --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,204 @@ +//! MCP Server implementation +//! +//! This module contains the main server that handles JSON-RPC communication. + +mod handlers; + +use std::collections::HashMap; +use std::io::{self, BufRead, BufReader, BufWriter, Write}; + +use serde_json::{json, Value}; + +use crate::protocol::{ + JsonRpcError, JsonRpcRequest, JsonRpcResponse, McpTool, ServerInfo, Tool, +}; +use crate::types::McpResult; + +pub use handlers::*; + +/// MCP Server that handles JSON-RPC communication over stdio +pub struct McpServer { + server_info: ServerInfo, + tools: HashMap>, + reader: BufReader, + writer: BufWriter, +} + +impl McpServer { + /// Create a new MCP server with default settings + pub fn new() -> Self { + Self { + server_info: ServerInfo::default(), + tools: HashMap::new(), + reader: BufReader::new(io::stdin()), + writer: BufWriter::new(io::stdout()), + } + } + + /// Create a new MCP server with custom server info + pub fn with_info(info: ServerInfo) -> Self { + Self { + server_info: info, + tools: HashMap::new(), + reader: BufReader::new(io::stdin()), + writer: BufWriter::new(io::stdout()), + } + } + + /// Register a tool with the server + pub fn register_tool(&mut self, tool: Box) -> &mut Self { + let name = tool.definition().name.clone(); + self.tools.insert(name, tool); + self + } + + /// Get the number of registered tools + pub fn tool_count(&self) -> usize { + self.tools.len() + } + + /// Run the server (blocking) + pub fn run(&mut self) -> McpResult<()> { + let mut line = String::new(); + while self.reader.read_line(&mut line)? > 0 { + let trimmed = line.trim(); + if !trimmed.is_empty() { + self.handle_request(trimmed)?; + } + line.clear(); + } + Ok(()) + } + + /// Handle a single JSON-RPC request + fn handle_request(&mut self, request_str: &str) -> McpResult<()> { + let request: JsonRpcRequest = match serde_json::from_str(request_str) { + Ok(req) => req, + Err(e) => { + self.send_error_response( + Value::Null, + -32700, + "Parse error", + Some(json!({"details": e.to_string()})), + )?; + return Ok(()); + } + }; + + if request.jsonrpc != "2.0" { + self.send_error_response( + request.id.unwrap_or(Value::Null), + -32600, + "Invalid Request", + Some(json!({"details": "jsonrpc must be '2.0'"})), + )?; + return Ok(()); + } + + let id = request.id.clone().unwrap_or(Value::Null); + + match request.method.as_str() { + "initialize" => self.handle_initialize(id, request.params), + "notifications/initialized" => Ok(()), // Notification, no response + "tools/list" => self.handle_tools_list(id), + "tools/call" => self.handle_tool_call(id, request.params), + "ping" => self.send_success_response(id, json!({})), + _ => self.send_error_response( + id, + -32601, + "Method not found", + Some(json!({"method": request.method})), + ), + } + } + + /// Handle initialize request + fn handle_initialize(&mut self, id: Value, _params: Option) -> McpResult<()> { + let result = json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": { + "name": self.server_info.name, + "version": self.server_info.version + } + }); + self.send_success_response(id, result) + } + + /// Handle tools/list request + fn handle_tools_list(&mut self, id: Value) -> McpResult<()> { + let tools: Vec = self.tools.values().map(|t| t.definition()).collect(); + let result = json!({ "tools": tools }); + self.send_success_response(id, result) + } + + /// Handle tools/call request + fn handle_tool_call(&mut self, id: Value, params: Option) -> McpResult<()> { + let params = params.ok_or("Missing parameters")?; + let tool_name = params + .get("name") + .and_then(|v| v.as_str()) + .ok_or("Missing tool name")?; + + let tool = match self.tools.get(tool_name) { + Some(tool) => tool, + None => { + self.send_error_response( + id, + -32602, + "Unknown tool", + Some(json!({"tool": tool_name})), + )?; + return Ok(()); + } + }; + + let arguments = params.get("arguments").cloned().unwrap_or(json!({})); + + match tool.execute(arguments) { + Ok(result) => self.send_success_response(id, result), + Err(e) => self.send_error_response( + id, + -32603, + "Tool execution error", + Some(json!({"details": e.to_string()})), + ), + } + } + + /// Send a success response + fn send_success_response(&mut self, id: Value, result: Value) -> McpResult<()> { + let response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result, + }; + let json = serde_json::to_string(&response)?; + writeln!(self.writer, "{}", json)?; + self.writer.flush()?; + Ok(()) + } + + /// Send an error response + fn send_error_response( + &mut self, + id: Value, + code: i32, + message: &str, + data: Option, + ) -> McpResult<()> { + let response = JsonRpcError::new(id, code, message.to_string(), data); + let json = serde_json::to_string(&response)?; + writeln!(self.writer, "{}", json)?; + self.writer.flush()?; + Ok(()) + } +} + +impl Default for McpServer { + fn default() -> Self { + Self::new() + } +} diff --git a/src/tools/memory/add_observations.rs b/src/tools/memory/add_observations.rs new file mode 100644 index 0000000..ecdb758 --- /dev/null +++ b/src/tools/memory/add_observations.rs @@ -0,0 +1,63 @@ +//! Add observations tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::{McpResult, Observation}; + +/// Tool for adding new observations to existing entities +pub struct AddObservationsTool { + kb: Arc, +} + +impl AddObservationsTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for AddObservationsTool { + fn definition(&self) -> McpTool { + McpTool { + name: "add_observations".to_string(), + description: "Add new observations to existing entities in the knowledge graph" + .to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "observations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entityName": { "type": "string", "description": "The name of the entity" }, + "contents": { + "type": "array", + "items": { "type": "string" }, + "description": "Observation contents to add" + } + }, + "required": ["entityName", "contents"] + } + } + }, + "required": ["observations"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let observations: Vec = + serde_json::from_value(params.get("observations").cloned().unwrap_or(json!([])))?; + let added = self.kb.add_observations(observations)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&added)? + }] + })) + } +} diff --git a/src/tools/memory/create_entities.rs b/src/tools/memory/create_entities.rs new file mode 100644 index 0000000..0cddf1c --- /dev/null +++ b/src/tools/memory/create_entities.rs @@ -0,0 +1,84 @@ +//! Create entities tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::{Entity, McpResult}; +use crate::validation::validate_entity_type; + +/// Tool for creating multiple new entities in the knowledge graph +pub struct CreateEntitiesTool { + kb: Arc, +} + +impl CreateEntitiesTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for CreateEntitiesTool { + fn definition(&self) -> McpTool { + McpTool { + name: "create_entities".to_string(), + description: "Create multiple new entities in the knowledge graph".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string", "description": "The name of the entity" }, + "entityType": { "type": "string", "description": "The type of the entity" }, + "observations": { + "type": "array", + "items": { "type": "string" }, + "description": "Initial observations about the entity" + }, + "createdBy": { "type": "string", "description": "Who created this entity (auto-filled from git/env if not provided)" }, + "updatedBy": { "type": "string", "description": "Who last updated this entity (auto-filled from git/env if not provided)" } + }, + "required": ["name", "entityType"] + } + } + }, + "required": ["entities"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let entities: Vec = + serde_json::from_value(params.get("entities").cloned().unwrap_or(json!([])))?; + + // Collect warnings for non-standard types + let warnings: Vec = entities + .iter() + .filter_map(|e| validate_entity_type(&e.entity_type)) + .collect(); + + let created = self.kb.create_entities(entities)?; + + let response = if warnings.is_empty() { + serde_json::to_string_pretty(&created)? + } else { + format!( + "{}\n\n{}", + serde_json::to_string_pretty(&created)?, + warnings.join("\n") + ) + }; + + Ok(json!({ + "content": [{ + "type": "text", + "text": response + }] + })) + } +} diff --git a/src/tools/memory/create_relations.rs b/src/tools/memory/create_relations.rs new file mode 100644 index 0000000..91cf0e7 --- /dev/null +++ b/src/tools/memory/create_relations.rs @@ -0,0 +1,82 @@ +//! Create relations tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::{McpResult, Relation}; +use crate::validation::validate_relation_type; + +/// Tool for creating multiple new relations between entities +pub struct CreateRelationsTool { + kb: Arc, +} + +impl CreateRelationsTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for CreateRelationsTool { + fn definition(&self) -> McpTool { + McpTool { + name: "create_relations".to_string(), + description: "Create multiple new relations between entities in the knowledge graph" + .to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "from": { "type": "string", "description": "The source entity name" }, + "to": { "type": "string", "description": "The target entity name" }, + "relationType": { "type": "string", "description": "The type of relation" }, + "createdBy": { "type": "string", "description": "Who created this relation (auto-filled from git/env if not provided)" }, + "validFrom": { "type": "integer", "description": "Unix timestamp when relation becomes valid" }, + "validTo": { "type": "integer", "description": "Unix timestamp when relation expires" } + }, + "required": ["from", "to", "relationType"] + } + } + }, + "required": ["relations"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let relations: Vec = + serde_json::from_value(params.get("relations").cloned().unwrap_or(json!([])))?; + + // Collect warnings for non-standard relation types + let warnings: Vec = relations + .iter() + .filter_map(|r| validate_relation_type(&r.relation_type)) + .collect(); + + let created = self.kb.create_relations(relations)?; + + let response = if warnings.is_empty() { + serde_json::to_string_pretty(&created)? + } else { + format!( + "{}\n\n{}", + serde_json::to_string_pretty(&created)?, + warnings.join("\n") + ) + }; + + Ok(json!({ + "content": [{ + "type": "text", + "text": response + }] + })) + } +} diff --git a/src/tools/memory/delete_entities.rs b/src/tools/memory/delete_entities.rs new file mode 100644 index 0000000..da05267 --- /dev/null +++ b/src/tools/memory/delete_entities.rs @@ -0,0 +1,52 @@ +//! Delete entities tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; + +/// Tool for deleting multiple entities from the knowledge graph +pub struct DeleteEntitiesTool { + kb: Arc, +} + +impl DeleteEntitiesTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for DeleteEntitiesTool { + fn definition(&self) -> McpTool { + McpTool { + name: "delete_entities".to_string(), + description: "Delete multiple entities from the knowledge graph".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "entityNames": { + "type": "array", + "items": { "type": "string" }, + "description": "An array of entity names to delete" + } + }, + "required": ["entityNames"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let entity_names: Vec = + serde_json::from_value(params.get("entityNames").cloned().unwrap_or(json!([])))?; + self.kb.delete_entities(entity_names)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": "Entities deleted successfully" + }] + })) + } +} diff --git a/src/tools/memory/delete_observations.rs b/src/tools/memory/delete_observations.rs new file mode 100644 index 0000000..5623d0c --- /dev/null +++ b/src/tools/memory/delete_observations.rs @@ -0,0 +1,63 @@ +//! Delete observations tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::{McpResult, ObservationDeletion}; + +/// Tool for deleting specific observations from entities +pub struct DeleteObservationsTool { + kb: Arc, +} + +impl DeleteObservationsTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for DeleteObservationsTool { + fn definition(&self) -> McpTool { + McpTool { + name: "delete_observations".to_string(), + description: "Delete specific observations from entities in the knowledge graph" + .to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "deletions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entityName": { "type": "string", "description": "The name of the entity" }, + "observations": { + "type": "array", + "items": { "type": "string" }, + "description": "Observations to delete" + } + }, + "required": ["entityName", "observations"] + } + } + }, + "required": ["deletions"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let deletions: Vec = + serde_json::from_value(params.get("deletions").cloned().unwrap_or(json!([])))?; + self.kb.delete_observations(deletions)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": "Observations deleted successfully" + }] + })) + } +} diff --git a/src/tools/memory/delete_relations.rs b/src/tools/memory/delete_relations.rs new file mode 100644 index 0000000..2ea2728 --- /dev/null +++ b/src/tools/memory/delete_relations.rs @@ -0,0 +1,59 @@ +//! Delete relations tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::{McpResult, Relation}; + +/// Tool for deleting multiple relations from the knowledge graph +pub struct DeleteRelationsTool { + kb: Arc, +} + +impl DeleteRelationsTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for DeleteRelationsTool { + fn definition(&self) -> McpTool { + McpTool { + name: "delete_relations".to_string(), + description: "Delete multiple relations from the knowledge graph".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "from": { "type": "string", "description": "The source entity name" }, + "to": { "type": "string", "description": "The target entity name" }, + "relationType": { "type": "string", "description": "The type of relation" } + }, + "required": ["from", "to", "relationType"] + } + } + }, + "required": ["relations"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let relations: Vec = + serde_json::from_value(params.get("relations").cloned().unwrap_or(json!([])))?; + self.kb.delete_relations(relations)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": "Relations deleted successfully" + }] + })) + } +} diff --git a/src/tools/memory/mod.rs b/src/tools/memory/mod.rs new file mode 100644 index 0000000..9869bde --- /dev/null +++ b/src/tools/memory/mod.rs @@ -0,0 +1,23 @@ +//! Memory tools for CRUD operations +//! +//! This module contains 9 tools for managing entities, relations, and observations. + +mod add_observations; +mod create_entities; +mod create_relations; +mod delete_entities; +mod delete_observations; +mod delete_relations; +mod open_nodes; +mod read_graph; +mod search_nodes; + +pub use add_observations::AddObservationsTool; +pub use create_entities::CreateEntitiesTool; +pub use create_relations::CreateRelationsTool; +pub use delete_entities::DeleteEntitiesTool; +pub use delete_observations::DeleteObservationsTool; +pub use delete_relations::DeleteRelationsTool; +pub use open_nodes::OpenNodesTool; +pub use read_graph::ReadGraphTool; +pub use search_nodes::SearchNodesTool; diff --git a/src/tools/memory/open_nodes.rs b/src/tools/memory/open_nodes.rs new file mode 100644 index 0000000..8d4b0c3 --- /dev/null +++ b/src/tools/memory/open_nodes.rs @@ -0,0 +1,52 @@ +//! Open nodes tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; + +/// Tool for opening specific nodes by their names +pub struct OpenNodesTool { + kb: Arc, +} + +impl OpenNodesTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for OpenNodesTool { + fn definition(&self) -> McpTool { + McpTool { + name: "open_nodes".to_string(), + description: "Open specific nodes in the knowledge graph by their names".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "names": { + "type": "array", + "items": { "type": "string" }, + "description": "An array of entity names to retrieve" + } + }, + "required": ["names"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let names: Vec = + serde_json::from_value(params.get("names").cloned().unwrap_or(json!([])))?; + let graph = self.kb.open_nodes(names)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&graph)? + }] + })) + } +} diff --git a/src/tools/memory/read_graph.rs b/src/tools/memory/read_graph.rs new file mode 100644 index 0000000..fadaa98 --- /dev/null +++ b/src/tools/memory/read_graph.rs @@ -0,0 +1,68 @@ +//! Read graph tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; + +/// Tool for reading the knowledge graph with optional pagination +pub struct ReadGraphTool { + kb: Arc, +} + +impl ReadGraphTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for ReadGraphTool { + fn definition(&self) -> McpTool { + McpTool { + name: "read_graph".to_string(), + description: "Read the knowledge graph with optional pagination. Use limit/offset to avoid context overflow with large graphs.".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "Maximum number of entities to return. Recommended: 50-100 for large graphs" + }, + "offset": { + "type": "integer", + "description": "Number of entities to skip (for pagination)" + } + }, + "required": [] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let limit = params + .get("limit") + .and_then(|v| v.as_u64()) + .map(|v| v as usize); + let offset = params + .get("offset") + .and_then(|v| v.as_u64()) + .map(|v| v as usize); + let graph = self.kb.read_graph(limit, offset)?; + + let total_msg = if limit.is_some() || offset.is_some() { + format!(" (showing {} entities)", graph.entities.len()) + } else { + String::new() + }; + + Ok(json!({ + "content": [{ + "type": "text", + "text": format!("{}{}", serde_json::to_string_pretty(&graph)?, total_msg) + }] + })) + } +} diff --git a/src/tools/memory/search_nodes.rs b/src/tools/memory/search_nodes.rs new file mode 100644 index 0000000..edb2845 --- /dev/null +++ b/src/tools/memory/search_nodes.rs @@ -0,0 +1,69 @@ +//! Search nodes tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; + +/// Tool for searching nodes in the knowledge graph with semantic matching +pub struct SearchNodesTool { + kb: Arc, +} + +impl SearchNodesTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for SearchNodesTool { + fn definition(&self) -> McpTool { + McpTool { + name: "search_nodes".to_string(), + description: + "Search for nodes in the knowledge graph. Returns matching entities with optional relations." + .to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to match against entity names, types, and observations" + }, + "limit": { + "type": "integer", + "description": "Maximum number of entities to return (default: no limit)" + }, + "includeRelations": { + "type": "boolean", + "description": "Whether to include relations connected to matching entities (default: true)" + } + }, + "required": ["query"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let query = params.get("query").and_then(|v| v.as_str()).unwrap_or(""); + let limit = params + .get("limit") + .and_then(|v| v.as_u64()) + .map(|v| v as usize); + let include_relations = params + .get("includeRelations") + .and_then(|v| v.as_bool()) + .unwrap_or(true); + + let graph = self.kb.search_nodes(query, limit, include_relations)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&graph)? + }] + })) + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs new file mode 100644 index 0000000..7559820 --- /dev/null +++ b/src/tools/mod.rs @@ -0,0 +1,47 @@ +//! MCP Tools implementation +//! +//! This module contains all 15 MCP tools organized by category: +//! - Memory tools (9): CRUD operations +//! - Query tools (3): Graph traversal and search +//! - Temporal tools (3): Time-based queries + +pub mod memory; +pub mod query; +pub mod temporal; + +use std::sync::Arc; + +use crate::knowledge_base::KnowledgeBase; +use crate::server::McpServer; + +// Re-export all tools for convenience +pub use memory::{ + AddObservationsTool, CreateEntitiesTool, CreateRelationsTool, DeleteEntitiesTool, + DeleteObservationsTool, DeleteRelationsTool, OpenNodesTool, ReadGraphTool, SearchNodesTool, +}; +pub use query::{GetRelatedTool, SummarizeTool, TraverseTool}; +pub use temporal::{GetCurrentTimeTool, GetRelationHistoryTool, GetRelationsAtTimeTool}; + +/// Register all tools with the MCP server +pub fn register_all_tools(server: &mut McpServer, kb: Arc) { + // Memory tools (9) + server.register_tool(Box::new(CreateEntitiesTool::new(kb.clone()))); + server.register_tool(Box::new(CreateRelationsTool::new(kb.clone()))); + server.register_tool(Box::new(AddObservationsTool::new(kb.clone()))); + server.register_tool(Box::new(DeleteEntitiesTool::new(kb.clone()))); + server.register_tool(Box::new(DeleteObservationsTool::new(kb.clone()))); + server.register_tool(Box::new(DeleteRelationsTool::new(kb.clone()))); + server.register_tool(Box::new(ReadGraphTool::new(kb.clone()))); + server.register_tool(Box::new(SearchNodesTool::new(kb.clone()))); + server.register_tool(Box::new(OpenNodesTool::new(kb.clone()))); + + // Query tools (3) + server.register_tool(Box::new(GetRelatedTool::new(kb.clone()))); + server.register_tool(Box::new(TraverseTool::new(kb.clone()))); + server.register_tool(Box::new(SummarizeTool::new(kb.clone()))); + + // Temporal tools (3) + server.register_tool(Box::new(GetRelationsAtTimeTool::new(kb.clone()))); + server.register_tool(Box::new(GetRelationHistoryTool::new(kb.clone()))); + server.register_tool(Box::new(GetCurrentTimeTool::new())); +} diff --git a/src/tools/query/get_related.rs b/src/tools/query/get_related.rs new file mode 100644 index 0000000..e4aac29 --- /dev/null +++ b/src/tools/query/get_related.rs @@ -0,0 +1,69 @@ +//! Get related tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; + +/// Tool for getting entities related to a specific entity +pub struct GetRelatedTool { + kb: Arc, +} + +impl GetRelatedTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for GetRelatedTool { + fn definition(&self) -> McpTool { + McpTool { + name: "get_related".to_string(), + description: "Get entities related to a specific entity".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "entityName": { + "type": "string", + "description": "Name of the entity to find relations for" + }, + "relationType": { + "type": "string", + "description": "Filter by relation type (optional)" + }, + "direction": { + "type": "string", + "enum": ["outgoing", "incoming", "both"], + "default": "both", + "description": "Direction of relations" + } + }, + "required": ["entityName"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let entity_name = params + .get("entityName") + .and_then(|v| v.as_str()) + .ok_or("Missing entityName")?; + let relation_type = params.get("relationType").and_then(|v| v.as_str()); + let direction = params + .get("direction") + .and_then(|v| v.as_str()) + .unwrap_or("both"); + + let related = self.kb.get_related(entity_name, relation_type, direction)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&related)? + }] + })) + } +} diff --git a/src/tools/query/mod.rs b/src/tools/query/mod.rs new file mode 100644 index 0000000..d6d25bd --- /dev/null +++ b/src/tools/query/mod.rs @@ -0,0 +1,11 @@ +//! Query tools for graph traversal and search +//! +//! This module contains 3 tools for advanced graph operations. + +mod get_related; +mod summarize; +mod traverse; + +pub use get_related::GetRelatedTool; +pub use summarize::SummarizeTool; +pub use traverse::TraverseTool; diff --git a/src/tools/query/summarize.rs b/src/tools/query/summarize.rs new file mode 100644 index 0000000..68793a2 --- /dev/null +++ b/src/tools/query/summarize.rs @@ -0,0 +1,74 @@ +//! Summarize tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; + +/// Tool for getting a condensed summary of entities +pub struct SummarizeTool { + kb: Arc, +} + +impl SummarizeTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for SummarizeTool { + fn definition(&self) -> McpTool { + McpTool { + name: "summarize".to_string(), + description: "Get a condensed summary of entities".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "entityNames": { + "type": "array", + "items": { "type": "string" }, + "description": "Specific entities to summarize (optional)" + }, + "entityType": { + "type": "string", + "description": "Summarize all entities of this type (optional)" + }, + "format": { + "type": "string", + "enum": ["brief", "detailed", "stats"], + "default": "brief", + "description": "Output format: brief (first observation), detailed (all observations), stats (statistics)" + } + }, + "required": [] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let entity_names: Option> = params + .get("entityNames") + .and_then(|v| serde_json::from_value(v.clone()).ok()); + + let entity_type = params + .get("entityType") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let format = params + .get("format") + .and_then(|v| v.as_str()) + .unwrap_or("brief"); + + let summary = self.kb.summarize(entity_names, entity_type, format)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&summary)? + }] + })) + } +} diff --git a/src/tools/query/traverse.rs b/src/tools/query/traverse.rs new file mode 100644 index 0000000..8b33c0e --- /dev/null +++ b/src/tools/query/traverse.rs @@ -0,0 +1,91 @@ +//! Traverse tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::{McpResult, PathStep}; + +/// Tool for traversing the graph following a path pattern +pub struct TraverseTool { + kb: Arc, +} + +impl TraverseTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for TraverseTool { + fn definition(&self) -> McpTool { + McpTool { + name: "traverse".to_string(), + description: "Traverse the graph following a path pattern for multi-hop queries" + .to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "startNode": { + "type": "string", + "description": "Starting entity name" + }, + "path": { + "type": "array", + "items": { + "type": "object", + "properties": { + "relationType": { + "type": "string", + "description": "Type of relation to follow" + }, + "direction": { + "type": "string", + "enum": ["out", "in"], + "description": "Direction: out (outgoing) or in (incoming)" + }, + "targetType": { + "type": "string", + "description": "Filter by target entity type (optional)" + } + }, + "required": ["relationType", "direction"] + }, + "description": "Path pattern to follow" + }, + "maxResults": { + "type": "integer", + "default": 50, + "description": "Maximum number of results" + } + }, + "required": ["startNode", "path"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let start_node = params + .get("startNode") + .and_then(|v| v.as_str()) + .ok_or("Missing startNode")?; + + let path: Vec = + serde_json::from_value(params.get("path").cloned().unwrap_or(json!([])))?; + + let max_results = params + .get("maxResults") + .and_then(|v| v.as_u64()) + .unwrap_or(50) as usize; + + let result = self.kb.traverse(start_node, path, max_results)?; + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&result)? + }] + })) + } +} diff --git a/src/tools/temporal/get_current_time.rs b/src/tools/temporal/get_current_time.rs new file mode 100644 index 0000000..422a0d7 --- /dev/null +++ b/src/tools/temporal/get_current_time.rs @@ -0,0 +1,46 @@ +//! Get current time tool + +use serde_json::{json, Value}; + +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; +use crate::utils::time::get_current_time; + +/// Tool for getting the current datetime and timestamp +pub struct GetCurrentTimeTool; + +impl GetCurrentTimeTool { + pub fn new() -> Self { + Self + } +} + +impl Default for GetCurrentTimeTool { + fn default() -> Self { + Self::new() + } +} + +impl Tool for GetCurrentTimeTool { + fn definition(&self) -> McpTool { + McpTool { + name: "get_current_time".to_string(), + description: "Get the current datetime and timestamp".to_string(), + input_schema: json!({ + "type": "object", + "properties": {}, + "required": [] + }), + } + } + + fn execute(&self, _params: Value) -> McpResult { + let time_info = get_current_time(); + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&time_info)? + }] + })) + } +} diff --git a/src/tools/temporal/get_relation_history.rs b/src/tools/temporal/get_relation_history.rs new file mode 100644 index 0000000..26e6df4 --- /dev/null +++ b/src/tools/temporal/get_relation_history.rs @@ -0,0 +1,83 @@ +//! Get relation history tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; +use crate::utils::time::current_timestamp; + +/// Tool for getting all relations (current and historical) for an entity +pub struct GetRelationHistoryTool { + kb: Arc, +} + +impl GetRelationHistoryTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for GetRelationHistoryTool { + fn definition(&self) -> McpTool { + McpTool { + name: "get_relation_history".to_string(), + description: "Get all relations (current and historical) for an entity. Shows temporal validity (validFrom/validTo) for each relation.".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "entityName": { + "type": "string", + "description": "The name of the entity to get relation history for" + } + }, + "required": ["entityName"] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let entity_name = params + .get("entityName") + .and_then(|v| v.as_str()) + .ok_or("entityName is required")?; + + let relations = self.kb.get_relation_history(entity_name)?; + let current_time = current_timestamp(); + + // Mark each relation as current or historical + let annotated: Vec = relations + .iter() + .map(|r| { + let is_current = match (r.valid_from, r.valid_to) { + (Some(vf), Some(vt)) => current_time >= vf && current_time <= vt, + (Some(vf), None) => current_time >= vf, + (None, Some(vt)) => current_time <= vt, + (None, None) => true, + }; + + json!({ + "from": r.from, + "to": r.to, + "relationType": r.relation_type, + "validFrom": r.valid_from, + "validTo": r.valid_to, + "isCurrent": is_current + }) + }) + .collect(); + + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&json!({ + "entity": entity_name, + "currentTime": current_time, + "relations": annotated + }))? + }] + })) + } +} diff --git a/src/tools/temporal/get_relations_at_time.rs b/src/tools/temporal/get_relations_at_time.rs new file mode 100644 index 0000000..516f18b --- /dev/null +++ b/src/tools/temporal/get_relations_at_time.rs @@ -0,0 +1,61 @@ +//! Get relations at time tool + +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::knowledge_base::KnowledgeBase; +use crate::protocol::{McpTool, Tool}; +use crate::types::McpResult; +use crate::utils::time::current_timestamp; + +/// Tool for getting relations valid at a specific point in time +pub struct GetRelationsAtTimeTool { + kb: Arc, +} + +impl GetRelationsAtTimeTool { + pub fn new(kb: Arc) -> Self { + Self { kb } + } +} + +impl Tool for GetRelationsAtTimeTool { + fn definition(&self) -> McpTool { + McpTool { + name: "get_relations_at_time".to_string(), + description: "Get relations that are valid at a specific point in time. Useful for querying historical state of the knowledge graph.".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "timestamp": { + "type": "integer", + "description": "Unix timestamp to query. If not provided, uses current time." + }, + "entityName": { + "type": "string", + "description": "Optional: filter relations involving this entity" + } + }, + "required": [] + }), + } + } + + fn execute(&self, params: Value) -> McpResult { + let timestamp = params.get("timestamp").and_then(|v| v.as_u64()); + let entity_name = params.get("entityName").and_then(|v| v.as_str()); + + let relations = self.kb.get_relations_at_time(timestamp, entity_name)?; + + Ok(json!({ + "content": [{ + "type": "text", + "text": serde_json::to_string_pretty(&json!({ + "queryTime": timestamp.unwrap_or_else(current_timestamp), + "relations": relations + }))? + }] + })) + } +} diff --git a/src/tools/temporal/mod.rs b/src/tools/temporal/mod.rs new file mode 100644 index 0000000..57cbd62 --- /dev/null +++ b/src/tools/temporal/mod.rs @@ -0,0 +1,11 @@ +//! Temporal tools for time-based queries +//! +//! This module contains 3 tools for temporal operations. + +mod get_current_time; +mod get_relation_history; +mod get_relations_at_time; + +pub use get_current_time::GetCurrentTimeTool; +pub use get_relation_history::GetRelationHistoryTool; +pub use get_relations_at_time::GetRelationsAtTimeTool; diff --git a/src/types/entity.rs b/src/types/entity.rs new file mode 100644 index 0000000..1d96be9 --- /dev/null +++ b/src/types/entity.rs @@ -0,0 +1,68 @@ +//! Entity types for the knowledge graph + +use serde::{Deserialize, Serialize}; + +use super::{default_user, is_default_user, is_zero}; + +/// Entity in the knowledge graph +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Entity { + pub name: String, + #[serde(rename = "entityType")] + pub entity_type: String, + #[serde(default)] + pub observations: Vec, + #[serde( + rename = "createdBy", + default = "default_user", + skip_serializing_if = "is_default_user" + )] + pub created_by: String, + #[serde( + rename = "updatedBy", + default = "default_user", + skip_serializing_if = "is_default_user" + )] + pub updated_by: String, + #[serde(rename = "createdAt", default, skip_serializing_if = "is_zero")] + pub created_at: u64, + #[serde(rename = "updatedAt", default, skip_serializing_if = "is_zero")] + pub updated_at: u64, +} + +impl Entity { + /// Create a new entity with default values + pub fn new(name: String, entity_type: String) -> Self { + Self { + name, + entity_type, + observations: Vec::new(), + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + } + } + + /// Create a new entity with observations + pub fn with_observations(name: String, entity_type: String, observations: Vec) -> Self { + Self { + name, + entity_type, + observations, + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + } + } +} + +/// Brief entity info for summary +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EntityBrief { + pub name: String, + #[serde(rename = "entityType")] + pub entity_type: String, + pub brief: String, +} diff --git a/src/types/graph.rs b/src/types/graph.rs new file mode 100644 index 0000000..bf062e7 --- /dev/null +++ b/src/types/graph.rs @@ -0,0 +1,41 @@ +//! Knowledge graph container type + +use serde::{Deserialize, Serialize}; + +use super::{Entity, Relation}; + +/// Knowledge graph containing entities and relations +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct KnowledgeGraph { + #[serde(default)] + pub entities: Vec, + #[serde(default)] + pub relations: Vec, +} + +impl KnowledgeGraph { + /// Create an empty knowledge graph + pub fn new() -> Self { + Self::default() + } + + /// Create a knowledge graph with entities and relations + pub fn with_data(entities: Vec, relations: Vec) -> Self { + Self { entities, relations } + } + + /// Check if the graph is empty + pub fn is_empty(&self) -> bool { + self.entities.is_empty() && self.relations.is_empty() + } + + /// Get the number of entities + pub fn entity_count(&self) -> usize { + self.entities.len() + } + + /// Get the number of relations + pub fn relation_count(&self) -> usize { + self.relations.len() + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..fe346bb --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,35 @@ +//! Data types for the Memory Graph MCP Server +//! +//! This module contains all the core data structures used throughout the application. + +mod entity; +mod graph; +mod observation; +mod relation; +mod summary; +mod traversal; + +pub use entity::{Entity, EntityBrief}; +pub use graph::KnowledgeGraph; +pub use observation::{Observation, ObservationDeletion}; +pub use relation::{RelatedEntities, RelatedEntity, Relation}; +pub use summary::Summary; +pub use traversal::{PathStep, TraversalPath, TraversalResult}; + +/// Result type for MCP operations +pub type McpResult = Result>; + +/// Default user for serde deserialization +pub fn default_user() -> String { + "system".to_string() +} + +/// Check if string is empty or "system" (for skip_serializing_if) +pub fn is_default_user(val: &str) -> bool { + val.is_empty() || val == "system" +} + +/// Check if value is zero (for skip_serializing_if) +pub fn is_zero(val: &u64) -> bool { + *val == 0 +} diff --git a/src/types/observation.rs b/src/types/observation.rs new file mode 100644 index 0000000..04d1d00 --- /dev/null +++ b/src/types/observation.rs @@ -0,0 +1,39 @@ +//! Observation types for entity updates + +use serde::{Deserialize, Serialize}; + +/// Observation to add to an entity +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Observation { + #[serde(rename = "entityName")] + pub entity_name: String, + pub contents: Vec, +} + +impl Observation { + /// Create a new observation + pub fn new(entity_name: String, contents: Vec) -> Self { + Self { + entity_name, + contents, + } + } +} + +/// Observation deletion request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ObservationDeletion { + #[serde(rename = "entityName")] + pub entity_name: String, + pub observations: Vec, +} + +impl ObservationDeletion { + /// Create a new observation deletion request + pub fn new(entity_name: String, observations: Vec) -> Self { + Self { + entity_name, + observations, + } + } +} diff --git a/src/types/relation.rs b/src/types/relation.rs new file mode 100644 index 0000000..eb5854d --- /dev/null +++ b/src/types/relation.rs @@ -0,0 +1,76 @@ +//! Relation types for the knowledge graph + +use serde::{Deserialize, Serialize}; + +use super::{default_user, is_default_user, is_zero, Entity}; + +/// Relation between entities with temporal validity +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Relation { + pub from: String, + pub to: String, + #[serde(rename = "relationType")] + pub relation_type: String, + #[serde( + rename = "createdBy", + default = "default_user", + skip_serializing_if = "is_default_user" + )] + pub created_by: String, + #[serde(rename = "createdAt", default, skip_serializing_if = "is_zero")] + pub created_at: u64, + #[serde(rename = "validFrom", default, skip_serializing_if = "Option::is_none")] + pub valid_from: Option, + #[serde(rename = "validTo", default, skip_serializing_if = "Option::is_none")] + pub valid_to: Option, +} + +impl Relation { + /// Create a new relation + pub fn new(from: String, to: String, relation_type: String) -> Self { + Self { + from, + to, + relation_type, + created_by: String::new(), + created_at: 0, + valid_from: None, + valid_to: None, + } + } + + /// Create a new relation with temporal validity + pub fn with_validity( + from: String, + to: String, + relation_type: String, + valid_from: Option, + valid_to: Option, + ) -> Self { + Self { + from, + to, + relation_type, + created_by: String::new(), + created_at: 0, + valid_from, + valid_to, + } + } +} + +/// Related entity with relation info +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RelatedEntity { + #[serde(rename = "relationType")] + pub relation_type: String, + pub direction: String, + pub entity: Entity, +} + +/// Result of get_related query +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RelatedEntities { + pub entity: String, + pub relations: Vec, +} diff --git a/src/types/summary.rs b/src/types/summary.rs new file mode 100644 index 0000000..1878f15 --- /dev/null +++ b/src/types/summary.rs @@ -0,0 +1,57 @@ +//! Summary types for graph statistics + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use super::EntityBrief; + +/// Summary statistics +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Summary { + #[serde(rename = "totalEntities")] + pub total_entities: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub entities: Option>, + #[serde(rename = "byStatus", skip_serializing_if = "Option::is_none")] + pub by_status: Option>, + #[serde(rename = "byType", skip_serializing_if = "Option::is_none")] + pub by_type: Option>, + #[serde(rename = "byPriority", skip_serializing_if = "Option::is_none")] + pub by_priority: Option>, +} + +impl Summary { + /// Create an empty summary + pub fn new(total_entities: usize) -> Self { + Self { + total_entities, + ..Default::default() + } + } + + /// Create a summary with entity briefs + pub fn with_entities(total_entities: usize, entities: Vec) -> Self { + Self { + total_entities, + entities: Some(entities), + ..Default::default() + } + } + + /// Create a summary with statistics + pub fn with_stats( + total_entities: usize, + by_type: HashMap, + by_status: Option>, + by_priority: Option>, + ) -> Self { + Self { + total_entities, + entities: None, + by_status, + by_type: Some(by_type), + by_priority, + } + } +} diff --git a/src/types/traversal.rs b/src/types/traversal.rs new file mode 100644 index 0000000..afcaf3a --- /dev/null +++ b/src/types/traversal.rs @@ -0,0 +1,74 @@ +//! Graph traversal types + +use serde::{Deserialize, Serialize}; + +use super::Entity; + +/// Path step for traverse query +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PathStep { + #[serde(rename = "relationType")] + pub relation_type: String, + pub direction: String, + #[serde(rename = "targetType")] + pub target_type: Option, +} + +impl PathStep { + /// Create a new path step + pub fn new(relation_type: String, direction: String) -> Self { + Self { + relation_type, + direction, + target_type: None, + } + } + + /// Create a new path step with target type filter + pub fn with_target_type( + relation_type: String, + direction: String, + target_type: String, + ) -> Self { + Self { + relation_type, + direction, + target_type: Some(target_type), + } + } +} + +/// Single path in traversal result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TraversalPath { + pub nodes: Vec, + pub relations: Vec, +} + +impl TraversalPath { + /// Create a new traversal path + pub fn new(nodes: Vec, relations: Vec) -> Self { + Self { nodes, relations } + } +} + +/// Result of traverse query +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TraversalResult { + #[serde(rename = "startNode")] + pub start_node: String, + pub paths: Vec, + #[serde(rename = "endNodes")] + pub end_nodes: Vec, +} + +impl TraversalResult { + /// Create a new traversal result + pub fn new(start_node: String, paths: Vec, end_nodes: Vec) -> Self { + Self { + start_node, + paths, + end_nodes, + } + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..7511539 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,7 @@ +//! Utility functions and helpers +//! +//! This module contains timestamp utilities and other helper functions. + +pub mod time; + +pub use time::{current_timestamp, days_to_ymd, get_current_time, get_month_name, get_weekday}; diff --git a/src/utils/time.rs b/src/utils/time.rs new file mode 100644 index 0000000..feee7b5 --- /dev/null +++ b/src/utils/time.rs @@ -0,0 +1,136 @@ +//! Time and timestamp utilities + +use serde_json::{json, Value}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Get current Unix timestamp in seconds +pub fn current_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() +} + +/// Get current user from git config or OS environment +pub fn get_current_user() -> String { + use std::env; + use std::process::Command; + + // 1. Try Git Config (preferred for project context) + if let Ok(output) = Command::new("git").args(["config", "user.name"]).output() { + if output.status.success() { + let name = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !name.is_empty() { + return name; + } + } + } + + // 2. Try OS Environment Variable + env::var("USER") // Linux/Mac + .or_else(|_| env::var("USERNAME")) // Windows + .unwrap_or_else(|_| "anonymous".to_string()) +} + +/// Get current time information as JSON +pub fn get_current_time() -> Value { + let now = SystemTime::now(); + let duration = now.duration_since(UNIX_EPOCH).unwrap(); + let timestamp = duration.as_secs(); + let millis = duration.as_millis() as u64; + + // Calculate datetime components + let secs = timestamp as i64; + + // Days since epoch + let days = secs / 86400; + let remaining = secs % 86400; + + let hours = remaining / 3600; + let minutes = (remaining % 3600) / 60; + let seconds = remaining % 60; + + // Calculate year, month, day + let (year, month, day) = days_to_ymd(days); + + // Format ISO 8601 + let iso8601 = format!( + "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", + year, month, day, hours, minutes, seconds + ); + + // Format readable + let weekday = get_weekday(days); + let month_name = get_month_name(month); + let readable = format!( + "{}, {} {} {} {:02}:{:02}:{:02} UTC", + weekday, day, month_name, year, hours, minutes, seconds + ); + + json!({ + "timestamp": timestamp, + "timestamp_ms": millis, + "iso8601": iso8601, + "readable": readable, + "components": { + "year": year, + "month": month, + "day": day, + "hour": hours, + "minute": minutes, + "second": seconds, + "weekday": weekday + } + }) +} + +/// Convert days since epoch to year/month/day +pub fn days_to_ymd(days: i64) -> (i64, u32, u32) { + // Algorithm to convert days since epoch to year/month/day + let remaining_days = days + 719468; // Days from year 0 to 1970 + + let era = remaining_days / 146097; + let doe = remaining_days - era * 146097; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let year = yoe + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let day = (doy - (153 * mp + 2) / 5 + 1) as u32; + let month = if mp < 10 { mp + 3 } else { mp - 9 } as u32; + let year = if month <= 2 { year + 1 } else { year }; + + (year, month, day) +} + +/// Get weekday name from days since epoch +pub fn get_weekday(days: i64) -> &'static str { + match (days + 4) % 7 { + 0 => "Sunday", + 1 => "Monday", + 2 => "Tuesday", + 3 => "Wednesday", + 4 => "Thursday", + 5 => "Friday", + 6 => "Saturday", + _ => "Unknown", + } +} + +/// Get month name from month number +pub fn get_month_name(month: u32) -> &'static str { + match month { + 1 => "January", + 2 => "February", + 3 => "March", + 4 => "April", + 5 => "May", + 6 => "June", + 7 => "July", + 8 => "August", + 9 => "September", + 10 => "October", + 11 => "November", + 12 => "December", + _ => "Unknown", + } +} diff --git a/src/validation/mod.rs b/src/validation/mod.rs new file mode 100644 index 0000000..c32a49b --- /dev/null +++ b/src/validation/mod.rs @@ -0,0 +1,9 @@ +//! Type validation for entities and relations +//! +//! This module provides soft validation for standard entity and relation types. + +mod types; + +pub use types::{ + validate_entity_type, validate_relation_type, STANDARD_ENTITY_TYPES, STANDARD_RELATION_TYPES, +}; diff --git a/src/validation/types.rs b/src/validation/types.rs new file mode 100644 index 0000000..9dd1c61 --- /dev/null +++ b/src/validation/types.rs @@ -0,0 +1,95 @@ +//! Standard entity and relation types with validation + +/// Standard entity types for software project management +pub const STANDARD_ENTITY_TYPES: &[&str] = &[ + "Project", + "Module", + "Feature", + "Bug", + "Decision", + "Requirement", + "Milestone", + "Risk", + "Convention", + "Schema", + "Person", +]; + +/// Standard relation types for software project management +pub const STANDARD_RELATION_TYPES: &[&str] = &[ + "contains", + "implements", + "fixes", + "caused_by", + "depends_on", + "blocked_by", + "assigned_to", + "part_of", + "relates_to", + "supersedes", + "affects", + "requires", +]; + +/// Check if entity type is standard, return warning if not +pub fn validate_entity_type(entity_type: &str) -> Option { + if STANDARD_ENTITY_TYPES + .iter() + .any(|&t| t.eq_ignore_ascii_case(entity_type)) + { + None + } else { + Some(format!( + "⚠️ Non-standard entityType '{}'. Recommended: {:?}", + entity_type, STANDARD_ENTITY_TYPES + )) + } +} + +/// Check if relation type is standard, return warning if not +pub fn validate_relation_type(relation_type: &str) -> Option { + if STANDARD_RELATION_TYPES + .iter() + .any(|&t| t.eq_ignore_ascii_case(relation_type)) + { + None + } else { + Some(format!( + "⚠️ Non-standard relationType '{}'. Recommended: {:?}", + relation_type, STANDARD_RELATION_TYPES + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_standard_entity_type() { + assert!(validate_entity_type("Project").is_none()); + assert!(validate_entity_type("module").is_none()); // case insensitive + assert!(validate_entity_type("Person").is_none()); + } + + #[test] + fn test_validate_non_standard_entity_type() { + let warning = validate_entity_type("CustomType"); + assert!(warning.is_some()); + assert!(warning.unwrap().contains("Non-standard entityType")); + } + + #[test] + fn test_validate_standard_relation_type() { + assert!(validate_relation_type("contains").is_none()); + assert!(validate_relation_type("DEPENDS_ON").is_none()); // case insensitive + assert!(validate_relation_type("implements").is_none()); + } + + #[test] + fn test_validate_non_standard_relation_type() { + let warning = validate_relation_type("custom_relation"); + assert!(warning.is_some()); + assert!(warning.unwrap().contains("Non-standard relationType")); + } +} diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs new file mode 100644 index 0000000..9f94a8a --- /dev/null +++ b/tests/integration_tests.rs @@ -0,0 +1,361 @@ +//! Integration tests for Memory Graph MCP Server + +use std::fs; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::thread; + +use memory_graph::knowledge_base::KnowledgeBase; +use memory_graph::types::{Entity, Observation, Relation}; + +static TEST_COUNTER: AtomicU64 = AtomicU64::new(0); + +fn setup_test_kb() -> (Arc, String) { + let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst); + let temp_file = format!( + "test_memory_{}_{}.jsonl", + std::process::id(), + id + ); + + // Create a test file path + let kb = Arc::new(KnowledgeBase::with_file_path(temp_file.clone())); + (kb, temp_file) +} + +fn cleanup(file_path: &str) { + let _ = fs::remove_file(file_path); +} + +#[test] +fn test_create_entities() { + let (kb, temp_file) = setup_test_kb(); + + let entities = vec![ + Entity { + name: "Alice".to_string(), + entity_type: "Person".to_string(), + observations: vec!["Lives in NYC".to_string()], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + Entity { + name: "Bob".to_string(), + entity_type: "Person".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + ]; + + let created = kb.create_entities(entities).unwrap(); + assert_eq!(created.len(), 2); + + let graph = kb.read_graph(None, None).unwrap(); + assert_eq!(graph.entities.len(), 2); + + cleanup(&temp_file); +} + +#[test] +fn test_create_relations() { + let (kb, temp_file) = setup_test_kb(); + + // First create entities + let entities = vec![ + Entity { + name: "Alice".to_string(), + entity_type: "Person".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + Entity { + name: "Bob".to_string(), + entity_type: "Person".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + ]; + kb.create_entities(entities).unwrap(); + + // Then create relations + let relations = vec![Relation { + from: "Alice".to_string(), + to: "Bob".to_string(), + relation_type: "knows".to_string(), + created_by: String::new(), + created_at: 0, + valid_from: None, + valid_to: None, + }]; + + let created = kb.create_relations(relations).unwrap(); + assert_eq!(created.len(), 1); + + let graph = kb.read_graph(None, None).unwrap(); + assert_eq!(graph.relations.len(), 1); + + cleanup(&temp_file); +} + +#[test] +fn test_search_nodes() { + let (kb, temp_file) = setup_test_kb(); + + let entities = vec![ + Entity { + name: "Alice".to_string(), + entity_type: "Person".to_string(), + observations: vec!["Software Engineer".to_string()], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + Entity { + name: "Bob".to_string(), + entity_type: "Person".to_string(), + observations: vec!["Doctor".to_string()], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + ]; + kb.create_entities(entities).unwrap(); + + let result = kb.search_nodes("Alice", None, true).unwrap(); + assert_eq!(result.entities.len(), 1); + assert_eq!(result.entities[0].name, "Alice"); + + let result = kb.search_nodes("Engineer", None, true).unwrap(); + assert_eq!(result.entities.len(), 1); + assert_eq!(result.entities[0].name, "Alice"); + + cleanup(&temp_file); +} + +#[test] +fn test_delete_entities() { + let (kb, temp_file) = setup_test_kb(); + + let entities = vec![ + Entity { + name: "Alice".to_string(), + entity_type: "Person".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + Entity { + name: "Bob".to_string(), + entity_type: "Person".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }, + ]; + kb.create_entities(entities).unwrap(); + + kb.delete_entities(vec!["Alice".to_string()]).unwrap(); + + let graph = kb.read_graph(None, None).unwrap(); + assert_eq!(graph.entities.len(), 1); + assert_eq!(graph.entities[0].name, "Bob"); + + cleanup(&temp_file); +} + +#[test] +fn test_concurrent_access() { + let (kb, temp_file) = setup_test_kb(); + + // Spawn multiple threads simulating concurrent agents + let mut handles = vec![]; + + for i in 0..10 { + let kb_clone = Arc::clone(&kb); + let handle = thread::spawn(move || { + // Each "agent" creates an entity + let entity = Entity { + name: format!("Agent{}", i), + entity_type: "Person".to_string(), + observations: vec![format!("Created by thread {}", i)], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }; + kb_clone.create_entities(vec![entity]).unwrap(); + + // Each agent also reads the graph + let graph = kb_clone.read_graph(None, None).unwrap(); + assert!(graph.entities.len() >= 1); + + // Each agent adds an observation + let obs = Observation { + entity_name: format!("Agent{}", i), + contents: vec![format!("Observation from thread {}", i)], + }; + let _ = kb_clone.add_observations(vec![obs]); + }); + handles.push(handle); + } + + // Wait for all threads to complete + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Verify final state + let graph = kb.read_graph(None, None).unwrap(); + assert_eq!(graph.entities.len(), 10, "All 10 entities should exist"); + + // Verify all entities have observations + for entity in &graph.entities { + assert!( + entity.observations.len() >= 1, + "Entity should have observations" + ); + } + + cleanup(&temp_file); +} + +#[test] +fn test_concurrent_read_write() { + let (kb, temp_file) = setup_test_kb(); + + // Pre-populate with some entities + for i in 0..5 { + let entity = Entity { + name: format!("Entity{}", i), + entity_type: "Module".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }; + kb.create_entities(vec![entity]).unwrap(); + } + + let mut handles = vec![]; + + // 5 reader threads + for _ in 0..5 { + let kb_clone = Arc::clone(&kb); + let handle = thread::spawn(move || { + for _ in 0..100 { + let graph = kb_clone.read_graph(None, None).unwrap(); + assert!(graph.entities.len() >= 5); + let _ = kb_clone.search_nodes("Entity", None, true); + } + }); + handles.push(handle); + } + + // 3 writer threads + for i in 0..3 { + let kb_clone = Arc::clone(&kb); + let handle = thread::spawn(move || { + for j in 0..10 { + let obs = Observation { + entity_name: format!("Entity{}", i), + contents: vec![format!("Update {} from writer {}", j, i)], + }; + let _ = kb_clone.add_observations(vec![obs]); + } + }); + handles.push(handle); + } + + // Wait for all threads + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Verify no data corruption + let graph = kb.read_graph(None, None).unwrap(); + assert_eq!( + graph.entities.len(), + 5, + "Original entities should still exist" + ); + + cleanup(&temp_file); +} + +#[test] +fn test_semantic_search_synonyms() { + let (kb, temp_file) = setup_test_kb(); + + let entities = vec![Entity { + name: "Alice".to_string(), + entity_type: "Person".to_string(), + observations: vec!["Software developer working on backend".to_string()], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }]; + kb.create_entities(entities).unwrap(); + + // Search with synonym "coder" should find "developer" + let result = kb.search_nodes("coder", None, true).unwrap(); + assert_eq!(result.entities.len(), 1); + assert_eq!(result.entities[0].name, "Alice"); + + // Search with synonym "programmer" should also find "developer" + let result = kb.search_nodes("programmer", None, true).unwrap(); + assert_eq!(result.entities.len(), 1); + + cleanup(&temp_file); +} + +#[test] +fn test_pagination() { + let (kb, temp_file) = setup_test_kb(); + + // Create 20 entities + for i in 0..20 { + let entity = Entity { + name: format!("Entity{:02}", i), + entity_type: "Module".to_string(), + observations: vec![], + created_by: String::new(), + updated_by: String::new(), + created_at: 0, + updated_at: 0, + }; + kb.create_entities(vec![entity]).unwrap(); + } + + // Test limit + let result = kb.read_graph(Some(5), None).unwrap(); + assert_eq!(result.entities.len(), 5); + + // Test offset + let result = kb.read_graph(Some(5), Some(10)).unwrap(); + assert_eq!(result.entities.len(), 5); + + // Test beyond range + let result = kb.read_graph(Some(100), Some(50)).unwrap(); + assert_eq!(result.entities.len(), 0); + + cleanup(&temp_file); +}