Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dafdb32

Browse files
committed
improvement to dialing, added local TCP dialing to agentconn
1 parent b0e1cef commit dafdb32

File tree

19 files changed

+1239
-540
lines changed

19 files changed

+1239
-540
lines changed

agent/agent.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ type agent struct {
275275

276276
// Immortal streams
277277
immortalStreamsManager *immortalstreams.Manager
278-
immortalStreamsDialer *immortalstreams.SmartDialer
278+
immortalStreamsDialer *immortalstreams.LocalDialer
279279
}
280280

281281
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -343,8 +343,8 @@ func (a *agent) init() {
343343

344344
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
345345

346-
// Initialize immortal streams manager with smart dialer
347-
a.immortalStreamsDialer = immortalstreams.NewSmartDialerWithoutTailnet(a.logger.Named("immortal-streams"))
346+
// Initialize immortal streams manager with local dialer
347+
a.immortalStreamsDialer = immortalstreams.NewLocalDialer(a.logger.Named("immortal-streams"))
348348
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), a.immortalStreamsDialer)
349349

350350
a.reconnectingPTYServer = reconnectingpty.NewServer(
@@ -1303,10 +1303,9 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
13031303
a.network = network
13041304
a.statsReporter = newStatsReporter(a.logger, network, a)
13051305

1306-
// Update the immortal streams dialer with the new tailnet connection
1306+
// Update the immortal streams components with the new tailnet connection
13071307
if a.immortalStreamsDialer != nil {
1308-
agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID)
1309-
a.immortalStreamsDialer.UpdateTailnetConn(network, agentAddr)
1308+
a.immortalStreamsDialer.UpdateTailnetConn(network)
13101309
}
13111310
}
13121311
a.closeMutex.Unlock()

agent/immortalstreams/handler.go

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,31 +105,18 @@ func (h *Handler) listStreams(w http.ResponseWriter, r *http.Request) {
105105
// handleStreamRequest handles GET requests for a specific stream and returns stream info or handles WebSocket upgrades
106106
func (h *Handler) handleStreamRequest(w http.ResponseWriter, r *http.Request) {
107107
ctx := r.Context()
108-
streamID := getStreamID(ctx)
108+
_ = getStreamID(ctx)
109109

110-
// Check if this is a WebSocket upgrade request by looking for WebSocket headers
110+
// Require WebSocket upgrade for connection/reconnect
111111
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
112112
h.handleUpgrade(w, r)
113113
return
114114
}
115115

116-
// Otherwise, return stream info
117-
stream, ok := h.manager.GetStream(streamID)
118-
if !ok {
119-
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
120-
Message: "Stream not found",
121-
})
122-
return
123-
}
124-
125-
// Include the current remote reader sequence number in response header
126-
// for consistency with WebSocket upgrade behavior
127-
if stream != nil && stream.GetPipe() != nil {
128-
readerSeqNum := stream.GetPipe().ReaderSequenceNum()
129-
w.Header().Set(codersdk.HeaderImmortalStreamSequenceNum, strconv.FormatUint(readerSeqNum, 10))
130-
}
131-
132-
httpapi.Write(ctx, w, http.StatusOK, stream.ToAPI())
116+
// Otherwise, return bad request since only reconnect is supported
117+
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
118+
Message: "Upgrade required for immortal stream",
119+
})
133120
}
134121

135122
// deleteStream deletes a stream
@@ -197,6 +184,9 @@ func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
197184
connCtx, cancel := context.WithCancel(ctx)
198185
defer cancel()
199186

187+
// Keep the WebSocket connection alive with periodic pings
188+
go httpapi.Heartbeat(connCtx, conn)
189+
200190
// Ensure WebSocket is closed when this function returns
201191
defer func() {
202192
conn.Close(websocket.StatusNormalClosure, "connection closed")
@@ -234,7 +224,6 @@ func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
234224
<-connCtx.Done()
235225
}
236226

