// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package stack

import (
	"encoding/binary"
	"math"
	"testing"
	"time"

	"gvisor.dev/gvisor/pkg/buffer"
	"gvisor.dev/gvisor/pkg/sync"
	"gvisor.dev/gvisor/pkg/tcpip"
	"gvisor.dev/gvisor/pkg/tcpip/faketime"
	"gvisor.dev/gvisor/pkg/tcpip/header"
)

const (
	fwdTestNetNumber           tcpip.NetworkProtocolNumber = math.MaxUint32
	fwdTestNetHeaderLen                                    = 12
	fwdTestNetDefaultPrefixLen                             = 8

	// fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
	// except where another value is explicitly used. It is chosen to match
	// the MTU of loopback interfaces on linux systems.
	fwdTestNetDefaultMTU = 65536

	dstAddrOffset        = 0
	srcAddrOffset        = 4
	protocolNumberOffset = 8
)

var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil)
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)

// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
	AddressableEndpointState

	nic        NetworkInterface
	proto      *fwdTestNetworkProtocol
	dispatcher TransportDispatcher

	mu struct {
		sync.RWMutex
		forwarding bool
	}
}

func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
	return nil
}

func (*fwdTestNetworkEndpoint) Enabled() bool {
	return true
}

func (*fwdTestNetworkEndpoint) Disable() {}

func (f *fwdTestNetworkEndpoint) MTU() uint32 {
	return f.nic.MTU() - uint32(f.MaxHeaderLength())
}

func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
	return 123
}

func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
	if _, _, ok := f.proto.Parse(pkt); !ok {
		return
	}

	netHdr := pkt.NetworkHeader().Slice()
	_, dst := f.proto.ParseAddresses(netHdr)

	addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint, true /* readOnly */)
	if addressEndpoint != nil {
		// Dispatch the packet to the transport protocol.
		f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt)
		return
	}

	r, err := f.proto.stack.FindRoute(0, tcpip.Address{}, dst, fwdTestNetNumber, false /* multicastLoop */)
	if err != nil {
		return
	}
	defer r.Release()

	pkt = NewPacketBuffer(PacketBufferOptions{
		ReserveHeaderBytes: int(r.MaxHeaderLength()),
		Payload:            pkt.ToBuffer(),
	})
	// TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets.
	_ = r.WriteHeaderIncludedPacket(pkt)
}

func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
	return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}

func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
	return f.proto.Number()
}

func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
	// Add the protocol's header to the packet and send it to the link
	// endpoint.
	b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
	remote := r.RemoteAddress()
	local := r.LocalAddress()
	copy(b[dstAddrOffset:], remote.AsSlice())
	copy(b[srcAddrOffset:], local.AsSlice())
	b[protocolNumberOffset] = byte(params.Protocol)
	pkt.NetworkProtocolNumber = fwdTestNetNumber

	return f.nic.WritePacket(r, pkt)
}

func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error {
	// The network header should not already be populated.
	if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
		return &tcpip.ErrMalformedHeader{}
	}
	pkt.NetworkProtocolNumber = fwdTestNetNumber

	return f.nic.WritePacket(r, pkt)
}

// Close implements stack.LinkEndpoint.
func (f *fwdTestNetworkEndpoint) Close() {
	f.AddressableEndpointState.Cleanup()
}

// Stats implements stack.NetworkEndpoint.
func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats {
	return &fwdTestNetworkEndpointStats{}
}

var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil)

type fwdTestNetworkEndpointStats struct{}

// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {}

var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)

// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
	stack *Stack

	neigh                  *neighborCache
	addrResolveDelay       time.Duration
	onLinkAddressResolved  func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
	onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}

func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
	return fwdTestNetNumber
}

func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
	return fwdTestNetHeaderLen
}

func (*fwdTestNetworkProtocol) ParseAddresses(v []byte) (src, dst tcpip.Address) {
	return tcpip.AddrFrom4Slice(v[srcAddrOffset : srcAddrOffset+4]), tcpip.AddrFrom4Slice(v[dstAddrOffset : dstAddrOffset+4])
}

func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
	netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen)
	if !ok {
		return 0, false, false
	}
	return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}

func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint {
	e := &fwdTestNetworkEndpoint{
		nic:        nic,
		proto:      f,
		dispatcher: dispatcher,
	}
	e.AddressableEndpointState.Init(e, AddressableEndpointStateOptions{HiddenWhileDisabled: false})
	return e
}

