Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5aab9ef

Browse files
committed
feat(client): add timeout parameter to MemFuse clients for improved request handling
- Introduced a `timeout` parameter in both `AsyncMemFuse` and `MemFuse` classes to allow customization of request timeouts, defaulting to 10 seconds. - Updated relevant methods to utilize the new timeout setting for HTTP requests, enhancing reliability during long-running operations. - Adjusted tests to reflect the new timeout parameter, ensuring consistent behavior across different scenarios.
1 parent 22527d9 commit 5aab9ef

File tree

6 files changed

+44
-24
lines changed

6 files changed

+44
-24
lines changed

benchmarks/utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,36 @@ def convert_messages_for_memfuse(messages: List[Dict[str, Any]], dataset_type: s
382382
# LoCoMo messages already have speaker names embedded in content (e.g., "[CAROLINE]: text")
383383
memfuse_messages = []
384384
for msg in messages:
385+
content = msg.get("content", "")
386+
role = msg.get("role", "user")
387+
388+
# Skip messages with empty content or role
389+
if not content or not content.strip() or not role or not role.strip():
390+
continue
391+
385392
memfuse_msg = {
386-
"role": msg.get("role", "user"),
387-
"content": msg.get("content", "")
393+
"role": role,
394+
"content": content
388395
}
389396
memfuse_messages.append(memfuse_msg)
390397
return memfuse_messages
391398
else:
392-
# MSC and LME use standard format already
393-
return messages
399+
# MSC and LME need to filter out extra fields and empty messages to match MemFuse format
400+
memfuse_messages = []
401+
for msg in messages:
402+
content = msg.get("content", "")
403+
role = msg.get("role", "user")
404+
405+
# Skip messages with empty content or role
406+
if not content or not content.strip() or not role or not role.strip():
407+
continue
408+
409+
memfuse_msg = {
410+
"role": role,
411+
"content": content
412+
}
413+
memfuse_messages.append(memfuse_msg)
414+
return memfuse_messages
394415

395416

396417
async def load_dataset_to_memfuse(

src/memfuse/client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ class AsyncMemFuse:
2323
# Class variable to track all instances
2424
_instances = set()
2525

26-
def __init__(self, base_url: str = "http://localhost:8000", api_key: Optional[str] = None):
26+
def __init__(self, base_url: str = "http://localhost:8000", api_key: Optional[str] = None, timeout: int = 10):
2727
"""Initialize the MemFuse client.
2828
2929
Args:
3030
base_url: URL of the MemFuse server API
3131
api_key: API key for authentication (optional for local usage)
32+
timeout: Request timeout in seconds (default: 10)
3233
"""
3334
self.base_url = base_url.rstrip("/")
3435
self.api_key = api_key or os.environ.get("MEMFUSE_API_KEY")
36+
self.timeout = timeout
3537
self.session = None
3638

3739
# Initialize ASYNC API clients using the classes from .api
@@ -69,7 +71,7 @@ async def _check_server_health(self) -> bool:
6971
await self._ensure_session()
7072
try:
7173
url = f"{self.base_url}/api/v1/health"
72-
async with self.session.get(url, timeout=10) as response:
74+
async with self.session.get(url, timeout=self.timeout) as response:
7375
if response.status == 200:
7476
return True
7577
return False
@@ -108,7 +110,7 @@ async def _request(
108110

109111
url = f"{self.base_url}{endpoint}"
110112

111-
async with getattr(self.session, method.lower())(url, json=data) as response:
113+
async with getattr(self.session, method.lower())(url, json=data, timeout=self.timeout) as response:
112114
response_data = await response.json()
113115
if response.status >= 400:
114116
error_message = response_data.get("message", "Unknown error")
@@ -257,15 +259,17 @@ async def _thread_safe_coro_runner(self, coro):
257259
class MemFuse:
258260
"""Synchronous MemFuse client for communicating with the MemFuse server."""
259261

260-
def __init__(self, base_url: str = "http://localhost:8000", api_key: Optional[str] = None):
262+
def __init__(self, base_url: str = "http://localhost:8000", api_key: Optional[str] = None, timeout: int = 10):
261263
"""Initialize the synchronous MemFuse client.
262264
263265
Args:
264266
base_url: URL of the MemFuse server API
265267
api_key: API key for authentication (optional for local usage)
268+
timeout: Request timeout in seconds (default: 10)
266269
"""
267270
self.base_url = base_url.rstrip("/")
268271
self.api_key = api_key or os.environ.get("MEMFUSE_API_KEY")
272+
self.timeout = timeout
269273
self.sync_session = None # requests session for sync requests
270274

271275
# Initialize API clients using the classes from .api
@@ -300,7 +304,7 @@ def _check_server_health_sync(self) -> bool:
300304
self._ensure_sync_session()
301305
try:
302306
url = f"{self.base_url}/api/v1/health"
303-
response = self.sync_session.get(url, timeout=10)
307+
response = self.sync_session.get(url, timeout=self.timeout)
304308
return response.status_code == 200
305309
except Exception:
306310
return False
@@ -336,7 +340,7 @@ def _request_sync(
336340

337341
url = f"{self.base_url}{endpoint}"
338342

339-
response = getattr(self.sync_session, method.lower())(url, json=data, timeout=10)
343+
response = getattr(self.sync_session, method.lower())(url, json=data, timeout=self.timeout)
340344
response_data = response.json()
341345
if response.status_code >= 400:
342346
error_message = response_data.get("message", "Unknown error")

tests/benchmarks/test_msc_accuracy.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@
5454

5555
# Suppress Pydantic serialization warnings from LiteLLM
5656
warnings.filterwarnings("ignore", message=".*Pydantic serializer warnings.*")
57+
warnings.filterwarnings("ignore", message=".*Expected.*fields but got.*")
58+
warnings.filterwarnings("ignore", message=".*PydanticSerializationUnexpectedValue.*")
5759
warnings.filterwarnings("ignore", category=UserWarning, module=".*pydantic.*")
5860
warnings.filterwarnings("ignore", category=UserWarning, module=".*litellm.*")
61+
warnings.filterwarnings("ignore", category=UserWarning, module=".*litellm_core_utils.*")
5962

6063
# Test configuration
6164
DEFAULT_NUM_QUESTIONS = 20
@@ -103,7 +106,6 @@ async def load_msc_to_memfuse(dataset, logger):
103106
else:
104107
logger.error(f"Failed to add messages to session '{session_id}' for Q{question_number}: {add_result}")
105108
continue
106-
await memory_instance.close()
107109
successfully_loaded_count += 1
108110
logger.info(f"--- Successfully loaded Question {question_number}/{len(dataset)} ---")
109111
except ConnectionError as e:
@@ -120,7 +122,6 @@ async def run_msc_benchmark_with_results(dataset, logger):
120122
"""Query MemFuse memory with MSC questions and return structured results."""
121123
results = MSCBenchmarkResults()
122124
memfuse_client = AsyncMemFuse()
123-
all_created_mem_instances = []
124125
total_start_time = time.perf_counter()
125126
for i, data_sample in enumerate(dataset):
126127
question_number = i + 1
@@ -143,7 +144,6 @@ async def run_msc_benchmark_with_results(dataset, logger):
143144
user=user_name_for_test,
144145
agent=agent_name_for_test,
145146
)
146-
all_created_mem_instances.append(query_memory_instance)
147147
start_time = time.perf_counter()
148148
memory_response = await query_memory_instance.query(query=question_text, top_k=3)
149149
end_time = time.perf_counter()
@@ -207,12 +207,7 @@ async def run_msc_benchmark_with_results(dataset, logger):
207207
except Exception as e:
208208
logger.error(f"A critical unexpected error occurred: {e}", exc_info=True)
209209
finally:
210-
logger.info("Closing AsyncMemFuse client and query session instances...")
211-
for mem_instance in all_created_mem_instances:
212-
try:
213-
await mem_instance.close()
214-
except Exception as e_close:
215-
logger.error(f"Error closing memory instance for session {mem_instance.session if mem_instance else 'N/A'}: {e_close}")
210+
logger.info("Closing AsyncMemFuse client...")
216211
if memfuse_client:
217212
await memfuse_client.close()
218213
total_end_time = time.perf_counter()

tests/e2e/test_e2e_async_memory_followup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def test_async_memory_followup_includes_mars_reference():
5656
# Arrange – create AsyncMemFuse session & AsyncOpenAI client with memory
5757
# ---------------------------------------------------------------------
5858
try:
59-
memfuse = AsyncMemFuse(base_url=memfuse_base_url)
59+
memfuse = AsyncMemFuse(base_url=memfuse_base_url, timeout=30)
6060
except Exception as exc:
6161
pytest.skip(f"Cannot connect to MemFuse server at {memfuse_base_url}: {exc}")
6262

@@ -174,7 +174,7 @@ async def test_async_memory_with_context_manager():
174174
# Test using async context manager (recommended approach)
175175
# ---------------------------------------------------------------------
176176
try:
177-
async with AsyncMemFuse(base_url=memfuse_base_url) as memfuse:
177+
async with AsyncMemFuse(base_url=memfuse_base_url, timeout=30) as memfuse:
178178
memory = await memfuse.init(user="e2e_async_context_test_user", session="context_test")
179179

180180
client = AsyncOpenAI(

tests/e2e/test_e2e_memory_followup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_memory_followup_includes_mars_reference():
5454
# Arrange – create MemFuse session & OpenAI client with memory attached
5555
# ---------------------------------------------------------------------
5656
try:
57-
memfuse = MemFuse(base_url=memfuse_base_url)
57+
memfuse = MemFuse(base_url=memfuse_base_url, timeout=30)
5858
except Exception as exc:
5959
pytest.skip(f"Cannot connect to MemFuse server at {memfuse_base_url}: {exc}")
6060

tests/e2e/test_e2e_multi_turn_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_multi_turn_conversation_memory():
5858
openai_api_key = os.getenv("OPENAI_API_KEY")
5959

6060
try:
61-
# Initialize MemFuse client
62-
memfuse = MemFuse(base_url=memfuse_base_url)
61+
# Initialize MemFuse client with longer timeout for multi-turn operations
62+
memfuse = MemFuse(base_url=memfuse_base_url, timeout=30)
6363

6464
# Create a unique session for this test
6565
test_session = f"test_multi_turn_{int(time.time())}"

0 commit comments

Comments
 (0)