1515# Python 3.10 compatibility: use concurrent.futures.TimeoutError instead of the builtin
1616# In 3.11+, these are the same type, in 3.10 futures have their own timeout exception
1717from concurrent .futures import Future as SyncFuture , TimeoutError as SyncFutureTimeout
18+ from contextvars import ContextVar
1819from contextlib import AsyncExitStack , contextmanager
1920from functools import partial
2021from typing import (
2122 Any ,
2223 Awaitable ,
2324 Coroutine ,
2425 Callable ,
26+ ClassVar ,
2527 Generator ,
2628 TypeAlias ,
2729 TypeVar ,
3537from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
3638
3739from .schemas import DictObject
40+ from .sdk_api import LMStudioRuntimeError
3841from .json_api import (
3942 LMStudioWebsocket ,
4043 LMStudioWebsocketError ,
5861
5962
6063class AsyncTaskManager :
64+ _LMS_TASK_MANAGER : ClassVar [ContextVar [Self ]] = ContextVar ("_LMS_TASK_MANAGER" )
65+
6166 def __init__ (self , * , on_activation : Callable [[], Any ] | None = None ) -> None :
6267 self ._activated = False
6368 self ._event_loop : asyncio .AbstractEventLoop | None = None
@@ -98,15 +103,19 @@ async def __aexit__(self, *args: Any) -> None:
98103 with move_on_after (self .TERMINATION_TIMEOUT ):
99104 await self ._terminated .wait ()
100105
101- def check_running_in_task_loop (self , * , allow_inactive : bool = False ) -> bool :
102- """Returns if running in this manager's event loop, raises RuntimeError otherwise."""
106+ @classmethod
107+ def get_running_task_manager (cls ) -> Self :
108+ try :
109+ return cls ._LMS_TASK_MANAGER .get ()
110+ except LookupError :
111+ err_msg = "No async task manager active in current context"
112+ raise LMStudioRuntimeError (err_msg ) from None
113+
114+ def ensure_running_in_task_loop (self ) -> None :
103115 this_loop = self ._event_loop
104116 if this_loop is None :
105117 # Task manager isn't active -> no coroutine can be running in it
106- if allow_inactive :
107- # No exception, but indicate the task manager isn't actually running
108- return False
109- raise RuntimeError (f"{ self !r} is currently inactive." )
118+ raise LMStudioRuntimeError (f"{ self !r} is currently inactive." )
110119 try :
111120 running_loop = asyncio .get_running_loop ()
112121 except RuntimeError :
@@ -116,12 +125,27 @@ def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
116125 if running_loop is not this_loop :
117126 err_details = f"Expected: { this_loop !r} Running: { running_loop !r} "
118127 err_msg = f"{ self !r} is running in a different event loop ({ err_details } )."
119- raise RuntimeError (err_msg )
128+ raise LMStudioRuntimeError (err_msg )
129+
130+ def is_running_in_task_loop (self ) -> bool :
131+ try :
132+ self .ensure_running_in_task_loop ()
133+ except LMStudioRuntimeError :
134+ return False
120135 return True
121136
137+ def ensure_running_in_task_manager (self ) -> None :
138+ # Task manager must be active in the running event loop
139+ self .ensure_running_in_task_loop ()
140+ running_tm = self .get_running_task_manager ()
141+ if running_tm is not self :
142+ err_details = f"Expected: { self !r} Running: { running_tm !r} "
143+ err_msg = f"Task is running in a different task manager ({ err_details } )."
144+ raise LMStudioRuntimeError (err_msg )
145+
122146 async def request_termination (self ) -> bool :
123147 """Request termination of the task manager from the same thread."""
124- if not self .check_running_in_task_loop ( allow_inactive = True ):
148+ if not self .is_running_in_task_loop ( ):
125149 return False
126150 if self ._terminate .is_set ():
127151 return False
@@ -139,7 +163,7 @@ def request_termination_threadsafe(self) -> SyncFuture[bool]:
139163
140164 async def wait_for_termination (self ) -> None :
141165 """Wait in the same thread for the task manager to indicate it has terminated."""
142- if not self .check_running_in_task_loop ( allow_inactive = True ):
166+ if not self .is_running_in_task_loop ( ):
143167 return
144168 await self ._terminated .wait ()
145169
@@ -163,11 +187,13 @@ def terminate_threadsafe(self) -> None:
163187 if self .request_termination_threadsafe ().result ():
164188 self .wait_for_termination_threadsafe ()
165189
166- def _init_event_loop (self ) -> None :
190+ def _mark_as_running (self : Self ) -> None :
191+ # Explicit type hint to work around https://github.com/python/mypy/issues/16871
167192 if self ._event_loop is not None :
168- raise RuntimeError ( )
193+ raise LMStudioRuntimeError ( "Async task manager is already running" )
169194 self ._event_loop = asyncio .get_running_loop ()
170195 self ._activated = True
196+ self ._LMS_TASK_MANAGER .set (self )
171197 notify = self ._on_activation
172198 if notify is not None :
173199 notify ()
@@ -177,7 +203,7 @@ async def run_until_terminated(
177203 self , func : Callable [[], Coroutine [Any , Any , Any ]] | None = None
178204 ) -> None :
179205 """Run task manager until termination is requested."""
180- self ._init_event_loop ()
206+ self ._mark_as_running ()
181207 # Use anyio and exceptiongroup to handle the lack of native task
182208 # and exception groups prior to Python 3.11
183209 try :
@@ -206,18 +232,15 @@ async def schedule_task(self, func: Callable[[], Awaitable[Any]]) -> None:
206232
207233 Important: task must NOT access any scoped resources from the scheduling scope.
208234 """
209- self .check_running_in_task_loop ()
235+ self .ensure_running_in_task_loop ()
210236 await self ._task_queue .put (func )
211237
212238 def schedule_task_threadsafe (self , func : Callable [[], Awaitable [Any ]]) -> None :
213239 """Schedule given task in the task manager's base coroutine from any thread.
214240
215241 Important: task must NOT access any scoped resources from the scheduling scope.
216242 """
217- loop = self ._event_loop
218- if loop is None :
219- raise RuntimeError (f"{ self !r} is currently inactive." )
220- asyncio .run_coroutine_threadsafe (self .schedule_task (func ), loop )
243+ self .run_coroutine_threadsafe (self .schedule_task (func ))
221244
222245 def run_coroutine_threadsafe (self , coro : Coroutine [Any , Any , T ]) -> SyncFuture [T ]:
223246 """Call given coroutine in the task manager's event loop from any thread.
@@ -280,7 +303,6 @@ def __init__(
280303 async def connect (self ) -> bool :
281304 """Connect websocket from the task manager's event loop."""
282305 task_manager = self ._task_manager
283- assert task_manager .check_running_in_task_loop ()
284306 await task_manager .schedule_task (self ._logged_ws_handler )
285307 await self ._connection_attempted .wait ()
286308 return self ._ws is not None
@@ -293,7 +315,9 @@ def connect_threadsafe(self) -> bool:
293315
294316 async def disconnect (self ) -> None :
295317 """Disconnect websocket from the task manager's event loop."""
296- assert self ._task_manager .check_running_in_task_loop ()
318+ self ._task_manager .ensure_running_in_task_loop ()
319+ # Websocket handler task may already have been cancelled,
320+ # but the closure can be requested multiple times without issue
297321 self ._ws_disconnected .set ()
298322 ws = self ._ws
299323 if ws is None :
@@ -321,9 +345,10 @@ async def _logged_ws_handler(self) -> None:
321345 self ._logger .debug ("Websocket task terminated" )
322346
323347 async def _handle_ws (self ) -> None :
324- assert self ._task_manager .check_running_in_task_loop ()
325348 resources = AsyncExitStack ()
326349 try :
350+ # For reliable shutdown, handler must run entirely inside the task manager
351+ self ._task_manager .ensure_running_in_task_manager ()
327352 ws : AsyncWebSocketSession = await resources .enter_async_context (
328353 aconnect_ws (self ._ws_url )
329354 )
@@ -370,7 +395,7 @@ def _clear_task_state() -> None:
370395
371396 async def send_json (self , message : DictObject ) -> None :
372397 # This is only called if the websocket has been created
373- assert self ._task_manager .check_running_in_task_loop ()
398+ self ._task_manager .ensure_running_in_task_loop ()
374399 ws = self ._ws
375400 if ws is None :
376401 # Assume app is shutting down and the owning task has already been cancelled
@@ -396,14 +421,14 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:
396421
397422 @contextmanager
398423 def open_channel (self ) -> Generator [AsyncChannelInfo , None , None ]:
399- assert self ._task_manager .check_running_in_task_loop ()
424+ self ._task_manager .ensure_running_in_task_loop ()
400425 rx_queue : RxQueue = asyncio .Queue ()
401426 with self ._mux .assign_channel_id (rx_queue ) as call_id :
402427 yield call_id , rx_queue .get
403428
404429 @contextmanager
405430 def start_call (self ) -> Generator [AsyncRemoteCallInfo , None , None ]:
406- assert self ._task_manager .check_running_in_task_loop ()
431+ self ._task_manager .ensure_running_in_task_loop ()
407432 rx_queue : RxQueue = asyncio .Queue ()
408433 with self ._mux .assign_call_id (rx_queue ) as call_id :
409434 yield call_id , rx_queue .get
@@ -444,7 +469,9 @@ def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) ->
444469
445470 async def _receive_json (self ) -> Any :
446471 # This is only called if the websocket has been created
447- assert self ._task_manager .check_running_in_task_loop ()
472+ if __debug__ :
473+ # This should only be called as part of the self._handle_ws task
474+ self ._task_manager .ensure_running_in_task_manager ()
448475 ws = self ._ws
449476 if ws is None :
450477 # Assume app is shutting down and the owning task has already been cancelled
@@ -459,7 +486,9 @@ async def _receive_json(self) -> Any:
459486
460487 async def _authenticate (self ) -> bool :
461488 # This is only called if the websocket has been created
462- assert self ._task_manager .check_running_in_task_loop ()
489+ if __debug__ :
490+ # This should only be called as part of the self._handle_ws task
491+ self ._task_manager .ensure_running_in_task_manager ()
463492 ws = self ._ws
464493 if ws is None :
465494 # Assume app is shutting down and the owning task has already been cancelled
@@ -479,7 +508,9 @@ async def _process_next_message(self) -> bool:
479508 Returns True if a message queue was updated.
480509 """
481510 # This is only called if the websocket has been created
482- assert self ._task_manager .check_running_in_task_loop ()
511+ if __debug__ :
512+ # This should only be called as part of the self._handle_ws task
513+ self ._task_manager .ensure_running_in_task_manager ()
483514 ws = self ._ws
484515 if ws is None :
485516 # Assume app is shutting down and the owning task has already been cancelled
0 commit comments