Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a5bfb20

Browse files
authored
chore: refactor TestServer_X11 to use inproc networking (#18564)
relates to #18263 Refactors the x11Forwarder to accept a networking `interface` that we can fake out for testing. This isolates the unit tests from other processes listening in the port range used by X11 forwarding. This will become extremely important in up-stack PRs where we listen on every port in the range and need to control which ports have conflicts.
1 parent abcf3df commit a5bfb20

File tree

3 files changed

+63
-28
lines changed

3 files changed

+63
-28
lines changed

agent/agentssh/agentssh.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ type Config struct {
117117
// Note that this is different from the devcontainers feature, which uses
118118
// subagents.
119119
ExperimentalContainers bool
120+
// X11Net allows overriding the networking implementation used for X11
121+
// forwarding listeners. When nil, a default implementation backed by the
122+
// standard library networking package is used.
123+
X11Net X11Network
120124
}
121125

122126
type Server struct {
@@ -196,6 +200,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
196200
displayOffset: *config.X11DisplayOffset,
197201
sessions: make(map[*x11Session]struct{}),
198202
connections: make(map[net.Conn]struct{}),
203+
network: func() X11Network {
204+
if config.X11Net != nil {
205+
return config.X11Net
206+
}
207+
return osNet{}
208+
}(),
199209
},
200210
}
201211

agent/agentssh/x11.go

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,30 @@ const (
3737
X11MaxPort = X11StartPort + X11MaxDisplays
3838
)
3939

40+
// X11Network abstracts the creation of network listeners for X11 forwarding.
41+
// It is intended mainly for testing; production code uses the default
42+
// implementation backed by the operating system networking stack.
43+
type X11Network interface {
44+
Listen(network, address string) (net.Listener, error)
45+
}
46+
47+
// osNet is the default X11Network implementation that uses the standard
48+
// library network stack.
49+
type osNet struct{}
50+
51+
func (osNet) Listen(network, address string) (net.Listener, error) {
52+
return net.Listen(network, address)
53+
}
54+
4055
type x11Forwarder struct {
4156
logger slog.Logger
4257
x11HandlerErrors *prometheus.CounterVec
4358
fs afero.Fs
4459
displayOffset int
4560

61+
// network creates X11 listener sockets. Defaults to osNet{}.
62+
network X11Network
63+
4664
mu sync.Mutex
4765
sessions map[*x11Session]struct{}
4866
connections map[net.Conn]struct{}
@@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections(
147165
x.closeAndRemoveSession(session)
148166
}
149167

150-
tcpConn, ok := conn.(*net.TCPConn)
151-
if !ok {
152-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
153-
_ = conn.Close()
154-
continue
168+
var originAddr string
169+
var originPort uint32
170+
171+
if tcpConn, ok := conn.(*net.TCPConn); ok {
172+
if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok {
173+
originAddr = tcpAddr.IP.String()
174+
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
175+
originPort = uint32(tcpAddr.Port)
176+
}
155177
}
156-
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
157-
if !ok {
158-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
159-
_ = conn.Close()
160-
continue
178+
// Fallback values for in-memory or non-TCP connections.
179+
if originAddr == "" {
180+
originAddr = "127.0.0.1"
161181
}
162182

163183
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
164184
OriginatorAddress string
165185
OriginatorPort uint32
166186
}{
167-
OriginatorAddress: tcpAddr.IP.String(),
168-
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
169-
OriginatorPort: uint32(tcpAddr.Port),
187+
OriginatorAddress: originAddr,
188+
OriginatorPort: originPort,
170189
}))
171190
if err != nil {
172191
x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
@@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
287306
// createX11Listener creates a listener for X11 forwarding, it will use
288307
// the next available port starting from X11StartPort and displayOffset.
289308
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
290-
var lc net.ListenConfig
291309
// Look for an open port to listen on.
292310
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
293311
if ctx.Err() != nil {
294312
return nil, -1, ctx.Err()
295313
}
296-
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
314+
315+
ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port))
297316
if err == nil {
298317
display = port - X11StartPort
299318
return ln, display, nil

agent/agentssh/x11_test.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package agentssh_test
33
import (
44
"bufio"
55
"bytes"
6-
"context"
76
"encoding/hex"
87
"fmt"
98
"net"
@@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) {
3231
t.Skip("X11 forwarding is only supported on Linux")
3332
}
3433

35-
ctx := context.Background()
34+
ctx := testutil.Context(t, testutil.WaitShort)
3635
logger := testutil.Logger(t)
3736
fs := afero.NewMemMapFs()
38-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
37+
38+
// Use in-process networking for X11 forwarding.
39+
inproc := testutil.NewInProcNet()
40+
41+
// Create server config with custom X11 listener.
42+
cfg := &agentssh.Config{
43+
X11Net: inproc,
44+
}
45+
46+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
3947
require.NoError(t, err)
4048
defer s.Close()
4149
err = s.UpdateHostSigner(42)
@@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) {
93101

94102
x11Chans := c.HandleChannelOpen("x11")
95103
payload := "hello world"
96-
require.Eventually(t, func() bool {
97-
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
98-
if err == nil {
99-
_, err = conn.Write([]byte(payload))
100-
assert.NoError(t, err)
101-
_ = conn.Close()
102-
}
103-
return err == nil
104-
}, testutil.WaitShort, testutil.IntervalFast)
104+
go func() {
105+
conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)))
106+
assert.NoError(t, err)
107+
_, err = conn.Write([]byte(payload))
108+
assert.NoError(t, err)
109+
_ = conn.Close()
110+
}()
105111

106-
x11 := <-x11Chans
112+
x11 := testutil.RequireReceive(ctx, t, x11Chans)
107113
ch, reqs, err := x11.Accept()
108114
require.NoError(t, err)
109115
go gossh.DiscardRequests(reqs)

0 commit comments

Comments
 (0)