237-
// wsConn adapts a WebSocket connection to io.ReadWriteCloser
238227
type wsConn struct {
239228
conn *websocket.Conn
240229
logger slog.Logger

agent/immortalstreams/handler_test.go

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package immortalstreams_test
22

33
import (
4+
"bufio"
45
"bytes"
56
"context"
67
"encoding/json"
@@ -9,6 +10,9 @@ import (
910
"net"
1011
"net/http"
1112
"net/http/httptest"
13+
"net/url"
14+
"strconv"
15+
"strings"
1216
"testing"
1317

1418
"github.com/go-chi/chi/v5"
@@ -20,7 +24,6 @@ import (
2024
"github.com/coder/coder/v2/agent/immortalstreams"
2125
"github.com/coder/coder/v2/codersdk"
2226
"github.com/coder/coder/v2/testutil"
23-
"github.com/coder/websocket"
2427
)
2528

2629
func TestImmortalStreamsHandler_CreateStream(t *testing.T) {
@@ -341,20 +344,17 @@ func TestImmortalStreamsHandler_DeleteStream(t *testing.T) {
341344
assert.Equal(t, http.StatusNotFound, w.Code)
342345
}
343346

344-
func TestImmortalStreamsHandler_Upgrade(t *testing.T) {
347+
func TestImmortalStreamsHandler_RawUpgrade(t *testing.T) {
345348
t.Parallel()
346349

347350
ctx := testutil.Context(t, testutil.WaitShort)
348351
logger := slogtest.Make(t, nil)
349352

350-
// Start a test server
353+
// Start a test server providing echo on accepted connections
351354
listener, err := net.Listen("tcp", "localhost:0")
352355
require.NoError(t, err)
353356
defer listener.Close()
354-
355357
port := listener.Addr().(*net.TCPAddr).Port
356-
357-
// Accept connections in the background
358358
go func() {
359359
for {
360360
conn, err := listener.Accept()
@@ -363,64 +363,77 @@ func TestImmortalStreamsHandler_Upgrade(t *testing.T) {
363363
}
364364
go func() {
365365
defer conn.Close()
366-
// Echo server
367366
_, _ = io.Copy(conn, conn)
368367
}()
369368
}
370369
}()
371370

372-
// Create handler
371+
// Create handler and server
373372
dialer := &testDialer{}
374373
manager := immortalstreams.New(logger, dialer)
375374
defer manager.Close()
376-
377375
handler := immortalstreams.NewHandler(logger, manager)
378-
379-
// Create a test server
380376
server := httptest.NewServer(handler.Routes())
381377
defer server.Close()
382378

383379
// Create a stream
384380
stream, err := manager.CreateStream(ctx, port)
385381
require.NoError(t, err)
386382

387-
// Connect with WebSocket
388-
wsURL := fmt.Sprintf("ws%s/%s",
389-
server.URL[4:], // Remove "http" prefix
383+
// Dial server and send raw HTTP/1.1 Upgrade request
384+
u, err := url.Parse(server.URL)
385+
require.NoError(t, err)
386+
c, err := net.Dial("tcp", u.Host)
387+
require.NoError(t, err)
388+
defer c.Close()
389+
390+
req := fmt.Sprintf("GET /%s HTTP/1.1\r\nHost: %s\r\nUpgrade: %s\r\nConnection: Upgrade\r\n%s: 0\r\n\r\n",
390391
stream.ID,
392+
u.Host,
393+
codersdk.UpgradeImmortalStream,
394+
codersdk.HeaderImmortalStreamSequenceNum,
391395
)
392-
393-
conn, resp, err := websocket.Dial(ctx, wsURL, &websocket.DialOptions{
394-
HTTPHeader: http.Header{
395-
codersdk.HeaderImmortalStreamSequenceNum: []string{"0"},
396-
},
397-
})
398-
defer func() {
399-
if resp != nil && resp.Body != nil {
400-
_ = resp.Body.Close()
401-
}
402-
}()
396+
_, err = c.Write([]byte(req))
403397
require.NoError(t, err)
404-
defer conn.Close(websocket.StatusNormalClosure, "")
405-
406-
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
407398

408-
// Send some data
409-
testData := []byte("hello world")
410-
err = conn.Write(ctx, websocket.MessageBinary, testData)
399+
br := bufio.NewReader(c)
400+
status, err := br.ReadString('\n')
411401
require.NoError(t, err)
402+
assert.True(t, strings.HasPrefix(status, "HTTP/1.1 101 ") || strings.HasPrefix(status, "HTTP/1.0 101 "))
403+
404+
// Read headers until blank line and find sequence header
405+
seenSeq := false
406+
for {
407+
line, rerr := br.ReadString('\n')
408+
require.NoError(t, rerr)
409+
if line == "\r\n" {
410+
break
411+
}
412+
if i := strings.IndexByte(line, ':'); i > 0 {
413+
k := strings.TrimSpace(line[:i])
414+
v := strings.TrimSpace(strings.TrimSuffix(line[i+1:], "\r\n"))
415+
if strings.EqualFold(k, codersdk.HeaderImmortalStreamSequenceNum) {
416+
_, _ = strconv.ParseUint(v, 10, 64)
417+
seenSeq = true
418+
}
419+
}
420+
}
421+
assert.True(t, seenSeq)
412422

413-
// Read echoed data
414-
msgType, data, err := conn.Read(ctx)
423+
// Echo round-trip over upgraded connection
424+
payload := []byte("hello world")
425+
_, err = c.Write(payload)
426+
require.NoError(t, err)
427+
buf := make([]byte, len(payload))
428+
_, err = io.ReadFull(br, buf)
415429
require.NoError(t, err)
416-
assert.Equal(t, websocket.MessageBinary, msgType)
417-
assert.Equal(t, testData, data)
430+
assert.Equal(t, payload, buf)
418431
}
419432

420433
// Test helpers
421434

422435
type handlerTestDialer struct{}
423436

424-
func (*handlerTestDialer) DialContext(_ context.Context, network, address string) (net.Conn, error) {
425-
return net.Dial(network, address)
437+
func (*handlerTestDialer) DialContext(_ context.Context, address string) (net.Conn, error) {
438+
return net.Dial("tcp", address)
426439
}

agent/immortalstreams/localdialer.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package immortalstreams
2+
3+
import (
4+
"context"
5+
"net"
6+
"strconv"
7+
"strings"
8+
9+
"golang.org/x/xerrors"
10+
11+
"cdr.dev/slog"
12+
"github.com/coder/coder/v2/tailnet"
13+
)
14+
15+
// LocalDialer is a Dialer implementation that intelligently chooses between
16+
// local network connections and tailscale network connections based on the
17+
// target address and context.
18+
//
19+
// It follows a similar approach to tailnet.forwardTCP and the netstack
20+
// GetTCPHandlerForFlow pattern, providing context-aware routing for different
21+
// connection types.
22+
//
23+
// The LocalDialer can be created without a tailnet connection and will fall back
24+
// to local dialing until UpdateTailnetConn is called.
25+
type LocalDialer struct {
26+
logger slog.Logger
27+
28+
// localDialer handles traditional local network connections
29+
localDialer *net.Dialer
30+
31+
// tailnetConn provides access to the tailscale network (can be nil initially)
32+
tailnetConn *tailnet.Conn
33+
}
34+
35+
// NewLocalDialer creates a new LocalDialer that can route connections
36+
// intelligently between local and tailscale networks.
37+
// The tailnetConn will be nil initially and updated later
38+
// via UpdateTailnetConn.
39+
func NewLocalDialer(logger slog.Logger) *LocalDialer {
40+
return &LocalDialer{
41+
logger: logger.Named("local-dialer"),
42+
localDialer: &net.Dialer{},
43+
tailnetConn: nil,
44+
}
45+
}
46+
47+
// UpdateTailnetConn updates the tailnet connection and agent address.
48+
// This allows the LocalDialer to start using tailscale network routing.
49+
func (d *LocalDialer) UpdateTailnetConn(tailnetConn *tailnet.Conn) {
50+
d.tailnetConn = tailnetConn
51+
d.logger.Debug(context.Background(), "updated tailnet connection")
52+
}
53+
54+
// DialContext implements the Dialer interface that tries to connect to an
55+
// in-process listener on the requested TCP port via tailnet.Conn.forwardTCP using net.Pipe.
56+
// If no in-process listener is found, it falls back to dialing localhost.
57+
// Only TCP connections are supported as immortal streams only need TCP.
58+
func (d *LocalDialer) DialContext(ctx context.Context, address string) (net.Conn, error) {
59+
d.logger.Debug(ctx, "dialing connection",
60+
slog.F("address", address))
61+
62+
// Parse the address and extract port for potential in-process listener dial
63+
// We ignore the host part of the address; we always dial localhost.
64+
_, portStr, err := net.SplitHostPort(address)
65+
if err != nil {
66+
return nil, xerrors.Errorf("parse address %q: %w", address, err)
67+
}
68+
69+
port, err := strconv.ParseUint(portStr, 10, 16)
70+
if err != nil {
71+
return nil, xerrors.Errorf("parse port %q: %w", portStr, err)
72+
}
73+
74+
// If we have a tailnet connection, first attempt to connect to an
75+
// in-process listener on the requested TCP port via tailnet.Conn.forwardTCP using net.Pipe.
76+
if d.tailnetConn != nil {
77+
if c := d.tailnetConn.DialInternalTCP(ctx, uint16(port)); c != nil {
78+
d.logger.Debug(ctx, "connected to internal tailnet listener via pipe", slog.F("port", port))
79+
return c, nil
80+
}
81+
}
82+
83+
return d.dialLocal(ctx, net.JoinHostPort("localhost", portStr))
84+
}
85+
86+
// dialLocal uses the local network dialer
87+
func (d *LocalDialer) dialLocal(ctx context.Context, address string) (net.Conn, error) {
88+
conn, err := d.localDialer.DialContext(ctx, "tcp", address)
89+
if err != nil {
90+
d.logger.Debug(ctx, "local dial failed", slog.Error(err))
91+
return nil, err
92+
}
93+
d.logger.Debug(ctx, "local dial succeeded")
94+
return conn, nil
95+
}
96+
97+
// isConnectionRefusedError checks if an error indicates a connection was refused
98+
// This is used to preserve the semantics of the original error handling
99+
func isConnectionRefusedError(err error) bool {
100+
if err == nil {
101+
return false
102+
}
103+
104+
// This uses the same logic as the original manager.go
105+
errStr := err.Error()
106+
return strings.Contains(errStr, "connection refused") ||
107+
strings.Contains(errStr, "connectex: No connection could be made because the target machine actively refused it") ||
108+
strings.Contains(errStr, "actively refused")
109+
}

0 commit comments

Comments
 (0)