Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix pubsub/poll race on provisioner job logs #2783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto.
return nil, xerrors.Errorf("insert job logs: %w", err)
}
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
data, err := json.Marshal(logs)
data, err := json.Marshal(provisionerJobLogsMessage{Logs: logs})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
Expand Down Expand Up @@ -549,6 +549,16 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa
}
case *proto.FailedJob_TemplateImport_:
}

data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
}
return &proto.Empty{}, nil
}

Expand Down Expand Up @@ -711,6 +721,16 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr
reflect.TypeOf(completed.Type).String())
}

data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
}

server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID))
return &proto.Empty{}, nil
}
Expand Down
194 changes: 114 additions & 80 deletions coderd/provisionerjobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
// The combination of these responses should provide all current logs
// to the consumer, and future logs are streamed in the follow request.
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
logger := api.Logger.With(slog.F("job_id", job.ID))
follow := r.URL.Query().Has("follow")
afterRaw := r.URL.Query().Get("after")
beforeRaw := r.URL.Query().Get("before")
Expand All @@ -38,6 +39,37 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
return
}

// if we are following logs, start the subscription before we query the database, so that we don't miss any logs
// between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track
// of processed IDs.
var bufferedLogs <-chan database.ProvisionerJobLog
if follow {
bl, closeFollow, err := api.followLogs(job.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error watching provisioner logs.",
Detail: err.Error(),
})
return
}
defer closeFollow()
bufferedLogs = bl

// Next query the job itself to see if it is complete. If so, the historical query to the database will return
// the full set of logs. It's a little sad to have to query the job again, given that our caller definitely
// has, but we need to query it *after* we start following the pubsub to avoid a race condition where the job
// completes between the prior query and the start of following the pubsub. A more substantial refactor could
// avoid this, but not worth it for one fewer query at this point.
job, err = api.Database.GetProvisionerJobByID(r.Context(), job.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error querying job.",
Detail: err.Error(),
})
return
}
}

var after time.Time
// Only fetch logs created after the time provided.
if afterRaw != "" {
Expand Down Expand Up @@ -78,26 +110,27 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
}
}

if !follow {
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}

if !follow {
logger.Debug(r.Context(), "Finished non-follow job logs")
httpapi.Write(rw, http.StatusOK, convertProvisionerJobLogs(logs))
return
}
Expand All @@ -118,82 +151,43 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageText)
defer wsNetConn.Close() // Also closes conn.

bufferedLogs := make(chan database.ProvisionerJobLog, 128)
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) {
var logs []database.ProvisionerJobLog
err := json.Unmarshal(message, &logs)
if err != nil {
api.Logger.Warn(ctx, fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
return
}

for _, log := range logs {
select {
case bufferedLogs <- log:
api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
default:
// If this overflows users could miss logs streaming. This can happen
// if a database request takes a long amount of time, and we get a lot of logs.
api.Logger.Warn(ctx, "provisioner job log overflowing channel")
}
}
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error watching provisioner logs.",
Detail: err.Error(),
})
return
}
defer closeSubscribe()

provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
logIdsDone := make(map[uuid.UUID]bool)

// The Go stdlib JSON encoder appends a newline character after message write.
encoder := json.NewEncoder(wsNetConn)
for _, provisionerJobLog := range provisionerJobLogs {
for _, provisionerJobLog := range logs {
logIdsDone[provisionerJobLog.ID] = true
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
if err != nil {
return
}
}
if job.CompletedAt.Valid {
// job was complete before we queried the database for historical logs, meaning we got everything. No need
// to stream anything from the bufferedLogs.
return
}

ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID))
case <-ctx.Done():
logger.Debug(context.Background(), "job logs context canceled")
return
case log := <-bufferedLogs:
api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
err = encoder.Encode(convertProvisionerJobLog(log))
if err != nil {
case log, ok := <-bufferedLogs:
if !ok {
logger.Debug(context.Background(), "done with published logs")
return
}
case <-ticker.C:
job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID)
if err != nil {
api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
continue
}
if job.CompletedAt.Valid {
api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID))
return
if logIdsDone[log.ID] {
logger.Debug(r.Context(), "subscribe duplicated log",
slog.F("stage", log.Stage))
} else {
logger.Debug(r.Context(), "subscribe encoding log",
slog.F("stage", log.Stage))
err = encoder.Encode(convertProvisionerJobLog(log))
if err != nil {
return
}
}
}
}
Expand Down Expand Up @@ -343,3 +337,43 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
func provisionerJobLogsChannel(jobID uuid.UUID) string {
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
}

// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel
type provisionerJobLogsMessage struct {
EndOfLogs bool `json:"end_of_logs,omitempty"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be End instead for naming redundancy. Final may be more apt, chefs choice.

Logs []database.ProvisionerJobLog `json:"logs,omitempty"`
}

func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
logger := api.Logger.With(slog.F("job_id", jobID))
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(jobID),
func(ctx context.Context, message []byte) {
jlMsg := provisionerJobLogsMessage{}
err := json.Unmarshal(message, &jlMsg)
if err != nil {
logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err))
return
}

for _, log := range jlMsg.Logs {
select {
case bufferedLogs <- log:
logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage))
default:
// If this overflows users could miss logs streaming. This can happen
// we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub,
// so just drop them.
logger.Warn(ctx, "provisioner job log overflowing channel")
}
}
if jlMsg.EndOfLogs {
logger.Debug(ctx, "got End of Logs")
close(bufferedLogs)
}
})
if err != nil {
return nil, nil, err
}
return bufferedLogs, closeSubscribe, nil
}
Loading