From e1a4f0b701c039036632575535579f8288502e48 Mon Sep 17 00:00:00 2001 From: Garrett Delfosse Date: Wed, 23 Oct 2024 17:43:51 +0000 Subject: [PATCH 1/2] fix: close server pty connections on client disconnect --- coderd/httpapi/websocket.go | 23 ++++++++++++++++++++--- coderd/workspaceapps/proxy.go | 7 +++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/coderd/httpapi/websocket.go b/coderd/httpapi/websocket.go index 629dcac8131f3..b22193bfadf65 100644 --- a/coderd/httpapi/websocket.go +++ b/coderd/httpapi/websocket.go @@ -2,8 +2,10 @@ package httpapi import ( "context" + "errors" "time" + "golang.org/x/xerrors" "nhooyr.io/websocket" "cdr.dev/slog" @@ -31,7 +33,8 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) { // Heartbeat loops to ping a WebSocket to keep it alive. It calls `exit` on ping // failure. func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) { - ticker := time.NewTicker(15 * time.Second) + inverval := 15 * time.Second + ticker := time.NewTicker(inverval) defer ticker.Stop() for { @@ -40,12 +43,26 @@ func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn * return case <-ticker.C: } - err := conn.Ping(ctx) + err := pingWithTimeout(ctx, conn, inverval) if err != nil { + // context.DeadlineExceeded is expected when the client disconnects without sending a close frame + if !errors.Is(err, context.DeadlineExceeded) { + logger.Error(ctx, "failed to heartbeat ping", slog.Error(err)) + } _ = conn.Close(websocket.StatusGoingAway, "Ping failed") - logger.Info(ctx, "failed to heartbeat ping", slog.Error(err)) exit() return } } } + +func pingWithTimeout(ctx context.Context, conn *websocket.Conn, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + err := conn.Ping(ctx) + if err != nil { + return xerrors.Errorf("failed to ping: %w", err) + } + + return nil +} diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 69f1aadca49b2..c6cd01395db5c 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -593,7 +593,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx)) report := newStatsReportFromSignedToken(appToken) - s.collectStats(report) defer func() { // We must use defer here because ServeHTTP may panic. report.SessionEndedAt = dbtime.Now() @@ -614,7 +613,8 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT // @Success 101 // @Router /workspaceagents/{workspaceagent}/pty [get] func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() s.websocketWaitMutex.Lock() s.websocketWaitGroup.Add(1) @@ -670,12 +670,11 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { }) return } + go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn) ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() // Also closes conn. - go httpapi.Heartbeat(ctx, conn) - agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID) if err != nil { log.Debug(ctx, "dial workspace agent", slog.Error(err)) From 3e2717dab82230796806c42b7d87f83b34612b09 Mon Sep 17 00:00:00 2001 From: Garrett Delfosse Date: Wed, 23 Oct 2024 18:00:06 +0000 Subject: [PATCH 2/2] fix typo --- coderd/httpapi/websocket.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coderd/httpapi/websocket.go b/coderd/httpapi/websocket.go index b22193bfadf65..2d6f131fd5aa3 100644 --- a/coderd/httpapi/websocket.go +++ b/coderd/httpapi/websocket.go @@ -33,8 +33,8 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) { // Heartbeat loops to ping a WebSocket to keep it alive. It calls `exit` on ping // failure. func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) { - inverval := 15 * time.Second - ticker := time.NewTicker(inverval) + interval := 15 * time.Second + ticker := time.NewTicker(interval) defer ticker.Stop() for { @@ -43,7 +43,7 @@ func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn * return case <-ticker.C: } - err := pingWithTimeout(ctx, conn, inverval) + err := pingWithTimeout(ctx, conn, interval) if err != nil { // context.DeadlineExceeded is expected when the client disconnects without sending a close frame if !errors.Is(err, context.DeadlineExceeded) {