Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix: Improve use of context in websocket.NetConn code paths #6198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,10 +732,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
})
return
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()

go httpapi.Heartbeat(ctx, conn)

defer conn.Close(websocket.StatusNormalClosure, "")
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
Expand Down
45 changes: 44 additions & 1 deletion codersdk/agentsdk/agentsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
return nil, codersdk.ReadBodyAsError(res)
}

ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)

// Ping once every 30 seconds to ensure that the websocket is alive. If we
// don't get a response within 30s we kill the websocket and reconnect.
// See: https://github.com/coder/coder/pull/5824
Expand Down Expand Up @@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
}
}()

return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
return wsNetConn, nil
}

type PostAppHealthsRequest struct {
Expand Down Expand Up @@ -529,3 +531,44 @@ type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
58 changes: 54 additions & 4 deletions codersdk/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
Expand Down Expand Up @@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
return nil, nil, ReadBodyAsError(res)
}
logs := make(chan ProvisionerJobLog)
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn)
go func() {
defer close(closed)
defer close(logs)
Expand All @@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
}
}()
return logs, closeFunc(func() error {
_ = conn.Close(websocket.StatusNormalClosure, "")
_ = wsNetConn.Close()
<-closed
return nil
}), nil
}

// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
// implementation. The context is during dial, not during the lifetime of the
// client. Client should be closed after use.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
if err != nil {
Expand Down Expand Up @@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U

config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
// Use background context because caller should close the client.
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Client(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
_ = wsNetConn.Close()
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
}

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
2 changes: 1 addition & 1 deletion codersdk/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
}
return nil, ReadBodyAsError(res)
}
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil
}

// WorkspaceAgentListeningPorts returns a list of ports that are currently being
Expand Down
71 changes: 59 additions & 12 deletions enterprise/coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package coderd

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"

Expand Down Expand Up @@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
// @Success 101
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

tags := map[string]string{}
if r.URL.Query().Has("tag") {
for _, tag := range r.URL.Query()["tag"] {
parts := strings.SplitN(tag, "=", 2)
if len(parts) < 2 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
})
return
Expand All @@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}
}
if !r.URL.Query().Has("provisioner") {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "The provisioner query parameter must be specified.",
})
return
Expand All @@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
case string(codersdk.ProvisionerTypeTerraform):
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
default:
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
})
return
Expand All @@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)

if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "You aren't allowed to create provisioner daemons for the organization.",
})
return
Expand All @@ -155,15 +159,15 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}

name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: provisioners,
Tags: tags,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error writing provisioner daemon.",
Detail: err.Error(),
})
Expand All @@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)

rawTags, err := json.Marshal(daemon.Tags)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling daemon tags.",
Detail: err.Error(),
})
Expand All @@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error accepting websocket connection.",
Detail: err.Error(),
})
Expand All @@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
// the same connection.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
return
Expand All @@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
err = server.Serve(r.Context(), session)
err = server.Serve(ctx, session)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
Expand All @@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
}
return result
}

// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}

func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}

func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}

// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}