diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 651daf4fba353..88c17101f833c 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -231,17 +231,19 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams users = tmp } - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - for _, status := range params.Status { - if user.Status == status { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } + if len(params.Status) == 0 { + params.Status = []database.UserStatus{database.UserStatusActive} + } + + usersFilteredByStatus := make([]database.User, 0, len(users)) + for i, user := range users { + for _, status := range params.Status { + if user.Status == status { + usersFilteredByStatus = append(usersFilteredByStatus, users[i]) } } - users = usersFilteredByStatus } + users = usersFilteredByStatus if params.OffsetOpt > 0 { if int(params.OffsetOpt) > len(users)-1 { diff --git a/coderd/organizations.go b/coderd/organizations.go index 49dba0bb33324..0f58eac546bcb 100644 --- a/coderd/organizations.go +++ b/coderd/organizations.go @@ -57,8 +57,8 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) { } var organization database.Organization - err = api.Database.InTx(func(db database.Store) error { - organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{ + err = api.Database.InTx(func(store database.Store) error { + organization, err = store.InsertOrganization(r.Context(), database.InsertOrganizationParams{ ID: uuid.New(), Name: req.Name, CreatedAt: database.Now(), @@ -67,7 +67,7 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) { if err != nil { return xerrors.Errorf("create organization: %w", err) } - _, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ + _, err = store.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{ OrganizationID: organization.ID, UserID: apiKey.UserID, CreatedAt: database.Now(), diff --git a/coderd/templateversions.go b/coderd/templateversions.go index c0157114c46d9..d4fe2f9e779f4 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -385,51 +385,77 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque return } - apiVersion := []codersdk.TemplateVersion{} - versions, err := api.Database.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{ - TemplateID: template.ID, - AfterID: paginationParams.AfterID, - LimitOpt: int32(paginationParams.Limit), - OffsetOpt: int32(paginationParams.Offset), - }) - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusOK, apiVersion) - return - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get template version: %s", err), - }) - return - } - jobIDs := make([]uuid.UUID, 0, len(versions)) - for _, version := range versions { - jobIDs = append(jobIDs, version.JobID) - } - jobs, err := api.Database.GetProvisionerJobsByIDs(r.Context(), jobIDs) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get jobs: %s", err), + var err error + apiVersions := []codersdk.TemplateVersion{} + err = api.Database.InTx(func(store database.Store) error { + if paginationParams.AfterID != uuid.Nil { + // See if the record exists first. If the record does not exist, the pagination + // query will not work. + _, err := store.GetTemplateVersionByID(r.Context(), paginationParams.AfterID) + if err != nil && xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("record at \"after_id\" (%q) does not exists", paginationParams.AfterID.String()), + }) + return err + } else if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get template version at after_id: %s", err), + }) + return err + } + } + + versions, err := store.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{ + TemplateID: template.ID, + AfterID: paginationParams.AfterID, + LimitOpt: int32(paginationParams.Limit), + OffsetOpt: int32(paginationParams.Offset), }) - return - } - jobByID := map[string]database.ProvisionerJob{} - for _, job := range jobs { - jobByID[job.ID.String()] = job - } + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusOK, apiVersions) + return err + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get template version: %s", err), + }) + return err + } - for _, version := range versions { - job, exists := jobByID[version.JobID.String()] - if !exists { + jobIDs := make([]uuid.UUID, 0, len(versions)) + for _, version := range versions { + jobIDs = append(jobIDs, version.JobID) + } + jobs, err := store.GetProvisionerJobsByIDs(r.Context(), jobIDs) + if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("job %q doesn't exist for version %q", version.JobID, version.ID), + Message: fmt.Sprintf("get jobs: %s", err), }) - return + return err } - apiVersion = append(apiVersion, convertTemplateVersion(version, convertProvisionerJob(job))) + jobByID := map[string]database.ProvisionerJob{} + for _, job := range jobs { + jobByID[job.ID.String()] = job + } + + for _, version := range versions { + job, exists := jobByID[version.JobID.String()] + if !exists { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("job %q doesn't exist for version %q", version.JobID, version.ID), + }) + return err + } + apiVersions = append(apiVersions, convertTemplateVersion(version, convertProvisionerJob(job))) + } + + return nil + }) + if err != nil { + return } - httpapi.Write(rw, http.StatusOK, apiVersion) + httpapi.Write(rw, http.StatusOK, apiVersions) } func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) { @@ -582,7 +608,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht } } - provisionerJob, err = api.Database.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{ + provisionerJob, err = db.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{ ID: jobID, CreatedAt: database.Now(), UpdatedAt: database.Now(), @@ -606,7 +632,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht } } - templateVersion, err = api.Database.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{ + templateVersion, err = db.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{ ID: uuid.New(), TemplateID: templateID, OrganizationID: organization.ID, diff --git a/coderd/templateversions_test.go b/coderd/templateversions_test.go index 359acf391b223..c868cc65ac7fc 100644 --- a/coderd/templateversions_test.go +++ b/coderd/templateversions_test.go @@ -694,9 +694,10 @@ func TestPaginatedTemplateVersions(t *testing.T) { pagination codersdk.Pagination } tests := []struct { - name string - args args - want []codersdk.TemplateVersion + name string + args args + want []codersdk.TemplateVersion + expectedError string }{ { name: "Single result", @@ -728,6 +729,11 @@ func TestPaginatedTemplateVersions(t *testing.T) { args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 2, Offset: 10}}, want: []codersdk.TemplateVersion{}, }, + { + name: "After_id does not exist", + args: args{ctx: ctx, pagination: codersdk.Pagination{AfterID: uuid.New()}}, + expectedError: "does not exist", + }, } for _, tt := range tests { tt := tt @@ -737,8 +743,13 @@ func TestPaginatedTemplateVersions(t *testing.T) { TemplateID: template.ID, Pagination: tt.args.pagination, }) - assert.NoError(t, err) - assert.Equal(t, tt.want, got) + if tt.expectedError != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } }) } } diff --git a/coderd/users_test.go b/coderd/users_test.go index 821fd909a92f5..5da48ccf1a55a 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -830,6 +830,50 @@ func TestWorkspacesByUser(t *testing.T) { }) } +// TestSuspendedPagination is when the after_id is a suspended record. +// The database query should still return the correct page, as the after_id +// is in a subquery that finds the record regardless of its status. +// This is mainly to confirm the db fake has the same behavior. +func TestSuspendedPagination(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := coderdtest.New(t, &coderdtest.Options{APIRateLimit: -1}) + coderdtest.CreateFirstUser(t, client) + me, err := client.User(context.Background(), codersdk.Me) + require.NoError(t, err) + orgID := me.OrganizationIDs[0] + + total := 10 + users := make([]codersdk.User, 0, total) + // Create users + for i := 0; i < total; i++ { + email := fmt.Sprintf("%d@coder.com", i) + username := fmt.Sprintf("user%d", i) + user, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{ + Email: email, + Username: username, + Password: "password", + OrganizationID: orgID, + }) + require.NoError(t, err) + users = append(users, user) + } + sortUsers(users) + deletedUser := users[2] + expected := users[3:8] + _, err = client.UpdateUserStatus(ctx, deletedUser.ID.String(), codersdk.UserStatusSuspended) + require.NoError(t, err, "suspend user") + + page, err := client.Users(ctx, codersdk.UsersRequest{ + Pagination: codersdk.Pagination{ + Limit: len(expected), + AfterID: deletedUser.ID, + }, + }) + require.NoError(t, err) + require.Equal(t, expected, page, "expected page") +} + // TestPaginatedUsers creates a list of users, then tries to paginate through // them using different page sizes. func TestPaginatedUsers(t *testing.T) { diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index bb8dc2a578bd3..50d60b34cb604 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -51,22 +51,51 @@ func (api *API) workspaceBuilds(rw http.ResponseWriter, r *http.Request) { if !ok { return } - req := database.GetWorkspaceBuildByWorkspaceIDParams{ - WorkspaceID: workspace.ID, - AfterID: paginationParams.AfterID, - OffsetOpt: int32(paginationParams.Offset), - LimitOpt: int32(paginationParams.Limit), - } - builds, err := api.Database.GetWorkspaceBuildByWorkspaceID(r.Context(), req) - if xerrors.Is(err, sql.ErrNoRows) { - err = nil - } + + var builds []database.WorkspaceBuild + // Ensure all db calls happen in the same tx + err := api.Database.InTx(func(store database.Store) error { + var err error + if paginationParams.AfterID != uuid.Nil { + // See if the record exists first. If the record does not exist, the pagination + // query will not work. + _, err := store.GetWorkspaceBuildByID(r.Context(), paginationParams.AfterID) + if err != nil && xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("record at \"after_id\" (%q) does not exist", paginationParams.AfterID.String()), + }) + return err + } else if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace build at after_id: %s", err), + }) + return err + } + } + + req := database.GetWorkspaceBuildByWorkspaceIDParams{ + WorkspaceID: workspace.ID, + AfterID: paginationParams.AfterID, + OffsetOpt: int32(paginationParams.Offset), + LimitOpt: int32(paginationParams.Limit), + } + builds, err = store.GetWorkspaceBuildByWorkspaceID(r.Context(), req) + if xerrors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace builds: %s", err), + }) + return err + } + + return nil + }) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get workspace builds: %s", err), - }) return } + jobIDs := make([]uuid.UUID, 0, len(builds)) for _, version := range builds { jobIDs = append(jobIDs, version.JobID) diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index 70f61657e2a87..5b15f54c29d52 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/coderdtest" @@ -44,6 +45,30 @@ func TestWorkspaceBuilds(t *testing.T) { require.NoError(t, err) }) + t.Run("PaginateNonExistentRow", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + _, err := client.WorkspaceBuilds(ctx, codersdk.WorkspaceBuildsRequest{ + WorkspaceID: workspace.ID, + Pagination: codersdk.Pagination{ + AfterID: uuid.New(), + }, + }) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusBadRequest, apiError.StatusCode()) + require.Contains(t, apiError.Message, "does not exist") + }) + t.Run("PaginateLimitOffset", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) diff --git a/docker-compose.yaml b/docker-compose.yaml index a8beab8129ad5..74c8916fd0d21 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -18,6 +18,8 @@ services: condition: service_healthy database: image: "postgres:14.2" + ports: + - "5432:5432" environment: POSTGRES_USER: ${POSTGRES_USER:-username} # The PostgreSQL user (useful to connect to the database) POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} # The PostgreSQL password (useful to connect to the database)