Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f0f9569

Browse files
authored
chore: enforce that provisioners can only acquire jobs in their own organization (#12600)
* chore: add org ID as optional param to AcquireJob * chore: plumb through organization id to provisioner daemons * add org id to provisioner domain key * enforce org id argument * dbgen provisioner jobs defaults to default org
1 parent 0e8ebb9 commit f0f9569

File tree

15 files changed

+204
-126
lines changed

15 files changed

+204
-126
lines changed

coderd/coderd.go

+1
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name st
12711271
api.ctx, // use the same ctx as the API
12721272
api.AccessURL,
12731273
daemon.ID,
1274+
defaultOrg.ID,
12741275
logger,
12751276
daemon.Provisioners,
12761277
provisionerdserver.Tags(daemon.Tags),

coderd/database/dbauthz/dbauthz_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2093,7 +2093,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
20932093
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
20942094
StartedAt: sql.NullTime{Valid: false},
20952095
})
2096-
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
2096+
check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
20972097
Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ )
20982098
}))
20992099
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {

coderd/database/dbfake/dbfake.go

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
187187
// import job as well
188188
for {
189189
j, err := b.db.AcquireProvisionerJob(ownerCtx, database.AcquireProvisionerJobParams{
190+
OrganizationID: job.OrganizationID,
190191
StartedAt: sql.NullTime{
191192
Time: dbtime.Now(),
192193
Valid: true,

coderd/database/dbgen/dbgen.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) dat
387387
func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig database.ProvisionerJob) database.ProvisionerJob {
388388
t.Helper()
389389

390+
var defOrgID uuid.UUID
391+
if orig.OrganizationID == uuid.Nil {
392+
defOrg, _ := db.GetDefaultOrganization(genCtx)
393+
defOrgID = defOrg.ID
394+
}
395+
390396
jobID := takeFirst(orig.ID, uuid.New())
391397
// Always set some tags to prevent Acquire from grabbing jobs it should not.
392398
if !orig.StartedAt.Time.IsZero() {
@@ -401,7 +407,7 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
401407
ID: jobID,
402408
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
403409
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
404-
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
410+
OrganizationID: takeFirst(orig.OrganizationID, defOrgID, uuid.New()),
405411
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
406412
Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho),
407413
StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile),
@@ -418,10 +424,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
418424
}
419425
if !orig.StartedAt.Time.IsZero() {
420426
job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{
421-
StartedAt: orig.StartedAt,
422-
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
423-
Tags: must(json.Marshal(orig.Tags)),
424-
WorkerID: uuid.NullUUID{},
427+
StartedAt: orig.StartedAt,
428+
OrganizationID: job.OrganizationID,
429+
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
430+
Tags: must(json.Marshal(orig.Tags)),
431+
WorkerID: uuid.NullUUID{},
425432
})
426433
require.NoError(t, err)
427434
// There is no easy way to make sure we acquire the correct job.

coderd/database/dbmem/dbmem.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,9 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
803803
defer q.mutex.Unlock()
804804

805805
for index, provisionerJob := range q.provisionerJobs {
806+
if provisionerJob.OrganizationID != arg.OrganizationID {
807+
continue
808+
}
806809
if provisionerJob.StartedAt.Valid {
807810
continue
808811
}
@@ -7861,15 +7864,16 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
78617864
}
78627865
}
78637866
d := database.ProvisionerDaemon{
7864-
ID: uuid.New(),
7865-
CreatedAt: arg.CreatedAt,
7866-
Name: arg.Name,
7867-
Provisioners: arg.Provisioners,
7868-
Tags: maps.Clone(arg.Tags),
7869-
ReplicaID: uuid.NullUUID{},
7870-
LastSeenAt: arg.LastSeenAt,
7871-
Version: arg.Version,
7872-
APIVersion: arg.APIVersion,
7867+
ID: uuid.New(),
7868+
CreatedAt: arg.CreatedAt,
7869+
Name: arg.Name,
7870+
Provisioners: arg.Provisioners,
7871+
Tags: maps.Clone(arg.Tags),
7872+
ReplicaID: uuid.NullUUID{},
7873+
LastSeenAt: arg.LastSeenAt,
7874+
Version: arg.Version,
7875+
APIVersion: arg.APIVersion,
7876+
OrganizationID: arg.OrganizationID,
78737877
}
78747878
q.provisionerDaemons = append(q.provisionerDaemons, d)
78757879
return d, nil

coderd/database/querier_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ func TestQueuePosition(t *testing.T) {
363363
}
364364

365365
job, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
366+
OrganizationID: org.ID,
366367
StartedAt: sql.NullTime{
367368
Time: dbtime.Now(),
368369
Valid: true,

coderd/database/queries.sql.go

+10-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/provisionerjobs.sql

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ WHERE
1919
provisioner_jobs AS nested
2020
WHERE
2121
nested.started_at IS NULL
22+
AND nested.organization_id = @organization_id
2223
-- Ensure the caller has the correct provisioner.
2324
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
2425
AND CASE

coderd/prometheusmetrics/prometheusmetrics_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func TestWorkspaces(t *testing.T) {
134134
require.NoError(t, err)
135135
// This marks the job as started.
136136
_, err = db.AcquireProvisionerJob(context.Background(), database.AcquireProvisionerJobParams{
137+
OrganizationID: job.OrganizationID,
137138
StartedAt: sql.NullTime{
138139
Time: dbtime.Now(),
139140
Valid: true,

coderd/provisionerdserver/acquirer.go

+12-7
Original file line numberDiff line numberDiff line change
@@ -89,24 +89,25 @@ func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, p
8989
// done, or the database returns an error _other_ than that no jobs are available.
9090
// If no jobs are available, this method handles retrying as appropriate.
9191
func (a *Acquirer) AcquireJob(
92-
ctx context.Context, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
92+
ctx context.Context, organization uuid.UUID, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
9393
) (
9494
retJob database.ProvisionerJob, retErr error,
9595
) {
9696
logger := a.logger.With(
97+
slog.F("organization_id", organization),
9798
slog.F("worker_id", worker),
9899
slog.F("provisioner_types", pt),
99100
slog.F("tags", tags))
100101
logger.Debug(ctx, "acquiring job")
101-
dk := domainKey(pt, tags)
102+
dk := domainKey(organization, pt, tags)
102103
dbTags, err := tags.ToJSON()
103104
if err != nil {
104105
return database.ProvisionerJob{}, err
105106
}
106107
// buffer of 1 so that cancel doesn't deadlock while writing to the channel
107108
clearance := make(chan struct{}, 1)
108109
for {
109-
a.want(pt, tags, clearance)
110+
a.want(organization, pt, tags, clearance)
110111
select {
111112
case <-ctx.Done():
112113
err := ctx.Err()
@@ -120,6 +121,7 @@ func (a *Acquirer) AcquireJob(
120121
case <-clearance:
121122
logger.Debug(ctx, "got clearance to call database")
122123
job, err := a.store.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
124+
OrganizationID: organization,
123125
StartedAt: sql.NullTime{
124126
Time: dbtime.Now(),
125127
Valid: true,
@@ -152,8 +154,8 @@ func (a *Acquirer) AcquireJob(
152154
}
153155

154156
// want signals that an acquiree wants clearance to query for a job with the given dKey.
155-
func (a *Acquirer) want(pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
156-
dk := domainKey(pt, tags)
157+
func (a *Acquirer) want(organization uuid.UUID, pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
158+
dk := domainKey(organization, pt, tags)
157159
a.mu.Lock()
158160
defer a.mu.Unlock()
159161
cleared := false
@@ -404,13 +406,16 @@ type dKey string
404406
// unprintable control character and won't show up in any "reasonable" set of
405407
// string tags, even in non-Latin scripts. It is important that Tags are
406408
// validated not to contain this control character prior to use.
407-
func domainKey(pt []database.ProvisionerType, tags Tags) dKey {
409+
func domainKey(orgID uuid.UUID, pt []database.ProvisionerType, tags Tags) dKey {
410+
sb := strings.Builder{}
411+
_, _ = sb.WriteString(orgID.String())
412+
_ = sb.WriteByte(0x00)
413+
408414
// make a copy of pt before sorting, so that we don't mutate the original
409415
// slice or underlying array.
410416
pts := make([]database.ProvisionerType, len(pt))
411417
copy(pts, pt)
412418
slices.Sort(pts)
413-
sb := strings.Builder{}
414419
for _, t := range pts {
415420
_, _ = sb.WriteString(string(t))
416421
_ = sb.WriteByte(0x00)

0 commit comments

Comments
 (0)