func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
	return &tcpip.ErrUnknownProtocolOption{}
}

func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
	return &tcpip.ErrUnknownProtocolOption{}
}

func (*fwdTestNetworkProtocol) Close() {}

func (*fwdTestNetworkProtocol) Wait() {}

func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
	if fn := f.proto.onLinkAddressResolved; fn != nil {
		f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() {
			fn(f.proto.neigh, addr, remoteLinkAddr)
		})
	}
	return nil
}

func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
	if fn := f.proto.onResolveStaticAddress; fn != nil {
		return fn(addr)
	}
	return "", false
}

func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
	return fwdTestNetNumber
}

// Forwarding implements stack.ForwardingNetworkEndpoint.
func (f *fwdTestNetworkEndpoint) Forwarding() bool {
	f.mu.RLock()
	defer f.mu.RUnlock()
	return f.mu.forwarding

}

// SetForwarding implements stack.ForwardingNetworkEndpoint.
func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) bool {
	f.mu.Lock()
	defer f.mu.Unlock()
	prev := f.mu.forwarding
	f.mu.forwarding = v
	return prev
}

var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil)

type fwdTestLinkEndpoint struct {
	dispatcher NetworkDispatcher
	mtu        uint32
	linkAddr   tcpip.LinkAddress

	// C is where outbound packets are queued.
	C chan *PacketBuffer
}

// InjectInbound injects an inbound packet.
func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
	e.InjectLinkAddr(protocol, "", pkt)
}

// InjectLinkAddr injects an inbound packet with a remote link address.
func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
	e.dispatcher.DeliverNetworkPacket(protocol, pkt)
}

// Attach saves the stack network-layer dispatcher for use later when packets
// are injected.
func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
	e.dispatcher = dispatcher
}

// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *fwdTestLinkEndpoint) IsAttached() bool {
	return e.dispatcher != nil
}

// MTU implements stack.LinkEndpoint.MTU.
func (e *fwdTestLinkEndpoint) MTU() uint32 {
	return e.mtu
}

// SetMTU implements stack.LinkEndpoint.SetMTU.
func (e *fwdTestLinkEndpoint) SetMTU(mtu uint32) {
	e.mtu = mtu
}

// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
	caps := LinkEndpointCapabilities(0)
	return caps | CapabilityResolutionRequired
}

// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
	return 0
}

// LinkAddress returns the link address of this endpoint.
func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
	return e.linkAddr
}

// SetLinkAddress sets the link address of this endpoint.
func (e *fwdTestLinkEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
	e.linkAddr = addr
}

// WritePackets stores outbound packets into the channel.
func (e *fwdTestLinkEndpoint) WritePackets(pkts PacketBufferList) (int, tcpip.Error) {
	n := 0
	for _, pkt := range pkts.AsSlice() {
		select {
		case e.C <- pkt.IncRef():
		default:
		}

		n++
	}

	return n, nil
}

// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}

// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
	panic("not implemented")
}

// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*fwdTestLinkEndpoint) AddHeader(*PacketBuffer) {}

// ParseHeader implements stack.LinkEndpoint.ParseHeader.
func (*fwdTestLinkEndpoint) ParseHeader(*PacketBuffer) bool { return true }

func (*fwdTestLinkEndpoint) Close() {}

// SetOnCloseAction implements stack.LinkEndpoint.SetOnCloseAction.
func (*fwdTestLinkEndpoint) SetOnCloseAction(func()) {}

