-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathtest_ws_callback_handler_leak.py
More file actions
126 lines (97 loc) · 4.2 KB
/
Copy pathtest_ws_callback_handler_leak.py
File metadata and controls
126 lines (97 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Mehmet Bektas <[email protected]>
"""Pin the contract that `_messageCallbackHandlers` does not grow
unbounded for the lifetime of the websocket connection.
Pre-fix: every chat / generate-code / inline-completion request added an
entry keyed by messageId and nothing ever removed them. Across a long
chat session this leaked one response emitter + cancel token per turn.
Fix: the worker thread is wrapped in `_run_request_thread`, which pops
the entry once the request coroutine returns. `on_close` clears the
whole dict so requests still in flight at disconnect don't pin their
state for the worker's lifetime.
"""
import asyncio
import threading
from unittest.mock import MagicMock
from notebook_intelligence.extension import (
MessageCallbackHandlers,
WebsocketCopilotHandler,
)
def _make_handler():
"""Build a WebsocketCopilotHandler without booting the Tornado
application. Only the dict-management surface is under test, so
bypassing __init__ is the cleanest approach.
"""
h = WebsocketCopilotHandler.__new__(WebsocketCopilotHandler)
h._messageCallbackHandlers = {}
return h
class TestRunRequestThreadPopsHandler:
def test_pops_entry_after_coro_returns_normally(self):
h = _make_handler()
emitter = MagicMock()
token = MagicMock()
h._messageCallbackHandlers["m1"] = MessageCallbackHandlers(emitter, token)
async def coro():
return "ok"
h._run_request_thread(coro(), "m1")
assert "m1" not in h._messageCallbackHandlers
def test_pops_entry_when_coro_raises(self):
# A worker exception must not leak the entry. The user may keep
# the chat session open after a failed turn and start another;
# repeated failures must not grow the dict. The wrapper deliberately
# re-raises (asyncio.run propagates) so the upstream error surfaces
# in thread-level logging; pytest.raises pins that contract.
h = _make_handler()
h._messageCallbackHandlers["m1"] = MessageCallbackHandlers(MagicMock(), MagicMock())
import pytest as _pytest
async def boom():
raise RuntimeError("upstream failure")
with _pytest.raises(RuntimeError, match="upstream failure"):
h._run_request_thread(boom(), "m1")
assert "m1" not in h._messageCallbackHandlers
def test_unknown_message_id_is_safe(self):
# Pop with default short-circuits cleanly so a race where the
# cleanup runs twice does not crash.
h = _make_handler()
async def coro():
return None
h._run_request_thread(coro(), "never-registered")
# Second call is a no-op.
h._run_request_thread(coro(), "never-registered")
def test_multiple_concurrent_requests_each_clean_up(self):
# The realistic concurrency pattern: several inline-completion
# requests in flight at once. Each thread's cleanup pops only
# its own entry; no global clear.
h = _make_handler()
for i in range(5):
h._messageCallbackHandlers[f"m{i}"] = MessageCallbackHandlers(
MagicMock(), MagicMock()
)
threads = []
for i in range(5):
async def coro():
return None
t = threading.Thread(
target=h._run_request_thread, args=(coro(), f"m{i}")
)
threads.append(t)
t.start()
for t in threads:
t.join()
assert h._messageCallbackHandlers == {}
class TestOnCloseClearsHandlers:
def test_on_close_drops_all_in_flight_entries(self):
# Long-running requests left in flight at disconnect would
# otherwise pin their emitter + cancel token until the worker
# finished. on_close drops everything so the GC can reclaim.
h = _make_handler()
for i in range(3):
h._messageCallbackHandlers[f"m{i}"] = MessageCallbackHandlers(
MagicMock(), MagicMock()
)
h.on_close()
assert h._messageCallbackHandlers == {}
def test_on_close_is_idempotent(self):
h = _make_handler()
h.on_close()
h.on_close()
assert h._messageCallbackHandlers == {}