Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: consolidate websocketNetConn implementations #12065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 4 additions & 46 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sort"
Expand Down Expand Up @@ -544,7 +543,7 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
}
go httpapi.Heartbeat(ctx, conn)

ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close() // Also closes conn.

// The Go stdlib JSON encoder appends a newline character after message write.
Expand Down Expand Up @@ -881,7 +880,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
})
return
}
ctx, nconn := websocketNetConn(ctx, ws, websocket.MessageBinary)
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
defer nconn.Close()

// Slurp all packets from the connection into io.Discard so pongs get sent
Expand Down Expand Up @@ -990,7 +989,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
return
}

ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()

closeCtx, closeCtxCancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -1077,7 +1076,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
})
return
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()

go httpapi.Heartbeat(ctx, conn)
Expand Down Expand Up @@ -2108,47 +2107,6 @@ func createExternalAuthResponse(typ, token string, extra pqtype.NullRawMessage)
return resp, err
}

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog {
sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs))
for _, logEntry := range logs {
Expand Down
2 changes: 1 addition & 1 deletion coderd/workspaceagentsrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
return
}

ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()

ycfg := yamux.DefaultConfig()
Expand Down
46 changes: 1 addition & 45 deletions codersdk/agentsdk/agentsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
return nil, codersdk.ReadBodyAsError(res)
}

_, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
_, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)

netConn := &closeNetConn{
Conn: wsNetConn,
Expand Down Expand Up @@ -596,50 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
return authResp, json.NewDecoder(res.Body).Decode(&authResp)
}

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
// the default because some of our protocols can include large messages like startup scripts.
conn.SetReadLimit(1 << 22)
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

// LogsNotifyChannel returns the channel name responsible for notifying
// of new logs.
func LogsNotifyChannel(agentID uuid.UUID) string {
Expand Down
45 changes: 1 addition & 44 deletions codersdk/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"time"
Expand Down Expand Up @@ -248,7 +247,7 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
// Use background context because caller should close the client.
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
_, wsNetConn := WebsocketNetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Client(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
Expand All @@ -257,45 +256,3 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
}
return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil
}

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
53 changes: 53 additions & 0 deletions codersdk/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package codersdk

import (
"context"
"net"

"nhooyr.io/websocket"
)

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// WebsocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func WebsocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
// the default because some of our protocols can include large messages like startup scripts.
conn.SetReadLimit(1 << 22)
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
80 changes: 80 additions & 0 deletions codersdk/websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package codersdk_test

import (
"crypto/rand"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"

"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)

// TestWebsocketNetConn_LargeWrites tests that we can write large amounts of data thru the netconn
// in a single write. Without specifically setting the read limit, the websocket library limits
// the amount of data that can be read in a single message to 32kiB. Even after raising the limit,
// curiously, it still only reads 32kiB per Read(), but allows the large write to go thru.
func TestWebsocketNetConn_LargeWrites(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
n := 4 * 1024 * 1024 // 4 MiB
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
_, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
defer nc.Close()

// Although the writes are all in one go, the reads get broken up by
// the library.
j := 0
b := make([]byte, n)
for j < n {
k, err := nc.Read(b[j:])
if !assert.NoError(t, err) {
return
}
j += k
t.Logf("server read %d bytes, total %d", k, j)
}
assert.Equal(t, n, j)
j, err = nc.Write(b)
assert.Equal(t, n, j)
if !assert.NoError(t, err) {
return
}
}))

// use of random data is worst case scenario for compression
cb := make([]byte, n)
rk, err := rand.Read(cb)
require.NoError(t, err)
require.Equal(t, n, rk)

// nolint: bodyclose
cws, _, err := websocket.Dial(ctx, svr.URL, nil)
require.NoError(t, err)
_, cnc := codersdk.WebsocketNetConn(ctx, cws, websocket.MessageBinary)
ck, err := cnc.Write(cb)
require.NoError(t, err)
require.Equal(t, n, ck)

cb2 := make([]byte, n)
j := 0
for j < n {
k, err := cnc.Read(cb2[j:])
if !assert.NoError(t, err) {
return
}
j += k
t.Logf("client read %d bytes, total %d", k, j)
}
require.NoError(t, err)
require.Equal(t, n, j)
require.Equal(t, cb, cb2)
}
2 changes: 1 addition & 1 deletion codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
}
logChunks := make(chan []WorkspaceAgentLog, 1)
closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn)
go func() {
defer close(closed)
Expand Down
2 changes: 1 addition & 1 deletion codersdk/workspaceagents_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
if !assert.NoError(t, err) {
return
}
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to fix this bad use of r.Context() since someone could look at this and get the wrong idea. (The websocket.Accept invalidates r.Context() cancellation.)

Suggested change
ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
ctx, nc := WebsocketNetConn(context.Background(), sws, websocket.MessageBinary)

I know some other places are using r.Context() via ctx := r.Context() too, but this one just seems too explicitly wrong. 😄

Perhaps this is actually a case for moving websocket.Accept into codersdk.WebsocketNetConn as well (or creating a unified function, codersdk.WebsocketAcceptNetConn).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The websocket.Accept invalidates r.Context() cancellation.

I don't understand what it means to "invalidate" a context cancelation.

Are you referring to the idea that once we have a websocket, passing r.Context() is a bit pointless because the context won't get canceled before the underlying TCP connection is closed? I guess that's true enough, but changing it to context.Background() doesn't change anything from a functional standpoint.

Are you suggesting we don't accept a context at all?

Copy link
Member

@mafredri mafredri Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context is actually not even tied to the TCP connection, it is tied to the handler which creates a scenario where it’s unlikely to ever be cancelled in a way that matters.

https://pkg.go.dev/net/http#Hijacker

Functionally, you’re right, there’s no difference but this behavior can easily trip anyone up so its usage should be discouraged.

err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: clientID,
Expand Down
Loading