func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) {
	clock := faketime.NewManualClock()
	// Create a stack with the network protocol and two NICs.
	s := New(Options{
		NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
			proto.stack = s
			return proto
		}},
		Clock: clock,
	})

	protoNum := proto.Number()
	if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil {
		t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err)
	}

	// NIC 1 has the link address "a", and added the network address 1.
	ep1 := &fwdTestLinkEndpoint{
		C:        make(chan *PacketBuffer, 300),
		mtu:      fwdTestNetDefaultMTU,
		linkAddr: "a",
	}
	if err := s.CreateNIC(1, ep1); err != nil {
		t.Fatal("CreateNIC #1 failed:", err)
	}
	protocolAddr1 := tcpip.ProtocolAddress{
		Protocol: fwdTestNetNumber,
		AddressWithPrefix: tcpip.AddressWithPrefix{
			Address:   tcpip.AddrFrom4Slice([]byte("\x01\x00\x00\x00")),
			PrefixLen: fwdTestNetDefaultPrefixLen,
		},
	}
	if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
	}

	// NIC 2 has the link address "b", and added the network address 2.
	ep2 := &fwdTestLinkEndpoint{
		C:        make(chan *PacketBuffer, 300),
		mtu:      fwdTestNetDefaultMTU,
		linkAddr: "b",
	}
	if err := s.CreateNIC(2, ep2); err != nil {
		t.Fatal("CreateNIC #2 failed:", err)
	}
	protocolAddr2 := tcpip.ProtocolAddress{
		Protocol: fwdTestNetNumber,
		AddressWithPrefix: tcpip.AddressWithPrefix{
			Address:   tcpip.AddrFrom4Slice([]byte("\x02\x00\x00\x00")),
			PrefixLen: fwdTestNetDefaultPrefixLen,
		},
	}
	if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
	}

	s.mu.RLock()
	nic, ok := s.nics[2]
	s.mu.RUnlock()
	if !ok {
		t.Fatal("NIC 2 does not exist")
	}

	if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok {
		proto.neigh = &l.neigh
	}

	// Route all packets to NIC 2.
	{
		subnet, err := tcpip.NewSubnet(tcpip.AddrFrom4Slice([]byte("\x00\x00\x00\x00")), tcpip.MaskFrom("\x00\x00\x00\x00"))
		if err != nil {
			t.Fatal(err)
		}
		s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
	}

	return clock, ep1, ep2
}

func TestForwardingWithStaticResolver(t *testing.T) {
	// Create a network protocol with a static resolver.
	proto := &fwdTestNetworkProtocol{
		onResolveStaticAddress:
		// The network address 3 is resolved to the link address "c".
		func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
			if addr == tcpip.AddrFrom4Slice([]byte("\x03\x00\x00\x00")) {
				return "c", true
			}
			return "", false
		},
	}

	clock, ep1, ep2 := fwdTestNetFactory(t, proto)

	// Inject an inbound packet to address 3 on NIC 1, and see if it is
	// forwarded to NIC 2.
	buf := make([]byte, 30)
	copy(buf[dstAddrOffset:], []byte("\x03\x00\x00\x00"))
	ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
		Payload: buffer.MakeWithData(buf),
	}))

	var p *PacketBuffer

	clock.Advance(proto.addrResolveDelay)
	select {
	case p = <-ep2.C:
	default:
		t.Fatal("packet not forwarded")
	}

	// Test that the static address resolution happened correctly.
	if p.EgressRoute.RemoteLinkAddress != "c" {
		t.Fatalf("got p.EgressRoute.RemoteLinkAddress = %s, want = c", p.EgressRoute.RemoteLinkAddress)
	}
	if p.EgressRoute.LocalLinkAddress != "b" {
		t.Fatalf("got p.EgressRoute.LocalLinkAddress = %s, want = b", p.EgressRoute.LocalLinkAddress)
	}
}

func TestForwardingWithFakeResolver(t *testing.T) {
	proto := fwdTestNetworkProtocol{
		addrResolveDelay: 500 * time.Millisecond,
		onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
			t.Helper()
			if len(linkAddr) != 0 {
				t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
			}
			// Any address will be resolved to the link address "c".
			neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
				Solicited: true,
				Override:  false,
				IsRouter:  false,
			})
		},
	}
	clock, ep1, ep2 := fwdTestNetFactory(t, &proto)

	// Inject an inbound packet to address 3 on NIC 1, and see if it is
	// forwarded to NIC 2.
	buf := make([]byte, 30)
	buf[dstAddrOffset] = 3
	ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
		Payload: buffer.MakeWithData(buf),
	}))

	var p *PacketBuffer

	clock.Advance(proto.addrResolveDelay)
	select {
	case p = <-ep2.C:
	default:
		t.Fatal("packet not forwarded")
	}

	// Test that the address resolution happened correctly.
	if p.EgressRoute.RemoteLinkAddress != "c" {
		t.Fatalf("got p.EgressRoute.RemoteLinkAddress = %s, want = c", p.EgressRoute.RemoteLinkAddress)
	}
	if p.EgressRoute.LocalLinkAddress != "b" {
		t.Fatalf("got p.EgressRoute.LocalLinkAddress = %s, want = b", p.EgressRoute.LocalLinkAddress)
	}
}

