Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e6506f0

Browse files
authored
feat: change port-forward to opportunistically listen on IPv6 (#15640)
If the local IP address is not explicitly set, previously we assumed 127.0.0.1 (that is, IPv4 only localhost). This PR adds support to opportunistically _also_ listen on IPv6 ::1.
1 parent 1cdc3e8 commit e6506f0

File tree

3 files changed

+169
-67
lines changed

3 files changed

+169
-67
lines changed

cli/portforward.go

+70-51
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ import (
2525
"github.com/coder/serpent"
2626
)
2727

28+
var (
29+
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30+
// when the local address is not specified in port-forward flags.
31+
noAddr netip.Addr
32+
ipv6Loopback = netip.MustParseAddr("::1")
33+
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
34+
)
35+
2836
func (r *RootCmd) portForward() *serpent.Command {
2937
var (
3038
tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122130
// Start all listeners.
123131
var (
124132
wg = new(sync.WaitGroup)
125-
listeners = make([]net.Listener, len(specs))
133+
listeners = make([]net.Listener, 0, len(specs)*2)
126134
closeAllListeners = func() {
127135
logger.Debug(ctx, "closing all listeners")
128136
for _, l := range listeners {
@@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
135143
)
136144
defer closeAllListeners()
137145

138-
for i, spec := range specs {
146+
for _, spec := range specs {
147+
if spec.listenHost == noAddr {
148+
// first, opportunistically try to listen on IPv6
149+
spec6 := spec
150+
spec6.listenHost = ipv6Loopback
151+
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
152+
if err6 != nil {
153+
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
154+
} else {
155+
listeners = append(listeners, l6)
156+
}
157+
spec.listenHost = ipv4Loopback
158+
}
139159
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
140160
if err != nil {
141161
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
142162
return err
143163
}
144-
listeners[i] = l
164+
listeners = append(listeners, l)
145165
}
146166

147167
stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
@@ -206,12 +226,19 @@ func listenAndPortForward(
206226
spec portForwardSpec,
207227
logger slog.Logger,
208228
) (net.Listener, error) {
209-
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
210-
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
229+
logger = logger.With(
230+
slog.F("network", spec.network),
231+
slog.F("listen_host", spec.listenHost),
232+
slog.F("listen_port", spec.listenPort),
233+
)
234+
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
235+
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
236+
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
237+
spec.network, listenAddress, spec.network, dialAddress)
211238

212-
l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
239+
l, err := inv.Net.Listen(spec.network, listenAddress.String())
213240
if err != nil {
214-
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
241+
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
215242
}
216243
logger.Debug(ctx, "listening")
217244

@@ -226,24 +253,31 @@ func listenAndPortForward(
226253
logger.Debug(ctx, "listener closed")
227254
return
228255
}
229-
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
256+
_, _ = fmt.Fprintf(inv.Stderr,
257+
"Error accepting connection from '%s://%s': %v\n",
258+
spec.network, listenAddress.String(), err)
230259
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
231260
return
232261
}
233-
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
262+
logger.Debug(ctx, "accepted connection",
263+
slog.F("remote_addr", netConn.RemoteAddr()))
234264

235265
go func(netConn net.Conn) {
236266
defer netConn.Close()
237-
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
267+
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
238268
if err != nil {
239-
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
269+
_, _ = fmt.Fprintf(inv.Stderr,
270+
"Failed to dial '%s://%s' in workspace: %s\n",
271+
spec.network, dialAddress, err)
240272
return
241273
}
242274
defer remoteConn.Close()
243-
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
275+
logger.Debug(ctx,
276+
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
244277

245278
agentssh.Bicopy(ctx, netConn, remoteConn)
246-
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
279+
logger.Debug(ctx,
280+
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
247281
}(netConn)
248282
}
249283
}(spec)
@@ -252,58 +286,48 @@ func listenAndPortForward(
252286
}
253287

254288
type portForwardSpec struct {
255-
listenNetwork string // tcp, udp
256-
listenAddress string // <ip>:<port> or path
257-
258-
dialNetwork string // tcp, udp
259-
dialAddress string // <ip>:<port> or path
289+
network string // tcp, udp
290+
listenHost netip.Addr
291+
listenPort, dialPort uint16
260292
}
261293

262294
func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
263295
specs := []portForwardSpec{}
264296

265297
for _, specEntry := range tcpSpecs {
266298
for _, spec := range strings.Split(specEntry, ",") {
267-
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
299+
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
268300
if err != nil {
269301
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
270302
}
271303

272-
for _, port := range ports {
273-
specs = append(specs, portForwardSpec{
274-
listenNetwork: "tcp",
275-
listenAddress: port.local.String(),
276-
dialNetwork: "tcp",
277-
dialAddress: port.remote.String(),
278-
})
304+
for _, pfSpec := range pfSpecs {
305+
pfSpec.network = "tcp"
306+
specs = append(specs, pfSpec)
279307
}
280308
}
281309
}
282310

283311
for _, specEntry := range udpSpecs {
284312
for _, spec := range strings.Split(specEntry, ",") {
285-
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
313+
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
286314
if err != nil {
287315
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
288316
}
289317

290-
for _, port := range ports {
291-
specs = append(specs, portForwardSpec{
292-
listenNetwork: "udp",
293-
listenAddress: port.local.String(),
294-
dialNetwork: "udp",
295-
dialAddress: port.remote.String(),
296-
})
318+
for _, pfSpec := range pfSpecs {
319+
pfSpec.network = "udp"
320+
specs = append(specs, pfSpec)
297321
}
298322
}
299323
}
300324

301325
// Check for duplicate entries.
302326
locals := map[string]struct{}{}
303327
for _, spec := range specs {
304-
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
328+
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
305329
if _, ok := locals[localStr]; ok {
306-
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
330+
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
307331
}
308332
locals[localStr] = struct{}{}
309333
}
@@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
323347
return uint16(port), nil
324348
}
325349

326-
type parsedSrcDestPort struct {
327-
local, remote netip.AddrPort
328-
}
329-
330350
// specRegexp matches port specs. It handles all the following formats:
331351
//
332352
// 8000
@@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
347367
// 9: end or remote port range
348368
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)
349369

350-
func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
351-
var (
352-
err error
353-
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
354-
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
355-
)
370+
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
356371
groups := specRegexp.FindStringSubmatch(in)
357372
if len(groups) == 0 {
358373
return nil, xerrors.Errorf("invalid port specification %q", in)
359374
}
375+
376+
var localAddr netip.Addr
360377
if groups[2] != "" {
361-
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
378+
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
362379
if err != nil {
363380
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
364381
}
382+
localAddr = parsedAddr
365383
}
366384

367385
local, err := parsePortRange(groups[3], groups[5])
@@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378396
if len(local) != len(remote) {
379397
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
380398
}
381-
var out []parsedSrcDestPort
399+
var out []portForwardSpec
382400
for i := range local {
383-
out = append(out, parsedSrcDestPort{
384-
local: netip.AddrPortFrom(localAddr, local[i]),
385-
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
401+
out = append(out, portForwardSpec{
402+
listenHost: localAddr,
403+
listenPort: local[i],
404+
dialPort: remote[i],
386405
})
387406
}
388407
return out, nil

cli/portforward_internal_test.go

+16-16
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) {
2929
},
3030
},
3131
want: []portForwardSpec{
32-
{"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"},
33-
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
34-
{"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"},
35-
{"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"},
36-
{"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"},
37-
{"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"},
38-
{"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"},
39-
{"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"},
40-
{"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"},
32+
{"tcp", noAddr, 8000, 8000},
33+
{"tcp", noAddr, 8080, 8081},
34+
{"tcp", noAddr, 9000, 9000},
35+
{"tcp", noAddr, 9001, 9001},
36+
{"tcp", noAddr, 9002, 9002},
37+
{"tcp", noAddr, 9003, 9005},
38+
{"tcp", noAddr, 9004, 9006},
39+
{"tcp", noAddr, 10000, 10000},
40+
{"tcp", noAddr, 4444, 4444},
4141
},
4242
},
4343
{
@@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) {
4646
tcpSpecs: []string{"127.0.0.1:8080:8081"},
4747
},
4848
want: []portForwardSpec{
49-
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
49+
{"tcp", ipv4Loopback, 8080, 8081},
5050
},
5151
},
5252
{
@@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) {
5555
tcpSpecs: []string{"[::1]:8080:8081"},
5656
},
5757
want: []portForwardSpec{
58-
{"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"},
58+
{"tcp", ipv6Loopback, 8080, 8081},
5959
},
6060
},
6161
{
@@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) {
6464
udpSpecs: []string{"8000,8080-8081"},
6565
},
6666
want: []portForwardSpec{
67-
{"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"},
68-
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"},
69-
{"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"},
67+
{"udp", noAddr, 8000, 8000},
68+
{"udp", noAddr, 8080, 8080},
69+
{"udp", noAddr, 8081, 8081},
7070
},
7171
},
7272
{
@@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) {
7575
udpSpecs: []string{"127.0.0.1:8080:8081"},
7676
},
7777
want: []portForwardSpec{
78-
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"},
78+
{"udp", ipv4Loopback, 8080, 8081},
7979
},
8080
},
8181
{
@@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) {
8484
udpSpecs: []string{"[::1]:8080:8081"},
8585
},
8686
want: []portForwardSpec{
87-
{"udp", "[::1]:8080", "udp", "127.0.0.1:8081"},
87+
{"udp", ipv6Loopback, 8080, 8081},
8888
},
8989
},
9090
{

0 commit comments

Comments
 (0)