Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3576258

Browse files
authored
Improve internal task management checks (#128)
* Move task management state validity checks out of assert statements * Add stricter checks to methods that *must* run in the task manager
1 parent 83c3f19 commit 3576258

File tree

1 file changed

+57
-26
lines changed

1 file changed

+57
-26
lines changed

src/lmstudio/_ws_impl.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
# Python 3.10 compatibility: use concurrent.futures.TimeoutError instead of the builtin
1616
# In 3.11+, these are the same type, in 3.10 futures have their own timeout exception
1717
from concurrent.futures import Future as SyncFuture, TimeoutError as SyncFutureTimeout
18+
from contextvars import ContextVar
1819
from contextlib import AsyncExitStack, contextmanager
1920
from functools import partial
2021
from typing import (
2122
Any,
2223
Awaitable,
2324
Coroutine,
2425
Callable,
26+
ClassVar,
2527
Generator,
2628
TypeAlias,
2729
TypeVar,
@@ -35,6 +37,7 @@
3537
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
3638

3739
from .schemas import DictObject
40+
from .sdk_api import LMStudioRuntimeError
3841
from .json_api import (
3942
LMStudioWebsocket,
4043
LMStudioWebsocketError,
@@ -58,6 +61,8 @@
5861

5962

6063
class AsyncTaskManager:
64+
_LMS_TASK_MANAGER: ClassVar[ContextVar[Self]] = ContextVar("_LMS_TASK_MANAGER")
65+
6166
def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None:
6267
self._activated = False
6368
self._event_loop: asyncio.AbstractEventLoop | None = None
@@ -98,15 +103,19 @@ async def __aexit__(self, *args: Any) -> None:
98103
with move_on_after(self.TERMINATION_TIMEOUT):
99104
await self._terminated.wait()
100105

101-
def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
102-
"""Returns if running in this manager's event loop, raises RuntimeError otherwise."""
106+
@classmethod
107+
def get_running_task_manager(cls) -> Self:
108+
try:
109+
return cls._LMS_TASK_MANAGER.get()
110+
except LookupError:
111+
err_msg = "No async task manager active in current context"
112+
raise LMStudioRuntimeError(err_msg) from None
113+
114+
def ensure_running_in_task_loop(self) -> None:
103115
this_loop = self._event_loop
104116
if this_loop is None:
105117
# Task manager isn't active -> no coroutine can be running in it
106-
if allow_inactive:
107-
# No exception, but indicate the task manager isn't actually running
108-
return False
109-
raise RuntimeError(f"{self!r} is currently inactive.")
118+
raise LMStudioRuntimeError(f"{self!r} is currently inactive.")
110119
try:
111120
running_loop = asyncio.get_running_loop()
112121
except RuntimeError:
@@ -116,12 +125,27 @@ def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
116125
if running_loop is not this_loop:
117126
err_details = f"Expected: {this_loop!r} Running: {running_loop!r}"
118127
err_msg = f"{self!r} is running in a different event loop ({err_details})."
119-
raise RuntimeError(err_msg)
128+
raise LMStudioRuntimeError(err_msg)
129+
130+
def is_running_in_task_loop(self) -> bool:
131+
try:
132+
self.ensure_running_in_task_loop()
133+
except LMStudioRuntimeError:
134+
return False
120135
return True
121136

137+
def ensure_running_in_task_manager(self) -> None:
138+
# Task manager must be active in the running event loop
139+
self.ensure_running_in_task_loop()
140+
running_tm = self.get_running_task_manager()
141+
if running_tm is not self:
142+
err_details = f"Expected: {self!r} Running: {running_tm!r}"
143+
err_msg = f"Task is running in a different task manager ({err_details})."
144+
raise LMStudioRuntimeError(err_msg)
145+
122146
async def request_termination(self) -> bool:
123147
"""Request termination of the task manager from the same thread."""
124-
if not self.check_running_in_task_loop(allow_inactive=True):
148+
if not self.is_running_in_task_loop():
125149
return False
126150
if self._terminate.is_set():
127151
return False
@@ -139,7 +163,7 @@ def request_termination_threadsafe(self) -> SyncFuture[bool]:
139163

140164
async def wait_for_termination(self) -> None:
141165
"""Wait in the same thread for the task manager to indicate it has terminated."""
142-
if not self.check_running_in_task_loop(allow_inactive=True):
166+
if not self.is_running_in_task_loop():
143167
return
144168
await self._terminated.wait()
145169

@@ -163,11 +187,13 @@ def terminate_threadsafe(self) -> None:
163187
if self.request_termination_threadsafe().result():
164188
self.wait_for_termination_threadsafe()
165189

166-
def _init_event_loop(self) -> None:
190+
def _mark_as_running(self: Self) -> None:
191+
# Explicit type hint to work around https://github.com/python/mypy/issues/16871
167192
if self._event_loop is not None:
168-
raise RuntimeError()
193+
raise LMStudioRuntimeError("Async task manager is already running")
169194
self._event_loop = asyncio.get_running_loop()
170195
self._activated = True
196+
self._LMS_TASK_MANAGER.set(self)
171197
notify = self._on_activation
172198
if notify is not None:
173199
notify()
@@ -177,7 +203,7 @@ async def run_until_terminated(
177203
self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None
178204
) -> None:
179205
"""Run task manager until termination is requested."""
180-
self._init_event_loop()
206+
self._mark_as_running()
181207
# Use anyio and exceptiongroup to handle the lack of native task
182208
# and exception groups prior to Python 3.11
183209
try:
@@ -206,18 +232,15 @@ async def schedule_task(self, func: Callable[[], Awaitable[Any]]) -> None:
206232
207233
Important: task must NOT access any scoped resources from the scheduling scope.
208234
"""
209-
self.check_running_in_task_loop()
235+
self.ensure_running_in_task_loop()
210236
await self._task_queue.put(func)
211237

212238
def schedule_task_threadsafe(self, func: Callable[[], Awaitable[Any]]) -> None:
213239
"""Schedule given task in the task manager's base coroutine from any thread.
214240
215241
Important: task must NOT access any scoped resources from the scheduling scope.
216242
"""
217-
loop = self._event_loop
218-
if loop is None:
219-
raise RuntimeError(f"{self!r} is currently inactive.")
220-
asyncio.run_coroutine_threadsafe(self.schedule_task(func), loop)
243+
self.run_coroutine_threadsafe(self.schedule_task(func))
221244

222245
def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T]:
223246
"""Call given coroutine in the task manager's event loop from any thread.
@@ -280,7 +303,6 @@ def __init__(
280303
async def connect(self) -> bool:
281304
"""Connect websocket from the task manager's event loop."""
282305
task_manager = self._task_manager
283-
assert task_manager.check_running_in_task_loop()
284306
await task_manager.schedule_task(self._logged_ws_handler)
285307
await self._connection_attempted.wait()
286308
return self._ws is not None
@@ -293,7 +315,9 @@ def connect_threadsafe(self) -> bool:
293315

294316
async def disconnect(self) -> None:
295317
"""Disconnect websocket from the task manager's event loop."""
296-
assert self._task_manager.check_running_in_task_loop()
318+
self._task_manager.ensure_running_in_task_loop()
319+
# Websocket handler task may already have been cancelled,
320+
# but the closure can be requested multiple times without issue
297321
self._ws_disconnected.set()
298322
ws = self._ws
299323
if ws is None:
@@ -321,9 +345,10 @@ async def _logged_ws_handler(self) -> None:
321345
self._logger.debug("Websocket task terminated")
322346

323347
async def _handle_ws(self) -> None:
324-
assert self._task_manager.check_running_in_task_loop()
325348
resources = AsyncExitStack()
326349
try:
350+
# For reliable shutdown, handler must run entirely inside the task manager
351+
self._task_manager.ensure_running_in_task_manager()
327352
ws: AsyncWebSocketSession = await resources.enter_async_context(
328353
aconnect_ws(self._ws_url)
329354
)
@@ -370,7 +395,7 @@ def _clear_task_state() -> None:
370395

371396
async def send_json(self, message: DictObject) -> None:
372397
# This is only called if the websocket has been created
373-
assert self._task_manager.check_running_in_task_loop()
398+
self._task_manager.ensure_running_in_task_loop()
374399
ws = self._ws
375400
if ws is None:
376401
# Assume app is shutting down and the owning task has already been cancelled
@@ -396,14 +421,14 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:
396421

397422
@contextmanager
398423
def open_channel(self) -> Generator[AsyncChannelInfo, None, None]:
399-
assert self._task_manager.check_running_in_task_loop()
424+
self._task_manager.ensure_running_in_task_loop()
400425
rx_queue: RxQueue = asyncio.Queue()
401426
with self._mux.assign_channel_id(rx_queue) as call_id:
402427
yield call_id, rx_queue.get
403428

404429
@contextmanager
405430
def start_call(self) -> Generator[AsyncRemoteCallInfo, None, None]:
406-
assert self._task_manager.check_running_in_task_loop()
431+
self._task_manager.ensure_running_in_task_loop()
407432
rx_queue: RxQueue = asyncio.Queue()
408433
with self._mux.assign_call_id(rx_queue) as call_id:
409434
yield call_id, rx_queue.get
@@ -444,7 +469,9 @@ def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) ->
444469

445470
async def _receive_json(self) -> Any:
446471
# This is only called if the websocket has been created
447-
assert self._task_manager.check_running_in_task_loop()
472+
if __debug__:
473+
# This should only be called as part of the self._handle_ws task
474+
self._task_manager.ensure_running_in_task_manager()
448475
ws = self._ws
449476
if ws is None:
450477
# Assume app is shutting down and the owning task has already been cancelled
@@ -459,7 +486,9 @@ async def _receive_json(self) -> Any:
459486

460487
async def _authenticate(self) -> bool:
461488
# This is only called if the websocket has been created
462-
assert self._task_manager.check_running_in_task_loop()
489+
if __debug__:
490+
# This should only be called as part of the self._handle_ws task
491+
self._task_manager.ensure_running_in_task_manager()
463492
ws = self._ws
464493
if ws is None:
465494
# Assume app is shutting down and the owning task has already been cancelled
@@ -479,7 +508,9 @@ async def _process_next_message(self) -> bool:
479508
Returns True if a message queue was updated.
480509
"""
481510
# This is only called if the websocket has been created
482-
assert self._task_manager.check_running_in_task_loop()
511+
if __debug__:
512+
# This should only be called as part of the self._handle_ws task
513+
self._task_manager.ensure_running_in_task_manager()
483514
ws = self._ws
484515
if ws is None:
485516
# Assume app is shutting down and the owning task has already been cancelled

0 commit comments

Comments
 (0)