Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d03849e

Browse files
committed
Migrate coordinator to use net.conn
1 parent 93f965a commit d03849e

File tree

7 files changed

+184
-118
lines changed

7 files changed

+184
-118
lines changed

agent/agent.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import (
2828
"go.uber.org/atomic"
2929
gossh "golang.org/x/crypto/ssh"
3030
"golang.org/x/xerrors"
31-
"nhooyr.io/websocket"
3231
"tailscale.com/tailcfg"
3332

3433
"cdr.dev/slog"
@@ -82,7 +81,7 @@ type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Lis
8281

8382
// CoordinatorDialer is a function that constructs a new broker.
8483
// A dialer must be passed in to allow for reconnects.
85-
type CoordinatorDialer func(ctx context.Context) (*websocket.Conn, error)
84+
type CoordinatorDialer func(ctx context.Context) (net.Conn, error)
8685

8786
// FetchMetadata is a function to obtain metadata for the agent.
8887
type FetchMetadata func(ctx context.Context) (Metadata, error)
@@ -220,7 +219,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
220219

221220
// runCoordinator listens for nodes and updates the self-node as it changes.
222221
func (a *agent) runCoordinator(ctx context.Context) {
223-
var coordinator *websocket.Conn
222+
var coordinator net.Conn
224223
var err error
225224
// An exponential back-off occurs when the connection is failing to dial.
226225
// This is to prevent server spam in case of a coderd outage.
@@ -239,7 +238,7 @@ func (a *agent) runCoordinator(ctx context.Context) {
239238
a.logger.Info(context.Background(), "connected to coordination server")
240239
break
241240
}
242-
sendNodes, errChan := tailnet.ServeCoordinator(ctx, coordinator, a.network.UpdateNodes)
241+
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, a.network.UpdateNodes)
243242
a.network.SetNodeCallback(sendNodes)
244243
select {
245244
case <-ctx.Done():
@@ -885,9 +884,7 @@ func (a *agent) Close() error {
885884
}
886885
close(a.closed)
887886
a.closeCancel()
888-
fmt.Printf("CLOSING NETWORK!!!!\n")
889887
if a.network != nil {
890-
fmt.Printf("ACTUALLY CLOSING NETWORK!!!!\n")
891888
_ = a.network.Close()
892889
}
893890
_ = a.sshServer.Close()

agent/agent_test.go

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ package agent_test
33
import (
44
"bufio"
55
"context"
6+
"crypto/tls"
67
"encoding/json"
78
"fmt"
89
"io"
910
"net"
11+
"net/http"
12+
"net/http/httptest"
13+
"net/netip"
1014
"os"
1115
"os/exec"
1216
"path/filepath"
@@ -17,6 +21,11 @@ import (
1721
"time"
1822

1923
"golang.org/x/xerrors"
24+
"tailscale.com/derp"
25+
"tailscale.com/derp/derphttp"
26+
"tailscale.com/tailcfg"
27+
"tailscale.com/types/key"
28+
tslogger "tailscale.com/types/logger"
2029

2130
scp "github.com/bramvdbogaerde/go-scp"
2231
"github.com/google/uuid"
@@ -38,6 +47,7 @@ import (
3847
"github.com/coder/coder/peerbroker/proto"
3948
"github.com/coder/coder/provisionersdk"
4049
"github.com/coder/coder/pty/ptytest"
50+
"github.com/coder/coder/tailnet"
4151
"github.com/coder/coder/testutil"
4252
)
4353

@@ -423,7 +433,14 @@ func TestAgent(t *testing.T) {
423433

424434
t.Run("Tailscale", func(t *testing.T) {
425435
t.Parallel()
426-
436+
derpMap := runDERPAndStun(t, tailnet.Logger(slogtest.Make(t, nil)))
437+
conn := setupSSHSession(t, agent.Metadata{
438+
DERPMap: derpMap,
439+
})
440+
defer conn.Close()
441+
output, err := conn.CombinedOutput("echo test")
442+
require.NoError(t, err)
443+
t.Log(string(output))
427444
})
428445
}
429446

@@ -469,6 +486,9 @@ func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
469486

470487
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn {
471488
client, server := provisionersdk.TransportPipe()
489+
tailscale := metadata.DERPMap != nil
490+
coordinator := tailnet.NewCoordinator()
491+
agentID := uuid.New()
472492
closer := agent.New(agent.Options{
473493
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
474494
return metadata, nil
@@ -477,6 +497,12 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
477497
listener, err := peerbroker.Listen(server, nil)
478498
return listener, err
479499
},
500+
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
501+
clientConn, serverConn := net.Pipe()
502+
go coordinator.ServeAgent(serverConn, agentID)
503+
return clientConn, nil
504+
},
505+
EnableTailnet: tailscale,
480506
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
481507
ReconnectingPTYTimeout: ptyTimeout,
482508
})
@@ -488,6 +514,24 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
488514
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
489515
stream, err := api.NegotiateConnection(context.Background())
490516
assert.NoError(t, err)
517+
if tailscale {
518+
conn, err := tailnet.NewConn(&tailnet.Options{
519+
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
520+
DERPMap: metadata.DERPMap,
521+
Logger: slogtest.Make(t, nil).Named("tailnet"),
522+
})
523+
require.NoError(t, err)
524+
525+
clientConn, serverConn := net.Pipe()
526+
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
527+
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
528+
return conn.UpdateNodes(node)
529+
})
530+
conn.SetNodeCallback(sendNode)
531+
return &agent.TailnetConn{
532+
Conn: conn,
533+
}
534+
}
491535
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
492536
Logger: slogtest.Make(t, nil),
493537
})
@@ -532,3 +576,53 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
532576
assert.NoError(t, err, "write payload")
533577
assert.Equal(t, len(payload), n, "payload length does not match")
534578
}
579+
580+
func runDERPAndStun(t *testing.T, logf tslogger.Logf) (derpMap *tailcfg.DERPMap) {
581+
d := derp.NewServer(key.NewNode(), logf)
582+
server := httptest.NewUnstartedServer(derphttp.Handler(d))
583+
server.Config.ErrorLog = tslogger.StdLogger(logf)
584+
server.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
585+
server.StartTLS()
586+
587+
// stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
588+
t.Cleanup(func() {
589+
server.CloseClientConnections()
590+
server.Close()
591+
d.Close()
592+
// stunCleanup()
593+
})
594+
595+
tcpAddr, ok := server.Listener.Addr().(*net.TCPAddr)
596+
if !ok {
597+
t.FailNow()
598+
}
599+
600+
return &tailcfg.DERPMap{
601+
Regions: map[int]*tailcfg.DERPRegion{
602+
1: {
603+
RegionID: 1,
604+
RegionCode: "test",
605+
RegionName: "Testlandia",
606+
Nodes: []*tailcfg.DERPNode{
607+
{
608+
Name: "t1",
609+
RegionID: 1,
610+
HostName: "stun.l.google.com",
611+
DERPPort: -1,
612+
STUNPort: 19302,
613+
STUNOnly: true,
614+
},
615+
{
616+
Name: "t2",
617+
RegionID: 1,
618+
IPv4: "127.0.0.1",
619+
IPv6: "none",
620+
STUNPort: -1,
621+
DERPPort: tcpAddr.Port,
622+
InsecureForTests: true,
623+
},
624+
},
625+
},
626+
},
627+
}
628+
}

