From 14fcceb0078f0aaa2cb11aaa1e3e53132f755d50 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 22 Jan 2026 12:02:10 +0500 Subject: [PATCH 01/13] Implement pagination for /api/projects/list --- .../_internal/server/routers/projects.py | 10 +- src/dstack/_internal/server/schemas/fleets.py | 6 +- .../_internal/server/schemas/projects.py | 8 +- .../_internal/server/services/projects.py | 92 +++++++++++-------- 4 files changed, 73 insertions(+), 43 deletions(-) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index b07b7b1c62..0b4f122103 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -52,10 +52,16 @@ async def list_projects( """ if body is None: # For backward compatibility - body = ListProjectsRequest() + body = ListProjectsRequest(limit=2000) return CustomORJSONResponse( await projects.list_user_accessible_projects( - session=session, user=user, include_not_joined=body.include_not_joined + session=session, + user=user, + include_not_joined=body.include_not_joined, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, ) ) diff --git a/src/dstack/_internal/server/schemas/fleets.py b/src/dstack/_internal/server/schemas/fleets.py index ae66818ab9..3df43d12ce 100644 --- a/src/dstack/_internal/server/schemas/fleets.py +++ b/src/dstack/_internal/server/schemas/fleets.py @@ -9,10 +9,10 @@ class ListFleetsRequest(CoreModel): - project_name: Optional[str] + project_name: Optional[str] = None only_active: bool = False - prev_created_at: Optional[datetime] - prev_id: Optional[UUID] + prev_created_at: Optional[datetime] = None + prev_id: Optional[UUID] = None limit: int = Field(100, ge=0, le=100) ascending: bool = False diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py index ec05c1fb47..529f07cd2a 100644 --- a/src/dstack/_internal/server/schemas/projects.py +++ b/src/dstack/_internal/server/schemas/projects.py @@ -1,4 +1,6 @@ -from typing import Annotated, List +from datetime import datetime +from typing import Annotated, List, Optional +from uuid import UUID from pydantic import Field @@ -10,6 +12,10 @@ class ListProjectsRequest(CoreModel): include_not_joined: Annotated[ bool, Field(description="Include public projects where user is not a member") ] = True + prev_created_at: Optional[datetime] = None + prev_id: Optional[UUID] = None + limit: int = Field(2000, ge=0, le=2000) + ascending: bool = False class CreateProjectRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 937247f5a1..022c9eaa61 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -1,8 +1,9 @@ import secrets import uuid +from datetime import datetime from typing import Awaitable, Callable, List, Optional, Tuple -from sqlalchemy import delete, select, update +from sqlalchemy import and_, delete, or_, select, update from sqlalchemy import func as safunc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import QueryableAttribute, joinedload, load_only @@ -62,56 +63,73 @@ async def get_or_create_default_project( return default_project, True -async def list_user_projects( - session: AsyncSession, - user: UserModel, -) -> List[Project]: - """ - Returns projects where the user is a member or all projects for global admins. - """ - projects = await list_user_project_models( - session=session, - user=user, - ) - projects = sorted(projects, key=lambda p: p.created_at) - return [ - project_model_to_project(p, include_backends=False, include_members=False) - for p in projects - ] - - async def list_user_accessible_projects( session: AsyncSession, user: UserModel, include_not_joined: bool, + prev_created_at: Optional[datetime], + prev_id: Optional[uuid.UUID], + limit: int, + ascending: bool, ) -> List[Project]: """ Returns all projects accessible to the user: + - All projects for global admins - Projects where user is a member (public or private) - if `include_not_joined`: Public projects where user is NOT a member """ - if user.global_role == GlobalRole.ADMIN: - projects = await list_project_models(session=session) - else: - projects = await list_member_project_models(session=session, user=user) + stmt = select(ProjectModel).where(ProjectModel.deleted == False) + if user.global_role != GlobalRole.ADMIN: + stmt = stmt.outerjoin( + MemberModel, + onclause=and_( + MemberModel.project_id == ProjectModel.id, + MemberModel.user_id == user.id, + ), + ) if include_not_joined: - public_projects = await list_public_non_member_project_models( - session=session, user=user + stmt = stmt.where( + or_( + ProjectModel.is_public == True, + MemberModel.user_id.is_not(None), + ) ) - projects += public_projects - - projects = sorted(projects, key=lambda p: p.created_at) - return [ - project_model_to_project(p, include_backends=False, include_members=False) - for p in projects - ] - - -async def list_projects(session: AsyncSession) -> List[Project]: - projects = await list_project_models(session=session) + else: + stmt = stmt.where(MemberModel.user_id.is_not(None)) + pagination_filters = [] + if prev_created_at is not None: + if ascending: + if prev_id is None: + pagination_filters.append(ProjectModel.created_at > prev_created_at) + else: + pagination_filters.append( + or_( + ProjectModel.created_at > prev_created_at, + and_( + ProjectModel.created_at == prev_created_at, ProjectModel.id < prev_id + ), + ) + ) + else: + if prev_id is None: + pagination_filters.append(ProjectModel.created_at < prev_created_at) + else: + pagination_filters.append( + or_( + ProjectModel.created_at < prev_created_at, + and_( + ProjectModel.created_at == prev_created_at, ProjectModel.id > prev_id + ), + ) + ) + order_by = (ProjectModel.created_at.desc(), ProjectModel.id) + if ascending: + order_by = (ProjectModel.created_at.asc(), ProjectModel.id.desc()) + res = await session.execute(stmt.where(*pagination_filters).order_by(*order_by).limit(limit)) + project_models = res.scalars().all() return [ project_model_to_project(p, include_backends=False, include_members=False) - for p in projects + for p in project_models ] From 3f80d9c3b57a2d688995b645cdf3f370af94297d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 22 Jan 2026 12:20:43 +0500 Subject: [PATCH 02/13] Test /api/projects/list pagination --- .../_internal/server/routers/projects.py | 2 +- .../_internal/server/services/projects.py | 2 +- .../_internal/server/routers/test_projects.py | 128 +++++++++++++++++- 3 files changed, 125 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 0b4f122103..87f4915534 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -43,7 +43,7 @@ async def list_projects( user: UserModel = Depends(Authenticated()), ): """ - Returns projects visible to the user, sorted by ascending `created_at`. + Returns projects visible to the user. Returns all accessible projects (member projects for regular users, all non-deleted projects for global admins, plus public projects if `include_not_joined` is `True`). diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 022c9eaa61..fecd33ff56 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -126,7 +126,7 @@ async def list_user_accessible_projects( if ascending: order_by = (ProjectModel.created_at.asc(), ProjectModel.id.desc()) res = await session.execute(stmt.where(*pagination_filters).order_by(*order_by).limit(limit)) - project_models = res.scalars().all() + project_models = res.unique().scalars().all() return [ project_model_to_project(p, include_backends=False, include_members=False) for p in project_models diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 5c9ef42ffb..79e38762c3 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -15,7 +15,6 @@ from dstack._internal.server.services.permissions import DefaultPermissions from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( - create_backend, create_fleet, create_project, create_repo, @@ -65,10 +64,6 @@ async def test_returns_projects(self, test_db, session: AsyncSession, client: As await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.ADMIN ) - await create_backend( - session=session, - project_id=project.id, - ) response = await client.post("/api/projects/list", headers=get_auth_headers(user.token)) assert response.status_code in [200] assert response.json() == [ @@ -216,6 +211,129 @@ async def test_member_sees_both_public_and_private_projects( assert "public_project" in project_names assert "private_project" in project_names + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_paginated_projects( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user( + session=session, + created_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + project1 = await create_project( + session=session, + name="project1", + owner=user, + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + project2 = await create_project( + session=session, + name="project2", + owner=user, + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + project3 = await create_project( + session=session, + name="project3", + owner=user, + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + response = await client.post( + "/api/projects/list", + headers=get_auth_headers(user.token), + json={"limit": 1}, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "project_id": str(project3.id), + "project_name": project3.name, + "owner": { + "id": str(user.id), + "username": user.name, + "created_at": "2023-01-02T03:00:00+00:00", + "global_role": user.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + }, + "created_at": "2023-01-02T03:05:00+00:00", + "backends": [], + "members": [], + "is_public": False, + } + ] + response = await client.post( + "/api/projects/list", + headers=get_auth_headers(user.token), + json={ + "prev_created_at": "2023-01-02T03:05:00+00:00", + "prev_id": str(project3.id), + "limit": 1, + }, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "project_id": str(project2.id), + "project_name": project2.name, + "owner": { + "id": str(user.id), + "username": user.name, + "created_at": "2023-01-02T03:00:00+00:00", + "global_role": user.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + }, + "created_at": "2023-01-02T03:05:00+00:00", + "backends": [], + "members": [], + "is_public": False, + } + ] + response = await client.post( + "/api/projects/list", + headers=get_auth_headers(user.token), + json={ + "prev_created_at": "2023-01-02T03:05:00+00:00", + "prev_id": str(project2.id), + "limit": 1, + }, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "project_id": str(project1.id), + "project_name": project1.name, + "owner": { + "id": str(user.id), + "username": user.name, + "created_at": "2023-01-02T03:00:00+00:00", + "global_role": user.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + }, + "created_at": "2023-01-02T03:04:00+00:00", + "backends": [], + "members": [], + "is_public": False, + } + ] + + +class TestListOnlyNoFleets: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_only_no_fleets_returns_projects_without_active_fleets( From 7f21ffeb7674f991fad4986bd8fcfeed8de152f8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 22 Jan 2026 15:26:51 +0500 Subject: [PATCH 03/13] Update ProjectsAPIClient.list() --- .../_internal/server/routers/projects.py | 2 +- .../_internal/server/schemas/projects.py | 2 +- src/dstack/api/server/_projects.py | 30 +++++++++++++++---- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 87f4915534..00e59e6699 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -52,7 +52,7 @@ async def list_projects( """ if body is None: # For backward compatibility - body = ListProjectsRequest(limit=2000) + body = ListProjectsRequest() return CustomORJSONResponse( await projects.list_user_accessible_projects( session=session, diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py index 529f07cd2a..c9695fbb18 100644 --- a/src/dstack/_internal/server/schemas/projects.py +++ b/src/dstack/_internal/server/schemas/projects.py @@ -14,7 +14,7 @@ class ListProjectsRequest(CoreModel): ] = True prev_created_at: Optional[datetime] = None prev_id: Optional[UUID] = None - limit: int = Field(2000, ge=0, le=2000) + limit: int = Field(default=2000, ge=0, le=2000) ascending: bool = False diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py index 31bdc3b2de..70f9435303 100644 --- a/src/dstack/api/server/_projects.py +++ b/src/dstack/api/server/_projects.py @@ -1,4 +1,7 @@ -from typing import List +import json +from datetime import datetime +from typing import Any, List, Optional +from uuid import UUID from pydantic import parse_obj_as @@ -8,7 +11,6 @@ AddProjectMemberRequest, CreateProjectRequest, DeleteProjectsRequest, - ListProjectsRequest, MemberSetting, RemoveProjectMemberRequest, SetProjectMembersRequest, @@ -17,9 +19,27 @@ class ProjectsAPIClient(APIClientGroup): - def list(self, include_not_joined: bool = True) -> List[Project]: - body = ListProjectsRequest(include_not_joined=include_not_joined) - resp = self._request("/api/projects/list", body=body.json()) + def list( + self, + include_not_joined: bool = True, + prev_created_at: Optional[datetime] = None, + prev_id: Optional[UUID] = None, + limit: Optional[int] = None, + ascending: Optional[bool] = None, + ) -> List[Project]: + # Excluding None fields for backward compatibility with 0.20 servers. + body: dict[str, Any] = { + "include_not_joined": include_not_joined, + } + if prev_created_at is not None: + body["prev_created_at"] = prev_created_at + if prev_id is not None: + body["prev_id"] = prev_id + if limit is not None: + body["limit"] = limit + if ascending is not None: + body["ascending"] = ascending + resp = self._request("/api/projects/list", body=json.dumps(body)) return parse_obj_as(List[Project.__response__], resp.json()) def create(self, project_name: str, is_public: bool = False) -> Project: From 1dd378f1be231a21b331bdbfbef5f5185a8e9995 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 22 Jan 2026 16:19:20 +0500 Subject: [PATCH 04/13] Return projects total_count --- src/dstack/_internal/core/models/projects.py | 12 +++- .../_internal/server/routers/projects.py | 5 +- .../_internal/server/schemas/projects.py | 33 +++++++++-- .../_internal/server/services/projects.py | 16 ++++- src/dstack/api/server/_projects.py | 17 ++++-- .../_internal/server/routers/test_projects.py | 58 ++++++++++++++++++- 6 files changed, 123 insertions(+), 18 deletions(-) diff --git a/src/dstack/_internal/core/models/projects.py b/src/dstack/_internal/core/models/projects.py index 9748ece1ae..63adf91962 100644 --- a/src/dstack/_internal/core/models/projects.py +++ b/src/dstack/_internal/core/models/projects.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Union from pydantic import UUID4 @@ -28,6 +28,16 @@ class Project(CoreModel): is_public: bool = False +class ProjectsInfoList(CoreModel): + total_count: Optional[int] = None + projects: List[Project] + + +# For backward compatibility with 0.20 clients, endpoints return `List[Project]` if `total_count` is None. +# TODO: Replace with ProjectsInfoList in 0.21. +ProjectsInfoListOrProjectsList = Union[List[Project], ProjectsInfoList] + + class ProjectHookConfig(CoreModel): """ This class can be inherited to extend the project creation configuration passed to the hooks. diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 00e59e6699..e6971e242a 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.projects import Project +from dstack._internal.core.models.projects import Project, ProjectsInfoListOrProjectsList from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.projects import ( @@ -36,7 +36,7 @@ ) -@router.post("/list", response_model=List[Project]) +@router.post("/list", response_model=ProjectsInfoListOrProjectsList) async def list_projects( body: Optional[ListProjectsRequest] = None, session: AsyncSession = Depends(get_session), @@ -58,6 +58,7 @@ async def list_projects( session=session, user=user, include_not_joined=body.include_not_joined, + return_total_count=body.return_total_count, prev_created_at=body.prev_created_at, prev_id=body.prev_id, limit=body.limit, diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py index c9695fbb18..bc54ef52df 100644 --- a/src/dstack/_internal/server/schemas/projects.py +++ b/src/dstack/_internal/server/schemas/projects.py @@ -10,12 +10,35 @@ class ListProjectsRequest(CoreModel): include_not_joined: Annotated[ - bool, Field(description="Include public projects where user is not a member") + bool, Field(description="Include public projects where user is not a member.") ] = True - prev_created_at: Optional[datetime] = None - prev_id: Optional[UUID] = None - limit: int = Field(default=2000, ge=0, le=2000) - ascending: bool = False + return_total_count: Annotated[ + bool, Field(description="Return `total_count` with the total number of projects.") + ] = False + prev_created_at: Annotated[ + Optional[datetime], + Field( + description="Paginate projects by specifying `created_at` of the last (first) project in previous batch for descending (ascending)." + ), + ] = None + prev_id: Annotated[ + Optional[UUID], + Field( + description=( + "Paginate projects by specifying `id` of the last (first) project in previous batch for descending (ascending)." + " Must be used together with `prev_created_at`." + ) + ), + ] = None + limit: Annotated[ + int, Field(ge=0, le=2000, description="Limit number of projects returned.") + ] = 2000 + ascending: Annotated[ + bool, + Field( + description="Return projects sorted by `created_at` in ascending order. Defaults to descending." + ), + ] = False class CreateProjectRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index fecd33ff56..a8313a843f 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Awaitable, Callable, List, Optional, Tuple -from sqlalchemy import and_, delete, or_, select, update +from sqlalchemy import and_, delete, func, literal_column, or_, select, update from sqlalchemy import func as safunc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import QueryableAttribute, joinedload, load_only @@ -20,6 +20,8 @@ MemberPermissions, Project, ProjectHookConfig, + ProjectsInfoList, + ProjectsInfoListOrProjectsList, ) from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole @@ -67,11 +69,12 @@ async def list_user_accessible_projects( session: AsyncSession, user: UserModel, include_not_joined: bool, + return_total_count: bool, prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, ascending: bool, -) -> List[Project]: +) -> ProjectsInfoListOrProjectsList: """ Returns all projects accessible to the user: - All projects for global admins @@ -125,12 +128,19 @@ async def list_user_accessible_projects( order_by = (ProjectModel.created_at.desc(), ProjectModel.id) if ascending: order_by = (ProjectModel.created_at.asc(), ProjectModel.id.desc()) + total_count = None + if return_total_count: + res = await session.execute(stmt.with_only_columns(func.count(literal_column("1")))) + total_count = res.scalar_one() res = await session.execute(stmt.where(*pagination_filters).order_by(*order_by).limit(limit)) project_models = res.unique().scalars().all() - return [ + projects = [ project_model_to_project(p, include_backends=False, include_members=False) for p in project_models ] + if total_count is None: + return projects + return ProjectsInfoList(total_count=total_count, projects=projects) async def get_project_by_name( diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py index 70f9435303..496ba71f13 100644 --- a/src/dstack/api/server/_projects.py +++ b/src/dstack/api/server/_projects.py @@ -5,7 +5,11 @@ from pydantic import parse_obj_as -from dstack._internal.core.models.projects import Project +from dstack._internal.core.models.projects import ( + Project, + ProjectsInfoList, + ProjectsInfoListOrProjectsList, +) from dstack._internal.core.models.users import ProjectRole from dstack._internal.server.schemas.projects import ( AddProjectMemberRequest, @@ -22,15 +26,18 @@ class ProjectsAPIClient(APIClientGroup): def list( self, include_not_joined: bool = True, + return_total_count: Optional[bool] = None, prev_created_at: Optional[datetime] = None, prev_id: Optional[UUID] = None, limit: Optional[int] = None, ascending: Optional[bool] = None, - ) -> List[Project]: - # Excluding None fields for backward compatibility with 0.20 servers. + ) -> ProjectsInfoListOrProjectsList: + # Passing only non-None fields for backward compatibility with 0.20 servers. body: dict[str, Any] = { "include_not_joined": include_not_joined, } + if return_total_count is not None: + body["return_total_count"] = return_total_count if prev_created_at is not None: body["prev_created_at"] = prev_created_at if prev_id is not None: @@ -40,7 +47,9 @@ def list( if ascending is not None: body["ascending"] = ascending resp = self._request("/api/projects/list", body=json.dumps(body)) - return parse_obj_as(List[Project.__response__], resp.json()) + if return_total_count is None: + return parse_obj_as(List[Project.__response__], resp.json()) + return parse_obj_as(ProjectsInfoList, resp.json()) def create(self, project_name: str, is_public: bool = False) -> Project: body = CreateProjectRequest(project_name=project_name, is_public=is_public) diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 79e38762c3..2da65c7d5c 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -237,7 +237,7 @@ async def test_returns_paginated_projects( session=session, name="project3", owner=user, - created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + created_at=datetime(2023, 1, 2, 3, 6, tzinfo=timezone.utc), ) response = await client.post( "/api/projects/list", @@ -261,7 +261,7 @@ async def test_returns_paginated_projects( }, "ssh_public_key": None, }, - "created_at": "2023-01-02T03:05:00+00:00", + "created_at": "2023-01-02T03:06:00+00:00", "backends": [], "members": [], "is_public": False, @@ -271,7 +271,7 @@ async def test_returns_paginated_projects( "/api/projects/list", headers=get_auth_headers(user.token), json={ - "prev_created_at": "2023-01-02T03:05:00+00:00", + "prev_created_at": "2023-01-02T03:06:00+00:00", "prev_id": str(project3.id), "limit": 1, }, @@ -332,6 +332,58 @@ async def test_returns_paginated_projects( } ] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_total_count(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user( + session=session, + created_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + await create_project( + session=session, + name="project1", + owner=user, + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + project3 = await create_project( + session=session, + name="project3", + owner=user, + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + response = await client.post( + "/api/projects/list", + headers=get_auth_headers(user.token), + json={"limit": 1, "return_total_count": True}, + ) + assert response.status_code == 200 + assert response.json() == { + "total_count": 2, + "projects": [ + { + "project_id": str(project3.id), + "project_name": project3.name, + "owner": { + "id": str(user.id), + "username": user.name, + "created_at": "2023-01-02T03:00:00+00:00", + "global_role": user.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + }, + "created_at": "2023-01-02T03:05:00+00:00", + "backends": [], + "members": [], + "is_public": False, + } + ], + } + class TestListOnlyNoFleets: @pytest.mark.asyncio From add86dede46f0a61c7d2abb168ecd72fed68661e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 22 Jan 2026 16:43:53 +0500 Subject: [PATCH 05/13] Fix APIClient.list() --- src/dstack/api/server/_projects.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py index 496ba71f13..a85e73b121 100644 --- a/src/dstack/api/server/_projects.py +++ b/src/dstack/api/server/_projects.py @@ -39,17 +39,18 @@ def list( if return_total_count is not None: body["return_total_count"] = return_total_count if prev_created_at is not None: - body["prev_created_at"] = prev_created_at + body["prev_created_at"] = prev_created_at.isoformat() if prev_id is not None: - body["prev_id"] = prev_id + body["prev_id"] = str(prev_id) if limit is not None: body["limit"] = limit if ascending is not None: body["ascending"] = ascending resp = self._request("/api/projects/list", body=json.dumps(body)) - if return_total_count is None: - return parse_obj_as(List[Project.__response__], resp.json()) - return parse_obj_as(ProjectsInfoList, resp.json()) + resp_json = resp.json() + if isinstance(resp_json, list): + return parse_obj_as(List[Project.__response__], resp_json) + return parse_obj_as(ProjectsInfoList, resp_json) def create(self, project_name: str, is_public: bool = False) -> Project: body = CreateProjectRequest(project_name=project_name, is_public=is_public) From c16dbe32b9348a2fb3fd1702e5faca3d173a1da3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 22 Jan 2026 17:19:00 +0500 Subject: [PATCH 06/13] Test APIClient.list() --- src/tests/api/common.py | 29 ++++++++++++++++ src/tests/api/test_projects.py | 60 ++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 src/tests/api/common.py create mode 100644 src/tests/api/test_projects.py diff --git a/src/tests/api/common.py b/src/tests/api/common.py new file mode 100644 index 0000000000..c453b6afee --- /dev/null +++ b/src/tests/api/common.py @@ -0,0 +1,29 @@ +import json +from dataclasses import dataclass, field +from typing import Any, Optional + +import requests + + +@dataclass +class RequestRecorder: + payload: Any + last_path: Optional[str] = None + last_body: Optional[str] = None + last_kwargs: dict[str, Any] = field(default_factory=dict) + + def __call__( + self, + path: str, + body: Optional[str] = None, + raise_for_status: bool = True, + method: str = "POST", + **kwargs, + ) -> requests.Response: + self.last_path = path + self.last_body = body + self.last_kwargs = kwargs + resp = requests.Response() + resp.status_code = 200 + resp._content = json.dumps(self.payload).encode("utf-8") + return resp diff --git a/src/tests/api/test_projects.py b/src/tests/api/test_projects.py new file mode 100644 index 0000000000..d84608750c --- /dev/null +++ b/src/tests/api/test_projects.py @@ -0,0 +1,60 @@ +import json +import logging +from datetime import datetime, timezone +from uuid import UUID + +from dstack.api.server._projects import ProjectsAPIClient +from tests.api.common import RequestRecorder + +PROJECT_PAYLOAD = { + "project_id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", + "project_name": "p", + "owner": { + "id": "2b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", + "username": "u", + "created_at": "2023-01-02T03:04:00+00:00", + "global_role": "user", + "email": None, + "active": True, + "permissions": {"can_create_projects": True}, + "ssh_public_key": None, + }, + "created_at": "2023-01-02T03:04:00+00:00", + "backends": [], + "members": [], + "is_public": False, +} + + +class TestProjectsList: + def test_projects_list_serializes_pagination_and_parses_total_count(self): + request = RequestRecorder(payload={"total_count": 1, "projects": [PROJECT_PAYLOAD]}) + client = ProjectsAPIClient(_request=request, _logger=logging.getLogger("test")) + dt = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + pid = UUID("3b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") + + result = client.list( + return_total_count=True, + prev_created_at=dt, + prev_id=pid, + limit=1, + ascending=True, + ) + + payload = json.loads(request.last_body) + assert request.last_path == "/api/projects/list" + assert payload["include_not_joined"] is True + assert payload["return_total_count"] is True + assert payload["prev_created_at"] == dt.isoformat() + assert payload["prev_id"] == str(pid) + assert payload["limit"] == 1 + assert payload["ascending"] is True + assert result.total_count == 1 + assert result.projects[0].project_name == "p" + + def test_projects_list_parses_list_response(self): + request = RequestRecorder(payload=[PROJECT_PAYLOAD]) + client = ProjectsAPIClient(_request=request, _logger=logging.getLogger("test")) + result = client.list() + assert isinstance(result, list) + assert result[0].project_name == PROJECT_PAYLOAD["project_name"] From 3b7b4d3f6efdfa253e068483ce76719f0c74dcdb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 11:04:16 +0500 Subject: [PATCH 07/13] Implement pagination for /api/users/list --- src/dstack/_internal/core/models/users.py | 12 +- src/dstack/_internal/server/routers/users.py | 31 ++++- src/dstack/_internal/server/schemas/users.py | 41 ++++++- src/dstack/_internal/server/services/users.py | 78 ++++++++++-- src/dstack/api/server/_users.py | 47 ++++++- .../_internal/server/routers/test_users.py | 116 +++++++++++++++++- 6 files changed, 304 insertions(+), 21 deletions(-) diff --git a/src/dstack/_internal/core/models/users.py b/src/dstack/_internal/core/models/users.py index 99fb8823e3..8e70e092d6 100644 --- a/src/dstack/_internal/core/models/users.py +++ b/src/dstack/_internal/core/models/users.py @@ -1,6 +1,6 @@ import enum from datetime import datetime -from typing import Optional +from typing import List, Optional, Union from pydantic import UUID4 @@ -42,6 +42,16 @@ class UserWithCreds(User): ssh_private_key: Optional[str] = None +class UsersInfoList(CoreModel): + total_count: Optional[int] = None + users: List[User] + + +# For backward compatibility with 0.20 clients, endpoints return `List[User]` if `total_count` is None. +# TODO: Replace with UsersInfoList in 0.21. +UsersInfoListOrUsersList = Union[List[User], UsersInfoList] + + class UserHookConfig(CoreModel): """ This class can be inherited to extend the user creation configuration passed to the hooks. diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 6030416f50..4c20038d1d 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -1,16 +1,17 @@ -from typing import List +from typing import Optional from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.errors import ResourceNotExistsError -from dstack._internal.core.models.users import User, UserWithCreds +from dstack._internal.core.models.users import User, UsersInfoListOrUsersList, UserWithCreds from dstack._internal.server.db import get_session from dstack._internal.server.models import UserModel from dstack._internal.server.schemas.users import ( CreateUserRequest, DeleteUsersRequest, GetUserRequest, + ListUsersRequest, RefreshTokenRequest, UpdateUserRequest, ) @@ -28,12 +29,34 @@ ) -@router.post("/list", response_model=List[User]) +@router.post("/list", response_model=UsersInfoListOrUsersList) async def list_users( + body: Optional[ListUsersRequest] = None, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): - return CustomORJSONResponse(await users.list_users_for_user(session=session, user=user)) + """ + Returns users visible to the user, sorted by descending `created_at`. + + Admins see all non-deleted users. Non-admins only see themselves. + + The results are paginated. To get the next page, pass `created_at` and `id` of + the last user from the previous page as `prev_created_at` and `prev_id`. + """ + if body is None: + # For backward compatibility + body = ListUsersRequest() + return CustomORJSONResponse( + await users.list_users_for_user( + session=session, + user=user, + return_total_count=body.return_total_count, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) + ) @router.post("/get_my_user", response_model=UserWithCreds) diff --git a/src/dstack/_internal/server/schemas/users.py b/src/dstack/_internal/server/schemas/users.py index 6579d96572..17f3aca236 100644 --- a/src/dstack/_internal/server/schemas/users.py +++ b/src/dstack/_internal/server/schemas/users.py @@ -1,9 +1,48 @@ -from typing import List, Optional +from datetime import datetime +from typing import Annotated, List, Optional +from uuid import UUID + +from pydantic import Field from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.users import GlobalRole +class ListUsersRequest(CoreModel): + return_total_count: Annotated[ + bool, Field(description="Return `total_count` with the total number of users.") + ] = False + prev_created_at: Annotated[ + Optional[datetime], + Field( + description=( + "Paginate users by specifying `created_at` of the last (first) user in previous " + "batch for descending (ascending)." + ) + ), + ] = None + prev_id: Annotated[ + Optional[UUID], + Field( + description=( + "Paginate users by specifying `id` of the last (first) user in previous batch " + "for descending (ascending). Must be used together with `prev_created_at`." + ) + ), + ] = None + limit: Annotated[int, Field(ge=0, le=2000, description="Limit number of users returned.")] = ( + 2000 + ) + ascending: Annotated[ + bool, + Field( + description=( + "Return users sorted by `created_at` in ascending order. Defaults to descending." + ) + ), + ] = False + + class GetUserRequest(CoreModel): username: str diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 3f8f6afa7b..a35ba9c694 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -5,9 +5,10 @@ import uuid from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from datetime import datetime from typing import Awaitable, Callable, List, Optional, Tuple -from sqlalchemy import delete, select +from sqlalchemy import and_, delete, literal_column, or_, select from sqlalchemy import func as safunc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import load_only @@ -21,6 +22,8 @@ User, UserHookConfig, UserPermissions, + UsersInfoList, + UsersInfoListOrUsersList, UserTokenCreds, UserWithCreds, ) @@ -55,23 +58,82 @@ async def get_or_create_admin_user(session: AsyncSession) -> Tuple[UserModel, bo async def list_users_for_user( session: AsyncSession, user: UserModel, -) -> List[User]: + return_total_count: bool, + prev_created_at: Optional[datetime], + prev_id: Optional[uuid.UUID], + limit: int, + ascending: bool, +) -> UsersInfoListOrUsersList: if user.global_role == GlobalRole.ADMIN: - return await list_all_users(session=session) - return [user_model_to_user(user)] + return await list_all_users( + session=session, + include_deleted=False, + return_total_count=return_total_count, + prev_created_at=prev_created_at, + prev_id=prev_id, + limit=limit, + ascending=ascending, + ) + users = [user_model_to_user(user)] + if return_total_count: + return UsersInfoList(total_count=len(users), users=users) + return users async def list_all_users( session: AsyncSession, include_deleted: bool = False, -) -> List[User]: + return_total_count: bool = False, + prev_created_at: Optional[datetime] = None, + prev_id: Optional[uuid.UUID] = None, + limit: int = 2000, + ascending: bool = False, +) -> UsersInfoListOrUsersList: filters = [] if not include_deleted: filters.append(UserModel.deleted == False) - res = await session.execute(select(UserModel).where(*filters)) + stmt = select(UserModel).where(*filters) + pagination_filters = [] + if prev_created_at is not None: + if ascending: + if prev_id is None: + pagination_filters.append(UserModel.created_at > prev_created_at) + else: + pagination_filters.append( + or_( + UserModel.created_at > prev_created_at, + and_( + UserModel.created_at == prev_created_at, + UserModel.id < prev_id, + ), + ) + ) + else: + if prev_id is None: + pagination_filters.append(UserModel.created_at < prev_created_at) + else: + pagination_filters.append( + or_( + UserModel.created_at < prev_created_at, + and_( + UserModel.created_at == prev_created_at, + UserModel.id > prev_id, + ), + ) + ) + order_by = (UserModel.created_at.desc(), UserModel.id) + if ascending: + order_by = (UserModel.created_at.asc(), UserModel.id.desc()) + total_count = None + if return_total_count: + res = await session.execute(stmt.with_only_columns(safunc.count(literal_column("1")))) + total_count = res.scalar_one() + res = await session.execute(stmt.where(*pagination_filters).order_by(*order_by).limit(limit)) user_models = res.scalars().all() - user_models = sorted(user_models, key=lambda u: u.created_at) - return [user_model_to_user(u) for u in user_models] + users = [user_model_to_user(u) for u in user_models] + if total_count is None: + return users + return UsersInfoList(total_count=total_count, users=users) async def get_user_with_creds_by_name( diff --git a/src/dstack/api/server/_users.py b/src/dstack/api/server/_users.py index 6082636c4b..a60b45a11d 100644 --- a/src/dstack/api/server/_users.py +++ b/src/dstack/api/server/_users.py @@ -1,8 +1,18 @@ -from typing import List +import json +from datetime import datetime +from typing import Any, List, Optional +from uuid import UUID from pydantic import parse_obj_as +from pydantic.json import pydantic_encoder -from dstack._internal.core.models.users import GlobalRole, User, UserWithCreds +from dstack._internal.core.models.users import ( + GlobalRole, + User, + UsersInfoList, + UsersInfoListOrUsersList, + UserWithCreds, +) from dstack._internal.server.schemas.users import ( CreateUserRequest, GetUserRequest, @@ -13,9 +23,36 @@ class UsersAPIClient(APIClientGroup): - def list(self) -> List[User]: - resp = self._request("/api/users/list") - return parse_obj_as(List[User.__response__], resp.json()) + def list( + self, + return_total_count: Optional[bool] = None, + prev_created_at: Optional[datetime] = None, + prev_id: Optional[UUID] = None, + limit: Optional[int] = None, + ascending: Optional[bool] = None, + ) -> UsersInfoListOrUsersList: + # Passing only non-None fields for backward compatibility with 0.20 servers. + body: dict[str, Any] = {} + if return_total_count is not None: + body["return_total_count"] = return_total_count + if prev_created_at is not None: + body["prev_created_at"] = prev_created_at + if prev_id is not None: + body["prev_id"] = prev_id + if limit is not None: + body["limit"] = limit + if ascending is not None: + body["ascending"] = ascending + if body: + resp = self._request( + "/api/users/list", body=json.dumps(body, default=pydantic_encoder) + ) + else: + resp = self._request("/api/users/list") + resp_json = resp.json() + if isinstance(resp_json, list): + return parse_obj_as(List[User.__response__], resp_json) + return parse_obj_as(UsersInfoList, resp_json) def get_my_user(self) -> UserWithCreds: resp = self._request("/api/users/get_my_user") diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index 6c5b373a63..2e62f17a4f 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -39,7 +39,7 @@ async def test_admins_see_all_non_deleted_users( admin = await create_user( session=session, name="admin", - created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), global_role=GlobalRole.ADMIN, ) other_user = await create_user( @@ -61,7 +61,7 @@ async def test_admins_see_all_non_deleted_users( { "id": str(admin.id), "username": admin.name, - "created_at": "2023-01-02T03:04:00+00:00", + "created_at": "2023-01-02T03:05:00+00:00", "global_role": admin.global_role, "email": None, "active": True, @@ -84,6 +84,118 @@ async def test_admins_see_all_non_deleted_users( }, ] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_total_count(self, test_db, session: AsyncSession, client: AsyncClient): + admin = await create_user( + session=session, + name="admin", + created_at=datetime(2023, 1, 2, 3, 6, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + await create_user( + session=session, + name="user_one", + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + ) + await create_user( + session=session, + name="deleted_user", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + deleted=True, + ) + response = await client.post( + "/api/users/list", + headers=get_auth_headers(admin.token), + json={"limit": 1, "return_total_count": True}, + ) + assert response.status_code == 200 + assert response.json() == { + "total_count": 2, + "users": [ + { + "id": str(admin.id), + "username": admin.name, + "created_at": "2023-01-02T03:06:00+00:00", + "global_role": admin.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + } + ], + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_paginates_results(self, test_db, session: AsyncSession, client: AsyncClient): + admin = await create_user( + session=session, + name="admin", + created_at=datetime(2023, 1, 2, 3, 6, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + user_one = await create_user( + session=session, + name="user_one", + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + ) + await create_user( + session=session, + name="user_two", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + ) + response = await client.post( + "/api/users/list", + headers=get_auth_headers(admin.token), + json={"limit": 1}, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "id": str(admin.id), + "username": admin.name, + "created_at": "2023-01-02T03:06:00+00:00", + "global_role": admin.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + } + ] + response = await client.post( + "/api/users/list", + headers=get_auth_headers(admin.token), + json={ + "prev_created_at": "2023-01-02T03:06:00+00:00", + "prev_id": str(admin.id), + "limit": 1, + }, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "id": str(user_one.id), + "username": user_one.name, + "created_at": "2023-01-02T03:05:00+00:00", + "global_role": user_one.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + } + ] + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_non_admins_see_only_themselves( From 658fcb52bc74617f8e8e79eaa5d4fd9f5f9257ff Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 11:33:48 +0500 Subject: [PATCH 08/13] Add name_pattern for projects --- src/dstack/_internal/server/routers/projects.py | 1 + src/dstack/_internal/server/schemas/projects.py | 7 +++++++ src/dstack/_internal/server/services/projects.py | 10 +++++++--- src/dstack/api/server/_projects.py | 3 +++ src/tests/api/test_projects.py | 2 ++ 5 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index e6971e242a..aaf08809b7 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -59,6 +59,7 @@ async def list_projects( user=user, include_not_joined=body.include_not_joined, return_total_count=body.return_total_count, + name_pattern=body.name_pattern, prev_created_at=body.prev_created_at, prev_id=body.prev_id, limit=body.limit, diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py index bc54ef52df..8c23e9f3c0 100644 --- a/src/dstack/_internal/server/schemas/projects.py +++ b/src/dstack/_internal/server/schemas/projects.py @@ -15,6 +15,13 @@ class ListProjectsRequest(CoreModel): return_total_count: Annotated[ bool, Field(description="Return `total_count` with the total number of projects.") ] = False + name_pattern: Annotated[ + Optional[str], + Field( + description="Include only projects with the name containing `name_pattern`.", + regex="^[a-zA-Z0-9-]*$", + ), + ] = None prev_created_at: Annotated[ Optional[datetime], Field( diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index a8313a843f..46960f40ea 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Awaitable, Callable, List, Optional, Tuple -from sqlalchemy import and_, delete, func, literal_column, or_, select, update +from sqlalchemy import and_, delete, literal_column, or_, select, update from sqlalchemy import func as safunc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import QueryableAttribute, joinedload, load_only @@ -70,6 +70,7 @@ async def list_user_accessible_projects( user: UserModel, include_not_joined: bool, return_total_count: bool, + name_pattern: Optional[str], prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, @@ -81,7 +82,10 @@ async def list_user_accessible_projects( - Projects where user is a member (public or private) - if `include_not_joined`: Public projects where user is NOT a member """ - stmt = select(ProjectModel).where(ProjectModel.deleted == False) + filters = [ProjectModel.deleted == False] + if name_pattern: + filters.append(ProjectModel.name.ilike(f"%{name_pattern}%")) + stmt = select(ProjectModel).where(*filters) if user.global_role != GlobalRole.ADMIN: stmt = stmt.outerjoin( MemberModel, @@ -130,7 +134,7 @@ async def list_user_accessible_projects( order_by = (ProjectModel.created_at.asc(), ProjectModel.id.desc()) total_count = None if return_total_count: - res = await session.execute(stmt.with_only_columns(func.count(literal_column("1")))) + res = await session.execute(stmt.with_only_columns(safunc.count(literal_column("1")))) total_count = res.scalar_one() res = await session.execute(stmt.where(*pagination_filters).order_by(*order_by).limit(limit)) project_models = res.unique().scalars().all() diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py index a85e73b121..744edb8e12 100644 --- a/src/dstack/api/server/_projects.py +++ b/src/dstack/api/server/_projects.py @@ -27,6 +27,7 @@ def list( self, include_not_joined: bool = True, return_total_count: Optional[bool] = None, + name_pattern: Optional[str] = None, prev_created_at: Optional[datetime] = None, prev_id: Optional[UUID] = None, limit: Optional[int] = None, @@ -38,6 +39,8 @@ def list( } if return_total_count is not None: body["return_total_count"] = return_total_count + if name_pattern is not None: + body["name_pattern"] = name_pattern if prev_created_at is not None: body["prev_created_at"] = prev_created_at.isoformat() if prev_id is not None: diff --git a/src/tests/api/test_projects.py b/src/tests/api/test_projects.py index d84608750c..ebfbff31fd 100644 --- a/src/tests/api/test_projects.py +++ b/src/tests/api/test_projects.py @@ -36,6 +36,7 @@ def test_projects_list_serializes_pagination_and_parses_total_count(self): result = client.list( return_total_count=True, prev_created_at=dt, + name_pattern="p", prev_id=pid, limit=1, ascending=True, @@ -45,6 +46,7 @@ def test_projects_list_serializes_pagination_and_parses_total_count(self): assert request.last_path == "/api/projects/list" assert payload["include_not_joined"] is True assert payload["return_total_count"] is True + assert payload["name_pattern"] == "p" assert payload["prev_created_at"] == dt.isoformat() assert payload["prev_id"] == str(pid) assert payload["limit"] == 1 From f35d4ea7a4d63ab8a84cc011381907cc9cff1c5e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 11:40:47 +0500 Subject: [PATCH 09/13] Add TestProjectsAPIClientList --- src/tests/api/test_projects.py | 4 +-- src/tests/api/test_users.py | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 src/tests/api/test_users.py diff --git a/src/tests/api/test_projects.py b/src/tests/api/test_projects.py index ebfbff31fd..38b93bae5a 100644 --- a/src/tests/api/test_projects.py +++ b/src/tests/api/test_projects.py @@ -26,8 +26,8 @@ } -class TestProjectsList: - def test_projects_list_serializes_pagination_and_parses_total_count(self): +class TestProjectsAPIClientList: + def test_projects_list_serializes_pagination_and_parses_info_list(self): request = RequestRecorder(payload={"total_count": 1, "projects": [PROJECT_PAYLOAD]}) client = ProjectsAPIClient(_request=request, _logger=logging.getLogger("test")) dt = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) diff --git a/src/tests/api/test_users.py b/src/tests/api/test_users.py new file mode 100644 index 0000000000..29beebb7d8 --- /dev/null +++ b/src/tests/api/test_users.py @@ -0,0 +1,52 @@ +import json +import logging +from datetime import datetime, timezone +from uuid import UUID + +from dstack.api.server._users import UsersAPIClient +from tests.api.common import RequestRecorder + +USER_PAYLOAD = { + "id": "11111111-1111-4111-8111-111111111111", + "username": "user", + "created_at": "2023-01-02T03:04:00+00:00", + "global_role": "user", + "email": None, + "active": True, + "permissions": {"can_create_projects": True}, + "ssh_public_key": None, +} + + +class TestUsersAPIClientList: + def test_serializes_pagination_and_parses_total_count(self): + recorder = RequestRecorder({"total_count": 1, "users": [USER_PAYLOAD]}) + client = UsersAPIClient(_request=recorder, _logger=logging.getLogger("test")) + dt = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + uid = UUID("22222222-2222-4222-8222-222222222222") + + result = client.list( + return_total_count=True, + prev_created_at=dt, + prev_id=uid, + limit=1, + ascending=True, + ) + + payload = json.loads(recorder.last_body) + assert recorder.last_path == "/api/users/list" + assert payload["return_total_count"] is True + assert payload["prev_created_at"] == dt.isoformat() + assert payload["prev_id"] == str(uid) + assert payload["limit"] == 1 + assert payload["ascending"] is True + assert result.total_count == 1 + assert result.users[0].username == "user" + + def test_parses_list_response(self): + recorder = RequestRecorder([USER_PAYLOAD]) + client = UsersAPIClient(_request=recorder, _logger=logging.getLogger("test")) + result = client.list() + + assert isinstance(result, list) + assert result[0].username == "user" From 21a4a8918b2b50fd8711e47fd71e7e234f3e0a30 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 11:59:06 +0500 Subject: [PATCH 10/13] Add server-side validation for project name --- src/dstack/_internal/server/services/projects.py | 11 +++++++++++ src/tests/api/test_users.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 46960f40ea..cbaa19687c 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -1,3 +1,4 @@ +import re import secrets import uuid from datetime import datetime @@ -167,6 +168,7 @@ async def create_project( user_permissions = users.get_user_permissions(user) if not user_permissions.can_create_projects: raise ForbiddenError("User cannot create projects") + validate_project_name(project_name) project = await get_project_model_by_name( session=session, project_name=project_name, ignore_case=True ) @@ -681,6 +683,15 @@ def get_member_permissions(member_model: MemberModel) -> MemberPermissions: ) +def validate_project_name(project_name: str): + if not is_valid_project_name(project_name): + raise ServerClientError("Project name should match regex '^[a-zA-Z0-9-_]{1,50}$'") + + +def is_valid_project_name(project_name: str) -> bool: + return re.match("^[a-zA-Z0-9-_]{1,50}$", project_name) is not None + + _CREATE_PROJECT_HOOKS = [] diff --git a/src/tests/api/test_users.py b/src/tests/api/test_users.py index 29beebb7d8..1a2177a57e 100644 --- a/src/tests/api/test_users.py +++ b/src/tests/api/test_users.py @@ -19,7 +19,7 @@ class TestUsersAPIClientList: - def test_serializes_pagination_and_parses_total_count(self): + def test_serializes_pagination_and_parses_info_list(self): recorder = RequestRecorder({"total_count": 1, "users": [USER_PAYLOAD]}) client = UsersAPIClient(_request=recorder, _logger=logging.getLogger("test")) dt = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) From d2c1655f4f3efb97518a757bbededaf33513cb0b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 12:00:02 +0500 Subject: [PATCH 11/13] Allow _ in name_pattern --- src/dstack/_internal/server/schemas/projects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py index 8c23e9f3c0..5f0133ab72 100644 --- a/src/dstack/_internal/server/schemas/projects.py +++ b/src/dstack/_internal/server/schemas/projects.py @@ -19,7 +19,7 @@ class ListProjectsRequest(CoreModel): Optional[str], Field( description="Include only projects with the name containing `name_pattern`.", - regex="^[a-zA-Z0-9-]*$", + regex="^[a-zA-Z0-9-_]*$", ), ] = None prev_created_at: Annotated[ From c871d1230ba3bc48edaa101841db4665e222512d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 12:29:01 +0500 Subject: [PATCH 12/13] Add name_pattern for users --- src/dstack/_internal/server/routers/users.py | 1 + src/dstack/_internal/server/schemas/users.py | 7 +++ .../_internal/server/services/projects.py | 5 ++- src/dstack/_internal/server/services/users.py | 10 ++++- src/dstack/api/server/_users.py | 3 ++ .../_internal/server/routers/test_users.py | 44 +++++++++++++++++++ src/tests/api/test_users.py | 2 + 7 files changed, 69 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 4c20038d1d..6cd72f00a1 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -51,6 +51,7 @@ async def list_users( session=session, user=user, return_total_count=body.return_total_count, + name_pattern=body.name_pattern, prev_created_at=body.prev_created_at, prev_id=body.prev_id, limit=body.limit, diff --git a/src/dstack/_internal/server/schemas/users.py b/src/dstack/_internal/server/schemas/users.py index 17f3aca236..574d5b093e 100644 --- a/src/dstack/_internal/server/schemas/users.py +++ b/src/dstack/_internal/server/schemas/users.py @@ -12,6 +12,13 @@ class ListUsersRequest(CoreModel): return_total_count: Annotated[ bool, Field(description="Return `total_count` with the total number of users.") ] = False + name_pattern: Annotated[ + Optional[str], + Field( + description="Include only users with the name containing `name_pattern`.", + regex="^[a-zA-Z0-9-_]*$", + ), + ] = None prev_created_at: Annotated[ Optional[datetime], Field( diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index cbaa19687c..2383594690 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -85,7 +85,8 @@ async def list_user_accessible_projects( """ filters = [ProjectModel.deleted == False] if name_pattern: - filters.append(ProjectModel.name.ilike(f"%{name_pattern}%")) + name_pattern = name_pattern.replace("_", "/_") + filters.append(ProjectModel.name.ilike(f"%{name_pattern}%", escape="/")) stmt = select(ProjectModel).where(*filters) if user.global_role != GlobalRole.ADMIN: stmt = stmt.outerjoin( @@ -168,7 +169,6 @@ async def create_project( user_permissions = users.get_user_permissions(user) if not user_permissions.can_create_projects: raise ForbiddenError("User cannot create projects") - validate_project_name(project_name) project = await get_project_model_by_name( session=session, project_name=project_name, ignore_case=True ) @@ -577,6 +577,7 @@ async def get_project_model_by_id_or_error( async def create_project_model( session: AsyncSession, owner: UserModel, project_name: str, is_public: bool = False ) -> ProjectModel: + validate_project_name(project_name) private_bytes, public_bytes = await run_async( generate_rsa_key_pair_bytes, f"{project_name}@dstack" ) diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index a35ba9c694..73ceebe0ef 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -59,6 +59,7 @@ async def list_users_for_user( session: AsyncSession, user: UserModel, return_total_count: bool, + name_pattern: Optional[str], prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, @@ -69,12 +70,15 @@ async def list_users_for_user( session=session, include_deleted=False, return_total_count=return_total_count, + name_pattern=name_pattern, prev_created_at=prev_created_at, prev_id=prev_id, limit=limit, ascending=ascending, ) - users = [user_model_to_user(user)] + users = [] + if not user.deleted and (name_pattern is None or name_pattern.lower() in user.name.lower()): + users.append(user_model_to_user(user)) if return_total_count: return UsersInfoList(total_count=len(users), users=users) return users @@ -84,6 +88,7 @@ async def list_all_users( session: AsyncSession, include_deleted: bool = False, return_total_count: bool = False, + name_pattern: Optional[str] = None, prev_created_at: Optional[datetime] = None, prev_id: Optional[uuid.UUID] = None, limit: int = 2000, @@ -92,6 +97,9 @@ async def list_all_users( filters = [] if not include_deleted: filters.append(UserModel.deleted == False) + if name_pattern: + name_pattern = name_pattern.replace("_", "/_") + filters.append(UserModel.name.ilike(f"%{name_pattern}%", escape="/")) stmt = select(UserModel).where(*filters) pagination_filters = [] if prev_created_at is not None: diff --git a/src/dstack/api/server/_users.py b/src/dstack/api/server/_users.py index a60b45a11d..885eae54a2 100644 --- a/src/dstack/api/server/_users.py +++ b/src/dstack/api/server/_users.py @@ -26,6 +26,7 @@ class UsersAPIClient(APIClientGroup): def list( self, return_total_count: Optional[bool] = None, + name_pattern: Optional[str] = None, prev_created_at: Optional[datetime] = None, prev_id: Optional[UUID] = None, limit: Optional[int] = None, @@ -35,6 +36,8 @@ def list( body: dict[str, Any] = {} if return_total_count is not None: body["return_total_count"] = return_total_count + if name_pattern is not None: + body["name_pattern"] = name_pattern if prev_created_at is not None: body["prev_created_at"] = prev_created_at if prev_id is not None: diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index 2e62f17a4f..5042e75d6b 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -196,6 +196,50 @@ async def test_paginates_results(self, test_db, session: AsyncSession, client: A } ] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_filters_by_name_pattern( + self, test_db, session: AsyncSession, client: AsyncClient + ): + admin = await create_user( + session=session, + name="admin", + created_at=datetime(2023, 1, 2, 3, 6, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + matching_user = await create_user( + session=session, + name="alpha_user", + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + ) + await create_user( + session=session, + name="bravo", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + ) + response = await client.post( + "/api/users/list", + headers=get_auth_headers(admin.token), + json={"name_pattern": "alpha"}, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "id": str(matching_user.id), + "username": matching_user.name, + "created_at": "2023-01-02T03:05:00+00:00", + "global_role": matching_user.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + "ssh_public_key": None, + } + ] + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_non_admins_see_only_themselves( diff --git a/src/tests/api/test_users.py b/src/tests/api/test_users.py index 1a2177a57e..c01703b811 100644 --- a/src/tests/api/test_users.py +++ b/src/tests/api/test_users.py @@ -27,6 +27,7 @@ def test_serializes_pagination_and_parses_info_list(self): result = client.list( return_total_count=True, + name_pattern="user", prev_created_at=dt, prev_id=uid, limit=1, @@ -36,6 +37,7 @@ def test_serializes_pagination_and_parses_info_list(self): payload = json.loads(recorder.last_body) assert recorder.last_path == "/api/users/list" assert payload["return_total_count"] is True + assert payload["name_pattern"] == "user" assert payload["prev_created_at"] == dt.isoformat() assert payload["prev_id"] == str(uid) assert payload["limit"] == 1 From f8e0ff297be6b6202a9346c2008e31c188a13ee8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jan 2026 13:27:09 +0500 Subject: [PATCH 13/13] Add @overload for ProjectsAPIClient.list() --- src/dstack/api/server/_projects.py | 31 +++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py index 744edb8e12..96a1f511f7 100644 --- a/src/dstack/api/server/_projects.py +++ b/src/dstack/api/server/_projects.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Any, List, Optional +from typing import Any, List, Literal, Optional, Union, overload from uuid import UUID from pydantic import parse_obj_as @@ -23,9 +23,38 @@ class ProjectsAPIClient(APIClientGroup): + @overload def list( self, include_not_joined: bool = True, + *, + return_total_count: Literal[True], + name_pattern: Optional[str] = None, + prev_created_at: Optional[datetime] = None, + prev_id: Optional[UUID] = None, + limit: Optional[int] = None, + ascending: Optional[bool] = None, + ) -> ProjectsInfoList: + pass + + @overload + def list( + self, + include_not_joined: bool = True, + *, + return_total_count: Union[Literal[False], None] = None, + name_pattern: Optional[str] = None, + prev_created_at: Optional[datetime] = None, + prev_id: Optional[UUID] = None, + limit: Optional[int] = None, + ascending: Optional[bool] = None, + ) -> List[Project]: + pass + + def list( + self, + include_not_joined: bool = True, + *, return_total_count: Optional[bool] = None, name_pattern: Optional[str] = None, prev_created_at: Optional[datetime] = None,