func TestForwardingWithNoResolver(t *testing.T) {
	// Create a network protocol without a resolver.
	proto := &fwdTestNetworkProtocol{}

	// Whether or not we use the neighbor cache here does not matter since
	// neither linkAddrCache nor neighborCache will be used.
	clock, ep1, ep2 := fwdTestNetFactory(t, proto)

	// inject an inbound packet to address 3 on NIC 1, and see if it is
	// forwarded to NIC 2.
	buf := make([]byte, 30)
	buf[dstAddrOffset] = 3
	ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
		Payload: buffer.MakeWithData(buf),
	}))

	clock.Advance(proto.addrResolveDelay)
	select {
	case <-ep2.C:
		t.Fatal("Packet should not be forwarded")
	default:
	}
}

func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
	proto := &fwdTestNetworkProtocol{
		addrResolveDelay: 50 * time.Millisecond,
		onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) {
			// Don't resolve the link address.
		},
	}

	clock, ep1, ep2 := fwdTestNetFactory(t, proto)

	const numPackets int = 5
	// These packets will all be enqueued in the packet queue to wait for link
	// address resolution.
	for i := 0; i < numPackets; i++ {
		buf := make([]byte, 30)
		buf[dstAddrOffset] = 3
		ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
			Payload: buffer.MakeWithData(buf),
		}))
	}

	// All packets should fail resolution.
	for i := 0; i < numPackets; i++ {
		clock.Advance(proto.addrResolveDelay)
		select {
		case got := <-ep2.C:
			t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
		default:
		}
	}
}

func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
	proto := fwdTestNetworkProtocol{
		addrResolveDelay: 500 * time.Millisecond,
		onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
			t.Helper()
			if len(linkAddr) != 0 {
				t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
			}
			// Only packets to address 3 will be resolved to the
			// link address "c".
			if addr == tcpip.AddrFrom4Slice([]byte("\x03\x00\x00\x00")) {
				neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
					Solicited: true,
					Override:  false,
					IsRouter:  false,
				})
			}
		},
	}
	clock, ep1, ep2 := fwdTestNetFactory(t, &proto)

	// Inject an inbound packet to address 4 on NIC 1. This packet should
	// not be forwarded.
	buf := make([]byte, 30)
	buf[dstAddrOffset] = 4
	ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
		Payload: buffer.MakeWithData(buf),
	}))

	// Inject an inbound packet to address 3 on NIC 1, and see if it is
	// forwarded to NIC 2.
	buf = make([]byte, 30)
	buf[dstAddrOffset] = 3
	ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
		Payload: buffer.MakeWithData(buf),
	}))

	var p *PacketBuffer

	clock.Advance(proto.addrResolveDelay)
	select {
	case p = <-ep2.C:
	default:
		t.Fatal("packet not forwarded")
	}

	nh := PayloadSince(p.NetworkHeader())
	defer nh.Release()
	if nh.AsSlice()[dstAddrOffset] != 3 {
		t.Fatalf("got p.NetworkHeader[dstAddrOffset] = %d, want = 3", nh.AsSlice()[dstAddrOffset])
	}

	// Test that the address resolution happened correctly.
	if p.EgressRoute.RemoteLinkAddress != "c" {
		t.Fatalf("got p.EgressRoute.RemoteLinkAddress = %s, want = c", p.EgressRoute.RemoteLinkAddress)
	}
	if p.EgressRoute.LocalLinkAddress != "b" {
		t.Fatalf("got p.EgressRoute.LocalLinkAddress = %s, want = b", p.EgressRoute.LocalLinkAddress)
	}
}

func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
	proto := fwdTestNetworkProtocol{
		addrResolveDelay: 500 * time.Millisecond,
		onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
			t.Helper()
			if len(linkAddr) != 0 {
				t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
			}
			// Any packets will be resolved to the link address "c".
			neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
				Solicited: true,
				Override:  false,
				IsRouter:  false,
			})
		},
	}
	clock, ep1, ep2 := fwdTestNetFactory(t, &proto)

	// Inject two inbound packets to address 3 on NIC 1.
	for i := 0; i < 2; i++ {
		buf := make([]byte, 30)
		buf[dstAddrOffset] = 3
		ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
			Payload: buffer.MakeWithData(buf),
		}))
	}

	for i := 0; i < 2; i++ {
		var p *PacketBuffer

		clock.Advance(proto.addrResolveDelay)
		select {
		case p = <-ep2.C:
		default:
			t.Fatal("packet not forwarded")
		}

		nh := PayloadSince(p.NetworkHeader())
		defer nh.Release()
		if nh.AsSlice()[dstAddrOffset] != 3 {
			t.Fatalf("got p.NetworkHeader[dstAddrOffset] = %d, want = 3", nh.AsSlice()[dstAddrOffset])
		}

		// Test that the address resolution happened correctly.
		if p.EgressRoute.RemoteLinkAddress != "c" {
			t.Fatalf("got p.EgressRoute.RemoteLinkAddress = %s, want = c", p.EgressRoute.RemoteLinkAddress)
		}
		if p.EgressRoute.LocalLinkAddress != "b" {
			t.Fatalf("got p.EgressRoute.LocalLinkAddress = %s, want = b", p.EgressRoute.LocalLinkAddress)
		}
	}
}

