diff --git a/coderd/coderd.go b/coderd/coderd.go index b59dbb47b669f..2a0f255523e7e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -23,6 +23,7 @@ type Options struct { func New(options *Options) http.Handler { api := &api{ Database: options.Database, + Logger: options.Logger, Pubsub: options.Pubsub, } @@ -110,5 +111,6 @@ func New(options *Options) http.Handler { // be added to this struct for code clarity. type api struct { Database database.Store + Logger slog.Logger Pubsub database.Pubsub } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 6a6f97f3ef090..d064c82dcfb61 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -125,7 +125,7 @@ func New(t *testing.T) Server { } handler := coderd.New(&coderd.Options{ - Logger: slogtest.Make(t, nil), + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), Database: db, Pubsub: pubsub, }) diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 1a315402f08fc..4318dd2cde1be 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -19,6 +19,8 @@ import ( "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" + "cdr.dev/slog" + "github.com/coder/coder/coderd/projectparameter" "github.com/coder/coder/database" "github.com/coder/coder/httpapi" @@ -84,6 +86,7 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request) Database: api.Database, Pubsub: api.Pubsub, Provisioners: daemon.Provisioners, + Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), }) if err != nil { _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("drpc register provisioner daemon: %s", err)) @@ -109,6 +112,7 @@ type projectImportJob struct { // Implementation of the provisioner daemon protobuf server. type provisionerdServer struct { ID uuid.UUID + Logger slog.Logger Provisioners []database.ProvisionerType Database database.Store Pubsub database.Pubsub @@ -136,9 +140,11 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty if err != nil { return nil, xerrors.Errorf("acquire job: %w", err) } + server.Logger.Debug(ctx, "locked job from database", slog.F("id", job.ID)) + // Marks the acquired job as failed with the error message provided. failJob := func(errorMessage string) error { - err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: job.ID, CompletedAt: sql.NullTime{ Time: database.Now(), @@ -381,8 +387,12 @@ func (server *provisionerdServer) CancelJob(ctx context.Context, cancelJob *prot if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } - err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, CancelledAt: sql.NullTime{ Time: database.Now(), Valid: true, @@ -476,7 +486,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr // This must occur in a transaction in case of failure. err = server.Database.InTx(func(db database.Store) error { - err = db.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: database.Now(), CompletedAt: sql.NullTime{ @@ -495,6 +505,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr return xerrors.Errorf("insert project parameter %q: %w", projectParameter.Name, err) } } + server.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) return nil }) if err != nil { @@ -513,7 +524,7 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr } err = server.Database.InTx(func(db database.Store) error { - err = db.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: database.Now(), CompletedAt: sql.NullTime{ diff --git a/coderd/provisioners.go b/coderd/provisioners.go index f2afefa00cbef..959e69b565801 100644 --- a/coderd/provisioners.go +++ b/coderd/provisioners.go @@ -12,7 +12,7 @@ type ProvisionerJobStatus string // Completed returns whether the job is still processing. func (p ProvisionerJobStatus) Completed() bool { - return p == ProvisionerJobStatusSucceeded || p == ProvisionerJobStatusFailed + return p == ProvisionerJobStatusSucceeded || p == ProvisionerJobStatusFailed || p == ProvisionerJobStatusCancelled } const ( diff --git a/coderd/workspacehistory.go b/coderd/workspacehistory.go index f9e4c7690b4d3..e5ea22d4b6ac6 100644 --- a/coderd/workspacehistory.go +++ b/coderd/workspacehistory.go @@ -82,6 +82,10 @@ func (api *api) postWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Reque Message: fmt.Sprintf("The provided project history %q has failed to import. You cannot create workspaces using it!", projectHistory.Name), }) return + case ProvisionerJobStatusCancelled: + httpapi.Write(rw, http.StatusPreconditionFailed, httpapi.Response{ + Message: "The provided project history was canceled during import. You cannot create workspaces using it!", + }) } project, err := api.Database.GetProjectByID(r.Context(), projectHistory.ProjectID) diff --git a/coderd/workspacehistory_test.go b/coderd/workspacehistory_test.go index 66dc5bd444621..49a41a4d25e0e 100644 --- a/coderd/workspacehistory_test.go +++ b/coderd/workspacehistory_test.go @@ -56,6 +56,7 @@ func TestWorkspaceHistory(t *testing.T) { require.Eventually(t, func() bool { hist, err := client.ProjectHistory(context.Background(), user.Organization, project.Name, projectHistory.Name) require.NoError(t, err) + t.Logf("Import status: %s\n", hist.Import.Status) return hist.Import.Status.Completed() }, 15*time.Second, 50*time.Millisecond) return projectHistory diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index 7ddb71ba04751..cba5e5a3ece90 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -904,9 +904,24 @@ func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.U if arg.ID.String() != job.ID.String() { continue } + job.UpdatedAt = arg.UpdatedAt + q.provisionerJobs[index] = job + return nil + } + return sql.ErrNoRows +} + +func (q *fakeQuerier) UpdateProvisionerJobWithCompleteByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, job := range q.provisionerJobs { + if arg.ID.String() != job.ID.String() { + continue + } + job.UpdatedAt = arg.UpdatedAt job.CompletedAt = arg.CompletedAt job.CancelledAt = arg.CancelledAt - job.UpdatedAt = arg.UpdatedAt job.Error = arg.Error q.provisionerJobs[index] = job return nil diff --git a/database/querier.go b/database/querier.go index 870d122f11440..3db5795f9b656 100644 --- a/database/querier.go +++ b/database/querier.go @@ -60,6 +60,7 @@ type querier interface { UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error + UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error UpdateWorkspaceHistoryByID(ctx context.Context, arg UpdateWorkspaceHistoryByIDParams) error } diff --git a/database/query.sql b/database/query.sql index f0b09b4081850..c2a8fcb579113 100644 --- a/database/query.sql +++ b/database/query.sql @@ -604,12 +604,20 @@ WHERE id = $1; -- name: UpdateProvisionerJobByID :exec +UPDATE + provisioner_job +SET + updated_at = $2 +WHERE + id = $1; + +-- name: UpdateProvisionerJobWithCompleteByID :exec UPDATE provisioner_job SET updated_at = $2, - cancelled_at = $3, - completed_at = $4, + completed_at = $3, + cancelled_at = $4, error = $5 WHERE id = $1; diff --git a/database/query.sql.go b/database/query.sql.go index ad322bbc9b392..366e156132563 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -2143,31 +2143,50 @@ func (q *sqlQuerier) UpdateProvisionerDaemonByID(ctx context.Context, arg Update } const updateProvisionerJobByID = `-- name: UpdateProvisionerJobByID :exec +UPDATE + provisioner_job +SET + updated_at = $2 +WHERE + id = $1 +` + +type UpdateProvisionerJobByIDParams struct { + ID uuid.UUID `db:"id" json:"id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (q *sqlQuerier) UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error { + _, err := q.db.ExecContext(ctx, updateProvisionerJobByID, arg.ID, arg.UpdatedAt) + return err +} + +const updateProvisionerJobWithCompleteByID = `-- name: UpdateProvisionerJobWithCompleteByID :exec UPDATE provisioner_job SET updated_at = $2, - cancelled_at = $3, - completed_at = $4, + completed_at = $3, + cancelled_at = $4, error = $5 WHERE id = $1 ` -type UpdateProvisionerJobByIDParams struct { +type UpdateProvisionerJobWithCompleteByIDParams struct { ID uuid.UUID `db:"id" json:"id"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - CancelledAt sql.NullTime `db:"cancelled_at" json:"cancelled_at"` CompletedAt sql.NullTime `db:"completed_at" json:"completed_at"` + CancelledAt sql.NullTime `db:"cancelled_at" json:"cancelled_at"` Error sql.NullString `db:"error" json:"error"` } -func (q *sqlQuerier) UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error { - _, err := q.db.ExecContext(ctx, updateProvisionerJobByID, +func (q *sqlQuerier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error { + _, err := q.db.ExecContext(ctx, updateProvisionerJobWithCompleteByID, arg.ID, arg.UpdatedAt, - arg.CancelledAt, arg.CompletedAt, + arg.CancelledAt, arg.Error, ) return err diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 94e7e3577800d..4bac6bb777d5f 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -14,8 +14,6 @@ import ( "sync" "time" - "go.uber.org/atomic" - "github.com/hashicorp/yamux" "cdr.dev/slog" @@ -56,7 +54,12 @@ func New(clientDialer Dialer, opts *Options) io.Closer { closeContext: ctx, closeCancel: ctxCancel, closed: make(chan struct{}), + + jobRunning: make(chan struct{}), } + // Start off with a closed channel so + // isRunningJob() returns properly. + close(daemon.jobRunning) go daemon.connect(ctx) return daemon } @@ -65,38 +68,25 @@ type provisionerDaemon struct { opts *Options clientDialer Dialer - connectMutex sync.Mutex client proto.DRPCProvisionerDaemonClient updateStream proto.DRPCProvisionerDaemon_UpdateJobClient - // Only use for ending a job. closeContext context.Context closeCancel context.CancelFunc closed chan struct{} closeMutex sync.Mutex closeError error - // Lock on acquiring a job so two can't happen at once...? - // If a single cancel can happen, but an acquire could happen? - - // Lock on acquire - // Use atomic for checking if we are running a job - // Use atomic for checking if we are canceling job - // If we're running a job, wait for the done chan in - // close. - - acquiredJob *proto.AcquiredJob - acquiredJobMutex sync.Mutex - acquiredJobCancel context.CancelFunc - acquiredJobCancelled atomic.Bool - acquiredJobRunning atomic.Bool - acquiredJobGroup sync.WaitGroup + jobID string + jobMutex sync.Mutex + jobRunning chan struct{} + jobCancel context.CancelFunc } // Connect establishes a connection to coderd. func (p *provisionerDaemon) connect(ctx context.Context) { - p.connectMutex.Lock() - defer p.connectMutex.Unlock() + p.jobMutex.Lock() + defer p.jobMutex.Unlock() var err error // An exponential back-off occurs when the connection is failing to dial. @@ -118,6 +108,11 @@ func (p *provisionerDaemon) connect(ctx context.Context) { p.opts.Logger.Debug(context.Background(), "connected") break } + select { + case <-ctx.Done(): + return + default: + } go func() { if p.isClosed() { @@ -131,6 +126,12 @@ func (p *provisionerDaemon) connect(ctx context.Context) { // has been interrupted. This works well, because logs need // to buffer if a job is running in the background. p.opts.Logger.Debug(context.Background(), "update stream ended", slog.Error(p.updateStream.Context().Err())) + // Make sure we're not closing here! + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.isClosed() { + return + } p.connect(ctx) } }() @@ -143,7 +144,7 @@ func (p *provisionerDaemon) connect(ctx context.Context) { defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-p.closed: return case <-p.updateStream.Context().Done(): return @@ -154,16 +155,25 @@ func (p *provisionerDaemon) connect(ctx context.Context) { }() } +func (p *provisionerDaemon) isRunningJob() bool { + select { + case <-p.jobRunning: + return false + default: + return true + } +} + // Locks a job in the database, and runs it! func (p *provisionerDaemon) acquireJob(ctx context.Context) { - p.acquiredJobMutex.Lock() - defer p.acquiredJobMutex.Unlock() + p.jobMutex.Lock() + defer p.jobMutex.Unlock() if p.isRunningJob() { p.opts.Logger.Debug(context.Background(), "skipping acquire; job is already running") return } var err error - p.acquiredJob, err = p.client.AcquireJob(ctx, &proto.Empty{}) + job, err := p.client.AcquireJob(ctx, &proto.Empty{}) if err != nil { if errors.Is(err, context.Canceled) { return @@ -177,35 +187,28 @@ func (p *provisionerDaemon) acquireJob(ctx context.Context) { if p.isClosed() { return } - if p.acquiredJob.JobId == "" { + if job.JobId == "" { p.opts.Logger.Debug(context.Background(), "no jobs available") return } - ctx, p.acquiredJobCancel = context.WithCancel(ctx) - p.acquiredJobCancelled.Store(false) - p.acquiredJobRunning.Store(true) - p.acquiredJobGroup.Add(1) + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + ctx, p.jobCancel = context.WithCancel(ctx) + p.jobRunning = make(chan struct{}) + p.jobID = job.JobId p.opts.Logger.Info(context.Background(), "acquired job", - slog.F("organization_name", p.acquiredJob.OrganizationName), - slog.F("project_name", p.acquiredJob.ProjectName), - slog.F("username", p.acquiredJob.UserName), - slog.F("provisioner", p.acquiredJob.Provisioner), + slog.F("organization_name", job.OrganizationName), + slog.F("project_name", job.ProjectName), + slog.F("username", job.UserName), + slog.F("provisioner", job.Provisioner), + slog.F("id", job.JobId), ) - go p.runJob(ctx) + go p.runJob(ctx, job) } -func (p *provisionerDaemon) isRunningJob() bool { - return p.acquiredJobRunning.Load() -} - -func (p *provisionerDaemon) runJob(ctx context.Context) { - // Prevents p.updateStream from being accessed and - // written to at the same time. - p.connectMutex.Lock() - defer p.connectMutex.Unlock() - +func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) { go func() { ticker := time.NewTicker(p.opts.UpdateInterval) defer ticker.Stop() @@ -216,10 +219,10 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { return case <-ticker.C: err := p.updateStream.Send(&proto.JobUpdate{ - JobId: p.acquiredJob.JobId, + JobId: job.JobId, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("send periodic update: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("send periodic update: %s", err)) return } } @@ -228,43 +231,40 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { // Cleanup the work directory after execution. err := os.RemoveAll(p.opts.WorkDirectory) if err != nil { - p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err)) + go p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err)) return } p.opts.Logger.Debug(ctx, "cleaned up work directory") - p.acquiredJobMutex.Lock() - defer p.acquiredJobMutex.Unlock() - p.acquiredJobRunning.Store(false) - p.acquiredJobGroup.Done() + close(p.jobRunning) }() // It's safe to cast this ProvisionerType. This data is coming directly from coderd. - provisioner, hasProvisioner := p.opts.Provisioners[p.acquiredJob.Provisioner] + provisioner, hasProvisioner := p.opts.Provisioners[job.Provisioner] if !hasProvisioner { - p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", p.acquiredJob.Provisioner)) + go p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", job.Provisioner)) return } err := os.MkdirAll(p.opts.WorkDirectory, 0700) if err != nil { - p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) + go p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) return } - p.opts.Logger.Info(ctx, "unpacking project source archive", slog.F("size_bytes", len(p.acquiredJob.ProjectSourceArchive))) - reader := tar.NewReader(bytes.NewBuffer(p.acquiredJob.ProjectSourceArchive)) + p.opts.Logger.Info(ctx, "unpacking project source archive", slog.F("size_bytes", len(job.ProjectSourceArchive))) + reader := tar.NewReader(bytes.NewBuffer(job.ProjectSourceArchive)) for { header, err := reader.Next() if errors.Is(err, io.EOF) { break } if err != nil { - p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err)) return } // #nosec path := filepath.Join(p.opts.WorkDirectory, header.Name) if !strings.HasPrefix(path, filepath.Clean(p.opts.WorkDirectory)) { - p.cancelActiveJob("tar attempts to target relative upper directory") + go p.cancelActiveJob("tar attempts to target relative upper directory") return } mode := header.FileInfo().Mode() @@ -275,14 +275,14 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { case tar.TypeDir: err = os.MkdirAll(path, mode) if err != nil { - p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err)) + go p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err)) return } p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", path)) case tar.TypeReg: file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode) if err != nil { - p.cancelActiveJob(fmt.Sprintf("create file %q (mode %s): %s", path, mode, err)) + go p.cancelActiveJob(fmt.Sprintf("create file %q (mode %s): %s", path, mode, err)) return } // Max file size of 10MB. @@ -291,12 +291,12 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { err = nil } if err != nil { - p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err)) + go p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err)) return } err = file.Close() if err != nil { - p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err)) + go p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err)) return } p.opts.Logger.Debug(context.Background(), "extracted file", @@ -307,13 +307,13 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { } } - switch jobType := p.acquiredJob.Type.(type) { + switch jobType := job.Type.(type) { case *proto.AcquiredJob_ProjectImport_: p.opts.Logger.Debug(context.Background(), "acquired job is project import", slog.F("project_history_name", jobType.ProjectImport.ProjectHistoryName), ) - p.runProjectImport(ctx, provisioner, jobType) + p.runProjectImport(ctx, provisioner, job) case *proto.AcquiredJob_WorkspaceProvision_: p.opts.Logger.Debug(context.Background(), "acquired job is workspace provision", slog.F("workspace_name", jobType.WorkspaceProvision.WorkspaceName), @@ -321,29 +321,28 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { slog.F("parameters", jobType.WorkspaceProvision.ParameterValues), ) - p.runWorkspaceProvision(ctx, provisioner, jobType) + p.runWorkspaceProvision(ctx, provisioner, job) default: - p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(p.acquiredJob.Type).String())) + go p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String())) return } - p.acquiredJobCancel() - p.opts.Logger.Info(context.Background(), "completed job") + p.opts.Logger.Info(context.Background(), "completed job", slog.F("id", job.JobId)) } -func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_ProjectImport_) { +func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{ Directory: p.opts.WorkDirectory, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("parse source: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("parse source: %s", err)) return } defer stream.Close() for { msg, err := stream.Recv() if err != nil { - p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err)) return } switch msgType := msg.Type.(type) { @@ -351,11 +350,11 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd p.opts.Logger.Debug(context.Background(), "parse job logged", slog.F("level", msgType.Log.Level), slog.F("output", msgType.Log.Output), - slog.F("project_history_id", job.ProjectImport.ProjectHistoryId), + slog.F("project_history_id", job.GetProjectImport().ProjectHistoryId), ) err = p.updateStream.Send(&proto.JobUpdate{ - JobId: p.acquiredJob.JobId, + JobId: job.JobId, ProjectImportLogs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, Level: msgType.Log.Level, @@ -364,12 +363,15 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd }}, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("update job: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("update job: %s", err)) return } case *sdkproto.Parse_Response_Complete: + p.opts.Logger.Info(context.Background(), "parse job complete", + slog.F("parameter_schemas", msgType.Complete.ParameterSchemas)) + _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ - JobId: p.acquiredJob.JobId, + JobId: job.JobId, Type: &proto.CompletedJob_ProjectImport_{ ProjectImport: &proto.CompletedJob_ProjectImport{ ParameterSchemas: msgType.Complete.ParameterSchemas, @@ -377,27 +379,27 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd }, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) return } // Return so we stop looping! return default: - p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + go p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", reflect.TypeOf(msg.Type).String())) return } } } -func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob_WorkspaceProvision_) { +func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) { stream, err := provisioner.Provision(ctx, &sdkproto.Provision_Request{ Directory: p.opts.WorkDirectory, - ParameterValues: job.WorkspaceProvision.ParameterValues, - State: job.WorkspaceProvision.State, + ParameterValues: job.GetWorkspaceProvision().ParameterValues, + State: job.GetWorkspaceProvision().State, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("provision: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("provision: %s", err)) return } defer stream.Close() @@ -405,7 +407,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision for { msg, err := stream.Recv() if err != nil { - p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err)) return } switch msgType := msg.Type.(type) { @@ -413,11 +415,11 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision p.opts.Logger.Debug(context.Background(), "workspace provision job logged", slog.F("level", msgType.Log.Level), slog.F("output", msgType.Log.Output), - slog.F("workspace_history_id", job.WorkspaceProvision.WorkspaceHistoryId), + slog.F("workspace_history_id", job.GetWorkspaceProvision().WorkspaceHistoryId), ) err = p.updateStream.Send(&proto.JobUpdate{ - JobId: p.acquiredJob.JobId, + JobId: job.JobId, WorkspaceProvisionLogs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, Level: msgType.Log.Level, @@ -426,7 +428,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision }}, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("send job update: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("send job update: %s", err)) return } case *sdkproto.Provision_Response_Complete: @@ -439,7 +441,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision // Complete job may need to be async if we disconnected... // When we reconnect we can flush any of these cached values. _, err = p.client.CompleteJob(ctx, &proto.CompletedJob{ - JobId: p.acquiredJob.JobId, + JobId: job.JobId, Type: &proto.CompletedJob_WorkspaceProvision_{ WorkspaceProvision: &proto.CompletedJob_WorkspaceProvision{ State: msgType.Complete.State, @@ -448,13 +450,13 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision }, }) if err != nil { - p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + go p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) return } // Return so we stop looping! return default: - p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", + go p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", reflect.TypeOf(msg.Type).String())) return } @@ -462,6 +464,8 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision } func (p *provisionerDaemon) cancelActiveJob(errMsg string) { + p.jobMutex.Lock() + defer p.jobMutex.Unlock() if p.isClosed() { return } @@ -469,23 +473,20 @@ func (p *provisionerDaemon) cancelActiveJob(errMsg string) { p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg)) return } - if p.acquiredJobCancelled.Load() { - return - } - p.acquiredJobCancelled.Store(true) - p.acquiredJobCancel() + p.jobCancel() p.opts.Logger.Info(context.Background(), "canceling running job", slog.F("error_message", errMsg), - slog.F("job_id", p.acquiredJob.JobId), + slog.F("job_id", p.jobID), ) - _, err := p.client.CancelJob(p.closeContext, &proto.CancelledJob{ - JobId: p.acquiredJob.JobId, + _, err := p.client.CancelJob(context.Background(), &proto.CancelledJob{ + JobId: p.jobID, Error: fmt.Sprintf("provisioner daemon: %s", errMsg), }) if err != nil { p.opts.Logger.Warn(context.Background(), "failed to notify of cancel; job is no longer running", slog.Error(err)) return } + <-p.jobRunning p.opts.Logger.Debug(context.Background(), "canceled running job") } @@ -511,25 +512,18 @@ func (p *provisionerDaemon) closeWithError(err error) error { if p.isClosed() { return p.closeError } - - if p.isRunningJob() { - errMsg := "provisioner daemon was shutdown gracefully" - if err != nil { - errMsg = err.Error() - } - if !p.acquiredJobCancelled.Load() { - p.cancelActiveJob(errMsg) - } - p.acquiredJobGroup.Wait() + p.closeCancel() + errMsg := "provisioner daemon was shutdown gracefully" + if err != nil { + errMsg = err.Error() } - + p.cancelActiveJob(errMsg) + p.jobMutex.Lock() + defer p.jobMutex.Unlock() p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) p.closeError = err close(p.closed) - p.closeCancel() - p.connectMutex.Lock() - defer p.connectMutex.Unlock() if p.updateStream != nil { _ = p.client.DRPCConn().Close() _ = p.updateStream.Close()