diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml
index c56e0257..8db85cf9 100644
--- a/.github/workflows/python-tests.yml
+++ b/.github/workflows/python-tests.yml
@@ -12,11 +12,13 @@ on:
- "main"
- "dev"
- "feat/*"
+ - "test"
pull_request:
branches:
- "main"
- "dev"
- "feat/*"
+ - "test"
jobs:
build:
diff --git a/.gitignore b/.gitignore
index d3d9957f..ae7bdc4d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,7 @@ tmp/
# evaluation data
*.csv
*.jsonl
+**settings.json**
evaluation/*tmp/
evaluation/results
evaluation/.env
@@ -19,7 +20,7 @@ evaluation/scripts/personamem
# benchmarks
benchmarks/
-
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -47,6 +48,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
+.run
# PyInstaller
# Usually these files are written by a python script from a template
@@ -210,3 +212,5 @@ cython_debug/
# Outputs and Evaluation Results
outputs
+
+evaluation/data/temporal_locomo
diff --git a/.vscode/settings.json b/.vscode/settings.json
deleted file mode 100644
index 815846ed..00000000
--- a/.vscode/settings.json
+++ /dev/null
@@ -1,9 +0,0 @@
-{
- "python.testing.pytestArgs": [
- "tests",
- "-vv"
- ],
- "python.testing.unittestEnabled": false,
- "python.testing.pytestEnabled": true,
- "python.analysis.typeCheckingMode": "off"
-}
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index d8998b6f..0f680505 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -16,6 +16,9 @@ services:
environment:
- PYTHONPATH=/app/src
- HF_ENDPOINT=https://hf-mirror.com
+ - QDRANT_HOST=qdrant-docker
+ - QDRANT_PORT=6333
+ - NEO4J_URI=bolt://neo4j-docker:7687
volumes:
- ../src:/app/src
- .:/app/docker
@@ -29,7 +32,7 @@ services:
- "7474:7474" # HTTP
- "7687:7687" # Bolt
healthcheck:
- test: wget http://localhost:7687 || exit 1
+ test: wget http://localhost:7474 || exit 1
interval: 1s
timeout: 10s
retries: 20
@@ -44,7 +47,7 @@ services:
- memos_network
qdrant:
- image: qdrant/qdrant:v1.15.0
+ image: qdrant/qdrant:v1.15.3
container_name: qdrant-docker
ports:
- "6333:6333" # REST API
diff --git a/docker/requirements.txt b/docker/requirements.txt
index 211ec3ca..d20c0b36 100644
--- a/docker/requirements.txt
+++ b/docker/requirements.txt
@@ -1,137 +1,160 @@
-annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "4.0"
-anyio==4.9.0 ; python_version >= "3.10" and python_version < "4.0"
-attrs==25.3.0 ; python_version >= "3.10" and python_version < "4.0"
-authlib==1.6.0 ; python_version >= "3.10" and python_version < "4.0"
-beautifulsoup4==4.13.4 ; python_version >= "3.10" and python_version < "4.0"
-certifi==2025.7.14 ; python_version >= "3.10" and python_version < "4.0"
-cffi==1.17.1 ; python_version >= "3.10" and python_version < "4.0" and platform_python_implementation != "PyPy"
-cfgv==3.4.0 ; python_version >= "3.10" and python_version < "4.0"
-charset-normalizer==3.4.2 ; python_version >= "3.10" and python_version < "4.0"
-chonkie==1.1.1 ; python_version >= "3.10" and python_version < "4.0"
-click==8.2.1 ; python_version >= "3.10" and python_version < "4.0"
-cobble==0.1.4 ; python_version >= "3.10" and python_version < "4.0"
-colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
-coloredlogs==15.0.1 ; python_version >= "3.10" and python_version < "4.0"
-cryptography==45.0.5 ; python_version >= "3.10" and python_version < "4.0"
-cyclopts==3.22.2 ; python_version >= "3.10" and python_version < "4.0"
-defusedxml==0.7.1 ; python_version >= "3.10" and python_version < "4.0"
-distlib==0.4.0 ; python_version >= "3.10" and python_version < "4.0"
-distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0"
-dnspython==2.7.0 ; python_version >= "3.10" and python_version < "4.0"
-docstring-parser==0.16 ; python_version >= "3.10" and python_version < "4.0"
-docutils==0.21.2 ; python_version >= "3.10" and python_version < "4.0"
-email-validator==2.2.0 ; python_version >= "3.10" and python_version < "4.0"
-et-xmlfile==2.0.0 ; python_version >= "3.10" and python_version < "4.0"
-exceptiongroup==1.3.0 ; python_version >= "3.10" and python_version < "4.0"
-fastapi-cli==0.0.8 ; python_version >= "3.10" and python_version < "4.0"
-fastapi-cloud-cli==0.1.4 ; python_version >= "3.10" and python_version < "4.0"
-fastapi==0.115.14 ; python_version >= "3.10" and python_version < "4.0"
-fastmcp==2.10.5 ; python_version >= "3.10" and python_version < "4.0"
-filelock==3.18.0 ; python_version >= "3.10" and python_version < "4.0"
-flatbuffers==25.2.10 ; python_version >= "3.10" and python_version < "4.0"
-fsspec==2025.7.0 ; python_version >= "3.10" and python_version < "4.0"
-greenlet==3.2.3 ; python_version >= "3.10" and python_version < "3.14" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32")
-h11==0.16.0 ; python_version >= "3.10" and python_version < "4.0"
-hf-xet==1.1.5 ; python_version >= "3.10" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64")
-httpcore==1.0.9 ; python_version >= "3.10" and python_version < "4.0"
-httptools==0.6.4 ; python_version >= "3.10" and python_version < "4.0"
-httpx-sse==0.4.1 ; python_version >= "3.10" and python_version < "4.0"
-httpx==0.28.1 ; python_version >= "3.10" and python_version < "4.0"
-huggingface-hub==0.33.4 ; python_version >= "3.10" and python_version < "4.0"
-humanfriendly==10.0 ; python_version >= "3.10" and python_version < "4.0"
-identify==2.6.12 ; python_version >= "3.10" and python_version < "4.0"
-idna==3.10 ; python_version >= "3.10" and python_version < "4.0"
-iniconfig==2.1.0 ; python_version >= "3.10" and python_version < "4.0"
-itsdangerous==2.2.0 ; python_version >= "3.10" and python_version < "4.0"
-jinja2==3.1.6 ; python_version >= "3.10" and python_version < "4.0"
-jiter==0.10.0 ; python_version >= "3.10" and python_version < "4.0"
-joblib==1.5.1 ; python_version >= "3.10" and python_version < "4.0"
-jsonschema-specifications==2025.4.1 ; python_version >= "3.10" and python_version < "4.0"
-jsonschema==4.24.1 ; python_version >= "3.10" and python_version < "4.0"
-lxml==6.0.0 ; python_version >= "3.10" and python_version < "4.0"
-magika==0.6.2 ; python_version >= "3.10" and python_version < "4.0"
-mammoth==1.9.1 ; python_version >= "3.10" and python_version < "4.0"
-markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "4.0"
-markdownify==1.1.0 ; python_version >= "3.10" and python_version < "4.0"
-markitdown==0.1.2 ; python_version >= "3.10" and python_version < "4.0"
-markupsafe==3.0.2 ; python_version >= "3.10" and python_version < "4.0"
-mcp==1.12.0 ; python_version >= "3.10" and python_version < "4.0"
-mdurl==0.1.2 ; python_version >= "3.10" and python_version < "4.0"
-mpmath==1.3.0 ; python_version >= "3.10" and python_version < "4.0"
-neo4j==5.28.1 ; python_version >= "3.10" and python_version < "4.0"
-nodeenv==1.9.1 ; python_version >= "3.10" and python_version < "4.0"
-numpy==2.2.6 ; python_version == "3.10"
-numpy==2.3.1 ; python_version >= "3.11" and python_version < "4.0"
-ollama==0.4.9 ; python_version >= "3.10" and python_version < "4.0"
-onnxruntime==1.22.1 ; python_version >= "3.10" and python_version < "4.0"
-openai==1.97.0 ; python_version >= "3.10" and python_version < "4.0"
-openapi-pydantic==0.5.1 ; python_version >= "3.10" and python_version < "4.0"
-openpyxl==3.1.5 ; python_version >= "3.10" and python_version < "4.0"
-orjson==3.11.0 ; python_version >= "3.10" and python_version < "4.0"
-packaging==25.0 ; python_version >= "3.10" and python_version < "4.0"
-pandas==2.3.1 ; python_version >= "3.10" and python_version < "4.0"
-pdfminer-six==20250506 ; python_version >= "3.10" and python_version < "4.0"
-pillow==11.3.0 ; python_version >= "3.10" and python_version < "4.0"
-platformdirs==4.3.8 ; python_version >= "3.10" and python_version < "4.0"
-pluggy==1.6.0 ; python_version >= "3.10" and python_version < "4.0"
-pre-commit==4.2.0 ; python_version >= "3.10" and python_version < "4.0"
-protobuf==6.31.1 ; python_version >= "3.10" and python_version < "4.0"
-pycparser==2.22 ; python_version >= "3.10" and python_version < "4.0" and platform_python_implementation != "PyPy"
-pydantic-core==2.33.2 ; python_version >= "3.10" and python_version < "4.0"
-pydantic-extra-types==2.10.5 ; python_version >= "3.10" and python_version < "4.0"
-pydantic-settings==2.10.1 ; python_version >= "3.10" and python_version < "4.0"
-pydantic==2.11.7 ; python_version >= "3.10" and python_version < "4.0"
-pygments==2.19.2 ; python_version >= "3.10" and python_version < "4.0"
-pyperclip==1.9.0 ; python_version >= "3.10" and python_version < "4.0"
-pyreadline3==3.5.4 ; python_version >= "3.10" and python_version < "4.0" and sys_platform == "win32"
-pytest-asyncio==0.23.8 ; python_version >= "3.10" and python_version < "4.0"
-pytest==8.4.1 ; python_version >= "3.10" and python_version < "4.0"
-python-dateutil==2.9.0.post0 ; python_version >= "3.10" and python_version < "4.0"
-python-dotenv==1.1.1 ; python_version >= "3.10" and python_version < "4.0"
-python-multipart==0.0.20 ; python_version >= "3.10" and python_version < "4.0"
-python-pptx==1.0.2 ; python_version >= "3.10" and python_version < "4.0"
-pytz==2025.2 ; python_version >= "3.10" and python_version < "4.0"
-pywin32==311 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
-pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "4.0"
-referencing==0.36.2 ; python_version >= "3.10" and python_version < "4.0"
-regex==2024.11.6 ; python_version >= "3.10" and python_version < "4.0"
-requests==2.32.4 ; python_version >= "3.10" and python_version < "4.0"
-rich-rst==1.3.1 ; python_version >= "3.10" and python_version < "4.0"
-rich-toolkit==0.14.8 ; python_version >= "3.10" and python_version < "4.0"
-rich==14.0.0 ; python_version >= "3.10" and python_version < "4.0"
-rignore==0.6.2 ; python_version >= "3.10" and python_version < "4.0"
-rpds-py==0.26.0 ; python_version >= "3.10" and python_version < "4.0"
-ruff==0.11.13 ; python_version >= "3.10" and python_version < "4.0"
-safetensors==0.5.3 ; python_version >= "3.10" and python_version < "4.0"
-schedule==1.2.2 ; python_version >= "3.10" and python_version < "4.0"
-scikit-learn==1.7.0 ; python_version >= "3.10" and python_version < "4.0"
-scipy==1.15.3 ; python_version == "3.10"
-scipy==1.16.0 ; python_version >= "3.11" and python_version < "4.0"
-sentry-sdk==2.33.0 ; python_version >= "3.10" and python_version < "4.0"
-shellingham==1.5.4 ; python_version >= "3.10" and python_version < "4.0"
-six==1.17.0 ; python_version >= "3.10" and python_version < "4.0"
-sniffio==1.3.1 ; python_version >= "3.10" and python_version < "4.0"
-soupsieve==2.7 ; python_version >= "3.10" and python_version < "4.0"
-sqlalchemy==2.0.41 ; python_version >= "3.10" and python_version < "4.0"
-sse-starlette==2.4.1 ; python_version >= "3.10" and python_version < "4.0"
-starlette==0.46.2 ; python_version >= "3.10" and python_version < "4.0"
-sympy==1.14.0 ; python_version >= "3.10" and python_version < "4.0"
-tenacity==9.1.2 ; python_version >= "3.10" and python_version < "4.0"
-threadpoolctl==3.6.0 ; python_version >= "3.10" and python_version < "4.0"
-tokenizers==0.21.2 ; python_version >= "3.10" and python_version < "4.0"
-tomli==2.2.1 ; python_version == "3.10"
-tqdm==4.67.1 ; python_version >= "3.10" and python_version < "4.0"
-transformers==4.53.2 ; python_version >= "3.10" and python_version < "4.0"
-typer==0.16.0 ; python_version >= "3.10" and python_version < "4.0"
-typing-extensions==4.14.1 ; python_version >= "3.10" and python_version < "4.0"
-typing-inspection==0.4.1 ; python_version >= "3.10" and python_version < "4.0"
-tzdata==2025.2 ; python_version >= "3.10" and python_version < "4.0"
-ujson==5.10.0 ; python_version >= "3.10" and python_version < "4.0"
-urllib3==2.5.0 ; python_version >= "3.10" and python_version < "4.0"
-uvicorn==0.35.0 ; python_version >= "3.10" and python_version < "4.0"
-uvloop==0.21.0 ; python_version >= "3.10" and python_version < "4.0" and platform_python_implementation != "PyPy" and sys_platform != "win32" and sys_platform != "cygwin"
-virtualenv==20.31.2 ; python_version >= "3.10" and python_version < "4.0"
-watchfiles==1.1.0 ; python_version >= "3.10" and python_version < "4.0"
-websockets==15.0.1 ; python_version >= "3.10" and python_version < "4.0"
-xlrd==2.0.2 ; python_version >= "3.10" and python_version < "4.0"
-xlsxwriter==3.2.5 ; python_version >= "3.10" and python_version < "4.0"
+# Docker optimized requirements - Core dependencies only
+# Excludes Windows-specific and heavy GPU packages for faster builds
+
+annotated-types==0.7.0
+anyio==4.9.0
+async-timeout==5.0.1
+attrs==25.3.0
+authlib==1.6.0
+beautifulsoup4==4.13.4
+certifi==2025.7.14
+cffi==1.17.1
+charset-normalizer==3.4.2
+chonkie==1.1.1
+click==8.2.1
+cobble==0.1.4
+colorama==0.4.6
+coloredlogs==15.0.1
+cryptography==45.0.5
+cyclopts==3.22.2
+defusedxml==0.7.1
+distro==1.9.0
+dnspython==2.7.0
+docstring-parser==0.16
+docutils==0.21.2
+email-validator==2.2.0
+et-xmlfile==2.0.0
+exceptiongroup==1.3.0
+fastapi-cli==0.0.8
+fastapi-cloud-cli==0.1.4
+fastapi==0.115.14
+fastmcp==2.10.5
+filelock==3.18.0
+flatbuffers==25.2.10
+fsspec==2025.7.0
+greenlet==3.2.3
+grpcio==1.73.1
+h11==0.16.0
+h2==4.2.0
+hf-xet==1.1.5
+hpack==4.1.0
+httpcore==1.0.9
+httptools==0.6.4
+httpx-sse==0.4.1
+httpx==0.28.1
+huggingface-hub==0.33.4
+humanfriendly==10.0
+hyperframe==6.1.0
+idna==3.10
+itsdangerous==2.2.0
+jinja2==3.1.6
+jiter==0.10.0
+joblib==1.5.1
+jsonschema-specifications==2025.4.1
+jsonschema==4.24.1
+lxml==6.0.0
+magika==0.6.2
+mammoth==1.9.1
+markdown-it-py==3.0.0
+markdownify==1.1.0
+markitdown==0.1.2
+markupsafe==3.0.2
+mcp==1.12.0
+mdurl==0.1.2
+mpmath==1.3.0
+neo4j==5.28.1
+networkx==3.5
+numpy==2.3.1
+# NVIDIA CUDA packages excluded for lighter Docker images
+# If GPU support is needed, uncomment relevant packages below:
+# nvidia-cublas-cu12==12.6.4.1
+# nvidia-cuda-cupti-cu12==12.6.80
+# nvidia-cuda-nvrtc-cu12==12.6.77
+# nvidia-cuda-runtime-cu12==12.6.77
+# nvidia-cudnn-cu12==9.5.1.17
+# nvidia-cufft-cu12==11.3.0.4
+# nvidia-cufile-cu12==1.11.1.6
+# nvidia-curand-cu12==10.3.7.77
+# nvidia-cusolver-cu12==11.7.1.2
+# nvidia-cusparse-cu12==12.5.4.2
+# nvidia-cusparselt-cu12==0.6.3
+# nvidia-nccl-cu12==2.26.2
+# nvidia-nvjitlink-cu12==12.6.85
+# nvidia-nvtx-cu12==12.6.77
+ollama==0.4.9
+onnxruntime==1.22.1
+openai==1.97.0
+openapi-pydantic==0.5.1
+openpyxl==3.1.5
+orjson==3.11.0
+packaging==25.0
+pandas==2.3.1
+pdfminer-six==20250506
+pika==1.3.2
+pillow==11.3.0
+portalocker==2.10.1
+protobuf==6.31.1
+pycparser==2.22
+pydantic-core==2.33.2
+pydantic-extra-types==2.10.5
+pydantic-settings==2.10.1
+pydantic==2.11.7
+pygments==2.19.2
+pymysql==1.1.1
+pyperclip==1.9.0
+# Windows-specific packages excluded:
+# pyreadline3==3.5.4 # Windows only
+# pywin32==311 # Windows only
+python-dateutil==2.9.0.post0
+python-dotenv==1.1.1
+python-multipart==0.0.20
+python-pptx==1.0.2
+pytz==2025.2
+pyyaml==6.0.2
+qdrant-client==1.14.3
+redis==6.2.0
+referencing==0.36.2
+regex==2024.11.6
+requests==2.32.4
+rich-rst==1.3.1
+rich-toolkit==0.14.8
+rich==14.0.0
+rignore==0.6.2
+rpds-py==0.26.0
+safetensors==0.5.3
+schedule==1.2.2
+scikit-learn==1.7.0
+scipy==1.16.0
+sentence-transformers==4.1.0
+sentry-sdk==2.33.0
+setuptools==80.9.0
+shellingham==1.5.4
+six==1.17.0
+sniffio==1.3.1
+soupsieve==2.7
+sqlalchemy==2.0.41
+sse-starlette==2.4.1
+starlette==0.46.2
+sympy==1.14.0
+tenacity==9.1.2
+threadpoolctl==3.6.0
+tokenizers==0.21.2
+# Torch excluded for lighter Docker images (very large package ~2GB)
+# If needed for ML/AI features, uncomment:
+# torch==2.7.1
+# triton==3.3.1
+tqdm==4.67.1
+transformers==4.53.2
+typer==0.16.0
+typing-extensions==4.14.1
+typing-inspection==0.4.1
+tzdata==2025.2
+ujson==5.10.0
+urllib3==2.5.0
+uvicorn==0.35.0
+uvloop==0.21.0
+volcengine-python-sdk==4.0.6
+watchfiles==1.1.0
+websockets==15.0.1
+xlrd==2.0.2
+xlsxwriter==3.2.5
\ No newline at end of file
diff --git a/evaluation/scripts/temporal_locomo/__init__.py b/evaluation/scripts/temporal_locomo/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/evaluation/scripts/temporal_locomo/locomo_eval.py b/evaluation/scripts/temporal_locomo/locomo_eval.py
new file mode 100644
index 00000000..f19e5b68
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/locomo_eval.py
@@ -0,0 +1,417 @@
+import argparse
+import asyncio
+import json
+import os
+import time
+
+import nltk
+import numpy as np
+
+from bert_score import score as bert_score
+from dotenv import load_dotenv
+from modules.locomo_eval_module import LocomoEvalModelModules
+from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
+from nltk.translate.meteor_score import meteor_score
+from openai import AsyncOpenAI
+from pydantic import BaseModel, Field
+from rouge_score import rouge_scorer
+from scipy.spatial.distance import cosine
+from sentence_transformers import SentenceTransformer
+from tqdm import tqdm
+
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+# Download necessary NLTK resources
+try:
+ nltk.download("wordnet", quiet=True)
+ nltk.download("punkt", quiet=True)
+ print("NLTK resources downloaded successfully.")
+except Exception as e:
+ print(f"Warning: Failed to download NLTK resources: {e}")
+
+
+try:
+ sentence_model_name = "Qwen/Qwen3-Embedding-0.6B"
+ sentence_model = SentenceTransformer(sentence_model_name)
+ print(f"SentenceTransformer model : {sentence_model_name} loaded successfully.")
+except Exception as e:
+ print(f"Failed to load SentenceTransformer model: {e}")
+ sentence_model = None
+
+
+class LLMGrade(BaseModel):
+ llm_judgment: str = Field(description="CORRECT or WRONG")
+ llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.")
+
+
+async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool:
+ system_prompt = """
+ You are an expert grader that determines if answers to questions match a gold standard answer
+ """
+
+ accuracy_prompt = f"""
+ Your task is to label an answer to a question as ’CORRECT’ or ’WRONG’. You will be given the following data:
+ (1) a question (posed by one user to another user),
+ (2) a ’gold’ (ground truth) answer,
+ (3) a generated answer
+ which you will score as CORRECT/WRONG.
+
+ The point of the question is to ask about something one user should know about the other user based on their prior conversations.
+ The gold answer will usually be a concise and short answer that includes the referenced topic, for example:
+ Question: Do you remember what I got the last time I went to Hawaii?
+ Gold answer: A shell necklace
+ The generated answer might be much longer, but you should be generous with your grading - as long as it touches on the same topic as the gold answer, it should be counted as CORRECT.
+
+ For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date.
+
+ Now it’s time for the real question:
+ Question: {question}
+ Gold answer: {gold_answer}
+ Generated answer: {response}
+
+ First, provide a short (one sentence) explanation of your reasoning, then finish with CORRECT or WRONG.
+ Do NOT include both CORRECT and WRONG in your response, or it will break the evaluation script.
+
+ Just return the label CORRECT or WRONG in a json format with the key as "label".
+ """
+
+ response = await llm_client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": accuracy_prompt},
+ ],
+ temperature=0,
+ )
+ message_content = response.choices[0].message.content
+ label = json.loads(message_content)["label"]
+ parsed = LLMGrade(llm_judgment=label, llm_reasoning="")
+
+ return parsed.llm_judgment.strip().lower() == "correct"
+
+
+def calculate_rouge_scores(gold_answer, response):
+ metrics = {"rouge1_f": 0.0, "rouge2_f": 0.0, "rougeL_f": 0.0}
+ try:
+ scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
+ rouge_scores = scorer.score(gold_answer, response)
+ metrics["rouge1_f"] = rouge_scores["rouge1"].fmeasure
+ metrics["rouge2_f"] = rouge_scores["rouge2"].fmeasure
+ metrics["rougeL_f"] = rouge_scores["rougeL"].fmeasure
+ except Exception as e:
+ print(f"Failed to calculate ROUGE scores: {e}")
+ return metrics
+
+
+def calculate_bleu_scores(gold_tokens, response_tokens):
+ metrics = {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
+
+ try:
+ smoothing = SmoothingFunction().method1
+ weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
+
+ for i, weight in enumerate(weights, 1):
+ metrics[f"bleu{i}"] = sentence_bleu(
+ [gold_tokens], response_tokens, weights=weight, smoothing_function=smoothing
+ )
+ except ZeroDivisionError:
+ pass
+ except Exception as e:
+ print(f"Failed to calculate BLEU scores: {e}")
+
+ return metrics
+
+
+def calculate_meteor_score(gold_tokens, response_tokens):
+ try:
+ return meteor_score([gold_tokens], response_tokens)
+ except Exception as e:
+ print(f"Failed to calculate METEOR score: {e}")
+ return 0.0
+
+
+def calculate_semantic_similarity(gold_answer, response):
+ global sentence_model
+
+ try:
+ if sentence_model is None:
+ sentence_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
+
+ gold_embedding = sentence_model.encode([gold_answer], show_progress_bar=False)[0]
+ response_embedding = sentence_model.encode([response], show_progress_bar=False)[0]
+ return 1 - cosine(gold_embedding, response_embedding)
+ except Exception as e:
+ print(f"Failed to calculate semantic similarity: {e}")
+ return 0.0
+
+
+def calculate_f1_score(gold_tokens, response_tokens):
+ try:
+ gold_set = set(gold_tokens)
+ response_set = set(response_tokens)
+
+ if len(gold_set) == 0 or len(response_set) == 0:
+ return 0.0
+
+ precision = len(gold_set.intersection(response_set)) / len(response_set)
+ recall = len(gold_set.intersection(response_set)) / len(gold_set)
+
+ if precision + recall > 0:
+ return 2 * precision * recall / (precision + recall)
+ return 0.0
+ except Exception as e:
+ print(f"Failed to calculate F1 score: {e}")
+ return 0.0
+
+
+def calculate_nlp_metrics(gold_answer, response, context, options=None):
+ if options is None:
+ options = ["lexical", "semantic"]
+
+ gold_answer = str(gold_answer) if gold_answer is not None else ""
+ response = str(response) if response is not None else ""
+
+ metrics = {"context_tokens": len(nltk.word_tokenize(context)) if context else 0}
+
+ if "lexical" in options:
+ gold_tokens = nltk.word_tokenize(gold_answer.lower())
+ response_tokens = nltk.word_tokenize(response.lower())
+
+ metrics["lexical"] = {}
+ metrics["lexical"]["f1"] = calculate_f1_score(gold_tokens, response_tokens)
+ metrics["lexical"].update(calculate_rouge_scores(gold_answer, response))
+ metrics["lexical"].update(calculate_bleu_scores(gold_tokens, response_tokens))
+ metrics["lexical"]["meteor"] = calculate_meteor_score(gold_tokens, response_tokens)
+
+ if "semantic" in options:
+ metrics["semantic"] = {}
+ metrics["semantic"]["similarity"] = calculate_semantic_similarity(gold_answer, response)
+ _, _, f1 = bert_score(
+ [gold_answer], [response], lang="en", rescale_with_baseline=True, verbose=False
+ )
+ metrics["semantic"]["bert_f1"] = f1.item() if f1 is not None else 0.0
+
+ return metrics
+
+
+def convert_numpy_types(obj):
+ if isinstance(obj, np.number):
+ return float(obj)
+ elif isinstance(obj, dict):
+ return {k: convert_numpy_types(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [convert_numpy_types(i) for i in obj]
+ else:
+ return obj
+
+
+async def process_group_responses(
+ group_id, group_responses, oai_client, evaluation_options, num_runs: int
+):
+ graded_responses = []
+
+ # Process responses with asyncio for concurrent API calls
+ for response in tqdm(group_responses, desc=f"Processing group {group_id}"):
+ question = response.get("question")
+ answer = response.get("answer")
+ ground_truth = response.get("golden_answer")
+ category = response.get("category")
+
+ context = response.get("search_context", "")
+ response_duration_ms = response.get("response_duration_ms", 0.0)
+ search_duration_ms = response.get("search_duration_ms", 0.0)
+
+ if ground_truth is None:
+ continue
+
+ grading_tasks = [
+ locomo_grader(oai_client, question, ground_truth, answer) for _ in range(num_runs)
+ ]
+ judgments = await asyncio.gather(*grading_tasks)
+ judgments_dict = {f"judgment_{i + 1}": j for i, j in enumerate(judgments)}
+
+ nlp_metrics = calculate_nlp_metrics(ground_truth, answer, context, evaluation_options)
+
+ graded_response = {
+ "question": question,
+ "answer": answer,
+ "golden_answer": ground_truth,
+ "category": category,
+ "llm_judgments": judgments_dict,
+ "nlp_metrics": nlp_metrics,
+ "response_duration_ms": response_duration_ms,
+ "search_duration_ms": search_duration_ms,
+ "total_duration_ms": response_duration_ms + search_duration_ms,
+ }
+ graded_responses.append(graded_response)
+
+ return group_id, graded_responses
+
+
+async def process_single_group(group_id, group_responses, oai_client, evaluation_options, num_runs):
+ try:
+ start_time = time.time()
+ result = await process_group_responses(
+ group_id, group_responses, oai_client, evaluation_options, num_runs
+ )
+ end_time = time.time()
+ elapsed_time = round(end_time - start_time, 2)
+ print(f"Group {group_id} processed in {elapsed_time} seconds")
+ return result
+ except Exception as e:
+ logger.error(f"Error processing group {group_id}: {e}", exc_info=True)
+ return group_id, []
+
+
+class LocomoEvaluator(LocomoEvalModelModules):
+ def __init__(self, args):
+ # Initialize base class to populate self.frame, self.version, etc.
+ super().__init__(args=args)
+
+ self.evaluation_options = getattr(args, "evaluation_options", ["lexical", "semantic"])
+ self.num_runs = getattr(args, "num_runs", 1)
+ self.max_workers = getattr(args, "workers", 4)
+
+ load_dotenv()
+ self.oai_client = AsyncOpenAI(
+ api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL")
+ )
+
+ async def run(self):
+ print(
+ f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ==="
+ )
+ print(f"Using {self.max_workers} concurrent workers for processing groups")
+
+ with open(self.response_path) as file:
+ locomo_responses = json.load(file)
+
+ num_users = 10
+ all_grades = {}
+
+ total_responses_count = sum(
+ len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users)
+ )
+ print(f"Found {total_responses_count} total responses across {num_users} users to evaluate")
+
+ # Create tasks for processing each group
+ tasks = []
+ active_users = 0
+ for group_idx in range(num_users):
+ group_id = f"locomo_exp_user_{group_idx}"
+ group_responses = locomo_responses.get(group_id, [])
+ if not group_responses:
+ print(f"No responses found for group {group_id}")
+ continue
+
+ active_users += 1
+ tasks.append(
+ process_single_group(
+ group_id=group_id,
+ group_responses=group_responses,
+ oai_client=self.oai_client,
+ evaluation_options=self.evaluation_options,
+ num_runs=self.num_runs,
+ )
+ )
+
+ print(f"Starting evaluation of {active_users} user groups with responses")
+
+ semaphore = asyncio.Semaphore(self.max_workers)
+
+ async def limited_task(task):
+ async with semaphore:
+ return await task
+
+ limited_tasks = [limited_task(task) for task in tasks]
+ group_results = await asyncio.gather(*limited_tasks)
+
+ for group_id, graded_responses in group_results:
+ all_grades[group_id] = graded_responses
+
+ print("\n=== Evaluation Complete: Calculating final scores ===")
+
+ run_scores = []
+ evaluated_count = 0
+ if self.num_runs > 0:
+ for i in range(1, self.num_runs + 1):
+ judgment_key = f"judgment_{i}"
+ current_run_correct_count = 0
+ current_run_total_count = 0
+ for group in all_grades.values():
+ for response in group:
+ if judgment_key in response["llm_judgments"]:
+ if response["llm_judgments"][judgment_key]:
+ current_run_correct_count += 1
+ current_run_total_count += 1
+
+ if current_run_total_count > 0:
+ run_accuracy = current_run_correct_count / current_run_total_count
+ run_scores.append(run_accuracy)
+
+ evaluated_count = current_run_total_count
+
+ if evaluated_count > 0:
+ mean_of_scores = np.mean(run_scores)
+ std_of_scores = np.std(run_scores)
+ print(f"LLM-as-a-Judge Mean Score: {mean_of_scores:.4f}")
+ print(f"LLM-as-a-Judge Standard Deviation: {std_of_scores:.4f}")
+ print(
+ f"(Calculated from {self.num_runs} separate runs over {evaluated_count} questions)"
+ )
+ print(f"Individual run scores: {[round(s, 4) for s in run_scores]}")
+ else:
+ print("No responses were evaluated")
+ print("LLM-as-a-Judge score: N/A (0/0)")
+
+ all_grades = convert_numpy_types(all_grades)
+ with open(self.judged_path, "w") as f:
+ json.dump(all_grades, f, indent=2)
+ print(f"Saved detailed evaluation results to {self.judged_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="memos_scheduler",
+ choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"],
+ help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="v1.0.1",
+ help="Version identifier for loading results (e.g., 1010)",
+ )
+ parser.add_argument(
+ "--num_runs",
+ type=int,
+ default=3,
+ help="Number of times to run the LLM grader for each question",
+ )
+ parser.add_argument("--evaluation_options", nargs="+", default=["lexical", "semantic"])
+ parser.add_argument(
+ "--workers", type=int, default=10, help="Number of concurrent workers for processing groups"
+ )
+ cli_args = parser.parse_args()
+
+ # Build args for evaluator
+ class Args:
+ def __init__(self, cli_args):
+ self.frame = cli_args.lib
+ self.version = cli_args.version
+ self.workers = cli_args.workers
+ self.num_runs = cli_args.num_runs
+ self.evaluation_options = cli_args.evaluation_options
+ self.top_k = 20
+ self.scheduler_flag = True
+
+ args = Args(cli_args)
+ evaluator = LocomoEvaluator(args=args)
+ asyncio.run(evaluator.run())
diff --git a/evaluation/scripts/temporal_locomo/locomo_ingestion.py b/evaluation/scripts/temporal_locomo/locomo_ingestion.py
new file mode 100644
index 00000000..321302cf
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/locomo_ingestion.py
@@ -0,0 +1,303 @@
+import concurrent.futures
+import sys
+import time
+import traceback
+
+from datetime import datetime, timezone
+from pathlib import Path
+
+from modules.constants import (
+ MEM0_GRAPH_MODEL,
+ MEM0_MODEL,
+ MEMOS_MODEL,
+ MEMOS_SCHEDULER_MODEL,
+ ZEP_MODEL,
+)
+from modules.locomo_eval_module import LocomoEvalModelModules
+from tqdm import tqdm
+
+from memos.log import get_logger
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+class LocomoIngestor(LocomoEvalModelModules):
+ def __init__(self, args):
+ super().__init__(args=args)
+
+ def ingest_session(self, client, session, frame, metadata, revised_client=None):
+ session_date = metadata["session_date"]
+ date_format = "%I:%M %p on %d %B, %Y UTC"
+ date_string = datetime.strptime(session_date, date_format).replace(tzinfo=timezone.utc)
+ iso_date = date_string.isoformat()
+ conv_id = metadata["conv_id"]
+ conv_id = "locomo_exp_user_" + str(conv_id)
+ dt = datetime.fromisoformat(iso_date)
+ timestamp = int(dt.timestamp())
+ print(f"Processing conv {conv_id}, session {metadata['session_key']}")
+ start_time = time.time()
+ print_once = True # Print example only once per session
+
+ if frame == ZEP_MODEL:
+ for chat in tqdm(session, desc=f"{metadata['session_key']}"):
+ data = chat.get("speaker") + ": " + chat.get("text")
+
+ # Print example only once per session
+ if print_once:
+ print({"context": data, "conv_id": conv_id, "created_at": iso_date})
+ print_once = False
+
+ # Check if the group exists, if not create it
+ groups = client.group.get_all_groups()
+ groups = dict(groups)["groups"]
+ exist_ids = [gp.group_id for gp in groups]
+ if conv_id not in exist_ids:
+ client.group.add(group_id=conv_id)
+
+ # Add the message to the group
+ client.graph.add(
+ data=data,
+ type="message",
+ created_at=iso_date,
+ group_id=conv_id,
+ )
+
+ elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ messages = []
+ messages_reverse = []
+
+ for chat in tqdm(session, desc=f"{metadata['session_key']}"):
+ data = chat.get("speaker") + ": " + chat.get("text")
+
+ if chat.get("speaker") == metadata["speaker_a"]:
+ messages.append({"role": "user", "content": data, "chat_time": iso_date})
+ messages_reverse.append(
+ {"role": "assistant", "content": data, "chat_time": iso_date}
+ )
+ elif chat.get("speaker") == metadata["speaker_b"]:
+ messages.append({"role": "assistant", "content": data, "chat_time": iso_date})
+ messages_reverse.append(
+ {"role": "user", "content": data, "chat_time": iso_date}
+ )
+ else:
+ raise ValueError(
+ f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}"
+ )
+
+ # Print example only once per session
+ if print_once:
+ print({"context": data, "conv_id": conv_id, "created_at": iso_date})
+ print_once = False
+
+ speaker_a_user_id = conv_id + "_speaker_a"
+ speaker_b_user_id = conv_id + "_speaker_b"
+
+ client.add(
+ messages=messages,
+ user_id=speaker_a_user_id,
+ )
+
+ revised_client.add(
+ messages=messages_reverse,
+ user_id=speaker_b_user_id,
+ )
+ print(f"Added messages for {speaker_a_user_id} and {speaker_b_user_id} successfully.")
+
+ elif frame in [MEM0_MODEL, MEM0_GRAPH_MODEL]:
+ print(f"Processing abc for {metadata['session_key']}")
+ messages = []
+ messages_reverse = []
+
+ for chat in tqdm(session, desc=f"{metadata['session_key']}"):
+ data = chat.get("speaker") + ": " + chat.get("text")
+
+ if chat.get("speaker") == metadata["speaker_a"]:
+ messages.append({"role": "user", "content": data})
+ messages_reverse.append({"role": "assistant", "content": data})
+ elif chat.get("speaker") == metadata["speaker_b"]:
+ messages.append({"role": "assistant", "content": data})
+ messages_reverse.append({"role": "user", "content": data})
+ else:
+ raise ValueError(
+ f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}"
+ )
+
+ # Print example only once per session
+ if print_once:
+ print({"context": data, "conv_id": conv_id, "created_at": iso_date})
+ print_once = False
+
+ for i in range(0, len(messages), 2):
+ batch_messages = messages[i : i + 2]
+ batch_messages_reverse = messages_reverse[i : i + 2]
+
+ if frame == "mem0":
+ client.add(
+ messages=batch_messages,
+ timestamp=timestamp,
+ user_id=metadata["speaker_a_user_id"],
+ version="v2",
+ )
+ client.add(
+ messages=batch_messages_reverse,
+ timestamp=timestamp,
+ user_id=metadata["speaker_b_user_id"],
+ version="v2",
+ )
+
+ elif frame == "mem0_graph":
+ client.add(
+ messages=batch_messages,
+ timestamp=timestamp,
+ user_id=metadata["speaker_a_user_id"],
+ output_format="v1.1",
+ version="v2",
+ enable_graph=True,
+ )
+ client.add(
+ messages=batch_messages_reverse,
+ timestamp=timestamp,
+ user_id=metadata["speaker_b_user_id"],
+ output_format="v1.1",
+ version="v2",
+ enable_graph=True,
+ )
+
+ end_time = time.time()
+ elapsed_time = round(end_time - start_time, 2)
+
+ return elapsed_time
+
+ def process_user_for_ingestion(self, conv_id, frame, locomo_df, version, num_workers=1):
+ try:
+ # Check if locomo_df is empty or doesn't have the required columns
+ if locomo_df.empty or "conversation" not in locomo_df.columns:
+ logger.warning(
+ f"Skipping user {conv_id}: locomo_df is empty or missing 'conversation' column"
+ )
+ return 0
+
+ conversation = locomo_df["conversation"].iloc[conv_id]
+ max_session_count = 35
+ start_time = time.time()
+ total_session_time = 0
+ valid_sessions = 0
+
+ revised_client = None
+ if frame == "zep":
+ client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default")
+ elif frame == "mem0" or frame == "mem0_graph":
+ client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default")
+ client.delete_all(user_id=f"locomo_exp_user_{conv_id}")
+ client.delete_all(user_id=f"{conversation.get('speaker_a')}_{conv_id}")
+ client.delete_all(user_id=f"{conversation.get('speaker_b')}_{conv_id}")
+ elif frame in ["memos", "memos_scheduler"]:
+ conv_id = "locomo_exp_user_" + str(conv_id)
+ speaker_a_user_id = conv_id + "_speaker_a"
+ speaker_b_user_id = conv_id + "_speaker_b"
+
+ client = self.get_client_for_ingestion(
+ frame=frame, user_id=speaker_a_user_id, version=version
+ )
+ revised_client = self.get_client_for_ingestion(
+ frame=frame, user_id=speaker_b_user_id, version=version
+ )
+ else:
+ raise NotImplementedError()
+
+ sessions_to_process = []
+ for session_idx in tqdm(range(max_session_count), desc=f"process_user {conv_id}"):
+ session_key = f"session_{session_idx}"
+ session = conversation.get(session_key)
+ if session is None:
+ continue
+
+ metadata = {
+ "session_date": conversation.get(f"session_{session_idx}_date_time") + " UTC",
+ "speaker_a": conversation.get("speaker_a"),
+ "speaker_b": conversation.get("speaker_b"),
+ "speaker_a_user_id": f"{conversation.get('speaker_a')}_{conv_id}",
+ "speaker_b_user_id": f"{conversation.get('speaker_b')}_{conv_id}",
+ "conv_id": conv_id,
+ "session_key": session_key,
+ }
+ sessions_to_process.append((session, metadata))
+ valid_sessions += 1
+
+ print(
+ f"Processing {valid_sessions} sessions for user {conv_id} with {num_workers} workers"
+ )
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = {
+ executor.submit(
+ self.ingest_session, client, session, frame, metadata, revised_client
+ ): metadata["session_key"]
+ for session, metadata in sessions_to_process
+ }
+
+ for future in concurrent.futures.as_completed(futures):
+ session_key = futures[future]
+ try:
+ session_time = future.result()
+ total_session_time += session_time
+ print(f"User {conv_id}, {session_key} processed in {session_time} seconds")
+ except Exception as e:
+ print(f"Error processing user {conv_id}, session {session_key}: {e!s}")
+
+ end_time = time.time()
+ elapsed_time = round(end_time - start_time, 2)
+ print(f"User {conv_id} processed successfully in {elapsed_time} seconds")
+
+ return elapsed_time
+
+ except Exception as e:
+ return f"Error processing user {conv_id}: {e!s}. Exception: {traceback.format_exc()}"
+
+ def run_ingestion(self):
+ frame = self.frame
+ version = self.version
+ num_workers = self.workers
+
+ num_users = 10
+ start_time = time.time()
+ total_time = 0
+
+ print(
+ f"Starting processing for {num_users} users in serial mode,"
+ f" each user using {num_workers} workers for sessions..."
+ )
+
+ for user_id in range(num_users):
+ try:
+ result = self.process_user_for_ingestion(
+ user_id, frame, self.locomo_df, version, num_workers
+ )
+ if isinstance(result, float):
+ total_time += result
+ else:
+ print(result)
+ except Exception as e:
+ print(
+ f"Error processing user {user_id}: {e!s}. Traceback: {traceback.format_exc()}"
+ )
+
+ if num_users > 0:
+ average_time = total_time / num_users
+ minutes = int(average_time // 60)
+ seconds = int(average_time % 60)
+ average_time_formatted = f"{minutes} minutes and {seconds} seconds"
+ print(
+ f"The frame {frame} processed {num_users} users in average of {average_time_formatted} per user."
+ )
+
+ end_time = time.time()
+ elapsed_time = round(end_time - start_time, 2)
+ minutes = int(elapsed_time // 60)
+ seconds = int(elapsed_time % 60)
+ elapsed_time = f"{minutes} minutes and {seconds} seconds"
+ print(f"Total processing time: {elapsed_time}.")
diff --git a/evaluation/scripts/temporal_locomo/locomo_metric.py b/evaluation/scripts/temporal_locomo/locomo_metric.py
new file mode 100644
index 00000000..0187c37e
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/locomo_metric.py
@@ -0,0 +1,390 @@
+import argparse
+import json
+
+import numpy as np
+import pandas as pd
+
+from modules.locomo_eval_module import LocomoEvalModelModules
+
+
+# Category mapping as per your request
+category_mapping = {
+ "4": "single hop",
+ "1": "multi hop",
+ "2": "temporal reasoning",
+ "3": "open domain",
+}
+
+
+def calculate_scores(data):
+ category_scores = {}
+ category_question_count = {}
+
+ overall_metrics = {
+ "lexical": {
+ m: []
+ for m in [
+ "f1",
+ "rouge1_f",
+ "rouge2_f",
+ "rougeL_f",
+ "bleu1",
+ "bleu2",
+ "bleu3",
+ "bleu4",
+ "meteor",
+ ]
+ },
+ "semantic": {m: [] for m in ["bert_f1", "similarity"]},
+ "context_tokens": [],
+ "duration": {
+ m: [] for m in ["response_duration_ms", "search_duration_ms", "total_duration_ms"]
+ },
+ }
+
+ category_metrics = {}
+ user_metrics = {}
+
+ total_questions = 0
+
+ all_judgment_keys = set()
+ judgment_run_scores = {}
+
+ for _user, questions in data.items():
+ for question in questions:
+ if "llm_judgments" in question:
+ all_judgment_keys.update(question["llm_judgments"].keys())
+
+ for key in all_judgment_keys:
+ judgment_run_scores[key] = []
+
+ for user, questions in data.items():
+ user_total = 0
+
+ # Initialize user_metrics with each judgment run
+ user_metrics[user] = {
+ "total": 0,
+ "llm_judge_score": 0,
+ "llm_judge_std": 0,
+ "judgment_run_scores": {key: [] for key in all_judgment_keys},
+ "lexical": {m: [] for m in overall_metrics["lexical"]},
+ "semantic": {m: [] for m in overall_metrics["semantic"]},
+ "context_tokens": [],
+ "duration": {m: [] for m in overall_metrics["duration"]},
+ }
+
+ for question in questions:
+ total_questions += 1
+ user_total += 1
+
+ if "llm_judgments" in question:
+ for judgment_key, judgment_value in question["llm_judgments"].items():
+ score = 1 if judgment_value else 0
+ judgment_run_scores[judgment_key].append(score)
+ user_metrics[user]["judgment_run_scores"][judgment_key].append(score)
+
+ category = question["category"]
+ if category not in category_scores:
+ category_scores[category] = {
+ "total": 0,
+ "category_name": category_mapping.get(str(category), "Unknown"),
+ "judgment_run_scores": {key: [] for key in all_judgment_keys},
+ }
+ category_metrics[category] = {
+ "lexical": {m: [] for m in overall_metrics["lexical"]},
+ "semantic": {m: [] for m in overall_metrics["semantic"]},
+ "context_tokens": [],
+ "duration": {m: [] for m in overall_metrics["duration"]},
+ }
+ category_question_count[category] = 0
+
+ category_scores[category]["total"] += 1
+ category_question_count[category] += 1
+
+ if "llm_judgments" in question:
+ for judgment_key, judgment_value in question["llm_judgments"].items():
+ score = 1 if judgment_value else 0
+ category_scores[category]["judgment_run_scores"][judgment_key].append(score)
+
+ nlp = question.get("nlp_metrics", {})
+ for metric in overall_metrics["lexical"]:
+ v = nlp.get("lexical", {}).get(metric)
+ if v is not None:
+ overall_metrics["lexical"][metric].append(v)
+ category_metrics[category]["lexical"][metric].append(v)
+ user_metrics[user]["lexical"][metric].append(v)
+
+ for metric in overall_metrics["semantic"]:
+ v = nlp.get("semantic", {}).get(metric)
+ if v is not None:
+ overall_metrics["semantic"][metric].append(v)
+ category_metrics[category]["semantic"][metric].append(v)
+ user_metrics[user]["semantic"][metric].append(v)
+
+ ct = nlp.get("context_tokens")
+ if ct is not None:
+ overall_metrics["context_tokens"].append(ct)
+ category_metrics[category]["context_tokens"].append(ct)
+ user_metrics[user]["context_tokens"].append(ct)
+
+ for metric in overall_metrics["duration"]:
+ v = question.get(metric)
+ if v is not None:
+ overall_metrics["duration"][metric].append(v)
+ category_metrics[category]["duration"][metric].append(v)
+ user_metrics[user]["duration"][metric].append(v)
+
+ user_metrics[user]["total"] = user_total
+
+ judgment_avgs = []
+ for _judgment_key, scores in user_metrics[user]["judgment_run_scores"].items():
+ if scores:
+ avg = np.mean(scores)
+ judgment_avgs.append(avg)
+
+ user_metrics[user]["llm_judge_score"] = np.mean(judgment_avgs) if judgment_avgs else 0.0
+ user_metrics[user]["llm_judge_std"] = (
+ np.std(judgment_avgs) if len(judgment_avgs) > 1 else 0.0
+ )
+
+ for group in ["lexical", "semantic"]:
+ for metric in user_metrics[user][group]:
+ values = user_metrics[user][group][metric]
+ user_metrics[user][group][metric] = np.mean(values) if values else 0.0
+
+ user_metrics[user]["context_tokens"] = (
+ np.mean(user_metrics[user]["context_tokens"])
+ if user_metrics[user]["context_tokens"]
+ else 0.0
+ )
+
+ duration_metrics = list(user_metrics[user]["duration"].keys())
+ for metric in duration_metrics:
+ values = user_metrics[user]["duration"][metric]
+ if values:
+ user_metrics[user]["duration"][metric] = np.mean(values)
+ user_metrics[user]["duration"][f"{metric}_p50"] = np.percentile(values, 50)
+ user_metrics[user]["duration"][f"{metric}_p95"] = np.percentile(values, 95)
+ else:
+ user_metrics[user]["duration"][metric] = 0.0
+ user_metrics[user]["duration"][f"{metric}_p50"] = 0.0
+ user_metrics[user]["duration"][f"{metric}_p95"] = 0.0
+
+ judgment_run_averages = []
+ for _judgment_key, scores in judgment_run_scores.items():
+ if scores:
+ judgment_run_averages.append(np.mean(scores))
+
+ llm_judge_score = np.mean(judgment_run_averages) if judgment_run_averages else 0.0
+ llm_judge_std = np.std(judgment_run_averages) if len(judgment_run_averages) > 1 else 0.0
+
+ category_overall_scores = {}
+ for category, score_data in category_scores.items():
+ category_judgment_avgs = []
+ for _judgment_key, scores in score_data["judgment_run_scores"].items():
+ if scores:
+ category_judgment_avgs.append(np.mean(scores))
+
+ category_overall_scores[category] = {
+ "category_name": score_data["category_name"],
+ "llm_judge_score": np.mean(category_judgment_avgs) if category_judgment_avgs else 0.0,
+ "llm_judge_std": np.std(category_judgment_avgs)
+ if len(category_judgment_avgs) > 1
+ else 0.0,
+ "total": score_data["total"],
+ "lexical": {},
+ "semantic": {},
+ "duration": {},
+ "context_tokens": 0.0,
+ }
+
+ for group in ["lexical", "semantic"]:
+ for metric in category_metrics[category][group]:
+ values = category_metrics[category][group][metric]
+ category_overall_scores[category][group][metric] = (
+ np.mean(values) if values else 0.0
+ )
+
+ category_overall_scores[category]["context_tokens"] = (
+ np.mean(category_metrics[category]["context_tokens"])
+ if category_metrics[category]["context_tokens"]
+ else 0.0
+ )
+
+ # Calculate mean and percentiles for category duration metrics
+ duration_metrics = list(
+ category_metrics[category]["duration"].keys()
+ ) # Create a list of keys first
+ for metric in duration_metrics:
+ values = category_metrics[category]["duration"][metric]
+ if values:
+ category_overall_scores[category]["duration"][metric] = np.mean(values)
+ # Add P50 (median) and P95 percentiles
+ category_overall_scores[category]["duration"][f"{metric}_p50"] = np.percentile(
+ values, 50
+ )
+ category_overall_scores[category]["duration"][f"{metric}_p95"] = np.percentile(
+ values, 95
+ )
+ else:
+ category_overall_scores[category]["duration"][metric] = 0.0
+ category_overall_scores[category]["duration"][f"{metric}_p50"] = 0.0
+ category_overall_scores[category]["duration"][f"{metric}_p95"] = 0.0
+
+ # calculate overall scores
+ overall_metric_averages = {
+ "llm_judge_score": llm_judge_score,
+ "llm_judge_std": llm_judge_std,
+ "lexical": {},
+ "semantic": {},
+ "context_tokens": 0.0,
+ "duration": {},
+ }
+
+ for group in ["lexical", "semantic"]:
+ for metric in overall_metrics[group]:
+ values = overall_metrics[group][metric]
+ overall_metric_averages[group][metric] = np.mean(values) if values else 0.0
+
+ overall_metric_averages["context_tokens"] = (
+ np.mean(overall_metrics["context_tokens"]) if overall_metrics["context_tokens"] else 0.0
+ )
+
+ duration_metrics = list(overall_metrics["duration"].keys())
+ for metric in duration_metrics:
+ values = overall_metrics["duration"][metric]
+ if values:
+ overall_metric_averages["duration"][metric] = np.mean(values)
+ overall_metric_averages["duration"][f"{metric}_p50"] = np.percentile(values, 50)
+ overall_metric_averages["duration"][f"{metric}_p95"] = np.percentile(values, 95)
+ else:
+ overall_metric_averages["duration"][metric] = 0.0
+ overall_metric_averages["duration"][f"{metric}_p50"] = 0.0
+ overall_metric_averages["duration"][f"{metric}_p95"] = 0.0
+
+ return {
+ "metrics": overall_metric_averages,
+ "category_scores": category_overall_scores,
+ "user_scores": user_metrics,
+ }
+
+
+def save_to_excel(results, output_path):
+ # Create a combined data structure for metrics and category scores
+ combined_data = []
+
+ # Process overall metrics - flatten nested structures
+ overall_row = {"category": "overall"}
+ overall_row["llm_judge_score"] = results["metrics"]["llm_judge_score"]
+ overall_row["llm_judge_std"] = results["metrics"]["llm_judge_std"]
+
+ # Add all lexical metrics
+ for metric, value in results["metrics"]["lexical"].items():
+ overall_row[metric] = value
+
+ # Add all semantic metrics
+ for metric, value in results["metrics"]["semantic"].items():
+ overall_row[metric] = value
+
+ # Add context tokens
+ overall_row["context_tokens"] = results["metrics"]["context_tokens"]
+
+ # Add all duration metrics, including percentiles
+ for metric, value in results["metrics"]["duration"].items():
+ overall_row[metric] = value
+
+ combined_data.append(overall_row)
+
+ # Process category scores - flatten nested structures
+ for _, scores in results["category_scores"].items():
+ category_row = {"category": scores["category_name"]}
+ category_row["llm_judge_score"] = scores["llm_judge_score"]
+ category_row["llm_judge_std"] = scores["llm_judge_std"]
+
+ # Add all lexical metrics
+ for metric, value in scores["lexical"].items():
+ category_row[metric] = value
+
+ # Add all semantic metrics
+ for metric, value in scores["semantic"].items():
+ category_row[metric] = value
+
+ # Add context tokens
+ category_row["context_tokens"] = scores["context_tokens"]
+
+ # Add all duration metrics, including percentiles
+ for metric, value in scores["duration"].items():
+ category_row[metric] = value
+
+ combined_data.append(category_row)
+
+ # Create DataFrame and save to Excel
+ combined_df = pd.DataFrame(combined_data)
+
+ # Create a pandas Excel writer
+ with pd.ExcelWriter(output_path) as writer:
+ combined_df.to_excel(writer, sheet_name="Metrics", index=False)
+
+ print(f"Excel file saved to: {output_path}")
+
+
+class LocomoMetric(LocomoEvalModelModules):
+ def __init__(self, args):
+ super().__init__(args=args)
+
+ def run(self):
+ with open(self.judged_path) as file:
+ data = json.load(file)
+
+ results = calculate_scores(data)
+
+ with open(self.grade_path, "w") as outfile:
+ json.dump(results, outfile, indent=4)
+
+ save_to_excel(results, self.excel_path)
+
+ print("\n=== Metric Calculation Complete ===")
+ total = sum(results["category_scores"][cat]["total"] for cat in results["category_scores"])
+ print(
+ f"LLM-as-a-Judge score: {results['metrics']['llm_judge_score']:.4f} ± {results['metrics']['llm_judge_std']:.4f}"
+ )
+ print(f"Total questions evaluated: {total}")
+
+ print("\n=== Duration Metrics ===")
+ for metric in ["response_duration_ms", "search_duration_ms", "total_duration_ms"]:
+ print(f"{metric} (avg): {results['metrics']['duration'][metric]:.2f} ms")
+ print(f"{metric} (P50): {results['metrics']['duration'][f'{metric}_p50']:.2f} ms")
+ print(f"{metric} (P95): {results['metrics']['duration'][f'{metric}_p95']:.2f} ms")
+
+ print(f"\nResults have been written to {self.grade_path}")
+ print(f"Excel report has been saved to {self.excel_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="memos_scheduler",
+ choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"],
+ help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="v1.0.1",
+ help="Version identifier for loading results (e.g., 1010)",
+ )
+ cli_args = parser.parse_args()
+
+ # Build a minimal args namespace compatible with LocomoEvalModelModules
+ class _Args:
+ def __init__(self, frame, version):
+ self.frame = frame
+ self.version = version
+ self.workers = 1
+ self.top_k = 20
+ self.scheduler_flag = True
+
+ args = _Args(frame=cli_args.lib, version=cli_args.version)
+ LocomoMetric(args=args).run()
diff --git a/evaluation/scripts/temporal_locomo/locomo_processor.py b/evaluation/scripts/temporal_locomo/locomo_processor.py
new file mode 100644
index 00000000..4ae9cf91
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/locomo_processor.py
@@ -0,0 +1,324 @@
+import json
+import sys
+
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+from time import time
+
+from dotenv import load_dotenv
+from modules.constants import (
+ MEMOS_MODEL,
+ MEMOS_SCHEDULER_MODEL,
+)
+from modules.locomo_eval_module import LocomoEvalModelModules
+from modules.prompts import (
+ SEARCH_PROMPT_MEM0,
+ SEARCH_PROMPT_MEM0_GRAPH,
+ SEARCH_PROMPT_MEMOS,
+ SEARCH_PROMPT_ZEP,
+)
+from modules.schemas import ContextUpdateMethod, RecordingCase
+from modules.utils import save_evaluation_cases
+
+from memos.log import get_logger
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+class LocomoProcessor(LocomoEvalModelModules):
+ """
+ A class for handling conversational memory management across different memory frameworks.
+ Supports multiple memory backends (zep, mem0, memos, etc.) for searching and retrieving
+ relevant context to generate conversational responses.
+ """
+
+ def __init__(self, args):
+ """Initialize the LocomoChatter with path configurations and templates"""
+ super().__init__(args=args)
+
+ # Template definitions for different memory frameworks
+ self.search_template_zep = SEARCH_PROMPT_ZEP
+
+ self.search_template_mem0 = SEARCH_PROMPT_MEM0
+
+ self.search_template_mem0_graph = SEARCH_PROMPT_MEM0_GRAPH
+
+ self.search_template_memos = SEARCH_PROMPT_MEMOS
+
+ self.processed_data_dir = self.result_dir / "processed_data"
+
+ def update_context(self, conv_id, method, **kwargs):
+ if method == ContextUpdateMethod.DIRECT:
+ if "cur_context" not in kwargs:
+ raise ValueError("cur_context is required for DIRECT update method")
+ cur_context = kwargs["cur_context"]
+ self.pre_context_cache[conv_id] = cur_context
+ elif method == ContextUpdateMethod.TEMPLATE:
+ if "query" not in kwargs or "answer" not in kwargs:
+ raise ValueError("query and answer are required for TEMPLATE update method")
+ self._update_context_template(conv_id, kwargs["query"], kwargs["answer"])
+ else:
+ raise ValueError(f"Unsupported update method: {method}")
+
+ def _update_context_template(self, conv_id, query, answer):
+ new_context = f"User: {query}\nAssistant: {answer}\n\n"
+ if self.pre_context_cache[conv_id] is None:
+ self.pre_context_cache[conv_id] = ""
+ self.pre_context_cache[conv_id] += new_context
+
+ def _process_single_qa(
+ self,
+ qa,
+ *,
+ client,
+ reversed_client,
+ metadata,
+ frame,
+ version,
+ conv_id,
+ conv_stats_path,
+ oai_client,
+ top_k,
+ conv_stats,
+ ):
+ query = qa.get("question")
+ gold_answer = qa.get("answer")
+ qa_category = qa.get("category")
+ if qa_category == 5:
+ return None
+
+ # Search
+ cur_context, search_duration_ms = self.search_query(
+ client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k
+ )
+ if not cur_context:
+ logger.warning(f"No context found for query: {query[:100]}")
+ cur_context = ""
+
+ # Context answerability analysis (for memos_scheduler only)
+ if self.pre_context_cache[conv_id] is None:
+ # Update pre-context cache with current context
+ if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ self.update_context(
+ conv_id=conv_id,
+ method=self.context_update_method,
+ cur_context=cur_context,
+ )
+ else:
+ self.update_context(
+ conv_id=conv_id,
+ method=self.context_update_method,
+ query=query,
+ answer=gold_answer,
+ )
+ return None
+
+ can_answer = False
+ can_answer_duration_ms = 0.0
+ can_answer_start = time()
+ can_answer = self.analyze_context_answerability(
+ self.pre_context_cache[conv_id], query, gold_answer, oai_client
+ )
+ can_answer_duration_ms = (time() - can_answer_start) * 1000
+ # Update global stats
+ with self.stats_lock:
+ self.stats[self.frame][self.version]["memory_stats"]["total_queries"] += 1
+ if can_answer:
+ self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] += 1
+ else:
+ self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] += 1
+ total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"]
+ can_answer_count = self.stats[self.frame][self.version]["memory_stats"][
+ "can_answer_count"
+ ]
+ hit_rate = (can_answer_count / total_queries * 100) if total_queries > 0 else 0
+ self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = hit_rate
+ self.stats[self.frame][self.version]["memory_stats"]["can_answer_duration_ms"] = (
+ can_answer_duration_ms
+ )
+ self.save_stats()
+
+ # Generate answer
+ answer_start = time()
+ answer = self.locomo_response(frame, oai_client, self.pre_context_cache[conv_id], query)
+ response_duration_ms = (time() - answer_start) * 1000
+
+ # Record case for memos_scheduler
+ if frame == "memos_scheduler":
+ try:
+ recording_case = RecordingCase(
+ conv_id=conv_id,
+ query=query,
+ answer=answer,
+ context=cur_context,
+ pre_context=self.pre_context_cache[conv_id],
+ can_answer=can_answer,
+ can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}",
+ search_duration_ms=search_duration_ms,
+ can_answer_duration_ms=can_answer_duration_ms,
+ response_duration_ms=response_duration_ms,
+ category=int(qa_category) if qa_category is not None else None,
+ golden_answer=str(qa.get("answer", "")),
+ memories=[],
+ pre_memories=[],
+ history_queries=[],
+ )
+ if can_answer:
+ self.can_answer_cases.append(recording_case)
+ else:
+ self.cannot_answer_cases.append(recording_case)
+ except Exception as e:
+ logger.error(f"Error creating RecordingCase: {e}")
+ print(f"Error creating RecordingCase: {e}")
+ logger.error(f"QA data: {qa}")
+ print(f"QA data: {qa}")
+ logger.error(f"Query: {query}")
+ logger.error(f"Answer: {answer}")
+ logger.error(
+ f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})"
+ )
+ logger.error(f"Category: {qa_category} (type: {type(qa_category)})")
+ logger.error(f"Can answer: {can_answer}")
+ raise e
+
+ # Update conversation stats
+ conv_stats["total_queries"] += 1
+ conv_stats["response_count"] += 1
+ if frame == "memos_scheduler":
+ if can_answer:
+ conv_stats["can_answer_count"] += 1
+ else:
+ conv_stats["cannot_answer_count"] += 1
+ if conv_stats["total_queries"] > 0:
+ conv_stats["answer_hit_rate"] = (
+ conv_stats["can_answer_count"] / conv_stats["total_queries"]
+ ) * 100
+
+ # Persist conversation stats snapshot
+ self._save_conv_stats(conv_id, frame, version, conv_stats, conv_stats_path)
+
+ logger.info(f"Processed question: {query[:100]}")
+ logger.info(f"Answer: {answer[:100]}")
+
+ # Update pre-context cache with current context
+ with self.stats_lock:
+ if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ self.update_context(
+ conv_id=conv_id,
+ method=self.context_update_method,
+ cur_context=cur_context,
+ )
+ else:
+ self.update_context(
+ conv_id=conv_id,
+ method=self.context_update_method,
+ query=query,
+ answer=gold_answer,
+ )
+
+ self.print_eval_info()
+
+ return {
+ "question": query,
+ "answer": answer,
+ "category": qa_category,
+ "golden_answer": gold_answer,
+ "search_context": cur_context,
+ "response_duration_ms": response_duration_ms,
+ "search_duration_ms": search_duration_ms,
+ "can_answer_duration_ms": can_answer_duration_ms,
+ "can_answer": can_answer if frame == "memos_scheduler" else None,
+ }
+
+ def run_locomo_processing(self, num_users=10):
+ load_dotenv()
+
+ frame = self.frame
+ version = self.version
+ num_workers = self.workers
+ top_k = self.top_k
+
+ # Storage for aggregated results
+ all_search_results = defaultdict(list)
+ all_response_results = defaultdict(list)
+ num_users = num_users
+
+ # Prepare arguments for each user processing task
+ user_args = [(idx, self.locomo_df, frame, version, top_k) for idx in range(num_users)]
+
+ if num_workers > 1:
+ # === parallel running ====
+ # Use ThreadPoolExecutor for parallel processing
+ print(
+ f"Starting parallel processing for {num_users} users, using {num_workers} workers for sessions..."
+ )
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ # Submit all user processing tasks
+ future_to_user = {
+ executor.submit(self.process_user_wrapper, args): idx
+ for idx, args in enumerate(user_args)
+ }
+
+ # Collect results as they complete
+ for future in as_completed(future_to_user):
+ idx = future_to_user[future]
+ user_search_results, user_response_results, error = future.result()
+ if error is not None:
+ idx, e, traceback_str = error
+ print(f"Error processing user {idx}: {e}. Exception: {traceback_str}")
+ else:
+ # Aggregate results
+ conv_id = f"locomo_exp_user_{idx}"
+ all_search_results[conv_id].extend(user_search_results[conv_id])
+ all_response_results[conv_id].extend(user_response_results[conv_id])
+
+ else:
+ # Serial processing
+ print(
+ f"Starting serial processing for {num_users} users in serial mode, each user using {num_workers} workers for sessions..."
+ )
+ for idx, args in enumerate(user_args):
+ user_search_results, user_response_results, error = self.process_user_wrapper(args)
+ if error is not None:
+ idx, e, traceback_str = error
+ print(f"Error processing user {idx}: {e}. Exception: {traceback_str}")
+ else:
+ # Aggregate results
+ conv_id = f"locomo_exp_user_{idx}"
+ all_search_results[conv_id].extend(user_search_results[conv_id])
+ all_response_results[conv_id].extend(user_response_results[conv_id])
+
+ # Print evaluation information statistics
+ self.print_eval_info()
+ self.save_stats()
+
+ # Save all aggregated results
+ with open(self.search_path, "w") as fw:
+ json.dump(all_search_results, fw, indent=2)
+ print(f"Saved all search results to {self.search_path}")
+
+ with open(self.response_path, "w") as fw:
+ json.dump(all_response_results, fw, indent=2)
+ print(f"Saved all response results to {self.response_path}")
+
+ # Save evaluation cases if they exist
+ if self.can_answer_cases or self.cannot_answer_cases:
+ try:
+ saved_files = save_evaluation_cases(
+ can_answer_cases=self.can_answer_cases,
+ cannot_answer_cases=self.cannot_answer_cases,
+ output_dir=self.stats_dir,
+ frame=self.frame,
+ version=self.version,
+ )
+ print(f"Saved evaluation cases: {saved_files}")
+ except Exception as e:
+ logger.error(f"Error saving evaluation cases: {e}")
+
+ return dict(all_search_results), dict(all_response_results)
diff --git a/evaluation/scripts/temporal_locomo/modules/README.md b/evaluation/scripts/temporal_locomo/modules/README.md
new file mode 100644
index 00000000..31a274dd
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/README.md
@@ -0,0 +1,83 @@
+# Evaluation Modules
+
+This directory contains the modularized evaluation system for temporal locomo evaluation, organized using inheritance and composition patterns.
+
+## Structure
+
+### Base Classes
+
+- **`base_eval_module.py`**: Contains the `BaseEvalModule` class with common functionality:
+ - Statistics management
+ - Data loading and processing
+ - File I/O operations
+ - Basic evaluation methods
+
+### Specialized Modules
+
+- **`client_manager.py`**: Contains the `ClientManager` class for managing different memory framework clients:
+ - Zep client management
+ - Mem0 client management
+ - Memos client management
+ - Memos scheduler client management
+
+- **`search_modules.py`**: Contains the `SearchModules` class with all search methods:
+ - `mem0_search()`: Mem0 framework search
+ - `mem0_graph_search()`: Mem0 graph framework search
+ - `memos_search()`: Memos framework search
+ - `memos_scheduler_search()`: Memos scheduler framework search
+ - `zep_search()`: Zep framework search
+
+- **`locomo_eval_module.py`**: Contains the main `LocomoEvalModule` class that combines all functionality:
+ - Inherits from `BaseEvalModule`
+ - Uses `ClientManager` for client management
+ - Uses `SearchModules` for search operations
+ - Provides unified interface for evaluation
+
+## Usage
+
+### Basic Usage
+
+```python
+from modules import LocomoEvalModule
+import argparse
+
+# Create arguments
+args = argparse.Namespace()
+args.frame = 'memos_scheduler'
+args.version = 'v0.2.1'
+args.top_k = 20
+args.workers = 1
+
+# Initialize the evaluation module
+eval_module = LocomoEvalModule(args)
+
+# Use the module
+eval_module.print_eval_info()
+eval_module.save_stats()
+```
+
+### Backward Compatibility
+
+For backward compatibility, the original `LocomoEvalModelModules` class is available as an alias:
+
+```python
+from modules import LocomoEvalModule as LocomoEvalModelModules
+```
+
+## Benefits of Modularization
+
+1. **Separation of Concerns**: Each module has a specific responsibility
+2. **Maintainability**: Easier to modify and extend individual components
+3. **Testability**: Each module can be tested independently
+4. **Reusability**: Modules can be reused in different contexts
+5. **Readability**: Code is more organized and easier to understand
+
+## Migration from Original Code
+
+The original `eval_model_modules.py` has been refactored into this modular structure:
+
+- **Original class**: `LocomoEvalModelModules`
+- **New main class**: `LocomoEvalModule`
+- **Backward compatibility**: `LocomoEvalModelModules = LocomoEvalModule`
+
+All existing functionality is preserved, but now organized in a more maintainable structure.
diff --git a/evaluation/scripts/temporal_locomo/modules/__init__.py b/evaluation/scripts/temporal_locomo/modules/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py
new file mode 100644
index 00000000..4ec7d492
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py
@@ -0,0 +1,382 @@
+import json
+import os
+import traceback
+
+from collections import defaultdict
+from pathlib import Path
+from threading import Lock
+from typing import TYPE_CHECKING
+
+import pandas as pd
+
+from dotenv import load_dotenv
+
+from memos.configs.mem_scheduler import AuthConfig
+from memos.log import get_logger
+
+from .constants import (
+ BASE_DIR,
+ MEMOS_MODEL,
+ MEMOS_SCHEDULER_MODEL,
+)
+from .prompts import (
+ CUSTOM_INSTRUCTIONS,
+)
+from .schemas import ContextUpdateMethod
+
+
+if TYPE_CHECKING:
+ from .schemas import RecordingCase
+
+
+logger = get_logger(__name__)
+
+
+class BaseEvalModule:
+ def __init__(self, args):
+ # hyper-parameters
+ self.args = args
+ self.frame = self.args.frame
+ self.version = self.args.version
+ self.workers = self.args.workers
+ self.top_k = self.args.top_k
+
+ # attributes
+ if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ self.context_update_method = ContextUpdateMethod.DIRECT
+ else:
+ self.context_update_method = ContextUpdateMethod.TEMPLATE
+ self.custom_instructions = CUSTOM_INSTRUCTIONS
+ self.data_dir = Path(f"{BASE_DIR}/data")
+ self.locomo_df = pd.read_json(f"{self.data_dir}/locomo/locomo10.json")
+
+ # Load temporal_locomo dataset if it exists
+ self.temporal_locomo_data = None
+ temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json"
+ if temporal_locomo_file.exists():
+ with open(temporal_locomo_file, encoding="utf-8") as f:
+ self.temporal_locomo_data = json.load(f)
+ logger.info(
+ f"Loaded temporal_locomo dataset with {len(self.temporal_locomo_data)} conversations"
+ )
+ else:
+ logger.warning(f"Temporal locomo dataset not found at {temporal_locomo_file}")
+ # Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation
+ if (
+ hasattr(self.args, "scheduler_flag")
+ and self.frame == "memos_scheduler"
+ and self.args.scheduler_flag is False
+ ):
+ self.result_dir = Path(
+ f"{BASE_DIR}/results/temporal_locomo/{self.frame}-{self.version}-ablation/"
+ )
+ else:
+ self.result_dir = Path(
+ f"{BASE_DIR}/results/temporal_locomo/{self.frame}-{self.version}/"
+ )
+ self.result_dir.mkdir(parents=True, exist_ok=True)
+
+ self.search_path = self.result_dir / f"{self.frame}-{self.version}_search_results.json"
+ self.response_path = self.result_dir / f"{self.frame}-{self.version}_responses.json"
+ self.judged_path = self.result_dir / f"{self.frame}-{self.version}_judged.json"
+ self.grade_path = self.result_dir / f"{self.frame}-{self.version}_grades.json"
+ self.excel_path = self.result_dir / f"{self.frame}-{self.version}_metrics.xlsx"
+
+ self.ingestion_storage_dir = self.result_dir / "storages"
+ self.mos_config_path = Path(f"{BASE_DIR}/configs-example/mos_w_scheduler_config.json")
+ self.mem_cube_config_path = Path(f"{BASE_DIR}/configs-example/mem_cube_config.json")
+ self.openai_api_key = os.getenv("CHAT_MODEL_API_KEY")
+ self.openai_base_url = os.getenv("CHAT_MODEL_BASE_URL")
+ self.openai_chat_model = os.getenv("CHAT_MODEL")
+
+ auth_config_path = Path(f"{BASE_DIR}/scripts/temporal_locomo/eval_auth.json")
+ if auth_config_path.exists():
+ auth_config = AuthConfig.from_local_config(config_path=auth_config_path)
+
+ self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8"))
+ self.mem_cube_config_data = json.load(
+ self.mem_cube_config_path.open("r", encoding="utf-8")
+ )
+
+ # Update LLM authentication information in MOS configuration using dictionary assignment
+ self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_key"] = (
+ auth_config.openai.api_key
+ )
+ self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_base"] = (
+ auth_config.openai.base_url
+ )
+
+ # Update graph database authentication information in memory cube configuration using dictionary assignment
+ self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["uri"] = (
+ auth_config.graph_db.uri
+ )
+ self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["user"] = (
+ auth_config.graph_db.user
+ )
+ self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["password"] = (
+ auth_config.graph_db.password
+ )
+ self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = (
+ auth_config.graph_db.db_name
+ )
+ self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["auto_create"] = (
+ auth_config.graph_db.auto_create
+ )
+
+ self.openai_api_key = auth_config.openai.api_key
+ self.openai_base_url = auth_config.openai.base_url
+ self.openai_chat_model = auth_config.openai.default_model
+ else:
+ print("Please referring to configs-example to provide valid configs.")
+ exit()
+
+ # Logger initialization
+ self.logger = logger
+
+ # Statistics tracking with thread safety
+ self.stats = {self.frame: {self.version: defaultdict(dict)}}
+ self.stats[self.frame][self.version]["response_stats"] = defaultdict(dict)
+ self.stats[self.frame][self.version]["response_stats"]["response_failure"] = 0
+ self.stats[self.frame][self.version]["response_stats"]["response_count"] = 0
+
+ self.stats[self.frame][self.version]["memory_stats"] = defaultdict(dict)
+ self.stats[self.frame][self.version]["memory_stats"]["total_queries"] = 0
+ self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] = 0
+ self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] = 0
+ self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = 0.0
+
+ # Initialize memory history for tracking retrieval results
+ self.stats_lock = Lock()
+ # Reflect CLI flag
+ self.scheduler_flag = bool(getattr(self.args, "scheduler_flag", True))
+ self.stats_dir = self.result_dir / f"stats/{self.frame}_{self.version}"
+ self.stats_dir.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
+ self.stats_path = self.stats_dir / "stats.txt"
+
+ self.can_answer_cases: list[RecordingCase] = []
+ self.cannot_answer_cases: list[RecordingCase] = []
+ load_dotenv()
+
+ def print_eval_info(self):
+ """
+ Calculate and print the evaluation information including answer statistics for memory scheduler (thread-safe).
+ Shows total queries, can answer count, cannot answer count, and answer hit rate.
+ """
+ with self.stats_lock:
+ # Get statistics
+ total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"]
+ can_answer_count = self.stats[self.frame][self.version]["memory_stats"][
+ "can_answer_count"
+ ]
+ cannot_answer_count = self.stats[self.frame][self.version]["memory_stats"][
+ "cannot_answer_count"
+ ]
+ hit_rate = self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"]
+
+ # Print basic statistics
+ print(f"Total Queries: {total_queries}")
+ logger.info(f"Total Queries: {total_queries}")
+
+ print(f"Can Answer Count: {can_answer_count}")
+ logger.info(f"Can Answer Count: {can_answer_count}")
+
+ print(f"Cannot Answer Count: {cannot_answer_count}")
+ logger.info(f"Cannot Answer Count: {cannot_answer_count}")
+
+ # Verify count consistency
+ if total_queries != (can_answer_count + cannot_answer_count):
+ print(
+ f"WARNING: Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})"
+ )
+ logger.warning(
+ f"Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})"
+ )
+
+ print(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})")
+ logger.info(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})")
+
+ def save_stats(self):
+ """
+ Serializes and saves the contents of self.stats to the specified path:
+ Base_dir/results/frame-version/stats
+
+ This method handles directory creation, thread-safe access to statistics data,
+ and proper JSON serialization of complex data structures.
+ """
+ try:
+ # Thread-safe access to the stats data using the lock
+ # Create a copy of the data to prevent modification during serialization
+ stats_data = dict(self.stats)
+
+ # Helper function to convert defaultdict to regular dict for JSON serialization
+ def convert_defaultdict(obj):
+ if isinstance(obj, defaultdict):
+ return dict(obj)
+ return obj
+
+ # Debug: Print stats summary before saving
+ self.logger.info(f"DEBUG: Saving stats for {self.frame}-{self.version}")
+ self.logger.info(f"DEBUG: Stats path: {self.stats_path}")
+ self.logger.info(f"DEBUG: Stats data keys: {list(stats_data.keys())}")
+ if self.frame in stats_data and self.version in stats_data[self.frame]:
+ frame_data = stats_data[self.frame][self.version]
+ self.logger.info(f"DEBUG: Memory stats: {frame_data.get('memory_stats', {})}")
+ self.logger.info(
+ f"DEBUG: Total queries: {frame_data.get('memory_stats', {}).get('total_queries', 0)}"
+ )
+
+ # Serialize and save the statistics data to file
+ with self.stats_path.open("w", encoding="utf-8") as fw:
+ json.dump(stats_data, fw, ensure_ascii=False, indent=2, default=convert_defaultdict)
+
+ self.logger.info(f"Successfully saved stats to: {self.stats_path}")
+ print(f"DEBUG: Stats file created at {self.stats_path}")
+
+ except Exception as e:
+ self.logger.error(f"Failed to save stats: {e!s}")
+ self.logger.error(traceback.format_exc())
+ print(f"DEBUG: Error saving stats: {e}")
+
+ def get_answer_hit_rate(self):
+ """
+ Get current answer hit rate statistics.
+
+ Returns:
+ dict: Hit rate statistics
+ """
+ with self.stats_lock:
+ return {
+ "total_queries": self.stats[self.frame][self.version]["memory_stats"][
+ "total_queries"
+ ],
+ "can_answer_count": self.stats[self.frame][self.version]["memory_stats"][
+ "can_answer_count"
+ ],
+ "hit_rate_percentage": self.stats[self.frame][self.version]["memory_stats"][
+ "answer_hit_rate"
+ ],
+ }
+
+ def group_and_sort_qa_by_day(self, qa_set, sort_by_evidence):
+ """
+ Groups QA pairs by day and sorts them chronologically within each day group.
+
+ Args:
+ qa_set (list): List of dictionaries containing QA data with evidence references
+
+ Returns:
+ dict: Dictionary where keys are day strings (e.g., 'D1') and values are
+ lists of QA pairs sorted by evidence order within that day
+ """
+ # Initialize a dictionary that automatically creates lists for new keys
+ day_groups = defaultdict(list)
+
+ # Process each QA pair in the input dataset
+ for qa in qa_set:
+ # Extract all unique days referenced in this QA pair's evidence
+ days = set()
+ for evidence in qa["evidence"]:
+ # Split evidence string (e.g., 'D1:3') into day and position parts
+ day = evidence.split(":")[0] # Gets 'D1', 'D2', etc.
+ days.add(day)
+
+ # Add this QA pair to each day group it references
+ for day in days:
+ day_groups[day].append(qa)
+
+ if sort_by_evidence:
+ # Sort QA pairs within each day group by their earliest evidence position
+ for day in day_groups:
+ # Create list of (qa, position) pairs for proper sorting
+ qa_position_pairs = []
+
+ for qa in day_groups[day]:
+ # Find the earliest evidence position for this day
+ earliest_position = None
+ for evidence in qa["evidence"]:
+ if evidence.startswith(day + ":"):
+ try:
+ position = int(evidence.split(":")[1])
+ if earliest_position is None or position < earliest_position:
+ earliest_position = position
+ except (IndexError, ValueError):
+ # Skip invalid evidence format
+ continue
+
+ if earliest_position is not None:
+ qa_position_pairs.append((qa, earliest_position))
+
+ # Sort by evidence position (earliest first)
+ qa_position_pairs = sorted(qa_position_pairs, key=lambda x: x[1])
+ day_groups[day] = [qa for qa, _ in qa_position_pairs]
+
+ return dict(day_groups)
+
+ def convert_locomo_to_temporal_locomo(self, output_dir: str | None = None):
+ """
+ Convert locomo dataset to temporal_locomo dataset format.
+
+ This function processes the original locomo dataset and reorganizes it by days
+ with proper chronological ordering within each day group.
+
+ Args:
+ output_dir: Output directory for the converted dataset.
+ Defaults to evaluation/data/temporal_locomo/
+
+ Returns:
+ str: Path to the converted dataset file
+ """
+ if output_dir is None:
+ output_dir = f"{BASE_DIR}/data/temporal_locomo"
+
+ # Create output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Load original locomo data
+ locomo_data = self.locomo_df.to_dict("records")
+
+ # Process each conversation
+ temporal_data = []
+
+ for conv_id, conversation in enumerate(locomo_data):
+ logger.info(f"Processing conversation {conv_id + 1}/{len(locomo_data)}")
+
+ # Get QA pairs for this conversation
+ qa_set = conversation.get("qa", [])
+
+ # Group and sort QA pairs by day
+ day_groups = self.group_and_sort_qa_by_day(qa_set, sort_by_evidence=False)
+
+ # Create temporal structure for this conversation
+ temporal_conversation = {"conversation_id": f"locomo_exp_user_{conv_id}", "days": {}}
+
+ # Process each day group
+ for day, qa_list in day_groups.items():
+ temporal_conversation["days"][day] = {
+ "day_id": day,
+ "qa_pairs": qa_list,
+ "total_qa_pairs": len(qa_list),
+ }
+
+ temporal_data.append(temporal_conversation)
+
+ # Save the converted dataset
+ output_file = os.path.join(output_dir, "temporal_locomo_qa.json")
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(temporal_data, f, indent=2, ensure_ascii=False)
+
+ logger.info(f"Converted dataset saved to: {output_file}")
+ logger.info(f"Total conversations: {len(temporal_data)}")
+
+ # Log statistics
+ total_qa_pairs = sum(len(conv["qa"]) for conv in locomo_data)
+ total_temporal_qa_pairs = sum(
+ sum(day_data["total_qa_pairs"] for day_data in conv["days"].values())
+ for conv in temporal_data
+ )
+
+ logger.info(f"Original QA pairs: {total_qa_pairs}")
+ logger.info(f"Temporal QA pairs: {total_temporal_qa_pairs}")
+ logger.info("QA pairs may be duplicated across days if they reference multiple days")
+
+ return output_file
diff --git a/evaluation/scripts/temporal_locomo/modules/client_manager.py b/evaluation/scripts/temporal_locomo/modules/client_manager.py
new file mode 100644
index 00000000..f49ab40f
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/client_manager.py
@@ -0,0 +1,186 @@
+"""
+Client management module for handling different memory framework clients.
+"""
+
+import os
+
+from mem0 import MemoryClient
+from zep_cloud.client import Zep
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+from memos.mem_scheduler.analyzer.scheduler_for_eval import SchedulerForEval
+
+from .base_eval_module import BaseEvalModule
+from .constants import (
+ MEM0_GRAPH_MODEL,
+ MEM0_MODEL,
+ MEMOS_MODEL,
+ MEMOS_SCHEDULER_MODEL,
+ ZEP_MODEL,
+)
+from .prompts import (
+ ANSWER_PROMPT_MEM0,
+ ANSWER_PROMPT_MEMOS,
+ ANSWER_PROMPT_ZEP,
+)
+
+
+logger = get_logger(__name__)
+
+
+class EvalModuleWithClientManager(BaseEvalModule):
+ """
+ Manages different memory framework clients for evaluation.
+ """
+
+ def __init__(self, args):
+ super().__init__(args=args)
+
+ def get_client_for_ingestion(
+ self, frame: str, user_id: str | None = None, version: str = "default"
+ ):
+ if frame == ZEP_MODEL:
+ zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2")
+ return zep
+
+ elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL):
+ mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY"))
+ mem0.update_project(custom_instructions=self.custom_instructions)
+ return mem0
+ else:
+ if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ raise NotImplementedError(f"Unsupported framework: {frame}")
+
+ # scheduler is not needed in the ingestion step
+ self.mos_config_data["top_k"] = 20
+ self.mos_config_data["enable_mem_scheduler"] = False
+
+ mos_config = MOSConfig(**self.mos_config_data)
+ mos = MOS(mos_config)
+ mos.create_user(user_id=user_id)
+
+ self.mem_cube_config_data["user_id"] = user_id
+ self.mem_cube_config_data["cube_id"] = user_id
+ self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = (
+ f"{user_id.replace('_', '')}{version}"
+ )
+ mem_cube_config = GeneralMemCubeConfig.model_validate(self.mem_cube_config_data)
+ mem_cube = GeneralMemCube(mem_cube_config)
+
+ storage_path = str(self.ingestion_storage_dir / user_id)
+ try:
+ mem_cube.dump(storage_path)
+ except Exception as e:
+ print(f"dumping memory cube: {e!s} already exists, will use it.")
+
+ mos.register_mem_cube(
+ mem_cube_name_or_path=storage_path,
+ mem_cube_id=user_id,
+ user_id=user_id,
+ )
+
+ return mos
+
+ def get_client_from_storage(
+ self, frame: str, user_id: str | None = None, version: str = "default", top_k: int = 20
+ ):
+ """
+ Get a client instance for the specified memory framework.
+
+ Args:
+ frame: Memory framework to use (zep, mem0, mem0_graph, memos, memos_scheduler)
+ user_id: Unique identifier for the user
+ version: Version identifier for result storage
+ top_k: Number of results to retrieve in search queries
+
+ Returns:
+ Client instance for the specified framework
+ """
+ storage_path = str(self.ingestion_storage_dir / user_id)
+
+ if frame == ZEP_MODEL:
+ zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2")
+ return zep
+
+ elif frame == [MEM0_MODEL, MEM0_GRAPH_MODEL]:
+ mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY"))
+ return mem0
+
+ else:
+ if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ raise NotImplementedError(f"Unsupported framework: {frame}")
+
+ if frame == MEMOS_MODEL:
+ self.mos_config_data["enable_mem_scheduler"] = False
+
+ self.mos_config_data["top_k"] = top_k
+ mos_config = MOSConfig(**self.mos_config_data)
+ mos = MOS(mos_config)
+ mos.create_user(user_id=user_id)
+ mos.register_mem_cube(
+ mem_cube_name_or_path=storage_path,
+ mem_cube_id=user_id,
+ user_id=user_id,
+ )
+
+ if frame == MEMOS_SCHEDULER_MODEL:
+ # Configure memory scheduler
+ mos.mem_scheduler.current_mem_cube = mos.mem_cubes[user_id]
+ mos.mem_scheduler.current_mem_cube_id = user_id
+ mos.mem_scheduler.current_user_id = user_id
+
+ # Create SchedulerForEval instance with the same config
+ scheduler_for_eval = SchedulerForEval(config=mos.mem_scheduler.config)
+ # Initialize with the same modules as the original scheduler
+ scheduler_for_eval.initialize_modules(
+ chat_llm=mos.mem_scheduler.chat_llm,
+ process_llm=mos.mem_scheduler.process_llm,
+ db_engine=mos.mem_scheduler.db_engine,
+ )
+ # Set the same context
+ scheduler_for_eval.current_mem_cube = mos.mem_cubes[user_id]
+ scheduler_for_eval.current_mem_cube_id = user_id
+ scheduler_for_eval.current_user_id = user_id
+
+ # Replace the original scheduler
+ mos.mem_scheduler = scheduler_for_eval
+
+ return mos
+
+ def locomo_response(self, frame, llm_client, context: str, question: str) -> str:
+ if frame == ZEP_MODEL:
+ prompt = ANSWER_PROMPT_ZEP.format(
+ context=context,
+ question=question,
+ )
+ elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL):
+ prompt = ANSWER_PROMPT_MEM0.format(
+ context=context,
+ question=question,
+ )
+ elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ prompt = ANSWER_PROMPT_MEMOS.format(
+ context=context,
+ question=question,
+ )
+ else:
+ raise NotImplementedError()
+ response = llm_client.chat.completions.create(
+ model=self.openai_chat_model,
+ messages=[
+ {"role": "system", "content": prompt},
+ ],
+ temperature=0,
+ )
+
+ result = response.choices[0].message.content or ""
+
+ if result == "":
+ with self.stats_lock:
+ self.stats[self.frame][self.version]["response_stats"]["response_failure"] += 1
+ self.stats[self.frame][self.version]["response_stats"]["response_count"] += 1
+ return result
diff --git a/evaluation/scripts/temporal_locomo/modules/constants.py b/evaluation/scripts/temporal_locomo/modules/constants.py
new file mode 100644
index 00000000..51ad7c72
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/constants.py
@@ -0,0 +1,19 @@
+import sys
+
+from pathlib import Path
+
+from memos.log import get_logger
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+ZEP_MODEL = "zep"
+MEM0_MODEL = "mem0"
+MEM0_GRAPH_MODEL = "mem0_graph"
+MEMOS_MODEL = "memos"
+MEMOS_SCHEDULER_MODEL = "memos_scheduler"
diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py
new file mode 100644
index 00000000..c824fe5f
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py
@@ -0,0 +1,553 @@
+import json
+import time
+import traceback
+
+from collections import defaultdict
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+from openai import OpenAI
+from tqdm import tqdm
+
+from memos.log import get_logger
+
+from .client_manager import EvalModuleWithClientManager
+from .constants import (
+ MEMOS_MODEL,
+ MEMOS_SCHEDULER_MODEL,
+)
+from .prompts import (
+ CONTEXT_ANSWERABILITY_PROMPT,
+ SEARCH_PROMPT_MEM0,
+ SEARCH_PROMPT_MEM0_GRAPH,
+ SEARCH_PROMPT_MEMOS,
+ SEARCH_PROMPT_ZEP,
+)
+from .utils import filter_memory_data
+
+
+if TYPE_CHECKING:
+ from memos.mem_os.main import MOS
+logger = get_logger(__name__)
+
+
+class LocomoEvalModelModules(EvalModuleWithClientManager):
+ """
+ Contains search methods for different memory frameworks.
+ """
+
+ def __init__(self, args):
+ super().__init__(args=args)
+ self.pre_context_cache = {}
+
+ def analyze_context_answerability(self, context, query, gold_answer, oai_client):
+ """
+ Analyze whether the given context can answer the query.
+
+ Args:
+ context: The context string to analyze
+ query: The query string
+ oai_client: OpenAI client for LLM analysis
+
+ Returns:
+ bool: True if context can answer the query, False otherwise
+ """
+ try:
+ prompt = CONTEXT_ANSWERABILITY_PROMPT.format(
+ context=context, question=query, gold_answer=str(gold_answer)
+ )
+
+ response = oai_client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.1,
+ max_tokens=10,
+ )
+
+ answer = response.choices[0].message.content.strip().upper()
+ return answer == "YES"
+ except Exception as e:
+ logger.error(f"Error analyzing context answerability: {e}")
+ return False
+
+ def mem0_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20):
+ """
+ Search memories using the mem0 framework.
+
+ Args:
+ client: mem0 client instance
+ query: Search query string
+ speaker_a_user_id: User ID for first speaker
+ speaker_b_user_id: User ID for second speaker
+ top_k: Number of results to retrieve
+
+ Returns:
+ Tuple containing formatted context and search duration in milliseconds
+ """
+ start = time.time()
+ search_speaker_a_results = client.search(
+ query=query,
+ top_k=top_k,
+ user_id=speaker_a_user_id,
+ output_format="v1.1",
+ version="v2",
+ filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]},
+ )
+ search_speaker_b_results = client.search(
+ query=query,
+ top_k=top_k,
+ user_id=speaker_b_user_id,
+ output_format="v1.1",
+ version="v2",
+ filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]},
+ )
+
+ # Format speaker A memories
+ search_speaker_a_memory = [
+ {
+ "memory": memory["memory"],
+ "timestamp": memory["created_at"],
+ "score": round(memory["score"], 2),
+ }
+ for memory in search_speaker_a_results["results"]
+ ]
+
+ search_speaker_a_memory = [
+ [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory]
+ ]
+
+ # Format speaker B memories
+ search_speaker_b_memory = [
+ {
+ "memory": memory["memory"],
+ "timestamp": memory["created_at"],
+ "score": round(memory["score"], 2),
+ }
+ for memory in search_speaker_b_results["results"]
+ ]
+
+ search_speaker_b_memory = [
+ [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory]
+ ]
+
+ # Create context using template
+ context = SEARCH_PROMPT_MEM0.format(
+ speaker_1_user_id=speaker_a_user_id.split("_")[0],
+ speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4),
+ speaker_2_user_id=speaker_b_user_id.split("_")[0],
+ speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4),
+ )
+
+ duration_ms = (time.time() - start) * 1000
+ return context, duration_ms
+
+ def memos_search(self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None):
+ """
+ Search memories using the memos framework.
+
+ Args:
+ client: memos client instance
+ query: Search query string
+ conv_id: Conversation ID
+ speaker_a: First speaker identifier
+ speaker_b: Second speaker identifier
+ reversed_client: Client instance for reversed speaker context
+
+ Returns:
+ Tuple containing formatted context and search duration in milliseconds
+ """
+ start = time.time()
+ # Search memories for speaker A
+ search_a_results = client.search(
+ query=query,
+ user_id=conv_id + "_speaker_a",
+ )
+ filtered_search_a_results = filter_memory_data(search_a_results)["text_mem"][0]["memories"]
+ speaker_a_context = ""
+ for item in filtered_search_a_results:
+ speaker_a_context += f"{item['memory']}\n"
+
+ # Search memories for speaker B
+ search_b_results = reversed_client.search(
+ query=query,
+ user_id=conv_id + "_speaker_b",
+ )
+ filtered_search_b_results = filter_memory_data(search_b_results)["text_mem"][0]["memories"]
+ speaker_b_context = ""
+ for item in filtered_search_b_results:
+ speaker_b_context += f"{item['memory']}\n"
+
+ # Create context using template
+ context = SEARCH_PROMPT_MEMOS.format(
+ speaker_1=speaker_a,
+ speaker_1_memories=speaker_a_context,
+ speaker_2=speaker_b,
+ speaker_2_memories=speaker_b_context,
+ )
+
+ duration_ms = (time.time() - start) * 1000
+ return context, duration_ms
+
+ def memos_scheduler_search(
+ self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None
+ ):
+ start = time.time()
+ client: MOS = client
+
+ # Search for speaker A
+ search_a_results = client.mem_scheduler.search_for_eval(
+ query=query,
+ user_id=conv_id + "_speaker_a",
+ top_k=client.config.top_k,
+ scheduler_flag=self.scheduler_flag,
+ )
+
+ # Search for speaker B
+ search_b_results = reversed_client.mem_scheduler.search_for_eval(
+ query=query,
+ user_id=conv_id + "_speaker_b",
+ top_k=client.config.top_k,
+ scheduler_flag=self.scheduler_flag,
+ )
+
+ speaker_a_context = ""
+ for item in search_a_results:
+ speaker_a_context += f"{item}\n"
+
+ speaker_b_context = ""
+ for item in search_b_results:
+ speaker_b_context += f"{item}\n"
+
+ context = SEARCH_PROMPT_MEMOS.format(
+ speaker_1=speaker_a,
+ speaker_1_memories=speaker_a_context,
+ speaker_2=speaker_b,
+ speaker_2_memories=speaker_b_context,
+ )
+
+ logger.info(f'query "{query[:100]}", context: {context[:100]}"')
+ duration_ms = (time.time() - start) * 1000
+
+ return context, duration_ms
+
+ def mem0_graph_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20):
+ start = time.time()
+ search_speaker_a_results = client.search(
+ query=query,
+ top_k=top_k,
+ user_id=speaker_a_user_id,
+ output_format="v1.1",
+ version="v2",
+ enable_graph=True,
+ filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]},
+ )
+ search_speaker_b_results = client.search(
+ query=query,
+ top_k=top_k,
+ user_id=speaker_b_user_id,
+ output_format="v1.1",
+ version="v2",
+ enable_graph=True,
+ filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]},
+ )
+
+ search_speaker_a_memory = [
+ {
+ "memory": memory["memory"],
+ "timestamp": memory["created_at"],
+ "score": round(memory["score"], 2),
+ }
+ for memory in search_speaker_a_results["results"]
+ ]
+
+ search_speaker_a_memory = [
+ [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory]
+ ]
+
+ search_speaker_b_memory = [
+ {
+ "memory": memory["memory"],
+ "timestamp": memory["created_at"],
+ "score": round(memory["score"], 2),
+ }
+ for memory in search_speaker_b_results["results"]
+ ]
+
+ search_speaker_b_memory = [
+ [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory]
+ ]
+
+ search_speaker_a_graph = [
+ {
+ "source": relation["source"],
+ "relationship": relation["relationship"],
+ "target": relation["target"],
+ }
+ for relation in search_speaker_a_results["relations"]
+ ]
+
+ search_speaker_b_graph = [
+ {
+ "source": relation["source"],
+ "relationship": relation["relationship"],
+ "target": relation["target"],
+ }
+ for relation in search_speaker_b_results["relations"]
+ ]
+ context = SEARCH_PROMPT_MEM0_GRAPH.format(
+ speaker_1_user_id=speaker_a_user_id.split("_")[0],
+ speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4),
+ speaker_1_graph_memories=json.dumps(search_speaker_a_graph, indent=4),
+ speaker_2_user_id=speaker_b_user_id.split("_")[0],
+ speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4),
+ speaker_2_graph_memories=json.dumps(search_speaker_b_graph, indent=4),
+ )
+ print(query, context)
+ duration_ms = (time.time() - start) * 1000
+ return context, duration_ms
+
+ def zep_search(self, client, query, group_id, top_k=20):
+ start = time.time()
+ nodes_result = client.graph.search(
+ query=query,
+ group_id=group_id,
+ scope="nodes",
+ reranker="rrf",
+ limit=top_k,
+ )
+ edges_result = client.graph.search(
+ query=query,
+ group_id=group_id,
+ scope="edges",
+ reranker="cross_encoder",
+ limit=top_k,
+ )
+
+ nodes = nodes_result.nodes
+ edges = edges_result.edges
+
+ facts = [f" - {edge.fact} (event_time: {edge.valid_at})" for edge in edges]
+ entities = [f" - {node.name}: {node.summary}" for node in nodes]
+
+ context = SEARCH_PROMPT_ZEP.format(facts="\n".join(facts), entities="\n".join(entities))
+
+ duration_ms = (time.time() - start) * 1000
+
+ return context, duration_ms
+
+ def search_query(self, client, query, metadata, frame, reversed_client=None, top_k=20):
+ conv_id = metadata.get("conv_id")
+ speaker_a = metadata.get("speaker_a")
+ speaker_b = metadata.get("speaker_b")
+ speaker_a_user_id = metadata.get("speaker_a_user_id")
+ speaker_b_user_id = metadata.get("speaker_b_user_id")
+
+ if frame == "zep":
+ context, duration_ms = self.zep_search(client, query, conv_id, top_k)
+ elif frame == "mem0":
+ context, duration_ms = self.mem0_search(
+ client, query, speaker_a_user_id, speaker_b_user_id, top_k
+ )
+ elif frame == "mem0_graph":
+ context, duration_ms = self.mem0_graph_search(
+ client, query, speaker_a_user_id, speaker_b_user_id, top_k
+ )
+ elif frame == "memos":
+ context, duration_ms = self.memos_search(
+ client, query, conv_id, speaker_a, speaker_b, reversed_client
+ )
+ elif frame == "memos_scheduler":
+ context, duration_ms = self.memos_scheduler_search(
+ client, query, conv_id, speaker_a, speaker_b, reversed_client
+ )
+ else:
+ raise NotImplementedError()
+
+ return context, duration_ms
+
+ def _initialize_conv_stats(self):
+ """Create a fresh statistics dictionary for a conversation."""
+ return {
+ "total_queries": 0,
+ "can_answer_count": 0,
+ "cannot_answer_count": 0,
+ "answer_hit_rate": 0.0,
+ "response_failure": 0,
+ "response_count": 0,
+ }
+
+ def _build_day_groups(self, temporal_conv):
+ """Build mapping day_id -> qa_pairs from a temporal conversation dict."""
+ day_groups = {}
+ for day_id, day_data in temporal_conv.get("days", {}).items():
+ day_groups[day_id] = day_data.get("qa_pairs", [])
+ return day_groups
+
+ def _build_metadata(self, speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id):
+ """Assemble metadata for downstream calls."""
+ return {
+ "speaker_a": speaker_a,
+ "speaker_b": speaker_b,
+ "speaker_a_user_id": speaker_a_user_id,
+ "speaker_b_user_id": speaker_b_user_id,
+ "conv_id": conv_id,
+ }
+
+ def _get_clients(self, frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k):
+ """Return (client, reversed_client) according to the target frame."""
+ reversed_client = None
+ if frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
+ client = self.get_client_from_storage(frame, speaker_a_user_id, version, top_k=top_k)
+ reversed_client = self.get_client_from_storage(
+ frame, speaker_b_user_id, version, top_k=top_k
+ )
+ else:
+ client = self.get_client_from_storage(frame, conv_id, version)
+ return client, reversed_client
+
+ def _save_conv_stats(self, conv_id, frame, version, conv_stats, conv_stats_path):
+ """Persist per-conversation stats to disk."""
+ conv_stats_data = {
+ "conversation_id": conv_id,
+ "frame": frame,
+ "version": version,
+ "statistics": conv_stats,
+ "timestamp": str(datetime.now()),
+ }
+ with open(conv_stats_path, "w") as fw:
+ json.dump(conv_stats_data, fw, indent=2, ensure_ascii=False)
+ print(f"Saved conversation stats for {conv_id} to {conv_stats_path}")
+
+ def _write_user_search_results(self, user_search_path, search_results, conv_id):
+ """Write per-user search results to a temporary JSON file."""
+ with open(user_search_path, "w") as fw:
+ json.dump(dict(search_results), fw, indent=2)
+ print(f"Save search results {conv_id}")
+
+ def process_user(self, conv_id, locomo_df, frame, version, top_k=20):
+ user_search_path = self.result_dir / f"tmp/{frame}_locomo_search_results_{conv_id}.json"
+ user_search_path.parent.mkdir(exist_ok=True, parents=True)
+ search_results = defaultdict(list)
+ response_results = defaultdict(list)
+ conv_stats_path = self.stats_dir / f"{frame}_{version}_conv_{conv_id}_stats.json"
+
+ conversation = locomo_df["conversation"].iloc[conv_id]
+ speaker_a = conversation.get("speaker_a", "speaker_a")
+ speaker_b = conversation.get("speaker_b", "speaker_b")
+
+ # Use temporal_locomo data if available, otherwise fall back to original locomo data
+ temporal_conv = self.temporal_locomo_data[conv_id]
+ conv_id = temporal_conv["conversation_id"]
+ speaker_a_user_id = f"{conv_id}_speaker_a"
+ speaker_b_user_id = f"{conv_id}_speaker_b"
+
+ # Process temporal data by days
+ day_groups = {}
+ for day_id, day_data in temporal_conv["days"].items():
+ day_groups[day_id] = day_data["qa_pairs"]
+
+ # Initialize conversation-level statistics
+ conv_stats = self._initialize_conv_stats()
+
+ metadata = self._build_metadata(
+ speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id
+ )
+
+ client, reversed_client = self._get_clients(
+ frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k
+ )
+
+ oai_client = OpenAI(api_key=self.openai_api_key, base_url=self.openai_base_url)
+
+ with self.stats_lock:
+ self.pre_context_cache[conv_id] = None
+
+ def process_qa(qa):
+ return self._process_single_qa(
+ qa,
+ client=client,
+ reversed_client=reversed_client,
+ metadata=metadata,
+ frame=frame,
+ version=version,
+ conv_id=conv_id,
+ conv_stats_path=conv_stats_path,
+ oai_client=oai_client,
+ top_k=top_k,
+ conv_stats=conv_stats,
+ )
+
+ # ===================================
+ conv_stats["theoretical_total_queries"] = 0
+ for day, qa_list in day_groups.items():
+ conv_stats["theoretical_total_queries"] += len(qa_list) - 1
+ conv_stats["processing_failure_count"] = 0
+ print(f"Processing user {conv_id} day {day}")
+ for qa in tqdm(qa_list, desc=f"Processing user {conv_id} day {day}"):
+ try:
+ result = process_qa(qa)
+ except Exception as e:
+ logger.error(f"Error: {e}. traceback: {traceback.format_exc()}")
+ conv_stats["processing_failure_count"] += 1
+ continue
+ if result:
+ context_preview = (
+ result["search_context"][:20] + "..."
+ if result["search_context"]
+ else "No context"
+ )
+ if "can_answer" in result:
+ logger.info("Print can_answer case")
+ logger.info(
+ {
+ "question": result["question"][:100],
+ "pre context can answer": result["can_answer"],
+ "answer": result["answer"][:100],
+ "golden_answer": result["golden_answer"],
+ "search_context": context_preview[:100],
+ "search_duration_ms": result["search_duration_ms"],
+ }
+ )
+
+ search_results[conv_id].append(
+ {
+ "question": result["question"],
+ "context": result["search_context"],
+ "search_duration_ms": result["search_duration_ms"],
+ }
+ )
+ response_results[conv_id].append(result)
+
+ logger.warning(
+ f"Finished processing user {conv_id} day {day}, data_length: {len(qa_list)}"
+ )
+
+ # recording separate search results
+ with open(user_search_path, "w") as fw:
+ json.dump(dict(search_results), fw, indent=2)
+ print(f"Save search results {conv_id}")
+
+ # Dump stats after processing each user
+ self.save_stats()
+
+ return search_results, response_results
+
+ def process_user_wrapper(self, args):
+ """
+ Wraps the process_user function to support parallel execution and error handling.
+
+ Args:
+ args: Tuple containing parameters for process_user
+
+ Returns:
+ tuple: Contains user results or error information
+ """
+ idx, locomo_df, frame, version, top_k = args
+ try:
+ print(f"Processing user {idx}...")
+ user_search_results, user_response_results = self.process_user(
+ idx, locomo_df, frame, version, top_k
+ )
+ return (user_search_results, user_response_results, None)
+ except Exception as e:
+ return (None, None, (idx, e, traceback.format_exc()))
diff --git a/evaluation/scripts/temporal_locomo/modules/prompts.py b/evaluation/scripts/temporal_locomo/modules/prompts.py
new file mode 100644
index 00000000..c88a8ff2
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/prompts.py
@@ -0,0 +1,219 @@
+CUSTOM_INSTRUCTIONS = """
+Generate personal memories that follow these guidelines:
+
+1. Each memory should be self-contained with complete context, including:
+ - The person's name, do not use "user" while creating memories
+ - Personal details (career aspirations, hobbies, life circumstances)
+ - Emotional states and reactions
+ - Ongoing journeys or future plans
+ - Specific dates when events occurred
+
+2. Include meaningful personal narratives focusing on:
+ - Identity and self-acceptance journeys
+ - Family planning and parenting
+ - Creative outlets and hobbies
+ - Mental health and self-care activities
+ - Career aspirations and education goals
+ - Important life events and milestones
+
+3. Make each memory rich with specific details rather than general statements
+ - Include timeframes (exact dates when possible)
+ - Name specific activities (e.g., "charity race for mental health" rather than just "exercise")
+ - Include emotional context and personal growth elements
+
+4. Extract memories only from user messages, not incorporating assistant responses
+
+5. Format each memory as a paragraph with a clear narrative structure that captures the person's experience, challenges, and aspirations
+"""
+
+SEARCH_PROMPT_ZEP = """
+FACTS and ENTITIES represent relevant context to the current conversation.
+
+# These are the most relevant facts for the conversation along with the datetime of the event that the fact refers to.
+If a fact mentions something happening a week ago, then the datetime will be the date time of last week and not the datetime
+of when the fact was stated.
+Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message.
+
+
+{facts}
+
+
+# These are the most relevant entities
+# ENTITY_NAME: entity summary
+
+{entities}
+
+"""
+
+SEARCH_PROMPT_MEM0 = """Memories for user {speaker_1_user_id}:
+
+ {speaker_1_memories}
+
+ Memories for user {speaker_2_user_id}:
+
+ {speaker_2_memories}
+"""
+
+SEARCH_PROMPT_MEM0_GRAPH = """Memories for user {speaker_1_user_id}:
+
+ {speaker_1_memories}
+
+ Relations for user {speaker_1_user_id}:
+
+ {speaker_1_graph_memories}
+
+ Memories for user {speaker_2_user_id}:
+
+ {speaker_2_memories}
+
+ Relations for user {speaker_2_user_id}:
+
+ {speaker_2_graph_memories}
+"""
+
+SEARCH_PROMPT_MEMOS = """Memories for user {speaker_1}:
+
+ {speaker_1_memories}
+
+ Memories for user {speaker_2}:
+
+ {speaker_2_memories}
+"""
+
+
+ANSWER_PROMPT_MEM0 = """
+ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories.
+
+ # CONTEXT:
+ You have access to memories from two speakers in a conversation. These memories contain
+ timestamped information that may be relevant to answering the question.
+
+ # INSTRUCTIONS:
+ 1. Carefully analyze all provided memories from both speakers
+ 2. Pay special attention to the timestamps to determine the answer
+ 3. If the question asks about a specific event or fact, look for direct evidence in the memories
+ 4. If the memories contain contradictory information, prioritize the most recent memory
+ 5. If there is a question about time references (like "last year", "two months ago", etc.),
+ calculate the actual date based on the memory timestamp. For example, if a memory from
+ 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
+ 6. Always convert relative time references to specific dates, months, or years. For example,
+ convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory
+ timestamp. Ignore the reference while answering the question.
+ 7. Focus only on the content of the memories from both speakers. Do not confuse character
+ names mentioned in memories with the actual users who created those memories.
+ 8. The answer should be less than 5-6 words.
+
+ # APPROACH (Think step by step):
+ 1. First, examine all memories that contain information related to the question
+ 2. Examine the timestamps and content of these memories carefully
+ 3. Look for explicit mentions of dates, times, locations, or events that answer the question
+ 4. If the answer requires calculation (e.g., converting relative time references), show your work
+ 5. Formulate a precise, concise answer based solely on the evidence in the memories
+ 6. Double-check that your answer directly addresses the question asked
+ 7. Ensure your final answer is specific and avoids vague time references
+
+ {context}
+
+ Question: {question}
+
+ Answer:
+ """
+
+
+ANSWER_PROMPT_ZEP = """
+ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories.
+
+ # CONTEXT:
+ You have access to memories from a conversation. These memories contain
+ timestamped information that may be relevant to answering the question.
+
+ # INSTRUCTIONS:
+ 1. Carefully analyze all provided memories
+ 2. Pay special attention to the timestamps to determine the answer
+ 3. If the question asks about a specific event or fact, look for direct evidence in the memories
+ 4. If the memories contain contradictory information, prioritize the most recent memory
+ 5. If there is a question about time references (like "last year", "two months ago", etc.),
+ calculate the actual date based on the memory timestamp. For example, if a memory from
+ 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
+ 6. Always convert relative time references to specific dates, months, or years. For example,
+ convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory
+ timestamp. Ignore the reference while answering the question.
+ 7. Focus only on the content of the memories. Do not confuse character
+ names mentioned in memories with the actual users who created those memories.
+ 8. The answer should be less than 5-6 words.
+
+ # APPROACH (Think step by step):
+ 1. First, examine all memories that contain information related to the question
+ 2. Examine the timestamps and content of these memories carefully
+ 3. Look for explicit mentions of dates, times, locations, or events that answer the question
+ 4. If the answer requires calculation (e.g., converting relative time references), show your work
+ 5. Formulate a precise, concise answer based solely on the evidence in the memories
+ 6. Double-check that your answer directly addresses the question asked
+ 7. Ensure your final answer is specific and avoids vague time references
+
+ Context:
+
+ {context}
+
+ Question: {question}
+ Answer:
+ """
+
+ANSWER_PROMPT_MEMOS = """
+ You are a knowledgeable and helpful AI assistant.
+
+ # CONTEXT:
+ You have access to memories from two speakers in a conversation. These memories contain
+ timestamped information that may be relevant to answering the question.
+
+ # INSTRUCTIONS:
+ 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer.
+ 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth.
+ 3. If the question asks about a specific event or fact, look for direct evidence in the memories.
+ 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description).
+ 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
+ 6. Always convert relative time references to specific dates, months, or years in your final answer.
+ 7. Do not confuse character names mentioned in memories with the actual users who created them.
+ 8. The answer must be brief (under 5-6 words) and direct, with no extra description.
+
+ # APPROACH (Think step by step):
+ 1. First, examine all memories that contain information related to the question.
+ 2. Synthesize findings from multiple memories if a single entry is insufficient.
+ 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events.
+ 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation.
+ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge).
+ 6. Double-check that your answer directly addresses the question asked and adheres to all instructions.
+ 7. Ensure your final answer is specific and avoids vague time references.
+
+ {context}
+
+ Question: {question}
+
+ Answer:
+ """
+
+CONTEXT_ANSWERABILITY_PROMPT = """
+You are an AI assistant that analyzes whether given context can answer a specific question, considering the ground-truth answer.
+
+# TASK:
+Analyze the provided context and determine if it contains sufficient information to answer the given question. Use the provided ground-truth answer to guide your judgment: if the context contains the necessary evidence to derive that answer (explicitly or via direct inference), respond YES; otherwise respond NO.
+
+# INSTRUCTIONS:
+1. Carefully examine the context provided
+2. Identify if the context contains information directly related to the question
+3. Determine if the information is sufficient to provide a complete answer that matches the ground-truth
+4. Consider both explicit mentions and straightforward implications present in the context
+5. Return only "YES" if the context can yield the ground-truth answer, "NO" if it cannot
+
+# CONTEXT:
+{context}
+
+# QUESTION:
+{question}
+
+# GROUND_TRUTH_ANSWER:
+{gold_answer}
+
+# ANALYSIS:
+Can this context answer the question and support the ground-truth answer? (YES/NO):
+"""
diff --git a/evaluation/scripts/temporal_locomo/modules/schemas.py b/evaluation/scripts/temporal_locomo/modules/schemas.py
new file mode 100644
index 00000000..e5872c35
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/schemas.py
@@ -0,0 +1,141 @@
+from enum import Enum
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+
+class ContextUpdateMethod(Enum):
+ """Enumeration for context update methods"""
+
+ DIRECT = "direct" # Directly update with current context
+ TEMPLATE = "chat_history" # Update using template with history queries and answers
+
+
+class RecordingCase(BaseModel):
+ """
+ Data structure for recording evaluation cases in temporal locomo evaluation.
+
+ This schema represents a single evaluation case containing conversation history,
+ context information, memory data, and evaluation results.
+ """
+
+ # Conversation identification
+ conv_id: str = Field(description="Conversation identifier for this evaluation case")
+
+ # Conversation history and context
+ history_queries: list[str] = Field(
+ default_factory=list, description="List of previous queries in the conversation history"
+ )
+
+ context: str = Field(
+ default="",
+ description="Current search context retrieved from memory systems for answering the query",
+ )
+
+ pre_context: str | None = Field(
+ default=None,
+ description="Previous context from the last query, used for answerability analysis",
+ )
+
+ # Query and answer information
+ query: str = Field(description="The current question/query being evaluated")
+
+ answer: str = Field(description="The generated answer for the query")
+
+ # Memory data
+ memories: list[Any] = Field(
+ default_factory=list,
+ description="Current memories retrieved from the memory system for this query",
+ )
+
+ pre_memories: list[Any] | None = Field(
+ default=None, description="Previous memories from the last query, used for comparison"
+ )
+
+ # Evaluation metrics
+ can_answer: bool | None = Field(
+ default=None,
+ description="Whether the context can answer the query (only for memos_scheduler frame)",
+ )
+
+ can_answer_reason: str | None = Field(
+ default=None, description="Reasoning for the can_answer decision"
+ )
+
+ # Additional metadata
+ category: int | None = Field(
+ default=None, description="Category of the query (1-4, where 5 is filtered out)"
+ )
+
+ golden_answer: str | None = Field(
+ default=None, description="Ground truth answer for evaluation"
+ )
+
+ search_duration_ms: float | None = Field(
+ default=None, description="Time taken for memory search in milliseconds"
+ )
+
+ response_duration_ms: float | None = Field(
+ default=None, description="Time taken for response generation in milliseconds"
+ )
+
+ can_answer_duration_ms: float | None = Field(
+ default=None, description="Time taken for answerability analysis in milliseconds"
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Convert the RecordingCase to a dictionary for serialization.
+
+ Returns:
+ Dict[str, Any]: Dictionary representation of the RecordingCase
+ """
+ return self.dict()
+
+ def to_json(self, indent: int = 2) -> str:
+ """
+ Convert the RecordingCase to a JSON string.
+
+ Args:
+ indent: JSON indentation level
+
+ Returns:
+ str: JSON string representation of the RecordingCase
+ """
+ return self.json(indent=indent, ensure_ascii=False)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "RecordingCase":
+ """
+ Create a RecordingCase from a dictionary.
+
+ Args:
+ data: Dictionary containing RecordingCase data
+
+ Returns:
+ RecordingCase: New instance created from the dictionary
+ """
+ return cls(**data)
+
+ @classmethod
+ def from_json(cls, json_str: str) -> "RecordingCase":
+ """
+ Create a RecordingCase from a JSON string.
+
+ Args:
+ json_str: JSON string containing RecordingCase data
+
+ Returns:
+ RecordingCase: New instance created from the JSON string
+ """
+ import json
+
+ data = json.loads(json_str)
+ return cls.from_dict(data)
+
+ class Config:
+ """Pydantic configuration"""
+
+ extra = "allow" # Allow additional fields not defined in the schema
+ validate_assignment = True # Validate on assignment
+ use_enum_values = True # Use enum values instead of enum names
diff --git a/evaluation/scripts/temporal_locomo/modules/utils.py b/evaluation/scripts/temporal_locomo/modules/utils.py
new file mode 100644
index 00000000..215bc425
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/modules/utils.py
@@ -0,0 +1,296 @@
+import json
+
+from pathlib import Path
+
+from .schemas import RecordingCase
+
+
+def filter_memory_data(memories_data):
+ filtered_data = {}
+ for key, value in memories_data.items():
+ if key == "text_mem":
+ filtered_data[key] = []
+ for mem_group in value:
+ # Check if it's the new data structure (list of TextualMemoryItem objects)
+ if "memories" in mem_group and isinstance(mem_group["memories"], list):
+ # New data structure: directly a list of TextualMemoryItem objects
+ filtered_memories = []
+ for memory_item in mem_group["memories"]:
+ # Create filtered dictionary
+ filtered_item = {
+ "id": memory_item.id,
+ "memory": memory_item.memory,
+ "metadata": {},
+ }
+ # Filter metadata, excluding embedding
+ if hasattr(memory_item, "metadata") and memory_item.metadata:
+ for attr_name in dir(memory_item.metadata):
+ if not attr_name.startswith("_") and attr_name != "embedding":
+ attr_value = getattr(memory_item.metadata, attr_name)
+ if not callable(attr_value):
+ filtered_item["metadata"][attr_name] = attr_value
+ filtered_memories.append(filtered_item)
+
+ filtered_group = {
+ "cube_id": mem_group.get("cube_id", ""),
+ "memories": filtered_memories,
+ }
+ filtered_data[key].append(filtered_group)
+ else:
+ # Old data structure: dictionary with nodes and edges
+ filtered_group = {
+ "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])}
+ }
+ for node in mem_group["memories"].get("nodes", []):
+ filtered_node = {
+ "id": node.get("id"),
+ "memory": node.get("memory"),
+ "metadata": {
+ k: v
+ for k, v in node.get("metadata", {}).items()
+ if k != "embedding"
+ },
+ }
+ filtered_group["memories"]["nodes"].append(filtered_node)
+ filtered_data[key].append(filtered_group)
+ else:
+ filtered_data[key] = value
+ return filtered_data
+
+
+def save_recording_cases(
+ cases: list[RecordingCase], output_dir: str | Path, filename: str = "recording_cases.json"
+) -> Path:
+ """
+ Save a list of RecordingCase objects to a JSON file.
+
+ Args:
+ cases: List of RecordingCase objects to save
+ output_dir: Directory to save the file
+ filename: Name of the output file (default: "recording_cases.json")
+
+ Returns:
+ Path: Path to the saved file
+ """
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ file_path = output_dir / filename
+
+ # Convert cases to dictionaries for JSON serialization
+ cases_data = [case.to_dict() for case in cases]
+
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(cases_data, f, indent=2, ensure_ascii=False)
+
+ return file_path
+
+
+def load_recording_cases(file_path: str | Path) -> list[RecordingCase]:
+ """
+ Load RecordingCase objects from a JSON file.
+
+ Args:
+ file_path: Path to the JSON file containing RecordingCase data
+
+ Returns:
+ List[RecordingCase]: List of RecordingCase objects loaded from the file
+ """
+ file_path = Path(file_path)
+
+ with open(file_path, encoding="utf-8") as f:
+ cases_data = json.load(f)
+
+ return [RecordingCase.from_dict(case_data) for case_data in cases_data]
+
+
+def save_evaluation_cases(
+ can_answer_cases: list[RecordingCase],
+ cannot_answer_cases: list[RecordingCase],
+ output_dir: str | Path,
+ frame: str = "default",
+ version: str = "default",
+) -> dict[str, Path]:
+ """
+ Save both can_answer_cases and cannot_answer_cases to separate JSON files.
+
+ Args:
+ can_answer_cases: List of cases that can be answered
+ cannot_answer_cases: List of cases that cannot be answered
+ output_dir: Directory to save the files
+ frame: Framework name for filename prefix
+ version: Version identifier for filename
+
+ Returns:
+ Dict[str, Path]: Dictionary mapping case type to saved file path
+ """
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ saved_files = {}
+
+ # Save can_answer_cases
+ if can_answer_cases:
+ can_answer_filename = f"{frame}_{version}_can_answer_cases.json"
+ can_answer_path = save_recording_cases(can_answer_cases, output_dir, can_answer_filename)
+ saved_files["can_answer_cases"] = can_answer_path
+ print(f"Saved {len(can_answer_cases)} can_answer_cases to {can_answer_path}")
+
+ # Save cannot_answer_cases
+ if cannot_answer_cases:
+ cannot_answer_filename = f"{frame}_{version}_cannot_answer_cases.json"
+ cannot_answer_path = save_recording_cases(
+ cannot_answer_cases, output_dir, cannot_answer_filename
+ )
+ saved_files["cannot_answer_cases"] = cannot_answer_path
+ print(f"Saved {len(cannot_answer_cases)} cannot_answer_cases to {cannot_answer_path}")
+
+ return saved_files
+
+
+def compute_can_answer_stats(day_groups, rounds_to_consider=float("inf")):
+ """
+ Compute can-answer statistics for each day using the union of all prior evidences.
+
+ For each day, iterate over the QAs in the given order. If the current QA's
+ evidences (restricted to the same day) are a subset of the union of all
+ previously seen evidences for that day, increment can_answer_count. Then add
+ the current evidences to the seen set.
+
+ Note:
+ The first QA of each day is excluded from the statistics because it
+ cannot be answered without any prior evidences. It is still used to
+ seed the seen evidences for subsequent QAs.
+
+ Args:
+ day_groups: Dict mapping day_id (e.g., "D1") to a list of QA dicts. Each QA
+ dict should contain an "evidence" field that is a list of strings.
+ rounds_to_consider: Number of previous rounds to consider for evidence accumulation.
+ Default is infinity (all previous rounds).
+ Set to 1 to only consider the immediately preceding round.
+
+ Returns:
+ dict: Mapping day_id -> {"can_answer_count": int, "total": int, "ratio": float}
+ """
+ results = {}
+ for day, qa_list in day_groups.items():
+ seen = set()
+ # Keep track of evidence history for limited rounds
+ evidence_history = []
+ can_answer = 0
+ total = max(len(qa_list) - 1, 0)
+ rounds_count = 0
+ for idx, qa in enumerate(qa_list):
+ cur = set(qa.get("evidence", []))
+ rounds_count += 1
+
+ if idx == 0:
+ # Seed seen evidences with the first QA but do not count it
+ evidence_history.append(cur)
+ seen = set().union(*evidence_history)
+ continue
+
+ # Check if current evidence is subset of accumulated evidence
+ if cur and cur.issubset(seen):
+ can_answer += 1
+
+ # Add current evidence to history
+ evidence_history.append(cur)
+
+ # Limit history to specified number of rounds
+ if rounds_count > rounds_to_consider:
+ evidence_history.pop(0)
+
+ # Recalculate seen as union of evidence_history
+ seen = set().union(*evidence_history)
+
+ results[day] = {
+ "can_answer_count": can_answer,
+ "total": total,
+ "ratio": (can_answer / total) if total else 0.0,
+ }
+ return results
+
+
+def compute_can_answer_count_by_pre_evidences(
+ temporal_locomo_data, num_of_users, stats_dir=None, rounds_to_consider=float("inf")
+):
+ """
+ Compute can-answer statistics per day for each conversation using the
+ union of all previously asked evidences within the same day.
+
+ Args:
+ temporal_locomo_data: The temporal locomo data containing conversations
+ num_of_users: Number of users/conversations to process
+ stats_dir: Directory to save statistics (optional)
+ rounds_to_consider: Number of previous rounds to consider for evidence accumulation.
+ Default is infinity (all previous rounds).
+ Set to 1 to only consider the immediately preceding round.
+
+ Returns:
+ dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats
+ """
+ all_conversations_stats = {}
+ for conv_idx in range(num_of_users):
+ temporal_conv = temporal_locomo_data[conv_idx]
+ conversation_id = temporal_conv["conversation_id"]
+
+ # Build day -> qa_pairs mapping
+ day_groups = {}
+ for day_id, day_data in temporal_conv.get("days", {}).items():
+ day_groups[day_id] = day_data.get("qa_pairs", [])
+
+ # Use shared utility to compute stats with correct accumulation logic
+ per_day_stats = compute_can_answer_stats(day_groups, rounds_to_consider)
+ all_conversations_stats[conversation_id] = per_day_stats
+
+ # Build per-conversation summaries and overall summary
+ per_conversation_summaries = {}
+ overall_can = 0
+ overall_total = 0
+ for conv_id, day_stats in all_conversations_stats.items():
+ conv_can = 0
+ conv_total = 0
+ for _day, stats in day_stats.items():
+ conv_can += int(stats.get("can_answer_count", 0))
+ conv_total += int(stats.get("total", 0))
+ conv_ratio = (conv_can / conv_total) if conv_total else 0.0
+ per_conversation_summaries[conv_id] = {
+ "can_answer_count": conv_can,
+ "total": conv_total,
+ "ratio": conv_ratio,
+ }
+ overall_can += conv_can
+ overall_total += conv_total
+
+ overall_summary = {
+ "can_answer_count": overall_can,
+ "total": overall_total,
+ "ratio": (overall_can / overall_total) if overall_total else 0.0,
+ }
+
+ # Add rounds information to the result
+ result_payload = {
+ "per_conversation_summary": per_conversation_summaries,
+ "overall_summary": overall_summary,
+ "rounds_considered": rounds_to_consider if rounds_to_consider != float("inf") else "all",
+ }
+
+ # Print results
+ print("\nComputed can-answer-by-pre-evidences stats:")
+ print(
+ f"Rounds considered: {rounds_to_consider if rounds_to_consider != float('inf') else 'all'}"
+ )
+ print(json.dumps(result_payload, indent=2, ensure_ascii=False))
+
+ # Save results if stats_dir is provided
+ if stats_dir:
+ output_path = (
+ stats_dir
+ / f"evidences_rounds_{rounds_to_consider if rounds_to_consider != float('inf') else 'all'}.json"
+ )
+ with open(output_path, "w", encoding="utf-8") as fw:
+ json.dump(result_payload, fw, indent=2, ensure_ascii=False)
+ print(f"Saved stats to {output_path}")
+
+ return result_payload
diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py
new file mode 100644
index 00000000..0a2c20a0
--- /dev/null
+++ b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py
@@ -0,0 +1,156 @@
+import argparse
+import asyncio
+import os
+import sys
+
+from pathlib import Path
+
+from locomo_eval import LocomoEvaluator
+from locomo_ingestion import LocomoIngestor
+from locomo_metric import LocomoMetric
+from locomo_processor import LocomoProcessor
+from modules.locomo_eval_module import LocomoEvalModelModules
+from modules.utils import compute_can_answer_count_by_pre_evidences
+
+from memos.log import get_logger
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+class TemporalLocomoEval(LocomoEvalModelModules):
+ def __init__(self, args):
+ super().__init__(args=args)
+ self.num_of_users = 10
+
+ self.locomo_ingestor = LocomoIngestor(args=args)
+ self.locomo_processor = LocomoProcessor(args=args)
+
+ def run_eval_pipeline(self):
+ """
+ Run the complete evaluation pipeline including dataset conversion,
+ data ingestion, and processing.
+ """
+ print("=" * 80)
+ print("Starting TimeLocomo Evaluation Pipeline")
+ print("=" * 80)
+
+ # Step 1: Check if temporal_locomo dataset exists, if not convert it
+ temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json"
+ if not temporal_locomo_file.exists():
+ print(f"Temporal locomo dataset not found at {temporal_locomo_file}")
+ print("Converting locomo dataset to temporal_locomo format...")
+ self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo")
+ print("Dataset conversion completed.")
+ else:
+ print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.")
+
+ # Step 2: Data ingestion
+ print("\n" + "=" * 50)
+ print("Step 2: Data Ingestion")
+ print("=" * 50)
+ if not self.ingestion_storage_dir.exists() or not any(self.ingestion_storage_dir.iterdir()):
+ print(f"Directory {self.ingestion_storage_dir} not found, starting data ingestion...")
+ self.locomo_ingestor.run_ingestion()
+ print("Data ingestion completed.")
+ else:
+ print(
+ f"Directory {self.ingestion_storage_dir} already exists and is not empty, skipping ingestion."
+ )
+
+ # Step 3: Processing and evaluation
+ print("\n" + "=" * 50)
+ print("Step 3: Processing and Evaluation")
+ print("=" * 50)
+ print("Running locomo processing to search and answer...")
+
+ print("Starting locomo processing to generate search and response results...")
+ self.locomo_processor.run_locomo_processing(num_users=self.num_of_users)
+ print("Processing completed successfully.")
+
+ # Optional: run post-hoc evaluation over generated responses if available
+ try:
+ evaluator = LocomoEvaluator(args=args)
+
+ if os.path.exists(evaluator.response_path):
+ print("Running LocomoEvaluator over existing response results...")
+ asyncio.run(evaluator.run())
+ else:
+ print(
+ f"Skipping LocomoEvaluator: response file not found at {evaluator.response_path}"
+ )
+ # Run metrics summarization if judged file is produced
+ metric = LocomoMetric(args=args)
+ if os.path.exists(metric.judged_path):
+ print("Running LocomoMetric over judged results...")
+ metric.run()
+ else:
+ print(f"Skipping LocomoMetric: judged file not found at {metric.judged_path}")
+ except Exception as e:
+ logger.error(f"LocomoEvaluator step skipped due to error: {e}", exc_info=True)
+
+ # Step 4: Summary
+ print("\n" + "=" * 80)
+ print("Evaluation Pipeline Completed Successfully!")
+ print("=" * 80)
+ print("Results saved to:")
+ print(f" - Search results: {self.search_path}")
+ print(f" - Response results: {self.response_path}")
+ print(f" - Statistics: {self.stats_path}")
+ print("=" * 80)
+
+ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider):
+ """
+ Compute can-answer statistics per day for each conversation using the
+ union of all previously asked evidences within the same day.
+
+ Returns:
+ dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats
+ """
+ return compute_can_answer_count_by_pre_evidences(
+ temporal_locomo_data=self.temporal_locomo_data,
+ num_of_users=self.num_of_users,
+ stats_dir=self.stats_dir,
+ rounds_to_consider=rounds_to_consider,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--frame",
+ type=str,
+ default="memos_scheduler",
+ choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"],
+ help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="v1.0.1",
+ help="Version identifier for saving results (e.g., 1010)",
+ )
+ parser.add_argument(
+ "--workers", type=int, default=10, help="Number of parallel workers to process users"
+ )
+ parser.add_argument(
+ "--top_k", type=int, default=20, help="Number of results to retrieve in search queries"
+ )
+ parser.add_argument(
+ "--scheduler-flag",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Enable or disable memory scheduler features",
+ )
+ args = parser.parse_args()
+
+ evaluator = TemporalLocomoEval(args=args)
+ evaluator.run_eval_pipeline()
+
+ # rule-based baselines
+ evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=float("inf"))
+ evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=1)
diff --git a/examples/basic_modules/reranker.py b/examples/basic_modules/reranker.py
index 3969cc43..47bf1405 100644
--- a/examples/basic_modules/reranker.py
+++ b/examples/basic_modules/reranker.py
@@ -88,6 +88,20 @@ def main():
for it, emb in zip(items, doc_embeddings, strict=False):
it.metadata.embedding = emb
+ items[0].metadata.user_id = "u_123"
+ items[0].metadata.session_id = "s_abc"
+ items[0].metadata.tags = [*items[0].metadata.tags, "paris"]
+
+ items[1].metadata.user_id = "u_124"
+ items[1].metadata.session_id = "s_xyz"
+ items[1].metadata.tags = [*items[1].metadata.tags, "germany"]
+ items[2].metadata.user_id = "u_125"
+ items[2].metadata.session_id = "s_ss3"
+ items[3].metadata.user_id = "u_126"
+ items[3].metadata.session_id = "s_ss4"
+ items[4].metadata.user_id = "u_127"
+ items[4].metadata.session_id = "s_ss5"
+
# -------------------------------
# 4) Rerank with cosine_local (uses your real embeddings)
# -------------------------------
@@ -124,7 +138,7 @@ def main():
"url": bge_url,
"model": os.getenv("BGE_RERANKER_MODEL", "bge-reranker-v2-m3"),
"timeout": int(os.getenv("BGE_RERANKER_TIMEOUT", "10")),
- # "headers_extra": {"Authorization": f"Bearer {os.getenv('BGE_RERANKER_TOKEN')}"}
+ "boost_weights": {"user_id": 0.5, "tags": 0.2},
},
}
)
@@ -136,6 +150,20 @@ def main():
top_k=10,
)
show_ranked("HTTP BGE Reranker (OpenAI-style API)", ranked_http, top_n=5)
+
+ # --- NEW: search_filter with rerank ---
+ # hit rule:
+ # - user_id == "u_123" → score * (1 + 0.5) = 1.5
+ # - tags including "paris" → score * (1 + 0.2) = 1.2
+ # - project_id(not exist) → warning unrelated with score
+ search_filter = {"session_id": "germany", "tags": "germany", "project_id": "demo-p1"}
+ ranked_http_boosted = http_reranker.rerank(
+ query=query,
+ graph_results=items,
+ top_k=10,
+ search_filter=search_filter,
+ )
+ show_ranked("HTTP BGE Reranker (with search_filter boosts)", ranked_http_boosted, top_n=5)
else:
print("\n[Info] Skipped HTTP BGE scenario because BGE_RERANKER_URL is not set.")
diff --git a/examples/data/config/mem_scheduler/mem_cube_config.yaml b/examples/data/config/mem_scheduler/mem_cube_config.yaml
index 76428abb..0dd7e0f3 100644
--- a/examples/data/config/mem_scheduler/mem_cube_config.yaml
+++ b/examples/data/config/mem_scheduler/mem_cube_config.yaml
@@ -34,7 +34,7 @@ act_mem:
config:
memory_filename: "activation_memory.pickle"
extractor_llm:
- backend: "huggingface"
+ backend: "huggingface_singleton"
config:
model_name_or_path: "Qwen/Qwen3-1.7B"
temperature: 0.8
@@ -48,7 +48,7 @@ para_mem:
config:
memory_filename: "parametric_memory.adapter"
extractor_llm:
- backend: "huggingface"
+ backend: "huggingface_singleton"
config:
model_name_or_path: "Qwen/Qwen3-1.7B"
temperature: 0.8
diff --git a/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml
new file mode 100644
index 00000000..2d3958e6
--- /dev/null
+++ b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml
@@ -0,0 +1,51 @@
+user_id: "root"
+chat_model:
+ backend: "huggingface_singleton"
+ config:
+ model_name_or_path: "Qwen/Qwen3-1.7B"
+ temperature: 0.1
+ remove_think_prefix: true
+ max_tokens: 4096
+mem_reader:
+ backend: "simple_struct"
+ config:
+ llm:
+ backend: "openai"
+ config:
+ model_name_or_path: "gpt-4o-mini"
+ temperature: 0.8
+ max_tokens: 4096
+ top_p: 0.9
+ top_k: 50
+ remove_think_prefix: true
+ api_key: "sk-xxxxxx"
+ api_base: "https://api.openai.com/v1"
+ embedder:
+ backend: "ollama"
+ config:
+ model_name_or_path: "nomic-embed-text:latest"
+ chunker:
+ backend: "sentence"
+ config:
+ tokenizer_or_token_counter: "gpt2"
+ chunk_size: 512
+ chunk_overlap: 128
+ min_sentences_per_chunk: 1
+mem_scheduler:
+ backend: "optimized_scheduler"
+ config:
+ top_k: 10
+ act_mem_update_interval: 30
+ context_window_size: 10
+ thread_pool_max_workers: 10
+ consume_interval_seconds: 1
+ working_mem_monitor_capacity: 20
+ activation_mem_monitor_capacity: 5
+ enable_parallel_dispatch: true
+ enable_activation_memory: true
+max_turns_window: 20
+top_k: 5
+enable_textual_memory: true
+enable_activation_memory: true
+enable_parametric_memory: false
+enable_mem_scheduler: true
diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml
index a851ba77..0152d8cd 100644
--- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml
+++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml
@@ -1,6 +1,6 @@
user_id: "root"
chat_model:
- backend: "huggingface"
+ backend: "huggingface_singleton"
config:
model_name_or_path: "Qwen/Qwen3-1.7B"
temperature: 0.1
diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml
index b82dbd2b..cdfa49a7 100644
--- a/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml
+++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml
@@ -1,6 +1,6 @@
user_id: "root"
chat_model:
- backend: "huggingface"
+ backend: "huggingface_singleton"
config:
model_name_or_path: "Qwen/Qwen3-1.7B"
temperature: 0.1
diff --git a/examples/mem_scheduler/debug_text_mem_replace.py b/examples/mem_scheduler/debug_text_mem_replace.py
new file mode 100644
index 00000000..df80f7d0
--- /dev/null
+++ b/examples/mem_scheduler/debug_text_mem_replace.py
@@ -0,0 +1,109 @@
+import json
+import shutil
+import sys
+
+from pathlib import Path
+
+from memos_w_scheduler_for_test import init_task
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.configs.mem_scheduler import AuthConfig
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR))
+
+# Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+if __name__ == "__main__":
+ # set up data
+ conversations, questions = init_task()
+
+ # set configs
+ mos_config = MOSConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml"
+ )
+
+ mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
+ )
+
+ # default local graphdb uri
+ if AuthConfig.default_config_exists():
+ auth_config = AuthConfig.from_local_config()
+
+ mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
+ mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
+
+ mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
+ mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
+ mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
+ mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
+ mem_cube_config.text_mem.config.graph_db.config.auto_create = (
+ auth_config.graph_db.auto_create
+ )
+
+ # Initialization
+ mos = MOSForTestScheduler(mos_config)
+
+ user_id = "user_1"
+ mos.create_user(user_id)
+
+ mem_cube_id = "mem_cube_5"
+ mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
+
+ if Path(mem_cube_name_or_path).exists():
+ shutil.rmtree(mem_cube_name_or_path)
+ print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
+
+ mem_cube = GeneralMemCube(mem_cube_config)
+ mem_cube.dump(mem_cube_name_or_path)
+ mos.register_mem_cube(
+ mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
+ )
+
+ mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
+
+ # Add interfering conversations
+ file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json")
+ scene_data = json.load(file_path.open("r", encoding="utf-8"))
+ mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
+ mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
+
+ # Test the replace_working_memory functionality
+ print("\n--- Testing replace_working_memory ---")
+
+ # Get current working memories
+ text_mem_base = mem_cube.text_mem
+ if text_mem_base is not None:
+ working_memories_before = text_mem_base.get_working_memory()
+ print(f"Working memories before replacement: {len(working_memories_before)}")
+
+ # Create filtered memories (simulate what the scheduler would do)
+ # Keep only memories related to Max
+ filtered_memories = [working_memories_before[1], working_memories_before[4]]
+
+ text_mem_base.replace_working_memory(memories=filtered_memories)
+
+ # Check working memory after replacement
+ working_memories_after = text_mem_base.get_working_memory()
+ print(f"Working memories after replacement: {len(working_memories_after)}")
+
+ if len(working_memories_after) == len(filtered_memories):
+ print("✅ SUCCESS: Working memory count matches filtered memories")
+ else:
+ print(
+ f"❌ FAILED: Expected {len(filtered_memories)}, got {len(working_memories_after)}"
+ )
+
+ else:
+ print("❌ text_mem is None - not properly initialized")
+
+ mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler.py b/examples/mem_scheduler/memos_w_optimized_scheduler.py
new file mode 100644
index 00000000..fbd14536
--- /dev/null
+++ b/examples/mem_scheduler/memos_w_optimized_scheduler.py
@@ -0,0 +1,85 @@
+import shutil
+import sys
+
+from pathlib import Path
+
+from memos_w_scheduler import init_task, show_web_logs
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.configs.mem_scheduler import AuthConfig
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+def run_with_scheduler_init():
+ print("==== run_with_automatic_scheduler_init ====")
+ conversations, questions = init_task()
+
+ # set configs
+ mos_config = MOSConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml"
+ )
+
+ mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
+ )
+
+ # default local graphdb uri
+ if AuthConfig.default_config_exists():
+ auth_config = AuthConfig.from_local_config()
+
+ mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
+ mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
+
+ mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
+ mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
+ mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
+ mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
+ mem_cube_config.text_mem.config.graph_db.config.auto_create = (
+ auth_config.graph_db.auto_create
+ )
+
+ # Initialization
+ mos = MOS(mos_config)
+
+ user_id = "user_1"
+ mos.create_user(user_id)
+
+ mem_cube_id = "mem_cube_5"
+ mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
+
+ if Path(mem_cube_name_or_path).exists():
+ shutil.rmtree(mem_cube_name_or_path)
+ print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
+
+ mem_cube = GeneralMemCube(mem_cube_config)
+ mem_cube.dump(mem_cube_name_or_path)
+ mos.register_mem_cube(
+ mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
+ )
+
+ mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
+
+ for item in questions:
+ print("===== Chat Start =====")
+ query = item["question"]
+ print(f"Query:\n {query}\n")
+ response = mos.chat(query=query, user_id=user_id)
+ print(f"Answer:\n {response}\n")
+
+ show_web_logs(mem_scheduler=mos.mem_scheduler)
+
+ mos.mem_scheduler.stop()
+
+
+if __name__ == "__main__":
+ run_with_scheduler_init()
diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py
new file mode 100644
index 00000000..9b39bf77
--- /dev/null
+++ b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py
@@ -0,0 +1,87 @@
+import json
+import shutil
+import sys
+
+from pathlib import Path
+
+from memos_w_scheduler_for_test import init_task
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.configs.mem_scheduler import AuthConfig
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR))
+
+# Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+if __name__ == "__main__":
+ # set up data
+ conversations, questions = init_task()
+
+ # set configs
+ mos_config = MOSConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml"
+ )
+
+ mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
+ )
+
+ # default local graphdb uri
+ if AuthConfig.default_config_exists():
+ auth_config = AuthConfig.from_local_config()
+
+ mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
+ mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
+
+ mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
+ mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
+ mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
+ mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
+ mem_cube_config.text_mem.config.graph_db.config.auto_create = (
+ auth_config.graph_db.auto_create
+ )
+
+ # Initialization
+ mos = MOSForTestScheduler(mos_config)
+
+ user_id = "user_1"
+ mos.create_user(user_id)
+
+ mem_cube_id = "mem_cube_5"
+ mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
+
+ if Path(mem_cube_name_or_path).exists():
+ shutil.rmtree(mem_cube_name_or_path)
+ print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
+
+ mem_cube = GeneralMemCube(mem_cube_config)
+ mem_cube.dump(mem_cube_name_or_path)
+ mos.register_mem_cube(
+ mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
+ )
+
+ mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
+
+ # Add interfering conversations
+ file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json")
+ scene_data = json.load(file_path.open("r", encoding="utf-8"))
+ mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
+ mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
+
+ for item in questions:
+ print("===== Chat Start =====")
+ query = item["question"]
+ print(f"Query:\n {query}\n")
+ response = mos.chat(query=query, user_id=user_id)
+ print(f"Answer:\n {response}\n")
+
+ mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py
index 63054586..28641507 100644
--- a/examples/mem_scheduler/memos_w_scheduler.py
+++ b/examples/mem_scheduler/memos_w_scheduler.py
@@ -15,7 +15,7 @@
if TYPE_CHECKING:
- from memos.mem_scheduler.schemas import (
+ from memos.mem_scheduler.schemas.message_schemas import (
ScheduleLogForWebItem,
)
diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py
index 074400ee..ddf2dc6d 100644
--- a/examples/mem_scheduler/memos_w_scheduler_for_test.py
+++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py
@@ -1,6 +1,7 @@
import json
import shutil
import sys
+import time
from pathlib import Path
@@ -9,7 +10,7 @@
from memos.configs.mem_scheduler import AuthConfig
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler
+from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
FILE_PATH = Path(__file__).absolute()
@@ -19,6 +20,77 @@
logger = get_logger(__name__)
+def display_memory_cube_stats(mos, user_id, mem_cube_id):
+ """Display detailed memory cube statistics."""
+ print(f"\n📊 MEMORY CUBE STATISTICS for {mem_cube_id}:")
+ print("-" * 60)
+
+ mem_cube = mos.mem_cubes.get(mem_cube_id)
+ if not mem_cube:
+ print(" ❌ Memory cube not found")
+ return
+
+ # Text memory stats
+ if mem_cube.text_mem:
+ text_mem = mem_cube.text_mem
+ working_memories = text_mem.get_working_memory()
+ all_memories = text_mem.get_all()
+
+ print(" 📝 Text Memory:")
+ print(f" • Working Memory Items: {len(working_memories)}")
+ print(
+ f" • Total Memory Items: {len(all_memories) if isinstance(all_memories, list) else 'N/A'}"
+ )
+
+ if working_memories:
+ print(" • Working Memory Content Preview:")
+ for i, mem in enumerate(working_memories[:2]):
+ content = mem.memory[:60] + "..." if len(mem.memory) > 60 else mem.memory
+ print(f" {i + 1}. {content}")
+
+ # Activation memory stats
+ if mem_cube.act_mem:
+ act_mem = mem_cube.act_mem
+ act_memories = list(act_mem.get_all())
+ print(" ⚡ Activation Memory:")
+ print(f" • KV Cache Items: {len(act_memories)}")
+ if act_memories:
+ print(
+ f" • Latest Cache Size: {len(act_memories[-1].memory) if hasattr(act_memories[-1], 'memory') else 'N/A'}"
+ )
+
+ print("-" * 60)
+
+
+def display_scheduler_status(mos):
+ """Display current scheduler status and configuration."""
+ print("\n⚙️ SCHEDULER STATUS:")
+ print("-" * 60)
+
+ if not mos.mem_scheduler:
+ print(" ❌ Memory scheduler not initialized")
+ return
+
+ scheduler = mos.mem_scheduler
+ print(f" 🔄 Scheduler Running: {scheduler._running}")
+ print(f" 📊 Internal Queue Size: {scheduler.memos_message_queue.qsize()}")
+ print(f" 🧵 Parallel Dispatch: {scheduler.enable_parallel_dispatch}")
+ print(f" 👥 Max Workers: {scheduler.thread_pool_max_workers}")
+ print(f" ⏱️ Consume Interval: {scheduler._consume_interval}s")
+
+ if scheduler.monitor:
+ print(" 📈 Monitor Active: ✅")
+ print(f" 🗄️ Database Engine: {'✅' if scheduler.db_engine else '❌'}")
+
+ if scheduler.dispatcher:
+ print(" 🚀 Dispatcher Active: ✅")
+ print(
+ f" 🔧 Dispatcher Status: {scheduler.dispatcher.status if hasattr(scheduler.dispatcher, 'status') else 'Unknown'}"
+ )
+
+ print("-" * 60)
+
+
def init_task():
conversations = [
{
@@ -83,6 +155,9 @@ def init_task():
if __name__ == "__main__":
+ print("🚀 Starting Enhanced Memory Scheduler Test...")
+ print("=" * 80)
+
# set up data
conversations, questions = init_task()
@@ -111,6 +186,7 @@ def init_task():
)
# Initialization
+ print("🔧 Initializing MOS with Scheduler...")
mos = MOSForTestScheduler(mos_config)
user_id = "user_1"
@@ -121,7 +197,7 @@ def init_task():
if Path(mem_cube_name_or_path).exists():
shutil.rmtree(mem_cube_name_or_path)
- print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
+ print(f"🗑️ {mem_cube_name_or_path} is not empty, and has been removed.")
mem_cube = GeneralMemCube(mem_cube_config)
mem_cube.dump(mem_cube_name_or_path)
@@ -129,6 +205,7 @@ def init_task():
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)
+ print("📚 Adding initial conversations...")
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
# Add interfering conversations
@@ -137,11 +214,77 @@ def init_task():
mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
- for item in questions:
- print("===== Chat Start =====")
- query = item["question"]
- print(f"Query:\n {query}\n")
- response = mos.chat(query=query, user_id=user_id)
- print(f"Answer:\n {response}\n")
+ # Display initial status
+ print("\n📊 INITIAL SYSTEM STATUS:")
+ display_scheduler_status(mos)
+ display_memory_cube_stats(mos, user_id, mem_cube_id)
+
+ # Process questions with enhanced monitoring
+ print(f"\n🎯 Starting Question Processing ({len(questions)} questions)...")
+ question_start_time = time.time()
+
+ for i, item in enumerate(questions, 1):
+ print(f"\n{'=' * 20} Question {i}/{len(questions)} {'=' * 20}")
+ print(f"📝 Category: {item['category']} | Difficulty: {item['difficulty']}")
+ print(f"🎯 Expected: {item['expected']}")
+ if "hint" in item:
+ print(f"💡 Hint: {item['hint']}")
+ if "requires" in item:
+ print(f"🔍 Requires: {', '.join(item['requires'])}")
+
+ print(f"\n🚀 Processing Query: {item['question']}")
+ query_start_time = time.time()
+
+ response = mos.chat(query=item["question"], user_id=user_id)
+
+ query_time = time.time() - query_start_time
+ print(f"⏱️ Query Processing Time: {query_time:.3f}s")
+ print(f"🤖 Response: {response}")
+
+ # Display intermediate status every 2 questions
+ if i % 2 == 0:
+ print(f"\n📊 INTERMEDIATE STATUS (Question {i}):")
+ display_scheduler_status(mos)
+ display_memory_cube_stats(mos, user_id, mem_cube_id)
+
+ total_processing_time = time.time() - question_start_time
+ print(f"\n⏱️ Total Question Processing Time: {total_processing_time:.3f}s")
+
+ # Display final scheduler performance summary
+ print("\n" + "=" * 80)
+ print("📊 FINAL SCHEDULER PERFORMANCE SUMMARY")
+ print("=" * 80)
+
+ summary = mos.get_scheduler_summary()
+ print(f"🔢 Total Queries Processed: {summary['total_queries']}")
+ print(f"⚡ Total Scheduler Calls: {summary['total_scheduler_calls']}")
+ print(f"⏱️ Average Scheduler Response Time: {summary['average_scheduler_response_time']:.3f}s")
+ print(f"🧠 Memory Optimizations Applied: {summary['memory_optimization_count']}")
+ print(f"🔄 Working Memory Updates: {summary['working_memory_updates']}")
+ print(f"⚡ Activation Memory Updates: {summary['activation_memory_updates']}")
+ print(f"📈 Average Query Processing Time: {summary['average_query_processing_time']:.3f}s")
+
+ # Performance insights
+ print("\n💡 PERFORMANCE INSIGHTS:")
+ if summary["total_scheduler_calls"] > 0:
+ optimization_rate = (
+ summary["memory_optimization_count"] / summary["total_scheduler_calls"]
+ ) * 100
+ print(f" • Memory Optimization Rate: {optimization_rate:.1f}%")
+
+ if summary["average_scheduler_response_time"] < 0.1:
+ print(" • Scheduler Performance: 🟢 Excellent (< 100ms)")
+ elif summary["average_scheduler_response_time"] < 0.5:
+ print(" • Scheduler Performance: 🟡 Good (100-500ms)")
+ else:
+ print(" • Scheduler Performance: 🔴 Needs Improvement (> 500ms)")
+
+ # Final system status
+ print("\n🔍 FINAL SYSTEM STATUS:")
+ display_scheduler_status(mos)
+ display_memory_cube_stats(mos, user_id, mem_cube_id)
+
+ print("=" * 80)
+ print("🏁 Test completed successfully!")
mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/rabbitmq_example.py b/examples/mem_scheduler/rabbitmq_example.py
index ba573238..5e40eaad 100644
--- a/examples/mem_scheduler/rabbitmq_example.py
+++ b/examples/mem_scheduler/rabbitmq_example.py
@@ -2,7 +2,7 @@
import time
from memos.configs.mem_scheduler import AuthConfig
-from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule
+from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule
def publish_message(rabbitmq_module, message):
diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py
index 8c5d1415..634d69c3 100644
--- a/examples/mem_scheduler/try_schedule_modules.py
+++ b/examples/mem_scheduler/try_schedule_modules.py
@@ -12,8 +12,8 @@
from memos.configs.mem_scheduler import AuthConfig
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
+from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
from memos.mem_scheduler.general_scheduler import GeneralScheduler
-from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler
from memos.mem_scheduler.schemas.general_schemas import (
NOT_APPLICABLE_TYPE,
)
diff --git a/poetry.lock b/poetry.lock
index c6b6a0eb..e6830016 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3773,6 +3773,22 @@ files = [
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
+[[package]]
+name = "pymysql"
+version = "1.1.2"
+description = "Pure Python MySQL Driver"
+optional = false
+python-versions = ">=3.8"
+groups = ["main", "mem-user"]
+files = [
+ {file = "pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9"},
+ {file = "pymysql-1.1.2.tar.gz", hash = "sha256:4961d3e165614ae65014e361811a724e2044ad3ea3739de9903ae7c21f539f03"},
+]
+
+[package.extras]
+ed25519 = ["PyNaCl (>=1.4.0)"]
+rsa = ["cryptography"]
+
[[package]]
name = "pyparsing"
version = "3.2.3"
@@ -6285,12 +6301,13 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
-all = ["chonkie", "markitdown", "neo4j", "pika", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"]
+all = ["chonkie", "markitdown", "neo4j", "pika", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"]
mem-reader = ["chonkie", "markitdown"]
mem-scheduler = ["pika", "redis"]
+mem-user = ["pymysql"]
tree-mem = ["neo4j", "schedule"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
-content-hash = "94a3c4f97f0deda4c6ccbfd8ceda194f18dbc7525aa49004ffcc7846a1c40f7e"
+content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72"
diff --git a/pyproject.toml b/pyproject.toml
index c66bcb05..eae2e805 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,6 +42,7 @@ dependencies = [
"tenacity (>=9.1.2,<10.0.0)", # Error handling and retrying library
"fastapi[all] (>=0.115.12,<0.116.0)", # Web framework for building APIs
"sqlalchemy (>=2.0.41,<3.0.0)", # SQL toolkit
+ "pymysql (>=1.1.0,<2.0.0)", # MySQL Python driver
"scikit-learn (>=1.7.0,<2.0.0)", # Machine learning
"fastmcp (>=2.10.5,<3.0.0)",
"python-dateutil (>=2.9.0.post0,<3.0.0)",
@@ -76,6 +77,11 @@ mem-scheduler = [
"pika (>=1.3.2,<2.0.0)", # RabbitMQ client
]
+# MemUser (MySQL support)
+mem-user = [
+ "pymysql (>=1.1.0,<2.0.0)", # MySQL client for SQLAlchemy
+]
+
# MemReader
mem-reader = [
"chonkie (>=1.0.7,<2.0.0)", # Sentence chunking library
@@ -90,6 +96,7 @@ all = [
"schedule (>=1.2.2,<2.0.0)",
"redis (>=6.2.0,<7.0.0)",
"pika (>=1.3.2,<2.0.0)",
+ "pymysql (>=1.1.0,<2.0.0)",
"chonkie (>=1.0.7,<2.0.0)",
"markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)",
@@ -158,6 +165,10 @@ python-dotenv = "^1.1.1"
langgraph = "^0.5.1"
langmem = "^0.0.27"
+
+[tool.poetry.group.mem-user.dependencies]
+pymysql = "^1.1.2"
+
[[tool.poetry.source]]
name = "mirrors"
url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/"
diff --git a/src/memos/api/client.py b/src/memos/api/client.py
new file mode 100644
index 00000000..d45276f2
--- /dev/null
+++ b/src/memos/api/client.py
@@ -0,0 +1,109 @@
+import json
+import os
+
+from typing import Any
+
+import requests
+
+from memos.api.product_models import MemOSAddResponse, MemOSGetMessagesResponse, MemOSSearchResponse
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
+MAX_RETRY_COUNT = 3
+
+
+class MemOSClient:
+ """MemOS API client"""
+
+ def __init__(self, api_key: str | None = None, base_url: str | None = None):
+ self.base_url = (
+ base_url or os.getenv("MEMOS_BASE_URL") or "https://memos.memtensor.cn/api/openmem/v1"
+ )
+ api_key = api_key or os.getenv("MEMOS_API_KEY")
+
+ if not api_key:
+ raise ValueError("MemOS API key is required")
+
+ self.headers = {"Content-Type": "application/json", "Authorization": f"Token {api_key}"}
+
+ def _validate_required_params(self, **params):
+ """Validate required parameters - if passed, they must not be empty"""
+ for param_name, param_value in params.items():
+ if not param_value:
+ raise ValueError(f"{param_name} is required")
+
+ def get_message(
+ self, user_id: str, conversation_id: str | None = None
+ ) -> MemOSGetMessagesResponse:
+ """Get messages"""
+ # Validate required parameters
+ self._validate_required_params(user_id=user_id)
+
+ url = f"{self.base_url}/get/message"
+ payload = {"user_id": user_id, "conversation_id": conversation_id}
+ for retry in range(MAX_RETRY_COUNT):
+ try:
+ response = requests.post(
+ url, data=json.dumps(payload), headers=self.headers, timeout=30
+ )
+ response.raise_for_status()
+ response_data = response.json()
+ return MemOSGetMessagesResponse(**response_data)
+ except Exception as e:
+ logger.error(f"Failed to get messages (retry {retry + 1}/3): {e}")
+ if retry == MAX_RETRY_COUNT - 1:
+ raise
+
+ def add_message(
+ self, messages: list[dict[str, Any]], user_id: str, conversation_id: str
+ ) -> MemOSAddResponse:
+ """Add memories"""
+ # Validate required parameters
+ self._validate_required_params(
+ messages=messages, user_id=user_id, conversation_id=conversation_id
+ )
+
+ url = f"{self.base_url}/add/message"
+ payload = {"messages": messages, "user_id": user_id, "conversation_id": conversation_id}
+ for retry in range(MAX_RETRY_COUNT):
+ try:
+ response = requests.post(
+ url, data=json.dumps(payload), headers=self.headers, timeout=30
+ )
+ response.raise_for_status()
+ response_data = response.json()
+ return MemOSAddResponse(**response_data)
+ except Exception as e:
+ logger.error(f"Failed to add memory (retry {retry + 1}/3): {e}")
+ if retry == MAX_RETRY_COUNT - 1:
+ raise
+
+ def search_memory(
+ self, query: str, user_id: str, conversation_id: str, memory_limit_number: int = 6
+ ) -> MemOSSearchResponse:
+ """Search memories"""
+ # Validate required parameters
+ self._validate_required_params(query=query, user_id=user_id)
+
+ url = f"{self.base_url}/search/memory"
+ payload = {
+ "query": query,
+ "user_id": user_id,
+ "conversation_id": conversation_id,
+ "memory_limit_number": memory_limit_number,
+ }
+
+ for retry in range(MAX_RETRY_COUNT):
+ try:
+ response = requests.post(
+ url, data=json.dumps(payload), headers=self.headers, timeout=30
+ )
+ response.raise_for_status()
+ response_data = response.json()
+ return MemOSSearchResponse(**response_data)
+ except Exception as e:
+ logger.error(f"Failed to search memory (retry {retry + 1}/3): {e}")
+ if retry == MAX_RETRY_COUNT - 1:
+ raise
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index 990f4a16..355ee038 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -21,7 +21,7 @@ class APIConfig:
def get_openai_config() -> dict[str, Any]:
"""Get OpenAI configuration."""
return {
- "model_name_or_path": os.getenv("MOS_OPENAI_MODEL", "gpt-4o-mini"),
+ "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"),
"temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")),
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")),
"top_p": float(os.getenv("MOS_TOP_P", "0.9")),
@@ -100,8 +100,10 @@ def get_reranker_config() -> dict[str, Any]:
"backend": "http_bge",
"config": {
"url": os.getenv("MOS_RERANKER_URL"),
- "model": "bge-reranker-v2-m3",
+ "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
"timeout": 10,
+ "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
+ "rerank_source": os.getenv("MOS_RERANK_SOURCE"),
},
}
else:
@@ -186,22 +188,22 @@ def get_neo4j_community_config(user_id: str | None = None) -> dict[str, Any]:
return {
"uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
"user": os.getenv("NEO4J_USER", "neo4j"),
- "db_name": os.getenv("NEO4J_DB_NAME", "shared-tree-textual-memory"),
+ "db_name": os.getenv("NEO4J_DB_NAME", "neo4j"),
"password": os.getenv("NEO4J_PASSWORD", "12345678"),
"user_name": f"memos{user_id.replace('-', '')}",
- "auto_create": True,
+ "auto_create": False,
"use_multi_db": False,
- "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
+ "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)),
"vec_config": {
# Pass nested config to initialize external vector DB
# If you use qdrant, please use Server instead of local mode.
"backend": "qdrant",
"config": {
"collection_name": "neo4j_vec_db",
- "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
+ "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)),
"distance_metric": "cosine",
- "host": "localhost",
- "port": 6333,
+ "host": os.getenv("QDRANT_HOST", "localhost"),
+ "port": int(os.getenv("QDRANT_PORT", "6333")),
},
},
}
@@ -271,7 +273,7 @@ def get_mysql_config() -> dict[str, Any]:
def get_scheduler_config() -> dict[str, Any]:
"""Get scheduler configuration."""
return {
- "backend": "general_scheduler",
+ "backend": "optimized_scheduler",
"config": {
"top_k": int(os.getenv("MOS_SCHEDULER_TOP_K", "10")),
"act_mem_update_interval": int(
diff --git a/src/memos/api/context/context.py b/src/memos/api/context/context.py
deleted file mode 100644
index 8aee2cfe..00000000
--- a/src/memos/api/context/context.py
+++ /dev/null
@@ -1,147 +0,0 @@
-"""
-Global request context management for trace_id and request-scoped data.
-
-This module provides optional trace_id functionality that can be enabled
-when using the API components. It uses ContextVar to ensure thread safety
-and request isolation.
-"""
-
-import uuid
-
-from collections.abc import Callable
-from contextvars import ContextVar
-from typing import Any
-
-
-# Global context variable for request-scoped data
-_request_context: ContextVar[dict[str, Any] | None] = ContextVar("request_context", default=None)
-
-
-class RequestContext:
- """
- Request-scoped context object that holds trace_id and other request data.
-
- This provides a Flask g-like object for FastAPI applications.
- """
-
- def __init__(self, trace_id: str | None = None):
- self.trace_id = trace_id or str(uuid.uuid4())
- self._data: dict[str, Any] = {}
-
- def set(self, key: str, value: Any) -> None:
- """Set a value in the context."""
- self._data[key] = value
-
- def get(self, key: str, default: Any | None = None) -> Any:
- """Get a value from the context."""
- return self._data.get(key, default)
-
- def __setattr__(self, name: str, value: Any) -> None:
- if name.startswith("_") or name == "trace_id":
- super().__setattr__(name, value)
- else:
- if not hasattr(self, "_data"):
- super().__setattr__(name, value)
- else:
- self._data[name] = value
-
- def __getattr__(self, name: str) -> Any:
- if hasattr(self, "_data") and name in self._data:
- return self._data[name]
- raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
-
- def to_dict(self) -> dict[str, Any]:
- """Convert context to dictionary."""
- return {"trace_id": self.trace_id, "data": self._data.copy()}
-
-
-def set_request_context(context: RequestContext) -> None:
- """
- Set the current request context.
-
- This is typically called by the API dependency injection system.
- """
- _request_context.set(context.to_dict())
-
-
-def get_current_trace_id() -> str | None:
- """
- Get the current request's trace_id.
-
- Returns:
- The trace_id if available, None otherwise.
- """
- context = _request_context.get()
- if context:
- return context.get("trace_id")
- return None
-
-
-def get_current_context() -> RequestContext | None:
- """
- Get the current request context.
-
- Returns:
- The current RequestContext if available, None otherwise.
- """
- context_dict = _request_context.get()
- if context_dict:
- ctx = RequestContext(trace_id=context_dict.get("trace_id"))
- ctx._data = context_dict.get("data", {}).copy()
- return ctx
- return None
-
-
-def require_context() -> RequestContext:
- """
- Get the current request context, raising an error if not available.
-
- Returns:
- The current RequestContext.
-
- Raises:
- RuntimeError: If called outside of a request context.
- """
- context = get_current_context()
- if context is None:
- raise RuntimeError(
- "No request context available. This function must be called within a request handler."
- )
- return context
-
-
-# Type for trace_id getter function
-TraceIdGetter = Callable[[], str | None]
-
-# Global variable to hold the trace_id getter function
-_trace_id_getter: TraceIdGetter | None = None
-
-
-def set_trace_id_getter(getter: TraceIdGetter) -> None:
- """
- Set a custom trace_id getter function.
-
- This allows the logging system to retrieve trace_id without importing
- API-specific general_modules.
- """
- global _trace_id_getter
- _trace_id_getter = getter
-
-
-def get_trace_id_for_logging() -> str | None:
- """
- Get trace_id for logging purposes.
-
- This function is used by the logging system and will use either
- the custom getter function or fall back to the default context.
- """
- if _trace_id_getter:
- try:
- return _trace_id_getter()
- except Exception:
- pass
- return get_current_trace_id()
-
-
-# Initialize the default trace_id getter
-set_trace_id_getter(get_current_trace_id)
diff --git a/src/memos/api/context/context_thread.py b/src/memos/api/context/context_thread.py
deleted file mode 100644
index 41de13a6..00000000
--- a/src/memos/api/context/context_thread.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import functools
-import threading
-
-from collections.abc import Callable
-from concurrent.futures import ThreadPoolExecutor
-from typing import Any, TypeVar
-
-from memos.api.context.context import (
- RequestContext,
- get_current_context,
- get_current_trace_id,
- set_request_context,
-)
-
-
-T = TypeVar("T")
-
-
-class ContextThread(threading.Thread):
- """
- Thread class that automatically propagates the main thread's trace_id to child threads.
- """
-
- def __init__(self, target, args=(), kwargs=None, **thread_kwargs):
- super().__init__(**thread_kwargs)
- self.target = target
- self.args = args
- self.kwargs = kwargs or {}
-
- self.main_trace_id = get_current_trace_id()
- self.main_context = get_current_context()
-
- def run(self):
- # Create a new RequestContext with the main thread's trace_id
- if self.main_context:
- # Copy the context data
- child_context = RequestContext(trace_id=self.main_trace_id)
- child_context._data = self.main_context._data.copy()
-
- # Set the context in the child thread
- set_request_context(child_context)
-
- # Run the target function
- self.target(*self.args, **self.kwargs)
-
-
-class ContextThreadPoolExecutor(ThreadPoolExecutor):
- """
- ThreadPoolExecutor that automatically propagates the main thread's trace_id to worker threads.
- """
-
- def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
- """
- Submit a callable to be executed with the given arguments.
- Automatically propagates the current thread's context to the worker thread.
- """
- main_trace_id = get_current_trace_id()
- main_context = get_current_context()
-
- @functools.wraps(fn)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- if main_context:
- # Create and set new context in worker thread
- child_context = RequestContext(trace_id=main_trace_id)
- child_context._data = main_context._data.copy()
- set_request_context(child_context)
-
- return fn(*args, **kwargs)
-
- return super().submit(wrapper, *args, **kwargs)
-
- def map(
- self,
- fn: Callable[..., T],
- *iterables: Any,
- timeout: float | None = None,
- chunksize: int = 1,
- ) -> Any:
- """
- Returns an iterator equivalent to map(fn, iter).
- Automatically propagates the current thread's context to worker threads.
- """
- main_trace_id = get_current_trace_id()
- main_context = get_current_context()
-
- @functools.wraps(fn)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- if main_context:
- # Create and set new context in worker thread
- child_context = RequestContext(trace_id=main_trace_id)
- child_context._data = main_context._data.copy()
- set_request_context(child_context)
-
- return fn(*args, **kwargs)
-
- return super().map(wrapper, *iterables, timeout=timeout, chunksize=chunksize)
diff --git a/src/memos/api/context/dependencies.py b/src/memos/api/context/dependencies.py
index d26cadaa..d163fa0d 100644
--- a/src/memos/api/context/dependencies.py
+++ b/src/memos/api/context/dependencies.py
@@ -1,8 +1,6 @@
import logging
-from fastapi import Depends, Header, Request
-
-from memos.api.context.context import RequestContext, set_request_context
+from memos.context.context import RequestContext, get_current_context
logger = logging.getLogger(__name__)
@@ -11,56 +9,17 @@
G = RequestContext
-def get_trace_id_from_header(
- trace_id: str | None = Header(None, alias="trace-id"),
- x_trace_id: str | None = Header(None, alias="x-trace-id"),
- g_trace_id: str | None = Header(None, alias="g-trace-id"),
-) -> str | None:
- """
- Extract trace_id from various possible headers.
-
- Priority: g-trace-id > x-trace-id > trace-id
- """
- return g_trace_id or x_trace_id or trace_id
-
-
-def get_request_context(
- request: Request, trace_id: str | None = Depends(get_trace_id_from_header)
-) -> RequestContext:
- """
- Get request context object with trace_id and request metadata.
-
- This function creates a RequestContext and automatically sets it
- in the global context for use throughout the request lifecycle.
- """
- # Create context object
- ctx = RequestContext(trace_id=trace_id)
-
- # Set the context globally for this request
- set_request_context(ctx)
-
- # Log request start
- logger.info(f"Request started with trace_id: {ctx.trace_id}")
-
- # Add request metadata to context
- ctx.set("method", request.method)
- ctx.set("path", request.url.path)
- ctx.set("client_ip", request.client.host if request.client else None)
-
- return ctx
-
-
-def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G:
+def get_g_object() -> G:
"""
Get Flask g-like object for the current request.
-
- This creates a RequestContext and sets it globally for access
- throughout the request lifecycle.
+ Returns the context created by middleware.
"""
- g = RequestContext(trace_id=trace_id)
- set_request_context(g)
- logger.info(f"Request g object created with trace_id: {g.trace_id}")
- return g
+ ctx = get_current_context()
+ if ctx is None:
+ raise RuntimeError(
+ "No request context available. Make sure RequestContextMiddleware is properly configured."
+ )
+ return ctx
def get_current_g() -> G | None:
@@ -70,8 +29,6 @@ def get_current_g() -> G | None:
Returns:
The current request's g object if available, None otherwise.
"""
- from memos.context import get_current_context
-
return get_current_context()
@@ -85,6 +42,9 @@ def require_g() -> G:
Raises:
RuntimeError: If called outside of a request context.
"""
- from memos.context import require_context
-
- return require_context()
+ ctx = get_current_context()
+ if ctx is None:
+ raise RuntimeError(
+ "No request context available. This function must be called within a request handler."
+ )
+ return ctx
diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py
index 01f57a27..cb41428d 100644
--- a/src/memos/api/middleware/request_context.py
+++ b/src/memos/api/middleware/request_context.py
@@ -2,40 +2,25 @@
Request context middleware for automatic trace_id injection.
"""
-import logging
-import os
-
from collections.abc import Callable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
-from memos.api.context.context import RequestContext, set_request_context
-
+import memos.log
-logger = logging.getLogger(__name__)
+from memos.context.context import RequestContext, generate_trace_id, set_request_context
-def generate_trace_id() -> str:
- """Generate a random trace_id."""
- return os.urandom(16).hex()
+logger = memos.log.get_logger(__name__)
def extract_trace_id_from_headers(request: Request) -> str | None:
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
- trace_id = request.headers.get("g-trace-id")
- if trace_id:
- return trace_id
-
- trace_id = request.headers.get("x-trace-id")
- if trace_id:
- return trace_id
-
- trace_id = request.headers.get("trace-id")
- if trace_id:
- return trace_id
-
+ for header in ["g-trace-id", "x-trace-id", "trace-id"]:
+ if trace_id := request.headers.get(header):
+ return trace_id
return None
@@ -51,19 +36,12 @@ class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Extract or generate trace_id
- trace_id = extract_trace_id_from_headers(request)
- if not trace_id:
- trace_id = generate_trace_id()
+ trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
# Create and set request context
- context = RequestContext(trace_id=trace_id)
+ context = RequestContext(trace_id=trace_id, api_path=request.url.path)
set_request_context(context)
- # Add request metadata to context
- context.set("method", request.method)
- context.set("path", request.url.path)
- context.set("client_ip", request.client.host if request.client else None)
-
# Log request start with parameters
params_log = {}
@@ -71,16 +49,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
if request.query_params:
params_log["query_params"] = dict(request.query_params)
- # Get request body if it's available
- try:
- params_log = await request.json()
- except Exception as e:
- logger.error(f"Error getting request body: {e}")
- # If body is not JSON or empty, ignore it
-
- logger.info(
- f"Request started: {request.method} {request.url.path} - Parameters: {params_log}"
- )
+ logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
# Process the request
response = await call_next(request)
diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py
index 08940997..681644a0 100644
--- a/src/memos/api/product_api.py
+++ b/src/memos/api/product_api.py
@@ -17,9 +17,7 @@
version="1.0.1",
)
-# Add request context middleware (must be added first)
app.add_middleware(RequestContextMiddleware)
-
# Include routers
app.include_router(product_router)
@@ -35,5 +33,6 @@
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8001)
+ parser.add_argument("--workers", type=int, default=32)
args = parser.parse_args()
- uvicorn.run(app, host="0.0.0.0", port=args.port)
+ uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers)
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index 60764769..7e425415 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -1,26 +1,14 @@
import uuid
-from typing import Generic, Literal, TypeAlias, TypeVar
+from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, Field
-from typing_extensions import TypedDict
+# Import message types from core types module
+from memos.types import MessageDict
-T = TypeVar("T")
-
-
-# ─── Message Types ──────────────────────────────────────────────────────────────
-
-# Chat message roles
-MessageRole: TypeAlias = Literal["user", "assistant", "system"]
-
-
-# Message structure
-class MessageDict(TypedDict):
- """Typed dictionary for chat message dictionaries."""
- role: MessageRole
- content: str
+T = TypeVar("T")
class BaseRequest(BaseModel):
@@ -42,6 +30,7 @@ class UserRegisterRequest(BaseRequest):
user_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="User ID for registration"
)
+ mem_cube_id: str | None = Field(None, description="Cube ID for registration")
user_name: str | None = Field(None, description="User name for registration")
interests: str | None = Field(None, description="User interests")
@@ -85,6 +74,7 @@ class ChatRequest(BaseRequest):
history: list[MessageDict] | None = Field(None, description="Chat history")
internet_search: bool = Field(True, description="Whether to use internet search")
moscube: bool = Field(False, description="Whether to use MemOSCube")
+ session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
class ChatCompleteRequest(BaseRequest):
@@ -99,6 +89,7 @@ class ChatCompleteRequest(BaseRequest):
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
threshold: float = Field(0.5, description="Threshold for filtering references")
+ session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
class UserCreate(BaseRequest):
@@ -160,6 +151,7 @@ class MemoryCreateRequest(BaseRequest):
mem_cube_id: str | None = Field(None, description="Cube ID")
source: str | None = Field(None, description="Source of the memory")
user_profile: bool = Field(False, description="User profile memory")
+ session_id: str | None = Field(None, description="Session id")
class SearchRequest(BaseRequest):
@@ -169,6 +161,7 @@ class SearchRequest(BaseRequest):
query: str = Field(..., description="Search query")
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
top_k: int = Field(10, description="Number of results to return")
+ session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
class SuggestionRequest(BaseRequest):
@@ -177,3 +170,85 @@ class SuggestionRequest(BaseRequest):
user_id: str = Field(..., description="User ID")
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
message: list[MessageDict] | None = Field(None, description="List of messages to store.")
+
+
+# ─── MemOS Client Response Models ──────────────────────────────────────────────
+
+
+class MessageDetail(BaseModel):
+ """Individual message detail model based on actual API response."""
+
+ model_config = {"extra": "allow"}
+
+
+class MemoryDetail(BaseModel):
+ """Individual memory detail model based on actual API response."""
+
+ model_config = {"extra": "allow"}
+
+
+class GetMessagesData(BaseModel):
+ """Data model for get messages response based on actual API."""
+
+ message_detail_list: list[MessageDetail] = Field(
+ default_factory=list, alias="memory_detail_list", description="List of message details"
+ )
+
+
+class SearchMemoryData(BaseModel):
+ """Data model for search memory response based on actual API."""
+
+ memory_detail_list: list[MemoryDetail] = Field(
+ default_factory=list, alias="memory_detail_list", description="List of memory details"
+ )
+ message_detail_list: list[MessageDetail] | None = Field(
+ None, alias="message_detail_list", description="List of message details (usually None)"
+ )
+
+
+class AddMessageData(BaseModel):
+ """Data model for add message response based on actual API."""
+
+ success: bool = Field(..., description="Operation success status")
+
+
+# ─── MemOS Response Models (Similar to OpenAI ChatCompletion) ──────────────────
+
+
+class MemOSGetMessagesResponse(BaseModel):
+ """Response model for get messages operation based on actual API."""
+
+ code: int = Field(..., description="Response status code")
+ message: str = Field(..., description="Response message")
+ data: GetMessagesData = Field(..., description="Messages data")
+
+ @property
+ def messages(self) -> list[MessageDetail]:
+ """Convenient access to message list."""
+ return self.data.message_detail_list
+
+
+class MemOSSearchResponse(BaseModel):
+ """Response model for search memory operation based on actual API."""
+
+ code: int = Field(..., description="Response status code")
+ message: str = Field(..., description="Response message")
+ data: SearchMemoryData = Field(..., description="Search results data")
+
+ @property
+ def memories(self) -> list[MemoryDetail]:
+ """Convenient access to memory list."""
+ return self.data.memory_detail_list
+
+
+class MemOSAddResponse(BaseModel):
+ """Response model for add message operation based on actual API."""
+
+ code: int = Field(..., description="Response status code")
+ message: str = Field(..., description="Response message")
+ data: AddMessageData = Field(..., description="Add operation data")
+
+ @property
+ def success(self) -> bool:
+ """Convenient access to success status."""
+ return self.data.success
diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py
index a27e4e48..75b614cf 100644
--- a/src/memos/api/routers/product_router.py
+++ b/src/memos/api/routers/product_router.py
@@ -1,14 +1,11 @@
import json
+import time
import traceback
-from datetime import datetime
-from typing import Annotated
-
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from memos.api.config import APIConfig
-from memos.api.context.dependencies import G, get_g_object
from memos.api.product_models import (
BaseResponse,
ChatCompleteRequest,
@@ -79,24 +76,19 @@ def set_config(config):
@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
-def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_object)]):
+def register_user(user_req: UserRegisterRequest):
"""Register a new user with configuration and default cube."""
try:
- # Set request-related information in g object
- g.user_id = user_req.user_id
- g.action = "user_register"
- g.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
-
- logger.info(f"Starting user registration for user_id: {user_req.user_id}")
- logger.info(f"Request trace_id: {g.trace_id}")
- logger.info(f"Request timestamp: {g.timestamp}")
-
# Get configuration for the user
+ time_start_register = time.time()
user_config, default_mem_cube = APIConfig.create_user_config(
user_name=user_req.user_id, user_id=user_req.user_id
)
logger.info(f"user_config: {user_config.model_dump(mode='json')}")
logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}")
+ logger.info(
+ f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
+ )
mos_product = get_mos_product_instance()
# Register user with default config and mem cube
@@ -106,8 +98,11 @@ def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_o
interests=user_req.interests,
config=user_config,
default_mem_cube=default_mem_cube,
+ mem_cube_id=user_req.mem_cube_id,
+ )
+ logger.info(
+ f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
)
-
if result["status"] == "success":
return UserRegisterResponse(
message="User registered successfully",
@@ -194,6 +189,7 @@ def get_all_memories(memory_req: GetMemoryRequest):
def create_memory(memory_req: MemoryCreateRequest):
"""Create a new memory for a specific user."""
try:
+ time_start_add = time.time()
mos_product = get_mos_product_instance()
mos_product.add(
user_id=memory_req.user_id,
@@ -203,6 +199,10 @@ def create_memory(memory_req: MemoryCreateRequest):
mem_cube_id=memory_req.mem_cube_id,
source=memory_req.source,
user_profile=memory_req.user_profile,
+ session_id=memory_req.session_id,
+ )
+ logger.info(
+ f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
)
return SimpleResponse(message="Memory created successfully")
@@ -217,12 +217,17 @@ def create_memory(memory_req: MemoryCreateRequest):
def search_memories(search_req: SearchRequest):
"""Search memories for a specific user."""
try:
+ time_start_search = time.time()
mos_product = get_mos_product_instance()
result = mos_product.search(
query=search_req.query,
user_id=search_req.user_id,
install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
top_k=search_req.top_k,
+ session_id=search_req.session_id,
+ )
+ logger.info(
+ f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}"
)
return SearchResponse(message="Search completed successfully", data=result)
@@ -250,6 +255,7 @@ def generate_chat_response():
history=chat_req.history,
internet_search=chat_req.internet_search,
moscube=chat_req.moscube,
+ session_id=chat_req.session_id,
)
except Exception as e:
@@ -294,6 +300,7 @@ def chat_complete(chat_req: ChatCompleteRequest):
base_prompt=chat_req.base_prompt,
top_k=chat_req.top_k,
threshold=chat_req.threshold,
+ session_id=chat_req.session_id,
)
# Return the complete response
diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py
index 9f464a4a..cbcdf6ce 100644
--- a/src/memos/api/start_api.py
+++ b/src/memos/api/start_api.py
@@ -421,3 +421,13 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=500,
content={"code": 500, "message": str(exc), "data": None},
)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
+ parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
+ args = parser.parse_args()
diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py
index 01246b06..2df91716 100644
--- a/src/memos/configs/graph_db.py
+++ b/src/memos/configs/graph_db.py
@@ -140,6 +140,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig):
"If False: use a single shared database with logical isolation by user_name."
),
)
+ max_client: int = Field(
+ default=1000,
+ description=("max_client"),
+ )
embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")
@model_validator(mode="after")
diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py
index 4d62bd11..a36f3e2f 100644
--- a/src/memos/configs/mem_scheduler.py
+++ b/src/memos/configs/mem_scheduler.py
@@ -6,7 +6,7 @@
from pydantic import ConfigDict, Field, field_validator, model_validator
from memos.configs.base import BaseConfig
-from memos.mem_scheduler.general_modules.misc import DictConversionMixin
+from memos.mem_scheduler.general_modules.misc import DictConversionMixin, EnvConfigMixin
from memos.mem_scheduler.schemas.general_schemas import (
BASE_DIR,
DEFAULT_ACT_MEM_DUMP_PATH,
@@ -64,6 +64,19 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
default=20, description="Capacity of the activation memory monitor"
)
+ # Database configuration for ORM persistence
+ db_path: str | None = Field(
+ default=None,
+ description="Path to SQLite database file for ORM persistence. If None, uses default scheduler_orm.db",
+ )
+ db_url: str | None = Field(
+ default=None,
+ description="Database URL for ORM persistence (e.g., mysql://user:pass@host/db). Takes precedence over db_path",
+ )
+ enable_orm_persistence: bool = Field(
+ default=True, description="Whether to enable ORM-based persistence for monitors"
+ )
+
class SchedulerConfigFactory(BaseConfig):
"""Factory class for creating scheduler configurations."""
@@ -74,6 +87,7 @@ class SchedulerConfigFactory(BaseConfig):
model_config = ConfigDict(extra="forbid", strict=True)
backend_to_class: ClassVar[dict[str, Any]] = {
"general_scheduler": GeneralSchedulerConfig,
+ "optimized_scheduler": GeneralSchedulerConfig, # optimized_scheduler uses same config as general_scheduler
}
@field_validator("backend")
@@ -94,6 +108,8 @@ def create_config(self) -> "SchedulerConfigFactory":
# ************************* Auth *************************
class RabbitMQConfig(
BaseConfig,
+ DictConversionMixin,
+ EnvConfigMixin,
):
host_name: str = Field(default="", description="Endpoint for RabbitMQ instance access")
user_name: str = Field(default="", description="Static username for RabbitMQ instance")
@@ -110,7 +126,7 @@ class RabbitMQConfig(
)
-class GraphDBAuthConfig(BaseConfig):
+class GraphDBAuthConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
uri: str = Field(
default="bolt://localhost:7687",
description="URI for graph database access (e.g., bolt://host:port)",
@@ -127,7 +143,7 @@ class GraphDBAuthConfig(BaseConfig):
)
-class OpenAIConfig(BaseConfig):
+class OpenAIConfig(BaseConfig, DictConversionMixin, EnvConfigMixin):
api_key: str = Field(default="", description="API key for OpenAI service")
base_url: str = Field(default="", description="Base URL for API endpoint")
default_model: str = Field(default="", description="Default model to use")
@@ -183,6 +199,25 @@ def from_local_config(cls, config_path: str | Path | None = None) -> "AuthConfig
"Please use YAML (.yaml, .yml) or JSON (.json) files."
)
+ @classmethod
+ def from_local_env(cls) -> "AuthConfig":
+ """Creates an AuthConfig instance by loading configuration from environment variables.
+
+ This method loads configuration for all nested components (RabbitMQ, OpenAI, GraphDB)
+ from their respective environment variables using each component's specific prefix.
+
+ Returns:
+ AuthConfig: Configured instance with values from environment variables
+
+ Raises:
+ ValueError: If any required environment variables are missing
+ """
+ return cls(
+ rabbitmq=RabbitMQConfig.from_env(),
+ openai=OpenAIConfig.from_env(),
+ graph_db=GraphDBAuthConfig.from_env(),
+ )
+
def set_openai_config_to_environment(self):
# Set environment variables
os.environ["OPENAI_API_KEY"] = self.openai.api_key
diff --git a/src/memos/context/context.py b/src/memos/context/context.py
new file mode 100644
index 00000000..4f54348f
--- /dev/null
+++ b/src/memos/context/context.py
@@ -0,0 +1,255 @@
+"""
+Global request context management for trace_id and request-scoped data.
+
+This module provides optional trace_id functionality that can be enabled
+when using the API components. It uses ContextVar to ensure thread safety
+and request isolation.
+"""
+
+import functools
+import os
+import threading
+
+from collections.abc import Callable
+from concurrent.futures import ThreadPoolExecutor
+from contextvars import ContextVar
+from typing import Any, TypeVar
+
+
+T = TypeVar("T")
+
+# Global context variable for request-scoped data
+_request_context: ContextVar[dict[str, Any] | None] = ContextVar("request_context", default=None)
+
+
+class RequestContext:
+ """
+ Request-scoped context object that holds trace_id and other request data.
+
+ This provides a Flask g-like object for FastAPI applications.
+ """
+
+ def __init__(self, trace_id: str | None = None, api_path: str | None = None):
+ self.trace_id = trace_id or "trace-id"
+ self.api_path = api_path
+ self._data: dict[str, Any] = {}
+
+ def set(self, key: str, value: Any) -> None:
+ """Set a value in the context."""
+ self._data[key] = value
+
+ def get(self, key: str, default: Any | None = None) -> Any:
+ """Get a value from the context."""
+ return self._data.get(key, default)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ if name.startswith("_") or name in ("trace_id", "api_path"):
+ super().__setattr__(name, value)
+ else:
+ if not hasattr(self, "_data"):
+ super().__setattr__(name, value)
+ else:
+ self._data[name] = value
+
+ def __getattr__(self, name: str) -> Any:
+ if hasattr(self, "_data") and name in self._data:
+ return self._data[name]
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert context to dictionary."""
+ return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()}
+
+
+def set_request_context(context: RequestContext) -> None:
+ """
+ Set the current request context.
+
+ This is typically called by the API dependency injection system.
+ """
+ _request_context.set(context.to_dict())
+
+
+def get_current_trace_id() -> str | None:
+ """
+ Get the current request's trace_id.
+
+ Returns:
+ The trace_id if available, None otherwise.
+ """
+ context = _request_context.get()
+ if context:
+ return context.get("trace_id")
+ return None
+
+
+def get_current_api_path() -> str | None:
+ """
+ Get the current request's api path.
+ """
+ context = _request_context.get()
+ if context:
+ return context.get("api_path")
+ return None
+
+
+def get_current_context() -> RequestContext | None:
+ """
+ Get the current request context.
+
+ Returns:
+ The current RequestContext if available, None otherwise.
+ """
+ context_dict = _request_context.get()
+ if context_dict:
+ ctx = RequestContext(
+ trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path")
+ )
+ ctx._data = context_dict.get("data", {}).copy()
+ return ctx
+ return None
+
+
+def require_context() -> RequestContext:
+ """
+ Get the current request context, raising an error if not available.
+
+ Returns:
+ The current RequestContext.
+
+ Raises:
+ RuntimeError: If called outside of a request context.
+ """
+ context = get_current_context()
+ if context is None:
+ raise RuntimeError(
+ "No request context available. This function must be called within a request handler."
+ )
+ return context
+
+
+class ContextThread(threading.Thread):
+ """
+ Thread class that automatically propagates the main thread's trace_id to child threads.
+ """
+
+ def __init__(self, target, args=(), kwargs=None, **thread_kwargs):
+ super().__init__(**thread_kwargs)
+ self.target = target
+ self.args = args
+ self.kwargs = kwargs or {}
+
+ self.main_trace_id = get_current_trace_id()
+ self.main_api_path = get_current_api_path()
+ self.main_context = get_current_context()
+
+ def run(self):
+ # Create a new RequestContext with the main thread's trace_id
+ if self.main_context:
+ # Copy the context data
+ child_context = RequestContext(
+ trace_id=self.main_trace_id, api_path=self.main_context.api_path
+ )
+ child_context._data = self.main_context._data.copy()
+
+ # Set the context in the child thread
+ set_request_context(child_context)
+
+ # Run the target function
+ self.target(*self.args, **self.kwargs)
+
+
+class ContextThreadPoolExecutor(ThreadPoolExecutor):
+ """
+ ThreadPoolExecutor that automatically propagates the main thread's trace_id to worker threads.
+ """
+
+ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
+ """
+ Submit a callable to be executed with the given arguments.
+ Automatically propagates the current thread's context to the worker thread.
+ """
+ main_trace_id = get_current_trace_id()
+ main_api_path = get_current_api_path()
+ main_context = get_current_context()
+
+ @functools.wraps(fn)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ if main_context:
+ # Create and set new context in worker thread
+ child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
+ child_context._data = main_context._data.copy()
+ set_request_context(child_context)
+
+ return fn(*args, **kwargs)
+
+ return super().submit(wrapper, *args, **kwargs)
+
+ def map(
+ self,
+ fn: Callable[..., T],
+ *iterables: Any,
+ timeout: float | None = None,
+ chunksize: int = 1,
+ ) -> Any:
+ """
+ Returns an iterator equivalent to map(fn, iter).
+ Automatically propagates the current thread's context to worker threads.
+ """
+ main_trace_id = get_current_trace_id()
+ main_api_path = get_current_api_path()
+ main_context = get_current_context()
+
+ @functools.wraps(fn)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ if main_context:
+ # Create and set new context in worker thread
+ child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
+ child_context._data = main_context._data.copy()
+ set_request_context(child_context)
+
+ return fn(*args, **kwargs)
+
+ return super().map(wrapper, *iterables, timeout=timeout, chunksize=chunksize)
+
+
+# Type for trace_id getter function
+TraceIdGetter = Callable[[], str | None]
+
+# Global variable to hold the trace_id getter function
+_trace_id_getter: TraceIdGetter | None = None
+
+
+def generate_trace_id() -> str:
+ """Generate a random trace_id."""
+ return os.urandom(16).hex()
+
+
+def set_trace_id_getter(getter: TraceIdGetter) -> None:
+ """
+ Set a custom trace_id getter function.
+
+ This allows the logging system to retrieve trace_id without importing
+ API-specific general_modules.
+ """
+ global _trace_id_getter
+ _trace_id_getter = getter
+
+
+def get_trace_id_for_logging() -> str | None:
+ """
+ Get trace_id for logging purposes.
+
+ This function is used by the logging system and will use either
+ the custom getter function or fall back to the default context.
+ """
+ if _trace_id_getter:
+ try:
+ return _trace_id_getter()
+ except Exception:
+ pass
+ return get_current_trace_id()
+
+
+# Initialize the default trace_id getter
+set_trace_id_getter(get_current_trace_id)
diff --git a/src/memos/embedders/factory.py b/src/memos/embedders/factory.py
index b15ad7c4..be14db9e 100644
--- a/src/memos/embedders/factory.py
+++ b/src/memos/embedders/factory.py
@@ -6,6 +6,7 @@
from memos.embedders.ollama import OllamaEmbedder
from memos.embedders.sentence_transformer import SenTranEmbedder
from memos.embedders.universal_api import UniversalAPIEmbedder
+from memos.memos_tools.singleton import singleton_factory
class EmbedderFactory(BaseEmbedder):
@@ -19,6 +20,7 @@ class EmbedderFactory(BaseEmbedder):
}
@classmethod
+ @singleton_factory()
def from_config(cls, config_factory: EmbedderConfigFactory) -> BaseEmbedder:
backend = config_factory.backend
if backend not in cls.backend_to_class:
diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py
index 5ca8c895..66ad894a 100644
--- a/src/memos/graph_dbs/nebular.py
+++ b/src/memos/graph_dbs/nebular.py
@@ -3,7 +3,6 @@
from contextlib import suppress
from datetime import datetime
-from queue import Empty, Queue
from threading import Lock
from typing import TYPE_CHECKING, Any, ClassVar, Literal
@@ -17,12 +16,27 @@
if TYPE_CHECKING:
- from nebulagraph_python.client.pool import NebulaPool
+ from nebulagraph_python import (
+ NebulaClient,
+ )
logger = get_logger(__name__)
+_TRANSIENT_ERR_KEYS = (
+ "Session not found",
+ "Connection not established",
+ "timeout",
+ "deadline exceeded",
+ "Broken pipe",
+ "EOFError",
+ "socket closed",
+ "connection reset",
+ "connection refused",
+)
+
+
@timed
def _normalize(vec: list[float]) -> list[float]:
v = np.asarray(vec, dtype=np.float32)
@@ -87,137 +101,6 @@ def _normalize_datetime(val):
return str(val)
-class SessionPoolError(Exception):
- pass
-
-
-class SessionPool:
- @require_python_package(
- import_name="nebulagraph_python",
- install_command="pip install ... @Tianxing",
- install_link=".....",
- )
- def __init__(
- self,
- hosts: list[str],
- user: str,
- password: str,
- minsize: int = 1,
- maxsize: int = 10000,
- ):
- self.hosts = hosts
- self.user = user
- self.password = password
- self.minsize = minsize
- self.maxsize = maxsize
- self.pool = Queue(maxsize)
- self.lock = Lock()
-
- self.clients = []
-
- for _ in range(minsize):
- self._create_and_add_client()
-
- @timed
- def _create_and_add_client(self):
- from nebulagraph_python import NebulaClient
-
- client = NebulaClient(self.hosts, self.user, self.password)
- self.pool.put(client)
- self.clients.append(client)
-
- @timed
- def get_client(self, timeout: float = 5.0):
- try:
- return self.pool.get(timeout=timeout)
- except Empty:
- with self.lock:
- if len(self.clients) < self.maxsize:
- from nebulagraph_python import NebulaClient
-
- client = NebulaClient(self.hosts, self.user, self.password)
- self.clients.append(client)
- return client
- raise RuntimeError("NebulaClientPool exhausted") from None
-
- @timed
- def return_client(self, client):
- try:
- client.execute("YIELD 1")
- self.pool.put(client)
- except Exception:
- logger.info("[Pool] Client dead, replacing...")
- self.replace_client(client)
-
- @timed
- def close(self):
- for client in self.clients:
- with suppress(Exception):
- client.close()
- self.clients.clear()
-
- @timed
- def get(self):
- """
- Context manager: with pool.get() as client:
- """
-
- class _ClientContext:
- def __init__(self, outer):
- self.outer = outer
- self.client = None
-
- def __enter__(self):
- self.client = self.outer.get_client()
- return self.client
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.client:
- self.outer.return_client(self.client)
-
- return _ClientContext(self)
-
- @timed
- def reset_pool(self):
- """⚠️ Emergency reset: Close all clients and clear the pool."""
- logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
- with self.lock:
- for client in self.clients:
- try:
- client.close()
- except Exception:
- logger.error("Fail to close!!!")
- self.clients.clear()
- while not self.pool.empty():
- try:
- self.pool.get_nowait()
- except Empty:
- break
- for _ in range(self.minsize):
- self._create_and_add_client()
- logger.info("[Pool] Pool has been reset successfully.")
-
- @timed
- def replace_client(self, client):
- try:
- client.close()
- except Exception:
- logger.error("Fail to close client")
-
- if client in self.clients:
- self.clients.remove(client)
-
- from nebulagraph_python import NebulaClient
-
- new_client = NebulaClient(self.hosts, self.user, self.password)
- self.clients.append(new_client)
-
- self.pool.put(new_client)
-
- logger.info("[Pool] Replaced dead client with a new one.")
- return new_client
-
-
class NebulaGraphDB(BaseGraphDB):
"""
NebulaGraph-based implementation of a graph memory store.
@@ -226,94 +109,194 @@ class NebulaGraphDB(BaseGraphDB):
# ====== shared pool cache & refcount ======
# These are process-local; in a multi-process model each process will
# have its own cache.
- _POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
- _POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
- _POOL_LOCK: ClassVar[Lock] = Lock()
+ _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
+ _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
+ _CLIENT_LOCK: ClassVar[Lock] = Lock()
+ _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
@staticmethod
- def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
- """
- Build a cache key that captures all connection-affecting options.
- Keep this key stable and include fields that change the underlying pool behavior.
- """
- # NOTE: Do not include tenant-like or query-scope-only fields here.
- # Only include things that affect the actual TCP/auth/session pool.
+ def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
+ hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
+ if isinstance(hosts, str):
+ return [hosts]
+ return list(hosts or [])
+
+ @staticmethod
+ def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
+ hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
return "|".join(
[
- "nebula",
- str(getattr(cfg, "uri", "")),
+ "nebula-sync",
+ ",".join(hosts),
str(getattr(cfg, "user", "")),
- str(getattr(cfg, "password", "")),
- # pool sizing / tls / timeouts if you have them in config:
- str(getattr(cfg, "max_client", 1000)),
- # multi-db mode can impact how we use sessions; keep it to be safe
str(getattr(cfg, "use_multi_db", False)),
+ str(getattr(cfg, "space", "")),
]
)
@classmethod
- def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
- """
- Get a shared NebulaPool from cache or create one if missing.
- Thread-safe with a lock; maintains a simple refcount.
- """
- key = cls._make_pool_key(cfg)
-
- with cls._POOL_LOCK:
- pool = cls._POOL_CACHE.get(key)
- if pool is None:
- # Create a new pool and put into cache
- pool = SessionPool(
- hosts=cfg.get("uri"),
- user=cfg.get("user"),
- password=cfg.get("password"),
- minsize=1,
- maxsize=cfg.get("max_client", 1000),
+ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
+ tmp = object.__new__(NebulaGraphDB)
+ tmp.config = cfg
+ tmp.db_name = cfg.space
+ tmp.user_name = getattr(cfg, "user_name", None)
+ tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
+ tmp.default_memory_dimension = 3072
+ tmp.common_fields = {
+ "id",
+ "memory",
+ "user_name",
+ "user_id",
+ "session_id",
+ "status",
+ "key",
+ "confidence",
+ "tags",
+ "created_at",
+ "updated_at",
+ "memory_type",
+ "sources",
+ "source",
+ "node_type",
+ "visibility",
+ "usage",
+ "background",
+ }
+ tmp.base_fields = set(tmp.common_fields) - {"usage"}
+ tmp.heavy_fields = {"usage"}
+ tmp.dim_field = (
+ f"embedding_{tmp.embedding_dimension}"
+ if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
+ else "embedding"
+ )
+ tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space
+ tmp._client = client
+ tmp._owns_client = False
+ return tmp
+
+ @classmethod
+ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
+ from nebulagraph_python import (
+ ConnectionConfig,
+ NebulaClient,
+ SessionConfig,
+ SessionPoolConfig,
+ )
+
+ key = cls._make_client_key(cfg)
+ with cls._CLIENT_LOCK:
+ client = cls._CLIENT_CACHE.get(key)
+ if client is None:
+ # Connection setting
+ conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
+ if conn_conf is None:
+ conn_conf = ConnectionConfig.from_defults(
+ cls._get_hosts_from_cfg(cfg),
+ getattr(cfg, "ssl_param", None),
+ )
+
+ sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
+ pool_conf = SessionPoolConfig(
+ size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
+ )
+
+ client = NebulaClient(
+ hosts=conn_conf.hosts,
+ username=cfg.user,
+ password=cfg.password,
+ conn_config=conn_conf,
+ session_config=sess_conf,
+ session_pool_config=pool_conf,
)
- cls._POOL_CACHE[key] = pool
- cls._POOL_REFCOUNT[key] = 0
- logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
+ cls._CLIENT_CACHE[key] = client
+ cls._CLIENT_REFCOUNT[key] = 0
+ logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
- # Increase refcount for the caller
- cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
- return key, pool
+ cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
+
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
+ try:
+ pass
+ finally:
+ pass
+
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
+ with cls._CLIENT_LOCK:
+ if key not in cls._CLIENT_INIT_DONE:
+ admin = cls._bootstrap_admin(cfg, client)
+ try:
+ admin._ensure_database_exists()
+ admin._create_basic_property_indexes()
+ admin._create_vector_index(
+ dimensions=int(
+ admin.embedding_dimension or admin.default_memory_dimension
+ ),
+ )
+ cls._CLIENT_INIT_DONE.add(key)
+ logger.info("[NebulaGraphDBSync] One-time init done")
+ except Exception:
+ logger.exception("[NebulaGraphDBSync] One-time init failed")
+
+ return key, client
+
+ def _refresh_client(self):
+ """
+ refresh NebulaClient:
+ """
+ old_key = getattr(self, "_client_key", None)
+ if not old_key:
+ return
+
+ cls = self.__class__
+ with cls._CLIENT_LOCK:
+ try:
+ if old_key in cls._CLIENT_CACHE:
+ try:
+ cls._CLIENT_CACHE[old_key].close()
+ except Exception as e:
+ logger.warning(f"[refresh_client] close old client error: {e}")
+ finally:
+ cls._CLIENT_CACHE.pop(old_key, None)
+ finally:
+ cls._CLIENT_REFCOUNT[old_key] = 0
+
+ new_key, new_client = cls._get_or_create_shared_client(self.config)
+ self._client_key = new_key
+ self._client = new_client
+ logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
@classmethod
- def _release_shared_pool(cls, key: str):
- """
- Decrease refcount for the given pool key; only close when refcount hits zero.
- """
- with cls._POOL_LOCK:
- if key not in cls._POOL_CACHE:
+ def _release_shared_client(cls, key: str):
+ with cls._CLIENT_LOCK:
+ if key not in cls._CLIENT_CACHE:
return
- cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
- if cls._POOL_REFCOUNT[key] == 0:
+ cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
+ if cls._CLIENT_REFCOUNT[key] == 0:
try:
- cls._POOL_CACHE[key].close()
+ cls._CLIENT_CACHE[key].close()
except Exception as e:
- logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
+ logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
finally:
- cls._POOL_CACHE.pop(key, None)
- cls._POOL_REFCOUNT.pop(key, None)
- logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
+ cls._CLIENT_CACHE.pop(key, None)
+ cls._CLIENT_REFCOUNT.pop(key, None)
+ logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
@classmethod
- def close_all_shared_pools(cls):
- """Force close all cached pools. Call this on graceful shutdown."""
- with cls._POOL_LOCK:
- for key, pool in list(cls._POOL_CACHE.items()):
+ def close_all_shared_clients(cls):
+ with cls._CLIENT_LOCK:
+ for key, client in list(cls._CLIENT_CACHE.items()):
try:
- pool.close()
+ client.close()
except Exception as e:
- logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
+ logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
finally:
- logger.info(f"[NebulaGraphDB] Closed pool key={key}")
- cls._POOL_CACHE.clear()
- cls._POOL_REFCOUNT.clear()
+ logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
+ cls._CLIENT_CACHE.clear()
+ cls._CLIENT_REFCOUNT.clear()
@require_python_package(
import_name="nebulagraph_python",
- install_command="pip install ... @Tianxing",
+ install_command="pip install nebulagraph-python>=5.1.1",
install_link=".....",
)
def __init__(self, config: NebulaGraphDBConfig):
@@ -371,34 +354,32 @@ def __init__(self, config: NebulaGraphDBConfig):
# ---- NEW: pool acquisition strategy
# Get or create a shared pool from the class-level cache
- self._pool_key, self.pool = self._get_or_create_shared_pool(config)
- self._owns_pool = True # We manage refcount for this instance
-
- # auto-create graph type / graph / index if needed
- if config.auto_create:
- self._ensure_database_exists()
-
- self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
-
- # Create only if not exists
- self.create_index(dimensions=config.embedding_dimension)
+ self._client_key, self._client = self._get_or_create_shared_client(config)
+ self._owns_client = True
logger.info("Connected to NebulaGraph successfully.")
@timed
- def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
- with self.pool.get() as client:
- try:
- if auto_set_db and self.db_name:
- client.execute(f"SESSION SET GRAPH `{self.db_name}`")
- return client.execute(gql, timeout=timeout)
+ def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
+ def _wrap_use_db(q: str) -> str:
+ if auto_set_db and self.db_name:
+ return f"USE `{self.db_name}`\n{q}"
+ return q
- except Exception as e:
- if "Session not found" in str(e) or "Connection not established" in str(e):
- logger.warning(f"[execute_query] {e!s}, replacing client...")
- self.pool.replace_client(client)
- return self.execute_query(gql, timeout, auto_set_db)
- raise
+ try:
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
+
+ except Exception as e:
+ emsg = str(e)
+ if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
+ logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
+ try:
+ self._refresh_client()
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
+ except Exception:
+ logger.exception("[execute_query] retry after refresh failed")
+ raise
+ raise
@timed
def close(self):
@@ -409,13 +390,13 @@ def close(self):
- If pool was acquired via shared cache, decrement refcount and close
when the last owner releases it.
"""
- if not self._owns_pool:
- logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
+ if not self._owns_client:
+ logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
return
- if self._pool_key:
- self._release_shared_pool(self._pool_key)
- self._pool_key = None
- self.pool = None
+ if self._client_key:
+ self._release_shared_client(self._client_key)
+ self._client_key = None
+ self._client = None
# NOTE: __del__ is best-effort; do not rely on GC order.
def __del__(self):
@@ -972,6 +953,7 @@ def search_by_embedding(
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
+ search_filter: dict | None = None,
**kwargs,
) -> list[dict]:
"""
@@ -984,6 +966,8 @@ def search_by_embedding(
status (str, optional): Node status filter (e.g., 'active', 'archived').
If provided, restricts results to nodes with matching status.
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
+ search_filter (dict, optional): Additional metadata filters for search results.
+ Keys should match node properties, values are the expected values.
Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -993,6 +977,7 @@ def search_by_embedding(
- If scope is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, only nodes with the matching status will be returned.
- If threshold is provided, only results with score >= threshold will be returned.
+ - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- Typical use case: restrict to 'status = activated' to avoid
matching archived or merged nodes.
"""
@@ -1012,10 +997,17 @@ def search_by_embedding(
else:
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
+ # Add search_filter conditions
+ if search_filter:
+ for key, value in search_filter.items():
+ if isinstance(value, str):
+ where_clauses.append(f'n.{key} = "{value}"')
+ else:
+ where_clauses.append(f"n.{key} = {value}")
+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
gql = f"""
- USE `{self.db_name}`
MATCH (n@Memory)
{where_clause}
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
@@ -1038,7 +1030,7 @@ def search_by_embedding(
id_val = values[0].as_string()
score_val = values[1].as_double()
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
- if threshold is None or score_val <= threshold:
+ if threshold is None or score_val >= threshold:
output.append({"id": id_val, "score": score_val})
return output
except Exception as e:
@@ -1368,9 +1360,9 @@ def get_structure_optimization_candidates(
where_clause += f' AND n.user_name = "{self.config.user_name}"'
return_fields = self._build_return_fields(include_embedding)
+ return_fields += f", n.{self.dim_field} AS {self.dim_field}"
query = f"""
- USE `{self.db_name}`
MATCH (n@Memory)
WHERE {where_clause}
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
@@ -1380,11 +1372,16 @@ def get_structure_optimization_candidates(
"""
candidates = []
+ node_ids = set()
try:
results = self.execute_query(query)
for row in results:
props = {k: v.value for k, v in row.items()}
- candidates.append(self._parse_node(props))
+ node = self._parse_node(props)
+ node_id = node["id"]
+ if node_id not in node_ids:
+ candidates.append(node)
+ node_ids.add(node_id)
except Exception as e:
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
return candidates
@@ -1538,18 +1535,19 @@ def _ensure_database_exists(self):
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
- set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
-
try:
self.execute_query(create_graph, auto_set_db=False)
- self.execute_query(set_graph_working)
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
except Exception as e:
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
@timed
def _create_vector_index(
- self, label: str, vector_property: str, dimensions: int, index_name: str
+ self,
+ label: str = "Memory",
+ vector_property: str = "embedding",
+ dimensions: int = 3072,
+ index_name: str = "memory_vector_index",
) -> None:
"""
Create a vector index for the specified property in the label.
diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py
index b3a4a265..96908913 100644
--- a/src/memos/graph_dbs/neo4j.py
+++ b/src/memos/graph_dbs/neo4j.py
@@ -1,3 +1,4 @@
+import json
import time
from datetime import datetime
@@ -174,6 +175,12 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
n.updated_at = datetime($updated_at),
n += $metadata
"""
+
+ # serialization
+ if metadata["sources"]:
+ for idx in range(len(metadata["sources"])):
+ metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
+
with self.driver.session(database=self.db_name) as session:
session.run(
query,
@@ -606,6 +613,7 @@ def search_by_embedding(
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
+ search_filter: dict | None = None,
**kwargs,
) -> list[dict]:
"""
@@ -618,6 +626,8 @@ def search_by_embedding(
status (str, optional): Node status filter (e.g., 'active', 'archived').
If provided, restricts results to nodes with matching status.
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
+ search_filter (dict, optional): Additional metadata filters for search results.
+ Keys should match node properties, values are the expected values.
Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -627,6 +637,7 @@ def search_by_embedding(
- If scope is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, only nodes with the matching status will be returned.
- If threshold is provided, only results with score >= threshold will be returned.
+ - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- Typical use case: restrict to 'status = activated' to avoid
matching archived or merged nodes.
"""
@@ -639,6 +650,12 @@ def search_by_embedding(
if not self.config.use_multi_db and self.config.user_name:
where_clauses.append("node.user_name = $user_name")
+ # Add search_filter conditions
+ if search_filter:
+ for key, _ in search_filter.items():
+ param_name = f"filter_{key}"
+ where_clauses.append(f"node.{key} = ${param_name}")
+
where_clause = ""
if where_clauses:
where_clause = "WHERE " + " AND ".join(where_clauses)
@@ -650,7 +667,8 @@ def search_by_embedding(
RETURN node.id AS id, score
"""
- parameters = {"embedding": vector, "k": top_k, "scope": scope}
+ parameters = {"embedding": vector, "k": top_k}
+
if scope:
parameters["scope"] = scope
if status:
@@ -661,6 +679,12 @@ def search_by_embedding(
else:
parameters["user_name"] = self.config.user_name
+ # Add search_filter parameters
+ if search_filter:
+ for key, value in search_filter.items():
+ param_name = f"filter_{key}"
+ parameters[param_name] = value
+
with self.driver.session(database=self.db_name) as session:
result = session.run(query, parameters)
records = [{"id": record["id"], "score": record["score"]} for record in result]
@@ -1111,4 +1135,14 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
node[time_field] = node[time_field].isoformat()
node.pop("user_name", None)
+ # serialization
+ if node["sources"]:
+ for idx in range(len(node["sources"])):
+ if not (
+ isinstance(node["sources"][idx], str)
+ and node["sources"][idx][0] == "{"
+ and node["sources"][idx][0] == "}"
+ ):
+ break
+ node["sources"][idx] = json.loads(node["sources"][idx])
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py
index 500e2839..8acab420 100644
--- a/src/memos/graph_dbs/neo4j_community.py
+++ b/src/memos/graph_dbs/neo4j_community.py
@@ -129,6 +129,7 @@ def search_by_embedding(
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
+ search_filter: dict | None = None,
**kwargs,
) -> list[dict]:
"""
@@ -140,6 +141,7 @@ def search_by_embedding(
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
status (str, optional): Node status filter (e.g., 'activated', 'archived').
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
+ search_filter (dict, optional): Additional metadata filters to apply.
Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -149,6 +151,7 @@ def search_by_embedding(
- If 'scope' is provided, it restricts results to nodes with matching memory_type.
- If 'status' is provided, it further filters nodes by status.
- If 'threshold' is provided, only results with score >= threshold will be returned.
+ - If 'search_filter' is provided, it applies additional metadata-based filtering.
- The returned IDs can be used to fetch full node data from Neo4j if needed.
"""
# Build VecDB filter
@@ -163,6 +166,10 @@ def search_by_embedding(
else:
vec_filter["user_name"] = self.config.user_name
+ # Add search_filter conditions
+ if search_filter:
+ vec_filter.update(search_filter)
+
# Perform vector search
results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
diff --git a/src/memos/llms/factory.py b/src/memos/llms/factory.py
index 0c12a667..8589d775 100644
--- a/src/memos/llms/factory.py
+++ b/src/memos/llms/factory.py
@@ -9,6 +9,7 @@
from memos.llms.openai import AzureLLM, OpenAILLM
from memos.llms.qwen import QwenLLM
from memos.llms.vllm import VLLMLLM
+from memos.memos_tools.singleton import singleton_factory
class LLMFactory(BaseLLM):
@@ -26,6 +27,7 @@ class LLMFactory(BaseLLM):
}
@classmethod
+ @singleton_factory()
def from_config(cls, config_factory: LLMConfigFactory) -> BaseLLM:
backend = config_factory.backend
if backend not in cls.backend_to_class:
diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py
index 148f6a2c..698bc326 100644
--- a/src/memos/llms/openai.py
+++ b/src/memos/llms/openai.py
@@ -1,4 +1,8 @@
+import hashlib
+import json
+
from collections.abc import Generator
+from typing import ClassVar
import openai
@@ -13,11 +17,44 @@
class OpenAILLM(BaseLLM):
- """OpenAI LLM class."""
+ """OpenAI LLM class with singleton pattern."""
+
+ _instances: ClassVar[dict] = {} # Class variable to store instances
+
+ def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM":
+ config_hash = cls._get_config_hash(config)
+
+ if config_hash not in cls._instances:
+ logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}")
+ instance = super().__new__(cls)
+ cls._instances[config_hash] = instance
+ else:
+ logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}")
+
+ return cls._instances[config_hash]
def __init__(self, config: OpenAILLMConfig):
+ # Avoid duplicate initialization
+ if hasattr(self, "_initialized"):
+ return
+
self.config = config
self.client = openai.Client(api_key=config.api_key, base_url=config.api_base)
+ self._initialized = True
+ logger.info("OpenAI LLM instance initialized")
+
+ @classmethod
+ def _get_config_hash(cls, config: OpenAILLMConfig) -> str:
+ """Generate hash value of configuration"""
+ config_dict = config.model_dump()
+ config_str = json.dumps(config_dict, sort_keys=True)
+ return hashlib.md5(config_str.encode()).hexdigest()
+
+ @classmethod
+ def clear_cache(cls):
+ """Clear all cached instances"""
+ cls._instances.clear()
+ logger.info("OpenAI LLM instance cache cleared")
def generate(self, messages: MessageList) -> str:
"""Generate a response from OpenAI LLM."""
@@ -71,15 +108,50 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non
class AzureLLM(BaseLLM):
- """Azure OpenAI LLM class."""
+ """Azure OpenAI LLM class with singleton pattern."""
+
+ _instances: ClassVar[dict] = {} # Class variable to store instances
+
+ def __new__(cls, config: AzureLLMConfig):
+ # Generate hash value of config as cache key
+ config_hash = cls._get_config_hash(config)
+
+ if config_hash not in cls._instances:
+ logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}")
+ instance = super().__new__(cls)
+ cls._instances[config_hash] = instance
+ else:
+ logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}")
+
+ return cls._instances[config_hash]
def __init__(self, config: AzureLLMConfig):
+ # Avoid duplicate initialization
+ if hasattr(self, "_initialized"):
+ return
+
self.config = config
self.client = openai.AzureOpenAI(
azure_endpoint=config.base_url,
api_version=config.api_version,
api_key=config.api_key,
)
+ self._initialized = True
+ logger.info("Azure LLM instance initialized")
+
+ @classmethod
+ def _get_config_hash(cls, config: AzureLLMConfig) -> str:
+ """Generate hash value of configuration"""
+ # Convert config to dict and sort to ensure consistency
+ config_dict = config.model_dump()
+ config_str = json.dumps(config_dict, sort_keys=True)
+ return hashlib.md5(config_str.encode()).hexdigest()
+
+ @classmethod
+ def clear_cache(cls):
+ """Clear all cached instances"""
+ cls._instances.clear()
+ logger.info("Azure LLM instance cache cleared")
def generate(self, messages: MessageList) -> str:
"""Generate a response from Azure OpenAI LLM."""
diff --git a/src/memos/log.py b/src/memos/log.py
index a5b6648f..339d13f2 100644
--- a/src/memos/log.py
+++ b/src/memos/log.py
@@ -2,7 +2,9 @@
import logging
import os
import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
from logging.config import dictConfig
from pathlib import Path
from sys import stdout
@@ -12,8 +14,7 @@
from dotenv import load_dotenv
from memos import settings
-from memos.api.context.context import get_current_trace_id
-from memos.api.context.context_thread import ContextThreadPoolExecutor
+from memos.context.context import get_current_api_path, get_current_trace_id
# Load environment variables
@@ -39,9 +40,9 @@ class TraceIDFilter(logging.Filter):
def filter(self, record):
try:
trace_id = get_current_trace_id()
- record.trace_id = trace_id if trace_id else "no-trace-id"
+ record.trace_id = trace_id if trace_id else "trace-id"
except Exception:
- record.trace_id = "no-trace-id"
+ record.trace_id = "trace-id"
return True
@@ -65,7 +66,7 @@ def __init__(self):
if not self._initialized:
super().__init__()
workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2"))
- self._executor = ContextThreadPoolExecutor(
+ self._executor = ThreadPoolExecutor(
max_workers=workers, thread_name_prefix="log_sender"
)
self._is_shutting_down = threading.Event()
@@ -78,21 +79,32 @@ def emit(self, record):
if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set():
return
+ # Only process INFO and ERROR level logs
+ if record.levelno < logging.INFO: # Skip DEBUG and lower
+ return
+
try:
- trace_id = get_current_trace_id() or "no-trace-id"
- self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
+ trace_id = get_current_trace_id() or "trace-id"
+ api_path = get_current_api_path()
+ if api_path is not None:
+ self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path)
except Exception as e:
if not self._is_shutting_down.is_set():
print(f"Error sending log: {e}")
- def _send_log_sync(self, message, trace_id):
+ def _send_log_sync(self, message, trace_id, api_path):
"""Send log message synchronously in a separate thread"""
try:
logger_url = os.getenv("CUSTOM_LOGGER_URL")
token = os.getenv("CUSTOM_LOGGER_TOKEN")
headers = {"Content-Type": "application/json"}
- post_content = {"message": message, "trace_id": trace_id}
+ post_content = {
+ "message": message,
+ "trace_id": trace_id,
+ "action": api_path,
+ "current_time": round(time.time(), 3),
+ }
# Add auth token if exists
if token:
@@ -139,7 +151,7 @@ def close(self):
"format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
},
"simplified": {
- "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s | %(message)s"
+ "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s"
},
},
"filters": {
@@ -151,7 +163,7 @@ def close(self):
"level": selected_log_level,
"class": "logging.StreamHandler",
"stream": stdout,
- "formatter": "simplified",
+ "formatter": "no_datetime",
"filters": ["package_tree_filter", "trace_id_filter"],
},
"file": {
@@ -160,18 +172,18 @@ def close(self):
"filename": _setup_logfile(),
"maxBytes": 1024**2 * 10,
"backupCount": 10,
- "formatter": "simplified",
+ "formatter": "standard",
"filters": ["trace_id_filter"],
},
"custom_logger": {
- "level": selected_log_level,
+ "level": "INFO",
"class": "memos.log.CustomLoggerRequestHandler",
"formatter": "simplified",
},
},
"root": { # Root logger handles all logs
- "level": selected_log_level,
- "handlers": ["console", "file", "custom_logger"],
+ "level": logging.DEBUG if settings.DEBUG else logging.INFO,
+ "handlers": ["console", "file"],
},
"loggers": {
"memos": {
diff --git a/src/memos/mem_cube/general.py b/src/memos/mem_cube/general.py
index 7217c354..17e45809 100644
--- a/src/memos/mem_cube/general.py
+++ b/src/memos/mem_cube/general.py
@@ -1,4 +1,5 @@
import os
+import time
from typing import Literal
@@ -23,11 +24,13 @@ class GeneralMemCube(BaseMemCube):
def __init__(self, config: GeneralMemCubeConfig):
"""Initialize the MemCube with a configuration."""
self.config = config
+ time_start = time.time()
self._text_mem: BaseTextMemory | None = (
MemoryFactory.from_config(config.text_mem)
if config.text_mem.backend != "uninitialized"
else None
)
+ logger.info(f"init_text_mem in {time.time() - time_start} seconds")
self._act_mem: BaseActMemory | None = (
MemoryFactory.from_config(config.act_mem)
if config.act_mem.backend != "uninitialized"
@@ -137,7 +140,6 @@ def init_from_dir(
if default_config is not None:
config = merge_config_with_default(config, default_config)
logger.info(f"Applied default config to cube {config.cube_id}")
-
mem_cube = GeneralMemCube(config)
mem_cube.load(dir, memory_types)
return mem_cube
diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py
index a201e22c..54e507b5 100644
--- a/src/memos/mem_os/core.py
+++ b/src/memos/mem_os/core.py
@@ -24,7 +24,7 @@
from memos.memories.activation.item import ActivationMemoryItem
from memos.memories.parametric.item import ParametricMemoryItem
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
-from memos.memos_tools.thread_safe_dict import ThreadSafeDict
+from memos.memos_tools.thread_safe_dict_segment import OptimizedThreadSafeDict
from memos.templates.mos_prompts import QUERY_REWRITING_PROMPT
from memos.types import ChatHistory, MessageList, MOSSearchResult
@@ -47,8 +47,8 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None):
self.mem_reader = MemReaderFactory.from_config(config.mem_reader)
self.chat_history_manager: dict[str, ChatHistory] = {}
# use thread safe dict for multi-user product-server scenario
- self.mem_cubes: ThreadSafeDict[str, GeneralMemCube] = (
- ThreadSafeDict() if user_manager is not None else {}
+ self.mem_cubes: OptimizedThreadSafeDict[str, GeneralMemCube] = (
+ OptimizedThreadSafeDict() if user_manager is not None else {}
)
self._register_chat_history()
@@ -125,12 +125,16 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler:
"missing required 'llm' attribute"
)
self._mem_scheduler.initialize_modules(
- chat_llm=self.chat_llm, process_llm=self.chat_llm
+ chat_llm=self.chat_llm,
+ process_llm=self.chat_llm,
+ db_engine=self.user_manager.engine,
)
else:
# Configure scheduler general_modules
self._mem_scheduler.initialize_modules(
- chat_llm=self.chat_llm, process_llm=self.mem_reader.llm
+ chat_llm=self.chat_llm,
+ process_llm=self.mem_reader.llm,
+ db_engine=self.user_manager.engine,
)
self._mem_scheduler.start()
return self._mem_scheduler
@@ -182,13 +186,13 @@ def mem_reorganizer_wait(self) -> bool:
logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}")
mem_cube.text_mem.memory_manager.wait_reorganizer()
- def _register_chat_history(self, user_id: str | None = None) -> None:
+ def _register_chat_history(
+ self, user_id: str | None = None, session_id: str | None = None
+ ) -> None:
"""Initialize chat history with user ID."""
- if user_id is None:
- user_id = self.user_id
self.chat_history_manager[user_id] = ChatHistory(
- user_id=user_id,
- session_id=self.session_id,
+ user_id=user_id if user_id is not None else self.user_id,
+ session_id=session_id if session_id is not None else self.session_id,
created_at=datetime.utcnow(),
total_messages=0,
chat_history=[],
@@ -483,14 +487,14 @@ def register_mem_cube(
self.mem_cubes[mem_cube_id] = mem_cube_name_or_path
logger.info(f"register new cube {mem_cube_id} for user {target_user_id}")
elif os.path.exists(mem_cube_name_or_path):
- self.mem_cubes[mem_cube_id] = GeneralMemCube.init_from_dir(mem_cube_name_or_path)
+ mem_cube_obj = GeneralMemCube.init_from_dir(mem_cube_name_or_path)
+ self.mem_cubes[mem_cube_id] = mem_cube_obj
else:
logger.warning(
f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
)
- self.mem_cubes[mem_cube_id] = GeneralMemCube.init_from_remote_repo(
- mem_cube_name_or_path
- )
+ mem_cube_obj = GeneralMemCube.init_from_remote_repo(mem_cube_name_or_path)
+ self.mem_cubes[mem_cube_id] = mem_cube_obj
# Check if cube already exists in database
existing_cube = self.user_manager.get_cube(mem_cube_id)
@@ -547,6 +551,7 @@ def search(
mode: Literal["fast", "fine"] = "fast",
internet_search: bool = False,
moscube: bool = False,
+ session_id: str | None = None,
**kwargs,
) -> MOSSearchResult:
"""
@@ -562,7 +567,9 @@ def search(
Returns:
MemoryResult: A dictionary containing the search results.
"""
+ target_session_id = session_id if session_id is not None else self.session_id
target_user_id = user_id if user_id is not None else self.user_id
+
self._validate_user_exists(target_user_id)
# Get all cubes accessible by the target user
accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
@@ -575,6 +582,11 @@ def search(
self._register_chat_history(target_user_id)
chat_history = self.chat_history_manager[target_user_id]
+ # Create search filter if session_id is provided
+ search_filter = None
+ if session_id is not None:
+ search_filter = {"session_id": session_id}
+
result: MOSSearchResult = {
"text_mem": [],
"act_mem": [],
@@ -584,9 +596,13 @@ def search(
install_cube_ids = user_cube_ids
# create exist dict in mem_cubes and avoid one search slow
tmp_mem_cubes = {}
+ time_start_cube_get = time.time()
for mem_cube_id in install_cube_ids:
if mem_cube_id in self.mem_cubes:
tmp_mem_cubes[mem_cube_id] = self.mem_cubes.get(mem_cube_id)
+ logger.info(
+ f"time search: transform cube time user_id: {target_user_id} time is: {time.time() - time_start_cube_get}"
+ )
for mem_cube_id, mem_cube in tmp_mem_cubes.items():
if (
@@ -602,10 +618,11 @@ def search(
manual_close_internet=not internet_search,
info={
"user_id": target_user_id,
- "session_id": self.session_id,
+ "session_id": target_session_id,
"chat_history": chat_history.chat_history,
},
moscube=moscube,
+ search_filter=search_filter,
)
result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
logger.info(
@@ -624,6 +641,8 @@ def add(
doc_path: str | None = None,
mem_cube_id: str | None = None,
user_id: str | None = None,
+ session_id: str | None = None,
+ **kwargs,
) -> None:
"""
Add textual memories to a MemCube.
@@ -636,11 +655,16 @@ def add(
If None, the default MemCube for the user is used.
user_id (str, optional): The identifier of the user to add the memories to.
If None, the default user is used.
+ session_id (str, optional): session_id
"""
# user input messages
assert (messages is not None) or (memory_content is not None) or (doc_path is not None), (
"messages_or_doc_path or memory_content or doc_path must be provided."
)
+ # TODO: asure that session_id is a valid string
+ time_start = time.time()
+
+ target_session_id = session_id if session_id else self.session_id
target_user_id = user_id if user_id is not None else self.user_id
if mem_cube_id is None:
# Try to find a default cube for the user
@@ -652,18 +676,29 @@ def add(
mem_cube_id = accessible_cubes[0].cube_id # TODO not only first
else:
self._validate_cube_access(target_user_id, mem_cube_id)
+ logger.info(
+ f"time add: get mem_cube_id time user_id: {target_user_id} time is: {time.time() - time_start}"
+ )
+ time_start_0 = time.time()
if mem_cube_id not in self.mem_cubes:
raise ValueError(f"MemCube '{mem_cube_id}' is not loaded. Please register.")
+ logger.info(
+ f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}"
+ )
+ time_start_1 = time.time()
if (
(messages is not None)
and self.config.enable_textual_memory
and self.mem_cubes[mem_cube_id].text_mem
):
+ logger.info(
+ f"time add: messages is not None and enable_textual_memory and text_mem is not None time user_id: {target_user_id} time is: {time.time() - time_start_1}"
+ )
if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text":
add_memory = []
metadata = TextualMemoryMetadata(
- user_id=target_user_id, session_id=self.session_id, source="conversation"
+ user_id=target_user_id, session_id=target_session_id, source="conversation"
)
for message in messages:
add_memory.append(
@@ -672,12 +707,15 @@ def add(
self.mem_cubes[mem_cube_id].text_mem.add(add_memory)
else:
messages_list = [messages]
+ time_start_2 = time.time()
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
- info={"user_id": target_user_id, "session_id": self.session_id},
+ info={"user_id": target_user_id, "session_id": target_session_id},
+ )
+ logger.info(
+ f"time add: get mem_reader time user_id: {target_user_id} time is: {time.time() - time_start_2}"
)
-
mem_ids = []
for mem in memories:
mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem)
@@ -707,7 +745,7 @@ def add(
):
if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text":
metadata = TextualMemoryMetadata(
- user_id=self.user_id, session_id=self.session_id, source="conversation"
+ user_id=target_user_id, session_id=target_session_id, source="conversation"
)
self.mem_cubes[mem_cube_id].text_mem.add(
[TextualMemoryItem(memory=memory_content, metadata=metadata)]
@@ -719,7 +757,7 @@ def add(
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
- info={"user_id": target_user_id, "session_id": self.session_id},
+ info={"user_id": target_user_id, "session_id": target_session_id},
)
mem_ids = []
@@ -753,7 +791,7 @@ def add(
doc_memories = self.mem_reader.get_memory(
documents,
type="doc",
- info={"user_id": target_user_id, "session_id": self.session_id},
+ info={"user_id": target_user_id, "session_id": target_session_id},
)
mem_ids = []
@@ -986,7 +1024,7 @@ def load(
def get_user_info(self) -> dict[str, Any]:
"""Get current user information including accessible cubes.
-
+ TODO: maybe input user_id
Returns:
dict: User information and accessible cubes.
"""
diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py
index 2520c8fd..2e5b3254 100644
--- a/src/memos/mem_os/main.py
+++ b/src/memos/mem_os/main.py
@@ -5,6 +5,7 @@
from typing import Any
from memos.configs.mem_os import MOSConfig
+from memos.context.context import ContextThreadPoolExecutor
from memos.llms.factory import LLMFactory
from memos.log import get_logger
from memos.mem_os.core import MOSCore
@@ -487,9 +488,7 @@ def generate_answer_for_question(question_index: int, sub_question: str) -> tupl
# Generate answers in parallel while maintaining order
sub_answers = [None] * len(sub_questions)
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=min(len(sub_questions), 10)
- ) as executor:
+ with ContextThreadPoolExecutor(max_workers=min(len(sub_questions), 10)) as executor:
# Submit all answer generation tasks
future_to_index = {
executor.submit(generate_answer_for_question, i, question): i
@@ -552,9 +551,7 @@ def search_single_question(question: str) -> list[Any]:
# Search in parallel while maintaining order
all_memories = []
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=min(len(sub_questions), 10)
- ) as executor:
+ with ContextThreadPoolExecutor(max_workers=min(len(sub_questions), 10)) as executor:
# Submit all search tasks and keep track of their order
future_to_index = {
executor.submit(search_single_question, question): i
diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py
index 5899b680..a4ab4ef2 100644
--- a/src/memos/mem_os/product.py
+++ b/src/memos/mem_os/product.py
@@ -2,7 +2,6 @@
import json
import os
import random
-import threading
import time
from collections.abc import Generator
@@ -14,6 +13,7 @@
from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
+from memos.context.context import ContextThread
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_os.core import MOSCore
@@ -46,6 +46,7 @@
get_memos_prompt,
)
from memos.types import MessageList
+from memos.utils import timed
logger = get_logger(__name__)
@@ -257,6 +258,7 @@ def _preload_user_cubes(
except Exception as e:
logger.error(f"Error pre-loading cubes for user {user_id}: {e}", exc_info=True)
+ @timed
def _load_user_cubes(
self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None
) -> None:
@@ -288,6 +290,7 @@ def _load_user_cubes(
)
except Exception as e:
logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}")
+ logger.info(f"load user {user_id} cubes successfully")
def _ensure_user_instance(self, user_id: str, max_instances: int | None = None) -> None:
"""
@@ -694,8 +697,8 @@ def run_async_in_thread():
else None
)
except RuntimeError:
- # No event loop, run in a new thread
- thread = threading.Thread(
+ # No event loop, run in a new thread with context propagation
+ thread = ContextThread(
target=run_async_in_thread,
name=f"PostChatProcessing-{user_id}",
# Set as a daemon thread to avoid blocking program exit
@@ -775,10 +778,14 @@ def register_mem_cube(
return
# Create MemCube from path
+ time_start = time.time()
if os.path.exists(mem_cube_name_or_path):
mem_cube = GeneralMemCube.init_from_dir(
mem_cube_name_or_path, memory_types, default_config
)
+ logger.info(
+ f"time register_mem_cube: init_from_dir time is: {time.time() - time_start}"
+ )
else:
logger.warning(
f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
@@ -791,7 +798,10 @@ def register_mem_cube(
logger.info(
f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}"
)
+ time_start = time.time()
self.mem_cubes[mem_cube_id] = mem_cube
+ time_end = time.time()
+ logger.info(f"time register_mem_cube: add mem_cube time is: {time_end - time_start}")
def user_register(
self,
@@ -801,6 +811,7 @@ def user_register(
interests: str | None = None,
default_mem_cube: GeneralMemCube | None = None,
default_cube_config: GeneralMemCubeConfig | None = None,
+ mem_cube_id: str | None = None,
) -> dict[str, str]:
"""Register a new user with configuration and default cube.
@@ -836,15 +847,19 @@ def user_register(
default_cube_name = f"{user_name}_{user_id}_default_cube"
mem_cube_name_or_path = os.path.join(CUBE_PATH, default_cube_name)
default_cube_id = self.create_cube_for_user(
- cube_name=default_cube_name, owner_id=user_id, cube_path=mem_cube_name_or_path
+ cube_name=default_cube_name,
+ owner_id=user_id,
+ cube_path=mem_cube_name_or_path,
+ cube_id=mem_cube_id,
)
-
+ time_start = time.time()
if default_mem_cube:
try:
- default_mem_cube.dump(mem_cube_name_or_path)
+ default_mem_cube.dump(mem_cube_name_or_path, memory_types=[])
except Exception as e:
logger.error(f"Failed to dump default cube: {e}")
-
+ time_end = time.time()
+ logger.info(f"time user_register: dump default cube time is: {time_end - time_start}")
# Register the default cube with MOS
self.register_mem_cube(
mem_cube_name_or_path_or_object=default_mem_cube,
@@ -924,6 +939,7 @@ def chat(
moscube: bool = False,
top_k: int = 10,
threshold: float = 0.5,
+ session_id: str | None = None,
) -> str:
"""
Chat with LLM with memory references and complete response.
@@ -938,6 +954,7 @@ def chat(
mode="fine",
internet_search=internet_search,
moscube=moscube,
+ session_id=session_id,
)["text_mem"]
memories_list = []
@@ -982,6 +999,7 @@ def chat_with_references(
top_k: int = 20,
internet_search: bool = False,
moscube: bool = False,
+ session_id: str | None = None,
) -> Generator[str, None, None]:
"""
Chat with LLM with memory references and streaming output.
@@ -1008,6 +1026,7 @@ def chat_with_references(
mode="fine",
internet_search=internet_search,
moscube=moscube,
+ session_id=session_id,
)["text_mem"]
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
@@ -1028,7 +1047,7 @@ def chat_with_references(
system_prompt = self._build_enhance_system_prompt(user_id, memories_list)
# Get chat history
if user_id not in self.chat_history_manager:
- self._register_chat_history(user_id)
+ self._register_chat_history(user_id, session_id)
chat_history = self.chat_history_manager[user_id]
if history:
@@ -1296,6 +1315,7 @@ def search(
install_cube_ids: list[str] | None = None,
top_k: int = 10,
mode: Literal["fast", "fine"] = "fast",
+ session_id: str | None = None,
):
"""Search memories for a specific user."""
@@ -1306,7 +1326,9 @@ def search(
logger.info(
f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}"
)
- search_result = super().search(query, user_id, install_cube_ids, top_k, mode=mode)
+ search_result = super().search(
+ query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id
+ )
search_time_end = time.time()
logger.info(
f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}"
@@ -1342,13 +1364,15 @@ def add(
mem_cube_id: str | None = None,
source: str | None = None,
user_profile: bool = False,
+ session_id: str | None = None,
):
"""Add memory for a specific user."""
# Load user cubes if not already loaded
self._load_user_cubes(user_id, self.default_cube_config)
-
- result = super().add(messages, memory_content, doc_path, mem_cube_id, user_id)
+ result = super().add(
+ messages, memory_content, doc_path, mem_cube_id, user_id, session_id=session_id
+ )
if user_profile:
try:
user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0]
diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py
index 7997c821..52eed8d9 100644
--- a/src/memos/mem_reader/factory.py
+++ b/src/memos/mem_reader/factory.py
@@ -3,6 +3,7 @@
from memos.configs.mem_reader import MemReaderConfigFactory
from memos.mem_reader.base import BaseMemReader
from memos.mem_reader.simple_struct import SimpleStructMemReader
+from memos.memos_tools.singleton import singleton_factory
class MemReaderFactory(BaseMemReader):
@@ -13,6 +14,7 @@ class MemReaderFactory(BaseMemReader):
}
@classmethod
+ @singleton_factory()
def from_config(cls, config_factory: MemReaderConfigFactory) -> BaseMemReader:
backend = config_factory.backend
if backend not in cls.backend_to_class:
diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py
index 2b0bbc5d..b439cb2b 100644
--- a/src/memos/mem_reader/simple_struct.py
+++ b/src/memos/mem_reader/simple_struct.py
@@ -13,6 +13,7 @@
from memos.chunkers import ChunkerFactory
from memos.configs.mem_reader import SimpleStructMemReaderConfig
from memos.configs.parser import ParserConfigFactory
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import EmbedderFactory
from memos.llms.factory import LLMFactory
from memos.mem_reader.base import BaseMemReader
@@ -26,6 +27,7 @@
SIMPLE_STRUCT_MEM_READER_PROMPT,
SIMPLE_STRUCT_MEM_READER_PROMPT_ZH,
)
+from memos.utils import timed
logger = log.get_logger(__name__)
@@ -55,45 +57,60 @@ def detect_lang(text):
def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder):
# generate
- raw = llm.generate(message)
- if not raw:
+ try:
+ raw = llm.generate(message)
+ if not raw:
+ logger.warning(f"[LLM] Empty generation for input: {message}")
+ return None
+ except Exception as e:
+ logger.error(f"[LLM] Exception during generation: {e}")
return None
# parse_json_result
- chunk_res = parse_json_result(raw)
- if not chunk_res:
+ try:
+ chunk_res = parse_json_result(raw)
+ if not chunk_res:
+ logger.warning(f"[Parse] Failed to parse result: {raw}")
+ return None
+ except Exception as e:
+ logger.error(f"[Parse] Exception during JSON parsing: {e}")
return None
- value = chunk_res.get("value")
- if not value:
+ try:
+ value = chunk_res.get("value", "").strip()
+ if not value:
+ logger.warning("[BuildNode] value is empty")
+ return None
+
+ tags = chunk_res.get("tags", [])
+ if not isinstance(tags, list):
+ tags = []
+
+ key = chunk_res.get("key", None)
+
+ embedding = embedder.embed([value])[0]
+
+ return TextualMemoryItem(
+ memory=value,
+ metadata=TreeNodeTextualMemoryMetadata(
+ user_id=info.get("user_id", ""),
+ session_id=info.get("session_id", ""),
+ memory_type="LongTermMemory",
+ status="activated",
+ tags=tags,
+ key=key,
+ embedding=embedding,
+ usage=[],
+ sources=[{"type": "doc", "doc_path": f"{scene_file}_{idx}"}],
+ background="",
+ confidence=0.99,
+ type="fact",
+ ),
+ )
+ except Exception as e:
+ logger.error(f"[BuildNode] Error building node: {e}")
return None
- # embed
- embedding = embedder.embed([value])[0]
-
- # TextualMemoryItem
- tags = chunk_res["tags"] if isinstance(chunk_res.get("tags"), list) else []
- key = chunk_res.get("key", None)
-
- node_i = TextualMemoryItem(
- memory=value,
- metadata=TreeNodeTextualMemoryMetadata(
- user_id=info.get("user_id"),
- session_id=info.get("session_id"),
- memory_type="LongTermMemory",
- status="activated",
- tags=tags,
- key=key,
- embedding=embedding,
- usage=[],
- sources=[f"{scene_file}_{idx}"],
- background="",
- confidence=0.99,
- type="fact",
- ),
- )
- return node_i
-
class SimpleStructMemReader(BaseMemReader, ABC):
"""Naive implementation of MemReader."""
@@ -110,44 +127,77 @@ def __init__(self, config: SimpleStructMemReaderConfig):
self.embedder = EmbedderFactory.from_config(config.embedder)
self.chunker = ChunkerFactory.from_config(config.chunker)
+ @timed
def _process_chat_data(self, scene_data_info, info):
- lang = detect_lang("\n".join(scene_data_info))
+ mem_list = []
+ for item in scene_data_info:
+ if "chat_time" in item:
+ mem = item["role"] + ": " + f"[{item['chat_time']}]: " + item["content"]
+ mem_list.append(mem)
+ else:
+ mem = item["role"] + ":" + item["content"]
+ mem_list.append(mem)
+ lang = detect_lang("\n".join(mem_list))
template = PROMPT_DICT["chat"][lang]
examples = PROMPT_DICT["chat"][f"{lang}_example"]
- prompt = template.replace("${conversation}", "\n".join(scene_data_info))
+ prompt = template.replace("${conversation}", "\n".join(mem_list))
if self.config.remove_prompt_example:
prompt = prompt.replace(examples, "")
messages = [{"role": "user", "content": prompt}]
- response_text = self.llm.generate(messages)
- response_json = self.parse_json_result(response_text)
+ try:
+ response_text = self.llm.generate(messages)
+ response_json = self.parse_json_result(response_text)
+ except Exception as e:
+ logger.error(f"[LLM] Exception during chat generation: {e}")
+ response_json = {
+ "memory list": [
+ {
+ "key": "\n".join(mem_list)[:10],
+ "memory_type": "UserMemory",
+ "value": "\n".join(mem_list),
+ "tags": [],
+ }
+ ],
+ "summary": "\n".join(mem_list),
+ }
chat_read_nodes = []
for memory_i_raw in response_json.get("memory list", []):
- node_i = TextualMemoryItem(
- memory=memory_i_raw.get("value", ""),
- metadata=TreeNodeTextualMemoryMetadata(
- user_id=info.get("user_id"),
- session_id=info.get("session_id"),
- memory_type=memory_i_raw.get("memory_type", "")
+ try:
+ memory_type = (
+ memory_i_raw.get("memory_type", "LongTermMemory")
.replace("长期记忆", "LongTermMemory")
- .replace("用户记忆", "UserMemory"),
- status="activated",
- tags=memory_i_raw.get("tags", [])
- if type(memory_i_raw.get("tags", [])) is list
- else [],
- key=memory_i_raw.get("key", ""),
- embedding=self.embedder.embed([memory_i_raw.get("value", "")])[0],
- usage=[],
- sources=scene_data_info,
- background=response_json.get("summary", ""),
- confidence=0.99,
- type="fact",
- ),
- )
- chat_read_nodes.append(node_i)
+ .replace("用户记忆", "UserMemory")
+ )
+
+ if memory_type not in ["LongTermMemory", "UserMemory"]:
+ memory_type = "LongTermMemory"
+
+ node_i = TextualMemoryItem(
+ memory=memory_i_raw.get("value", ""),
+ metadata=TreeNodeTextualMemoryMetadata(
+ user_id=info.get("user_id"),
+ session_id=info.get("session_id"),
+ memory_type=memory_type,
+ status="activated",
+ tags=memory_i_raw.get("tags", [])
+ if type(memory_i_raw.get("tags", [])) is list
+ else [],
+ key=memory_i_raw.get("key", ""),
+ embedding=self.embedder.embed([memory_i_raw.get("value", "")])[0],
+ usage=[],
+ sources=scene_data_info,
+ background=response_json.get("summary", ""),
+ confidence=0.99,
+ type="fact",
+ ),
+ )
+ chat_read_nodes.append(node_i)
+ except Exception as e:
+ logger.error(f"[ChatReader] Error parsing memory item: {e}")
return chat_read_nodes
@@ -200,8 +250,8 @@ def get_memory(
else:
processing_func = self._process_doc_data
- # Process Q&A pairs concurrently
- with concurrent.futures.ThreadPoolExecutor() as executor:
+ # Process Q&A pairs concurrently with context propagation
+ with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(processing_func, scene_data_info, info)
for scene_data_info in list_scene_data_info
@@ -239,11 +289,9 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]:
for item in items:
# Convert dictionary to string
if "chat_time" in item:
- mem = item["role"] + ": " + f"[{item['chat_time']}]: " + item["content"]
- result.append(mem)
+ result.append(item)
else:
- mem = item["role"] + ":" + item["content"]
- result.append(mem)
+ result.append(item)
if len(result) >= 10:
results.append(result)
context = copy.deepcopy(result[-2:])
@@ -254,17 +302,21 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]:
for item in scene_data:
try:
if os.path.exists(item):
- parsed_text = parser.parse(item)
- results.append({"file": "pure_text", "text": parsed_text})
+ try:
+ parsed_text = parser.parse(item)
+ results.append({"file": item, "text": parsed_text})
+ except Exception as e:
+ logger.error(f"[SceneParser] Error parsing {item}: {e}")
+ continue
else:
parsed_text = item
- results.append({"file": item, "text": parsed_text})
+ results.append({"file": "pure_text", "text": parsed_text})
except Exception as e:
print(f"Error parsing file {item}: {e!s}")
return results
- def _process_doc_data(self, scene_data_info, info):
+ def _process_doc_data(self, scene_data_info, info, **kwargs):
chunks = self.chunker.chunk(scene_data_info["text"])
messages = []
for chunk in chunks:
@@ -277,7 +329,7 @@ def _process_doc_data(self, scene_data_info, info):
doc_nodes = []
scene_file = scene_data_info["file"]
- with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
+ with ContextThreadPoolExecutor(max_workers=50) as executor:
futures = {
executor.submit(
_build_node,
@@ -302,6 +354,7 @@ def _process_doc_data(self, scene_data_info, info):
doc_nodes.append(node)
except Exception as e:
tqdm.write(f"[ERROR] {e}")
+ logger.error(f"[DocReader] Future task failed: {e}")
return doc_nodes
def parse_json_result(self, response_text):
@@ -309,14 +362,14 @@ def parse_json_result(self, response_text):
json_start = response_text.find("{")
response_text = response_text[json_start:]
response_text = response_text.replace("```", "").strip()
- if response_text[-1] != "}":
+ if not response_text.endswith("}"):
response_text += "}"
- response_json = json.loads(response_text)
- return response_json
+ return json.loads(response_text)
except json.JSONDecodeError as e:
- logger.warning(
- f"Failed to parse LLM response as JSON: {e}\nRaw response:\n{response_text}"
- )
+ logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}")
+ return {}
+ except Exception as e:
+ logger.error(f"[JSONParse] Unexpected error: {e}")
return {}
def transform_memreader(self, data: dict) -> list[TextualMemoryItem]:
diff --git a/src/memos/mem_scheduler/analyzer/__init__.py b/src/memos/mem_scheduler/analyzer/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
new file mode 100644
index 00000000..7cd085ad
--- /dev/null
+++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
@@ -0,0 +1,569 @@
+from datetime import datetime
+
+from memos.configs.mem_os import MOSConfig
+from memos.log import get_logger
+from memos.mem_os.main import MOS
+from memos.mem_scheduler.schemas.general_schemas import (
+ ANSWER_LABEL,
+ MONITOR_WORKING_MEMORY_TYPE,
+ QUERY_LABEL,
+)
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+
+
+logger = get_logger(__name__)
+
+
+class MOSForTestScheduler(MOS):
+ """This class is only to test abilities of mem scheduler with enhanced monitoring"""
+
+ def __init__(self, config: MOSConfig):
+ super().__init__(config)
+ self.memory_helpfulness_analysis = []
+
+ def _str_memories(self, memories: list[str]) -> str:
+ """Format memories for display."""
+ if not memories:
+ return "No memories."
+ return "\n".join(f"{i + 1}. {memory}" for i, memory in enumerate(memories))
+
+ def _analyze_memory_helpfulness(
+ self,
+ query: str,
+ working_memories_before: list,
+ working_memories_after: list,
+ scheduler_memories: list,
+ ):
+ """Analyze how helpful each memory is for answering the current query."""
+ print("\n" + "=" * 80)
+ print("🧠 MEMORY HELPFULNESS ANALYSIS FOR QUERY")
+ print("=" * 80)
+
+ print(f"📝 Query: {query}")
+ print(f"📊 Working Memories Before Scheduler: {len(working_memories_before)}")
+ print(f"📊 Working Memories After Scheduler: {len(working_memories_after)}")
+ print(f"📊 Working Memories from Monitor: {len(scheduler_memories)}")
+
+ # Display working memories before scheduler (first 5 only)
+ if working_memories_before:
+ print("\n🔄 WORKING MEMORIES BEFORE SCHEDULER (first 5):")
+ for i, mem in enumerate(working_memories_before[:5]):
+ print(f" {i + 1}. {mem}")
+
+ # Display working memories after scheduler (first 5 only)
+ if working_memories_after:
+ print("\n🔄 WORKING MEMORIES AFTER SCHEDULER (first 5):")
+ for i, mem in enumerate(working_memories_after[:5]):
+ print(f" {i + 1}. {mem}")
+
+ # Display scheduler memories from monitor (first 5 only)
+ if scheduler_memories:
+ print("\n🔄 WORKING MEMORIES FROM MONITOR (first 5):")
+ for i, mem in enumerate(scheduler_memories[:5]):
+ print(f" {i + 1}. {mem}")
+
+ # Batch assess working memory helpfulness before scheduler
+ if working_memories_before:
+ print(
+ f"\n🔄 WORKING MEMORY HELPFULNESS BEFORE SCHEDULER ({len(working_memories_before)}):"
+ )
+ before_assessment = self._batch_assess_memories(
+ query, working_memories_before[:5], "before scheduler"
+ )
+ for i, (_mem, score, reason) in enumerate(before_assessment):
+ print(f" {i + 1}. Helpfulness: {score}/10 - {reason}")
+
+ # Batch assess working memory helpfulness after scheduler
+ if working_memories_after:
+ print(
+ f"\n🔄 WORKING MEMORY HELPFULNESS AFTER SCHEDULER ({len(working_memories_after)}):"
+ )
+ after_assessment = self._batch_assess_memories(
+ query, working_memories_after[:5], "after scheduler"
+ )
+ for i, (_mem, score, reason) in enumerate(after_assessment):
+ print(f" {i + 1}. Helpfulness: {score}/10 - {reason}")
+
+ # Batch assess scheduler memories from monitor
+ if scheduler_memories:
+ print(f"\n🔄 WORKINGMEMORIES FROM MONITOR HELPFULNESS ({len(scheduler_memories)}):")
+ scheduler_assessment = self._batch_assess_memories(
+ query, scheduler_memories[:5], "from monitor"
+ )
+ for i, (_mem, score, reason) in enumerate(scheduler_assessment):
+ print(f" {i + 1}. Helpfulness: {score}/10 - {reason}")
+
+ # Overall assessment - compare before vs after vs scheduler
+ print("\n💡 OVERALL ASSESSMENT:")
+ if working_memories_before and working_memories_after:
+ before_scores = (
+ [score for _, score, _ in before_assessment]
+ if "before_assessment" in locals()
+ else []
+ )
+ after_scores = (
+ [score for _, score, _ in after_assessment]
+ if "after_assessment" in locals()
+ else []
+ )
+ scheduler_scores = (
+ [score for _, score, _ in scheduler_assessment]
+ if "scheduler_assessment" in locals()
+ else []
+ )
+
+ avg_before_helpfulness = sum(before_scores) / len(before_scores)
+ avg_after_helpfulness = sum(after_scores) / len(after_scores)
+
+ print(f" Average Helpfulness Before Scheduler: {avg_before_helpfulness:.1f}/10")
+ print(f" Average Helpfulness After Scheduler: {avg_after_helpfulness:.1f}/10")
+ print(f" Improvement: {avg_after_helpfulness - avg_before_helpfulness:+.1f}")
+
+ if avg_after_helpfulness > avg_before_helpfulness:
+ print(" ✅ Scheduler improved working memory quality")
+ elif avg_after_helpfulness < avg_before_helpfulness:
+ print(" ❌ Scheduler decreased working memory quality")
+ else:
+ print(" ⚖️ Scheduler maintained working memory quality")
+
+ # Compare scheduler memories vs working memories
+
+ avg_scheduler_helpfulness = sum(scheduler_scores) / len(scheduler_scores)
+ print(
+ f" Average Helpfulness of Memories from Monitors: {avg_scheduler_helpfulness:.1f}/10"
+ )
+
+ if avg_scheduler_helpfulness > avg_after_helpfulness:
+ print(" 🎯 Memories from Monitors are more helpful than working memories")
+ elif avg_scheduler_helpfulness < avg_after_helpfulness:
+ print(" ⚠️ Working memories are more helpful than Memories from Monitors")
+ else:
+ print(
+ " ⚖️ WORKING Memories from Monitors and working memories have similar helpfulness"
+ )
+
+ # Record analysis results
+ self.memory_helpfulness_analysis.append(
+ {
+ "query": query,
+ "working_memories_before_count": len(working_memories_before),
+ "working_memories_after_count": len(working_memories_after),
+ "scheduler_memories_count": len(scheduler_memories),
+ "working_helpfulness_before": [score for _, score, _ in before_assessment]
+ if "before_assessment" in locals()
+ else [],
+ "working_helpfulness_after": [score for _, score, _ in after_assessment]
+ if "after_assessment" in locals()
+ else [],
+ "scheduler_helpfulness": [score for _, score, _ in scheduler_assessment]
+ if "scheduler_assessment" in locals()
+ else [],
+ }
+ )
+
+ print("=" * 80 + "\n")
+
+ def _batch_assess_memories(self, query: str, memories: list, context: str) -> list:
+ """Use LLM to assess multiple memories at once and compare their quality."""
+ try:
+ # Create prompt for batch assessment
+ memories_text = "\n".join([f"{i + 1}. {mem}" for i, mem in enumerate(memories)])
+
+ assessment_prompt = f"""
+ Task: Assess and compare the helpfulness of multiple memories for answering a query.
+
+ Query: "{query}"
+
+ Context: These are working memories {context}.
+
+ Memories to assess:
+ {memories_text}
+
+ Please provide:
+ 1. A helpfulness score from 1-10 for each memory (where 10 = extremely helpful, 1 = not helpful at all)
+ 2. A brief reason for each score
+ 3. Rank the memories from most helpful to least helpful
+
+ Format your response as:
+ Memory 1: Score [number] - [reason]
+ Memory 2: Score [number] - [reason]
+ Memory 3: Score [number] - [reason]
+ Memory 4: Score [number] - [reason]
+ Memory 5: Score [number] - [reason]
+
+ Ranking: [memory numbers in order from most to least helpful]
+
+ Consider:
+ - Direct relevance to the query
+ - Information completeness
+ - How directly it answers the question
+ - Whether it provides useful context or background
+ - Compare memories against each other for relative quality
+ """
+
+ # Use the chat LLM to get batch assessment
+ messages = [{"role": "user", "content": assessment_prompt}]
+ response = self.chat_llm.generate(messages)
+
+ # Parse the response to extract scores and reasons
+ assessment_results = []
+ lines = response.strip().split("\n")
+
+ for i, mem in enumerate(memories):
+ score = 5 # Default score
+ reason = "LLM assessment failed, using default score"
+
+ # Look for the corresponding memory line
+ for line in lines:
+ if line.startswith(f"Memory {i + 1}:"):
+ try:
+ # Extract score and reason from line like "Memory 1: Score 8 - Highly relevant"
+ parts = line.split("Score ")[1].split(" - ", 1)
+ score = int(parts[0])
+ score = max(1, min(10, score)) # Ensure score is 1-10
+ reason = parts[1] if len(parts) > 1 else "No reason provided"
+ except Exception:
+ pass
+ break
+
+ assessment_results.append((mem, score, reason))
+
+ return assessment_results
+
+ except Exception as e:
+ logger.warning(f"LLM batch assessment failed: {e}, using fallback scoring")
+ # Fallback to individual assessment if batch fails
+ return [
+ (
+ mem,
+ self._assess_memory_helpfulness(query, mem)["score"],
+ self._assess_memory_helpfulness(query, mem)["reason"],
+ )
+ for mem in memories
+ ]
+
+ def _assess_memory_helpfulness(self, query: str, memory: str) -> dict:
+ """Use LLM to assess how helpful a memory is for answering the current query (1-10 scale)"""
+ try:
+ # Create prompt for LLM assessment
+ assessment_prompt = f"""
+ Task: Rate how helpful this memory is for answering the given query on a scale of 1-10.
+
+ Query: "{query}"
+
+ Memory: "{memory}"
+
+ Please provide:
+ 1. A score from 1-10 (where 10 = extremely helpful, 1 = not helpful at all)
+ 2. A brief reason for your score
+
+ Format your response as:
+ Score: [number]
+ Reason: [your explanation]
+
+ Consider:
+ - Direct relevance to the query
+ - Information completeness
+ - How directly it answers the question
+ - Whether it provides useful context or background
+ """
+
+ # Use the chat LLM to get assessment
+ messages = [{"role": "user", "content": assessment_prompt}]
+ response = self.chat_llm.generate(messages)
+
+ # Parse the response to extract score and reason
+ lines = response.strip().split("\n")
+ score = 5 # Default score
+ reason = "LLM assessment failed, using default score"
+
+ for line in lines:
+ if line.startswith("Score:"):
+ try:
+ score_text = line.split(":")[1].strip()
+ score = int(score_text)
+ score = max(1, min(10, score)) # Ensure score is 1-10
+ except Exception:
+ pass
+ elif line.startswith("Reason:"):
+ reason = line.split(":", 1)[1].strip()
+
+ return {"score": score, "reason": reason}
+
+ except Exception as e:
+ logger.warning(f"LLM assessment failed: {e}, using fallback scoring")
+ # Fallback to simple keyword matching if LLM fails
+ return self._fallback_memory_assessment(query, memory)
+
+ def _fallback_memory_assessment(self, query: str, memory: str) -> dict:
+ """Fallback assessment method using keyword matching if LLM fails"""
+ query_lower = query.lower()
+ memory_lower = memory.lower()
+
+ # Keyword matching
+ query_words = set(query_lower.split())
+ memory_words = set(memory_lower.split())
+ common_words = query_words.intersection(memory_words)
+
+ # Semantic relevance scoring
+ score = 0
+
+ # Exact keyword matches (highest weight)
+ if len(common_words) > 0:
+ score += min(len(common_words) * 2, 6)
+
+ # Partial matches (medium weight)
+ partial_matches = sum(
+ 1 for qw in query_words for mw in memory_words if qw in mw or mw in qw
+ )
+ if partial_matches > 0:
+ score += min(partial_matches, 3)
+
+ # Topic relevance (through common topic words)
+ topic_words = [
+ "problem",
+ "solution",
+ "answer",
+ "method",
+ "reason",
+ "result",
+ "analysis",
+ "compare",
+ "explain",
+ ]
+ topic_matches = sum(1 for topic in topic_words if topic in memory_lower)
+ score += topic_matches
+
+ # Ensure score is 1-10
+ score = max(1, min(10, score))
+
+ # Determine helpfulness level
+ if score >= 8:
+ reason = "Highly relevant, directly answers the query"
+ elif score >= 6:
+ reason = "Relevant, provides useful information"
+ elif score >= 4:
+ reason = "Partially relevant, somewhat helpful"
+ elif score >= 2:
+ reason = "Low relevance, limited help"
+ else:
+ reason = "Very low relevance, minimal help"
+
+ return {"score": score, "reason": reason}
+
+ def _assess_ranking_quality(self, rank: int, helpfulness: int) -> str:
+ """Use LLM to assess whether the memory ranking is reasonable"""
+ try:
+ # Create prompt for LLM ranking assessment
+ ranking_prompt = f"""
+ Task: Assess whether this memory ranking is reasonable.
+
+ Context: A memory with helpfulness score {helpfulness}/10 is ranked at position {rank}.
+
+ Please evaluate if this ranking makes sense and provide a brief assessment.
+
+ Consider:
+ - Higher helpfulness scores should generally rank higher
+ - Rank 1 should typically have the highest helpfulness
+ - The relationship between rank and helpfulness
+
+ Provide a brief assessment in one sentence.
+ """
+
+ # Use the chat LLM to get assessment
+ messages = [{"role": "user", "content": ranking_prompt}]
+ response = self.chat_llm.generate(messages)
+
+ return response.strip()
+
+ except Exception as e:
+ logger.warning(f"LLM ranking assessment failed: {e}, using fallback assessment")
+ # Fallback assessment
+ if rank == 1 and helpfulness >= 8:
+ return "✅ Ranking is reasonable - most helpful memory ranked first"
+ elif rank == 1 and helpfulness <= 4:
+ return "❌ Ranking is unreasonable - first ranked memory has low helpfulness"
+ elif rank <= 3 and helpfulness >= 6:
+ return "✅ Ranking is reasonable - high helpfulness memory ranked high"
+ elif rank <= 3 and helpfulness <= 3:
+ return "⚠️ Ranking may be unreasonable - low helpfulness memory ranked high"
+ elif rank > 3 and helpfulness >= 7:
+ return "⚠️ Ranking may be unreasonable - high helpfulness memory ranked low"
+ else:
+ return "🟡 Ranking is acceptable - helpfulness and rank generally match"
+
+ def chat(self, query: str, user_id: str | None = None) -> str:
+ """
+ Chat with the MOS with memory helpfulness analysis.
+
+ Args:
+ query (str): The user's query.
+ user_id (str | None): The user ID.
+
+ Returns:
+ str: The response from the MOS.
+ """
+ target_user_id = user_id if user_id is not None else self.user_id
+ accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
+ user_cube_ids = [cube.cube_id for cube in accessible_cubes]
+
+ if target_user_id not in self.chat_history_manager:
+ self._register_chat_history(target_user_id)
+
+ chat_history = self.chat_history_manager[target_user_id]
+ topk_for_scheduler = 2
+
+ if self.config.enable_textual_memory and self.mem_cubes:
+ memories_all = []
+ for mem_cube_id, mem_cube in self.mem_cubes.items():
+ if mem_cube_id not in user_cube_ids:
+ continue
+ if not mem_cube.text_mem:
+ continue
+
+ # Get working memories BEFORE scheduler
+ working_memories_before = [m.memory for m in mem_cube.text_mem.get_working_memory()]
+
+ message_item = ScheduleMessageItem(
+ user_id=target_user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ label=QUERY_LABEL,
+ content=query,
+ timestamp=datetime.now(),
+ )
+
+ print(f"\n🚀 Starting Scheduler for {mem_cube_id}...")
+
+ # Force scheduler to run immediately
+ self.mem_scheduler.monitor.query_trigger_interval = 0
+ self.mem_scheduler._query_message_consumer(messages=[message_item])
+
+ # Get scheduler memories
+ scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories(
+ user_id=target_user_id,
+ mem_cube_id=mem_cube_id,
+ memory_type=MONITOR_WORKING_MEMORY_TYPE,
+ top_k=20,
+ )
+
+ # Get working memories AFTER scheduler
+ working_memories_after = [m.memory for m in mem_cube.text_mem.get_working_memory()]
+
+ # Get mem_cube memories for response generation
+ memories = mem_cube.text_mem.search(
+ query,
+ top_k=self.config.top_k - topk_for_scheduler,
+ info={
+ "user_id": target_user_id,
+ "session_id": self.session_id,
+ "chat_history": chat_history.chat_history,
+ },
+ )
+ text_memories = [m.memory for m in memories]
+
+ # Analyze memory helpfulness - compare before vs after vs scheduler
+ self._analyze_memory_helpfulness(
+ query, working_memories_before, working_memories_after, scheduler_memories
+ )
+
+ # Combine all memories for response generation
+ memories_all.extend(scheduler_memories[:topk_for_scheduler])
+ memories_all.extend(text_memories)
+ memories_all = list(set(memories_all))
+
+ logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
+ system_prompt = self._build_system_prompt(memories_all)
+ else:
+ system_prompt = self._build_system_prompt()
+
+ current_messages = [
+ {"role": "system", "content": system_prompt},
+ *chat_history.chat_history,
+ {"role": "user", "content": query},
+ ]
+ past_key_values = None
+
+ if self.config.enable_activation_memory:
+ assert self.config.chat_model.backend == "huggingface", (
+ "Activation memory only used for huggingface backend."
+ )
+ # TODO this only one cubes
+ for mem_cube_id, mem_cube in self.mem_cubes.items():
+ if mem_cube_id not in user_cube_ids:
+ continue
+ if mem_cube.act_mem:
+ kv_cache = next(iter(mem_cube.act_mem.get_all()), None)
+ past_key_values = (
+ kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None
+ )
+ break
+ # Generate response
+ response = self.chat_llm.generate(current_messages, past_key_values=past_key_values)
+ else:
+ response = self.chat_llm.generate(current_messages)
+
+ logger.info(f"🤖 [Assistant] {response}\n")
+ chat_history.chat_history.append({"role": "user", "content": query})
+ chat_history.chat_history.append({"role": "assistant", "content": response})
+ self.chat_history_manager[user_id] = chat_history
+
+ # Submit message to scheduler for answer processing
+ for accessible_mem_cube in accessible_cubes:
+ mem_cube_id = accessible_mem_cube.cube_id
+ mem_cube = self.mem_cubes[mem_cube_id]
+ if self.enable_mem_scheduler and self.mem_scheduler is not None:
+ message_item = ScheduleMessageItem(
+ user_id=target_user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ label=ANSWER_LABEL,
+ content=response,
+ timestamp=datetime.now(),
+ )
+ self.mem_scheduler.submit_messages(messages=[message_item])
+
+ return response
+
+ def get_memory_helpfulness_summary(self) -> dict:
+ """Get summary of memory helpfulness analysis."""
+ if not self.memory_helpfulness_analysis:
+ return {"message": "No memory helpfulness analysis data available"}
+
+ total_queries = len(self.memory_helpfulness_analysis)
+
+ # Calculate average helpfulness for working memories before scheduler
+ before_scores = []
+ for analysis in self.memory_helpfulness_analysis:
+ before_scores.extend(analysis["working_helpfulness_before"])
+
+ # Calculate average helpfulness for working memories after scheduler
+ after_scores = []
+ for analysis in self.memory_helpfulness_analysis:
+ after_scores.extend(analysis["working_helpfulness_after"])
+
+ # Calculate average helpfulness for scheduler memories from monitor
+ scheduler_scores = []
+ for analysis in self.memory_helpfulness_analysis:
+ scheduler_scores.extend(analysis["scheduler_helpfulness"])
+
+ avg_before_helpfulness = sum(before_scores) / len(before_scores) if before_scores else 0
+ avg_after_helpfulness = sum(after_scores) / len(after_scores) if after_scores else 0
+ avg_scheduler_helpfulness = (
+ sum(scheduler_scores) / len(scheduler_scores) if scheduler_scores else 0
+ )
+
+ return {
+ "total_queries": total_queries,
+ "working_memories_before_analyzed": len(before_scores),
+ "working_memories_after_analyzed": len(after_scores),
+ "scheduler_memories_analyzed": len(scheduler_scores),
+ "average_helpfulness_before_scheduler": f"{avg_before_helpfulness:.1f}/10",
+ "average_helpfulness_after_scheduler": f"{avg_after_helpfulness:.1f}/10",
+ "average_helpfulness_scheduler_memories": f"{avg_scheduler_helpfulness:.1f}/10",
+ "overall_improvement": f"{avg_after_helpfulness - avg_before_helpfulness:+.1f}",
+ "improvement_percentage": f"{((avg_after_helpfulness - avg_before_helpfulness) / avg_before_helpfulness * 100):+.1f}%"
+ if avg_before_helpfulness > 0
+ else "N/A",
+ "scheduler_vs_working_comparison": f"{avg_scheduler_helpfulness - avg_after_helpfulness:+.1f}",
+ }
diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
new file mode 100644
index 00000000..7c0fa5a4
--- /dev/null
+++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
@@ -0,0 +1,280 @@
+from __future__ import annotations
+
+import time
+
+from functools import wraps
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_scheduler import GeneralScheduler
+from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_MAX_QUERY_KEY_WORDS,
+ UserID,
+)
+from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem
+
+
+if TYPE_CHECKING:
+ from memos.memories.textual.tree import TextualMemoryItem
+
+
+logger = get_logger(__name__)
+
+
+class SchedulerForEval(GeneralScheduler):
+ """
+ A scheduler class that inherits from GeneralScheduler and provides evaluation-specific functionality.
+ This class extends GeneralScheduler with evaluation methods.
+ """
+
+ # Class variable to store timing information for all instances
+ timer_cache: ClassVar[dict[str, dict[str, Any]]] = {}
+
+ def __init__(self, config):
+ """
+ Initialize the SchedulerForEval with the same configuration as GeneralScheduler.
+
+ Args:
+ config: Configuration object for the scheduler
+ """
+ super().__init__(config)
+ # Initialize instance timer_cache
+ self.timer_cache = {}
+
+ @staticmethod
+ def time_it(func_name: str | None = None):
+ """
+ Static method decorator to measure function execution time and store in timer_cache.
+
+ Args:
+ func_name: Custom name for the function in timer_cache. If None, uses function.__name__
+ """
+
+ def decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ # Get function name
+ name = func_name or func.__name__
+
+ # Start timing
+ start_time = time.time()
+ result = func(self, *args, **kwargs)
+ end_time = time.time()
+
+ # Calculate execution time
+ exec_time = end_time - start_time
+
+ # Format time as HH:MM:SS.mmm
+ hours = int(exec_time // 3600)
+ minutes = int((exec_time % 3600) // 60)
+ seconds = exec_time % 60
+
+ if hours > 0:
+ time_str = f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
+ else:
+ time_str = f"{minutes:02d}:{seconds:06.3f}"
+
+ # Store in timer_cache
+ if not hasattr(self, "timer_cache"):
+ self.timer_cache = {}
+
+ self.timer_cache[name] = {
+ "time_str": time_str,
+ "seconds": exec_time,
+ }
+
+ logger.info(f"{name} executed in {time_str}")
+ return result
+
+ return wrapper
+
+ return decorator
+
+ def get_timer_summary(self) -> str:
+ """
+ Get a summary of all timed functions.
+
+ Returns:
+ Formatted string with timing information
+ """
+ if not self.timer_cache:
+ return "No timing data available."
+
+ summary = "=== Timing Summary ===\n"
+ for func_name, data in self.timer_cache.items():
+ summary += f"{func_name}: {data['time_str']} (at {data['timestamp']})\n"
+
+ return summary
+
+ def clear_timer_cache(self):
+ """Clear the timer cache."""
+ self.timer_cache.clear()
+
+ @time_it("update_working_memory")
+ def update_working_memory_for_eval(
+ self, query: str, user_id: UserID | str, top_k: int
+ ) -> list[str]:
+ """
+ Update working memory based on query and return the updated memory list.
+
+ Args:
+ query: The query string
+ user_id: User identifier
+ top_k: Number of top memories to return
+
+ Returns:
+ List of memory strings from updated working memory
+ """
+ self.monitor.register_query_monitor_if_not_exists(
+ user_id=user_id, mem_cube_id=self.current_mem_cube_id
+ )
+
+ query_keywords = self.monitor.extract_query_keywords(query=query)
+ logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')
+
+ item = QueryMonitorItem(
+ user_id=user_id,
+ mem_cube_id=self.current_mem_cube_id,
+ query_text=query,
+ keywords=query_keywords,
+ max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
+ )
+ query_db_manager = self.monitor.query_monitors[user_id][self.current_mem_cube_id]
+ query_db_manager.obj.put(item=item)
+ # Sync with database after adding new item
+ query_db_manager.sync_with_orm()
+ logger.debug(f"Queries in monitor are {query_db_manager.obj.get_queries_with_timesort()}.")
+
+ queries = [query]
+
+ # recall
+ mem_cube = self.current_mem_cube
+ text_mem_base = mem_cube.text_mem
+
+ cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
+ text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
+ intent_result = self.monitor.detect_intent(
+ q_list=queries, text_working_memory=text_working_memory
+ )
+
+ if intent_result["trigger_retrieval"]:
+ missing_evidences = intent_result["missing_evidences"]
+ num_evidence = len(missing_evidences)
+ k_per_evidence = max(1, top_k // max(1, num_evidence))
+ new_candidates = []
+ for item in missing_evidences:
+ logger.info(f"missing_evidences: {item}")
+ results: list[TextualMemoryItem] = self.retriever.search(
+ query=item,
+ mem_cube=mem_cube,
+ top_k=k_per_evidence,
+ method=self.search_method,
+ )
+ logger.info(
+ f"search results for {missing_evidences}: {[one.memory for one in results]}"
+ )
+ new_candidates.extend(results)
+ logger.info(
+ f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories."
+ )
+ else:
+ new_candidates = []
+ logger.info(f"intent_result: {intent_result}. not triggered")
+
+ # rerank
+ new_order_working_memory = self.replace_working_memory(
+ user_id=user_id,
+ mem_cube_id=self.current_mem_cube_id,
+ mem_cube=self.current_mem_cube,
+ original_memory=cur_working_memory,
+ new_memory=new_candidates,
+ )
+ new_order_working_memory = new_order_working_memory[:top_k]
+ logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")
+
+ return [m.memory for m in new_order_working_memory]
+
+ @time_it("memory_answer_ability")
+ def evaluate_memory_answer_ability(
+ self, query: str, memory_texts: list[str], top_k: int = 100
+ ) -> bool:
+ """
+ Use LLM to evaluate whether the given memories can answer the query.
+
+ Args:
+ query: The query string to evaluate
+ memory_texts: List of memory texts to check against
+ top_k: Maximum number of memories to consider for evaluation
+
+ Returns:
+ Boolean indicating whether the memories can answer the query
+ """
+ # Limit the number of memories to evaluate
+ limited_memories = memory_texts[:top_k] if memory_texts else []
+
+ # Build prompt using the template
+ prompt = self.monitor.build_prompt(
+ template_name="memory_answer_ability_evaluation",
+ query=query,
+ memory_list="\n".join([f"- {memory}" for memory in limited_memories])
+ if limited_memories
+ else "No memories available",
+ )
+
+ # Use the process LLM to generate response
+ response = self.monitor._process_llm.generate([{"role": "user", "content": prompt}])
+
+ try:
+ # Extract JSON response
+ from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+
+ result = extract_json_dict(response)
+
+ # Validate response structure
+ if "result" in result:
+ logger.info(
+ f"Memory answer ability evaluation result: {result['result']}, reason: {result.get('reason', 'No reason provided')}"
+ )
+ return result["result"]
+ else:
+ logger.warning(f"Invalid response structure from LLM: {result}")
+ return False
+
+ except Exception as e:
+ logger.error(
+ f"Failed to parse LLM response for memory answer ability evaluation: {response}. Error: {e}"
+ )
+ # Fallback: return False if we can't determine answer ability
+ return False
+
+ @time_it("search_for_eval")
+ def search_for_eval(
+ self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True
+ ) -> list[str]:
+ """
+ Original search_for_eval function refactored to use the new decomposed functions.
+
+ Args:
+ query: The query string
+ user_id: User identifier
+ top_k: Number of top memories to return
+ scheduler_flag: Whether to update working memory or just evaluate
+
+ Returns:
+ Tuple of (memory_list, can_answer_boolean)
+ """
+ if not scheduler_flag:
+ # Get current working memory without updating
+ mem_cube = self.current_mem_cube
+ text_mem_base = mem_cube.text_mem
+ cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
+ text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
+
+ return text_working_memory
+ else:
+ # Update working memory and get the result
+ updated_memories = self.update_working_memory_for_eval(
+ query=query, user_id=user_id, top_k=top_k
+ )
+
+ return updated_memories
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 44bc7da3..b6ef00d8 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -5,16 +5,16 @@
from datetime import datetime
from pathlib import Path
+from sqlalchemy.engine import Engine
+
from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
-from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule
-from memos.mem_scheduler.general_modules.redis_service import RedisSchedulerModule
-from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
+from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.schemas.general_schemas import (
@@ -33,6 +33,8 @@
from memos.mem_scheduler.utils.filter_utils import (
transform_name_to_key,
)
+from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule
+from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
from memos.memories.activation.kv import KVCacheMemory
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
@@ -62,6 +64,7 @@ def __init__(self, config: BaseSchedulerConfig):
)
self.retriever: SchedulerRetriever | None = None
+ self.db_engine: Engine | None = None
self.monitor: SchedulerGeneralMonitor | None = None
self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None
self.dispatcher = SchedulerDispatcher(
@@ -70,12 +73,15 @@ def __init__(self, config: BaseSchedulerConfig):
)
# internal message queue
- self.max_internal_messae_queue_size = 100
+ self.max_internal_message_queue_size = self.config.get(
+ "max_internal_message_queue_size", 100
+ )
self.memos_message_queue: Queue[ScheduleMessageItem] = Queue(
- maxsize=self.max_internal_messae_queue_size
+ maxsize=self.max_internal_message_queue_size
)
+ self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50)
self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue(
- maxsize=self.max_internal_messae_queue_size
+ maxsize=self.max_web_log_queue_size
)
self._consumer_thread = None # Reference to our consumer thread
self._running = False
@@ -92,34 +98,57 @@ def __init__(self, config: BaseSchedulerConfig):
self.auth_config = None
self.rabbitmq_config = None
- def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = None):
+ def initialize_modules(
+ self,
+ chat_llm: BaseLLM,
+ process_llm: BaseLLM | None = None,
+ db_engine: Engine | None = None,
+ ):
if process_llm is None:
process_llm = chat_llm
- # initialize submodules
- self.chat_llm = chat_llm
- self.process_llm = process_llm
- self.monitor = SchedulerGeneralMonitor(process_llm=self.process_llm, config=self.config)
- self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config)
- self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)
+ try:
+ # initialize submodules
+ self.chat_llm = chat_llm
+ self.process_llm = process_llm
+ self.db_engine = db_engine
+ self.monitor = SchedulerGeneralMonitor(
+ process_llm=self.process_llm, config=self.config, db_engine=self.db_engine
+ )
+ self.db_engine = self.monitor.db_engine
+ self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config)
+ self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)
+
+ if self.enable_parallel_dispatch:
+ self.dispatcher_monitor.initialize(dispatcher=self.dispatcher)
+ self.dispatcher_monitor.start()
+
+ # initialize with auth_config
+ if self.auth_config_path is not None and Path(self.auth_config_path).exists():
+ self.auth_config = AuthConfig.from_local_config(config_path=self.auth_config_path)
+ elif AuthConfig.default_config_exists():
+ self.auth_config = AuthConfig.from_local_config()
+ else:
+ self.auth_config = AuthConfig.from_local_env()
- if self.enable_parallel_dispatch:
- self.dispatcher_monitor.initialize(dispatcher=self.dispatcher)
- self.dispatcher_monitor.start()
-
- # initialize with auth_cofig
- if self.auth_config_path is not None and Path(self.auth_config_path).exists():
- self.auth_config = AuthConfig.from_local_config(config_path=self.auth_config_path)
- elif AuthConfig.default_config_exists():
- self.auth_config = AuthConfig.from_local_config()
- else:
- self.auth_config = None
+ if self.auth_config is not None:
+ self.rabbitmq_config = self.auth_config.rabbitmq
+ self.initialize_rabbitmq(config=self.rabbitmq_config)
- if self.auth_config is not None:
- self.rabbitmq_config = self.auth_config.rabbitmq
- self.initialize_rabbitmq(config=self.rabbitmq_config)
+ logger.debug("GeneralScheduler has been initialized")
+ except Exception as e:
+ logger.error(f"Failed to initialize scheduler modules: {e}", exc_info=True)
+ # Clean up any partially initialized resources
+ self._cleanup_on_init_failure()
+ raise
- logger.debug("GeneralScheduler has been initialized")
+ def _cleanup_on_init_failure(self):
+ """Clean up resources if initialization fails."""
+ try:
+ if hasattr(self, "dispatcher_monitor") and self.dispatcher_monitor is not None:
+ self.dispatcher_monitor.stop()
+ except Exception as e:
+ logger.warning(f"Error during cleanup: {e}")
@property
def mem_cube(self) -> GeneralMemCube:
@@ -200,8 +229,11 @@ def replace_working_memory(
text_mem_base: TreeTextMemory = text_mem_base
# process rerank memories with llm
- query_monitor = self.monitor.query_monitors[user_id][mem_cube_id]
- query_history = query_monitor.get_queries_with_timesort()
+ query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
+ # Sync with database to get latest query history
+ query_db_manager.sync_with_orm()
+
+ query_history = query_db_manager.obj.get_queries_with_timesort()
memories_with_new_order, rerank_success_flag = (
self.retriever.process_and_rerank_memories(
queries=query_history,
@@ -211,8 +243,27 @@ def replace_working_memory(
)
)
- # update working memory monitors
- query_keywords = query_monitor.get_keywords_collections()
+ # Filter completely unrelated memories according to query_history
+ logger.info(f"Filtering memories based on query history: {len(query_history)} queries")
+ filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history,
+ memories=memories_with_new_order,
+ )
+
+ if filter_success_flag:
+ logger.info(
+ f"Memory filtering completed successfully. "
+ f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories"
+ )
+ memories_with_new_order = filtered_memories
+ else:
+ logger.warning(
+ "Memory filtering failed - keeping all memories as fallback. "
+ f"Original count: {len(memories_with_new_order)}"
+ )
+
+ # Update working memory monitors
+ query_keywords = query_db_manager.obj.get_keywords_collections()
logger.info(
f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
)
@@ -235,7 +286,7 @@ def replace_working_memory(
mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][
mem_cube_id
- ].get_sorted_mem_monitors(reverse=True)
+ ].obj.get_sorted_mem_monitors(reverse=True)
new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors]
text_mem_base.replace_working_memory(memories=new_working_memories)
@@ -278,6 +329,7 @@ def update_activation_memory(
new_text_memories = new_memories
else:
logger.error("Not Implemented.")
+ return
try:
if isinstance(mem_cube.act_mem, VLLMKVCacheMemory):
@@ -333,7 +385,9 @@ def update_activation_memory(
)
except Exception as e:
- logger.warning(f"MOS-based activation memory update failed: {e}", exc_info=True)
+ logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True)
+ # Re-raise the exception if it's critical for the operation
+ # For now, we'll continue execution but this should be reviewed
def update_activation_memory_periodically(
self,
@@ -358,7 +412,8 @@ def update_activation_memory_periodically(
if (
user_id not in self.monitor.working_memory_monitors
or mem_cube_id not in self.monitor.working_memory_monitors[user_id]
- or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0
+ or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories)
+ == 0
):
logger.warning(
"No memories found in working_memory_monitors, activation memory update is skipped"
@@ -369,9 +424,13 @@ def update_activation_memory_periodically(
user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
)
+ # Sync with database to get latest activation memories
+ activation_db_manager = self.monitor.activation_memory_monitors[user_id][
+ mem_cube_id
+ ]
+ activation_db_manager.sync_with_orm()
new_activation_memories = [
- m.memory_text
- for m in self.monitor.activation_memory_monitors[user_id][mem_cube_id].memories
+ m.memory_text for m in activation_db_manager.obj.memories
]
logger.info(
@@ -412,6 +471,11 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
messages = [messages] # transform single message to list
for message in messages:
+ if not isinstance(message, ScheduleMessageItem):
+ error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem"
+ logger.error(error_msg)
+ raise TypeError(error_msg)
+
self.memos_message_queue.put(message)
logger.info(f"Submitted message: {message.label} - {message.content}")
@@ -427,6 +491,11 @@ def _submit_web_logs(
messages = [messages] # transform single message to list
for message in messages:
+ if not isinstance(message, ScheduleLogForWebItem):
+ error_msg = f"Invalid message type: {type(message)}, expected ScheduleLogForWebItem"
+ logger.error(error_msg)
+ raise TypeError(error_msg)
+
self._web_log_message_queue.put(message)
message_info = message.debug_info()
logger.debug(f"Submitted Scheduling log for web: {message_info}")
@@ -461,25 +530,26 @@ def _message_consumer(self) -> None:
"""
while self._running: # Use a running flag for graceful shutdown
try:
- # Check if queue has messages (non-blocking)
- if not self.memos_message_queue.empty():
- # Get all available messages at once
- messages = []
- while not self.memos_message_queue.empty():
- try:
- messages.append(self.memos_message_queue.get_nowait())
- except queue.Empty:
- break
-
- if messages:
- try:
- self.dispatcher.dispatch(messages)
- except Exception as e:
- logger.error(f"Error dispatching messages: {e!s}")
- finally:
- # Mark all messages as processed
- for _ in messages:
- self.memos_message_queue.task_done()
+ # Get all available messages at once (thread-safe approach)
+ messages = []
+ while True:
+ try:
+ # Use get_nowait() directly without empty() check to avoid race conditions
+ message = self.memos_message_queue.get_nowait()
+ messages.append(message)
+ except queue.Empty:
+ # No more messages available
+ break
+
+ if messages:
+ try:
+ self.dispatcher.dispatch(messages)
+ except Exception as e:
+ logger.error(f"Error dispatching messages: {e!s}")
+ finally:
+ # Mark all messages as processed
+ for _ in messages:
+ self.memos_message_queue.task_done()
# Sleep briefly to prevent busy waiting
time.sleep(self._consume_interval) # Adjust interval as needed
diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py
index b2cb4bd3..ce6df4d5 100644
--- a/src/memos/mem_scheduler/general_modules/dispatcher.py
+++ b/src/memos/mem_scheduler/general_modules/dispatcher.py
@@ -2,8 +2,8 @@
from collections import defaultdict
from collections.abc import Callable
-from concurrent.futures import ThreadPoolExecutor
+from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
@@ -33,7 +33,7 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False):
self.enable_parallel_dispatch = enable_parallel_dispatch
self.thread_name_prefix = "dispatcher"
if self.enable_parallel_dispatch:
- self.dispatcher_executor = ThreadPoolExecutor(
+ self.dispatcher_executor = ContextThreadPoolExecutor(
max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix
)
else:
diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py
index 41ebdfd4..3c7116b7 100644
--- a/src/memos/mem_scheduler/general_modules/misc.py
+++ b/src/memos/mem_scheduler/general_modules/misc.py
@@ -1,9 +1,10 @@
import json
+import os
from contextlib import suppress
from datetime import datetime
from queue import Empty, Full, Queue
-from typing import TYPE_CHECKING, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, TypeVar
from pydantic import field_serializer
@@ -16,6 +17,75 @@
BaseModelType = TypeVar("T", bound="BaseModel")
+class EnvConfigMixin(Generic[T]):
+ """Abstract base class for environment variable configuration."""
+
+ ENV_PREFIX = "MEMSCHEDULER_"
+
+ @classmethod
+ def get_env_prefix(cls) -> str:
+ """Automatically generates environment variable prefix from class name.
+
+ Converts the class name to uppercase and appends an underscore.
+ If the class name ends with 'Config', that suffix is removed first.
+
+ Examples:
+ RabbitMQConfig -> "RABBITMQ_"
+ OpenAIConfig -> "OPENAI_"
+ GraphDBAuthConfig -> "GRAPH_DB_AUTH_"
+ """
+ class_name = cls.__name__
+ # Remove 'Config' suffix if present
+ if class_name.endswith("Config"):
+ class_name = class_name[:-6]
+ # Convert to uppercase and add trailing underscore
+
+ return f"{cls.ENV_PREFIX}{class_name.upper()}_"
+
+ @classmethod
+ def from_env(cls: type[T]) -> T:
+ """Creates a config instance from environment variables.
+
+ Reads all environment variables with the class-specific prefix and maps them
+ to corresponding configuration fields (converting to the appropriate types).
+
+ Returns:
+ An instance of the config class populated from environment variables.
+
+ Raises:
+ ValueError: If required environment variables are missing.
+ """
+ prefix = cls.get_env_prefix()
+ field_values = {}
+
+ for field_name, field_info in cls.model_fields.items():
+ env_var = f"{prefix}{field_name.upper()}"
+ field_type = field_info.annotation
+
+ if field_info.is_required() and env_var not in os.environ:
+ raise ValueError(f"Required environment variable {env_var} is missing")
+
+ if env_var in os.environ:
+ raw_value = os.environ[env_var]
+ field_values[field_name] = cls._parse_env_value(raw_value, field_type)
+ elif field_info.default is not None:
+ field_values[field_name] = field_info.default
+ else:
+ raise ValueError()
+ return cls(**field_values)
+
+ @classmethod
+ def _parse_env_value(cls, value: str, target_type: type) -> Any:
+ """Converts environment variable string to appropriate type."""
+ if target_type is bool:
+ return value.lower() in ("true", "1", "t", "y", "yes")
+ if target_type is int:
+ return int(value)
+ if target_type is float:
+ return float(value)
+ return value
+
+
class DictConversionMixin:
"""
Provides conversion functionality between Pydantic models and dictionaries,
@@ -44,6 +114,26 @@ def to_dict(self) -> dict:
dump_data["timestamp"] = self.serialize_datetime(self.timestamp, None)
return dump_data
+ def to_json(self, **kwargs) -> str:
+ """
+ Convert model instance to a JSON string.
+ - Accepts the same kwargs as json.dumps (e.g., indent, ensure_ascii)
+ - Default settings make JSON human-readable and UTF-8 safe
+ """
+ return json.dumps(self.to_dict(), ensure_ascii=False, default=lambda o: str(o), **kwargs)
+
+ @classmethod
+ def from_json(cls: type[BaseModelType], json_str: str) -> BaseModelType:
+ """
+ Create model instance from a JSON string.
+ - Parses JSON into a dictionary and delegates to from_dict
+ """
+ try:
+ data = json.loads(json_str)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid JSON string: {e}") from e
+ return cls.from_dict(data)
+
@classmethod
def from_dict(cls: type[BaseModelType], data: dict) -> BaseModelType:
"""
@@ -102,3 +192,11 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non
def get_queue_content_without_pop(self) -> list[T]:
"""Return a copy of the queue's contents without modifying it."""
return list(self.queue)
+
+ def clear(self) -> None:
+ """Remove all items from the queue.
+
+ This operation is thread-safe.
+ """
+ with self.mutex:
+ self.queue.clear()
diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
index 0aa66707..44e74453 100644
--- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py
+++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
@@ -69,7 +69,7 @@ def create_autofilled_log_item(
and mem_cube_id in self.monitor.activation_memory_monitors[user_id]
):
activation_monitor = self.monitor.activation_memory_monitors[user_id][mem_cube_id]
- transformed_act_memory_size = len(activation_monitor.memories)
+ transformed_act_memory_size = len(activation_monitor.obj.memories)
logger.info(
f'activation_memory_monitors currently has "{transformed_act_memory_size}" transformed memory size'
)
@@ -98,6 +98,7 @@ def create_autofilled_log_item(
)
return log_message
+ # TODO: 日志打出来数量不对
@log_exceptions(logger=logger)
def log_working_memory_replacement(
self,
@@ -125,6 +126,7 @@ def log_working_memory_replacement(
added_memories = list(new_set - original_set) # Present in new but not original
# recording messages
+ log_messages = []
for memory in added_memories:
normalized_mem = transform_name_to_key(name=memory)
if normalized_mem not in memory_type_map:
@@ -145,11 +147,13 @@ def log_working_memory_replacement(
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)
- log_func_callback([log_message])
- logger.info(
- f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) "
- f"transformed to {WORKING_MEMORY_TYPE} memories."
- )
+ log_messages.append(log_message)
+
+ logger.info(
+ f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) "
+ f"transformed to {WORKING_MEMORY_TYPE} memories."
+ )
+ log_func_callback(log_messages)
@log_exceptions(logger=logger)
def log_activation_memory_update(
@@ -170,6 +174,7 @@ def log_activation_memory_update(
added_memories = list(new_set - original_set) # Present in new but not original
# recording messages
+ log_messages = []
for mem in added_memories:
log_message_a = self.create_autofilled_log_item(
log_content=mem,
@@ -194,12 +199,13 @@ def log_activation_memory_update(
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)
- logger.info(
- f"{len(added_memories)} {ACTIVATION_MEMORY_TYPE} memorie(s) "
- f"transformed to {PARAMETER_MEMORY_TYPE} memories."
- )
- log_func_callback([log_message_a, log_message_b])
+ log_messages.extend([log_message_a, log_message_b])
+ logger.info(
+ f"{len(added_memories)} {ACTIVATION_MEMORY_TYPE} memorie(s) "
+ f"transformed to {PARAMETER_MEMORY_TYPE} memories."
+ )
+ log_func_callback(log_messages)
@log_exceptions(logger=logger)
def log_adding_memory(
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index 08293886..340400ab 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -27,6 +27,8 @@ def __init__(self, config: GeneralSchedulerConfig):
"""Initialize the scheduler with the given configuration."""
super().__init__(config)
+ self.query_key_words_limit = self.config.get("query_key_words_limit", 20)
+
# register handlers
handlers = {
QUERY_LABEL: self._query_message_consumer,
@@ -35,78 +37,6 @@ def __init__(self, config: GeneralSchedulerConfig):
}
self.dispatcher.register_handlers(handlers)
- # for evaluation
- def search_for_eval(
- self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True
- ) -> (list[str], bool):
- self.monitor.register_query_monitor_if_not_exists(
- user_id=user_id, mem_cube_id=self.current_mem_cube_id
- )
-
- query_keywords = self.monitor.extract_query_keywords(query=query)
- logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')
-
- item = QueryMonitorItem(
- query_text=query,
- keywords=query_keywords,
- max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
- )
- query_monitor = self.monitor.query_monitors[user_id][self.current_mem_cube_id]
- query_monitor.put(item=item)
- logger.debug(f"Queries in monitor are {query_monitor.get_queries_with_timesort()}.")
-
- queries = [query]
-
- # recall
- mem_cube = self.current_mem_cube
- text_mem_base = mem_cube.text_mem
-
- cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
- text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
- intent_result = self.monitor.detect_intent(
- q_list=queries, text_working_memory=text_working_memory
- )
-
- if not scheduler_flag:
- return text_working_memory, intent_result["trigger_retrieval"]
- else:
- if intent_result["trigger_retrieval"]:
- missing_evidences = intent_result["missing_evidences"]
- num_evidence = len(missing_evidences)
- k_per_evidence = max(1, top_k // max(1, num_evidence))
- new_candidates = []
- for item in missing_evidences:
- logger.info(f"missing_evidences: {item}")
- results: list[TextualMemoryItem] = self.retriever.search(
- query=item,
- mem_cube=mem_cube,
- top_k=k_per_evidence,
- method=self.search_method,
- )
- logger.info(
- f"search results for {missing_evidences}: {[one.memory for one in results]}"
- )
- new_candidates.extend(results)
- print(
- f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories."
- )
- else:
- new_candidates = []
- print(f"intent_result: {intent_result}. not triggered")
-
- # rerank
- new_order_working_memory = self.replace_working_memory(
- user_id=user_id,
- mem_cube_id=self.current_mem_cube_id,
- mem_cube=self.current_mem_cube,
- original_memory=cur_working_memory,
- new_memory=new_candidates,
- )
- new_order_working_memory = new_order_working_memory[:top_k]
- logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")
-
- return [m.memory for m in new_order_working_memory], intent_result["trigger_retrieval"]
-
def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
"""
Process and handle query trigger messages from the queue.
@@ -140,7 +70,9 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
query = msg.content
query_keywords = self.monitor.extract_query_keywords(query=query)
- logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')
+ logger.info(
+ f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}'
+ )
if len(query_keywords) == 0:
stripped_query = query.strip()
@@ -155,21 +87,26 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
)
words = stripped_query # Default to character count
- query_keywords = list(set(words[:20]))
+ query_keywords = list(set(words[: self.query_key_words_limit]))
logger.error(
- f"Keyword extraction failed for query. Using fallback keywords: {query_keywords[:10]}... (truncated)"
+ f"Keyword extraction failed for query '{query}' (user_id={user_id}). Using fallback keywords: {query_keywords[:10]}... (truncated)",
+ exc_info=True,
)
item = QueryMonitorItem(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
query_text=query,
keywords=query_keywords,
max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
)
- self.monitor.query_monitors[user_id][mem_cube_id].put(item=item)
+ query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
+ query_db_manager.obj.put(item=item)
+ # Sync with database after adding new item
+ query_db_manager.sync_with_orm()
logger.debug(
- f"Queries in monitor are "
- f"{self.monitor.query_monitors[user_id][mem_cube_id].get_queries_with_timesort()}."
+ f"Queries in monitor for user_id={user_id}, mem_cube_id={mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}"
)
queries = [msg.content for msg in messages]
@@ -183,7 +120,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
top_k=self.top_k,
)
logger.info(
- f"Processed {queries} and get {len(new_candidates)} new candidate memories."
+ f"Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} new candidate memories for user_id={user_id}"
)
# rerank
@@ -194,7 +131,9 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
original_memory=cur_working_memory,
new_memory=new_candidates,
)
- logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")
+ logger.info(
+ f"Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}"
+ )
# update activation memories
logger.info(
@@ -293,10 +232,17 @@ def process_session_turn(
text_mem_base = mem_cube.text_mem
if not isinstance(text_mem_base, TreeTextMemory):
- logger.error("Not implemented!", exc_info=True)
+ logger.error(
+ f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} "
+ f"for mem_cube_id={mem_cube_id}, user_id={user_id}. "
+ f"text_mem_base value: {text_mem_base}",
+ exc_info=True,
+ )
return
- logger.info(f"Processing {len(queries)} queries.")
+ logger.info(
+ f"Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}"
+ )
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
@@ -312,16 +258,20 @@ def process_session_turn(
time_trigger_flag = True
if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag):
- logger.info(f"Query schedule not triggered. Intent_result: {intent_result}")
+ logger.info(
+ f"Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}"
+ )
return
elif (not intent_result["trigger_retrieval"]) and time_trigger_flag:
- logger.info("Query schedule is forced to trigger due to time ticker")
+ logger.info(
+ f"Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}"
+ )
intent_result["trigger_retrieval"] = True
intent_result["missing_evidences"] = queries
else:
logger.info(
- f'Query schedule triggered for user "{user_id}" and mem_cube "{mem_cube_id}".'
- f" Missing evidences: {intent_result['missing_evidences']}"
+ f"Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. "
+ f"Missing evidences: {intent_result['missing_evidences']}"
)
missing_evidences = intent_result["missing_evidences"]
@@ -329,7 +279,9 @@ def process_session_turn(
k_per_evidence = max(1, top_k // max(1, num_evidence))
new_candidates = []
for item in missing_evidences:
- logger.info(f"missing_evidences: {item}")
+ logger.info(
+ f"Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}"
+ )
info = {
"user_id": user_id,
"session_id": "",
@@ -343,7 +295,7 @@ def process_session_turn(
info=info,
)
logger.info(
- f"search results for {missing_evidences}: {[one.memory for one in results]}"
+ f"Search results for missing evidence '{item}': {[one.memory for one in results]}"
)
new_candidates.extend(results)
return cur_working_memory, new_candidates
diff --git a/src/memos/mem_scheduler/memory_manage_modules/__init__.py b/src/memos/mem_scheduler/memory_manage_modules/__init__.py
new file mode 100644
index 00000000..94d70429
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/__init__.py
@@ -0,0 +1,5 @@
+from .memory_filter import MemoryFilter
+from .retriever import SchedulerRetriever
+
+
+__all__ = ["MemoryFilter", "SchedulerRetriever"]
diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
new file mode 100644
index 00000000..e18c6e51
--- /dev/null
+++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
@@ -0,0 +1,308 @@
+from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.llms.base import BaseLLM
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
+from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+from memos.memories.textual.tree import TextualMemoryItem
+
+
+logger = get_logger(__name__)
+
+
+class MemoryFilter(BaseSchedulerModule):
+ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
+ super().__init__()
+ self.config: BaseSchedulerConfig = config
+ self.process_llm = process_llm
+
+ def filter_unrelated_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ """
+ Filter out memories that are completely unrelated to the query history using LLM.
+
+ Args:
+ query_history: List of query strings to determine relevance
+ memories: List of TextualMemoryItem objects to be filtered
+
+ Returns:
+ Tuple of (filtered_memories, success_flag)
+ - filtered_memories: List of TextualMemoryItem objects that are relevant to queries
+ - success_flag: Boolean indicating if LLM filtering was successful
+
+ Note:
+ If LLM filtering fails, returns all memories (conservative approach)
+ """
+ success_flag = False
+
+ if not memories:
+ logger.info("No memories to filter - returning empty list")
+ return [], True
+
+ if not query_history:
+ logger.info("No query history provided - keeping all memories")
+ return memories, True
+
+ logger.info(
+ f"Starting memory filtering for {len(memories)} memories against {len(query_history)} queries"
+ )
+
+ # Extract memory texts for LLM processing
+ memory_texts = [mem.memory for mem in memories]
+
+ # Build LLM prompt for memory filtering
+ prompt = self.build_prompt(
+ "memory_filtering",
+ query_history=[f"[{i}] {query}" for i, query in enumerate(query_history)],
+ memories=[f"[{i}] {mem}" for i, mem in enumerate(memory_texts)],
+ )
+ logger.debug(f"Generated filtering prompt: {prompt[:200]}...") # Log first 200 chars
+
+ # Get LLM response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug(f"Received LLM filtering response: {response[:200]}...") # Log first 200 chars
+
+ try:
+ # Parse JSON response
+ response = extract_json_dict(response)
+ logger.debug(f"Parsed JSON response: {response}")
+ relevant_indices = response["relevant_memories"]
+ filtered_count = response["filtered_count"]
+ reasoning = response["reasoning"]
+
+ # Validate indices
+ if not isinstance(relevant_indices, list):
+ raise ValueError("relevant_memories must be a list")
+
+ # Filter memories based on relevant indices
+ filtered_memories = []
+ for idx in relevant_indices:
+ if isinstance(idx, int) and 0 <= idx < len(memories):
+ filtered_memories.append(memories[idx])
+ else:
+ logger.warning(f"Invalid memory index {idx} - skipping")
+
+ logger.info(
+ f"Successfully filtered memories. Kept {len(filtered_memories)} out of {len(memories)} memories. "
+ f"Filtered out {filtered_count} unrelated memories. "
+ f"Filtering reasoning: {reasoning}"
+ )
+ success_flag = True
+
+ except Exception as e:
+ logger.error(
+ f"Failed to filter memories with LLM. Exception: {e}. Raw response: {response}",
+ exc_info=True,
+ )
+ # Conservative approach: keep all memories if filtering fails
+ filtered_memories = memories
+ success_flag = False
+
+ return filtered_memories, success_flag
+
+ def filter_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ """
+ Filter out redundant memories using LLM analysis.
+
+ This function removes redundant memories by keeping the most informative
+ version when multiple memories contain similar information relevant to queries.
+
+ Args:
+ query_history: List of query strings to determine relevance and value
+ memories: List of TextualMemoryItem objects to be filtered
+
+ Returns:
+ Tuple of (filtered_memories, success_flag)
+ - filtered_memories: List of TextualMemoryItem objects after redundancy filtering
+ - success_flag: Boolean indicating if LLM filtering was successful
+
+ Note:
+ If LLM filtering fails, returns all memories (conservative approach)
+ """
+ success_flag = False
+
+ if not memories:
+ logger.info("No memories to filter for redundancy - returning empty list")
+ return [], True
+
+ if not query_history:
+ logger.info("No query history provided - keeping all memories")
+ return memories, True
+
+ if len(memories) <= 1:
+ logger.info("Only one memory - no redundancy to filter")
+ return memories, True
+
+ logger.info(
+ f"Starting redundancy filtering for {len(memories)} memories against {len(query_history)} queries"
+ )
+
+ # Extract memory texts for LLM processing
+ memory_texts = [mem.memory for mem in memories]
+
+ # Build LLM prompt for redundancy filtering
+ prompt = self.build_prompt(
+ "memory_redundancy_filtering",
+ query_history=[f"[{i}] {query}" for i, query in enumerate(query_history)],
+ memories=[f"[{i}] {mem}" for i, mem in enumerate(memory_texts)],
+ )
+ logger.debug(
+ f"Generated redundancy filtering prompt: {prompt[:200]}..."
+ ) # Log first 200 chars
+
+ # Get LLM response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug(
+ f"Received LLM redundancy filtering response: {response[:200]}..."
+ ) # Log first 200 chars
+
+ try:
+ # Parse JSON response
+ response = extract_json_dict(response)
+ logger.debug(f"Parsed JSON response: {response}")
+ kept_indices = response["kept_memories"]
+ redundant_groups = response.get("redundant_groups", [])
+ reasoning = response["reasoning"]
+
+ # Validate indices
+ if not isinstance(kept_indices, list):
+ raise ValueError("kept_memories must be a list")
+
+ # Filter memories based on kept indices
+ filtered_memories = []
+ for idx in kept_indices:
+ if isinstance(idx, int) and 0 <= idx < len(memories):
+ filtered_memories.append(memories[idx])
+ else:
+ logger.warning(f"Invalid memory index {idx} - skipping")
+
+ logger.info(
+ f"Successfully filtered redundant memories. "
+ f"Kept {len(filtered_memories)} out of {len(memories)} memories. "
+ f"Removed {len(memories) - len(filtered_memories)} redundant memories. "
+ f"Redundant groups identified: {len(redundant_groups)}. "
+ f"Filtering reasoning: {reasoning}"
+ )
+ success_flag = True
+
+ except Exception as e:
+ logger.error(
+ f"Failed to filter redundant memories with LLM. Exception: {e}. Raw response: {response}",
+ exc_info=True,
+ )
+ # Conservative approach: keep all memories if filtering fails
+ filtered_memories = memories
+ success_flag = False
+
+ return filtered_memories, success_flag
+
+ def filter_unrelated_and_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ """
+ Filter out both unrelated and redundant memories using LLM analysis.
+
+ This function performs two types of filtering in sequence:
+ 1. Remove memories that are completely unrelated to the query history
+ 2. Remove redundant memories by keeping the most informative version
+
+ Args:
+ query_history: List of query strings to determine relevance and value
+ memories: List of TextualMemoryItem objects to be filtered
+
+ Returns:
+ Tuple of (filtered_memories, success_flag)
+ - filtered_memories: List of TextualMemoryItem objects after both filtering steps
+ - success_flag: Boolean indicating if LLM filtering was successful
+
+ Note:
+ If LLM filtering fails, returns all memories (conservative approach)
+ """
+ success_flag = False
+
+ if not memories:
+ logger.info("No memories to filter for unrelated and redundant - returning empty list")
+ return [], True
+
+ if not query_history:
+ logger.info("No query history provided - keeping all memories")
+ return memories, True
+
+ if len(memories) <= 1:
+ logger.info("Only one memory - no filtering needed")
+ return memories, True
+
+ logger.info(
+ f"Starting combined unrelated and redundant filtering for {len(memories)} memories against {len(query_history)} queries"
+ )
+
+ # Extract memory texts for LLM processing
+ memory_texts = [mem.memory for mem in memories]
+
+ # Build LLM prompt for combined filtering
+ prompt = self.build_prompt(
+ "memory_combined_filtering",
+ query_history=[f"[{i}] {query}" for i, query in enumerate(query_history)],
+ memories=[f"[{i}] {mem}" for i, mem in enumerate(memory_texts)],
+ )
+ logger.debug(
+ f"Generated combined filtering prompt: {prompt[:200]}..."
+ ) # Log first 200 chars
+
+ # Get LLM response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug(
+ f"Received LLM combined filtering response: {response[:200]}..."
+ ) # Log first 200 chars
+
+ try:
+ # Parse JSON response
+ response = extract_json_dict(response)
+ logger.debug(f"Parsed JSON response: {response}")
+ kept_indices = response["kept_memories"]
+ unrelated_removed_count = response.get("unrelated_removed_count", 0)
+ redundant_removed_count = response.get("redundant_removed_count", 0)
+ redundant_groups = response.get("redundant_groups", [])
+ reasoning = response["reasoning"]
+
+ # Validate indices
+ if not isinstance(kept_indices, list):
+ raise ValueError("kept_memories must be a list")
+
+ # Filter memories based on kept indices
+ filtered_memories = []
+ for idx in kept_indices:
+ if isinstance(idx, int) and 0 <= idx < len(memories):
+ filtered_memories.append(memories[idx])
+ else:
+ logger.warning(f"Invalid memory index {idx} - skipping")
+
+ logger.info(
+ f"Successfully filtered unrelated and redundant memories. "
+ f"Kept {len(filtered_memories)} out of {len(memories)} memories. "
+ f"Removed {len(memories) - len(filtered_memories)} memories total. "
+ f"Unrelated removed: {unrelated_removed_count}. "
+ f"Redundant removed: {redundant_removed_count}. "
+ f"Redundant groups identified: {len(redundant_groups)}. "
+ f"Filtering reasoning: {reasoning}"
+ )
+ success_flag = True
+
+ except Exception as e:
+ logger.error(
+ f"Failed to filter unrelated and redundant memories with LLM. Exception: {e}. Raw response: {response}",
+ exc_info=True,
+ )
+ # Conservative approach: keep all memories if filtering fails
+ filtered_memories = memories
+ success_flag = False
+
+ return filtered_memories, success_flag
diff --git a/src/memos/mem_scheduler/general_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
similarity index 85%
rename from src/memos/mem_scheduler/general_modules/retriever.py
rename to src/memos/mem_scheduler/memory_manage_modules/retriever.py
index 3732078d..b766f001 100644
--- a/src/memos/mem_scheduler/general_modules/retriever.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
@@ -8,8 +8,8 @@
TreeTextMemory_SEARCH_METHOD,
)
from memos.mem_scheduler.utils.filter_utils import (
- filter_similar_memories,
filter_too_short_memories,
+ filter_vector_based_similar_memories,
transform_name_to_key,
)
from memos.mem_scheduler.utils.misc_utils import (
@@ -17,6 +17,8 @@
)
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+from .memory_filter import MemoryFilter
+
logger = get_logger(__name__)
@@ -32,6 +34,9 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
self.config: BaseSchedulerConfig = config
self.process_llm = process_llm
+ # Initialize memory filter
+ self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
+
def search(
self,
query: str,
@@ -77,10 +82,7 @@ def search(
return results
def rerank_memories(
- self,
- queries: list[str],
- original_memories: list[str],
- top_k: int,
+ self, queries: list[str], original_memories: list[str], top_k: int
) -> (list[str], bool):
"""
Rerank memories based on relevance to given queries using LLM.
@@ -96,7 +98,6 @@ def rerank_memories(
Note:
If LLM reranking fails, falls back to original order (truncated to top_k)
"""
- success_flag = False
logger.info(f"Starting memory reranking for {len(original_memories)} memories")
@@ -163,7 +164,7 @@ def process_and_rerank_memories(
combined_text_memory = [m.memory for m in combined_memory]
# Apply similarity filter to remove overly similar memories
- filtered_combined_text_memory = filter_similar_memories(
+ filtered_combined_text_memory = filter_vector_based_similar_memories(
text_memories=combined_text_memory,
similarity_threshold=self.filter_similarity_threshold,
)
@@ -197,3 +198,29 @@ def process_and_rerank_memories(
)
return memories_with_new_order, success_flag
+
+ def filter_unrelated_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ return self.memory_filter.filter_unrelated_memories(query_history, memories)
+
+ def filter_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ return self.memory_filter.filter_redundant_memories(query_history, memories)
+
+ def filter_unrelated_and_redundant_memories(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ """
+ Filter out both unrelated and redundant memories using LLM analysis.
+
+ This method delegates to the MemoryFilter class.
+ """
+ return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories)
diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
index 229e9c3a..85dc17ad 100644
--- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
+++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
@@ -1,11 +1,11 @@
import threading
import time
-from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from time import perf_counter
from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
@@ -21,7 +21,7 @@ def __init__(self, config: BaseSchedulerConfig):
super().__init__()
self.config: BaseSchedulerConfig = config
- self.check_interval = self.config.get("dispatcher_monitor_check_interval", 60)
+ self.check_interval = self.config.get("dispatcher_monitor_check_interval", 300)
self.max_failures = self.config.get("dispatcher_monitor_max_failures", 2)
# Registry of monitored thread pools
@@ -49,7 +49,7 @@ def initialize(self, dispatcher: SchedulerDispatcher):
def register_pool(
self,
name: str,
- executor: ThreadPoolExecutor,
+ executor: ContextThreadPoolExecutor,
max_workers: int,
restart_on_failure: bool = True,
) -> bool:
@@ -177,10 +177,11 @@ def _check_pools_health(self) -> None:
else:
pool_info["failure_count"] += 1
pool_info["healthy"] = False
- logger.warning(
- f"Pool '{name}' unhealthy ({pool_info['failure_count']}/{self.max_failures}): {reason}"
+ logger.info(
+ f"Pool '{name}' unhealthy ({pool_info['failure_count']}/{self.max_failures}): {reason}."
+ f" Note: This status does not necessarily indicate a problem with the pool itself - "
+ f"it may also be considered unhealthy if no tasks have been scheduled for an extended period"
)
-
if (
pool_info["failure_count"] >= self.max_failures
and pool_info["restart"]
@@ -236,14 +237,14 @@ def _restart_pool(self, name: str, pool_info: dict) -> None:
return
self._restart_in_progress = True
- logger.warning(f"Attempting to restart thread pool '{name}'")
+ logger.info(f"Attempting to restart thread pool '{name}'")
try:
old_executor = pool_info["executor"]
self.dispatcher.shutdown()
# Create new executor with same parameters
- new_executor = ThreadPoolExecutor(
+ new_executor = ContextThreadPoolExecutor(
max_workers=pool_info["max_workers"],
thread_name_prefix=self.dispatcher.thread_name_prefix, # pylint: disable=protected-access
)
diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py
index 6bc796cc..87d99654 100644
--- a/src/memos/mem_scheduler/monitors/general_monitor.py
+++ b/src/memos/mem_scheduler/monitors/general_monitor.py
@@ -2,11 +2,18 @@
from threading import Lock
from typing import Any
+from sqlalchemy.engine import Engine
+
from memos.configs.mem_scheduler import BaseSchedulerConfig
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
+from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
+from memos.mem_scheduler.orm_modules.monitor_models import (
+ DBManagerForMemoryMonitorManager,
+ DBManagerForQueryMonitorQueue,
+)
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT,
DEFAULT_WEIGHT_VECTOR_FOR_RANKING,
@@ -19,7 +26,6 @@
from memos.mem_scheduler.schemas.monitor_schemas import (
MemoryMonitorItem,
MemoryMonitorManager,
- QueryMonitorItem,
QueryMonitorQueue,
)
from memos.mem_scheduler.utils.misc_utils import extract_json_dict
@@ -32,7 +38,9 @@
class SchedulerGeneralMonitor(BaseSchedulerModule):
"""Monitors and manages scheduling operations with LLM integration."""
- def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
+ def __init__(
+ self, process_llm: BaseLLM, config: BaseSchedulerConfig, db_engine: Engine | None = None
+ ):
super().__init__()
# hyper-parameters
@@ -49,12 +57,22 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
"activation_mem_monitor_capacity", DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT
)
- # attributes
- # recording query_messages
- self.query_monitors: dict[UserID, dict[MemCubeID, QueryMonitorQueue[QueryMonitorItem]]] = {}
+ # ORM-based monitor managers
+ self.db_engine = db_engine
+ if self.db_engine is None:
+ logger.warning(
+ "No database engine provided; falling back to default temporary SQLite engine. "
+ "This is intended for testing only. Consider providing a configured engine for production use."
+ )
+ self.db_engine = BaseDBManager.create_default_engine()
- self.working_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {}
- self.activation_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {}
+ self.query_monitors: dict[UserID, dict[MemCubeID, DBManagerForQueryMonitorQueue]] = {}
+ self.working_memory_monitors: dict[
+ UserID, dict[MemCubeID, DBManagerForMemoryMonitorManager]
+ ] = {}
+ self.activation_memory_monitors: dict[
+ UserID, dict[MemCubeID, DBManagerForMemoryMonitorManager]
+ ] = {}
# Lifecycle monitor
self.last_activation_mem_update_time = datetime.min
@@ -96,40 +114,47 @@ def register_query_monitor_if_not_exists(
if user_id not in self.query_monitors:
self.query_monitors[user_id] = {}
if mem_cube_id not in self.query_monitors[user_id]:
- self.query_monitors[user_id][mem_cube_id] = QueryMonitorQueue(
- maxsize=self.config.context_window_size
- )
+ if self.db_engine:
+ # Create ORM manager with initial QueryMonitorQueue
+ initial_queue = QueryMonitorQueue(maxsize=self.config.context_window_size)
+ db_manager = DBManagerForQueryMonitorQueue(
+ engine=self.db_engine,
+ user_id=str(user_id),
+ mem_cube_id=str(mem_cube_id),
+ obj=initial_queue,
+ )
+ self.query_monitors[user_id][mem_cube_id] = db_manager
+ else:
+ # Fallback to in-memory (this shouldn't happen with proper config)
+ logger.warning("ORM persistence disabled, using in-memory fallback")
+ # For backward compatibility, we'll need to handle this case differently
+ raise RuntimeError("ORM persistence is required but not properly configured")
def register_memory_manager_if_not_exists(
self,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
- memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]],
+ memory_monitors: dict[UserID, dict[MemCubeID, DBManagerForMemoryMonitorManager]],
max_capacity: int,
) -> None:
"""
- Register a new MemoryMonitorManager for the given user and memory cube if it doesn't exist.
+ Register a new MemoryMonitorManager ORM manager for the given user and memory cube if it doesn't exist.
Thread-safe implementation using double-checked locking pattern.
- Checks if a MemoryMonitorManager already exists for the specified user_id and mem_cube_id.
- If not, creates a new MemoryMonitorManager with appropriate capacity settings and registers it.
+ Checks if a MemoryMonitorManager ORM manager already exists for the specified user_id and mem_cube_id.
+ If not, creates a new ORM manager with appropriate capacity settings and registers it.
Args:
user_id: The ID of the user to associate with the memory manager
mem_cube_id: The ID of the memory cube to monitor
- memory_monitors: Dictionary storing existing memory monitor managers
+ memory_monitors: Dictionary storing existing memory monitor ORM managers
max_capacity: Maximum capacity for the new memory monitor manager
- lock: Threading lock to ensure safe concurrent access
-
- Note:
- This function will update the loose_max_working_memory_capacity based on the current
- WorkingMemory size plus partial retention number before creating a new manager.
"""
# First check (lock-free, fast path)
# Quickly verify existence without lock overhead
if user_id in memory_monitors and mem_cube_id in memory_monitors[user_id]:
logger.info(
- f"MemoryMonitorManager already exists for user_id={user_id}, "
+ f"MemoryMonitorManager ORM manager already exists for user_id={user_id}, "
f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary"
)
return
@@ -140,22 +165,33 @@ def register_memory_manager_if_not_exists(
# Re-check after acquiring lock, as another thread might have created it
if user_id in memory_monitors and mem_cube_id in memory_monitors[user_id]:
logger.info(
- f"MemoryMonitorManager already exists for user_id={user_id}, "
+ f"MemoryMonitorManager ORM manager already exists for user_id={user_id}, "
f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary"
)
return
- # Initialize MemoryMonitorManager with user ID, memory cube ID, and max capacity
- monitor_manager = MemoryMonitorManager(
- user_id=user_id, mem_cube_id=mem_cube_id, max_capacity=max_capacity
- )
+ if self.db_engine:
+ # Initialize MemoryMonitorManager with user ID, memory cube ID, and max capacity
+ monitor_manager = MemoryMonitorManager(
+ user_id=user_id, mem_cube_id=mem_cube_id, max_capacity=max_capacity
+ )
- # Safely register the new manager in the nested dictionary structure
- memory_monitors.setdefault(user_id, {})[mem_cube_id] = monitor_manager
- logger.info(
- f"Registered new MemoryMonitorManager for user_id={user_id},"
- f" mem_cube_id={mem_cube_id} with max_capacity={max_capacity}"
- )
+ # Create ORM manager
+ db_manager = DBManagerForMemoryMonitorManager(
+ engine=self.db_engine,
+ user_id=str(user_id),
+ mem_cube_id=str(mem_cube_id),
+ obj=monitor_manager,
+ )
+
+ # Safely register the new ORM manager in the nested dictionary structure
+ memory_monitors.setdefault(user_id, {})[mem_cube_id] = db_manager
+ logger.info(
+ f"Registered new MemoryMonitorManager ORM manager for user_id={user_id},"
+ f" mem_cube_id={mem_cube_id} with max_capacity={max_capacity}"
+ )
+ else:
+ raise RuntimeError("ORM persistence is required but not properly configured")
def update_working_memory_monitors(
self,
@@ -182,10 +218,14 @@ def update_working_memory_monitors(
max_capacity=self.working_mem_monitor_capacity,
)
- self.working_memory_monitors[user_id][mem_cube_id].update_memories(
+ # Get the ORM manager and update memories with database sync
+ db_manager = self.working_memory_monitors[user_id][mem_cube_id]
+ db_manager.obj.update_memories(
new_memory_monitors=new_working_memory_monitors,
partial_retention_number=self.partial_retention_number,
)
+ # Sync with database
+ db_manager.sync_with_orm(size_limit=self.working_mem_monitor_capacity)
def update_activation_memory_monitors(
self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube
@@ -199,17 +239,21 @@ def update_activation_memory_monitors(
# === update activation memory monitors ===
# Sort by importance_score in descending order and take top k
+ working_db_manager = self.working_memory_monitors[user_id][mem_cube_id]
top_k_memories = sorted(
- self.working_memory_monitors[user_id][mem_cube_id].memories,
+ working_db_manager.obj.memories,
key=lambda m: m.get_importance_score(weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING),
reverse=True,
)[: self.activation_mem_monitor_capacity]
# Update the activation memory monitors with these important memories
- self.activation_memory_monitors[user_id][mem_cube_id].update_memories(
+ activation_db_manager = self.activation_memory_monitors[user_id][mem_cube_id]
+ activation_db_manager.obj.update_memories(
new_memory_monitors=top_k_memories,
partial_retention_number=self.partial_retention_number,
)
+ # Sync with database
+ activation_db_manager.sync_with_orm(size_limit=self.activation_mem_monitor_capacity)
def timed_trigger(self, last_time: datetime, interval_seconds: float) -> bool:
now = datetime.utcnow()
@@ -255,9 +299,12 @@ def get_monitor_memories(
)
return []
- manager: MemoryMonitorManager = monitor_dict[user_id][mem_cube_id]
+ db_manager: DBManagerForMemoryMonitorManager = monitor_dict[user_id][mem_cube_id]
+ # Load latest data from database before accessing
+ db_manager.sync_with_orm()
+
# Sort memories by recording_count in descending order and return top_k items
- sorted_memory_monitors = manager.get_sorted_mem_monitors(reverse=True)
+ sorted_memory_monitors = db_manager.obj.get_sorted_mem_monitors(reverse=True)
sorted_text_memories = [m.memory_text for m in sorted_memory_monitors[:top_k]]
return sorted_text_memories
@@ -273,16 +320,19 @@ def get_monitors_info(self, user_id: str, mem_cube_id: str) -> dict[str, Any]:
return {}
info_dict = {}
- for manager in [
+ for db_manager in [
self.working_memory_monitors[user_id][mem_cube_id],
self.activation_memory_monitors[user_id][mem_cube_id],
]:
+ # Sync with database to get latest data
+ db_manager.sync_with_orm()
+ manager = db_manager.obj
info_dict[str(type(manager))] = {
"user_id": user_id,
"mem_cube_id": mem_cube_id,
"memory_count": manager.memory_size,
"max_capacity": manager.max_capacity,
- "top_memories": self.get_scheduler_working_memories(user_id, mem_cube_id, top_k=1),
+ "top_memories": self.get_monitor_memories(user_id, mem_cube_id, top_k=1),
}
return info_dict
@@ -308,3 +358,33 @@ def detect_intent(
logger.error(f"Fail to extract json dict from response: {response}")
response = {"trigger_retrieval": False, "missing_evidences": q_list}
return response
+
+ def close(self):
+ """Close all database connections and clean up resources"""
+ logger.info("Closing database connections for all monitors")
+
+ # Close all query monitor database managers
+ for user_monitors in self.query_monitors.values():
+ for db_manager in user_monitors.values():
+ try:
+ db_manager.close()
+ except Exception as e:
+ logger.error(f"Error closing query monitor DB manager: {e}")
+
+ # Close all working memory monitor database managers
+ for user_monitors in self.working_memory_monitors.values():
+ for db_manager in user_monitors.values():
+ try:
+ db_manager.close()
+ except Exception as e:
+ logger.error(f"Error closing working memory monitor DB manager: {e}")
+
+ # Close all activation memory monitor database managers
+ for user_monitors in self.activation_memory_monitors.values():
+ for db_manager in user_monitors.values():
+ try:
+ db_manager.close()
+ except Exception as e:
+ logger.error(f"Error closing activation memory monitor DB manager: {e}")
+
+ logger.info("All database connections closed")
diff --git a/src/memos/mem_scheduler/mos_for_test_scheduler.py b/src/memos/mem_scheduler/mos_for_test_scheduler.py
deleted file mode 100644
index f275da2b..00000000
--- a/src/memos/mem_scheduler/mos_for_test_scheduler.py
+++ /dev/null
@@ -1,146 +0,0 @@
-from datetime import datetime
-
-from memos.configs.mem_os import MOSConfig
-from memos.log import get_logger
-from memos.mem_os.main import MOS
-from memos.mem_scheduler.schemas.general_schemas import (
- ANSWER_LABEL,
- MONITOR_WORKING_MEMORY_TYPE,
- QUERY_LABEL,
-)
-from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
-
-
-logger = get_logger(__name__)
-
-
-class MOSForTestScheduler(MOS):
- """This class is only to test abilities of mem scheduler"""
-
- def __init__(self, config: MOSConfig):
- super().__init__(config)
-
- def _str_memories(self, memories: list[str]) -> str:
- """Format memories for display."""
- if not memories:
- return "No memories."
- return "\n".join(f"{i + 1}. {memory}" for i, memory in enumerate(memories))
-
- def chat(self, query: str, user_id: str | None = None) -> str:
- """
- Chat with the MOS.
-
- Args:
- query (str): The user's query.
-
- Returns:
- str: The response from the MOS.
- """
- target_user_id = user_id if user_id is not None else self.user_id
- accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
- user_cube_ids = [cube.cube_id for cube in accessible_cubes]
- if target_user_id not in self.chat_history_manager:
- self._register_chat_history(target_user_id)
-
- chat_history = self.chat_history_manager[target_user_id]
-
- topk_for_scheduler = 2
-
- if self.config.enable_textual_memory and self.mem_cubes:
- memories_all = []
- for mem_cube_id, mem_cube in self.mem_cubes.items():
- if mem_cube_id not in user_cube_ids:
- continue
- if not mem_cube.text_mem:
- continue
-
- message_item = ScheduleMessageItem(
- user_id=target_user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- label=QUERY_LABEL,
- content=query,
- timestamp=datetime.now(),
- )
- cur_working_memories = [m.memory for m in mem_cube.text_mem.get_working_memory()]
- print(f"Working memories before schedule: {cur_working_memories}")
-
- # --- force to run mem_scheduler ---
- self.mem_scheduler.monitor.query_trigger_interval = 0
- self.mem_scheduler._query_message_consumer(messages=[message_item])
-
- # from scheduler
- scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories(
- user_id=target_user_id,
- mem_cube_id=mem_cube_id,
- memory_type=MONITOR_WORKING_MEMORY_TYPE,
- top_k=topk_for_scheduler,
- )
- print(f"Working memories after schedule: {scheduler_memories}")
- memories_all.extend(scheduler_memories)
-
- # from mem_cube
- memories = mem_cube.text_mem.search(
- query,
- top_k=self.config.top_k - topk_for_scheduler,
- info={
- "user_id": target_user_id,
- "session_id": self.session_id,
- "chat_history": chat_history.chat_history,
- },
- )
- text_memories = [m.memory for m in memories]
- print(f"Search results with new working memories: {text_memories}")
- memories_all.extend(text_memories)
-
- memories_all = list(set(memories_all))
-
- logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
- system_prompt = self._build_system_prompt(memories_all)
- else:
- system_prompt = self._build_system_prompt()
- current_messages = [
- {"role": "system", "content": system_prompt},
- *chat_history.chat_history,
- {"role": "user", "content": query},
- ]
- past_key_values = None
-
- if self.config.enable_activation_memory:
- assert self.config.chat_model.backend == "huggingface", (
- "Activation memory only used for huggingface backend."
- )
- # TODO this only one cubes
- for mem_cube_id, mem_cube in self.mem_cubes.items():
- if mem_cube_id not in user_cube_ids:
- continue
- if mem_cube.act_mem:
- kv_cache = next(iter(mem_cube.act_mem.get_all()), None)
- past_key_values = (
- kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None
- )
- break
- # Generate response
- response = self.chat_llm.generate(current_messages, past_key_values=past_key_values)
- else:
- response = self.chat_llm.generate(current_messages)
- logger.info(f"🤖 [Assistant] {response}\n")
- chat_history.chat_history.append({"role": "user", "content": query})
- chat_history.chat_history.append({"role": "assistant", "content": response})
- self.chat_history_manager[user_id] = chat_history
-
- # submit message to scheduler
- for accessible_mem_cube in accessible_cubes:
- mem_cube_id = accessible_mem_cube.cube_id
- mem_cube = self.mem_cubes[mem_cube_id]
- if self.enable_mem_scheduler and self.mem_scheduler is not None:
- message_item = ScheduleMessageItem(
- user_id=target_user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- label=ANSWER_LABEL,
- content=response,
- timestamp=datetime.now(),
- )
- self.mem_scheduler.submit_messages(messages=[message_item])
- return response
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
new file mode 100644
index 00000000..dd08954a
--- /dev/null
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -0,0 +1,124 @@
+from typing import TYPE_CHECKING
+
+from memos.configs.mem_scheduler import GeneralSchedulerConfig
+from memos.log import get_logger
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_scheduler.general_scheduler import GeneralScheduler
+from memos.mem_scheduler.schemas.general_schemas import (
+ MemCubeID,
+ UserID,
+)
+from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
+
+
+logger = get_logger(__name__)
+
+
+class OptimizedScheduler(GeneralScheduler):
+ """Optimized scheduler with improved working memory management"""
+
+ def __init__(self, config: GeneralSchedulerConfig):
+ super().__init__(config)
+
+ def replace_working_memory(
+ self,
+ user_id: UserID | str,
+ mem_cube_id: MemCubeID | str,
+ mem_cube: GeneralMemCube,
+ original_memory: list[TextualMemoryItem],
+ new_memory: list[TextualMemoryItem],
+ ) -> None | list[TextualMemoryItem]:
+ """Replace working memory with new memories after reranking."""
+ text_mem_base = mem_cube.text_mem
+ if isinstance(text_mem_base, TreeTextMemory):
+ text_mem_base: TreeTextMemory = text_mem_base
+
+ # process rerank memories with llm
+ query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
+ # Sync with database to get latest query history
+ query_db_manager.sync_with_orm()
+
+ query_history = query_db_manager.obj.get_queries_with_timesort()
+ memories_with_new_order, rerank_success_flag = (
+ self.retriever.process_and_rerank_memories(
+ queries=query_history,
+ original_memory=original_memory,
+ new_memory=new_memory,
+ top_k=self.top_k,
+ )
+ )
+
+ # Apply combined filtering (unrelated + redundant)
+ logger.info(
+ f"Applying combined unrelated and redundant memory filtering to {len(memories_with_new_order)} memories"
+ )
+ filtered_memories, filtering_success_flag = (
+ self.retriever.filter_unrelated_and_redundant_memories(
+ query_history=query_history,
+ memories=memories_with_new_order,
+ )
+ )
+
+ if filtering_success_flag:
+ logger.info(
+ f"Combined filtering completed successfully. "
+ f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories"
+ )
+ memories_with_new_order = filtered_memories
+ else:
+ logger.warning(
+ "Combined filtering failed - keeping memories as fallback. "
+ f"Count: {len(memories_with_new_order)}"
+ )
+
+ # Update working memory monitors
+ query_keywords = query_db_manager.obj.get_keywords_collections()
+ logger.info(
+ f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
+ )
+ new_working_memory_monitors = self.transform_working_memories_to_monitors(
+ query_keywords=query_keywords,
+ memories=memories_with_new_order,
+ )
+
+ if not rerank_success_flag:
+ for one in new_working_memory_monitors:
+ one.sorting_score = 0
+
+ logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors")
+ self.monitor.update_working_memory_monitors(
+ new_working_memory_monitors=new_working_memory_monitors,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+
+ # Use the filtered and reranked memories directly
+ text_mem_base.replace_working_memory(memories=memories_with_new_order)
+
+ # Update monitor after replacing working memory
+ mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][
+ mem_cube_id
+ ].obj.get_sorted_mem_monitors(reverse=True)
+ new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors]
+
+ logger.info(
+ f"The working memory has been replaced with {len(memories_with_new_order)} new memories."
+ )
+ self.log_working_memory_replacement(
+ original_memory=original_memory,
+ new_memory=new_working_memories,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ log_func_callback=self._submit_web_logs,
+ )
+ else:
+ logger.error("memory_base is not supported")
+ memories_with_new_order = new_memory
+
+ return memories_with_new_order
diff --git a/src/memos/mem_scheduler/orm_modules/__init__.py b/src/memos/mem_scheduler/orm_modules/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/memos/mem_scheduler/orm_modules/base_model.py b/src/memos/mem_scheduler/orm_modules/base_model.py
new file mode 100644
index 00000000..9d75a12b
--- /dev/null
+++ b/src/memos/mem_scheduler/orm_modules/base_model.py
@@ -0,0 +1,635 @@
+import json
+import os
+import tempfile
+import time
+
+from abc import abstractmethod
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, TypeVar
+
+from sqlalchemy import Boolean, Column, DateTime, String, Text, and_, create_engine
+from sqlalchemy.engine import Engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session, sessionmaker
+
+from memos.log import get_logger
+from memos.mem_user.user_manager import UserManager
+
+
+T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.)
+ORM = TypeVar("ORM") # The ORM model type
+
+logger = get_logger(__name__)
+
+Base = declarative_base()
+
+
+class LockableORM(Base):
+ """Abstract base class for lockable ORM models"""
+
+ __abstract__ = True
+
+ # Primary composite key
+ user_id = Column(String(255), primary_key=True)
+ mem_cube_id = Column(String(255), primary_key=True)
+
+ # Serialized data
+ serialized_data = Column(Text, nullable=False)
+
+ lock_acquired = Column(Boolean, default=False)
+ lock_expiry = Column(DateTime, nullable=True)
+
+ # Version control tag (0-255, cycles back to 0)
+ version_control = Column(String(3), default="0")
+
+
+class BaseDBManager(UserManager):
+ """Abstract base class for database managers with proper locking mechanism
+
+ This class provides a foundation for managing database operations with
+ distributed locking capabilities to ensure data consistency across
+ multiple processes or threads.
+ """
+
+ def __init__(
+ self,
+ engine: Engine,
+ user_id: str | None = None,
+ mem_cube_id: str | None = None,
+ lock_timeout: int = 10,
+ ):
+ """Initialize the database manager
+
+ Args:
+ engine: SQLAlchemy engine instance
+ user_id: Unique identifier for the user
+ mem_cube_id: Unique identifier for the memory cube
+ lock_timeout: Timeout in seconds for lock acquisition
+ """
+ # Do not use super init func to avoid UserManager initialization
+ self.engine = engine
+ self.SessionLocal = None
+ self.obj = None
+ self.user_id = user_id
+ self.mem_cube_id = mem_cube_id
+ self.lock_timeout = lock_timeout
+ self.last_version_control = None # Track the last version control tag
+
+ self.init_manager(
+ engine=self.engine,
+ user_id=self.user_id,
+ mem_cube_id=self.mem_cube_id,
+ )
+
+ @property
+ @abstractmethod
+ def orm_class(self) -> type[LockableORM]:
+ """Return the ORM model class for this manager
+
+ Returns:
+ The SQLAlchemy ORM model class
+ """
+ raise NotImplementedError()
+
+ @property
+ @abstractmethod
+ def obj_class(self) -> Any:
+ """Return the business object class for this manager
+
+ Returns:
+ The business logic object class
+ """
+ raise NotImplementedError()
+
+ def init_manager(self, engine: Engine, user_id: str, mem_cube_id: str):
+ """Initialize the database manager with engine and identifiers
+
+ Args:
+ engine: SQLAlchemy engine instance
+ user_id: User identifier
+ mem_cube_id: Memory cube identifier
+
+ Raises:
+ RuntimeError: If database initialization fails
+ """
+ try:
+ self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+
+ logger.info(f"{self.orm_class} initialized with engine {engine}")
+ logger.info(f"Set user_id to {user_id}; mem_cube_id to {mem_cube_id}")
+
+ # Create tables if they don't exist
+ self._create_table_with_error_handling(engine)
+ logger.debug(f"Successfully created/verified table for {self.orm_class.__tablename__}")
+
+ except Exception as e:
+ error_msg = f"Failed to initialize database manager for {self.orm_class.__name__}: {e}"
+ logger.error(error_msg, exc_info=True)
+ raise RuntimeError(error_msg) from e
+
+ def _create_table_with_error_handling(self, engine: Engine):
+ """Create table with proper error handling for common database conflicts
+
+ Args:
+ engine: SQLAlchemy engine instance
+
+ Raises:
+ RuntimeError: If table creation fails after handling known issues
+ """
+ try:
+ self.orm_class.__table__.create(bind=engine, checkfirst=True)
+ except Exception as e:
+ error_str = str(e).lower()
+
+ # Handle common SQLite index already exists error
+ if "index" in error_str and "already exists" in error_str:
+ logger.warning(f"Index already exists for {self.orm_class.__tablename__}: {e}")
+ # Try to create just the table without indexes
+ try:
+ # Create a temporary table definition without indexes
+ table_without_indexes = self.orm_class.__table__.copy()
+ table_without_indexes._indexes.clear() # Remove all indexes
+ table_without_indexes.create(bind=engine, checkfirst=True)
+ logger.info(
+ f"Created table {self.orm_class.__tablename__} without problematic indexes"
+ )
+ except Exception as table_error:
+ logger.error(f"Failed to create table even without indexes: {table_error}")
+ raise
+ else:
+ # Re-raise other types of errors
+ raise
+
+ def _get_session(self) -> Session:
+ """Get a database session"""
+ return self.SessionLocal()
+
+ def _serialize(self, obj: T) -> str:
+ """Serialize the object to JSON"""
+ if hasattr(obj, "to_json"):
+ return obj.to_json()
+ return json.dumps(obj)
+
+ def _deserialize(self, data: str, model_class: type[T]) -> T:
+ """Deserialize JSON to object"""
+ if hasattr(model_class, "from_json"):
+ return model_class.from_json(data)
+ return json.loads(data)
+
+ def acquire_lock(self, block: bool = True, **kwargs) -> bool:
+ """Acquire a distributed lock for the current user and memory cube
+
+ Args:
+ block: Whether to block until lock is acquired
+ **kwargs: Additional filter criteria
+
+ Returns:
+ True if lock was acquired, False otherwise
+ """
+ session = self._get_session()
+
+ try:
+ now = datetime.now()
+ expiry = now + timedelta(seconds=self.lock_timeout)
+
+ # Query for existing record with lock information
+ query = (
+ session.query(self.orm_class)
+ .filter_by(**kwargs)
+ .filter(
+ and_(
+ self.orm_class.user_id == self.user_id,
+ self.orm_class.mem_cube_id == self.mem_cube_id,
+ )
+ )
+ )
+
+ record = query.first()
+
+ # If no record exists, lock can be acquired immediately
+ if record is None:
+ logger.info(
+ f"No existing record found for {self.user_id}/{self.mem_cube_id}, lock can be acquired"
+ )
+ return True
+
+ # Check if lock is currently held and not expired
+ if record.lock_acquired and record.lock_expiry and now < record.lock_expiry:
+ if block:
+ # Wait for lock to be released or expire
+ logger.info(
+ f"Waiting for lock to be released for {self.user_id}/{self.mem_cube_id}"
+ )
+ while record.lock_acquired and record.lock_expiry and now < record.lock_expiry:
+ time.sleep(0.1) # Small delay before retry
+ session.refresh(record) # Refresh record state
+ now = datetime.now()
+ else:
+ logger.warning(
+ f"Lock is held for {self.user_id}/{self.mem_cube_id}, cannot acquire"
+ )
+ return False
+
+ # Acquire the lock by updating the record
+ query.update(
+ {
+ "lock_acquired": True,
+ "lock_expiry": expiry,
+ },
+ synchronize_session=False,
+ )
+
+ session.commit()
+ logger.info(f"Lock acquired for {self.user_id}/{self.mem_cube_id}")
+ return True
+
+ except Exception as e:
+ session.rollback()
+ logger.error(f"Failed to acquire lock for {self.user_id}/{self.mem_cube_id}: {e}")
+ return False
+ finally:
+ session.close()
+
+ def release_locks(self, user_id: str, mem_cube_id: str, **kwargs):
+ """Release locks for the specified user and memory cube
+
+ Args:
+ user_id: User identifier
+ mem_cube_id: Memory cube identifier
+ **kwargs: Additional filter criteria
+ """
+ session = self._get_session()
+
+ try:
+ # Update all matching records to release locks
+ result = (
+ session.query(self.orm_class)
+ .filter_by(**kwargs)
+ .filter(
+ and_(
+ self.orm_class.user_id == user_id, self.orm_class.mem_cube_id == mem_cube_id
+ )
+ )
+ .update(
+ {
+ "lock_acquired": False,
+ "lock_expiry": None, # Clear expiry time as well
+ },
+ synchronize_session=False,
+ )
+ )
+ session.commit()
+ logger.info(f"Lock released for {user_id}/{mem_cube_id} (affected {result} records)")
+
+ except Exception as e:
+ session.rollback()
+ logger.error(f"Failed to release lock for {user_id}/{mem_cube_id}: {e}")
+ finally:
+ session.close()
+
+ def _get_primary_key(self) -> dict[str, Any]:
+ """Get the primary key dictionary for the current instance
+
+ Returns:
+ Dictionary containing user_id and mem_cube_id
+ """
+ return {"user_id": self.user_id, "mem_cube_id": self.mem_cube_id}
+
+ def _increment_version_control(self, current_tag: str) -> str:
+ """Increment the version control tag, cycling from 255 back to 0
+
+ Args:
+ current_tag: Current version control tag as string
+
+ Returns:
+ Next version control tag as string
+ """
+ try:
+ current_value = int(current_tag)
+ next_value = (current_value + 1) % 256 # Cycle from 255 back to 0
+ return str(next_value)
+ except (ValueError, TypeError):
+ # If current_tag is invalid, start from 0
+ logger.warning(f"Invalid version_control '{current_tag}', resetting to '0'")
+ return "0"
+
+ @abstractmethod
+ def merge_items(self, orm_instance, obj_instance, size_limit):
+ """Merge items from database with current object instance
+
+ Args:
+ orm_instance: ORM instance from database
+ obj_instance: Current business object instance
+ size_limit: Maximum number of items to keep after merge
+ """
+
+ def sync_with_orm(self, size_limit: int | None = None) -> None:
+ """
+ Synchronize data between the database and the business object.
+
+ This method performs a three-step synchronization process:
+ 1. Acquire lock and get existing data from database
+ 2. Merge database items with current object items
+ 3. Write merged data back to database and release lock
+
+ Args:
+ size_limit: Optional maximum number of items to keep after synchronization.
+ If specified, only the most recent items will be retained.
+ """
+ logger.info(
+ f"Starting sync_with_orm for {self.user_id}/{self.mem_cube_id} with size_limit={size_limit}"
+ )
+ user_id = self.user_id
+ mem_cube_id = self.mem_cube_id
+
+ session = self._get_session()
+
+ try:
+ # Acquire lock before any database operations
+ lock_status = self.acquire_lock(block=True)
+ if not lock_status:
+ logger.error("Failed to acquire lock for synchronization")
+ return
+
+ # 1. Get existing data from database
+ orm_instance = (
+ session.query(self.orm_class)
+ .filter_by(user_id=user_id, mem_cube_id=mem_cube_id)
+ .first()
+ )
+
+ # If no existing record, create a new one
+ if orm_instance is None:
+ if self.obj is None:
+ logger.warning("No object to synchronize and no existing database record")
+ return
+
+ orm_instance = self.orm_class(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ serialized_data=self.obj.to_json(),
+ version_control="0", # Start with tag 0 for new records
+ )
+ logger.info(
+ "No existing ORM instance found. Created a new one. "
+ "Note: size_limit was not applied because there is no existing data to merge."
+ )
+ session.add(orm_instance)
+ session.commit()
+ # Update last_version_control for new record
+ self.last_version_control = "0"
+ return
+
+ # 2. Check version control and merge data from database with current object
+ if self.obj is not None:
+ current_db_tag = orm_instance.version_control
+ new_tag = self._increment_version_control(current_db_tag)
+ # Check if this is the first sync (last_version_control is None)
+ if self.last_version_control is None:
+ # First sync, increment version and perform merge
+ logger.info(
+ f"First sync, incrementing version from {current_db_tag} to {new_tag} for {self.user_id}/{self.mem_cube_id}"
+ )
+ elif current_db_tag == self.last_version_control:
+ logger.info(
+ f"Version control unchanged ({current_db_tag}), directly update {self.user_id}/{self.mem_cube_id}"
+ )
+ else:
+ # Version control has changed, increment it and perform merge
+ logger.info(
+ f"Version control changed from {self.last_version_control} to {current_db_tag}, incrementing to {new_tag} for {self.user_id}/{self.mem_cube_id}"
+ )
+ try:
+ self.merge_items(
+ orm_instance=orm_instance, obj_instance=self.obj, size_limit=size_limit
+ )
+ except Exception as merge_error:
+ logger.error(f"Error during merge_items: {merge_error}", exc_info=True)
+ logger.warning("Continuing with current object data without merge")
+
+ # 3. Write merged data back to database
+ orm_instance.serialized_data = self.obj.to_json()
+ orm_instance.version_control = new_tag
+ logger.info(f"Updated serialized_data for {self.user_id}/{self.mem_cube_id}")
+
+ # Update last_version_control to current value
+ self.last_version_control = orm_instance.version_control
+ else:
+ logger.warning("No current object to merge with database data")
+
+ session.commit()
+ logger.info(f"Synchronization completed for {self.user_id}/{self.mem_cube_id}")
+
+ except Exception as e:
+ session.rollback()
+ logger.error(
+ f"Error during synchronization for {user_id}/{mem_cube_id}: {e}", exc_info=True
+ )
+ finally:
+ # Always release locks and close session
+ self.release_locks(user_id=user_id, mem_cube_id=mem_cube_id)
+ session.close()
+
+ def save_to_db(self, obj_instance) -> None:
+ """Save the current state of the business object to the database
+
+ Args:
+ obj_instance: The business object instance to save
+ """
+ user_id = self.user_id
+ mem_cube_id = self.mem_cube_id
+
+ session = self._get_session()
+
+ try:
+ # Acquire lock before database operations
+ lock_status = self.acquire_lock(block=True)
+ if not lock_status:
+ logger.error("Failed to acquire lock for saving to database")
+ return
+
+ # Check if record already exists
+ orm_instance = (
+ session.query(self.orm_class)
+ .filter_by(user_id=user_id, mem_cube_id=mem_cube_id)
+ .first()
+ )
+
+ if orm_instance is None:
+ # Create new record
+ orm_instance = self.orm_class(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ serialized_data=obj_instance.to_json(),
+ version_control="0", # Start with version 0 for new records
+ )
+ session.add(orm_instance)
+ logger.info(f"Created new database record for {user_id}/{mem_cube_id}")
+ # Update last_version_control for new record
+ self.last_version_control = "0"
+ else:
+ # Update existing record with version control
+ current_version = orm_instance.version_control
+ new_version = self._increment_version_control(current_version)
+ orm_instance.serialized_data = obj_instance.to_json()
+ orm_instance.version_control = new_version
+ logger.info(
+ f"Updated existing database record for {user_id}/{mem_cube_id} with version {new_version}"
+ )
+ # Update last_version_control
+ self.last_version_control = new_version
+
+ session.commit()
+
+ except Exception as e:
+ session.rollback()
+ logger.error(f"Error saving to database for {user_id}/{mem_cube_id}: {e}")
+ finally:
+ # Always release locks and close session
+ self.release_locks(user_id=user_id, mem_cube_id=mem_cube_id)
+ session.close()
+
+ def load_from_db(self, acquire_lock: bool = False):
+ """Load the business object from the database
+
+ Args:
+ acquire_lock: Whether to acquire a lock during the load operation
+
+ Returns:
+ The deserialized business object instance, or None if not found
+ """
+ user_id = self.user_id
+ mem_cube_id = self.mem_cube_id
+
+ session = self._get_session()
+
+ try:
+ if acquire_lock:
+ lock_status = self.acquire_lock(block=True)
+ if not lock_status:
+ logger.error("Failed to acquire lock for loading from database")
+ return None
+
+ # Query for the database record
+ orm_instance = (
+ session.query(self.orm_class)
+ .filter_by(user_id=user_id, mem_cube_id=mem_cube_id)
+ .first()
+ )
+
+ if orm_instance is None:
+ logger.info(f"No database record found for {user_id}/{mem_cube_id}")
+ return None
+
+ # Deserialize the business object from JSON
+ db_instance = self.obj_class.from_json(orm_instance.serialized_data)
+ # Update last_version_control to track the loaded version
+ self.last_version_control = orm_instance.version_control
+ logger.info(
+ f"Successfully loaded object from database for {user_id}/{mem_cube_id} with version {orm_instance.version_control}"
+ )
+
+ return db_instance
+
+ except Exception as e:
+ logger.error(f"Error loading from database for {user_id}/{mem_cube_id}: {e}")
+ return None
+ finally:
+ if acquire_lock:
+ self.release_locks(user_id=user_id, mem_cube_id=mem_cube_id)
+ session.close()
+
+ def close(self):
+ """Close the database manager and clean up resources
+
+ This method releases any held locks and disposes of the database engine.
+ Should be called when the manager is no longer needed.
+ """
+ try:
+ # Release any locks held by this manager instance
+ if self.user_id and self.mem_cube_id:
+ self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id)
+ logger.info(f"Released locks for {self.user_id}/{self.mem_cube_id}")
+
+ # Dispose of the engine to close all connections
+ if self.engine:
+ self.engine.dispose()
+ logger.info("Database engine disposed")
+
+ except Exception as e:
+ logger.error(f"Error during close operation: {e}")
+
+ @staticmethod
+ def create_default_engine() -> Engine:
+ """Create SQLAlchemy engine with default database path
+
+ Returns:
+ SQLAlchemy Engine instance using default scheduler_orm.db
+ """
+ temp_dir = tempfile.mkdtemp()
+ db_path = os.path.join(temp_dir, "test_scheduler_orm.db")
+
+ # Clean up any existing file (though unlikely)
+ if os.path.exists(db_path):
+ os.remove(db_path)
+ # Remove the temp directory if still exists (should be empty)
+ if os.path.exists(temp_dir) and not os.listdir(temp_dir):
+ os.rmdir(temp_dir)
+
+ # Ensure parent directory exists (re-create in case rmdir removed it)
+ parent_dir = Path(db_path).parent
+ parent_dir.mkdir(parents=True, exist_ok=True)
+
+ # Log the creation of the default engine with database path
+ logger.info(
+ "Creating default SQLAlchemy engine with temporary SQLite database at: %s", db_path
+ )
+
+ return create_engine(f"sqlite:///{db_path}", echo=False)
+
+ @staticmethod
+ def create_engine_from_db_path(db_path: str) -> Engine:
+ """Create SQLAlchemy engine from database path
+
+ Args:
+ db_path: Path to database file
+
+ Returns:
+ SQLAlchemy Engine instance
+ """
+ # Ensure the directory exists
+ Path(db_path).parent.mkdir(parents=True, exist_ok=True)
+
+ return create_engine(f"sqlite:///{db_path}", echo=False)
+
+ @staticmethod
+ def create_mysql_db_path(
+ host: str = "localhost",
+ port: int = 3306,
+ username: str = "root",
+ password: str = "",
+ database: str = "scheduler_orm",
+ charset: str = "utf8mb4",
+ ) -> str:
+ """Create MySQL database connection URL
+
+ Args:
+ host: MySQL server hostname
+ port: MySQL server port
+ username: Database username
+ password: Database password (optional)
+ database: Database name
+ charset: Character set encoding
+
+ Returns:
+ MySQL connection URL string
+ """
+ # Build MySQL connection URL with proper formatting
+ if password:
+ db_path = (
+ f"mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset={charset}"
+ )
+ else:
+ db_path = f"mysql+pymysql://{username}@{host}:{port}/{database}?charset={charset}"
+ return db_path
diff --git a/src/memos/mem_scheduler/orm_modules/monitor_models.py b/src/memos/mem_scheduler/orm_modules/monitor_models.py
new file mode 100644
index 00000000..a5a04eb4
--- /dev/null
+++ b/src/memos/mem_scheduler/orm_modules/monitor_models.py
@@ -0,0 +1,261 @@
+from typing import TypeVar
+
+from sqlalchemy import Index
+from sqlalchemy.engine import Engine
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.monitor_schemas import (
+ MemoryMonitorItem,
+ MemoryMonitorManager,
+ QueryMonitorItem,
+ QueryMonitorQueue,
+)
+
+from .base_model import BaseDBManager, LockableORM
+
+
+logger = get_logger(__name__)
+
+# Type variables for generic type hints
+T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.)
+ORM = TypeVar("ORM") # The ORM model type
+
+
+class MemoryMonitorManagerORM(LockableORM):
+ """ORM model for MemoryMonitorManager persistence
+
+ This table stores serialized MemoryMonitorManager instances with
+ proper indexing for efficient user and memory cube lookups.
+ """
+
+ __tablename__ = "memory_monitor_manager"
+
+ # Database indexes for performance optimization
+ __table_args__ = (Index("idx_memory_monitor_user_memcube", "user_id", "mem_cube_id"),)
+
+
+class QueryMonitorQueueORM(LockableORM):
+ """ORM model for QueryMonitorQueue persistence
+
+ This table stores serialized QueryMonitorQueue instances with
+ proper indexing for efficient user and memory cube lookups.
+ """
+
+ __tablename__ = "query_monitor_queue"
+
+ # Database indexes for performance optimization
+ __table_args__ = (Index("idx_query_monitor_user_memcube", "user_id", "mem_cube_id"),)
+
+
+class DBManagerForMemoryMonitorManager(BaseDBManager):
+ """Database manager for MemoryMonitorManager objects
+
+ This class handles persistence, synchronization, and locking
+ for MemoryMonitorManager instances in the database.
+ """
+
+ def __init__(
+ self,
+ engine: Engine,
+ user_id: str | None = None,
+ mem_cube_id: str | None = None,
+ obj: MemoryMonitorManager | None = None,
+ lock_timeout: int = 10,
+ ):
+ """
+ Initialize the MemoryMonitorManager database manager.
+
+ Args:
+ engine: SQLAlchemy engine instance
+ user_id: Unique identifier for the user
+ mem_cube_id: Unique identifier for the memory cube
+ obj: Optional MemoryMonitorManager instance to manage
+ lock_timeout: Timeout in seconds for lock acquisition
+ """
+ super().__init__(
+ engine=engine, user_id=user_id, mem_cube_id=mem_cube_id, lock_timeout=lock_timeout
+ )
+ self.obj: MemoryMonitorManager | None = obj
+
+ @property
+ def orm_class(self) -> type[MemoryMonitorManagerORM]:
+ return MemoryMonitorManagerORM
+
+ @property
+ def obj_class(self) -> type[MemoryMonitorManager]:
+ return MemoryMonitorManager
+
+ def merge_items(
+ self,
+ orm_instance: MemoryMonitorManagerORM,
+ obj_instance: MemoryMonitorManager,
+ size_limit: int,
+ ):
+ """Merge memory monitor items from database with current object
+
+ This method combines items from the database with items in the current
+ object, prioritizing current object items and applying size limits.
+
+ Args:
+ orm_instance: ORM instance containing serialized database data
+ obj_instance: Current MemoryMonitorManager instance
+ size_limit: Maximum number of items to keep after merge
+
+ Returns:
+ Updated obj_instance with merged items
+ """
+ logger.debug(f"Starting merge_items for MemoryMonitorManager with size_limit={size_limit}")
+
+ try:
+ # Deserialize the database instance
+ db_instance: MemoryMonitorManager = MemoryMonitorManager.from_json(
+ orm_instance.serialized_data
+ )
+ except Exception as e:
+ logger.error(f"Failed to deserialize database instance: {e}", exc_info=True)
+ logger.warning("Skipping merge due to deserialization error, using current object only")
+ return obj_instance
+
+ # Merge items - prioritize existing ones in current object
+ merged_items: list[MemoryMonitorItem] = []
+ seen_ids = set()
+
+ # First, add all items from current object (higher priority)
+ for item in obj_instance.memories:
+ if item.item_id not in seen_ids:
+ merged_items.append(item)
+ seen_ids.add(item.item_id)
+
+ # Then, add items from database that aren't in current object
+ for item in db_instance.memories:
+ if item.item_id not in seen_ids:
+ merged_items.append(item)
+ seen_ids.add(item.item_id)
+
+ # Apply size limit if specified (keep most recent items)
+ if size_limit is not None and size_limit > 0:
+ try:
+ # Sort by sorting_score descending (highest priority first) and take top N
+ # Note: MemoryMonitorItem doesn't have timestamp, so we use sorting_score instead
+ merged_items = sorted(merged_items, key=lambda x: x.sorting_score, reverse=True)[
+ :size_limit
+ ]
+ logger.debug(f"Applied size limit of {size_limit}, kept {len(merged_items)} items")
+ except AttributeError as e:
+ logger.error(f"Error sorting MemoryMonitorItem objects: {e}")
+ logger.error(
+ "Available attributes: "
+ + ", ".join(dir(merged_items[0]) if merged_items else [])
+ )
+ raise
+ except Exception as e:
+ logger.error(f"Unexpected error during sorting: {e}")
+ raise
+
+ # Update the object with merged items
+ obj_instance.memories = merged_items
+
+ logger.info(
+ f"Merged {len(merged_items)} memory items for {obj_instance} (size_limit: {size_limit})"
+ )
+
+ return obj_instance
+
+
+class DBManagerForQueryMonitorQueue(BaseDBManager):
+ """Database manager for QueryMonitorQueue objects
+
+ This class handles persistence, synchronization, and locking
+ for QueryMonitorQueue instances in the database.
+ """
+
+ def __init__(
+ self,
+ engine: Engine,
+ user_id: str | None = None,
+ mem_cube_id: str | None = None,
+ obj: QueryMonitorQueue | None = None,
+ lock_timeout: int = 10,
+ ):
+ """
+ Initialize the QueryMonitorQueue database manager.
+
+ Args:
+ engine: SQLAlchemy engine instance
+ user_id: Unique identifier for the user
+ mem_cube_id: Unique identifier for the memory cube
+ obj: Optional QueryMonitorQueue instance to manage
+ lock_timeout: Timeout in seconds for lock acquisition
+ """
+ super().__init__(
+ engine=engine, user_id=user_id, mem_cube_id=mem_cube_id, lock_timeout=lock_timeout
+ )
+ self.obj: QueryMonitorQueue | None = obj
+
+ @property
+ def orm_class(self) -> type[QueryMonitorQueueORM]:
+ return QueryMonitorQueueORM
+
+ @property
+ def obj_class(self) -> type[QueryMonitorQueue]:
+ return QueryMonitorQueue
+
+ def merge_items(
+ self, orm_instance: QueryMonitorQueueORM, obj_instance: QueryMonitorQueue, size_limit: int
+ ):
+ """Merge query monitor items from database with current queue
+
+ This method combines items from the database with items in the current
+ queue, prioritizing current queue items and applying size limits.
+
+ Args:
+ orm_instance: ORM instance containing serialized database data
+ obj_instance: Current QueryMonitorQueue instance
+ size_limit: Maximum number of items to keep after merge
+
+ Returns:
+ Updated obj_instance with merged items
+ """
+ try:
+ # Deserialize the database instance
+ db_instance: QueryMonitorQueue = QueryMonitorQueue.from_json(
+ orm_instance.serialized_data
+ )
+ except Exception as e:
+ logger.error(f"Failed to deserialize database instance: {e}")
+ logger.warning("Skipping merge due to deserialization error, using current object only")
+ return obj_instance
+
+ # Merge items - prioritize existing ones in current object
+ merged_items: list[QueryMonitorItem] = []
+ seen_ids = set()
+
+ # First, add all items from current queue (higher priority)
+ for item in obj_instance.get_queue_content_without_pop():
+ if item.item_id not in seen_ids:
+ merged_items.append(item)
+ seen_ids.add(item.item_id)
+
+ # Then, add items from database queue that aren't in current queue
+ for item in db_instance.get_queue_content_without_pop():
+ if item.item_id not in seen_ids:
+ merged_items.append(item)
+ seen_ids.add(item.item_id)
+
+ # Apply size limit if specified (keep most recent items)
+ if size_limit is not None and size_limit > 0:
+ # Sort by timestamp descending (newest first) and take top N
+ merged_items = sorted(merged_items, key=lambda x: x.timestamp, reverse=True)[
+ :size_limit
+ ]
+
+ # Update the queue with merged items
+ obj_instance.clear() # Clear existing items
+ for item in merged_items:
+ obj_instance.put(item) # Add merged items back
+
+ logger.info(
+ f"Merged {len(merged_items)} query items for {obj_instance} (size_limit: {size_limit})"
+ )
+
+ return obj_instance
diff --git a/src/memos/mem_scheduler/scheduler_factory.py b/src/memos/mem_scheduler/scheduler_factory.py
index 5bcd0e2b..3cd406f3 100644
--- a/src/memos/mem_scheduler/scheduler_factory.py
+++ b/src/memos/mem_scheduler/scheduler_factory.py
@@ -3,6 +3,7 @@
from memos.configs.mem_scheduler import SchedulerConfigFactory
from memos.mem_scheduler.base_scheduler import BaseScheduler
from memos.mem_scheduler.general_scheduler import GeneralScheduler
+from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
class SchedulerFactory(BaseScheduler):
@@ -10,6 +11,7 @@ class SchedulerFactory(BaseScheduler):
backend_to_class: ClassVar[dict[str, Any]] = {
"general_scheduler": GeneralScheduler,
+ "optimized_scheduler": OptimizedScheduler,
}
@classmethod
diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py
index 65238d72..f148f30d 100644
--- a/src/memos/mem_scheduler/schemas/monitor_schemas.py
+++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py
@@ -1,3 +1,4 @@
+import json
import threading
from collections import Counter
@@ -30,6 +31,8 @@ class QueryMonitorItem(BaseModel, DictConversionMixin):
item_id: str = Field(
description="Unique identifier for the query item", default_factory=lambda: str(uuid4())
)
+ user_id: str = Field(..., description="Required user identifier", min_length=1)
+ mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1)
query_text: str = Field(
...,
description="The actual user query text content",
@@ -111,7 +114,8 @@ def get_keywords_collections(self) -> Counter:
"""
with self.mutex:
logger.debug(f"Thread {threading.get_ident()} acquired mutex.")
- all_keywords = [kw for item in self.queue for kw in item.keywords]
+ # Fix: Handle None keywords safely
+ all_keywords = [kw for item in self.queue if item.keywords for kw in item.keywords]
return Counter(all_keywords)
def get_queries_with_timesort(self, reverse: bool = True) -> list[str]:
@@ -132,9 +136,62 @@ def get_queries_with_timesort(self, reverse: bool = True) -> list[str]:
for monitor in sorted(self.queue, key=lambda x: x.timestamp, reverse=reverse)
]
+ def to_json(self) -> str:
+ """Serialize the queue to a JSON string.
+
+ Args:
+ item_serializer: Optional function to serialize individual items.
+ If not provided, items must be JSON-serializable.
+
+ Returns:
+ A JSON string representing the queue's content and maxsize.
+ """
+ with self.mutex:
+ serialized_items = [item.to_json() for item in self.queue]
+
+ data = {"maxsize": self.maxsize, "items": serialized_items}
+ return json.dumps(data, ensure_ascii=False, indent=2)
+
+ @classmethod
+ def from_json(cls, json_str: str) -> "QueryMonitorQueue":
+ """Create a new AutoDroppingQueue from a JSON string.
+
+ Args:
+ json_str: JSON string created by to_json()
+ item_deserializer: Optional function to reconstruct items from dicts.
+ If not provided, items are used as-is.
+
+ Returns:
+ A new AutoDroppingQueue instance with deserialized data.
+ """
+ data = json.loads(json_str)
+ maxsize = data.get("maxsize", 0)
+ item_strs = data.get("items", [])
+
+ queue = cls(maxsize=maxsize)
+
+ items = [QueryMonitorItem.from_json(json_str=item_str) for item_str in item_strs]
+
+ # Fix: Add error handling for put operations
+ for item in items:
+ try:
+ queue.put(item) # Use put() to respect maxsize and auto-drop behavior
+ except Exception as e:
+ logger.error(f"Failed to add item to queue: {e}")
+ # Continue with other items instead of failing completely
+
+ return queue
+
# ============== Memories ==============
class MemoryMonitorItem(BaseModel, DictConversionMixin):
+ """
+ Represents a memory item in the monitoring system.
+
+ Note: This class does NOT have a timestamp field, unlike QueryMonitorItem.
+ For sorting by recency, use sorting_score or importance_score instead.
+ """
+
item_id: str = Field(
description="Unique identifier for the memory item", default_factory=lambda: str(uuid4())
)
@@ -167,7 +224,7 @@ class MemoryMonitorItem(BaseModel, DictConversionMixin):
recording_count: int = Field(
default=1,
description="How many times this memory has been recorded",
- ge=1, # Greater than or equal to 1
+ ge=1,
)
@field_validator("tree_memory_item_mapping_key", mode="before")
@@ -177,27 +234,28 @@ def generate_mapping_key(cls, v, values): # noqa: N805
return v
def get_importance_score(self, weight_vector: list[float] | None = None) -> float:
- """
- Calculate the effective score for the memory item.
+ return self._get_complex_importance_score(weight_vector=weight_vector)
- Returns:
- float: The importance_score if it has been initialized (>=0),
- otherwise the recording_count converted to float.
-
- Note:
- This method provides a unified way to retrieve a comparable score
- for memory items, regardless of whether their importance has been explicitly set.
- """
+ def _get_complex_importance_score(self, weight_vector: list[float] | None = None) -> float:
+ """Calculate traditional importance score using existing logic"""
if weight_vector is None:
- logger.warning("weight_vector of get_importance_score is None.")
+ logger.warning("weight_vector of get_complex_score is None.")
weight_vector = DEFAULT_WEIGHT_VECTOR_FOR_RANKING
- assert sum(weight_vector) == 1
- normalized_keywords_score = min(self.keywords_score * weight_vector[1], 5)
+
+ # Fix: Add proper validation for weight_vector
+ if not weight_vector or len(weight_vector) != 3 or abs(sum(weight_vector) - 1.0) > 1e-6:
+ raise ValueError("weight_vector must be provided, have length 3, and sum to 1.0")
+
+ # Fix: Handle uninitialized scores safely
+ sorting_score = self.sorting_score if self.sorting_score != NOT_INITIALIZED else 0.0
+ keywords_score = self.keywords_score if self.keywords_score != NOT_INITIALIZED else 0.0
+
+ normalized_keywords_score = min(keywords_score * weight_vector[1], 5)
normalized_recording_count_score = min(self.recording_count * weight_vector[2], 2)
self.importance_score = (
- self.sorting_score * weight_vector[0]
- + normalized_keywords_score
- + normalized_recording_count_score
+ sorting_score * weight_vector[0]
+ + normalized_keywords_score * weight_vector[1]
+ + normalized_recording_count_score * weight_vector[2]
)
return self.importance_score
@@ -258,7 +316,7 @@ def get_sorted_mem_monitors(self, reverse=True) -> list[MemoryMonitorItem]:
def update_memories(
self, new_memory_monitors: list[MemoryMonitorItem], partial_retention_number: int
- ) -> MemoryMonitorItem:
+ ) -> list[MemoryMonitorItem]: # Fix: Correct return type
"""
Update memories based on monitor_working_memories.
"""
@@ -302,6 +360,13 @@ def update_memories(
reverse=True,
)
+ # Fix: Add bounds checking to prevent IndexError
+ if partial_retention_number > len(sorted_old_mem_monitors):
+ partial_retention_number = len(sorted_old_mem_monitors)
+ logger.info(
+ f"partial_retention_number adjusted to {partial_retention_number} to match available old memories"
+ )
+
# Keep the top N old memories
memories_to_remove = sorted_old_mem_monitors[partial_retention_number:]
memories_to_change_score = sorted_old_mem_monitors[:partial_retention_number]
@@ -312,19 +377,21 @@ def update_memories(
for memory in memories_to_change_score:
memory.sorting_score = 0
- memory.recording_count = 0
+ memory.recording_count = 1
memory.keywords_score = 0
# Step 4: Enforce max_capacity if set
- sorted_memories = sorted(
- self.memories,
- key=lambda item: item.get_importance_score(
- weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING
- ),
- reverse=True,
- )
- # Keep only the top max_capacity memories
- self.memories = sorted_memories[: self.max_capacity]
+ # Fix: Handle max_capacity safely
+ if self.max_capacity is not None:
+ sorted_memories = sorted(
+ self.memories,
+ key=lambda item: item.get_importance_score(
+ weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING
+ ),
+ reverse=True,
+ )
+ # Keep only the top max_capacity memories
+ self.memories = sorted_memories[: self.max_capacity]
# Log the update result
logger.info(
diff --git a/src/memos/mem_scheduler/utils/config_utils.py b/src/memos/mem_scheduler/utils/config_utils.py
new file mode 100644
index 00000000..8bb1050e
--- /dev/null
+++ b/src/memos/mem_scheduler/utils/config_utils.py
@@ -0,0 +1,100 @@
+import json
+import os
+
+from typing import Any
+
+import yaml
+
+
+def flatten_dict(
+ data: dict[str, Any], parent_keys: list[str] | None = None, prefix: str = ""
+) -> dict[str, str]:
+ """
+ Recursively flattens a nested dictionary to generate environment variable keys following the specified format.
+ Combines nested keys with underscores, converts to uppercase, and prepends a custom prefix if provided.
+
+ Args:
+ data: Nested dictionary to be flattened (parsed from JSON/YAML)
+ parent_keys: List to track nested keys during recursion
+ prefix: Custom prefix to be added to all generated keys
+
+ Returns:
+ Flattened dictionary with keys in PREFIX_KEY1_KEY2... format and string values
+ """
+ parent_keys = parent_keys or []
+ flat_data = {}
+
+ for key, value in data.items():
+ # Clean and standardize key: convert to uppercase, replace spaces/hyphens with underscores
+ clean_key = key.upper().replace(" ", "_").replace("-", "_")
+ current_keys = [*parent_keys, clean_key]
+
+ if isinstance(value, dict):
+ # Recursively process nested dictionaries
+ nested_flat = flatten_dict(value, current_keys, prefix)
+ flat_data.update(nested_flat)
+ else:
+ # Construct full key name with prefix (if provided) and nested keys
+ if prefix:
+ full_key = f"{prefix.upper()}_{'_'.join(current_keys)}"
+ else:
+ full_key = "_".join(current_keys)
+
+ # Process value: ensure string type, convert None to empty string
+ flat_value = "" if value is None else str(value).strip()
+
+ flat_data[full_key] = flat_value
+
+ return flat_data
+
+
+def convert_config_to_env(input_file: str, output_file: str = ".env", prefix: str = "") -> None:
+ """
+ Converts a JSON or YAML configuration file to a .env file with standardized environment variables.
+ Uses the flatten_dict function to generate keys in PREFIX_KEY1_KEY2... format.
+
+ Args:
+ input_file: Path to input configuration file (.json, .yaml, or .yml)
+ output_file: Path to output .env file (default: .env)
+ prefix: Custom prefix for all environment variable keys
+
+ Raises:
+ FileNotFoundError: If input file does not exist
+ ValueError: If file format is unsupported or parsing fails
+ """
+ # Check if input file exists
+ if not os.path.exists(input_file):
+ raise FileNotFoundError(f"Input file not found: {input_file}")
+
+ # Parse input file based on extension
+ file_ext = os.path.splitext(input_file)[1].lower()
+ config_data: dict[str, Any] = {}
+
+ try:
+ with open(input_file, encoding="utf-8") as f:
+ if file_ext in (".json",):
+ config_data = json.load(f)
+ elif file_ext in (".yaml", ".yml"):
+ config_data = yaml.safe_load(f)
+ else:
+ raise ValueError(
+ f"Unsupported file format: {file_ext}. Supported formats: .json, .yaml, .yml"
+ )
+ except (json.JSONDecodeError, yaml.YAMLError) as e:
+ raise ValueError(f"Error parsing file: {e!s}") from e
+
+ # Flatten configuration and generate environment variable key-value pairs
+ flat_config = flatten_dict(config_data, prefix=prefix)
+
+ # Write to .env file
+ with open(output_file, "w", encoding="utf-8") as f:
+ for key, value in flat_config.items():
+ # Handle values containing double quotes (use no surrounding quotes)
+ if '"' in value:
+ f.write(f"{key}={value}\n")
+ else:
+ f.write(f'{key}="{value}"\n') # Enclose regular values in double quotes
+
+ print(
+ f"Conversion complete! Generated {output_file} with {len(flat_config)} environment variables"
+ )
diff --git a/src/memos/mem_scheduler/utils/db_utils.py b/src/memos/mem_scheduler/utils/db_utils.py
new file mode 100644
index 00000000..5d7cc52c
--- /dev/null
+++ b/src/memos/mem_scheduler/utils/db_utils.py
@@ -0,0 +1,33 @@
+import os
+import sqlite3
+
+
+def print_db_tables(db_path: str):
+ """Print all table names and structures in the SQLite database"""
+ print(f"\n🔍 Checking database file: {db_path}")
+
+ if not os.path.exists(db_path):
+ print(f"❌ File does not exist! Path: {db_path}")
+ return
+
+ conn = sqlite3.connect(db_path)
+ cursor = conn.cursor()
+
+ # List all tables
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
+ tables = cursor.fetchall()
+ if not tables:
+ print("❌ Database is empty, no tables created")
+ else:
+ print(f"✅ Database contains {len(tables)} table(s):")
+ for (table_name,) in tables:
+ print(f" 📂 Table name: {table_name}")
+
+ # Print table structure
+ cursor.execute(f"PRAGMA table_info({table_name});")
+ columns = cursor.fetchall()
+ print(" 🧩 Structure:")
+ for col in columns:
+ print(f" {col[1]} ({col[2]}) {'(PK)' if col[5] else ''}")
+
+ conn.close()
diff --git a/src/memos/mem_scheduler/utils/filter_utils.py b/src/memos/mem_scheduler/utils/filter_utils.py
index 6055fe41..7aa0657e 100644
--- a/src/memos/mem_scheduler/utils/filter_utils.py
+++ b/src/memos/mem_scheduler/utils/filter_utils.py
@@ -60,7 +60,7 @@ def is_all_chinese(input_string: str) -> bool:
install_command="pip install scikit-learn",
install_link="https://scikit-learn.org/stable/install.html",
)
-def filter_similar_memories(
+def filter_vector_based_similar_memories(
text_memories: list[str], similarity_threshold: float = 0.75
) -> list[str]:
"""
diff --git a/src/memos/mem_scheduler/webservice_modules/__init__.py b/src/memos/mem_scheduler/webservice_modules/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/memos/mem_scheduler/general_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
similarity index 100%
rename from src/memos/mem_scheduler/general_modules/rabbitmq_service.py
rename to src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
diff --git a/src/memos/mem_scheduler/general_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py
similarity index 100%
rename from src/memos/mem_scheduler/general_modules/redis_service.py
rename to src/memos/mem_scheduler/webservice_modules/redis_service.py
diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py
index 06cef794..2fa08590 100644
--- a/src/memos/memories/activation/kv.py
+++ b/src/memos/memories/activation/kv.py
@@ -1,9 +1,10 @@
import os
import pickle
+
from datetime import datetime
from importlib.metadata import version
-from packaging.version import Version
+from packaging.version import Version
from transformers import DynamicCache
from memos.configs.memory import KVCacheMemoryConfig
diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py
index 6b6e70fd..2da283d4 100644
--- a/src/memos/memories/textual/item.py
+++ b/src/memos/memories/textual/item.py
@@ -1,13 +1,48 @@
"""Defines memory item types for textual memory."""
+import json
import uuid
from datetime import datetime
-from typing import Literal
+from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
+ALLOWED_ROLES = {"user", "assistant", "system"}
+
+
+class SourceMessage(BaseModel):
+ """
+ Purpose: **memory provenance / traceability**.
+
+ Capture the minimal, reproducible origin context of a memory item so it can be
+ audited, traced, rolled back, or de-duplicated later.
+
+ Fields & conventions:
+ - type: Source kind (e.g., "chat", "doc", "web", "file", "system", ...).
+ If not provided, upstream logic may infer it:
+ presence of `role` ⇒ "chat"; otherwise ⇒ "doc".
+ - role: Conversation role ("user" | "assistant" | "system") when the
+ source is a chat turn.
+ - content: Minimal reproducible snippet from the source. If omitted,
+ upstream may fall back to `doc_path` / `url` / `message_id`.
+ - chat_time / message_id / doc_path: Locators for precisely pointing back
+ to the original record (timestamp, message id, document path).
+ - Extra fields: Allowed (`model_config.extra="allow"`) to carry arbitrary
+ provenance attributes (e.g., url, page, offset, span, local_confidence).
+ """
+
+ type: str | None = "chat"
+ role: Literal["user", "assistant", "system"] | None = None
+ chat_time: str | None = None
+ message_id: str | None = None
+ content: str | None = None
+ doc_path: str | None = None
+
+ model_config = ConfigDict(extra="allow")
+
+
class TextualMemoryMetadata(BaseModel):
"""Metadata for a memory item.
@@ -62,7 +97,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
memory_type: Literal["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] = Field(
default="WorkingMemory", description="Memory lifecycle type."
)
- sources: list[str] | None = Field(
+ sources: list[SourceMessage] | None = Field(
default=None, description="Multiple origins of the memory (e.g., URLs, notes)."
)
embedding: list[float] | None = Field(
@@ -74,8 +109,8 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
description="The timestamp of the first creation to the memory. Useful "
"for tracking memory initialization. Format: ISO 8601.",
)
- usage: list[str] | None = Field(
- default=[],
+ usage: list[str] = Field(
+ default_factory=list,
description="Usage history of this node",
)
background: str | None = Field(
@@ -83,12 +118,40 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
description="background of this node",
)
- @field_validator("sources")
+ @field_validator("sources", mode="before")
@classmethod
- def validate_sources(cls, v):
- if v is not None and not isinstance(v, list):
- raise ValueError("Sources must be a list of strings.")
- return v
+ def coerce_sources(cls, v):
+ if v is None:
+ return v
+ if not isinstance(v, list):
+ raise TypeError("sources must be a list")
+ out = []
+ for item in v:
+ if isinstance(item, SourceMessage):
+ out.append(item)
+
+ elif isinstance(item, dict):
+ d = dict(item)
+ if d.get("type") is None:
+ d["type"] = "chat" if d.get("role") in ALLOWED_ROLES else "doc"
+ out.append(SourceMessage(**d))
+
+ elif isinstance(item, str):
+ try:
+ parsed = json.loads(item)
+ except Exception:
+ parsed = None
+
+ if isinstance(parsed, dict):
+ if parsed.get("type") is None:
+ parsed["type"] = "chat" if parsed.get("role") in ALLOWED_ROLES else "doc"
+ out.append(SourceMessage(**parsed))
+ else:
+ out.append(SourceMessage(type="doc", content=item))
+
+ else:
+ out.append(SourceMessage(type="doc", content=str(item)))
+ return out
def __str__(self) -> str:
"""Pretty string representation of the metadata."""
@@ -114,19 +177,17 @@ class TextualMemoryItem(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
memory: str
metadata: (
- TextualMemoryMetadata
+ SearchedTreeNodeTextualMemoryMetadata
| TreeNodeTextualMemoryMetadata
- | SearchedTreeNodeTextualMemoryMetadata
+ | TextualMemoryMetadata
) = Field(default_factory=TextualMemoryMetadata)
model_config = ConfigDict(extra="forbid")
+ @field_validator("id")
@classmethod
- def validate_id(cls, v):
- try:
- uuid.UUID(v)
- except ValueError as e:
- raise ValueError("Invalid UUID format") from e
+ def _validate_id(cls, v: str) -> str:
+ uuid.UUID(v)
return v
@classmethod
@@ -136,6 +197,24 @@ def from_dict(cls, data: dict) -> "TextualMemoryItem":
def to_dict(self) -> dict:
return self.model_dump(exclude_none=True)
+ @field_validator("metadata", mode="before")
+ @classmethod
+ def _coerce_metadata(cls, v: Any):
+ if isinstance(
+ v,
+ SearchedTreeNodeTextualMemoryMetadata
+ | TreeNodeTextualMemoryMetadata
+ | TextualMemoryMetadata,
+ ):
+ return v
+ if isinstance(v, dict):
+ if v.get("relativity") is not None:
+ return SearchedTreeNodeTextualMemoryMetadata(**v)
+ if any(k in v for k in ("sources", "memory_type", "embedding", "background", "usage")):
+ return TreeNodeTextualMemoryMetadata(**v)
+ return TextualMemoryMetadata(**v)
+ return v
+
def __str__(self) -> str:
"""Pretty string representation of the memory item."""
return f""
diff --git a/src/memos/memories/textual/naive.py b/src/memos/memories/textual/naive.py
index 236ce8f2..f8684729 100644
--- a/src/memos/memories/textual/naive.py
+++ b/src/memos/memories/textual/naive.py
@@ -115,7 +115,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any])
self.memories[i] = memory_dict
break
- def search(self, query: str, top_k: int) -> list[TextualMemoryItem]:
+ def search(self, query: str, top_k: int, **kwargs) -> list[TextualMemoryItem]:
"""Search for memories based on a query."""
sims = [
(memory, len(set(query.split()) & set(memory["memory"].split())))
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index 265150a2..f324f41c 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -2,6 +2,7 @@
import os
import shutil
import tempfile
+import time
from datetime import datetime
from pathlib import Path
@@ -32,15 +33,28 @@ class TreeTextMemory(BaseTextMemory):
def __init__(self, config: TreeTextMemoryConfig):
"""Initialize memory with the given configuration."""
+ time_start = time.time()
self.config: TreeTextMemoryConfig = config
self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
config.extractor_llm
)
+ logger.info(f"time init: extractor_llm time is: {time.time() - time_start}")
+
+ time_start_ex = time.time()
self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
config.dispatcher_llm
)
+ logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}")
+
+ time_start_em = time.time()
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
+ logger.info(f"time init: embedder time is: {time.time() - time_start_em}")
+
+ time_start_gs = time.time()
self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db)
+ logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}")
+
+ time_start_rr = time.time()
if config.reranker is None:
default_cfg = RerankerConfigFactory.model_validate(
{
@@ -54,9 +68,10 @@ def __init__(self, config: TreeTextMemoryConfig):
self.reranker = RerankerFactory.from_config(default_cfg)
else:
self.reranker = RerankerFactory.from_config(config.reranker)
-
+ logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
self.is_reorganize = config.reorganize
+ time_start_mm = time.time()
self.memory_manager: MemoryManager = MemoryManager(
self.graph_store,
self.embedder,
@@ -69,7 +84,8 @@ def __init__(self, config: TreeTextMemoryConfig):
},
is_reorganize=self.is_reorganize,
)
-
+ logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}")
+ time_start_ir = time.time()
# Create internet retriever if configured
self.internet_retriever = None
if config.internet_retriever is not None:
@@ -81,6 +97,7 @@ def __init__(self, config: TreeTextMemoryConfig):
)
else:
logger.info("No internet retriever configured")
+ logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}")
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
"""Add memories.
@@ -122,6 +139,7 @@ def search(
memory_type: str = "All",
manual_close_internet: bool = False,
moscube: bool = False,
+ search_filter: dict | None = None,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
User query -> TaskGoalParser -> MemoryPathResolver ->
@@ -136,6 +154,12 @@ def search(
memory_type (str): Type restriction for search.
['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config.
+ moscube (bool): whether you use moscube to answer questions
+ search_filter (dict, optional): Optional metadata filters for search results.
+ - Keys correspond to memory metadata fields (e.g., "user_id", "session_id").
+ - Values are exact-match conditions.
+ Example: {"user_id": "123", "session_id": "abc"}
+ If None, no additional filtering is applied.
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
@@ -160,7 +184,7 @@ def search(
internet_retriever=self.internet_retriever,
moscube=moscube,
)
- return searcher.search(query, top_k, info, mode, memory_type)
+ return searcher.search(query, top_k, info, mode, memory_type, search_filter)
def get_relevant_subgraph(
self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated"
diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py
index a1121fcd..595cf099 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/handler.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py
@@ -1,5 +1,6 @@
import json
import re
+
from datetime import datetime
from dateutil import parser
@@ -14,6 +15,7 @@
MEMORY_RELATION_RESOLVER_PROMPT,
)
+
logger = get_logger(__name__)
@@ -50,12 +52,12 @@ def detect(self, memory, top_k: int = 5, scope=None):
]
result = self.llm.generate(prompt).strip()
if result == "contradictory":
- logger.warning(
+ logger.info(
f'detected "{memory.memory}" <==CONFLICT==> "{embedding_candidate.memory}"'
)
detected_relationships.append([memory, embedding_candidate, "contradictory"])
elif result == "redundant":
- logger.warning(
+ logger.info(
f'detected "{memory.memory}" <==REDUNDANT==> "{embedding_candidate.memory}"'
)
detected_relationships.append([memory, embedding_candidate, "redundant"])
diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py
index 6b0a6a55..c9cd4de8 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/manager.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py
@@ -1,8 +1,10 @@
+import traceback
import uuid
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import as_completed
from datetime import datetime
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.neo4j import Neo4jGraphDB
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
@@ -55,24 +57,35 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]:
"""
added_ids: list[str] = []
- with ThreadPoolExecutor(max_workers=8) as executor:
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
futures = {executor.submit(self._process_memory, m): m for m in memories}
- for future in as_completed(futures):
+ for future in as_completed(futures, timeout=60):
try:
ids = future.result()
added_ids.extend(ids)
except Exception as e:
logger.exception("Memory processing error: ", exc_info=e)
- self.graph_store.remove_oldest_memory(
- memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
- )
- self.graph_store.remove_oldest_memory(
- memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
- )
- self.graph_store.remove_oldest_memory(
- memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
- )
+ try:
+ self.graph_store.remove_oldest_memory(
+ memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
+ )
+ except Exception:
+ logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}")
+
+ try:
+ self.graph_store.remove_oldest_memory(
+ memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
+ )
+ except Exception:
+ logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}")
+
+ try:
+ self.graph_store.remove_oldest_memory(
+ memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
+ )
+ except Exception:
+ logger.warning(f"Remove UserMemory error: {traceback.format_exc()}")
self._refresh_memory_size()
return added_ids
@@ -82,12 +95,12 @@ def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
Replace WorkingMemory
"""
working_memory_top_k = memories[: self.memory_size["WorkingMemory"]]
- with ThreadPoolExecutor(max_workers=8) as executor:
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
futures = [
executor.submit(self._add_memory_to_db, memory, "WorkingMemory")
for memory in working_memory_top_k
]
- for future in as_completed(futures):
+ for future in as_completed(futures, timeout=60):
try:
future.result()
except Exception as e:
@@ -102,6 +115,7 @@ def get_current_memory_size(self) -> dict[str, int]:
"""
Return the cached memory type counts.
"""
+ self._refresh_memory_size()
return self.current_memory_size
def _refresh_memory_size(self) -> None:
diff --git a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py
index 39e0a2ed..ad9dcb2b 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py
@@ -46,7 +46,7 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5
"sequence_links": [],
"aggregate_nodes": [],
}
-
+ """
nearest = self.graph_store.get_neighbors_by_tag(
tags=node.metadata.tags,
exclude_ids=exclude_ids,
@@ -55,7 +55,6 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5
)
nearest = [GraphDBNode(**cand_data) for cand_data in nearest]
- """
# 1) Pairwise relations (including CAUSE/CONDITION/CONFLICT)
pairwise = self._detect_pairwise_causal_condition_relations(node, nearest)
results["relations"].extend(pairwise["relations"])
diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py
index 586deaab..0337225d 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py
@@ -4,19 +4,20 @@
import traceback
from collections import defaultdict
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import as_completed
from queue import PriorityQueue
from typing import Literal
import numpy as np
+from memos.context.context import ContextThreadPoolExecutor
from memos.dependency import require_python_package
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.item import GraphDBEdge, GraphDBNode
from memos.graph_dbs.neo4j import Neo4jGraphDB
from memos.llms.base import BaseLLM
from memos.log import get_logger
-from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
+from memos.memories.textual.item import SourceMessage, TreeNodeTextualMemoryMetadata
from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
RelationAndReasoningDetector,
@@ -27,6 +28,22 @@
logger = get_logger(__name__)
+def build_summary_parent_node(cluster_nodes):
+ normalized_sources = []
+ for n in cluster_nodes:
+ sm = SourceMessage(
+ type="chat",
+ role=None,
+ chat_time=None,
+ message_id=None,
+ content=n.memory,
+ # extra
+ node_id=n.id,
+ )
+ normalized_sources.append(sm)
+ return normalized_sources
+
+
class QueueMessage:
def __init__(
self,
@@ -51,6 +68,15 @@ def __lt__(self, other: "QueueMessage") -> bool:
return op_priority[self.op] < op_priority[other.op]
+def extract_first_to_last_brace(text: str):
+ start = text.find("{")
+ end = text.rfind("}")
+ if start == -1 or end == -1 or end < start:
+ return "", None
+ json_str = text[start : end + 1]
+ return json_str, json.loads(json_str)
+
+
class GraphStructureReorganizer:
def __init__(
self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool
@@ -87,6 +113,7 @@ def wait_until_current_task_done(self):
1) queue is empty
2) any running structure optimization is done
"""
+ deadline = time.time() + 600
if not self.is_reorganize:
return
@@ -96,6 +123,9 @@ def wait_until_current_task_done(self):
while any(self._is_optimizing.values()):
logger.debug(f"Waiting for structure optimizer to finish... {self._is_optimizing}")
+ if time.time() > deadline:
+ logger.error(f"Wait timed out; flags={self._is_optimizing}")
+ break
time.sleep(1)
logger.debug("Structure optimizer is now idle.")
@@ -129,6 +159,9 @@ def _run_structure_organizer_loop(self):
logger.info("Structure optimizer schedule started.")
while not getattr(self, "_stop_scheduler", False):
+ if any(self._is_optimizing.values()):
+ time.sleep(1)
+ continue
if self._reorganize_needed:
logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
self.optimize_structure(scope="LongTermMemory")
@@ -176,6 +209,7 @@ def optimize_structure(
local_tree_threshold: int = 10,
min_cluster_size: int = 4,
min_group_size: int = 20,
+ max_duration_sec: int = 600,
):
"""
Periodically reorganize the graph:
@@ -183,8 +217,20 @@ def optimize_structure(
2. Summarize each cluster.
3. Create parent nodes and build local PARENT trees.
"""
+ # --- Total time watch dog: check functions ---
+ start_ts = time.time()
+
+ def _check_deadline(where: str):
+ if time.time() - start_ts > max_duration_sec:
+ logger.error(
+ f"[GraphStructureReorganize] {scope} surpass {max_duration_sec}s,time "
+ f"over at {where}"
+ )
+ return True
+ return False
+
if self._is_optimizing[scope]:
- logger.info(f"Already optimizing for {scope}. Skipping.")
+ logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.")
return
if self.graph_store.node_not_exist(scope):
@@ -198,32 +244,35 @@ def optimize_structure(
)
logger.debug(
- f"Num of scope in self.graph_store is {self.graph_store.get_memory_count(scope)}"
+ f"[GraphStructureReorganize] Num of scope in self.graph_store is"
+ f" {self.graph_store.get_memory_count(scope)}"
)
# Load candidate nodes
+ if _check_deadline("[GraphStructureReorganize] Before loading candidates"):
+ return
raw_nodes = self.graph_store.get_structure_optimization_candidates(scope)
nodes = [GraphDBNode(**n) for n in raw_nodes]
if not nodes:
logger.info("[GraphStructureReorganize] No nodes to optimize. Skipping.")
return
-
if len(nodes) < min_group_size:
logger.info(
f"[GraphStructureReorganize] Only {len(nodes)} candidate nodes found. Not enough to reorganize. Skipping."
)
return
- logger.info(f"[GraphStructureReorganize] Loaded {len(nodes)} nodes.")
-
# Step 2: Partition nodes
+ if _check_deadline("[GraphStructureReorganize] Before partition"):
+ return
partitioned_groups = self._partition(nodes)
-
logger.info(
f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters."
)
- with ThreadPoolExecutor(max_workers=4) as executor:
+ if _check_deadline("[GraphStructureReorganize] Before submit partition task"):
+ return
+ with ContextThreadPoolExecutor(max_workers=4) as executor:
futures = []
for cluster_nodes in partitioned_groups:
futures.append(
@@ -237,14 +286,17 @@ def optimize_structure(
)
for f in as_completed(futures):
+ if _check_deadline("[GraphStructureReorganize] Waiting clusters..."):
+ for x in futures:
+ x.cancel()
+ return
try:
f.result()
except Exception as e:
logger.warning(
- f"[Reorganize] Cluster processing "
- f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
+ f"[GraphStructureReorganize] Cluster processing failed: {e}, trace: {traceback.format_exc()}"
)
- logger.info("[GraphStructure Reorganize] Structure optimization finished.")
+ logger.info("[GraphStructure Reorganize] Structure optimization finished.")
finally:
self._is_optimizing[scope] = False
@@ -282,7 +334,7 @@ def _process_cluster_and_write(
nodes_to_check = cluster_nodes
exclude_ids = [n.id for n in nodes_to_check]
- with ThreadPoolExecutor(max_workers=4) as executor:
+ with ContextThreadPoolExecutor(max_workers=4) as executor:
futures = []
for node in nodes_to_check:
futures.append(
@@ -294,7 +346,7 @@ def _process_cluster_and_write(
)
)
- for f in as_completed(futures):
+ for f in as_completed(futures, timeout=300):
results = f.result()
# 1) Add pairwise relations
@@ -331,11 +383,11 @@ def _process_cluster_and_write(
for child_id in agg_node.metadata.sources:
self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO")
- logger.info("[Reorganizer] Cluster relation/reasoning done.")
+ logger.info("[Reorganizer] Cluster relation/reasoning done.")
def _local_subcluster(
- self, cluster_nodes: list[GraphDBNode], max_length: int = 8000
- ) -> (list)[list[GraphDBNode]]:
+ self, cluster_nodes: list[GraphDBNode], max_length: int = 15000
+ ) -> list[list[GraphDBNode]]:
"""
Use LLM to split a large cluster into semantically coherent sub-clusters.
"""
@@ -350,7 +402,7 @@ def _local_subcluster(
joined_scene = "\n".join(scene_lines)
if len(joined_scene) > max_length:
- logger.warning(f"Sub-cluster too long: {joined_scene}")
+ logger.warning("Sub-cluster too long")
prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
messages = [{"role": "user", "content": prompt}]
@@ -499,17 +551,17 @@ def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> Gr
parent_node = GraphDBNode(
memory=parent_value,
metadata=TreeNodeTextualMemoryMetadata(
- user_id="", # TODO: summarized node: no user_id
- session_id="", # TODO: summarized node: no session_id
+ user_id=None,
+ session_id=None,
memory_type=scope,
status="activated",
key=parent_key,
tags=parent_tags,
embedding=embedding,
usage=[],
- sources=[n.id for n in cluster_nodes],
+ sources=build_summary_parent_node(cluster_nodes),
background=parent_background,
- confidence=0.99,
+ confidence=0.66,
type="topic",
),
)
@@ -518,7 +570,7 @@ def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> Gr
def _parse_json_result(self, response_text):
try:
response_text = response_text.replace("```", "").replace("json", "")
- response_json = json.loads(response_text)
+ response_json = extract_first_to_last_brace(response_text)[1]
return response_json
except json.JSONDecodeError as e:
logger.warning(
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py
index 07f2c0a5..31b91477 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py
@@ -2,15 +2,17 @@
import json
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import as_completed
from datetime import datetime
+from typing import Any
import requests
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.log import get_logger
from memos.mem_reader.base import BaseMemReader
-from memos.memories.textual.item import TextualMemoryItem
+from memos.memories.textual.item import SourceMessage, TextualMemoryItem
logger = get_logger(__name__)
@@ -177,7 +179,7 @@ def _convert_to_mem_items(
if not info:
info = {"user_id": "", "session_id": ""}
- with ThreadPoolExecutor(max_workers=8) as executor:
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
futures = [
executor.submit(self._process_result, r, query, parsed_goal, info)
for r in search_results
@@ -193,7 +195,7 @@ def _convert_to_mem_items(
return list(unique_memory_items.values())
def _process_result(
- self, result: dict, query: str, parsed_goal: str, info: None
+ self, result: dict, query: str, parsed_goal: str, info: dict[str, Any]
) -> list[TextualMemoryItem]:
"""Process one Bocha search result into TextualMemoryItem."""
title = result.get("name", "")
@@ -225,7 +227,7 @@ def _process_result(
)
read_item_i.metadata.source = "web"
read_item_i.metadata.memory_type = "OuterMemory"
- read_item_i.metadata.sources = [url] if url else []
+ read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else []
read_item_i.metadata.visibility = "public"
memory_items.append(read_item_i)
return memory_items
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py
index 7ec235fb..819b4e36 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py
@@ -7,7 +7,11 @@
import requests
from memos.embedders.factory import OllamaEmbedder
-from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
+from memos.memories.textual.item import (
+ SourceMessage,
+ TextualMemoryItem,
+ TreeNodeTextualMemoryMetadata,
+)
class GoogleCustomSearchAPI:
@@ -172,7 +176,7 @@ def retrieve_from_internet(
visibility="public",
memory_type="LongTermMemory", # Internet search results as working memory
key=title,
- sources=[link] if link else [],
+ sources=[SourceMessage(type="web", url=link)] if link else [],
embedding=self.embedder.embed([memory_content])[0], # Can add embedding later
created_at=datetime.now().isoformat(),
usage=[],
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py
index b9a1cf13..3498f596 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py
@@ -10,6 +10,7 @@
InternetGoogleRetriever,
)
from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever
+from memos.memos_tools.singleton import singleton_factory
class InternetRetrieverFactory:
@@ -23,6 +24,7 @@ class InternetRetrieverFactory:
}
@classmethod
+ @singleton_factory()
def from_config(
cls, config_factory: InternetRetrieverConfigFactory, embedder: BaseEmbedder
) -> InternetGoogleRetriever | None:
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
index 1f6a5a41..84cc8ecb 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
@@ -1,11 +1,16 @@
import concurrent.futures
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.neo4j import Neo4jGraphDB
+from memos.log import get_logger
from memos.memories.textual.item import TextualMemoryItem
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
+logger = get_logger(__name__)
+
+
class GraphMemoryRetriever:
"""
Unified memory retriever that combines both graph-based and vector-based retrieval logic.
@@ -14,6 +19,8 @@ class GraphMemoryRetriever:
def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder):
self.graph_store = graph_store
self.embedder = embedder
+ self.max_workers = 10
+ self.filter_weight = 0.6
def retrieve(
self,
@@ -22,6 +29,7 @@ def retrieve(
top_k: int,
memory_scope: str,
query_embedding: list[list[float]] | None = None,
+ search_filter: dict | None = None,
) -> list[TextualMemoryItem]:
"""
Perform hybrid memory retrieval:
@@ -35,7 +43,7 @@ def retrieve(
top_k (int): Number of candidates to return.
memory_scope (str): One of ['working', 'long_term', 'user'].
query_embedding(list of embedding): list of embedding of query
-
+ search_filter (dict, optional): Optional metadata filters for search results.
Returns:
list: Combined memory items.
"""
@@ -45,16 +53,20 @@ def retrieve(
if memory_scope == "WorkingMemory":
# For working memory, retrieve all entries (no filtering)
working_memories = self.graph_store.get_all_memory_items(
- scope="WorkingMemory", include_embedding=True
+ scope="WorkingMemory", include_embedding=False
)
return [TextualMemoryItem.from_dict(record) for record in working_memories]
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
# Structured graph-based retrieval
future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope)
# Vector similarity search
future_vector = executor.submit(
- self._vector_recall, query_embedding, memory_scope, top_k
+ self._vector_recall,
+ query_embedding or [],
+ memory_scope,
+ top_k,
+ search_filter=search_filter,
)
graph_results = future_graph.result()
@@ -153,7 +165,7 @@ def _graph_recall(
return []
# Load nodes and post-filter
- node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=True)
+ node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)
final_nodes = []
for node in node_dicts:
@@ -181,34 +193,70 @@ def _vector_recall(
top_k: int = 20,
max_num: int = 3,
cube_name: str | None = None,
+ search_filter: dict | None = None,
) -> list[TextualMemoryItem]:
"""
- # TODO: tackle with post-filter and pre-filter(5.18+) better.
Perform vector-based similarity retrieval using query embedding.
+ # TODO: tackle with post-filter and pre-filter(5.18+) better.
"""
- all_matches = []
+ if not query_embedding:
+ return []
- def search_single(vec):
+ def search_single(vec, filt=None):
return (
self.graph_store.search_by_embedding(
- vector=vec, top_k=top_k, scope=memory_scope, cube_name=cube_name
+ vector=vec,
+ top_k=top_k,
+ scope=memory_scope,
+ cube_name=cube_name,
+ search_filter=filt,
)
or []
)
- with concurrent.futures.ThreadPoolExecutor() as executor:
- futures = [executor.submit(search_single, vec) for vec in query_embedding[:max_num]]
- for future in concurrent.futures.as_completed(futures):
- result = future.result()
- all_matches.extend(result)
-
- if not all_matches:
+ def search_path_a():
+ """Path A: search without filter"""
+ path_a_hits = []
+ with ContextThreadPoolExecutor() as executor:
+ futures = [
+ executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
+ ]
+ for f in concurrent.futures.as_completed(futures):
+ path_a_hits.extend(f.result() or [])
+ return path_a_hits
+
+ def search_path_b():
+ """Path B: search with filter"""
+ if not search_filter:
+ return []
+ path_b_hits = []
+ with ContextThreadPoolExecutor() as executor:
+ futures = [
+ executor.submit(search_single, vec, search_filter)
+ for vec in query_embedding[:max_num]
+ ]
+ for f in concurrent.futures.as_completed(futures):
+ path_b_hits.extend(f.result() or [])
+ return path_b_hits
+
+ # Execute both paths concurrently
+ all_hits = []
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
+ path_a_future = executor.submit(search_path_a)
+ path_b_future = executor.submit(search_path_b)
+
+ all_hits.extend(path_a_future.result())
+ all_hits.extend(path_b_future.result())
+
+ if not all_hits:
return []
- # Step 3: Extract matched IDs and retrieve full nodes
- unique_ids = set({r["id"] for r in all_matches})
- node_dicts = self.graph_store.get_nodes(
- list(unique_ids), include_embedding=True, cube_name=cube_name
+ # merge and deduplicate
+ unique_ids = {r["id"] for r in all_hits if r.get("id")}
+ node_dicts = (
+ self.graph_store.get_nodes(
+ list(unique_ids), include_embedding=False, cube_name=cube_name
+ )
+ or []
)
-
- return [TextualMemoryItem.from_dict(record) for record in node_dicts]
+ return [TextualMemoryItem.from_dict(n) for n in node_dicts]
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index 9ac1646e..df154f23 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -1,8 +1,9 @@
-import concurrent.futures
import json
+import traceback
from datetime import datetime
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.graph_dbs.factory import Neo4jGraphDB
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
@@ -42,13 +43,17 @@ def __init__(
self.internet_retriever = internet_retriever
self.moscube = moscube
- self._usage_executor = concurrent.futures.ThreadPoolExecutor(
- max_workers=4, thread_name_prefix="usage"
- )
+ self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")
@timed
def search(
- self, query: str, top_k: int, info=None, mode="fast", memory_type="All"
+ self,
+ query: str,
+ top_k: int,
+ info=None,
+ mode="fast",
+ memory_type="All",
+ search_filter: dict | None = None,
) -> list[TextualMemoryItem]:
"""
Search for memories based on a query.
@@ -63,6 +68,7 @@ def search(
- 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance.
memory_type (str): Type restriction for search.
['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
+ search_filter (dict, optional): Optional metadata filters for search results.
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
@@ -78,9 +84,11 @@ def search(
else:
logger.debug(f"[SEARCH] Received info dict: {info}")
- parsed_goal, query_embedding, context, query = self._parse_task(query, info, mode)
+ parsed_goal, query_embedding, context, query = self._parse_task(
+ query, info, mode, search_filter=search_filter
+ )
results = self._retrieve_paths(
- query, parsed_goal, query_embedding, info, top_k, mode, memory_type
+ query, parsed_goal, query_embedding, info, top_k, mode, memory_type, search_filter
)
deduped = self._deduplicate_results(results)
final_results = self._sort_and_trim(deduped, top_k)
@@ -96,7 +104,7 @@ def search(
return final_results
@timed
- def _parse_task(self, query, info, mode, top_k=5):
+ def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = None):
"""Parse user query, do embedding search and create context"""
context = []
query_embedding = None
@@ -109,9 +117,24 @@ def _parse_task(self, query, info, mode, top_k=5):
# retrieve related nodes by embedding
related_nodes = [
self.graph_store.get_node(n["id"])
- for n in self.graph_store.search_by_embedding(query_embedding, top_k=top_k)
+ for n in self.graph_store.search_by_embedding(
+ query_embedding, top_k=top_k, search_filter=search_filter
+ )
]
- context = list({node["memory"] for node in related_nodes})
+ memories = []
+ for node in related_nodes:
+ try:
+ m = (
+ node.get("memory")
+ if isinstance(node, dict)
+ else (getattr(node, "memory", None))
+ )
+ if isinstance(m, str) and m:
+ memories.append(m)
+ except Exception:
+ logger.error(f"[SEARCH] Error during search: {traceback.format_exc()}")
+ continue
+ context = list(dict.fromkeys(memories))
# optional: supplement context with internet knowledge
"""if self.internet_retriever:
@@ -135,10 +158,20 @@ def _parse_task(self, query, info, mode, top_k=5):
return parsed_goal, query_embedding, context, query
@timed
- def _retrieve_paths(self, query, parsed_goal, query_embedding, info, top_k, mode, memory_type):
+ def _retrieve_paths(
+ self,
+ query,
+ parsed_goal,
+ query_embedding,
+ info,
+ top_k,
+ mode,
+ memory_type,
+ search_filter: dict | None = None,
+ ):
"""Run A/B/C retrieval paths in parallel"""
tasks = []
- with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
+ with ContextThreadPoolExecutor(max_workers=3) as executor:
tasks.append(
executor.submit(
self._retrieve_from_working_memory,
@@ -147,6 +180,7 @@ def _retrieve_paths(self, query, parsed_goal, query_embedding, info, top_k, mode
query_embedding,
top_k,
memory_type,
+ search_filter,
)
)
tasks.append(
@@ -157,6 +191,7 @@ def _retrieve_paths(self, query, parsed_goal, query_embedding, info, top_k, mode
query_embedding,
top_k,
memory_type,
+ search_filter,
)
)
tasks.append(
@@ -193,14 +228,24 @@ def _retrieve_paths(self, query, parsed_goal, query_embedding, info, top_k, mode
# --- Path A
@timed
def _retrieve_from_working_memory(
- self, query, parsed_goal, query_embedding, top_k, memory_type
+ self,
+ query,
+ parsed_goal,
+ query_embedding,
+ top_k,
+ memory_type,
+ search_filter: dict | None = None,
):
"""Retrieve and rerank from WorkingMemory"""
if memory_type not in ["All", "WorkingMemory"]:
logger.info(f"[PATH-A] '{query}'Skipped (memory_type does not match)")
return []
items = self.graph_retriever.retrieve(
- query=query, parsed_goal=parsed_goal, top_k=top_k, memory_scope="WorkingMemory"
+ query=query,
+ parsed_goal=parsed_goal,
+ top_k=top_k,
+ memory_scope="WorkingMemory",
+ search_filter=search_filter,
)
return self.reranker.rerank(
query=query,
@@ -208,37 +253,61 @@ def _retrieve_from_working_memory(
graph_results=items,
top_k=top_k,
parsed_goal=parsed_goal,
+ search_filter=search_filter,
)
# --- Path B
@timed
def _retrieve_from_long_term_and_user(
- self, query, parsed_goal, query_embedding, top_k, memory_type
+ self,
+ query,
+ parsed_goal,
+ query_embedding,
+ top_k,
+ memory_type,
+ search_filter: dict | None = None,
):
"""Retrieve and rerank from LongTermMemory and UserMemory"""
results = []
- if memory_type in ["All", "LongTermMemory"]:
- results += self.graph_retriever.retrieve(
- query=query,
- parsed_goal=parsed_goal,
- query_embedding=query_embedding,
- top_k=top_k * 2,
- memory_scope="LongTermMemory",
- )
- if memory_type in ["All", "UserMemory"]:
- results += self.graph_retriever.retrieve(
- query=query,
- parsed_goal=parsed_goal,
- query_embedding=query_embedding,
- top_k=top_k * 2,
- memory_scope="UserMemory",
- )
+ tasks = []
+
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
+ if memory_type in ["All", "LongTermMemory"]:
+ tasks.append(
+ executor.submit(
+ self.graph_retriever.retrieve,
+ query=query,
+ parsed_goal=parsed_goal,
+ query_embedding=query_embedding,
+ top_k=top_k * 2,
+ memory_scope="LongTermMemory",
+ search_filter=search_filter,
+ )
+ )
+ if memory_type in ["All", "UserMemory"]:
+ tasks.append(
+ executor.submit(
+ self.graph_retriever.retrieve,
+ query=query,
+ parsed_goal=parsed_goal,
+ query_embedding=query_embedding,
+ top_k=top_k * 2,
+ memory_scope="UserMemory",
+ search_filter=search_filter,
+ )
+ )
+
+ # Collect results from all tasks
+ for task in tasks:
+ results.extend(task.result())
+
return self.reranker.rerank(
query=query,
query_embedding=query_embedding[0],
graph_results=results,
top_k=top_k,
parsed_goal=parsed_goal,
+ search_filter=search_filter,
)
@timed
@@ -300,8 +369,7 @@ def _sort_and_trim(self, results, top_k):
final_items = []
for item, score in sorted_results:
meta_data = item.metadata.model_dump()
- if "relativity" not in meta_data:
- meta_data["relativity"] = score
+ meta_data["relativity"] = score
final_items.append(
TextualMemoryItem(
id=item.id,
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py
index 2fae16c1..e5acd00f 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py
@@ -3,15 +3,16 @@
import json
import uuid
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import as_completed
from datetime import datetime
import requests
+from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import OllamaEmbedder
from memos.log import get_logger
from memos.mem_reader.base import BaseMemReader
-from memos.memories.textual.item import TextualMemoryItem
+from memos.memories.textual.item import SourceMessage, TextualMemoryItem
logger = get_logger(__name__)
@@ -150,7 +151,7 @@ def retrieve_from_internet(
# Convert to TextualMemoryItem format
memory_items: list[TextualMemoryItem] = []
- with ThreadPoolExecutor(max_workers=8) as executor:
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
futures = [
executor.submit(self._process_result, result, query, parsed_goal, info)
for result in search_results
@@ -332,7 +333,7 @@ def _process_result(
)
read_item_i.metadata.source = "web"
read_item_i.metadata.memory_type = "OuterMemory"
- read_item_i.metadata.sources = [url] if url else []
+ read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else []
read_item_i.metadata.visibility = "public"
memory_items.append(read_item_i)
diff --git a/src/memos/memos_tools/singleton.py b/src/memos/memos_tools/singleton.py
new file mode 100644
index 00000000..c3171edb
--- /dev/null
+++ b/src/memos/memos_tools/singleton.py
@@ -0,0 +1,174 @@
+"""
+Singleton decorator module for caching factory instances to avoid excessive memory usage
+from repeated initialization.
+"""
+
+import hashlib
+import json
+
+from collections.abc import Callable
+from functools import wraps
+from typing import Any, TypeVar
+from weakref import WeakValueDictionary
+
+
+T = TypeVar("T")
+
+
+class FactorySingleton:
+ """Factory singleton manager that caches instances based on configuration parameters"""
+
+ def __init__(self):
+ # Use weak reference dictionary for automatic cleanup when instances are no longer referenced
+ self._instances: dict[str, WeakValueDictionary] = {}
+
+ def _generate_cache_key(self, config: Any, *args, **kwargs) -> str:
+ """Generate cache key based on configuration only (ignoring other parameters)"""
+
+ # Handle configuration objects - only use the config parameter
+ if hasattr(config, "model_dump"): # Pydantic model
+ config_data = config.model_dump()
+ elif hasattr(config, "dict"): # Legacy Pydantic model
+ config_data = config.dict()
+ elif isinstance(config, dict):
+ config_data = config
+ else:
+ # For other types, try to convert to string
+ config_data = str(config)
+
+ # Filter out time-related fields that shouldn't affect caching
+ filtered_config = self._filter_temporal_fields(config_data)
+
+ # Generate hash key based only on config
+ try:
+ cache_str = json.dumps(filtered_config, sort_keys=True, ensure_ascii=False, default=str)
+ except (TypeError, ValueError):
+ # If JSON serialization fails, convert the entire config to string
+ cache_str = str(filtered_config)
+
+ return hashlib.md5(cache_str.encode("utf-8")).hexdigest()
+
+ def _filter_temporal_fields(self, config_data: Any) -> Any:
+ """Filter out temporal fields that shouldn't affect instance caching"""
+ if isinstance(config_data, dict):
+ filtered = {}
+ for key, value in config_data.items():
+ # Skip common temporal field names
+ if key.lower() in {
+ "created_at",
+ "updated_at",
+ "timestamp",
+ "time",
+ "date",
+ "created_time",
+ "updated_time",
+ "last_modified",
+ "modified_at",
+ "start_time",
+ "end_time",
+ "execution_time",
+ "run_time",
+ }:
+ continue
+ # Recursively filter nested dictionaries
+ filtered[key] = self._filter_temporal_fields(value)
+ return filtered
+ elif isinstance(config_data, list):
+ # Recursively filter lists
+ return [self._filter_temporal_fields(item) for item in config_data]
+ else:
+ # For primitive types, return as-is
+ return config_data
+
+ def get_or_create(self, factory_class: type, cache_key: str, creator_func: Callable) -> Any:
+ """Get or create instance"""
+ class_name = factory_class.__name__
+
+ if class_name not in self._instances:
+ self._instances[class_name] = WeakValueDictionary()
+
+ class_cache = self._instances[class_name]
+
+ if cache_key in class_cache:
+ return class_cache[cache_key]
+
+ # Create new instance
+ instance = creator_func()
+ class_cache[cache_key] = instance
+ return instance
+
+ def clear_cache(self, factory_class: type | None = None):
+ """Clear cache"""
+ if factory_class:
+ class_name = factory_class.__name__
+ if class_name in self._instances:
+ self._instances[class_name].clear()
+ else:
+ for cache in self._instances.values():
+ cache.clear()
+
+
+# Global singleton manager
+_factory_singleton = FactorySingleton()
+
+
+def singleton_factory(factory_class: type | str | None = None):
+ """
+ Factory singleton decorator
+
+ Usage:
+ @singleton_factory()
+ def from_config(cls, config):
+ return SomeClass(config)
+
+ Or specify factory class:
+ @singleton_factory(EmbedderFactory)
+ def from_config(cls, config):
+ return SomeClass(config)
+ """
+
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> T:
+ # Determine factory class and config parameter
+ target_factory_class = factory_class
+ config = None
+
+ # Simple logic: check if first parameter is a class or config
+ if args:
+ if hasattr(args[0], "__name__") and hasattr(args[0], "__module__"):
+ # First parameter is a class (cls), so this is a @classmethod
+ if target_factory_class is None:
+ target_factory_class = args[0]
+ config = args[1] if len(args) > 1 else None
+ else:
+ # First parameter is config, so this is a @staticmethod
+ if target_factory_class is None:
+ raise ValueError(
+ "Factory class must be explicitly specified for static methods"
+ )
+ if isinstance(target_factory_class, str):
+ # Convert string to a mock class for caching purposes
+ class MockFactoryClass:
+ __name__ = target_factory_class
+
+ target_factory_class = MockFactoryClass
+ config = args[0]
+
+ if config is None:
+ # If no configuration parameter, call original function directly
+ return func(*args, **kwargs)
+
+ # Generate cache key based only on config
+ cache_key = _factory_singleton._generate_cache_key(config)
+
+ # Function to create instance
+ def creator():
+ return func(*args, **kwargs)
+
+ # Get or create instance
+ return _factory_singleton.get_or_create(target_factory_class, cache_key, creator)
+
+ return wrapper
+
+ return decorator
diff --git a/src/memos/memos_tools/thread_safe_dict.py b/src/memos/memos_tools/thread_safe_dict.py
index dc7f98f4..1e9fb6a1 100644
--- a/src/memos/memos_tools/thread_safe_dict.py
+++ b/src/memos/memos_tools/thread_safe_dict.py
@@ -7,10 +7,15 @@
from collections.abc import ItemsView, Iterator, KeysView, ValuesView
from typing import Generic, TypeVar
+from memos.log import get_logger
+from memos.utils import timed
+
K = TypeVar("K")
V = TypeVar("V")
+logger = get_logger(__name__)
+
class ReadWriteLock:
"""A simple read-write lock implementation. use for product-server scenario"""
@@ -19,6 +24,7 @@ def __init__(self):
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
+ @timed
def acquire_read(self):
"""Acquire a read lock. Multiple readers can hold the lock simultaneously."""
self._read_ready.acquire()
@@ -37,6 +43,7 @@ def release_read(self):
finally:
self._read_ready.release()
+ @timed
def acquire_write(self):
"""Acquire a write lock. Only one writer can hold the lock."""
self._read_ready.acquire()
@@ -67,6 +74,7 @@ def __init__(self, initial_dict: dict[K, V] | None = None):
self._dict: dict[K, V] = initial_dict.copy() if initial_dict else {}
self._lock = ReadWriteLock()
+ @timed
def __getitem__(self, key: K) -> V:
"""Get item by key."""
self._lock.acquire_read()
@@ -75,6 +83,7 @@ def __getitem__(self, key: K) -> V:
finally:
self._lock.release_read()
+ @timed
def __setitem__(self, key: K, value: V) -> None:
"""Set item by key."""
self._lock.acquire_write()
@@ -83,6 +92,7 @@ def __setitem__(self, key: K, value: V) -> None:
finally:
self._lock.release_write()
+ @timed
def __delitem__(self, key: K) -> None:
"""Delete item by key."""
self._lock.acquire_write()
@@ -91,6 +101,7 @@ def __delitem__(self, key: K) -> None:
finally:
self._lock.release_write()
+ @timed
def __contains__(self, key: K) -> bool:
"""Check if key exists in dictionary."""
self._lock.acquire_read()
@@ -99,6 +110,7 @@ def __contains__(self, key: K) -> bool:
finally:
self._lock.release_read()
+ @timed
def __len__(self) -> int:
"""Get length of dictionary."""
self._lock.acquire_read()
@@ -115,6 +127,7 @@ def __bool__(self) -> bool:
finally:
self._lock.release_read()
+ @timed
def __iter__(self) -> Iterator[K]:
"""Iterate over keys. Returns a snapshot to avoid iteration issues."""
self._lock.acquire_read()
@@ -124,6 +137,7 @@ def __iter__(self) -> Iterator[K]:
finally:
self._lock.release_read()
+ @timed
def get(self, key: K, default: V | None = None) -> V:
"""Get item by key with optional default."""
self._lock.acquire_read()
@@ -132,6 +146,7 @@ def get(self, key: K, default: V | None = None) -> V:
finally:
self._lock.release_read()
+ @timed
def pop(self, key: K, *args) -> V:
"""Pop item by key."""
self._lock.acquire_write()
@@ -140,6 +155,7 @@ def pop(self, key: K, *args) -> V:
finally:
self._lock.release_write()
+ @timed
def update(self, *args, **kwargs) -> None:
"""Update dictionary."""
self._lock.acquire_write()
@@ -148,6 +164,7 @@ def update(self, *args, **kwargs) -> None:
finally:
self._lock.release_write()
+ @timed
def clear(self) -> None:
"""Clear all items."""
self._lock.acquire_write()
@@ -156,6 +173,7 @@ def clear(self) -> None:
finally:
self._lock.release_write()
+ @timed
def keys(self) -> KeysView[K]:
"""Get dictionary keys view (snapshot)."""
self._lock.acquire_read()
@@ -164,6 +182,7 @@ def keys(self) -> KeysView[K]:
finally:
self._lock.release_read()
+ @timed
def values(self) -> ValuesView[V]:
"""Get dictionary values view (snapshot)."""
self._lock.acquire_read()
@@ -172,6 +191,7 @@ def values(self) -> ValuesView[V]:
finally:
self._lock.release_read()
+ @timed
def items(self) -> ItemsView[K, V]:
"""Get dictionary items view (snapshot)."""
self._lock.acquire_read()
@@ -180,6 +200,7 @@ def items(self) -> ItemsView[K, V]:
finally:
self._lock.release_read()
+ @timed
def copy(self) -> dict[K, V]:
"""Create a copy of the dictionary."""
self._lock.acquire_read()
@@ -188,6 +209,7 @@ def copy(self) -> dict[K, V]:
finally:
self._lock.release_read()
+ @timed
def setdefault(self, key: K, default: V | None = None) -> V:
"""Set default value for key if not exists."""
self._lock.acquire_write()
diff --git a/src/memos/memos_tools/thread_safe_dict_segment.py b/src/memos/memos_tools/thread_safe_dict_segment.py
new file mode 100644
index 00000000..c1c10e3e
--- /dev/null
+++ b/src/memos/memos_tools/thread_safe_dict_segment.py
@@ -0,0 +1,382 @@
+import threading
+import time
+
+from collections.abc import Iterator
+from contextlib import contextmanager
+from typing import Any, Generic, TypeVar
+
+
+K = TypeVar("K")
+V = TypeVar("V")
+
+
+class FastReadWriteLock:
+ """Read-write lock optimized for FastAPI scenarios:
+ reader priority with writer starvation prevention"""
+
+ def __init__(self):
+ self._readers = 0
+ self._writers = 0
+ self._waiting_writers = 0
+ self._lock = threading.RLock()
+ self._read_ready = threading.Condition(self._lock)
+ self._write_ready = threading.Condition(self._lock)
+ # Writer starvation detection
+ self._last_write_time = 0
+ self._write_starvation_threshold = 0.1 # 100ms
+
+ def acquire_read(self) -> bool:
+ """Fast read lock acquisition"""
+ with self._lock:
+ # Check if writers are starving
+ current_time = time.time()
+ write_starving = (
+ self._waiting_writers > 0
+ and current_time - self._last_write_time > self._write_starvation_threshold
+ )
+
+ # If no writers are active and no starvation, allow readers to continue
+ if self._writers == 0 and not write_starving:
+ self._readers += 1
+ return True
+
+ # Otherwise wait
+ while self._writers > 0 or write_starving:
+ self._read_ready.wait()
+ current_time = time.time()
+ write_starving = (
+ self._waiting_writers > 0
+ and current_time - self._last_write_time > self._write_starvation_threshold
+ )
+
+ self._readers += 1
+ return True
+
+ def release_read(self):
+ """Release read lock"""
+ with self._lock:
+ self._readers -= 1
+ if self._readers == 0:
+ self._write_ready.notify()
+
+ def acquire_write(self) -> bool:
+ """Write lock acquisition"""
+ with self._lock:
+ self._waiting_writers += 1
+ try:
+ while self._readers > 0 or self._writers > 0:
+ self._write_ready.wait()
+
+ self._writers = 1
+ self._waiting_writers -= 1
+ self._last_write_time = time.time()
+ return True
+ except:
+ self._waiting_writers -= 1
+ raise
+
+ def release_write(self):
+ """Release write lock"""
+ with self._lock:
+ self._writers = 0
+ # Prioritize notifying readers (reader priority strategy)
+ self._read_ready.notify_all()
+ self._write_ready.notify()
+
+
+class SegmentedLock:
+ """Segmented lock, segments based on key hash"""
+
+ def __init__(self, segment_count: int = 64):
+ self.segment_count = segment_count
+ self.locks = [FastReadWriteLock() for _ in range(segment_count)]
+
+ def get_lock(self, key: K) -> FastReadWriteLock:
+ """Get the corresponding lock based on key"""
+ segment = hash(key) % self.segment_count
+ return self.locks[segment]
+
+ @contextmanager
+ def read_lock(self, key: K):
+ """Read lock context manager"""
+ lock = self.get_lock(key)
+ lock.acquire_read()
+ try:
+ yield
+ finally:
+ lock.release_read()
+
+ @contextmanager
+ def write_lock(self, key: K):
+ """Write lock context manager"""
+ lock = self.get_lock(key)
+ lock.acquire_write()
+ try:
+ yield
+ finally:
+ lock.release_write()
+
+
+class OptimizedThreadSafeDict(Generic[K, V]):
+ """
+ Thread-safe dictionary optimized for FastAPI scenarios:
+ - Segmented locks to reduce contention
+ - Reader priority with writer starvation prevention
+ - Support for large object storage
+ - Strong consistency guarantee
+ """
+
+ def __init__(
+ self, initial_dict: dict[K, V] | None = None, segment_count: int = 128
+ ): # More segments for high concurrency
+ self._segments: list[dict[K, V]] = [{} for _ in range(segment_count)]
+ self._segment_count = segment_count
+ self._segmented_lock = SegmentedLock(segment_count)
+
+ # Initialize data
+ if initial_dict:
+ for k, v in initial_dict.items():
+ segment_idx = self._get_segment(k)
+ self._segments[segment_idx][k] = v
+
+ def _get_segment(self, key: K) -> int:
+ """Calculate the segment corresponding to the key"""
+ return hash(key) % self._segment_count
+
+ def __getitem__(self, key: K) -> V:
+ """Get element"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.read_lock(key):
+ return self._segments[segment_idx][key]
+
+ def __setitem__(self, key: K, value: V) -> None:
+ """Set element - key optimization point"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.write_lock(key):
+ self._segments[segment_idx][key] = value
+
+ def __delitem__(self, key: K) -> None:
+ """Delete element"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.write_lock(key):
+ del self._segments[segment_idx][key]
+
+ def __contains__(self, key: K) -> bool:
+ """Check if key is contained"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.read_lock(key):
+ return key in self._segments[segment_idx]
+
+ def get(self, key: K, default: V | None = None) -> V | None:
+ """Safely get element"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.read_lock(key):
+ return self._segments[segment_idx].get(key, default)
+
+ def pop(self, key: K, *args) -> V:
+ """Pop element"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.write_lock(key):
+ return self._segments[segment_idx].pop(key, *args)
+
+ def setdefault(self, key: K, default: V | None = None) -> V:
+ """Set default value"""
+ segment_idx = self._get_segment(key)
+ with self._segmented_lock.write_lock(key):
+ return self._segments[segment_idx].setdefault(key, default)
+
+ def update(self, other=None, **kwargs) -> None:
+ """Batch update - optimized batch operation"""
+ items = (other.items() if hasattr(other, "items") else other) if other is not None else []
+
+ # Group update items by segment
+ segment_updates: dict[int, list[tuple[K, V]]] = {}
+
+ for k, v in items:
+ segment_idx = self._get_segment(k)
+ if segment_idx not in segment_updates:
+ segment_updates[segment_idx] = []
+ segment_updates[segment_idx].append((k, v))
+
+ for k, v in kwargs.items():
+ segment_idx = self._get_segment(k)
+ if segment_idx not in segment_updates:
+ segment_updates[segment_idx] = []
+ segment_updates[segment_idx].append((k, v))
+
+ # Update segment by segment to reduce lock holding time
+ for segment_idx, updates in segment_updates.items():
+ # Use the first key to get the lock (all keys in the same segment map to the same lock)
+ first_key = updates[0][0]
+ with self._segmented_lock.write_lock(first_key):
+ for k, v in updates:
+ self._segments[segment_idx][k] = v
+
+ def clear(self) -> None:
+ """Clear all elements - need to acquire all locks"""
+ # Acquire all locks in order to avoid deadlock
+ acquired_locks = []
+ try:
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_write()
+ acquired_locks.append(lock)
+
+ # Clear all segments
+ for segment in self._segments:
+ segment.clear()
+
+ finally:
+ # Release locks in reverse order
+ for lock in reversed(acquired_locks):
+ lock.release_write()
+
+ def __len__(self) -> int:
+ """Get total length - snapshot read"""
+ total = 0
+ acquired_locks = []
+ try:
+ # Acquire all read locks
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_read()
+ acquired_locks.append(lock)
+
+ # Calculate total length
+ for segment in self._segments:
+ total += len(segment)
+
+ return total
+
+ finally:
+ # Release all read locks
+ for lock in reversed(acquired_locks):
+ lock.release_read()
+
+ def __bool__(self) -> bool:
+ """Check if empty"""
+ return len(self) > 0
+
+ def keys(self) -> list[K]:
+ """Get snapshot of all keys"""
+ all_keys = []
+ acquired_locks = []
+
+ try:
+ # Acquire all read locks
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_read()
+ acquired_locks.append(lock)
+
+ # Collect all keys
+ for segment in self._segments:
+ all_keys.extend(segment.keys())
+
+ return all_keys
+
+ finally:
+ for lock in reversed(acquired_locks):
+ lock.release_read()
+
+ def values(self) -> list[V]:
+ """Get snapshot of all values"""
+ all_values = []
+ acquired_locks = []
+
+ try:
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_read()
+ acquired_locks.append(lock)
+
+ for segment in self._segments:
+ all_values.extend(segment.values())
+
+ return all_values
+
+ finally:
+ for lock in reversed(acquired_locks):
+ lock.release_read()
+
+ def items(self) -> list[tuple[K, V]]:
+ """Get snapshot of all items"""
+ all_items = []
+ acquired_locks = []
+
+ try:
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_read()
+ acquired_locks.append(lock)
+
+ for segment in self._segments:
+ all_items.extend(segment.items())
+
+ return all_items
+
+ finally:
+ for lock in reversed(acquired_locks):
+ lock.release_read()
+
+ def copy(self) -> dict[K, V]:
+ """Create dictionary copy"""
+ result = {}
+ acquired_locks = []
+
+ try:
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_read()
+ acquired_locks.append(lock)
+
+ for segment in self._segments:
+ result.update(segment)
+
+ return result
+
+ finally:
+ for lock in reversed(acquired_locks):
+ lock.release_read()
+
+ def __iter__(self) -> Iterator[K]:
+ """Iterator - returns snapshot"""
+ return iter(self.keys())
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return f"OptimizedThreadSafeDict({dict(self.items())})"
+
+ def stats(self) -> dict[str, Any]:
+ """Get statistics"""
+ segment_sizes = []
+ total_items = 0
+
+ acquired_locks = []
+ try:
+ for i in range(self._segment_count):
+ lock = self._segmented_lock.locks[i]
+ lock.acquire_read()
+ acquired_locks.append(lock)
+
+ for segment in self._segments:
+ size = len(segment)
+ segment_sizes.append(size)
+ total_items += size
+
+ avg_size = total_items / self._segment_count if self._segment_count > 0 else 0
+ max_size = max(segment_sizes) if segment_sizes else 0
+ min_size = min(segment_sizes) if segment_sizes else 0
+
+ return {
+ "total_items": total_items,
+ "segment_count": self._segment_count,
+ "avg_segment_size": avg_size,
+ "max_segment_size": max_size,
+ "min_segment_size": min_size,
+ "load_balance_ratio": min_size / max_size if max_size > 0 else 1.0,
+ }
+
+ finally:
+ for lock in reversed(acquired_locks):
+ lock.release_read()
diff --git a/src/memos/parsers/factory.py b/src/memos/parsers/factory.py
index ddfe4401..9fedc28c 100644
--- a/src/memos/parsers/factory.py
+++ b/src/memos/parsers/factory.py
@@ -1,6 +1,7 @@
from typing import Any, ClassVar
from memos.configs.parser import ParserConfigFactory
+from memos.memos_tools.singleton import singleton_factory
from memos.parsers.base import BaseParser
from memos.parsers.markitdown import MarkItDownParser
@@ -11,6 +12,7 @@ class ParserFactory(BaseParser):
backend_to_class: ClassVar[dict[str, Any]] = {"markitdown": MarkItDownParser}
@classmethod
+ @singleton_factory()
def from_config(cls, config_factory: ParserConfigFactory) -> BaseParser:
backend = config_factory.backend
if backend not in cls.backend_to_class:
diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py
new file mode 100644
index 00000000..5ad33952
--- /dev/null
+++ b/src/memos/reranker/concat.py
@@ -0,0 +1,59 @@
+import re
+
+from typing import Any
+
+
+_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
+
+
+def process_source(
+ items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3
+) -> str:
+ """
+ Args:
+ items: List of tuples where each tuple contains (memory, source).
+ source can be str, Dict, or List.
+ recent_num: Number of recent items to concatenate.
+ Returns:
+ str: Concatenated source.
+ """
+ if items is None:
+ items = []
+ concat_data = []
+ memory = None
+ for item in items:
+ memory, source = item
+ for content in source:
+ if isinstance(content, str):
+ if "assistant:" in content:
+ continue
+ concat_data.append(content)
+ if memory is not None:
+ concat_data = [memory, *concat_data]
+ return "\n".join(concat_data)
+
+
+def concat_original_source(
+ graph_results: list,
+ merge_field: list[str] | None = None,
+) -> list[str]:
+ """
+ Merge memory items with original dialogue.
+ Args:
+ graph_results (list[TextualMemoryItem]): List of memory items with embeddings.
+ merge_field (List[str]): List of fields to merge.
+ Returns:
+ list[str]: List of memory and concat orginal memory.
+ """
+ if merge_field is None:
+ merge_field = ["sources"]
+ documents = []
+ for item in graph_results:
+ memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
+ sources = []
+ for field in merge_field:
+ source = getattr(item.metadata, field, "")
+ sources.append((memory, source))
+ concat_string = process_source(sources)
+ documents.append(concat_string)
+ return documents
diff --git a/src/memos/reranker/cosine_local.py b/src/memos/reranker/cosine_local.py
index 39f44b9b..000b64cf 100644
--- a/src/memos/reranker/cosine_local.py
+++ b/src/memos/reranker/cosine_local.py
@@ -49,6 +49,7 @@ def __init__(
self,
level_weights: dict[str, float] | None = None,
level_field: str = "background",
+ **kwargs,
):
self.level_weights = level_weights or {"topic": 1.0, "concept": 1.0, "fact": 1.0}
self.level_field = level_field
diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py
index 244b6928..134e29eb 100644
--- a/src/memos/reranker/factory.py
+++ b/src/memos/reranker/factory.py
@@ -3,6 +3,9 @@
from typing import TYPE_CHECKING, Any
+# Import singleton decorator
+from memos.memos_tools.singleton import singleton_factory
+
from .cosine_local import CosineLocalReranker
from .http_bge import HTTPBGEReranker
from .noop import NoopReranker
@@ -16,6 +19,7 @@
class RerankerFactory:
@staticmethod
+ @singleton_factory("RerankerFactory")
def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
if not cfg:
return None
@@ -29,6 +33,7 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
model=c.get("model", "bge-reranker-v2-m3"),
timeout=int(c.get("timeout", 10)),
headers_extra=c.get("headers_extra"),
+ rerank_source=c.get("rerank_source"),
)
if backend in {"cosine_local", "cosine"}:
diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py
index a852f325..f0f5d17a 100644
--- a/src/memos/reranker/http_bge.py
+++ b/src/memos/reranker/http_bge.py
@@ -3,22 +3,74 @@
import re
-from typing import TYPE_CHECKING
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Any
import requests
+from memos.log import get_logger
+
from .base import BaseReranker
+from .concat import concat_original_source
+
+
+logger = get_logger(__name__)
if TYPE_CHECKING:
from memos.memories.textual.item import TextualMemoryItem
+# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...")
+# before sending text to the reranker. This keeps inputs clean and
+# avoids misleading the model with bracketed prefixes.
_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
+DEFAULT_BOOST_WEIGHTS = {"user_id": 0.5, "tags": 0.2, "session_id": 0.3}
+
+
+def _value_matches(item_value: Any, wanted: Any) -> bool:
+ """
+ Generic matching:
+ - if item_value is list/tuple/set: check membership (any match if wanted is iterable)
+ - else: equality (any match if wanted is iterable)
+ """
+
+ def _iterable(x):
+ # exclude strings from "iterable"
+ return isinstance(x, Iterable) and not isinstance(x, str | bytes)
+
+ if _iterable(item_value):
+ if _iterable(wanted):
+ return any(w in item_value for w in wanted)
+ return wanted in item_value
+ else:
+ if _iterable(wanted):
+ return any(item_value == w for w in wanted)
+ return item_value == wanted
class HTTPBGEReranker(BaseReranker):
"""
- HTTP-based BGE reranker. Mirrors your old MemoryReranker, but configurable.
+ HTTP-based BGE reranker.
+
+ This class sends (query, documents[]) to a remote HTTP endpoint that
+ performs cross-encoder-style re-ranking (e.g., BGE reranker) and returns
+ relevance scores. It then maps those scores back onto the original
+ TextualMemoryItem list and returns (item, score) pairs sorted by score.
+
+ Notes
+ -----
+ - The endpoint is expected to accept JSON:
+ {
+ "model": "",
+ "query": "",
+ "documents": ["doc1", "doc2", ...]
+ }
+ - Two response shapes are supported:
+ 1) {"results": [{"index": , "relevance_score": }, ...]}
+ where "index" refers to the *position in the documents array*.
+ 2) {"data": [{"score": }, ...]} (aligned by list order)
+ - If the service fails or responds unexpectedly, this falls back to
+ returning the original items with 0.0 scores (best-effort).
"""
def __init__(
@@ -28,7 +80,26 @@ def __init__(
model: str = "bge-reranker-v2-m3",
timeout: int = 10,
headers_extra: dict | None = None,
+ rerank_source: list[str] | None = None,
+ boost_weights: dict[str, float] | None = None,
+ boost_default: float = 0.0,
+ warn_unknown_filter_keys: bool = True,
+ **kwargs,
):
+ """
+ Parameters
+ ----------
+ reranker_url : str
+ HTTP endpoint for the reranker service.
+ token : str, optional
+ Bearer token for auth. If non-empty, added to the Authorization header.
+ model : str, optional
+ Model identifier understood by the server.
+ timeout : int, optional
+ Request timeout (seconds).
+ headers_extra : dict | None, optional
+ Additional headers to merge into the request headers.
+ """
if not reranker_url:
raise ValueError("reranker_url must not be empty")
self.reranker_url = reranker_url
@@ -36,22 +107,62 @@ def __init__(
self.model = model
self.timeout = timeout
self.headers_extra = headers_extra or {}
+ self.concat_source = rerank_source
+
+ self.boost_weights = (
+ DEFAULT_BOOST_WEIGHTS.copy()
+ if boost_weights is None
+ else {k: float(v) for k, v in boost_weights.items()}
+ )
+ self.boost_default = float(boost_default)
+ self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys)
+ self._warned_missing_keys: set[str] = set()
def rerank(
self,
query: str,
- graph_results: list,
+ graph_results: list[TextualMemoryItem],
top_k: int,
+ search_filter: dict | None = None,
**kwargs,
) -> list[tuple[TextualMemoryItem, float]]:
+ """
+ Rank candidate memories by relevance to the query.
+
+ Parameters
+ ----------
+ query : str
+ The search query.
+ graph_results : list[TextualMemoryItem]
+ Candidate items to re-rank. Each item is expected to have a
+ `.memory` str field; non-strings are ignored.
+ top_k : int
+ Return at most this many items.
+ search_filter : dict | None
+ Currently unused. Present to keep signature compatible.
+
+ Returns
+ -------
+ list[tuple[TextualMemoryItem, float]]
+ Re-ranked items with scores, sorted descending by score.
+ """
if not graph_results:
return []
- documents = [
- (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
- for item in graph_results
- ]
- documents = [d for d in documents if isinstance(d, str) and d]
+ # Build a mapping from "payload docs index" -> "original graph_results index"
+ # Only include items that have a non-empty string memory. This ensures that
+ # any index returned by the server can be mapped back correctly.
+ if self.concat_source:
+ documents = concat_original_source(graph_results, self.concat_source)
+ else:
+ documents = [
+ (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
+ for item in graph_results
+ ]
+ documents = [d for d in documents if isinstance(d, str) and d]
+
+ logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...")
+
if not documents:
return []
@@ -59,6 +170,7 @@ def rerank(
payload = {"model": self.model, "query": query, "documents": documents}
try:
+ # Make the HTTP request to the reranker service
resp = requests.post(
self.reranker_url, headers=headers, json=payload, timeout=self.timeout
)
@@ -68,18 +180,28 @@ def rerank(
scored_items: list[tuple[TextualMemoryItem, float]] = []
if "results" in data:
+ # Format:
+ # dict("results": [{"index": int, "relevance_score": float},
+ # ...])
rows = data.get("results", [])
for r in rows:
idx = r.get("index")
+ # The returned index refers to 'documents' (i.e., our 'pairs' order),
+ # so we must map it back to the original graph_results index.
if isinstance(idx, int) and 0 <= idx < len(graph_results):
- score = float(r.get("relevance_score", r.get("score", 0.0)))
- scored_items.append((graph_results[idx], score))
+ raw_score = float(r.get("relevance_score", r.get("score", 0.0)))
+ item = graph_results[idx]
+ # generic boost
+ score = self._apply_boost_generic(item, raw_score, search_filter)
+ scored_items.append((item, score))
scored_items.sort(key=lambda x: x[1], reverse=True)
return scored_items[: min(top_k, len(scored_items))]
elif "data" in data:
+ # Format: {"data": [{"score": float}, ...]} aligned by list order
rows = data.get("data", [])
+ # Build a list of scores aligned with our 'documents' (pairs)
score_list = [float(r.get("score", 0.0)) for r in rows]
if len(score_list) < len(graph_results):
@@ -87,13 +209,104 @@ def rerank(
elif len(score_list) > len(graph_results):
score_list = score_list[: len(graph_results)]
- scored_items = list(zip(graph_results, score_list, strict=False))
+ scored_items = []
+ for item, raw_score in zip(graph_results, score_list, strict=False):
+ score = self._apply_boost_generic(item, raw_score, search_filter)
+ scored_items.append((item, score))
+
scored_items.sort(key=lambda x: x[1], reverse=True)
return scored_items[: min(top_k, len(scored_items))]
else:
+ # Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs
+ # Note: we use 'pairs' to keep alignment with valid (string) docs.
return [(item, 0.0) for item in graph_results[:top_k]]
except Exception as e:
- print(f"[HTTPBGEReranker] request failed: {e}")
+ # Network error, timeout, JSON decode error, etc.
+ # Degrade gracefully by returning first top_k valid docs with 0.0 score.
+ logger.error(f"[HTTPBGEReranker] request failed: {e}")
return [(item, 0.0) for item in graph_results[:top_k]]
+
+ def _get_attr_or_key(self, obj: Any, key: str) -> Any:
+ """
+ Resolve `key` on `obj` with one-level fallback into `obj.metadata`.
+
+ Priority:
+ 1) obj.
+ 2) obj[key]
+ 3) obj.metadata.
+ 4) obj.metadata[key]
+ """
+ if obj is None:
+ return None
+
+ # support input like "metadata.user_id"
+ if "." in key:
+ head, tail = key.split(".", 1)
+ base = self._get_attr_or_key(obj, head)
+ return self._get_attr_or_key(base, tail)
+
+ def _resolve(o: Any, k: str):
+ if o is None:
+ return None
+ v = getattr(o, k, None)
+ if v is not None:
+ return v
+ if hasattr(o, "get"):
+ try:
+ return o.get(k)
+ except Exception:
+ return None
+ return None
+
+ # 1) find in obj
+ v = _resolve(obj, key)
+ if v is not None:
+ return v
+
+ # 2) find in obj.metadata
+ meta = _resolve(obj, "metadata")
+ if meta is not None:
+ return _resolve(meta, key)
+
+ return None
+
+ def _apply_boost_generic(
+ self,
+ item: TextualMemoryItem,
+ base_score: float,
+ search_filter: dict | None,
+ ) -> float:
+ """
+ Multiply base_score by (1 + weight) for each matching key in search_filter.
+ - key resolution: self._get_attr_or_key(item, key)
+ - weight = boost_weights.get(key, self.boost_default)
+ - unknown key -> one-time warning
+ """
+ if not search_filter:
+ return base_score
+
+ score = float(base_score)
+
+ for key, wanted in search_filter.items():
+ # _get_attr_or_key automatically find key in item and
+ # item.metadata ("metadata.user_id" supported)
+ resolved = self._get_attr_or_key(item, key)
+
+ if resolved is None:
+ if self.warn_unknown_filter_keys and key not in self._warned_missing_keys:
+ logger.warning(
+ "[HTTPBGEReranker] search_filter key '%s' not found on TextualMemoryItem or metadata",
+ key,
+ )
+ self._warned_missing_keys.add(key)
+ continue
+
+ if _value_matches(resolved, wanted):
+ w = float(self.boost_weights.get(key, self.boost_default))
+ if w != 0.0:
+ score *= 1.0 + w
+ score = min(max(0.0, score), 1.0)
+
+ return score
diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py
index a1fa5324..b4d091c1 100644
--- a/src/memos/templates/mem_scheduler_prompts.py
+++ b/src/memos/templates/mem_scheduler_prompts.py
@@ -151,11 +151,253 @@
Answer:
"""
+MEMORY_FILTERING_PROMPT = """
+# Memory Relevance Filtering Task
+
+## Role
+You are an intelligent memory filtering system. Your primary function is to analyze memory relevance and filter out memories that are completely unrelated to the user's query history.
+
+## Task Description
+Analyze the provided memories and determine which ones are relevant to the user's query history:
+1. Evaluate semantic relationship between each memory and the query history
+2. Identify memories that are completely unrelated or irrelevant
+3. Filter out memories that don't contribute to answering the queries
+4. Preserve memories that provide context, evidence, or relevant information
+
+## Relevance Criteria
+A memory is considered RELEVANT if it:
+- Directly answers questions from the query history
+- Provides context or background information related to the queries
+- Contains information that could be useful for understanding the queries
+- Shares semantic similarity with query topics or themes
+- Contains keywords or concepts mentioned in the queries
+
+A memory is considered IRRELEVANT if it:
+- Has no semantic connection to any query in the history
+- Discusses completely unrelated topics
+- Contains information that cannot help answer any query
+- Is too generic or vague to be useful
+
+## Input Format
+- Query History: List of user queries (chronological order)
+- Memories: List of memory texts to be evaluated
+
+## Output Format Requirements
+You MUST output a valid JSON object with EXACTLY the following structure:
+{{
+ "relevant_memories": [array_of_memory_indices],
+ "filtered_count": ,
+ "reasoning": "string_explanation"
+}}
+
+## Important Notes:
+- Only output the JSON object, nothing else
+- Do not include any markdown formatting or code block notation
+- Ensure all brackets and quotes are properly closed
+- The output must be parseable by a JSON parser
+- Memory indices should correspond to the input order (0-based)
+
+## Processing Guidelines
+1. Be conservative in filtering - when in doubt, keep the memory
+2. Consider both direct and indirect relevance
+3. Look for thematic connections, not just exact keyword matches
+4. Preserve memories that provide valuable context
+
+## Current Task
+Query History: {query_history}
+Memories to Filter: {memories}
+
+Please provide your filtering analysis:
+"""
+
+MEMORY_REDUNDANCY_FILTERING_PROMPT = """
+# Memory Redundancy Filtering Task
+
+## Role
+You are an intelligent memory optimization system. Your primary function is to analyze memories and remove redundancy to improve memory quality and relevance.
+
+## Task Description
+Analyze the provided memories and identify redundant ones:
+1. **Redundancy Detection**: Find memories that contain the same core facts relevant to queries
+2. **Best Memory Selection**: Keep only the most concise and focused version of redundant information
+3. **Quality Preservation**: Ensure the final set covers all necessary information without redundancy
+
+## Redundancy Detection Criteria
+A memory is considered REDUNDANT if it:
+- Contains the same core fact as another memory that's relevant to the queries
+- Provides the same information but with additional irrelevant details
+- Repeats information that's already covered by a more concise memory
+- Has overlapping content with another memory that serves the same purpose
+
+When redundancy is found, KEEP the memory that:
+- Is more concise and focused
+- Contains less irrelevant information
+- Is more directly relevant to the queries
+- Has higher information density
+
+## Input Format
+- Query History: List of user queries (chronological order)
+- Memories: List of memory texts to be evaluated
+
+## Output Format Requirements
+You MUST output a valid JSON object with EXACTLY the following structure:
+{{
+ "kept_memories": [array_of_memory_indices_to_keep],
+ "redundant_groups": [
+ {{
+ "group_id": ,
+ "memories": [array_of_redundant_memory_indices],
+ "kept_memory": ,
+ "reason": "explanation_of_why_this_memory_was_kept"
+ }}
+ ],
+ "reasoning": "string_explanation_of_filtering_decisions"
+}}
+
+## Important Notes:
+- Only output the JSON object, nothing else
+- Do not include any markdown formatting or code block notation
+- Ensure all brackets and quotes are properly closed
+- The output must be parseable by a JSON parser
+- Memory indices should correspond to the input order (0-based)
+- Be conservative in filtering - when in doubt, keep the memory
+- Focus on semantic similarity, not just exact text matches
+
+## Processing Guidelines
+1. First identify which memories are relevant to the queries
+2. Group relevant memories by semantic similarity and core facts
+3. Within each group, select the best memory (most concise, least noise)
+4. Ensure the final set covers all necessary information without redundancy
+
+## Current Task
+Query History: {query_history}
+Memories to Filter: {memories}
+
+Please provide your redundancy filtering analysis:
+"""
+
+MEMORY_COMBINED_FILTERING_PROMPT = """
+# Memory Combined Filtering Task
+
+## Role
+You are an intelligent memory optimization system. Your primary function is to analyze memories and perform two types of filtering in sequence:
+1. **Unrelated Memory Removal**: Remove memories that are completely unrelated to the user's query history
+2. **Redundancy Removal**: Remove redundant memories by keeping only the most informative version
+
+## Task Description
+Analyze the provided memories and perform comprehensive filtering:
+1. **First Step - Unrelated Filtering**: Identify and remove memories that have no semantic connection to any query
+2. **Second Step - Redundancy Filtering**: Group similar memories and keep only the best version from each group
+
+## Unrelated Memory Detection Criteria
+A memory is considered UNRELATED if it:
+- Has no semantic connection to any query in the history
+- Discusses completely unrelated topics
+- Contains information that cannot help answer any query
+- Is too generic or vague to be useful
+
+## Redundancy Detection Criteria
+A memory is considered REDUNDANT if it:
+- Contains the same core fact as another memory that's relevant to the queries
+- Provides the same information but with additional irrelevant details
+- Repeats information that's already covered by a more concise memory
+- Has overlapping content with another memory that serves the same purpose
+
+When redundancy is found, KEEP the memory that:
+- Is more concise and focused
+- Contains less irrelevant information
+- Is more directly relevant to the queries
+- Has higher information density
+
+## Input Format
+- Query History: List of user queries (chronological order)
+- Memories: List of memory texts to be evaluated
+
+## Output Format Requirements
+You MUST output a valid JSON object with EXACTLY the following structure:
+{{
+ "kept_memories": [array_of_memory_indices_to_keep],
+ "unrelated_removed_count": ,
+ "redundant_removed_count": ,
+ "redundant_groups": [
+ {{
+ "group_id": ,
+ "memories": [array_of_redundant_memory_indices],
+ "kept_memory": ,
+ "reason": "explanation_of_why_this_memory_was_kept"
+ }}
+ ],
+ "reasoning": "string_explanation_of_filtering_decisions"
+}}
+
+## Important Notes:
+- Only output the JSON object, nothing else
+- Do not include any markdown formatting or code block notation
+- Ensure all brackets and quotes are properly closed
+- The output must be parseable by a JSON parser
+- Memory indices should correspond to the input order (0-based)
+- Be conservative in filtering - when in doubt, keep the memory
+- Focus on semantic similarity, not just exact text matches
+
+## Processing Guidelines
+1. **First, identify unrelated memories** and mark them for removal
+2. **Then, group remaining memories** by semantic similarity and core facts
+3. **Within each group, select the best memory** (most concise, least noise)
+4. **Ensure the final set covers all necessary information** without redundancy
+5. **Count how many memories were removed** for each reason
+
+## Current Task
+Query History: {query_history}
+Memories to Filter: {memories}
+
+Please provide your combined filtering analysis:
+"""
+
+
+MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT = """
+# Memory Answer Ability Evaluation Task
+
+## Task
+Evaluate whether the provided memories contain sufficient information to answer the user's query.
+
+## Evaluation Criteria
+Consider these factors:
+1. **Answer completeness**: Do the memories cover all aspects of the query?
+2. **Evidence relevance**: Do the memories directly support answering the query?
+3. **Detail specificity**: Do the memories contain necessary granularity?
+4. **Information gaps**: Are there obvious missing pieces of information?
+
+## Decision Rules
+- Return `True` for "result" ONLY when memories provide complete, relevant answers
+- Return `False` for "result" if memories are insufficient, irrelevant, or incomplete
+
+## User Query
+{query}
+
+## Available Memories
+{memory_list}
+
+## Required Output
+Return a JSON object with this exact structure:
+{{
+ "result": ,
+ "reason": ""
+}}
+
+## Instructions
+- Only output the JSON object, nothing else
+- Be conservative: if there's any doubt about completeness, return true
+- Focus on whether the memories can fully answer the query without additional information
+"""
PROMPT_MAPPING = {
"intent_recognizing": INTENT_RECOGNIZING_PROMPT,
"memory_reranking": MEMORY_RERANKING_PROMPT,
"query_keywords_extraction": QUERY_KEYWORDS_EXTRACTION_PROMPT,
+ "memory_filtering": MEMORY_FILTERING_PROMPT,
+ "memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT,
+ "memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT,
+ "memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT,
}
MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}"""
diff --git a/src/memos/types.py b/src/memos/types.py
index 0897cecd..60d5da8d 100644
--- a/src/memos/types.py
+++ b/src/memos/types.py
@@ -22,11 +22,14 @@
# Message structure
-class MessageDict(TypedDict):
+class MessageDict(TypedDict, total=False):
"""Typed dictionary for chat message dictionaries."""
role: MessageRole
content: str
+ chat_time: str | None # Optional timestamp for the message, format is not
+ # restricted, it can be any vague or precise time string.
+ message_id: str | None # Optional unique identifier for the message
# Message collections
diff --git a/tests/api/test_thread_context.py b/tests/api/test_thread_context.py
index 36da692f..97c395db 100644
--- a/tests/api/test_thread_context.py
+++ b/tests/api/test_thread_context.py
@@ -1,7 +1,12 @@
import time
-from memos.api.context.context import RequestContext, get_current_context, set_request_context
-from memos.api.context.context_thread import ContextThread, ContextThreadPoolExecutor
+from memos.context.context import (
+ ContextThread,
+ ContextThreadPoolExecutor,
+ RequestContext,
+ get_current_context,
+ set_request_context,
+)
from memos.log import get_logger
diff --git a/tests/mem_os/test_memos_core.py b/tests/mem_os/test_memos_core.py
index 2c873e5a..6d2408d0 100644
--- a/tests/mem_os/test_memos_core.py
+++ b/tests/mem_os/test_memos_core.py
@@ -682,41 +682,8 @@ def test_chat_without_memories(
# Verify response
assert response == "This is a test response from the assistant."
- @patch("memos.mem_os.core.UserManager")
- @patch("memos.mem_os.core.MemReaderFactory")
- @patch("memos.mem_os.core.LLMFactory")
- def test_clear_messages(
- self,
- mock_llm_factory,
- mock_reader_factory,
- mock_user_manager_class,
- mock_config,
- mock_llm,
- mock_mem_reader,
- mock_user_manager,
- ):
- """Test clearing chat history."""
- # Setup mocks
- mock_llm_factory.from_config.return_value = mock_llm
- mock_reader_factory.from_config.return_value = mock_mem_reader
- mock_user_manager_class.return_value = mock_user_manager
-
- mos = MOSCore(MOSConfig(**mock_config))
-
- # Add some chat history
- mos.chat_history_manager["test_user"].chat_history.append(
- {"role": "user", "content": "Hello"}
- )
- mos.chat_history_manager["test_user"].chat_history.append(
- {"role": "assistant", "content": "Hi"}
- )
-
- assert len(mos.chat_history_manager["test_user"].chat_history) == 2
-
- mos.clear_messages()
- assert len(mos.chat_history_manager["test_user"].chat_history) == 0
- assert mos.chat_history_manager["test_user"].user_id == "test_user"
+# TODO: test clear message
class TestMOSSystemPrompt:
diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py
index 6048eee3..5407ae54 100644
--- a/tests/mem_reader/test_simple_structure.py
+++ b/tests/mem_reader/test_simple_structure.py
@@ -43,9 +43,9 @@ def test_init(self):
def test_process_chat_data(self):
"""Test processing chat data into memory items."""
scene_data_info = [
- "user: Hello",
- "assistant: Hi there",
- "user: How are you?",
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there"},
+ {"role": "user", "content": "How are you?"},
]
info = {"user_id": "user1", "session_id": "session1"}
@@ -115,7 +115,14 @@ def test_get_scene_data_info_with_chat(self):
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
- self.assertEqual(result[0][0], "user: [3 May 2025]: I'm feeling a bit down today.")
+ self.assertEqual(
+ result[0][0],
+ {
+ "role": "user",
+ "chat_time": "3 May 2025",
+ "content": "I'm feeling a bit down today.",
+ },
+ )
@patch("memos.mem_reader.simple_struct.ParserFactory")
def test_get_scene_data_info_with_doc(self, mock_parser_factory):
diff --git a/tests/mem_scheduler/test_config.py b/tests/mem_scheduler/test_config.py
new file mode 100644
index 00000000..b389220a
--- /dev/null
+++ b/tests/mem_scheduler/test_config.py
@@ -0,0 +1,319 @@
+import os
+import sys
+import unittest
+
+from pathlib import Path
+from tempfile import NamedTemporaryFile, TemporaryDirectory
+
+from memos.configs.mem_scheduler import AuthConfig, GraphDBAuthConfig, OpenAIConfig, RabbitMQConfig
+from memos.mem_scheduler.general_modules.misc import EnvConfigMixin
+from memos.mem_scheduler.utils.config_utils import convert_config_to_env, flatten_dict
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR))
+
+ENV_PREFIX = EnvConfigMixin.ENV_PREFIX
+
+
+class TestEnvConfigMixin(unittest.TestCase):
+ """Tests specifically for the EnvConfigMixin functionality"""
+
+ def test_env_prefix_class_variable(self):
+ """Verify the base environment prefix is set correctly"""
+ self.assertEqual(EnvConfigMixin.ENV_PREFIX, "MEMSCHEDULER_")
+
+ def test_get_env_prefix_generation(self):
+ """Test the dynamic environment variable prefix generation"""
+ # Test GraphDBAuthConfig specifically since it's causing issues
+ self.assertEqual(
+ GraphDBAuthConfig.get_env_prefix(),
+ f"{ENV_PREFIX}GRAPHDBAUTH_", # Critical: This is the correct prefix!
+ )
+
+ # Verify other configs
+ self.assertEqual(RabbitMQConfig.get_env_prefix(), f"{ENV_PREFIX}RABBITMQ_")
+ self.assertEqual(OpenAIConfig.get_env_prefix(), f"{ENV_PREFIX}OPENAI_")
+
+
+class TestSchedulerConfig(unittest.TestCase):
+ def setUp(self):
+ self.env_backup = dict(os.environ)
+ self._clear_prefixed_env_vars()
+
+ def tearDown(self):
+ os.environ.clear()
+ os.environ.update(self.env_backup)
+
+ def _clear_prefixed_env_vars(self):
+ for key in list(os.environ.keys()):
+ if key.startswith(ENV_PREFIX):
+ del os.environ[key]
+
+ def test_loads_all_configs_from_env(self):
+ """Test loading all configurations from prefixed environment variables"""
+ os.environ.update(
+ {
+ # RabbitMQ configs
+ f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "rabbit.test.com",
+ f"{ENV_PREFIX}RABBITMQ_USER_NAME": "test_user",
+ f"{ENV_PREFIX}RABBITMQ_PASSWORD": "test_pass",
+ f"{ENV_PREFIX}RABBITMQ_VIRTUAL_HOST": "test_vhost",
+ f"{ENV_PREFIX}RABBITMQ_ERASE_ON_CONNECT": "false",
+ f"{ENV_PREFIX}RABBITMQ_PORT": "5673",
+ # OpenAI configs
+ f"{ENV_PREFIX}OPENAI_API_KEY": "test_api_key",
+ f"{ENV_PREFIX}OPENAI_BASE_URL": "https://api.test.openai.com",
+ f"{ENV_PREFIX}OPENAI_DEFAULT_MODEL": "gpt-test",
+ # GraphDBAuthConfig configs - NOTE THE CORRECT PREFIX!
+ f"{ENV_PREFIX}GRAPHDBAUTH_URI": "bolt://test.db:7687",
+ f"{ENV_PREFIX}GRAPHDBAUTH_USER": "test_neo4j",
+ f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "test_db_pass_123", # 13 chars (valid)
+ f"{ENV_PREFIX}GRAPHDBAUTH_DB_NAME": "test_db",
+ f"{ENV_PREFIX}GRAPHDBAUTH_AUTO_CREATE": "false",
+ }
+ )
+
+ config = AuthConfig.from_local_env()
+
+ # Verify GraphDB configuration
+ self.assertEqual(config.graph_db.uri, "bolt://test.db:7687")
+ self.assertEqual(config.graph_db.user, "test_neo4j")
+ self.assertEqual(config.graph_db.password, "test_db_pass_123")
+ self.assertEqual(config.graph_db.db_name, "test_db")
+ self.assertFalse(config.graph_db.auto_create)
+
+ def test_uses_default_values_when_env_not_set(self):
+ """Test that default values are used when prefixed environment variables are not set"""
+ os.environ.update(
+ {
+ # RabbitMQ
+ f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "rabbit.test.com",
+ # OpenAI
+ f"{ENV_PREFIX}OPENAI_API_KEY": "test_api_key",
+ # GraphDB - with correct prefix and valid password length
+ f"{ENV_PREFIX}GRAPHDBAUTH_URI": "bolt://test.db:7687",
+ f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "default_pass", # 11 chars (valid)
+ }
+ )
+
+ config = AuthConfig.from_local_env()
+
+ # Verify default values take effect
+ self.assertEqual(config.rabbitmq.port, 5672) # RabbitMQ default port
+ self.assertTrue(config.graph_db.auto_create) # GraphDB default auto-create
+
+ def test_raises_on_missing_required_variables(self):
+ """Test that exceptions are raised when required prefixed variables are missing"""
+ with self.assertRaises((ValueError, Exception)) as context:
+ AuthConfig.from_local_env()
+
+ error_msg = str(context.exception).lower()
+ self.assertTrue(
+ "missing" in error_msg or "validation" in error_msg or "required" in error_msg,
+ f"Error message does not meet expectations: {error_msg}",
+ )
+
+ def test_type_conversion(self):
+ """Test type conversion for prefixed environment variables"""
+ os.environ.update(
+ {
+ # RabbitMQ
+ f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "rabbit.test.com",
+ f"{ENV_PREFIX}RABBITMQ_PORT": "1234",
+ f"{ENV_PREFIX}RABBITMQ_ERASE_ON_CONNECT": "yes",
+ # OpenAI
+ f"{ENV_PREFIX}OPENAI_API_KEY": "test_api_key",
+ # GraphDB - correct prefix and valid password
+ f"{ENV_PREFIX}GRAPHDBAUTH_URI": "bolt://test.db:7687",
+ f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "type_conv_pass", # 13 chars (valid)
+ f"{ENV_PREFIX}GRAPHDBAUTH_AUTO_CREATE": "0",
+ }
+ )
+
+ config = AuthConfig.from_local_env()
+
+ # Verify type conversion results
+ self.assertIsInstance(config.rabbitmq.port, int)
+ self.assertIsInstance(config.rabbitmq.erase_on_connect, bool)
+ self.assertIsInstance(config.graph_db.auto_create, bool)
+ self.assertTrue(config.rabbitmq.erase_on_connect)
+ self.assertFalse(config.graph_db.auto_create)
+
+ def test_combined_with_local_config(self):
+ """Test priority between prefixed environment variables and config files"""
+ with NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f:
+ f.write("""
+ rabbitmq:
+ host_name: "file.rabbit.com"
+ port: 1234
+ openai:
+ api_key: "file_api_key"
+ graph_db:
+ uri: "bolt://file.db:7687"
+ password: "file_db_pass"
+ """)
+ config_file_path = f.name
+
+ try:
+ # Environment variables with correct prefixes
+ os.environ.update(
+ {
+ f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "env.rabbit.com",
+ f"{ENV_PREFIX}OPENAI_API_KEY": "env_api_key",
+ f"{ENV_PREFIX}GRAPHDBAUTH_USER": "env_user",
+ f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "env_db_pass", # 11 chars (valid)
+ }
+ )
+
+ # 1. Test loading from config file
+ file_config = AuthConfig.from_local_config(Path(config_file_path))
+ self.assertEqual(file_config.rabbitmq.host_name, "file.rabbit.com")
+ self.assertEqual(file_config.rabbitmq.port, 1234)
+ self.assertEqual(file_config.openai.api_key, "file_api_key")
+ self.assertEqual(file_config.graph_db.password, "file_db_pass")
+
+ # 2. Test loading from environment variables
+ env_config = AuthConfig.from_local_env()
+ self.assertEqual(env_config.rabbitmq.host_name, "env.rabbit.com")
+ self.assertEqual(env_config.openai.api_key, "env_api_key")
+ self.assertEqual(env_config.graph_db.user, "env_user")
+ self.assertEqual(env_config.graph_db.password, "env_db_pass")
+ self.assertEqual(env_config.rabbitmq.port, 5672)
+
+ finally:
+ os.unlink(config_file_path)
+
+
+class TestConfigUtils(unittest.TestCase):
+ """Tests for config_utils functions: flatten_dict and convert_config_to_env"""
+
+ def test_flatten_dict_basic(self):
+ """Test basic dictionary flattening without prefix"""
+ input_dict = {"database": {"host": "localhost", "port": 5432}, "auth": {"enabled": True}}
+
+ expected = {"DATABASE_HOST": "localhost", "DATABASE_PORT": "5432", "AUTH_ENABLED": "True"}
+
+ self.assertEqual(flatten_dict(input_dict), expected)
+
+ def test_flatten_dict_with_prefix(self):
+ """Test dictionary flattening with a custom prefix"""
+ input_dict = {"rabbitmq": {"host": "rabbit.local"}}
+
+ expected = {"APP_RABBITMQ_HOST": "rabbit.local"}
+
+ self.assertEqual(flatten_dict(input_dict, prefix="app"), expected)
+
+ def test_flatten_dict_special_chars(self):
+ """Test handling of spaces and hyphens in keys"""
+ input_dict = {"my key": "value", "other-key": {"nested key": 123}}
+
+ expected = {"MY_KEY": "value", "OTHER_KEY_NESTED_KEY": "123"}
+
+ self.assertEqual(flatten_dict(input_dict), expected)
+
+ def test_flatten_dict_none_values(self):
+ """Test handling of None values"""
+ input_dict = {"optional": None, "required": "present"}
+
+ expected = {"OPTIONAL": "", "REQUIRED": "present"}
+
+ self.assertEqual(flatten_dict(input_dict), expected)
+
+ def test_convert_json_to_env(self):
+ """Test conversion from JSON to .env file"""
+ with TemporaryDirectory() as temp_dir:
+ input_path = os.path.join(temp_dir, "config.json")
+ output_path = os.path.join(temp_dir, ".env")
+
+ # Create test JSON file
+ with open(input_path, "w") as f:
+ f.write('{"server": {"port": 8080}, "debug": false}')
+
+ # Convert to .env
+ convert_config_to_env(input_path, output_path, prefix="app")
+
+ # Verify output
+ with open(output_path) as f:
+ content = f.read()
+
+ self.assertIn('APP_SERVER_PORT="8080"', content)
+ self.assertIn('APP_DEBUG="False"', content)
+
+ def test_convert_yaml_to_env(self):
+ """Test conversion from YAML to .env file"""
+ with TemporaryDirectory() as temp_dir:
+ input_path = os.path.join(temp_dir, "config.yaml")
+ output_path = os.path.join(temp_dir, ".env")
+
+ # Create test YAML file
+ with open(input_path, "w") as f:
+ f.write("""
+ database:
+ host: db.example.com
+ credentials:
+ user: admin
+ pass: secret
+ """)
+
+ # Convert to .env
+ convert_config_to_env(input_path, output_path)
+
+ # Verify output
+ with open(output_path) as f:
+ content = f.read()
+
+ self.assertIn('DATABASE_HOST="db.example.com"', content)
+ self.assertIn('DATABASE_CREDENTIALS_USER="admin"', content)
+ self.assertIn('DATABASE_CREDENTIALS_PASS="secret"', content)
+
+ def test_convert_with_special_values(self):
+ """Test conversion with values containing quotes and special characters"""
+ with TemporaryDirectory() as temp_dir:
+ input_path = os.path.join(temp_dir, "config.json")
+ output_path = os.path.join(temp_dir, ".env")
+
+ # Create test JSON with special values
+ with open(input_path, "w") as f:
+ f.write('{"description": "Hello \\"World\\"", "empty": null}')
+
+ # Convert to .env
+ convert_config_to_env(input_path, output_path)
+
+ # Verify output
+ with open(output_path) as f:
+ content = f.read()
+
+ # Values with double quotes should not have surrounding quotes
+ self.assertIn('DESCRIPTION=Hello "World"', content)
+ self.assertIn('EMPTY=""', content)
+
+ def test_unsupported_file_format(self):
+ """Test error handling for unsupported file formats"""
+ with TemporaryDirectory() as temp_dir:
+ input_path = os.path.join(temp_dir, "config.txt")
+ with open(input_path, "w") as f:
+ f.write("some content")
+
+ with self.assertRaises(ValueError) as context:
+ convert_config_to_env(input_path)
+
+ self.assertIn("Unsupported file format", str(context.exception))
+
+ def test_file_not_found(self):
+ """Test error handling for non-existent input file"""
+ with self.assertRaises(FileNotFoundError):
+ convert_config_to_env("non_existent_file.json")
+
+ def test_invalid_json(self):
+ """Test error handling for invalid JSON"""
+ with TemporaryDirectory() as temp_dir:
+ input_path = os.path.join(temp_dir, "bad.json")
+ with open(input_path, "w") as f:
+ f.write('{"invalid": json}') # Invalid JSON
+
+ with self.assertRaises(ValueError) as context:
+ convert_config_to_env(input_path)
+
+ self.assertIn("Error parsing file", str(context.exception))
diff --git a/tests/mem_scheduler/test_orm.py b/tests/mem_scheduler/test_orm.py
new file mode 100644
index 00000000..ddf4fea8
--- /dev/null
+++ b/tests/mem_scheduler/test_orm.py
@@ -0,0 +1,299 @@
+import os
+import tempfile
+import time
+
+from datetime import datetime, timedelta
+
+import pytest
+
+from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
+
+# Import the classes to test
+from memos.mem_scheduler.orm_modules.monitor_models import (
+ DBManagerForMemoryMonitorManager,
+ DBManagerForQueryMonitorQueue,
+)
+from memos.mem_scheduler.schemas.monitor_schemas import (
+ MemoryMonitorItem,
+ MemoryMonitorManager,
+ QueryMonitorItem,
+ QueryMonitorQueue,
+)
+
+
+# Test data
+TEST_USER_ID = "test_user"
+TEST_MEM_CUBE_ID = "test_mem_cube"
+TEST_QUEUE_ID = "test_queue"
+
+
+class TestBaseDBManager:
+ """Base class for DBManager tests with common fixtures"""
+
+ @pytest.fixture
+ def temp_db(self):
+ """Create a temporary database for testing."""
+ temp_dir = tempfile.mkdtemp()
+ db_path = os.path.join(temp_dir, "test_scheduler_orm.db")
+ yield db_path
+ # Cleanup
+ try:
+ if os.path.exists(db_path):
+ os.remove(db_path)
+ os.rmdir(temp_dir)
+ except (OSError, PermissionError):
+ pass # Ignore cleanup errors (e.g., file locked on Windows)
+
+ @pytest.fixture
+ def memory_manager_obj(self):
+ """Create a MemoryMonitorManager object for testing"""
+ return MemoryMonitorManager(
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ items=[
+ MemoryMonitorItem(
+ item_id="custom-id-123",
+ memory_text="Full test memory",
+ tree_memory_item=None,
+ tree_memory_item_mapping_key="full_test_key",
+ keywords_score=0.8,
+ sorting_score=0.9,
+ importance_score=0.7,
+ recording_count=3,
+ )
+ ],
+ )
+
+ @pytest.fixture
+ def query_queue_obj(self):
+ """Create a QueryMonitorQueue object for testing"""
+ queue = QueryMonitorQueue()
+ queue.put(
+ QueryMonitorItem(
+ item_id="query1",
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ query_text="How are you?",
+ timestamp=datetime.now(),
+ keywords=["how", "you"],
+ )
+ )
+ return queue
+
+ @pytest.fixture
+ def query_monitor_manager(self, temp_db, query_queue_obj):
+ """Create DBManagerForQueryMonitorQueue instance with temp DB."""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForQueryMonitorQueue(
+ engine=engine,
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ obj=query_queue_obj,
+ lock_timeout=10,
+ )
+
+ assert manager.engine is not None
+ assert manager.SessionLocal is not None
+ assert os.path.exists(temp_db)
+
+ yield manager
+ manager.close()
+
+ @pytest.fixture
+ def memory_monitor_manager(self, temp_db, memory_manager_obj):
+ """Create DBManagerForMemoryMonitorManager instance with temp DB."""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ obj=memory_manager_obj,
+ lock_timeout=10,
+ )
+
+ assert manager.engine is not None
+ assert manager.SessionLocal is not None
+ assert os.path.exists(temp_db)
+
+ yield manager
+ manager.close()
+
+ def test_save_and_load_query_queue(self, query_monitor_manager, query_queue_obj):
+ """Test saving and loading QueryMonitorQueue."""
+ # Save to database
+ query_monitor_manager.save_to_db(query_queue_obj)
+
+ # Load in a new manager
+ engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database)
+ new_manager = DBManagerForQueryMonitorQueue(
+ engine=engine,
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ obj=None,
+ lock_timeout=10,
+ )
+ loaded_queue = new_manager.load_from_db(acquire_lock=True)
+
+ assert loaded_queue is not None
+ items = loaded_queue.get_queue_content_without_pop()
+ assert len(items) == 1
+ assert items[0].item_id == "query1"
+ assert items[0].query_text == "How are you?"
+ new_manager.close()
+
+ def test_lock_mechanism(self, query_monitor_manager, query_queue_obj):
+ """Test lock acquisition and release."""
+ # Save current state
+ query_monitor_manager.save_to_db(query_queue_obj)
+
+ # Acquire lock
+ acquired = query_monitor_manager.acquire_lock(block=True)
+ assert acquired
+
+ # Try to acquire again (should fail without blocking)
+ assert not query_monitor_manager.acquire_lock(block=False)
+
+ # Release lock
+ query_monitor_manager.release_locks(
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ )
+
+ # Should be able to acquire again
+ assert query_monitor_manager.acquire_lock(block=False)
+
+ def test_lock_timeout(self, query_monitor_manager, query_queue_obj):
+ """Test lock timeout mechanism."""
+ # Save current state
+ query_monitor_manager.save_to_db(query_queue_obj)
+
+ query_monitor_manager.lock_timeout = 1
+
+ # Acquire lock
+ assert query_monitor_manager.acquire_lock(block=True)
+
+ # Wait for lock to expire
+ time.sleep(1.1)
+
+ # Should be able to acquire again
+ assert query_monitor_manager.acquire_lock(block=False)
+
+ def test_sync_with_orm(self, query_monitor_manager, query_queue_obj):
+ """Test synchronization between ORM and object."""
+ query_queue_obj.put(
+ QueryMonitorItem(
+ item_id="query2",
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ query_text="What's your name?",
+ timestamp=datetime.now(),
+ keywords=["name"],
+ )
+ )
+
+ # Save current state
+ query_monitor_manager.save_to_db(query_queue_obj)
+
+ # Create sync manager with empty queue
+ empty_queue = QueryMonitorQueue(maxsize=10)
+ engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database)
+ sync_manager = DBManagerForQueryMonitorQueue(
+ engine=engine,
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ obj=empty_queue,
+ lock_timeout=10,
+ )
+
+ # First sync - should create a new record with empty queue
+ sync_manager.sync_with_orm(size_limit=None)
+ items = sync_manager.obj.get_queue_content_without_pop()
+ assert len(items) == 0 # Empty queue since no existing data to merge
+
+ # Now save the empty queue to create a record
+ sync_manager.save_to_db(empty_queue)
+
+ # Test that sync_with_orm correctly handles version control
+ # The sync should increment version but not merge data when versions are the same
+ sync_manager.sync_with_orm(size_limit=None)
+ items = sync_manager.obj.get_queue_content_without_pop()
+ assert len(items) == 0 # Should remain empty since no merge occurred
+
+ # Verify that the version was incremented
+ assert sync_manager.last_version_control == "3" # Should increment from 2 to 3
+
+ sync_manager.close()
+
+ def test_sync_with_size_limit(self, query_monitor_manager, query_queue_obj):
+ """Test synchronization with size limit."""
+ now = datetime.now()
+ item_size = 1
+ for i in range(2, 6):
+ item_size += 1
+ query_queue_obj.put(
+ QueryMonitorItem(
+ item_id=f"query{i}",
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ query_text=f"Question {i}",
+ timestamp=now + timedelta(minutes=i),
+ keywords=[f"kw{i}"],
+ )
+ )
+
+ # First sync - should create a new record (size_limit not applied for new records)
+ size_limit = 3
+ query_monitor_manager.sync_with_orm(size_limit=size_limit)
+ items = query_monitor_manager.obj.get_queue_content_without_pop()
+ assert len(items) == item_size # All items since size_limit not applied for new records
+
+ # Save to create the record
+ query_monitor_manager.save_to_db(query_monitor_manager.obj)
+
+ # Test that sync_with_orm correctly handles version control
+ # The sync should increment version but not merge data when versions are the same
+ query_monitor_manager.sync_with_orm(size_limit=size_limit)
+ items = query_monitor_manager.obj.get_queue_content_without_pop()
+ assert len(items) == item_size # Should remain the same since no merge occurred
+
+ # Verify that the version was incremented
+ assert query_monitor_manager.last_version_control == "2"
+
+ def test_concurrent_access(self, temp_db, query_queue_obj):
+ """Test concurrent access to the same database."""
+
+ # Manager 1
+ engine1 = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager1 = DBManagerForQueryMonitorQueue(
+ engine=engine1,
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ obj=query_queue_obj,
+ lock_timeout=10,
+ )
+ manager1.save_to_db(query_queue_obj)
+
+ # Manager 2
+ engine2 = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager2 = DBManagerForQueryMonitorQueue(
+ engine=engine2,
+ user_id=TEST_USER_ID,
+ mem_cube_id=TEST_MEM_CUBE_ID,
+ obj=query_queue_obj,
+ lock_timeout=10,
+ )
+
+ # Manager1 acquires lock
+ assert manager1.acquire_lock(block=True)
+
+ # Manager2 fails to acquire
+ assert not manager2.acquire_lock(block=False)
+
+ # Manager1 releases
+ manager1.release_locks(user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID)
+
+ # Manager2 can now acquire
+ assert manager2.acquire_lock(block=False)
+
+ manager1.close()
+ manager2.close()
diff --git a/tests/mem_scheduler/test_retriever.py b/tests/mem_scheduler/test_retriever.py
index 0ef6eb8e..35c8b7f3 100644
--- a/tests/mem_scheduler/test_retriever.py
+++ b/tests/mem_scheduler/test_retriever.py
@@ -1,18 +1,25 @@
+import json
import sys
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
-from memos.configs.mem_scheduler import SchedulerConfigFactory
+from memos.configs.mem_scheduler import (
+ AuthConfig,
+ GraphDBAuthConfig,
+ OpenAIConfig,
+ RabbitMQConfig,
+ SchedulerConfigFactory,
+)
from memos.llms.base import BaseLLM
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
from memos.mem_scheduler.utils.filter_utils import (
- filter_similar_memories,
filter_too_short_memories,
+ filter_vector_based_similar_memories,
)
-from memos.memories.textual.tree import TreeTextMemory
+from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
FILE_PATH = Path(__file__).absolute()
@@ -21,6 +28,25 @@
class TestSchedulerRetriever(unittest.TestCase):
+ def _create_mock_auth_config(self):
+ """Create a mock AuthConfig for testing purposes."""
+ # Create mock configs with valid test values
+ graph_db_config = GraphDBAuthConfig(
+ uri="bolt://localhost:7687",
+ user="neo4j",
+ password="test_password_123", # 8+ characters to pass validation
+ db_name="neo4j",
+ auto_create=True,
+ )
+
+ rabbitmq_config = RabbitMQConfig(
+ host_name="localhost", port=5672, user_name="guest", password="guest", virtual_host="/"
+ )
+
+ openai_config = OpenAIConfig(api_key="test_api_key_123", default_model="gpt-3.5-turbo")
+
+ return AuthConfig(rabbitmq=rabbitmq_config, openai=openai_config, graph_db=graph_db_config)
+
def setUp(self):
"""Initialize test environment with mock objects."""
example_scheduler_config_path = (
@@ -37,6 +63,13 @@ def setUp(self):
self.mem_cube.text_mem = self.tree_text_memory
self.mem_cube.act_mem = MagicMock()
+ # Mock AuthConfig.from_local_env() to return our test config
+ mock_auth_config = self._create_mock_auth_config()
+ self.auth_config_patch = patch(
+ "memos.configs.mem_scheduler.AuthConfig.from_local_env", return_value=mock_auth_config
+ )
+ self.auth_config_patch.start()
+
# Initialize general_modules with mock LLM
self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm)
self.scheduler.mem_cube = self.mem_cube
@@ -47,17 +80,21 @@ def setUp(self):
self.logging_warning_patch = patch("logging.warning")
self.mock_logging_warning = self.logging_warning_patch.start()
- self.logger_info_patch = patch("memos.mem_scheduler.general_modules.retriever.logger.info")
+ # Mock the MemoryFilter logger since that's where the actual logging happens
+ self.logger_info_patch = patch(
+ "memos.mem_scheduler.memory_manage_modules.memory_filter.logger.info"
+ )
self.mock_logger_info = self.logger_info_patch.start()
def tearDown(self):
"""Clean up patches."""
self.logging_warning_patch.stop()
self.logger_info_patch.stop()
+ self.auth_config_patch.stop()
def test_filter_similar_memories_empty_input(self):
"""Test filter_similar_memories with empty input list."""
- result = filter_similar_memories([])
+ result = filter_vector_based_similar_memories([])
self.assertEqual(result, [])
def test_filter_similar_memories_no_duplicates(self):
@@ -68,7 +105,7 @@ def test_filter_similar_memories_no_duplicates(self):
"And this third one has nothing in common with the others",
]
- result = filter_similar_memories(memories)
+ result = filter_vector_based_similar_memories(memories)
self.assertEqual(len(result), 3)
self.assertEqual(set(result), set(memories))
@@ -79,14 +116,14 @@ def test_filter_similar_memories_with_duplicates(self):
"The user is planning to move to Chicago next month, which reflects a significant change in their living situation.",
"The user is planning to move to Chicago in the upcoming month, indicating a significant change in their living situation.",
]
- result = filter_similar_memories(memories, similarity_threshold=0.75)
+ result = filter_vector_based_similar_memories(memories, similarity_threshold=0.75)
self.assertLess(len(result), len(memories))
def test_filter_similar_memories_error_handling(self):
"""Test filter_similar_memories error handling."""
# Test with non-string input (should return original list due to error)
memories = ["valid text", 12345, "another valid text"]
- result = filter_similar_memories(memories)
+ result = filter_vector_based_similar_memories(memories)
self.assertEqual(result, memories)
def test_filter_too_short_memories_empty_input(self):
@@ -134,3 +171,192 @@ def test_filter_too_short_memories_edge_case(self):
) # "Exactly three words here", "Two words only", "Four words right here"
self.assertIn("Exactly three words here", result)
self.assertIn("Four words right here", result)
+
+ def test_filter_unrelated_memories_empty_memories(self):
+ """Test filter_unrelated_memories with empty memories list."""
+ query_history = ["What is the weather like?", "Tell me about Python programming"]
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=[]
+ )
+
+ self.assertEqual(result, [])
+ self.assertTrue(success_flag)
+ self.mock_logger_info.assert_called_with("No memories to filter - returning empty list")
+
+ def test_filter_unrelated_memories_empty_query_history(self):
+ """Test filter_unrelated_memories with empty query history."""
+ memories = [
+ TextualMemoryItem(memory="Python is a programming language"),
+ TextualMemoryItem(memory="Machine learning uses algorithms"),
+ TextualMemoryItem(memory="Data science involves statistics"),
+ ]
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=[], memories=memories
+ )
+
+ self.assertEqual(result, memories)
+ self.assertTrue(success_flag)
+ self.mock_logger_info.assert_called_with("No query history provided - keeping all memories")
+
+ def test_filter_unrelated_memories_successful_filtering(self):
+ """Test filter_unrelated_memories with successful LLM filtering."""
+ query_history = ["What is Python?", "How does machine learning work?"]
+ memories = [
+ TextualMemoryItem(memory="Python is a high-level programming language"),
+ TextualMemoryItem(memory="Machine learning algorithms learn from data"),
+ TextualMemoryItem(memory="The weather is sunny today"), # Unrelated
+ TextualMemoryItem(memory="Python has many libraries for ML"),
+ TextualMemoryItem(memory="Cooking recipes for pasta"), # Unrelated
+ ]
+
+ # Mock LLM response for successful filtering
+ mock_llm_response = {
+ "relevant_memories": [0, 1, 3], # Keep Python, ML, and Python ML libraries
+ "filtered_count": 2, # Filter out weather and cooking
+ "reasoning": "Kept memories related to Python and machine learning, filtered out unrelated topics",
+ }
+
+ # Convert to proper JSON string
+ self.llm.generate.return_value = json.dumps(mock_llm_response)
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=memories
+ )
+
+ # Verify results
+ self.assertEqual(len(result), 3)
+ self.assertIn(memories[0], result) # Python
+ self.assertIn(memories[1], result) # ML
+ self.assertIn(memories[3], result) # Python ML libraries
+ self.assertNotIn(memories[2], result) # Weather
+ self.assertNotIn(memories[4], result) # Cooking
+ self.assertTrue(success_flag)
+
+ # Verify LLM was called correctly
+ self.llm.generate.assert_called_once()
+ call_args = self.llm.generate.call_args[0][0]
+ self.assertEqual(call_args[0]["role"], "user")
+ self.assertIn("Memory Relevance Filtering Task", call_args[0]["content"])
+
+ def test_filter_unrelated_memories_llm_failure_fallback(self):
+ """Test filter_unrelated_memories with LLM failure - should fallback to keeping all memories."""
+ query_history = ["What is Python?"]
+ memories = [
+ TextualMemoryItem(memory="Python is a programming language"),
+ TextualMemoryItem(memory="Machine learning is a subset of AI"),
+ ]
+
+ # Mock LLM to return an invalid response that will trigger error handling
+ self.llm.generate.return_value = "Invalid response that cannot be parsed"
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=memories
+ )
+
+ # Should return all memories as fallback
+ self.assertEqual(result, memories)
+ self.assertFalse(success_flag)
+
+ # Verify error was logged
+ self.mock_logger_info.assert_called_with(
+ "Starting memory filtering for 2 memories against 1 queries"
+ )
+
+ def test_filter_unrelated_memories_invalid_json_response(self):
+ """Test filter_unrelated_memories with invalid JSON response from LLM."""
+ query_history = ["What is Python?"]
+ memories = [
+ TextualMemoryItem(memory="Python is a programming language"),
+ TextualMemoryItem(memory="Machine learning is a subset of AI"),
+ ]
+
+ # Mock LLM to return invalid JSON
+ self.llm.generate.return_value = "This is not valid JSON"
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=memories
+ )
+
+ # Should return all memories as fallback
+ self.assertEqual(result, memories)
+ self.assertFalse(success_flag)
+
+ def test_filter_unrelated_memories_invalid_indices(self):
+ """Test filter_unrelated_memories with invalid indices in LLM response."""
+ query_history = ["What is Python?"]
+ memories = [
+ TextualMemoryItem(memory="Python is a programming language"),
+ TextualMemoryItem(memory="Machine learning is a subset of AI"),
+ ]
+
+ # Mock LLM to return invalid indices
+ mock_llm_response = {
+ "relevant_memories": [0, 5, -1], # Invalid indices
+ "filtered_count": 1,
+ "reasoning": "Some memories are relevant",
+ }
+
+ # Convert to proper JSON string
+ self.llm.generate.return_value = json.dumps(mock_llm_response)
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=memories
+ )
+
+ # Should only include valid indices
+ self.assertEqual(len(result), 1)
+ self.assertIn(memories[0], result) # Index 0 is valid
+ self.assertTrue(success_flag)
+
+ def test_filter_unrelated_memories_missing_required_fields(self):
+ """Test filter_unrelated_memories with missing required fields in LLM response."""
+ query_history = ["What is Python?"]
+ memories = [
+ TextualMemoryItem(memory="Python is a programming language"),
+ TextualMemoryItem(memory="Machine learning is a subset of AI"),
+ ]
+
+ # Mock LLM to return response missing required fields
+ mock_llm_response = {
+ "relevant_memories": [0, 1]
+ # Missing "filtered_count" and "reasoning"
+ }
+
+ # Convert to proper JSON string
+ self.llm.generate.return_value = json.dumps(mock_llm_response)
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=memories
+ )
+
+ # Should return all memories as fallback due to missing fields
+ self.assertEqual(result, memories)
+ self.assertFalse(success_flag)
+
+ def test_filter_unrelated_memories_conservative_filtering(self):
+ """Test that filter_unrelated_memories uses conservative approach - keeps memories when in doubt."""
+ query_history = ["What is Python?"]
+ memories = [
+ TextualMemoryItem(memory="Python is a programming language"),
+ TextualMemoryItem(memory="Machine learning is a subset of AI"),
+ TextualMemoryItem(memory="The weather is sunny today"), # Potentially unrelated
+ ]
+
+ # Mock LLM to return all memories as relevant (conservative)
+ mock_llm_response = {
+ "relevant_memories": [0, 1, 2], # Keep all memories
+ "filtered_count": 0, # No filtering
+ "reasoning": "All memories could potentially provide context",
+ }
+
+ self.llm.generate.return_value = json.dumps(mock_llm_response)
+
+ result, success_flag = self.retriever.filter_unrelated_memories(
+ query_history=query_history, memories=memories
+ )
+
+ # Should return all memories
+ self.assertEqual(result, memories)
+ self.assertTrue(success_flag)
diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py
index 97377738..51ea5677 100644
--- a/tests/mem_scheduler/test_scheduler.py
+++ b/tests/mem_scheduler/test_scheduler.py
@@ -3,12 +3,18 @@
from datetime import datetime
from pathlib import Path
-from unittest.mock import MagicMock
-
-from memos.configs.mem_scheduler import SchedulerConfigFactory
+from unittest.mock import MagicMock, patch
+
+from memos.configs.mem_scheduler import (
+ AuthConfig,
+ GraphDBAuthConfig,
+ OpenAIConfig,
+ RabbitMQConfig,
+ SchedulerConfigFactory,
+)
from memos.llms.base import BaseLLM
from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever
+from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
from memos.mem_scheduler.schemas.general_schemas import (
@@ -27,6 +33,25 @@
class TestGeneralScheduler(unittest.TestCase):
+ def _create_mock_auth_config(self):
+ """Create a mock AuthConfig for testing purposes."""
+ # Create mock configs with valid test values
+ graph_db_config = GraphDBAuthConfig(
+ uri="bolt://localhost:7687",
+ user="neo4j",
+ password="test_password_123", # 8+ characters to pass validation
+ db_name="neo4j",
+ auto_create=True,
+ )
+
+ rabbitmq_config = RabbitMQConfig(
+ host_name="localhost", port=5672, user_name="guest", password="guest", virtual_host="/"
+ )
+
+ openai_config = OpenAIConfig(api_key="test_api_key_123", default_model="gpt-3.5-turbo")
+
+ return AuthConfig(rabbitmq=rabbitmq_config, openai=openai_config, graph_db=graph_db_config)
+
def setUp(self):
"""Initialize test environment with mock objects and test scheduler instance."""
example_scheduler_config_path = (
@@ -43,6 +68,13 @@ def setUp(self):
self.mem_cube.text_mem = self.tree_text_memory
self.mem_cube.act_mem = MagicMock()
+ # Mock AuthConfig.from_local_env() to return our test config
+ mock_auth_config = self._create_mock_auth_config()
+ self.auth_config_patch = patch(
+ "memos.configs.mem_scheduler.AuthConfig.from_local_env", return_value=mock_auth_config
+ )
+ self.auth_config_patch.start()
+
# Initialize general_modules with mock LLM
self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm)
self.scheduler.mem_cube = self.mem_cube
@@ -51,6 +83,10 @@ def setUp(self):
self.scheduler.current_user_id = "test_user"
self.scheduler.current_mem_cube_id = "test_cube"
+ def tearDown(self):
+ """Clean up patches."""
+ self.auth_config_patch.stop()
+
def test_initialization(self):
"""Test that scheduler initializes with correct default values and handlers."""
# Verify handler registration
diff --git a/tests/mem_scheduler/test_version_control.py b/tests/mem_scheduler/test_version_control.py
new file mode 100644
index 00000000..efe2c6b7
--- /dev/null
+++ b/tests/mem_scheduler/test_version_control.py
@@ -0,0 +1,273 @@
+import os
+import tempfile
+
+import pytest
+
+from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
+from memos.mem_scheduler.orm_modules.monitor_models import DBManagerForMemoryMonitorManager
+from memos.mem_scheduler.schemas.monitor_schemas import (
+ MemoryMonitorItem,
+ MemoryMonitorManager,
+)
+
+
+class TestVersionControl:
+ """Test version control functionality"""
+
+ @pytest.fixture
+ def temp_db(self):
+ """Create a temporary database for testing."""
+ temp_dir = tempfile.mkdtemp()
+ db_path = os.path.join(temp_dir, "test_version_control.db")
+ yield db_path
+ # Cleanup
+ try:
+ if os.path.exists(db_path):
+ os.remove(db_path)
+ os.rmdir(temp_dir)
+ except (OSError, PermissionError):
+ pass
+
+ @pytest.fixture
+ def memory_manager_obj(self):
+ """Create a MemoryMonitorManager object for testing"""
+ return MemoryMonitorManager(
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ memories=[
+ MemoryMonitorItem(
+ item_id="test-item-1",
+ memory_text="Test memory 1",
+ tree_memory_item=None,
+ tree_memory_item_mapping_key="test_key_1",
+ keywords_score=0.8,
+ sorting_score=0.9,
+ importance_score=0.7,
+ recording_count=1,
+ )
+ ],
+ )
+
+ def test_version_control_increment(self, temp_db, memory_manager_obj):
+ """Test that version_control increments correctly"""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # Test increment method
+ assert manager._increment_version_control("0") == "1"
+ assert manager._increment_version_control("255") == "0" # Should cycle back to 0
+ assert manager._increment_version_control("100") == "101"
+ assert (
+ manager._increment_version_control("invalid") == "0"
+ ) # Should handle invalid input
+
+ finally:
+ manager.close()
+
+ def test_new_record_has_version_zero(self, temp_db, memory_manager_obj):
+ """Test that new records start with version_control = "0" """
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # Save to database
+ manager.save_to_db(memory_manager_obj)
+
+ # Check that last_version_control was set to "0"
+ assert manager.last_version_control == "0"
+
+ # Load from database and verify version_control
+ loaded_obj = manager.load_from_db()
+ assert loaded_obj is not None
+
+ # Check that the version was tracked
+ assert manager.last_version_control == "0"
+
+ finally:
+ manager.close()
+
+ def test_version_control_increments_on_save(self, temp_db, memory_manager_obj):
+ """Test that version_control increments when saving existing records"""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # First save - should create with version "0"
+ manager.save_to_db(memory_manager_obj)
+ assert manager.last_version_control == "0"
+
+ # Second save - should increment to version "1"
+ manager.save_to_db(memory_manager_obj)
+ assert manager.last_version_control == "1"
+
+ # Third save - should increment to version "2"
+ manager.save_to_db(memory_manager_obj)
+ assert manager.last_version_control == "2"
+
+ finally:
+ manager.close()
+
+ def test_sync_with_orm_version_control(self, temp_db, memory_manager_obj):
+ """Test version control behavior in sync_with_orm"""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # First sync - should create with version "0"
+ manager.sync_with_orm()
+ assert manager.last_version_control == "0"
+
+ # Second sync with same object - should increment version because sync_with_orm always increments
+ manager.sync_with_orm()
+ assert (
+ manager.last_version_control == "1"
+ ) # Should increment to "1" since sync_with_orm always increments
+
+ # Third sync - should increment to version "2"
+ manager.sync_with_orm()
+ assert manager.last_version_control == "2" # Should increment to "2"
+
+ # Simulate a change by creating a new object with different content
+ new_memory_manager = MemoryMonitorManager(
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ memories=[
+ MemoryMonitorItem(
+ item_id="test-item-2",
+ memory_text="Test memory 2",
+ tree_memory_item=None,
+ tree_memory_item_mapping_key="test_key_2",
+ keywords_score=0.9,
+ sorting_score=0.8,
+ importance_score=0.6,
+ recording_count=2,
+ )
+ ],
+ )
+
+ # Update the manager's object
+ manager.obj = new_memory_manager
+
+ # Sync again - should increment version because object content changed
+ manager.sync_with_orm()
+ assert manager.last_version_control == "3" # Should increment to "3"
+
+ finally:
+ manager.close()
+
+ def test_version_control_cycles_correctly(self, temp_db, memory_manager_obj):
+ """Test that version_control cycles from 255 back to 0"""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # Test the increment method directly
+ assert manager._increment_version_control("255") == "0"
+ assert manager._increment_version_control("254") == "255"
+ assert manager._increment_version_control("0") == "1"
+
+ finally:
+ manager.close()
+
+ def test_load_from_db_updates_version_control(self, temp_db, memory_manager_obj):
+ """Test that load_from_db updates last_version_control correctly"""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+ manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # Save to database first
+ manager.save_to_db(memory_manager_obj)
+ assert manager.last_version_control == "0"
+
+ # Create a new manager instance to load the data
+ load_manager = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ )
+
+ # Load from database
+ loaded_obj = load_manager.load_from_db()
+ assert loaded_obj is not None
+ assert load_manager.last_version_control == "0" # Should be updated to loaded version
+
+ load_manager.close()
+
+ finally:
+ manager.close()
+
+ def test_version_control_persistence_across_instances(self, temp_db, memory_manager_obj):
+ """Test that version control persists across different manager instances"""
+ engine = BaseDBManager.create_engine_from_db_path(temp_db)
+
+ # First manager instance
+ manager1 = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ try:
+ # Save multiple times to increment version
+ manager1.save_to_db(memory_manager_obj)
+ assert manager1.last_version_control == "0"
+
+ manager1.save_to_db(memory_manager_obj)
+ assert manager1.last_version_control == "1"
+
+ manager1.save_to_db(memory_manager_obj)
+ assert manager1.last_version_control == "2"
+
+ # Create second manager instance
+ manager2 = DBManagerForMemoryMonitorManager(
+ engine=engine,
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ obj=memory_manager_obj,
+ )
+
+ # Load should show the same version
+ loaded_obj = manager2.load_from_db()
+ assert loaded_obj is not None
+ assert manager2.last_version_control == "2" # Should match the last saved version
+
+ # Save again should increment from the loaded version
+ manager2.save_to_db(memory_manager_obj)
+ assert manager2.last_version_control == "3"
+
+ manager2.close()
+
+ finally:
+ manager1.close()