diff --git a/session.go b/session.go index 6a6e21e..b991e28 100644 --- a/session.go +++ b/session.go @@ -127,21 +127,25 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne type session struct { sync.Mutex gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - x11 *X11 - winch chan Window - env []string - ptyCb PtyCallback - x11Cb X11Callback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context + conn *gossh.ServerConn + handler Handler + subsystemHandlers map[string]SubsystemHandler + handled bool + exited bool + pty *Pty + x11 *X11 + winch chan Window + env []string + ptyCb PtyCallback + x11Cb X11Callback + sessReqCb SessionRequestCallback + rawCmd string + subsystem string + ctx Context + // sigMu protects sigCh and sigBuf, it is made separate from the + // session mutex to reduce the risk of deadlocks while we process + // buffered signals. + sigMu sync.Mutex sigCh chan<- Signal sigBuf []Signal breakCh chan<- bool @@ -247,16 +251,30 @@ func (sess *session) X11() (X11, bool) { } func (sess *session) Signals(c chan<- Signal) { - sess.Lock() - defer sess.Unlock() + sess.sigMu.Lock() sess.sigCh = c - if len(sess.sigBuf) > 0 { - go func() { - for _, sig := range sess.sigBuf { - sess.sigCh <- sig - } - }() + if len(sess.sigBuf) == 0 || sess.sigCh == nil { + sess.sigMu.Unlock() + return } + // If we have buffered signals, we need to send them whilst + // holding the signal mutex to avoid race conditions on sigCh + // and sigBuf. We also guarantee that calling Signals(ch) + // followed by Signals(nil) will have depleted the sigBuf when + // the second call returns and that there will be no more + // signals on ch. This is done in a goroutine so we can return + // early and allow the caller to set up processing for the + // channel even after calling Signals(ch). + go func() { + // Here we're relying on the mutex being locked in the outer + // Signals() function, so we simply unlock it when we're done. + defer sess.sigMu.Unlock() + + for _, sig := range sess.sigBuf { + sess.sigCh <- sig + } + sess.sigBuf = nil + }() } func (sess *session) Break(c chan<- bool) { @@ -379,7 +397,7 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) { case "signal": var payload struct{ Signal string } gossh.Unmarshal(req.Payload, &payload) - sess.Lock() + sess.sigMu.Lock() if sess.sigCh != nil { sess.sigCh <- Signal(payload.Signal) } else { @@ -387,7 +405,7 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) { sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) } } - sess.Unlock() + sess.sigMu.Unlock() case "pty-req": if sess.handled || sess.pty != nil { req.Reply(false, nil) diff --git a/session_test.go b/session_test.go index 0db4702..7514993 100644 --- a/session_test.go +++ b/session_test.go @@ -390,6 +390,111 @@ func TestSignals(t *testing.T) { } } +func TestSignalsRaceDeregisterAndReregister(t *testing.T) { + t.Parallel() + + numSignals := 128 + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + // Channels to synchronize the handler and the test. + handlerPreRegister := make(chan struct{}) + handlerPostRegister := make(chan struct{}) + signalInit := make(chan struct{}) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // Single buffer slot, this is to make sure we don't miss + // signals or send on nil a channel. + signals := make(chan Signal, 1) + + <-handlerPreRegister // Wait for initial signal buffering. + + // Register signals. + s.Signals(signals) + close(handlerPostRegister) // Trigger post register signaling. + + // Process signals so that we can don't see a deadlock. + discarded := 0 + discardDone := make(chan struct{}) + go func() { + defer close(discardDone) + for range signals { + discarded++ + } + }() + // Deregister signals. + s.Signals(nil) + // Close channel to close goroutine and ensure we don't send + // on a closed channel. + close(signals) + <-discardDone + + signals = make(chan Signal, 1) + consumeDone := make(chan struct{}) + go func() { + defer close(consumeDone) + + for i := 0; i < numSignals-discarded; i++ { + select { + case sig := <-signals: + if sig != SIGHUP { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGHUP, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + } + }() + + // Re-register signals and make sure we don't miss any. + s.Signals(signals) + close(signalInit) + + <-consumeDone + }, + }, nil) + defer cleanup() + + go func() { + // Send 1/4th directly to buffer. + for i := 0; i < numSignals/4; i++ { + session.Signal(gossh.SIGHUP) + } + close(handlerPreRegister) + <-handlerPostRegister + // Send 1/4th to channel or buffer. + for i := 0; i < numSignals/4; i++ { + session.Signal(gossh.SIGHUP) + } + // Send final 1/2 to channel. + <-signalInit + for i := 0; i < numSignals/2; i++ { + session.Signal(gossh.SIGHUP) + } + }() + + go func() { + errChan <- session.Run("") + }() + + select { + case err := <-errChan: + close(doneChan) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for session to exit") + } +} + func TestBreakWithChanRegistered(t *testing.T) { t.Parallel()