cli/ssh.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"golang.org/x/xerrors"
2222

2323
"cdr.dev/slog"
24-
"cdr.dev/slog/sloggers/sloghuman"
2524

2625
"github.com/coder/coder/agent"
2726
"github.com/coder/coder/cli/cliflag"
@@ -91,7 +90,7 @@ func ssh() *cobra.Command {
9190
if !wireguard {
9291
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
9392
} else {
94-
conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Make(sloghuman.Sink(cmd.ErrOrStderr())).Leveled(slog.LevelDebug), workspaceAgent.ID)
93+
conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
9594
}
9695
if err != nil {
9796
return err

coderd/workspaceagents.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
530530
return
531531
}
532532
defer conn.Close(websocket.StatusNormalClosure, "")
533-
err = api.ConnCoordinator.ServeAgent(r.Context(), conn, workspaceAgent.ID)
533+
err = api.ConnCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID)
534534
if err != nil {
535535
_ = conn.Close(websocket.StatusInternalError, err.Error())
536536
return
@@ -556,7 +556,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
556556
return
557557
}
558558
defer conn.Close(websocket.StatusNormalClosure, "")
559-
err = api.ConnCoordinator.ServeClient(r.Context(), conn, uuid.New(), workspaceAgent.ID)
559+
err = api.ConnCoordinator.ServeClient(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
560560
if err != nil {
561561
_ = conn.Close(websocket.StatusInternalError, err.Error())
562562
return

codersdk/workspaceagents.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ func (c *Client) UpdateWorkspaceAgentNode(ctx context.Context, agentID uuid.UUID
281281
return nil
282282
}
283283

284-
func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (*websocket.Conn, error) {
284+
func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) {
285285
coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate")
286286
if err != nil {
287287
return nil, xerrors.Errorf("parse url: %w", err)
@@ -300,7 +300,10 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (*websocket.Co
300300
conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
301301
HTTPClient: httpClient,
302302
})
303-
return conn, err
303+
if err != nil {
304+
return nil, err
305+
}
306+
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
304307
}
305308

306309
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) {
@@ -370,7 +373,7 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
370373
logger.Debug(ctx, "failed to dial", slog.Error(err))
371374
continue
372375
}
373-
sendNode, errChan := tailnet.ServeCoordinator(ctx, ws, func(node []*tailnet.Node) error {
376+
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
374377
return conn.UpdateNodes(node)
375378
})
376379
conn.SetNodeCallback(sendNode)

0 commit comments

Comments
 (0)