diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 76d76ac8ade16..b403f7ff83a8e 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -819,102 +819,105 @@ func TestSSH(t *testing.T) { tmpdir := tempDirUnixSocket(t) localSock := filepath.Join(tmpdir, "local.sock") - l, err := net.Listen("unix", localSock) - require.NoError(t, err) - defer l.Close() remoteSock := path.Join(tmpdir, "remote.sock") for i := 0; i < 2; i++ { - t.Logf("connect %d of 2", i+1) - inv, root := clitest.New(t, - "ssh", - workspace.Name, - "--remote-forward", - remoteSock+":"+localSock, - ) - fsn := clitest.NewFakeSignalNotifier(t) - inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext) - inv.Stdout = io.Discard - inv.Stderr = io.Discard - - clitest.SetupConfig(t, client, root) - cmdDone := tGo(t, func() { - err := inv.WithContext(ctx).Run() - assert.Error(t, err) - }) + func() { // Function scope for defer. + t.Logf("Connect %d/2", i+1) + + inv, root := clitest.New(t, + "ssh", + workspace.Name, + "--remote-forward", + remoteSock+":"+localSock, + ) + fsn := clitest.NewFakeSignalNotifier(t) + inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext) + inv.Stdout = io.Discard + inv.Stderr = io.Discard - // accept a single connection - msgs := make(chan string, 1) - go func() { - conn, err := l.Accept() - if !assert.NoError(t, err) { - return - } - msg, err := io.ReadAll(conn) - if !assert.NoError(t, err) { - return - } - msgs <- string(msg) - }() + clitest.SetupConfig(t, client, root) + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.Error(t, err) + }) - // Unfortunately, there is a race in crypto/ssh where it sends the request to forward - // unix sockets before it is prepared to receive the response, meaning that even after - // the socket exists on the file system, the client might not be ready to accept the - // channel. - // - // https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33 - // - // To work around this, we attempt to send messages in a loop until one succeeds - success := make(chan struct{}) - done := make(chan struct{}) - go func() { - defer close(done) - var ( - conn net.Conn - err error - ) - for { - time.Sleep(testutil.IntervalMedium) - select { - case <-ctx.Done(): - t.Error("timeout") - return - case <-success: + // accept a single connection + msgs := make(chan string, 1) + l, err := net.Listen("unix", localSock) + require.NoError(t, err) + defer l.Close() + go func() { + conn, err := l.Accept() + if !assert.NoError(t, err) { return - default: - // Ok } - conn, err = net.Dial("unix", remoteSock) - if err != nil { - t.Logf("dial error: %s", err) - continue - } - _, err = conn.Write([]byte("test")) - if err != nil { - t.Logf("write error: %s", err) + msg, err := io.ReadAll(conn) + if !assert.NoError(t, err) { + return } - err = conn.Close() - if err != nil { - t.Logf("close error: %s", err) + msgs <- string(msg) + }() + + // Unfortunately, there is a race in crypto/ssh where it sends the request to forward + // unix sockets before it is prepared to receive the response, meaning that even after + // the socket exists on the file system, the client might not be ready to accept the + // channel. + // + // https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33 + // + // To work around this, we attempt to send messages in a loop until one succeeds + success := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + var ( + conn net.Conn + err error + ) + for { + time.Sleep(testutil.IntervalMedium) + select { + case <-ctx.Done(): + t.Error("timeout") + return + case <-success: + return + default: + // Ok + } + conn, err = net.Dial("unix", remoteSock) + if err != nil { + t.Logf("dial error: %s", err) + continue + } + _, err = conn.Write([]byte("test")) + if err != nil { + t.Logf("write error: %s", err) + } + err = conn.Close() + if err != nil { + t.Logf("close error: %s", err) + } } - } - }() + }() - msg := testutil.RequireRecvCtx(ctx, t, msgs) - require.Equal(t, "test", msg) - close(success) - fsn.Notify() - <-cmdDone - fsn.AssertStopped() - // wait for dial goroutine to complete - _ = testutil.RequireRecvCtx(ctx, t, done) - - // wait for the remote socket to get cleaned up before retrying, - // because cleaning up the socket happens asynchronously, and we - // might connect to an old listener on the agent side. - require.Eventually(t, func() bool { - _, err = os.Stat(remoteSock) - return xerrors.Is(err, os.ErrNotExist) - }, testutil.WaitShort, testutil.IntervalFast) + msg := testutil.RequireRecvCtx(ctx, t, msgs) + require.Equal(t, "test", msg) + close(success) + fsn.Notify() + <-cmdDone + fsn.AssertStopped() + // wait for dial goroutine to complete + _ = testutil.RequireRecvCtx(ctx, t, done) + + // wait for the remote socket to get cleaned up before retrying, + // because cleaning up the socket happens asynchronously, and we + // might connect to an old listener on the agent side. + require.Eventually(t, func() bool { + _, err = os.Stat(remoteSock) + return xerrors.Is(err, os.ErrNotExist) + }, testutil.WaitShort, testutil.IntervalFast) + }() } })