From 0d455cd0e18f1204655251c3e5b7f6306306b4d9 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Sun, 20 Feb 2022 19:50:57 +0000 Subject: [PATCH] fix: Use sync.WaitGroup to await hijacked HTTP connections WebSockets hijack the HTTP connection from the server, causing server.Close() to not wait for these connections to fully cleanup. This adds a global wait-group to the coderd API, which ensures all WebSocket HTTP handlers have properly exited before returning. --- coderd/cmd/root.go | 7 ++++--- coderd/coderd.go | 18 ++++++++++-------- coderd/coderdtest/coderdtest.go | 7 +++++-- coderd/provisionerdaemons.go | 4 ++++ 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/coderd/cmd/root.go b/coderd/cmd/root.go index 9087ac435b715..2778cb320a17e 100644 --- a/coderd/cmd/root.go +++ b/coderd/cmd/root.go @@ -33,7 +33,7 @@ func Root() *cobra.Command { Use: "coderd", RunE: func(cmd *cobra.Command, args []string) error { logger := slog.Make(sloghuman.Sink(os.Stderr)) - handler := coderd.New(&coderd.Options{ + handler, closeCoderd := coderd.New(&coderd.Options{ Logger: logger, Database: databasefake.New(), Pubsub: database.NewPubsubInMemory(), @@ -49,11 +49,11 @@ func Root() *cobra.Command { Scheme: "http", Host: address, }) - closer, err := newProvisionerDaemon(cmd.Context(), client, logger) + daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger) if err != nil { return xerrors.Errorf("create provisioner daemon: %w", err) } - defer closer.Close() + defer daemonClose.Close() errCh := make(chan error) go func() { @@ -61,6 +61,7 @@ func Root() *cobra.Command { errCh <- http.Serve(listener, handler) }() + closeCoderd() select { case <-cmd.Context().Done(): return cmd.Context().Err() diff --git a/coderd/coderd.go b/coderd/coderd.go index 1213b04aa0a86..765e4b0f1951c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -2,6 +2,7 @@ package coderd import ( "net/http" + "sync" "github.com/go-chi/chi/v5" @@ -20,11 +21,12 @@ type Options struct { } // New constructs the Coder API into an HTTP handler. -func New(options *Options) http.Handler { +// +// A wait function is returned to handle awaiting closure +// of hijacked HTTP requests. +func New(options *Options) (http.Handler, func()) { api := &api{ - Database: options.Database, - Logger: options.Logger, - Pubsub: options.Pubsub, + Options: options, } r := chi.NewRouter() @@ -144,13 +146,13 @@ func New(options *Options) http.Handler { }) }) r.NotFound(site.Handler(options.Logger).ServeHTTP) - return r + return r, api.websocketWaitGroup.Wait } // API contains all route handlers. Only HTTP handlers should // be added to this struct for code clarity. type api struct { - Database database.Store - Logger slog.Logger - Pubsub database.Pubsub + *Options + + websocketWaitGroup sync.WaitGroup } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 40eba2fb53942..ff05abb63a78a 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,7 +55,7 @@ func New(t *testing.T) *codersdk.Client { }) } - handler := coderd.New(&coderd.Options{ + handler, closeWait := coderd.New(&coderd.Options{ Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), Database: db, Pubsub: pubsub, @@ -69,7 +69,10 @@ func New(t *testing.T) *codersdk.Client { srv.Start() serverURL, err := url.Parse(srv.URL) require.NoError(t, err) - t.Cleanup(srv.Close) + t.Cleanup(func() { + srv.Close() + closeWait() + }) return codersdk.New(serverURL) } diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index a75730dc0a876..d95e6e62c0f9d 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -62,6 +62,8 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request) }) return } + api.websocketWaitGroup.Add(1) + defer api.websocketWaitGroup.Done() daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ ID: uuid.New(), @@ -100,7 +102,9 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request) err = server.Serve(r.Context(), session) if err != nil { _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err)) + return } + _ = conn.Close(websocket.StatusGoingAway, "") } // The input for a "workspace_provision" job.