|
1 | 1 | """Test common client session behaviour.""" |
2 | 2 |
|
3 | 3 | import logging |
| 4 | +import os |
| 5 | + |
4 | 6 | from typing import Generator |
| 7 | +from unittest import mock |
5 | 8 |
|
6 | 9 | import pytest |
7 | 10 | from pytest import LogCaptureFixture as LogCap |
8 | 11 |
|
9 | 12 | from lmstudio import ( |
10 | 13 | AsyncClient, |
11 | 14 | Client, |
| 15 | + LMStudioValueError, |
12 | 16 | LMStudioWebsocketError, |
13 | 17 | ) |
14 | 18 | from lmstudio.async_api import ( |
15 | 19 | _AsyncLMStudioWebsocket, |
16 | 20 | _AsyncSession, |
17 | 21 | _AsyncSessionSystem, |
18 | 22 | ) |
| 23 | +from lmstudio.json_api import ClientBase |
19 | 24 | from lmstudio.sync_api import ( |
20 | 25 | SyncLMStudioWebsocket, |
21 | 26 | _SyncSession, |
|
24 | 29 | from lmstudio._ws_impl import AsyncTaskManager |
25 | 30 | from lmstudio._ws_thread import AsyncWebsocketThread |
26 | 31 |
|
| 32 | +# This API token is structurally valid |
| 33 | +_VALID_API_TOKEN = "sk-lm-abcDEF78:abcDEF7890abcDEF7890" |
| 34 | + |
| 35 | + |
| 36 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 37 | +def test_auth_message_default(client_cls: ClientBase) -> None: |
| 38 | + with mock.patch.dict(os.environ) as env: |
| 39 | + env.pop("LMSTUDIO_API_TOKEN", None) |
| 40 | + auth_message = client_cls._create_auth_from_token(None) |
| 41 | + assert auth_message["authVersion"] == 1 |
| 42 | + assert auth_message["clientIdentifier"].startswith("guest:") |
| 43 | + client_key = auth_message["clientPasskey"] |
| 44 | + assert client_key != "" |
| 45 | + assert isinstance(client_key, str) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 49 | +def test_auth_message_empty_token(client_cls: ClientBase) -> None: |
| 50 | + with mock.patch.dict(os.environ) as env: |
| 51 | + # Set a valid token in the env to ensure it is ignored |
| 52 | + env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN |
| 53 | + auth_message = client_cls._create_auth_from_token("") |
| 54 | + assert auth_message["authVersion"] == 1 |
| 55 | + assert auth_message["clientIdentifier"].startswith("guest:") |
| 56 | + client_key = auth_message["clientPasskey"] |
| 57 | + assert client_key != "" |
| 58 | + assert isinstance(client_key, str) |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 62 | +def test_auth_message_empty_token_from_env(client_cls: ClientBase) -> None: |
| 63 | + with mock.patch.dict(os.environ) as env: |
| 64 | + env["LMSTUDIO_API_TOKEN"] = "" |
| 65 | + auth_message = client_cls._create_auth_from_token(None) |
| 66 | + assert auth_message["authVersion"] == 1 |
| 67 | + assert auth_message["clientIdentifier"].startswith("guest:") |
| 68 | + client_key = auth_message["clientPasskey"] |
| 69 | + assert client_key != "" |
| 70 | + assert isinstance(client_key, str) |
| 71 | + |
| 72 | + |
| 73 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 74 | +def test_auth_message_valid_token(client_cls: ClientBase) -> None: |
| 75 | + auth_message = client_cls._create_auth_from_token(_VALID_API_TOKEN) |
| 76 | + assert auth_message["authVersion"] == 1 |
| 77 | + assert auth_message["clientIdentifier"] == "abcDEF78" |
| 78 | + assert auth_message["clientPasskey"] == "abcDEF7890abcDEF7890" |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 82 | +def test_auth_message_valid_token_from_env(client_cls: ClientBase) -> None: |
| 83 | + with mock.patch.dict(os.environ) as env: |
| 84 | + env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN |
| 85 | + auth_message = client_cls._create_auth_from_token(None) |
| 86 | + assert auth_message["authVersion"] == 1 |
| 87 | + assert auth_message["clientIdentifier"] == "abcDEF78" |
| 88 | + assert auth_message["clientPasskey"] == "abcDEF7890abcDEF7890" |
| 89 | + |
| 90 | + |
| 91 | +_INVALID_TOKENS = [ |
| 92 | + "missing-token-prefix", |
| 93 | + "sk-lm-missing-id-and-key-separator", |
| 94 | + "sk-lm-invalid_id:invalid_key", |
| 95 | + "sk-lm-idtoolong:abcDEF7890abcDEF7890", |
| 96 | + "sk-lm-abcDEF78:keytooshort", |
| 97 | +] |
| 98 | + |
| 99 | + |
| 100 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 101 | +@pytest.mark.parametrize("api_token", _INVALID_TOKENS) |
| 102 | +def test_auth_message_invalid_token(client_cls: ClientBase, api_token: str) -> None: |
| 103 | + with mock.patch.dict(os.environ) as env: |
| 104 | + env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN |
| 105 | + with pytest.raises(LMStudioValueError): |
| 106 | + client_cls._create_auth_from_token(api_token) |
| 107 | + |
| 108 | + |
| 109 | +@pytest.mark.parametrize("client_cls", [AsyncClient, Client]) |
| 110 | +@pytest.mark.parametrize("api_token", _INVALID_TOKENS) |
| 111 | +def test_auth_message_invalid_token_from_env( |
| 112 | + client_cls: ClientBase, api_token: str |
| 113 | +) -> None: |
| 114 | + with mock.patch.dict(os.environ) as env: |
| 115 | + env["LMSTUDIO_API_TOKEN"] = api_token |
| 116 | + with pytest.raises(LMStudioValueError): |
| 117 | + client_cls._create_auth_from_token(None) |
| 118 | + |
27 | 119 |
|
28 | 120 | async def check_connected_async_session(session: _AsyncSession) -> None: |
29 | 121 | assert session.connected |
@@ -160,7 +252,7 @@ def test_implicit_reconnection_sync(caplog: LogCap) -> None: |
160 | 252 | async def test_websocket_cm_async(caplog: LogCap) -> None: |
161 | 253 | caplog.set_level(logging.DEBUG) |
162 | 254 | api_host = await AsyncClient.find_default_local_api_host() |
163 | | - auth_details = AsyncClient._format_auth_message() |
| 255 | + auth_details = AsyncClient._create_auth_from_token(None) |
164 | 256 | tm = AsyncTaskManager(on_activation=None) |
165 | 257 | lmsws = _AsyncLMStudioWebsocket(tm, f"http://{api_host}/system", auth_details) |
166 | 258 | # SDK client websockets start out disconnected |
@@ -200,7 +292,7 @@ def ws_thread() -> Generator[AsyncWebsocketThread, None, None]: |
200 | 292 | def test_websocket_cm_sync(ws_thread: AsyncWebsocketThread, caplog: LogCap) -> None: |
201 | 293 | caplog.set_level(logging.DEBUG) |
202 | 294 | api_host = Client.find_default_local_api_host() |
203 | | - auth_details = Client._format_auth_message() |
| 295 | + auth_details = Client._create_auth_from_token(None) |
204 | 296 | lmsws = SyncLMStudioWebsocket(ws_thread, f"http://{api_host}/system", auth_details) |
205 | 297 | # SDK client websockets start out disconnected |
206 | 298 | assert not lmsws.connected |
|
0 commit comments