diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index f69f366b43d4e..62fe6fad8d4de 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -24,6 +24,7 @@ import ( "github.com/coder/coder/v2/coderd/prometheusmetrics" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspacestats" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" @@ -45,14 +46,15 @@ type API struct { *ScriptsAPI *tailnet.DRPCService - mu sync.Mutex - cachedWorkspaceID uuid.UUID + mu sync.Mutex } var _ agentproto.DRPCAgentServer = &API{} type Options struct { - AgentID uuid.UUID + AgentID uuid.UUID + OwnerID uuid.UUID + WorkspaceID uuid.UUID Ctx context.Context Log slog.Logger @@ -62,7 +64,7 @@ type Options struct { TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] StatsReporter *workspacestats.Reporter AppearanceFetcher *atomic.Pointer[appearance.Fetcher] - PublishWorkspaceUpdateFn func(ctx context.Context, workspaceID uuid.UUID) + PublishWorkspaceUpdateFn func(ctx context.Context, userID uuid.UUID, event wspubsub.WorkspaceEvent) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) NetworkTelemetryHandler func(batch []*tailnetproto.TelemetryEvent) @@ -75,18 +77,13 @@ type Options struct { ExternalAuthConfigs []*externalauth.Config Experiments codersdk.Experiments - // Optional: - // WorkspaceID avoids a future lookup to find the workspace ID by setting - // the cache in advance. - WorkspaceID uuid.UUID UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) } func New(opts Options) *API { api := &API{ - opts: opts, - mu: sync.Mutex{}, - cachedWorkspaceID: opts.WorkspaceID, + opts: opts, + mu: sync.Mutex{}, } api.ManifestAPI = &ManifestAPI{ @@ -98,16 +95,7 @@ func New(opts Options) *API { AgentFn: api.agent, Database: opts.Database, DerpMapFn: opts.DerpMapFn, - WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) { - if opts.WorkspaceID != uuid.Nil { - return opts.WorkspaceID, nil - } - ws, err := opts.Database.GetWorkspaceByAgentID(ctx, wa.ID) - if err != nil { - return uuid.Nil, err - } - return ws.ID, nil - }, + WorkspaceID: opts.WorkspaceID, } api.AnnouncementBannerAPI = &AnnouncementBannerAPI{ @@ -125,7 +113,7 @@ func New(opts Options) *API { api.LifecycleAPI = &LifecycleAPI{ AgentFn: api.agent, - WorkspaceIDFn: api.workspaceID, + WorkspaceID: opts.WorkspaceID, Database: opts.Database, Log: opts.Log, PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, @@ -209,39 +197,11 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil } -func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - a.mu.Lock() - if a.cachedWorkspaceID != uuid.Nil { - id := a.cachedWorkspaceID - a.mu.Unlock() - return id, nil - } - - if agent == nil { - agnt, err := a.agent(ctx) - if err != nil { - return uuid.Nil, err - } - agent = &agnt - } - - getWorkspaceAgentByIDRow, err := a.opts.Database.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return uuid.Nil, xerrors.Errorf("get workspace by agent id %q: %w", agent.ID, err) - } - - a.mu.Lock() - a.cachedWorkspaceID = getWorkspaceAgentByIDRow.ID - a.mu.Unlock() - return getWorkspaceAgentByIDRow.ID, nil -} - -func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error { - workspaceID, err := a.workspaceID(ctx, agent) - if err != nil { - return err - } - - a.opts.PublishWorkspaceUpdateFn(ctx, workspaceID) +func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { + a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{ + Kind: kind, + WorkspaceID: a.opts.WorkspaceID, + AgentID: &agent.ID, + }) return nil } diff --git a/coderd/agentapi/apps.go b/coderd/agentapi/apps.go index b8aefa8883c3b..956e154e89d0d 100644 --- a/coderd/agentapi/apps.go +++ b/coderd/agentapi/apps.go @@ -9,13 +9,14 @@ import ( "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/wspubsub" ) type AppsAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) Database database.Store Log slog.Logger - PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error } func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) { @@ -96,7 +97,7 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat } if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 { - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAppHealthUpdate) if err != nil { return nil, xerrors.Errorf("publish workspace update: %w", err) } diff --git a/coderd/agentapi/apps_test.go b/coderd/agentapi/apps_test.go index c774c6777b32a..e212a093ae002 100644 --- a/coderd/agentapi/apps_test.go +++ b/coderd/agentapi/apps_test.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/agentapi" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/wspubsub" ) func TestBatchUpdateAppHealths(t *testing.T) { @@ -62,7 +63,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishCalled = true return nil }, @@ -100,7 +101,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishCalled = true return nil }, @@ -139,7 +140,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishCalled = true return nil }, diff --git a/coderd/agentapi/lifecycle.go b/coderd/agentapi/lifecycle.go index e5211e804a7c4..5dd5e7b0c1b06 100644 --- a/coderd/agentapi/lifecycle.go +++ b/coderd/agentapi/lifecycle.go @@ -15,6 +15,7 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/wspubsub" ) type contextKeyAPIVersion struct{} @@ -25,10 +26,10 @@ func WithAPIVersion(ctx context.Context, version string) context.Context { type LifecycleAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) - WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) + WorkspaceID uuid.UUID Database database.Store Log slog.Logger - PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error TimeNowFn func() time.Time // defaults to dbtime.Now() } @@ -45,13 +46,9 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda if err != nil { return nil, err } - workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) - if err != nil { - return nil, err - } logger := a.Log.With( - slog.F("workspace_id", workspaceID), + slog.F("workspace_id", a.WorkspaceID), slog.F("payload", req), ) logger.Debug(ctx, "workspace agent state report") @@ -122,7 +119,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda } if a.PublishWorkspaceUpdateFn != nil { - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLifecycleUpdate) if err != nil { return nil, xerrors.Errorf("publish workspace update: %w", err) } @@ -140,15 +137,11 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update if err != nil { return nil, err } - workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) - if err != nil { - return nil, err - } a.Log.Debug( ctx, "post workspace agent version", - slog.F("workspace_id", workspaceID), + slog.F("workspace_id", a.WorkspaceID), slog.F("agent_version", req.Startup.Version), ) diff --git a/coderd/agentapi/lifecycle_test.go b/coderd/agentapi/lifecycle_test.go index fe1469db0aa99..5ec6834d6b878 100644 --- a/coderd/agentapi/lifecycle_test.go +++ b/coderd/agentapi/lifecycle_test.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/wspubsub" ) func TestUpdateLifecycle(t *testing.T) { @@ -69,12 +70,10 @@ func TestUpdateLifecycle(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishCalled = true return nil }, @@ -111,11 +110,9 @@ func TestUpdateLifecycle(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentStarting, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), // Test that nil publish fn works. PublishWorkspaceUpdateFn: nil, } @@ -156,12 +153,10 @@ func TestUpdateLifecycle(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishCalled = true return nil }, @@ -204,9 +199,7 @@ func TestUpdateLifecycle(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, + WorkspaceID: workspaceID, Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: nil, @@ -239,12 +232,10 @@ func TestUpdateLifecycle(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { atomic.AddInt64(&publishCalled, 1) return nil }, @@ -314,12 +305,10 @@ func TestUpdateLifecycle(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishCalled = true return nil }, @@ -354,11 +343,9 @@ func TestUpdateStartup(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), // Not used by UpdateStartup. PublishWorkspaceUpdateFn: nil, } @@ -402,11 +389,9 @@ func TestUpdateStartup(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), // Not used by UpdateStartup. PublishWorkspaceUpdateFn: nil, } @@ -435,11 +420,9 @@ func TestUpdateStartup(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil }, - WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { - return workspaceID, nil - }, - Database: dbM, - Log: slogtest.Make(t, nil), + WorkspaceID: workspaceID, + Database: dbM, + Log: slogtest.Make(t, nil), // Not used by UpdateStartup. PublishWorkspaceUpdateFn: nil, } diff --git a/coderd/agentapi/logs.go b/coderd/agentapi/logs.go index 809137525fd04..1d63f32b7b0dd 100644 --- a/coderd/agentapi/logs.go +++ b/coderd/agentapi/logs.go @@ -11,6 +11,7 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk/agentsdk" ) @@ -18,7 +19,7 @@ type LogsAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) Database database.Store Log slog.Logger - PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) TimeNowFn func() time.Time // defaults to dbtime.Now() @@ -123,7 +124,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea } if a.PublishWorkspaceUpdateFn != nil { - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLogsOverflow) if err != nil { return nil, xerrors.Errorf("publish workspace update: %w", err) } @@ -143,7 +144,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil { // If these are the first logs being appended, we publish a UI update // to notify the UI that logs are now available. - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentFirstLogs) if err != nil { return nil, xerrors.Errorf("publish workspace update: %w", err) } diff --git a/coderd/agentapi/logs_test.go b/coderd/agentapi/logs_test.go index 261b6c8f6ea83..8e6638ba82624 100644 --- a/coderd/agentapi/logs_test.go +++ b/coderd/agentapi/logs_test.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk/agentsdk" ) @@ -50,7 +51,7 @@ func TestBatchCreateLogs(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishWorkspaceUpdateCalled = true return nil }, @@ -154,7 +155,7 @@ func TestBatchCreateLogs(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishWorkspaceUpdateCalled = true return nil }, @@ -202,7 +203,7 @@ func TestBatchCreateLogs(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishWorkspaceUpdateCalled = true return nil }, @@ -295,7 +296,7 @@ func TestBatchCreateLogs(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishWorkspaceUpdateCalled = true return nil }, @@ -339,7 +340,7 @@ func TestBatchCreateLogs(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishWorkspaceUpdateCalled = true return nil }, @@ -386,7 +387,7 @@ func TestBatchCreateLogs(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { publishWorkspaceUpdateCalled = true return nil }, diff --git a/coderd/agentapi/manifest.go b/coderd/agentapi/manifest.go index a58bf6941cb04..fd4d38d4a75ab 100644 --- a/coderd/agentapi/manifest.go +++ b/coderd/agentapi/manifest.go @@ -29,11 +29,11 @@ type ManifestAPI struct { ExternalAuthConfigs []*externalauth.Config DisableDirectConnections bool DerpForceWebSockets bool + WorkspaceID uuid.UUID - AgentFn func(context.Context) (database.WorkspaceAgent, error) - WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) - Database database.Store - DerpMapFn func() *tailcfg.DERPMap + AgentFn func(context.Context) (database.WorkspaceAgent, error) + Database database.Store + DerpMapFn func() *tailcfg.DERPMap } func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) { @@ -41,11 +41,6 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest if err != nil { return nil, err } - workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) - if err != nil { - return nil, err - } - var ( dbApps []database.WorkspaceApp scripts []database.WorkspaceAgentScript @@ -75,7 +70,7 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest return err }) eg.Go(func() (err error) { - workspace, err = a.Database.GetWorkspaceByID(ctx, workspaceID) + workspace, err = a.Database.GetWorkspaceByID(ctx, a.WorkspaceID) if err != nil { return xerrors.Errorf("getting workspace by id: %w", err) } diff --git a/coderd/agentapi/manifest_test.go b/coderd/agentapi/manifest_test.go index e7a36081f64b4..2cde35ba03ab9 100644 --- a/coderd/agentapi/manifest_test.go +++ b/coderd/agentapi/manifest_test.go @@ -288,11 +288,9 @@ func TestGetManifest(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil }, - WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) { - return workspace.ID, nil - }, - Database: mDB, - DerpMapFn: derpMapFn, + WorkspaceID: workspace.ID, + Database: mDB, + DerpMapFn: derpMapFn, } mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil) @@ -355,11 +353,9 @@ func TestGetManifest(t *testing.T) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil }, - WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) { - return workspace.ID, nil - }, - Database: mDB, - DerpMapFn: derpMapFn, + WorkspaceID: workspace.ID, + Database: mDB, + DerpMapFn: derpMapFn, } mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil) diff --git a/coderd/agentapi/stats_test.go b/coderd/agentapi/stats_test.go index 83edb8cccc4e1..3ebf99aa6bc4b 100644 --- a/coderd/agentapi/stats_test.go +++ b/coderd/agentapi/stats_test.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/coderd/workspacestats/workspacestatstest" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -153,12 +154,19 @@ func TestUpdateStates(t *testing.T) { }).Return(nil) // Ensure that pubsub notifications are sent. - notifyDescription := make(chan []byte) - ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, description []byte) { - go func() { - notifyDescription <- description - }() - }) + notifyDescription := make(chan struct{}) + ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStatsUpdate && e.WorkspaceID == workspace.ID { + go func() { + notifyDescription <- struct{}{} + }() + } + })) resp, err := api.UpdateStats(context.Background(), req) require.NoError(t, err) @@ -183,8 +191,7 @@ func TestUpdateStates(t *testing.T) { select { case <-ctx.Done(): t.Error("timed out while waiting for pubsub notification") - case description := <-notifyDescription: - require.Equal(t, description, []byte{}) + case <-notifyDescription: } require.True(t, updateAgentMetricsFnCalled) }) @@ -495,12 +502,19 @@ func TestUpdateStates(t *testing.T) { dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil) // Ensure that pubsub notifications are sent. - notifyDescription := make(chan []byte) - ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, description []byte) { - go func() { - notifyDescription <- description - }() - }) + notifyDescription := make(chan struct{}) + ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStatsUpdate && e.WorkspaceID == workspace.ID { + go func() { + notifyDescription <- struct{}{} + }() + } + })) resp, err := api.UpdateStats(context.Background(), req) require.NoError(t, err) @@ -523,8 +537,7 @@ func TestUpdateStates(t *testing.T) { select { case <-ctx.Done(): t.Error("timed out while waiting for pubsub notification") - case description := <-notifyDescription: - require.Equal(t, description, []byte{}) + case <-notifyDescription: } require.True(t, updateAgentMetricsFnCalled) }) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 616dd2afac619..3ff9f59fa138e 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -19,7 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/telemetry" - "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" ) @@ -225,7 +225,12 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse { _ = dbgen.WorkspaceBuildParameters(b.t, b.db, b.params) if b.ps != nil { - err = b.ps.Publish(codersdk.WorkspaceNotifyChannel(resp.Build.WorkspaceID), []byte{}) + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: resp.Workspace.ID, + }) + require.NoError(b.t, err) + err = b.ps.Publish(wspubsub.WorkspaceEventChannel(resp.Workspace.OwnerID), msg) require.NoError(b.t, err) } diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 6c72ff5831947..9a9da3d8112ab 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/provisioner" @@ -493,7 +494,15 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo for _, group := range ownerGroups { ownerGroupNames = append(ownerGroupNames, group.Group.Name) } - err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{}) + + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: workspace.ID, + }) + if err != nil { + return nil, failJob(fmt.Sprintf("marshal workspace update event: %s", err)) + } + err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) if err != nil { return nil, failJob(fmt.Sprintf("publish workspace update: %s", err)) } @@ -1023,9 +1032,16 @@ func (s *server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto. s.notifyWorkspaceBuildFailed(ctx, workspace, build) - err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{}) + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: workspace.ID, + }) if err != nil { - return nil, xerrors.Errorf("update workspace: %w", err) + return nil, xerrors.Errorf("marshal workspace update event: %s", err) + } + err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) } case *proto.FailedJob_TemplateImport_: } @@ -1369,9 +1385,6 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) return nil, xerrors.Errorf("update provisioner job: %w", err) } s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) - if err != nil { - return nil, xerrors.Errorf("complete job: %w", err) - } case *proto.CompletedJob_WorkspaceBuild_: var input WorkspaceProvisionJob err = json.Unmarshal(job.Input, &input) @@ -1491,7 +1504,15 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) return case <-wait: // Wait for the next potential timeout to occur. - if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil { + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentTimeout, + WorkspaceID: workspace.ID, + }) + if err != nil { + s.Logger.Error(ctx, "marshal workspace update event", slog.Error(err)) + break + } + if err := s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg); err != nil { if s.lifecycleCtx.Err() != nil { // If the server is shutting down, we don't want to log this error, nor wait around. s.Logger.Debug(ctx, "stopping notifications due to server shutdown", @@ -1608,7 +1629,14 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) }) } - err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}) + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: workspace.ID, + }) + if err != nil { + return nil, xerrors.Errorf("marshal workspace update event: %s", err) + } + err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) if err != nil { return nil, xerrors.Errorf("update workspace: %w", err) } @@ -1639,9 +1667,6 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) return nil, xerrors.Errorf("update provisioner job: %w", err) } s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) - if err != nil { - return nil, xerrors.Errorf("complete job: %w", err) - } default: if completed.Type == nil { diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index baa53b92d74e2..98ab07db3d0f7 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -40,6 +40,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/telemetry" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" @@ -295,12 +296,19 @@ func TestAcquireJob(t *testing.T) { startPublished := make(chan struct{}) var closed bool - closeStartSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { - if !closed { - close(startPublished) - closed = true - } - }) + closeStartSubscribe, err := ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspace.ID { + if !closed { + close(startPublished) + closed = true + } + } + })) require.NoError(t, err) defer closeStartSubscribe() @@ -398,9 +406,16 @@ func TestAcquireJob(t *testing.T) { }) stopPublished := make(chan struct{}) - closeStopSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { - close(stopPublished) - }) + closeStopSubscribe, err := ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspace.ID { + close(stopPublished) + } + })) require.NoError(t, err) defer closeStopSubscribe() @@ -874,12 +889,11 @@ func TestFailJob(t *testing.T) { auditor: auditor, }) org := dbgen.Organization(t, db, database.Organization{}) - workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ ID: uuid.New(), AutomaticUpdates: database.AutomaticUpdatesNever, OrganizationID: org.ID, }) - require.NoError(t, err) buildID := uuid.New() input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: buildID, @@ -889,6 +903,7 @@ func TestFailJob(t *testing.T) { job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Input: input, + InitiatorID: workspace.OwnerID, Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeWorkspaceBuild, StorageMethod: database.ProvisionerStorageMethodFile, @@ -897,6 +912,7 @@ func TestFailJob(t *testing.T) { err = db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: buildID, WorkspaceID: workspace.ID, + InitiatorID: workspace.OwnerID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, JobID: job.ID, @@ -913,9 +929,16 @@ func TestFailJob(t *testing.T) { require.NoError(t, err) publishedWorkspace := make(chan struct{}) - closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { - close(publishedWorkspace) - }) + closeWorkspaceSubscribe, err := ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspace.ID { + close(publishedWorkspace) + } + })) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) @@ -1279,13 +1302,15 @@ func TestCompleteJob(t *testing.T) { }) build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspaceTable.ID, + InitiatorID: user.ID, TemplateVersionID: version.ID, Transition: c.transition, Reason: database.BuildReasonInitiator, }) job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ - FileID: file.ID, - Type: database.ProvisionerJobTypeWorkspaceBuild, + FileID: file.ID, + InitiatorID: user.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), @@ -1302,9 +1327,16 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) publishedWorkspace := make(chan struct{}) - closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { - close(publishedWorkspace) - }) + closeWorkspaceSubscribe, err := ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspaceTable.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspaceTable.ID { + close(publishedWorkspace) + } + })) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) @@ -1643,8 +1675,9 @@ func TestNotifications(t *testing.T) { Reason: tc.deletionReason, }) job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ - FileID: file.ID, - Type: database.ProvisionerJobTypeWorkspaceBuild, + FileID: file.ID, + InitiatorID: initiator.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), @@ -1761,8 +1794,9 @@ func TestNotifications(t *testing.T) { Reason: tc.buildReason, }) job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ - FileID: file.ID, - Type: database.ProvisionerJobTypeWorkspaceBuild, + FileID: file.ID, + InitiatorID: initiator.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), @@ -1833,6 +1867,7 @@ func TestNotifications(t *testing.T) { }) job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ FileID: dbgen.File(t, db, database.File{CreatedBy: user.ID}).ID, + InitiatorID: user.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{WorkspaceBuildID: build.ID})), OrganizationID: pd.OrganizationID, diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index a181697f27279..14e986123edb7 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -34,6 +34,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac/policy" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -242,25 +243,20 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request) api.Logger.Warn(ctx, "failed to update workspace agent log overflow", slog.Error(err)) } - resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) + workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to get workspace resource.", + Message: "Failed to get workspace.", Detail: err.Error(), }) return } - build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Internal error fetching workspace build job.", - Detail: err.Error(), - }) - return - } - - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) + api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentLogsOverflow, + WorkspaceID: workspace.ID, + AgentID: &workspaceAgent.ID, + }) httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{ Message: "Logs limit exceeded", @@ -279,25 +275,20 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request) if workspaceAgent.LogsLength == 0 { // If these are the first logs being appended, we publish a UI update // to notify the UI that logs are now available. - resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) + workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to get workspace resource.", + Message: "Failed to get workspace.", Detail: err.Error(), }) return } - build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Internal error fetching workspace build job.", - Detail: err.Error(), - }) - return - } - - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) + api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentFirstLogs, + WorkspaceID: workspace.ID, + AgentID: &workspaceAgent.ID, + }) } httpapi.Write(ctx, rw, http.StatusOK, nil) @@ -426,12 +417,19 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { notifyCh <- struct{}{} // Subscribe to workspace to detect new builds. - closeSubscribeWorkspace, err := api.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { - select { - case workspaceNotifyCh <- struct{}{}: - default: - } - }) + closeSubscribeWorkspace, err := api.Pubsub.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(_ context.Context, e wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspace.ID { + select { + case workspaceNotifyCh <- struct{}{}: + default: + } + } + })) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to subscribe to workspace for log streaming.", diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index a47fa0c12ed1a..29f2ad476dca0 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -26,6 +26,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" tailnetproto "github.com/coder/coder/v2/tailnet/proto" @@ -132,11 +133,13 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { closeCtx, closeCtxCancel := context.WithCancel(ctx) defer closeCtxCancel() - monitor := api.startAgentYamuxMonitor(closeCtx, workspaceAgent, build, mux) + monitor := api.startAgentYamuxMonitor(closeCtx, workspace, workspaceAgent, build, mux) defer monitor.close() agentAPI := agentapi.New(agentapi.Options{ - AgentID: workspaceAgent.ID, + AgentID: workspaceAgent.ID, + OwnerID: workspace.OwnerID, + WorkspaceID: workspace.ID, Ctx: api.ctx, Log: logger, @@ -160,7 +163,6 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { Experiments: api.Experiments, // Optional: - WorkspaceID: build.WorkspaceID, // saves the extra lookup later UpdateAgentMetricsFn: api.UpdateAgentMetrics, }) @@ -225,11 +227,14 @@ func (y *yamuxPingerCloser) Ping(ctx context.Context) error { } func (api *API) startAgentYamuxMonitor(ctx context.Context, - workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild, + workspace database.Workspace, + workspaceAgent database.WorkspaceAgent, + workspaceBuild database.WorkspaceBuild, mux *yamux.Session, ) *agentConnectionMonitor { monitor := &agentConnectionMonitor{ apiCtx: api.ctx, + workspace: workspace, workspaceAgent: workspaceAgent, workspaceBuild: workspaceBuild, conn: &yamuxPingerCloser{mux: mux}, @@ -250,7 +255,7 @@ func (api *API) startAgentYamuxMonitor(ctx context.Context, } type workspaceUpdater interface { - publishWorkspaceUpdate(ctx context.Context, workspaceID uuid.UUID) + publishWorkspaceUpdate(ctx context.Context, ownerID uuid.UUID, event wspubsub.WorkspaceEvent) } type pingerCloser interface { @@ -262,6 +267,7 @@ type agentConnectionMonitor struct { apiCtx context.Context cancel context.CancelFunc wg sync.WaitGroup + workspace database.Workspace workspaceAgent database.WorkspaceAgent workspaceBuild database.WorkspaceBuild conn pingerCloser @@ -393,7 +399,11 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { ) } } - m.updater.publishWorkspaceUpdate(finalCtx, m.workspaceBuild.WorkspaceID) + m.updater.publishWorkspaceUpdate(finalCtx, m.workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentConnectionUpdate, + WorkspaceID: m.workspaceBuild.WorkspaceID, + AgentID: &m.workspaceAgent.ID, + }) }() reason := "disconnect" defer func() { @@ -407,7 +417,11 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { reason = err.Error() return } - m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID) + m.updater.publishWorkspaceUpdate(ctx, m.workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentConnectionUpdate, + WorkspaceID: m.workspaceBuild.WorkspaceID, + AgentID: &m.workspaceAgent.ID, + }) ticker := time.NewTicker(m.pingPeriod) defer ticker.Stop() @@ -441,7 +455,11 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { return } if connectionStatusChanged { - m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID) + m.updater.publishWorkspaceUpdate(ctx, m.workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindAgentConnectionUpdate, + WorkspaceID: m.workspaceBuild.WorkspaceID, + AgentID: &m.workspaceAgent.ID, + }) } err = checkBuildIsLatest(ctx, m.db, m.workspaceBuild) if err != nil { diff --git a/coderd/workspaceagentsrpc_internal_test.go b/coderd/workspaceagentsrpc_internal_test.go index dbae11a218619..338c2e4899368 100644 --- a/coderd/workspaceagentsrpc_internal_test.go +++ b/coderd/workspaceagentsrpc_internal_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -356,10 +357,10 @@ type fakeUpdater struct { updates []uuid.UUID } -func (f *fakeUpdater) publishWorkspaceUpdate(_ context.Context, workspaceID uuid.UUID) { +func (f *fakeUpdater) publishWorkspaceUpdate(_ context.Context, _ uuid.UUID, event wspubsub.WorkspaceEvent) { f.Lock() defer f.Unlock() - f.updates = append(f.updates, workspaceID) + f.updates = append(f.updates, event.WorkspaceID) } func (f *fakeUpdater) requireEventuallySomeUpdates(t *testing.T, workspaceID uuid.UUID) { diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 3515bc4a944b5..da785ac3a5a8a 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -30,6 +30,7 @@ import ( "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/wsbuilder" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" ) @@ -412,7 +413,10 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { return } - api.publishWorkspaceUpdate(ctx, workspace.ID) + api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: workspace.ID, + }) httpapi.Write(ctx, rw, http.StatusCreated, apiBuild) } @@ -491,7 +495,10 @@ func (api *API) patchCancelWorkspaceBuild(rw http.ResponseWriter, r *http.Reques return } - api.publishWorkspaceUpdate(ctx, workspace.ID) + api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStateChange, + WorkspaceID: workspace.ID, + }) httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{ Message: "Job has been marked as canceled...", diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 394a728472b0d..4638596e66eae 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -34,6 +34,7 @@ import ( "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/wsbuilder" + "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" ) @@ -806,7 +807,11 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { return } - api.publishWorkspaceUpdate(ctx, workspace.ID) + api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindMetadataUpdate, + WorkspaceID: workspace.ID, + }) + aReq.New = newWorkspace rw.WriteHeader(http.StatusNoContent) @@ -1216,7 +1221,11 @@ func (api *API) putExtendWorkspace(rw http.ResponseWriter, r *http.Request) { if err != nil { api.Logger.Info(ctx, "extending workspace", slog.Error(err)) } - api.publishWorkspaceUpdate(ctx, workspace.ID) + + api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindMetadataUpdate, + WorkspaceID: workspace.ID, + }) httpapi.Write(ctx, rw, code, resp) } @@ -1667,7 +1676,17 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { }) } - cancelWorkspaceSubscribe, err := api.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), sendUpdate) + cancelWorkspaceSubscribe, err := api.Pubsub.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspace.OwnerID), + wspubsub.HandleWorkspaceEvent( + func(ctx context.Context, payload wspubsub.WorkspaceEvent, err error) { + if err != nil { + return + } + if payload.WorkspaceID != workspace.ID { + return + } + sendUpdate(ctx, nil) + })) if err != nil { _ = sendEvent(ctx, codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeError, @@ -2006,11 +2025,24 @@ func validWorkspaceSchedule(s *string) (sql.NullString, error) { }, nil } -func (api *API) publishWorkspaceUpdate(ctx context.Context, workspaceID uuid.UUID) { - err := api.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceID), []byte{}) +func (api *API) publishWorkspaceUpdate(ctx context.Context, ownerID uuid.UUID, event wspubsub.WorkspaceEvent) { + err := event.Validate() + if err != nil { + api.Logger.Warn(ctx, "invalid workspace update event", + slog.F("workspace_id", event.WorkspaceID), + slog.F("event_kind", event.Kind), slog.Error(err)) + return + } + msg, err := json.Marshal(event) + if err != nil { + api.Logger.Warn(ctx, "failed to marshal workspace update", + slog.F("workspace_id", event.WorkspaceID), slog.Error(err)) + return + } + err = api.Pubsub.Publish(wspubsub.WorkspaceEventChannel(ownerID), msg) if err != nil { api.Logger.Warn(ctx, "failed to publish workspace update", - slog.F("workspace_id", workspaceID), slog.Error(err)) + slog.F("workspace_id", event.WorkspaceID), slog.Error(err)) } } diff --git a/coderd/workspacestats/reporter.go b/coderd/workspacestats/reporter.go index e59a9f15d5e95..b00523b1ad5d4 100644 --- a/coderd/workspacestats/reporter.go +++ b/coderd/workspacestats/reporter.go @@ -2,6 +2,7 @@ package workspacestats import ( "context" + "encoding/json" "sync/atomic" "time" @@ -18,7 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/workspaceapps" - "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/coderd/wspubsub" ) type ReporterOptions struct { @@ -174,7 +175,14 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac r.opts.UsageTracker.Add(workspace.ID) // notify workspace update - err := r.opts.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{}) + msg, err := json.Marshal(wspubsub.WorkspaceEvent{ + Kind: wspubsub.WorkspaceEventKindStatsUpdate, + WorkspaceID: workspace.ID, + }) + if err != nil { + return xerrors.Errorf("marshal workspace agent stats event: %w", err) + } + err = r.opts.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) if err != nil { r.opts.Logger.Warn(ctx, "failed to publish workspace agent stats", slog.F("workspace_id", workspace.ID), slog.Error(err)) diff --git a/coderd/wspubsub/wspubsub.go b/coderd/wspubsub/wspubsub.go new file mode 100644 index 0000000000000..0326efa695304 --- /dev/null +++ b/coderd/wspubsub/wspubsub.go @@ -0,0 +1,71 @@ +package wspubsub + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +// WorkspaceEventChannel can be used to subscribe to events for +// workspaces owned by the provided user ID. +func WorkspaceEventChannel(ownerID uuid.UUID) string { + return fmt.Sprintf("workspace_owner:%s", ownerID) +} + +func HandleWorkspaceEvent(cb func(ctx context.Context, payload WorkspaceEvent, err error)) func(ctx context.Context, message []byte, err error) { + return func(ctx context.Context, message []byte, err error) { + if err != nil { + cb(ctx, WorkspaceEvent{}, xerrors.Errorf("workspace event pubsub: %w", err)) + return + } + var payload WorkspaceEvent + if err := json.Unmarshal(message, &payload); err != nil { + cb(ctx, WorkspaceEvent{}, xerrors.Errorf("unmarshal workspace event")) + return + } + if err := payload.Validate(); err != nil { + cb(ctx, payload, xerrors.Errorf("validate workspace event")) + return + } + cb(ctx, payload, err) + } +} + +type WorkspaceEvent struct { + Kind WorkspaceEventKind `json:"kind"` + WorkspaceID uuid.UUID `json:"workspace_id" format:"uuid"` + // AgentID is only set for WorkspaceEventKindAgent* events + // (excluding AgentTimeout) + AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"` +} + +type WorkspaceEventKind string + +const ( + WorkspaceEventKindStateChange WorkspaceEventKind = "state_change" + WorkspaceEventKindStatsUpdate WorkspaceEventKind = "stats_update" + WorkspaceEventKindMetadataUpdate WorkspaceEventKind = "mtd_update" + WorkspaceEventKindAppHealthUpdate WorkspaceEventKind = "app_health" + + WorkspaceEventKindAgentLifecycleUpdate WorkspaceEventKind = "agt_lifecycle_update" + WorkspaceEventKindAgentConnectionUpdate WorkspaceEventKind = "agt_connection_update" + WorkspaceEventKindAgentFirstLogs WorkspaceEventKind = "agt_first_logs" + WorkspaceEventKindAgentLogsOverflow WorkspaceEventKind = "agt_logs_overflow" + WorkspaceEventKindAgentTimeout WorkspaceEventKind = "agt_timeout" +) + +func (w *WorkspaceEvent) Validate() error { + if w.WorkspaceID == uuid.Nil { + return xerrors.New("workspaceID must be set") + } + if w.Kind == "" { + return xerrors.New("kind must be set") + } + if w.Kind == WorkspaceEventKindAgentLifecycleUpdate && w.AgentID == nil { + return xerrors.New("agentID must be set for Agent events") + } + return nil +} diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index 5ce1769150e02..d6f3e30a92979 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -639,10 +639,3 @@ func (c *Client) WorkspaceTimings(ctx context.Context, id uuid.UUID) (WorkspaceB var timings WorkspaceBuildTimings return timings, json.NewDecoder(res.Body).Decode(&timings) } - -// WorkspaceNotifyChannel is the PostgreSQL NOTIFY -// channel to listen for updates on. The payload is empty, -// because the size of a workspace payload can be very large. -func WorkspaceNotifyChannel(id uuid.UUID) string { - return fmt.Sprintf("workspace:%s", id) -}