func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
	proto := fwdTestNetworkProtocol{
		addrResolveDelay: 500 * time.Millisecond,
		onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
			t.Helper()
			if len(linkAddr) != 0 {
				t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
			}
			// Any packets will be resolved to the link address "c".
			neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
				Solicited: true,
				Override:  false,
				IsRouter:  false,
			})
		},
	}
	clock, ep1, ep2 := fwdTestNetFactory(t, &proto)

	for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
		// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
		buf := make([]byte, 30)
		buf[dstAddrOffset] = 3
		// Set the packet sequence number.
		binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
		ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
			Payload: buffer.MakeWithData(buf),
		}))
	}

	for i := 0; i < maxPendingPacketsPerResolution; i++ {
		var p *PacketBuffer

		clock.Advance(proto.addrResolveDelay)
		select {
		case p = <-ep2.C:
		default:
			t.Fatal("packet not forwarded")
		}

		b := PayloadSince(p.NetworkHeader())
		defer b.Release()
		if b.AsSlice()[dstAddrOffset] != 3 {
			t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b.AsSlice()[dstAddrOffset])
		}
		if b.Size() < fwdTestNetHeaderLen+2 {
			t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b.Size())
		}
		seqNumBuf := b.AsSlice()[fwdTestNetHeaderLen:]

		// The first 5 packets should not be forwarded so the sequence number should
		// start with 5.
		want := uint16(i + 5)
		if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
			t.Fatalf("got the packet #%d, want = #%d", n, want)
		}

		// Test that the address resolution happened correctly.
		if p.EgressRoute.RemoteLinkAddress != "c" {
			t.Fatalf("got p.EgressRoute.RemoteLinkAddress = %s, want = c", p.EgressRoute.RemoteLinkAddress)
		}
		if p.EgressRoute.LocalLinkAddress != "b" {
			t.Fatalf("got p.EgressRoute.LocalLinkAddress = %s, want = b", p.EgressRoute.LocalLinkAddress)
		}
	}
}

func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
	proto := fwdTestNetworkProtocol{
		addrResolveDelay: 500 * time.Millisecond,
		onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
			t.Helper()
			if len(linkAddr) != 0 {
				t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
			}
			// Any packets will be resolved to the link address "c".
			neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
				Solicited: true,
				Override:  false,
				IsRouter:  false,
			})
		},
	}
	clock, ep1, ep2 := fwdTestNetFactory(t, &proto)

	for i := 0; i < maxPendingResolutions+5; i++ {
		// Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
		// Each packet has a different destination address (3 to
		// maxPendingResolutions + 7).
		buf := make([]byte, 30)
		buf[dstAddrOffset] = byte(3 + i)
		ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
			Payload: buffer.MakeWithData(buf),
		}))
	}

	for i := 0; i < maxPendingResolutions; i++ {
		var p *PacketBuffer

		clock.Advance(proto.addrResolveDelay)
		select {
		case p = <-ep2.C:
		default:
			t.Fatal("packet not forwarded")
		}

		// The first 5 packets (address 3 to 7) should not be forwarded
		// because their address resolutions are interrupted.
		nh := PayloadSince(p.NetworkHeader())
		defer nh.Release()
		if nh.AsSlice()[dstAddrOffset] < 8 {
			t.Fatalf("got p.NetworkHeader[dstAddrOffset] = %d, want p.NetworkHeader[dstAddrOffset] >= 8", nh.AsSlice()[dstAddrOffset])
		}

		// Test that the address resolution happened correctly.
		if p.EgressRoute.RemoteLinkAddress != "c" {
			t.Fatalf("got p.EgressRoute.RemoteLinkAddress = %s, want = c", p.EgressRoute.RemoteLinkAddress)
		}
		if p.EgressRoute.LocalLinkAddress != "b" {
			t.Fatalf("got p.EgressRoute.LocalLinkAddress = %s, want = b", p.EgressRoute.LocalLinkAddress)
		}
	}
}
