|
7 | 7 | "errors"
|
8 | 8 | "fmt"
|
9 | 9 | "io"
|
| 10 | + "math" |
10 | 11 | "net"
|
11 | 12 | "os"
|
12 | 13 | "path/filepath"
|
@@ -37,12 +38,30 @@ const (
|
37 | 38 | X11MaxPort = X11StartPort + X11MaxDisplays
|
38 | 39 | )
|
39 | 40 |
|
| 41 | +// X11Network abstracts the creation of network listeners for X11 forwarding. |
| 42 | +// It is intended mainly for testing; production code uses the default |
| 43 | +// implementation backed by the operating system networking stack. |
| 44 | +type X11Network interface { |
| 45 | + Listen(network, address string) (net.Listener, error) |
| 46 | +} |
| 47 | + |
| 48 | +// osNet is the default X11Network implementation that uses the standard |
| 49 | +// library network stack. |
| 50 | +type osNet struct{} |
| 51 | + |
| 52 | +func (osNet) Listen(network, address string) (net.Listener, error) { |
| 53 | + return net.Listen(network, address) |
| 54 | +} |
| 55 | + |
40 | 56 | type x11Forwarder struct {
|
41 | 57 | logger slog.Logger
|
42 | 58 | x11HandlerErrors *prometheus.CounterVec
|
43 | 59 | fs afero.Fs
|
44 | 60 | displayOffset int
|
45 | 61 |
|
| 62 | + // network creates X11 listener sockets. Defaults to osNet{}. |
| 63 | + network X11Network |
| 64 | + |
46 | 65 | mu sync.Mutex
|
47 | 66 | sessions map[*x11Session]struct{}
|
48 | 67 | connections map[net.Conn]struct{}
|
@@ -145,26 +164,35 @@ func (x *x11Forwarder) listenForConnections(ctx context.Context, session *x11Ses
|
145 | 164 | x.cleanSession(session)
|
146 | 165 | }
|
147 | 166 |
|
148 |
| - tcpConn, ok := conn.(*net.TCPConn) |
149 |
| - if !ok { |
150 |
| - x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) |
151 |
| - _ = conn.Close() |
152 |
| - continue |
| 167 | + var originAddr string |
| 168 | + var originPort uint32 |
| 169 | + |
| 170 | + if tcpConn, ok := conn.(*net.TCPConn); ok { |
| 171 | + if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok { |
| 172 | + originAddr = tcpAddr.IP.String() |
| 173 | + // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) |
| 174 | + originPort = uint32(tcpAddr.Port) |
| 175 | + } |
153 | 176 | }
|
154 |
| - tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) |
155 |
| - if !ok { |
156 |
| - x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) |
157 |
| - _ = conn.Close() |
158 |
| - continue |
| 177 | + // Fallback values for in-memory or non-TCP connections. |
| 178 | + if originAddr == "" { |
| 179 | + originAddr = "127.0.0.1" |
| 180 | + } |
| 181 | + if originPort == 0 { |
| 182 | + p := X11StartPort + session.display |
| 183 | + if p > math.MaxUint32 { |
| 184 | + panic("overflow") |
| 185 | + } |
| 186 | + // #nosec G115 - Safe conversion as port number is within uint32 range |
| 187 | + originPort = uint32(p) |
159 | 188 | }
|
160 | 189 |
|
161 | 190 | channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
|
162 | 191 | OriginatorAddress string
|
163 | 192 | OriginatorPort uint32
|
164 | 193 | }{
|
165 |
| - OriginatorAddress: tcpAddr.IP.String(), |
166 |
| - // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) |
167 |
| - OriginatorPort: uint32(tcpAddr.Port), |
| 194 | + OriginatorAddress: originAddr, |
| 195 | + OriginatorPort: originPort, |
168 | 196 | }))
|
169 | 197 | if err != nil {
|
170 | 198 | x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
|
@@ -281,13 +309,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
|
281 | 309 | // createX11Listener creates a listener for X11 forwarding, it will use
|
282 | 310 | // the next available port starting from X11StartPort and displayOffset.
|
283 | 311 | func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
|
284 |
| - var lc net.ListenConfig |
285 | 312 | // Look for an open port to listen on.
|
286 | 313 | for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
|
287 | 314 | if ctx.Err() != nil {
|
288 | 315 | return nil, -1, ctx.Err()
|
289 | 316 | }
|
290 |
| - ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) |
| 317 | + |
| 318 | + ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port)) |
291 | 319 | if err == nil {
|
292 | 320 | display = port - X11StartPort
|
293 | 321 | return ln, display, nil
|
|
0 commit comments