Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: enforce that provisioners can only acquire jobs in their own organization #12600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 18, 2024
1 change: 1 addition & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name st
api.ctx, // use the same ctx as the API
api.AccessURL,
daemon.ID,
defaultOrg.ID,
logger,
daemon.Provisioners,
provisionerdserver.Tags(daemon.Tags),
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2093,7 +2093,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ )
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {
Expand Down
1 change: 1 addition & 0 deletions coderd/database/dbfake/dbfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
// import job as well
for {
j, err := b.db.AcquireProvisionerJob(ownerCtx, database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down
17 changes: 12 additions & 5 deletions coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) dat
func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig database.ProvisionerJob) database.ProvisionerJob {
t.Helper()

var defOrgID uuid.UUID
if orig.OrganizationID == uuid.Nil {
defOrg, _ := db.GetDefaultOrganization(genCtx)
defOrgID = defOrg.ID
}

jobID := takeFirst(orig.ID, uuid.New())
// Always set some tags to prevent Acquire from grabbing jobs it should not.
if !orig.StartedAt.Time.IsZero() {
Expand All @@ -401,7 +407,7 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
ID: jobID,
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
OrganizationID: takeFirst(orig.OrganizationID, defOrgID, uuid.New()),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho),
StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile),
Expand All @@ -418,10 +424,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
}
if !orig.StartedAt.Time.IsZero() {
job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{
StartedAt: orig.StartedAt,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
StartedAt: orig.StartedAt,
OrganizationID: job.OrganizationID,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
})
require.NoError(t, err)
// There is no easy way to make sure we acquire the correct job.
Expand Down
22 changes: 13 additions & 9 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,9 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
defer q.mutex.Unlock()

for index, provisionerJob := range q.provisionerJobs {
if provisionerJob.OrganizationID != arg.OrganizationID {
continue
}
if provisionerJob.StartedAt.Valid {
continue
}
Expand Down Expand Up @@ -7871,15 +7874,16 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
}
}
d := database.ProvisionerDaemon{
ID: uuid.New(),
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
APIVersion: arg.APIVersion,
ID: uuid.New(),
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
APIVersion: arg.APIVersion,
OrganizationID: arg.OrganizationID,
}
q.provisionerDaemons = append(q.provisionerDaemons, d)
return d, nil
Expand Down
1 change: 1 addition & 0 deletions coderd/database/querier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ func TestQueuePosition(t *testing.T) {
}

job, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: org.ID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down
17 changes: 10 additions & 7 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions coderd/database/queries/provisionerjobs.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ WHERE
provisioner_jobs AS nested
WHERE
nested.started_at IS NULL
AND nested.organization_id = @organization_id
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
AND CASE
Expand Down
1 change: 1 addition & 0 deletions coderd/prometheusmetrics/prometheusmetrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func TestWorkspaces(t *testing.T) {
require.NoError(t, err)
// This marks the job as started.
_, err = db.AcquireProvisionerJob(context.Background(), database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down
19 changes: 12 additions & 7 deletions coderd/provisionerdserver/acquirer.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,24 +89,25 @@ func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, p
// done, or the database returns an error _other_ than that no jobs are available.
// If no jobs are available, this method handles retrying as appropriate.
func (a *Acquirer) AcquireJob(
ctx context.Context, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
ctx context.Context, organization uuid.UUID, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
) (
retJob database.ProvisionerJob, retErr error,
) {
logger := a.logger.With(
slog.F("organization_id", organization),
slog.F("worker_id", worker),
slog.F("provisioner_types", pt),
slog.F("tags", tags))
logger.Debug(ctx, "acquiring job")
dk := domainKey(pt, tags)
dk := domainKey(organization, pt, tags)
dbTags, err := tags.ToJSON()
if err != nil {
return database.ProvisionerJob{}, err
}
// buffer of 1 so that cancel doesn't deadlock while writing to the channel
clearance := make(chan struct{}, 1)
for {
a.want(pt, tags, clearance)
a.want(organization, pt, tags, clearance)
select {
case <-ctx.Done():
err := ctx.Err()
Expand All @@ -120,6 +121,7 @@ func (a *Acquirer) AcquireJob(
case <-clearance:
logger.Debug(ctx, "got clearance to call database")
job, err := a.store.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: organization,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down Expand Up @@ -152,8 +154,8 @@ func (a *Acquirer) AcquireJob(
}

// want signals that an acquiree wants clearance to query for a job with the given dKey.
func (a *Acquirer) want(pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
dk := domainKey(pt, tags)
func (a *Acquirer) want(organization uuid.UUID, pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
dk := domainKey(organization, pt, tags)
a.mu.Lock()
defer a.mu.Unlock()
cleared := false
Expand Down Expand Up @@ -404,13 +406,16 @@ type dKey string
// unprintable control character and won't show up in any "reasonable" set of
// string tags, even in non-Latin scripts. It is important that Tags are
// validated not to contain this control character prior to use.
func domainKey(pt []database.ProvisionerType, tags Tags) dKey {
func domainKey(orgID uuid.UUID, pt []database.ProvisionerType, tags Tags) dKey {
sb := strings.Builder{}
_, _ = sb.WriteString(orgID.String())
_ = sb.WriteByte(0x00)

// make a copy of pt before sorting, so that we don't mutate the original
// slice or underlying array.
pts := make([]database.ProvisionerType, len(pt))
copy(pts, pt)
slices.Sort(pts)
sb := strings.Builder{}
for _, t := range pts {
_, _ = sb.WriteString(string(t))
_ = sb.WriteByte(0x00)
Expand Down
Loading