Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 011e763

Browse files
committed
chore: refactor TestServer_X11 to use inproc networking
1 parent abcf3df commit 011e763

File tree

3 files changed

+63
-28
lines changed

3 files changed

+63
-28
lines changed

agent/agentssh/agentssh.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ type Config struct {
117117
// Note that this is different from the devcontainers feature, which uses
118118
// subagents.
119119
ExperimentalContainers bool
120+
// X11Net allows overriding the networking implementation used for X11
121+
// forwarding listeners. When nil, a default implementation backed by the
122+
// standard library networking package is used.
123+
X11Net X11Network
120124
}
121125

122126
type Server struct {
@@ -196,6 +200,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
196200
displayOffset: *config.X11DisplayOffset,
197201
sessions: make(map[*x11Session]struct{}),
198202
connections: make(map[net.Conn]struct{}),
203+
network: func() X11Network {
204+
if config.X11Net != nil {
205+
return config.X11Net
206+
}
207+
return osNet{}
208+
}(),
199209
},
200210
}
201211

agent/agentssh/x11.go

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,30 @@ const (
3737
X11MaxPort = X11StartPort + X11MaxDisplays
3838
)
3939

40+
// X11Network abstracts the creation of network listeners for X11 forwarding.
41+
// It is intended mainly for testing; production code uses the default
42+
// implementation backed by the operating system networking stack.
43+
type X11Network interface {
44+
Listen(network, address string) (net.Listener, error)
45+
}
46+
47+
// osNet is the default X11Network implementation that uses the standard
48+
// library network stack.
49+
type osNet struct{}
50+
51+
func (osNet) Listen(network, address string) (net.Listener, error) {
52+
return net.Listen(network, address)
53+
}
54+
4055
type x11Forwarder struct {
4156
logger slog.Logger
4257
x11HandlerErrors *prometheus.CounterVec
4358
fs afero.Fs
4459
displayOffset int
4560

61+
// network creates X11 listener sockets. Defaults to osNet{}.
62+
network X11Network
63+
4664
mu sync.Mutex
4765
sessions map[*x11Session]struct{}
4866
connections map[net.Conn]struct{}
@@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections(
147165
x.closeAndRemoveSession(session)
148166
}
149167

150-
tcpConn, ok := conn.(*net.TCPConn)
151-
if !ok {
152-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
153-
_ = conn.Close()
154-
continue
168+
var originAddr string
169+
var originPort uint32
170+
171+
if tcpConn, ok := conn.(*net.TCPConn); ok {
172+
if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok {
173+
originAddr = tcpAddr.IP.String()
174+
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
175+
originPort = uint32(tcpAddr.Port)
176+
}
155177
}
156-
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
157-
if !ok {
158-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
159-
_ = conn.Close()
160-
continue
178+
// Fallback values for in-memory or non-TCP connections.
179+
if originAddr == "" {
180+
originAddr = "127.0.0.1"
161181
}
162182

163183
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
164184
OriginatorAddress string
165185
OriginatorPort uint32
166186
}{
167-
OriginatorAddress: tcpAddr.IP.String(),
168-
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
169-
OriginatorPort: uint32(tcpAddr.Port),
187+
OriginatorAddress: originAddr,
188+
OriginatorPort: originPort,
170189
}))
171190
if err != nil {
172191
x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
@@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
287306
// createX11Listener creates a listener for X11 forwarding, it will use
288307
// the next available port starting from X11StartPort and displayOffset.
289308
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
290-
var lc net.ListenConfig
291309
// Look for an open port to listen on.
292310
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
293311
if ctx.Err() != nil {
294312
return nil, -1, ctx.Err()
295313
}
296-
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
314+
315+
ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port))
297316
if err == nil {
298317
display = port - X11StartPort
299318
return ln, display, nil

agent/agentssh/x11_test.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package agentssh_test
33
import (
44
"bufio"
55
"bytes"
6-
"context"
76
"encoding/hex"
87
"fmt"
98
"net"
@@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) {
3231
t.Skip("X11 forwarding is only supported on Linux")
3332
}
3433

35-
ctx := context.Background()
34+
ctx := testutil.Context(t, testutil.WaitShort)
3635
logger := testutil.Logger(t)
3736
fs := afero.NewMemMapFs()
38-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
37+
38+
// Use in-process networking for X11 forwarding.
39+
inproc := testutil.NewInProcNet()
40+
41+
// Create server config with custom X11 listener.
42+
cfg := &agentssh.Config{
43+
X11Net: inproc,
44+
}
45+
46+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
3947
require.NoError(t, err)
4048
defer s.Close()
4149
err = s.UpdateHostSigner(42)
@@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) {
93101

94102
x11Chans := c.HandleChannelOpen("x11")
95103
payload := "hello world"
96-
require.Eventually(t, func() bool {
97-
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
98-
if err == nil {
99-
_, err = conn.Write([]byte(payload))
100-
assert.NoError(t, err)
101-
_ = conn.Close()
102-
}
103-
return err == nil
104-
}, testutil.WaitShort, testutil.IntervalFast)
104+
go func() {
105+
conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)))
106+
assert.NoError(t, err)
107+
_, err = conn.Write([]byte(payload))
108+
assert.NoError(t, err)
109+
_ = conn.Close()
110+
}()
105111

106-
x11 := <-x11Chans
112+
x11 := testutil.RequireReceive(ctx, t, x11Chans)
107113
ch, reqs, err := x11.Accept()
108114
require.NoError(t, err)
109115
go gossh.DiscardRequests(reqs)

0 commit comments

Comments
 (0)