diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index 5bee0d3f07065..3e09ed02a5fec 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -380,7 +380,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto. return nil, xerrors.Errorf("insert job logs: %w", err) } server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) - data, err := json.Marshal(logs) + data, err := json.Marshal(provisionerJobLogsMessage{Logs: logs}) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } @@ -549,6 +549,16 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa } case *proto.FailedJob_TemplateImport_: } + + data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } return &proto.Empty{}, nil } @@ -711,6 +721,16 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr reflect.TypeOf(completed.Type).String()) } + data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true}) + if err != nil { + return nil, xerrors.Errorf("marshal job log: %w", err) + } + err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data) + if err != nil { + server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) + return nil, xerrors.Errorf("publish end of job logs: %w", err) + } + server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID)) return &proto.Empty{}, nil } diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 8b163412f0ff4..97aafc95909d1 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -28,6 +28,7 @@ import ( // The combination of these responses should provide all current logs // to the consumer, and future logs are streamed in the follow request. func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) { + logger := api.Logger.With(slog.F("job_id", job.ID)) follow := r.URL.Query().Has("follow") afterRaw := r.URL.Query().Get("after") beforeRaw := r.URL.Query().Get("before") @@ -38,6 +39,37 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job return } + // if we are following logs, start the subscription before we query the database, so that we don't miss any logs + // between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track + // of processed IDs. + var bufferedLogs <-chan database.ProvisionerJobLog + if follow { + bl, closeFollow, err := api.followLogs(job.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: "Internal error watching provisioner logs.", + Detail: err.Error(), + }) + return + } + defer closeFollow() + bufferedLogs = bl + + // Next query the job itself to see if it is complete. If so, the historical query to the database will return + // the full set of logs. It's a little sad to have to query the job again, given that our caller definitely + // has, but we need to query it *after* we start following the pubsub to avoid a race condition where the job + // completes between the prior query and the start of following the pubsub. A more substantial refactor could + // avoid this, but not worth it for one fewer query at this point. + job, err = api.Database.GetProvisionerJobByID(r.Context(), job.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: "Internal error querying job.", + Detail: err.Error(), + }) + return + } + } + var after time.Time // Only fetch logs created after the time provided. if afterRaw != "" { @@ -78,26 +110,27 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job } } - if !follow { - logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{ - JobID: job.ID, - CreatedAfter: after, - CreatedBefore: before, + logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{ + JobID: job.ID, + CreatedAfter: after, + CreatedBefore: before, + }) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: "Internal error fetching provisioner logs.", + Detail: err.Error(), }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error fetching provisioner logs.", - Detail: err.Error(), - }) - return - } - if logs == nil { - logs = []database.ProvisionerJobLog{} - } + return + } + if logs == nil { + logs = []database.ProvisionerJobLog{} + } + if !follow { + logger.Debug(r.Context(), "Finished non-follow job logs") httpapi.Write(rw, http.StatusOK, convertProvisionerJobLogs(logs)) return } @@ -118,82 +151,43 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageText) defer wsNetConn.Close() // Also closes conn. - bufferedLogs := make(chan database.ProvisionerJobLog, 128) - closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) { - var logs []database.ProvisionerJobLog - err := json.Unmarshal(message, &logs) - if err != nil { - api.Logger.Warn(ctx, fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error())) - return - } - - for _, log := range logs { - select { - case bufferedLogs <- log: - api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) - default: - // If this overflows users could miss logs streaming. This can happen - // if a database request takes a long amount of time, and we get a lot of logs. - api.Logger.Warn(ctx, "provisioner job log overflowing channel") - } - } - }) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error watching provisioner logs.", - Detail: err.Error(), - }) - return - } - defer closeSubscribe() - - provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{ - JobID: job.ID, - CreatedAfter: after, - CreatedBefore: before, - }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: "Internal error fetching provisioner logs.", - Detail: err.Error(), - }) - return - } + logIdsDone := make(map[uuid.UUID]bool) // The Go stdlib JSON encoder appends a newline character after message write. encoder := json.NewEncoder(wsNetConn) - for _, provisionerJobLog := range provisionerJobLogs { + for _, provisionerJobLog := range logs { + logIdsDone[provisionerJobLog.ID] = true err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog)) if err != nil { return } } + if job.CompletedAt.Valid { + // job was complete before we queried the database for historical logs, meaning we got everything. No need + // to stream anything from the bufferedLogs. + return + } - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() for { select { - case <-r.Context().Done(): - api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID)) + case <-ctx.Done(): + logger.Debug(context.Background(), "job logs context canceled") return - case log := <-bufferedLogs: - api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage)) - err = encoder.Encode(convertProvisionerJobLog(log)) - if err != nil { + case log, ok := <-bufferedLogs: + if !ok { + logger.Debug(context.Background(), "done with published logs") return } - case <-ticker.C: - job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID) - if err != nil { - api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String())) - continue - } - if job.CompletedAt.Valid { - api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID)) - return + if logIdsDone[log.ID] { + logger.Debug(r.Context(), "subscribe duplicated log", + slog.F("stage", log.Stage)) + } else { + logger.Debug(r.Context(), "subscribe encoding log", + slog.F("stage", log.Stage)) + err = encoder.Encode(convertProvisionerJobLog(log)) + if err != nil { + return + } } } } @@ -343,3 +337,43 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov func provisionerJobLogsChannel(jobID uuid.UUID) string { return fmt.Sprintf("provisioner-log-logs:%s", jobID) } + +// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel +type provisionerJobLogsMessage struct { + EndOfLogs bool `json:"end_of_logs,omitempty"` + Logs []database.ProvisionerJobLog `json:"logs,omitempty"` +} + +func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) { + logger := api.Logger.With(slog.F("job_id", jobID)) + bufferedLogs := make(chan database.ProvisionerJobLog, 128) + closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(jobID), + func(ctx context.Context, message []byte) { + jlMsg := provisionerJobLogsMessage{} + err := json.Unmarshal(message, &jlMsg) + if err != nil { + logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err)) + return + } + + for _, log := range jlMsg.Logs { + select { + case bufferedLogs <- log: + logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage)) + default: + // If this overflows users could miss logs streaming. This can happen + // we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub, + // so just drop them. + logger.Warn(ctx, "provisioner job log overflowing channel") + } + } + if jlMsg.EndOfLogs { + logger.Debug(ctx, "got End of Logs") + close(bufferedLogs) + } + }) + if err != nil { + return nil, nil, err + } + return bufferedLogs, closeSubscribe, nil +} diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go new file mode 100644 index 0000000000000..4901f2f1ea9a4 --- /dev/null +++ b/coderd/provisionerjobs_internal_test.go @@ -0,0 +1,183 @@ +package coderd + +import ( + "context" + "crypto/sha256" + "encoding/json" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/codersdk" +) + +func TestProvisionerJobLogs_Unit(t *testing.T) { + t.Parallel() + + t.Run("QueryPubSubDupes", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + // mDB := mocks.NewStore(t) + fDB := databasefake.New() + fPubsub := &fakePubSub{t: t, cond: sync.NewCond(&sync.Mutex{})} + opts := Options{ + Logger: logger, + Database: fDB, + Pubsub: fPubsub, + } + api := New(&opts) + server := httptest.NewServer(api.Handler) + t.Cleanup(server.Close) + userID := uuid.New() + keyID, keySecret, err := generateAPIKeyIDSecret() + require.NoError(t, err) + hashed := sha256.Sum256([]byte(keySecret)) + + u, err := url.Parse(server.URL) + require.NoError(t, err) + client := codersdk.Client{ + HTTPClient: server.Client(), + SessionToken: keyID + "-" + keySecret, + URL: u, + } + + buildID := uuid.New() + workspaceID := uuid.New() + jobID := uuid.New() + + expectedLogs := []database.ProvisionerJobLog{ + {ID: uuid.New(), JobID: jobID, Stage: "Stage0"}, + {ID: uuid.New(), JobID: jobID, Stage: "Stage1"}, + {ID: uuid.New(), JobID: jobID, Stage: "Stage2"}, + {ID: uuid.New(), JobID: jobID, Stage: "Stage3"}, + } + + // wow there are a lot of DB rows we touch... + _, err = fDB.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: keyID, + HashedSecret: hashed[:], + UserID: userID, + ExpiresAt: time.Now().Add(5 * time.Hour), + }) + require.NoError(t, err) + _, err = fDB.InsertUser(ctx, database.InsertUserParams{ + ID: userID, + RBACRoles: []string{"admin"}, + }) + require.NoError(t, err) + _, err = fDB.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: buildID, + WorkspaceID: workspaceID, + JobID: jobID, + }) + require.NoError(t, err) + _, err = fDB.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + ID: workspaceID, + }) + require.NoError(t, err) + _, err = fDB.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: jobID, + }) + require.NoError(t, err) + for _, l := range expectedLogs[:2] { + _, err := fDB.InsertProvisionerJobLogs(ctx, database.InsertProvisionerJobLogsParams{ + ID: []uuid.UUID{l.ID}, + JobID: jobID, + Stage: []string{l.Stage}, + }) + require.NoError(t, err) + } + + logs, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now()) + require.NoError(t, err) + + // when the endpoint calls subscribe, we get the listener here. + fPubsub.cond.L.Lock() + for fPubsub.listener == nil { + fPubsub.cond.Wait() + } + + // endpoint should now be listening + assert.False(t, fPubsub.canceled) + assert.False(t, fPubsub.closed) + + // send all the logs in two batches, duplicating what we already returned on the DB query. + msg := provisionerJobLogsMessage{} + msg.Logs = expectedLogs[:2] + data, err := json.Marshal(msg) + require.NoError(t, err) + fPubsub.listener(ctx, data) + msg.Logs = expectedLogs[2:] + data, err = json.Marshal(msg) + require.NoError(t, err) + fPubsub.listener(ctx, data) + + // send end of logs + msg.Logs = nil + msg.EndOfLogs = true + data, err = json.Marshal(msg) + require.NoError(t, err) + fPubsub.listener(ctx, data) + + var stages []string + for l := range logs { + logger.Info(ctx, "got log", + slog.F("id", l.ID), + slog.F("stage", l.Stage)) + stages = append(stages, l.Stage) + } + assert.Equal(t, []string{"Stage0", "Stage1", "Stage2", "Stage3"}, stages) + for !fPubsub.canceled { + fPubsub.cond.Wait() + } + assert.False(t, fPubsub.closed) + }) +} + +type fakePubSub struct { + t *testing.T + cond *sync.Cond + listener database.Listener + canceled bool + closed bool +} + +func (f *fakePubSub) Subscribe(_ string, listener database.Listener) (cancel func(), err error) { + f.cond.L.Lock() + defer f.cond.L.Unlock() + f.listener = listener + f.cond.Signal() + return f.cancel, nil +} + +func (f *fakePubSub) Publish(_ string, _ []byte) error { + f.t.Fail() + return nil +} + +func (f *fakePubSub) Close() error { + f.cond.L.Lock() + defer f.cond.L.Unlock() + f.closed = true + f.cond.Signal() + return nil +} + +func (f *fakePubSub) cancel() { + f.cond.L.Lock() + defer f.cond.L.Unlock() + f.canceled = true + f.cond.Signal() +} diff --git a/coderd/provisionerjobs_test.go b/coderd/provisionerjobs_test.go index 404cd53683aa3..9d35f482dadc6 100644 --- a/coderd/provisionerjobs_test.go +++ b/coderd/provisionerjobs_test.go @@ -45,7 +45,8 @@ func TestProvisionerJobLogs(t *testing.T) { logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) require.NoError(t, err) for { - _, ok := <-logs + log, ok := <-logs + t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output) if !ok { return }