Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: change port-forward to opportunistically listen on IPv6 #15640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 70 additions & 51 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ import (
"github.com/coder/serpent"
)

var (
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
// when the local address is not specified in port-forward flags.
noAddr netip.Addr
ipv6Loopback = netip.MustParseAddr("::1")
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
)

func (r *RootCmd) portForward() *serpent.Command {
var (
tcpForwards []string // <port>:<port>
Expand Down Expand Up @@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
// Start all listeners.
var (
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
listeners = make([]net.Listener, 0, len(specs)*2)
closeAllListeners = func() {
logger.Debug(ctx, "closing all listeners")
for _, l := range listeners {
Expand All @@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
)
defer closeAllListeners()

for i, spec := range specs {
for _, spec := range specs {
if spec.listenHost == noAddr {
// first, opportunistically try to listen on IPv6
spec6 := spec
spec6.listenHost = ipv6Loopback
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
if err6 != nil {
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
} else {
listeners = append(listeners, l6)
}
spec.listenHost = ipv4Loopback
}
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
if err != nil {
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
return err
}
listeners[i] = l
listeners = append(listeners, l)
}

stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
Expand Down Expand Up @@ -206,12 +226,19 @@ func listenAndPortForward(
spec portForwardSpec,
logger slog.Logger,
) (net.Listener, error) {
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
logger = logger.With(
slog.F("network", spec.network),
slog.F("listen_host", spec.listenHost),
slog.F("listen_port", spec.listenPort),
)
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
spec.network, listenAddress, spec.network, dialAddress)

l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
l, err := inv.Net.Listen(spec.network, listenAddress.String())
if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
}
logger.Debug(ctx, "listening")

Expand All @@ -226,24 +253,31 @@ func listenAndPortForward(
logger.Debug(ctx, "listener closed")
return
}
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintf(inv.Stderr,
"Error accepting connection from '%s://%s': %v\n",
spec.network, listenAddress.String(), err)
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
return
}
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx, "accepted connection",
slog.F("remote_addr", netConn.RemoteAddr()))

go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
if err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
_, _ = fmt.Fprintf(inv.Stderr,
"Failed to dial '%s://%s' in workspace: %s\n",
spec.network, dialAddress, err)
return
}
defer remoteConn.Close()
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx,
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))

agentssh.Bicopy(ctx, netConn, remoteConn)
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx,
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
}(netConn)
}
}(spec)
Expand All @@ -252,58 +286,48 @@ func listenAndPortForward(
}

type portForwardSpec struct {
listenNetwork string // tcp, udp
listenAddress string // <ip>:<port> or path

dialNetwork string // tcp, udp
dialAddress string // <ip>:<port> or path
network string // tcp, udp
listenHost netip.Addr
listenPort, dialPort uint16
}

func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
specs := []portForwardSpec{}

for _, specEntry := range tcpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
}

for _, port := range ports {
specs = append(specs, portForwardSpec{
listenNetwork: "tcp",
listenAddress: port.local.String(),
dialNetwork: "tcp",
dialAddress: port.remote.String(),
})
for _, pfSpec := range pfSpecs {
pfSpec.network = "tcp"
specs = append(specs, pfSpec)
}
}
}

for _, specEntry := range udpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
}

for _, port := range ports {
specs = append(specs, portForwardSpec{
listenNetwork: "udp",
listenAddress: port.local.String(),
dialNetwork: "udp",
dialAddress: port.remote.String(),
})
for _, pfSpec := range pfSpecs {
pfSpec.network = "udp"
specs = append(specs, pfSpec)
}
}
}

// Check for duplicate entries.
locals := map[string]struct{}{}
for _, spec := range specs {
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
if _, ok := locals[localStr]; ok {
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
}
locals[localStr] = struct{}{}
}
Expand All @@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
return uint16(port), nil
}

type parsedSrcDestPort struct {
local, remote netip.AddrPort
}

// specRegexp matches port specs. It handles all the following formats:
//
// 8000
Expand All @@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
// 9: end or remote port range
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)

func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
var (
err error
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
)
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
groups := specRegexp.FindStringSubmatch(in)
if len(groups) == 0 {
return nil, xerrors.Errorf("invalid port specification %q", in)
}

var localAddr netip.Addr
if groups[2] != "" {
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
if err != nil {
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
}
localAddr = parsedAddr
}

local, err := parsePortRange(groups[3], groups[5])
Expand All @@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
if len(local) != len(remote) {
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
}
var out []parsedSrcDestPort
var out []portForwardSpec
for i := range local {
out = append(out, parsedSrcDestPort{
local: netip.AddrPortFrom(localAddr, local[i]),
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
out = append(out, portForwardSpec{
listenHost: localAddr,
listenPort: local[i],
dialPort: remote[i],
})
}
return out, nil
Expand Down
32 changes: 16 additions & 16 deletions cli/portforward_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) {
},
},
want: []portForwardSpec{
{"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"},
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
{"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"},
{"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"},
{"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"},
{"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"},
{"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"},
{"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"},
{"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"},
{"tcp", noAddr, 8000, 8000},
{"tcp", noAddr, 8080, 8081},
{"tcp", noAddr, 9000, 9000},
{"tcp", noAddr, 9001, 9001},
{"tcp", noAddr, 9002, 9002},
{"tcp", noAddr, 9003, 9005},
{"tcp", noAddr, 9004, 9006},
{"tcp", noAddr, 10000, 10000},
{"tcp", noAddr, 4444, 4444},
},
},
{
Expand All @@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) {
tcpSpecs: []string{"127.0.0.1:8080:8081"},
},
want: []portForwardSpec{
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
{"tcp", ipv4Loopback, 8080, 8081},
},
},
{
Expand All @@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) {
tcpSpecs: []string{"[::1]:8080:8081"},
},
want: []portForwardSpec{
{"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"},
{"tcp", ipv6Loopback, 8080, 8081},
},
},
{
Expand All @@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) {
udpSpecs: []string{"8000,8080-8081"},
},
want: []portForwardSpec{
{"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"},
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"},
{"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"},
{"udp", noAddr, 8000, 8000},
{"udp", noAddr, 8080, 8080},
{"udp", noAddr, 8081, 8081},
},
},
{
Expand All @@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) {
udpSpecs: []string{"127.0.0.1:8080:8081"},
},
want: []portForwardSpec{
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"},
{"udp", ipv4Loopback, 8080, 8081},
},
},
{
Expand All @@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) {
udpSpecs: []string{"[::1]:8080:8081"},
},
want: []portForwardSpec{
{"udp", "[::1]:8080", "udp", "127.0.0.1:8081"},
{"udp", ipv6Loopback, 8080, 8081},
},
},
{
Expand Down
Loading
Loading