From 71c667098c971cb6a64e37b25628df89e5b618dd Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 4 Feb 2022 01:39:55 +0000 Subject: [PATCH] fix: Allow provisionerd to cleanup acquired job If a job is acquired from the database, then provisionerd was killed, the job would be left in an idle state where it was technically in-progress. --- .vscode/settings.json | 1 + provisionerd/provisionerd.go | 118 ++++++++++++++++------------------- 2 files changed, 54 insertions(+), 65 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3886c15fbfa72..3f149bf5d5604 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -34,6 +34,7 @@ "goleak", "hashicorp", "httpmw", + "Jobf", "moby", "nhooyr", "nolint", diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 4bac6bb777d5f..c69fdec7c5efa 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -51,9 +51,8 @@ func New(clientDialer Dialer, opts *Options) io.Closer { clientDialer: clientDialer, opts: opts, - closeContext: ctx, - closeCancel: ctxCancel, - closed: make(chan struct{}), + closeCancel: ctxCancel, + closed: make(chan struct{}), jobRunning: make(chan struct{}), } @@ -71,23 +70,21 @@ type provisionerDaemon struct { client proto.DRPCProvisionerDaemonClient updateStream proto.DRPCProvisionerDaemon_UpdateJobClient - closeContext context.Context - closeCancel context.CancelFunc - closed chan struct{} - closeMutex sync.Mutex - closeError error + // Locked when closing the daemon. + closeMutex sync.Mutex + closeCancel context.CancelFunc + closed chan struct{} + closeError error - jobID string + // Locked when acquiring or canceling a job. jobMutex sync.Mutex + jobID string jobRunning chan struct{} jobCancel context.CancelFunc } // Connect establishes a connection to coderd. func (p *provisionerDaemon) connect(ctx context.Context) { - p.jobMutex.Lock() - defer p.jobMutex.Unlock() - var err error // An exponential back-off occurs when the connection is failing to dial. // This is to prevent server spam in case of a coderd outage. @@ -102,6 +99,9 @@ func (p *provisionerDaemon) connect(ctx context.Context) { } p.updateStream, err = p.client.UpdateJob(ctx) if err != nil { + if errors.Is(err, context.Canceled) { + return + } p.opts.Logger.Warn(context.Background(), "create update job stream", slog.Error(err)) continue } @@ -126,12 +126,6 @@ func (p *provisionerDaemon) connect(ctx context.Context) { // has been interrupted. This works well, because logs need // to buffer if a job is running in the background. p.opts.Logger.Debug(context.Background(), "update stream ended", slog.Error(p.updateStream.Context().Err())) - // Make sure we're not closing here! - p.closeMutex.Lock() - defer p.closeMutex.Unlock() - if p.isClosed() { - return - } p.connect(ctx) } }() @@ -168,6 +162,9 @@ func (p *provisionerDaemon) isRunningJob() bool { func (p *provisionerDaemon) acquireJob(ctx context.Context) { p.jobMutex.Lock() defer p.jobMutex.Unlock() + if p.isClosed() { + return + } if p.isRunningJob() { p.opts.Logger.Debug(context.Background(), "skipping acquire; job is already running") return @@ -184,15 +181,10 @@ func (p *provisionerDaemon) acquireJob(ctx context.Context) { p.opts.Logger.Warn(context.Background(), "acquire job", slog.Error(err)) return } - if p.isClosed() { - return - } if job.JobId == "" { p.opts.Logger.Debug(context.Background(), "no jobs available") return } - p.closeMutex.Lock() - defer p.closeMutex.Unlock() ctx, p.jobCancel = context.WithCancel(ctx) p.jobRunning = make(chan struct{}) p.jobID = job.JobId @@ -222,7 +214,7 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) JobId: job.JobId, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("send periodic update: %s", err)) + go p.cancelActiveJobf("send periodic update: %s", err) return } } @@ -230,23 +222,19 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) defer func() { // Cleanup the work directory after execution. err := os.RemoveAll(p.opts.WorkDirectory) - if err != nil { - go p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err)) - return - } - p.opts.Logger.Debug(ctx, "cleaned up work directory") + p.opts.Logger.Debug(ctx, "cleaned up work directory", slog.Error(err)) close(p.jobRunning) }() // It's safe to cast this ProvisionerType. This data is coming directly from coderd. provisioner, hasProvisioner := p.opts.Provisioners[job.Provisioner] if !hasProvisioner { - go p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", job.Provisioner)) + go p.cancelActiveJobf("provisioner %q not registered", job.Provisioner) return } err := os.MkdirAll(p.opts.WorkDirectory, 0700) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) + go p.cancelActiveJobf("create work directory %q: %s", p.opts.WorkDirectory, err) return } @@ -258,13 +246,13 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) break } if err != nil { - go p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err)) + go p.cancelActiveJobf("read project source archive: %s", err) return } // #nosec path := filepath.Join(p.opts.WorkDirectory, header.Name) if !strings.HasPrefix(path, filepath.Clean(p.opts.WorkDirectory)) { - go p.cancelActiveJob("tar attempts to target relative upper directory") + go p.cancelActiveJobf("tar attempts to target relative upper directory") return } mode := header.FileInfo().Mode() @@ -275,14 +263,14 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) case tar.TypeDir: err = os.MkdirAll(path, mode) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err)) + go p.cancelActiveJobf("mkdir %q: %s", path, err) return } p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", path)) case tar.TypeReg: file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("create file %q (mode %s): %s", path, mode, err)) + go p.cancelActiveJobf("create file %q (mode %s): %s", path, mode, err) return } // Max file size of 10MB. @@ -291,12 +279,12 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) err = nil } if err != nil { - go p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err)) + go p.cancelActiveJobf("copy file %q: %s", path, err) return } err = file.Close() if err != nil { - go p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err)) + go p.cancelActiveJobf("close file %q: %s", path, err) return } p.opts.Logger.Debug(context.Background(), "extracted file", @@ -323,7 +311,7 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob) p.runWorkspaceProvision(ctx, provisioner, job) default: - go p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String())) + go p.cancelActiveJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String()) return } @@ -335,14 +323,14 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd Directory: p.opts.WorkDirectory, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("parse source: %s", err)) + go p.cancelActiveJobf("parse source: %s", err) return } defer stream.Close() for { msg, err := stream.Recv() if err != nil { - go p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err)) + go p.cancelActiveJobf("recv parse source: %s", err) return } switch msgType := msg.Type.(type) { @@ -363,7 +351,7 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd }}, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("update job: %s", err)) + go p.cancelActiveJobf("update job: %s", err) return } case *sdkproto.Parse_Response_Complete: @@ -379,14 +367,14 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd }, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + go p.cancelActiveJobf("complete job: %s", err) return } // Return so we stop looping! return default: - go p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", - reflect.TypeOf(msg.Type).String())) + go p.cancelActiveJobf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String()) return } } @@ -399,7 +387,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision State: job.GetWorkspaceProvision().State, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("provision: %s", err)) + go p.cancelActiveJobf("provision: %s", err) return } defer stream.Close() @@ -407,7 +395,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision for { msg, err := stream.Recv() if err != nil { - go p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err)) + go p.cancelActiveJobf("recv workspace provision: %s", err) return } switch msgType := msg.Type.(type) { @@ -428,7 +416,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision }}, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("send job update: %s", err)) + go p.cancelActiveJobf("send job update: %s", err) return } case *sdkproto.Provision_Response_Complete: @@ -450,26 +438,28 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision }, }) if err != nil { - go p.cancelActiveJob(fmt.Sprintf("complete job: %s", err)) + go p.cancelActiveJobf("complete job: %s", err) return } // Return so we stop looping! return default: - go p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner", - reflect.TypeOf(msg.Type).String())) + go p.cancelActiveJobf("invalid message type %q received from provisioner", + reflect.TypeOf(msg.Type).String()) return } } } -func (p *provisionerDaemon) cancelActiveJob(errMsg string) { +func (p *provisionerDaemon) cancelActiveJobf(format string, args ...interface{}) { p.jobMutex.Lock() defer p.jobMutex.Unlock() - if p.isClosed() { - return - } + errMsg := fmt.Sprintf(format, args...) if !p.isRunningJob() { + if p.isClosed() { + // We don't want to log if we're already closed! + return + } p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg)) return } @@ -512,22 +502,20 @@ func (p *provisionerDaemon) closeWithError(err error) error { if p.isClosed() { return p.closeError } - p.closeCancel() + p.closeError = err + close(p.closed) + errMsg := "provisioner daemon was shutdown gracefully" if err != nil { errMsg = err.Error() } - p.cancelActiveJob(errMsg) - p.jobMutex.Lock() - defer p.jobMutex.Unlock() - p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) - p.closeError = err - close(p.closed) + p.cancelActiveJobf(errMsg) + p.closeCancel() - if p.updateStream != nil { - _ = p.client.DRPCConn().Close() - _ = p.updateStream.Close() - } + // Required until we're on Go 1.18. See: + // https://github.com/golang/go/issues/50510 + _ = os.RemoveAll(p.opts.WorkDirectory) + p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) return err }