diff --git a/README.md b/README.md index ddf5e25..557128a 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,155 @@ pip install toondb-client > ℹ️ **About the Binaries**: This Python SDK packages pre-compiled binaries from the [main ToonDB repository](https://github.com/toondb/toondb). Each wheel contains platform-specific executables (`toondb-bulk`, `toondb-server`, `toondb-grpc-server`) and native FFI libraries. See [RELEASE.md](RELEASE.md) for details on the release process. -## What's New in Latest Release +## What's New in v0.3.3 + +### πŸ•ΈοΈ Graph Overlay for Agent Memory +Build lightweight graph structures on top of ToonDB's KV storage for agent memory: + +```python +from toondb import Database, GraphOverlay + +db = Database.open("./agent_db") +graph = GraphOverlay(db, namespace="agent_memory") + +# Add nodes (entities, concepts, events) +graph.add_node("user_alice", "person", {"name": "Alice", "role": "developer"}) +graph.add_node("conv_123", "conversation", {"topic": "ToonDB features"}) +graph.add_node("action_456", "action", {"type": "code_commit", "status": "success"}) + +# Add edges (relationships, causality, references) +graph.add_edge("user_alice", "started", "conv_123", {"timestamp": "2026-01-05"}) +graph.add_edge("conv_123", "triggered", "action_456", {"reason": "user request"}) + +# Retrieve nodes and edges +node = graph.get_node("user_alice") +edges = graph.get_edges("user_alice", edge_type="started") + +# Graph traversal +visited = graph.bfs_traversal("user_alice", max_depth=3) # BFS from Alice +path = graph.shortest_path("user_alice", "action_456") # Find connection + +# Get neighbors +neighbors = graph.get_neighbors("conv_123", direction="both") + +# Extract subgraph +subgraph = graph.get_subgraph(["user_alice", "conv_123", "action_456"]) +``` + +**Use Cases:** +- Agent conversation history with causal chains +- Entity relationship tracking across sessions +- Action dependency graphs for planning +- Knowledge graph construction + +### πŸ›‘οΈ Policy & Safety Hooks +Enforce safety policies on agent operations with pre/post triggers: + +```python +from toondb import Database, PolicyEngine, PolicyAction + +db = Database.open("./agent_data") +policy = PolicyEngine(db) + +# Block writes to system keys from agents +@policy.before_write("system/*") +def block_system_writes(key, value, context): + if context.get("agent_id"): + return PolicyAction.DENY + return PolicyAction.ALLOW + +# Redact sensitive data on read +@policy.after_read("users/*/email") +def redact_emails(key, value, context): + if context.get("redact_pii"): + return b"[REDACTED]" + return value + +# Rate limit writes per agent +policy.add_rate_limit("write", max_per_minute=100, scope="agent_id") + +# Enable audit logging +policy.enable_audit() + +# Use policy-wrapped operations +policy.put(b"users/alice", b"data", context={"agent_id": "agent_001"}) +``` + +### πŸ”€ Multi-Agent Tool Routing +Route tool calls to specialized agents with automatic failover: + +```python +from toondb import Database, ToolDispatcher, ToolCategory, RoutingStrategy + +db = Database.open("./agent_data") +dispatcher = ToolDispatcher(db) + +# Register agents with capabilities +dispatcher.register_local_agent( + "code_agent", + capabilities=[ToolCategory.CODE, ToolCategory.GIT], + handler=lambda tool, args: {"result": f"Processed {tool}"}, +) + +dispatcher.register_remote_agent( + "search_agent", + capabilities=[ToolCategory.SEARCH], + endpoint="http://localhost:8001/invoke", +) + +# Register tools +dispatcher.register_tool( + name="search_code", + description="Search codebase", + category=ToolCategory.CODE, +) + +# Invoke with automatic routing (priority, round-robin, fastest, etc.) +result = dispatcher.invoke("search_code", {"query": "auth"}, session_id="sess_001") +print(f"Routed to: {result.agent_id}, Success: {result.success}") +``` + +### πŸ•ΈοΈ Graph Overlay +Lightweight graph layer for agent memory relationships: + +```python +from toondb import Database, GraphOverlay, TraversalOrder + +db = Database.open("./agent_data") +graph = GraphOverlay(db) + +# Add nodes (entities, concepts, events) +graph.add_node("user:alice", node_type="user", properties={"role": "admin"}) +graph.add_node("project:toondb", node_type="project", properties={"status": "active"}) + +# Add relationships +graph.add_edge("user:alice", "project:toondb", edge_type="owns", properties={"since": "2024"}) + +# Traverse graph (BFS/DFS) +related = graph.bfs("user:alice", max_depth=2, edge_filter=lambda e: e.edge_type == "owns") + +# Find shortest path +path = graph.shortest_path("user:alice", "project:toondb") +``` + +### πŸ”— Unified Connection API +Single entry point with auto-detection: + +```python +import toondb + +# Auto-detects embedded mode from path +db = toondb.connect("./my_database") + +# Auto-detects IPC mode from socket +db = toondb.connect("/tmp/toondb.sock") + +# Auto-detects gRPC mode from host:port +db = toondb.connect("localhost:50051") + +# Explicit mode +db = toondb.connect("./data", mode="embedded", config={"sync_mode": "full"}) +``` ### 🎯 Namespace Isolation Logical database namespaces for true multi-tenancy without key prefixing: diff --git a/pyproject.toml b/pyproject.toml index 89c1542..ef3c510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "toondb-client" -version = "0.3.2" +version = "0.3.3" description = "ToonDB is an AI-native database with token-optimized output, O(|path|) lookups, built-in vector search, and durable transactions." readme = "README.md" license = {text = "Apache-2.0"} diff --git a/src/toondb/__init__.py b/src/toondb/__init__.py index 93e1857..ed746f9 100644 --- a/src/toondb/__init__.py +++ b/src/toondb/__init__.py @@ -68,6 +68,44 @@ estimate_tokens, split_by_tokens, ) +from .graph import ( + # Graph Overlay (Task 10) + GraphOverlay, + GraphNode, + GraphEdge, + TraversalOrder, +) +from .policy import ( + # Policy & Safety Hooks (Task 11) + PolicyEngine, + PolicyAction, + PolicyTrigger, + PolicyResult, + PolicyContext, + PolicyHandler, + PatternPolicy, + RateLimiter, + PolicyViolation, + # Built-in policy helpers + deny_all, + allow_all, + require_agent_id, + redact_value, + log_and_allow, +) +from .routing import ( + # Tool Routing (Task 12) + ToolRouter, + AgentRegistry, + ToolDispatcher, + Tool, + Agent, + ToolCategory, + RoutingStrategy, + AgentStatus, + RouteResult, + RoutingContext, +) # Vector search (optional - requires libtoondb_index) try: @@ -104,8 +142,144 @@ track_batch_insert = None is_analytics_disabled = lambda: True -__version__ = "0.3.1" +__version__ = "0.3.3" + + +# ============================================================================= +# Unified Connection API (Task 9: Standardize Deployment Modes) +# ============================================================================= + +from enum import Enum +from typing import Optional, Union + + +class ConnectionMode(Enum): + """ToonDB connection mode.""" + EMBEDDED = "embedded" # Direct FFI to Rust library + IPC = "ipc" # Unix socket to local server + GRPC = "grpc" # gRPC to remote server + + +def connect( + path_or_url: str, + mode: Optional[Union[str, ConnectionMode]] = None, + config: Optional[dict] = None, +) -> Union[Database, IpcClient]: + """ + Connect to ToonDB with automatic mode detection. + + This is the unified entry point for all ToonDB connection modes. + If mode is not specified, it auto-detects based on the path/URL: + + - Embedded: File paths (./data, /tmp/db, ~/toondb) + - IPC: Unix socket paths (/tmp/toondb.sock, unix://...) + - gRPC: URLs with grpc:// scheme or host:port format + + Args: + path_or_url: Database path, socket path, or gRPC URL + mode: Optional explicit mode ('embedded', 'ipc', 'grpc' or ConnectionMode enum) + config: Optional configuration dict (passed to underlying client) + + Returns: + Database, IpcClient, or GrpcClient depending on mode + + Examples: + # Embedded mode (auto-detected from file path) + db = toondb.connect("./my_database") + db.put(b"key", b"value") + + # IPC mode (auto-detected from .sock extension) + db = toondb.connect("/tmp/toondb.sock") + + # gRPC mode (auto-detected from host:port) + db = toondb.connect("localhost:50051") + + # Explicit mode + db = toondb.connect("./data", mode="embedded", config={ + "sync_mode": "full", + "index_policy": "scan_optimized", + }) + + # Using enum + db = toondb.connect("localhost:50051", mode=toondb.ConnectionMode.GRPC) + """ + # Normalize mode to enum + if mode is None: + detected_mode = _detect_mode(path_or_url) + elif isinstance(mode, str): + try: + detected_mode = ConnectionMode(mode.lower()) + except ValueError: + raise ValueError( + f"Invalid mode '{mode}'. Valid modes: embedded, ipc, grpc" + ) + else: + detected_mode = mode + + # Create appropriate client + if detected_mode == ConnectionMode.EMBEDDED: + return Database.open(path_or_url, config=config) + + elif detected_mode == ConnectionMode.IPC: + socket_path = path_or_url + if socket_path.startswith("unix://"): + socket_path = socket_path[7:] # Strip unix:// prefix + return IpcClient(socket_path) + + elif detected_mode == ConnectionMode.GRPC: + try: + from .grpc_client import GrpcClient + url = path_or_url + if url.startswith("grpc://"): + url = url[7:] # Strip grpc:// prefix + return GrpcClient(url) + except ImportError: + raise ImportError( + "gRPC mode requires grpc dependencies. " + "Install with: pip install toondb[grpc]" + ) + + else: + raise ValueError(f"Unknown connection mode: {detected_mode}") + + +def _detect_mode(path_or_url: str) -> ConnectionMode: + """Auto-detect connection mode from path/URL format.""" + import os + + # Explicit scheme detection + if path_or_url.startswith("grpc://"): + return ConnectionMode.GRPC + if path_or_url.startswith("unix://"): + return ConnectionMode.IPC + + # Socket file detection + if path_or_url.endswith(".sock"): + return ConnectionMode.IPC + if "/tmp/" in path_or_url and "sock" in path_or_url.lower(): + return ConnectionMode.IPC + + # Host:port detection (gRPC) + if ":" in path_or_url: + parts = path_or_url.rsplit(":", 1) + if len(parts) == 2: + try: + port = int(parts[1]) + if 1 <= port <= 65535: + # Looks like host:port - probably gRPC + return ConnectionMode.GRPC + except ValueError: + pass + + # Default to embedded for file paths + return ConnectionMode.EMBEDDED + + __all__ = [ + # Unified API (Task 9) + "connect", + "ConnectionMode", + # Core "Database", "Transaction", @@ -129,6 +303,40 @@ "SearchResult", "SearchResults", + # Graph Overlay (Task 10) + "GraphOverlay", + "GraphNode", + "GraphEdge", + "TraversalOrder", + + # Policy & Safety Hooks (Task 11) + "PolicyEngine", + "PolicyAction", + "PolicyTrigger", + "PolicyResult", + "PolicyContext", + "PolicyHandler", + "PatternPolicy", + "RateLimiter", + "PolicyViolation", + "deny_all", + "allow_all", + "require_agent_id", + "redact_value", + "log_and_allow", + + # Tool Routing (Task 12) + "ToolRouter", + "AgentRegistry", + "ToolDispatcher", + "Tool", + "Agent", + "ToolCategory", + "RoutingStrategy", + "AgentStatus", + "RouteResult", + "RoutingContext", + # ContextQuery (Task 12) "ContextQuery", "ContextResult", diff --git a/src/toondb/context.py b/src/toondb/context.py index 9b23421..9cfe417 100644 --- a/src/toondb/context.py +++ b/src/toondb/context.py @@ -699,3 +699,476 @@ def split_by_tokens( chunks.append(" ".join(current_chunk)) return chunks + + +# ============================================================================ +# CONTEXT SELECT: Production-Ready Query Builder (Aligned with Rust Model) +# ============================================================================ + +class SectionKind(str, Enum): + """Section content kind - aligned with Rust SectionContent enum.""" + GET = "get" # GET path expression + LAST = "last" # LAST N FROM table + SEARCH = "search" # SEARCH by similarity + SELECT = "select" # Standard SQL subquery + LITERAL = "literal" # Literal value + VARIABLE = "variable" # Variable reference + TOOL_REGISTRY = "tool_registry" # Available tools + TOOL_CALLS = "tool_calls" # Recent tool calls + + +class TruncationPolicy(str, Enum): + """Truncation policy when budget is exceeded.""" + TAIL_DROP = "tail_drop" # Drop from tail (keep head) + HEAD_DROP = "head_drop" # Drop from head (keep tail) + PROPORTIONAL = "proportional" # Proportional truncation + FAIL = "fail" # Fail on budget exceeded + + +@dataclass +class ContextSectionConfig: + """ + Configuration for a single context section. + + Aligned with Rust ContextSection model for production consistency. + + Example: + # Get user profile + ContextSectionConfig( + name="user", + kind=SectionKind.GET, + priority=0, + path="users/alice/profile", + fields=["name", "preferences"] + ) + + # Last 10 tool calls + ContextSectionConfig( + name="history", + kind=SectionKind.LAST, + priority=1, + table="tool_calls", + count=10, + where={"status": "success"} + ) + + # Vector search + ContextSectionConfig( + name="knowledge", + kind=SectionKind.SEARCH, + priority=2, + collection="docs", + query="machine learning", + top_k=5 + ) + """ + name: str + kind: SectionKind + priority: int = 0 # Lower = higher priority + + # GET options + path: Optional[str] = None + fields: Optional[List[str]] = None + + # LAST/SELECT options + table: Optional[str] = None + count: Optional[int] = None + columns: Optional[List[str]] = None + where: Optional[Dict[str, Any]] = None + limit: Optional[int] = None + + # SEARCH options + collection: Optional[str] = None + query: Optional[str] = None + vector: Optional[List[float]] = None + top_k: int = 5 + min_score: Optional[float] = None + + # LITERAL options + text: Optional[str] = None + + # VARIABLE options + variable_name: Optional[str] = None + + # TOOL_REGISTRY options + include_tools: Optional[List[str]] = None + exclude_tools: Optional[List[str]] = None + include_schema: bool = True + + # TOOL_CALLS options + tool_filter: Optional[str] = None + status_filter: Optional[str] = None + include_outputs: bool = True + + +@dataclass +class ContextSelectResult: + """Result from CONTEXT SELECT execution.""" + sections: List[Dict[str, Any]] + total_tokens: int + budget_tokens: int + truncated: bool = False + truncated_sections: List[str] = field(default_factory=list) + provenance: Dict[str, Any] = field(default_factory=dict) + + def as_text(self, include_headers: bool = True) -> str: + """Format as text for LLM prompt.""" + parts = [] + for section in self.sections: + if include_headers: + parts.append(f"## {section['name']}") + if "content" in section: + content = section["content"] + if isinstance(content, list): + parts.append("\n".join(str(item) for item in content)) + elif isinstance(content, dict): + parts.append(json.dumps(content, indent=2)) + else: + parts.append(str(content)) + return "\n\n".join(parts) + + def as_json(self) -> str: + """Format as JSON.""" + return json.dumps({ + "sections": self.sections, + "total_tokens": self.total_tokens, + "budget_tokens": self.budget_tokens, + "truncated": self.truncated, + "provenance": self.provenance, + }, indent=2) + + +class ContextSelect: + """ + Production-ready CONTEXT SELECT query builder. + + Aligned with Rust ContextSelectQuery for consistent semantics across + embedded, IPC, and MCP deployment modes. + + Example: + from toondb import Database + from toondb.context import ContextSelect, SectionKind, ContextSectionConfig + + db = Database.open("./my_db") + + result = ( + ContextSelect(db) + .add_section(ContextSectionConfig( + name="user", + kind=SectionKind.GET, + priority=0, + path="users/alice/profile" + )) + .add_section(ContextSectionConfig( + name="history", + kind=SectionKind.LAST, + priority=1, + table="events", + count=10 + )) + .add_section(ContextSectionConfig( + name="knowledge", + kind=SectionKind.SEARCH, + priority=2, + collection="docs", + query="machine learning", + top_k=5 + )) + .with_token_budget(4096) + .execute() + ) + + prompt = f'''Context: + {result.as_text()} + + Question: What is machine learning? + ''' + """ + + def __init__( + self, + db: "Database", + token_estimator: Optional[TokenEstimator] = None, + ): + """ + Initialize CONTEXT SELECT builder. + + Args: + db: ToonDB Database instance + token_estimator: Optional token estimator (default: heuristic) + """ + from .database import Database + if not isinstance(db, Database): + raise TypeError("db must be a Database instance") + + self._db = db + self._estimator = token_estimator or TokenEstimator() + self._sections: List[ContextSectionConfig] = [] + self._token_budget: int = 4096 + self._truncation: TruncationPolicy = TruncationPolicy.TAIL_DROP + self._include_headers: bool = True + self._output_format: str = "text" + + def add_section(self, section: ContextSectionConfig) -> "ContextSelect": + """Add a section to the context query.""" + self._sections.append(section) + return self + + def with_token_budget(self, tokens: int) -> "ContextSelect": + """Set token budget for entire context.""" + self._token_budget = tokens + return self + + def with_truncation(self, policy: TruncationPolicy) -> "ContextSelect": + """Set truncation policy when budget is exceeded.""" + self._truncation = policy + return self + + def with_headers(self, include: bool = True) -> "ContextSelect": + """Include section headers in output.""" + self._include_headers = include + return self + + def execute(self) -> ContextSelectResult: + """ + Execute the CONTEXT SELECT query. + + Sections are processed in priority order (lower = higher priority). + Token budget is enforced using greedy allocation. + + Returns: + ContextSelectResult with assembled context + """ + if not self._sections: + raise ValueError("No sections added. Use add_section().") + + # Sort by priority + sorted_sections = sorted(self._sections, key=lambda s: s.priority) + + assembled_sections = [] + total_tokens = 0 + truncated = False + truncated_sections = [] + provenance = {} + + for section in sorted_sections: + # Execute section + section_data = self._execute_section(section) + + # Estimate tokens + section_text = json.dumps(section_data) if isinstance(section_data, (dict, list)) else str(section_data) + section_tokens = self._estimator.count(section_text) + + # Check budget + if total_tokens + section_tokens > self._token_budget: + if self._truncation == TruncationPolicy.FAIL: + raise ValueError(f"Token budget exceeded at section '{section.name}'") + elif self._truncation == TruncationPolicy.TAIL_DROP: + # Try to fit partial content + remaining = self._token_budget - total_tokens + if remaining > 50: # Minimum useful content + section_data = self._truncate_section(section_data, remaining) + section_tokens = remaining + truncated = True + truncated_sections.append(section.name) + else: + truncated_sections.append(section.name) + continue + else: + truncated_sections.append(section.name) + continue + + assembled_sections.append({ + "name": section.name, + "priority": section.priority, + "kind": section.kind.value, + "content": section_data, + "tokens": section_tokens, + }) + total_tokens += section_tokens + + # Track provenance + provenance[section.name] = { + "kind": section.kind.value, + "tokens": section_tokens, + "source": self._get_provenance_source(section), + } + + return ContextSelectResult( + sections=assembled_sections, + total_tokens=total_tokens, + budget_tokens=self._token_budget, + truncated=truncated, + truncated_sections=truncated_sections, + provenance=provenance, + ) + + def _execute_section(self, section: ContextSectionConfig) -> Any: + """Execute a single section and return its content.""" + if section.kind == SectionKind.GET: + return self._exec_get(section) + elif section.kind == SectionKind.LAST: + return self._exec_last(section) + elif section.kind == SectionKind.SELECT: + return self._exec_select(section) + elif section.kind == SectionKind.SEARCH: + return self._exec_search(section) + elif section.kind == SectionKind.LITERAL: + return section.text or "" + elif section.kind == SectionKind.VARIABLE: + return f"${{{section.variable_name}}}" # Placeholder for variable expansion + elif section.kind == SectionKind.TOOL_REGISTRY: + return self._exec_tool_registry(section) + elif section.kind == SectionKind.TOOL_CALLS: + return self._exec_tool_calls(section) + else: + return None + + def _exec_get(self, section: ContextSectionConfig) -> Any: + """Execute GET section.""" + if not section.path: + return None + + data = self._db.get_path(section.path) + if data is None: + return None + + try: + parsed = json.loads(data.decode("utf-8")) + # Project fields if specified + if section.fields and isinstance(parsed, dict): + return {k: parsed.get(k) for k in section.fields if k in parsed} + return parsed + except (json.JSONDecodeError, UnicodeDecodeError): + return data.decode("utf-8", errors="replace") + + def _exec_last(self, section: ContextSectionConfig) -> List[Any]: + """Execute LAST section.""" + if not section.table: + return [] + + prefix = f"_sql/tables/{section.table}/rows/".encode() + results = [] + count = section.count or 10 + + for key, value in self._db.scan_prefix(prefix): + try: + row = json.loads(value.decode("utf-8")) + # Apply WHERE filter + if section.where: + if not self._matches_where(row, section.where): + continue + results.append(row) + if len(results) >= count: + break + except (json.JSONDecodeError, UnicodeDecodeError): + continue + + return results + + def _exec_select(self, section: ContextSectionConfig) -> List[Any]: + """Execute SELECT section.""" + # Use SQL engine if available + if section.table: + columns = ",".join(section.columns or ["*"]) + sql = f"SELECT {columns} FROM {section.table}" + if section.where: + conditions = " AND ".join(f"{k}='{v}'" for k, v in section.where.items()) + sql += f" WHERE {conditions}" + if section.limit: + sql += f" LIMIT {section.limit}" + + try: + result = self._db.execute(sql) + return result.rows + except Exception: + return [] + return [] + + def _exec_search(self, section: ContextSectionConfig) -> List[Any]: + """Execute SEARCH section.""" + # Vector search via namespace/collection + # This is a placeholder - real implementation would use vector index + if section.collection: + # Try to get collection from namespace + try: + ns = self._db.namespace("default") + coll = ns.collection(section.collection) + from .namespace import SearchRequest + + request = SearchRequest( + text_query=section.query, + k=section.top_k, + ) + results = coll.search(request) + return [{"id": r.id, "score": r.score, "metadata": r.metadata} for r in results] + except Exception: + pass + return [] + + def _exec_tool_registry(self, section: ContextSectionConfig) -> List[Dict[str, Any]]: + """Execute TOOL_REGISTRY section.""" + # Placeholder - would query actual tool registry + return [{"name": "example_tool", "description": "Example tool"}] + + def _exec_tool_calls(self, section: ContextSectionConfig) -> List[Any]: + """Execute TOOL_CALLS section.""" + prefix = b"_tool_calls/" + results = [] + count = section.count or 10 + + for key, value in self._db.scan_prefix(prefix): + try: + call = json.loads(value.decode("utf-8")) + if section.tool_filter and call.get("tool") != section.tool_filter: + continue + if section.status_filter and call.get("status") != section.status_filter: + continue + if not section.include_outputs: + call.pop("output", None) + results.append(call) + if len(results) >= count: + break + except (json.JSONDecodeError, UnicodeDecodeError): + continue + + return results + + def _matches_where(self, row: Dict[str, Any], where: Dict[str, Any]) -> bool: + """Check if row matches WHERE conditions.""" + for key, value in where.items(): + if row.get(key) != value: + return False + return True + + def _truncate_section(self, data: Any, max_tokens: int) -> Any: + """Truncate section content to fit token budget.""" + if isinstance(data, str): + # Truncate string + chars = max_tokens * 4 # Heuristic + return data[:chars] + "..." if len(data) > chars else data + elif isinstance(data, list): + # Truncate list + truncated = [] + tokens = 0 + for item in data: + item_str = json.dumps(item) if isinstance(item, dict) else str(item) + item_tokens = self._estimator.count(item_str) + if tokens + item_tokens > max_tokens: + break + truncated.append(item) + tokens += item_tokens + return truncated + return data + + def _get_provenance_source(self, section: ContextSectionConfig) -> str: + """Get provenance source string for a section.""" + if section.kind == SectionKind.GET: + return f"path:{section.path}" + elif section.kind == SectionKind.LAST: + return f"table:{section.table}" + elif section.kind == SectionKind.SEARCH: + return f"collection:{section.collection}" + elif section.kind == SectionKind.SELECT: + return f"sql:{section.table}" + return section.kind.value diff --git a/src/toondb/database.py b/src/toondb/database.py index ed5c616..b75fe34 100644 --- a/src/toondb/database.py +++ b/src/toondb/database.py @@ -23,7 +23,7 @@ import sys import ctypes import warnings -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Union from contextlib import contextmanager from .errors import ( DatabaseError, @@ -130,6 +130,33 @@ class C_TxnHandle(ctypes.Structure): ] +class C_CommitResult(ctypes.Structure): + """Commit result with HLC-backed monotonic timestamp.""" + _fields_ = [ + ("commit_ts", ctypes.c_uint64), # HLC timestamp, 0 on error + ("error_code", ctypes.c_int32), # 0=success, -1=error, -2=SSI conflict + ] + + +class C_DatabaseConfig(ctypes.Structure): + """Database configuration passed to toondb_open_with_config. + + Configuration options control durability, performance, and indexing behavior. + Fields with _set suffix indicate whether the corresponding value was explicitly set. + """ + _fields_ = [ + ("wal_enabled", ctypes.c_bool), # Enable WAL for durability + ("wal_enabled_set", ctypes.c_bool), # Whether wal_enabled was set + ("sync_mode", ctypes.c_uint8), # 0=OFF, 1=NORMAL, 2=FULL + ("sync_mode_set", ctypes.c_bool), # Whether sync_mode was set + ("memtable_size_bytes", ctypes.c_uint64), # Memtable size (0=default 64MB) + ("group_commit", ctypes.c_bool), # Enable group commit + ("group_commit_set", ctypes.c_bool), # Whether group_commit was set + ("default_index_policy", ctypes.c_uint8), # 0=WriteOptimized, 1=Balanced, 2=ScanOptimized, 3=AppendOnly + ("default_index_policy_set", ctypes.c_bool), # Whether index policy was set + ] + + class C_StorageStats(ctypes.Structure): _fields_ = [ ("memtable_size_bytes", ctypes.c_uint64), @@ -163,6 +190,10 @@ def _setup_bindings(cls): lib.toondb_open.argtypes = [ctypes.c_char_p] lib.toondb_open.restype = ctypes.c_void_p + # toondb_open_with_config(path: *const c_char, config: C_DatabaseConfig) -> *mut DatabasePtr + lib.toondb_open_with_config.argtypes = [ctypes.c_char_p, C_DatabaseConfig] + lib.toondb_open_with_config.restype = ctypes.c_void_p + # toondb_close(ptr: *mut DatabasePtr) lib.toondb_close.argtypes = [ctypes.c_void_p] lib.toondb_close.restype = None @@ -172,9 +203,10 @@ def _setup_bindings(cls): lib.toondb_begin_txn.argtypes = [ctypes.c_void_p] lib.toondb_begin_txn.restype = C_TxnHandle - # toondb_commit(ptr: *mut DatabasePtr, handle: C_TxnHandle) -> c_int + # toondb_commit(ptr: *mut DatabasePtr, handle: C_TxnHandle) -> C_CommitResult + # Returns HLC-backed monotonic commit timestamp for MVCC observability lib.toondb_commit.argtypes = [ctypes.c_void_p, C_TxnHandle] - lib.toondb_commit.restype = ctypes.c_int + lib.toondb_commit.restype = C_CommitResult # toondb_abort(ptr: *mut DatabasePtr, handle: C_TxnHandle) -> c_int lib.toondb_abort.argtypes = [ctypes.c_void_p, C_TxnHandle] @@ -273,6 +305,24 @@ def _setup_bindings(cls): # toondb_stats(ptr) -> C_StorageStats lib.toondb_stats.argtypes = [ctypes.c_void_p] lib.toondb_stats.restype = C_StorageStats + + # Per-Table Index Policy API + # toondb_set_table_index_policy(ptr, table_name, policy) -> c_int + # Sets index policy for a table: 0=WriteOptimized, 1=Balanced, 2=ScanOptimized, 3=AppendOnly + lib.toondb_set_table_index_policy.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_uint8 + ] + lib.toondb_set_table_index_policy.restype = ctypes.c_int + + # toondb_get_table_index_policy(ptr, table_name) -> u8 + # Gets index policy for a table. Returns 255 on error. + lib.toondb_get_table_index_policy.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p + ] + lib.toondb_get_table_index_policy.restype = ctypes.c_uint8 class Transaction: @@ -474,18 +524,50 @@ def scan_prefix(self, prefix: bytes): This method is safe for multi-tenant isolation - it will NEVER return keys from other tenants/prefixes. + Prefix Safety: + A minimum prefix length of 2 bytes is required to prevent + expensive full-database scans. Use scan_prefix_unchecked() if + you need unrestricted access for internal operations. + Args: - prefix: The prefix to match. All returned keys will start with this prefix. + prefix: The prefix to match (minimum 2 bytes). All returned keys + will start with this prefix. Yields: (key, value) tuples where key.startswith(prefix) is True. + Raises: + ValueError: If prefix is less than 2 bytes. + Example: # Get all user keys - safe for multi-tenant for key, value in txn.scan_prefix(b"tenant_a/"): print(f"{key}: {value}") # Will NEVER include keys like b"tenant_b/..." """ + MIN_PREFIX_LEN = 2 + if len(prefix) < MIN_PREFIX_LEN: + raise ValueError( + f"Prefix too short: {len(prefix)} bytes (minimum {MIN_PREFIX_LEN} required). " + f"Use scan_prefix_unchecked() for unrestricted prefix access." + ) + return self.scan_prefix_unchecked(prefix) + + def scan_prefix_unchecked(self, prefix: bytes): + """ + Scan keys matching a prefix without length validation. + + Warning: + This method allows empty/short prefixes which can cause expensive + full-database scans. Use scan_prefix() unless you specifically need + unrestricted prefix access for internal operations. + + Args: + prefix: The prefix to match. Can be empty for full scan. + + Yields: + (key, value) tuples where key.startswith(prefix) is True. + """ if self._committed or self._aborted: raise TransactionError("Transaction already completed") @@ -631,19 +713,30 @@ def commit(self) -> int: Commit the transaction. Returns: - Commit timestamp. + Commit timestamp (HLC-backed, monotonically increasing). + This timestamp is suitable for: + - MVCC observability ("what commit did I read?") + - Replication and log shipping + - Agent audit trails + - Time-travel queries + - Deterministic replay + + Raises: + TransactionError: If commit fails (e.g., SSI conflict) """ if self._committed: raise TransactionError("Transaction already committed") if self._aborted: raise TransactionError("Transaction already aborted") - res = self._lib.toondb_commit(self._db._handle, self._handle) - if res != 0: + result = self._lib.toondb_commit(self._db._handle, self._handle) + if result.error_code != 0: + if result.error_code == -2: + raise TransactionError("SSI conflict: transaction aborted due to serialization failure") raise TransactionError("Failed to commit transaction") self._committed = True - return 0 # TODO: Return actual commit timestamp if exposed + return result.commit_ts def abort(self) -> None: """Abort the transaction.""" @@ -727,29 +820,91 @@ def open(cls, path: str, config: Optional[dict] = None) -> "Database": Args: path: Path to the database directory. config: Optional configuration dictionary with keys: - - create_if_missing (bool): Create if missing (default: True) - - wal_enabled (bool): Enable WAL (default: True) + - wal_enabled (bool): Enable WAL for durability (default: True) - sync_mode (str): 'full', 'normal', or 'off' (default: 'normal') - - memtable_size_bytes (int): Memtable size (default: 64MB) + - 'off': No fsync, ~10x faster but risk of data loss + - 'normal': Fsync at checkpoints, good balance (default) + - 'full': Fsync every commit, safest but slowest + - memtable_size_bytes (int): Memtable size before flush (default: 64MB) + - group_commit (bool): Enable group commit for throughput (default: True) + - index_policy (str): Default index policy for tables: + - 'write_optimized': O(1) insert, O(N) scan - for high-write + - 'balanced': O(1) amortized insert, O(log K) scan - default + - 'scan_optimized': O(log N) insert, O(log N + K) scan - for analytics + - 'append_only': O(1) insert, O(N) scan - for time-series Returns: Database instance. - Note: - The config parameter is currently accepted but not yet fully - implemented in v0.2.8. Future versions will apply these settings. + Example: + # Default configuration (good for most use cases) + db = Database.open("./my_database") + + # High-durability configuration + db = Database.open("./critical_data", config={ + "sync_mode": "full", + "wal_enabled": True, + }) + + # High-throughput configuration + db = Database.open("./logs", config={ + "sync_mode": "off", + "group_commit": True, + "index_policy": "write_optimized", + }) """ - if config is not None: - warnings.warn( - "Database.open() config parameter is not yet fully implemented in v0.2.8. " - "Configuration options will be supported in future versions.", - FutureWarning, - stacklevel=2 - ) - lib = _FFI.get_lib() path_bytes = path.encode("utf-8") - handle = lib.toondb_open(path_bytes) + + if config is not None: + # Build C config struct from Python dict + c_config = C_DatabaseConfig() + + # WAL enabled + if "wal_enabled" in config: + c_config.wal_enabled = bool(config["wal_enabled"]) + c_config.wal_enabled_set = True + + # Sync mode + if "sync_mode" in config: + mode = config["sync_mode"].lower() if isinstance(config["sync_mode"], str) else str(config["sync_mode"]) + if mode in ("off", "0"): + c_config.sync_mode = 0 + elif mode in ("normal", "1"): + c_config.sync_mode = 1 + elif mode in ("full", "2"): + c_config.sync_mode = 2 + else: + c_config.sync_mode = 1 # Default to normal + c_config.sync_mode_set = True + + # Memtable size + if "memtable_size_bytes" in config: + c_config.memtable_size_bytes = int(config["memtable_size_bytes"]) + + # Group commit + if "group_commit" in config: + c_config.group_commit = bool(config["group_commit"]) + c_config.group_commit_set = True + + # Index policy + if "index_policy" in config: + policy = config["index_policy"].lower() if isinstance(config["index_policy"], str) else str(config["index_policy"]) + if policy == "write_optimized": + c_config.default_index_policy = 0 + elif policy == "balanced": + c_config.default_index_policy = 1 + elif policy == "scan_optimized": + c_config.default_index_policy = 2 + elif policy == "append_only": + c_config.default_index_policy = 3 + else: + c_config.default_index_policy = 1 # Default to balanced + c_config.default_index_policy_set = True + + handle = lib.toondb_open_with_config(path_bytes, c_config) + else: + handle = lib.toondb_open(path_bytes) if not handle: raise DatabaseError(f"Failed to open database at {path}") @@ -881,12 +1036,20 @@ def scan_prefix(self, prefix: bytes): which operates on an arbitrary range, scan_prefix() guarantees that only keys starting with the given prefix are returned. + Prefix Safety: + A minimum prefix length of 2 bytes is required to prevent + expensive full-database scans. + Args: - prefix: The prefix to match. All returned keys will start with this prefix. + prefix: The prefix to match (minimum 2 bytes). All returned keys + will start with this prefix. Yields: (key, value) tuples where key.startswith(prefix) is True. + Raises: + ValueError: If prefix is less than 2 bytes. + Example: # Get all keys under "users/" for key, value in db.scan_prefix(b"users/"): @@ -900,6 +1063,24 @@ def scan_prefix(self, prefix: bytes): with self.transaction() as txn: yield from txn.scan_prefix(prefix) + def scan_prefix_unchecked(self, prefix: bytes): + """ + Scan keys matching a prefix without length validation (auto-commit transaction). + + Warning: + This method allows empty/short prefixes which can cause expensive + full-database scans. Use scan_prefix() unless you specifically need + unrestricted prefix access for internal operations like graph overlay. + + Args: + prefix: The prefix to match. Can be empty for full scan. + + Yields: + (key, value) tuples where key.startswith(prefix) is True. + """ + with self.transaction() as txn: + yield from txn.scan_prefix_unchecked(prefix) + def delete_path(self, path: str) -> None: """ Delete at a path (auto-commit). @@ -966,6 +1147,124 @@ def stats(self) -> dict: "last_checkpoint_lsn": stats.last_checkpoint_lsn, } + # ========================================================================= + # Per-Table Index Policy API + # ========================================================================= + + # Index policy constants + INDEX_WRITE_OPTIMIZED = 0 + INDEX_BALANCED = 1 + INDEX_SCAN_OPTIMIZED = 2 + INDEX_APPEND_ONLY = 3 + + _POLICY_NAMES = { + INDEX_WRITE_OPTIMIZED: "write_optimized", + INDEX_BALANCED: "balanced", + INDEX_SCAN_OPTIMIZED: "scan_optimized", + INDEX_APPEND_ONLY: "append_only", + } + + _POLICY_VALUES = { + "write_optimized": INDEX_WRITE_OPTIMIZED, + "write": INDEX_WRITE_OPTIMIZED, + "balanced": INDEX_BALANCED, + "default": INDEX_BALANCED, + "scan_optimized": INDEX_SCAN_OPTIMIZED, + "scan": INDEX_SCAN_OPTIMIZED, + "append_only": INDEX_APPEND_ONLY, + "append": INDEX_APPEND_ONLY, + } + + def set_table_index_policy(self, table: str, policy: Union[int, str]) -> None: + """ + Set the index policy for a specific table. + + Index policies control the trade-off between write and read performance: + + - 'write_optimized' (0): O(1) writes, O(N) scans + Best for write-heavy tables with rare range queries. + + - 'balanced' (1): O(1) amortized writes, O(output + log K) scans + Good balance for mixed OLTP workloads. This is the default. + + - 'scan_optimized' (2): O(log N) writes, O(log N + K) scans + Best for analytics tables with frequent range queries. + + - 'append_only' (3): O(1) writes, O(N) forward-only scans + Best for time-series logs where data is naturally ordered. + + Args: + table: Table name (uses table prefix for key grouping) + policy: Policy name (str) or value (int) + + Raises: + ValueError: If policy is invalid + DatabaseError: If FFI call fails + + Example: + # For write-heavy user sessions + db.set_table_index_policy("sessions", "write_optimized") + + # For analytics queries + db.set_table_index_policy("events", "scan_optimized") + """ + self._check_open() + + # Convert string policy to int + if isinstance(policy, str): + policy_value = self._POLICY_VALUES.get(policy.lower()) + if policy_value is None: + raise ValueError( + f"Invalid policy '{policy}'. Valid policies: " + f"{list(self._POLICY_VALUES.keys())}" + ) + else: + policy_value = int(policy) + if policy_value not in self._POLICY_NAMES: + raise ValueError( + f"Invalid policy value {policy_value}. Valid values: 0-3" + ) + + table_bytes = table.encode("utf-8") + result = self._lib.toondb_set_table_index_policy( + self._handle, + table_bytes, + policy_value + ) + + if result == -1: + raise DatabaseError("Failed to set table index policy") + elif result == -2: + raise ValueError(f"Invalid policy value: {policy_value}") + + def get_table_index_policy(self, table: str) -> str: + """ + Get the index policy for a specific table. + + Args: + table: Table name + + Returns: + Policy name as string: 'write_optimized', 'balanced', + 'scan_optimized', or 'append_only' + + Example: + policy = db.get_table_index_policy("users") + print(f"Users table uses {policy} indexing") + """ + self._check_open() + + table_bytes = table.encode("utf-8") + policy_value = self._lib.toondb_get_table_index_policy( + self._handle, + table_bytes + ) + + if policy_value == 255: + raise DatabaseError("Failed to get table index policy") + + return self._POLICY_NAMES.get(policy_value, "balanced") + def execute(self, sql: str) -> 'SQLQueryResult': """ Execute a SQL query. diff --git a/src/toondb/graph.py b/src/toondb/graph.py new file mode 100644 index 0000000..0b83bbb --- /dev/null +++ b/src/toondb/graph.py @@ -0,0 +1,663 @@ +# Copyright 2025 Sushanth (https://github.com/sushanthpy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Semi-GraphDB Overlay for Agent Memory. + +Provides a lightweight graph layer on top of ToonDB's KV storage for modeling +agent memory relationships. This is NOT a full graph database - it's optimized +for typical agent memory patterns: + +- Entity-to-entity relationships (user <-> conversation <-> message) +- Causal chains (action1 -> action2 -> action3) +- Reference graphs (document <- citation <- quote) + +Storage Model: +-------------- +Nodes: _graph/{namespace}/nodes/{node_id} -> {type, properties} +Edges: _graph/{namespace}/edges/{from_id}/{edge_type}/{to_id} -> {properties} +Index: _graph/{namespace}/index/{edge_type}/{to_id} -> [from_ids] (reverse lookup) + +Performance: +------------ +- Add node: O(1) +- Add edge: O(1) +- Get node: O(1) +- Get outgoing edges: O(degree) +- Get incoming edges: O(degree) via reverse index +- BFS/DFS traversal: O(V + E) for reachable subgraph + +Example: +-------- + from toondb import Database + from toondb.graph import GraphOverlay + + db = Database.open("./agent_memory") + graph = GraphOverlay(db, namespace="agent_001") + + # Create nodes + graph.add_node("user_1", "User", {"name": "Alice"}) + graph.add_node("conv_1", "Conversation", {"title": "Planning Session"}) + graph.add_node("msg_1", "Message", {"content": "Let's start planning"}) + + # Create edges + graph.add_edge("user_1", "STARTED", "conv_1") + graph.add_edge("conv_1", "CONTAINS", "msg_1") + graph.add_edge("user_1", "SENT", "msg_1") + + # Query relationships + conversations = graph.get_edges("user_1", "STARTED") + # [("conv_1", {"title": "Planning Session"})] + + # Traverse graph + reachable = graph.bfs("user_1", max_depth=2) + # ["user_1", "conv_1", "msg_1"] +""" + +import json +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Set, Iterator + + +class TraversalOrder(Enum): + """Graph traversal order.""" + BFS = "bfs" # Breadth-first search + DFS = "dfs" # Depth-first search + + +@dataclass +class GraphNode: + """A node in the graph.""" + id: str + type: str + properties: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict: + return { + "id": self.id, + "type": self.type, + "properties": self.properties, + } + + @classmethod + def from_dict(cls, data: Dict) -> "GraphNode": + return cls( + id=data["id"], + type=data["type"], + properties=data.get("properties", {}), + ) + + +@dataclass +class GraphEdge: + """An edge in the graph.""" + from_id: str + edge_type: str + to_id: str + properties: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict: + return { + "from_id": self.from_id, + "edge_type": self.edge_type, + "to_id": self.to_id, + "properties": self.properties, + } + + @classmethod + def from_dict(cls, data: Dict) -> "GraphEdge": + return cls( + from_id=data["from_id"], + edge_type=data["edge_type"], + to_id=data["to_id"], + properties=data.get("properties", {}), + ) + + +class GraphOverlay: + """ + Lightweight graph overlay on ToonDB. + + Provides graph operations for agent memory without a full graph database. + Uses the underlying KV store for persistence with O(1) node/edge operations. + """ + + # Key prefixes + PREFIX = "_graph" + + def __init__(self, db, namespace: str = "default"): + """ + Initialize graph overlay. + + Args: + db: ToonDB Database instance + namespace: Namespace for graph isolation (e.g., agent_id) + """ + self._db = db + self._namespace = namespace + self._prefix = f"{self.PREFIX}/{namespace}".encode() + + def _node_key(self, node_id: str) -> bytes: + """Key for a node.""" + return f"{self.PREFIX}/{self._namespace}/nodes/{node_id}".encode() + + def _edge_key(self, from_id: str, edge_type: str, to_id: str) -> bytes: + """Key for an edge.""" + return f"{self.PREFIX}/{self._namespace}/edges/{from_id}/{edge_type}/{to_id}".encode() + + def _edge_prefix(self, from_id: str, edge_type: Optional[str] = None) -> bytes: + """Prefix for outgoing edges.""" + if edge_type: + return f"{self.PREFIX}/{self._namespace}/edges/{from_id}/{edge_type}/".encode() + return f"{self.PREFIX}/{self._namespace}/edges/{from_id}/".encode() + + def _reverse_index_key(self, edge_type: str, to_id: str, from_id: str) -> bytes: + """Key for reverse edge index.""" + return f"{self.PREFIX}/{self._namespace}/index/{edge_type}/{to_id}/{from_id}".encode() + + def _reverse_index_prefix(self, edge_type: str, to_id: str) -> bytes: + """Prefix for reverse index lookup.""" + return f"{self.PREFIX}/{self._namespace}/index/{edge_type}/{to_id}/".encode() + + # ========================================================================= + # Node Operations + # ========================================================================= + + def add_node( + self, + node_id: str, + node_type: str, + properties: Optional[Dict[str, Any]] = None, + ) -> GraphNode: + """ + Add a node to the graph. + + Args: + node_id: Unique node identifier + node_type: Node type label (e.g., "User", "Message", "Tool") + properties: Optional node properties + + Returns: + The created GraphNode + """ + node = GraphNode( + id=node_id, + type=node_type, + properties=properties or {}, + ) + self._db.put(self._node_key(node_id), json.dumps(node.to_dict()).encode()) + return node + + def get_node(self, node_id: str) -> Optional[GraphNode]: + """ + Get a node by ID. + + Args: + node_id: Node identifier + + Returns: + GraphNode if found, None otherwise + """ + data = self._db.get(self._node_key(node_id)) + if data is None: + return None + return GraphNode.from_dict(json.loads(data.decode())) + + def update_node( + self, + node_id: str, + properties: Optional[Dict[str, Any]] = None, + node_type: Optional[str] = None, + ) -> Optional[GraphNode]: + """ + Update a node's properties or type. + + Args: + node_id: Node identifier + properties: Properties to merge (None to skip) + node_type: New type (None to keep existing) + + Returns: + Updated GraphNode if found, None otherwise + """ + node = self.get_node(node_id) + if node is None: + return None + + if properties: + node.properties.update(properties) + if node_type: + node.type = node_type + + self._db.put(self._node_key(node_id), json.dumps(node.to_dict()).encode()) + return node + + def delete_node(self, node_id: str, cascade: bool = False) -> bool: + """ + Delete a node from the graph. + + Args: + node_id: Node identifier + cascade: If True, also delete all connected edges + + Returns: + True if deleted, False if not found + """ + if self.get_node(node_id) is None: + return False + + if cascade: + # Delete outgoing edges + for edge in self.get_edges(node_id): + self.delete_edge(node_id, edge.edge_type, edge.to_id) + + # Delete incoming edges (using reverse index) + for edge in self.get_incoming_edges(node_id): + self.delete_edge(edge.from_id, edge.edge_type, node_id) + + self._db.delete(self._node_key(node_id)) + return True + + def node_exists(self, node_id: str) -> bool: + """Check if a node exists.""" + return self._db.get(self._node_key(node_id)) is not None + + # ========================================================================= + # Edge Operations + # ========================================================================= + + def add_edge( + self, + from_id: str, + edge_type: str, + to_id: str, + properties: Optional[Dict[str, Any]] = None, + ) -> GraphEdge: + """ + Add an edge between two nodes. + + Args: + from_id: Source node ID + edge_type: Edge type label (e.g., "SENT", "REFERENCES", "CAUSED") + to_id: Target node ID + properties: Optional edge properties + + Returns: + The created GraphEdge + """ + edge = GraphEdge( + from_id=from_id, + edge_type=edge_type, + to_id=to_id, + properties=properties or {}, + ) + + # Store edge + self._db.put( + self._edge_key(from_id, edge_type, to_id), + json.dumps(edge.to_dict()).encode() + ) + + # Store reverse index for incoming edge queries + self._db.put( + self._reverse_index_key(edge_type, to_id, from_id), + from_id.encode() + ) + + return edge + + def get_edge( + self, + from_id: str, + edge_type: str, + to_id: str, + ) -> Optional[GraphEdge]: + """ + Get a specific edge. + + Args: + from_id: Source node ID + edge_type: Edge type + to_id: Target node ID + + Returns: + GraphEdge if found, None otherwise + """ + data = self._db.get(self._edge_key(from_id, edge_type, to_id)) + if data is None: + return None + return GraphEdge.from_dict(json.loads(data.decode())) + + def get_edges( + self, + from_id: str, + edge_type: Optional[str] = None, + ) -> List[GraphEdge]: + """ + Get all outgoing edges from a node. + + Args: + from_id: Source node ID + edge_type: Optional filter by edge type + + Returns: + List of GraphEdge objects + """ + prefix = self._edge_prefix(from_id, edge_type) + edges = [] + + for _, value in self._db.scan_prefix_unchecked(prefix): + edges.append(GraphEdge.from_dict(json.loads(value.decode()))) + + return edges + + def get_incoming_edges( + self, + to_id: str, + edge_type: Optional[str] = None, + ) -> List[GraphEdge]: + """ + Get all incoming edges to a node. + + Uses reverse index for O(degree) lookup. + + Args: + to_id: Target node ID + edge_type: Optional filter by edge type + + Returns: + List of GraphEdge objects + """ + edges = [] + + if edge_type: + # Query specific edge type + prefix = self._reverse_index_prefix(edge_type, to_id) + for key, value in self._db.scan_prefix_unchecked(prefix): + from_id = value.decode() + edge = self.get_edge(from_id, edge_type, to_id) + if edge: + edges.append(edge) + else: + # Query all edge types - scan all index entries for to_id + # This is less efficient but works + index_prefix = f"{self.PREFIX}/{self._namespace}/index/".encode() + for key, value in self._db.scan_prefix_unchecked(index_prefix): + key_str = key.decode() + parts = key_str.split("/") + if len(parts) >= 6 and parts[4] == to_id: + from_id = value.decode() + et = parts[3] + edge = self.get_edge(from_id, et, to_id) + if edge: + edges.append(edge) + + return edges + + def delete_edge( + self, + from_id: str, + edge_type: str, + to_id: str, + ) -> bool: + """ + Delete an edge. + + Args: + from_id: Source node ID + edge_type: Edge type + to_id: Target node ID + + Returns: + True if deleted, False if not found + """ + if self.get_edge(from_id, edge_type, to_id) is None: + return False + + # Delete edge + self._db.delete(self._edge_key(from_id, edge_type, to_id)) + + # Delete reverse index + self._db.delete(self._reverse_index_key(edge_type, to_id, from_id)) + + return True + + # ========================================================================= + # Traversal Operations + # ========================================================================= + + def bfs( + self, + start_id: str, + max_depth: int = 10, + edge_types: Optional[List[str]] = None, + node_types: Optional[List[str]] = None, + ) -> List[str]: + """ + Breadth-first search from a starting node. + + Args: + start_id: Starting node ID + max_depth: Maximum traversal depth + edge_types: Optional filter by edge types + node_types: Optional filter by node types + + Returns: + List of reachable node IDs in BFS order + """ + return list(self._traverse( + start_id, + max_depth=max_depth, + edge_types=edge_types, + node_types=node_types, + order=TraversalOrder.BFS, + )) + + def dfs( + self, + start_id: str, + max_depth: int = 10, + edge_types: Optional[List[str]] = None, + node_types: Optional[List[str]] = None, + ) -> List[str]: + """ + Depth-first search from a starting node. + + Args: + start_id: Starting node ID + max_depth: Maximum traversal depth + edge_types: Optional filter by edge types + node_types: Optional filter by node types + + Returns: + List of reachable node IDs in DFS order + """ + return list(self._traverse( + start_id, + max_depth=max_depth, + edge_types=edge_types, + node_types=node_types, + order=TraversalOrder.DFS, + )) + + def _traverse( + self, + start_id: str, + max_depth: int, + edge_types: Optional[List[str]], + node_types: Optional[List[str]], + order: TraversalOrder, + ) -> Iterator[str]: + """Internal traversal implementation.""" + visited: Set[str] = set() + + if order == TraversalOrder.BFS: + from collections import deque + frontier: Any = deque([(start_id, 0)]) + else: + frontier = [(start_id, 0)] + + while frontier: + if order == TraversalOrder.BFS: + node_id, depth = frontier.popleft() + else: + node_id, depth = frontier.pop() + + if node_id in visited: + continue + + visited.add(node_id) + + # Check node type filter + if node_types: + node = self.get_node(node_id) + if node is None or node.type not in node_types: + continue + + yield node_id + + if depth >= max_depth: + continue + + # Get outgoing edges + for edge in self.get_edges(node_id): + if edge_types and edge.edge_type not in edge_types: + continue + if edge.to_id not in visited: + frontier.append((edge.to_id, depth + 1)) + + def shortest_path( + self, + from_id: str, + to_id: str, + max_depth: int = 10, + edge_types: Optional[List[str]] = None, + ) -> Optional[List[str]]: + """ + Find shortest path between two nodes using BFS. + + Args: + from_id: Source node ID + to_id: Target node ID + max_depth: Maximum path length + edge_types: Optional filter by edge types + + Returns: + List of node IDs forming the path, or None if not reachable + """ + from collections import deque + + if from_id == to_id: + return [from_id] + + visited: Set[str] = {from_id} + parent: Dict[str, str] = {} + frontier: Any = deque([(from_id, 0)]) + + while frontier: + node_id, depth = frontier.popleft() + + if depth >= max_depth: + continue + + for edge in self.get_edges(node_id): + if edge_types and edge.edge_type not in edge_types: + continue + + next_id = edge.to_id + if next_id in visited: + continue + + visited.add(next_id) + parent[next_id] = node_id + + if next_id == to_id: + # Reconstruct path + path = [to_id] + current = to_id + while current in parent: + current = parent[current] + path.append(current) + return list(reversed(path)) + + frontier.append((next_id, depth + 1)) + + return None # No path found + + # ========================================================================= + # Query Operations + # ========================================================================= + + def get_neighbors( + self, + node_id: str, + edge_types: Optional[List[str]] = None, + direction: str = "outgoing", + ) -> List[Tuple[str, GraphEdge]]: + """ + Get neighboring nodes with their connecting edges. + + Args: + node_id: Node ID + edge_types: Optional filter by edge types + direction: 'outgoing', 'incoming', or 'both' + + Returns: + List of (neighbor_id, edge) tuples + """ + neighbors = [] + + if direction in ("outgoing", "both"): + for edge in self.get_edges(node_id): + if edge_types and edge.edge_type not in edge_types: + continue + neighbors.append((edge.to_id, edge)) + + if direction in ("incoming", "both"): + for edge in self.get_incoming_edges(node_id): + if edge_types and edge.edge_type not in edge_types: + continue + neighbors.append((edge.from_id, edge)) + + return neighbors + + def get_nodes_by_type( + self, + node_type: str, + limit: int = 100, + ) -> List[GraphNode]: + """ + Get all nodes of a specific type. + + Note: This scans all nodes, use sparingly for large graphs. + + Args: + node_type: Node type to filter by + limit: Maximum number of nodes to return + + Returns: + List of GraphNode objects + """ + prefix = f"{self.PREFIX}/{self._namespace}/nodes/".encode() + nodes = [] + + for _, value in self._db.scan_prefix_unchecked(prefix): + node = GraphNode.from_dict(json.loads(value.decode())) + if node.type == node_type: + nodes.append(node) + if len(nodes) >= limit: + break + + return nodes diff --git a/src/toondb/policy.py b/src/toondb/policy.py new file mode 100644 index 0000000..6fdf158 --- /dev/null +++ b/src/toondb/policy.py @@ -0,0 +1,620 @@ +# Copyright 2025 Sushanth (https://github.com/sushanthpy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Policy & Safety Hooks for Agent Operations. + +Provides a trigger system for enforcing policies on agent actions: + +- Pre-write validation (block dangerous operations) +- Post-read filtering (redact sensitive data) +- Rate limiting (prevent runaway agents) +- Audit logging (track all operations) + +This is designed for AI agent safety and compliance scenarios where you need +to enforce guardrails on what agents can read/write. + +Example: +-------- + from toondb import Database + from toondb.policy import PolicyEngine, Policy, PolicyAction, PolicyTrigger + + db = Database.open("./agent_data") + policy = PolicyEngine(db) + + # Block writes to system keys + @policy.before_write("system/*") + def block_system_writes(key, value, context): + if context.get("agent_id"): + return PolicyAction.DENY + return PolicyAction.ALLOW + + # Redact sensitive data on read + @policy.after_read("users/*/email") + def redact_emails(key, value, context): + if context.get("agent_id") and not context.get("has_pii_access"): + return "[REDACTED]".encode() + return value + + # Rate limit writes per agent + policy.add_rate_limit("writes", max_per_minute=100, scope="agent_id") + + # Use policy-wrapped operations + policy.put(b"users/alice/name", b"Alice", context={"agent_id": "agent_001"}) + value = policy.get(b"users/alice/email", context={"agent_id": "agent_001"}) +""" + +import time +import json +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union +from threading import Lock + + +class PolicyAction(Enum): + """Action to take when a policy is triggered.""" + ALLOW = "allow" # Allow the operation + DENY = "deny" # Block the operation + MODIFY = "modify" # Allow with modifications + LOG = "log" # Allow but log the operation + RATE_LIMIT = "rate_limit" # Apply rate limiting + + +class PolicyTrigger(Enum): + """When the policy is triggered.""" + BEFORE_READ = "before_read" + AFTER_READ = "after_read" + BEFORE_WRITE = "before_write" + AFTER_WRITE = "after_write" + BEFORE_DELETE = "before_delete" + AFTER_DELETE = "after_delete" + + +@dataclass +class PolicyResult: + """Result of a policy evaluation.""" + action: PolicyAction + modified_value: Optional[bytes] = None + reason: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PolicyContext: + """Context passed to policy handlers.""" + operation: str # "read", "write", "delete" + key: bytes + value: Optional[bytes] = None + agent_id: Optional[str] = None + session_id: Optional[str] = None + timestamp: float = field(default_factory=time.time) + custom: Dict[str, Any] = field(default_factory=dict) + + def get(self, key: str, default: Any = None) -> Any: + """Get a context value.""" + if key == "agent_id": + return self.agent_id + if key == "session_id": + return self.session_id + return self.custom.get(key, default) + + +class PolicyHandler(ABC): + """Abstract base class for policy handlers.""" + + @abstractmethod + def evaluate(self, context: PolicyContext) -> PolicyResult: + """Evaluate the policy for the given context.""" + pass + + +class PatternPolicy(PolicyHandler): + """Policy that applies to keys matching a pattern.""" + + def __init__( + self, + pattern: str, + trigger: PolicyTrigger, + handler: Callable[[bytes, Optional[bytes], Dict[str, Any]], Union[PolicyAction, Tuple[PolicyAction, bytes]]], + ): + """ + Create a pattern-based policy. + + Args: + pattern: Key pattern (glob-style: users/*/email, system/*) + trigger: When to trigger + handler: Function(key, value, context) -> PolicyAction or (PolicyAction, modified_value) + """ + self.pattern = pattern + self.trigger = trigger + self.handler = handler + self._regex = self._pattern_to_regex(pattern) + + def _pattern_to_regex(self, pattern: str) -> Pattern: + """Convert glob pattern to regex.""" + regex = pattern.replace(".", r"\.").replace("*", r"[^/]*").replace("**", r".*") + return re.compile(f"^{regex}$") + + def matches(self, key: bytes) -> bool: + """Check if key matches the pattern.""" + try: + key_str = key.decode("utf-8") + return bool(self._regex.match(key_str)) + except UnicodeDecodeError: + return False + + def evaluate(self, context: PolicyContext) -> PolicyResult: + """Evaluate the policy.""" + if not self.matches(context.key): + return PolicyResult(action=PolicyAction.ALLOW) + + result = self.handler(context.key, context.value, context.custom) + + if isinstance(result, PolicyAction): + return PolicyResult(action=result) + elif isinstance(result, tuple): + action, value = result + return PolicyResult(action=action, modified_value=value) + elif isinstance(result, bytes): + return PolicyResult(action=PolicyAction.MODIFY, modified_value=result) + else: + return PolicyResult(action=PolicyAction.ALLOW) + + +class RateLimiter: + """Token bucket rate limiter.""" + + def __init__(self, max_per_minute: int): + self.max_per_minute = max_per_minute + self.tokens = max_per_minute + self.last_refill = time.time() + self._lock = Lock() + + def try_acquire(self) -> bool: + """Try to acquire a token. Returns True if allowed.""" + with self._lock: + now = time.time() + elapsed = now - self.last_refill + + # Refill tokens based on elapsed time + refill = int(elapsed * self.max_per_minute / 60) + if refill > 0: + self.tokens = min(self.max_per_minute, self.tokens + refill) + self.last_refill = now + + if self.tokens > 0: + self.tokens -= 1 + return True + return False + + def remaining(self) -> int: + """Get remaining tokens.""" + return self.tokens + + +class PolicyEngine: + """ + Policy engine for enforcing safety rules on database operations. + + Wraps a ToonDB Database instance and applies policies to all operations. + """ + + def __init__(self, db): + """ + Create a policy engine. + + Args: + db: ToonDB Database instance + """ + self._db = db + self._policies: Dict[PolicyTrigger, List[PolicyHandler]] = defaultdict(list) + self._rate_limiters: Dict[str, Dict[str, RateLimiter]] = defaultdict(dict) + self._rate_limit_configs: List[Dict] = [] + self._audit_log: List[Dict] = [] + self._audit_enabled = False + self._max_audit_entries = 10000 + self._lock = Lock() + + # ========================================================================= + # Decorator API + # ========================================================================= + + def before_write(self, pattern: str): + """ + Decorator for pre-write policies. + + Example: + @policy.before_write("system/*") + def block_system_writes(key, value, context): + return PolicyAction.DENY + """ + def decorator(handler): + policy = PatternPolicy(pattern, PolicyTrigger.BEFORE_WRITE, handler) + self._policies[PolicyTrigger.BEFORE_WRITE].append(policy) + return handler + return decorator + + def after_write(self, pattern: str): + """Decorator for post-write policies (e.g., audit logging).""" + def decorator(handler): + policy = PatternPolicy(pattern, PolicyTrigger.AFTER_WRITE, handler) + self._policies[PolicyTrigger.AFTER_WRITE].append(policy) + return handler + return decorator + + def before_read(self, pattern: str): + """Decorator for pre-read policies (e.g., access control).""" + def decorator(handler): + policy = PatternPolicy(pattern, PolicyTrigger.BEFORE_READ, handler) + self._policies[PolicyTrigger.BEFORE_READ].append(policy) + return handler + return decorator + + def after_read(self, pattern: str): + """ + Decorator for post-read policies (e.g., redaction). + + Example: + @policy.after_read("users/*/email") + def redact_emails(key, value, context): + if context.get("redact_pii"): + return b"[REDACTED]" + return value + """ + def decorator(handler): + policy = PatternPolicy(pattern, PolicyTrigger.AFTER_READ, handler) + self._policies[PolicyTrigger.AFTER_READ].append(policy) + return handler + return decorator + + def before_delete(self, pattern: str): + """Decorator for pre-delete policies.""" + def decorator(handler): + policy = PatternPolicy(pattern, PolicyTrigger.BEFORE_DELETE, handler) + self._policies[PolicyTrigger.BEFORE_DELETE].append(policy) + return handler + return decorator + + # ========================================================================= + # Rate Limiting API + # ========================================================================= + + def add_rate_limit( + self, + operation: str, # "read", "write", "delete", or "all" + max_per_minute: int, + scope: str = "global", # "global", "agent_id", "session_id" + ): + """ + Add a rate limit policy. + + Args: + operation: Which operation to limit + max_per_minute: Maximum operations per minute + scope: Scope for the limit (global, per-agent, per-session) + + Example: + # Global write limit + policy.add_rate_limit("write", max_per_minute=1000) + + # Per-agent read limit + policy.add_rate_limit("read", max_per_minute=100, scope="agent_id") + """ + self._rate_limit_configs.append({ + "operation": operation, + "max_per_minute": max_per_minute, + "scope": scope, + }) + + def _check_rate_limit(self, operation: str, context: PolicyContext) -> bool: + """Check if operation is allowed under rate limits.""" + for config in self._rate_limit_configs: + if config["operation"] not in (operation, "all"): + continue + + # Determine scope key + scope = config["scope"] + if scope == "global": + scope_key = "global" + elif scope == "agent_id": + scope_key = context.agent_id or "unknown" + elif scope == "session_id": + scope_key = context.session_id or "unknown" + else: + scope_key = context.get(scope, "unknown") + + # Get or create rate limiter + limiter_key = f"{config['operation']}:{scope}" + if scope_key not in self._rate_limiters[limiter_key]: + self._rate_limiters[limiter_key][scope_key] = RateLimiter( + config["max_per_minute"] + ) + + if not self._rate_limiters[limiter_key][scope_key].try_acquire(): + return False + + return True + + # ========================================================================= + # Audit API + # ========================================================================= + + def enable_audit(self, max_entries: int = 10000): + """Enable audit logging.""" + self._audit_enabled = True + self._max_audit_entries = max_entries + + def disable_audit(self): + """Disable audit logging.""" + self._audit_enabled = False + + def get_audit_log( + self, + limit: int = 100, + agent_id: Optional[str] = None, + operation: Optional[str] = None, + ) -> List[Dict]: + """ + Get audit log entries. + + Args: + limit: Maximum entries to return + agent_id: Filter by agent ID + operation: Filter by operation type + + Returns: + List of audit log entries + """ + with self._lock: + entries = self._audit_log[-limit:] + + if agent_id: + entries = [e for e in entries if e.get("agent_id") == agent_id] + if operation: + entries = [e for e in entries if e.get("operation") == operation] + + return entries + + def _audit(self, operation: str, key: bytes, context: PolicyContext, result: str): + """Add an audit log entry.""" + if not self._audit_enabled: + return + + with self._lock: + entry = { + "timestamp": time.time(), + "operation": operation, + "key": key.decode("utf-8", errors="replace"), + "agent_id": context.agent_id, + "session_id": context.session_id, + "result": result, + } + self._audit_log.append(entry) + + # Trim if too many entries + if len(self._audit_log) > self._max_audit_entries: + self._audit_log = self._audit_log[-self._max_audit_entries:] + + # ========================================================================= + # Evaluation Logic + # ========================================================================= + + def _evaluate_policies( + self, + trigger: PolicyTrigger, + context: PolicyContext, + ) -> PolicyResult: + """Evaluate all policies for a trigger.""" + for policy in self._policies[trigger]: + result = policy.evaluate(context) + if result.action in (PolicyAction.DENY, PolicyAction.MODIFY): + return result + return PolicyResult(action=PolicyAction.ALLOW) + + # ========================================================================= + # Wrapped Database Operations + # ========================================================================= + + def put( + self, + key: bytes, + value: bytes, + context: Optional[Dict[str, Any]] = None, + ) -> bool: + """ + Put a value with policy enforcement. + + Args: + key: Key bytes + value: Value bytes + context: Policy context (agent_id, session_id, etc.) + + Returns: + True if write succeeded, False if blocked by policy + + Raises: + PolicyViolation: If policy blocks the write + """ + ctx = self._make_context("write", key, value, context) + + # Check rate limits + if not self._check_rate_limit("write", ctx): + self._audit("write", key, ctx, "rate_limited") + raise PolicyViolation("Rate limit exceeded") + + # Evaluate before-write policies + result = self._evaluate_policies(PolicyTrigger.BEFORE_WRITE, ctx) + if result.action == PolicyAction.DENY: + self._audit("write", key, ctx, "denied") + raise PolicyViolation(result.reason or "Write blocked by policy") + + # Apply modifications if any + write_value = result.modified_value if result.action == PolicyAction.MODIFY else value + + # Perform the write + self._db.put(key, write_value) + + # Evaluate after-write policies + ctx.value = write_value + self._evaluate_policies(PolicyTrigger.AFTER_WRITE, ctx) + + self._audit("write", key, ctx, "allowed") + return True + + def get( + self, + key: bytes, + context: Optional[Dict[str, Any]] = None, + ) -> Optional[bytes]: + """ + Get a value with policy enforcement. + + Args: + key: Key bytes + context: Policy context + + Returns: + Value bytes (possibly modified by policy) or None + """ + ctx = self._make_context("read", key, None, context) + + # Check rate limits + if not self._check_rate_limit("read", ctx): + self._audit("read", key, ctx, "rate_limited") + raise PolicyViolation("Rate limit exceeded") + + # Evaluate before-read policies + result = self._evaluate_policies(PolicyTrigger.BEFORE_READ, ctx) + if result.action == PolicyAction.DENY: + self._audit("read", key, ctx, "denied") + raise PolicyViolation(result.reason or "Read blocked by policy") + + # Perform the read + value = self._db.get(key) + if value is None: + return None + + # Evaluate after-read policies + ctx.value = value + result = self._evaluate_policies(PolicyTrigger.AFTER_READ, ctx) + + if result.action == PolicyAction.MODIFY: + value = result.modified_value + elif result.action == PolicyAction.DENY: + self._audit("read", key, ctx, "redacted") + return None + + self._audit("read", key, ctx, "allowed") + return value + + def delete( + self, + key: bytes, + context: Optional[Dict[str, Any]] = None, + ) -> bool: + """ + Delete a value with policy enforcement. + + Args: + key: Key bytes + context: Policy context + + Returns: + True if delete succeeded + """ + ctx = self._make_context("delete", key, None, context) + + # Check rate limits + if not self._check_rate_limit("delete", ctx): + self._audit("delete", key, ctx, "rate_limited") + raise PolicyViolation("Rate limit exceeded") + + # Evaluate before-delete policies + result = self._evaluate_policies(PolicyTrigger.BEFORE_DELETE, ctx) + if result.action == PolicyAction.DENY: + self._audit("delete", key, ctx, "denied") + raise PolicyViolation(result.reason or "Delete blocked by policy") + + # Perform the delete + self._db.delete(key) + + self._audit("delete", key, ctx, "allowed") + return True + + def _make_context( + self, + operation: str, + key: bytes, + value: Optional[bytes], + context: Optional[Dict[str, Any]], + ) -> PolicyContext: + """Create a policy context.""" + ctx = context or {} + return PolicyContext( + operation=operation, + key=key, + value=value, + agent_id=ctx.get("agent_id"), + session_id=ctx.get("session_id"), + custom=ctx, + ) + + +class PolicyViolation(Exception): + """Raised when a policy blocks an operation.""" + pass + + +# ============================================================================= +# Built-in Policy Helpers +# ============================================================================= + +def deny_all(key: bytes, value: Optional[bytes], context: Dict) -> PolicyAction: + """Policy that denies all matching operations.""" + return PolicyAction.DENY + + +def allow_all(key: bytes, value: Optional[bytes], context: Dict) -> PolicyAction: + """Policy that allows all matching operations.""" + return PolicyAction.ALLOW + + +def require_agent_id(key: bytes, value: Optional[bytes], context: Dict) -> PolicyAction: + """Policy that requires an agent_id in context.""" + if context.get("agent_id"): + return PolicyAction.ALLOW + return PolicyAction.DENY + + +def redact_value(replacement: bytes = b"[REDACTED]"): + """Policy factory that redacts values.""" + def handler(key: bytes, value: Optional[bytes], context: Dict) -> bytes: + return replacement + return handler + + +def log_and_allow(logger=None): + """Policy factory that logs operations but allows them.""" + def handler(key: bytes, value: Optional[bytes], context: Dict) -> PolicyAction: + msg = f"Operation on {key.decode('utf-8', errors='replace')}" + if logger: + logger.info(msg) + else: + print(msg) + return PolicyAction.ALLOW + return handler diff --git a/src/toondb/routing.py b/src/toondb/routing.py new file mode 100644 index 0000000..2402a1e --- /dev/null +++ b/src/toondb/routing.py @@ -0,0 +1,810 @@ +# Copyright 2025 Sushanth (https://github.com/sushanthpy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tool Routing Primitive for Multi-Agent Scenarios. + +Provides a first-class system for routing tool calls to agents based on: +- Agent capabilities +- Tool requirements +- Load balancing +- Agent availability + +This enables multi-agent architectures where specialized agents handle +different tool domains (e.g., code, search, database, email). + +Example: +-------- + from toondb import Database + from toondb.routing import ( + ToolRouter, + AgentRegistry, + Tool, + ToolCategory, + RoutingStrategy, + ) + + db = Database.open("./agent_data") + registry = AgentRegistry(db) + router = ToolRouter(registry) + + # Register agents with capabilities + registry.register_agent( + agent_id="code_agent", + capabilities=[ToolCategory.CODE, ToolCategory.GIT], + endpoint="http://localhost:8001/invoke", + ) + registry.register_agent( + agent_id="search_agent", + capabilities=[ToolCategory.SEARCH, ToolCategory.WEB], + endpoint="http://localhost:8002/invoke", + ) + + # Register tools with categories + router.register_tool(Tool( + name="search_code", + description="Search codebase", + category=ToolCategory.CODE, + schema={"type": "object", "properties": {"query": {"type": "string"}}}, + )) + + # Route a tool call to the best agent + result = router.route( + tool_name="search_code", + args={"query": "authentication handler"}, + context={"session_id": "sess_001"}, + ) +""" + +import time +import json +import hashlib +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from threading import Lock +import random + + +class ToolCategory(Enum): + """Standard tool categories for routing.""" + CODE = "code" # Code analysis, generation, editing + SEARCH = "search" # Search operations (semantic, keyword) + DATABASE = "database" # Database operations (CRUD, queries) + WEB = "web" # Web browsing, API calls + FILE = "file" # File system operations + GIT = "git" # Git/VCS operations + SHELL = "shell" # Shell command execution + EMAIL = "email" # Email operations + CALENDAR = "calendar" # Calendar/scheduling + MEMORY = "memory" # Agent memory operations + VECTOR = "vector" # Vector search and embeddings + GRAPH = "graph" # Graph operations + CUSTOM = "custom" # User-defined category + + +class RoutingStrategy(Enum): + """How to select among multiple capable agents.""" + ROUND_ROBIN = "round_robin" # Cycle through agents + RANDOM = "random" # Random selection + LEAST_LOADED = "least_loaded" # Prefer less busy agents + STICKY = "sticky" # Keep same agent for session + PRIORITY = "priority" # Use agent priority scores + FASTEST = "fastest" # Prefer historically fastest + + +class AgentStatus(Enum): + """Agent availability status.""" + AVAILABLE = "available" + BUSY = "busy" + OFFLINE = "offline" + DEGRADED = "degraded" + + +@dataclass +class Tool: + """Definition of a routable tool.""" + name: str + description: str + category: ToolCategory + schema: Dict[str, Any] = field(default_factory=dict) + required_capabilities: List[ToolCategory] = field(default_factory=list) + timeout_seconds: float = 30.0 + retries: int = 1 + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Agent: + """Definition of an agent that can handle tools.""" + agent_id: str + capabilities: List[ToolCategory] + endpoint: Optional[str] = None + handler: Optional[Callable] = None # Local function handler + priority: int = 100 # Higher = preferred + max_concurrent: int = 10 + metadata: Dict[str, Any] = field(default_factory=dict) + + # Runtime state + status: AgentStatus = AgentStatus.AVAILABLE + current_load: int = 0 + total_calls: int = 0 + total_latency_ms: float = 0.0 + last_success: Optional[float] = None + last_failure: Optional[float] = None + + +@dataclass +class RouteResult: + """Result of a tool routing decision.""" + agent_id: str + tool_name: str + result: Any + latency_ms: float + success: bool + error: Optional[str] = None + retries_used: int = 0 + + +@dataclass +class RoutingContext: + """Context for routing decisions.""" + session_id: Optional[str] = None + user_id: Optional[str] = None + priority: int = 100 + timeout_override: Optional[float] = None + preferred_agent: Optional[str] = None + excluded_agents: List[str] = field(default_factory=list) + custom: Dict[str, Any] = field(default_factory=dict) + + +class AgentRegistry: + """ + Registry of agents and their capabilities. + + Persists agent registrations to ToonDB for durability across restarts. + """ + + PREFIX = "/_routing/agents/" + + def __init__(self, db): + """ + Create an agent registry. + + Args: + db: ToonDB Database instance + """ + self._db = db + self._agents: Dict[str, Agent] = {} + self._lock = Lock() + self._load_agents() + + def _load_agents(self): + """Load agent registrations from database.""" + try: + results = self._db.scan_prefix(self.PREFIX.encode()) + for key, value in results: + try: + data = json.loads(value.decode()) + agent = Agent( + agent_id=data["agent_id"], + capabilities=[ToolCategory(c) for c in data["capabilities"]], + endpoint=data.get("endpoint"), + priority=data.get("priority", 100), + max_concurrent=data.get("max_concurrent", 10), + metadata=data.get("metadata", {}), + ) + self._agents[agent.agent_id] = agent + except (json.JSONDecodeError, KeyError): + pass + except Exception: + pass # No existing agents + + def register_agent( + self, + agent_id: str, + capabilities: List[Union[ToolCategory, str]], + endpoint: Optional[str] = None, + handler: Optional[Callable] = None, + priority: int = 100, + max_concurrent: int = 10, + metadata: Optional[Dict[str, Any]] = None, + ) -> Agent: + """ + Register an agent with capabilities. + + Args: + agent_id: Unique agent identifier + capabilities: List of tool categories this agent handles + endpoint: HTTP endpoint for remote agents + handler: Local function handler for in-process agents + priority: Routing priority (higher = preferred) + max_concurrent: Max concurrent tool calls + metadata: Additional metadata + + Returns: + The registered Agent + """ + caps = [ + c if isinstance(c, ToolCategory) else ToolCategory(c) + for c in capabilities + ] + + agent = Agent( + agent_id=agent_id, + capabilities=caps, + endpoint=endpoint, + handler=handler, + priority=priority, + max_concurrent=max_concurrent, + metadata=metadata or {}, + ) + + with self._lock: + self._agents[agent_id] = agent + + # Persist to database (skip handler as it's not serializable) + data = { + "agent_id": agent_id, + "capabilities": [c.value for c in caps], + "endpoint": endpoint, + "priority": priority, + "max_concurrent": max_concurrent, + "metadata": metadata or {}, + } + key = f"{self.PREFIX}{agent_id}".encode() + self._db.put(key, json.dumps(data).encode()) + + return agent + + def unregister_agent(self, agent_id: str) -> bool: + """Remove an agent registration.""" + with self._lock: + if agent_id in self._agents: + del self._agents[agent_id] + key = f"{self.PREFIX}{agent_id}".encode() + self._db.delete(key) + return True + return False + + def get_agent(self, agent_id: str) -> Optional[Agent]: + """Get an agent by ID.""" + return self._agents.get(agent_id) + + def list_agents(self) -> List[Agent]: + """List all registered agents.""" + return list(self._agents.values()) + + def find_capable_agents( + self, + required: List[ToolCategory], + exclude: Optional[List[str]] = None, + ) -> List[Agent]: + """ + Find agents capable of handling the required categories. + + Args: + required: Required tool categories + exclude: Agent IDs to exclude + + Returns: + List of capable agents + """ + exclude_set = set(exclude or []) + capable = [] + + for agent in self._agents.values(): + if agent.agent_id in exclude_set: + continue + if agent.status == AgentStatus.OFFLINE: + continue + + # Check if agent has all required capabilities + agent_caps = set(agent.capabilities) + if all(req in agent_caps for req in required): + capable.append(agent) + + return capable + + def update_agent_status(self, agent_id: str, status: AgentStatus): + """Update an agent's status.""" + if agent := self._agents.get(agent_id): + agent.status = status + + def record_call(self, agent_id: str, latency_ms: float, success: bool): + """Record a tool call result for an agent.""" + if agent := self._agents.get(agent_id): + agent.total_calls += 1 + agent.total_latency_ms += latency_ms + if success: + agent.last_success = time.time() + else: + agent.last_failure = time.time() + + +class ToolRouter: + """ + Routes tool calls to appropriate agents. + + Supports multiple routing strategies and automatic failover. + """ + + TOOLS_PREFIX = "/_routing/tools/" + SESSIONS_PREFIX = "/_routing/sessions/" + + def __init__( + self, + registry: AgentRegistry, + default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY, + ): + """ + Create a tool router. + + Args: + registry: Agent registry + default_strategy: Default routing strategy + """ + self._registry = registry + self._db = registry._db + self._default_strategy = default_strategy + self._tools: Dict[str, Tool] = {} + self._round_robin_idx: Dict[str, int] = defaultdict(int) + self._session_affinity: Dict[str, str] = {} # session_id -> agent_id + self._lock = Lock() + self._load_tools() + + def _load_tools(self): + """Load tool registrations from database.""" + try: + results = self._db.scan_prefix(self.TOOLS_PREFIX.encode()) + for key, value in results: + try: + data = json.loads(value.decode()) + tool = Tool( + name=data["name"], + description=data["description"], + category=ToolCategory(data["category"]), + schema=data.get("schema", {}), + required_capabilities=[ + ToolCategory(c) for c in data.get("required_capabilities", []) + ], + timeout_seconds=data.get("timeout_seconds", 30.0), + retries=data.get("retries", 1), + metadata=data.get("metadata", {}), + ) + self._tools[tool.name] = tool + except (json.JSONDecodeError, KeyError): + pass + except Exception: + pass + + def register_tool(self, tool: Tool) -> Tool: + """ + Register a tool for routing. + + Args: + tool: Tool definition + + Returns: + The registered tool + """ + with self._lock: + self._tools[tool.name] = tool + + # Persist to database + data = { + "name": tool.name, + "description": tool.description, + "category": tool.category.value, + "schema": tool.schema, + "required_capabilities": [c.value for c in tool.required_capabilities], + "timeout_seconds": tool.timeout_seconds, + "retries": tool.retries, + "metadata": tool.metadata, + } + key = f"{self.TOOLS_PREFIX}{tool.name}".encode() + self._db.put(key, json.dumps(data).encode()) + + return tool + + def unregister_tool(self, name: str) -> bool: + """Remove a tool registration.""" + with self._lock: + if name in self._tools: + del self._tools[name] + key = f"{self.TOOLS_PREFIX}{name}".encode() + self._db.delete(key) + return True + return False + + def get_tool(self, name: str) -> Optional[Tool]: + """Get a tool by name.""" + return self._tools.get(name) + + def list_tools(self) -> List[Tool]: + """List all registered tools.""" + return list(self._tools.values()) + + def route( + self, + tool_name: str, + args: Dict[str, Any], + context: Optional[Union[Dict[str, Any], RoutingContext]] = None, + strategy: Optional[RoutingStrategy] = None, + ) -> RouteResult: + """ + Route a tool call to the best agent. + + Args: + tool_name: Name of the tool to call + args: Tool arguments + context: Routing context + strategy: Override routing strategy + + Returns: + RouteResult with the call outcome + """ + tool = self._tools.get(tool_name) + if not tool: + return RouteResult( + agent_id="", + tool_name=tool_name, + result=None, + latency_ms=0, + success=False, + error=f"Unknown tool: {tool_name}", + ) + + # Normalize context + if context is None: + ctx = RoutingContext() + elif isinstance(context, dict): + ctx = RoutingContext( + session_id=context.get("session_id"), + user_id=context.get("user_id"), + priority=context.get("priority", 100), + timeout_override=context.get("timeout"), + preferred_agent=context.get("preferred_agent"), + excluded_agents=context.get("excluded_agents", []), + custom=context, + ) + else: + ctx = context + + # Determine required capabilities + required = tool.required_capabilities or [tool.category] + + # Find capable agents + capable = self._registry.find_capable_agents(required, ctx.excluded_agents) + if not capable: + return RouteResult( + agent_id="", + tool_name=tool_name, + result=None, + latency_ms=0, + success=False, + error=f"No capable agents for tool '{tool_name}' (requires: {[c.value for c in required]})", + ) + + # Select agent using strategy + use_strategy = strategy or self._default_strategy + agent = self._select_agent(capable, use_strategy, ctx) + + # Execute with retries + timeout = ctx.timeout_override or tool.timeout_seconds + retries = tool.retries + last_error = None + + for attempt in range(retries + 1): + start_time = time.time() + try: + result = self._invoke_agent(agent, tool, args, timeout) + latency_ms = (time.time() - start_time) * 1000 + + # Record success + self._registry.record_call(agent.agent_id, latency_ms, True) + + # Update session affinity + if ctx.session_id: + self._session_affinity[ctx.session_id] = agent.agent_id + + return RouteResult( + agent_id=agent.agent_id, + tool_name=tool_name, + result=result, + latency_ms=latency_ms, + success=True, + retries_used=attempt, + ) + + except Exception as e: + latency_ms = (time.time() - start_time) * 1000 + self._registry.record_call(agent.agent_id, latency_ms, False) + last_error = str(e) + + # Try next capable agent on failure + capable = [a for a in capable if a.agent_id != agent.agent_id] + if capable: + agent = self._select_agent(capable, use_strategy, ctx) + + return RouteResult( + agent_id=agent.agent_id if agent else "", + tool_name=tool_name, + result=None, + latency_ms=0, + success=False, + error=last_error or "All retries exhausted", + retries_used=retries, + ) + + def _select_agent( + self, + capable: List[Agent], + strategy: RoutingStrategy, + ctx: RoutingContext, + ) -> Agent: + """Select an agent using the specified strategy.""" + if not capable: + raise ValueError("No capable agents") + + # Preferred agent override + if ctx.preferred_agent: + for agent in capable: + if agent.agent_id == ctx.preferred_agent: + return agent + + # Session affinity (sticky routing) + if strategy == RoutingStrategy.STICKY and ctx.session_id: + if prev_agent := self._session_affinity.get(ctx.session_id): + for agent in capable: + if agent.agent_id == prev_agent: + return agent + + if strategy == RoutingStrategy.ROUND_ROBIN: + with self._lock: + key = ",".join(sorted(a.agent_id for a in capable)) + idx = self._round_robin_idx[key] % len(capable) + self._round_robin_idx[key] = idx + 1 + return capable[idx] + + elif strategy == RoutingStrategy.RANDOM: + return random.choice(capable) + + elif strategy == RoutingStrategy.LEAST_LOADED: + return min(capable, key=lambda a: a.current_load) + + elif strategy == RoutingStrategy.PRIORITY: + # Sort by priority descending, then by load ascending + return max(capable, key=lambda a: (a.priority, -a.current_load)) + + elif strategy == RoutingStrategy.FASTEST: + # Sort by average latency + def avg_latency(a: Agent) -> float: + if a.total_calls == 0: + return float("inf") + return a.total_latency_ms / a.total_calls + return min(capable, key=avg_latency) + + # Default to priority + return max(capable, key=lambda a: a.priority) + + def _invoke_agent( + self, + agent: Agent, + tool: Tool, + args: Dict[str, Any], + timeout: float, + ) -> Any: + """Invoke a tool on an agent.""" + agent.current_load += 1 + try: + if agent.handler: + # Local function handler + return agent.handler(tool.name, args) + + elif agent.endpoint: + # Remote HTTP invocation + import urllib.request + import urllib.error + + request_data = json.dumps({ + "tool": tool.name, + "args": args, + "metadata": tool.metadata, + }).encode("utf-8") + + req = urllib.request.Request( + agent.endpoint, + data=request_data, + headers={"Content-Type": "application/json"}, + ) + + with urllib.request.urlopen(req, timeout=timeout) as resp: + response_data = resp.read().decode("utf-8") + return json.loads(response_data) + + else: + raise ValueError(f"Agent {agent.agent_id} has no handler or endpoint") + + finally: + agent.current_load -= 1 + + +class ToolDispatcher: + """ + High-level dispatcher for multi-agent tool orchestration. + + Provides a simple interface for agents to register and invoke tools. + """ + + def __init__(self, db): + """ + Create a tool dispatcher. + + Args: + db: ToonDB Database instance + """ + self._db = db + self._registry = AgentRegistry(db) + self._router = ToolRouter(self._registry) + + @property + def registry(self) -> AgentRegistry: + """Get the agent registry.""" + return self._registry + + @property + def router(self) -> ToolRouter: + """Get the tool router.""" + return self._router + + def register_local_agent( + self, + agent_id: str, + capabilities: List[Union[ToolCategory, str]], + handler: Callable[[str, Dict[str, Any]], Any], + priority: int = 100, + ) -> Agent: + """ + Register a local (in-process) agent. + + Args: + agent_id: Unique agent identifier + capabilities: Tool categories this agent handles + handler: Function(tool_name, args) -> result + priority: Routing priority + + Returns: + The registered agent + """ + return self._registry.register_agent( + agent_id=agent_id, + capabilities=capabilities, + handler=handler, + priority=priority, + ) + + def register_remote_agent( + self, + agent_id: str, + capabilities: List[Union[ToolCategory, str]], + endpoint: str, + priority: int = 100, + ) -> Agent: + """ + Register a remote (HTTP) agent. + + Args: + agent_id: Unique agent identifier + capabilities: Tool categories this agent handles + endpoint: HTTP endpoint URL + priority: Routing priority + + Returns: + The registered agent + """ + return self._registry.register_agent( + agent_id=agent_id, + capabilities=capabilities, + endpoint=endpoint, + priority=priority, + ) + + def register_tool( + self, + name: str, + description: str, + category: Union[ToolCategory, str], + schema: Optional[Dict[str, Any]] = None, + timeout: float = 30.0, + ) -> Tool: + """ + Register a tool for routing. + + Args: + name: Tool name + description: Tool description + category: Tool category + schema: JSON schema for arguments + timeout: Call timeout in seconds + + Returns: + The registered tool + """ + cat = category if isinstance(category, ToolCategory) else ToolCategory(category) + tool = Tool( + name=name, + description=description, + category=cat, + schema=schema or {}, + timeout_seconds=timeout, + ) + return self._router.register_tool(tool) + + def invoke( + self, + tool_name: str, + args: Optional[Dict[str, Any]] = None, + session_id: Optional[str] = None, + strategy: Optional[RoutingStrategy] = None, + ) -> RouteResult: + """ + Invoke a tool with automatic routing. + + Args: + tool_name: Name of the tool to call + args: Tool arguments + session_id: Optional session for sticky routing + strategy: Override routing strategy + + Returns: + RouteResult with the call outcome + """ + ctx = RoutingContext(session_id=session_id) + return self._router.route(tool_name, args or {}, ctx, strategy) + + def list_agents(self) -> List[Dict[str, Any]]: + """List all registered agents with their status.""" + agents = self._registry.list_agents() + return [ + { + "agent_id": a.agent_id, + "capabilities": [c.value for c in a.capabilities], + "status": a.status.value, + "priority": a.priority, + "current_load": a.current_load, + "total_calls": a.total_calls, + "avg_latency_ms": (a.total_latency_ms / a.total_calls) if a.total_calls > 0 else None, + "has_endpoint": a.endpoint is not None, + "has_handler": a.handler is not None, + } + for a in agents + ] + + def list_tools(self) -> List[Dict[str, Any]]: + """List all registered tools.""" + tools = self._router.list_tools() + return [ + { + "name": t.name, + "description": t.description, + "category": t.category.value, + "schema": t.schema, + "timeout_seconds": t.timeout_seconds, + "retries": t.retries, + } + for t in tools + ] diff --git a/src/toondb/sql_engine.py b/src/toondb/sql_engine.py index 8afc09a..c321bf5 100644 --- a/src/toondb/sql_engine.py +++ b/src/toondb/sql_engine.py @@ -92,15 +92,20 @@ def parse(sql: str) -> Tuple[str, Dict[str, Any]]: Returns: Tuple of (operation_type, parsed_info) - operation_type: CREATE_TABLE, DROP_TABLE, INSERT, SELECT, UPDATE, DELETE + operation_type: CREATE_TABLE, DROP_TABLE, CREATE_INDEX, DROP_INDEX, + INSERT, SELECT, UPDATE, DELETE """ sql = sql.strip() upper = sql.upper() if upper.startswith("CREATE TABLE"): return SQLParser._parse_create_table(sql) + elif upper.startswith("CREATE INDEX"): + return SQLParser._parse_create_index(sql) elif upper.startswith("DROP TABLE"): return SQLParser._parse_drop_table(sql) + elif upper.startswith("DROP INDEX"): + return SQLParser._parse_drop_index(sql) elif upper.startswith("INSERT"): return SQLParser._parse_insert(sql) elif upper.startswith("SELECT"): @@ -112,6 +117,54 @@ def parse(sql: str) -> Tuple[str, Dict[str, Any]]: else: raise ValueError(f"Unsupported SQL statement: {sql[:50]}") + @staticmethod + def _parse_create_index(sql: str) -> Tuple[str, Dict]: + """ + Parse CREATE INDEX statement. + + Syntax: CREATE INDEX idx_name ON table_name(column_name) + """ + match = re.match( + r'CREATE\s+INDEX\s+(\w+)\s+ON\s+(\w+)\s*\(\s*(\w+)\s*\)', + sql, + re.IGNORECASE + ) + if not match: + raise ValueError(f"Invalid CREATE INDEX syntax: {sql}") + + index_name = match.group(1) + table = match.group(2) + column = match.group(3) + + return "CREATE_INDEX", { + "index_name": index_name, + "table": table, + "column": column + } + + @staticmethod + def _parse_drop_index(sql: str) -> Tuple[str, Dict]: + """ + Parse DROP INDEX statement. + + Syntax: DROP INDEX idx_name ON table_name + """ + match = re.match( + r'DROP\s+INDEX\s+(\w+)\s+ON\s+(\w+)', + sql, + re.IGNORECASE + ) + if not match: + raise ValueError(f"Invalid DROP INDEX syntax: {sql}") + + index_name = match.group(1) + table = match.group(2) + + return "DROP_INDEX", { + "index_name": index_name, + "table": table + } + @staticmethod def _parse_create_table(sql: str) -> Tuple[str, Dict]: """Parse CREATE TABLE statement.""" @@ -492,11 +545,139 @@ class SQLExecutor: TABLE_PREFIX = b"_sql/tables/" SCHEMA_SUFFIX = b"/schema" ROWS_PREFIX = b"/rows/" + INDEXES_PREFIX = b"/indexes/" + INDEX_META_SUFFIX = b"/meta" def __init__(self, db): """Initialize with a Database instance.""" self._db = db + # ========================================================================= + # Index Infrastructure + # ========================================================================= + + def _index_key(self, table: str, index_name: str, value: Any, row_id: str) -> bytes: + """Get key for an index entry: _sql/tables/{table}/indexes/{idx}/{value}/{row_id}.""" + # Encode value as sortable string for range queries + value_str = self._encode_index_value(value) + return ( + self.TABLE_PREFIX + table.encode() + + self.INDEXES_PREFIX + index_name.encode() + + b"/" + value_str.encode() + b"/" + row_id.encode() + ) + + def _index_prefix(self, table: str, index_name: str) -> bytes: + """Get prefix for all entries in an index.""" + return ( + self.TABLE_PREFIX + table.encode() + + self.INDEXES_PREFIX + index_name.encode() + b"/" + ) + + def _index_value_prefix(self, table: str, index_name: str, value: Any) -> bytes: + """Get prefix for all entries with a specific value in an index.""" + value_str = self._encode_index_value(value) + return ( + self.TABLE_PREFIX + table.encode() + + self.INDEXES_PREFIX + index_name.encode() + + b"/" + value_str.encode() + b"/" + ) + + def _encode_index_value(self, value: Any) -> str: + """Encode a value for use in index keys (sortable string format).""" + if value is None: + return "__null__" + elif isinstance(value, bool): + return f"b:{1 if value else 0}" + elif isinstance(value, int): + # Zero-pad integers for proper string sorting + return f"i:{value:020d}" + elif isinstance(value, float): + return f"f:{value:030.15f}" + else: + # String values - escape slashes + return f"s:{str(value).replace('/', '__SLASH__')}" + + def _index_meta_key(self, table: str, index_name: str) -> bytes: + """Get key for index metadata.""" + return ( + self.TABLE_PREFIX + table.encode() + + self.INDEXES_PREFIX + index_name.encode() + + self.INDEX_META_SUFFIX + ) + + def _get_indexes(self, table: str) -> Dict[str, str]: + """Get all indexes for a table. Returns {index_name: column_name}.""" + prefix = self.TABLE_PREFIX + table.encode() + self.INDEXES_PREFIX + indexes = {} + + for key, value in self._db.scan_prefix(prefix): + if key.endswith(self.INDEX_META_SUFFIX): + # Extract index name from key + parts = key.decode().split("/") + if len(parts) >= 5: # _sql/tables/{table}/indexes/{idx}/meta + idx_name = parts[4] + meta = json.loads(value.decode()) + indexes[idx_name] = meta.get("column", idx_name) + + return indexes + + def _update_index(self, table: str, index_name: str, column: str, + old_row: Optional[Dict], new_row: Optional[Dict], row_id: str): + """Update index entries when a row changes.""" + old_value = old_row.get(column) if old_row else None + new_value = new_row.get(column) if new_row else None + + # Remove old index entry if value changed + if old_row and (new_row is None or old_value != new_value): + old_key = self._index_key(table, index_name, old_value, row_id) + self._db.delete(old_key) + + # Add new index entry if there's a new row + if new_row and (old_row is None or old_value != new_value): + new_key = self._index_key(table, index_name, new_value, row_id) + self._db.put(new_key, row_id.encode()) + + def _lookup_by_index(self, table: str, column: str, value: Any) -> List[str]: + """Look up row IDs by index value. Returns list of row_ids.""" + indexes = self._get_indexes(table) + + # Find index for this column + index_name = None + for idx_name, col in indexes.items(): + if col == column: + index_name = idx_name + break + + if index_name is None: + return [] # No index available + + # Scan index entries for this value + prefix = self._index_value_prefix(table, index_name, value) + row_ids = [] + + for key, value_bytes in self._db.scan_prefix(prefix): + row_ids.append(value_bytes.decode()) + + return row_ids + + def _has_index_for_column(self, table: str, column: str) -> bool: + """Check if an index exists for a column.""" + indexes = self._get_indexes(table) + return any(col == column for col in indexes.values()) + + def _find_indexed_equality_condition(self, table: str, + conditions: List[Tuple]) -> Optional[Tuple[str, Any]]: + """ + Find a WHERE condition that can use an index. + Returns (column, value) if found, None otherwise. + Only considers equality conditions (=). + """ + for condition in conditions: + col, op, val = condition + if op == "=" and self._has_index_for_column(table, col): + return (col, val) + return None + def execute(self, sql: str) -> SQLQueryResult: """Execute a SQL statement.""" operation, data = SQLParser.parse(sql) @@ -505,6 +686,10 @@ def execute(self, sql: str) -> SQLQueryResult: return self._create_table(data) elif operation == "DROP_TABLE": return self._drop_table(data) + elif operation == "CREATE_INDEX": + return self._create_index(data) + elif operation == "DROP_INDEX": + return self._drop_index(data) elif operation == "INSERT": return self._insert(data) elif operation == "SELECT": @@ -559,7 +744,15 @@ def _drop_table(self, data: Dict) -> SQLQueryResult: """Drop a table.""" table = data["table"] - # Delete all rows first + # Delete all indexes first + indexes = self._get_indexes(table) + for idx_name in indexes: + idx_prefix = self._index_prefix(table, idx_name) + for key, _ in self._db.scan_prefix(idx_prefix): + self._db.delete(key) + self._db.delete(self._index_meta_key(table, idx_name)) + + # Delete all rows prefix = self._row_prefix(table) rows_deleted = 0 for key, _ in self._db.scan_prefix(prefix): @@ -571,8 +764,77 @@ def _drop_table(self, data: Dict) -> SQLQueryResult: return SQLQueryResult(rows=[], columns=[], rows_affected=rows_deleted) + def _create_index(self, data: Dict) -> SQLQueryResult: + """ + Create a secondary index on a column. + + CREATE INDEX idx_name ON table(column) + + This builds the index by scanning existing rows. + """ + index_name = data["index_name"] + table = data["table"] + column = data["column"] + + schema = self._get_schema(table) + if schema is None: + raise ValueError(f"Table '{table}' does not exist") + + # Check column exists + if not any(c.name == column for c in schema.columns): + raise ValueError(f"Column '{column}' does not exist in table '{table}'") + + # Check index doesn't already exist + if self._db.get(self._index_meta_key(table, index_name)) is not None: + raise ValueError(f"Index '{index_name}' already exists on table '{table}'") + + # Store index metadata + meta = {"column": column, "table": table} + self._db.put( + self._index_meta_key(table, index_name), + json.dumps(meta).encode() + ) + + # Build index from existing rows + prefix = self._row_prefix(table) + indexed_count = 0 + + for _, value in self._db.scan_prefix(prefix): + row = json.loads(value.decode()) + row_id = row.get("_id", "") + col_value = row.get(column) + + # Add index entry + idx_key = self._index_key(table, index_name, col_value, row_id) + self._db.put(idx_key, row_id.encode()) + indexed_count += 1 + + return SQLQueryResult(rows=[], columns=[], rows_affected=indexed_count) + + def _drop_index(self, data: Dict) -> SQLQueryResult: + """Drop a secondary index.""" + index_name = data["index_name"] + table = data.get("table") + + # If table not specified, find it from index meta + if table is None: + # Search for index in all tables - not ideal but works + raise ValueError("DROP INDEX requires table name: DROP INDEX idx_name ON table") + + # Delete all index entries + idx_prefix = self._index_prefix(table, index_name) + deleted = 0 + for key, _ in self._db.scan_prefix(idx_prefix): + self._db.delete(key) + deleted += 1 + + # Delete index metadata + self._db.delete(self._index_meta_key(table, index_name)) + + return SQLQueryResult(rows=[], columns=[], rows_affected=deleted) + def _insert(self, data: Dict) -> SQLQueryResult: - """Insert a row.""" + """Insert a row and maintain secondary indexes.""" table = data["table"] columns = data["columns"] values = data["values"] @@ -606,6 +868,11 @@ def _insert(self, data: Dict) -> SQLQueryResult: json.dumps(row).encode() ) + # Update secondary indexes + indexes = self._get_indexes(table) + for idx_name, idx_col in indexes.items(): + self._update_index(table, idx_name, idx_col, None, row, row_id) + return SQLQueryResult(rows=[], columns=[], rows_affected=1) def _select(self, data: Dict) -> SQLQueryResult: @@ -692,7 +959,12 @@ def _matches_conditions(self, row: Dict, conditions: List[Tuple]) -> bool: return True def _update(self, data: Dict) -> SQLQueryResult: - """Update rows.""" + """ + Update rows. Uses index lookup when WHERE clause has indexed equality condition. + + Index-accelerated path: O(k) where k = matching rows + Fallback path: O(n) full table scan + """ table = data["table"] updates = data["updates"] conditions = data.get("where", []) @@ -701,27 +973,76 @@ def _update(self, data: Dict) -> SQLQueryResult: if schema is None: raise ValueError(f"Table '{table}' does not exist") - # Scan all rows - prefix = self._row_prefix(table) + indexes = self._get_indexes(table) rows_affected = 0 - for key, value in self._db.scan_prefix(prefix): - row = json.loads(value.decode()) + # Try index-accelerated path + indexed_cond = self._find_indexed_equality_condition(table, conditions) + + if indexed_cond: + # Index-accelerated UPDATE: lookup matching row IDs directly + col, val = indexed_cond + row_ids = self._lookup_by_index(table, col, val) - # Apply WHERE conditions - if self._matches_conditions(row, conditions): + for row_id in row_ids: + key = self._row_key(table, row_id) + value = self._db.get(key) + if value is None: + continue + + old_row = json.loads(value.decode()) + + # Apply all WHERE conditions (not just the indexed one) + if not self._matches_conditions(old_row, conditions): + continue + # Apply updates - for col, val in updates.items(): - row[col] = val + new_row = old_row.copy() + for ucol, uval in updates.items(): + new_row[ucol] = uval + + # Update indexes for changed columns + for idx_name, idx_col in indexes.items(): + if idx_col in updates: + self._update_index(table, idx_name, idx_col, old_row, new_row, row_id) # Save updated row - self._db.put(key, json.dumps(row).encode()) + self._db.put(key, json.dumps(new_row).encode()) rows_affected += 1 + else: + # Fallback: full table scan + prefix = self._row_prefix(table) + + for key, value in self._db.scan_prefix(prefix): + old_row = json.loads(value.decode()) + + # Apply WHERE conditions + if self._matches_conditions(old_row, conditions): + # Apply updates + new_row = old_row.copy() + for ucol, uval in updates.items(): + new_row[ucol] = uval + + row_id = old_row.get("_id", "") + + # Update indexes for changed columns + for idx_name, idx_col in indexes.items(): + if idx_col in updates: + self._update_index(table, idx_name, idx_col, old_row, new_row, row_id) + + # Save updated row + self._db.put(key, json.dumps(new_row).encode()) + rows_affected += 1 return SQLQueryResult(rows=[], columns=[], rows_affected=rows_affected) def _delete(self, data: Dict) -> SQLQueryResult: - """Delete rows.""" + """ + Delete rows. Uses index lookup when WHERE clause has indexed equality condition. + + Index-accelerated path: O(k) where k = matching rows + Fallback path: O(n) full table scan + """ table = data["table"] conditions = data.get("where", []) @@ -729,22 +1050,60 @@ def _delete(self, data: Dict) -> SQLQueryResult: if schema is None: raise ValueError(f"Table '{table}' does not exist") - # Scan all rows - prefix = self._row_prefix(table) + indexes = self._get_indexes(table) rows_affected = 0 - keys_to_delete = [] - # Collect keys to delete (don't modify while iterating) - for key, value in self._db.scan_prefix(prefix): - row = json.loads(value.decode()) - - # Apply WHERE conditions - if self._matches_conditions(row, conditions): - keys_to_delete.append(key) + # Try index-accelerated path + indexed_cond = self._find_indexed_equality_condition(table, conditions) - # Delete collected keys - for key in keys_to_delete: - self._db.delete(key) - rows_affected += 1 + if indexed_cond: + # Index-accelerated DELETE: lookup matching row IDs directly + col, val = indexed_cond + row_ids = self._lookup_by_index(table, col, val) + + rows_to_delete = [] # (key, row) pairs + + for row_id in row_ids: + key = self._row_key(table, row_id) + value = self._db.get(key) + if value is None: + continue + + row = json.loads(value.decode()) + + # Apply all WHERE conditions (not just the indexed one) + if self._matches_conditions(row, conditions): + rows_to_delete.append((key, row, row_id)) + + # Delete rows and update indexes + for key, row, row_id in rows_to_delete: + # Remove from all indexes + for idx_name, idx_col in indexes.items(): + self._update_index(table, idx_name, idx_col, row, None, row_id) + + self._db.delete(key) + rows_affected += 1 + else: + # Fallback: full table scan + prefix = self._row_prefix(table) + rows_to_delete = [] + + # Collect keys to delete (don't modify while iterating) + for key, value in self._db.scan_prefix(prefix): + row = json.loads(value.decode()) + + # Apply WHERE conditions + if self._matches_conditions(row, conditions): + row_id = row.get("_id", "") + rows_to_delete.append((key, row, row_id)) + + # Delete collected rows and update indexes + for key, row, row_id in rows_to_delete: + # Remove from all indexes + for idx_name, idx_col in indexes.items(): + self._update_index(table, idx_name, idx_col, row, None, row_id) + + self._db.delete(key) + rows_affected += 1 return SQLQueryResult(rows=[], columns=[], rows_affected=rows_affected)