From 0a384f259e30e46600b23e40f860bacc9f2caf03 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 22 Jan 2024 15:12:18 +0400 Subject: [PATCH 1/2] feat: changes codersdk to use tailnet v2 for DERPMap updates --- codersdk/workspaceagents.go | 324 +++++++++++++++++++++--------------- 1 file changed, 188 insertions(+), 136 deletions(-) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 63b8de3c044b2..b50fd7e71af07 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -14,6 +14,8 @@ import ( "strings" "time" + "golang.org/x/sync/errgroup" + "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -317,142 +319,28 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, q := coordinateURL.Query() q.Add("version", proto.CurrentVersion.String()) coordinateURL.RawQuery = q.Encode() - closedCoordinator := make(chan struct{}) - // Must only ever be used once, send error OR close to avoid - // reassignment race. Buffered so we don't hang in goroutine. - firstCoordinator := make(chan error, 1) - go func() { - defer close(closedCoordinator) - isFirst := true - for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - options.Logger.Debug(ctx, "connecting") - // nolint:bodyclose - ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ - HTTPClient: c.HTTPClient, - HTTPHeader: headers, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if isFirst { - if res != nil && res.StatusCode == http.StatusConflict { - firstCoordinator <- ReadBodyAsError(res) - return - } - isFirst = false - close(firstCoordinator) - } - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) - continue - } - client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary)) - if err != nil { - options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) - _ = ws.Close(websocket.StatusInternalError, "") - continue - } - coordinate, err := client.Coordinate(ctx) - if err != nil { - options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err)) - _ = ws.Close(websocket.StatusInternalError, "") - continue - } - - coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID) - options.Logger.Debug(ctx, "serving coordinator") - err = <-coordination.Error() - if errors.Is(err, context.Canceled) { - _ = ws.Close(websocket.StatusGoingAway, "") - return - } - if err != nil { - options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err)) - _ = ws.Close(websocket.StatusGoingAway, "") - continue - } - _ = ws.Close(websocket.StatusGoingAway, "") - } - }() - derpMapURL, err := c.URL.Parse("/api/v2/derp-map") - if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) - } - closedDerpMap := make(chan struct{}) - // Must only ever be used once, send error OR close to avoid - // reassignment race. Buffered so we don't hang in goroutine. - firstDerpMap := make(chan error, 1) - go func() { - defer close(closedDerpMap) - isFirst := true - for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - options.Logger.Debug(ctx, "connecting to server for derp map updates") - // nolint:bodyclose - ws, res, err := websocket.Dial(ctx, derpMapURL.String(), &websocket.DialOptions{ - HTTPClient: c.HTTPClient, - HTTPHeader: headers, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if isFirst { - if res != nil && res.StatusCode == http.StatusConflict { - firstDerpMap <- ReadBodyAsError(res) - return - } - isFirst = false - close(firstDerpMap) - } - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) - continue - } - - var ( - nconn = websocket.NetConn(ctx, ws, websocket.MessageBinary) - dec = json.NewDecoder(nconn) - ) - for { - var derpMap tailcfg.DERPMap - err := dec.Decode(&derpMap) - if xerrors.Is(err, context.Canceled) { - _ = ws.Close(websocket.StatusGoingAway, "") - return - } - if err != nil { - options.Logger.Debug(ctx, "failed to decode derp map", slog.Error(err)) - _ = ws.Close(websocket.StatusGoingAway, "") - return - } - - if !tailnet.CompareDERPMaps(conn.DERPMap(), &derpMap) { - options.Logger.Debug(ctx, "updating derp map due to detected changes") - conn.SetDERPMap(&derpMap) - } - } - } - }() - - for firstCoordinator != nil || firstDerpMap != nil { - select { - case <-dialCtx.Done(): - return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err()) - case err = <-firstCoordinator: - if err != nil { - return nil, xerrors.Errorf("start coordinator: %w", err) - } - firstCoordinator = nil - case err = <-firstDerpMap: - if err != nil { - return nil, xerrors.Errorf("receive derp map: %w", err) - } - firstDerpMap = nil + connector := runTailnetAPIConnector(ctx, options.Logger, + agentID, coordinateURL.String(), + &websocket.DialOptions{ + HTTPClient: c.HTTPClient, + HTTPHeader: headers, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }, + conn, + ) + options.Logger.Debug(ctx, "running tailnet API v2+ connector") + + select { + case <-dialCtx.Done(): + return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err()) + case err = <-connector.connected: + if err != nil { + options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err)) + return nil, xerrors.Errorf("start connector: %w", err) } + options.Logger.Debug(ctx, "connected to tailnet v2+ API") } agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ @@ -464,8 +352,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, AgentIP: WorkspaceAgentIP, CloseFunc: func() error { cancel() - <-closedCoordinator - <-closedDerpMap + <-connector.closed return conn.Close() }, }) @@ -478,6 +365,171 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, return agentConn, nil } +// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to +// +// 1) run the Coordinate API and pass node information back and forth +// 2) stream DERPMap updates and program the Conn +// +// These functions share the same websocket, and so are combined here so that if we hit a problem +// we tear the whole thing down and start over with a new websocket. +// +// @typescript-ignore tailnetAPIConnector +type tailnetAPIConnector struct { + ctx context.Context + logger slog.Logger + + agentID uuid.UUID + coordinateURL string + dialOptions *websocket.DialOptions + conn *tailnet.Conn + + connected chan error + isFirst bool + closed chan struct{} +} + +// runTailnetAPIConnector creates and runs a tailnetAPIConnector +func runTailnetAPIConnector( + ctx context.Context, logger slog.Logger, + agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions, + conn *tailnet.Conn, +) *tailnetAPIConnector { + tac := &tailnetAPIConnector{ + ctx: ctx, + logger: logger, + agentID: agentID, + coordinateURL: coordinateURL, + dialOptions: dialOptions, + conn: conn, + connected: make(chan error, 1), + closed: make(chan struct{}), + } + go tac.run() + return tac +} + +func (tac *tailnetAPIConnector) run() { + tac.isFirst = true + defer close(tac.closed) + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); { + tailnetClient, err := tac.dial() + if err != nil { + continue + } + tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client") + tac.coordinateAndDERPMap(tailnetClient) + tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost") + } +} + +func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) { + tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API") + // nolint:bodyclose + ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions) + if tac.isFirst { + if res != nil && res.StatusCode == http.StatusConflict { + err = ReadBodyAsError(res) + tac.connected <- err + return nil, err + } + tac.isFirst = false + close(tac.connected) + } + if err != nil { + if !errors.Is(err, context.Canceled) { + tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err)) + } + return nil, err + } + client, err := tailnet.NewDRPCClient(websocket.NetConn(tac.ctx, ws, websocket.MessageBinary)) + if err != nil { + tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return nil, err + } + return client, err +} + +// coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined +// into one function so that a problem with one tears down the other and triggers a retry (if +// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same +// fate. +func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetClient) { + defer func() { + conn := client.DRPCConn() + closeErr := conn.Close() + if closeErr != nil && + !xerrors.Is(closeErr, io.EOF) && + !xerrors.Is(closeErr, context.Canceled) && + !xerrors.Is(closeErr, context.DeadlineExceeded) { + tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr)) + <-conn.Closed() + } + }() + eg, egCtx := errgroup.WithContext(tac.ctx) + eg.Go(func() error { + return tac.coordinate(egCtx, client) + }) + eg.Go(func() error { + return tac.derpMap(egCtx, client) + }) + err := eg.Wait() + if err != nil && + !xerrors.Is(err, io.EOF) && + !xerrors.Is(err, context.Canceled) && + !xerrors.Is(err, context.DeadlineExceeded) { + tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API") + } +} + +func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error { + coord, err := client.Coordinate(ctx) + if err != nil { + return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err) + } + defer func() { + cErr := coord.Close() + if cErr != nil { + tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr)) + } + }() + coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) + tac.logger.Debug(ctx, "serving coordinator") + err = <-coordination.Error() + if err != nil && + !xerrors.Is(err, io.EOF) && + !xerrors.Is(err, context.Canceled) && + !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("remote coordination error: %w", err) + } + return nil +} + +func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error { + s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{}) + if err != nil { + return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err) + } + defer func() { + cErr := s.Close() + if cErr != nil { + tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) + } + }() + for { + dmp, err := s.Recv() + if err != nil { + if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + return nil + } + return xerrors.Errorf("error receiving DERP Map: %w", err) + } + tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp)) + dm := tailnet.DERPMapFromProto(dmp) + tac.conn.SetDERPMap(dm) + } +} + // WatchWorkspaceAgentMetadata watches the metadata of a workspace agent. // The returned channel will be closed when the context is canceled. Exactly // one error will be sent on the error channel. The metadata channel is never closed. From 38f23eaae5d7c171b5cd903a67cff875dadd81f7 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 23 Jan 2024 13:00:41 +0400 Subject: [PATCH 2/2] fix: fix TestServiceBanners/Agent --- enterprise/coderd/appearance_test.go | 46 +++++++++++++--------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index ca82e40168505..493bfd7cc5aad 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -6,7 +6,11 @@ import ( "net/http" "testing" - "github.com/google/uuid" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,7 +21,6 @@ import ( "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" - "github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/testutil" ) @@ -125,13 +128,15 @@ func TestServiceBanners(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + store, ps := dbtestutil.NewDB(t) client, user := coderdenttest.New(t, &coderdenttest.Options{ Options: &coderdtest.Options{ - IncludeProvisionerDaemon: true, + Database: store, + Pubsub: ps, }, DontAddLicense: true, }) - license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + lic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ Features: license.Features{ codersdk.FeatureAppearance: 1, }, @@ -146,35 +151,28 @@ func TestServiceBanners(t *testing.T) { err := client.UpdateAppearance(ctx, cfg) require.NoError(t, err) - authToken := uuid.NewString() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionPlan: echo.PlanComplete, - ProvisionApply: echo.ProvisionApplyWithAgent(authToken), - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) - coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + r := dbfake.WorkspaceBuild(t, store, database.Workspace{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(r.AgentToken) banner, err := agentClient.GetServiceBanner(ctx) require.NoError(t, err) require.Equal(t, cfg.ServiceBanner, banner) - // No enterprise means a 404 on the endpoint meaning no banner. - client = coderdtest.New(t, &coderdtest.Options{ - IncludeProvisionerDaemon: true, - }) - agentClient = agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) - banner, err = agentClient.GetServiceBanner(ctx) + // Create an AGPL Coderd against the same database + agplClient := coderdtest.New(t, &coderdtest.Options{Database: store, Pubsub: ps}) + agplAgentClient := agentsdk.New(agplClient.URL) + agplAgentClient.SetSessionToken(r.AgentToken) + banner, err = agplAgentClient.GetServiceBanner(ctx) require.NoError(t, err) require.Equal(t, codersdk.ServiceBannerConfig{}, banner) // No license means no banner. - client.DeleteLicense(ctx, license.ID) + err = client.DeleteLicense(ctx, lic.ID) + require.NoError(t, err) banner, err = agentClient.GetServiceBanner(ctx) require.NoError(t, err) require.Equal(t, codersdk.ServiceBannerConfig{}, banner)