diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index b48ede661296c..746fe45d6f3b0 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -709,6 +709,29 @@ func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg datab return database.OrganizationMember{}, sql.ErrNoRows } +func (q *fakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) + for _, userID := range ids { + userOrganizationIDs := make([]uuid.UUID, 0) + for _, membership := range q.organizationMembers { + if membership.UserID == userID { + userOrganizationIDs = append(userOrganizationIDs, membership.OrganizationID) + } + } + getOrganizationIDsByMemberIDRows = append(getOrganizationIDsByMemberIDRows, database.GetOrganizationIDsByMemberIDsRow{ + UserID: userID, + OrganizationIDs: userOrganizationIDs, + }) + } + if len(getOrganizationIDsByMemberIDRows) == 0 { + return nil, sql.ErrNoRows + } + return getOrganizationIDsByMemberIDRows, nil +} + func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3819eb3608934..304e012e24e88 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -17,6 +17,7 @@ type querier interface { GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) + GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) GetOrganizations(ctx context.Context) ([]Organization, error) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 5f0574812b79c..d80bb3135b6c1 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -306,6 +306,45 @@ func (q *sqlQuerier) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyPar return err } +const getOrganizationIDsByMemberIDs = `-- name: GetOrganizationIDsByMemberIDs :many +SELECT + user_id, array_agg(organization_id) :: uuid [ ] AS "organization_IDs" +FROM + organization_members +WHERE + user_id = ANY($1 :: uuid [ ]) +GROUP BY + user_id +` + +type GetOrganizationIDsByMemberIDsRow struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + OrganizationIDs []uuid.UUID `db:"organization_IDs" json:"organization_IDs"` +} + +func (q *sqlQuerier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error) { + rows, err := q.db.QueryContext(ctx, getOrganizationIDsByMemberIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetOrganizationIDsByMemberIDsRow + for rows.Next() { + var i GetOrganizationIDsByMemberIDsRow + if err := rows.Scan(&i.UserID, pq.Array(&i.OrganizationIDs)); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one SELECT user_id, organization_id, created_at, updated_at, roles diff --git a/coderd/database/queries/organizationmembers.sql b/coderd/database/queries/organizationmembers.sql index 243bdc26c9878..27c41ed53f577 100644 --- a/coderd/database/queries/organizationmembers.sql +++ b/coderd/database/queries/organizationmembers.sql @@ -20,3 +20,13 @@ INSERT INTO ) VALUES ($1, $2, $3, $4, $5) RETURNING *; + +-- name: GetOrganizationIDsByMemberIDs :many +SELECT + user_id, array_agg(organization_id) :: uuid [ ] AS "organization_IDs" +FROM + organization_members +WHERE + user_id = ANY(@ids :: uuid [ ]) +GROUP BY + user_id; diff --git a/coderd/users.go b/coderd/users.go index a38af8ba12f63..b77acee499994 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -137,16 +137,34 @@ func (api *api) users(rw http.ResponseWriter, r *http.Request) { LimitOpt: int32(pageLimit), Search: searchName, }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: err.Error(), + }) + return + } + userIDs := make([]uuid.UUID, 0, len(users)) + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + organizationIDsByMemberIDsRows, err := api.Database.GetOrganizationIDsByMemberIDs(r.Context(), userIDs) + if xerrors.Is(err, sql.ErrNoRows) { + err = nil + } if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: err.Error(), }) return } + organizationIDsByUserID := map[uuid.UUID][]uuid.UUID{} + for _, organizationIDsByMemberIDsRow := range organizationIDsByMemberIDsRows { + organizationIDsByUserID[organizationIDsByMemberIDsRow.UserID] = organizationIDsByMemberIDsRow.OrganizationIDs + } render.Status(r, http.StatusOK) - render.JSON(rw, r, convertUsers(users)) + render.JSON(rw, r, convertUsers(users, organizationIDsByUserID)) } // Creates a new user. @@ -213,15 +231,23 @@ func (api *api) postUser(rw http.ResponseWriter, r *http.Request) { return } - httpapi.Write(rw, http.StatusCreated, convertUser(user)) + httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{createUser.OrganizationID})) } // Returns the parameterized user requested. All validation // is completed in the middleware for this route. -func (*api) userByName(rw http.ResponseWriter, r *http.Request) { +func (api *api) userByName(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) + organizationIDs, err := userOrganizationIDs(r.Context(), api, user) + + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organization IDs: %s", err.Error()), + }) + return + } - httpapi.Write(rw, http.StatusOK, convertUser(user)) + httpapi.Write(rw, http.StatusOK, convertUser(user, organizationIDs)) } func (api *api) putUserProfile(rw http.ResponseWriter, r *http.Request) { @@ -278,7 +304,15 @@ func (api *api) putUserProfile(rw http.ResponseWriter, r *http.Request) { return } - httpapi.Write(rw, http.StatusOK, convertUser(updatedUserProfile)) + organizationIDs, err := userOrganizationIDs(r.Context(), api, user) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organization IDs: %s", err.Error()), + }) + return + } + + httpapi.Write(rw, http.StatusOK, convertUser(updatedUserProfile, organizationIDs)) } func (api *api) putUserSuspend(rw http.ResponseWriter, r *http.Request) { @@ -297,7 +331,15 @@ func (api *api) putUserSuspend(rw http.ResponseWriter, r *http.Request) { return } - httpapi.Write(rw, http.StatusOK, convertUser(suspendedUser)) + organizations, err := userOrganizationIDs(r.Context(), api, user) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get organization IDs: %s", err.Error()), + }) + return + } + + httpapi.Write(rw, http.StatusOK, convertUser(suspendedUser, organizations)) } // Returns organizations the parameterized user has access to. @@ -626,20 +668,34 @@ func (api *api) createUser(ctx context.Context, req codersdk.CreateUserRequest) }) } -func convertUser(user database.User) codersdk.User { +func convertUser(user database.User, organizationIDs []uuid.UUID) codersdk.User { return codersdk.User{ - ID: user.ID, - Email: user.Email, - CreatedAt: user.CreatedAt, - Username: user.Username, - Status: codersdk.UserStatus(user.Status), + ID: user.ID, + Email: user.Email, + CreatedAt: user.CreatedAt, + Username: user.Username, + Status: codersdk.UserStatus(user.Status), + OrganizationIDs: organizationIDs, } } -func convertUsers(users []database.User) []codersdk.User { +func convertUsers(users []database.User, organizationIDsByUserID map[uuid.UUID][]uuid.UUID) []codersdk.User { converted := make([]codersdk.User, 0, len(users)) for _, u := range users { - converted = append(converted, convertUser(u)) + userOrganizationIDs := organizationIDsByUserID[u.ID] + converted = append(converted, convertUser(u, userOrganizationIDs)) } return converted } + +func userOrganizationIDs(ctx context.Context, api *api, user database.User) ([]uuid.UUID, error) { + organizationIDsByMemberIDsRows, err := api.Database.GetOrganizationIDsByMemberIDs(ctx, []uuid.UUID{user.ID}) + if errors.Is(err, sql.ErrNoRows) || len(organizationIDsByMemberIDsRows) == 0 { + return []uuid.UUID{}, nil + } + if err != nil { + return []uuid.UUID{}, err + } + member := organizationIDsByMemberIDsRows[0] + return member.OrganizationIDs, nil +} diff --git a/coderd/users_test.go b/coderd/users_test.go index 8ee4db1c6f80e..8d4d6b2da2f52 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -321,9 +321,11 @@ func TestPutUserSuspend(t *testing.T) { func TestUserByName(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) - _ = coderdtest.CreateFirstUser(t, client) - _, err := client.User(context.Background(), codersdk.Me) + firstUser := coderdtest.CreateFirstUser(t, client) + user, err := client.User(context.Background(), codersdk.Me) + require.NoError(t, err) + require.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0]) } func TestGetUsers(t *testing.T) { @@ -340,6 +342,7 @@ func TestGetUsers(t *testing.T) { users, err := client.Users(context.Background(), codersdk.UsersRequest{}) require.NoError(t, err) require.Len(t, users, 2) + require.Len(t, users[0].OrganizationIDs, 1) } func TestOrganizationsByUser(t *testing.T) { @@ -451,14 +454,12 @@ func TestPaginatedUsers(t *testing.T) { coderdtest.CreateFirstUser(t, client) me, err := client.User(context.Background(), codersdk.Me) require.NoError(t, err) + orgID := me.OrganizationIDs[0] allUsers := make([]codersdk.User, 0) allUsers = append(allUsers, me) specialUsers := make([]codersdk.User, 0) - org, err := client.CreateOrganization(ctx, me.ID, codersdk.CreateOrganizationRequest{ - Name: "default", - }) require.NoError(t, err) // When 100 users exist @@ -481,7 +482,7 @@ func TestPaginatedUsers(t *testing.T) { Email: email, Username: username, Password: "password", - OrganizationID: org.ID, + OrganizationID: orgID, }) require.NoError(t, err) allUsers = append(allUsers, newUser) diff --git a/codersdk/users.go b/codersdk/users.go index 38506f03cd192..317ab7cd0f85b 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -37,11 +37,12 @@ const ( // User represents a user in Coder. type User struct { - ID uuid.UUID `json:"id" validate:"required"` - Email string `json:"email" validate:"required"` - CreatedAt time.Time `json:"created_at" validate:"required"` - Username string `json:"username" validate:"required"` - Status UserStatus `json:"status"` + ID uuid.UUID `json:"id" validate:"required"` + Email string `json:"email" validate:"required"` + CreatedAt time.Time `json:"created_at" validate:"required"` + Username string `json:"username" validate:"required"` + Status UserStatus `json:"status"` + OrganizationIDs []uuid.UUID `json:"organization_ids"` } type CreateFirstUserRequest struct { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 268d1f837c347..10acc3acbc4fb 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -90,7 +90,7 @@ export interface User { readonly status: UserStatus } -// From codersdk/users.go:47:6. +// From codersdk/users.go:48:6. export interface CreateFirstUserRequest { readonly email: string readonly username: string @@ -98,41 +98,41 @@ export interface CreateFirstUserRequest { readonly organization: string } -// From codersdk/users.go:60:6. +// From codersdk/users.go:61:6. export interface CreateUserRequest { readonly email: string readonly username: string readonly password: string } -// From codersdk/users.go:67:6. +// From codersdk/users.go:68:6. export interface UpdateUserProfileRequest { readonly email: string readonly username: string } -// From codersdk/users.go:73:6. +// From codersdk/users.go:74:6. export interface LoginWithPasswordRequest { readonly email: string readonly password: string } -// From codersdk/users.go:79:6. +// From codersdk/users.go:80:6. export interface LoginWithPasswordResponse { readonly session_token: string } -// From codersdk/users.go:84:6. +// From codersdk/users.go:85:6. export interface GenerateAPIKeyResponse { readonly key: string } -// From codersdk/users.go:88:6. +// From codersdk/users.go:89:6. export interface CreateOrganizationRequest { readonly name: string } -// From codersdk/users.go:93:6. +// From codersdk/users.go:94:6. export interface AuthMethods { readonly password: boolean readonly github: boolean