Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 99b8467

Browse files
committed
chore: consolidate websocketNetConn implementations
1 parent ec8e41f commit 99b8467

10 files changed

+145
-184
lines changed

coderd/workspaceagents.go

+4-46
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"errors"
88
"fmt"
99
"io"
10-
"net"
1110
"net/http"
1211
"net/url"
1312
"sort"
@@ -544,7 +543,7 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
544543
}
545544
go httpapi.Heartbeat(ctx, conn)
546545

547-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
546+
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
548547
defer wsNetConn.Close() // Also closes conn.
549548

550549
// The Go stdlib JSON encoder appends a newline character after message write.
@@ -881,7 +880,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
881880
})
882881
return
883882
}
884-
ctx, nconn := websocketNetConn(ctx, ws, websocket.MessageBinary)
883+
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
885884
defer nconn.Close()
886885

887886
// Slurp all packets from the connection into io.Discard so pongs get sent
@@ -990,7 +989,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
990989
return
991990
}
992991

993-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
992+
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
994993
defer wsNetConn.Close()
995994

996995
closeCtx, closeCtxCancel := context.WithCancel(ctx)
@@ -1077,7 +1076,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
10771076
})
10781077
return
10791078
}
1080-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
1079+
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
10811080
defer wsNetConn.Close()
10821081

10831082
go httpapi.Heartbeat(ctx, conn)
@@ -2108,47 +2107,6 @@ func createExternalAuthResponse(typ, token string, extra pqtype.NullRawMessage)
21082107
return resp, err
21092108
}
21102109

2111-
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
2112-
// is called if a read or write error is encountered.
2113-
type wsNetConn struct {
2114-
cancel context.CancelFunc
2115-
net.Conn
2116-
}
2117-
2118-
func (c *wsNetConn) Read(b []byte) (n int, err error) {
2119-
n, err = c.Conn.Read(b)
2120-
if err != nil {
2121-
c.cancel()
2122-
}
2123-
return n, err
2124-
}
2125-
2126-
func (c *wsNetConn) Write(b []byte) (n int, err error) {
2127-
n, err = c.Conn.Write(b)
2128-
if err != nil {
2129-
c.cancel()
2130-
}
2131-
return n, err
2132-
}
2133-
2134-
func (c *wsNetConn) Close() error {
2135-
defer c.cancel()
2136-
return c.Conn.Close()
2137-
}
2138-
2139-
// websocketNetConn wraps websocket.NetConn and returns a context that
2140-
// is tied to the parent context and the lifetime of the conn. Any error
2141-
// during read or write will cancel the context, but not close the
2142-
// conn. Close should be called to release context resources.
2143-
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
2144-
ctx, cancel := context.WithCancel(ctx)
2145-
nc := websocket.NetConn(ctx, conn, msgType)
2146-
return ctx, &wsNetConn{
2147-
cancel: cancel,
2148-
Conn: nc,
2149-
}
2150-
}
2151-
21522110
func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog {
21532111
sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs))
21542112
for _, logEntry := range logs {

coderd/workspaceagentsrpc.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
100100
return
101101
}
102102

103-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
103+
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
104104
defer wsNetConn.Close()
105105

106106
ycfg := yamux.DefaultConfig()

codersdk/agentsdk/agentsdk.go

+1-45
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
203203
return nil, codersdk.ReadBodyAsError(res)
204204
}
205205

206-
_, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
206+
_, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
207207

208208
netConn := &closeNetConn{
209209
Conn: wsNetConn,
@@ -596,50 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
596596
return authResp, json.NewDecoder(res.Body).Decode(&authResp)
597597
}
598598

