Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3900323

Browse files
Support API tokens in SDK (lmstudio-ai#163)
* Pass api_token to Client or AsyncClient constructor * Pass api_token to configure_default_client() in the convenience API * Set LMSTUDIO_API_TOKEN in the process environment --------- Co-authored-by: Nick Coghlan <[email protected]>
1 parent c6998ba commit 3900323

File tree

6 files changed

+170
-26
lines changed

6 files changed

+170
-26
lines changed

src/lmstudio/async_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,9 +1492,11 @@ async def embed(
14921492
class AsyncClient(ClientBase):
14931493
"""Async SDK client interface."""
14941494

1495-
def __init__(self, api_host: str | None = None) -> None:
1495+
def __init__(
1496+
self, api_host: str | None = None, api_token: str | None = None
1497+
) -> None:
14961498
"""Initialize API client."""
1497-
super().__init__(api_host)
1499+
super().__init__(api_host, api_token)
14981500
self._resources = AsyncExitStack()
14991501
self._sessions: dict[str, _AsyncSession] = {}
15001502
self._task_manager = AsyncTaskManager()

src/lmstudio/json_api.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import copy
1414
import inspect
1515
import json
16+
import os
17+
import re
1618
import sys
1719
import uuid
1820
import warnings
@@ -197,6 +199,12 @@
197199

198200
DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour
199201

202+
# lmstudio-js and lmstudio-python use the same API token environment variable
203+
_ENV_API_TOKEN = "LMSTUDIO_API_TOKEN"
204+
_LMS_API_TOKEN_REGEX = re.compile(
205+
r"^sk-lm-(?P<clientIdentifier>[A-Za-z0-9]{8}):(?P<clientPasskey>[A-Za-z0-9]{20})$"
206+
)
207+
200208
# Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility
201209
SendMessageAsync: TypeAlias = Callable[[DictObject], Coroutine[Any, Any, None]]
202210

@@ -401,7 +409,9 @@ def from_details(message: str, details: DictObject) -> "LMStudioServerError":
401409
if display_data:
402410
specific_error: LMStudioServerError | None = None
403411
match display_data:
404-
case {"code": "generic.noModelMatchingQuery"}:
412+
case {"code": "generic.noModelMatchingQuery"} | {
413+
"code": "generic.pathNotFound"
414+
}:
405415
specific_error = LMStudioModelNotFoundError(str(default_error))
406416
case {"code": "generic.presetNotFound"}:
407417
specific_error = LMStudioPresetNotFoundError(str(default_error))
@@ -2041,10 +2051,12 @@ def _ensure_connected(self, usage: str) -> None | NoReturn:
20412051
class ClientBase:
20422052
"""Common base class for SDK client interfaces."""
20432053

2044-
def __init__(self, api_host: str | None = None) -> None:
2054+
def __init__(
2055+
self, api_host: str | None = None, api_token: str | None = None
2056+
) -> None:
20452057
"""Initialize API client."""
20462058
self._api_host = api_host
2047-
self._auth_details = self._create_auth_message()
2059+
self._auth_details = self._create_auth_message(api_token)
20482060

20492061
@property
20502062
def api_host(self) -> str:
@@ -2087,25 +2099,57 @@ def _format_auth_message(
20872099
client_id: str | None = None, client_key: str | None = None
20882100
) -> DictObject:
20892101
"""Create an LM Studio websocket authentication message."""
2090-
# Note: authentication (in its current form) is primarily a cooperative
2102+
# Note: the authentication fields are used for two distinct purposes.
2103+
# When extracted from an API token (see _create_auth_message below),
2104+
# they are an actual authentication & authorisation mechanism.
2105+
# When generated internally by the SDK, they are instead a cooperative
20912106
# resource management mechanism that allows the server to appropriately
20922107
# manage client-scoped resources (such as temporary file handles).
2093-
# As such, the client ID and client passkey are currently more a two part
2094-
# client identifier than they are an adversarial security measure. This is
2095-
# sufficient to prevent accidental conflicts and, in combination with secure
2096-
# websocket support, would be sufficient to ensure that access to the running
2097-
# client was required to extract the auth details.
2098-
client_identifier = client_id if client_id is not None else str(uuid.uuid4())
2108+
# As such, when the API host isn't configured to require API tokens,
2109+
# the client ID and client key are more a two part client
2110+
# identifier than they are an adversarial security measure.
2111+
client_identifier = (
2112+
client_id if client_id is not None else f"guest:{str(uuid.uuid4())}"
2113+
)
20992114
client_passkey = client_key if client_key is not None else str(uuid.uuid4())
21002115
return {
21012116
"authVersion": 1,
21022117
"clientIdentifier": client_identifier,
21032118
"clientPasskey": client_passkey,
21042119
}
21052120

2106-
def _create_auth_message(self) -> DictObject:
2121+
@classmethod
2122+
def _create_auth_from_token(cls, api_token: str | None) -> DictObject:
2123+
"""Create an LM Studio websocket auth message from an API token.
2124+
2125+
If no token is given, and none is set in the environment,
2126+
falls back to generating a client scoped guest identifier
2127+
"""
2128+
if api_token is None:
2129+
api_token = os.environ.get(_ENV_API_TOKEN, None)
2130+
if api_token: # Accept empty string as equivalent to None
2131+
match = _LMS_API_TOKEN_REGEX.match(api_token.strip())
2132+
if match is None:
2133+
raise LMStudioValueError(
2134+
"The api_token argument does not look like a valid LM Studio API token.\n\n"
2135+
"LM Studio API tokens are obtained from LM Studio, and they look like this:\n"
2136+
"sk-lm-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx."
2137+
)
2138+
groups = match.groupdict()
2139+
client_identifier = groups.get("clientIdentifier")
2140+
client_passkey = groups.get("clientPasskey")
2141+
if client_identifier is None or client_passkey is None:
2142+
raise LMStudioValueError(
2143+
"Unexpected error parsing api_token: required token fields were not detected."
2144+
)
2145+
return cls._format_auth_message(client_identifier, client_passkey)
2146+
2147+
return cls._format_auth_message()
2148+
2149+
def _create_auth_message(self, api_token: str | None = None) -> DictObject:
21072150
"""Create an LM Studio websocket authentication message."""
2108-
return self._format_auth_message()
2151+
# This is an instance method purely so subclasses may override it
2152+
return self._create_auth_from_token(api_token)
21092153

21102154

21112155
TClient = TypeVar("TClient", bound=ClientBase)

src/lmstudio/plugin/runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def __init__(
103103
_AsyncSessionPlugins,
104104
)
105105

106-
def _create_auth_message(self) -> DictObject:
106+
def _create_auth_message(self, api_token: str | None = None) -> DictObject:
107107
"""Create an LM Studio websocket authentication message."""
108108
if self._client_id is None or self._client_key is None:
109-
return super()._create_auth_message()
109+
return super()._create_auth_message(api_token)
110110
# Use plugin credentials to unlock the full plugin client API
111111
return self._format_auth_message(self._client_id, self._client_key)
112112

src/lmstudio/sync_api.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,9 +1540,11 @@ def embed(
15401540
class Client(ClientBase):
15411541
"""Synchronous SDK client interface."""
15421542

1543-
def __init__(self, api_host: str | None = None) -> None:
1543+
def __init__(
1544+
self, api_host: str | None = None, api_token: str | None = None
1545+
) -> None:
15441546
"""Initialize API client."""
1545-
super().__init__(api_host)
1547+
super().__init__(api_host, api_token)
15461548
self._resources = rm = ExitStack()
15471549
self._ws_thread = ws_thread = AsyncWebsocketThread(dict(client=repr(self)))
15481550
ws_thread.start()
@@ -1699,40 +1701,43 @@ def list_loaded_models(
16991701

17001702

17011703
# Convenience API
1702-
_default_api_host = None
1704+
_default_api_host: str | None = None
1705+
_default_api_token: str | None = None
17031706
_default_client: Client | None = None
17041707

17051708

17061709
@sdk_public_api()
1707-
def configure_default_client(api_host: str) -> None:
1710+
def configure_default_client(api_host: str, api_token: str | None = None) -> None:
17081711
"""Set the server API host for the default global client (without creating the client)."""
17091712
global _default_api_host
17101713
if _default_client is not None:
17111714
raise LMStudioClientError(
1712-
"Default client is already created, cannot set its API host."
1715+
"Default client is already created, cannot set its API host or token."
17131716
)
17141717
_default_api_host = api_host
1718+
_default_api_token = api_token
17151719

17161720

17171721
@sdk_public_api()
17181722
def get_default_client(api_host: str | None = None) -> Client:
17191723
"""Get the default global client (creating it if necessary)."""
1724+
# Note: call configure_default_client() explicitly to set the API token
17201725
global _default_client
17211726
if api_host is not None:
17221727
# This will raise an exception if the client already exists
17231728
configure_default_client(api_host)
17241729
if _default_client is None:
1725-
_default_client = Client(_default_api_host)
1730+
_default_client = Client(_default_api_host, _default_api_token)
17261731
_default_client._ensure_api_host_is_valid()
17271732
return _default_client
17281733

17291734

17301735
def _reset_default_client() -> None:
17311736
# Allow the test suite to reset the client without
17321737
# having to poke directly at the module's internals
1733-
global _default_api_host, _default_client
1738+
global _default_api_host, _default_api_token, _default_client
17341739
previous_client = _default_client
1735-
_default_api_host = _default_client = None
1740+
_default_api_host = _default_api_token = _default_client = None
17361741
if previous_client is not None:
17371742
previous_client.close()
17381743

tests/test_sessions.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
"""Test common client session behaviour."""
22

33
import logging
4+
import os
5+
46
from typing import Generator
7+
from unittest import mock
58

69
import pytest
710
from pytest import LogCaptureFixture as LogCap
811

912
from lmstudio import (
1013
AsyncClient,
1114
Client,
15+
LMStudioValueError,
1216
LMStudioWebsocketError,
1317
)
1418
from lmstudio.async_api import (
1519
_AsyncLMStudioWebsocket,
1620
_AsyncSession,
1721
_AsyncSessionSystem,
1822
)
23+
from lmstudio.json_api import ClientBase
1924
from lmstudio.sync_api import (
2025
SyncLMStudioWebsocket,
2126
_SyncSession,
@@ -24,6 +29,93 @@
2429
from lmstudio._ws_impl import AsyncTaskManager
2530
from lmstudio._ws_thread import AsyncWebsocketThread
2631

32+
# This API token is structurally valid
33+
_VALID_API_TOKEN = "sk-lm-abcDEF78:abcDEF7890abcDEF7890"
34+
35+
36+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
37+
def test_auth_message_default(client_cls: ClientBase) -> None:
38+
with mock.patch.dict(os.environ) as env:
39+
env.pop("LMSTUDIO_API_TOKEN", None)
40+
auth_message = client_cls._create_auth_from_token(None)
41+
assert auth_message["authVersion"] == 1
42+
assert auth_message["clientIdentifier"].startswith("guest:")
43+
client_key = auth_message["clientPasskey"]
44+
assert client_key != ""
45+
assert isinstance(client_key, str)
46+
47+
48+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
49+
def test_auth_message_empty_token(client_cls: ClientBase) -> None:
50+
with mock.patch.dict(os.environ) as env:
51+
# Set a valid token in the env to ensure it is ignored
52+
env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN
53+
auth_message = client_cls._create_auth_from_token("")
54+
assert auth_message["authVersion"] == 1
55+
assert auth_message["clientIdentifier"].startswith("guest:")
56+
client_key = auth_message["clientPasskey"]
57+
assert client_key != ""
58+
assert isinstance(client_key, str)
59+
60+
61+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
62+
def test_auth_message_empty_token_from_env(client_cls: ClientBase) -> None:
63+
with mock.patch.dict(os.environ) as env:
64+
env["LMSTUDIO_API_TOKEN"] = ""
65+
auth_message = client_cls._create_auth_from_token(None)
66+
assert auth_message["authVersion"] == 1
67+
assert auth_message["clientIdentifier"].startswith("guest:")
68+
client_key = auth_message["clientPasskey"]
69+
assert client_key != ""
70+
assert isinstance(client_key, str)
71+
72+
73+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
74+
def test_auth_message_valid_token(client_cls: ClientBase) -> None:
75+
auth_message = client_cls._create_auth_from_token(_VALID_API_TOKEN)
76+
assert auth_message["authVersion"] == 1
77+
assert auth_message["clientIdentifier"] == "abcDEF78"
78+
assert auth_message["clientPasskey"] == "abcDEF7890abcDEF7890"
79+
80+
81+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
82+
def test_auth_message_valid_token_from_env(client_cls: ClientBase) -> None:
83+
with mock.patch.dict(os.environ) as env:
84+
env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN
85+
auth_message = client_cls._create_auth_from_token(None)
86+
assert auth_message["authVersion"] == 1
87+
assert auth_message["clientIdentifier"] == "abcDEF78"
88+
assert auth_message["clientPasskey"] == "abcDEF7890abcDEF7890"
89+
90+
91+
_INVALID_TOKENS = [
92+
"missing-token-prefix",
93+
"sk-lm-missing-id-and-key-separator",
94+
"sk-lm-invalid_id:invalid_key",
95+
"sk-lm-idtoolong:abcDEF7890abcDEF7890",
96+
"sk-lm-abcDEF78:keytooshort",
97+
]
98+
99+
100+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
101+
@pytest.mark.parametrize("api_token", _INVALID_TOKENS)
102+
def test_auth_message_invalid_token(client_cls: ClientBase, api_token: str) -> None:
103+
with mock.patch.dict(os.environ) as env:
104+
env["LMSTUDIO_API_TOKEN"] = _VALID_API_TOKEN
105+
with pytest.raises(LMStudioValueError):
106+
client_cls._create_auth_from_token(api_token)
107+
108+
109+
@pytest.mark.parametrize("client_cls", [AsyncClient, Client])
110+
@pytest.mark.parametrize("api_token", _INVALID_TOKENS)
111+
def test_auth_message_invalid_token_from_env(
112+
client_cls: ClientBase, api_token: str
113+
) -> None:
114+
with mock.patch.dict(os.environ) as env:
115+
env["LMSTUDIO_API_TOKEN"] = api_token
116+
with pytest.raises(LMStudioValueError):
117+
client_cls._create_auth_from_token(None)
118+
27119

28120
async def check_connected_async_session(session: _AsyncSession) -> None:
29121
assert session.connected
@@ -160,7 +252,7 @@ def test_implicit_reconnection_sync(caplog: LogCap) -> None:
160252
async def test_websocket_cm_async(caplog: LogCap) -> None:
161253
caplog.set_level(logging.DEBUG)
162254
api_host = await AsyncClient.find_default_local_api_host()
163-
auth_details = AsyncClient._format_auth_message()
255+
auth_details = AsyncClient._create_auth_from_token(None)
164256
tm = AsyncTaskManager(on_activation=None)
165257
lmsws = _AsyncLMStudioWebsocket(tm, f"http://{api_host}/system", auth_details)
166258
# SDK client websockets start out disconnected
@@ -200,7 +292,7 @@ def ws_thread() -> Generator[AsyncWebsocketThread, None, None]:
200292
def test_websocket_cm_sync(ws_thread: AsyncWebsocketThread, caplog: LogCap) -> None:
201293
caplog.set_level(logging.DEBUG)
202294
api_host = Client.find_default_local_api_host()
203-
auth_details = Client._format_auth_message()
295+
auth_details = Client._create_auth_from_token(None)
204296
lmsws = SyncLMStudioWebsocket(ws_thread, f"http://{api_host}/system", auth_details)
205297
# SDK client websockets start out disconnected
206298
assert not lmsws.connected

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ allowlist_externals = pytest
1919
passenv =
2020
CI
2121
LMS_*
22+
LMSTUDIO_*
2223
commands =
2324
# Even the "slow" tests aren't absurdly slow, so default to running them
2425
pytest {posargs} tests/

0 commit comments

Comments
 (0)