Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b60589d

Browse files
committed
chore: refactor TestServer_X11 to use inproc networking
1 parent 1e438a6 commit b60589d

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

agent/agentssh/agentssh.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ type Config struct {
116116
// Experimental: allow connecting to running containers if
117117
// CODER_AGENT_DEVCONTAINERS_ENABLE=true.
118118
ExperimentalDevContainersEnabled bool
119+
// X11Net allows overriding the networking implementation used for X11
120+
// forwarding listeners. When nil, a default implementation backed by the
121+
// standard library networking package is used.
122+
X11Net X11Network
119123
}
120124

121125
type Server struct {
@@ -195,6 +199,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
195199
displayOffset: *config.X11DisplayOffset,
196200
sessions: make(map[*x11Session]struct{}),
197201
connections: make(map[net.Conn]struct{}),
202+
network: func() X11Network {
203+
if config.X11Net != nil {
204+
return config.X11Net
205+
}
206+
return osNet{}
207+
}(),
198208
},
199209
}
200210

agent/agentssh/x11.go

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -37,12 +38,30 @@ const (
3738
X11MaxPort = X11StartPort + X11MaxDisplays
3839
)
3940

41+
// X11Network abstracts the creation of network listeners for X11 forwarding.
42+
// It is intended mainly for testing; production code uses the default
43+
// implementation backed by the operating system networking stack.
44+
type X11Network interface {
45+
Listen(network, address string) (net.Listener, error)
46+
}
47+
48+
// osNet is the default X11Network implementation that uses the standard
49+
// library network stack.
50+
type osNet struct{}
51+
52+
func (osNet) Listen(network, address string) (net.Listener, error) {
53+
return net.Listen(network, address)
54+
}
55+
4056
type x11Forwarder struct {
4157
logger slog.Logger
4258
x11HandlerErrors *prometheus.CounterVec
4359
fs afero.Fs
4460
displayOffset int
4561

62+
// network creates X11 listener sockets. Defaults to osNet{}.
63+
network X11Network
64+
4665
mu sync.Mutex
4766
sessions map[*x11Session]struct{}
4867
connections map[net.Conn]struct{}
@@ -145,26 +164,35 @@ func (x *x11Forwarder) listenForConnections(ctx context.Context, session *x11Ses
145164
x.cleanSession(session)
146165
}
147166

148-
tcpConn, ok := conn.(*net.TCPConn)
149-
if !ok {
150-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
151-
_ = conn.Close()
152-
continue
167+
var originAddr string
168+
var originPort uint32
169+
170+
if tcpConn, ok := conn.(*net.TCPConn); ok {
171+
if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok {
172+
originAddr = tcpAddr.IP.String()
173+
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
174+
originPort = uint32(tcpAddr.Port)
175+
}
153176
}
154-
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
155-
if !ok {
156-
x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
157-
_ = conn.Close()
158-
continue
177+
// Fallback values for in-memory or non-TCP connections.
178+
if originAddr == "" {
179+
originAddr = "127.0.0.1"
180+
}
181+
if originPort == 0 {
182+
p := X11StartPort + session.display
183+
if p > math.MaxUint32 {
184+
panic("overflow")
185+
}
186+
// #nosec G115 - Safe conversion as port number is within uint32 range
187+
originPort = uint32(p)
159188
}
160189

161190
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
162191
OriginatorAddress string
163192
OriginatorPort uint32
164193
}{
165-
OriginatorAddress: tcpAddr.IP.String(),
166-
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
167-
OriginatorPort: uint32(tcpAddr.Port),
194+
OriginatorAddress: originAddr,
195+
OriginatorPort: originPort,
168196
}))
169197
if err != nil {
170198
x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
@@ -281,13 +309,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
281309
// createX11Listener creates a listener for X11 forwarding, it will use
282310
// the next available port starting from X11StartPort and displayOffset.
283311
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
284-
var lc net.ListenConfig
285312
// Look for an open port to listen on.
286313
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
287314
if ctx.Err() != nil {
288315
return nil, -1, ctx.Err()
289316
}
290-
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
317+
318+
ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port))
291319
if err == nil {
292320
display = port - X11StartPort
293321
return ln, display, nil

agent/agentssh/x11_test.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package agentssh_test
33
import (
44
"bufio"
55
"bytes"
6-
"context"
76
"encoding/hex"
87
"fmt"
98
"net"
@@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) {
3231
t.Skip("X11 forwarding is only supported on Linux")
3332
}
3433

35-
ctx := context.Background()
34+
ctx := testutil.Context(t, testutil.WaitShort)
3635
logger := testutil.Logger(t)
3736
fs := afero.NewMemMapFs()
38-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
37+
38+
// Use in-process networking for X11 forwarding.
39+
inproc := testutil.NewInProcNet()
40+
41+
// Create server config with custom X11 listener.
42+
cfg := &agentssh.Config{
43+
X11Net: inproc,
44+
}
45+
46+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
3947
require.NoError(t, err)
4048
defer s.Close()
4149
err = s.UpdateHostSigner(42)
@@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) {
93101

94102
x11Chans := c.HandleChannelOpen("x11")
95103
payload := "hello world"
96-
require.Eventually(t, func() bool {
97-
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
98-
if err == nil {
99-
_, err = conn.Write([]byte(payload))
100-
assert.NoError(t, err)
101-
_ = conn.Close()
102-
}
103-
return err == nil
104-
}, testutil.WaitShort, testutil.IntervalFast)
104+
go func() {
105+
conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)))
106+
assert.NoError(t, err)
107+
_, err = conn.Write([]byte(payload))
108+
assert.NoError(t, err)
109+
_ = conn.Close()
110+
}()
105111

106-
x11 := <-x11Chans
112+
x11 := testutil.RequireReceive(ctx, t, x11Chans)
107113
ch, reqs, err := x11.Accept()
108114
require.NoError(t, err)
109115
go gossh.DiscardRequests(reqs)

0 commit comments

Comments
 (0)