@@ -17,6 +17,7 @@ import (
17
17
"nhooyr.io/websocket"
18
18
19
19
"cdr.dev/slog"
20
+
20
21
"github.com/coder/coder/agent"
21
22
"github.com/coder/coder/coderd/database"
22
23
"github.com/coder/coder/coderd/httpapi"
@@ -77,17 +78,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
77
78
})
78
79
return
79
80
}
80
- defer func () {
81
- _ = conn .Close (websocket .StatusNormalClosure , "" )
82
- }()
81
+
82
+ ctx , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageBinary )
83
+ defer wsNetConn .Close () // Also closes conn.
84
+
83
85
config := yamux .DefaultConfig ()
84
86
config .LogOutput = io .Discard
85
- session , err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) , config )
87
+ session , err := yamux .Server (wsNetConn , config )
86
88
if err != nil {
87
89
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
88
90
return
89
91
}
90
- err = peerbroker .ProxyListen (r . Context () , session , peerbroker.ProxyOptions {
92
+ err = peerbroker .ProxyListen (ctx , session , peerbroker.ProxyOptions {
91
93
ChannelID : workspaceAgent .ID .String (),
92
94
Logger : api .Logger .Named ("peerbroker-proxy-dial" ),
93
95
Pubsub : api .Pubsub ,
@@ -201,13 +203,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
201
203
return
202
204
}
203
205
204
- defer func () {
205
- _ = conn .Close (websocket .StatusNormalClosure , "" )
206
- }()
206
+ ctx , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageBinary )
207
+ defer wsNetConn .Close () // Also closes conn.
207
208
208
209
config := yamux .DefaultConfig ()
209
210
config .LogOutput = io .Discard
210
- session , err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) , config )
211
+ session , err := yamux .Server (wsNetConn , config )
211
212
if err != nil {
212
213
_ = conn .Close (websocket .StatusAbnormalClosure , err .Error ())
213
214
return
@@ -237,7 +238,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
237
238
}
238
239
disconnectedAt := workspaceAgent .DisconnectedAt
239
240
updateConnectionTimes := func () error {
240
- err = api .Database .UpdateWorkspaceAgentConnectionByID (r . Context () , database.UpdateWorkspaceAgentConnectionByIDParams {
241
+ err = api .Database .UpdateWorkspaceAgentConnectionByID (ctx , database.UpdateWorkspaceAgentConnectionByIDParams {
241
242
ID : workspaceAgent .ID ,
242
243
FirstConnectedAt : firstConnectedAt ,
243
244
LastConnectedAt : lastConnectedAt ,
@@ -263,7 +264,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
263
264
return
264
265
}
265
266
266
- api .Logger .Info (r . Context () , "accepting agent" , slog .F ("resource" , resource ), slog .F ("agent" , workspaceAgent ))
267
+ api .Logger .Info (ctx , "accepting agent" , slog .F ("resource" , resource ), slog .F ("agent" , workspaceAgent ))
267
268
268
269
ticker := time .NewTicker (api .AgentConnectionUpdateFrequency )
269
270
defer ticker .Stop ()
@@ -332,16 +333,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
332
333
})
333
334
return
334
335
}
335
- defer func () {
336
- _ = wsConn . Close ( websocket . StatusNormalClosure , "" )
337
- }()
338
- netConn := websocket . NetConn ( r . Context (), wsConn , websocket . MessageBinary )
339
- api .Logger .Debug (r . Context () , "accepting turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
336
+
337
+ ctx , wsNetConn := websocketNetConn ( r . Context (), wsConn , websocket . MessageBinary )
338
+ defer wsNetConn . Close () // Also closes conn.
339
+
340
+ api .Logger .Debug (ctx , "accepting turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
340
341
select {
341
- case <- api .TURNServer .Accept (netConn , remoteAddress , localAddress ).Closed ():
342
- case <- r . Context () .Done ():
342
+ case <- api .TURNServer .Accept (wsNetConn , remoteAddress , localAddress ).Closed ():
343
+ case <- ctx .Done ():
343
344
}
344
- api .Logger .Debug (r . Context () , "completed turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
345
+ api .Logger .Debug (ctx , "completed turn connection" , slog .F ("remote-address" , r .RemoteAddr ), slog .F ("local-address" , localAddress ))
345
346
}
346
347
347
348
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -392,11 +393,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
392
393
})
393
394
return
394
395
}
395
- defer func () {
396
- _ = conn .Close (websocket .StatusNormalClosure , "ended" )
397
- }()
398
- // Accept text connections, because it's more developer friendly.
399
- wsNetConn := websocket .NetConn (r .Context (), conn , websocket .MessageBinary )
396
+
397
+ _ , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageBinary )
398
+ defer wsNetConn .Close () // Also closes conn.
399
+
400
400
agentConn , release , err := api .workspaceAgentCache .Acquire (r , workspaceAgent .ID )
401
401
if err != nil {
402
402
_ = conn .Close (websocket .StatusInternalError , httpapi .WebsocketCloseSprintf ("dial workspace agent: %s" , err ))
@@ -416,8 +416,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
416
416
_ , _ = io .Copy (ptNetConn , wsNetConn )
417
417
}
418
418
419
- // dialWorkspaceAgent connects to a workspace agent by ID.
420
- func (api * API ) dialWorkspaceAgent (r * http.Request , agentID uuid.UUID ) (* agent.Conn , error ) {
419
+ // dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
420
+ // r.Context() for cancellation if it's use is safe or r.Hijack() has
421
+ // not been performed.
422
+ func (api * API ) dialWorkspaceAgent (ctx context.Context , r * http.Request , agentID uuid.UUID ) (* agent.Conn , error ) {
421
423
client , server := provisionersdk .TransportPipe ()
422
424
ctx , cancelFunc := context .WithCancel (context .Background ())
423
425
go func () {
@@ -446,7 +448,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
446
448
options .SettingEngine .SetICEProxyDialer (turnconn .ProxyDialer (func () (c net.Conn , err error ) {
447
449
clientPipe , serverPipe := net .Pipe ()
448
450
go func () {
449
- <- r . Context () .Done ()
451
+ <- ctx .Done ()
450
452
_ = clientPipe .Close ()
451
453
_ = serverPipe .Close ()
452
454
}()
@@ -546,3 +548,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
546
548
547
549
return workspaceAgent , nil
548
550
}
551
+
552
+ // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
553
+ // is called if a read or write error is encountered.
554
+ type wsNetConn struct {
555
+ cancel context.CancelFunc
556
+ net.Conn
557
+ }
558
+
559
+ func (c * wsNetConn ) Read (b []byte ) (n int , err error ) {
560
+ n , err = c .Conn .Read (b )
561
+ if err != nil {
562
+ c .cancel ()
563
+ }
564
+ return n , err
565
+ }
566
+
567
+ func (c * wsNetConn ) Write (b []byte ) (n int , err error ) {
568
+ n , err = c .Conn .Write (b )
569
+ if err != nil {
570
+ c .cancel ()
571
+ }
572
+ return n , err
573
+ }
574
+
575
+ func (c * wsNetConn ) Close () error {
576
+ defer c .cancel ()
577
+ return c .Conn .Close ()
578
+ }
579
+
580
+ // websocketNetConn wraps websocket.NetConn and returns a context that
581
+ // is tied to the parent context and the lifetime of the conn. Any error
582
+ // during read or write will cancel the context, but not close the
583
+ // conn. Close should be called to release context resources.
584
+ func websocketNetConn (ctx context.Context , conn * websocket.Conn , msgType websocket.MessageType ) (context.Context , net.Conn ) {
585
+ ctx , cancel := context .WithCancel (ctx )
586
+ nc := websocket .NetConn (ctx , conn , msgType )
587
+ return ctx , & wsNetConn {
588
+ cancel : cancel ,
589
+ Conn : nc ,
590
+ }
591
+ }
0 commit comments