diff --git a/CLAUDE.md b/CLAUDE.md index 6a38f34..6953d65 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,5 +1,9 @@ # CLAUDE.md - Redis Agent Memory Server Project Context +## Redis Version +This project uses Redis 8, which is the redis:8 docker image. +Do not use Redis Stack or other earlier versions of Redis. + ## Frequently Used Commands Get started in a new environment by installing `uv`: ```bash @@ -188,7 +192,6 @@ EMBEDDING_MODEL=text-embedding-3-small # Memory Configuration LONG_TERM_MEMORY=true -WINDOW_SIZE=20 ENABLE_TOPIC_EXTRACTION=true ENABLE_NER=true ``` diff --git a/agent-memory-client/agent_memory_client/__init__.py b/agent-memory-client/agent_memory_client/__init__.py index 8cee03a..7647c8b 100644 --- a/agent-memory-client/agent_memory_client/__init__.py +++ b/agent-memory-client/agent_memory_client/__init__.py @@ -5,7 +5,7 @@ memory management capabilities for AI agents and applications. """ -__version__ = "0.9.1" +__version__ = "0.9.2" from .client import MemoryAPIClient, MemoryClientConfig, create_memory_client from .exceptions import ( diff --git a/agent-memory-client/agent_memory_client/client.py b/agent-memory-client/agent_memory_client/client.py index adc58aa..2eb3ca6 100644 --- a/agent-memory-client/agent_memory_client/client.py +++ b/agent-memory-client/agent_memory_client/client.py @@ -209,7 +209,6 @@ async def get_working_memory( session_id: str, user_id: str | None = None, namespace: str | None = None, - window_size: int | None = None, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, ) -> WorkingMemoryResponse: @@ -220,7 +219,6 @@ async def get_working_memory( session_id: The session ID to retrieve working memory for user_id: The user ID to retrieve working memory for namespace: Optional namespace for the session - window_size: Optional number of messages to include model_name: Optional model name to determine context window size context_window_max: Optional direct specification of context window tokens @@ -241,9 +239,6 @@ async def get_working_memory( elif self.config.default_namespace is not None: params["namespace"] = self.config.default_namespace - if window_size is not None: - params["window_size"] = str(window_size) - # Use provided model_name or fall back to config default effective_model_name = model_name or self.config.default_model_name if effective_model_name is not None: @@ -2139,7 +2134,6 @@ async def memory_prompt( query: str, session_id: str | None = None, namespace: str | None = None, - window_size: int | None = None, model_name: str | None = None, context_window_max: int | None = None, long_term_search: dict[str, Any] | None = None, @@ -2154,7 +2148,6 @@ async def memory_prompt( query: The input text to find relevant context for session_id: Optional session ID to include session messages namespace: Optional namespace for the session - window_size: Optional number of messages to include model_name: Optional model name to determine context window size context_window_max: Optional direct specification of context window tokens long_term_search: Optional search parameters for long-term memory @@ -2169,7 +2162,6 @@ async def memory_prompt( prompt = await client.memory_prompt( query="What are my UI preferences?", session_id="current_session", - window_size=10, long_term_search={ "topics": {"any": ["preferences", "ui"]}, "limit": 5 @@ -2190,8 +2182,6 @@ async def memory_prompt( session_params["namespace"] = namespace elif self.config.default_namespace is not None: session_params["namespace"] = self.config.default_namespace - if window_size is not None: - session_params["window_size"] = str(window_size) # Use provided model_name or fall back to config default effective_model_name = model_name or self.config.default_model_name if effective_model_name is not None: diff --git a/agent-memory-client/agent_memory_client/models.py b/agent-memory-client/agent_memory_client/models.py index a77b1ea..f9b3a72 100644 --- a/agent-memory-client/agent_memory_client/models.py +++ b/agent-memory-client/agent_memory_client/models.py @@ -110,7 +110,7 @@ class MemoryRecord(BaseModel): ) discrete_memory_extracted: Literal["t", "f"] = Field( default="f", - description="Whether memory extraction has run for this memory (only messages)", + description="Whether memory extraction has run for this memory", ) memory_type: MemoryTypeEnum = Field( default=MemoryTypeEnum.MESSAGE, @@ -130,6 +130,19 @@ class MemoryRecord(BaseModel): ) +class ExtractedMemoryRecord(MemoryRecord): + """A memory record that has already been extracted (e.g., explicit memories from API/MCP)""" + + discrete_memory_extracted: Literal["t", "f"] = Field( + default="t", + description="Whether memory extraction has run for this memory", + ) + memory_type: MemoryTypeEnum = Field( + default=MemoryTypeEnum.SEMANTIC, + description="Type of memory", + ) + + class ClientMemoryRecord(MemoryRecord): """A memory record with a client-provided ID""" diff --git a/agent_memory_server/__init__.py b/agent_memory_server/__init__.py index 074b7f6..b0eb7f7 100644 --- a/agent_memory_server/__init__.py +++ b/agent_memory_server/__init__.py @@ -1,3 +1,3 @@ """Redis Agent Memory Server - A memory system for conversational AI.""" -__version__ = "0.9.3" +__version__ = "0.9.4" diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index 578795d..a16efad 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -582,19 +582,16 @@ async def memory_prompt( logger.debug(f"Memory prompt params: {params}") if params.session: - # Use token limit for memory prompt, fallback to message count for backward compatibility + # Use token limit for memory prompt - model info is required now if params.session.model_name or params.session.context_window_max: token_limit = _get_effective_token_limit( model_name=params.session.model_name, context_window_max=params.session.context_window_max, ) - effective_window_size = ( - token_limit # We'll handle token-based truncation below - ) + effective_token_limit = token_limit else: - effective_window_size = ( - params.session.window_size - ) # Fallback to message count + # No model info provided - use all messages without truncation + effective_token_limit = None working_mem = await working_memory.get_working_memory( session_id=params.session.session_id, namespace=params.session.namespace, @@ -616,11 +613,11 @@ async def memory_prompt( ) ) # Apply token-based truncation if model info is provided - if params.session.model_name or params.session.context_window_max: + if effective_token_limit is not None: # Token-based truncation if ( _calculate_messages_token_count(working_mem.messages) - > effective_window_size + > effective_token_limit ): # Keep removing oldest messages until we're under the limit recent_messages = working_mem.messages[:] @@ -628,34 +625,30 @@ async def memory_prompt( recent_messages = recent_messages[1:] # Remove oldest if ( _calculate_messages_token_count(recent_messages) - <= effective_window_size + <= effective_token_limit ): break else: recent_messages = working_mem.messages - - for msg in recent_messages: - if msg.role == "user": - msg_class = base.UserMessage - else: - msg_class = base.AssistantMessage - _messages.append( - msg_class( - content=TextContent(type="text", text=msg.content), - ) - ) else: - # No token-based truncation - use all messages - for msg in working_mem.messages: - if msg.role == "user": - msg_class = base.UserMessage - else: - msg_class = base.AssistantMessage - _messages.append( - msg_class( - content=TextContent(type="text", text=msg.content), - ) + # No token limit provided - use all messages + recent_messages = working_mem.messages + + for msg in recent_messages: + if msg.role == "user": + msg_class = base.UserMessage + elif msg.role == "assistant": + msg_class = base.AssistantMessage + else: + # For tool messages or other roles, treat as assistant for MCP compatibility + # since MCP base only supports UserMessage and AssistantMessage + msg_class = base.AssistantMessage + + _messages.append( + msg_class( + content=TextContent(type="text", text=msg.content), ) + ) if params.long_term_search: logger.debug( diff --git a/agent_memory_server/cli.py b/agent_memory_server/cli.py index 4a09940..b0a76bf 100644 --- a/agent_memory_server/cli.py +++ b/agent_memory_server/cli.py @@ -128,11 +128,14 @@ async def setup_and_run(): logger.info(f"Starting MCP server on port {port}\n") await mcp_app.run_sse_async() elif mode == "stdio": - # Logging already configured above + # Don't run a task worker in stdio mode. + # TODO: Make configurable with a CLI flag? + settings.use_docket = False await mcp_app.run_stdio_async() else: raise ValueError(f"Invalid mode: {mode}") + # TODO: Do we really need to update the port again? # Update the port in settings settings.mcp_port = port diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 73acbb2..35bba92 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -52,7 +52,9 @@ class Settings(BaseSettings): long_term_memory: bool = True openai_api_key: str | None = None anthropic_api_key: str | None = None - generation_model: str = "gpt-4o-mini" + openai_api_base: str | None = None + anthropic_api_base: str | None = None + generation_model: str = "gpt-4o" embedding_model: str = "text-embedding-3-small" port: int = 8000 mcp_port: int = 9000 @@ -118,7 +120,6 @@ class Settings(BaseSettings): auth0_client_secret: str | None = None # Working memory settings - window_size: int = 20 # Default number of recent messages to return summarization_threshold: float = ( 0.7 # Fraction of context window that triggers summarization ) diff --git a/agent_memory_server/llms.py b/agent_memory_server/llms.py index 9e14da5..18537a7 100644 --- a/agent_memory_server/llms.py +++ b/agent_memory_server/llms.py @@ -9,6 +9,8 @@ from openai import AsyncOpenAI from pydantic import BaseModel +from agent_memory_server.config import settings + logger = logging.getLogger(__name__) @@ -203,14 +205,21 @@ def total_tokens(self) -> int: class AnthropicClientWrapper: """Wrapper for Anthropic client""" - def __init__(self, api_key: str | None = None): + def __init__(self, api_key: str | None = None, base_url: str | None = None): """Initialize the Anthropic client""" anthropic_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + anthropic_api_base = base_url or os.environ.get("ANTHROPIC_API_BASE") if not anthropic_api_key: raise ValueError("Anthropic API key is required") - self.client = anthropic.AsyncAnthropic(api_key=anthropic_api_key) + if anthropic_api_base: + self.client = anthropic.AsyncAnthropic( + api_key=anthropic_api_key, + base_url=anthropic_api_base, + ) + else: + self.client = anthropic.AsyncAnthropic(api_key=anthropic_api_key) async def create_chat_completion( self, @@ -397,9 +406,15 @@ async def get_model_client( model_config = get_model_config(model_name) if model_config.provider == ModelProvider.OPENAI: - model = OpenAIClientWrapper(api_key=os.environ.get("OPENAI_API_KEY")) - if model_config.provider == ModelProvider.ANTHROPIC: - model = AnthropicClientWrapper(api_key=os.environ.get("ANTHROPIC_API_KEY")) + model = OpenAIClientWrapper( + api_key=settings.openai_api_key, + base_url=settings.openai_api_base, + ) + elif model_config.provider == ModelProvider.ANTHROPIC: + model = AnthropicClientWrapper( + api_key=settings.anthropic_api_key, + base_url=settings.anthropic_api_base, + ) if model: _model_clients[model_name] = model diff --git a/agent_memory_server/long_term_memory.py b/agent_memory_server/long_term_memory.py index c886513..1f60144 100644 --- a/agent_memory_server/long_term_memory.py +++ b/agent_memory_server/long_term_memory.py @@ -30,6 +30,7 @@ get_model_client, ) from agent_memory_server.models import ( + ExtractedMemoryRecord, MemoryMessage, MemoryRecord, MemoryRecordResults, @@ -324,21 +325,37 @@ async def compact_long_term_memories( index_name = Keys.search_index_name() # Create aggregation query to group by memory_hash and find duplicates - agg_query = ( - f"FT.AGGREGATE {index_name} {filter_str} " - "GROUPBY 1 @memory_hash " - "REDUCE COUNT 0 AS count " - 'FILTER "@count>1" ' # Only groups with more than 1 memory - "SORTBY 2 @count DESC " - f"LIMIT 0 {limit}" - ) + agg_query = [ + "FT.AGGREGATE", + index_name, + filter_str, + "GROUPBY", + str(1), + "@memory_hash", + "REDUCE", + "COUNT", + str(0), + "AS", + "count", + "FILTER", + "@count>1", # Only groups with more than 1 memory + "SORTBY", + str(2), + "@count", + "DESC", + "LIMIT", + str(0), + str(limit), + ] # Execute aggregation to find duplicate groups - duplicate_groups = await redis_client.execute_command(agg_query) + duplicate_groups = await redis_client.execute_command(*agg_query) if duplicate_groups and duplicate_groups[0] > 0: num_groups = duplicate_groups[0] - logger.info(f"Found {num_groups} groups of hash-based duplicates") + logger.info( + f"Found {num_groups} groups with hash-based duplicates to process" + ) # Process each group of duplicates for i in range(1, len(duplicate_groups), 2): @@ -423,9 +440,11 @@ async def compact_long_term_memories( ) except Exception as e: logger.error(f"Error processing duplicate group: {e}") + else: + logger.info("No hash-based duplicates found") logger.info( - f"Completed hash-based deduplication. Merged {memories_merged} memories." + f"Completed hash-based deduplication. Removed {memories_merged} duplicate memories." ) except Exception as e: logger.error(f"Error during hash-based duplicate compaction: {e}") @@ -575,7 +594,7 @@ async def compact_long_term_memories( async def index_long_term_memories( - memories: list[MemoryRecord], + memories: list[MemoryRecord | ExtractedMemoryRecord], redis_client: Redis | None = None, deduplicate: bool = False, vector_distance_threshold: float = 0.12, @@ -1186,7 +1205,6 @@ async def promote_working_memory_to_long_term( text=f"{msg.role}: {msg.content}", namespace=namespace, user_id=current_working_memory.user_id, - memory_type=MemoryTypeEnum.MESSAGE, persisted_at=None, ) diff --git a/agent_memory_server/main.py b/agent_memory_server/main.py index c2f4b66..bb4a715 100644 --- a/agent_memory_server/main.py +++ b/agent_memory_server/main.py @@ -135,7 +135,6 @@ async def lifespan(app: FastAPI): logger.info( "Redis Agent Memory Server initialized", - window_size=settings.window_size, generation_model=settings.generation_model, embedding_model=settings.embedding_model, long_term_memory=settings.long_term_memory, diff --git a/agent_memory_server/mcp.py b/agent_memory_server/mcp.py index 6ca1dc3..c5fc264 100644 --- a/agent_memory_server/mcp.py +++ b/agent_memory_server/mcp.py @@ -310,9 +310,7 @@ async def create_long_term_memories( if mem.user_id is None and settings.default_mcp_user_id: mem.user_id = settings.default_mcp_user_id - payload = CreateMemoryRecordRequest( - memories=[MemoryRecord(**mem.model_dump()) for mem in memories] - ) + payload = CreateMemoryRecordRequest(memories=memories) return await core_create_long_term_memory( payload, background_tasks=get_background_tasks() ) @@ -360,14 +358,23 @@ async def search_long_term_memory( search_long_term_memory(text="user's favorite color") ``` - 2. Search with simple session filter: + 2. Get ALL memories for a user (e.g., "what do you remember about me?"): + ```python + search_long_term_memory( + text="", # Empty string returns all memories for the user + user_id={"eq": "user_123"}, + limit=50 # Adjust based on how many memories you want + ) + ``` + + 3. Search with simple session filter: ```python search_long_term_memory(text="user's favorite color", session_id={ "eq": "session_12345" }) ``` - 3. Search with complex filters: + 4. Search with complex filters: ```python search_long_term_memory( text="user preferences", @@ -381,7 +388,7 @@ async def search_long_term_memory( ) ``` - 4. Search with datetime range filters: + 5. Search with datetime range filters: ```python search_long_term_memory( text="recent conversations", @@ -395,7 +402,7 @@ async def search_long_term_memory( ) ``` - 5. Search with between datetime filter: + 6. Search with between datetime filter: ```python search_long_term_memory( text="holiday discussions", @@ -406,7 +413,7 @@ async def search_long_term_memory( ``` Args: - text: The semantic search query text (required) + text: The semantic search query text (required). Use empty string "" to get all memories for a user. session_id: Filter by session ID namespace: Filter by namespace topics: Filter by topics @@ -467,7 +474,6 @@ async def memory_prompt( query: str, session_id: SessionId | None = None, namespace: Namespace | None = None, - window_size: int = settings.window_size, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, topics: Topics | None = None, @@ -483,20 +489,14 @@ async def memory_prompt( """ Hydrate a user query with relevant session history and long-term memories. - CRITICAL: Use this tool for EVERY question that might benefit from memory context, - especially when you don't have sufficient information to answer confidently. - This tool enriches the user's query by retrieving: 1. Context from the current conversation session 2. Relevant long-term memories related to the query - ALWAYS use this tool when: - - The user references past conversations - - The question is about user preferences or personal information - - You need additional context to provide a complete answer - - The question seems to assume information you don't have in current context + The tool returns both the relevant memories AND the user's query in a format ready for + generating comprehensive responses. - The function uses the text field from the payload as the user's query, + The function uses the query field from the payload as the user's query, and any filters to retrieve relevant memories. DATETIME INPUT FORMAT: @@ -513,12 +513,20 @@ async def memory_prompt( COMMON USAGE PATTERNS: ```python 1. Hydrate a user prompt with long-term memory search: - hydrate_memory_prompt(text="What was my favorite color?") + memory_prompt(query="What was my favorite color?") + ``` + + 2. Answer "what do you remember about me?" type questions: + memory_prompt( + query="What do you remember about me?", + user_id={"eq": "user_123"}, + limit=50 + ) ``` - 2. Hydrate a user prompt with long-term memory search and session filter: - hydrate_memory_prompt( - text="What is my favorite color?", + 3. Hydrate a user prompt with long-term memory search and session filter: + memory_prompt( + query="What is my favorite color?", session_id={ "eq": "session_12345" }, @@ -527,9 +535,9 @@ async def memory_prompt( } ) - 3. Hydrate a user prompt with long-term memory search and complex filters: - hydrate_memory_prompt( - text="What was my favorite color?", + 4. Hydrate a user prompt with long-term memory search and complex filters: + memory_prompt( + query="What was my favorite color?", topics={ "any": ["preferences", "settings"] }, @@ -539,9 +547,9 @@ async def memory_prompt( limit=5 ) - 4. Search with datetime range filters: - hydrate_memory_prompt( - text="What did we discuss recently?", + 5. Search with datetime range filters: + memory_prompt( + query="What did we discuss recently?", created_at={ "gte": "2024-01-01T00:00:00Z", "lt": "2024-02-01T00:00:00Z" @@ -553,7 +561,7 @@ async def memory_prompt( ``` Args: - - text: The user's query + - query: The user's query - session_id: Add conversation history from a working memory session - namespace: Filter session and long-term memory namespace - topics: Search for long-term memories matching topics @@ -579,7 +587,6 @@ async def memory_prompt( session_id=_session_id, namespace=namespace.eq if namespace and namespace.eq else None, user_id=user_id.eq if user_id and user_id.eq else None, - window_size=window_size, model_name=model_name, context_window_max=context_window_max, ) diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index 204dfdf..b018dfe 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field from ulid import ULID -from agent_memory_server.config import settings from agent_memory_server.filters import ( CreatedAt, Entities, @@ -131,7 +130,7 @@ class MemoryRecord(BaseModel): ) discrete_memory_extracted: Literal["t", "f"] = Field( default="f", - description="Whether memory extraction has run for this memory (only messages)", + description="Whether memory extraction has run for this memory", ) memory_type: MemoryTypeEnum = Field( default=MemoryTypeEnum.MESSAGE, @@ -151,6 +150,19 @@ class MemoryRecord(BaseModel): ) +class ExtractedMemoryRecord(MemoryRecord): + """A memory record that has already been extracted (e.g., explicit memories from API/MCP)""" + + discrete_memory_extracted: Literal["t", "f"] = Field( + default="t", + description="Whether memory extraction has run for this memory", + ) + memory_type: MemoryTypeEnum = Field( + default=MemoryTypeEnum.SEMANTIC, + description="Type of memory", + ) + + class ClientMemoryRecord(MemoryRecord): """A memory record with a client-provided ID""" @@ -238,7 +250,6 @@ class WorkingMemoryRequest(BaseModel): session_id: str namespace: str | None = None user_id: str | None = None - window_size: int = settings.window_size model_name: ModelNameLiteral | None = None context_window_max: int | None = None @@ -270,7 +281,7 @@ class MemoryRecordResultsResponse(MemoryRecordResults): class CreateMemoryRecordRequest(BaseModel): """Payload for creating memory records""" - memories: list[MemoryRecord] + memories: list[ExtractedMemoryRecord] class GetSessionsQuery(BaseModel): @@ -403,7 +414,7 @@ class MemoryPromptResponse(BaseModel): messages: list[base.Message | SystemMessage] -class LenientMemoryRecord(MemoryRecord): +class LenientMemoryRecord(ExtractedMemoryRecord): """A memory record that can be created without an ID""" id: str | None = Field(default_factory=lambda: str(ULID())) diff --git a/agent_memory_server/pluggable-long-term-memory.md b/agent_memory_server/pluggable-long-term-memory.md deleted file mode 100644 index 4096ad7..0000000 --- a/agent_memory_server/pluggable-long-term-memory.md +++ /dev/null @@ -1,152 +0,0 @@ -## Feature: Pluggable Long-Term Memory via LangChain VectorStore Adapter - -**Summary:** -Refactor agent-memory-server's long-term memory component to use the [LangChain VectorStore interface](https://python.langchain.com/docs/integrations/vectorstores/) as its backend abstraction. -This will allow users to select from dozens of supported databases (Chroma, Pinecone, Weaviate, Redis, Qdrant, Milvus, Postgres/PGVector, LanceDB, and more) with minimal custom code. -The backend should be configurable at runtime via environment variables or config, and require no custom adapters for each new supported store. - -**Reference:** -- [agent-memory-server repo](https://github.com/redis-developer/agent-memory-server) -- [LangChain VectorStore docs](https://python.langchain.com/docs/integrations/vectorstores/) - ---- - -### Requirements - -1. **Adopt LangChain VectorStore as the Storage Interface** - - All long-term memory operations (`add`, `search`, `delete`, `update`) must delegate to a LangChain-compatible VectorStore instance. - - Avoid any database-specific code paths for core CRUD/search; rely on VectorStore's interface. - - The VectorStore instance must be initialized at server startup, using connection parameters from environment variables or config. - -2. **Backend Swappability** - - The backend type (e.g., Chroma, Pinecone, Redis, Postgres, etc.) must be selectable at runtime via a config variable (e.g., `LONG_TERM_MEMORY_BACKEND`). - - All required connection/config parameters for the backend should be loaded from environment/config. - - Adding new supported databases should require no new adapter code—just list them in documentation and config. - -3. **API Mapping and Model Translation** - - Ensure your memory API endpoints map directly to the underlying VectorStore methods (e.g., `add_texts`, `similarity_search`, `delete`). - - Translate between your internal MemoryRecord model and LangChain's `Document` (or other types as needed) at the service boundary. - - Support metadata storage and filtering as allowed by the backend; document any differences in filter syntax or capability. - -4. **Configuration and Documentation** - - Document all supported backends, their config options, and any installation requirements (e.g., which Python extras to install for each backend). - - Update `.env.example` with required variables for each backend type. - - Add a table in the README listing supported databases and any notable feature support/limitations (e.g., advanced filters, hybrid search). - -5. **Testing and CI** - - Add tests to verify core flows (add, search, delete, filter) work with at least two VectorStore backends (e.g., Chroma and Redis). - - (Optional) Use in-memory stores for unit tests where possible. - -6. **(Optional but Preferred) Dependency Handling** - - Optional dependencies for each backend should be installed only if required (using extras, e.g., `pip install agent-memory-server[chroma]`). - ---- - -### Implementation Steps - -1. **Create a Thin Adapter Layer** - - Implement a `VectorStoreMemoryAdapter` class that wraps a LangChain VectorStore instance and exposes memory operations. - - Adapter methods should map 1:1 to LangChain methods (e.g., `add_texts`, `similarity_search`, `delete`), translating data models as needed. - -2. **Backend Selection and Initialization** - - On startup, read `LONG_TERM_MEMORY_BACKEND` and associated connection params. - - Dynamically instantiate the appropriate VectorStore via LangChain, passing required config. - - Store the instance as a singleton/service to be used by API endpoints. - -3. **API Endpoint Refactor** - - Refactor long-term memory API endpoints to call adapter methods only; eliminate any backend-specific logic from the endpoints. - - Ensure filter syntax in your API is converted to the form expected by each VectorStore. Where not possible, document or gracefully reject unsupported filter types. - -4. **Update Documentation** - - Clearly explain backend selection, configuration, and how to install dependencies for each supported backend. - - Add usage examples for at least two backends (Chroma and Redis recommended). - - List any differences in filtering, advanced features, or limits by backend. - -5. **Testing** - - Add or update tests to cover core memory operations with at least two different VectorStore backends. - - Use environment variables or test config files to run tests with different backends in CI. - ---- - -### Acceptance Criteria - -- [x] agent-memory-server supports Redis backends for long-term memory, both selectable at runtime via config/env. -- [x] All long-term memory API operations are delegated through the LangChain VectorStore interface. -- [x] README documents backend selection, configuration, and installation for each supported backend. -- [x] Tests cover all core flows with at least two backends (Redis and Postgres). -- [x] No breaking changes to API or existing users by default. - ---- - -**See [LangChain VectorStore Integrations](https://python.langchain.com/docs/integrations/vectorstores/) for a full list of supported databases and client libraries.** - -## Progress of Development -Keep track of your progress building this feature here. - -### Analysis Phase (Complete) -- [x] **Read existing codebase** - Analyzed current Redis-based implementation in `long_term_memory.py` -- [x] **Understand current architecture** - Current system uses RedisVL with direct Redis connections -- [x] **Identify key components to refactor**: - - `search_long_term_memories()` - Main search function using RedisVL VectorQuery - - `index_long_term_memories()` - Memory indexing with Redis hash storage - - `count_long_term_memories()` - Count operations - - Redis utilities in `utils/redis.py` for connection management and index setup -- [x] **Understand data models** - MemoryRecord contains text, metadata (topics, entities, dates), and embeddings -- [x] **Review configuration** - Current Redis config in `config.py`, need to add backend selection - -### Implementation Plan -1. **Add LangChain dependencies and backend configuration** ✅ -2. **Create VectorStore adapter interface** ✅ -3. **Implement backend factory for different VectorStores** ✅ -4. **Refactor long-term memory functions to use adapter** ✅ -5. **Update API endpoints and add documentation** ✅ -6. **Add tests for multiple backends** ✅ - -### Current Status: Implementation Complete ✅ -- [x] **Added LangChain dependencies** - Added langchain-core and optional dependencies for all major vectorstore backends -- [x] **Extended configuration** - Added backend selection and connection parameters for all supported backends -- [x] **Created VectorStoreAdapter interface** - Abstract base class with methods for add/search/delete/count operations -- [x] **Implemented LangChainVectorStoreAdapter** - Generic adapter that works with any LangChain VectorStore -- [x] **Created VectorStore factory** - Factory functions for all supported backends (Redis, Chroma, Pinecone, Weaviate, Qdrant, Milvus, PGVector, LanceDB, OpenSearch) -- [x] **Refactored core long-term memory functions** - `search_long_term_memories()`, `index_long_term_memories()`, and `count_long_term_memories()` now use the adapter -- [x] **Check and update API endpoints** - Ensure all memory API endpoints use the new adapter through the refactored functions -- [x] **Update environment configuration** - Add .env.example entries for all supported backends -- [x] **Create comprehensive documentation** - Document all supported backends, configuration options, and usage examples -- [x] **Add basic tests** - Created test suite for vectorstore adapter functionality -- [x] **Verified implementation** - All core functionality tested and working correctly - -## Summary - -✅ **FEATURE COMPLETE**: The pluggable long-term memory feature has been successfully implemented! - -The Redis Agent Memory Server now supports **9 different vector store backends** through the LangChain VectorStore interface: -- Redis (default), Chroma, Pinecone, Weaviate, Qdrant, Milvus, PostgreSQL/PGVector, LanceDB, and OpenSearch - -**Key Achievements:** -- ✅ **Zero breaking changes** - Existing Redis users continue to work without any changes -- ✅ **Runtime backend selection** - Set `LONG_TERM_MEMORY_BACKEND=` to switch -- ✅ **Unified API interface** - All backends work through the same API endpoints -- ✅ **Production ready** - Full error handling, logging, and documentation -- ✅ **Comprehensive documentation** - Complete setup guides for all backends -- ✅ **Verified functionality** - Core operations tested and working - -**Implementation Details:** -- **VectorStore Adapter Pattern** - Clean abstraction layer between memory server and LangChain VectorStores -- **Backend Factory** - Dynamic instantiation of vectorstore backends based on configuration -- **Metadata Handling** - Proper conversion between MemoryRecord and LangChain Document formats -- **Filtering Support** - Post-processing filters for complex queries (Redis native filtering disabled temporarily due to syntax complexity) -- **Error Handling** - Graceful fallbacks and comprehensive error logging - -**Testing Results:** -- ✅ **CRUD Operations** - Add, search, delete, and count operations working correctly -- ✅ **Semantic Search** - Vector similarity search with proper scoring -- ✅ **Metadata Filtering** - Session, user, namespace, topics, and entities filtering -- ✅ **Data Persistence** - Memories properly stored and retrieved -- ✅ **No Breaking Changes** - Existing functionality preserved - -**Next Steps for Future Development:** -- [ ] **Optimize Redis filtering** - Implement proper Redis JSON path filtering for better performance -- [ ] **Add proper error handling and logging** - Improve error messages for different backend failures -- [ ] **Create tests for multiple backends** - Test core functionality with Redis and at least one other backend -- [ ] **Performance benchmarking** - Compare performance across different backends -- [ ] **Migration tooling** - Tools to migrate data between backends diff --git a/agent_memory_server/summarization.py b/agent_memory_server/summarization.py index 6d90be3..bbffb9e 100644 --- a/agent_memory_server/summarization.py +++ b/agent_memory_server/summarization.py @@ -113,24 +113,20 @@ async def _incremental_summary( async def summarize_session( session_id: str, model: str, - window_size: int, + max_context_tokens: int | None = None, ) -> None: """ - Summarize messages in a session when they exceed the window size. + Summarize messages in a session when they exceed the token limit. This function: - 1. Gets the oldest messages up to window size and current context - 2. Generates a new summary that includes these messages + 1. Gets messages and current context + 2. Generates a new summary that includes older messages 3. Removes older, summarized messages and updates the context - Stop summarizing - Args: session_id: The session ID - model: The model to use - window_size: Maximum number of messages to keep - client: The client wrapper (OpenAI or Anthropic) - redis_conn: Redis connection + model: The model to use for summarization + max_context_tokens: Maximum context tokens to keep (defaults to model's context window * summarization_threshold) """ logger.debug(f"Summarizing session {session_id}") redis = await get_redis_conn() @@ -144,11 +140,11 @@ async def summarize_session( num_messages = await pipe.llen(messages_key) # type: ignore logger.debug(f"[summarization] Number of messages: {num_messages}") - if num_messages < window_size: + if num_messages < 2: # Need at least 2 messages to summarize logger.info(f"Not enough messages to summarize for session {session_id}") return - messages_raw = await pipe.lrange(messages_key, 0, window_size - 1) # type: ignore + messages_raw = await pipe.lrange(messages_key, 0, -1) # Get all messages metadata = await pipe.hgetall(metadata_key) # type: ignore pipe.multi() @@ -164,48 +160,89 @@ async def summarize_session( logger.debug(f"[summarization] Messages: {messages}") model_config = get_model_config(model) - max_tokens = model_config.max_tokens - - # Token allocation: - # - For small context (<10k): use 12.5% (min 512) - # - For medium context (10k-50k): use 10% (min 1024) - # - For large context (>50k): use 5% (min 2048) - if max_tokens < 10000: - summary_max_tokens = max(512, max_tokens // 8) # 12.5% - elif max_tokens < 50000: - summary_max_tokens = max(1024, max_tokens // 10) # 10% + full_context_tokens = model_config.max_tokens + + # Use provided max_context_tokens or calculate from model context window + if max_context_tokens is None: + max_context_tokens = int( + full_context_tokens * settings.summarization_threshold + ) + + # Calculate current token usage + encoding = tiktoken.get_encoding("cl100k_base") + current_tokens = sum( + len(encoding.encode(f"{msg.role}: {msg.content}")) + for msg in messages + ) + + # If we're under the limit, no need to summarize + if current_tokens <= max_context_tokens: + logger.info( + f"Messages under token limit ({current_tokens} <= {max_context_tokens}) for session {session_id}" + ) + return + + # Token allocation for summarization + if full_context_tokens < 10000: + summary_max_tokens = max(512, full_context_tokens // 8) # 12.5% + elif full_context_tokens < 50000: + summary_max_tokens = max(1024, full_context_tokens // 10) # 10% else: - summary_max_tokens = max(2048, max_tokens // 20) # 5% + summary_max_tokens = max(2048, full_context_tokens // 20) # 5% logger.debug( f"[summarization] Summary max tokens: {summary_max_tokens}" ) # Scale buffer tokens with context size, but keep reasonable bounds - buffer_tokens = min(max(230, max_tokens // 100), 1000) + buffer_tokens = min(max(230, full_context_tokens // 100), 1000) logger.debug(f"[summarization] Buffer tokens: {buffer_tokens}") - max_message_tokens = max_tokens - summary_max_tokens - buffer_tokens - encoding = tiktoken.get_encoding("cl100k_base") + max_message_tokens = ( + full_context_tokens - summary_max_tokens - buffer_tokens + ) + + # Determine how many messages to keep (target ~40% of max_context_tokens for recent messages) + target_remaining_tokens = int(max_context_tokens * 0.4) + + # Work backwards to find recent messages to keep + recent_messages_tokens = 0 + keep_count = 0 + + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + msg_str = f"{msg.role}: {msg.content}" + msg_tokens = len(encoding.encode(msg_str)) + + if recent_messages_tokens + msg_tokens <= target_remaining_tokens: + recent_messages_tokens += msg_tokens + keep_count += 1 + else: + break + + # Messages to summarize are the ones we're not keeping + messages_to_check = ( + messages[:-keep_count] if keep_count > 0 else messages[:-1] + ) + total_tokens = 0 messages_to_summarize = [] - for msg in messages: + for msg in messages_to_check: msg_str = f"{msg.role}: {msg.content}" msg_tokens = len(encoding.encode(msg_str)) - # TODO: Here, we take a partial message if a single message's - # total size exceeds the buffer. Should this be configurable - # behavior? + # Handle oversized messages if msg_tokens > max_message_tokens: msg_str = msg_str[: max_message_tokens // 2] msg_tokens = len(encoding.encode(msg_str)) - total_tokens += msg_tokens if total_tokens + msg_tokens <= max_message_tokens: total_tokens += msg_tokens messages_to_summarize.append(msg_str) + else: + break if not messages_to_summarize: logger.info(f"No messages to summarize for session {session_id}") @@ -227,9 +264,13 @@ async def summarize_session( pipe.hmset(metadata_key, mapping=metadata) logger.debug(f"[summarization] Metadata: {metadata_key} {metadata}") - # Messages that were summarized - num_summarized = len(messages_to_summarize) - pipe.ltrim(messages_key, 0, num_summarized - 1) + # Keep only the most recent messages that fit in our token budget + if keep_count > 0: + # Keep the last keep_count messages + pipe.ltrim(messages_key, -keep_count, -1) + else: + # Keep at least the last message + pipe.ltrim(messages_key, -1, -1) await pipe.execute() break diff --git a/agent_memory_server/vectorstore_adapter.py b/agent_memory_server/vectorstore_adapter.py index 410b6fe..18e76d1 100644 --- a/agent_memory_server/vectorstore_adapter.py +++ b/agent_memory_server/vectorstore_adapter.py @@ -874,6 +874,9 @@ def parse_timestamp_to_datetime(timestamp_val): topics=self._parse_list_field(doc.metadata.get("topics")), entities=self._parse_list_field(doc.metadata.get("entities")), memory_hash=doc.metadata.get("memory_hash", ""), + discrete_memory_extracted=doc.metadata.get( + "discrete_memory_extracted", "f" + ), memory_type=doc.metadata.get("memory_type", "message"), persisted_at=doc.metadata.get("persisted_at"), extracted_from=self._parse_list_field( diff --git a/docker-compose.yml b/docker-compose.yml index 74cdbef..1891032 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,6 @@ services: - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} # Optional configurations with defaults - LONG_TERM_MEMORY=True - - WINDOW_SIZE=20 - GENERATION_MODEL=gpt-4o-mini - EMBEDDING_MODEL=text-embedding-3-small - ENABLE_TOPIC_EXTRACTION=True @@ -38,13 +37,6 @@ services: # Add your API keys here or use a .env file - OPENAI_API_KEY=${OPENAI_API_KEY} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} - # Optional configurations with defaults - - LONG_TERM_MEMORY=True - - WINDOW_SIZE=20 - - GENERATION_MODEL=gpt-4o-mini - - EMBEDDING_MODEL=text-embedding-3-small - - ENABLE_TOPIC_EXTRACTION=True - - ENABLE_NER=True ports: - "9050:9000" depends_on: @@ -61,12 +53,6 @@ services: - OPENAI_API_KEY=${OPENAI_API_KEY} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} # Optional configurations with defaults - - LONG_TERM_MEMORY=True - - WINDOW_SIZE=20 - - GENERATION_MODEL=gpt-4o-mini - - EMBEDDING_MODEL=text-embedding-3-small - - ENABLE_TOPIC_EXTRACTION=True - - ENABLE_NER=True depends_on: - redis command: ["uv", "run", "agent-memory", "task-worker"] @@ -80,7 +66,7 @@ services: - "16380:6379" # Redis port volumes: - redis_data:/data - command: redis-server --save "" --loglevel warning --appendonly no --stop-writes-on-bgsave-error no + command: redis-server --save "30 1" --loglevel warning --appendonly no --stop-writes-on-bgsave-error no healthcheck: test: [ "CMD", "redis-cli", "ping" ] interval: 30s diff --git a/docs/api.md b/docs/api.md index aa576b3..b708fd1 100644 --- a/docs/api.md +++ b/docs/api.md @@ -24,13 +24,12 @@ The following endpoints are available: _Query Parameters:_ - `namespace` (string, optional): The namespace to use for the session - - `window_size` (int, optional): Number of messages to include in the response (default from config) - `model_name` (string, optional): The client's LLM model name to determine appropriate context window size - `context_window_max` (int, optional): Direct specification of max context window tokens (overrides model_name) - **PUT /v1/working-memory/{session_id}** Sets working memory for a session, replacing any existing memory. - Automatically summarizes conversations that exceed the window size. + Automatically summarizes conversations that exceed the token limit. _Request Body Example:_ ```json @@ -103,7 +102,8 @@ The following endpoints are available: "session": { "session_id": "session-123", "namespace": "default", - "window_size": 10 + "model_name": "gpt-4o", + "context_window_max": 4000 }, "long_term_search": { "text": "AI discussion", diff --git a/docs/memory-types.md b/docs/memory-types.md index 02bf30d..6ddd59e 100644 --- a/docs/memory-types.md +++ b/docs/memory-types.md @@ -86,7 +86,7 @@ Working memory contains: ```http # Get working memory for a session -GET /v1/working-memory/{session_id}?namespace=demo&window_size=50 +GET /v1/working-memory/{session_id}?namespace=demo&model_name=gpt-4o # Set working memory (replaces existing) PUT /v1/working-memory/{session_id} @@ -300,7 +300,8 @@ response = await memory_prompt({ "query": "Help me plan dinner", "session": { "session_id": "current_chat", - "window_size": 20 + "model_name": "gpt-4o", + "context_window_max": 4000 }, "long_term_search": { "text": "food preferences dietary restrictions", diff --git a/examples/memory_prompt_agent.py b/examples/memory_prompt_agent.py index 2aa665f..29e09f2 100644 --- a/examples/memory_prompt_agent.py +++ b/examples/memory_prompt_agent.py @@ -123,7 +123,7 @@ async def _get_memory_prompt( session_id=session_id, query=user_input, # Optional parameters to control memory retrieval - window_size=30, # Controls working memory messages + model_name="gpt-4o-mini", # Controls token-based truncation long_term_search={"limit": 30}, # Controls long-term memory limit user_id=user_id, ) diff --git a/examples/travel_agent.py b/examples/travel_agent.py index 52ba0ad..ccdde53 100644 --- a/examples/travel_agent.py +++ b/examples/travel_agent.py @@ -255,7 +255,7 @@ async def _get_working_memory(self, session_id: str, user_id: str) -> WorkingMem result = await client.get_working_memory( session_id=session_id, namespace=self._get_namespace(user_id), - window_size=15, + model_name="gpt-4o-mini", # Controls token-based truncation ) return WorkingMemory(**result.model_dump()) diff --git a/manual_oauth_qa/manual_auth0_test.py b/manual_oauth_qa/manual_auth0_test.py index 48b1328..8621011 100755 --- a/manual_oauth_qa/manual_auth0_test.py +++ b/manual_oauth_qa/manual_auth0_test.py @@ -180,7 +180,7 @@ def run_comprehensive_test(self): "session": { "session_id": "test-session-auth0", "namespace": "test-auth0", - "window_size": 10, + "model_name": "gpt-4o-mini", }, }, ), diff --git a/manual_oauth_qa/test_auth0.py b/manual_oauth_qa/test_auth0.py index 17ae874..d286ca3 100755 --- a/manual_oauth_qa/test_auth0.py +++ b/manual_oauth_qa/test_auth0.py @@ -180,7 +180,7 @@ def run_comprehensive_test(self): "session": { "session_id": "test-session-auth0", "namespace": "test-auth0", - "window_size": 10, + "model_name": "gpt-4o-mini", }, }, ), diff --git a/pyproject.toml b/pyproject.toml index aae3142..63ef15d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "sentence-transformers>=3.4.1", "structlog>=25.2.0", "tiktoken>=0.5.1", - "transformers<=4.50.3,>=4.30.0", + "transformers>=4.51.1", "uvicorn>=0.24.0", "sniffio>=1.3.1", "click>=8.1.0", diff --git a/tests/test_api.py b/tests/test_api.py index c2417f7..f7fb129 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -260,7 +260,7 @@ async def test_post_memory_compacts_long_conversation( "namespace": "test-namespace", "session_id": "test-session", } - mock_settings = Settings(window_size=1, long_term_memory=False) + mock_settings = Settings(long_term_memory=False) with ( patch("agent_memory_server.api.settings", mock_settings), @@ -288,7 +288,7 @@ async def test_post_memory_compacts_long_conversation( # Should return the summarized working memory assert "messages" in data assert "context" in data - # Should have been summarized (only 1 message kept due to window_size=1) + # Should have been summarized (token-based summarization in _summarize_working_memory) assert len(data["messages"]) == 1 assert data["messages"][0]["content"] == "Hi there" assert "Summary:" in data["context"] @@ -400,7 +400,6 @@ async def test_memory_prompt_with_session_id(self, mock_get_working_memory, clie "session": { "session_id": "test-session", "namespace": "test-namespace", - "window_size": 10, "model_name": "gpt-4o", "context_window_max": 1000, }, diff --git a/tests/test_client_api.py b/tests/test_client_api.py index 1202419..8235652 100644 --- a/tests/test_client_api.py +++ b/tests/test_client_api.py @@ -303,7 +303,6 @@ async def test_memory_prompt(memory_test_client: MemoryAPIClient): query=query, session_id=session_id, namespace="test-namespace", - window_size=5, model_name="gpt-4o", context_window_max=4000, ) diff --git a/tests/test_full_integration.py b/tests/test_full_integration.py index 1c0761a..aa0ac6d 100644 --- a/tests/test_full_integration.py +++ b/tests/test_full_integration.py @@ -653,7 +653,6 @@ async def test_memory_prompt_with_working_memory( prompt_result = await client.memory_prompt( query="What programming language should I use?", session_id=unique_session_id, - window_size=4, ) assert "messages" in prompt_result diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 732d65d..b56ff6e 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -437,3 +437,34 @@ async def test_set_working_memory_auto_id_generation(self, mcp_test_setup): memory = working_memory.memories[0] assert memory.id is not None assert len(memory.id) > 0 # ULID generates non-empty strings + + @pytest.mark.asyncio + async def test_mcp_lenient_memory_record_defaults(self, session, mcp_test_setup): + """Test that LenientMemoryRecord used by MCP has correct defaults for discrete_memory_extracted.""" + from agent_memory_server.models import ( + ExtractedMemoryRecord, + LenientMemoryRecord, + ) + + # Test 1: LenientMemoryRecord should default to discrete_memory_extracted='t' + lenient_memory = LenientMemoryRecord( + text="User likes green tea", + memory_type="semantic", + namespace="user_preferences", + ) + + assert ( + lenient_memory.discrete_memory_extracted == "t" + ), f"LenientMemoryRecord should default to 't', got '{lenient_memory.discrete_memory_extracted}'" + assert lenient_memory.memory_type.value == "semantic" + assert lenient_memory.id is not None + + # Test 2: ExtractedMemoryRecord should also default to discrete_memory_extracted='t' + extracted_memory = ExtractedMemoryRecord( + id="test_001", text="User prefers coffee", memory_type="semantic" + ) + + assert ( + extracted_memory.discrete_memory_extracted == "t" + ), f"ExtractedMemoryRecord should default to 't', got '{extracted_memory.discrete_memory_extracted}'" + assert extracted_memory.memory_type.value == "semantic" diff --git a/tests/test_summarization.py b/tests/test_summarization.py index 65bf4e9..3c92c0e 100644 --- a/tests/test_summarization.py +++ b/tests/test_summarization.py @@ -88,19 +88,22 @@ async def test_summarize_session( """Test summarize_session with mocked summarization""" session_id = "test-session" model = "gpt-3.5-turbo" - window_size = 4 + max_context_tokens = 1000 pipeline_mock = MagicMock() # pipeline is not a coroutine pipeline_mock.__aenter__ = AsyncMock(return_value=pipeline_mock) pipeline_mock.watch = AsyncMock() mock_async_redis_client.pipeline = MagicMock(return_value=pipeline_mock) - # This needs to match the window size + # Create messages that exceed the token limit + long_content = ( + "This is a very long message that will exceed our token limit " * 50 + ) messages_raw = [ - json.dumps({"role": "user", "content": "Message 1"}), - json.dumps({"role": "assistant", "content": "Message 2"}), - json.dumps({"role": "user", "content": "Message 3"}), - json.dumps({"role": "assistant", "content": "Message 4"}), + json.dumps({"role": "user", "content": long_content}), + json.dumps({"role": "assistant", "content": long_content}), + json.dumps({"role": "user", "content": long_content}), + json.dumps({"role": "assistant", "content": "Short recent message"}), ] pipeline_mock.lrange = AsyncMock(return_value=messages_raw) @@ -113,7 +116,7 @@ async def test_summarize_session( pipeline_mock.hmset = MagicMock(return_value=True) pipeline_mock.ltrim = MagicMock(return_value=True) pipeline_mock.execute = AsyncMock(return_value=True) - pipeline_mock.llen = AsyncMock(return_value=window_size) + pipeline_mock.llen = AsyncMock(return_value=4) mock_summarization.return_value = ("New summary", 300) @@ -131,28 +134,33 @@ async def test_summarize_session( await summarize_session( session_id, model, - window_size, + max_context_tokens, ) assert pipeline_mock.lrange.call_count == 1 assert pipeline_mock.lrange.call_args[0][0] == Keys.messages_key(session_id) assert pipeline_mock.lrange.call_args[0][1] == 0 - assert pipeline_mock.lrange.call_args[0][2] == window_size - 1 + assert pipeline_mock.lrange.call_args[0][2] == -1 # Get all messages assert pipeline_mock.hgetall.call_count == 1 assert pipeline_mock.hgetall.call_args[0][0] == Keys.metadata_key(session_id) assert pipeline_mock.hmset.call_count == 1 assert pipeline_mock.hmset.call_args[0][0] == Keys.metadata_key(session_id) - assert pipeline_mock.hmset.call_args.kwargs["mapping"] == { - "context": "New summary", - "tokens": "320", - } + # Verify that hmset was called with the new summary + hmset_mapping = pipeline_mock.hmset.call_args.kwargs["mapping"] + assert hmset_mapping["context"] == "New summary" + # Token count will vary based on the actual messages passed for summarization + assert "tokens" in hmset_mapping + assert ( + int(hmset_mapping["tokens"]) > 300 + ) # Should include summarization tokens plus message tokens assert pipeline_mock.ltrim.call_count == 1 assert pipeline_mock.ltrim.call_args[0][0] == Keys.messages_key(session_id) - assert pipeline_mock.ltrim.call_args[0][1] == 0 - assert pipeline_mock.ltrim.call_args[0][2] == window_size - 1 + # New token-based approach keeps recent messages + assert pipeline_mock.ltrim.call_args[0][1] == -1 # Keep last message + assert pipeline_mock.ltrim.call_args[0][2] == -1 assert pipeline_mock.execute.call_count == 1 @@ -160,12 +168,8 @@ async def test_summarize_session( assert mock_summarization.call_args[0][0] == model assert mock_summarization.call_args[0][1] == mock_openai_client assert mock_summarization.call_args[0][2] == "Previous summary" - assert mock_summarization.call_args[0][3] == [ - "user: Message 1", - "assistant: Message 2", - "user: Message 3", - "assistant: Message 4", - ] + # Verify that some messages were passed for summarization + assert len(mock_summarization.call_args[0][3]) > 0 @pytest.mark.asyncio @patch("agent_memory_server.summarization._incremental_summary") @@ -175,18 +179,24 @@ async def test_handle_summarization_no_messages( """Test summarize_session when no messages need summarization""" session_id = "test-session" model = "gpt-3.5-turbo" - window_size = 12 + max_context_tokens = 10000 # High limit so no summarization needed pipeline_mock = MagicMock() # pipeline is not a coroutine pipeline_mock.__aenter__ = AsyncMock(return_value=pipeline_mock) pipeline_mock.watch = AsyncMock() mock_async_redis_client.pipeline = MagicMock(return_value=pipeline_mock) - pipeline_mock.llen = AsyncMock(return_value=0) - pipeline_mock.lrange = AsyncMock(return_value=[]) + # Set up short messages that won't exceed token limit + short_messages = [ + json.dumps({"role": "user", "content": "Short message 1"}), + json.dumps({"role": "assistant", "content": "Short response 1"}), + ] + + pipeline_mock.llen = AsyncMock(return_value=2) + pipeline_mock.lrange = AsyncMock(return_value=short_messages) pipeline_mock.hgetall = AsyncMock(return_value={}) pipeline_mock.hmset = AsyncMock(return_value=True) - pipeline_mock.lpop = AsyncMock(return_value=True) + pipeline_mock.ltrim = AsyncMock(return_value=True) pipeline_mock.execute = AsyncMock(return_value=True) with patch( @@ -196,12 +206,15 @@ async def test_handle_summarization_no_messages( await summarize_session( session_id, model, - window_size, + max_context_tokens, ) + # Should not summarize because messages are under token limit assert mock_summarization.call_count == 0 - assert pipeline_mock.lrange.call_count == 0 - assert pipeline_mock.hgetall.call_count == 0 + # But should still check messages and metadata + assert pipeline_mock.lrange.call_count == 1 + assert pipeline_mock.hgetall.call_count == 1 + # Should not update anything since no summarization needed assert pipeline_mock.hmset.call_count == 0 - assert pipeline_mock.lpop.call_count == 0 + assert pipeline_mock.ltrim.call_count == 0 assert pipeline_mock.execute.call_count == 0 diff --git a/tests/test_vectorstore_adapter.py b/tests/test_vectorstore_adapter.py index 9f92b25..1d3a935 100644 --- a/tests/test_vectorstore_adapter.py +++ b/tests/test_vectorstore_adapter.py @@ -1,9 +1,11 @@ """Tests for the VectorStore adapter functionality.""" +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest +from agent_memory_server.filters import Namespace from agent_memory_server.models import MemoryRecord, MemoryTypeEnum from agent_memory_server.vectorstore_adapter import ( LangChainVectorStoreAdapter, @@ -544,3 +546,70 @@ async def asimilarity_search_with_relevance_scores( ) assert len(unprocessed_results_after.memories) == 0 + + def test_redis_adapter_preserves_discrete_memory_extracted_flag(self): + """Regression test: Ensure Redis adapter preserves discrete_memory_extracted='t' during search. + + This test catches the bug where MCP-created memories with discrete_memory_extracted='t' + were being returned as 'f' because the Redis vector store adapter wasn't populating + the field during document-to-memory conversion. + """ + from datetime import UTC, datetime + from unittest.mock import MagicMock + + # Create mock vectorstore and embeddings + mock_vectorstore = MagicMock() + mock_embeddings = MagicMock() + + # Create Redis adapter + adapter = RedisVectorStoreAdapter(mock_vectorstore, mock_embeddings) + + # Mock document that simulates what Redis returns for an MCP-created memory + mock_doc = MagicMock() + mock_doc.page_content = "User likes green tea" + mock_doc.metadata = { + "id_": "memory_001", + "session_id": None, + "user_id": None, + "namespace": "user_preferences", + "created_at": datetime.now(UTC).timestamp(), + "updated_at": datetime.now(UTC).timestamp(), + "last_accessed": datetime.now(UTC).timestamp(), + "topics": "preferences,beverages", + "entities": "", + "memory_hash": "abc123", + "discrete_memory_extracted": "t", # This should be preserved! + "memory_type": "semantic", + "persisted_at": None, + "extracted_from": "", + "event_date": None, + } + + # Mock the search to return our test document + mock_vectorstore.asimilarity_search_with_relevance_scores = AsyncMock( + return_value=[(mock_doc, 0.9)] + ) + + # Perform search + result = asyncio.run( + adapter.search_memories( + query="green tea", + namespace=Namespace(field="namespace", eq="user_preferences"), + limit=10, + ) + ) + + # Verify we got the memory back + assert len(result.memories) == 1 + memory = result.memories[0] + + # REGRESSION TEST: This should be 't', not 'f' + assert memory.discrete_memory_extracted == "t", ( + f"Regression: Expected discrete_memory_extracted='t', got '{memory.discrete_memory_extracted}'. " + f"This indicates the Redis adapter is not preserving the flag during search." + ) + + # Also verify other expected properties + assert memory.memory_type.value == "semantic" + assert memory.namespace == "user_preferences" + assert memory.text == "User likes green tea" diff --git a/uv.lock b/uv.lock index 2188b84..6ec4ce6 100644 --- a/uv.lock +++ b/uv.lock @@ -150,7 +150,7 @@ requires-dist = [ { name = "sniffio", specifier = ">=1.3.1" }, { name = "structlog", specifier = ">=25.2.0" }, { name = "tiktoken", specifier = ">=0.5.1" }, - { name = "transformers", specifier = ">=4.30.0,<=4.50.3" }, + { name = "transformers", specifier = ">=4.51.1" }, { name = "uvicorn", specifier = ">=0.24.0" }, ] @@ -615,7 +615,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.33.1" +version = "0.34.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -627,9 +627,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/01/bfe0534a63ce7a2285e90dbb33e8a5b815ff096d8f7743b135c256916589/huggingface_hub-0.33.1.tar.gz", hash = "sha256:589b634f979da3ea4b8bdb3d79f97f547840dc83715918daf0b64209c0844c7b", size = 426728 } +sdist = { url = "https://files.pythonhosted.org/packages/91/b4/e6b465eca5386b52cf23cb6df8644ad318a6b0e12b4b96a7e0be09cbfbcc/huggingface_hub-0.34.3.tar.gz", hash = "sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853", size = 456800 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/fb/5307bd3612eb0f0e62c3a916ae531d3a31e58fb5c82b58e3ebf7fd6f47a1/huggingface_hub-0.33.1-py3-none-any.whl", hash = "sha256:ec8d7444628210c0ba27e968e3c4c973032d44dcea59ca0d78ef3f612196f095", size = 515377 }, + { url = "https://files.pythonhosted.org/packages/59/a8/4677014e771ed1591a87b63a2392ce6923baf807193deef302dcfde17542/huggingface_hub-0.34.3-py3-none-any.whl", hash = "sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492", size = 558847 }, ] [[package]] @@ -2216,7 +2216,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.50.3" +version = "4.54.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2230,9 +2230,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/29/37877123d6633a188997d75dc17d6f526745d63361794348ce748db23d49/transformers-4.50.3.tar.gz", hash = "sha256:1d795d24925e615a8e63687d077e4f7348c2702eb87032286eaa76d83cdc684f", size = 8774363 } +sdist = { url = "https://files.pythonhosted.org/packages/21/6c/4caeb57926f91d943f309b062e22ad1eb24a9f530421c5a65c1d89378a7a/transformers-4.54.1.tar.gz", hash = "sha256:b2551bb97903f13bd90c9467d0a144d41ca4d142defc044a99502bb77c5c1052", size = 9514288 } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/22/733a6fc4a6445d835242f64c490fdd30f4a08d58f2b788613de3f9170692/transformers-4.50.3-py3-none-any.whl", hash = "sha256:6111610a43dec24ef32c3df0632c6b25b07d9711c01d9e1077bdd2ff6b14a38c", size = 10180411 }, + { url = "https://files.pythonhosted.org/packages/cf/18/eb7578f84ef5a080d4e5ca9bc4f7c68e7aa9c1e464f1b3d3001e4c642fce/transformers-4.54.1-py3-none-any.whl", hash = "sha256:c89965a4f62a0d07009d45927a9c6372848a02ab9ead9c318c3d082708bab529", size = 11176397 }, ] [[package]]