599-
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
600-
// is called if a read or write error is encountered.
601-
type wsNetConn struct {
602-
cancel context.CancelFunc
603-
net.Conn
604-
}
605-
606-
func (c *wsNetConn) Read(b []byte) (n int, err error) {
607-
n, err = c.Conn.Read(b)
608-
if err != nil {
609-
c.cancel()
610-
}
611-
return n, err
612-
}
613-
614-
func (c *wsNetConn) Write(b []byte) (n int, err error) {
615-
n, err = c.Conn.Write(b)
616-
if err != nil {
617-
c.cancel()
618-
}
619-
return n, err
620-
}
621-
622-
func (c *wsNetConn) Close() error {
623-
defer c.cancel()
624-
return c.Conn.Close()
625-
}
626-
627-
// websocketNetConn wraps websocket.NetConn and returns a context that
628-
// is tied to the parent context and the lifetime of the conn. Any error
629-
// during read or write will cancel the context, but not close the
630-
// conn. Close should be called to release context resources.
631-
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
632-
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
633-
// the default because some of our protocols can include large messages like startup scripts.
634-
conn.SetReadLimit(1 << 22)
635-
ctx, cancel := context.WithCancel(ctx)
636-
nc := websocket.NetConn(ctx, conn, msgType)
637-
return ctx, &wsNetConn{
638-
cancel: cancel,
639-
Conn: nc,
640-
}
641-
}
642-
643599
// LogsNotifyChannel returns the channel name responsible for notifying
644600
// of new logs.
645601
func LogsNotifyChannel(agentID uuid.UUID) string {

codersdk/provisionerdaemons.go

+1-44
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/json"
66
"fmt"
77
"io"
8-
"net"
98
"net/http"
109
"net/http/cookiejar"
1110
"time"
@@ -248,7 +247,7 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
248247
config := yamux.DefaultConfig()
249248
config.LogOutput = io.Discard
250249
// Use background context because caller should close the client.
251-
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
250+
_, wsNetConn := WebsocketNetConn(context.Background(), conn, websocket.MessageBinary)
252251
session, err := yamux.Client(wsNetConn, config)
253252
if err != nil {
254253
_ = conn.Close(websocket.StatusGoingAway, "")
@@ -257,45 +256,3 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
257256
}
258257
return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil
259258
}
260-
261-
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
262-
// is called if a read or write error is encountered.
263-
// @typescript-ignore wsNetConn
264-
type wsNetConn struct {
265-
cancel context.CancelFunc
266-
net.Conn
267-
}
268-
269-
func (c *wsNetConn) Read(b []byte) (n int, err error) {
270-
n, err = c.Conn.Read(b)
271-
if err != nil {
272-
c.cancel()
273-
}
274-
return n, err
275-
}
276-
277-
func (c *wsNetConn) Write(b []byte) (n int, err error) {
278-
n, err = c.Conn.Write(b)
279-
if err != nil {
280-
c.cancel()
281-
}
282-
return n, err
283-
}
284-
285-
func (c *wsNetConn) Close() error {
286-
defer c.cancel()
287-
return c.Conn.Close()
288-
}
289-
290-
// websocketNetConn wraps websocket.NetConn and returns a context that
291-
// is tied to the parent context and the lifetime of the conn. Any error
292-
// during read or write will cancel the context, but not close the
293-
// conn. Close should be called to release context resources.
294-
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
295-
ctx, cancel := context.WithCancel(ctx)
296-
nc := websocket.NetConn(ctx, conn, msgType)
297-
return ctx, &wsNetConn{
298-
cancel: cancel,
299-
Conn: nc,
300-
}
301-
}

codersdk/websocket.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package codersdk
2+
3+
import (
4+
"context"
5+
"net"
6+
7+
"nhooyr.io/websocket"
8+
)
9+
10+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
11+
// is called if a read or write error is encountered.
12+
// @typescript-ignore wsNetConn
13+
type wsNetConn struct {
14+
cancel context.CancelFunc
15+
net.Conn
16+
}
17+
18+
func (c *wsNetConn) Read(b []byte) (n int, err error) {
19+
n, err = c.Conn.Read(b)
20+
if err != nil {
21+
c.cancel()
22+
}
23+
return n, err
24+
}
25+
26+
func (c *wsNetConn) Write(b []byte) (n int, err error) {
27+
n, err = c.Conn.Write(b)
28+
if err != nil {
29+
c.cancel()
30+
}
31+
return n, err
32+
}
33+
34+
func (c *wsNetConn) Close() error {
35+
defer c.cancel()
36+
return c.Conn.Close()
37+
}
38+
39+
// WebsocketNetConn wraps websocket.NetConn and returns a context that
40+
// is tied to the parent context and the lifetime of the conn. Any error
41+
// during read or write will cancel the context, but not close the
42+
// conn. Close should be called to release context resources.
43+
func WebsocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
44+
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
45+
// the default because some of our protocols can include large messages like startup scripts.
46+
conn.SetReadLimit(1 << 22)
47+
ctx, cancel := context.WithCancel(ctx)
48+
nc := websocket.NetConn(ctx, conn, msgType)
49+
return ctx, &wsNetConn{
50+
cancel: cancel,
51+
Conn: nc,
52+
}
53+
}

