From 1b9add1bec9b281d2f7dd8a2c3d8a33e9f37c784 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 26 Jun 2025 13:49:44 +0800 Subject: [PATCH 01/14] fix retransmission logic for path probing packets (#5241) To achieve an exponential backoff, the timer should only be reset after having fired. --- path_manager_outgoing.go | 2 +- path_manager_outgoing_test.go | 41 +++++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/path_manager_outgoing.go b/path_manager_outgoing.go index 595eda237a4..e8c065ffb93 100644 --- a/path_manager_outgoing.go +++ b/path_manager_outgoing.go @@ -50,6 +50,7 @@ func (p *Path) Probe(ctx context.Context) error { p.validated.Store(true) return nil case <-timerChan: + nextProbeDur *= 2 // exponential backoff p.pathManager.enqueueProbe(p) case <-path.ProbeSent(): case <-p.abandon: @@ -61,7 +62,6 @@ func (p *Path) Probe(ctx context.Context) error { } timer = time.NewTimer(nextProbeDur) timerChan = timer.C - nextProbeDur *= 2 // exponential backoff } } diff --git a/path_manager_outgoing_test.go b/path_manager_outgoing_test.go index 3e965b06a81..fc717890973 100644 --- a/path_manager_outgoing_test.go +++ b/path_manager_outgoing_test.go @@ -119,7 +119,7 @@ func TestPathManagerOutgoingRetransmissions(t *testing.T) { require.False(t, ok) tr1 := &Transport{} - initialRTT := scaleDuration(2 * time.Millisecond) + initialRTT := scaleDuration(5 * time.Millisecond) p := pm.NewPath(tr1, initialRTT, func() {}) pathChallengeChan := make(chan [8]byte) @@ -146,31 +146,40 @@ func TestPathManagerOutgoingRetransmissions(t *testing.T) { go func() { errChan <- p.Probe(context.Background()) }() start := time.Now() - var pathChallenges [][8]byte + type result struct { + pc *[8]byte + took time.Duration + } + var results []result for range 4 { select { case err := <-errChan: require.NoError(t, err) case pc := <-pathChallengeChan: - pathChallenges = append(pathChallenges, pc) + results = append(results, result{pc: &pc, took: time.Since(start)}) case <-time.After(scaleDuration(time.Second)): t.Fatal("timeout") } } - took := time.Since(start) - - require.NotContains(t, pathChallenges, [8]byte{}) - require.NotEqual(t, pathChallenges[0], pathChallenges[1]) - require.NotEqual(t, pathChallenges[0], pathChallenges[2]) - require.NotEqual(t, pathChallenges[0], pathChallenges[3]) - require.NotEqual(t, pathChallenges[1], pathChallenges[2]) - require.NotEqual(t, pathChallenges[2], pathChallenges[3]) - require.Greater(t, took, initialRTT*(1+2+4+8)) - require.Less(t, took, initialRTT*(1+2+4+8)*3/2) + for i, r1 := range results { + require.NotNil(t, r1.pc) + if i > 0 { + took := r1.took - results[i-1].took + t.Log("took", took) + require.Greater(t, took, initialRTT<<(i-1)) + require.Less(t, took, initialRTT< Date: Thu, 26 Jun 2025 14:42:08 +0800 Subject: [PATCH 02/14] implement receiver side behavior for RESET_STREAM_AT (#5235) * implement receiver side behavior for RESET_STREAM_AT * simplify reliable offset tracking --- receive_stream.go | 40 ++++++-- receive_stream_test.go | 211 ++++++++++++++++++++++++++++++++++------- 2 files changed, 205 insertions(+), 46 deletions(-) diff --git a/receive_stream.go b/receive_stream.go index 3b2d618ce57..61f4cdf4244 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -42,6 +42,9 @@ type ReceiveStream struct { cancelErr *StreamError closeForShutdownErr error + readPos protocol.ByteCount + reliableSize protocol.ByteCount + readChan chan struct{} readOnce chan struct{} // cap: 1, to protect against concurrent use of Read deadline time.Time @@ -128,7 +131,7 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW s.errorRead = true return false, false, 0, io.EOF } - if s.cancelledRemotely || s.cancelledLocally { + if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) { s.errorRead = true return false, false, 0, s.cancelErr } @@ -151,9 +154,9 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW if s.closeForShutdownErr != nil { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr } - if s.cancelledRemotely || s.cancelledLocally { + if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) { s.errorRead = true - return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr } deadline := s.deadline @@ -194,14 +197,11 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW if s.readPosInFrame > len(s.currentFrame) { return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) } - m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) - s.readPosInFrame += m - bytesRead += m // when a RESET_STREAM was received, the flow controller was already - // informed about the final byteOffset for this stream - if !s.cancelledRemotely { + // informed about the final offset for this stream + if !s.cancelledRemotely || s.readPos < s.reliableSize { hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m)) if hasStream { s.queuedMaxStreamData = true @@ -212,6 +212,14 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW } } + s.readPosInFrame += m + s.readPos += protocol.ByteCount(m) + bytesRead += m + + if s.cancelledRemotely && s.readPos >= s.reliableSize { + s.flowController.Abandon() + } + if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { s.currentFrame = nil if s.currentFrameDone != nil { @@ -221,6 +229,10 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF } } + if s.cancelledRemotely && s.readPos >= s.reliableSize { + s.errorRead = true + return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr + } return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil } @@ -231,7 +243,7 @@ func (s *ReceiveStream) dequeueNextFrame() { s.currentFrameDone() } offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() - s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset + s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset && !s.cancelledRemotely s.readPosInFrame = 0 } @@ -323,11 +335,19 @@ func (s *ReceiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame, } s.finalOffset = frame.FinalSize + // senders are allowed to reduce the reliable size, but frames might have been reordered + if (!s.cancelledRemotely && s.reliableSize == 0) || frame.ReliableSize < s.reliableSize { + s.reliableSize = frame.ReliableSize + } + if s.readPos >= s.reliableSize { + // calling Abandon multiple times is a no-op + s.flowController.Abandon() + } // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) if s.cancelledRemotely { return nil } - s.flowController.Abandon() + // don't save the error if the RESET_STREAM frames was received after CancelRead was called if s.cancelledLocally { return nil diff --git a/receive_stream_test.go b/receive_stream_test.go index 2b9f38d2aa0..3619fa3ca40 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -447,7 +447,16 @@ func TestReceiveStreamCancellation(t *testing.T) { require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false}) } -func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) { +func TestReceiveStreamCancelReadAfterFIN(t *testing.T) { + t.Run("FIN not read", func(t *testing.T) { + testReceiveStreamCancelReadAfterFIN(t, false) + }) + t.Run("FIN read", func(t *testing.T) { + testReceiveStreamCancelReadAfterFIN(t, true) + }) +} + +func testReceiveStreamCancelReadAfterFIN(t *testing.T, finRead bool) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) @@ -456,46 +465,38 @@ func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) + if finRead { + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) + n, err := str.Read(make([]byte, 10)) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 6, n) + } // if the FIN was received, but not read yet, a STOP_SENDING frame is queued - mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str) - mockFC.EXPECT().Abandon() + if !finRead { + mockFC.EXPECT().Abandon() + mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str) + } str.CancelRead(1337) f, ok, hasMore := str.getControlFrame(time.Now()) - require.True(t, ok) - require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame) - require.False(t, hasMore) - - // Read returns the error - n, err := str.Read([]byte{0}) - require.Zero(t, n) - require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) -} - -func TestReceiveStreamCancelReadAfterFINRead(t *testing.T) { - mockCtrl := gomock.NewController(t) - mockFC := mocks.NewMockStreamFlowController(mockCtrl) - mockSender := NewMockStreamSender(mockCtrl) - str := newReceiveStream(42, mockSender, mockFC) - - mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any()) - mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6)) - mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) - require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) - n, err := str.Read(make([]byte, 10)) - require.ErrorIs(t, err, io.EOF) - require.Equal(t, 6, n) - // if the EOF was already read, no STOP_SENDING frame is queued - str.CancelRead(1234) - _, ok, hasMore := str.getControlFrame(time.Now()) - require.False(t, ok) - require.False(t, hasMore) + if finRead { + require.False(t, ok) + require.False(t, hasMore) + } else { + require.True(t, ok) + require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame) + require.False(t, hasMore) + } // Read returns the error - n, err = str.Read([]byte{0}) + n, err := str.Read([]byte{0}) require.Zero(t, n) - require.ErrorIs(t, err, io.EOF) + if finRead { + require.ErrorIs(t, err, io.EOF) + } else { + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}) + } } func TestReceiveStreamReset(t *testing.T) { @@ -520,7 +521,7 @@ func TestReceiveStreamReset(t *testing.T) { mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) gomock.InOrder( mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()), - mockFC.EXPECT().Abandon(), + mockFC.EXPECT().Abandon().MinTimes(1), ) require.NoError(t, str.handleResetStreamFrame( &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42}, @@ -616,14 +617,14 @@ func TestReceiveStreamConcurrentReads(t *testing.T) { const num = 3 errChan := make(chan error, num) - for i := 0; i < num; i++ { + for range num { go func() { _, err := str.Read(make([]byte, 8)) errChan <- err }() } require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now())) - for i := 0; i < num; i++ { + for range num { select { case err := <-errChan: require.ErrorIs(t, err, io.EOF) @@ -634,3 +635,141 @@ func TestReceiveStreamConcurrentReads(t *testing.T) { require.Equal(t, protocol.ByteCount(6), bytesRead) require.Equal(t, int32(1), numCompleted.Load()) } + +func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) + b := make([]byte, 3) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, []byte("foo"), b) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + mockFC.EXPECT().Abandon() + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // Read returns the error + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read([]byte{0}) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Zero(t, n) +} + +func TestReceiveStreamResetStreamAtAfterReadOffset(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + b := make([]byte, 2) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("fo"), b) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // Read returns the error + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)) + n, err = str.Read(b) + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, []byte("ob"), b) + require.True(t, mockCtrl.Satisfied()) + + gomock.InOrder( + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)), + mockFC.EXPECT().Abandon(), + ) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read(b) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Equal(t, 2, n) + require.Equal(t, []byte("ar"), b) +} + +func TestReceiveStreamMultipleResetStreamAt(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) + b := make([]byte, 3) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, []byte("foo"), b) + require.True(t, mockCtrl.Satisfied()) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // receiving a reordered RESET_STREAM_AT frame has no effect + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // receiving a RESET_STREAM_AT frame with a smaller reliable size is valid + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + mockFC.EXPECT().Abandon() + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now()) + + // Read returns the error + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read(b) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Zero(t, n) +} + +func TestReceiveStreamResetStreamAtAfterResetStream(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newReceiveStream(42, mockSender, mockFC) + + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any()) + require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now())) + + mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3)) + b := make([]byte, 3) + n, err := str.Read(b) + require.NoError(t, err) + require.Equal(t, 3, n) + require.Equal(t, []byte("foo"), b) + require.True(t, mockCtrl.Satisfied()) + + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // receiving a reordered RESET_STREAM_AT frame has no effect + mockFC.EXPECT().Abandon() + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any()) + str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now()) + require.True(t, mockCtrl.Satisfied()) + + // Read returns the error + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42)) + n, err = str.Read(b) + require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true}) + require.Zero(t, n) +} From a2926a36032fbcd1cc709d114753834cf25d62cb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 28 Jun 2025 11:44:47 +0800 Subject: [PATCH 03/14] implement sender side behavior for RESET_STREAM_AT (#5242) * improve existing send stream test * implement sender side behavior for RESET_STREAM_AT * refactor send stream cancelation and shutdown error handling * correctly deal with 0-RTT corner case --- connection.go | 4 +- send_stream.go | 291 +++++++++++----- send_stream_test.go | 631 +++++++++++++++++++++++++++++++++-- stream.go | 7 +- stream_test.go | 6 +- streams_map.go | 24 +- streams_map_incoming_test.go | 11 +- streams_map_outgoing.go | 9 + streams_map_outgoing_test.go | 5 + streams_map_test.go | 14 +- 10 files changed, 864 insertions(+), 138 deletions(-) diff --git a/connection.go b/connection.go index 056cdb7405a..a6af5fcda3c 100644 --- a/connection.go +++ b/connection.go @@ -1882,7 +1882,7 @@ func (c *Conn) restoreTransportParameters(params *wire.TransportParameters) { c.peerParams = params c.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) c.connFlowController.UpdateSendWindow(params.InitialMaxData) - c.streamsMap.UpdateLimits(params) + c.streamsMap.HandleTransportParameters(params) c.connStateMutex.Lock() c.connState.SupportsDatagrams = c.supportsDatagrams() c.connStateMutex.Unlock() @@ -1961,7 +1961,7 @@ func (c *Conn) applyTransportParameters() { c.idleTimeout = min(c.idleTimeout, params.MaxIdleTimeout) } c.keepAliveInterval = min(c.config.KeepAlivePeriod, c.idleTimeout/2) - c.streamsMap.UpdateLimits(params) + c.streamsMap.HandleTransportParameters(params) c.frameParser.SetAckDelayExponent(params.AckDelayExponent) c.connFlowController.UpdateSendWindow(params.InitialMaxData) c.rttStats.SetMaxAckDelay(params.MaxAckDelay) diff --git a/send_stream.go b/send_stream.go index 0c76ee5e310..d5ce777a674 100644 --- a/send_stream.go +++ b/send_stream.go @@ -9,7 +9,6 @@ import ( "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) @@ -27,21 +26,25 @@ type SendStream struct { streamID protocol.StreamID sender streamSender - writeOffset protocol.ByteCount + // reliableSize is the portion of the stream that needs to be transmitted reliably, + // even if the stream is cancelled. + // This requires the peer to support RESET_STREAM_AT. + // This value should not be accessed directly, but only through the reliableOffset method. + // This method returns 0 if the peer doesn't support the RESET_STREAM_AT extension. + reliableSize protocol.ByteCount + writeOffset protocol.ByteCount - // finalError is the error that is returned by Write. - // It can either be a cancellation error or the shutdown error. - finalError error + shutdownErr error + resetErr *StreamError queuedResetStreamFrame *wire.ResetStreamFrame - finishedWriting bool // set once Close() is called - finSent bool // set when a STREAM_FRAME with FIN bit has been sent + supportsResetStreamAt bool + finishedWriting bool // set once Close() is called + finSent bool // set when a STREAM_FRAME with FIN bit has been sent // Set when the application knows about the cancellation. // This can happen because the application called CancelWrite, // or because Write returned the error (for remote cancellations). cancellationFlagged bool - cancelled bool // both local and remote cancellations - closedForShutdown bool // set by closeForShutdown completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out @@ -65,13 +68,15 @@ func newSendStream( streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, + supportsResetStreamAt bool, ) *SendStream { s := &SendStream{ - streamID: streamID, - sender: sender, - flowController: flowController, - writeChan: make(chan struct{}, 1), - writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write + supportsResetStreamAt: supportsResetStreamAt, } s.ctx, s.ctxCancel = context.WithCancelCause(ctx) return s @@ -103,11 +108,12 @@ func (s *SendStream) write(p []byte) (bool /* is newly completed */, int, error) s.mutex.Lock() defer s.mutex.Unlock() - if s.finalError != nil { - if s.cancelled { - s.cancellationFlagged = true - } - return s.isNewlyCompleted(), 0, s.finalError + if s.resetErr != nil { + s.cancellationFlagged = true + return s.isNewlyCompleted(), 0, s.resetErr + } + if s.shutdownErr != nil { + return false, 0, s.shutdownErr } if s.finishedWriting { return false, 0, fmt.Errorf("write on closed stream %d", s.streamID) @@ -165,7 +171,7 @@ func (s *SendStream) write(p []byte) (bool /* is newly completed */, int, error) } deadlineTimer.Reset(deadline) } - if s.dataForWriting == nil || s.finalError != nil { + if s.dataForWriting == nil || s.shutdownErr != nil || s.resetErr != nil { break } } @@ -194,11 +200,12 @@ func (s *SendStream) write(p []byte) (bool /* is newly completed */, int, error) if bytesWritten == len(p) { return false, bytesWritten, nil } - if s.finalError != nil { - if s.cancelled { - s.cancellationFlagged = true - } - return s.isNewlyCompleted(), bytesWritten, s.finalError + if s.shutdownErr != nil { + return false, bytesWritten, s.shutdownErr + } + if s.resetErr != nil { + s.cancellationFlagged = true + return s.isNewlyCompleted(), bytesWritten, s.resetErr } return false, bytesWritten, nil } @@ -231,9 +238,15 @@ func (s *SendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers } func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) { - if s.finalError != nil { + if s.shutdownErr != nil { return nil, nil, false } + if s.resetErr != nil { + reliableOffset := s.reliableOffset() + if reliableOffset == 0 || (s.writeOffset >= reliableOffset && len(s.retransmissionQueue) == 0) { + return nil, nil, false + } + } if len(s.retransmissionQueue) > 0 { f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v) @@ -260,12 +273,17 @@ func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun return nil, nil, false } - sendWindow := s.flowController.SendWindowSize() - if sendWindow == 0 { + maxDataLen := s.flowController.SendWindowSize() + if maxDataLen == 0 { return nil, nil, true } - f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v) + // if the stream is canceled, only data up to the reliable size needs to be sent + reliableOffset := s.reliableOffset() + if s.resetErr != nil && reliableOffset > 0 { + maxDataLen = min(maxDataLen, reliableOffset-s.writeOffset) + } + f, hasMoreData := s.popNewStreamFrame(maxBytes, maxDataLen, v) if f == nil { return nil, nil, hasMoreData } @@ -273,10 +291,13 @@ func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun s.writeOffset += f.DataLen() s.flowController.AddBytesSent(f.DataLen()) } + if s.resetErr != nil && s.writeOffset >= reliableOffset { + hasMoreData = false + } var blocked *wire.StreamDataBlockedFrame // If the entire send window is used, the stream might have become blocked on stream-level flow control. // This is not guaranteed though, because the stream might also have been blocked on connection-level flow control. - if f.DataLen() == sendWindow && s.flowController.IsNewlyBlocked() { + if f.DataLen() == maxDataLen && s.flowController.IsNewlyBlocked() { blocked = &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset} } f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent @@ -286,9 +307,11 @@ func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun return f, blocked, hasMoreData } -func (s *SendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) { +// popNewStreamFrame returns a new STREAM frame to send for this stream +// hasMoreData says if there's more data to send, *not* taking into account the reliable size +func (s *SendStream) popNewStreamFrame(maxBytes, maxDataLen protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, hasMoreData bool) { if s.nextFrame != nil { - maxDataLen := min(sendWindow, s.nextFrame.MaxDataLen(maxBytes, v)) + maxDataLen := min(maxDataLen, s.nextFrame.MaxDataLen(maxBytes, v)) if maxDataLen == 0 { return nil, true } @@ -315,7 +338,7 @@ func (s *SendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, f.DataLenPresent = true f.Data = f.Data[:0] - hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v) + hasMoreData = s.popNewStreamFrameWithoutBuffer(f, maxBytes, maxDataLen, v) if len(f.Data) == 0 && !f.Fin { f.PutBack() return nil, hasMoreData @@ -363,6 +386,9 @@ func (s *SendStream) isNewlyCompleted() bool { if s.completed { return false } + if s.nextFrame != nil && s.nextFrame.DataLen() > 0 { + return false + } // We need to keep the stream around until all frames have been sent and acknowledged. if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil { return false @@ -377,7 +403,7 @@ func (s *SendStream) isNewlyCompleted() bool { // 2. we received a STOP_SENDING, and // * the application consumed the error via Write, or // * the application called Close - if s.cancelled && (s.cancellationFlagged || s.finishedWriting) { + if s.resetErr != nil && (s.cancellationFlagged || s.finishedWriting) { s.completed = true return true } @@ -390,12 +416,12 @@ func (s *SendStream) isNewlyCompleted() bool { // It must not be called after calling CancelWrite. func (s *SendStream) Close() error { s.mutex.Lock() - if s.closedForShutdown || s.finishedWriting { + if s.shutdownErr != nil || s.finishedWriting { s.mutex.Unlock() return nil } s.finishedWriting = true - cancelled := s.cancelled + cancelled := s.resetErr != nil if cancelled { s.cancellationFlagged = true } @@ -414,6 +440,20 @@ func (s *SendStream) Close() error { return nil } +// SetReliableBoundary marks the data written to this stream so far as reliable. +// It is valid to call this function multiple times, thereby increasing the reliable size. +// It only has an effect if the peer enabled support for the RESET_STREAM_AT extension, +// otherwise, it is a no-op. +func (s *SendStream) SetReliableBoundary() { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.reliableSize = s.writeOffset + if s.nextFrame != nil { + s.reliableSize += s.nextFrame.DataLen() + } +} + // CancelWrite aborts sending on this stream. // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. // Write will unblock immediately, and future calls to Write will fail. @@ -421,45 +461,64 @@ func (s *SendStream) Close() error { // When called after Close, it aborts reliable delivery of outstanding stream data. // Note that there is no guarantee if the peer will receive the FIN or the cancellation error first. func (s *SendStream) CancelWrite(errorCode StreamErrorCode) { - s.cancelWrite(errorCode, false) -} - -// cancelWrite cancels the stream -// It is possible to cancel a stream after it has been closed, both locally and remotely. -// This is useful to prevent the retransmission of outstanding stream data. -func (s *SendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() - if s.closedForShutdown { + if s.shutdownErr != nil { s.mutex.Unlock() return } - if !remote { - s.cancellationFlagged = true - if s.cancelled { - completed := s.isNewlyCompleted() - s.mutex.Unlock() - // The user has called CancelWrite. If the previous cancellation was - // because of a STOP_SENDING, we don't need to flag the error to the - // user anymore. - if completed { - s.sender.onStreamCompleted(s.streamID) - } - return - } - } - if s.cancelled { + + s.cancellationFlagged = true + + if s.resetErr != nil { + completed := s.isNewlyCompleted() s.mutex.Unlock() + // The user has called CancelWrite. If the previous cancellation was because of a + // STOP_SENDING, we don't need to flag the error to the user anymore. + if completed { + s.sender.onStreamCompleted(s.streamID) + } return } - s.cancelled = true - s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} - s.ctxCancel(s.finalError) - s.numOutstandingFrames = 0 - s.retransmissionQueue = nil + s.resetErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} + s.ctxCancel(s.resetErr) + + reliableOffset := s.reliableOffset() + if reliableOffset == 0 { + s.numOutstandingFrames = 0 + s.retransmissionQueue = nil + } s.queuedResetStreamFrame = &wire.ResetStreamFrame{ StreamID: s.streamID, - FinalSize: s.writeOffset, + FinalSize: max(s.writeOffset, reliableOffset), ErrorCode: errorCode, + // if the peer doesn't support the extension, the reliable offset will always be 0 + ReliableSize: reliableOffset, + } + if reliableOffset > 0 { + if s.nextFrame != nil { + if s.nextFrame.Offset >= reliableOffset { + s.nextFrame.PutBack() + s.nextFrame = nil + } else if s.nextFrame.Offset+s.nextFrame.DataLen() > reliableOffset { + s.nextFrame.Data = s.nextFrame.Data[:reliableOffset-s.nextFrame.Offset] + } + } + if len(s.retransmissionQueue) > 0 { + retransmissionQueue := make([]*wire.StreamFrame, 0, len(s.retransmissionQueue)) + for _, f := range s.retransmissionQueue { + if f.Offset >= reliableOffset { + f.PutBack() + continue + } + if f.Offset+f.DataLen() <= reliableOffset { + retransmissionQueue = append(retransmissionQueue, f) + } else { + f.Data = f.Data[:reliableOffset-f.Offset] + retransmissionQueue = append(retransmissionQueue, f) + } + } + s.retransmissionQueue = retransmissionQueue + } } s.mutex.Unlock() @@ -467,6 +526,12 @@ func (s *SendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) { s.sender.onHasStreamControlFrame(s.streamID, s) } +func (s *SendStream) enableResetStreamAt() { + s.mutex.Lock() + s.supportsResetStreamAt = true + s.mutex.Unlock() +} + func (s *SendStream) updateSendWindow(limit protocol.ByteCount) { updated := s.flowController.UpdateSendWindow(limit) if !updated { // duplicate or reordered MAX_STREAM_DATA frame @@ -480,8 +545,36 @@ func (s *SendStream) updateSendWindow(limit protocol.ByteCount) { } } -func (s *SendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.cancelWrite(frame.ErrorCode, true) +func (s *SendStream) handleStopSendingFrame(f *wire.StopSendingFrame) { + s.mutex.Lock() + if s.shutdownErr != nil { + s.mutex.Unlock() + return + } + + // If the stream was already cancelled (either locally, or due to a previous STOP_SENDING frame), + // there's nothing else to do. + if s.resetErr != nil && s.reliableOffset() == 0 { + s.mutex.Unlock() + return + } + // if the peer stopped reading from the stream, there's no need to transmit any data reliably + s.reliableSize = 0 + s.numOutstandingFrames = 0 + s.retransmissionQueue = nil + if s.resetErr == nil { + s.resetErr = &StreamError{StreamID: s.streamID, ErrorCode: f.ErrorCode, Remote: true} + s.ctxCancel(s.resetErr) + } + s.queuedResetStreamFrame = &wire.ResetStreamFrame{ + StreamID: s.streamID, + FinalSize: s.writeOffset, + ErrorCode: s.resetErr.ErrorCode, + } + s.mutex.Unlock() + + s.signalWrite() + s.sender.onHasStreamControlFrame(s.streamID, s) } func (s *SendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) { @@ -500,6 +593,13 @@ func (s *SendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore return f, true, false } +func (s *SendStream) reliableOffset() protocol.ByteCount { + if !s.supportsResetStreamAt { + return 0 + } + return s.reliableSize +} + // The Context is canceled as soon as the write-side of the stream is closed. // This happens when Close() or CancelWrite() is called, or when the peer // cancels the read-side of their stream. @@ -527,9 +627,8 @@ func (s *SendStream) SetWriteDeadline(t time.Time) error { // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. func (s *SendStream) closeForShutdown(err error) { s.mutex.Lock() - s.closedForShutdown = true - if s.finalError == nil && !s.finishedWriting { - s.finalError = err + if s.shutdownErr == nil && !s.finishedWriting { + s.shutdownErr = err } s.mutex.Unlock() s.signalWrite() @@ -550,8 +649,9 @@ var _ ackhandler.FrameHandler = &sendStreamAckHandler{} func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.PutBack() + s.mutex.Lock() - if s.cancelled { + if s.resetErr != nil && (*SendStream)(s).reliableOffset() == 0 { s.mutex.Unlock() return } @@ -570,16 +670,39 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { func (s *sendStreamAckHandler) OnLost(f wire.Frame) { sf := f.(*wire.StreamFrame) s.mutex.Lock() - if s.cancelled { + // If the reliable size was 0 when the stream was cancelled, + // the number of outstanding frames was immediately set to 0, and the retransmission queue was dropped. + if s.resetErr != nil && (*SendStream)(s).reliableOffset() == 0 { s.mutex.Unlock() return } - sf.DataLenPresent = true - s.retransmissionQueue = append(s.retransmissionQueue, sf) s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") } + + if s.resetErr != nil && (*SendStream)(s).reliableOffset() > 0 { + // If the stream was reset, and this frame is beyond the reliable offset, + // it doesn't need to be retransmitted. + if sf.Offset >= (*SendStream)(s).reliableOffset() { + sf.PutBack() + // If this frame was the last one tracked, losing it might cause the stream to be completed. + completed := (*SendStream)(s).isNewlyCompleted() + s.mutex.Unlock() + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return + } + // If the payload of the frame extends beyond the reliable size, + // truncate the frame to the reliable size. + if sf.Offset+sf.DataLen() > (*SendStream)(s).reliableOffset() { + sf.Data = sf.Data[:(*SendStream)(s).reliableOffset()-sf.Offset] + } + } + + sf.DataLenPresent = true + s.retransmissionQueue = append(s.retransmissionQueue, sf) s.mutex.Unlock() s.sender.onHasStreamData(s.streamID, (*SendStream)(s)) @@ -589,8 +712,16 @@ type sendStreamResetStreamHandler SendStream var _ ackhandler.FrameHandler = &sendStreamResetStreamHandler{} -func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) { +func (s *sendStreamResetStreamHandler) OnAcked(f wire.Frame) { + rsf := f.(*wire.ResetStreamFrame) s.mutex.Lock() + // If the peer sent a STOP_SENDING after we sent a RESET_STREAM_AT frame, + // we sent 1. reduced the reliable size to 0 and 2. sent a RESET_STREAM frame. + // In this case, we don't care about the acknowledgment of this frame. + if rsf.ReliableSize != (*SendStream)(s).reliableOffset() { + s.mutex.Unlock() + return + } s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") @@ -604,8 +735,16 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) { } func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) { + rsf := f.(*wire.ResetStreamFrame) s.mutex.Lock() - s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame) + // If the peer sent a STOP_SENDING after we sent a RESET_STREAM_AT frame, + // we sent 1. reduced the reliable size to 0 and 2. sent a RESET_STREAM frame. + // In this case, the loss of the RESET_STREAM_AT frame can be ignored. + if rsf.ReliableSize != (*SendStream)(s).reliableOffset() { + s.mutex.Unlock() + return + } + s.queuedResetStreamFrame = rsf s.numOutstandingFrames-- s.mutex.Unlock() s.sender.onHasStreamControlFrame(s.streamID, (*SendStream)(s)) diff --git a/send_stream_test.go b/send_stream_test.go index dcd593cd776..89456068650 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -10,9 +10,12 @@ import ( mrand "math/rand/v2" "net" "os" + "runtime" + "slices" "testing" "time" + "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/mocks" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" @@ -50,7 +53,7 @@ func TestSendStreamSetup(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) ctx := context.WithValue(context.Background(), "foo", "bar") - str := newSendStream(ctx, 1337, nil, mockFC) + str := newSendStream(ctx, 1337, nil, mockFC, false) require.NotNil(t, str.Context()) require.Equal(t, "bar", str.Context().Value("foo")) require.Equal(t, protocol.StreamID(1337), str.StreamID()) @@ -61,7 +64,7 @@ func TestSendStreamWriteData(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str) @@ -143,7 +146,7 @@ func TestSendStreamLargeWrites(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) data := make([]byte, 5000) @@ -173,6 +176,7 @@ func TestSendStreamLargeWrites(t *testing.T) { offset += size require.True(t, mockCtrl.Satisfied()) } + // Write should still be blocked, since there's more than protocol.MaxPacketBufferSize left to send select { case err := <-errChan: @@ -180,8 +184,13 @@ func TestSendStreamLargeWrites(t *testing.T) { case <-time.After(scaleDuration(5 * time.Millisecond)): // short wait to ensure write is blocked } + // empty frames are not sent + frame, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, offset), protocol.Version1) + require.Nil(t, frame.Frame) + require.True(t, hasMore) + mockSender.EXPECT().onHasStreamData(streamID, str) // from the Close call - frame, _, hasMore := str.popStreamFrame(size+expectedFrameHeaderLen(streamID, offset), protocol.Version1) + frame, _, hasMore = str.popStreamFrame(size+expectedFrameHeaderLen(streamID, offset), protocol.Version1) require.NotNil(t, frame.Frame) require.True(t, hasMore) require.Equal(t, data[offset:offset+size], frame.Frame.Data) @@ -206,7 +215,7 @@ func TestSendStreamLargeWriteBlocking(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) @@ -255,7 +264,7 @@ func TestSendStreamCopyData(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} // for small writes @@ -277,7 +286,7 @@ func TestSendStreamDeadlineInThePast(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), 42, mockSender, mockFC) + str := newSendStream(context.Background(), 42, mockSender, mockFC, false) // no data is written when the deadline is in the past require.NoError(t, str.SetWriteDeadline(time.Now().Add(-time.Second))) @@ -300,7 +309,7 @@ func TestSendStreamDeadlineRemoval(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), 42, mockSender, mockFC) + str := newSendStream(context.Background(), 42, mockSender, mockFC, false) deadline := scaleDuration(20 * time.Millisecond) require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) @@ -352,7 +361,7 @@ func TestSendStreamDeadlineExtension(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), 42, mockSender, mockFC) + str := newSendStream(context.Background(), 42, mockSender, mockFC, false) deadline := scaleDuration(20 * time.Millisecond) require.NoError(t, str.SetWriteDeadline(time.Now().Add(deadline))) @@ -388,7 +397,7 @@ func TestSendStreamClose(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) @@ -444,7 +453,7 @@ func TestSendStreamImmediateClose(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) require.NoError(t, str.Close()) frame, _, hasMore := str.popStreamFrame(expectedFrameHeaderLen(streamID, 13)+3, protocol.Version1) @@ -460,7 +469,7 @@ func TestSendStreamFlowControlBlocked(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := str.Write([]byte("foobar")) @@ -493,7 +502,7 @@ func TestSendStreamCloseForShutdown(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str) @@ -519,6 +528,12 @@ func TestSendStreamCloseForShutdown(t *testing.T) { t.Fatal("timeout") } + // STOP_SENDING frames are ignored + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) + _, ok, hasMore := str.getControlFrame(time.Now()) + require.False(t, ok) + require.False(t, hasMore) + // future calls to Write should return the error _, err := strWithTimeout.Write([]byte("foobar")) require.ErrorIs(t, err, assert.AnError) @@ -541,7 +556,7 @@ func TestSendStreamUpdateSendWindow(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), 42, mockSender, mockFC) + str := newSendStream(context.Background(), 42, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(gomock.Any(), str) _, err := str.Write([]byte("foobar")) @@ -551,6 +566,12 @@ func TestSendStreamUpdateSendWindow(t *testing.T) { // no calls to onHasStreamData if the window size wasn't increased mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(41)).Return(false) str.updateSendWindow(41) + + gomock.InOrder( + mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(123)).Return(true), + mockSender.EXPECT().onHasStreamData(protocol.StreamID(42), str), + ) + str.updateSendWindow(123) } func TestSendStreamCancellation(t *testing.T) { @@ -558,7 +579,7 @@ func TestSendStreamCancellation(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str) @@ -572,6 +593,10 @@ func TestSendStreamCancellation(t *testing.T) { require.Equal(t, []byte("foo"), frame.Frame.Data) require.True(t, mockCtrl.Satisfied()) + // The stream doesn't support RESET_STREAM_AT. + // Setting the reliable boundary has no effect. + str.SetReliableBoundary() + wrote := make(chan struct{}) mockSender.EXPECT().onHasStreamData(streamID, str).Do(func(protocol.StreamID, *SendStream) { close(wrote) }) errChan := make(chan error, 1) @@ -651,7 +676,7 @@ func TestSendStreamCancellationAfterClose(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second} mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) @@ -690,7 +715,7 @@ func testSendStreamCancellationStreamRetransmission(t *testing.T, remote bool) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) @@ -741,7 +766,7 @@ func TestSendStreamCancellationResetStreamRetransmission(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamControlFrame(streamID, str) str.CancelWrite(1337) @@ -767,12 +792,73 @@ func TestSendStreamCancellationResetStreamRetransmission(t *testing.T) { f2.Handler.OnAcked(f2.Frame) } -func TestSendStreamStopSending(t *testing.T) { +func TestSendStreamStopSendingAfterWrite(t *testing.T) { + t.Run("complete by Write", func(t *testing.T) { + testSendStreamStopSendingAfterWrite(t, "write") + }) + t.Run("complete by Close", func(t *testing.T) { + testSendStreamStopSendingAfterWrite(t, "close") + }) + t.Run("complete by CancelWrite", func(t *testing.T) { + testSendStreamStopSendingAfterWrite(t, "cancelwrite") + }) +} + +func testSendStreamStopSendingAfterWrite(t *testing.T, completeBy string) { + const streamID protocol.StreamID = 1000 + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) + + mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(2) + _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.NoError(t, err) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + frame, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.NotNil(t, frame.Frame) + require.True(t, mockCtrl.Satisfied()) + + mockSender.EXPECT().onHasStreamControlFrame(streamID, str) + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1337}) + + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 6, ErrorCode: 1337}, cf.Frame) + require.False(t, hasMore) + + // acknowledging the RESET_STREAM frame doesn't complete the stream, + // since it was neither cancelled nor closed + cf.Handler.OnAcked(cf.Frame) + require.True(t, mockCtrl.Satisfied()) + + mockSender.EXPECT().onStreamCompleted(streamID) + switch completeBy { + case "write": + // calls to Write should return an error + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + case "close": + require.ErrorContains(t, str.Close(), "close called for canceled stream") + case "cancelwrite": + str.CancelWrite(1234) + } + // error code and remote flag are unchanged + _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) + require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) + frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Nil(t, frame.Frame) + _, ok, _ = str.getControlFrame(time.Now()) + require.False(t, ok) +} + +func TestSendStreamStopSendingDuringWrite(t *testing.T) { const streamID protocol.StreamID = 1000 mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(2) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) @@ -785,7 +871,7 @@ func TestSendStreamStopSending(t *testing.T) { errChan := make(chan error, 1) go func() { - _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write(make([]byte, 2000)) + _, err := str.Write(make([]byte, 2000)) errChan <- err }() @@ -804,6 +890,17 @@ func TestSendStreamStopSending(t *testing.T) { require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 6, ErrorCode: 1337}, cf.Frame) require.False(t, hasMore) + // receiving another STOP_SENDING frame has no effect + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: streamID, ErrorCode: 1234}) + _, ok, hasMore = str.getControlFrame(time.Now()) + require.False(t, ok) + require.False(t, hasMore) + + // acknowledging the RESET_STREAM frame completes the stream + mockSender.EXPECT().onStreamCompleted(streamID) + cf.Handler.OnAcked(cf.Frame) + require.True(t, mockCtrl.Satisfied()) + // calls to Write should return an error _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true}) @@ -835,7 +932,7 @@ func TestSendStreamConcurrentWriteAndCancel(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamControlFrame(gomock.Any(), gomock.Any()).MaxTimes(1) mockSender.EXPECT().onHasStreamData(streamID, str).MaxTimes(1) @@ -884,7 +981,7 @@ func TestSendStreamRetransmissions(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := str.Write([]byte("foo")) @@ -937,7 +1034,7 @@ func TestSendStreamRetransmissionFraming(t *testing.T) { mockCtrl := gomock.NewController(t) mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockSender := NewMockStreamSender(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str) _, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) @@ -987,7 +1084,7 @@ func TestSendStreamRetransmitDataUntilAcknowledged(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) - str := newSendStream(context.Background(), streamID, mockSender, mockFC) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, false) mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes() mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { @@ -1011,20 +1108,482 @@ func TestSendStreamRetransmitDataUntilAcknowledged(t *testing.T) { mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) received := make([]byte, dataLen) - for !completed { + var counter int + frameQueue := make([]ackhandler.StreamFrame, 0, 32) + for !completed || len(frameQueue) > 0 { + counter++ + if counter > 1e6 { + t.Fatal("stream should have completed") + } f, _, _ := str.popStreamFrame(protocol.ByteCount(mrand.IntN(300)+100), protocol.Version1) - if f.Frame == nil { - continue + var dequeuedFrame bool + if f.Frame != nil { + frameQueue = append(frameQueue, f) + dequeuedFrame = true } - sf := f.Frame - // 50%: acknowledge the frame and save the data - // 50%: lose the frame - if mrand.IntN(100) < 50 { - copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) - f.Handler.OnAcked(f.Frame) - } else { - f.Handler.OnLost(f.Frame) + + // Process one of the queued frames at random. + // This simulates potential reordering. + if len(frameQueue) > 0 && (!dequeuedFrame || len(frameQueue) == cap(frameQueue)) { + idx := mrand.IntN(len(frameQueue)) + f := frameQueue[idx] + // 50%: acknowledge the frame and save the data + // 50%: lose the frame + if mrand.Int()%2 == 0 { + copy(received[f.Frame.Offset:f.Frame.Offset+f.Frame.DataLen()], f.Frame.Data) + f.Handler.OnAcked(f.Frame) + } else { + f.Handler.OnLost(f.Frame) + } + frameQueue = slices.Delete(frameQueue, idx, idx+1) } + runtime.Gosched() } require.Equal(t, data, received) } + +func TestSendStreamResetStreamAtCancelBeforeSend(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) + + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) + _, err := str.Write([]byte("foobar")) + require.NoError(t, err) + str.SetReliableBoundary() + _, err = str.Write([]byte("baz")) + require.NoError(t, err) + + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + str.CancelWrite(1337) + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 6, ErrorCode: 1337, ReliableSize: 6}, cf.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + mockFC.EXPECT().IsNewlyBlocked() + f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, + f.Frame, + ) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // Lose the frame. + // Since it's before the reliable size, we should get a retransmission. + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str) + f.Handler.OnLost(f.Frame) + require.True(t, mockCtrl.Satisfied()) + + retransmission, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, + retransmission.Frame, + ) + require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission + require.True(t, mockCtrl.Satisfied()) + f, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Nil(t, f.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // acknowledging the RESET_STREAM_AT and the retransmission completes the stream + cf.Handler.OnAcked(cf.Frame) + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) + retransmission.Handler.OnAcked(retransmission.Frame) +} + +func TestSendStreamResetStreamAtCancelAfterSend(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) + + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) + _, err := str.Write([]byte("foobar")) + require.NoError(t, err) + str.SetReliableBoundary() + _, err = str.Write([]byte("baz")) + require.NoError(t, err) + + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(9)) + f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("foobarbaz"), DataLenPresent: true}, + f.Frame, + ) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + str.CancelWrite(42) + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 6}, cf.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + cf.Handler.OnAcked(cf.Frame) + // lose the STREAM frame + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str) + f.Handler.OnLost(f.Frame) + // only the first 6 bytes need to be retransmitted + retransmission1, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, + retransmission1.Frame, + ) + require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission + require.True(t, mockCtrl.Satisfied()) + f, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Nil(t, f.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // lose the retransmission as well + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str) + retransmission1.Handler.OnLost(retransmission1.Frame) + retransmission2, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("foobar"), DataLenPresent: true}, + retransmission2.Frame, + ) + require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission + require.True(t, mockCtrl.Satisfied()) + f, _, hasMore = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Nil(t, f.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // acknowledge the 2nd retransmission + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) + retransmission2.Handler.OnAcked(retransmission2.Frame) +} + +func TestSendStreamResetStreamAtRetransmissions(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) + + // f1: lorem + // f2: ipsumdolor (reliable offset: right after the "ipsum") + // f3: sit + // f4: amet + // sitting in the write buffer: consectetur (but not popped) + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).AnyTimes() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() + _, err := str.Write([]byte("lorem")) + require.NoError(t, err) + f1, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("lorem"), DataLenPresent: true}, + f1.Frame, + ) + _, err = str.Write([]byte("ipsum")) + require.NoError(t, err) + str.SetReliableBoundary() + _, err = str.Write([]byte("dolor")) + require.NoError(t, err) + f2, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Offset: 5, Data: []byte("ipsumdolor"), DataLenPresent: true}, + f2.Frame, + ) + _, err = str.Write([]byte("sit")) + require.NoError(t, err) + f3, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Offset: 15, Data: []byte("sit"), DataLenPresent: true}, + f3.Frame, + ) + _, err = str.Write([]byte("amet")) + require.NoError(t, err) + f4, _, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Offset: 18, Data: []byte("amet"), DataLenPresent: true}, + f4.Frame, + ) + _, err = str.Write([]byte("consectetur")) + require.NoError(t, err) + + // lose the frames, in no particular order + f2.Handler.OnLost(f2.Frame) + f1.Handler.OnLost(f1.Frame) + f3.Handler.OnLost(f3.Frame) + // f4 is lost at a later point + + // Now cancel the stream. + // We expect f1 and the first half of f2 to be retransmitted, + // but f3 and the data in the buffer should not. + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + str.CancelWrite(42) + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 22, ErrorCode: 42, ReliableSize: 10}, cf.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + cf.Handler.OnAcked(cf.Frame) + + // // the retransmission of f1 should be truncated to 6 bytes + r1, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Offset: 5, Data: []byte("ipsum"), DataLenPresent: true}, + r1.Frame, + ) + require.True(t, hasMore) + r2, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.EqualExportedValues(t, + &wire.StreamFrame{StreamID: 1337, Data: []byte("lorem"), DataLenPresent: true}, + r2.Frame, + ) + require.True(t, hasMore) // hasMore is always true when dequeuing a retransmission + require.True(t, mockCtrl.Satisfied()) + r3, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Nil(t, r3.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + r1.Handler.OnAcked(r1.Frame) + r2.Handler.OnAcked(r2.Frame) + require.True(t, mockCtrl.Satisfied()) + + // the stream is only completed once f4 is lost + // it's beyond the reliable size, so it's not retransmitted + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) + f4.Handler.OnLost(f4.Frame) +} + +func TestSendStreamResetStreamAtStopSendingBeforeCancelation(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) + + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) + _, err := str.Write([]byte("foobar")) + require.NoError(t, err) + str.SetReliableBoundary() + _, err = str.Write([]byte("baz")) + require.NoError(t, err) + + // send out a STREAM frame with all the data written so far + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(9)) + f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Equal(t, protocol.ByteCount(9), f.Frame.DataLen()) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: 1337, ErrorCode: 42}) + cf, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + // Since the peer reset the stream, the resulting RESET_STREAM frame has a reliable size of 0 + require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 0}, cf.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // calling CancelWrite doesn't cause any more frames to be enqueued + str.CancelWrite(1234) + + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) + cf.Handler.OnAcked(cf.Frame) +} + +func TestSendStreamResetStreamAtStopSendingAfterCancelation(t *testing.T) { + t.Run("RESET_STREAM_AT lost", func(t *testing.T) { + testSendStreamResetStreamAtStopSendingAfterCancelation(t, true) + }) + t.Run("RESET_STREAM_AT acknowledged", func(t *testing.T) { + testSendStreamResetStreamAtStopSendingAfterCancelation(t, false) + }) +} + +func testSendStreamResetStreamAtStopSendingAfterCancelation(t *testing.T, loseResetStreamAt bool) { + mockCtrl := gomock.NewController(t) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + mockSender := NewMockStreamSender(mockCtrl) + str := newSendStream(context.Background(), 1337, mockSender, mockFC, true) + + mockSender.EXPECT().onHasStreamData(protocol.StreamID(1337), str).Times(2) + _, err := str.Write([]byte("foobar")) + require.NoError(t, err) + str.SetReliableBoundary() + _, err = str.Write([]byte("baz")) + require.NoError(t, err) + + // send out a STREAM frame with all the data written so far + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(9)) + f, _, hasMore := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) + require.Equal(t, protocol.ByteCount(9), f.Frame.DataLen()) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + // Canceling the stream results in a RESET_STREAM_AT frame. + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + str.CancelWrite(42) + cf1, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 6}, cf1.Frame) + require.False(t, hasMore) + + // Receiving a STOP_SENDING frame results in a RESET_STREAM frame, + // effectively reducing the reliable size to 0. + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + str.handleStopSendingFrame(&wire.StopSendingFrame{StreamID: 1337, ErrorCode: 1234}) + cf2, ok, hasMore := str.getControlFrame(time.Now()) + require.True(t, ok) + // Since the peer reset the stream, the resulting RESET_STREAM frame has a reliable size of 0. + // The error code is still the one used for the CancelWrite call. + require.Equal(t, &wire.ResetStreamFrame{StreamID: 1337, FinalSize: 9, ErrorCode: 42, ReliableSize: 0}, cf2.Frame) + require.False(t, hasMore) + require.True(t, mockCtrl.Satisfied()) + + if loseResetStreamAt { + // losing the RESET_STREAM_AT frame does nothing + cf1.Handler.OnLost(cf1.Frame) + } else { + // receiving an acknowledgment for the RESET_STREAM_AT frame does nothing either: + // the RESET_STREAM frame still needs to be transmitted reliably + cf1.Handler.OnAcked(cf1.Frame) + } + _, ok, _ = str.getControlFrame(time.Now()) + require.False(t, ok) + + // but when the RESET_STREAM frame is lost, it needs to be retransmitted + mockSender.EXPECT().onHasStreamControlFrame(protocol.StreamID(1337), str) + cf2.Handler.OnLost(cf2.Frame) + cf3, ok, _ := str.getControlFrame(time.Now()) + require.True(t, ok) + require.Equal(t, cf2, cf3) + + mockSender.EXPECT().onStreamCompleted(protocol.StreamID(1337)) + cf3.Handler.OnAcked(cf3.Frame) +} + +func TestSendStreamResetStreamAtRandomized(t *testing.T) { + const streamID protocol.StreamID = 123456 + const dataLen = 8 << 10 + reliableOffset := 1 + mrand.IntN(dataLen*3/4) + t.Logf("reliable offset: %d", reliableOffset) + + mockCtrl := gomock.NewController(t) + mockSender := NewMockStreamSender(mockCtrl) + mockFC := mocks.NewMockStreamFlowController(mockCtrl) + str := newSendStream(context.Background(), streamID, mockSender, mockFC, true) + + mockSender.EXPECT().onHasStreamData(streamID, str).AnyTimes() + mockSender.EXPECT().onHasStreamControlFrame(streamID, str).AnyTimes() + mockFC.EXPECT().SendWindowSize().DoAndReturn(func() protocol.ByteCount { + return protocol.ByteCount(mrand.IntN(500)) + 50 + }).AnyTimes() + mockFC.EXPECT().IsNewlyBlocked().Return(false).AnyTimes() + mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() + + data := make([]byte, dataLen) + _, err := rand.Read(data) + require.NoError(t, err) + errChan := make(chan error, 1) + go func() { + b := data + var offset int + for len(b) > 0 { + m := mrand.IntN(1024) + if offset < reliableOffset { + m = min(m, reliableOffset-offset) + } + n, err := str.Write(b[:min(m, len(b))]) + if err != nil { + errChan <- err + return + } + offset += n + if offset <= reliableOffset { + str.SetReliableBoundary() + } + b = b[n:] + } + str.CancelWrite(1234) + errChan <- nil + }() + + var completed bool + mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { completed = true }) + + received := make([]byte, dataLen) + var highestOffset int + var receivedResetStreamAt bool + var counter int + frameQueue := make([]any, 0, 10) + for !completed || len(frameQueue) > 0 { + counter++ + if counter > 1e6 { + t.Fatal("stream should have completed") + } + var dequeuedFrame bool + cf, ok, _ := str.getControlFrame(time.Now()) + if ok { + dequeuedFrame = true + frameQueue = append(frameQueue, cf) + receivedResetStreamAt = true + require.Equal(t, protocol.ByteCount(reliableOffset), cf.Frame.(*wire.ResetStreamFrame).ReliableSize) + } else { + f, _, _ := str.popStreamFrame(protocol.ByteCount(mrand.IntN(300)+100), protocol.Version1) + if f.Frame != nil { + // make sure that only retransmissions are sent once the RESET_STREAM_AT frame is sent + if receivedResetStreamAt { + require.LessOrEqualf(t, + f.Frame.Offset+f.Frame.DataLen(), + protocol.ByteCount(reliableOffset), + "STREAM frame past reliable offset after RESET_STREAM_AT (offset: %d, data length: %d)", + f.Frame.Offset, f.Frame.DataLen(), + ) + } + dequeuedFrame = true + frameQueue = append(frameQueue, f) + } + } + + if len(frameQueue) > 0 && (!dequeuedFrame || len(frameQueue) == cap(frameQueue)) { + idx := mrand.IntN(len(frameQueue)) + switch f := frameQueue[idx].(type) { + case ackhandler.Frame: + // 50%: acknowledge the frame + // 50%: lose the frame + if mrand.Int()%2 == 0 { + f.Handler.OnLost(f.Frame) + } else { + f.Handler.OnAcked(f.Frame) + } + case ackhandler.StreamFrame: + sf := f.Frame + // 50%: acknowledge the frame and save the data + // 50%: lose the frame + if mrand.Int()%2 == 0 { + f.Handler.OnLost(f.Frame) + } else { + highestOffset = max(highestOffset, int(sf.Offset+sf.DataLen())) + copy(received[sf.Offset:sf.Offset+sf.DataLen()], sf.Data) + f.Handler.OnAcked(f.Frame) + } + default: + t.Fatalf("unexpected frame type: %T", f) + } + frameQueue = slices.Delete(frameQueue, idx, idx+1) + } + runtime.Gosched() + } + + t.Logf("highest received offset: %d", highestOffset) + require.GreaterOrEqual(t, highestOffset, reliableOffset) + require.Equal(t, data[:reliableOffset], received[:reliableOffset]) +} diff --git a/stream.go b/stream.go index 9c5a40e2a15..64bcc838caa 100644 --- a/stream.go +++ b/stream.go @@ -71,6 +71,7 @@ func newStream( streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, + supportsResetStreamAt bool, ) *Stream { s := &Stream{sender: sender} senderForSendStream := &uniStreamSender{ @@ -85,7 +86,7 @@ func newStream( sender.onHasStreamControlFrame(streamID, s) }, } - s.sendStr = newSendStream(ctx, streamID, senderForSendStream, flowController) + s.sendStr = newSendStream(ctx, streamID, senderForSendStream, flowController, supportsResetStreamAt) senderForReceiveStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { @@ -162,6 +163,10 @@ func (s *Stream) updateSendWindow(limit protocol.ByteCount) { s.sendStr.updateSendWindow(limit) } +func (s *Stream) enableResetStreamAt() { + s.sendStr.enableResetStreamAt() +} + func (s *Stream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) { return s.sendStr.popStreamFrame(maxBytes, v) } diff --git a/stream_test.go b/stream_test.go index 4b4807304b3..d4ee0a44bfb 100644 --- a/stream_test.go +++ b/stream_test.go @@ -20,7 +20,7 @@ func TestStreamDeadlines(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) - str := newStream(context.Background(), streamID, mockSender, mockFC) + str := newStream(context.Background(), streamID, mockSender, mockFC, false) // SetDeadline sets both read and write deadlines str.SetDeadline(time.Now().Add(-time.Second)) @@ -82,7 +82,7 @@ func TestStreamCompletion(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) - str := newStream(context.Background(), streamID, mockSender, mockFC) + str := newStream(context.Background(), streamID, mockSender, mockFC, false) completeReadSide(t, str, mockCtrl, mockFC) mockSender.EXPECT().onStreamCompleted(streamID) @@ -93,7 +93,7 @@ func TestStreamCompletion(t *testing.T) { mockCtrl := gomock.NewController(t) mockSender := NewMockStreamSender(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl) - str := newStream(context.Background(), streamID, mockSender, mockFC) + str := newStream(context.Background(), streamID, mockSender, mockFC, false) completeWriteSide(t, str, mockCtrl, mockFC, mockSender) mockSender.EXPECT().onStreamCompleted(streamID) diff --git a/streams_map.go b/streams_map.go index 7584744a689..14be5ad0bd1 100644 --- a/streams_map.go +++ b/streams_map.go @@ -30,12 +30,13 @@ type streamsMap struct { queueControlFrame func(wire.Frame) newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController - mutex sync.Mutex - outgoingBidiStreams *outgoingStreamsMap[*Stream] - outgoingUniStreams *outgoingStreamsMap[*SendStream] - incomingBidiStreams *incomingStreamsMap[*Stream] - incomingUniStreams *incomingStreamsMap[*ReceiveStream] - reset bool + mutex sync.Mutex + outgoingBidiStreams *outgoingStreamsMap[*Stream] + outgoingUniStreams *outgoingStreamsMap[*SendStream] + incomingBidiStreams *incomingStreamsMap[*Stream] + incomingUniStreams *incomingStreamsMap[*ReceiveStream] + reset bool + supportsResetStreamAt bool } func newStreamsMap( @@ -64,7 +65,7 @@ func (m *streamsMap) initMaps() { m.outgoingBidiStreams = newOutgoingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *Stream { - return newStream(m.ctx, id, m.sender, m.newFlowController(id)) + return newStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt) }, m.queueControlFrame, m.perspective, @@ -72,7 +73,7 @@ func (m *streamsMap) initMaps() { m.incomingBidiStreams = newIncomingStreamsMap( protocol.StreamTypeBidi, func(id protocol.StreamID) *Stream { - return newStream(m.ctx, id, m.sender, m.newFlowController(id)) + return newStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt) }, m.maxIncomingBidiStreams, m.queueControlFrame, @@ -81,7 +82,7 @@ func (m *streamsMap) initMaps() { m.outgoingUniStreams = newOutgoingStreamsMap( protocol.StreamTypeUni, func(id protocol.StreamID) *SendStream { - return newSendStream(m.ctx, id, m.sender, m.newFlowController(id)) + return newSendStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt) }, m.queueControlFrame, m.perspective, @@ -316,7 +317,10 @@ func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) e return str.handleStreamFrame(f, rcvTime) } -func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { +func (m *streamsMap) HandleTransportParameters(p *wire.TransportParameters) { + m.supportsResetStreamAt = p.EnableResetStreamAt + m.outgoingBidiStreams.EnableResetStreamAt() + m.outgoingUniStreams.EnableResetStreamAt() m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective)) m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index 643a43aa2b4..9e2785f7a5b 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -17,9 +17,10 @@ import ( type mockStream struct { id protocol.StreamID - closed bool - closeErr error - sendWindow protocol.ByteCount + closed bool + closeErr error + sendWindow protocol.ByteCount + supportsResetStreamAt bool } func (s *mockStream) closeForShutdown(err error) { @@ -31,6 +32,10 @@ func (s *mockStream) updateSendWindow(limit protocol.ByteCount) { s.sendWindow = limit } +func (s *mockStream) enableResetStreamAt() { + s.supportsResetStreamAt = true +} + func TestStreamsMapIncomingGettingStreams(t *testing.T) { t.Run("client", func(t *testing.T) { testStreamsMapIncomingGettingStreams(t, protocol.PerspectiveClient, protocol.FirstIncomingUniStreamClient) diff --git a/streams_map_outgoing.go b/streams_map_outgoing.go index 28b2179f1bd..7d7975a5ebe 100644 --- a/streams_map_outgoing.go +++ b/streams_map_outgoing.go @@ -13,6 +13,7 @@ import ( type outgoingStream interface { updateSendWindow(protocol.ByteCount) + enableResetStreamAt() closeForShutdown(error) } @@ -204,6 +205,14 @@ func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) { m.mutex.Unlock() } +func (m *outgoingStreamsMap[T]) EnableResetStreamAt() { + m.mutex.Lock() + for _, str := range m.streams { + str.enableResetStreamAt() + } + m.mutex.Unlock() +} + // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream func (m *outgoingStreamsMap[T]) maybeUnblockOpenSync() { if len(m.openQueue) == 0 { diff --git a/streams_map_outgoing_test.go b/streams_map_outgoing_test.go index 20330257bdc..c0c8d2860dc 100644 --- a/streams_map_outgoing_test.go +++ b/streams_map_outgoing_test.go @@ -54,6 +54,11 @@ func testStreamsMapOutgoingOpenAndDelete(t *testing.T, perspective protocol.Pers require.Equal(t, protocol.ByteCount(1000), str1.sendWindow) require.Equal(t, protocol.ByteCount(1000), str2.sendWindow) + // enable reset stream at + m.EnableResetStreamAt() + require.True(t, str1.supportsResetStreamAt) + require.True(t, str2.supportsResetStreamAt) + err = m.DeleteStream(firstStream + 1337*4) require.Error(t, err) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) diff --git a/streams_map_test.go b/streams_map_test.go index 07274e8eaee..ebb11a8da45 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -58,7 +58,7 @@ func testStreamsMapCreatingStreams(t *testing.T, 1, perspective, ) - m.UpdateLimits(&wire.TransportParameters{ + m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: protocol.MaxStreamCount, MaxUniStreamNum: protocol.MaxStreamCount, }) @@ -135,7 +135,7 @@ func testStreamsMapDeletingStreams(t *testing.T, 100, perspective, ) - m.UpdateLimits(&wire.TransportParameters{ + m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: 10, MaxUniStreamNum: 10, }) @@ -227,7 +227,7 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) _, err := m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) require.ErrorContains(t, err, "too many open streams") - m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 1}) + m.HandleTransportParameters(&wire.TransportParameters{MaxBidiStreamNum: 1}) _, err = m.OpenStream() require.NoError(t, err) _, err = m.OpenStream() @@ -235,7 +235,7 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) _, err = m.OpenUniStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) - m.UpdateLimits(&wire.TransportParameters{MaxUniStreamNum: 1}) + m.HandleTransportParameters(&wire.TransportParameters{MaxUniStreamNum: 1}) _, err = m.OpenUniStream() require.NoError(t, err) _, err = m.OpenUniStream() @@ -261,7 +261,7 @@ func testStreamsMapStreamLimits(t *testing.T, perspective protocol.Perspective) require.ErrorIs(t, err, &StreamLimitReachedError{}) // decrease via transport parameters - m.UpdateLimits(&wire.TransportParameters{MaxBidiStreamNum: 0}) + m.HandleTransportParameters(&wire.TransportParameters{MaxBidiStreamNum: 0}) _, err = m.OpenStream() require.ErrorIs(t, err, &StreamLimitReachedError{}) } @@ -544,7 +544,7 @@ func TestStreamsMap0RTT(t *testing.T) { protocol.PerspectiveClient, ) // restored transport parameters - m.UpdateLimits(&wire.TransportParameters{ + m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: 1, MaxUniStreamNum: 1, }) @@ -556,7 +556,7 @@ func TestStreamsMap0RTT(t *testing.T) { fcBidi.EXPECT().UpdateSendWindow(protocol.ByteCount(1234)) fcUni.EXPECT().UpdateSendWindow(protocol.ByteCount(4321)) // new transport parameters - m.UpdateLimits(&wire.TransportParameters{ + m.HandleTransportParameters(&wire.TransportParameters{ MaxBidiStreamNum: 1000, InitialMaxStreamDataBidiRemote: 1234, MaxUniStreamNum: 1000, From 08e9c7e7acb8cf58c953f6c7cd2124667e3de47e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 28 Jun 2025 15:09:36 +0800 Subject: [PATCH 04/14] fix flaky TestTransportReplaceWithClosed (#5245) --- transport_test.go | 91 +++++++++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/transport_test.go b/transport_test.go index 8c918250670..37ec043878f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5,8 +5,10 @@ import ( "context" "crypto/tls" "errors" + "math" "net" - "os" + "runtime" + "sync/atomic" "syscall" "testing" "time" @@ -598,6 +600,8 @@ func TestTransportDialingVersionNegotiation(t *testing.T) { } func TestTransportReplaceWithClosed(t *testing.T) { + t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") + t.Run("local", func(t *testing.T) { testTransportReplaceWithClosed(t, true) }) @@ -616,7 +620,7 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) { tr.init(true) defer tr.Close() - dur := scaleDuration(10 * time.Millisecond) + dur := scaleDuration(20 * time.Millisecond) var closePacket []byte if local { @@ -627,7 +631,6 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) { connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) m := (*packetHandlerMap)(tr) require.True(t, m.Add(connID, handler)) - start := time.Now() m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur) p := make([]byte, 100) @@ -635,38 +638,64 @@ func testTransportReplaceWithClosed(t *testing.T, local bool) { copy(p[1:], connID.Bytes()) conn := newUDPConnLocalhost(t) - var sent int - for now := range time.NewTicker(dur / 20).C { - _, err := conn.WriteTo(p, tr.Conn.LocalAddr()) - require.NoError(t, err) - sent++ - if now.After(start.Add(dur / 2)) { - break + var sent atomic.Int64 + errChan := make(chan error, 1) + stopSending := make(chan struct{}) + go func() { + defer close(errChan) + ticker := time.NewTicker(dur / 50) + timeout := time.NewTimer(scaleDuration(time.Second)) + for { + select { + case <-stopSending: + return + case <-timeout.C: + errChan <- errors.New("timeout") + return + case <-ticker.C: + } + if _, err := conn.WriteTo(p, tr.Conn.LocalAddr()); err != nil { + errChan <- err + return + } + sent.Add(1) } - } + }() + // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff - for i := 0; i*i < sent; i++ { - conn.SetReadDeadline(time.Now().Add(time.Second)) - b := make([]byte, 100) - if local { - n, _, err := conn.ReadFrom(b) - require.NoError(t, err) - require.Equal(t, []byte("foobar"), b[:n]) - } - } - // Afterwards, we receive a stateless reset, not a copy of the CONNECTION_CLOSE packet. - // Retry a few times, since the connection is deleted from the map on a timer. - require.Eventually(t, func() bool { - _, err := conn.WriteTo(p, tr.Conn.LocalAddr()) - require.NoError(t, err) - conn.SetReadDeadline(time.Now().Add(dur / 4)) + var received int + conn.SetReadDeadline(time.Now().Add(scaleDuration(time.Second))) + for { b := make([]byte, 100) n, _, err := conn.ReadFrom(b) - if errors.Is(err, os.ErrDeadlineExceeded) || bytes.Equal(b[:n], []byte("foobar")) { - return false + require.NoError(t, err) + // at some point, the connection is cleaned up, and we'll receive a stateless reset + if !bytes.Equal(b[:n], []byte("foobar")) { + require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) + close(stopSending) // stop sending packets + break } + received++ + } + + select { + case err := <-errChan: require.NoError(t, err) - require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize) - return true - }, scaleDuration(200*time.Millisecond), scaleDuration(10*time.Millisecond)) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + numSent := sent.Load() + if !local { + require.Zero(t, received) + t.Logf("sent %d packets", numSent) + return + } + t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received) + // timer resolution on Windows is terrible + if runtime.GOOS != "windows" { + require.GreaterOrEqual(t, numSent, int64(8)) + } + require.GreaterOrEqual(t, received, int(math.Floor(math.Log2(float64(numSent))))) + require.LessOrEqual(t, received, int(math.Ceil(math.Log2(float64(numSent))))) } From 3f3d3099b777207e312de49bdd062490fef0cae1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 28 Jun 2025 15:37:18 +0800 Subject: [PATCH 05/14] fix flaky TestDrainServerAcceptQueue (#5247) The connections are not necessarily accepted in the same order that they are dialed. --- integrationtests/self/close_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index ec4a52bbd0a..ce7df5dcd70 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -95,7 +95,7 @@ func TestDrainServerAcceptQueue(t *testing.T) { defer cancel() // fill up the accept queue conns := make([]*quic.Conn, 0, protocol.MaxAcceptQueueSize) - for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + for range protocol.MaxAcceptQueueSize { conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) require.NoError(t, err) conns = append(conns, conn) @@ -107,9 +107,8 @@ func TestDrainServerAcceptQueue(t *testing.T) { c, err := server.Accept(ctx) require.NoError(t, err) // make sure the connection is not closed - require.NoError(t, conns[i].Context().Err(), "client connection closed") - require.NoError(t, c.Context().Err(), "server connection closed") - conns[i].CloseWithError(0, "") + require.NoError(t, context.Cause(conns[i].Context()), "client connection closed") + require.NoError(t, context.Cause(c.Context()), "server connection closed") c.CloseWithError(0, "") } _, err = server.Accept(ctx) From dadc8db8360e09ef60ad164840d0812253d83daa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 28 Jun 2025 21:19:28 +0800 Subject: [PATCH 06/14] fix flaky TestServerReceiveQueue (#5249) --- server_test.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server_test.go b/server_test.go index 3fd84691dec..6de609d7169 100644 --- a/server_test.go +++ b/server_test.go @@ -858,7 +858,9 @@ func TestServerGetConfigForClientReject(t *testing.T) { func TestServerReceiveQueue(t *testing.T) { mockCtrl := gomock.NewController(t) acceptConn := make(chan struct{}) + defer close(acceptConn) tracer, mockTracer := mocklogging.NewMockTracer(mockCtrl) + newConnChan := make(chan struct{}, protocol.MaxServerUnprocessedPackets+2) server := newTestServer(t, &serverOpts{ tracer: tracer, newConn: func( @@ -882,14 +884,23 @@ func TestServerReceiveQueue(t *testing.T) { _ utils.Logger, _ protocol.Version, ) *wrappedConn { + newConnChan <- struct{}{} <-acceptConn return &wrappedConn{testHooks: &connTestHooks{handlePacket: func(receivedPacket) {}}} }, }) conn := newUDPConnLocalhost(t) - for range protocol.MaxServerUnprocessedPackets + 1 { + for i := range protocol.MaxServerUnprocessedPackets + 1 { server.handlePacket(getValidInitialPacket(t, conn.LocalAddr(), randConnID(6), randConnID(8))) + // newConn blocks on the acceptConn channel, so this blocks the server's run loop + if i == 0 { + select { + case <-newConnChan: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } } done := make(chan struct{}) @@ -904,7 +915,6 @@ func TestServerReceiveQueue(t *testing.T) { case <-time.After(time.Second): t.Fatal("timeout") } - close(acceptConn) } func TestServerAccept(t *testing.T) { From 61d2fa57ac6e31baa363498325033fb73356096e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 29 Jun 2025 02:28:07 +0800 Subject: [PATCH 07/14] http3: fix flaky TestConnGoAwayFailures (#5252) --- http3/conn_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/http3/conn_test.go b/http3/conn_test.go index 407ae543a79..21e50afa3e4 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -213,6 +213,8 @@ func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, exp conn.handleUnidirectionalStreams(nil) }() + conn.openRequestStream(context.Background(), nil, nil, true, 1000) + switch readErr { case nil: _, err = controlStr.Write(data) @@ -227,8 +229,6 @@ func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, exp controlStr.CancelWrite(1337) } - conn.openRequestStream(context.Background(), nil, nil, true, 1000) - select { case <-serverConn.Context().Done(): require.ErrorIs(t, From 0eb237f7973e232c73f58a4ebc7227c27070cbaa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 29 Jun 2025 11:42:02 +0800 Subject: [PATCH 08/14] add a Config and ConnectionState flag for RESET_STREAM_AT (#5243) * add a Config and ConnectionState flag for RESET_STREAM_AT * add RESET_STREAM_AT to README --- README.md | 1 + config.go | 37 +++++++++++++++++++------------------ config_test.go | 6 ++++-- connection.go | 8 +++++++- interface.go | 7 ++++++- 5 files changed, 37 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ccc9e2133ce..246c330490c 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ In addition to these base RFCs, it also implements the following RFCs: * Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)) * QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369)) * QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/)) +* QUIC Stream Resets with Partial Delivery ([draft-ietf-quic-reliable-stream-reset](https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07)) Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go). diff --git a/config.go b/config.go index 540a3240bca..74c2054e45c 100644 --- a/config.go +++ b/config.go @@ -106,23 +106,24 @@ func populateConfig(config *Config) *Config { } return &Config{ - GetConfigForClient: config.GetConfigForClient, - Versions: versions, - HandshakeIdleTimeout: handshakeIdleTimeout, - MaxIdleTimeout: idleTimeout, - KeepAlivePeriod: config.KeepAlivePeriod, - InitialStreamReceiveWindow: initialStreamReceiveWindow, - MaxStreamReceiveWindow: maxStreamReceiveWindow, - InitialConnectionReceiveWindow: initialConnectionReceiveWindow, - MaxConnectionReceiveWindow: maxConnectionReceiveWindow, - AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - TokenStore: config.TokenStore, - EnableDatagrams: config.EnableDatagrams, - InitialPacketSize: initialPacketSize, - DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, - Allow0RTT: config.Allow0RTT, - Tracer: config.Tracer, + GetConfigForClient: config.GetConfigForClient, + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + InitialPacketSize: initialPacketSize, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + EnableStreamResetPartialDelivery: config.EnableStreamResetPartialDelivery, + Allow0RTT: config.Allow0RTT, + Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index b1c1c6c6e5f..b8e4fc39144 100644 --- a/config_test.go +++ b/config_test.go @@ -126,6 +126,8 @@ func configWithNonZeroNonFunctionFields(t *testing.T) *Config { f.Set(reflect.ValueOf(true)) case "Allow0RTT": f.Set(reflect.ValueOf(true)) + case "EnableStreamResetPartialDelivery": + f.Set(reflect.ValueOf(true)) default: t.Fatalf("all fields must be accounted for, but saw unknown field %q", fn) } @@ -133,7 +135,7 @@ func configWithNonZeroNonFunctionFields(t *testing.T) *Config { return c } -func TestConfigCloning(t *testing.T) { +func TestConfigClone(t *testing.T) { t.Run("function fields", func(t *testing.T) { var calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ @@ -153,7 +155,7 @@ func TestConfigCloning(t *testing.T) { require.True(t, calledTracer) }) - t.Run("clones non-function fields", func(t *testing.T) { + t.Run("non-function fields", func(t *testing.T) { c := configWithNonZeroNonFunctionFields(t) require.Equal(t, c, c.Clone()) }) diff --git a/connection.go b/connection.go index a6af5fcda3c..ce565aa014f 100644 --- a/connection.go +++ b/connection.go @@ -315,6 +315,7 @@ var newConnection = func( ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, RetrySourceConnectionID: retrySrcConnID, + EnableResetStreamAt: conf.EnableStreamResetPartialDelivery, } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = wire.MaxDatagramSize @@ -425,6 +426,7 @@ var newClientConnection = func( // See https://github.com/quic-go/quic-go/pull/3806. ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, + EnableResetStreamAt: conf.EnableStreamResetPartialDelivery, } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = wire.MaxDatagramSize @@ -468,7 +470,10 @@ func (c *Conn) preSetup() { c.handshakeStream = newCryptoStream() c.sendQueue = newSendQueue(c.conn) c.retransmissionQueue = newRetransmissionQueue() - c.frameParser = *wire.NewFrameParser(c.config.EnableDatagrams, false) + c.frameParser = *wire.NewFrameParser( + c.config.EnableDatagrams, + c.config.EnableStreamResetPartialDelivery, + ) c.rttStats = &utils.RTTStats{} c.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ByteCount(c.config.InitialConnectionReceiveWindow), @@ -722,6 +727,7 @@ func (c *Conn) ConnectionState() ConnectionState { cs := c.cryptoStreamHandler.ConnectionState() c.connState.TLS = cs.ConnectionState c.connState.Used0RTT = cs.Used0RTT + c.connState.SupportsStreamResetPartialDelivery = c.peerParams.EnableResetStreamAt c.connState.GSO = c.conn.capabilities().GSO return c.connState } diff --git a/interface.go b/interface.go index 4ba75a378e3..70dcdef1c0a 100644 --- a/interface.go +++ b/interface.go @@ -179,7 +179,10 @@ type Config struct { Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer + // Enable QUIC Stream Resets with Partial Delivery. + // See https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07. + EnableStreamResetPartialDelivery bool + Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer } // ClientHelloInfo contains information about an incoming connection attempt. @@ -207,6 +210,8 @@ type ConnectionState struct { // This is a unilateral declaration by the peer - receiving datagrams is only possible if // datagram support was enabled locally via Config.EnableDatagrams. SupportsDatagrams bool + // SupportsStreamResetPartialDelivery indicates whether the peer advertised support for QUIC Stream Resets with Partial Delivery. + SupportsStreamResetPartialDelivery bool // Used0RTT says if 0-RTT resumption was used. Used0RTT bool // Version is the QUIC version of the QUIC connection. From fd32cf5c69a5941890d55041313a71a2e42a7773 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 29 Jun 2025 13:09:45 +0800 Subject: [PATCH 09/14] fix flaky TestPostQuantumClientHello (#5253) --- integrationtests/self/handshake_drop_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 7d1d10b73a8..211d635606a 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -17,7 +17,6 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" ) @@ -277,8 +276,10 @@ func TestPostQuantumClientHello(t *testing.T) { b := make([]byte, 2500) // the ClientHello will now span across 3 packets rand.Read(b) wire.AdditionalTransportParametersClient = map[uint64][]byte{ - // Avoid random collisions with the greased transport parameters. - uint64(27+31*(1000+mrand.IntN(31))/31) % quicvarint.Max: b, + // We don't use a greased transport parameter here, since the transport parameter serialization function + // will add a greased transport parameter, and therefore there's a risk of a collision. + // Instead, we just use pseudorandom constant value. + 1234567: b, } ln, proxyPort := startDropTestListenerAndProxy(t, 10*time.Millisecond, 20*time.Second, dropCallbackDropOneThird(quicproxy.DirectionIncoming), false, false) From 0a9c6ea4c8c55e9c0616eb6efc571a420905e370 Mon Sep 17 00:00:00 2001 From: Robin Thellend Date: Mon, 7 Jul 2025 04:41:23 -0700 Subject: [PATCH 10/14] http3: remove dependency on quic internal packages (#5256) * Remove http3 dependency on quic internal packages Remove the dependency on internal/protocol from the http3 package. This makes it possible for a forked http3 to use the mainline quic-go package. * Address review comments * Fix syntax * Use broader pattern for http3 directory * Copy internal/testdata * Replace perspective with bool * clone the supported version slice --------- Co-authored-by: Marten Seemann --- .golangci.yml | 7 ++++ http3/client.go | 3 +- http3/conn.go | 47 +++++++++++---------- http3/conn_test.go | 27 ++++++------ http3/http3_helper_test.go | 12 +++--- http3/internal/testdata/ca.pem | 17 ++++++++ http3/internal/testdata/cert.go | 56 +++++++++++++++++++++++++ http3/internal/testdata/cert.pem | 18 ++++++++ http3/internal/testdata/cert_test.go | 28 +++++++++++++ http3/internal/testdata/generate_key.sh | 24 +++++++++++ http3/internal/testdata/priv.key | 28 +++++++++++++ http3/server.go | 3 +- http3/server_test.go | 2 +- http3/stream.go | 7 ++-- http3/stream_test.go | 7 ++-- http3/transport.go | 3 +- http3/transport_test.go | 3 +- interface.go | 7 ++++ 18 files changed, 240 insertions(+), 59 deletions(-) create mode 100644 http3/internal/testdata/ca.pem create mode 100644 http3/internal/testdata/cert.go create mode 100644 http3/internal/testdata/cert.pem create mode 100644 http3/internal/testdata/cert_test.go create mode 100755 http3/internal/testdata/generate_key.sh create mode 100644 http3/internal/testdata/priv.key diff --git a/.golangci.yml b/.golangci.yml index 0c9ecffa59c..4a97047bafd 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -46,6 +46,13 @@ linters: desc: "use standard Go tests" - pkg: github.com/onsi/gomega desc: "use standard Go tests" + http3-internal: + list-mode: lax + files: + - '**/http3/**' + deny: + - pkg: 'github.com/quic-go/quic-go/internal' + desc: 'no dependency on quic-go/internal' misspell: ignore-rules: - ect diff --git a/http3/client.go b/http3/client.go index 7cd02bc3286..678f5bc7b11 100644 --- a/http3/client.go +++ b/http3/client.go @@ -12,7 +12,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/qpack" @@ -102,7 +101,7 @@ func newClientConn( conn.Context(), conn, c.enableDatagrams, - protocol.PerspectiveClient, + false, // client c.logger, 0, ) diff --git a/http3/conn.go b/http3/conn.go index 63ac9cbe22f..842664963c0 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -14,7 +14,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/qpack" @@ -24,6 +23,9 @@ const maxQuarterStreamID = 1<<60 - 1 var errGoAway = errors.New("connection in graceful shutdown") +// invalidStreamID is a stream ID that is invalid. The first valid stream ID in QUIC is 0. +const invalidStreamID = quic.StreamID(-1) + // Conn is an HTTP/3 connection. // It has all methods from the quic.Conn expect for AcceptStream, AcceptUniStream, // SendDatagram and ReceiveDatagram. @@ -32,17 +34,17 @@ type Conn struct { ctx context.Context - perspective protocol.Perspective - logger *slog.Logger + isServer bool + logger *slog.Logger enableDatagrams bool decoder *qpack.Decoder streamMx sync.Mutex - streams map[protocol.StreamID]*stateTrackingStream - lastStreamID protocol.StreamID - maxStreamID protocol.StreamID + streams map[quic.StreamID]*stateTrackingStream + lastStreamID quic.StreamID + maxStreamID quic.StreamID settings *Settings receivedSettings chan struct{} @@ -55,22 +57,22 @@ func newConnection( ctx context.Context, quicConn *quic.Conn, enableDatagrams bool, - perspective protocol.Perspective, + isServer bool, logger *slog.Logger, idleTimeout time.Duration, ) *Conn { c := &Conn{ ctx: ctx, conn: quicConn, - perspective: perspective, + isServer: isServer, logger: logger, idleTimeout: idleTimeout, enableDatagrams: enableDatagrams, decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), receivedSettings: make(chan struct{}), - streams: make(map[protocol.StreamID]*stateTrackingStream), - maxStreamID: protocol.InvalidStreamID, - lastStreamID: protocol.InvalidStreamID, + streams: make(map[quic.StreamID]*stateTrackingStream), + maxStreamID: invalidStreamID, + lastStreamID: invalidStreamID, } if idleTimeout > 0 { c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer) @@ -124,7 +126,7 @@ func (c *Conn) clearStream(id quic.StreamID) { } // The server is performing a graceful shutdown. // If no more streams are remaining, close the connection. - if c.maxStreamID != protocol.InvalidStreamID { + if c.maxStreamID != invalidStreamID { if len(c.streams) == 0 { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") } @@ -141,7 +143,7 @@ func (c *Conn) openRequestStream( c.streamMx.Lock() maxStreamID := c.maxStreamID var nextStreamID quic.StreamID - if c.lastStreamID == protocol.InvalidStreamID { + if c.lastStreamID == invalidStreamID { nextStreamID = 0 } else { nextStreamID = c.lastStreamID + 4 @@ -149,7 +151,7 @@ func (c *Conn) openRequestStream( c.streamMx.Unlock() // Streams with stream ID equal to or greater than the stream ID carried in the GOAWAY frame // will be rejected, see section 5.2 of RFC 9114. - if maxStreamID != protocol.InvalidStreamID && nextStreamID >= maxStreamID { + if maxStreamID != invalidStreamID && nextStreamID >= maxStreamID { return nil, errGoAway } @@ -268,13 +270,12 @@ func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.Connecti // Our QPACK implementation doesn't use the dynamic table yet. return case streamTypePushStream: - switch c.perspective { - case protocol.PerspectiveClient: - // we never increased the Push ID, so we don't expect any push streams - c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") - case protocol.PerspectiveServer: + if c.isServer { // only the server can push c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") + } else { + // we never increased the Push ID, so we don't expect any push streams + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") } return default: @@ -342,7 +343,7 @@ func (c *Conn) handleControlStream(str *quic.ReceiveStream) { } // we don't support server push, hence we don't expect any GOAWAY frames from the client - if c.perspective == protocol.PerspectiveServer { + if c.isServer { return } @@ -370,7 +371,7 @@ func (c *Conn) handleControlStream(str *quic.ReceiveStream) { return } c.streamMx.Lock() - if c.maxStreamID != protocol.InvalidStreamID && goaway.StreamID > c.maxStreamID { + if c.maxStreamID != invalidStreamID && goaway.StreamID > c.maxStreamID { c.streamMx.Unlock() c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") return @@ -387,7 +388,7 @@ func (c *Conn) handleControlStream(str *quic.ReceiveStream) { } } -func (c *Conn) sendDatagram(streamID protocol.StreamID, b []byte) error { +func (c *Conn) sendDatagram(streamID quic.StreamID, b []byte) error { // TODO: this creates a lot of garbage and an additional copy data := make([]byte, 0, len(b)+8) data = quicvarint.Append(data, uint64(streamID/4)) @@ -410,7 +411,7 @@ func (c *Conn) receiveDatagrams() error { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") return fmt.Errorf("invalid quarter stream id: %w", err) } - streamID := protocol.StreamID(4 * quarterStreamID) + streamID := quic.StreamID(4 * quarterStreamID) c.streamMx.Lock() dg, ok := c.streams[streamID] c.streamMx.Unlock() diff --git a/http3/conn_test.go b/http3/conn_test.go index 21e50afa3e4..d781eefd477 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/require" @@ -21,7 +20,7 @@ func TestConnReceiveSettings(t *testing.T) { serverConn.Context(), serverConn, false, - protocol.PerspectiveServer, + true, // server nil, 0, ) @@ -71,7 +70,7 @@ func testConnRejectDuplicateStreams(t *testing.T, typ uint64) { context.Background(), serverConn, false, - protocol.PerspectiveServer, + true, // server nil, 0, ) @@ -116,7 +115,7 @@ func TestConnResetUnknownUniStream(t *testing.T) { context.Background(), serverConn, false, - protocol.PerspectiveServer, + true, // server nil, 0, ) @@ -198,7 +197,7 @@ func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, exp clientConn.Context(), clientConn, false, - protocol.PerspectiveClient, + false, // client nil, 0, ) @@ -261,7 +260,7 @@ func testConnGoAway(t *testing.T, withStream bool) { clientConn.Context(), clientConn, false, - protocol.PerspectiveClient, + false, // client nil, 0, ) @@ -318,21 +317,21 @@ func testConnGoAway(t *testing.T, withStream bool) { func TestConnRejectPushStream(t *testing.T) { t.Run("client", func(t *testing.T) { - testConnRejectPushStream(t, protocol.PerspectiveClient, ErrCodeStreamCreationError) + testConnRejectPushStream(t, false, ErrCodeStreamCreationError) }) t.Run("server", func(t *testing.T) { - testConnRejectPushStream(t, protocol.PerspectiveServer, ErrCodeIDError) + testConnRejectPushStream(t, true, ErrCodeIDError) }) } -func testConnRejectPushStream(t *testing.T, pers protocol.Perspective, expectedErr ErrCode) { +func testConnRejectPushStream(t *testing.T, isServer bool, expectedErr ErrCode) { clientConn, serverConn := newConnPair(t) conn := newConnection( clientConn.Context(), clientConn, false, - pers.Opposite(), + !isServer, nil, 0, ) @@ -370,7 +369,7 @@ func TestConnInconsistentDatagramSupport(t *testing.T) { clientConn.Context(), clientConn, true, - protocol.PerspectiveClient, + false, // client nil, 0, ) @@ -400,7 +399,7 @@ func TestConnSendAndReceiveDatagram(t *testing.T) { clientConn.Context(), clientConn, true, - protocol.PerspectiveClient, + false, // client nil, 0, ) @@ -429,7 +428,7 @@ func TestConnSendAndReceiveDatagram(t *testing.T) { str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) require.NoError(t, err) - require.Equal(t, protocol.StreamID(strID), str.StreamID()) + require.Equal(t, quic.StreamID(strID), str.StreamID()) // now open the stream... require.NoError(t, serverConn.SendDatagram(append(quarterStreamID, []byte("bar")...))) @@ -467,7 +466,7 @@ func testConnDatagramFailures(t *testing.T, datagram []byte) { clientConn.Context(), clientConn, true, - protocol.PerspectiveClient, + false, // client nil, 0, ) diff --git a/http3/http3_helper_test.go b/http3/http3_helper_test.go index 66a6284ee80..2098d337e4c 100644 --- a/http3/http3_helper_test.go +++ b/http3/http3_helper_test.go @@ -21,11 +21,13 @@ import ( "github.com/quic-go/qpack" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) +// maxByteCount is the maximum value of a ByteCount +const maxByteCount = uint64(1<<62 - 1) + func newUDPConnLocalhost(t testing.TB) *net.UDPConn { t.Helper() conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) @@ -134,8 +136,8 @@ func newConnPair(t *testing.T) (client, server *quic.Conn) { newUDPConnLocalhost(t), getTLSConfig(), &quic.Config{ - InitialStreamReceiveWindow: uint64(protocol.MaxByteCount), - InitialConnectionReceiveWindow: uint64(protocol.MaxByteCount), + InitialStreamReceiveWindow: maxByteCount, + InitialConnectionReceiveWindow: maxByteCount, }, ) require.NoError(t, err) @@ -164,8 +166,8 @@ func newConnPairWithDatagrams(t *testing.T) (client, server *quic.Conn) { newUDPConnLocalhost(t), getTLSConfig(), &quic.Config{ - InitialStreamReceiveWindow: uint64(protocol.MaxByteCount), - InitialConnectionReceiveWindow: uint64(protocol.MaxByteCount), + InitialStreamReceiveWindow: maxByteCount, + InitialConnectionReceiveWindow: maxByteCount, EnableDatagrams: true, }, ) diff --git a/http3/internal/testdata/ca.pem b/http3/internal/testdata/ca.pem new file mode 100644 index 00000000000..67a5545e816 --- /dev/null +++ b/http3/internal/testdata/ca.pem @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICzDCCAbQCCQDA+rLymNnfJzANBgkqhkiG9w0BAQsFADAoMSYwJAYDVQQKDB1x +dWljLWdvIENlcnRpZmljYXRlIEF1dGhvcml0eTAeFw0yMDA4MTgwOTIxMzVaFw0z +MDA4MTYwOTIxMzVaMCgxJjAkBgNVBAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0 +aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1OcsYrVaSDfh +iDppl6oteVspOY3yFb96T9Y/biaGPJAkBO9VGKcqwOUPmUeiWpedRAUB9LE7Srs6 +qBX4mnl90Icjp8jbIs5cPgIWLkIu8Qm549RghFzB3bn+EmCQSe4cxvyDMN3ndClp +3YMXpZgXWgJGiPOylVi/OwHDdWDBorw4hvry+6yDtpQo2TuI2A/xtxXPT7BgsEJD +WGffdgZOYXChcFA0c1XVLIYlu2w2JhxS8c2TUF6uSDlmcoONNKVoiNCuu1Z9MorS +Qmg7a2G7dSPu123KcTcSQFcmJrt+1G81gOBtHB69kacD8xDmgksj09h/ODPL/gIU +1ZcU2ci1/QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB0Tb1JbLXp/BvWovSAhO/j +wG7UEaUA1rCtkDB+fV2HS9bxCbV5eErdg8AMHKgB51ygUrq95vm/baZmUILr84XK +uTEoxxrw5S9Z7SrhtbOpKCumoSeTsCPjDvCcwFExHv4XHFk+CPqZwbMHueVIMT0+ +nGWss/KecCPdJLdnUgMRz0tIuXzkoRuOiUiZfUeyBNVNbDFSrLigYshTeAPGaYjX +CypoHxkeS93nWfOMUu8FTYLYkvGMU5i076zDoFGKJiEtbjSiNW+Hei7u2aSEuCzp +qyTKzYPWYffAq3MM2MKJgZdL04e9GEGeuce/qhM1o3q77aI/XJImwEDdut2LDec1 +-----END CERTIFICATE----- diff --git a/http3/internal/testdata/cert.go b/http3/internal/testdata/cert.go new file mode 100644 index 00000000000..f77a7b2ddbe --- /dev/null +++ b/http3/internal/testdata/cert.go @@ -0,0 +1,56 @@ +package testdata + +import ( + "crypto/tls" + "crypto/x509" + "os" + "path" + "runtime" +) + +var certPath string + +func init() { + _, filename, _, ok := runtime.Caller(0) + if !ok { + panic("Failed to get current frame") + } + + certPath = path.Dir(filename) +} + +// GetCertificatePaths returns the paths to certificate and key +func GetCertificatePaths() (string, string) { + return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key") +} + +// GetTLSConfig returns a tls config for quic.clemente.io +func GetTLSConfig() *tls.Config { + cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) + if err != nil { + panic(err) + } + return &tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{cert}, + } +} + +// AddRootCA adds the root CA certificate to a cert pool +func AddRootCA(certPool *x509.CertPool) { + caCertPath := path.Join(certPath, "ca.pem") + caCertRaw, err := os.ReadFile(caCertPath) + if err != nil { + panic(err) + } + if ok := certPool.AppendCertsFromPEM(caCertRaw); !ok { + panic("Could not add root ceritificate to pool.") + } +} + +// GetRootCA returns an x509.CertPool containing (only) the CA certificate +func GetRootCA() *x509.CertPool { + pool := x509.NewCertPool() + AddRootCA(pool) + return pool +} diff --git a/http3/internal/testdata/cert.pem b/http3/internal/testdata/cert.pem new file mode 100644 index 00000000000..91d1aa9e762 --- /dev/null +++ b/http3/internal/testdata/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC1TCCAb2gAwIBAgIJAK2fcqC0BVA7MA0GCSqGSIb3DQEBCwUAMCgxJjAkBgNV +BAoMHXF1aWMtZ28gQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTIwMDgxODA5MjEz +NVoXDTMwMDgxNjA5MjEzNVowEjEQMA4GA1UECgwHcXVpYy1nbzCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAN/YwrigSXdJCL/bdBGhb0UpqtU8H+krV870 ++w1yCSykLImH8x3qHZEXt9sr/vgjcJoV6Z15RZmnbEqnAx84sIClIBoIgnk0VPxu +WF+/U/dElbftCfYcfJAddhRckdmGB+yb3Wogb32UJ+q3my++h6NjHsYb+OwpJPnQ +meXjOE7Kkf+bXfFywHF3R8kzVdh5JUFYeKbxYmYgxRps1YTsbCrZCrSy1CbQ9FJw +Wg5C8t+7yvVFmOeWPECypBCz2xS2mu+kycMNIjIWMl0SL7oVM5cBkRKPeVIG/KcM +i5+/4lRSLoPh0Txh2TKBWfpzLbIOdPU8/O7cAukIGWx0XsfHUQMCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAyxxvebdMz +shp5pt1SxMOSXbo8sTa1cpaf2rTmb4nxjXs6KPBEn53hSBz9bhe5wXE4f94SHadf +636rLh3d75KgrLUwO9Yq0HfCxMo1jUV/Ug++XwcHCI9vk58Tk/H4hqEM6C8RrdTj +fYeuegQ0/oNLJ4uTw2P2A8TJbL6FC2dcICEAvUGZUcVyZ8m8tHXNRYYh6MZ7ubCh +hinvL+AA5fY6EVlc5G/P4DN6fYxGn1cFNbiL4uZP4+W3dOmP+NV0YV9ihTyMzz0R +vSoOZ9FeVkyw8EhMb3LoyXYKazvJy2VQST1ltzAGit9RiM1Gv4vuna74WsFzrn1U +A/TbaR0ih/qG +-----END CERTIFICATE----- diff --git a/http3/internal/testdata/cert_test.go b/http3/internal/testdata/cert_test.go new file mode 100644 index 00000000000..2eff1fef284 --- /dev/null +++ b/http3/internal/testdata/cert_test.go @@ -0,0 +1,28 @@ +package testdata + +import ( + "crypto/tls" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCertificates(t *testing.T) { + ln, err := tls.Listen("tcp", "localhost:4433", GetTLSConfig()) + require.NoError(t, err) + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) + }() + + conn, err := tls.Dial("tcp", "localhost:4433", &tls.Config{RootCAs: GetRootCA()}) + require.NoError(t, err) + data, err := io.ReadAll(conn) + require.NoError(t, err) + require.Equal(t, "foobar", string(data)) +} diff --git a/http3/internal/testdata/generate_key.sh b/http3/internal/testdata/generate_key.sh new file mode 100755 index 00000000000..7ecaa966d26 --- /dev/null +++ b/http3/internal/testdata/generate_key.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e + +echo "Generating CA key and certificate:" +openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:2048 \ + -keyout ca.key -out ca.pem \ + -subj "/O=quic-go Certificate Authority/" + +echo "Generating CSR" +openssl req -out cert.csr -new -newkey rsa:2048 -nodes -keyout priv.key \ + -subj "/O=quic-go/" + +echo "Sign certificate:" +openssl x509 -req -sha256 -days 3650 -in cert.csr -out cert.pem \ + -CA ca.pem -CAkey ca.key -CAcreateserial \ + -extfile <(printf "subjectAltName=DNS:localhost") + +# debug output the certificate +openssl x509 -noout -text -in cert.pem + +# we don't need the CA key, the serial number and the CSR any more +rm ca.key cert.csr ca.srl + diff --git a/http3/internal/testdata/priv.key b/http3/internal/testdata/priv.key new file mode 100644 index 00000000000..56b8d894dc1 --- /dev/null +++ b/http3/internal/testdata/priv.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDf2MK4oEl3SQi/ +23QRoW9FKarVPB/pK1fO9PsNcgkspCyJh/Md6h2RF7fbK/74I3CaFemdeUWZp2xK +pwMfOLCApSAaCIJ5NFT8blhfv1P3RJW37Qn2HHyQHXYUXJHZhgfsm91qIG99lCfq +t5svvoejYx7GG/jsKST50Jnl4zhOypH/m13xcsBxd0fJM1XYeSVBWHim8WJmIMUa +bNWE7Gwq2Qq0stQm0PRScFoOQvLfu8r1RZjnljxAsqQQs9sUtprvpMnDDSIyFjJd +Ei+6FTOXAZESj3lSBvynDIufv+JUUi6D4dE8YdkygVn6cy2yDnT1PPzu3ALpCBls +dF7Hx1EDAgMBAAECggEBAMm+mLDBdbUWk9YmuZNyRdC13wvT5obF05vo26OglXgw +dxt09b6OVBuCnuff3SpS9pdJDIYq2HnFlSorH/sxopIvQKF17fHDIp1n7ipNTCXd +IHrmHkY8Il/YzaVIUQMVc2rih0mw9greTqOS20DKnYC6QvAWIeDmrDaitTGl+ge3 +hm7e2lsgZi13R6fTNwQs9geEQSGzP2k7bFceHQFDChOYiQraR5+VZZ8S8AMGjk47 +AUa5EsKeUe6O9t2xuDSFxzYz5eadOAiErKGDos5KXXr3VQgFcC8uPEFFjcJ/yl+8 +tOe4iLeVwGSDJhTAThdR2deJOjaDcarWM7ixmxA3DAECgYEA/WVwmY4gWKwv49IJ +Jnh1Gu93P772GqliMNpukdjTI+joQxfl4jRSt2hk4b1KRwyT9aaKfvdz0HFlXo/r +9NVSAYT3/3vbcw61bfvPhhtz44qRAAKua6b5cUM6XqxVt1hqdP8lrf/blvA5ln+u +O51S8+wpxZMuqKz/29zdWSG6tAMCgYEA4iWXMXX9dZajI6abVkWwuosvOakXdLk4 +tUy7zd+JPF7hmUzzj2gtg4hXoiQPAOi+GY3TX+1Nza3s1LD7iWaXSKeOWvvligw9 +Q/wVTNW2P1+tdhScJf9QudzW69xOm5HNBgx9uWV2cHfjC12vg5aTH0k5axvaq15H +9WBXlH5q3wECgYBYoYGYBDFmMpvxmMagkSOMz1OrlVSpkLOKmOxx0SBRACc1SIec +7mY8RqR6nOX9IfYixyTMMittLiyhvb9vfKnZZDQGRcFFZlCpbplws+t+HDqJgWaW +uumm5zfkY2z7204pLBF24fZhvha2gGRl76pTLTiTJd79Gr3HnmJByd1vFwKBgHL7 +vfYuEeM55lT4Hz8sTAFtR2O/7+cvTgAQteSlZbfGXlp939DonUulhTkxsFc7/3wq +unCpzcdoSWSTYDGqcf1FBIKKVVltg7EPeR0KBJIQabgCHqrLOBZojPZ7m5RJ+765 +lysuxZvFuTFMPzNe2gssRf+JuBMt6tR+WclsxZYBAoGAEEFs1ppDil1xlP5rdH7T +d3TSw/u4eU/X8Ei1zi25hdRUiV76fP9fBELYFmSrPBhugYv91vtSv/LmD4zLfLv/ +yzwAD9j1lGbgM8Of8klCkk+XSJ88ryUwnMTJ5loQJW8t4L+zLv5Le7Ca9SAT0kJ1 +jT0GzDymgLMGp8RPdBkpk+w= +-----END PRIVATE KEY----- diff --git a/http3/server.go b/http3/server.go index 8d53a724670..7e002a09b64 100644 --- a/http3/server.go +++ b/http3/server.go @@ -18,7 +18,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/qpack" @@ -468,7 +467,7 @@ func (s *Server) handleConn(conn *quic.Conn) error { connCtx, conn, s.EnableDatagrams, - protocol.PerspectiveServer, + true, // server s.Logger, s.IdleTimeout, ) diff --git a/http3/server_test.go b/http3/server_test.go index 3c81253f8ba..b39993a0c5d 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -15,7 +15,7 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/testdata" + "github.com/quic-go/quic-go/http3/internal/testdata" "github.com/quic-go/quic-go/quicvarint" "github.com/stretchr/testify/assert" diff --git a/http3/stream.go b/http3/stream.go index a1dafa11dc8..50295600c22 100644 --- a/http3/stream.go +++ b/http3/stream.go @@ -10,7 +10,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/qpack" ) @@ -79,7 +78,7 @@ func (s *Stream) Read(b []byte) (int, error) { s.bytesRemainingInFrame = f.Length break parseLoop case *headersFrame: - if s.conn.perspective == protocol.PerspectiveServer { + if s.conn.isServer { continue } if s.parsedTrailer { @@ -124,7 +123,7 @@ func (s *Stream) writeUnframed(b []byte) (int, error) { return s.datagramStream.Write(b) } -func (s *Stream) StreamID() protocol.StreamID { +func (s *Stream) StreamID() quic.StreamID { return s.datagramStream.StreamID() } @@ -194,7 +193,7 @@ func (s *RequestStream) Read(b []byte) (int, error) { } // StreamID returns the QUIC stream ID of the underlying QUIC stream. -func (s *RequestStream) StreamID() protocol.StreamID { +func (s *RequestStream) StreamID() quic.StreamID { return s.str.StreamID() } diff --git a/http3/stream_test.go b/http3/stream_test.go index a6352a1d44b..ca18a9cd6d2 100644 --- a/http3/stream_test.go +++ b/http3/stream_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/qpack" @@ -40,7 +39,7 @@ func TestStreamReadDataFrames(t *testing.T) { clientConn.Context(), clientConn, false, - protocol.PerspectiveClient, + false, // client nil, 0, ), @@ -92,7 +91,7 @@ func TestStreamInvalidFrame(t *testing.T) { str := newStream( qstr, - newConnection(context.Background(), clientConn, false, protocol.PerspectiveClient, nil, 0), + newConnection(context.Background(), clientConn, false, false, nil, 0), nil, func(r io.Reader, u uint64) error { return nil }, ) @@ -146,7 +145,7 @@ func TestRequestStream(t *testing.T) { str := newRequestStream( newStream( qstr, - newConnection(context.Background(), clientConn, false, protocol.PerspectiveClient, nil, 0), + newConnection(context.Background(), clientConn, false, false, nil, 0), &httptrace.ClientTrace{}, func(r io.Reader, u uint64) error { return nil }, ), diff --git a/http3/transport.go b/http3/transport.go index dac84bb8dad..4085faca5ef 100644 --- a/http3/transport.go +++ b/http3/transport.go @@ -18,7 +18,6 @@ import ( "golang.org/x/net/http/httpguts" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" ) // Settings are HTTP/3 settings that apply to the underlying connection. @@ -146,7 +145,7 @@ func (t *Transport) init() error { } if len(t.QUICConfig.Versions) == 0 { t.QUICConfig = t.QUICConfig.Clone() - t.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]} + t.QUICConfig.Versions = []quic.Version{quic.SupportedVersions()[0]} } if len(t.QUICConfig.Versions) != 1 { return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") diff --git a/http3/transport_test.go b/http3/transport_test.go index 63adc76f074..2b98ccf1cf7 100644 --- a/http3/transport_test.go +++ b/http3/transport_test.go @@ -14,7 +14,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -208,7 +207,7 @@ func TestTransportDatagrams(t *testing.T) { func TestTransportMultipleQUICVersions(t *testing.T) { qconf := &quic.Config{ - Versions: []quic.Version{protocol.Version2, protocol.Version1}, + Versions: []quic.Version{quic.Version2, quic.Version1}, } tr := &Transport{QUICConfig: qconf} req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) diff --git a/interface.go b/interface.go index 70dcdef1c0a..45a03a52ac4 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "net" + "slices" "time" "github.com/quic-go/quic-go/internal/handshake" @@ -25,6 +26,12 @@ const ( Version2 = protocol.Version2 ) +// SupportedVersions returns the support versions, sorted in descending order of preference. +func SupportedVersions() []Version { + // clone the slice to prevent the caller from modifying the slice + return slices.Clone(protocol.SupportedVersions) +} + // A ClientToken is a token received by the client. // It can be used to skip address validation on future connection attempts. type ClientToken struct { From afe01ef103a52e842adf7560662aa076458e0e16 Mon Sep 17 00:00:00 2001 From: Coia Prant Date: Sun, 13 Jul 2025 00:58:30 +0800 Subject: [PATCH 11/14] close Transport when DialAddr fails (#5259) Close the transport after dial fails to avoid memory leaks. Same logic as DialAddrEarly. Signed-off-by: Coia Prant --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 5386e598036..63132f2deb1 100644 --- a/client.go +++ b/client.go @@ -31,6 +31,7 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi } conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, false) if err != nil { + tr.Close() return nil, err } return conn, nil From 893a5941fbb077255d14e93dd5b5c13c7e05fe4a Mon Sep 17 00:00:00 2001 From: Jannis Seemann <5215310+jannis-seemann@users.noreply.github.com> Date: Fri, 18 Jul 2025 19:33:04 +0300 Subject: [PATCH 12/14] wire: improve frame parsing benchmarks (#5263) * Add master-style frame handling benchmarks using type switches * Fixing styling issue. * put STREAM frame back * remove BenchmarkParseStreamAndACK * use random data for STREAM and DATAGRAM * improve comment --------- Co-authored-by: Marten Seemann --- internal/wire/frame_parser_test.go | 258 +++++++++++++++++------------ 1 file changed, 149 insertions(+), 109 deletions(-) diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index e044f2841bf..fefb5c11ee0 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "crypto/rand" + "slices" "testing" "time" @@ -259,132 +260,171 @@ func TestFrameParsingErrorsOnInvalidFrames(t *testing.T) { require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) } -// STREAM and ACK are the most relevant frames for high-throughput transfers. -func BenchmarkParseStreamAndACK(b *testing.B) { - ack := &AckFrame{ - AckRanges: []AckRange{ - {Smallest: 5000, Largest: 5200}, - {Smallest: 1, Largest: 4200}, - }, - DelayTime: 42 * time.Millisecond, - ECT0: 5000, - ECT1: 0, - ECNCE: 10, - } - sf := &StreamFrame{ - StreamID: 1337, - Offset: 1e7, - Data: make([]byte, 200), - DataLenPresent: true, +func writeFrames(tb testing.TB, frames ...Frame) []byte { + var b []byte + for _, f := range frames { + var err error + b, err = f.Append(b, protocol.Version1) + require.NoError(tb, err) } - rand.Read(sf.Data) + return b +} - data, err := ack.Append([]byte{}, protocol.Version1) - if err != nil { - b.Fatal(err) - } - data, err = sf.Append(data, protocol.Version1) - if err != nil { - b.Fatal(err) +// This function is used in benchmarks, and also to ensure zero allocation for STREAM frame parsing. +// We can therefore not use the require framework, as it allocates. +func parseFrames(tb testing.TB, parser *FrameParser, data []byte, frames ...Frame) { + for _, expectedFrame := range frames { + l, frame, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1) + if err != nil { + tb.Fatal(err) + } + data = data[l:] + if frame == nil { + break + } + + // Use type switch approach (like master branch) + switch f := frame.(type) { + case *StreamFrame: + sf := expectedFrame.(*StreamFrame) + if sf.StreamID != f.StreamID || sf.Offset != f.Offset || !bytes.Equal(sf.Data, f.Data) { + tb.Fatalf("STREAM frame does not match: %v vs %v", sf, f) + } + f.PutBack() + case *AckFrame: + af, ok := expectedFrame.(*AckFrame) + if !ok { + tb.Fatalf("expected ACK, but got %v", expectedFrame) + } + if f.DelayTime != af.DelayTime || f.ECNCE != af.ECNCE || f.ECT0 != af.ECT0 || f.ECT1 != af.ECT1 { + tb.Fatalf("ACK frame does not match: %v vs %v", af, f) + } + if !slices.Equal(f.AckRanges, af.AckRanges) { + tb.Fatalf("ACK frame ACK ranges don't match: %v vs %v", af, f) + } + case *DatagramFrame: + df, ok := expectedFrame.(*DatagramFrame) + if !ok { + tb.Fatalf("expected DATAGRAM, but got %v", expectedFrame) + } + if df.DataLenPresent != f.DataLenPresent || !bytes.Equal(df.Data, f.Data) { + tb.Fatalf("DATAGRAM frame does not match: %v vs %v", df, f) + } + case *MaxDataFrame: + mdf, ok := expectedFrame.(*MaxDataFrame) + if !ok { + tb.Fatalf("expected MAX_DATA, but got %v", expectedFrame) + } + if *f != *mdf { + tb.Fatalf("MAX_DATA frame does not match: %v vs %v", f, mdf) + } + case *MaxStreamsFrame: + msf, ok := expectedFrame.(*MaxStreamsFrame) + if !ok { + tb.Fatalf("expected MAX_STREAMS, but got %v", expectedFrame) + } + if *f != *msf { + tb.Fatalf("MAX_STREAMS frame does not match: %v vs %v", f, msf) + } + case *MaxStreamDataFrame: + mdf, ok := expectedFrame.(*MaxStreamDataFrame) + if !ok { + tb.Fatalf("expected MAX_STREAM_DATA, but got %v", expectedFrame) + } + if *f != *mdf { + tb.Fatalf("MAX_STREAM_DATA frame does not match: %v vs %v", f, mdf) + } + case *CryptoFrame: + cf, ok := expectedFrame.(*CryptoFrame) + if !ok { + tb.Fatalf("expected CRYPTO, but got %v", expectedFrame) + } + if f.Offset != cf.Offset || !bytes.Equal(f.Data, cf.Data) { + tb.Fatalf("CRYPTO frame does not match: %v vs %v", f, cf) + } + case *PingFrame: + _ = f + case *ResetStreamFrame: + rsf, ok := expectedFrame.(*ResetStreamFrame) + if !ok { + tb.Fatalf("expected RESET_STREAM, but got %v", expectedFrame) + } + if *f != *rsf { + tb.Fatalf("RESET_STREAM frame does not match: %v vs %v", f, rsf) + } + default: + tb.Fatalf("Frame type not supported in benchmark: %T", f) + } } +} - parser := NewFrameParser(false, false) +func benchmarkFrames(b *testing.B, frames ...Frame) { + buf := writeFrames(b, frames...) + + parser := NewFrameParser(true, true) parser.SetAckDelayExponent(3) b.ResetTimer() b.ReportAllocs() - for i := 0; i < b.N; i++ { - l, f, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1) - if err != nil { - b.Fatal(err) - } - ackParsed := f.(*AckFrame) - if ackParsed.DelayTime != ack.DelayTime || ackParsed.ECNCE != ack.ECNCE { - b.Fatalf("incorrect ACK frame: %v vs %v", ack, ackParsed) - } - l2, f, err := parser.ParseNext(data[l:], protocol.Encryption1RTT, protocol.Version1) - if err != nil { - b.Fatal(err) - } - if len(data[l:]) != l2 { - b.Fatal("didn't parse the entire packet") - } - sfParsed := f.(*StreamFrame) - if sfParsed.StreamID != sf.StreamID || !bytes.Equal(sfParsed.Data, sf.Data) { - b.Fatalf("incorrect STREAM frame: %v vs %v", sf, sfParsed) - } + + for range b.N { + parseFrames(b, parser, buf, frames...) } } func BenchmarkParseOtherFrames(b *testing.B) { - maxDataFrame := &MaxDataFrame{MaximumData: 123456} - maxStreamsFrame := &MaxStreamsFrame{MaxStreamNum: 10} - maxStreamDataFrame := &MaxStreamDataFrame{StreamID: 1337, MaximumStreamData: 1e6} - cryptoFrame := &CryptoFrame{Offset: 1000, Data: make([]byte, 128)} - resetStreamFrame := &ResetStreamFrame{StreamID: 87654, ErrorCode: 1234, FinalSize: 1e8} - rand.Read(cryptoFrame.Data) frames := []Frame{ - maxDataFrame, - maxStreamsFrame, - maxStreamDataFrame, - cryptoFrame, + &MaxDataFrame{MaximumData: 123456}, + &MaxStreamsFrame{MaxStreamNum: 10}, + &MaxStreamDataFrame{StreamID: 1337, MaximumStreamData: 1e6}, + &CryptoFrame{Offset: 1000, Data: make([]byte, 128)}, &PingFrame{}, - resetStreamFrame, + &ResetStreamFrame{StreamID: 87654, ErrorCode: 1234, FinalSize: 1e8}, } - var buf []byte - for i, frame := range frames { - var err error - buf, err = frame.Append(buf, protocol.Version1) - if err != nil { - b.Fatal(err) - } - if i == len(frames)/2 { - // add 3 PADDING frames - buf = append(buf, 0) - buf = append(buf, 0) - buf = append(buf, 0) - } + benchmarkFrames(b, frames...) +} + +func BenchmarkParseAckFrame(b *testing.B) { + var frames []Frame + for i := range 10 { + frames = append(frames, &AckFrame{ + AckRanges: []AckRange{ + {Smallest: protocol.PacketNumber(5000 + i), Largest: protocol.PacketNumber(5200 + i)}, + {Smallest: protocol.PacketNumber(1 + i), Largest: protocol.PacketNumber(4200 + i)}, + }, + DelayTime: time.Duration(int64(time.Millisecond) * int64(i)), + ECT0: uint64(5000 + i), + ECT1: uint64(i), + ECNCE: uint64(10 + i), + }) } + benchmarkFrames(b, frames...) +} - parser := NewFrameParser(false, false) +func BenchmarkParseStreamFrame(b *testing.B) { + var frames []Frame + for i := range 10 { + data := make([]byte, 200+i) + rand.Read(data) + frames = append(frames, &StreamFrame{ + StreamID: protocol.StreamID(1337 + i), + Offset: protocol.ByteCount(1e7 + i), + Data: data, + DataLenPresent: true, + }) + } + benchmarkFrames(b, frames...) +} - b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - data := buf - for j := 0; j < len(frames); j++ { - l, f, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1) - if err != nil { - b.Fatal(err) - } - data = data[l:] - switch j { - case 0: - if f.(*MaxDataFrame).MaximumData != maxDataFrame.MaximumData { - b.Fatalf("MAX_DATA frame does not match: %v vs %v", f, maxDataFrame) - } - case 1: - if f.(*MaxStreamsFrame).MaxStreamNum != maxStreamsFrame.MaxStreamNum { - b.Fatalf("MAX_STREAMS frame does not match: %v vs %v", f, maxStreamsFrame) - } - case 2: - if f.(*MaxStreamDataFrame).StreamID != maxStreamDataFrame.StreamID || - f.(*MaxStreamDataFrame).MaximumStreamData != maxStreamDataFrame.MaximumStreamData { - b.Fatalf("MAX_STREAM_DATA frame does not match: %v vs %v", f, maxStreamDataFrame) - } - case 3: - if f.(*CryptoFrame).Offset != cryptoFrame.Offset || !bytes.Equal(f.(*CryptoFrame).Data, cryptoFrame.Data) { - b.Fatalf("CRYPTO frame does not match: %v vs %v", f, cryptoFrame) - } - case 4: - _ = f.(*PingFrame) - case 5: - rst := f.(*ResetStreamFrame) - if rst.StreamID != resetStreamFrame.StreamID || rst.ErrorCode != resetStreamFrame.ErrorCode || - rst.FinalSize != resetStreamFrame.FinalSize { - b.Fatalf("RESET_STREAM frame does not match: %v vs %v", rst, resetStreamFrame) - } - } - } +func BenchmarkParseDatagramFrame(b *testing.B) { + var frames []Frame + for i := range 10 { + data := make([]byte, 200+i) + rand.Read(data) + frames = append(frames, &DatagramFrame{ + Data: data, + DataLenPresent: true, + }) } + benchmarkFrames(b, frames...) } From c2e784aaf21fe66f55b166249d8c9dc9b0aa0fc7 Mon Sep 17 00:00:00 2001 From: Jannis Seemann <5215310+jannis-seemann@users.noreply.github.com> Date: Sun, 20 Jul 2025 14:14:38 +0300 Subject: [PATCH 13/14] wire: optimize parsing logic for STREAM, DATAGRAM and ACK frames (#5227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ParseOtherFrames-16 148ns ± 4% 150ns ± 3% ~ (p=0.223 n=8+8) ParseAckFrame-16 302ns ± 2% 298ns ± 3% ~ (p=0.246 n=8+8) ParseStreamFrame-16 262ns ± 3% 213ns ± 2% -18.61% (p=0.000 n=8+8) ParseDatagramFrame-16 561ns ± 5% 547ns ± 4% ~ (p=0.105 n=8+8) --- connection.go | 107 +++- connection_test.go | 40 +- fuzzing/frames/fuzz.go | 24 +- internal/ackhandler/ack_eliciting.go | 13 + internal/ackhandler/ack_eliciting_test.go | 42 ++ internal/wire/ack_frame.go | 8 +- internal/wire/ack_frame_test.go | 40 +- internal/wire/connection_close_frame.go | 8 +- internal/wire/connection_close_frame_test.go | 18 +- internal/wire/crypto_frame.go | 2 +- internal/wire/crypto_frame_test.go | 2 +- internal/wire/data_blocked_frame.go | 2 +- internal/wire/data_blocked_frame_test.go | 2 +- internal/wire/datagram_frame.go | 4 +- internal/wire/frame.go | 12 + internal/wire/frame_parser.go | 237 ++++--- internal/wire/frame_parser_test.go | 577 ++++++++++++++---- internal/wire/frame_test.go | 13 + internal/wire/frame_type.go | 77 +++ internal/wire/frame_type_test.go | 29 + internal/wire/handshake_done_frame.go | 2 +- internal/wire/handshake_done_frame_test.go | 2 +- internal/wire/max_data_frame.go | 2 +- internal/wire/max_data_frame_test.go | 2 +- internal/wire/max_stream_data_frame.go | 2 +- internal/wire/max_stream_data_frame_test.go | 2 +- internal/wire/max_streams_frame.go | 11 +- internal/wire/max_streams_frame_test.go | 12 +- internal/wire/new_connection_id_frame.go | 2 +- internal/wire/new_connection_id_frame_test.go | 2 +- internal/wire/new_token_frame.go | 2 +- internal/wire/new_token_frame_test.go | 2 +- internal/wire/path_challenge_frame.go | 2 +- internal/wire/path_challenge_frame_test.go | 2 +- internal/wire/path_response_frame.go | 2 +- internal/wire/path_response_frame_test.go | 2 +- internal/wire/ping_frame.go | 2 +- internal/wire/reset_stream_frame.go | 4 +- internal/wire/reset_stream_frame_test.go | 4 +- internal/wire/retire_connection_id_frame.go | 2 +- .../wire/retire_connection_id_frame_test.go | 2 +- internal/wire/stop_sending_frame.go | 2 +- internal/wire/stop_sending_frame_test.go | 2 +- .../wire/stream_data_blocked_frame_test.go | 2 +- internal/wire/stream_frame.go | 2 +- internal/wire/stream_frame_test.go | 22 +- internal/wire/streams_blocked_frame.go | 11 +- internal/wire/streams_blocked_frame_test.go | 16 +- packet_packer_test.go | 18 +- 49 files changed, 1022 insertions(+), 375 deletions(-) create mode 100644 internal/wire/frame_type.go create mode 100644 internal/wire/frame_type_test.go diff --git a/connection.go b/connection.go index ce565aa014f..2251af08d5c 100644 --- a/connection.go +++ b/connection.go @@ -1437,39 +1437,100 @@ func (c *Conn) handleFrames( } handshakeWasComplete := c.handshakeComplete var handleErr error + var skipHandling bool + for len(data) > 0 { - l, frame, err := c.frameParser.ParseNext(data, encLevel, c.version) + frameType, l, err := c.frameParser.ParseType(data, encLevel) if err != nil { + // The frame parser skips over PADDING frames, and returns an io.EOF if the PADDING + // frames were the last frames in this packet. + if err == io.EOF { + break + } return false, false, nil, err } data = data[l:] - if frame == nil { - break - } - if ackhandler.IsFrameAckEliciting(frame) { + + if ackhandler.IsFrameTypeAckEliciting(frameType) { isAckEliciting = true } - if !wire.IsProbingFrame(frame) { + if !wire.IsProbingFrameType(frameType) { isNonProbing = true } - if log != nil { - frames = append(frames, toLoggingFrame(frame)) - } - // An error occurred handling a previous frame. - // Don't handle the current frame. - if handleErr != nil { - continue - } - pc, err := c.handleFrame(frame, encLevel, destConnID, rcvTime) - if err != nil { - if log == nil { + + // We're inlining common cases, to avoid using interfaces + // Fast path: STREAM, DATAGRAM and ACK + if frameType.IsStreamFrameType() { + streamFrame, l, err := c.frameParser.ParseStreamFrame(frameType, data, c.version) + if err != nil { + return false, false, nil, err + } + data = data[l:] + + if log != nil { + frames = append(frames, toLoggingFrame(streamFrame)) + } + // an error occurred handling a previous frame, don't handle the current frame + if skipHandling { + continue + } + handleErr = c.streamsMap.HandleStreamFrame(streamFrame, rcvTime) + } else if frameType.IsAckFrameType() { + ackFrame, l, err := c.frameParser.ParseAckFrame(frameType, data, encLevel, c.version) + if err != nil { + return false, false, nil, err + } + data = data[l:] + if log != nil { + frames = append(frames, toLoggingFrame(ackFrame)) + } + // an error occurred handling a previous frame, don't handle the current frame + if skipHandling { + continue + } + handleErr = c.handleAckFrame(ackFrame, encLevel, rcvTime) + } else if frameType.IsDatagramFrameType() { + datagramFrame, l, err := c.frameParser.ParseDatagramFrame(frameType, data, c.version) + if err != nil { + return false, false, nil, err + } + data = data[l:] + + if log != nil { + frames = append(frames, toLoggingFrame(datagramFrame)) + } + // an error occurred handling a previous frame, don't handle the current frame + if skipHandling { + continue + } + handleErr = c.handleDatagramFrame(datagramFrame) + } else { + frame, l, err := c.frameParser.ParseLessCommonFrame(frameType, data, c.version) + if err != nil { return false, false, nil, err } - // If we're logging, we need to keep parsing (but not handling) all frames. + data = data[l:] + + if log != nil { + frames = append(frames, toLoggingFrame(frame)) + } + // an error occurred handling a previous frame, don't handle the current frame + if skipHandling { + continue + } + pc, err := c.handleFrame(frame, encLevel, destConnID, rcvTime) + if pc != nil { + pathChallenge = pc + } handleErr = err } - if pc != nil { - pathChallenge = pc + + if handleErr != nil { + // if we're logging, we need to keep parsing (but not handling) all frames + skipHandling = true + if log == nil { + return false, false, nil, handleErr + } } } @@ -1503,10 +1564,6 @@ func (c *Conn) handleFrame( switch frame := f.(type) { case *wire.CryptoFrame: err = c.handleCryptoFrame(frame, encLevel, rcvTime) - case *wire.StreamFrame: - err = c.streamsMap.HandleStreamFrame(frame, rcvTime) - case *wire.AckFrame: - err = c.handleAckFrame(frame, encLevel, rcvTime) case *wire.ConnectionCloseFrame: err = c.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: @@ -1537,8 +1594,6 @@ func (c *Conn) handleFrame( err = c.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*c.rttStats.PTO(false))) case *wire.HandshakeDoneFrame: err = c.handleHandshakeDoneFrame(rcvTime) - case *wire.DatagramFrame: - err = c.handleDatagramFrame(frame) default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } diff --git a/connection_test.go b/connection_test.go index cf976eb69ca..30d7dd8c15c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -215,17 +215,19 @@ func TestConnectionHandleStreamRelatedFrames(t *testing.T) { name string frame wire.Frame }{ - {name: "STREAM", frame: &wire.StreamFrame{StreamID: id, Data: []byte("foobar")}}, {name: "RESET_STREAM", frame: &wire.ResetStreamFrame{StreamID: id, ErrorCode: 42, FinalSize: 1337}}, {name: "STOP_SENDING", frame: &wire.StopSendingFrame{StreamID: id, ErrorCode: 42}}, {name: "MAX_STREAM_DATA", frame: &wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1337}}, {name: "STREAM_DATA_BLOCKED", frame: &wire.StreamDataBlockedFrame{StreamID: id, MaximumStreamData: 42}}, + {name: "STREAM_FRAME", frame: &wire.StreamFrame{StreamID: id, Data: []byte{1, 2, 3, 4, 5, 6, 7, 8}, Offset: 1337}}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { tc := newServerTestConnection(t, gomock.NewController(t), nil, false) - _, err := tc.conn.handleFrame(test.frame, protocol.Encryption1RTT, connID, time.Now()) + data, err := test.frame.Append(nil, protocol.Version1) + require.NoError(t, err) + _, _, _, err = tc.conn.handleFrames(data, connID, protocol.Encryption1RTT, nil, time.Now()) require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError}) }) } @@ -2996,3 +2998,37 @@ func testConnectionMigration(t *testing.T, enabled bool) { t.Fatal("timeout") } } + +func TestConnectionDatagrams(t *testing.T) { + t.Run("disabled", func(t *testing.T) { + testConnectionDatagrams(t, false) + }) + t.Run("enabled", func(t *testing.T) { + testConnectionDatagrams(t, true) + }) +} + +func testConnectionDatagrams(t *testing.T, enabled bool) { + tc := newServerTestConnection(t, nil, &Config{EnableDatagrams: enabled}, false) + + data, err := (&wire.DatagramFrame{Data: []byte("foo"), DataLenPresent: true}).Append(nil, protocol.Version1) + require.NoError(t, err) + data, err = (&wire.DatagramFrame{Data: []byte("bar")}).Append(data, protocol.Version1) + require.NoError(t, err) + _, _, _, err = tc.conn.handleFrames(data, protocol.ConnectionID{}, protocol.Encryption1RTT, nil, time.Now()) + + if !enabled { + require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.FrameEncodingError, FrameType: uint64(wire.FrameTypeDatagramWithLength)}) + return + } + + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d, err := tc.conn.ReceiveDatagram(ctx) + require.NoError(t, err) + require.Equal(t, []byte("foo"), d) + d, err = tc.conn.ReceiveDatagram(ctx) + require.NoError(t, err) + require.Equal(t, []byte("bar"), d) +} diff --git a/fuzzing/frames/fuzz.go b/fuzzing/frames/fuzz.go index e7b0247aa44..c8200265680 100644 --- a/fuzzing/frames/fuzz.go +++ b/fuzzing/frames/fuzz.go @@ -2,6 +2,7 @@ package frames import ( "fmt" + "io" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/protocol" @@ -41,15 +42,32 @@ func Fuzz(data []byte) int { var b []byte for len(data) > 0 { initialLen := len(data) - l, f, err := parser.ParseNext(data, encLevel, version) + frameType, l, err := parser.ParseType(data, encLevel) if err != nil { + if err == io.EOF { // the last frame was a PADDING frame + continue + } break } + data = data[l:] numFrames++ - if f == nil { // PADDING frame - continue + + var f wire.Frame + switch { + case frameType.IsStreamFrameType(): + f, l, err = parser.ParseStreamFrame(frameType, data, version) + case frameType == wire.FrameTypeAck || frameType == wire.FrameTypeAckECN: + f, l, err = parser.ParseAckFrame(frameType, data, encLevel, version) + case frameType == wire.FrameTypeDatagramNoLength || frameType == wire.FrameTypeDatagramWithLength: + f, l, err = parser.ParseDatagramFrame(frameType, data, version) + default: + f, l, err = parser.ParseLessCommonFrame(frameType, data, version) } + if err != nil { + break + } + data = data[l:] wire.IsProbingFrame(f) ackhandler.IsFrameAckEliciting(f) // We accept empty STREAM frames, but we don't write them. diff --git a/internal/ackhandler/ack_eliciting.go b/internal/ackhandler/ack_eliciting.go index 34506b12e01..8d8436123e5 100644 --- a/internal/ackhandler/ack_eliciting.go +++ b/internal/ackhandler/ack_eliciting.go @@ -2,6 +2,19 @@ package ackhandler import "github.com/quic-go/quic-go/internal/wire" +// IsFrameTypeAckEliciting returns true if the frame is ack-eliciting. +func IsFrameTypeAckEliciting(t wire.FrameType) bool { + //nolint:exhaustive // The default case catches the rest. + switch t { + case wire.FrameTypeAck, wire.FrameTypeAckECN: + return false + case wire.FrameTypeConnectionClose, wire.FrameTypeApplicationClose: + return false + default: + return true + } +} + // IsFrameAckEliciting returns true if the frame is ack-eliciting. func IsFrameAckEliciting(f wire.Frame) bool { _, isAck := f.(*wire.AckFrame) diff --git a/internal/ackhandler/ack_eliciting_test.go b/internal/ackhandler/ack_eliciting_test.go index 1c363e9304e..65cc627e65c 100644 --- a/internal/ackhandler/ack_eliciting_test.go +++ b/internal/ackhandler/ack_eliciting_test.go @@ -7,6 +7,48 @@ import ( "github.com/stretchr/testify/require" ) +func TestIsFrameTypeAckEliciting(t *testing.T) { + testCases := map[wire.FrameType]bool{ + wire.FrameTypePing: true, + wire.FrameTypeAck: false, + wire.FrameTypeAckECN: false, + wire.FrameTypeResetStream: true, + wire.FrameTypeStopSending: true, + wire.FrameTypeCrypto: true, + wire.FrameTypeNewToken: true, + wire.FrameType(0x08): true, + wire.FrameType(0x09): true, + wire.FrameType(0x0a): true, + wire.FrameType(0x0b): true, + wire.FrameType(0x0c): true, + wire.FrameType(0x0d): true, + wire.FrameType(0x0e): true, + wire.FrameType(0x0f): true, + wire.FrameTypeMaxData: true, + wire.FrameTypeMaxStreamData: true, + wire.FrameTypeBidiMaxStreams: true, + wire.FrameTypeUniMaxStreams: true, + wire.FrameTypeDataBlocked: true, + wire.FrameTypeStreamDataBlocked: true, + wire.FrameTypeBidiStreamBlocked: true, + wire.FrameTypeUniStreamBlocked: true, + wire.FrameTypeNewConnectionID: true, + wire.FrameTypeRetireConnectionID: true, + wire.FrameTypePathChallenge: true, + wire.FrameTypePathResponse: true, + wire.FrameTypeConnectionClose: false, + wire.FrameTypeApplicationClose: false, + wire.FrameTypeHandshakeDone: true, + wire.FrameTypeResetStreamAt: true, + wire.FrameTypeDatagramNoLength: true, + wire.FrameTypeDatagramWithLength: true, + } + + for ft, expected := range testCases { + require.Equal(t, expected, IsFrameTypeAckEliciting(ft), "unexpected result for frame type 0x%x", ft) + } +} + func TestAckElicitingFrames(t *testing.T) { testCases := map[wire.Frame]bool{ &wire.AckFrame{}: false, diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index 8befef4f2de..68bebfa7917 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -21,9 +21,9 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8, _ protocol.Version) (int, error) { +func parseAckFrame(frame *AckFrame, b []byte, typ FrameType, ackDelayExponent uint8, _ protocol.Version) (int, error) { startLen := len(b) - ecn := typ == ackECNFrameType + ecn := typ == FrameTypeAckECN la, l, err := quicvarint.Parse(b) if err != nil { @@ -122,9 +122,9 @@ func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8 func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { - b = append(b, ackECNFrameType) + b = append(b, byte(FrameTypeAckECN)) } else { - b = append(b, ackFrameType) + b = append(b, byte(FrameTypeAck)) } b = quicvarint.Append(b, uint64(f.LargestAcked())) b = quicvarint.Append(b, encodeAckDelay(f.DelayTime)) diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index f6390bfcd68..e7f21a0b8ff 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -17,7 +17,7 @@ func TestParseACKWithoutRanges(t *testing.T) { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(10)...) // first ack block var frame AckFrame - n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked()) @@ -31,7 +31,7 @@ func TestParseACKSinglePacket(t *testing.T) { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block var frame AckFrame - n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(55), frame.LargestAcked()) @@ -45,7 +45,7 @@ func TestParseACKAllPacketsFrom0ToLargest(t *testing.T) { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(20)...) // first ack block var frame AckFrame - n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(20), frame.LargestAcked()) @@ -59,7 +59,7 @@ func TestParseACKRejectFirstBlockLargerThanLargestAcked(t *testing.T) { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(21)...) // first ack block var frame AckFrame - _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + _, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.EqualError(t, err, "invalid first ACK range") } @@ -71,7 +71,7 @@ func TestParseACKWithSingleBlock(t *testing.T) { data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block var frame AckFrame - n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(1000), frame.LargestAcked()) @@ -93,7 +93,7 @@ func TestParseACKWithMultipleBlocks(t *testing.T) { data = append(data, encodeVarInt(1)...) // gap data = append(data, encodeVarInt(1)...) // ack block var frame AckFrame - n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1) + n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), n) require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked()) @@ -118,7 +118,7 @@ func TestParseACKUseAckDelayExponent(t *testing.T) { typ, l, err := quicvarint.Parse(b) require.NoError(t, err) var frame AckFrame - n, err := parseAckFrame(&frame, b[l:], typ, protocol.AckDelayExponent+i, protocol.Version1) + n, err := parseAckFrame(&frame, b[l:], FrameType(typ), protocol.AckDelayExponent+i, protocol.Version1) require.NoError(t, err) require.Equal(t, len(b[l:]), n) require.Equal(t, delayTime*(1< 0 + f.DataLenPresent = uint64(typ)&0x1 > 0 var length uint64 if f.DataLenPresent { diff --git a/internal/wire/frame.go b/internal/wire/frame.go index 10d4eebc31c..09ea92f7541 100644 --- a/internal/wire/frame.go +++ b/internal/wire/frame.go @@ -19,3 +19,15 @@ func IsProbingFrame(f Frame) bool { } return false } + +// IsProbingFrameType returns true if the FrameType is a probing frame. +// See section 9.1 of RFC 9000. +func IsProbingFrameType(f FrameType) bool { + //nolint:exhaustive // PATH_CHALLENGE, PATH_RESPONSE and NEW_CONNECTION_ID are the only probing frames + switch f { + case FrameTypePathChallenge, FrameTypePathResponse, FrameTypeNewConnectionID: + return true + default: + return false + } +} diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 794c70ccf10..e92e29c8bd8 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -4,39 +4,12 @@ import ( "errors" "fmt" "io" - "reflect" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/quicvarint" ) -const ( - pingFrameType = 0x1 - ackFrameType = 0x2 - ackECNFrameType = 0x3 - resetStreamFrameType = 0x4 - stopSendingFrameType = 0x5 - cryptoFrameType = 0x6 - newTokenFrameType = 0x7 - maxDataFrameType = 0x10 - maxStreamDataFrameType = 0x11 - bidiMaxStreamsFrameType = 0x12 - uniMaxStreamsFrameType = 0x13 - dataBlockedFrameType = 0x14 - streamDataBlockedFrameType = 0x15 - bidiStreamBlockedFrameType = 0x16 - uniStreamBlockedFrameType = 0x17 - newConnectionIDFrameType = 0x18 - retireConnectionIDFrameType = 0x19 - pathChallengeFrameType = 0x1a - pathResponseFrameType = 0x1b - connectionCloseFrameType = 0x1c - applicationCloseFrameType = 0x1d - handshakeDoneFrameType = 0x1e - resetStreamAtFrameType = 0x24 // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/ -) - var errUnknownFrameType = errors.New("unknown frame type") // The FrameParser parses QUIC frames, one by one. @@ -59,20 +32,15 @@ func NewFrameParser(supportsDatagrams, supportsResetStreamAt bool) *FrameParser } } -// ParseNext parses the next frame. -// It skips PADDING frames. -func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) { - frame, l, err := p.parseNext(data, encLevel, v) - return l, frame, err -} - -func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) { +// ParseType parses the frame type of the next frame. +// It skips over PADDING frames. +func (p *FrameParser) ParseType(b []byte, encLevel protocol.EncryptionLevel) (FrameType, int, error) { var parsed int for len(b) != 0 { typ, l, err := quicvarint.Parse(b) parsed += l if err != nil { - return nil, parsed, &qerr.TransportError{ + return 0, parsed, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, ErrorMessage: err.Error(), } @@ -81,115 +49,126 @@ func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v p if typ == 0x0 { // skip PADDING frames continue } - - f, l, err := p.parseFrame(b, typ, encLevel, v) - parsed += l - if err != nil { - return nil, parsed, &qerr.TransportError{ + ft := FrameType(typ) + valid := ft.isValidRFC9000() || + (p.supportsDatagrams && ft.IsDatagramFrameType()) || + (p.supportsResetStreamAt && ft == FrameTypeResetStreamAt) + if !valid { + return 0, parsed, &qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, FrameType: typ, + ErrorMessage: errUnknownFrameType.Error(), + } + } + if !ft.isAllowedAtEncLevel(encLevel) { + return 0, parsed, &qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, - ErrorMessage: err.Error(), + FrameType: typ, + ErrorMessage: fmt.Sprintf("%d not allowed at encryption level %s", ft, encLevel), } } - return f, parsed, nil + return ft, parsed, nil } - return nil, parsed, nil + return 0, parsed, io.EOF } -func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) { - var frame Frame - var err error - var l int - if typ&0xf8 == 0x8 { - frame, l, err = parseStreamFrame(b, typ, v) - } else { - switch typ { - case pingFrameType: - frame = &PingFrame{} - case ackFrameType, ackECNFrameType: - ackDelayExponent := p.ackDelayExponent - if encLevel != protocol.Encryption1RTT { - ackDelayExponent = protocol.DefaultAckDelayExponent - } - p.ackFrame.Reset() - l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v) - frame = p.ackFrame - case resetStreamFrameType: - frame, l, err = parseResetStreamFrame(b, false, v) - case stopSendingFrameType: - frame, l, err = parseStopSendingFrame(b, v) - case cryptoFrameType: - frame, l, err = parseCryptoFrame(b, v) - case newTokenFrameType: - frame, l, err = parseNewTokenFrame(b, v) - case maxDataFrameType: - frame, l, err = parseMaxDataFrame(b, v) - case maxStreamDataFrameType: - frame, l, err = parseMaxStreamDataFrame(b, v) - case bidiMaxStreamsFrameType, uniMaxStreamsFrameType: - frame, l, err = parseMaxStreamsFrame(b, typ, v) - case dataBlockedFrameType: - frame, l, err = parseDataBlockedFrame(b, v) - case streamDataBlockedFrameType: - frame, l, err = parseStreamDataBlockedFrame(b, v) - case bidiStreamBlockedFrameType, uniStreamBlockedFrameType: - frame, l, err = parseStreamsBlockedFrame(b, typ, v) - case newConnectionIDFrameType: - frame, l, err = parseNewConnectionIDFrame(b, v) - case retireConnectionIDFrameType: - frame, l, err = parseRetireConnectionIDFrame(b, v) - case pathChallengeFrameType: - frame, l, err = parsePathChallengeFrame(b, v) - case pathResponseFrameType: - frame, l, err = parsePathResponseFrame(b, v) - case connectionCloseFrameType, applicationCloseFrameType: - frame, l, err = parseConnectionCloseFrame(b, typ, v) - case handshakeDoneFrameType: - frame = &HandshakeDoneFrame{} - case 0x30, 0x31: - if !p.supportsDatagrams { - return nil, 0, errUnknownFrameType - } - frame, l, err = parseDatagramFrame(b, typ, v) - case resetStreamAtFrameType: - if !p.supportsResetStreamAt { - return nil, 0, errUnknownFrameType - } - frame, l, err = parseResetStreamFrame(b, true, v) - default: - err = errUnknownFrameType +func (p *FrameParser) ParseStreamFrame(frameType FrameType, data []byte, v protocol.Version) (*StreamFrame, int, error) { + frame, n, err := ParseStreamFrame(data, frameType, v) + if err != nil { + return nil, n, &qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: uint64(frameType), + ErrorMessage: err.Error(), } } - if err != nil { - return nil, 0, err + return frame, n, nil +} + +func (p *FrameParser) ParseAckFrame(frameType FrameType, data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (*AckFrame, int, error) { + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent } - if !p.isAllowedAtEncLevel(frame, encLevel) { - return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) + p.ackFrame.Reset() + l, err := parseAckFrame(p.ackFrame, data, frameType, ackDelayExponent, v) + if err != nil { + return nil, l, &qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: uint64(frameType), + ErrorMessage: err.Error(), + } } - return frame, l, nil + + return p.ackFrame, l, nil } -func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { - switch encLevel { - case protocol.EncryptionInitial, protocol.EncryptionHandshake: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: - return true - default: - return false - } - case protocol.Encryption0RTT: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - return false - default: - return true +func (p *FrameParser) ParseDatagramFrame(frameType FrameType, data []byte, v protocol.Version) (*DatagramFrame, int, error) { + f, l, err := parseDatagramFrame(data, frameType, v) + if err != nil { + return nil, 0, &qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: uint64(frameType), + ErrorMessage: err.Error(), } - case protocol.Encryption1RTT: - return true + } + return f, l, nil +} + +// ParseLessCommonFrame parses everything except STREAM, ACK or DATAGRAM. +// These cases should be handled separately for performance reasons. +func (p *FrameParser) ParseLessCommonFrame(frameType FrameType, data []byte, v protocol.Version) (Frame, int, error) { + var frame Frame + var l int + var err error + //nolint:exhaustive // Common frames should already be handled. + switch frameType { + case FrameTypePing: + frame = &PingFrame{} + case FrameTypeResetStream: + frame, l, err = parseResetStreamFrame(data, false, v) + case FrameTypeStopSending: + frame, l, err = parseStopSendingFrame(data, v) + case FrameTypeCrypto: + frame, l, err = parseCryptoFrame(data, v) + case FrameTypeNewToken: + frame, l, err = parseNewTokenFrame(data, v) + case FrameTypeMaxData: + frame, l, err = parseMaxDataFrame(data, v) + case FrameTypeMaxStreamData: + frame, l, err = parseMaxStreamDataFrame(data, v) + case FrameTypeBidiMaxStreams, FrameTypeUniMaxStreams: + frame, l, err = parseMaxStreamsFrame(data, frameType, v) + case FrameTypeDataBlocked: + frame, l, err = parseDataBlockedFrame(data, v) + case FrameTypeStreamDataBlocked: + frame, l, err = parseStreamDataBlockedFrame(data, v) + case FrameTypeBidiStreamBlocked, FrameTypeUniStreamBlocked: + frame, l, err = parseStreamsBlockedFrame(data, frameType, v) + case FrameTypeNewConnectionID: + frame, l, err = parseNewConnectionIDFrame(data, v) + case FrameTypeRetireConnectionID: + frame, l, err = parseRetireConnectionIDFrame(data, v) + case FrameTypePathChallenge: + frame, l, err = parsePathChallengeFrame(data, v) + case FrameTypePathResponse: + frame, l, err = parsePathResponseFrame(data, v) + case FrameTypeConnectionClose, FrameTypeApplicationClose: + frame, l, err = parseConnectionCloseFrame(data, frameType, v) + case FrameTypeHandshakeDone: + frame = &HandshakeDoneFrame{} + case FrameTypeResetStreamAt: + frame, l, err = parseResetStreamFrame(data, true, v) default: - panic("unknown encryption level") + err = errUnknownFrameType + } + if err != nil { + return frame, l, &qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: uint64(frameType), + ErrorMessage: err.Error(), + } } + return frame, l, err } // SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters). diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index fefb5c11ee0..821bc5d6526 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -3,22 +3,31 @@ package wire import ( "bytes" "crypto/rand" + "fmt" + "io" "slices" "testing" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/stretchr/testify/require" ) -func TestFrameParsingReturnsNilWhenNothingToRead(t *testing.T) { +func TestFrameTypeParsingReturnsNilWhenNothingToRead(t *testing.T) { parser := NewFrameParser(true, true) - l, f, err := parser.ParseNext(nil, protocol.Encryption1RTT, protocol.Version1) - require.NoError(t, err) + frameType, l, err := parser.ParseType(nil, protocol.Encryption1RTT) + require.Equal(t, io.EOF, err) + require.Zero(t, frameType) + require.Zero(t, l) +} + +func TestParseLessCommonFrameReturnsEOFWhenNothingToRead(t *testing.T) { + parser := NewFrameParser(true, true) + l, f, err := parser.ParseLessCommonFrame(FrameTypeMaxStreamData, nil, protocol.Version1) + require.IsType(t, &qerr.TransportError{}, err) require.Zero(t, l) - require.Nil(t, f) + require.Zero(t, f) } func TestFrameParsingSkipsPaddingFrames(t *testing.T) { @@ -26,17 +35,24 @@ func TestFrameParsingSkipsPaddingFrames(t *testing.T) { b := []byte{0, 0} // 2 PADDING frames b, err := (&PingFrame{}).Append(b, protocol.Version1) require.NoError(t, err) - l, f, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) - require.Equal(t, &PingFrame{}, f) - require.Equal(t, 2+1, l) + require.Equal(t, 3, l) + require.Equal(t, FrameTypePing, frameType) + + frame, l, err := parser.ParseLessCommonFrame(frameType, b[1:], protocol.Version1) + require.NoError(t, err) + require.Zero(t, l) + require.IsType(t, &PingFrame{}, frame) } func TestFrameParsingHandlesPaddingAtEnd(t *testing.T) { parser := NewFrameParser(true, true) - l, f, err := parser.ParseNext([]byte{0, 0, 0}, protocol.Encryption1RTT, protocol.Version1) - require.NoError(t, err) - require.Nil(t, f) + b := []byte{0, 0, 0} + + _, l, err := parser.ParseType(b, protocol.Encryption1RTT) + require.Equal(t, io.EOF, err) require.Equal(t, 3, l) } @@ -48,10 +64,15 @@ func TestFrameParsingParsesSingleFrame(t *testing.T) { b, err = (&PingFrame{}).Append(b, protocol.Version1) require.NoError(t, err) } - l, f, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) - require.IsType(t, &PingFrame{}, f) + require.Equal(t, FrameTypePing, frameType) require.Equal(t, 1, l) + + frame, l, err := parser.ParseLessCommonFrame(frameType, b, protocol.Version1) + require.NoError(t, err) + require.Zero(t, l) + require.IsType(t, &PingFrame{}, frame) } func TestFrameParserACK(t *testing.T) { @@ -59,12 +80,16 @@ func TestFrameParserACK(t *testing.T) { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - l, frame, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) + require.NoError(t, err) + require.Equal(t, FrameTypeAck, frameType) + require.Equal(t, 1, l) + + frame, l, err := parser.ParseAckFrame(frameType, b[l:], protocol.Encryption1RTT, protocol.Version1) require.NoError(t, err) require.NotNil(t, frame) - require.IsType(t, f, frame) - require.Equal(t, protocol.PacketNumber(0x13), frame.(*AckFrame).LargestAcked()) - require.Equal(t, len(b), l) + require.Equal(t, protocol.PacketNumber(0x13), frame.LargestAcked()) + require.Equal(t, len(b)-1, l) } func TestFrameParserAckDelay(t *testing.T) { @@ -85,15 +110,31 @@ func testFrameParserAckDelay(t *testing.T, encLevel protocol.EncryptionLevel) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - _, frame, err := parser.ParseNext(b, encLevel, protocol.Version1) + frameType, l, err := parser.ParseType(b, encLevel) + require.NoError(t, err) + require.Equal(t, FrameTypeAck, frameType) + require.Equal(t, 1, l) + + frame, l, err := parser.ParseAckFrame(frameType, b[l:], encLevel, protocol.Version1) require.NoError(t, err) + require.Equal(t, len(b)-1, l) if encLevel == protocol.Encryption1RTT { - require.Equal(t, 4*time.Second, frame.(*AckFrame).DelayTime) + require.Equal(t, 4*time.Second, frame.DelayTime) } else { - require.Equal(t, time.Second, frame.(*AckFrame).DelayTime) + require.Equal(t, time.Second, frame.DelayTime) } } +func checkFrameUnsupported(t *testing.T, err error, expectedFrameType uint64) { + t.Helper() + require.ErrorContains(t, err, errUnknownFrameType.Error()) + var transportErr *qerr.TransportError + require.ErrorAs(t, err, &transportErr) + require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) + require.Equal(t, expectedFrameType, transportErr.FrameType) + require.Equal(t, "unknown frame type", transportErr.ErrorMessage) +} + func TestFrameParserStreamFrames(t *testing.T) { parser := NewFrameParser(true, true) f := &StreamFrame{ @@ -104,28 +145,95 @@ func TestFrameParserStreamFrames(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - l, frame, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) require.NoError(t, err) - require.NotNil(t, frame) - require.Equal(t, f, frame) - require.Equal(t, len(b), l) + require.Equal(t, FrameType(0xd), frameType) + require.True(t, frameType.IsStreamFrameType()) + require.Equal(t, 1, l) + + // ParseLessCommonFrame should not handle Stream Frames + frame, l, err := parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1) + checkFrameUnsupported(t, err, 0xd) + require.Nil(t, frame) + require.Zero(t, l) +} + +func TestParseStreamFrameWrapsError(t *testing.T) { + parser := NewFrameParser(true, true) + f := &StreamFrame{ + StreamID: 0x1234, + Offset: 0x1000, + Data: []byte("hello world"), + DataLenPresent: true, + } + b, err := f.Append(nil, protocol.Version1) + require.NoError(t, err) + + // Corrupt the buffer to trigger a parse error + b = b[:len(b)-2] // Remove last 2 bytes to cause an EOF + + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) + require.NoError(t, err) + + frame, n, err := parser.ParseStreamFrame(frameType, b[l:], protocol.Version1) + require.Nil(t, frame) + require.Zero(t, n) + + var transportErr *qerr.TransportError + require.ErrorAs(t, err, &transportErr) + require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) + require.Equal(t, uint64(frameType), transportErr.FrameType) + require.Contains(t, transportErr.Error(), "EOF") +} + +func TestParseStreamFrameSuccess(t *testing.T) { + parser := NewFrameParser(true, true) + original := &StreamFrame{ + StreamID: 0x1234, + Offset: 0x1000, + Fin: true, + Data: []byte("hello world"), + DataLenPresent: true, + } + b, err := original.Append(nil, protocol.Version1) + require.NoError(t, err) + + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) + require.NoError(t, err) + require.True(t, frameType.IsStreamFrameType()) + require.Equal(t, FrameType(0x0f), frameType) // STREAM | OFF | LEN | FIN + + parsed, n, err := parser.ParseStreamFrame(frameType, b[l:], protocol.Version1) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, len(b)-l, n) + + require.Equal(t, original.StreamID, parsed.StreamID) + require.Equal(t, original.Offset, parsed.Offset) + require.Equal(t, original.Fin, parsed.Fin) + require.Equal(t, original.DataLenPresent, parsed.DataLenPresent) + require.Equal(t, original.Data, parsed.Data) } func TestFrameParserFrames(t *testing.T) { tests := []struct { - name string - frame Frame + name string + frameType FrameType + frame Frame }{ { - name: "MAX_DATA", - frame: &MaxDataFrame{MaximumData: 0xcafe}, + name: "MAX_DATA", + frameType: FrameTypeMaxData, + frame: &MaxDataFrame{MaximumData: 0xcafe}, }, { - name: "MAX_STREAM_DATA", - frame: &MaxStreamDataFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdecafbad}, + name: "MAX_STREAM_DATA", + frameType: FrameTypeMaxStreamData, + frame: &MaxStreamDataFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdecafbad}, }, { - name: "RESET_STREAM", + name: "RESET_STREAM", + frameType: FrameTypeResetStream, frame: &ResetStreamFrame{ StreamID: 0xdeadbeef, FinalSize: 0xdecafbad1234, @@ -133,35 +241,43 @@ func TestFrameParserFrames(t *testing.T) { }, }, { - name: "STOP_SENDING", - frame: &StopSendingFrame{StreamID: 0x42}, + name: "STOP_SENDING", + frameType: FrameTypeStopSending, + frame: &StopSendingFrame{StreamID: 0x42}, }, { - name: "CRYPTO", - frame: &CryptoFrame{Offset: 0x1337, Data: []byte("lorem ipsum")}, + name: "CRYPTO", + frameType: FrameTypeCrypto, + frame: &CryptoFrame{Offset: 0x1337, Data: []byte("lorem ipsum")}, }, { - name: "NEW_TOKEN", - frame: &NewTokenFrame{Token: []byte("foobar")}, + name: "NEW_TOKEN", + frameType: FrameTypeNewToken, + frame: &NewTokenFrame{Token: []byte("foobar")}, }, { - name: "MAX_STREAMS", - frame: &MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 0x1337}, + name: "MAX_STREAMS", + frameType: FrameTypeBidiMaxStreams, + frame: &MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 0x1337}, }, { - name: "DATA_BLOCKED", - frame: &DataBlockedFrame{MaximumData: 0x1234}, + name: "DATA_BLOCKED", + frameType: FrameTypeDataBlocked, + frame: &DataBlockedFrame{MaximumData: 0x1234}, }, { - name: "STREAM_DATA_BLOCKED", - frame: &StreamDataBlockedFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdead}, + name: "STREAM_DATA_BLOCKED", + frameType: FrameTypeStreamDataBlocked, + frame: &StreamDataBlockedFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdead}, }, { - name: "STREAMS_BLOCKED", - frame: &StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 0x1234567}, + name: "STREAMS_BLOCKED", + frameType: FrameTypeBidiStreamBlocked, + frame: &StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 0x1234567}, }, { - name: "NEW_CONNECTION_ID", + name: "NEW_CONNECTION_ID", + frameType: FrameTypeNewConnectionID, frame: &NewConnectionIDFrame{ SequenceNumber: 0x1337, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), @@ -169,32 +285,39 @@ func TestFrameParserFrames(t *testing.T) { }, }, { - name: "RETIRE_CONNECTION_ID", - frame: &RetireConnectionIDFrame{SequenceNumber: 0x1337}, + name: "RETIRE_CONNECTION_ID", + frameType: FrameTypeRetireConnectionID, + frame: &RetireConnectionIDFrame{SequenceNumber: 0x1337}, }, { - name: "PATH_CHALLENGE", - frame: &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, + name: "PATH_CHALLENGE", + frameType: FrameTypePathChallenge, + frame: &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, }, { - name: "PATH_RESPONSE", - frame: &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, + name: "PATH_RESPONSE", + frameType: FrameTypePathResponse, + frame: &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, }, { - name: "CONNECTION_CLOSE", - frame: &ConnectionCloseFrame{IsApplicationError: true, ReasonPhrase: "foobar"}, + name: "CONNECTION_CLOSE", + frameType: FrameTypeConnectionClose, + frame: &ConnectionCloseFrame{IsApplicationError: false, ReasonPhrase: "foobar"}, }, { - name: "HANDSHAKE_DONE", - frame: &HandshakeDoneFrame{}, + name: "APPLICATION_CLOSE", + frameType: FrameTypeApplicationClose, + frame: &ConnectionCloseFrame{IsApplicationError: true, ReasonPhrase: "foobar"}, }, { - name: "DATAGRAM", - frame: &DatagramFrame{Data: []byte("foobar")}, + name: "HANDSHAKE_DONE", + frameType: FrameTypeHandshakeDone, + frame: &HandshakeDoneFrame{}, }, { - name: "RESET_STREAM_AT", - frame: &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef}, + name: "RESET_STREAM_AT", + frameType: FrameTypeResetStreamAt, + frame: &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef}, }, } @@ -203,22 +326,173 @@ func TestFrameParserFrames(t *testing.T) { parser := NewFrameParser(true, true) b, err := test.frame.Append(nil, protocol.Version1) require.NoError(t, err) - l, frame, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) + require.NoError(t, err) + require.Equal(t, test.frameType, frameType) + require.Equal(t, 1, l) + + frame, l, err := parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1) require.NoError(t, err) require.Equal(t, test.frame, frame) - require.Equal(t, len(b), l) + require.Equal(t, len(b)-1, l) }) } } -func checkFrameUnsupported(t *testing.T, err error, expectedFrameType uint64) { - t.Helper() - require.ErrorContains(t, err, errUnknownFrameType.Error()) - var transportErr *qerr.TransportError - require.ErrorAs(t, err, &transportErr) - require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) - require.Equal(t, expectedFrameType, transportErr.FrameType) - require.Equal(t, "unknown frame type", transportErr.ErrorMessage) +func TestFrameAllowedAtEncLevel(t *testing.T) { + type testCase struct { + name string + frameType FrameType + frame Frame + allowedInitial bool + allowedHandshake bool + allowedZeroRTT bool + allowedOneRTT bool + } + + for _, tc := range []testCase{ + { + name: "CRYPTO_FRAME", + frameType: FrameTypeCrypto, + frame: &CryptoFrame{Offset: 0, Data: []byte("foo")}, + allowedInitial: true, + allowedHandshake: true, + allowedZeroRTT: false, + allowedOneRTT: true, + }, + { + name: "ACK_FRAME", + frameType: FrameTypeAck, + frame: &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 1}}}, + allowedInitial: true, + allowedHandshake: true, + allowedZeroRTT: false, + allowedOneRTT: true, + }, + { + name: "CONNECTION_CLOSE_FRAME", + frameType: FrameTypeConnectionClose, + frame: &ConnectionCloseFrame{IsApplicationError: false, ReasonPhrase: "err"}, + allowedInitial: true, + allowedHandshake: true, + allowedZeroRTT: false, + allowedOneRTT: true, + }, + { + name: "PING_FRAME", + frameType: FrameTypePing, + frame: &PingFrame{}, + allowedInitial: true, + allowedHandshake: true, + allowedZeroRTT: true, + allowedOneRTT: true, + }, + { + name: "NEW_TOKEN_FRAME", + frameType: FrameTypeNewToken, + frame: &NewTokenFrame{Token: []byte("tok")}, + allowedInitial: false, + allowedHandshake: false, + allowedZeroRTT: false, + allowedOneRTT: true, + }, + { + name: "PATH_RESPONSE_FRAME", + frameType: FrameTypePathResponse, + frame: &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, + allowedInitial: false, + allowedHandshake: false, + allowedZeroRTT: false, + allowedOneRTT: true, + }, + { + name: "RETIRE_CONNECTION_ID_FRAME", + frameType: FrameTypeRetireConnectionID, + frame: &RetireConnectionIDFrame{SequenceNumber: 1}, + allowedInitial: false, + allowedHandshake: false, + allowedZeroRTT: false, + allowedOneRTT: true, + }, + { + name: "MAX_DATA_FRAME", + frameType: FrameTypeMaxData, + frame: &MaxDataFrame{MaximumData: 1}, + allowedInitial: false, + allowedHandshake: false, + allowedZeroRTT: true, + allowedOneRTT: true, + }, + { + name: "STREAM_FRAME", + frameType: FrameType(0x8), + frame: &StreamFrame{StreamID: 1, Data: []byte("foobar")}, + allowedInitial: false, + allowedHandshake: false, + allowedZeroRTT: true, + allowedOneRTT: true, + }, + } { + for _, encLevel := range []protocol.EncryptionLevel{ + protocol.EncryptionInitial, + protocol.EncryptionHandshake, + protocol.Encryption0RTT, + protocol.Encryption1RTT, + } { + t.Run(fmt.Sprintf("%s/%v", tc.name, encLevel), func(t *testing.T) { + var allowed bool + switch encLevel { + case protocol.EncryptionInitial: + allowed = tc.allowedInitial + case protocol.EncryptionHandshake: + allowed = tc.allowedHandshake + case protocol.Encryption0RTT: + allowed = tc.allowedZeroRTT + case protocol.Encryption1RTT: + allowed = tc.allowedOneRTT + } + + parser := NewFrameParser(true, true) + b, err := tc.frame.Append(nil, protocol.Version1) + require.NoError(t, err) + frameType, _, err := parser.ParseType(b, encLevel) + if allowed { + require.NoError(t, err) + require.Equal(t, tc.frameType, frameType) + } else { + require.Error(t, err) + var transportErr *qerr.TransportError + require.ErrorAs(t, err, &transportErr) + require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) + } + }) + } + } +} + +func TestFrameParserDatagramFrame(t *testing.T) { + parser := NewFrameParser(true, true) + f := &DatagramFrame{ + Data: []byte("foobar"), + } + b, err := f.Append(nil, protocol.Version1) + require.NoError(t, err) + frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT) + require.NoError(t, err) + require.Equal(t, FrameTypeDatagramNoLength, frameType) + require.Equal(t, 1, l) + + // ParseLessCommonFrame should not be used to handle DATAGRAM frames + _, _, err = parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1) + require.Error(t, err) + + // parseDatagramFrame should be used for this type + datagramFrame, l, err := parser.ParseDatagramFrame(frameType, b[l:], protocol.Version1) + require.NoError(t, err) + require.IsType(t, &DatagramFrame{}, datagramFrame) + require.Equal(t, 6, l) + require.Equal(t, f.Data, datagramFrame.Data) } func TestFrameParserDatagramUnsupported(t *testing.T) { @@ -226,7 +500,8 @@ func TestFrameParserDatagramUnsupported(t *testing.T) { f := &DatagramFrame{Data: []byte("foobar")} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - _, _, err = parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + + _, _, err = parser.ParseType(b, protocol.Encryption1RTT) checkFrameUnsupported(t, err, 0x30) } @@ -235,14 +510,22 @@ func TestFrameParserResetStreamAtUnsupported(t *testing.T) { f := &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - _, _, err = parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1) + + _, _, err = parser.ParseType(b, protocol.Encryption1RTT) checkFrameUnsupported(t, err, 0x24) } func TestFrameParserInvalidFrameType(t *testing.T) { parser := NewFrameParser(true, true) - _, _, err := parser.ParseNext(encodeVarInt(0x42), protocol.Encryption1RTT, protocol.Version1) - checkFrameUnsupported(t, err, 0x42) + + _, l, err := parser.ParseType(encodeVarInt(0x42), protocol.Encryption1RTT) + + require.Equal(t, 2, l) + + require.Error(t, err) + var transportErr *qerr.TransportError + require.ErrorAs(t, err, &transportErr) + require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode) } func TestFrameParsingErrorsOnInvalidFrames(t *testing.T) { @@ -253,7 +536,13 @@ func TestFrameParsingErrorsOnInvalidFrames(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - _, _, err = parser.ParseNext(b[:len(b)-2], protocol.Encryption1RTT, protocol.Version1) + + frameType, l, err := parser.ParseType(b[:len(b)-2], protocol.Encryption1RTT) + require.NoError(t, err) + require.Equal(t, FrameTypeMaxStreamData, frameType) + require.Equal(t, 1, l) + + _, _, err = parser.ParseLessCommonFrame(frameType, b[1:len(b)-2], protocol.Version1) require.Error(t, err) var transportErr *qerr.TransportError require.ErrorAs(t, err, &transportErr) @@ -274,102 +563,160 @@ func writeFrames(tb testing.TB, frames ...Frame) []byte { // We can therefore not use the require framework, as it allocates. func parseFrames(tb testing.TB, parser *FrameParser, data []byte, frames ...Frame) { for _, expectedFrame := range frames { - l, frame, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1) + frameType, l, err := parser.ParseType(data, protocol.Encryption1RTT) if err != nil { tb.Fatal(err) } data = data[l:] - if frame == nil { - break - } - // Use type switch approach (like master branch) - switch f := frame.(type) { - case *StreamFrame: + if frameType.IsStreamFrameType() { sf := expectedFrame.(*StreamFrame) - if sf.StreamID != f.StreamID || sf.Offset != f.Offset || !bytes.Equal(sf.Data, f.Data) { - tb.Fatalf("STREAM frame does not match: %v vs %v", sf, f) + frame, l, err := ParseStreamFrame(data, frameType, protocol.Version1) + if err != nil { + tb.Fatal(err) + } + if sf.StreamID != frame.StreamID || sf.Offset != frame.Offset { + tb.Fatalf("STREAM frame does not match: %v vs %v", sf, frame) } - f.PutBack() - case *AckFrame: + frame.PutBack() + data = data[l:] + continue + } + + if frameType.IsAckFrameType() { af, ok := expectedFrame.(*AckFrame) if !ok { tb.Fatalf("expected ACK, but got %v", expectedFrame) } + + f, l, err := parser.ParseAckFrame(frameType, data, protocol.Encryption1RTT, protocol.Version1) if f.DelayTime != af.DelayTime || f.ECNCE != af.ECNCE || f.ECT0 != af.ECT0 || f.ECT1 != af.ECT1 { + tb.Fatal(err) + } + if f.DelayTime != af.DelayTime { tb.Fatalf("ACK frame does not match: %v vs %v", af, f) } if !slices.Equal(f.AckRanges, af.AckRanges) { tb.Fatalf("ACK frame ACK ranges don't match: %v vs %v", af, f) } - case *DatagramFrame: + data = data[l:] + continue + } + + if frameType.IsDatagramFrameType() { df, ok := expectedFrame.(*DatagramFrame) if !ok { tb.Fatalf("expected DATAGRAM, but got %v", expectedFrame) } + + f, l, err := parser.ParseDatagramFrame(frameType, data, protocol.Version1) + if err != nil { + tb.Fatal(err) + } if df.DataLenPresent != f.DataLenPresent || !bytes.Equal(df.Data, f.Data) { tb.Fatalf("DATAGRAM frame does not match: %v vs %v", df, f) } - case *MaxDataFrame: + data = data[l:] + continue + } + + f, l, err := parser.ParseLessCommonFrame(frameType, data, protocol.Version1) + if err != nil { + tb.Fatal(err) + } + data = data[l:] + + switch frameType { + case FrameTypeMaxData: mdf, ok := expectedFrame.(*MaxDataFrame) if !ok { tb.Fatalf("expected MAX_DATA, but got %v", expectedFrame) } - if *f != *mdf { + if *f.(*MaxDataFrame) != *mdf { tb.Fatalf("MAX_DATA frame does not match: %v vs %v", f, mdf) } - case *MaxStreamsFrame: + case FrameTypeUniMaxStreams: msf, ok := expectedFrame.(*MaxStreamsFrame) if !ok { tb.Fatalf("expected MAX_STREAMS, but got %v", expectedFrame) } - if *f != *msf { + if *f.(*MaxStreamsFrame) != *msf { tb.Fatalf("MAX_STREAMS frame does not match: %v vs %v", f, msf) } - case *MaxStreamDataFrame: + case FrameTypeMaxStreamData: mdf, ok := expectedFrame.(*MaxStreamDataFrame) if !ok { tb.Fatalf("expected MAX_STREAM_DATA, but got %v", expectedFrame) } - if *f != *mdf { + if *f.(*MaxStreamDataFrame) != *mdf { tb.Fatalf("MAX_STREAM_DATA frame does not match: %v vs %v", f, mdf) } - case *CryptoFrame: + case FrameTypeCrypto: cf, ok := expectedFrame.(*CryptoFrame) if !ok { tb.Fatalf("expected CRYPTO, but got %v", expectedFrame) } - if f.Offset != cf.Offset || !bytes.Equal(f.Data, cf.Data) { + frame := f.(*CryptoFrame) + if frame.Offset != cf.Offset || !bytes.Equal(frame.Data, cf.Data) { tb.Fatalf("CRYPTO frame does not match: %v vs %v", f, cf) } - case *PingFrame: - _ = f - case *ResetStreamFrame: + case FrameTypePing: + _ = f.(*PingFrame) + case FrameTypeResetStream: rsf, ok := expectedFrame.(*ResetStreamFrame) if !ok { tb.Fatalf("expected RESET_STREAM, but got %v", expectedFrame) } - if *f != *rsf { + if *f.(*ResetStreamFrame) != *rsf { tb.Fatalf("RESET_STREAM frame does not match: %v vs %v", f, rsf) } + continue default: - tb.Fatalf("Frame type not supported in benchmark: %T", f) + tb.Fatalf("Frame type not supported in benchmark or should not occur: %v", frameType) } } } -func benchmarkFrames(b *testing.B, frames ...Frame) { - buf := writeFrames(b, frames...) +func TestFrameParserAllocs(t *testing.T) { + t.Run("STREAM", func(t *testing.T) { + var frames []Frame + for i := range 10 { + frames = append(frames, &StreamFrame{ + StreamID: protocol.StreamID(1337 + i), + Offset: protocol.ByteCount(1e7 + i), + Data: make([]byte, 200+i), + DataLenPresent: true, + }) + } + require.Zero(t, testFrameParserAllocs(t, frames)) + }) + t.Run("ACK", func(t *testing.T) { + var frames []Frame + for i := range 10 { + frames = append(frames, &AckFrame{ + AckRanges: []AckRange{ + {Smallest: protocol.PacketNumber(5000 + i), Largest: protocol.PacketNumber(5200 + i)}, + {Smallest: protocol.PacketNumber(1 + i), Largest: protocol.PacketNumber(4200 + i)}, + }, + DelayTime: time.Duration(int64(time.Millisecond) * int64(i)), + ECT0: uint64(5000 + i), + ECT1: uint64(i), + ECNCE: uint64(10 + i), + }) + } + require.Zero(t, testFrameParserAllocs(t, frames)) + }) +} + +func testFrameParserAllocs(t *testing.T, frames []Frame) float64 { + buf := writeFrames(t, frames...) parser := NewFrameParser(true, true) parser.SetAckDelayExponent(3) - b.ResetTimer() - b.ReportAllocs() - - for range b.N { - parseFrames(b, parser, buf, frames...) - } + return testing.AllocsPerRun(100, func() { + parseFrames(t, parser, buf, frames...) + }) } func BenchmarkParseOtherFrames(b *testing.B) { @@ -428,3 +775,17 @@ func BenchmarkParseDatagramFrame(b *testing.B) { } benchmarkFrames(b, frames...) } + +func benchmarkFrames(b *testing.B, frames ...Frame) { + buf := writeFrames(b, frames...) + + parser := NewFrameParser(true, true) + parser.SetAckDelayExponent(3) + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + parseFrames(b, parser, buf, frames...) + } +} diff --git a/internal/wire/frame_test.go b/internal/wire/frame_test.go index 3012b558d7b..5dbfba8963b 100644 --- a/internal/wire/frame_test.go +++ b/internal/wire/frame_test.go @@ -27,3 +27,16 @@ func TestProbingFrames(t *testing.T) { require.Equal(t, expected, IsProbingFrame(f)) } } + +func TestIsProbingFrameType(t *testing.T) { + tests := map[FrameType]bool{ + FrameTypePathChallenge: true, + FrameTypePathResponse: true, + FrameTypeNewConnectionID: true, + FrameType(0x01): false, + FrameType(0xFF): false, + } + for ft, expected := range tests { + require.Equal(t, expected, IsProbingFrameType(ft)) + } +} diff --git a/internal/wire/frame_type.go b/internal/wire/frame_type.go new file mode 100644 index 00000000000..0576657f730 --- /dev/null +++ b/internal/wire/frame_type.go @@ -0,0 +1,77 @@ +package wire + +import "github.com/quic-go/quic-go/internal/protocol" + +type FrameType uint64 + +// These constants correspond to those defined in RFC 9000. +// Stream frame types are not listed explicitly here; use FrameType.IsStreamFrameType() to identify them. +const ( + FrameTypePing FrameType = 0x1 + FrameTypeAck FrameType = 0x2 + FrameTypeAckECN FrameType = 0x3 + FrameTypeResetStream FrameType = 0x4 + FrameTypeStopSending FrameType = 0x5 + FrameTypeCrypto FrameType = 0x6 + FrameTypeNewToken FrameType = 0x7 + + FrameTypeMaxData FrameType = 0x10 + FrameTypeMaxStreamData FrameType = 0x11 + FrameTypeBidiMaxStreams FrameType = 0x12 + FrameTypeUniMaxStreams FrameType = 0x13 + FrameTypeDataBlocked FrameType = 0x14 + FrameTypeStreamDataBlocked FrameType = 0x15 + FrameTypeBidiStreamBlocked FrameType = 0x16 + FrameTypeUniStreamBlocked FrameType = 0x17 + FrameTypeNewConnectionID FrameType = 0x18 + FrameTypeRetireConnectionID FrameType = 0x19 + FrameTypePathChallenge FrameType = 0x1a + FrameTypePathResponse FrameType = 0x1b + FrameTypeConnectionClose FrameType = 0x1c + FrameTypeApplicationClose FrameType = 0x1d + FrameTypeHandshakeDone FrameType = 0x1e + FrameTypeResetStreamAt FrameType = 0x24 // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/ + + FrameTypeDatagramNoLength FrameType = 0x30 + FrameTypeDatagramWithLength FrameType = 0x31 +) + +func (t FrameType) IsStreamFrameType() bool { + return t >= 0x8 && t <= 0xf +} + +func (t FrameType) isValidRFC9000() bool { + return t <= 0x1e +} + +func (t FrameType) IsAckFrameType() bool { + return t == FrameTypeAck || t == FrameTypeAckECN +} + +func (t FrameType) IsDatagramFrameType() bool { + return t == FrameTypeDatagramNoLength || t == FrameTypeDatagramWithLength +} + +func (t FrameType) isAllowedAtEncLevel(encLevel protocol.EncryptionLevel) bool { + //nolint:exhaustive + switch encLevel { + case protocol.EncryptionInitial, protocol.EncryptionHandshake: + switch t { + case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypePing: + return true + default: + return false + } + case protocol.Encryption0RTT: + switch t { + case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypeNewToken, FrameTypePathResponse, FrameTypeRetireConnectionID: + return false + default: + return true + } + case protocol.Encryption1RTT: + return true + default: + panic("unknown encryption level") + } +} diff --git a/internal/wire/frame_type_test.go b/internal/wire/frame_type_test.go new file mode 100644 index 00000000000..2702af37954 --- /dev/null +++ b/internal/wire/frame_type_test.go @@ -0,0 +1,29 @@ +package wire + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsStreamFrameType(t *testing.T) { + for i := 0x08; i <= 0x0f; i++ { + require.Truef(t, FrameType(i).IsStreamFrameType(), "FrameType(0x%x).IsStreamFrameType() = false, want true", i) + } + + require.False(t, FrameType(0x1).IsStreamFrameType()) +} + +func TestIsAckFrameType(t *testing.T) { + require.True(t, FrameTypeAck.IsAckFrameType(), "AckFrameType should be recognized as ACK") + require.True(t, FrameTypeAckECN.IsAckFrameType(), "AckECNFrameType should be recognized as ACK") + require.False(t, FrameTypePing.IsAckFrameType(), "PingFrameType should not be recognized as ACK") + require.False(t, FrameType(0x10).IsAckFrameType(), "MaxDataFrameType should not be recognized as ACK") +} + +func TestIsDatagramFrameType(t *testing.T) { + require.True(t, FrameTypeDatagramNoLength.IsDatagramFrameType(), "DatagramNoLengthFrameType should be recognized as DATAGRAM") + require.True(t, FrameTypeDatagramWithLength.IsDatagramFrameType(), "DatagramWithLengthFrameType should be recognized as DATAGRAM") + require.False(t, FrameTypePing.IsDatagramFrameType(), "PingFrameType should not be recognized as DATAGRAM") + require.False(t, FrameType(0x1e).IsDatagramFrameType(), "HandshakeDoneFrameType should not be recognized as DATAGRAM") +} diff --git a/internal/wire/handshake_done_frame.go b/internal/wire/handshake_done_frame.go index 85dd6474559..bf95f525b8c 100644 --- a/internal/wire/handshake_done_frame.go +++ b/internal/wire/handshake_done_frame.go @@ -8,7 +8,7 @@ import ( type HandshakeDoneFrame struct{} func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - return append(b, handshakeDoneFrameType), nil + return append(b, byte(FrameTypeHandshakeDone)), nil } // Length of a written frame diff --git a/internal/wire/handshake_done_frame_test.go b/internal/wire/handshake_done_frame_test.go index 51381df455d..bec44ec9733 100644 --- a/internal/wire/handshake_done_frame_test.go +++ b/internal/wire/handshake_done_frame_test.go @@ -11,6 +11,6 @@ func TestWriteHandshakeDoneSampleFrame(t *testing.T) { frame := HandshakeDoneFrame{} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - require.Equal(t, []byte{handshakeDoneFrameType}, b) + require.Equal(t, []byte{byte(FrameTypeHandshakeDone)}, b) require.Equal(t, protocol.ByteCount(1), frame.Length(protocol.Version1)) } diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go index 5819c027393..bfbdcba6666 100644 --- a/internal/wire/max_data_frame.go +++ b/internal/wire/max_data_frame.go @@ -22,7 +22,7 @@ func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) } func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, maxDataFrameType) + b = append(b, byte(FrameTypeMaxData)) b = quicvarint.Append(b, uint64(f.MaximumData)) return b, nil } diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go index 2c8060894d3..a6cb3d25567 100644 --- a/internal/wire/max_data_frame_test.go +++ b/internal/wire/max_data_frame_test.go @@ -32,7 +32,7 @@ func TestWriteMaxDataFrame(t *testing.T) { f := &MaxDataFrame{MaximumData: 0xdeadbeefcafe} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{maxDataFrameType} + expected := []byte{byte(FrameTypeMaxData)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go index db9091af8e1..0966ea46954 100644 --- a/internal/wire/max_stream_data_frame.go +++ b/internal/wire/max_stream_data_frame.go @@ -31,7 +31,7 @@ func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, } func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, maxStreamDataFrameType) + b = append(b, byte(FrameTypeMaxStreamData)) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) return b, nil diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go index 05559ebfa17..f4757da5c01 100644 --- a/internal/wire/max_stream_data_frame_test.go +++ b/internal/wire/max_stream_data_frame_test.go @@ -36,7 +36,7 @@ func TestWriteMaxStreamDataFrame(t *testing.T) { StreamID: 0xdecafbad, MaximumStreamData: 0xdeadbeefcafe42, } - expected := []byte{maxStreamDataFrameType} + expected := []byte{byte(FrameTypeMaxStreamData)} expected = append(expected, encodeVarInt(0xdecafbad)...) expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) b, err := f.Append(nil, protocol.Version1) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go index a8745bd124d..30612e23bc4 100644 --- a/internal/wire/max_streams_frame.go +++ b/internal/wire/max_streams_frame.go @@ -13,12 +13,13 @@ type MaxStreamsFrame struct { MaxStreamNum protocol.StreamNum } -func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreamsFrame, int, error) { +func parseMaxStreamsFrame(b []byte, typ FrameType, _ protocol.Version) (*MaxStreamsFrame, int, error) { f := &MaxStreamsFrame{} + //nolint:exhaustive // Function will only be called with BidiMaxStreamsFrameType or UniMaxStreamsFrameType switch typ { - case bidiMaxStreamsFrameType: + case FrameTypeBidiMaxStreams: f.Type = protocol.StreamTypeBidi - case uniMaxStreamsFrameType: + case FrameTypeUniMaxStreams: f.Type = protocol.StreamTypeUni } streamID, l, err := quicvarint.Parse(b) @@ -35,9 +36,9 @@ func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreams func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: - b = append(b, bidiMaxStreamsFrameType) + b = append(b, byte(FrameTypeBidiMaxStreams)) case protocol.StreamTypeUni: - b = append(b, uniMaxStreamsFrameType) + b = append(b, byte(FrameTypeUniMaxStreams)) } b = quicvarint.Append(b, uint64(f.MaxStreamNum)) return b, nil diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go index f5be03e1e22..7afbce2ff8b 100644 --- a/internal/wire/max_streams_frame_test.go +++ b/internal/wire/max_streams_frame_test.go @@ -13,7 +13,7 @@ import ( func TestParseMaxStreamsFrameBidirectional(t *testing.T) { data := encodeVarInt(0xdecaf) - f, l, err := parseMaxStreamsFrame(data, bidiMaxStreamsFrameType, protocol.Version1) + f, l, err := parseMaxStreamsFrame(data, FrameTypeBidiMaxStreams, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeBidi, f.Type) require.EqualValues(t, 0xdecaf, f.MaxStreamNum) @@ -22,7 +22,7 @@ func TestParseMaxStreamsFrameBidirectional(t *testing.T) { func TestParseMaxStreamsFrameUnidirectional(t *testing.T) { data := encodeVarInt(0xdecaf) - f, l, err := parseMaxStreamsFrame(data, uniMaxStreamsFrameType, protocol.Version1) + f, l, err := parseMaxStreamsFrame(data, FrameTypeUniMaxStreams, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeUni, f.Type) require.EqualValues(t, 0xdecaf, f.MaxStreamNum) @@ -59,7 +59,7 @@ func TestParseMaxStreamsMaxValue(t *testing.T) { typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] - frame, _, err := parseMaxStreamsFrame(b, typ, protocol.Version1) + frame, _, err := parseMaxStreamsFrame(b, FrameType(typ), protocol.Version1) require.NoError(t, err) require.Equal(t, f, frame) }) @@ -84,7 +84,7 @@ func TestParseMaxStreamsErrorsOnTooLargeStreamCount(t *testing.T) { typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] - _, _, err = parseMaxStreamsFrame(b, typ, protocol.Version1) + _, _, err = parseMaxStreamsFrame(b, FrameType(typ), protocol.Version1) require.EqualError(t, err, fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1)) }) } @@ -97,7 +97,7 @@ func TestWriteMaxStreamsBidirectional(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{bidiMaxStreamsFrameType} + expected := []byte{byte(FrameTypeBidiMaxStreams)} expected = append(expected, encodeVarInt(0xdeadbeef)...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) @@ -110,7 +110,7 @@ func TestWriteMaxStreamsUnidirectional(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{uniMaxStreamsFrameType} + expected := []byte{byte(FrameTypeUniMaxStreams)} expected = append(expected, encodeVarInt(0xdecafbad)...) require.Equal(t, expected, b) require.Len(t, b, int(f.Length(protocol.Version1))) diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go index 6f2287f44b7..058319266f5 100644 --- a/internal/wire/new_connection_id_frame.go +++ b/internal/wire/new_connection_id_frame.go @@ -61,7 +61,7 @@ func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFr } func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, newConnectionIDFrameType) + b = append(b, byte(FrameTypeNewConnectionID)) b = quicvarint.Append(b, f.SequenceNumber) b = quicvarint.Append(b, f.RetirePriorTo) connIDLen := f.ConnectionID.Len() diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index 739cf2cfda9..3292e0f3c91 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -77,7 +77,7 @@ func TestWriteNewConnectionIDFrame(t *testing.T) { } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{newConnectionIDFrameType} + expected := []byte{byte(FrameTypeNewConnectionID)} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0x42)...) expected = append(expected, 6) diff --git a/internal/wire/new_token_frame.go b/internal/wire/new_token_frame.go index f1d4d00fe66..73d356b1ad1 100644 --- a/internal/wire/new_token_frame.go +++ b/internal/wire/new_token_frame.go @@ -31,7 +31,7 @@ func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, erro } func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, newTokenFrameType) + b = append(b, byte(FrameTypeNewToken)) b = quicvarint.Append(b, uint64(len(f.Token))) b = append(b, f.Token...) return b, nil diff --git a/internal/wire/new_token_frame_test.go b/internal/wire/new_token_frame_test.go index 77da62e8f8b..cd2ae3dafce 100644 --- a/internal/wire/new_token_frame_test.go +++ b/internal/wire/new_token_frame_test.go @@ -43,7 +43,7 @@ func TestWriteNewTokenFrame(t *testing.T) { f := &NewTokenFrame{Token: []byte(token)} b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{newTokenFrameType} + expected := []byte{byte(FrameTypeNewToken)} expected = append(expected, encodeVarInt(uint64(len(token)))...) expected = append(expected, token...) require.Equal(t, expected, b) diff --git a/internal/wire/path_challenge_frame.go b/internal/wire/path_challenge_frame.go index 2aca989fa6b..7a4a767e518 100644 --- a/internal/wire/path_challenge_frame.go +++ b/internal/wire/path_challenge_frame.go @@ -21,7 +21,7 @@ func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, } func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, pathChallengeFrameType) + b = append(b, byte(FrameTypePathChallenge)) b = append(b, f.Data[:]...) return b, nil } diff --git a/internal/wire/path_challenge_frame_test.go b/internal/wire/path_challenge_frame_test.go index 3a755e89177..f5e521aa591 100644 --- a/internal/wire/path_challenge_frame_test.go +++ b/internal/wire/path_challenge_frame_test.go @@ -32,6 +32,6 @@ func TestWritePathChallenge(t *testing.T) { frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - require.Equal(t, []byte{pathChallengeFrameType, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b) + require.Equal(t, []byte{byte(FrameTypePathChallenge), 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } diff --git a/internal/wire/path_response_frame.go b/internal/wire/path_response_frame.go index 76532c8527b..e76d037b151 100644 --- a/internal/wire/path_response_frame.go +++ b/internal/wire/path_response_frame.go @@ -21,7 +21,7 @@ func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, i } func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, pathResponseFrameType) + b = append(b, byte(FrameTypePathResponse)) b = append(b, f.Data[:]...) return b, nil } diff --git a/internal/wire/path_response_frame_test.go b/internal/wire/path_response_frame_test.go index d939b0a1c42..884c407df76 100644 --- a/internal/wire/path_response_frame_test.go +++ b/internal/wire/path_response_frame_test.go @@ -32,6 +32,6 @@ func TestWritePathResponse(t *testing.T) { frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - require.Equal(t, []byte{pathResponseFrameType, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b) + require.Equal(t, []byte{byte(FrameTypePathResponse), 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b) require.Len(t, b, int(frame.Length(protocol.Version1))) } diff --git a/internal/wire/ping_frame.go b/internal/wire/ping_frame.go index 71f8d16c38f..5d344d447f8 100644 --- a/internal/wire/ping_frame.go +++ b/internal/wire/ping_frame.go @@ -8,7 +8,7 @@ import ( type PingFrame struct{} func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - return append(b, pingFrameType), nil + return append(b, byte(FrameTypePing)), nil } // Length of a written frame diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go index cb678bf453d..4101b76b26f 100644 --- a/internal/wire/reset_stream_frame.go +++ b/internal/wire/reset_stream_frame.go @@ -56,9 +56,9 @@ func parseResetStreamFrame(b []byte, isResetStreamAt bool, _ protocol.Version) ( func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if f.ReliableSize == 0 { - b = quicvarint.Append(b, resetStreamFrameType) + b = quicvarint.Append(b, uint64(FrameTypeResetStream)) } else { - b = quicvarint.Append(b, resetStreamAtFrameType) + b = quicvarint.Append(b, uint64(FrameTypeResetStreamAt)) } b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go index 3f85a985394..d85d616ba33 100644 --- a/internal/wire/reset_stream_frame_test.go +++ b/internal/wire/reset_stream_frame_test.go @@ -77,7 +77,7 @@ func TestWriteResetStream(t *testing.T) { } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{resetStreamFrameType} + expected := []byte{byte(FrameTypeResetStream)} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0xcafe)...) expected = append(expected, encodeVarInt(0x11223344decafbad)...) @@ -94,7 +94,7 @@ func TestWriteResetStreamAt(t *testing.T) { } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{resetStreamAtFrameType} + expected := []byte{byte(FrameTypeResetStreamAt)} expected = append(expected, encodeVarInt(1337)...) expected = append(expected, encodeVarInt(0xcafe)...) expected = append(expected, encodeVarInt(42)...) diff --git a/internal/wire/retire_connection_id_frame.go b/internal/wire/retire_connection_id_frame.go index 27aeff8428b..1927f9dc07b 100644 --- a/internal/wire/retire_connection_id_frame.go +++ b/internal/wire/retire_connection_id_frame.go @@ -19,7 +19,7 @@ func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnecti } func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, retireConnectionIDFrameType) + b = append(b, byte(FrameTypeRetireConnectionID)) b = quicvarint.Append(b, f.SequenceNumber) return b, nil } diff --git a/internal/wire/retire_connection_id_frame_test.go b/internal/wire/retire_connection_id_frame_test.go index 9c76151c961..e1b64ccfdf4 100644 --- a/internal/wire/retire_connection_id_frame_test.go +++ b/internal/wire/retire_connection_id_frame_test.go @@ -32,7 +32,7 @@ func TestWriteRetireConnectionID(t *testing.T) { frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{retireConnectionIDFrameType} + expected := []byte{byte(FrameTypeRetireConnectionID)} expected = append(expected, encodeVarInt(0x1337)...) require.Equal(t, expected, b) require.Len(t, b, int(frame.Length(protocol.Version1))) diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go index a2326f8ec42..2b15c7109f6 100644 --- a/internal/wire/stop_sending_frame.go +++ b/internal/wire/stop_sending_frame.go @@ -38,7 +38,7 @@ func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount { } func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { - b = append(b, stopSendingFrameType) + b = append(b, byte(FrameTypeStopSending)) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) return b, nil diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go index b670c047a7a..90bc8a32e66 100644 --- a/internal/wire/stop_sending_frame_test.go +++ b/internal/wire/stop_sending_frame_test.go @@ -39,7 +39,7 @@ func TestWriteStopSendingFrame(t *testing.T) { } b, err := frame.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{stopSendingFrameType} + expected := []byte{byte(FrameTypeStopSending)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) expected = append(expected, encodeVarInt(0xdecafbad)...) require.Equal(t, expected, b) diff --git a/internal/wire/stream_data_blocked_frame_test.go b/internal/wire/stream_data_blocked_frame_test.go index b73e45815c7..325cb58fd7f 100644 --- a/internal/wire/stream_data_blocked_frame_test.go +++ b/internal/wire/stream_data_blocked_frame_test.go @@ -38,7 +38,7 @@ func TestWriteStreamDataBlocked(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{streamDataBlockedFrameType} + expected := []byte{byte(FrameTypeStreamDataBlocked)} expected = append(expected, encodeVarInt(uint64(f.StreamID))...) expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) require.Equal(t, expected, b) diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go index cdc32722fbc..e53962b193c 100644 --- a/internal/wire/stream_frame.go +++ b/internal/wire/stream_frame.go @@ -19,7 +19,7 @@ type StreamFrame struct { fromPool bool } -func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, int, error) { +func ParseStreamFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamFrame, int, error) { startLen := len(b) hasOffset := typ&0b100 > 0 fin := typ&0b1 > 0 diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go index b6177571253..3c658100135 100644 --- a/internal/wire/stream_frame_test.go +++ b/internal/wire/stream_frame_test.go @@ -14,7 +14,7 @@ func TestParseStreamFrameWithOffBit(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, []byte("foobar")...) - frame, l, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1) + frame, l, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, []byte("foobar"), frame.Data) @@ -27,7 +27,7 @@ func TestParseStreamFrameRespectsLEN(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(4)...) // data length data = append(data, []byte("foobar")...) - frame, l, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1) + frame, l, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, []byte("foob"), frame.Data) @@ -39,7 +39,7 @@ func TestParseStreamFrameRespectsLEN(t *testing.T) { func TestParseStreamFrameWithFINBit(t *testing.T) { data := encodeVarInt(9) // stream ID data = append(data, []byte("foobar")...) - frame, l, err := parseStreamFrame(data, 0x8^0x1, protocol.Version1) + frame, l, err := ParseStreamFrame(data, 0x8^0x1, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(9), frame.StreamID) require.Equal(t, []byte("foobar"), frame.Data) @@ -51,7 +51,7 @@ func TestParseStreamFrameWithFINBit(t *testing.T) { func TestParseStreamFrameAllowsEmpty(t *testing.T) { data := encodeVarInt(0x1337) // stream ID data = append(data, encodeVarInt(0x12345)...) // offset - f, l, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1) + f, l, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x1337), f.StreamID) require.Equal(t, protocol.ByteCount(0x12345), f.Offset) @@ -64,7 +64,7 @@ func TestParseStreamFrameRejectsOverflow(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset data = append(data, []byte("foobar")...) - _, _, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1) + _, _, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1) require.EqualError(t, err, "stream data overflows maximum offset") } @@ -72,7 +72,7 @@ func TestParseStreamFrameRejectsLongFrames(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) - _, _, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1) + _, _, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1) require.Equal(t, io.EOF, err) } @@ -80,7 +80,7 @@ func TestParseStreamFrameRejectsFramesExceedingRemainingSize(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, encodeVarInt(7)...) // data length data = append(data, []byte("foobar")...) - _, _, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1) + _, _, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1) require.Equal(t, io.EOF, err) } @@ -90,10 +90,10 @@ func TestParseStreamFrameErrorsOnEOFs(t *testing.T) { data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) - _, _, err := parseStreamFrame(data, typ, protocol.Version1) + _, _, err := ParseStreamFrame(data, FrameType(typ), protocol.Version1) require.NoError(t, err) for i := range data { - _, _, err = parseStreamFrame(data[:i], typ, protocol.Version1) + _, _, err = ParseStreamFrame(data[:i], FrameType(typ), protocol.Version1) require.Error(t, err) } } @@ -101,7 +101,7 @@ func TestParseStreamFrameErrorsOnEOFs(t *testing.T) { func TestParseStreamUsesBufferForLongFrames(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) - frame, l, err := parseStreamFrame(data, 0x8, protocol.Version1) + frame, l, err := ParseStreamFrame(data, 0x8, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize), frame.Data) @@ -115,7 +115,7 @@ func TestParseStreamUsesBufferForLongFrames(t *testing.T) { func TestParseStreamDoesNotUseBufferForShortFrames(t *testing.T) { data := encodeVarInt(0x12345) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) - frame, l, err := parseStreamFrame(data, 0x8, protocol.Version1) + frame, l, err := ParseStreamFrame(data, 0x8, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamID(0x12345), frame.StreamID) require.Equal(t, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1), frame.Data) diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go index c946fec31bf..d98fde46c01 100644 --- a/internal/wire/streams_blocked_frame.go +++ b/internal/wire/streams_blocked_frame.go @@ -13,12 +13,13 @@ type StreamsBlockedFrame struct { StreamLimit protocol.StreamNum } -func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, int, error) { +func parseStreamsBlockedFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamsBlockedFrame, int, error) { f := &StreamsBlockedFrame{} + //nolint:exhaustive // This will only be called with a BidiStreamBlockedFrameType or a UniStreamBlockedFrameType. switch typ { - case bidiStreamBlockedFrameType: + case FrameTypeBidiStreamBlocked: f.Type = protocol.StreamTypeBidi - case uniStreamBlockedFrameType: + case FrameTypeUniStreamBlocked: f.Type = protocol.StreamTypeUni } streamLimit, l, err := quicvarint.Parse(b) @@ -35,9 +36,9 @@ func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*Stream func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: - b = append(b, bidiStreamBlockedFrameType) + b = append(b, byte(FrameTypeBidiStreamBlocked)) case protocol.StreamTypeUni: - b = append(b, uniStreamBlockedFrameType) + b = append(b, byte(FrameTypeUniStreamBlocked)) } b = quicvarint.Append(b, uint64(f.StreamLimit)) return b, nil diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go index ae8913ac970..e49a124d0ca 100644 --- a/internal/wire/streams_blocked_frame_test.go +++ b/internal/wire/streams_blocked_frame_test.go @@ -13,7 +13,7 @@ import ( func TestParseStreamsBlockedFrameBidirectional(t *testing.T) { data := encodeVarInt(0x1337) - f, l, err := parseStreamsBlockedFrame(data, bidiStreamBlockedFrameType, protocol.Version1) + f, l, err := parseStreamsBlockedFrame(data, FrameTypeBidiStreamBlocked, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeBidi, f.Type) require.EqualValues(t, 0x1337, f.StreamLimit) @@ -22,7 +22,7 @@ func TestParseStreamsBlockedFrameBidirectional(t *testing.T) { func TestParseStreamsBlockedFrameUnidirectional(t *testing.T) { data := encodeVarInt(0x7331) - f, l, err := parseStreamsBlockedFrame(data, uniStreamBlockedFrameType, protocol.Version1) + f, l, err := parseStreamsBlockedFrame(data, FrameTypeUniStreamBlocked, protocol.Version1) require.NoError(t, err) require.Equal(t, protocol.StreamTypeUni, f.Type) require.EqualValues(t, 0x7331, f.StreamLimit) @@ -31,11 +31,11 @@ func TestParseStreamsBlockedFrameUnidirectional(t *testing.T) { func TestParseStreamsBlockedFrameErrorsOnEOFs(t *testing.T) { data := encodeVarInt(0x12345678) - _, l, err := parseStreamsBlockedFrame(data, bidiStreamBlockedFrameType, protocol.Version1) + _, l, err := parseStreamsBlockedFrame(data, FrameTypeBidiStreamBlocked, protocol.Version1) require.NoError(t, err) require.Equal(t, len(data), l) for i := range data { - _, _, err := parseStreamsBlockedFrame(data[:i], bidiStreamBlockedFrameType, protocol.Version1) + _, _, err := parseStreamsBlockedFrame(data[:i], FrameTypeBidiStreamBlocked, protocol.Version1) require.Equal(t, io.EOF, err) } } @@ -58,7 +58,7 @@ func TestParseStreamsBlockedFrameMaxStreamCount(t *testing.T) { typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] - frame, l, err := parseStreamsBlockedFrame(b, typ, protocol.Version1) + frame, l, err := parseStreamsBlockedFrame(b, FrameType(typ), protocol.Version1) require.NoError(t, err) require.Equal(t, f, frame) require.Equal(t, len(b), l) @@ -84,7 +84,7 @@ func TestParseStreamsBlockedFrameErrorOnTooLargeStreamCount(t *testing.T) { typ, l, err := quicvarint.Parse(b) require.NoError(t, err) b = b[l:] - _, _, err = parseStreamsBlockedFrame(b, typ, protocol.Version1) + _, _, err = parseStreamsBlockedFrame(b, FrameType(typ), protocol.Version1) require.EqualError(t, err, fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1)) }) } @@ -97,7 +97,7 @@ func TestWriteStreamsBlockedFrameBidirectional(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{bidiStreamBlockedFrameType} + expected := []byte{byte(FrameTypeBidiStreamBlocked)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) @@ -110,7 +110,7 @@ func TestWriteStreamsBlockedFrameUnidirectional(t *testing.T) { } b, err := f.Append(nil, protocol.Version1) require.NoError(t, err) - expected := []byte{uniStreamBlockedFrameType} + expected := []byte{byte(FrameTypeUniStreamBlocked)} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) require.Equal(t, expected, b) require.Equal(t, int(f.Length(protocol.Version1)), len(b)) diff --git a/packet_packer_test.go b/packet_packer_test.go index 5f2ec0c14c9..748c061e980 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -736,10 +736,15 @@ func TestPackLongHeaderPadToAtLeast4Bytes(t *testing.T) { require.Equal(t, []byte{0, 0}, data[:2]) // ...followed by the PING frame frameParser := wire.NewFrameParser(false, false) - l, frame, err := frameParser.ParseNext(data[2:], protocol.EncryptionHandshake, protocol.Version1) + + frameType, lt, err := frameParser.ParseType(data[2:], protocol.EncryptionHandshake) + require.NoError(t, err) + require.Equal(t, 1, lt) + frame, l, err := frameParser.ParseLessCommonFrame(frameType, data[2+lt:], protocol.Version1) require.NoError(t, err) require.IsType(t, &wire.PingFrame{}, frame) - require.Equal(t, sealer.Overhead(), len(data)-2-l) + require.Zero(t, l) + require.Equal(t, sealer.Overhead(), len(data)-2-lt) } func TestPackShortHeaderPadToAtLeast4Bytes(t *testing.T) { @@ -774,10 +779,15 @@ func TestPackShortHeaderPadToAtLeast4Bytes(t *testing.T) { // ... followed by the STREAM frame frameParser := wire.NewFrameParser(false, false) - frameLen, frame, err := frameParser.ParseNext(payload[1:], protocol.Encryption1RTT, protocol.Version1) + frameType, l, err := frameParser.ParseType(payload[1:], protocol.Encryption1RTT) + require.NoError(t, err) + require.Equal(t, 1, l) + require.True(t, frameType.IsStreamFrameType()) + + frame, frameLen, err := wire.ParseStreamFrame(payload[1+l:], frameType, protocol.Version1) require.NoError(t, err) require.Equal(t, f, frame) - require.Equal(t, len(payload)-1, frameLen) + require.Equal(t, len(payload)-2, frameLen) } func TestPackInitialProbePacket(t *testing.T) { From 0264fbc02e94a24370ff68005e02aa53f10add58 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 23 Sep 2025 00:26:32 +0800 Subject: [PATCH 14/14] drop initial packets when the handshake is confirmed --- connection.go | 3 +++ connection_test.go | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/connection.go b/connection.go index 2251af08d5c..8d5af6f2749 100644 --- a/connection.go +++ b/connection.go @@ -845,6 +845,9 @@ func (c *Conn) handleHandshakeComplete(now time.Time) error { } func (c *Conn) handleHandshakeConfirmed(now time.Time) error { + if err := c.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil { + return err + } if err := c.dropEncryptionLevel(protocol.EncryptionHandshake, now); err != nil { return err } diff --git a/connection_test.go b/connection_test.go index 30d7dd8c15c..d7b1d57db84 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1065,7 +1065,7 @@ func TestConnectionHandshakeServer(t *testing.T) { data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1) require.NoError(t, err) - cs.EXPECT().DiscardInitialKeys() + cs.EXPECT().DiscardInitialKeys().Times(2) gomock.InOrder( cs.EXPECT().StartHandshake(gomock.Any()), cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), @@ -1216,6 +1216,7 @@ func testConnectionHandshakeClient(t *testing.T, usePreferredAddress bool) { unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return( &unpackedPacket{hdr: hdr, encryptionLevel: protocol.Encryption1RTT, data: data}, nil, ), + cs.EXPECT().DiscardInitialKeys(), cs.EXPECT().SetHandshakeConfirmed(), tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {