Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cd2d12e

Browse files
committed
Merge branch 'main' into apps
2 parents 5b9194f + b4f5920 commit cd2d12e

File tree

1 file changed

+70
-27
lines changed

1 file changed

+70
-27
lines changed

coderd/workspaceagents.go

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"nhooyr.io/websocket"
1818

1919
"cdr.dev/slog"
20+
2021
"github.com/coder/coder/agent"
2122
"github.com/coder/coder/coderd/database"
2223
"github.com/coder/coder/coderd/httpapi"
@@ -77,17 +78,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
7778
})
7879
return
7980
}
80-
defer func() {
81-
_ = conn.Close(websocket.StatusNormalClosure, "")
82-
}()
81+
82+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
83+
defer wsNetConn.Close() // Also closes conn.
84+
8385
config := yamux.DefaultConfig()
8486
config.LogOutput = io.Discard
85-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
87+
session, err := yamux.Server(wsNetConn, config)
8688
if err != nil {
8789
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
8890
return
8991
}
90-
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
92+
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
9193
ChannelID: workspaceAgent.ID.String(),
9294
Logger: api.Logger.Named("peerbroker-proxy-dial"),
9395
Pubsub: api.Pubsub,
@@ -201,13 +203,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
201203
return
202204
}
203205

204-
defer func() {
205-
_ = conn.Close(websocket.StatusNormalClosure, "")
206-
}()
206+
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
207+
defer wsNetConn.Close() // Also closes conn.
207208

208209
config := yamux.DefaultConfig()
209210
config.LogOutput = io.Discard
210-
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
211+
session, err := yamux.Server(wsNetConn, config)
211212
if err != nil {
212213
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
213214
return
@@ -237,7 +238,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
237238
}
238239
disconnectedAt := workspaceAgent.DisconnectedAt
239240
updateConnectionTimes := func() error {
240-
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
241+
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
241242
ID: workspaceAgent.ID,
242243
FirstConnectedAt: firstConnectedAt,
243244
LastConnectedAt: lastConnectedAt,
@@ -263,7 +264,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
263264
return
264265
}
265266

266-
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
267+
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
267268

268269
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
269270
defer ticker.Stop()
@@ -332,16 +333,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
332333
})
333334
return
334335
}
335-
defer func() {
336-
_ = wsConn.Close(websocket.StatusNormalClosure, "")
337-
}()
338-
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
339-
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
336+
337+
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
338+
defer wsNetConn.Close() // Also closes conn.
339+
340+
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
340341
select {
341-
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
342-
case <-r.Context().Done():
342+
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
343+
case <-ctx.Done():
343344
}
344-
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
345+
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
345346
}
346347

347348
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -392,11 +393,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
392393
})
393394
return
394395
}
395-
defer func() {
396-
_ = conn.Close(websocket.StatusNormalClosure, "ended")
397-
}()
398-
// Accept text connections, because it's more developer friendly.
399-
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
396+
397+
_, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
398+
defer wsNetConn.Close() // Also closes conn.
399+
400400
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
401401
if err != nil {
402402
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
@@ -416,8 +416,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
416416
_, _ = io.Copy(ptNetConn, wsNetConn)
417417
}
418418

419-
// dialWorkspaceAgent connects to a workspace agent by ID.
420-
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
419+
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
420+
// r.Context() for cancellation if it's use is safe or r.Hijack() has
421+
// not been performed.
422+
func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
421423
client, server := provisionersdk.TransportPipe()
422424
ctx, cancelFunc := context.WithCancel(context.Background())
423425
go func() {
@@ -446,7 +448,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
446448
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
447449
clientPipe, serverPipe := net.Pipe()
448450
go func() {
449-
<-r.Context().Done()
451+
<-ctx.Done()
450452
_ = clientPipe.Close()
451453
_ = serverPipe.Close()
452454
}()
@@ -546,3 +548,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
546548

547549
return workspaceAgent, nil
548550
}
551+
552+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
553+
// is called if a read or write error is encountered.
554+
type wsNetConn struct {
555+
cancel context.CancelFunc
556+
net.Conn
557+
}
558+
559+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
560+
n, err = c.Conn.Read(b)
561+
if err != nil {
562+
c.cancel()
563+
}
564+
return n, err
565+
}
566+
567+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
568+
n, err = c.Conn.Write(b)
569+
if err != nil {
570+
c.cancel()
571+
}
572+
return n, err
573+
}
574+
575+
func (c *wsNetConn) Close() error {
576+
defer c.cancel()
577+
return c.Conn.Close()
578+
}
579+
580+
// websocketNetConn wraps websocket.NetConn and returns a context that
581+
// is tied to the parent context and the lifetime of the conn. Any error
582+
// during read or write will cancel the context, but not close the
583+
// conn. Close should be called to release context resources.
584+
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
585+
ctx, cancel := context.WithCancel(ctx)
586+
nc := websocket.NetConn(ctx, conn, msgType)
587+
return ctx, &wsNetConn{
588+
cancel: cancel,
589+
Conn: nc,
590+
}
591+
}

0 commit comments

Comments
 (0)