codersdk/websocket_test.go

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package codersdk_test
2+
3+
import (
4+
"crypto/rand"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"nhooyr.io/websocket"
12+
13+
"github.com/coder/coder/v2/codersdk"
14+
"github.com/coder/coder/v2/testutil"
15+
)
16+
17+
// TestWebsocketNetConn_LargeWrites tests that we can write large amounts of data thru the netconn
18+
// in a single write. Without specifically setting the read limit, the websocket library limits
19+
// the amount of data that can be read in a single message to 32kiB. Even after raising the limit,
20+
// curiously, it still only reads 32kiB per Read(), but allows the large write to go thru.
21+
func TestWebsocketNetConn_LargeWrites(t *testing.T) {
22+
t.Parallel()
23+
ctx := testutil.Context(t, testutil.WaitShort)
24+
n := 4 * 1024 * 1024 // 4 MiB
25+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26+
sws, err := websocket.Accept(w, r, nil)
27+
if !assert.NoError(t, err) {
28+
return
29+
}
30+
_, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
31+
defer nc.Close()
32+
33+
// Although the writes are all in one go, the reads get broken up by
34+
// the library.
35+
j := 0
36+
b := make([]byte, n)
37+
for j < n {
38+
k, err := nc.Read(b[j:])
39+
if !assert.NoError(t, err) {
40+
return
41+
}
42+
j += k
43+
t.Logf("server read %d bytes, total %d", k, j)
44+
}
45+
assert.Equal(t, n, j)
46+
j, err = nc.Write(b)
47+
assert.Equal(t, n, j)
48+
if !assert.NoError(t, err) {
49+
return
50+
}
51+
}))
52+
53+
// use of random data is worst case scenario for compression
54+
cb := make([]byte, n)
55+
rk, err := rand.Read(cb)
56+
require.NoError(t, err)
57+
require.Equal(t, n, rk)
58+
59+
// nolint: bodyclose
60+
cws, _, err := websocket.Dial(ctx, svr.URL, nil)
61+
require.NoError(t, err)
62+
_, cnc := codersdk.WebsocketNetConn(ctx, cws, websocket.MessageBinary)
63+
ck, err := cnc.Write(cb)
64+
require.NoError(t, err)
65+
require.Equal(t, n, ck)
66+
67+
cb2 := make([]byte, n)
68+
j := 0
69+
for j < n {
70+
k, err := cnc.Read(cb2[j:])
71+
if !assert.NoError(t, err) {
72+
return
73+
}
74+
j += k
75+
t.Logf("client read %d bytes, total %d", k, j)
76+
}
77+
require.NoError(t, err)
78+
require.Equal(t, n, j)
79+
require.Equal(t, cb, cb2)
80+
}

codersdk/workspaceagents.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
844844
}
845845
logChunks := make(chan []WorkspaceAgentLog, 1)
846846
closed := make(chan struct{})
847-
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
847+
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
848848
decoder := json.NewDecoder(wsNetConn)
849849
go func() {
850850
defer close(closed)

codersdk/workspaceagents_internal_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
5050
if !assert.NoError(t, err) {
5151
return
5252
}
53-
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
53+
ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
5454
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
5555
Name: "client",
5656
ID: clientID,

0 commit comments

Comments
 (0)