Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 64265cb

Browse files
committed
feat: opportunistically listen on IPv6 in port-forward
1 parent f6c3f0a commit 64265cb

File tree

3 files changed

+170
-67
lines changed

3 files changed

+170
-67
lines changed

cli/portforward.go

+71-51
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ import (
2525
"github.com/coder/serpent"
2626
)
2727

28+
var (
29+
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30+
// when the local address is not specified in port-forward flags.
31+
noAddr netip.Addr
32+
ipv6Loopback = netip.MustParseAddr("::1")
33+
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
34+
)
35+
2836
func (r *RootCmd) portForward() *serpent.Command {
2937
var (
3038
tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122130
// Start all listeners.
123131
var (
124132
wg = new(sync.WaitGroup)
125-
listeners = make([]net.Listener, len(specs))
133+
listeners = make([]net.Listener, 0, len(specs)*2)
126134
closeAllListeners = func() {
127135
logger.Debug(ctx, "closing all listeners")
128136
for _, l := range listeners {
@@ -135,13 +143,26 @@ func (r *RootCmd) portForward() *serpent.Command {
135143
)
136144
defer closeAllListeners()
137145

138-
for i, spec := range specs {
146+
for _, spec := range specs {
147+
148+
if spec.listenHost == noAddr {
149+
// first, opportunistically try to listen on IPv6
150+
spec6 := spec
151+
spec6.listenHost = ipv6Loopback
152+
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
153+
if err6 != nil {
154+
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
155+
} else {
156+
listeners = append(listeners, l6)
157+
}
158+
spec.listenHost = ipv4Loopback
159+
}
139160
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
140161
if err != nil {
141162
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
142163
return err
143164
}
144-
listeners[i] = l
165+
listeners = append(listeners, l)
145166
}
146167

147168
stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
@@ -206,12 +227,19 @@ func listenAndPortForward(
206227
spec portForwardSpec,
207228
logger slog.Logger,
208229
) (net.Listener, error) {
209-
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
210-
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
230+
logger = logger.With(
231+
slog.F("network", spec.network),
232+
slog.F("listen_host", spec.listenHost),
233+
slog.F("listen_port", spec.listenPort),
234+
)
235+
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
236+
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
237+
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
238+
spec.network, listenAddress, spec.network, dialAddress)
211239

212-
l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
240+
l, err := inv.Net.Listen(spec.network, listenAddress.String())
213241
if err != nil {
214-
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
242+
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
215243
}
216244
logger.Debug(ctx, "listening")
217245

@@ -226,24 +254,31 @@ func listenAndPortForward(
226254
logger.Debug(ctx, "listener closed")
227255
return
228256
}
229-
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
257+
_, _ = fmt.Fprintf(inv.Stderr,
258+
"Error accepting connection from '%s://%s': %v\n",
259+
spec.network, listenAddress.String(), err)
230260
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
231261
return
232262
}
233-
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
263+
logger.Debug(ctx, "accepted connection",
264+
slog.F("remote_addr", netConn.RemoteAddr()))
234265

235266
go func(netConn net.Conn) {
236267
defer netConn.Close()
237-
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
268+
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
238269
if err != nil {
239-
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
270+
_, _ = fmt.Fprintf(inv.Stderr,
271+
"Failed to dial '%s://%s' in workspace: %s\n",
272+
spec.network, dialAddress, err)
240273
return
241274
}
242275
defer remoteConn.Close()
243-
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
276+
logger.Debug(ctx,
277+
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
244278

245279
agentssh.Bicopy(ctx, netConn, remoteConn)
246-
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
280+
logger.Debug(ctx,
281+
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
247282
}(netConn)
248283
}
249284
}(spec)
@@ -252,58 +287,48 @@ func listenAndPortForward(
252287
}
253288

254289
type portForwardSpec struct {
255-
listenNetwork string // tcp, udp
256-
listenAddress string // <ip>:<port> or path
257-
258-
dialNetwork string // tcp, udp
259-
dialAddress string // <ip>:<port> or path
290+
network string // tcp, udp
291+
listenHost netip.Addr
292+
listenPort, dialPort uint16
260293
}
261294

262295
func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
263296
specs := []portForwardSpec{}
264297

265298
for _, specEntry := range tcpSpecs {
266299
for _, spec := range strings.Split(specEntry, ",") {
267-
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
300+
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
268301
if err != nil {
269302
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
270303
}
271304

272-
for _, port := range ports {
273-
specs = append(specs, portForwardSpec{
274-
listenNetwork: "tcp",
275-
listenAddress: port.local.String(),
276-
dialNetwork: "tcp",
277-
dialAddress: port.remote.String(),
278-
})
305+
for _, pfSpec := range pfSpecs {
306+
pfSpec.network = "tcp"
307+
specs = append(specs, pfSpec)
279308
}
280309
}
281310
}
282311

283312
for _, specEntry := range udpSpecs {
284313
for _, spec := range strings.Split(specEntry, ",") {
285-
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
314+
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
286315
if err != nil {
287316
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
288317
}
289318

290-
for _, port := range ports {
291-
specs = append(specs, portForwardSpec{
292-
listenNetwork: "udp",
293-
listenAddress: port.local.String(),
294-
dialNetwork: "udp",
295-
dialAddress: port.remote.String(),
296-
})
319+
for _, pfSpec := range pfSpecs {
320+
pfSpec.network = "udp"
321+
specs = append(specs, pfSpec)
297322
}
298323
}
299324
}
300325

301326
// Check for duplicate entries.
302327
locals := map[string]struct{}{}
303328
for _, spec := range specs {
304-
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
329+
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
305330
if _, ok := locals[localStr]; ok {
306-
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
331+
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
307332
}
308333
locals[localStr] = struct{}{}
309334
}
@@ -323,10 +348,6 @@ func parsePort(in string) (uint16, error) {
323348
return uint16(port), nil
324349
}
325350

326-
type parsedSrcDestPort struct {
327-
local, remote netip.AddrPort
328-
}
329-
330351
// specRegexp matches port specs. It handles all the following formats:
331352
//
332353
// 8000
@@ -347,21 +368,19 @@ type parsedSrcDestPort struct {
347368
// 9: end or remote port range
348369
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)
349370

350-
func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
351-
var (
352-
err error
353-
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
354-
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
355-
)
371+
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
356372
groups := specRegexp.FindStringSubmatch(in)
357373
if len(groups) == 0 {
358374
return nil, xerrors.Errorf("invalid port specification %q", in)
359375
}
376+
377+
var localAddr netip.Addr
360378
if groups[2] != "" {
361-
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
379+
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
362380
if err != nil {
363381
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
364382
}
383+
localAddr = parsedAddr
365384
}
366385

367386
local, err := parsePortRange(groups[3], groups[5])
@@ -378,11 +397,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378397
if len(local) != len(remote) {
379398
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
380399
}
381-
var out []parsedSrcDestPort
400+
var out []portForwardSpec
382401
for i := range local {
383-
out = append(out, parsedSrcDestPort{
384-
local: netip.AddrPortFrom(localAddr, local[i]),
385-
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
402+
out = append(out, portForwardSpec{
403+
listenHost: localAddr,
404+
listenPort: local[i],
405+
dialPort: remote[i],
386406
})
387407
}
388408
return out, nil

cli/portforward_internal_test.go

+16-16
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) {
2929
},
3030
},
3131
want: []portForwardSpec{
32-
{"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"},
33-
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
34-
{"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"},
35-
{"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"},
36-
{"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"},
37-
{"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"},
38-
{"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"},
39-
{"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"},
40-
{"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"},
32+
{"tcp", noAddr, 8000, 8000},
33+
{"tcp", noAddr, 8080, 8081},
34+
{"tcp", noAddr, 9000, 9000},
35+
{"tcp", noAddr, 9001, 9001},
36+
{"tcp", noAddr, 9002, 9002},
37+
{"tcp", noAddr, 9003, 9005},
38+
{"tcp", noAddr, 9004, 9006},
39+
{"tcp", noAddr, 10000, 10000},
40+
{"tcp", noAddr, 4444, 4444},
4141
},
4242
},
4343
{
@@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) {
4646
tcpSpecs: []string{"127.0.0.1:8080:8081"},
4747
},
4848
want: []portForwardSpec{
49-
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
49+
{"tcp", ipv4Loopback, 8080, 8081},
5050
},
5151
},
5252
{
@@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) {
5555
tcpSpecs: []string{"[::1]:8080:8081"},
5656
},
5757
want: []portForwardSpec{
58-
{"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"},
58+
{"tcp", ipv6Loopback, 8080, 8081},
5959
},
6060
},
6161
{
@@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) {
6464
udpSpecs: []string{"8000,8080-8081"},
6565
},
6666
want: []portForwardSpec{
67-
{"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"},
68-
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"},
69-
{"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"},
67+
{"udp", noAddr, 8000, 8000},
68+
{"udp", noAddr, 8080, 8080},
69+
{"udp", noAddr, 8081, 8081},
7070
},
7171
},
7272
{
@@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) {
7575
udpSpecs: []string{"127.0.0.1:8080:8081"},
7676
},
7777
want: []portForwardSpec{
78-
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"},
78+
{"udp", ipv4Loopback, 8080, 8081},
7979
},
8080
},
8181
{
@@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) {
8484
udpSpecs: []string{"[::1]:8080:8081"},
8585
},
8686
want: []portForwardSpec{
87-
{"udp", "[::1]:8080", "udp", "127.0.0.1:8081"},
87+
{"udp", ipv6Loopback, 8080, 8081},
8888
},
8989
},
9090
{

0 commit comments

Comments
 (0)