Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b0e1cef

Browse files
committed
smart dialer
1 parent 76489db commit b0e1cef

File tree

3 files changed

+481
-2
lines changed

3 files changed

+481
-2
lines changed

agent/agent.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ type agent struct {
275275

276276
// Immortal streams
277277
immortalStreamsManager *immortalstreams.Manager
278+
immortalStreamsDialer *immortalstreams.SmartDialer
278279
}
279280

280281
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -342,8 +343,9 @@ func (a *agent) init() {
342343

343344
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
344345

345-
// Initialize immortal streams manager
346-
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), &net.Dialer{})
346+
// Initialize immortal streams manager with smart dialer
347+
a.immortalStreamsDialer = immortalstreams.NewSmartDialerWithoutTailnet(a.logger.Named("immortal-streams"))
348+
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), a.immortalStreamsDialer)
347349

348350
a.reconnectingPTYServer = reconnectingpty.NewServer(
349351
a.logger.Named("reconnecting-pty"),
@@ -1300,6 +1302,12 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
13001302
if !closing {
13011303
a.network = network
13021304
a.statsReporter = newStatsReporter(a.logger, network, a)
1305+
1306+
// Update the immortal streams dialer with the new tailnet connection
1307+
if a.immortalStreamsDialer != nil {
1308+
agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID)
1309+
a.immortalStreamsDialer.UpdateTailnetConn(network, agentAddr)
1310+
}
13031311
}
13041312
a.closeMutex.Unlock()
13051313
if closing {

agent/immortalstreams/smartdialer.go

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
package immortalstreams
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/netip"
7+
"strconv"
8+
"strings"
9+
10+
"golang.org/x/xerrors"
11+
12+
"cdr.dev/slog"
13+
"github.com/coder/coder/v2/tailnet"
14+
)
15+
16+
// SmartDialer is a Dialer implementation that intelligently chooses between
17+
// local network connections and tailscale network connections based on the
18+
// target address and context.
19+
//
20+
// It follows a similar approach to tailnet.forwardTCP and the netstack
21+
// GetTCPHandlerForFlow pattern, providing context-aware routing for different
22+
// connection types.
23+
//
24+
// The SmartDialer can be created without a tailnet connection and will fall back
25+
// to local dialing until UpdateTailnetConn is called.
26+
type SmartDialer struct {
27+
logger slog.Logger
28+
29+
// localDialer handles traditional local network connections
30+
localDialer *net.Dialer
31+
32+
// tailnetConn provides access to the tailscale network (can be nil initially)
33+
tailnetConn *tailnet.Conn
34+
35+
// agentAddr is the agent's own tailscale address (can be zero initially)
36+
agentAddr netip.Addr
37+
}
38+
39+
// NewSmartDialer creates a new SmartDialer that can route connections
40+
// intelligently between local and tailscale networks.
41+
// The tailnetConn and agentAddr can be nil/zero initially and updated later
42+
// via UpdateTailnetConn.
43+
func NewSmartDialer(logger slog.Logger, tailnetConn *tailnet.Conn, agentAddr netip.Addr) *SmartDialer {
44+
return &SmartDialer{
45+
logger: logger.Named("smart-dialer"),
46+
localDialer: &net.Dialer{},
47+
tailnetConn: tailnetConn,
48+
agentAddr: agentAddr,
49+
}
50+
}
51+
52+
// NewSmartDialerWithoutTailnet creates a new SmartDialer that initially only
53+
// supports local dialing. The tailnet connection can be added later via
54+
// UpdateTailnetConn.
55+
func NewSmartDialerWithoutTailnet(logger slog.Logger) *SmartDialer {
56+
return NewSmartDialer(logger, nil, netip.Addr{})
57+
}
58+
59+
// UpdateTailnetConn updates the tailnet connection and agent address.
60+
// This allows the SmartDialer to start using tailscale network routing.
61+
func (d *SmartDialer) UpdateTailnetConn(tailnetConn *tailnet.Conn, agentAddr netip.Addr) {
62+
d.tailnetConn = tailnetConn
63+
d.agentAddr = agentAddr
64+
d.logger.Debug(context.Background(), "updated tailnet connection",
65+
slog.F("has_tailnet", tailnetConn != nil),
66+
slog.F("agent_addr", agentAddr.String()))
67+
}
68+
69+
// DialContext implements the Dialer interface with intelligent routing.
70+
// It analyzes the target address to determine whether to use local network
71+
// or tailscale network connections.
72+
// Only TCP connections are supported as immortal streams only need TCP.
73+
func (d *SmartDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
74+
d.logger.Debug(ctx, "dialing connection",
75+
slog.F("network", network),
76+
slog.F("address", address))
77+
78+
// Only support TCP for immortal streams
79+
if network != "tcp" {
80+
return nil, xerrors.Errorf("unsupported network type: %q (only tcp is supported)", network)
81+
}
82+
83+
// Parse the address to determine routing strategy
84+
host, portStr, err := net.SplitHostPort(address)
85+
if err != nil {
86+
return nil, xerrors.Errorf("parse address %q: %w", address, err)
87+
}
88+
89+
port, err := strconv.ParseUint(portStr, 10, 16)
90+
if err != nil {
91+
return nil, xerrors.Errorf("parse port %q: %w", portStr, err)
92+
}
93+
94+
strategy := d.determineDialStrategy(host, uint16(port))
95+
96+
d.logger.Debug(ctx, "determined dial strategy",
97+
slog.F("strategy", strategy),
98+
slog.F("host", host),
99+
slog.F("port", port))
100+
101+
switch strategy {
102+
case dialStrategyLocal:
103+
return d.dialLocal(ctx, network, address)
104+
case dialStrategyTailscale:
105+
return d.dialTailscale(ctx, network, host, uint16(port))
106+
case dialStrategyFallback:
107+
return d.dialWithFallback(ctx, network, address, host, uint16(port))
108+
default:
109+
return nil, xerrors.Errorf("unknown dial strategy: %v", strategy)
110+
}
111+
}
112+
113+
type dialStrategy int
114+
115+
const (
116+
dialStrategyLocal dialStrategy = iota
117+
dialStrategyTailscale
118+
dialStrategyFallback
119+
)
120+
121+
// determineDialStrategy analyzes the target address and determines the best
122+
// routing strategy based on the host and port.
123+
func (d *SmartDialer) determineDialStrategy(host string, port uint16) dialStrategy {
124+
// Parse the host to check if it's an IP address
125+
if addr, err := netip.ParseAddr(host); err == nil {
126+
// It's an IP address - analyze it
127+
return d.analyzeIPAddress(addr, port)
128+
}
129+
130+
// It's a hostname - analyze it
131+
return d.analyzeHostname(host, port)
132+
}
133+
134+
// analyzeIPAddress determines the dial strategy for IP addresses
135+
func (d *SmartDialer) analyzeIPAddress(addr netip.Addr, _ uint16) dialStrategy {
136+
// Check if it's a localhost/loopback address
137+
if addr.IsLoopback() {
138+
d.logger.Debug(context.Background(), "detected loopback address, using local dial")
139+
return dialStrategyLocal
140+
}
141+
142+
// Check if it's a tailscale service prefix address
143+
if tailnet.TailscaleServicePrefix.AsNetip().Contains(addr) {
144+
d.logger.Debug(context.Background(), "detected tailscale service address, using tailscale dial")
145+
return dialStrategyTailscale
146+
}
147+
148+
// Check if it's a coder service prefix address
149+
if tailnet.CoderServicePrefix.AsNetip().Contains(addr) {
150+
d.logger.Debug(context.Background(), "detected coder service address, using tailscale dial")
151+
return dialStrategyTailscale
152+
}
153+
154+
// Check if it's a private/local network address
155+
if addr.IsPrivate() || addr.IsLinkLocalUnicast() {
156+
d.logger.Debug(context.Background(), "detected private address, trying local first with tailscale fallback")
157+
return dialStrategyFallback
158+
}
159+
160+
// For other addresses, prefer local dialing (might be reachable through local routing)
161+
d.logger.Debug(context.Background(), "unknown address type, trying local first with tailscale fallback")
162+
return dialStrategyFallback
163+
}
164+
165+
// analyzeHostname determines the dial strategy for hostnames
166+
func (d *SmartDialer) analyzeHostname(host string, _ uint16) dialStrategy {
167+
// Normalize hostname
168+
host = strings.ToLower(host)
169+
170+
// Check for localhost variants
171+
if host == "localhost" || host == "localhost.localdomain" {
172+
d.logger.Debug(context.Background(), "detected localhost hostname, using local dial")
173+
return dialStrategyLocal
174+
}
175+
176+
// Check for special tailscale hostnames (if any patterns emerge)
177+
if strings.HasSuffix(host, ".ts.net") || strings.HasSuffix(host, ".tailscale") {
178+
d.logger.Debug(context.Background(), "detected tailscale hostname, using tailscale dial")
179+
return dialStrategyTailscale
180+
}
181+
182+
// For other hostnames, try local first with fallback
183+
d.logger.Debug(context.Background(), "unknown hostname, trying local first with tailscale fallback")
184+
return dialStrategyFallback
185+
}
186+
187+
// dialLocal uses the local network dialer
188+
func (d *SmartDialer) dialLocal(ctx context.Context, network, address string) (net.Conn, error) {
189+
conn, err := d.localDialer.DialContext(ctx, network, address)
190+
if err != nil {
191+
d.logger.Debug(ctx, "local dial failed", slog.Error(err))
192+
return nil, err
193+
}
194+
d.logger.Debug(ctx, "local dial succeeded")
195+
return conn, nil
196+
}
197+
198+
// dialTailscale uses the tailscale network connection
199+
func (d *SmartDialer) dialTailscale(ctx context.Context, network, host string, port uint16) (net.Conn, error) {
200+
if d.tailnetConn == nil {
201+
return nil, xerrors.New("tailnet connection not available")
202+
}
203+
204+
// Parse the host as an IP address for tailscale dialing
205+
addr, err := netip.ParseAddr(host)
206+
if err != nil {
207+
return nil, xerrors.Errorf("parse tailscale address %q: %w", host, err)
208+
}
209+
210+
addrPort := netip.AddrPortFrom(addr, port)
211+
212+
if network != "tcp" {
213+
return nil, xerrors.Errorf("unsupported network type for tailscale dial: %q (only tcp is supported)", network)
214+
}
215+
216+
conn, err := d.tailnetConn.DialContextTCP(ctx, addrPort)
217+
if err != nil {
218+
d.logger.Debug(ctx, "tailscale TCP dial failed", slog.Error(err))
219+
return nil, err
220+
}
221+
d.logger.Debug(ctx, "tailscale TCP dial succeeded")
222+
return conn, nil
223+
}
224+
225+
// dialWithFallback tries local first, then falls back to tailscale if available
226+
func (d *SmartDialer) dialWithFallback(ctx context.Context, network, address, host string, port uint16) (net.Conn, error) {
227+
// Try local first
228+
localConn, localErr := d.dialLocal(ctx, network, address)
229+
if localErr == nil {
230+
d.logger.Debug(ctx, "fallback: local dial succeeded")
231+
return localConn, nil
232+
}
233+
234+
d.logger.Debug(ctx, "fallback: local dial failed, trying tailscale", slog.Error(localErr))
235+
236+
// If local failed and we have tailnet, try tailscale
237+
if d.tailnetConn != nil {
238+
// Try to parse as IP for tailscale dial
239+
if _, err := netip.ParseAddr(host); err == nil {
240+
tailscaleConn, tailscaleErr := d.dialTailscale(ctx, network, host, port)
241+
if tailscaleErr == nil {
242+
d.logger.Debug(ctx, "fallback: tailscale dial succeeded")
243+
return tailscaleConn, nil
244+
}
245+
d.logger.Debug(ctx, "fallback: tailscale dial also failed", slog.Error(tailscaleErr))
246+
247+
// Return the more specific error if both failed
248+
if isConnectionRefusedError(localErr) {
249+
return nil, localErr
250+
}
251+
return nil, tailscaleErr
252+
}
253+
}
254+
255+
// If we can't try tailscale or it's not available, return the local error
256+
d.logger.Debug(ctx, "fallback: only local dial attempted")
257+
return nil, localErr
258+
}
259+
260+
// isConnectionRefusedError checks if an error indicates a connection was refused
261+
// This is used to preserve the semantics of the original error handling
262+
func isConnectionRefusedError(err error) bool {
263+
if err == nil {
264+
return false
265+
}
266+
267+
// This uses the same logic as the original manager.go
268+
errStr := err.Error()
269+
return strings.Contains(errStr, "connection refused") ||
270+
strings.Contains(errStr, "connectex: No connection could be made because the target machine actively refused it") ||
271+
strings.Contains(errStr, "actively refused")
272+
}

0 commit comments

Comments
 (0)