diff --git a/disco/disco.go b/disco/disco.go index 0e7c3f7e5f882..8badad2dc895f 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -27,6 +27,7 @@ import ( "net/netip" "go4.org/mem" + "golang.org/x/crypto/nacl/box" "tailscale.com/types/key" ) @@ -48,6 +49,19 @@ const ( const v0 = byte(0) +// v1 Ping and Pong are padded as follows. CallMeMaybe is still on v0 and unpadded. +const v1 = byte(1) + +// paddedPayloadLen is the desired length we want to pad Ping and Pong payloads +// to so that they are the maximum size of a Wireguard packet we would +// subsequently send. This ensures that any UDP paths we discover will actually +// support the packet sizes the net stack will send over those paths. Any peers +// behind a small-MTU link will have to depend on DERP. +// c.f. https://github.com/coder/coder/issues/15523 +// Our inner IP packets can be up to 1280 bytes, with the Wireguard header of +// 30 bytes, that is 1310. The final 2 is the inner payload header's type and version. +const paddedPayloadLen = 1310 - len(Magic) - keyLen - NonceLen - box.Overhead - 2 + var errShort = errors.New("short message") // LooksLikeDiscoWrapper reports whether p looks like it's a packet @@ -120,12 +134,8 @@ type Ping struct { } func (m *Ping) AppendMarshal(b []byte) []byte { - dataLen := 12 hasKey := !m.NodeKey.IsZero() - if hasKey { - dataLen += key.NodePublicRawLen - } - ret, d := appendMsgHeader(b, TypePing, v0, dataLen) + ret, d := appendMsgHeader(b, TypePing, v1, paddedPayloadLen) n := copy(d, m.TxID[:]) if hasKey { m.NodeKey.AppendTo(d[:n]) @@ -217,7 +227,7 @@ type Pong struct { const pongLen = 12 + 16 + 2 func (m *Pong) AppendMarshal(b []byte) []byte { - ret, d := appendMsgHeader(b, TypePong, v0, pongLen) + ret, d := appendMsgHeader(b, TypePong, v1, paddedPayloadLen) d = d[copy(d, m.TxID[:]):] ip16 := m.Src.Addr().As16() d = d[copy(d, ip16[:]):] diff --git a/disco/disco_test.go b/disco/disco_test.go index 67bd1561a9bf6..475203e0aa1d2 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -11,6 +11,7 @@ import ( "testing" "go4.org/mem" + "golang.org/x/crypto/nacl/box" "tailscale.com/types/key" ) @@ -25,7 +26,7 @@ func TestMarshalAndParse(t *testing.T) { m: &Ping{ TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", + want: "01 01 01 02 03 04 05 06 07 08 09 0a 0b 0c", }, { name: "ping_with_nodekey_src", @@ -33,7 +34,7 @@ func TestMarshalAndParse(t *testing.T) { TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", + want: "01 01 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", }, { name: "pong", @@ -41,7 +42,7 @@ func TestMarshalAndParse(t *testing.T) { TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, Src: mustIPPort("2.3.4.5:1234"), }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", + want: "02 01 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", }, { name: "pongv6", @@ -49,7 +50,7 @@ func TestMarshalAndParse(t *testing.T) { TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, Src: mustIPPort("[fed0::12]:6666"), }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", + want: "02 01 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", }, { name: "call_me_maybe", @@ -75,10 +76,23 @@ func TestMarshalAndParse(t *testing.T) { if !ok { t.Fatalf("didn't start with foo: got %q", got) } + // CODER: 1310 is max size of a Wireguard packet we will send. + expectedLen := 1310 - len(Magic) - keyLen - NonceLen - box.Overhead + switch tt.m.(type) { + case *Ping: + if len(got) != expectedLen { + t.Fatalf("Ping not padded: got len %d, want len %d", len(got), expectedLen) + } + case *Pong: + if len(got) != expectedLen { + t.Fatalf("Pong not padded: got len %d, want len %d", len(got), expectedLen) + } + // CallMeMaybe is unpadded + } gotHex := fmt.Sprintf("% x", got) - if gotHex != tt.want { - t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) + if !strings.HasPrefix(gotHex, tt.want) { + t.Fatalf("wrong marshal\n got: %s\nwant prefix: %s\n", gotHex, tt.want) } back, err := Parse([]byte(got)) @@ -92,6 +106,69 @@ func TestMarshalAndParse(t *testing.T) { } } +func TestParsePingPongV0(t *testing.T) { + tests := []struct { + name string + payload []byte + m Message + }{ + { + name: "ping", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + }, + payload: []byte{0x01, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c}, + }, + { + name: "ping_with_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + }, + payload: []byte{ + 0x01, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, + 0x00, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1e, 0x1f}, + }, + { + name: "pong", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("2.3.4.5:1234"), + }, + payload: []byte{ + 0x02, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x02, 0x03, 0x04, 0x05, + 0x04, 0xd2}, + }, + { + name: "pongv6", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("[fed0::12]:6666"), + }, + payload: []byte{ + 0x02, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, + 0xfe, 0xd0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, + 0x1a, 0x0a}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + back, err := Parse(tt.payload) + if err != nil { + t.Fatalf("parse back: %v", err) + } + if !reflect.DeepEqual(back, tt.m) { + t.Errorf("message in %+v doesn't match Parse result %+v", tt.m, back) + } + }) + } +} + func mustIPPort(s string) netip.AddrPort { ipp, err := netip.ParseAddrPort(s) if err != nil { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 3c77a11353012..fc70b8ae6080c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2377,6 +2377,12 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur continue } trySetSocketBuffer(pconn, c.logf) + // CODER: https://github.com/coder/coder/issues/15523 + // Attempt to tell the OS not to fragment packets over this interface. We pad disco Ping and Pong packets to the + // size of the direct UDP packets that get sent for direct connections. Thus, any interfaces or paths that + // cannot fully support direct connections due to MTU limitations will not be selected. If no direct paths meet + // the MTU requirements for a peer, we will fall back to DERP for that peer. + tryPreventFragmentation(pconn, c.logf, network) // Success. if debugBindSocket() { c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port) diff --git a/wgengine/magicsock/magicsock_darwin.go b/wgengine/magicsock/magicsock_darwin.go new file mode 100644 index 0000000000000..f3b32959f5ba8 --- /dev/null +++ b/wgengine/magicsock/magicsock_darwin.go @@ -0,0 +1,36 @@ +package magicsock + +import ( + "net" + + "golang.org/x/sys/unix" + "tailscale.com/types/logger" + "tailscale.com/types/nettype" +) + +func tryPreventFragmentation(pconn nettype.PacketConn, logf logger.Logf, network string) { + if c, ok := pconn.(*net.UDPConn); ok { + s, err := c.SyscallConn() + if err != nil { + logf("magicsock: dontfrag: failed to get syscall conn: %v", err) + } + level := unix.IPPROTO_IP + option := unix.IP_DONTFRAG + if network == "udp6" { + level = unix.IPPROTO_IPV6 + option = unix.IPV6_DONTFRAG + } + err = s.Control(func(fd uintptr) { + err := unix.SetsockoptInt(int(fd), level, option, 1) + if err != nil { + logf("magicsock: dontfrag: SetsockoptInt failed: %v", err) + } + }) + if err != nil { + logf("magicsock: dontfrag: control connection failed: %v", err) + } + logf("magicsock: dontfrag: success on %s", pconn.LocalAddr().String()) + return + } + logf("magicsock: dontfrag: failed because it was not a UDPConn") +} diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index a4101ccbaa69d..f1be3d1cf40e0 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -404,3 +404,30 @@ func init() { // message. These contain a single uint16 of data. controlMessageSize = unix.CmsgSpace(2) } + +func tryPreventFragmentation(pconn nettype.PacketConn, logf logger.Logf, network string) { + if c, ok := pconn.(*net.UDPConn); ok { + s, err := c.SyscallConn() + if err != nil { + logf("magicsock: dontfrag: failed to get syscall conn: %v", err) + } + level := unix.IPPROTO_IP + option := unix.IP_MTU_DISCOVER + if network == "udp6" { + level = unix.IPPROTO_IPV6 + option = unix.IPV6_MTU_DISCOVER + } + err = s.Control(func(fd uintptr) { + err := unix.SetsockoptInt(int(fd), level, option, unix.IP_PMTUDISC_DO) + if err != nil { + logf("magicsock: dontfrag: SetsockoptInt failed: %v", err) + } + }) + if err != nil { + logf("magicsock: dontfrag: control connection failed: %v", err) + } + logf("magicsock: dontfrag: success on %s", pconn.LocalAddr().String()) + return + } + logf("magicsock: dontfrag: failed because it was not a UDPConn") +} diff --git a/wgengine/magicsock/magicsock_windows.go b/wgengine/magicsock/magicsock_windows.go new file mode 100644 index 0000000000000..0206fbc47c896 --- /dev/null +++ b/wgengine/magicsock/magicsock_windows.go @@ -0,0 +1,40 @@ +package magicsock + +import ( + "net" + + "golang.org/x/sys/windows" + "tailscale.com/types/logger" + "tailscale.com/types/nettype" +) + +// https://github.com/tpn/winsdk-10/blob/9b69fd26ac0c7d0b83d378dba01080e93349c2ed/Include/10.0.16299.0/shared/ws2ipdef.h +const ( + IP_MTU_DISCOVER = 71 // IPV6_MTU_DISCOVER has the same value, which is nice. + IP_PMTUDISC_DO = 1 +) + +func tryPreventFragmentation(pconn nettype.PacketConn, logf logger.Logf, network string) { + if c, ok := pconn.(*net.UDPConn); ok { + s, err := c.SyscallConn() + if err != nil { + logf("magicsock: dontfrag: failed to get syscall conn: %v", err) + } + level := windows.IPPROTO_IP + if network == "udp6" { + level = windows.IPPROTO_IPV6 + } + err = s.Control(func(fd uintptr) { + err := windows.SetsockoptInt(windows.Handle(fd), level, IP_MTU_DISCOVER, IP_PMTUDISC_DO) + if err != nil { + logf("magicsock: dontfrag: SetsockoptInt failed: %v", err) + } + }) + if err != nil { + logf("magicsock: dontfrag: control connection failed: %v", err) + } + logf("magicsock: dontfrag: success on %s", pconn.LocalAddr().String()) + return + } + logf("magicsock: dontfrag: failed because it was not a UDPConn") +}