Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Make SetDeadline on NetConn not always close Conn
NetConn has to close the connection to interrupt in progress reads
and writes. However, it can block reads and writes that occur
after the deadline instead of closing the connection.

Closes #228
  • Loading branch information
nhooyr committed May 18, 2020
commit 0a61ffe87a498f8ff9fef8020bee799cfa4f927f
9 changes: 9 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ func (m *mu) forceLock() {
m.ch <- struct{}{}
}

func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}

func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
Expand Down
128 changes: 85 additions & 43 deletions netconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -28,9 +28,10 @@ import (
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
// and "websocket/unknown-addr" for String.
Expand All @@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
nc := &netConn{
c: c,
msgType: msgType,
readMu: newMu(c),
writeMu: newMu(c),
}

var cancel context.CancelFunc
nc.writeContext, cancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
var writeCancel context.CancelFunc
nc.writeCtx, writeCancel = context.WithCancel(ctx)
var readCancel context.CancelFunc
nc.readCtx, readCancel = context.WithCancel(ctx)

nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.writeMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active write goroutine and so we should cancel the context.
writeCancel()
return
}
defer nc.writeMu.unlock()

// Prevents future writes from writing until the deadline is reset.
atomic.StoreInt64(&nc.writeExpired, 1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}

nc.readContext, cancel = context.WithCancel(ctx)
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.readMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active read goroutine and so we should cancel the context.
readCancel()
return
}
defer nc.readMu.unlock()

// Prevents future reads from reading until the deadline is reset.
atomic.StoreInt64(&nc.readExpired, 1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
Expand All @@ -64,59 +91,72 @@ type netConn struct {
msgType MessageType

writeTimer *time.Timer
writeContext context.Context
writeMu *mu
writeExpired int64
writeCtx context.Context

readTimer *time.Timer
readContext context.Context

readMu sync.Mutex
eofed bool
reader io.Reader
readMu *mu
readExpired int64
readCtx context.Context
readEOFed bool
reader io.Reader
}

var _ net.Conn = &netConn{}

func (c *netConn) Close() error {
return c.c.Close(StatusNormalClosure, "")
func (nc *netConn) Close() error {
return nc.c.Close(StatusNormalClosure, "")
}

func (c *netConn) Write(p []byte) (int, error) {
err := c.c.Write(c.writeContext, c.msgType, p)
func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()

if atomic.LoadInt64(&nc.writeExpired) == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}

err := nc.c.Write(nc.writeCtx, nc.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}

func (c *netConn) Read(p []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()

if atomic.LoadInt64(&nc.readExpired) == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}

if c.eofed {
if nc.readEOFed {
return 0, io.EOF
}

if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if nc.reader == nil {
typ, r, err := nc.c.Reader(nc.readCtx)
if err != nil {
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
c.eofed = true
nc.readEOFed = true
return 0, io.EOF
}
return 0, err
}
if typ != c.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
c.c.Close(StatusUnsupportedData, err.Error())
if typ != nc.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
nc.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
c.reader = r
nc.reader = r
}

n, err := c.reader.Read(p)
n, err := nc.reader.Read(p)
if err == io.EOF {
c.reader = nil
nc.reader = nil
err = nil
}
return n, err
Expand All @@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}

func (c *netConn) RemoteAddr() net.Addr {
func (nc *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}

func (c *netConn) LocalAddr() net.Addr {
func (nc *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}

func (c *netConn) SetDeadline(t time.Time) error {
c.SetWriteDeadline(t)
c.SetReadDeadline(t)
func (nc *netConn) SetDeadline(t time.Time) error {
nc.SetWriteDeadline(t)
nc.SetReadDeadline(t)
return nil
}

func (c *netConn) SetWriteDeadline(t time.Time) error {
func (nc *netConn) SetWriteDeadline(t time.Time) error {
atomic.StoreInt64(&nc.writeExpired, 0)
if t.IsZero() {
c.writeTimer.Stop()
nc.writeTimer.Stop()
} else {
c.writeTimer.Reset(t.Sub(time.Now()))
nc.writeTimer.Reset(t.Sub(time.Now()))
}
return nil
}

func (c *netConn) SetReadDeadline(t time.Time) error {
func (nc *netConn) SetReadDeadline(t time.Time) error {
atomic.StoreInt64(&nc.readExpired, 0)
if t.IsZero() {
c.readTimer.Stop()
nc.readTimer.Stop()
} else {
c.readTimer.Reset(t.Sub(time.Now()))
nc.readTimer.Reset(t.Sub(time.Now()))
}
return nil
}
32 changes: 32 additions & 0 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,35 @@ const (
// MessageBinary is for binary messages like protobufs.
MessageBinary
)

type mu struct {
c *Conn
ch chan struct{}
}

func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}

func (m *mu) forceLock() {
m.ch <- struct{}{}
}

func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}

func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}