|
8 | 8 | "errors"
|
9 | 9 | "fmt"
|
10 | 10 | "io"
|
| 11 | + "net" |
11 | 12 | "net/http"
|
12 | 13 | "net/http/httptest"
|
13 | 14 | "os"
|
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
|
460 | 461 | }
|
461 | 462 |
|
462 | 463 | func BenchmarkConn(b *testing.B) {
|
463 |
| - var benchCases = []struct { |
| 464 | + benchCases := []struct { |
464 | 465 | name string
|
465 | 466 | mode websocket.CompressionMode
|
466 | 467 | }{
|
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
|
625 | 626 | }()
|
626 | 627 | }
|
627 | 628 | }
|
| 629 | + |
| 630 | +func TestConnClosePropagation(t *testing.T) { |
| 631 | + t.Parallel() |
| 632 | + |
| 633 | + want := []byte("hello") |
| 634 | + keepWriting := func(c *websocket.Conn) <-chan error { |
| 635 | + return xsync.Go(func() error { |
| 636 | + for { |
| 637 | + err := c.Write(context.Background(), websocket.MessageText, want) |
| 638 | + if err != nil { |
| 639 | + return err |
| 640 | + } |
| 641 | + } |
| 642 | + }) |
| 643 | + } |
| 644 | + keepReading := func(c *websocket.Conn) <-chan error { |
| 645 | + return xsync.Go(func() error { |
| 646 | + for { |
| 647 | + _, got, err := c.Read(context.Background()) |
| 648 | + if err != nil { |
| 649 | + return err |
| 650 | + } |
| 651 | + if !bytes.Equal(want, got) { |
| 652 | + return fmt.Errorf("unexpected message: want %q, got %q", want, got) |
| 653 | + } |
| 654 | + } |
| 655 | + }) |
| 656 | + } |
| 657 | + checkReadErr := func(t *testing.T, err error) { |
| 658 | + // Check read error (output depends on when read is called in relation to connection closure). |
| 659 | + var ce websocket.CloseError |
| 660 | + if errors.As(err, &ce) { |
| 661 | + assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) |
| 662 | + } else { |
| 663 | + assert.ErrorIs(t, net.ErrClosed, err) |
| 664 | + } |
| 665 | + } |
| 666 | + checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) { |
| 667 | + for _, c := range conn { |
| 668 | + // Check write error. |
| 669 | + err := c.Write(context.Background(), websocket.MessageText, want) |
| 670 | + assert.ErrorIs(t, net.ErrClosed, err) |
| 671 | + |
| 672 | + _, _, err = c.Read(context.Background()) |
| 673 | + checkReadErr(t, err) |
| 674 | + } |
| 675 | + } |
| 676 | + |
| 677 | + t.Run("CloseOtherSideDuringWrite", func(t *testing.T) { |
| 678 | + tt, this, other := newConnTest(t, nil, nil) |
| 679 | + |
| 680 | + _ = this.CloseRead(tt.ctx) |
| 681 | + thisWriteErr := keepWriting(this) |
| 682 | + |
| 683 | + _, got, err := other.Read(tt.ctx) |
| 684 | + assert.Success(t, err) |
| 685 | + assert.Equal(t, "msg", want, got) |
| 686 | + |
| 687 | + err = other.Close(websocket.StatusNormalClosure, "") |
| 688 | + assert.Success(t, err) |
| 689 | + |
| 690 | + select { |
| 691 | + case err := <-thisWriteErr: |
| 692 | + assert.ErrorIs(t, net.ErrClosed, err) |
| 693 | + case <-tt.ctx.Done(): |
| 694 | + t.Fatal(tt.ctx.Err()) |
| 695 | + } |
| 696 | + |
| 697 | + checkConnErrs(t, this, other) |
| 698 | + }) |
| 699 | + t.Run("CloseThisSideDuringWrite", func(t *testing.T) { |
| 700 | + tt, this, other := newConnTest(t, nil, nil) |
| 701 | + |
| 702 | + _ = this.CloseRead(tt.ctx) |
| 703 | + thisWriteErr := keepWriting(this) |
| 704 | + otherReadErr := keepReading(other) |
| 705 | + |
| 706 | + err := this.Close(websocket.StatusNormalClosure, "") |
| 707 | + assert.Success(t, err) |
| 708 | + |
| 709 | + select { |
| 710 | + case err := <-thisWriteErr: |
| 711 | + assert.ErrorIs(t, net.ErrClosed, err) |
| 712 | + case <-tt.ctx.Done(): |
| 713 | + t.Fatal(tt.ctx.Err()) |
| 714 | + } |
| 715 | + |
| 716 | + select { |
| 717 | + case err := <-otherReadErr: |
| 718 | + checkReadErr(t, err) |
| 719 | + case <-tt.ctx.Done(): |
| 720 | + t.Fatal(tt.ctx.Err()) |
| 721 | + } |
| 722 | + |
| 723 | + checkConnErrs(t, this, other) |
| 724 | + }) |
| 725 | + t.Run("CloseOtherSideDuringRead", func(t *testing.T) { |
| 726 | + tt, this, other := newConnTest(t, nil, nil) |
| 727 | + |
| 728 | + _ = other.CloseRead(tt.ctx) |
| 729 | + errs := keepReading(this) |
| 730 | + |
| 731 | + err := other.Write(tt.ctx, websocket.MessageText, want) |
| 732 | + assert.Success(t, err) |
| 733 | + |
| 734 | + err = other.Close(websocket.StatusNormalClosure, "") |
| 735 | + assert.Success(t, err) |
| 736 | + |
| 737 | + select { |
| 738 | + case err := <-errs: |
| 739 | + checkReadErr(t, err) |
| 740 | + case <-tt.ctx.Done(): |
| 741 | + t.Fatal(tt.ctx.Err()) |
| 742 | + } |
| 743 | + |
| 744 | + checkConnErrs(t, this, other) |
| 745 | + }) |
| 746 | + t.Run("CloseThisSideDuringRead", func(t *testing.T) { |
| 747 | + tt, this, other := newConnTest(t, nil, nil) |
| 748 | + |
| 749 | + thisReadErr := keepReading(this) |
| 750 | + otherReadErr := keepReading(other) |
| 751 | + |
| 752 | + err := other.Write(tt.ctx, websocket.MessageText, want) |
| 753 | + assert.Success(t, err) |
| 754 | + |
| 755 | + err = this.Close(websocket.StatusNormalClosure, "") |
| 756 | + assert.Success(t, err) |
| 757 | + |
| 758 | + select { |
| 759 | + case err := <-thisReadErr: |
| 760 | + checkReadErr(t, err) |
| 761 | + case <-tt.ctx.Done(): |
| 762 | + t.Fatal(tt.ctx.Err()) |
| 763 | + } |
| 764 | + |
| 765 | + select { |
| 766 | + case err := <-otherReadErr: |
| 767 | + checkReadErr(t, err) |
| 768 | + case <-tt.ctx.Done(): |
| 769 | + t.Fatal(tt.ctx.Err()) |
| 770 | + } |
| 771 | + |
| 772 | + checkConnErrs(t, this, other) |
| 773 | + }) |
| 774 | +} |
0 commit comments