Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 38f4004

Browse files
committed
Finish refactor to make this work
1 parent 36655f0 commit 38f4004

19 files changed

+355
-317
lines changed

.vscode/settings.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
"sdkproto",
7676
"sdktrace",
7777
"Signup",
78+
"slogtest",
7879
"sourcemapped",
7980
"Srcs",
8081
"stretchr",
@@ -110,6 +111,7 @@
110111
"wgmonitor",
111112
"wgnet",
112113
"workspaceagent",
114+
"workspaceagents",
113115
"workspaceapp",
114116
"workspaceapps",
115117
"workspacebuilds",

agent/agent.go

Lines changed: 154 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,19 @@ const (
4646
)
4747

4848
type Options struct {
49-
EnableWireguard bool
50-
UpdateTailscaleNode UpdateTailscaleNode
51-
ListenTailscaleNodes ListenTailscaleNodes
49+
EnableTailnet bool
50+
NodeDialer NodeDialer
51+
WebRTCDialer WebRTCDialer
52+
FetchMetadata FetchMetadata
53+
5254
ReconnectingPTYTimeout time.Duration
5355
EnvironmentVariables map[string]string
5456
Logger slog.Logger
5557
}
5658

5759
type Metadata struct {
58-
TailscaleAddresses []netaddr.IPPrefix `json:"tailscale_addresses"`
59-
TailscaleDERPMap *tailcfg.DERPMap `json:"tailscale_derpmap"`
60+
IPAddresses []netaddr.IP `json:"ip_addresses"`
61+
DERPMap *tailcfg.DERPMap `json:"derpmap"`
6062

6163
OwnerEmail string `json:"owner_email"`
6264
OwnerUsername string `json:"owner_username"`
@@ -65,37 +67,47 @@ type Metadata struct {
6567
Directory string `json:"directory"`
6668
}
6769

68-
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
70+
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
6971

70-
type UpdateTailscaleNode func(ctx context.Context, node *tailnet.Node) error
71-
type ListenTailscaleNodes func(ctx context.Context, logger slog.Logger) (<-chan *tailnet.Node, func(), error)
72+
// NodeBroker handles the exchange of node information.
73+
type NodeBroker interface {
74+
io.Closer
75+
// Read will be a constant stream of incoming connection requests.
76+
Read(ctx context.Context) (*tailnet.Node, error)
77+
// Write should be called with the listening agent node information.
78+
Write(ctx context.Context, node *tailnet.Node) error
79+
}
7280

73-
func New(dialer Dialer, options *Options) io.Closer {
74-
if options == nil {
75-
options = &Options{}
76-
}
81+
// NodeDialer is a function that constructs a new broker.
82+
// A dialer must be passed in to allow for reconnects.
83+
type NodeDialer func(ctx context.Context) (NodeBroker, error)
84+
85+
// FetchMetadata is a function to obtain metadata for the agent.
86+
type FetchMetadata func(ctx context.Context) (Metadata, error)
87+
88+
func New(options Options) io.Closer {
7789
if options.ReconnectingPTYTimeout == 0 {
7890
options.ReconnectingPTYTimeout = 5 * time.Minute
7991
}
8092
ctx, cancelFunc := context.WithCancel(context.Background())
8193
server := &agent{
82-
dialer: dialer,
94+
webrtcDialer: options.WebRTCDialer,
8395
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
8496
logger: options.Logger,
8597
closeCancel: cancelFunc,
8698
closed: make(chan struct{}),
8799
envVars: options.EnvironmentVariables,
88-
enableWireguard: options.EnableWireguard,
89-
updateTailscaleNode: options.UpdateTailscaleNode,
90-
listenTailscaleNodes: options.ListenTailscaleNodes,
100+
enableTailnet: options.EnableTailnet,
101+
nodeDialer: options.NodeDialer,
102+
fetchMetadata: options.FetchMetadata,
91103
}
92104
server.init(ctx)
93105
return server
94106
}
95107

96108
type agent struct {
97-
dialer Dialer
98-
logger slog.Logger
109+
webrtcDialer WebRTCDialer
110+
logger slog.Logger
99111

100112
reconnectingPTYs sync.Map
101113
reconnectingPTYTimeout time.Duration
@@ -108,23 +120,21 @@ type agent struct {
108120
envVars map[string]string
109121
// metadata is atomic because values can change after reconnection.
110122
metadata atomic.Value
111-
startupScript atomic.Bool
123+
fetchMetadata FetchMetadata
112124
sshServer *ssh.Server
113125

114-
enableWireguard bool
115-
network *tailnet.Server
116-
updateTailscaleNode UpdateTailscaleNode
117-
listenTailscaleNodes ListenTailscaleNodes
126+
enableTailnet bool
127+
network *tailnet.Server
128+
nodeDialer NodeDialer
118129
}
119130

120131
func (a *agent) run(ctx context.Context) {
121132
var metadata Metadata
122-
var peerListener *peerbroker.Listener
123133
var err error
124134
// An exponential back-off occurs when the connection is failing to dial.
125135
// This is to prevent server spam in case of a coderd outage.
126136
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
127-
metadata, peerListener, err = a.dialer(ctx, a.logger)
137+
metadata, err = a.fetchMetadata(ctx)
128138
if err != nil {
129139
if errors.Is(err, context.Canceled) {
130140
return
@@ -135,7 +145,7 @@ func (a *agent) run(ctx context.Context) {
135145
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
136146
continue
137147
}
138-
a.logger.Info(context.Background(), "connected")
148+
a.logger.Info(context.Background(), "fetched metadata")
139149
break
140150
}
141151
select {
@@ -145,25 +155,131 @@ func (a *agent) run(ctx context.Context) {
145155
}
146156
a.metadata.Store(metadata)
147157

148-
if a.startupScript.CAS(false, true) {
149-
// The startup script has not ran yet!
150-
go func() {
151-
err := a.runStartupScript(ctx, metadata.StartupScript)
158+
// The startup script has not ran yet!
159+
go func() {
160+
err := a.runStartupScript(ctx, metadata.StartupScript)
161+
if errors.Is(err, context.Canceled) {
162+
return
163+
}
164+
if err != nil {
165+
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
166+
}
167+
}()
168+
169+
go a.runWebRTCNetworking(ctx)
170+
if a.enableTailnet {
171+
go a.runTailnet(ctx, metadata.IPAddresses, metadata.DERPMap)
172+
}
173+
}
174+
175+
func (a *agent) runTailnet(ctx context.Context, addresses []netaddr.IP, derpMap *tailcfg.DERPMap) {
176+
ipRanges := make([]netaddr.IPPrefix, 0, len(addresses))
177+
for _, address := range addresses {
178+
ipRanges = append(ipRanges, netaddr.IPPrefixFrom(address, 128))
179+
}
180+
var err error
181+
a.network, err = tailnet.New(&tailnet.Options{
182+
Addresses: ipRanges,
183+
DERPMap: derpMap,
184+
Logger: a.logger.Named("tailnet"),
185+
})
186+
if err != nil {
187+
a.logger.Critical(ctx, "create tailnet", slog.Error(err))
188+
return
189+
}
190+
go a.runNodeBroker(ctx)
191+
192+
sshListener, err := a.network.Listen("tcp", ":12212")
193+
if err != nil {
194+
a.logger.Critical(ctx, "listen for ssh", slog.Error(err))
195+
return
196+
}
197+
go func() {
198+
for {
199+
conn, err := sshListener.Accept()
200+
if err != nil {
201+
return
202+
}
203+
go a.sshServer.HandleConn(conn)
204+
}
205+
}()
206+
}
207+
208+
// runNodeBroker listens for nodes and updates the self-node as it changes.
209+
func (a *agent) runNodeBroker(ctx context.Context) {
210+
var nodeBroker NodeBroker
211+
var err error
212+
// An exponential back-off occurs when the connection is failing to dial.
213+
// This is to prevent server spam in case of a coderd outage.
214+
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
215+
nodeBroker, err = a.nodeDialer(ctx)
216+
if err != nil {
152217
if errors.Is(err, context.Canceled) {
153218
return
154219
}
155-
if err != nil {
156-
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
220+
if a.isClosed() {
221+
return
157222
}
158-
}()
223+
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
224+
continue
225+
}
226+
a.logger.Info(context.Background(), "connected to node broker")
227+
break
228+
}
229+
select {
230+
case <-ctx.Done():
231+
return
232+
default:
233+
}
234+
235+
a.network.SetNodeCallback(func(node *tailnet.Node) {
236+
err := nodeBroker.Write(ctx, node)
237+
if err != nil {
238+
a.logger.Warn(context.Background(), "write node", slog.Error(err), slog.F("node", node))
239+
}
240+
})
241+
242+
for {
243+
node, err := nodeBroker.Read(ctx)
244+
if err != nil {
245+
if a.isClosed() {
246+
return
247+
}
248+
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
249+
a.runNodeBroker(ctx)
250+
return
251+
}
252+
err = a.network.UpdateNodes([]*tailnet.Node{node})
253+
if err != nil {
254+
a.logger.Error(ctx, "update tailnet nodes", slog.Error(err), slog.F("node", node))
255+
}
159256
}
257+
}
160258

161-
// We don't want to reinitialize the network if it already exists.
162-
if a.enableWireguard && a.network == nil {
163-
err = a.startWireguard(ctx, metadata.TailscaleAddresses, metadata.TailscaleDERPMap)
259+
func (a *agent) runWebRTCNetworking(ctx context.Context) {
260+
var peerListener *peerbroker.Listener
261+
var err error
262+
// An exponential back-off occurs when the connection is failing to dial.
263+
// This is to prevent server spam in case of a coderd outage.
264+
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
265+
peerListener, err = a.webrtcDialer(ctx, a.logger)
164266
if err != nil {
165-
a.logger.Error(ctx, "start wireguard", slog.Error(err))
267+
if errors.Is(err, context.Canceled) {
268+
return
269+
}
270+
if a.isClosed() {
271+
return
272+
}
273+
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
274+
continue
166275
}
276+
a.logger.Info(context.Background(), "connected to webrtc broker")
277+
break
278+
}
279+
select {
280+
case <-ctx.Done():
281+
return
282+
default:
167283
}
168284

169285
for {
@@ -173,7 +289,7 @@ func (a *agent) run(ctx context.Context) {
173289
return
174290
}
175291
a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
176-
a.run(ctx)
292+
a.runWebRTCNetworking(ctx)
177293
return
178294
}
179295
a.closeMutex.Lock()
@@ -667,71 +783,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
667783
}
668784
}
669785

670-
func (a *agent) startWireguard(ctx context.Context, addresses []netaddr.IPPrefix, derpMap *tailcfg.DERPMap) error {
671-
var err error
672-
a.network, err = tailnet.New(&tailnet.Options{
673-
Addresses: addresses,
674-
DERPMap: derpMap,
675-
Logger: a.logger.Named("tailnet"),
676-
})
677-
if err != nil {
678-
return err
679-
}
680-
a.network.SetNodeCallback(func(node *tailnet.Node) {
681-
err := a.updateTailscaleNode(ctx, node)
682-
if err != nil {
683-
a.logger.Error(ctx, "update tailscale node", slog.Error(err))
684-
}
685-
})
686-
go func() {
687-
for {
688-
var nodes <-chan *tailnet.Node
689-
var err error
690-
var listenClose func()
691-
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
692-
nodes, listenClose, err = a.listenTailscaleNodes(ctx, a.logger)
693-
if err != nil {
694-
if errors.Is(err, context.Canceled) {
695-
return
696-
}
697-
a.logger.Warn(ctx, "listen for tailscale nodes", slog.Error(err))
698-
continue
699-
}
700-
defer listenClose()
701-
a.logger.Info(context.Background(), "listening for tailscale nodes")
702-
break
703-
}
704-
for {
705-
var node *tailnet.Node
706-
select {
707-
case <-ctx.Done():
708-
case node = <-nodes:
709-
}
710-
if node == nil {
711-
// The channel ended!
712-
break
713-
}
714-
a.network.UpdateNodes([]*tailnet.Node{node})
715-
}
716-
}
717-
}()
718-
719-
sshListener, err := a.network.Listen("tcp", ":12212")
720-
if err != nil {
721-
return xerrors.Errorf("listen for ssh: %w", err)
722-
}
723-
go func() {
724-
for {
725-
conn, err := sshListener.Accept()
726-
if err != nil {
727-
return
728-
}
729-
go a.sshServer.HandleConn(conn)
730-
}
731-
}()
732-
return nil
733-
}
734-
735786
// dialResponse is written to datachannels with protocol "dial" by the agent as
736787
// the first packet to signify whether the dial succeeded or failed.
737788
type dialResponse struct {

agent/agent_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,14 @@ func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
441441

442442
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
443443
client, server := provisionersdk.TransportPipe()
444-
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
445-
listener, err := peerbroker.Listen(server, nil)
446-
return metadata, listener, err
447-
}, &agent.Options{
444+
closer := agent.New(agent.Options{
445+
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
446+
return metadata, nil
447+
},
448+
WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
449+
listener, err := peerbroker.Listen(server, nil)
450+
return listener, err
451+
},
448452
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
449453
ReconnectingPTYTimeout: ptyTimeout,
450454
})

0 commit comments

Comments
 (0)