Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0d455cd

Browse files
committed
fix: Use sync.WaitGroup to await hijacked HTTP connections
WebSockets hijack the HTTP connection from the server, causing server.Close() to not wait for these connections to fully cleanup. This adds a global wait-group to the coderd API, which ensures all WebSocket HTTP handlers have properly exited before returning.
1 parent 8f843d2 commit 0d455cd

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

coderd/cmd/root.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func Root() *cobra.Command {
3333
Use: "coderd",
3434
RunE: func(cmd *cobra.Command, args []string) error {
3535
logger := slog.Make(sloghuman.Sink(os.Stderr))
36-
handler := coderd.New(&coderd.Options{
36+
handler, closeCoderd := coderd.New(&coderd.Options{
3737
Logger: logger,
3838
Database: databasefake.New(),
3939
Pubsub: database.NewPubsubInMemory(),
@@ -49,18 +49,19 @@ func Root() *cobra.Command {
4949
Scheme: "http",
5050
Host: address,
5151
})
52-
closer, err := newProvisionerDaemon(cmd.Context(), client, logger)
52+
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
5353
if err != nil {
5454
return xerrors.Errorf("create provisioner daemon: %w", err)
5555
}
56-
defer closer.Close()
56+
defer daemonClose.Close()
5757

5858
errCh := make(chan error)
5959
go func() {
6060
defer close(errCh)
6161
errCh <- http.Serve(listener, handler)
6262
}()
6363

64+
closeCoderd()
6465
select {
6566
case <-cmd.Context().Done():
6667
return cmd.Context().Err()

coderd/coderd.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package coderd
22

33
import (
44
"net/http"
5+
"sync"
56

67
"github.com/go-chi/chi/v5"
78

@@ -20,11 +21,12 @@ type Options struct {
2021
}
2122

2223
// New constructs the Coder API into an HTTP handler.
23-
func New(options *Options) http.Handler {
24+
//
25+
// A wait function is returned to handle awaiting closure
26+
// of hijacked HTTP requests.
27+
func New(options *Options) (http.Handler, func()) {
2428
api := &api{
25-
Database: options.Database,
26-
Logger: options.Logger,
27-
Pubsub: options.Pubsub,
29+
Options: options,
2830
}
2931

3032
r := chi.NewRouter()
@@ -144,13 +146,13 @@ func New(options *Options) http.Handler {
144146
})
145147
})
146148
r.NotFound(site.Handler(options.Logger).ServeHTTP)
147-
return r
149+
return r, api.websocketWaitGroup.Wait
148150
}
149151

150152
// API contains all route handlers. Only HTTP handlers should
151153
// be added to this struct for code clarity.
152154
type api struct {
153-
Database database.Store
154-
Logger slog.Logger
155-
Pubsub database.Pubsub
155+
*Options
156+
157+
websocketWaitGroup sync.WaitGroup
156158
}

coderd/coderdtest/coderdtest.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func New(t *testing.T) *codersdk.Client {
5555
})
5656
}
5757

58-
handler := coderd.New(&coderd.Options{
58+
handler, closeWait := coderd.New(&coderd.Options{
5959
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
6060
Database: db,
6161
Pubsub: pubsub,
@@ -69,7 +69,10 @@ func New(t *testing.T) *codersdk.Client {
6969
srv.Start()
7070
serverURL, err := url.Parse(srv.URL)
7171
require.NoError(t, err)
72-
t.Cleanup(srv.Close)
72+
t.Cleanup(func() {
73+
srv.Close()
74+
closeWait()
75+
})
7376

7477
return codersdk.New(serverURL)
7578
}

coderd/provisionerdaemons.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request)
6262
})
6363
return
6464
}
65+
api.websocketWaitGroup.Add(1)
66+
defer api.websocketWaitGroup.Done()
6567

6668
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
6769
ID: uuid.New(),
@@ -100,7 +102,9 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request)
100102
err = server.Serve(r.Context(), session)
101103
if err != nil {
102104
_ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err))
105+
return
103106
}
107+
_ = conn.Close(websocket.StatusGoingAway, "")
104108
}
105109

106110
// The input for a "workspace_provision" job.

0 commit comments

Comments
 (0)