Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
1 change: 1 addition & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name st
api.ctx, // use the same ctx as the API
api.AccessURL,
daemon.ID,
defaultOrg.ID,
logger,
daemon.Provisioners,
provisionerdserver.Tags(daemon.Tags),
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2093,7 +2093,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ )
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {
Expand Down
1 change: 1 addition & 0 deletions coderd/database/dbfake/dbfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
// import job as well
for {
j, err := b.db.AcquireProvisionerJob(ownerCtx, database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down
17 changes: 12 additions & 5 deletions coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) dat
func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig database.ProvisionerJob) database.ProvisionerJob {
t.Helper()

var defOrgID uuid.UUID
if orig.OrganizationID == uuid.Nil {
defOrg, _ := db.GetDefaultOrganization(genCtx)
defOrgID = defOrg.ID
}

jobID := takeFirst(orig.ID, uuid.New())
// Always set some tags to prevent Acquire from grabbing jobs it should not.
if !orig.StartedAt.Time.IsZero() {
Expand All @@ -401,7 +407,7 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
ID: jobID,
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
OrganizationID: takeFirst(orig.OrganizationID, defOrgID, uuid.New()),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho),
StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile),
Expand All @@ -418,10 +424,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
}
if !orig.StartedAt.Time.IsZero() {
job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{
StartedAt: orig.StartedAt,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
StartedAt: orig.StartedAt,
OrganizationID: job.OrganizationID,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
})
require.NoError(t, err)
// There is no easy way to make sure we acquire the correct job.
Expand Down
22 changes: 13 additions & 9 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,9 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
defer q.mutex.Unlock()

for index, provisionerJob := range q.provisionerJobs {
if provisionerJob.OrganizationID != arg.OrganizationID {
continue
}
if provisionerJob.StartedAt.Valid {
continue
}
Expand Down Expand Up @@ -7871,15 +7874,16 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
}
}
d := database.ProvisionerDaemon{
ID: uuid.New(),
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
APIVersion: arg.APIVersion,
ID: uuid.New(),
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
APIVersion: arg.APIVersion,
OrganizationID: arg.OrganizationID,
}
q.provisionerDaemons = append(q.provisionerDaemons, d)
return d, nil
Expand Down
1 change: 1 addition & 0 deletions coderd/database/querier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ func TestQueuePosition(t *testing.T) {
}

job, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: org.ID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down
17 changes: 10 additions & 7 deletions coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions coderd/database/queries/provisionerjobs.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ WHERE
provisioner_jobs AS nested
WHERE
nested.started_at IS NULL
AND nested.organization_id = @organization_id
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
AND CASE
Expand Down
1 change: 1 addition & 0 deletions coderd/prometheusmetrics/prometheusmetrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func TestWorkspaces(t *testing.T) {
require.NoError(t, err)
// This marks the job as started.
_, err = db.AcquireProvisionerJob(context.Background(), database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down
19 changes: 12 additions & 7 deletions coderd/provisionerdserver/acquirer.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,24 +89,25 @@ func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, p
// done, or the database returns an error _other_ than that no jobs are available.
// If no jobs are available, this method handles retrying as appropriate.
func (a *Acquirer) AcquireJob(
ctx context.Context, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
ctx context.Context, organization uuid.UUID, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
) (
retJob database.ProvisionerJob, retErr error,
) {
logger := a.logger.With(
slog.F("organization_id", organization),
slog.F("worker_id", worker),
slog.F("provisioner_types", pt),
slog.F("tags", tags))
logger.Debug(ctx, "acquiring job")
dk := domainKey(pt, tags)
dk := domainKey(organization, pt, tags)
dbTags, err := tags.ToJSON()
if err != nil {
return database.ProvisionerJob{}, err
}
// buffer of 1 so that cancel doesn't deadlock while writing to the channel
clearance := make(chan struct{}, 1)
for {
a.want(pt, tags, clearance)
a.want(organization, pt, tags, clearance)
select {
case <-ctx.Done():
err := ctx.Err()
Expand All @@ -120,6 +121,7 @@ func (a *Acquirer) AcquireJob(
case <-clearance:
logger.Debug(ctx, "got clearance to call database")
job, err := a.store.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: organization,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
Expand Down Expand Up @@ -152,8 +154,8 @@ func (a *Acquirer) AcquireJob(
}

// want signals that an acquiree wants clearance to query for a job with the given dKey.
func (a *Acquirer) want(pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
dk := domainKey(pt, tags)
func (a *Acquirer) want(organization uuid.UUID, pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
dk := domainKey(organization, pt, tags)
a.mu.Lock()
defer a.mu.Unlock()
cleared := false
Expand Down Expand Up @@ -404,13 +406,16 @@ type dKey string
// unprintable control character and won't show up in any "reasonable" set of
// string tags, even in non-Latin scripts. It is important that Tags are
// validated not to contain this control character prior to use.
func domainKey(pt []database.ProvisionerType, tags Tags) dKey {
func domainKey(orgID uuid.UUID, pt []database.ProvisionerType, tags Tags) dKey {
sb := strings.Builder{}
_, _ = sb.WriteString(orgID.String())
_ = sb.WriteByte(0x00)

// make a copy of pt before sorting, so that we don't mutate the original
// slice or underlying array.
pts := make([]database.ProvisionerType, len(pt))
copy(pts, pt)
slices.Sort(pts)
sb := strings.Builder{}
for _, t := range pts {
_, _ = sb.WriteString(string(t))
_ = sb.WriteByte(0x00)
Expand Down
Loading